rai-bench / src /components /ScoreDistribution.tsx
rohanjaggi
fix: logo size on render
d5a5fae
Raw
History Blame Contribute Delete
8.89 kB
'use client'
import { useState } from 'react'
import type { ModelData } from '@/lib/types'
import { LabLogo } from '@/components/LabLogo'
import { creatorColor } from '@/lib/utils'
const STRIP_ROWS = [
{ key: 'luc' as const, label: 'Refusal Rate', invert: false, color: '#00C0F3', fmt: (v: number) => `${(v * 100).toFixed(0)}%` },
{ key: 'rag' as const, label: 'RAG Score', invert: false, color: '#F4333D', fmt: (v: number) => `${(v * 100).toFixed(0)}%` },
{ key: 'fairness' as const, label: 'Fairness', invert: true, color: '#BA2FA2', fmt: (v: number) => v.toFixed(3) },
]
function getAvg(m: ModelData, key: 'luc' | 'rag' | 'fairness'): number | null {
if (key === 'luc') return m.luc.avg
if (key === 'rag') return m.rag.avg
return m.fairness.avg
}
export default function ScoreDistribution({ models, maxFairness }: { models: ModelData[]; maxFairness: number }) {
const active = models.filter(m => !m.archived)
const [tooltip, setTooltip] = useState<{
x: number; y: number; model: string; val: string; creator: string
} | null>(null)
const VW = 800
const ROW_H = 120
const LBL_W = 85
const PAD_R = 85
const AXIS_W = VW - LBL_W - PAD_R
const DOT_R = 7
const VH = STRIP_ROWS.length * ROW_H
const TT_W = 168
const TT_H = 36
const VIOLIN_H = 40
const VIOLIN_BW = AXIS_W * 0.07
return (
<div style={{ position: 'relative', width: '100%' }}>
<svg
viewBox={`0 0 ${VW} ${VH}`}
style={{ width: '100%', height: 'auto', display: 'block' }}
aria-label="Score distribution across all models"
onMouseLeave={() => setTooltip(null)}
>
{STRIP_ROWS.map((row, ri) => {
const y0 = ri * ROW_H
const yMid = y0 + ROW_H / 2
const vals = active
.map(m => ({ val: getAvg(m, row.key), creator: m.creator, model: m.model }))
.filter(d => d.val !== null) as { val: number; creator: string; model: string }[]
const rawMax = row.key === 'fairness' ? maxFairness : 1
const mean = vals.reduce((s, d) => s + d.val, 0) / vals.length
const toX = (v: number) => {
const pct = row.invert ? 1 - v / rawMax : v / rawMax
return LBL_W + pct * AXIS_W
}
const meanX = toX(mean)
return (
<g key={row.key}>
{/* Alternating row bg */}
{ri % 2 === 1 && (
<rect x={0} y={y0} width={VW} height={ROW_H} fill="var(--bg-0)" fillOpacity={0.5} />
)}
{/* Row divider */}
{ri > 0 && (
<line x1={0} y1={y0} x2={VW} y2={y0} stroke="var(--border-0)" strokeWidth={1} />
)}
{/* Axis line */}
<line x1={LBL_W} y1={yMid} x2={VW - PAD_R} y2={yMid}
stroke="var(--border-1)" strokeWidth={1} />
{/* Half violin */}
{vals.length > 1 && (() => {
const N = 90
const sampleXs = Array.from({ length: N }, (_, i) => LBL_W + (i / (N - 1)) * AXIS_W)
const densities = sampleXs.map(px =>
vals.reduce((s, d) => s + Math.exp(-0.5 * ((px - toX(d.val)) / VIOLIN_BW) ** 2), 0) /
(vals.length * VIOLIN_BW * Math.sqrt(2 * Math.PI))
)
const maxD = Math.max(...densities)
if (maxD <= 0) return null
const pts = sampleXs.map((px, i) =>
`L${px.toFixed(1)},${(yMid - (densities[i] / maxD) * VIOLIN_H).toFixed(1)}`
).join(' ')
const d = `M${LBL_W},${yMid} ${pts} L${VW - PAD_R},${yMid} Z`
return (
<path
key="violin"
d={d}
fill="var(--text-3)"
fillOpacity={0.15}
stroke="var(--text-3)"
strokeOpacity={0.45}
strokeWidth={1.2}
strokeLinejoin="round"
/>
)
})()}
{/* Tick marks + labels */}
{[0, 0.25, 0.5, 0.75, 1].map(t => {
const tx = LBL_W + t * AXIS_W
const tickLabel = row.invert
? (t === 0 ? rawMax.toFixed(2) : t === 1 ? '0' : (rawMax * (1 - t)).toFixed(2))
: `${Math.round(t * 100)}%`
return (
<g key={t}>
<line x1={tx} y1={yMid - 5} x2={tx} y2={yMid + 5}
stroke="var(--border-2)" strokeWidth={1} />
<text x={tx} y={y0 + ROW_H - 8}
textAnchor="middle" fontSize={9} fill="var(--text-3)" fontFamily="inherit">
{tickLabel}
</text>
</g>
)
})}
{/* Mean line */}
<line x1={meanX} y1={y0 + 10} x2={meanX} y2={y0 + ROW_H - 20}
stroke="var(--text-3)" strokeOpacity={0.9} strokeWidth={1.2} strokeDasharray="4 3" />
<text x={meanX + 4} y={y0 + 20} fontSize={8} fill="var(--text-3)" fillOpacity={0.9} fontFamily="inherit">
avg {row.fmt(mean)}
</text>
{/* Dots — render hovered dot last so it stays on top */}
{[...vals].sort((a, b) => {
const aH = tooltip?.model === a.model ? 1 : 0
const bH = tooltip?.model === b.model ? 1 : 0
return aH - bH
}).map((d, i) => {
const isHovered = tooltip?.model === d.model
const dimmed = tooltip !== null && !isHovered
const cx = toX(d.val)
return (
<g key={d.model}>
<circle
cx={cx}
cy={yMid}
r={isHovered ? 20 : DOT_R}
fill={isHovered ? 'var(--bg-1)' : creatorColor(d.creator)}
fillOpacity={dimmed ? 0.35 : 1}
stroke={isHovered ? creatorColor(d.creator) : 'var(--bg-1)'}
strokeWidth={isHovered ? 2.5 : 1.2}
style={{
cursor: 'pointer',
transition: 'r 0.15s, fill-opacity 0.15s',
filter: dimmed ? 'saturate(0.15)' : undefined,
}}
onMouseEnter={() => setTooltip({
x: cx,
y: yMid,
model: d.model,
val: row.fmt(d.val),
creator: d.creator,
})}
/>
</g>
)
})}
{/* Row label */}
<text x={LBL_W - 16} y={yMid + 5}
textAnchor="end" fontSize={12} fontWeight="700"
fill="var(--text-1)" fontFamily="inherit">
{row.label}
</text>
</g>
)
})}
{/* Tooltip — rendered last so it appears above all dots */}
{tooltip && (() => {
const ttX = Math.max(0, Math.min(tooltip.x - TT_W / 2, VW - TT_W))
const above = tooltip.y - DOT_R - TT_H - 8
const ttY = above < 2 ? tooltip.y + DOT_R + 6 : above
const cc = creatorColor(tooltip.creator)
return (
<g style={{ pointerEvents: 'none' }}>
<rect x={ttX} y={ttY} width={TT_W} height={TT_H} rx={5}
fill="var(--bg-1)" stroke="var(--border-2)" strokeWidth={1.2} />
<rect x={ttX} y={ttY} width={3} height={TT_H} rx={2} fill={cc} />
<text x={ttX + 10} y={ttY + 13}
fontSize={10} fontWeight="700" fill="var(--text-0)" fontFamily="inherit">
{tooltip.model}
</text>
<text x={ttX + 10} y={ttY + 27}
fontSize={9} fill={cc} fontFamily="inherit">
{tooltip.val}
</text>
</g>
)
})()}
</svg>
{/* Logo overlays — one per row for the hovered model */}
{tooltip && STRIP_ROWS.map((row, ri) => {
const m = active.find(m => m.model === tooltip.model)
if (!m) return null
const val = getAvg(m, row.key)
if (val === null) return null
const rawMax = row.key === 'fairness' ? maxFairness : 1
const pct = row.invert ? 1 - val / rawMax : val / rawMax
const cx = LBL_W + pct * AXIS_W
const yMid = ri * ROW_H + ROW_H / 2
const pctX = cx / VW * 100
const pctY = yMid / VH * 100
return (
<div key={row.key} style={{
position: 'absolute',
left: `${pctX}%`,
top: `${pctY}%`,
transform: 'translate(-50%, -50%)',
width: 40,
height: 40,
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
pointerEvents: 'none',
}}>
<LabLogo creator={tooltip.creator} size={32} />
</div>
)
})}
</div>
)
}