rai-bench / src /components /BarChartSection.tsx
rohanjaggi
fix: layout and css adjustments
40f4a8e
Raw
History Blame Contribute Delete
9.15 kB
'use client'
import { useState, useMemo } from 'react'
import type { ModelData } from '@/lib/types'
import { fmtPct, fmtNum, creatorColor } from '@/lib/utils'
import { LabLogo } from '@/components/LabLogo'
type Metric = 'luc' | 'rag' | 'fairness'
const TABS: { key: Metric; label: string; yLabel: string }[] = [
{ key: 'luc', label: 'Refusal Rate', yLabel: 'Score (Higher is Better)' },
{ key: 'rag', label: 'RAG Score', yLabel: 'Score (Higher is Better)' },
{ key: 'fairness', label: 'Fairness', yLabel: 'Wasserstein Distance (Lower is Better)' },
]
function getValue(m: ModelData, metric: Metric): number | null {
if (metric === 'luc') return m.luc.avg
if (metric === 'rag') return m.rag.avg
return m.fairness.avg
}
// SVG layout (coordinate units)
const ML = 80 // left margin
const MR = 80 // right margin
const MT = 14 // top margin
const BAR_H = 300 // bar area height
const LBL_H = 150 // rotated label area
const VW = 1100
const VH = MT + BAR_H + LBL_H
const CHART_W = VW - ML - MR
const GAP = 4
const GRID = [0, 0.25, 0.5, 0.75, 1.0]
export default function BarChartSection({
models,
maxFairness,
}: {
models: ModelData[]
maxFairness: number
}) {
const [metric, setMetric] = useState<Metric>('luc')
const [activeCreators, setCreators] = useState<Set<string>>(new Set())
const allCreators = useMemo(
() => [...new Set(models.filter(m => !m.archived).map(m => m.creator))].sort(),
[models],
)
const toggleCreator = (c: string) =>
setCreators(prev => { const s = new Set(prev); s.has(c) ? s.delete(c) : s.add(c); return s })
const sorted = useMemo(() => {
let list = models.filter(m => !m.archived)
if (activeCreators.size) list = list.filter(m => activeCreators.has(m.creator))
list = list.filter(m => getValue(m, metric) !== null)
return [...list].sort((a, b) => {
const va = getValue(a, metric)!
const vb = getValue(b, metric)!
return metric === 'fairness' ? va - vb : vb - va
})
}, [models, metric, activeCreators])
const maxVal = metric === 'fairness' ? maxFairness : 1
const n = sorted.length
const bw = (CHART_W - (n - 1) * GAP) / n
const legendCreators = useMemo(
() => [...new Set(sorted.map(m => m.creator))].sort(),
[sorted],
)
const tab = TABS.find(t => t.key === metric)!
return (
<div>
{/* Filters row */}
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', gap: '0.75rem', marginBottom: '1rem', flexWrap: 'wrap' }}>
{/* Creator chips */}
<div style={{ display: 'flex', flexWrap: 'wrap', gap: 5 }}>
{allCreators.map(c => {
const cc = creatorColor(c)
const on = activeCreators.has(c)
return (
<button key={c} onClick={() => toggleCreator(c)} style={{
height: 26, padding: '0 10px',
border: `1.5px solid ${on ? cc : 'var(--border-1)'}`,
borderRadius: 5, fontSize: 10, fontFamily: 'inherit', fontWeight: 700,
color: on ? cc : 'var(--text-2)',
background: on ? `${cc}18` : 'var(--bg-0)',
cursor: 'pointer',
transition: 'all 0.15s',
display: 'inline-flex', alignItems: 'center', gap: 5,
}}>
<span style={{ display: 'inline-flex', alignItems: 'center', lineHeight: 0 }}>
<LabLogo creator={c} size={14} />
</span>
{c}
</button>
)
})}
</div>
{/* Metric tabs */}
<div style={{ display: 'inline-flex', border: '1.5px solid var(--border-1)', borderRadius: 7, overflow: 'hidden', flexShrink: 0 }}>
{TABS.map((t, i) => (
<button key={t.key} onClick={() => setMetric(t.key)} style={{
height: 28, padding: '0 13px',
borderLeft: i > 0 ? '1px solid var(--border-1)' : 'none',
outline: 'none', border: i > 0 ? '1px solid var(--border-1)' : 'none',
borderTop: 'none', borderRight: 'none', borderBottom: 'none',
fontSize: 10, fontFamily: 'inherit', fontWeight: 700,
letterSpacing: '0.05em', textTransform: 'uppercase' as const,
color: metric === t.key ? 'white' : 'var(--text-2)',
background: metric === t.key ? 'var(--accent)' : 'transparent',
cursor: 'pointer', whiteSpace: 'nowrap' as const,
transition: 'background 0.12s, color 0.12s',
}}>
{t.label}
</button>
))}
</div>
</div>
{/* Chart container */}
<div style={{
background: 'var(--bg-1)',
border: '1.5px solid var(--border-1)',
borderRadius: 10,
padding: '1rem 1rem 0',
boxShadow: '0 1px 4px rgba(0,0,0,0.05)',
}}>
<svg
viewBox={`0 0 ${VW} ${VH}`}
style={{ width: '100%', height: 'auto', display: 'block' }}
aria-label={`Bar chart: ${tab.label}`}
>
{/* Grid lines + y-axis labels */}
{GRID.map(lvl => {
const y = MT + BAR_H - lvl * BAR_H
const display = metric === 'fairness'
? (lvl * maxVal).toFixed(2)
: `${Math.round(lvl * 100)}`
return (
<g key={lvl}>
<line
x1={ML} y1={y} x2={VW - MR} y2={y}
stroke={lvl === 0 ? 'var(--border-2)' : 'var(--border-0)'}
strokeWidth={lvl === 0 ? 1 : 0.8}
strokeDasharray={lvl === 0 || lvl === 1 ? undefined : '4 4'}
/>
<text
x={ML - 5} y={y + 3.5}
textAnchor="end"
fontSize={9}
fill="var(--text-3)"
fontFamily="inherit"
>
{display}
</text>
</g>
)
})}
{/* Y-axis title */}
<text
x={11}
y={MT + BAR_H / 2}
textAnchor="middle"
fontSize={8.5}
fill="var(--text-3)"
fontFamily="inherit"
transform={`rotate(-90, 11, ${MT + BAR_H / 2})`}
>
{tab.yLabel}
</text>
{/* Bars + labels */}
{sorted.map((m, i) => {
const val = getValue(m, metric)!
const pct = Math.min(1, val / maxVal)
const barHeight = pct * BAR_H
const x = ML + i * (bw + GAP)
const y = MT + BAR_H - barHeight
const cc = creatorColor(m.creator)
const valLabel = metric === 'fairness'
? val.toFixed(3)
: `${(val * 100).toFixed(0)}%`
return (
<g key={`${m.rank}-${metric}`}>
{/* Bar */}
<rect
x={x} y={y}
width={bw} height={barHeight}
fill={cc} fillOpacity={0.88}
rx={2}
/>
{/* Value above bar */}
{barHeight > 18 && (
<text
x={x + bw / 2} y={y - 4}
textAnchor="middle"
fontSize={8}
fill={cc}
fontFamily="inherit"
fontWeight="bold"
>
{valLabel}
</text>
)}
{/* Model name */}
<text
x={x + bw / 2}
y={MT + BAR_H + 5}
transform={`rotate(-45, ${x + bw / 2}, ${MT + BAR_H + 5})`}
textAnchor="end"
fontSize={10}
fill="var(--text-1)"
fontFamily="inherit"
fontWeight="600"
>
{m.model}
</text>
</g>
)
})}
</svg>
{/* Creator legend */}
<div style={{
display: 'flex',
flexWrap: 'wrap',
gap: '8px 20px',
justifyContent: 'center',
padding: '12px 8px',
borderTop: '1px solid var(--border-0)',
marginTop: 4,
}}>
{legendCreators.map(c => (
<div key={c} style={{ display: 'flex', alignItems: 'center', gap: 6, fontSize: 12, color: 'var(--text-2)', fontWeight: 600 }}>
<div style={{ width: 11, height: 11, borderRadius: 2, background: creatorColor(c), flexShrink: 0 }} />
{c}
</div>
))}
</div>
</div>
</div>
)
}