'use client' import { useState } from 'react' import type { GuardrailData } from '@/lib/types' import { LabLogo } from '@/components/LabLogo' import { guardrailCreatorColor } from '@/lib/utils' const STRIP_ROWS = [ { key: 'recall' as const, label: 'Recall', color: '#00C0F3', fmt: (v: number) => `${(v * 100).toFixed(0)}%` }, { key: 'f1' as const, label: 'F1 Score', color: '#F4333D', fmt: (v: number) => `${(v * 100).toFixed(0)}%` }, ] function getVal(g: GuardrailData, key: 'recall' | 'f1'): number | null { return g[key] } export default function GuardrailDistribution({ guardrails }: { guardrails: GuardrailData[] }) { const [tooltip, setTooltip] = useState<{ x: number; y: number; guardrail: string; val: string; creator: string } | null>(null) const VW = 800 const ROW_H = 120 const LBL_W = 85 const PAD_R = 85 const AXIS_W = VW - LBL_W - PAD_R const DOT_R = 7 const VH = STRIP_ROWS.length * ROW_H const TT_W = 200 const TT_H = 36 const VIOLIN_H = 40 const VIOLIN_BW = AXIS_W * 0.07 return (
setTooltip(null)} > {STRIP_ROWS.map((row, ri) => { const y0 = ri * ROW_H const yMid = y0 + ROW_H / 2 const vals = guardrails .map(g => ({ val: getVal(g, row.key), creator: g.creator, guardrail: g.guardrail })) .filter(d => d.val !== null) as { val: number; creator: string; guardrail: string }[] const mean = vals.reduce((s, d) => s + d.val, 0) / vals.length const toX = (v: number) => LBL_W + (v) * AXIS_W const meanX = toX(mean) return ( {ri % 2 === 1 && ( )} {ri > 0 && ( )} {/* Half violin */} {vals.length > 1 && (() => { const N = 90 const sampleXs = Array.from({ length: N }, (_, i) => LBL_W + (i / (N - 1)) * AXIS_W) const densities = sampleXs.map(px => vals.reduce((s, d) => s + Math.exp(-0.5 * ((px - toX(d.val)) / VIOLIN_BW) ** 2), 0) / (vals.length * VIOLIN_BW * Math.sqrt(2 * Math.PI)) ) const maxD = Math.max(...densities) if (maxD <= 0) return null const pts = sampleXs.map((px, i) => `L${px.toFixed(1)},${(yMid - (densities[i] / maxD) * VIOLIN_H).toFixed(1)}` ).join(' ') const d = `M${LBL_W},${yMid} ${pts} L${VW - PAD_R},${yMid} Z` return ( ) })()} {/* Tick marks */} {[0, 0.25, 0.5, 0.75, 1].map(t => { const tx = LBL_W + t * AXIS_W return ( {`${Math.round(t * 100)}%`} ) })} {/* Mean line */} avg {row.fmt(mean)} {/* Dots */} {[...vals].sort((a, b) => { const aH = tooltip?.guardrail === a.guardrail ? 1 : 0 const bH = tooltip?.guardrail === b.guardrail ? 1 : 0 return aH - bH }).map((d) => { const isHovered = tooltip?.guardrail === d.guardrail const dimmed = tooltip !== null && !isHovered const cx = toX(d.val) return ( setTooltip({ x: cx, y: yMid, guardrail: d.guardrail, val: row.fmt(d.val), creator: d.creator, })} /> ) })} {/* Row label */} {row.label} ) })} {/* Tooltip */} {tooltip && (() => { const ttX = Math.max(0, Math.min(tooltip.x - TT_W / 2, VW - TT_W)) const above = tooltip.y - DOT_R - TT_H - 8 const ttY = above < 2 ? tooltip.y + DOT_R + 6 : above const cc = guardrailCreatorColor(tooltip.creator) return ( {tooltip.guardrail} {tooltip.val} ) })()} {/* Logo overlays — one per row for the hovered guardrail */} {tooltip && STRIP_ROWS.map((row, ri) => { const g = guardrails.find(g => g.guardrail === tooltip.guardrail) if (!g) return null const val = getVal(g, row.key) if (val === null) return null const cx = LBL_W + val * AXIS_W const yMid = ri * ROW_H + ROW_H / 2 const pctX = cx / VW * 100 const pctY = yMid / VH * 100 return (
) })}
) }