import { useState, useEffect, useCallback, useMemo } from "react"; import { HF_ORG } from "../../../config"; // --------------------------------------------------------------------------- // Types // --------------------------------------------------------------------------- interface HeatmapViewerProps { datasetRepo: string; split?: string; onClose: () => void; } interface SummaryRow { model: string; hf_model_name: string; local_model_name: string; task: string; split: string; self_tc: boolean; neg_tc: boolean; gpt2_tc: boolean; finetuned: boolean; training_config: string; eval_variant: string; gen_roc: number | null; val_roc: number | null; val_acc: number | null; corr: number | null; corr_pos: number | null; corr_neg: number | null; n_samples: number; filename: string; } type Metric = "gen_roc" | "val_roc" | "val_acc" | "corr" | "corr_pos" | "corr_neg"; type EvalVariant = "raw" | "tc" | "lenorm" | "tc+lenorm"; type TCType = "none" | "self" | "neg" | "gpt2"; type ViewMode = "heatmap" | "bar"; type DomainFilter = "all" | "ood" | "in-domain"; type ComparisonPreset = "all" | "training-effect" | "plus-tcself" | "tcself-vs-tcneg" | "tcself-vs-tcgpt2"; // --------------------------------------------------------------------------- // Constants // --------------------------------------------------------------------------- const HF_DATASETS_API = "https://datasets-server.huggingface.co"; const METRICS: { key: Metric; label: string }[] = [ { key: "gen_roc", label: "Gen ROC" }, { key: "val_roc", label: "Val ROC" }, { key: "val_acc", label: "Val Acc" }, { key: "corr", label: "Correlation" }, { key: "corr_pos", label: "Corr (pos)" }, { key: "corr_neg", label: "Corr (neg)" }, ]; const EVAL_VARIANTS: EvalVariant[] = ["raw", "tc", "lenorm", "tc+lenorm"]; const TC_TYPES: { key: TCType; label: string }[] = [ { key: "self", label: "Self TC" }, { key: "neg", label: "Neg TC" }, { key: "gpt2", label: "GPT-2 TC" }, ]; // Plotly default color cycle for bar charts const BAR_COLORS = [ "#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A", "#19D3F3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52", ]; const TASK_FAMILIES = [ { key: "hypernym", label: "Hypernym", prefix: "hypernym-" }, { key: "ifeval", label: "IFEval", prefix: "ifeval-" }, { key: "plausibleqa", label: "PlausibleQA", prefix: "plausibleqa-" }, { key: "ambigqa", label: "AmbigQA", prefix: "ambigqa-" }, ]; // Finetuned models must use force-same-x and the correct combined training dataset. // Hypernym and IFEval have single-task variants that must be excluded. function isValidFinetunedModel(row: SummaryRow): boolean { if (!row.finetuned) return true; const tc = row.training_config; if (!tc.includes("force-same-x")) return false; // Hypernym models must be trained on the combined "double" dataset if (tc.includes("hypernym-") && !tc.includes("hypernym-concat-bananas-to-dogs-double")) return false; // IFEval models must be trained on the concat dataset if (tc.includes("ifeval-") && !tc.includes("ifeval-concat")) return false; return true; } const COMPARISON_PRESETS: { key: ComparisonPreset; label: string }[] = [ { key: "all", label: "All models" }, { key: "training-effect", label: "Training Effect" }, { key: "plus-tcself", label: "+ TC-Self" }, { key: "tcself-vs-tcneg", label: "TC-Self vs TC-Neg" }, { key: "tcself-vs-tcgpt2", label: "TC-Self vs TC-GPT2" }, ]; // Preset row sets: // Training Effect: Base + lo/semi × {Pref, SFT, Comb-v} const TRAINING_EFFECT_ROWS = new Set([ "Base", "Pref-lo", "SFT-lo", "Comb-v-lo", "Pref-semi", "SFT-semi", "Comb-v-semi", ]); const DEFAULT_HIDDEN_ROWS = new Set([ "Pref-v-lo", "Pref-tcself-v-lo", "Pref-tcself-norm-lo", "Pref-tcself-norm-v-lo", "Comb-tcself-norm-v-lo", "SFT-tcself-norm-semi", "Comb-tcself-norm-v-semi", // PlausibleQA models without lo/semi — not comparable "Comb-tcself", "Comb-tcself-v", "SFT-tcself-v", ]); // Strip a TC flag from a row label to get the "base" label for pairing const TC_FLAGS = ["tco", "tcself", "tcneg"]; function stripTCFlag(label: string): string { const parts = label.split("-"); return parts.filter((p) => !TC_FLAGS.includes(p)).join("-"); } function hasTCFlag(label: string, flag: string): boolean { return label.split("-").includes(flag); } // IFEval prompts 1-21 are OOD (test-only, never trained on) const IFEVAL_OOD_MAX = 21; // Hypernym: fixed in-domain / OOD split const HYPERNYM_IN_DOMAIN = new Set([ "bananas", "bazookas", "cabinets", "cars", "chairs", "crows", "diapers", "dogs", ]); const HYPERNYM_OOD = new Set([ "ducklings", "elephants", "guns", "hammers", "helmets", "jackets", "kayaks", "kites", "mirrors", ]); const HYPERNYM_VALID = new Set([...HYPERNYM_IN_DOMAIN, ...HYPERNYM_OOD]); // Valid eval task patterns — excludes training tasks like concat, bare family names, etc. function isValidEvalTask(task: string): boolean { if (task.startsWith("hypernym-")) { const subtask = task.replace("hypernym-", ""); return HYPERNYM_VALID.has(subtask); } if (task.startsWith("ifeval-")) { return /^ifeval-prompt[_-]\d+$/.test(task); } if (task.startsWith("plausibleqa-")) { return task !== "plausibleqa"; // must have a subtask } if (task.startsWith("ambigqa-")) { return task !== "ambigqa"; // must have a subtask } // Bare family names without subtask suffix are training tasks if (task === "ambigqa" || task === "plausibleqa" || task === "hypernym" || task === "ifeval") { return false; } return true; } // --------------------------------------------------------------------------- // Color scale: RdYlGn for [0, 1] mapped to percentages // --------------------------------------------------------------------------- // Plotly RdYlGn colorscale stops (matches Plotly's built-in) const RDYLGN_STOPS: [number, [number, number, number]][] = [ [0.0, [165, 0, 38]], [0.1, [215, 48, 39]], [0.2, [244, 109, 67]], [0.3, [253, 174, 97]], [0.4, [254, 224, 139]], [0.5, [255, 255, 191]], [0.6, [217, 239, 139]], [0.7, [166, 217, 106]], [0.8, [102, 189, 99]], [0.9, [26, 152, 80]], [1.0, [0, 104, 55]], ]; function interpolateRdYlGn(t: number): string { const clamped = Math.max(0, Math.min(1, t)); // Find the two stops to interpolate between for (let i = 0; i < RDYLGN_STOPS.length - 1; i++) { const [t0, c0] = RDYLGN_STOPS[i]; const [t1, c1] = RDYLGN_STOPS[i + 1]; if (clamped >= t0 && clamped <= t1) { const f = (clamped - t0) / (t1 - t0); const r = Math.round(c0[0] + f * (c1[0] - c0[0])); const g = Math.round(c0[1] + f * (c1[1] - c0[1])); const b = Math.round(c0[2] + f * (c1[2] - c0[2])); return `rgb(${r}, ${g}, ${b})`; } } return `rgb(0, 104, 55)`; } function rdYlGn(value: number): string { if (isNaN(value)) return "#f3f4f6"; // light gray for no data return interpolateRdYlGn(value); } function corrColor(value: number): string { if (isNaN(value)) return "#f3f4f6"; return rdYlGn((value + 1) / 2); } function textColor(bgValue: number, isCorr: boolean): string { if (isNaN(bgValue)) return "#9ca3af"; const t = isCorr ? (bgValue + 1) / 2 : bgValue; // Dark text on the light middle, white text on dark extremes return t > 0.3 && t < 0.7 ? "#1f2937" : "#ffffff"; } // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- function getEvalTCType(row: SummaryRow): TCType { if (row.self_tc) return "self"; if (row.neg_tc) return "neg"; if (row.gpt2_tc) return "gpt2"; return "none"; } function getTrainingTCType(row: SummaryRow): TCType { const tc = row.training_config; if (tc.includes("tc-online") || tc.includes("tc_online")) return "gpt2"; if (tc.includes("tc-self") || tc.includes("tc_self")) return "self"; if (tc.includes("tc-neg") || tc.includes("tc_neg")) return "neg"; return "none"; } /** For finetuned models, check that eval TC matches training TC */ function isMatchedTC(row: SummaryRow): boolean { if (!row.finetuned) return true; // base models: any eval TC is valid const trainTC = getTrainingTCType(row); if (trainTC === "none") return true; // model not trained with TC: any eval TC is valid const evalTC = getEvalTCType(row); return evalTC === trainTC; } function getTaskFamily(task: string): string | null { for (const fam of TASK_FAMILIES) { if (task.startsWith(fam.prefix)) return fam.key; } return null; } function parseIfevalPromptNum(task: string): number | null { // ifeval-prompt_42 or ifeval-prompt-42 const m = task.match(/ifeval-prompt[_-](\d+)/); return m ? parseInt(m[1], 10) : null; } function isOodTask(task: string, trainingConfig: string, family: string): boolean | null { // Returns true=OOD, false=in-domain, null=unknown/not applicable if (family === "plausibleqa" || family === "ambigqa") return true; // always OOD if (family === "ifeval") { const num = parseIfevalPromptNum(task); if (num === null) return null; return num <= IFEVAL_OOD_MAX; } if (family === "hypernym") { const subtask = task.replace("hypernym-", ""); if (HYPERNYM_OOD.has(subtask)) return true; if (HYPERNYM_IN_DOMAIN.has(subtask)) return false; return null; // unknown subtask } return null; } function getTrainingFamily(config: string): string | null { for (const fam of TASK_FAMILIES) { if (config.includes(fam.prefix)) return fam.key; } return null; } function extractFloat(pattern: RegExp, text: string): number | null { const m = text.match(pattern); return m ? parseFloat(m[1]) : null; } function parseTrainingMode(config: string): string { // Parse pref/nllv/nllg weights from training config. // Defaults when absent: pref=1.0, nllv=0.0, nllg=0.0 const pref = extractFloat(/(?:^|[_-])pref(\d+(?:\.\d+)?)/, config) ?? 1.0; const nllv = extractFloat(/nllv(\d+(?:\.\d+)?)/, config) ?? 0.0; const nllg = extractFloat(/nllg(\d+(?:\.\d+)?)/, config) ?? 0.0; const isSFT = pref === 0.0 && nllv === 1.0 && nllg === 1.0; const isPref = nllv === 0.0 && nllg === 0.0; const isComb = nllv === 1.0 && nllg === 1.0 && pref === 1.0; if (isSFT) return "SFT"; if (isComb) return "Comb"; if (isPref) return "Pref"; return "Pref"; // fallback } function buildRowLabel(config: string): string { const mode = parseTrainingMode(config); const flags: string[] = []; // TC flags are mutually exclusive const hasTco = config.includes("_tc-online_") || config.includes("-tc-online-"); const hasTcself = config.includes("_tc-self_") || config.includes("-tc-self-"); const hasTcneg = config.includes("_tc-neg_") || config.includes("-tc-neg-"); const tcCount = [hasTco, hasTcself, hasTcneg].filter(Boolean).length; if (tcCount > 1) { console.warn(`Multiple TC flags in training config (expected at most 1): ${config}`); } if (hasTco) flags.push("tco"); if (hasTcself) flags.push("tcself"); if (hasTcneg) flags.push("tcneg"); // Optional independent flags if (config.includes("_lenorm_") || config.includes("-lenorm-")) flags.push("norm"); if (config.includes("_vallogodds") || config.includes("-vallogodds")) flags.push("v"); // Data regime flags if (config.includes("labelonly")) flags.push("lo"); if (config.includes("semi")) flags.push("semi"); const parts = [mode, ...flags]; return parts.join("-"); } function getRowLabel(row: SummaryRow): string { if (!row.finetuned) return "Base"; return buildRowLabel(row.training_config); } // Display-friendly row label: reorder to [regime, tc, mode, extras], spaces, Pref→RankAlign function displayRowLabel(label: string): string { if (label === "Base") return "Base"; const parts = label.split("-"); const mode = parts[0] === "Pref" ? "RankAlign" : parts[0]; const flags = parts.slice(1); // Extract known flag groups const regime = flags.filter((f) => f === "lo" || f === "semi"); const tc = flags.filter((f) => f === "tcself" || f === "tcneg" || f === "tco"); const extras = flags.filter((f) => !["lo", "semi", "tcself", "tcneg", "tco"].includes(f)); // Order: regime, tc, mode, extras return [...regime, ...tc, mode, ...extras].join(" "); } // JSX version with colored TC flags for use in HTML contexts function DisplayRowLabel({ label }: { label: string }) { if (label === "Base") return <>Base; const parts = label.split("-"); const mode = parts[0] === "Pref" ? "RankAlign" : parts[0]; const flags = parts.slice(1); const regime = flags.filter((f) => f === "lo" || f === "semi"); const tc = flags.filter((f) => f === "tcself" || f === "tcneg" || f === "tco"); const extras = flags.filter((f) => !["lo", "semi", "tcself", "tcneg", "tco"].includes(f)); const tokens = [...regime, ...tc, mode, ...extras]; const tcColor: Record = { tcself: "#f87171", tcneg: "#fb923c", tco: "#a78bfa" }; return ( <> {tokens.map((t, i) => ( {i > 0 && " "} {tcColor[t] ? {t} : t} ))} ); } // SVG version with colored TC flags (uses tspan) function SvgRowLabel({ label, maxLen = 22 }: { label: string; maxLen?: number }) { if (label === "Base") return <>Base; const parts = label.split("-"); const mode = parts[0] === "Pref" ? "RankAlign" : parts[0]; const flags = parts.slice(1); const regime = flags.filter((f) => f === "lo" || f === "semi"); const tc = flags.filter((f) => f === "tcself" || f === "tcneg" || f === "tco"); const extras = flags.filter((f) => !["lo", "semi", "tcself", "tcneg", "tco"].includes(f)); const tokens = [...regime, ...tc, mode, ...extras]; const full = tokens.join(" "); const display = full.length > maxLen ? full.slice(0, maxLen - 2) + ".." : full; const tcColor: Record = { tcself: "#f87171", tcneg: "#fb923c", tco: "#a78bfa" }; // Re-tokenize the display string to color tc flags const displayTokens = display.split(" "); return ( <> {displayTokens.map((t, i) => ( {i > 0 && " "}{t} ))} ); } function mean(values: number[]): number { if (values.length === 0) return NaN; return values.reduce((a, b) => a + b, 0) / values.length; } function stdErr(values: number[]): number { if (values.length < 2) return 0; const m = mean(values); const variance = values.reduce((sum, v) => sum + (v - m) ** 2, 0) / (values.length - 1); return Math.sqrt(variance) / Math.sqrt(values.length); } // --------------------------------------------------------------------------- // Data fetching — downloads the parquet file directly (1 request, ~1MB) // --------------------------------------------------------------------------- function useDatasetRows(repo: string, _split: string) { const [rows, setRows] = useState([]); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const [progress, setProgress] = useState({ loaded: 0, total: 0 }); const fetchAll = useCallback(async () => { setLoading(true); setError(null); setRows([]); try { // Step 1: get the parquet file URL from HF datasets server const metaUrl = `${HF_DATASETS_API}/parquet?dataset=${encodeURIComponent(repo)}`; const metaResp = await fetch(metaUrl); if (!metaResp.ok) { throw new Error(`Failed to get parquet info: ${metaResp.status}`); } const metaData = await metaResp.json(); const parquetFiles = metaData.parquet_files ?? []; const matchingFile = parquetFiles.find((f: { split: string }) => f.split === _split) ?? parquetFiles[0]; if (!matchingFile) { throw new Error("No parquet files found for this dataset"); } setProgress({ loaded: 0, total: matchingFile.size ?? 0 }); // Step 2: download the parquet file directly from HF (static file, no rate limits) const parquetUrl: string = matchingFile.url; const { asyncBufferFromUrl, parquetRead } = await import("hyparquet"); const file = await asyncBufferFromUrl({ url: parquetUrl }); // Step 3: parse all rows const allRows: SummaryRow[] = []; await parquetRead({ file, rowFormat: "object", onComplete: (data: Record[]) => { for (const row of data) { allRows.push(row as unknown as SummaryRow); } }, }); setProgress({ loaded: allRows.length, total: allRows.length }); setRows(allRows); } catch (err) { setError(err instanceof Error ? err.message : String(err)); } finally { setLoading(false); } }, [repo, _split]); useEffect(() => { fetchAll(); }, [fetchAll]); return { rows, loading, error, progress, refetch: fetchAll }; } // --------------------------------------------------------------------------- // Aggregation // --------------------------------------------------------------------------- interface AggCell { mean: number; se: number; n: number; } /** * Validate that each row label maps to exactly one model identity. * Checks BOTH the base model (r.model) AND training_config. * If two different models or configs produce the same row label, * we're silently averaging different models — a critical bug. * * Returns null if clean, or an error message string if collisions found. */ function validateRowLabels(rows: SummaryRow[], groupByRow: (r: SummaryRow) => string): string | null { // Check 1: all rows must come from the same base model const baseModels = new Set(rows.map((r) => r.model)); if (baseModels.size > 1) { return ( `DATA INTEGRITY ERROR: Multiple base models in the same view!\n` + `Found ${baseModels.size} different models being mixed together:\n` + Array.from(baseModels).map((m) => ` • ${m}`).join("\n") + `\n\nThis means results from different models are being averaged. ` + `Filter by a single model before displaying.` ); } // Check 2: each row label maps to exactly one training_config const labelToConfigs = new Map>(); for (const row of rows) { const label = groupByRow(row); const config = row.finetuned ? row.training_config : "__base__"; if (!labelToConfigs.has(label)) labelToConfigs.set(label, new Set()); labelToConfigs.get(label)!.add(config); } const collisions: string[] = []; for (const [label, configs] of labelToConfigs) { if (configs.size > 1) { collisions.push( `"${label}" → ${configs.size} configs: ${Array.from(configs).join(", ")}` ); } } if (collisions.length > 0) { return ( `DATA INTEGRITY ERROR: Row label collisions detected!\n` + `The following labels map to multiple different model configs ` + `(their results are being silently averaged):\n\n` + collisions.join("\n") ); } return null; } interface AggResult { data: Map>; validationError: string | null; } function aggregateData( rows: SummaryRow[], metric: Metric, groupByRow: (r: SummaryRow) => string, ): AggResult { // Validate: each row label must correspond to one model const validationError = validateRowLabels(rows, groupByRow); // Group: rowLabel → evalVariant → values[] const groups = new Map>(); for (const row of rows) { const rl = groupByRow(row); const ev = row.eval_variant as EvalVariant; if (!EVAL_VARIANTS.includes(ev)) continue; const val = row[metric]; if (val === null || val === undefined || isNaN(val as number)) continue; if (!groups.has(rl)) groups.set(rl, new Map()); const evMap = groups.get(rl)!; if (!evMap.has(ev)) evMap.set(ev, []); evMap.get(ev)!.push(val as number); } // Compute aggregates const result = new Map>(); for (const [rl, evMap] of groups) { const aggMap = new Map(); for (const [ev, vals] of evMap) { aggMap.set(ev, { mean: mean(vals), se: stdErr(vals), n: vals.length }); } result.set(rl, aggMap); } return { data: result, validationError }; } // --------------------------------------------------------------------------- // Sort row labels: Base first, then alphabetical // --------------------------------------------------------------------------- // Sort order: Base first, then by [tc group, data regime, mode] // tc group: no-tc (0) < tcself (1) < tcneg (2) < tco (3) // data regime: lo (0) < semi (1) < other (2) // mode: Pref (0) < SFT (1) < Comb (2) < other (3) function rowSortKey(label: string): [number, number, number, number, string] { if (label === "Base") return [-1, 0, 0, 0, ""]; const parts = label.split("-"); // TC group let tcGroup = 0; if (parts.includes("tcself")) tcGroup = 1; else if (parts.includes("tcneg")) tcGroup = 2; else if (parts.includes("tco")) tcGroup = 3; // Data regime let regime = 2; if (parts.includes("lo")) regime = 0; else if (parts.includes("semi")) regime = 1; // Mode (first part) const mode = parts[0] ?? ""; const modeOrder: Record = { Pref: 0, SFT: 1, Comb: 2 }; // Remaining flags for tiebreak const remaining = parts.filter((p) => !["Pref", "SFT", "Comb"].includes(p)).join("-"); return [0, tcGroup, regime, modeOrder[mode] ?? 3, remaining]; } function sortRowLabels(labels: string[]): string[] { return [...labels].sort((a, b) => { const ka = rowSortKey(a); const kb = rowSortKey(b); for (let i = 0; i < 4; i++) { if (ka[i] !== kb[i]) return (ka[i] as number) - (kb[i] as number); } return ka[4].localeCompare(kb[4]); }); } // --------------------------------------------------------------------------- // Sub-components // --------------------------------------------------------------------------- function Dropdown({ label, value, options, onChange, }: { label: string; value: T; options: { key: T; label: string }[]; onChange: (v: T) => void; }) { return (
); } function MultiSelect({ label, selected, options, onChange, }: { label: string; selected: Set; options: { key: T; label: string }[]; onChange: (s: Set) => void; }) { const toggle = (key: T) => { const next = new Set(selected); if (next.has(key)) next.delete(key); else next.add(key); onChange(next); }; return (
{options.map((o) => ( ))}
); } // --------------------------------------------------------------------------- // Heatmap component (pure HTML/CSS) // --------------------------------------------------------------------------- function HeatmapGrid({ data, rowLabels, colLabels, metric, title, fullConfigs, }: { data: Map>; rowLabels: string[]; colLabels: EvalVariant[]; metric: Metric; title: string; fullConfigs: Map; }) { const isCorr = metric.startsWith("corr"); const formatVal = (v: number) => { if (isNaN(v)) return "-"; return isCorr ? v.toFixed(3) : (v * 100).toFixed(1); }; const colorFn = isCorr ? corrColor : rdYlGn; return (

