rai-bench / src /components /GuardrailDistribution.tsx
rohanjaggi
fix: remove guardrails accuracy
2b01b44
Raw
History Blame Contribute Delete
8.03 kB
'use client'
import { useState } from 'react'
import type { GuardrailData } from '@/lib/types'
import { LabLogo } from '@/components/LabLogo'
import { guardrailCreatorColor } from '@/lib/utils'
const STRIP_ROWS = [
{ key: 'recall' as const, label: 'Recall', color: '#00C0F3', fmt: (v: number) => `${(v * 100).toFixed(0)}%` },
{ key: 'f1' as const, label: 'F1 Score', color: '#F4333D', fmt: (v: number) => `${(v * 100).toFixed(0)}%` },
]
function getVal(g: GuardrailData, key: 'recall' | 'f1'): number | null {
return g[key]
}
export default function GuardrailDistribution({ guardrails }: { guardrails: GuardrailData[] }) {
const [tooltip, setTooltip] = useState<{
x: number; y: number; guardrail: string; val: string; creator: string
} | null>(null)
const VW = 800
const ROW_H = 120
const LBL_W = 85
const PAD_R = 85
const AXIS_W = VW - LBL_W - PAD_R
const DOT_R = 7
const VH = STRIP_ROWS.length * ROW_H
const TT_W = 200
const TT_H = 36
const VIOLIN_H = 40
const VIOLIN_BW = AXIS_W * 0.07
return (
<div style={{ position: 'relative', width: '100%' }}>
<svg
viewBox={`0 0 ${VW} ${VH}`}
style={{ width: '100%', height: 'auto', display: 'block' }}
aria-label="Guardrail score distribution"
onMouseLeave={() => setTooltip(null)}
>
{STRIP_ROWS.map((row, ri) => {
const y0 = ri * ROW_H
const yMid = y0 + ROW_H / 2
const vals = guardrails
.map(g => ({ val: getVal(g, row.key), creator: g.creator, guardrail: g.guardrail }))
.filter(d => d.val !== null) as { val: number; creator: string; guardrail: string }[]
const mean = vals.reduce((s, d) => s + d.val, 0) / vals.length
const toX = (v: number) => LBL_W + (v) * AXIS_W
const meanX = toX(mean)
return (
<g key={row.key}>
{ri % 2 === 1 && (
<rect x={0} y={y0} width={VW} height={ROW_H} fill="var(--bg-0)" fillOpacity={0.5} />
)}
{ri > 0 && (
<line x1={0} y1={y0} x2={VW} y2={y0} stroke="var(--border-0)" strokeWidth={1} />
)}
<line x1={LBL_W} y1={yMid} x2={VW - PAD_R} y2={yMid}
stroke="var(--border-1)" strokeWidth={1} />
{/* Half violin */}
{vals.length > 1 && (() => {
const N = 90
const sampleXs = Array.from({ length: N }, (_, i) => LBL_W + (i / (N - 1)) * AXIS_W)
const densities = sampleXs.map(px =>
vals.reduce((s, d) => s + Math.exp(-0.5 * ((px - toX(d.val)) / VIOLIN_BW) ** 2), 0) /
(vals.length * VIOLIN_BW * Math.sqrt(2 * Math.PI))
)
const maxD = Math.max(...densities)
if (maxD <= 0) return null
const pts = sampleXs.map((px, i) =>
`L${px.toFixed(1)},${(yMid - (densities[i] / maxD) * VIOLIN_H).toFixed(1)}`
).join(' ')
const d = `M${LBL_W},${yMid} ${pts} L${VW - PAD_R},${yMid} Z`
return (
<path
key="violin"
d={d}
fill="var(--text-3)"
fillOpacity={0.15}
stroke="var(--text-3)"
strokeOpacity={0.45}
strokeWidth={1.2}
strokeLinejoin="round"
/>
)
})()}
{/* Tick marks */}
{[0, 0.25, 0.5, 0.75, 1].map(t => {
const tx = LBL_W + t * AXIS_W
return (
<g key={t}>
<line x1={tx} y1={yMid - 5} x2={tx} y2={yMid + 5}
stroke="var(--border-2)" strokeWidth={1} />
<text x={tx} y={y0 + ROW_H - 8}
textAnchor="middle" fontSize={9} fill="var(--text-3)" fontFamily="inherit">
{`${Math.round(t * 100)}%`}
</text>
</g>
)
})}
{/* Mean line */}
<line x1={meanX} y1={y0 + 10} x2={meanX} y2={y0 + ROW_H - 20}
stroke="var(--text-3)" strokeOpacity={0.9} strokeWidth={1.2} strokeDasharray="4 3" />
<text x={meanX + 4} y={y0 + 20} fontSize={8} fill="var(--text-3)" fillOpacity={0.9} fontFamily="inherit">
avg {row.fmt(mean)}
</text>
{/* Dots */}
{[...vals].sort((a, b) => {
const aH = tooltip?.guardrail === a.guardrail ? 1 : 0
const bH = tooltip?.guardrail === b.guardrail ? 1 : 0
return aH - bH
}).map((d) => {
const isHovered = tooltip?.guardrail === d.guardrail
const dimmed = tooltip !== null && !isHovered
const cx = toX(d.val)
return (
<g key={d.guardrail}>
<circle
cx={cx}
cy={yMid}
r={isHovered ? 20 : DOT_R}
fill={isHovered ? 'var(--bg-1)' : guardrailCreatorColor(d.creator)}
fillOpacity={dimmed ? 0.35 : 1}
stroke={isHovered ? guardrailCreatorColor(d.creator) : 'var(--bg-1)'}
strokeWidth={isHovered ? 2.5 : 1.2}
style={{
cursor: 'pointer',
transition: 'r 0.15s, fill-opacity 0.15s',
filter: dimmed ? 'saturate(0.15)' : undefined,
}}
onMouseEnter={() => setTooltip({
x: cx,
y: yMid,
guardrail: d.guardrail,
val: row.fmt(d.val),
creator: d.creator,
})}
/>
</g>
)
})}
{/* Row label */}
<text x={LBL_W - 16} y={yMid + 5}
textAnchor="end" fontSize={12} fontWeight="700"
fill="var(--text-1)" fontFamily="inherit">
{row.label}
</text>
</g>
)
})}
{/* Tooltip */}
{tooltip && (() => {
const ttX = Math.max(0, Math.min(tooltip.x - TT_W / 2, VW - TT_W))
const above = tooltip.y - DOT_R - TT_H - 8
const ttY = above < 2 ? tooltip.y + DOT_R + 6 : above
const cc = guardrailCreatorColor(tooltip.creator)
return (
<g style={{ pointerEvents: 'none' }}>
<rect x={ttX} y={ttY} width={TT_W} height={TT_H} rx={5}
fill="var(--bg-1)" stroke="var(--border-2)" strokeWidth={1.2} />
<rect x={ttX} y={ttY} width={3} height={TT_H} rx={2} fill={cc} />
<text x={ttX + 10} y={ttY + 13}
fontSize={10} fontWeight="700" fill="var(--text-0)" fontFamily="inherit">
{tooltip.guardrail}
</text>
<text x={ttX + 10} y={ttY + 27}
fontSize={9} fill={cc} fontFamily="inherit">
{tooltip.val}
</text>
</g>
)
})()}
</svg>
{/* Logo overlays — one per row for the hovered guardrail */}
{tooltip && STRIP_ROWS.map((row, ri) => {
const g = guardrails.find(g => g.guardrail === tooltip.guardrail)
if (!g) return null
const val = getVal(g, row.key)
if (val === null) return null
const cx = LBL_W + val * AXIS_W
const yMid = ri * ROW_H + ROW_H / 2
const pctX = cx / VW * 100
const pctY = yMid / VH * 100
return (
<div key={row.key} style={{
position: 'absolute',
left: `${pctX}%`,
top: `${pctY}%`,
transform: 'translate(-50%, -50%)',
width: 40,
height: 40,
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
pointerEvents: 'none',
}}>
<LabLogo creator={tooltip.creator} size={32} />
</div>
)
})}
</div>
)
}