Spaces:
Running
Running
| 'use client' | |
| import { useState } from 'react' | |
| import type { GuardrailData } from '@/lib/types' | |
| import { LabLogo } from '@/components/LabLogo' | |
| import { guardrailCreatorColor } from '@/lib/utils' | |
| const STRIP_ROWS = [ | |
| { key: 'recall' as const, label: 'Recall', color: '#00C0F3', fmt: (v: number) => `${(v * 100).toFixed(0)}%` }, | |
| { key: 'f1' as const, label: 'F1 Score', color: '#F4333D', fmt: (v: number) => `${(v * 100).toFixed(0)}%` }, | |
| ] | |
| function getVal(g: GuardrailData, key: 'recall' | 'f1'): number | null { | |
| return g[key] | |
| } | |
| export default function GuardrailDistribution({ guardrails }: { guardrails: GuardrailData[] }) { | |
| const [tooltip, setTooltip] = useState<{ | |
| x: number; y: number; guardrail: string; val: string; creator: string | |
| } | null>(null) | |
| const VW = 800 | |
| const ROW_H = 120 | |
| const LBL_W = 85 | |
| const PAD_R = 85 | |
| const AXIS_W = VW - LBL_W - PAD_R | |
| const DOT_R = 7 | |
| const VH = STRIP_ROWS.length * ROW_H | |
| const TT_W = 200 | |
| const TT_H = 36 | |
| const VIOLIN_H = 40 | |
| const VIOLIN_BW = AXIS_W * 0.07 | |
| return ( | |
| <div style={{ position: 'relative', width: '100%' }}> | |
| <svg | |
| viewBox={`0 0 ${VW} ${VH}`} | |
| style={{ width: '100%', height: 'auto', display: 'block' }} | |
| aria-label="Guardrail score distribution" | |
| onMouseLeave={() => setTooltip(null)} | |
| > | |
| {STRIP_ROWS.map((row, ri) => { | |
| const y0 = ri * ROW_H | |
| const yMid = y0 + ROW_H / 2 | |
| const vals = guardrails | |
| .map(g => ({ val: getVal(g, row.key), creator: g.creator, guardrail: g.guardrail })) | |
| .filter(d => d.val !== null) as { val: number; creator: string; guardrail: string }[] | |
| const mean = vals.reduce((s, d) => s + d.val, 0) / vals.length | |
| const toX = (v: number) => LBL_W + (v) * AXIS_W | |
| const meanX = toX(mean) | |
| return ( | |
| <g key={row.key}> | |
| {ri % 2 === 1 && ( | |
| <rect x={0} y={y0} width={VW} height={ROW_H} fill="var(--bg-0)" fillOpacity={0.5} /> | |
| )} | |
| {ri > 0 && ( | |
| <line x1={0} y1={y0} x2={VW} y2={y0} stroke="var(--border-0)" strokeWidth={1} /> | |
| )} | |
| <line x1={LBL_W} y1={yMid} x2={VW - PAD_R} y2={yMid} | |
| stroke="var(--border-1)" strokeWidth={1} /> | |
| {/* Half violin */} | |
| {vals.length > 1 && (() => { | |
| const N = 90 | |
| const sampleXs = Array.from({ length: N }, (_, i) => LBL_W + (i / (N - 1)) * AXIS_W) | |
| const densities = sampleXs.map(px => | |
| vals.reduce((s, d) => s + Math.exp(-0.5 * ((px - toX(d.val)) / VIOLIN_BW) ** 2), 0) / | |
| (vals.length * VIOLIN_BW * Math.sqrt(2 * Math.PI)) | |
| ) | |
| const maxD = Math.max(...densities) | |
| if (maxD <= 0) return null | |
| const pts = sampleXs.map((px, i) => | |
| `L${px.toFixed(1)},${(yMid - (densities[i] / maxD) * VIOLIN_H).toFixed(1)}` | |
| ).join(' ') | |
| const d = `M${LBL_W},${yMid} ${pts} L${VW - PAD_R},${yMid} Z` | |
| return ( | |
| <path | |
| key="violin" | |
| d={d} | |
| fill="var(--text-3)" | |
| fillOpacity={0.15} | |
| stroke="var(--text-3)" | |
| strokeOpacity={0.45} | |
| strokeWidth={1.2} | |
| strokeLinejoin="round" | |
| /> | |
| ) | |
| })()} | |
| {/* Tick marks */} | |
| {[0, 0.25, 0.5, 0.75, 1].map(t => { | |
| const tx = LBL_W + t * AXIS_W | |
| return ( | |
| <g key={t}> | |
| <line x1={tx} y1={yMid - 5} x2={tx} y2={yMid + 5} | |
| stroke="var(--border-2)" strokeWidth={1} /> | |
| <text x={tx} y={y0 + ROW_H - 8} | |
| textAnchor="middle" fontSize={9} fill="var(--text-3)" fontFamily="inherit"> | |
| {`${Math.round(t * 100)}%`} | |
| </text> | |
| </g> | |
| ) | |
| })} | |
| {/* Mean line */} | |
| <line x1={meanX} y1={y0 + 10} x2={meanX} y2={y0 + ROW_H - 20} | |
| stroke="var(--text-3)" strokeOpacity={0.9} strokeWidth={1.2} strokeDasharray="4 3" /> | |
| <text x={meanX + 4} y={y0 + 20} fontSize={8} fill="var(--text-3)" fillOpacity={0.9} fontFamily="inherit"> | |
| avg {row.fmt(mean)} | |
| </text> | |
| {/* Dots */} | |
| {[...vals].sort((a, b) => { | |
| const aH = tooltip?.guardrail === a.guardrail ? 1 : 0 | |
| const bH = tooltip?.guardrail === b.guardrail ? 1 : 0 | |
| return aH - bH | |
| }).map((d) => { | |
| const isHovered = tooltip?.guardrail === d.guardrail | |
| const dimmed = tooltip !== null && !isHovered | |
| const cx = toX(d.val) | |
| return ( | |
| <g key={d.guardrail}> | |
| <circle | |
| cx={cx} | |
| cy={yMid} | |
| r={isHovered ? 20 : DOT_R} | |
| fill={isHovered ? 'var(--bg-1)' : guardrailCreatorColor(d.creator)} | |
| fillOpacity={dimmed ? 0.35 : 1} | |
| stroke={isHovered ? guardrailCreatorColor(d.creator) : 'var(--bg-1)'} | |
| strokeWidth={isHovered ? 2.5 : 1.2} | |
| style={{ | |
| cursor: 'pointer', | |
| transition: 'r 0.15s, fill-opacity 0.15s', | |
| filter: dimmed ? 'saturate(0.15)' : undefined, | |
| }} | |
| onMouseEnter={() => setTooltip({ | |
| x: cx, | |
| y: yMid, | |
| guardrail: d.guardrail, | |
| val: row.fmt(d.val), | |
| creator: d.creator, | |
| })} | |
| /> | |
| </g> | |
| ) | |
| })} | |
| {/* Row label */} | |
| <text x={LBL_W - 16} y={yMid + 5} | |
| textAnchor="end" fontSize={12} fontWeight="700" | |
| fill="var(--text-1)" fontFamily="inherit"> | |
| {row.label} | |
| </text> | |
| </g> | |
| ) | |
| })} | |
| {/* Tooltip */} | |
| {tooltip && (() => { | |
| const ttX = Math.max(0, Math.min(tooltip.x - TT_W / 2, VW - TT_W)) | |
| const above = tooltip.y - DOT_R - TT_H - 8 | |
| const ttY = above < 2 ? tooltip.y + DOT_R + 6 : above | |
| const cc = guardrailCreatorColor(tooltip.creator) | |
| return ( | |
| <g style={{ pointerEvents: 'none' }}> | |
| <rect x={ttX} y={ttY} width={TT_W} height={TT_H} rx={5} | |
| fill="var(--bg-1)" stroke="var(--border-2)" strokeWidth={1.2} /> | |
| <rect x={ttX} y={ttY} width={3} height={TT_H} rx={2} fill={cc} /> | |
| <text x={ttX + 10} y={ttY + 13} | |
| fontSize={10} fontWeight="700" fill="var(--text-0)" fontFamily="inherit"> | |
| {tooltip.guardrail} | |
| </text> | |
| <text x={ttX + 10} y={ttY + 27} | |
| fontSize={9} fill={cc} fontFamily="inherit"> | |
| {tooltip.val} | |
| </text> | |
| </g> | |
| ) | |
| })()} | |
| </svg> | |
| {/* Logo overlays — one per row for the hovered guardrail */} | |
| {tooltip && STRIP_ROWS.map((row, ri) => { | |
| const g = guardrails.find(g => g.guardrail === tooltip.guardrail) | |
| if (!g) return null | |
| const val = getVal(g, row.key) | |
| if (val === null) return null | |
| const cx = LBL_W + val * AXIS_W | |
| const yMid = ri * ROW_H + ROW_H / 2 | |
| const pctX = cx / VW * 100 | |
| const pctY = yMid / VH * 100 | |
| return ( | |
| <div key={row.key} style={{ | |
| position: 'absolute', | |
| left: `${pctX}%`, | |
| top: `${pctY}%`, | |
| transform: 'translate(-50%, -50%)', | |
| width: 40, | |
| height: 40, | |
| display: 'flex', | |
| alignItems: 'center', | |
| justifyContent: 'center', | |
| pointerEvents: 'none', | |
| }}> | |
| <LabLogo creator={tooltip.creator} size={32} /> | |
| </div> | |
| ) | |
| })} | |
| </div> | |
| ) | |
| } | |