{title}

{colLabels.map((col) => ( ))} {rowLabels.map((rl) => { const evMap = data.get(rl); return ( {colLabels.map((col) => { const cell = evMap?.get(col); const val = cell?.mean ?? NaN; const bg = colorFn(isCorr ? val : val); const fg = textColor(val, isCorr); return ( ); })} ); })}
Model {col}
1 ? ` (n=${cell.n}, se=${formatVal(cell.se)})` : ""}` : "no data"} > {formatVal(val)}
); } // --------------------------------------------------------------------------- // Bar chart component (SVG) // --------------------------------------------------------------------------- function BarChart({ data, rowLabels, evalVariants, metric, title, }: { data: Map>; rowLabels: string[]; evalVariants: EvalVariant[]; metric: Metric; title: string; }) { const isCorr = metric.startsWith("corr"); const subBarWidth = 18; const subGap = 2; const groupGap = 16; const numSub = evalVariants.length; const groupWidth = numSub * subBarWidth + (numSub - 1) * subGap; const chartHeight = 200; const marginTop = 20; const marginBottom = 80; const marginLeft = 50; const legendHeight = 24; const svgWidth = marginLeft + rowLabels.length * (groupWidth + groupGap) + 20; const svgHeight = chartHeight + marginTop + marginBottom + legendHeight; // Collect all values for scale const allCells: { mean: number; se: number }[] = []; for (const rl of rowLabels) { for (const ev of evalVariants) { const cell = data.get(rl)?.get(ev); if (cell && !isNaN(cell.mean)) allCells.push(cell); } } if (allCells.length === 0) { return (

{title}

No data

); } let minVal: number, maxVal: number; const low = Math.min(...allCells.map((c) => c.mean - c.se)); const high = Math.max(...allCells.map((c) => c.mean + c.se)); if (isCorr) { // Data-driven range with padding, clamped to [-1, 1] minVal = Math.max(-1, low - 0.05); maxVal = Math.min(1, high + 0.05); } else { minVal = Math.max(0, low - 0.05); maxVal = Math.min(1, high + 0.05); } const range = maxVal - minVal || 1; const yScale = (v: number) => marginTop + chartHeight * (1 - (v - minVal) / range); // Tick marks const ticks: number[] = []; const corrRange = maxVal - minVal; const step = isCorr ? (corrRange > 0.5 ? 0.2 : 0.1) : 0.1; for (let t = Math.ceil(minVal / step) * step; t <= maxVal; t += step) { ticks.push(Math.round(t * 1000) / 1000); } // Shade: raw=full, tc=darker, lenorm=lighter, tc+lenorm=darkest const variantOpacity: Record = { raw: 0.85, tc: 0.6, lenorm: 0.7, "tc+lenorm": 0.45 }; return (

{title}

{/* Hatching patterns for tc variants */} {/* Grid lines */} {ticks.map((t) => ( {isCorr ? t.toFixed(1) : (t * 100).toFixed(0)} ))} {/* Grouped bars */} {rowLabels.map((rl, i) => { const groupX = marginLeft + i * (groupWidth + groupGap); const color = BAR_COLORS[i % BAR_COLORS.length]; const labelX = groupX + groupWidth / 2; return ( {evalVariants.map((ev, j) => { const cell = data.get(rl)?.get(ev); const val = cell?.mean ?? NaN; if (isNaN(val)) return null; const x = groupX + j * (subBarWidth + subGap); const barY = yScale(val); const baseY = yScale(isCorr ? 0 : minVal); const barH = Math.abs(baseY - barY); const actualY = Math.min(barY, baseY); const opacity = variantOpacity[ev] ?? 0.85; const hasHatch = ev === "tc" || ev === "tc+lenorm"; const cx = x + subBarWidth / 2; const se = cell?.se ?? 0; const seTop = yScale(Math.min(maxVal, val + se)); const seBot = yScale(Math.max(minVal, val - se)); return ( {hasHatch && ( )} {se > 0 && ( <> )} {`${displayRowLabel(rl)} [${ev}]: ${isCorr ? val.toFixed(3) : (val * 100).toFixed(1)} (n=${cell?.n ?? 0}, se=${se.toFixed(4)})`} ); })} {/* Model label */} ); })} {/* Legend */} {evalVariants.map((ev, j) => { const lx = marginLeft + j * 80; const ly = svgHeight - 10; const opacity = variantOpacity[ev] ?? 0.85; const hasHatch = ev === "tc" || ev === "tc+lenorm"; return ( {hasHatch && } {ev} ); })}
); } // --------------------------------------------------------------------------- // Delta comparison view: side-by-side heatmaps + delta // --------------------------------------------------------------------------- // Diverging color scale for deltas: red (negative) ↔ white (zero) ↔ blue (positive) // Uses Plotly's RdBu stops (reversed so blue = positive) const DELTA_NEG: [number, number, number][] = [ [103, 0, 31], // -max: deep red [178, 24, 43], [214, 96, 77], [244, 165, 130], [253, 219, 199], [255, 255, 255], // zero: white ]; const DELTA_POS: [number, number, number][] = [ [255, 255, 255], // zero: white [209, 229, 240], [146, 197, 222], [67, 147, 195], [33, 102, 172], [5, 48, 97], // +max: deep blue ]; function deltaColor(delta: number): string { if (isNaN(delta)) return "#f3f4f6"; const maxDelta = 0.10; // saturate at ±10pp const t = Math.min(Math.abs(delta) / maxDelta, 1); // 0..1 const stops = delta < 0 ? DELTA_NEG : DELTA_POS; // Interpolate across 6 stops (indices 0..5), t maps to position in stops const pos = t * (stops.length - 1); const i = Math.min(Math.floor(pos), stops.length - 2); const f = pos - i; // For negative, go from white (index 5) toward deep red (index 0) → reverse index const idx = delta < 0 ? (stops.length - 1) - i : i; const idxNext = delta < 0 ? idx - 1 : idx + 1; const c0 = stops[idx]; const c1 = stops[Math.max(0, Math.min(stops.length - 1, idxNext))]; const r = Math.round(c0[0] + f * (c1[0] - c0[0])); const g = Math.round(c0[1] + f * (c1[1] - c0[1])); const b = Math.round(c0[2] + f * (c1[2] - c0[2])); return `rgb(${r}, ${g}, ${b})`; } function deltaTextColor(delta: number): string { if (isNaN(delta)) return "#9ca3af"; const t = Math.min(Math.abs(delta) / 0.10, 1); return t > 0.6 ? "#ffffff" : "#1f2937"; } function SideBySideDelta({ pairs, leftData, rightData, colLabels, metric, title, leftLabel, rightLabel, fullConfigs, }: { pairs: [string, string][]; // [leftRowLabel, rightRowLabel] leftData: Map>; rightData: Map>; colLabels: EvalVariant[]; metric: Metric; title: string; leftLabel: string; rightLabel: string; fullConfigs: Map; }) { const isCorr = metric.startsWith("corr"); const formatVal = (v: number) => { if (isNaN(v)) return "-"; return isCorr ? v.toFixed(3) : (v * 100).toFixed(1); }; const formatDelta = (v: number) => { if (isNaN(v)) return "-"; const pp = isCorr ? v : v * 100; const sign = pp > 0 ? "+" : ""; return isCorr ? `${sign}${pp.toFixed(3)}` : `${sign}${pp.toFixed(1)}`; }; const colorFn = isCorr ? corrColor : rdYlGn; // Build pair labels (strip the tc flag difference for a clean display) const pairLabels = pairs.map(([l]) => l); return (

{title}

{/* Left heatmap */}

{leftLabel}

{colLabels.map((col) => ( ))} {pairs.map(([leftLabel_]) => { const evMap = leftData.get(leftLabel_); return ( {colLabels.map((col) => { const cell = evMap?.get(col); const val = cell?.mean ?? NaN; const bg = colorFn(val); const fg = textColor(val, isCorr); return ( ); })} ); })}
Model{col}
{formatVal(val)}
{/* Right heatmap */}

{rightLabel}

{colLabels.map((col) => ( ))} {pairs.map(([, rightLabel_]) => { const evMap = rightData.get(rightLabel_); return ( {colLabels.map((col) => { const cell = evMap?.get(col); const val = cell?.mean ?? NaN; const bg = colorFn(val); const fg = textColor(val, isCorr); return ( ); })} ); })}
Model{col}
{formatVal(val)}
{/* Delta heatmap */}

Delta ({rightLabel} − {leftLabel})

{colLabels.map((col) => ( ))} {pairs.map(([leftLabel_, rightLabel_]) => { const leftEvMap = leftData.get(leftLabel_); const rightEvMap = rightData.get(rightLabel_); return ( {colLabels.map((col) => { const leftVal = leftEvMap?.get(col)?.mean ?? NaN; const rightVal = rightEvMap?.get(col)?.mean ?? NaN; const delta = (!isNaN(leftVal) && !isNaN(rightVal)) ? rightVal - leftVal : NaN; const bg = deltaColor(delta); const fg = deltaTextColor(delta); return ( ); })} ); })}
Model{col}
l === leftLabel_)]} /> {formatDelta(delta)}
); } // --------------------------------------------------------------------------- // Per-task collapsible section // --------------------------------------------------------------------------- function PerTaskCollapsible({ sections, fullConfigs, visibleRows, }: { sections: { label: string; taskCount: number; metrics: { metric: Metric; data: Map> }[] }[]; fullConfigs: Map; visibleRows: Set; }) { const [open, setOpen] = useState(false); return (
{open && (
{sections.map((section) => (

{section.label}

{section.metrics.map(({ metric: m, data }) => { const metricLabel = METRICS.find((x) => x.key === m)?.label ?? m; const metricRowLabels = sortRowLabels(Array.from(data.keys()).filter((l) => visibleRows.has(l))); const metricEvs = EVAL_VARIANTS.filter((ev) => { for (const evMap of data.values()) { if (evMap.has(ev)) return true; } return false; }); if (metricRowLabels.length === 0) return null; return ( ); })}
))}
)}
); } // --------------------------------------------------------------------------- // Per-task comparison collapsible (side-by-side + delta per task) // --------------------------------------------------------------------------- function PerTaskComparisonCollapsible({ filteredRows, pairs, leftLabel, rightLabel, leftFilter, rightFilter, fullConfigs, }: { filteredRows: SummaryRow[]; pairs: [string, string][]; leftLabel: string; rightLabel: string; leftFilter: (r: SummaryRow) => boolean; rightFilter: (r: SummaryRow) => boolean; fullConfigs: Map; }) { const [open, setOpen] = useState(false); const tasks = useMemo(() => { const t = new Set(filteredRows.map((r) => r.task)); return Array.from(t).sort(); }, [filteredRows]); if (tasks.length <= 1) return null; return (
{open && (
{tasks.map((task) => { const taskRows = filteredRows.filter((r) => r.task === task); const leftRows = taskRows.filter(leftFilter); const rightRows = taskRows.filter(rightFilter); return (

{task}

{METRICS.map((m) => { const leftData = aggregateData(leftRows, m.key, getRowLabel).data; const rightData = aggregateData(rightRows, m.key, getRowLabel).data; const existingPairs = pairs.filter(([l, r]) => leftData.has(l) && rightData.has(r)); if (existingPairs.length === 0) return null; const evs = EVAL_VARIANTS.filter((ev) => { for (const evMap of [...leftData.values(), ...rightData.values()]) { if (evMap.has(ev)) return true; } return false; }); return ( ); })}
); })}
)}
); } // --------------------------------------------------------------------------- // Main component // --------------------------------------------------------------------------- export default function HeatmapViewer({ datasetRepo, split: _split = "train", onClose, }: HeatmapViewerProps) { const fullRepo = datasetRepo.includes("/") ? datasetRepo : `${HF_ORG}/${datasetRepo}`; const shortName = datasetRepo.split("/").pop() ?? datasetRepo; const { rows, loading, error, progress, refetch } = useDatasetRows(fullRepo, _split); // --- Filter state --- const [selectedModel, setSelectedModel] = useState("__first__"); const [selectedFamily, setSelectedFamily] = useState("hypernym"); const [selectedTask, setSelectedTask] = useState("__all__"); const [selectedSplit, setSelectedSplit] = useState("test"); const [selectedTCType, setSelectedTCType] = useState("self"); const [selectedMetric, setSelectedMetric] = useState("gen_roc"); const [selectedDomain, setSelectedDomain] = useState("all"); const [viewMode, setViewMode] = useState("heatmap"); const [barVariant, setBarVariant] = useState("raw"); const [comparisonPreset, setComparisonPreset] = useState("all"); // Row visibility: null means "show all", otherwise explicit set from preset or manual toggle const [visibleRows, setVisibleRows] = useState | null>(null); // --- Derived: available splits, tasks, families --- // --- Available base models (model column = base model identity, same for all finetuned variants) --- const availableModels = useMemo(() => { const s = new Set(rows.map((r) => r.model)); return Array.from(s).sort(); }, [rows]); // Display-friendly name: "v6-google_gemma-2-2b" → "gemma-2-2b" const modelDisplayName = useCallback((m: string) => { return m.replace(/^v\d+-[^_]+_/, ""); }, []); // Resolve __first__ to actual first model once data loads const resolvedModel = useMemo(() => { if (selectedModel === "__first__" && availableModels.length > 0) return availableModels[0]; if (availableModels.includes(selectedModel)) return selectedModel; return availableModels[0] ?? ""; }, [selectedModel, availableModels]); const availableSplits = useMemo(() => { const s = new Set(rows.map((r) => r.split)); return Array.from(s).sort(); }, [rows]); const availableTasks = useMemo(() => { let filtered = rows.filter((r) => isValidEvalTask(r.task)); const fam = TASK_FAMILIES.find((f) => f.key === selectedFamily); if (fam) filtered = filtered.filter((r) => r.task.startsWith(fam.prefix)); const tasks = new Set(filtered.map((r) => r.task)); return Array.from(tasks).sort(); }, [rows, selectedFamily]); // Reset task selection when family changes useEffect(() => { setSelectedTask("__all__"); }, [selectedFamily]); // --- Filtering pipeline --- const filteredRows = useMemo(() => { let result = rows; // Model filter — critical: never mix different base models if (resolvedModel) { result = result.filter((r) => r.model === resolvedModel); } // Exclude non-eval tasks (concat, bare family names, etc.) result = result.filter((r) => isValidEvalTask(r.task)); // Only show finetuned models trained on valid combined datasets (Setting U) result = result.filter((r) => isValidFinetunedModel(r)); // Split filter result = result.filter((r) => r.split === selectedSplit); // TC type filter (single-select: each TC type is a different eval condition) result = result.filter((r) => getEvalTCType(r) === selectedTCType); // For finetuned models trained with a specific TC, only show rows where eval TC matches training TC result = result.filter((r) => isMatchedTC(r)); // Family filter — also exclude finetuned models trained on a different family const fam = TASK_FAMILIES.find((f) => f.key === selectedFamily); if (fam) { result = result.filter((r) => r.task.startsWith(fam.prefix)); result = result.filter((r) => { if (!r.finetuned) return true; // base models always shown const trainFam = getTrainingFamily(r.training_config); return trainFam === null || trainFam === selectedFamily; }); } // Specific task filter if (selectedTask !== "__all__") { result = result.filter((r) => r.task === selectedTask); } // Domain filter (OOD vs in-domain) if (selectedDomain !== "all") { result = result.filter((r) => { const family = getTaskFamily(r.task); if (!family) return true; const ood = isOodTask(r.task, r.training_config, family); if (ood === null) return true; return selectedDomain === "ood" ? ood : !ood; }); } return result; }, [rows, resolvedModel, selectedSplit, selectedTCType, selectedFamily, selectedTask, selectedDomain]); // --- Build row label → full config mapping --- const fullConfigs = useMemo(() => { const map = new Map(); for (const r of filteredRows) { const label = getRowLabel(r); if (!map.has(label)) { map.set(label, r.finetuned ? r.training_config : "Base (not finetuned)"); } } return map; }, [filteredRows]); // --- All available row labels (before visibility filter) --- const allRowLabels = useMemo(() => { const labels = new Set(); for (const r of filteredRows) labels.add(getRowLabel(r)); return sortRowLabels(Array.from(labels)); }, [filteredRows]); // Effective visible rows: use explicit selection, preset, or show all const effectiveVisibleRows = useMemo(() => { const available = new Set(allRowLabels); // If a comparison preset restricts rows, apply that if (comparisonPreset === "training-effect" && visibleRows === null) { const effective = new Set(); for (const label of available) { if (TRAINING_EFFECT_ROWS.has(label)) effective.add(label); } return effective.size > 0 ? effective : available; } // Manual selection if (visibleRows !== null) { const effective = new Set(); for (const label of visibleRows) { if (available.has(label)) effective.add(label); } return effective.size > 0 ? effective : available; } // Default: show all except hidden-by-default rows const effective = new Set(); for (const label of available) { if (!DEFAULT_HIDDEN_ROWS.has(label)) effective.add(label); } return effective.size > 0 ? effective : available; }, [visibleRows, allRowLabels, comparisonPreset]); // --- Aggregation --- const aggResult = useMemo(() => { return aggregateData(filteredRows, selectedMetric, getRowLabel); }, [filteredRows, selectedMetric]); const aggData = aggResult.data; const validationError = aggResult.validationError; const rowLabels = useMemo(() => { return sortRowLabels(Array.from(aggData.keys()).filter((l) => effectiveVisibleRows.has(l))); }, [aggData, effectiveVisibleRows]); // Available eval variants (only those with data) const availableEvalVariants = useMemo(() => { const evs = new Set(); for (const evMap of aggData.values()) { for (const ev of evMap.keys()) evs.add(ev); } return EVAL_VARIANTS.filter((ev) => evs.has(ev)); }, [aggData]); // --- Stats --- const stats = useMemo(() => { const taskCount = new Set(filteredRows.map((r) => r.task)).size; const modelCount = new Set(filteredRows.map((r) => r.finetuned ? r.training_config : "base")).size; return { tasks: taskCount, models: modelCount, rows: filteredRows.length }; }, [filteredRows]); // --- Domain-split rows for aggregated heatmaps --- const domainSplitRows = useMemo(() => { const hasDomainSplit = (selectedFamily === "ifeval" || selectedFamily === "hypernym") && selectedTask === "__all__"; if (!hasDomainSplit || selectedDomain !== "all") return null; const oodRows = filteredRows.filter((r) => { const family = getTaskFamily(r.task); if (!family) return false; return isOodTask(r.task, r.training_config, family) === true; }); const idRows = filteredRows.filter((r) => { const family = getTaskFamily(r.task); if (!family) return false; return isOodTask(r.task, r.training_config, family) === false; }); const oodTaskCount = new Set(oodRows.map((r) => r.task)).size; const idTaskCount = new Set(idRows.map((r) => r.task)).size; return { oodRows, idRows, oodTaskCount, idTaskCount }; }, [filteredRows, selectedFamily, selectedTask, selectedDomain]); // --- Multi-metric heatmaps (one set per domain group when applicable) --- type HeatmapSection = { label: string; taskCount: number; metrics: { metric: Metric; data: Map> }[]; }; const heatmapSections = useMemo((): HeatmapSection[] => { if (viewMode !== "heatmap") return []; const buildSection = (sectionRows: SummaryRow[], label: string, taskCount: number): HeatmapSection => { const metrics: { metric: Metric; data: Map> }[] = []; for (const m of METRICS) { metrics.push({ metric: m.key, data: aggregateData(sectionRows, m.key, getRowLabel).data }); } return { label, taskCount, metrics }; }; if (domainSplitRows) { const sections: HeatmapSection[] = []; if (domainSplitRows.oodRows.length > 0) { sections.push(buildSection(domainSplitRows.oodRows, "Out-of-domain", domainSplitRows.oodTaskCount)); } if (domainSplitRows.idRows.length > 0) { sections.push(buildSection(domainSplitRows.idRows, "In-domain", domainSplitRows.idTaskCount)); } return sections; } // For families without domain splits but with multiple tasks, still show aggregate return [buildSection(filteredRows, "All tasks", stats.tasks)]; }, [filteredRows, viewMode, domainSplitRows, stats.tasks]); // --- Per-task heatmaps (when viewing "All tasks" in a family) --- const perTaskSections = useMemo((): HeatmapSection[] => { if (viewMode !== "heatmap") return []; if (selectedTask !== "__all__") return []; // single task already shown in main sections const tasks = new Set(filteredRows.map((r) => r.task)); if (tasks.size <= 1) return []; // no point showing per-task if only one const sorted = Array.from(tasks).sort(); return sorted.map((task) => { const taskRows = filteredRows.filter((r) => r.task === task); const metrics: { metric: Metric; data: Map> }[] = []; for (const m of METRICS) { metrics.push({ metric: m.key, data: aggregateData(taskRows, m.key, getRowLabel).data }); } return { label: task, taskCount: 1, metrics }; }); }, [filteredRows, viewMode, selectedTask]); // --- Comparison preset data --- // For delta comparisons, we need rows from multiple TC types simultaneously const comparisonBaseRows = useMemo(() => { // Same as filteredRows but without TC type filter let result = rows; // Model filter — must match filteredRows if (resolvedModel) { result = result.filter((r) => r.model === resolvedModel); } result = result.filter((r) => isValidEvalTask(r.task)); result = result.filter((r) => isValidFinetunedModel(r)); result = result.filter((r) => r.split === selectedSplit); // For finetuned models trained with a specific TC, only show rows where eval TC matches training TC result = result.filter((r) => isMatchedTC(r)); const fam = TASK_FAMILIES.find((f) => f.key === selectedFamily); if (fam) { result = result.filter((r) => r.task.startsWith(fam.prefix)); result = result.filter((r) => { if (!r.finetuned) return true; const trainFam = getTrainingFamily(r.training_config); return trainFam === null || trainFam === selectedFamily; }); } if (selectedTask !== "__all__") { result = result.filter((r) => r.task === selectedTask); } if (selectedDomain !== "all") { result = result.filter((r) => { const family = getTaskFamily(r.task); if (!family) return true; const ood = isOodTask(r.task, r.training_config, family); if (ood === null) return true; return selectedDomain === "ood" ? ood : !ood; }); } return result; }, [rows, resolvedModel, selectedSplit, selectedFamily, selectedTask, selectedDomain]); // Build aggregated data for rows with a specific training TC flag (in the row label) // Note: this is the TRAINING tc flag (in the model name), NOT the eval TC type const aggByTrainingTC = useCallback((trainingTCFlag: string | null, metric: Metric) => { const subset = filteredRows.filter((r) => { const label = getRowLabel(r); if (trainingTCFlag === null) { // No training TC flag — exclude rows that have any TC flag return !TC_FLAGS.some((f) => hasTCFlag(label, f)); } return hasTCFlag(label, trainingTCFlag); }); return aggregateData(subset, metric, getRowLabel).data; }, [filteredRows]); // Dynamic pairing: find rows that differ only by a TC flag const buildPairs = useCallback((leftFlag: string | null, rightFlag: string): [string, string][] => { const leftLabels = new Set(); const rightLabels = new Set(); for (const r of filteredRows) { const label = getRowLabel(r); if (label === "Base") continue; if (rightFlag && hasTCFlag(label, rightFlag)) { rightLabels.add(label); } else if (leftFlag === null && !TC_FLAGS.some((f) => hasTCFlag(label, f))) { leftLabels.add(label); } else if (leftFlag && hasTCFlag(label, leftFlag)) { leftLabels.add(label); } } // Match: strip TC flag from right label, see if it matches a left label const pairs: [string, string][] = []; for (const rl of rightLabels) { const stripped = stripTCFlag(rl); if (leftLabels.has(stripped)) { pairs.push([stripped, rl]); } } return sortRowLabels(pairs.map(([l]) => l)).map((l) => { const r = pairs.find(([left]) => left === l)![1]; return [l, r] as [string, string]; }); }, [filteredRows]); // + TC-Self pairs: without-tc ↔ tcself const tcSelfPairs = useMemo((): [string, string][] => { if (comparisonPreset !== "plus-tcself") return []; return buildPairs(null, "tcself"); }, [comparisonPreset, buildPairs]); // TC-Self vs TC-Neg: for each model, compare train-tcself+eval-self vs train-tcneg+eval-neg // Base model: same "Base" label, just different eval TC type // Finetuned: pair by stripping tcself/tcneg from training label const aggByMatchedTC = useCallback((evalTC: TCType, metric: Metric) => { // From comparisonBaseRows (no eval TC filter), get rows matching this eval TC type const subset = comparisonBaseRows.filter((r) => getEvalTCType(r) === evalTC); return aggregateData(subset, metric, getRowLabel).data; }, [comparisonBaseRows]); const tcSelfVsNegPairs = useMemo((): [string, string][] => { if (comparisonPreset !== "tcself-vs-tcneg") return []; // Left: models evaluated with self-TC (base + models trained with tcself) const selfRows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "self"); const selfLabels = new Set(selfRows.map(getRowLabel)); // Right: models evaluated with neg-TC (base + models trained with tcneg) const negRows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "neg"); const negLabels = new Set(negRows.map(getRowLabel)); const pairs: [string, string][] = []; // Base: appears in both with same label if (selfLabels.has("Base") && negLabels.has("Base")) { pairs.push(["Base", "Base"]); } // Finetuned: match tcself label ↔ tcneg label (strip tc flag to find pair) for (const sl of selfLabels) { if (sl === "Base") continue; if (!hasTCFlag(sl, "tcself")) continue; const negVersion = sl.replace("tcself", "tcneg"); if (negLabels.has(negVersion)) { pairs.push([sl, negVersion]); } } return sortRowLabels(pairs.map(([l]) => l)).map((l) => { const p = pairs.find(([left]) => left === l)!; return p; }); }, [comparisonPreset, comparisonBaseRows]); // TC-Self vs TC-GPT2: pair tcself-trained (eval=self) with tco-trained (eval=gpt2) // e.g. "Comb-tcself-lo" ↔ "Comb-tco-lo" const tcSelfVsGpt2Pairs = useMemo((): [string, string][] => { if (comparisonPreset !== "tcself-vs-tcgpt2") return []; const selfRows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "self"); const selfLabels = new Set(selfRows.map(getRowLabel)); const gpt2Rows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "gpt2"); const gpt2Labels = new Set(gpt2Rows.map(getRowLabel)); const pairs: [string, string][] = []; // Base: appears in both with same label if (selfLabels.has("Base") && gpt2Labels.has("Base")) { pairs.push(["Base", "Base"]); } // Finetuned: match tcself label ↔ tco label (swap tc flag to find pair) for (const sl of selfLabels) { if (sl === "Base") continue; if (!hasTCFlag(sl, "tcself")) continue; const tcoVersion = sl.replace("tcself", "tco"); if (gpt2Labels.has(tcoVersion)) { pairs.push([sl, tcoVersion]); } } return sortRowLabels(pairs.map(([l]) => l)).map((l) => { const p = pairs.find(([left]) => left === l)!; return p; }); }, [comparisonPreset, comparisonBaseRows]); // Family options for dropdown const familyOptions = useMemo((): { key: string; label: string }[] => { const options: { key: string; label: string }[] = []; for (const fam of TASK_FAMILIES) { const count = new Set(rows.filter((r) => r.task.startsWith(fam.prefix)).map((r) => r.task)).size; if (count > 0) options.push({ key: fam.key, label: `${fam.label} (${count})` }); } return options; }, [rows]); // Task options for dropdown const taskOptions = useMemo(() => { const options = [{ key: "__all__", label: `All tasks (${availableTasks.length})` }]; for (const t of availableTasks) { options.push({ key: t, label: t }); } return options; }, [availableTasks]); // Domain filter visibility const showDomainFilter = selectedFamily === "ifeval" || selectedFamily === "hypernym"; return (
{/* Header */}
{shortName} heatmap {!loading && ( {rows.length.toLocaleString()} rows loaded )}
{/* Loading */} {loading && (

Loading... {progress.loaded.toLocaleString()} / {progress.total.toLocaleString()} rows

)} {/* Error */} {!loading && error && (

Failed to load dataset

{error}

)} {/* Main content */} {!loading && !error && rows.length > 0 && (
{/* Controls sidebar */}
{/* View mode */}
{availableModels.length > 1 && ( ({ key: m, label: modelDisplayName(m) }))} onChange={setSelectedModel} /> )} ({ key: s, label: s }))} onChange={setSelectedSplit} /> {showDomainFilter && ( )} {/* Row visibility */} {allRowLabels.length > 0 && (
{allRowLabels.map((label) => ( ))}
)} {/* Stats */}

Filtered data

{stats.tasks} tasks, {stats.models} configs

{stats.rows.toLocaleString()} rows

{selectedTask === "__all__" && stats.tasks > 1 && (

Averaging across {stats.tasks} tasks

)}
{/* Visualization area */}
{validationError ? (
!!!

Data Integrity Violation

{validationError}

All visualizations are blocked until this is resolved. If you see this, the data pipeline has a bug — different models or training configs are collapsing into the same row label.

) : filteredRows.length === 0 ? (

No data matches the current filters.

) : viewMode === "heatmap" ? ( <> {/* Aggregated sections (domain-split or single aggregate) — hidden when comparison preset is active */} {(comparisonPreset === "all" || comparisonPreset === "training-effect") && heatmapSections.map((section, sIdx) => (
{section.label && (

{section.label} {section.taskCount > 1 && ( (mean over {section.taskCount} tasks) )}

)}
{section.metrics.map(({ metric: m, data }) => { const metricLabel = METRICS.find((x) => x.key === m)?.label ?? m; const metricRowLabels = sortRowLabels(Array.from(data.keys()).filter((l) => effectiveVisibleRows.has(l))); const metricEvs = EVAL_VARIANTS.filter((ev) => { for (const evMap of data.values()) { if (evMap.has(ev)) return true; } return false; }); if (metricRowLabels.length === 0) return null; return ( ); })}
))} {/* Comparison: + TC-Self (side-by-side + delta) */} {comparisonPreset === "plus-tcself" && tcSelfPairs.length === 0 && (

No matched pairs found for + TC-Self comparison. Check that both non-TC and tcself models exist for this family.

)} {comparisonPreset === "plus-tcself" && tcSelfPairs.length > 0 && METRICS.map((m) => { const leftData = aggByTrainingTC(null, m.key); const rightData = aggByTrainingTC("tcself", m.key); const existingPairs = tcSelfPairs.filter(([l, r]) => leftData.has(l) && rightData.has(r)); if (existingPairs.length === 0) return null; const evs = EVAL_VARIANTS.filter((ev) => { for (const evMap of [...leftData.values(), ...rightData.values()]) { if (evMap.has(ev)) return true; } return false; }); return ( ); })} {/* Comparison: TC-Self vs TC-Neg (side-by-side + delta) */} {comparisonPreset === "tcself-vs-tcneg" && tcSelfVsNegPairs.length === 0 && (

No matched pairs found for TC-Self vs TC-Neg comparison. Check that both tcself and tcneg models exist for this family.

)} {comparisonPreset === "tcself-vs-tcneg" && tcSelfVsNegPairs.length > 0 && METRICS.map((m) => { const leftData = aggByMatchedTC("self", m.key); const rightData = aggByMatchedTC("neg", m.key); const existingPairs = tcSelfVsNegPairs.filter(([l, r]) => leftData.has(l) && rightData.has(r)); if (existingPairs.length === 0) return null; const evs = EVAL_VARIANTS.filter((ev) => { for (const evMap of [...leftData.values(), ...rightData.values()]) { if (evMap.has(ev)) return true; } return false; }); return ( ); })} {/* Comparison: TC-Self vs TC-GPT2 (side-by-side + delta) */} {comparisonPreset === "tcself-vs-tcgpt2" && tcSelfVsGpt2Pairs.length === 0 && (

No matched pairs found for TC-Self vs TC-GPT2 comparison.

)} {comparisonPreset === "tcself-vs-tcgpt2" && tcSelfVsGpt2Pairs.length > 0 && METRICS.map((m) => { const leftData = aggByMatchedTC("self", m.key); const rightData = aggByMatchedTC("gpt2", m.key); const existingPairs = tcSelfVsGpt2Pairs.filter(([l, r]) => leftData.has(l) && rightData.has(r)); if (existingPairs.length === 0) return null; const evs = EVAL_VARIANTS.filter((ev) => { for (const evMap of [...leftData.values(), ...rightData.values()]) { if (evMap.has(ev)) return true; } return false; }); return ( ); })} {/* Per-task breakdown — regular heatmap for all/training-effect */} {(comparisonPreset === "all" || comparisonPreset === "training-effect") && perTaskSections.length > 0 && ( )} {/* Per-task breakdown — side-by-side + delta for + TC-Self */} {comparisonPreset === "plus-tcself" && selectedTask === "__all__" && ( { const label = getRowLabel(r); return !TC_FLAGS.some((f) => hasTCFlag(label, f)); }} rightFilter={(r) => hasTCFlag(getRowLabel(r), "tcself")} fullConfigs={fullConfigs} /> )} {/* Per-task breakdown — side-by-side + delta for TC-Self vs TC-Neg */} {comparisonPreset === "tcself-vs-tcneg" && selectedTask === "__all__" && ( getEvalTCType(r) === "self"} rightFilter={(r) => getEvalTCType(r) === "neg"} fullConfigs={fullConfigs} /> )} {/* Per-task breakdown — side-by-side + delta for TC-Self vs TC-GPT2 */} {comparisonPreset === "tcself-vs-tcgpt2" && selectedTask === "__all__" && ( getEvalTCType(r) === "self"} rightFilter={(r) => getEvalTCType(r) === "gpt2"} fullConfigs={fullConfigs} /> )} ) : (
{METRICS.map((m) => { const metricAgg = aggregateData(filteredRows, m.key, getRowLabel).data; const metricRowLabels = sortRowLabels(Array.from(metricAgg.keys()).filter((l) => effectiveVisibleRows.has(l))); if (metricRowLabels.length === 0) return null; return ( ev === "raw" || ev === "tc")} metric={m.key} title={`${m.label}${selectedTask === "__all__" && stats.tasks > 1 ? ` (mean ± SE over ${stats.tasks} tasks)` : ""}`} /> ); })}
)}
)}
); }