Spaces:
Running
Running
| 'use client' | |
| import { useState } from 'react' | |
| import type { ModelData } from '@/lib/types' | |
| import { LabLogo } from '@/components/LabLogo' | |
| import { creatorColor } from '@/lib/utils' | |
| const STRIP_ROWS = [ | |
| { key: 'luc' as const, label: 'Refusal Rate', invert: false, color: '#00C0F3', fmt: (v: number) => `${(v * 100).toFixed(0)}%` }, | |
| { key: 'rag' as const, label: 'RAG Score', invert: false, color: '#F4333D', fmt: (v: number) => `${(v * 100).toFixed(0)}%` }, | |
| { key: 'fairness' as const, label: 'Fairness', invert: true, color: '#BA2FA2', fmt: (v: number) => v.toFixed(3) }, | |
| ] | |
| function getAvg(m: ModelData, key: 'luc' | 'rag' | 'fairness'): number | null { | |
| if (key === 'luc') return m.luc.avg | |
| if (key === 'rag') return m.rag.avg | |
| return m.fairness.avg | |
| } | |
| export default function ScoreDistribution({ models, maxFairness }: { models: ModelData[]; maxFairness: number }) { | |
| const active = models.filter(m => !m.archived) | |
| const [tooltip, setTooltip] = useState<{ | |
| x: number; y: number; model: string; val: string; creator: string | |
| } | null>(null) | |
| const VW = 800 | |
| const ROW_H = 120 | |
| const LBL_W = 85 | |
| const PAD_R = 85 | |
| const AXIS_W = VW - LBL_W - PAD_R | |
| const DOT_R = 7 | |
| const VH = STRIP_ROWS.length * ROW_H | |
| const TT_W = 168 | |
| const TT_H = 36 | |
| const VIOLIN_H = 40 | |
| const VIOLIN_BW = AXIS_W * 0.07 | |
| return ( | |
| <div style={{ position: 'relative', width: '100%' }}> | |
| <svg | |
| viewBox={`0 0 ${VW} ${VH}`} | |
| style={{ width: '100%', height: 'auto', display: 'block' }} | |
| aria-label="Score distribution across all models" | |
| onMouseLeave={() => setTooltip(null)} | |
| > | |
| {STRIP_ROWS.map((row, ri) => { | |
| const y0 = ri * ROW_H | |
| const yMid = y0 + ROW_H / 2 | |
| const vals = active | |
| .map(m => ({ val: getAvg(m, row.key), creator: m.creator, model: m.model })) | |
| .filter(d => d.val !== null) as { val: number; creator: string; model: string }[] | |
| const rawMax = row.key === 'fairness' ? maxFairness : 1 | |
| const mean = vals.reduce((s, d) => s + d.val, 0) / vals.length | |
| const toX = (v: number) => { | |
| const pct = row.invert ? 1 - v / rawMax : v / rawMax | |
| return LBL_W + pct * AXIS_W | |
| } | |
| const meanX = toX(mean) | |
| return ( | |
| <g key={row.key}> | |
| {/* Alternating row bg */} | |
| {ri % 2 === 1 && ( | |
| <rect x={0} y={y0} width={VW} height={ROW_H} fill="var(--bg-0)" fillOpacity={0.5} /> | |
| )} | |
| {/* Row divider */} | |
| {ri > 0 && ( | |
| <line x1={0} y1={y0} x2={VW} y2={y0} stroke="var(--border-0)" strokeWidth={1} /> | |
| )} | |
| {/* Axis line */} | |
| <line x1={LBL_W} y1={yMid} x2={VW - PAD_R} y2={yMid} | |
| stroke="var(--border-1)" strokeWidth={1} /> | |
| {/* Half violin */} | |
| {vals.length > 1 && (() => { | |
| const N = 90 | |
| const sampleXs = Array.from({ length: N }, (_, i) => LBL_W + (i / (N - 1)) * AXIS_W) | |
| const densities = sampleXs.map(px => | |
| vals.reduce((s, d) => s + Math.exp(-0.5 * ((px - toX(d.val)) / VIOLIN_BW) ** 2), 0) / | |
| (vals.length * VIOLIN_BW * Math.sqrt(2 * Math.PI)) | |
| ) | |
| const maxD = Math.max(...densities) | |
| if (maxD <= 0) return null | |
| const pts = sampleXs.map((px, i) => | |
| `L${px.toFixed(1)},${(yMid - (densities[i] / maxD) * VIOLIN_H).toFixed(1)}` | |
| ).join(' ') | |
| const d = `M${LBL_W},${yMid} ${pts} L${VW - PAD_R},${yMid} Z` | |
| return ( | |
| <path | |
| key="violin" | |
| d={d} | |
| fill="var(--text-3)" | |
| fillOpacity={0.15} | |
| stroke="var(--text-3)" | |
| strokeOpacity={0.45} | |
| strokeWidth={1.2} | |
| strokeLinejoin="round" | |
| /> | |
| ) | |
| })()} | |
| {/* Tick marks + labels */} | |
| {[0, 0.25, 0.5, 0.75, 1].map(t => { | |
| const tx = LBL_W + t * AXIS_W | |
| const tickLabel = row.invert | |
| ? (t === 0 ? rawMax.toFixed(2) : t === 1 ? '0' : (rawMax * (1 - t)).toFixed(2)) | |
| : `${Math.round(t * 100)}%` | |
| return ( | |
| <g key={t}> | |
| <line x1={tx} y1={yMid - 5} x2={tx} y2={yMid + 5} | |
| stroke="var(--border-2)" strokeWidth={1} /> | |
| <text x={tx} y={y0 + ROW_H - 8} | |
| textAnchor="middle" fontSize={9} fill="var(--text-3)" fontFamily="inherit"> | |
| {tickLabel} | |
| </text> | |
| </g> | |
| ) | |
| })} | |
| {/* Mean line */} | |
| <line x1={meanX} y1={y0 + 10} x2={meanX} y2={y0 + ROW_H - 20} | |
| stroke="var(--text-3)" strokeOpacity={0.9} strokeWidth={1.2} strokeDasharray="4 3" /> | |
| <text x={meanX + 4} y={y0 + 20} fontSize={8} fill="var(--text-3)" fillOpacity={0.9} fontFamily="inherit"> | |
| avg {row.fmt(mean)} | |
| </text> | |
| {/* Dots — render hovered dot last so it stays on top */} | |
| {[...vals].sort((a, b) => { | |
| const aH = tooltip?.model === a.model ? 1 : 0 | |
| const bH = tooltip?.model === b.model ? 1 : 0 | |
| return aH - bH | |
| }).map((d, i) => { | |
| const isHovered = tooltip?.model === d.model | |
| const dimmed = tooltip !== null && !isHovered | |
| const cx = toX(d.val) | |
| return ( | |
| <g key={d.model}> | |
| <circle | |
| cx={cx} | |
| cy={yMid} | |
| r={isHovered ? 20 : DOT_R} | |
| fill={isHovered ? 'var(--bg-1)' : creatorColor(d.creator)} | |
| fillOpacity={dimmed ? 0.35 : 1} | |
| stroke={isHovered ? creatorColor(d.creator) : 'var(--bg-1)'} | |
| strokeWidth={isHovered ? 2.5 : 1.2} | |
| style={{ | |
| cursor: 'pointer', | |
| transition: 'r 0.15s, fill-opacity 0.15s', | |
| filter: dimmed ? 'saturate(0.15)' : undefined, | |
| }} | |
| onMouseEnter={() => setTooltip({ | |
| x: cx, | |
| y: yMid, | |
| model: d.model, | |
| val: row.fmt(d.val), | |
| creator: d.creator, | |
| })} | |
| /> | |
| </g> | |
| ) | |
| })} | |
| {/* Row label */} | |
| <text x={LBL_W - 16} y={yMid + 5} | |
| textAnchor="end" fontSize={12} fontWeight="700" | |
| fill="var(--text-1)" fontFamily="inherit"> | |
| {row.label} | |
| </text> | |
| </g> | |
| ) | |
| })} | |
| {/* Tooltip — rendered last so it appears above all dots */} | |
| {tooltip && (() => { | |
| const ttX = Math.max(0, Math.min(tooltip.x - TT_W / 2, VW - TT_W)) | |
| const above = tooltip.y - DOT_R - TT_H - 8 | |
| const ttY = above < 2 ? tooltip.y + DOT_R + 6 : above | |
| const cc = creatorColor(tooltip.creator) | |
| return ( | |
| <g style={{ pointerEvents: 'none' }}> | |
| <rect x={ttX} y={ttY} width={TT_W} height={TT_H} rx={5} | |
| fill="var(--bg-1)" stroke="var(--border-2)" strokeWidth={1.2} /> | |
| <rect x={ttX} y={ttY} width={3} height={TT_H} rx={2} fill={cc} /> | |
| <text x={ttX + 10} y={ttY + 13} | |
| fontSize={10} fontWeight="700" fill="var(--text-0)" fontFamily="inherit"> | |
| {tooltip.model} | |
| </text> | |
| <text x={ttX + 10} y={ttY + 27} | |
| fontSize={9} fill={cc} fontFamily="inherit"> | |
| {tooltip.val} | |
| </text> | |
| </g> | |
| ) | |
| })()} | |
| </svg> | |
| {/* Logo overlays — one per row for the hovered model */} | |
| {tooltip && STRIP_ROWS.map((row, ri) => { | |
| const m = active.find(m => m.model === tooltip.model) | |
| if (!m) return null | |
| const val = getAvg(m, row.key) | |
| if (val === null) return null | |
| const rawMax = row.key === 'fairness' ? maxFairness : 1 | |
| const pct = row.invert ? 1 - val / rawMax : val / rawMax | |
| const cx = LBL_W + pct * AXIS_W | |
| const yMid = ri * ROW_H + ROW_H / 2 | |
| const pctX = cx / VW * 100 | |
| const pctY = yMid / VH * 100 | |
| return ( | |
| <div key={row.key} style={{ | |
| position: 'absolute', | |
| left: `${pctX}%`, | |
| top: `${pctY}%`, | |
| transform: 'translate(-50%, -50%)', | |
| width: 40, | |
| height: 40, | |
| display: 'flex', | |
| alignItems: 'center', | |
| justifyContent: 'center', | |
| pointerEvents: 'none', | |
| }}> | |
| <LabLogo creator={tooltip.creator} size={32} /> | |
| </div> | |
| ) | |
| })} | |
| </div> | |
| ) | |
| } | |