Spaces:
Sleeping
Sleeping
| import { useState, useEffect, useCallback, useMemo } from "react"; | |
| import { HF_ORG } from "../../../config"; | |
| // --------------------------------------------------------------------------- | |
| // Types | |
| // --------------------------------------------------------------------------- | |
| interface HeatmapViewerProps { | |
| datasetRepo: string; | |
| split?: string; | |
| onClose: () => void; | |
| } | |
| interface SummaryRow { | |
| model: string; | |
| hf_model_name: string; | |
| local_model_name: string; | |
| task: string; | |
| split: string; | |
| self_tc: boolean; | |
| neg_tc: boolean; | |
| gpt2_tc: boolean; | |
| finetuned: boolean; | |
| training_config: string; | |
| eval_variant: string; | |
| gen_roc: number | null; | |
| val_roc: number | null; | |
| val_acc: number | null; | |
| corr: number | null; | |
| corr_pos: number | null; | |
| corr_neg: number | null; | |
| n_samples: number; | |
| filename: string; | |
| } | |
| type Metric = "gen_roc" | "val_roc" | "val_acc" | "corr" | "corr_pos" | "corr_neg"; | |
| type EvalVariant = "raw" | "tc" | "lenorm" | "tc+lenorm"; | |
| type TCType = "none" | "self" | "neg" | "gpt2"; | |
| type ViewMode = "heatmap" | "bar"; | |
| type DomainFilter = "all" | "ood" | "in-domain"; | |
| type ComparisonPreset = "all" | "training-effect" | "plus-tcself" | "tcself-vs-tcneg" | "tcself-vs-tcgpt2"; | |
| // --------------------------------------------------------------------------- | |
| // Constants | |
| // --------------------------------------------------------------------------- | |
| const HF_DATASETS_API = "https://datasets-server.huggingface.co"; | |
| const METRICS: { key: Metric; label: string }[] = [ | |
| { key: "gen_roc", label: "Gen ROC" }, | |
| { key: "val_roc", label: "Val ROC" }, | |
| { key: "val_acc", label: "Val Acc" }, | |
| { key: "corr", label: "Correlation" }, | |
| { key: "corr_pos", label: "Corr (pos)" }, | |
| { key: "corr_neg", label: "Corr (neg)" }, | |
| ]; | |
| const EVAL_VARIANTS: EvalVariant[] = ["raw", "tc", "lenorm", "tc+lenorm"]; | |
| const TC_TYPES: { key: TCType; label: string }[] = [ | |
| { key: "self", label: "Self TC" }, | |
| { key: "neg", label: "Neg TC" }, | |
| { key: "gpt2", label: "GPT-2 TC" }, | |
| ]; | |
| // Plotly default color cycle for bar charts | |
| const BAR_COLORS = [ | |
| "#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A", | |
| "#19D3F3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52", | |
| ]; | |
| const TASK_FAMILIES = [ | |
| { key: "hypernym", label: "Hypernym", prefix: "hypernym-" }, | |
| { key: "ifeval", label: "IFEval", prefix: "ifeval-" }, | |
| { key: "plausibleqa", label: "PlausibleQA", prefix: "plausibleqa-" }, | |
| { key: "ambigqa", label: "AmbigQA", prefix: "ambigqa-" }, | |
| ]; | |
| // Finetuned models must use force-same-x and the correct combined training dataset. | |
| // Hypernym and IFEval have single-task variants that must be excluded. | |
| function isValidFinetunedModel(row: SummaryRow): boolean { | |
| if (!row.finetuned) return true; | |
| const tc = row.training_config; | |
| if (!tc.includes("force-same-x")) return false; | |
| // Hypernym models must be trained on the combined "double" dataset | |
| if (tc.includes("hypernym-") && !tc.includes("hypernym-concat-bananas-to-dogs-double")) return false; | |
| // IFEval models must be trained on the concat dataset | |
| if (tc.includes("ifeval-") && !tc.includes("ifeval-concat")) return false; | |
| return true; | |
| } | |
| const COMPARISON_PRESETS: { key: ComparisonPreset; label: string }[] = [ | |
| { key: "all", label: "All models" }, | |
| { key: "training-effect", label: "Training Effect" }, | |
| { key: "plus-tcself", label: "+ TC-Self" }, | |
| { key: "tcself-vs-tcneg", label: "TC-Self vs TC-Neg" }, | |
| { key: "tcself-vs-tcgpt2", label: "TC-Self vs TC-GPT2" }, | |
| ]; | |
| // Preset row sets: | |
| // Training Effect: Base + lo/semi × {Pref, SFT, Comb-v} | |
| const TRAINING_EFFECT_ROWS = new Set([ | |
| "Base", "Pref-lo", "SFT-lo", "Comb-v-lo", "Pref-semi", "SFT-semi", "Comb-v-semi", | |
| ]); | |
| const DEFAULT_HIDDEN_ROWS = new Set([ | |
| "Pref-v-lo", | |
| "Pref-tcself-v-lo", | |
| "Pref-tcself-norm-lo", | |
| "Pref-tcself-norm-v-lo", | |
| "Comb-tcself-norm-v-lo", | |
| "SFT-tcself-norm-semi", | |
| "Comb-tcself-norm-v-semi", | |
| // PlausibleQA models without lo/semi — not comparable | |
| "Comb-tcself", | |
| "Comb-tcself-v", | |
| "SFT-tcself-v", | |
| ]); | |
| // Strip a TC flag from a row label to get the "base" label for pairing | |
| const TC_FLAGS = ["tco", "tcself", "tcneg"]; | |
| function stripTCFlag(label: string): string { | |
| const parts = label.split("-"); | |
| return parts.filter((p) => !TC_FLAGS.includes(p)).join("-"); | |
| } | |
| function hasTCFlag(label: string, flag: string): boolean { | |
| return label.split("-").includes(flag); | |
| } | |
| // IFEval prompts 1-21 are OOD (test-only, never trained on) | |
| const IFEVAL_OOD_MAX = 21; | |
| // Hypernym: fixed in-domain / OOD split | |
| const HYPERNYM_IN_DOMAIN = new Set([ | |
| "bananas", "bazookas", "cabinets", "cars", "chairs", "crows", "diapers", "dogs", | |
| ]); | |
| const HYPERNYM_OOD = new Set([ | |
| "ducklings", "elephants", "guns", "hammers", "helmets", "jackets", "kayaks", "kites", "mirrors", | |
| ]); | |
| const HYPERNYM_VALID = new Set([...HYPERNYM_IN_DOMAIN, ...HYPERNYM_OOD]); | |
| // Valid eval task patterns — excludes training tasks like concat, bare family names, etc. | |
| function isValidEvalTask(task: string): boolean { | |
| if (task.startsWith("hypernym-")) { | |
| const subtask = task.replace("hypernym-", ""); | |
| return HYPERNYM_VALID.has(subtask); | |
| } | |
| if (task.startsWith("ifeval-")) { | |
| return /^ifeval-prompt[_-]\d+$/.test(task); | |
| } | |
| if (task.startsWith("plausibleqa-")) { | |
| return task !== "plausibleqa"; // must have a subtask | |
| } | |
| if (task.startsWith("ambigqa-")) { | |
| return task !== "ambigqa"; // must have a subtask | |
| } | |
| // Bare family names without subtask suffix are training tasks | |
| if (task === "ambigqa" || task === "plausibleqa" || task === "hypernym" || task === "ifeval") { | |
| return false; | |
| } | |
| return true; | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Color scale: RdYlGn for [0, 1] mapped to percentages | |
| // --------------------------------------------------------------------------- | |
| // Plotly RdYlGn colorscale stops (matches Plotly's built-in) | |
| const RDYLGN_STOPS: [number, [number, number, number]][] = [ | |
| [0.0, [165, 0, 38]], | |
| [0.1, [215, 48, 39]], | |
| [0.2, [244, 109, 67]], | |
| [0.3, [253, 174, 97]], | |
| [0.4, [254, 224, 139]], | |
| [0.5, [255, 255, 191]], | |
| [0.6, [217, 239, 139]], | |
| [0.7, [166, 217, 106]], | |
| [0.8, [102, 189, 99]], | |
| [0.9, [26, 152, 80]], | |
| [1.0, [0, 104, 55]], | |
| ]; | |
| function interpolateRdYlGn(t: number): string { | |
| const clamped = Math.max(0, Math.min(1, t)); | |
| // Find the two stops to interpolate between | |
| for (let i = 0; i < RDYLGN_STOPS.length - 1; i++) { | |
| const [t0, c0] = RDYLGN_STOPS[i]; | |
| const [t1, c1] = RDYLGN_STOPS[i + 1]; | |
| if (clamped >= t0 && clamped <= t1) { | |
| const f = (clamped - t0) / (t1 - t0); | |
| const r = Math.round(c0[0] + f * (c1[0] - c0[0])); | |
| const g = Math.round(c0[1] + f * (c1[1] - c0[1])); | |
| const b = Math.round(c0[2] + f * (c1[2] - c0[2])); | |
| return `rgb(${r}, ${g}, ${b})`; | |
| } | |
| } | |
| return `rgb(0, 104, 55)`; | |
| } | |
| function rdYlGn(value: number): string { | |
| if (isNaN(value)) return "#f3f4f6"; // light gray for no data | |
| return interpolateRdYlGn(value); | |
| } | |
| function corrColor(value: number): string { | |
| if (isNaN(value)) return "#f3f4f6"; | |
| return rdYlGn((value + 1) / 2); | |
| } | |
| function textColor(bgValue: number, isCorr: boolean): string { | |
| if (isNaN(bgValue)) return "#9ca3af"; | |
| const t = isCorr ? (bgValue + 1) / 2 : bgValue; | |
| // Dark text on the light middle, white text on dark extremes | |
| return t > 0.3 && t < 0.7 ? "#1f2937" : "#ffffff"; | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Helpers | |
| // --------------------------------------------------------------------------- | |
| function getEvalTCType(row: SummaryRow): TCType { | |
| if (row.self_tc) return "self"; | |
| if (row.neg_tc) return "neg"; | |
| if (row.gpt2_tc) return "gpt2"; | |
| return "none"; | |
| } | |
| function getTrainingTCType(row: SummaryRow): TCType { | |
| const tc = row.training_config; | |
| if (tc.includes("tc-online") || tc.includes("tc_online")) return "gpt2"; | |
| if (tc.includes("tc-self") || tc.includes("tc_self")) return "self"; | |
| if (tc.includes("tc-neg") || tc.includes("tc_neg")) return "neg"; | |
| return "none"; | |
| } | |
| /** For finetuned models, check that eval TC matches training TC */ | |
| function isMatchedTC(row: SummaryRow): boolean { | |
| if (!row.finetuned) return true; // base models: any eval TC is valid | |
| const trainTC = getTrainingTCType(row); | |
| if (trainTC === "none") return true; // model not trained with TC: any eval TC is valid | |
| const evalTC = getEvalTCType(row); | |
| return evalTC === trainTC; | |
| } | |
| function getTaskFamily(task: string): string | null { | |
| for (const fam of TASK_FAMILIES) { | |
| if (task.startsWith(fam.prefix)) return fam.key; | |
| } | |
| return null; | |
| } | |
| function parseIfevalPromptNum(task: string): number | null { | |
| // ifeval-prompt_42 or ifeval-prompt-42 | |
| const m = task.match(/ifeval-prompt[_-](\d+)/); | |
| return m ? parseInt(m[1], 10) : null; | |
| } | |
| function isOodTask(task: string, trainingConfig: string, family: string): boolean | null { | |
| // Returns true=OOD, false=in-domain, null=unknown/not applicable | |
| if (family === "plausibleqa" || family === "ambigqa") return true; // always OOD | |
| if (family === "ifeval") { | |
| const num = parseIfevalPromptNum(task); | |
| if (num === null) return null; | |
| return num <= IFEVAL_OOD_MAX; | |
| } | |
| if (family === "hypernym") { | |
| const subtask = task.replace("hypernym-", ""); | |
| if (HYPERNYM_OOD.has(subtask)) return true; | |
| if (HYPERNYM_IN_DOMAIN.has(subtask)) return false; | |
| return null; // unknown subtask | |
| } | |
| return null; | |
| } | |
| function getTrainingFamily(config: string): string | null { | |
| for (const fam of TASK_FAMILIES) { | |
| if (config.includes(fam.prefix)) return fam.key; | |
| } | |
| return null; | |
| } | |
| function extractFloat(pattern: RegExp, text: string): number | null { | |
| const m = text.match(pattern); | |
| return m ? parseFloat(m[1]) : null; | |
| } | |
| function parseTrainingMode(config: string): string { | |
| // Parse pref/nllv/nllg weights from training config. | |
| // Defaults when absent: pref=1.0, nllv=0.0, nllg=0.0 | |
| const pref = extractFloat(/(?:^|[_-])pref(\d+(?:\.\d+)?)/, config) ?? 1.0; | |
| const nllv = extractFloat(/nllv(\d+(?:\.\d+)?)/, config) ?? 0.0; | |
| const nllg = extractFloat(/nllg(\d+(?:\.\d+)?)/, config) ?? 0.0; | |
| const isSFT = pref === 0.0 && nllv === 1.0 && nllg === 1.0; | |
| const isPref = nllv === 0.0 && nllg === 0.0; | |
| const isComb = nllv === 1.0 && nllg === 1.0 && pref === 1.0; | |
| if (isSFT) return "SFT"; | |
| if (isComb) return "Comb"; | |
| if (isPref) return "Pref"; | |
| return "Pref"; // fallback | |
| } | |
| function buildRowLabel(config: string): string { | |
| const mode = parseTrainingMode(config); | |
| const flags: string[] = []; | |
| // TC flags are mutually exclusive | |
| const hasTco = config.includes("_tc-online_") || config.includes("-tc-online-"); | |
| const hasTcself = config.includes("_tc-self_") || config.includes("-tc-self-"); | |
| const hasTcneg = config.includes("_tc-neg_") || config.includes("-tc-neg-"); | |
| const tcCount = [hasTco, hasTcself, hasTcneg].filter(Boolean).length; | |
| if (tcCount > 1) { | |
| console.warn(`Multiple TC flags in training config (expected at most 1): ${config}`); | |
| } | |
| if (hasTco) flags.push("tco"); | |
| if (hasTcself) flags.push("tcself"); | |
| if (hasTcneg) flags.push("tcneg"); | |
| // Optional independent flags | |
| if (config.includes("_lenorm_") || config.includes("-lenorm-")) flags.push("norm"); | |
| if (config.includes("_vallogodds") || config.includes("-vallogodds")) flags.push("v"); | |
| // Data regime flags | |
| if (config.includes("labelonly")) flags.push("lo"); | |
| if (config.includes("semi")) flags.push("semi"); | |
| const parts = [mode, ...flags]; | |
| return parts.join("-"); | |
| } | |
| function getRowLabel(row: SummaryRow): string { | |
| if (!row.finetuned) return "Base"; | |
| return buildRowLabel(row.training_config); | |
| } | |
| // Display-friendly row label: reorder to [regime, tc, mode, extras], spaces, Pref→RankAlign | |
| function displayRowLabel(label: string): string { | |
| if (label === "Base") return "Base"; | |
| const parts = label.split("-"); | |
| const mode = parts[0] === "Pref" ? "RankAlign" : parts[0]; | |
| const flags = parts.slice(1); | |
| // Extract known flag groups | |
| const regime = flags.filter((f) => f === "lo" || f === "semi"); | |
| const tc = flags.filter((f) => f === "tcself" || f === "tcneg" || f === "tco"); | |
| const extras = flags.filter((f) => !["lo", "semi", "tcself", "tcneg", "tco"].includes(f)); | |
| // Order: regime, tc, mode, extras | |
| return [...regime, ...tc, mode, ...extras].join(" "); | |
| } | |
| // JSX version with colored TC flags for use in HTML contexts | |
| function DisplayRowLabel({ label }: { label: string }) { | |
| if (label === "Base") return <>Base</>; | |
| const parts = label.split("-"); | |
| const mode = parts[0] === "Pref" ? "RankAlign" : parts[0]; | |
| const flags = parts.slice(1); | |
| const regime = flags.filter((f) => f === "lo" || f === "semi"); | |
| const tc = flags.filter((f) => f === "tcself" || f === "tcneg" || f === "tco"); | |
| const extras = flags.filter((f) => !["lo", "semi", "tcself", "tcneg", "tco"].includes(f)); | |
| const tokens = [...regime, ...tc, mode, ...extras]; | |
| const tcColor: Record<string, string> = { tcself: "#f87171", tcneg: "#fb923c", tco: "#a78bfa" }; | |
| return ( | |
| <> | |
| {tokens.map((t, i) => ( | |
| <span key={i}> | |
| {i > 0 && " "} | |
| {tcColor[t] ? <span style={{ color: tcColor[t] }}>{t}</span> : t} | |
| </span> | |
| ))} | |
| </> | |
| ); | |
| } | |
| // SVG version with colored TC flags (uses tspan) | |
| function SvgRowLabel({ label, maxLen = 22 }: { label: string; maxLen?: number }) { | |
| if (label === "Base") return <>Base</>; | |
| const parts = label.split("-"); | |
| const mode = parts[0] === "Pref" ? "RankAlign" : parts[0]; | |
| const flags = parts.slice(1); | |
| const regime = flags.filter((f) => f === "lo" || f === "semi"); | |
| const tc = flags.filter((f) => f === "tcself" || f === "tcneg" || f === "tco"); | |
| const extras = flags.filter((f) => !["lo", "semi", "tcself", "tcneg", "tco"].includes(f)); | |
| const tokens = [...regime, ...tc, mode, ...extras]; | |
| const full = tokens.join(" "); | |
| const display = full.length > maxLen ? full.slice(0, maxLen - 2) + ".." : full; | |
| const tcColor: Record<string, string> = { tcself: "#f87171", tcneg: "#fb923c", tco: "#a78bfa" }; | |
| // Re-tokenize the display string to color tc flags | |
| const displayTokens = display.split(" "); | |
| return ( | |
| <> | |
| {displayTokens.map((t, i) => ( | |
| <tspan key={i} fill={tcColor[t] || undefined}> | |
| {i > 0 && " "}{t} | |
| </tspan> | |
| ))} | |
| </> | |
| ); | |
| } | |
| function mean(values: number[]): number { | |
| if (values.length === 0) return NaN; | |
| return values.reduce((a, b) => a + b, 0) / values.length; | |
| } | |
| function stdErr(values: number[]): number { | |
| if (values.length < 2) return 0; | |
| const m = mean(values); | |
| const variance = values.reduce((sum, v) => sum + (v - m) ** 2, 0) / (values.length - 1); | |
| return Math.sqrt(variance) / Math.sqrt(values.length); | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Data fetching — downloads the parquet file directly (1 request, ~1MB) | |
| // --------------------------------------------------------------------------- | |
| function useDatasetRows(repo: string, _split: string) { | |
| const [rows, setRows] = useState<SummaryRow[]>([]); | |
| const [loading, setLoading] = useState(true); | |
| const [error, setError] = useState<string | null>(null); | |
| const [progress, setProgress] = useState({ loaded: 0, total: 0 }); | |
| const fetchAll = useCallback(async () => { | |
| setLoading(true); | |
| setError(null); | |
| setRows([]); | |
| try { | |
| // Step 1: get the parquet file URL from HF datasets server | |
| const metaUrl = `${HF_DATASETS_API}/parquet?dataset=${encodeURIComponent(repo)}`; | |
| const metaResp = await fetch(metaUrl); | |
| if (!metaResp.ok) { | |
| throw new Error(`Failed to get parquet info: ${metaResp.status}`); | |
| } | |
| const metaData = await metaResp.json(); | |
| const parquetFiles = metaData.parquet_files ?? []; | |
| const matchingFile = parquetFiles.find((f: { split: string }) => f.split === _split) ?? parquetFiles[0]; | |
| if (!matchingFile) { | |
| throw new Error("No parquet files found for this dataset"); | |
| } | |
| setProgress({ loaded: 0, total: matchingFile.size ?? 0 }); | |
| // Step 2: download the parquet file directly from HF (static file, no rate limits) | |
| const parquetUrl: string = matchingFile.url; | |
| const { asyncBufferFromUrl, parquetRead } = await import("hyparquet"); | |
| const file = await asyncBufferFromUrl({ url: parquetUrl }); | |
| // Step 3: parse all rows | |
| const allRows: SummaryRow[] = []; | |
| await parquetRead({ | |
| file, | |
| rowFormat: "object", | |
| onComplete: (data: Record<string, unknown>[]) => { | |
| for (const row of data) { | |
| allRows.push(row as unknown as SummaryRow); | |
| } | |
| }, | |
| }); | |
| setProgress({ loaded: allRows.length, total: allRows.length }); | |
| setRows(allRows); | |
| } catch (err) { | |
| setError(err instanceof Error ? err.message : String(err)); | |
| } finally { | |
| setLoading(false); | |
| } | |
| }, [repo, _split]); | |
| useEffect(() => { | |
| fetchAll(); | |
| }, [fetchAll]); | |
| return { rows, loading, error, progress, refetch: fetchAll }; | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Aggregation | |
| // --------------------------------------------------------------------------- | |
| interface AggCell { | |
| mean: number; | |
| se: number; | |
| n: number; | |
| } | |
| /** | |
| * Validate that each row label maps to exactly one model identity. | |
| * Checks BOTH the base model (r.model) AND training_config. | |
| * If two different models or configs produce the same row label, | |
| * we're silently averaging different models — a critical bug. | |
| * | |
| * Returns null if clean, or an error message string if collisions found. | |
| */ | |
| function validateRowLabels(rows: SummaryRow[], groupByRow: (r: SummaryRow) => string): string | null { | |
| // Check 1: all rows must come from the same base model | |
| const baseModels = new Set(rows.map((r) => r.model)); | |
| if (baseModels.size > 1) { | |
| return ( | |
| `DATA INTEGRITY ERROR: Multiple base models in the same view!\n` + | |
| `Found ${baseModels.size} different models being mixed together:\n` + | |
| Array.from(baseModels).map((m) => ` • ${m}`).join("\n") + | |
| `\n\nThis means results from different models are being averaged. ` + | |
| `Filter by a single model before displaying.` | |
| ); | |
| } | |
| // Check 2: each row label maps to exactly one training_config | |
| const labelToConfigs = new Map<string, Set<string>>(); | |
| for (const row of rows) { | |
| const label = groupByRow(row); | |
| const config = row.finetuned ? row.training_config : "__base__"; | |
| if (!labelToConfigs.has(label)) labelToConfigs.set(label, new Set()); | |
| labelToConfigs.get(label)!.add(config); | |
| } | |
| const collisions: string[] = []; | |
| for (const [label, configs] of labelToConfigs) { | |
| if (configs.size > 1) { | |
| collisions.push( | |
| `"${label}" → ${configs.size} configs: ${Array.from(configs).join(", ")}` | |
| ); | |
| } | |
| } | |
| if (collisions.length > 0) { | |
| return ( | |
| `DATA INTEGRITY ERROR: Row label collisions detected!\n` + | |
| `The following labels map to multiple different model configs ` + | |
| `(their results are being silently averaged):\n\n` + | |
| collisions.join("\n") | |
| ); | |
| } | |
| return null; | |
| } | |
| interface AggResult { | |
| data: Map<string, Map<EvalVariant, AggCell>>; | |
| validationError: string | null; | |
| } | |
| function aggregateData( | |
| rows: SummaryRow[], | |
| metric: Metric, | |
| groupByRow: (r: SummaryRow) => string, | |
| ): AggResult { | |
| // Validate: each row label must correspond to one model | |
| const validationError = validateRowLabels(rows, groupByRow); | |
| // Group: rowLabel → evalVariant → values[] | |
| const groups = new Map<string, Map<EvalVariant, number[]>>(); | |
| for (const row of rows) { | |
| const rl = groupByRow(row); | |
| const ev = row.eval_variant as EvalVariant; | |
| if (!EVAL_VARIANTS.includes(ev)) continue; | |
| const val = row[metric]; | |
| if (val === null || val === undefined || isNaN(val as number)) continue; | |
| if (!groups.has(rl)) groups.set(rl, new Map()); | |
| const evMap = groups.get(rl)!; | |
| if (!evMap.has(ev)) evMap.set(ev, []); | |
| evMap.get(ev)!.push(val as number); | |
| } | |
| // Compute aggregates | |
| const result = new Map<string, Map<EvalVariant, AggCell>>(); | |
| for (const [rl, evMap] of groups) { | |
| const aggMap = new Map<EvalVariant, AggCell>(); | |
| for (const [ev, vals] of evMap) { | |
| aggMap.set(ev, { mean: mean(vals), se: stdErr(vals), n: vals.length }); | |
| } | |
| result.set(rl, aggMap); | |
| } | |
| return { data: result, validationError }; | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Sort row labels: Base first, then alphabetical | |
| // --------------------------------------------------------------------------- | |
| // Sort order: Base first, then by [tc group, data regime, mode] | |
| // tc group: no-tc (0) < tcself (1) < tcneg (2) < tco (3) | |
| // data regime: lo (0) < semi (1) < other (2) | |
| // mode: Pref (0) < SFT (1) < Comb (2) < other (3) | |
| function rowSortKey(label: string): [number, number, number, number, string] { | |
| if (label === "Base") return [-1, 0, 0, 0, ""]; | |
| const parts = label.split("-"); | |
| // TC group | |
| let tcGroup = 0; | |
| if (parts.includes("tcself")) tcGroup = 1; | |
| else if (parts.includes("tcneg")) tcGroup = 2; | |
| else if (parts.includes("tco")) tcGroup = 3; | |
| // Data regime | |
| let regime = 2; | |
| if (parts.includes("lo")) regime = 0; | |
| else if (parts.includes("semi")) regime = 1; | |
| // Mode (first part) | |
| const mode = parts[0] ?? ""; | |
| const modeOrder: Record<string, number> = { Pref: 0, SFT: 1, Comb: 2 }; | |
| // Remaining flags for tiebreak | |
| const remaining = parts.filter((p) => !["Pref", "SFT", "Comb"].includes(p)).join("-"); | |
| return [0, tcGroup, regime, modeOrder[mode] ?? 3, remaining]; | |
| } | |
| function sortRowLabels(labels: string[]): string[] { | |
| return [...labels].sort((a, b) => { | |
| const ka = rowSortKey(a); | |
| const kb = rowSortKey(b); | |
| for (let i = 0; i < 4; i++) { | |
| if (ka[i] !== kb[i]) return (ka[i] as number) - (kb[i] as number); | |
| } | |
| return ka[4].localeCompare(kb[4]); | |
| }); | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Sub-components | |
| // --------------------------------------------------------------------------- | |
| function Dropdown<T extends string>({ | |
| label, | |
| value, | |
| options, | |
| onChange, | |
| }: { | |
| label: string; | |
| value: T; | |
| options: { key: T; label: string }[]; | |
| onChange: (v: T) => void; | |
| }) { | |
| return ( | |
| <div className="flex flex-col gap-1"> | |
| <label className="text-[10px] uppercase tracking-wider text-gray-500 font-medium">{label}</label> | |
| <select | |
| value={value} | |
| onChange={(e) => onChange(e.target.value as T)} | |
| className="bg-gray-800 text-gray-200 text-xs border border-gray-700 rounded px-2 py-1.5 focus:outline-none focus:border-cyan-600" | |
| > | |
| {options.map((o) => ( | |
| <option key={o.key} value={o.key}> | |
| {o.label} | |
| </option> | |
| ))} | |
| </select> | |
| </div> | |
| ); | |
| } | |
| function MultiSelect<T extends string>({ | |
| label, | |
| selected, | |
| options, | |
| onChange, | |
| }: { | |
| label: string; | |
| selected: Set<T>; | |
| options: { key: T; label: string }[]; | |
| onChange: (s: Set<T>) => void; | |
| }) { | |
| const toggle = (key: T) => { | |
| const next = new Set(selected); | |
| if (next.has(key)) next.delete(key); | |
| else next.add(key); | |
| onChange(next); | |
| }; | |
| return ( | |
| <div className="flex flex-col gap-1"> | |
| <label className="text-[10px] uppercase tracking-wider text-gray-500 font-medium">{label}</label> | |
| <div className="flex flex-wrap gap-1"> | |
| {options.map((o) => ( | |
| <button | |
| key={o.key} | |
| onClick={() => toggle(o.key)} | |
| className={`text-xs px-2 py-1 rounded border transition-colors ${ | |
| selected.has(o.key) | |
| ? "bg-cyan-800/60 text-cyan-200 border-cyan-600/60" | |
| : "bg-gray-800 text-gray-500 border-gray-700 hover:text-gray-300" | |
| }`} | |
| > | |
| {o.label} | |
| </button> | |
| ))} | |
| </div> | |
| </div> | |
| ); | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Heatmap component (pure HTML/CSS) | |
| // --------------------------------------------------------------------------- | |
| function HeatmapGrid({ | |
| data, | |
| rowLabels, | |
| colLabels, | |
| metric, | |
| title, | |
| fullConfigs, | |
| }: { | |
| data: Map<string, Map<EvalVariant, AggCell>>; | |
| rowLabels: string[]; | |
| colLabels: EvalVariant[]; | |
| metric: Metric; | |
| title: string; | |
| fullConfigs: Map<string, string>; | |
| }) { | |
| const isCorr = metric.startsWith("corr"); | |
| const formatVal = (v: number) => { | |
| if (isNaN(v)) return "-"; | |
| return isCorr ? v.toFixed(3) : (v * 100).toFixed(1); | |
| }; | |
| const colorFn = isCorr ? corrColor : rdYlGn; | |
| return ( | |
| <div className="flex flex-col gap-2"> | |
| <h3 className="text-xs font-medium text-gray-300">{title}</h3> | |
| <div className="overflow-x-auto"> | |
| <table className="border-collapse text-xs"> | |
| <thead> | |
| <tr> | |
| <th className="text-left py-1 px-2 text-gray-500 font-normal min-w-[160px] max-w-[240px]">Model</th> | |
| {colLabels.map((col) => ( | |
| <th key={col} className="text-center py-1 px-3 text-gray-400 font-medium min-w-[72px]"> | |
| {col} | |
| </th> | |
| ))} | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {rowLabels.map((rl) => { | |
| const evMap = data.get(rl); | |
| return ( | |
| <tr key={rl} className="group"> | |
| <td | |
| className="py-1 px-2 text-gray-300 font-mono truncate max-w-[240px]" | |
| title={fullConfigs.get(rl) || rl} | |
| > | |
| <DisplayRowLabel label={rl} /> | |
| </td> | |
| {colLabels.map((col) => { | |
| const cell = evMap?.get(col); | |
| const val = cell?.mean ?? NaN; | |
| const bg = colorFn(isCorr ? val : val); | |
| const fg = textColor(val, isCorr); | |
| return ( | |
| <td | |
| key={col} | |
| className="text-center py-1.5 px-2 font-mono border border-gray-800/50 cursor-default transition-all hover:ring-1 hover:ring-cyan-500/50" | |
| style={{ backgroundColor: bg, color: fg }} | |
| title={cell ? `${formatVal(val)}${cell.n > 1 ? ` (n=${cell.n}, se=${formatVal(cell.se)})` : ""}` : "no data"} | |
| > | |
| {formatVal(val)} | |
| </td> | |
| ); | |
| })} | |
| </tr> | |
| ); | |
| })} | |
| </tbody> | |
| </table> | |
| </div> | |
| </div> | |
| ); | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Bar chart component (SVG) | |
| // --------------------------------------------------------------------------- | |
| function BarChart({ | |
| data, | |
| rowLabels, | |
| evalVariants, | |
| metric, | |
| title, | |
| }: { | |
| data: Map<string, Map<EvalVariant, AggCell>>; | |
| rowLabels: string[]; | |
| evalVariants: EvalVariant[]; | |
| metric: Metric; | |
| title: string; | |
| }) { | |
| const isCorr = metric.startsWith("corr"); | |
| const subBarWidth = 18; | |
| const subGap = 2; | |
| const groupGap = 16; | |
| const numSub = evalVariants.length; | |
| const groupWidth = numSub * subBarWidth + (numSub - 1) * subGap; | |
| const chartHeight = 200; | |
| const marginTop = 20; | |
| const marginBottom = 80; | |
| const marginLeft = 50; | |
| const legendHeight = 24; | |
| const svgWidth = marginLeft + rowLabels.length * (groupWidth + groupGap) + 20; | |
| const svgHeight = chartHeight + marginTop + marginBottom + legendHeight; | |
| // Collect all values for scale | |
| const allCells: { mean: number; se: number }[] = []; | |
| for (const rl of rowLabels) { | |
| for (const ev of evalVariants) { | |
| const cell = data.get(rl)?.get(ev); | |
| if (cell && !isNaN(cell.mean)) allCells.push(cell); | |
| } | |
| } | |
| if (allCells.length === 0) { | |
| return ( | |
| <div className="flex flex-col gap-2"> | |
| <h3 className="text-xs font-medium text-gray-300">{title}</h3> | |
| <p className="text-xs text-gray-500 italic">No data</p> | |
| </div> | |
| ); | |
| } | |
| let minVal: number, maxVal: number; | |
| const low = Math.min(...allCells.map((c) => c.mean - c.se)); | |
| const high = Math.max(...allCells.map((c) => c.mean + c.se)); | |
| if (isCorr) { | |
| // Data-driven range with padding, clamped to [-1, 1] | |
| minVal = Math.max(-1, low - 0.05); | |
| maxVal = Math.min(1, high + 0.05); | |
| } else { | |
| minVal = Math.max(0, low - 0.05); | |
| maxVal = Math.min(1, high + 0.05); | |
| } | |
| const range = maxVal - minVal || 1; | |
| const yScale = (v: number) => marginTop + chartHeight * (1 - (v - minVal) / range); | |
| // Tick marks | |
| const ticks: number[] = []; | |
| const corrRange = maxVal - minVal; | |
| const step = isCorr ? (corrRange > 0.5 ? 0.2 : 0.1) : 0.1; | |
| for (let t = Math.ceil(minVal / step) * step; t <= maxVal; t += step) { | |
| ticks.push(Math.round(t * 1000) / 1000); | |
| } | |
| // Shade: raw=full, tc=darker, lenorm=lighter, tc+lenorm=darkest | |
| const variantOpacity: Record<string, number> = { raw: 0.85, tc: 0.6, lenorm: 0.7, "tc+lenorm": 0.45 }; | |
| return ( | |
| <div className="flex flex-col gap-2"> | |
| <h3 className="text-xs font-medium text-gray-300">{title}</h3> | |
| <div className="overflow-x-auto"> | |
| <svg width={svgWidth} height={svgHeight} className="block"> | |
| {/* Hatching patterns for tc variants */} | |
| <defs> | |
| <pattern id="hatch-tc" patternUnits="userSpaceOnUse" width="4" height="4" patternTransform="rotate(45)"> | |
| <line x1="0" y1="0" x2="0" y2="4" stroke="rgba(0,0,0,0.3)" strokeWidth="1" /> | |
| </pattern> | |
| <pattern id="hatch-tclenorm" patternUnits="userSpaceOnUse" width="4" height="4" patternTransform="rotate(-45)"> | |
| <line x1="0" y1="0" x2="0" y2="4" stroke="rgba(0,0,0,0.4)" strokeWidth="1.5" /> | |
| </pattern> | |
| </defs> | |
| {/* Grid lines */} | |
| {ticks.map((t) => ( | |
| <g key={t}> | |
| <line x1={marginLeft} x2={svgWidth - 10} y1={yScale(t)} y2={yScale(t)} stroke="#374151" strokeWidth={0.5} /> | |
| <text x={marginLeft - 4} y={yScale(t) + 3} textAnchor="end" className="fill-gray-500 text-[10px]"> | |
| {isCorr ? t.toFixed(1) : (t * 100).toFixed(0)} | |
| </text> | |
| </g> | |
| ))} | |
| {/* Grouped bars */} | |
| {rowLabels.map((rl, i) => { | |
| const groupX = marginLeft + i * (groupWidth + groupGap); | |
| const color = BAR_COLORS[i % BAR_COLORS.length]; | |
| const labelX = groupX + groupWidth / 2; | |
| return ( | |
| <g key={rl}> | |
| {evalVariants.map((ev, j) => { | |
| const cell = data.get(rl)?.get(ev); | |
| const val = cell?.mean ?? NaN; | |
| if (isNaN(val)) return null; | |
| const x = groupX + j * (subBarWidth + subGap); | |
| const barY = yScale(val); | |
| const baseY = yScale(isCorr ? 0 : minVal); | |
| const barH = Math.abs(baseY - barY); | |
| const actualY = Math.min(barY, baseY); | |
| const opacity = variantOpacity[ev] ?? 0.85; | |
| const hasHatch = ev === "tc" || ev === "tc+lenorm"; | |
| const cx = x + subBarWidth / 2; | |
| const se = cell?.se ?? 0; | |
| const seTop = yScale(Math.min(maxVal, val + se)); | |
| const seBot = yScale(Math.max(minVal, val - se)); | |
| return ( | |
| <g key={ev}> | |
| <rect x={x} y={actualY} width={subBarWidth} height={barH} fill={color} rx={1} opacity={opacity} /> | |
| {hasHatch && ( | |
| <rect x={x} y={actualY} width={subBarWidth} height={barH} fill={ev === "tc" ? "url(#hatch-tc)" : "url(#hatch-tclenorm)"} rx={1} /> | |
| )} | |
| {se > 0 && ( | |
| <> | |
| <line x1={cx} x2={cx} y1={seTop} y2={seBot} stroke="#e5e7eb" strokeWidth={1} /> | |
| <line x1={cx - 3} x2={cx + 3} y1={seTop} y2={seTop} stroke="#e5e7eb" strokeWidth={1} /> | |
| <line x1={cx - 3} x2={cx + 3} y1={seBot} y2={seBot} stroke="#e5e7eb" strokeWidth={1} /> | |
| </> | |
| )} | |
| <title>{`${displayRowLabel(rl)} [${ev}]: ${isCorr ? val.toFixed(3) : (val * 100).toFixed(1)} (n=${cell?.n ?? 0}, se=${se.toFixed(4)})`}</title> | |
| </g> | |
| ); | |
| })} | |
| {/* Model label */} | |
| <text | |
| x={labelX} | |
| y={marginTop + chartHeight + 8} | |
| textAnchor="end" | |
| transform={`rotate(-45, ${labelX}, ${marginTop + chartHeight + 8})`} | |
| className="fill-gray-400 text-[9px]" | |
| > | |
| <SvgRowLabel label={rl} maxLen={22} /> | |
| </text> | |
| </g> | |
| ); | |
| })} | |
| {/* Legend */} | |
| {evalVariants.map((ev, j) => { | |
| const lx = marginLeft + j * 80; | |
| const ly = svgHeight - 10; | |
| const opacity = variantOpacity[ev] ?? 0.85; | |
| const hasHatch = ev === "tc" || ev === "tc+lenorm"; | |
| return ( | |
| <g key={ev}> | |
| <rect x={lx} y={ly - 8} width={12} height={8} fill="#636EFA" opacity={opacity} rx={1} /> | |
| {hasHatch && <rect x={lx} y={ly - 8} width={12} height={8} fill={ev === "tc" ? "url(#hatch-tc)" : "url(#hatch-tclenorm)"} rx={1} />} | |
| <text x={lx + 16} y={ly} className="fill-gray-400 text-[9px]">{ev}</text> | |
| </g> | |
| ); | |
| })} | |
| </svg> | |
| </div> | |
| </div> | |
| ); | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Delta comparison view: side-by-side heatmaps + delta | |
| // --------------------------------------------------------------------------- | |
| // Diverging color scale for deltas: red (negative) ↔ white (zero) ↔ blue (positive) | |
| // Uses Plotly's RdBu stops (reversed so blue = positive) | |
| const DELTA_NEG: [number, number, number][] = [ | |
| [103, 0, 31], // -max: deep red | |
| [178, 24, 43], | |
| [214, 96, 77], | |
| [244, 165, 130], | |
| [253, 219, 199], | |
| [255, 255, 255], // zero: white | |
| ]; | |
| const DELTA_POS: [number, number, number][] = [ | |
| [255, 255, 255], // zero: white | |
| [209, 229, 240], | |
| [146, 197, 222], | |
| [67, 147, 195], | |
| [33, 102, 172], | |
| [5, 48, 97], // +max: deep blue | |
| ]; | |
| function deltaColor(delta: number): string { | |
| if (isNaN(delta)) return "#f3f4f6"; | |
| const maxDelta = 0.10; // saturate at ±10pp | |
| const t = Math.min(Math.abs(delta) / maxDelta, 1); // 0..1 | |
| const stops = delta < 0 ? DELTA_NEG : DELTA_POS; | |
| // Interpolate across 6 stops (indices 0..5), t maps to position in stops | |
| const pos = t * (stops.length - 1); | |
| const i = Math.min(Math.floor(pos), stops.length - 2); | |
| const f = pos - i; | |
| // For negative, go from white (index 5) toward deep red (index 0) → reverse index | |
| const idx = delta < 0 ? (stops.length - 1) - i : i; | |
| const idxNext = delta < 0 ? idx - 1 : idx + 1; | |
| const c0 = stops[idx]; | |
| const c1 = stops[Math.max(0, Math.min(stops.length - 1, idxNext))]; | |
| const r = Math.round(c0[0] + f * (c1[0] - c0[0])); | |
| const g = Math.round(c0[1] + f * (c1[1] - c0[1])); | |
| const b = Math.round(c0[2] + f * (c1[2] - c0[2])); | |
| return `rgb(${r}, ${g}, ${b})`; | |
| } | |
| function deltaTextColor(delta: number): string { | |
| if (isNaN(delta)) return "#9ca3af"; | |
| const t = Math.min(Math.abs(delta) / 0.10, 1); | |
| return t > 0.6 ? "#ffffff" : "#1f2937"; | |
| } | |
| function SideBySideDelta({ | |
| pairs, | |
| leftData, | |
| rightData, | |
| colLabels, | |
| metric, | |
| title, | |
| leftLabel, | |
| rightLabel, | |
| fullConfigs, | |
| }: { | |
| pairs: [string, string][]; // [leftRowLabel, rightRowLabel] | |
| leftData: Map<string, Map<EvalVariant, AggCell>>; | |
| rightData: Map<string, Map<EvalVariant, AggCell>>; | |
| colLabels: EvalVariant[]; | |
| metric: Metric; | |
| title: string; | |
| leftLabel: string; | |
| rightLabel: string; | |
| fullConfigs: Map<string, string>; | |
| }) { | |
| const isCorr = metric.startsWith("corr"); | |
| const formatVal = (v: number) => { | |
| if (isNaN(v)) return "-"; | |
| return isCorr ? v.toFixed(3) : (v * 100).toFixed(1); | |
| }; | |
| const formatDelta = (v: number) => { | |
| if (isNaN(v)) return "-"; | |
| const pp = isCorr ? v : v * 100; | |
| const sign = pp > 0 ? "+" : ""; | |
| return isCorr ? `${sign}${pp.toFixed(3)}` : `${sign}${pp.toFixed(1)}`; | |
| }; | |
| const colorFn = isCorr ? corrColor : rdYlGn; | |
| // Build pair labels (strip the tc flag difference for a clean display) | |
| const pairLabels = pairs.map(([l]) => l); | |
| return ( | |
| <div className="flex flex-col gap-3"> | |
| <h3 className="text-xs font-medium text-gray-300">{title}</h3> | |
| <div className="flex gap-6 overflow-x-auto"> | |
| {/* Left heatmap */} | |
| <div> | |
| <p className="text-[10px] uppercase tracking-wider text-gray-500 font-medium mb-1">{leftLabel}</p> | |
| <table className="border-collapse text-xs"> | |
| <thead> | |
| <tr> | |
| <th className="text-left py-1 px-2 text-gray-500 font-normal min-w-[120px]">Model</th> | |
| {colLabels.map((col) => ( | |
| <th key={col} className="text-center py-1 px-3 text-gray-400 font-medium min-w-[60px]">{col}</th> | |
| ))} | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {pairs.map(([leftLabel_]) => { | |
| const evMap = leftData.get(leftLabel_); | |
| return ( | |
| <tr key={leftLabel_}> | |
| <td className="py-1 px-2 text-gray-300 font-mono text-[11px]" title={fullConfigs.get(leftLabel_) || leftLabel_}><DisplayRowLabel label={leftLabel_} /></td> | |
| {colLabels.map((col) => { | |
| const cell = evMap?.get(col); | |
| const val = cell?.mean ?? NaN; | |
| const bg = colorFn(val); | |
| const fg = textColor(val, isCorr); | |
| return ( | |
| <td key={col} className="text-center py-1.5 px-2 font-mono border border-gray-800/50" style={{ backgroundColor: bg, color: fg }}> | |
| {formatVal(val)} | |
| </td> | |
| ); | |
| })} | |
| </tr> | |
| ); | |
| })} | |
| </tbody> | |
| </table> | |
| </div> | |
| {/* Right heatmap */} | |
| <div> | |
| <p className="text-[10px] uppercase tracking-wider text-gray-500 font-medium mb-1">{rightLabel}</p> | |
| <table className="border-collapse text-xs"> | |
| <thead> | |
| <tr> | |
| <th className="text-left py-1 px-1.5 text-gray-500 font-normal min-w-[80px]">Model</th> | |
| {colLabels.map((col) => ( | |
| <th key={col} className="text-center py-1 px-3 text-gray-400 font-medium min-w-[60px]">{col}</th> | |
| ))} | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {pairs.map(([, rightLabel_]) => { | |
| const evMap = rightData.get(rightLabel_); | |
| return ( | |
| <tr key={rightLabel_}> | |
| <td className="py-1 px-1.5 text-gray-300 font-mono text-[9px] whitespace-nowrap" title={fullConfigs.get(rightLabel_) || rightLabel_}><DisplayRowLabel label={rightLabel_} /></td> | |
| {colLabels.map((col) => { | |
| const cell = evMap?.get(col); | |
| const val = cell?.mean ?? NaN; | |
| const bg = colorFn(val); | |
| const fg = textColor(val, isCorr); | |
| return ( | |
| <td key={col} className="text-center py-1.5 px-2 font-mono border border-gray-800/50" style={{ backgroundColor: bg, color: fg }}> | |
| {formatVal(val)} | |
| </td> | |
| ); | |
| })} | |
| </tr> | |
| ); | |
| })} | |
| </tbody> | |
| </table> | |
| </div> | |
| {/* Delta heatmap */} | |
| <div> | |
| <p className="text-[10px] uppercase tracking-wider text-gray-500 font-medium mb-1">Delta ({rightLabel} − {leftLabel})</p> | |
| <table className="border-collapse text-xs"> | |
| <thead> | |
| <tr> | |
| <th className="text-left py-1 px-2 text-gray-500 font-normal min-w-[120px]">Model</th> | |
| {colLabels.map((col) => ( | |
| <th key={col} className="text-center py-1 px-3 text-gray-400 font-medium min-w-[60px]">{col}</th> | |
| ))} | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {pairs.map(([leftLabel_, rightLabel_]) => { | |
| const leftEvMap = leftData.get(leftLabel_); | |
| const rightEvMap = rightData.get(rightLabel_); | |
| return ( | |
| <tr key={leftLabel_}> | |
| <td className="py-1 px-2 text-gray-300 font-mono text-[11px]"><DisplayRowLabel label={pairLabels[pairs.findIndex(([l]) => l === leftLabel_)]} /></td> | |
| {colLabels.map((col) => { | |
| const leftVal = leftEvMap?.get(col)?.mean ?? NaN; | |
| const rightVal = rightEvMap?.get(col)?.mean ?? NaN; | |
| const delta = (!isNaN(leftVal) && !isNaN(rightVal)) ? rightVal - leftVal : NaN; | |
| const bg = deltaColor(delta); | |
| const fg = deltaTextColor(delta); | |
| return ( | |
| <td key={col} className="text-center py-1.5 px-2 font-mono border border-gray-800/50" style={{ backgroundColor: bg, color: fg }}> | |
| {formatDelta(delta)} | |
| </td> | |
| ); | |
| })} | |
| </tr> | |
| ); | |
| })} | |
| </tbody> | |
| </table> | |
| </div> | |
| </div> | |
| </div> | |
| ); | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Per-task collapsible section | |
| // --------------------------------------------------------------------------- | |
| function PerTaskCollapsible({ | |
| sections, | |
| fullConfigs, | |
| visibleRows, | |
| }: { | |
| sections: { label: string; taskCount: number; metrics: { metric: Metric; data: Map<string, Map<EvalVariant, AggCell>> }[] }[]; | |
| fullConfigs: Map<string, string>; | |
| visibleRows: Set<string>; | |
| }) { | |
| const [open, setOpen] = useState(false); | |
| return ( | |
| <div className="border-t border-gray-700 pt-4"> | |
| <button | |
| onClick={() => setOpen((o) => !o)} | |
| className="flex items-center gap-2 text-sm font-semibold text-gray-300 hover:text-gray-100 transition-colors" | |
| > | |
| <span className={`text-xs transition-transform ${open ? "rotate-90" : ""}`}>▶</span> | |
| Per-task breakdown ({sections.length} tasks) | |
| </button> | |
| {open && ( | |
| <div className="mt-4 space-y-10"> | |
| {sections.map((section) => ( | |
| <div key={section.label} className="space-y-4"> | |
| <h3 className="text-sm font-medium text-cyan-300 border-b border-gray-800 pb-1">{section.label}</h3> | |
| <div className="flex flex-wrap gap-6"> | |
| {section.metrics.map(({ metric: m, data }) => { | |
| const metricLabel = METRICS.find((x) => x.key === m)?.label ?? m; | |
| const metricRowLabels = sortRowLabels(Array.from(data.keys()).filter((l) => visibleRows.has(l))); | |
| const metricEvs = EVAL_VARIANTS.filter((ev) => { | |
| for (const evMap of data.values()) { | |
| if (evMap.has(ev)) return true; | |
| } | |
| return false; | |
| }); | |
| if (metricRowLabels.length === 0) return null; | |
| return ( | |
| <HeatmapGrid | |
| key={m} | |
| data={data} | |
| rowLabels={metricRowLabels} | |
| colLabels={metricEvs} | |
| metric={m} | |
| title={metricLabel} | |
| fullConfigs={fullConfigs} | |
| /> | |
| ); | |
| })} | |
| </div> | |
| </div> | |
| ))} | |
| </div> | |
| )} | |
| </div> | |
| ); | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Per-task comparison collapsible (side-by-side + delta per task) | |
| // --------------------------------------------------------------------------- | |
| function PerTaskComparisonCollapsible({ | |
| filteredRows, | |
| pairs, | |
| leftLabel, | |
| rightLabel, | |
| leftFilter, | |
| rightFilter, | |
| fullConfigs, | |
| }: { | |
| filteredRows: SummaryRow[]; | |
| pairs: [string, string][]; | |
| leftLabel: string; | |
| rightLabel: string; | |
| leftFilter: (r: SummaryRow) => boolean; | |
| rightFilter: (r: SummaryRow) => boolean; | |
| fullConfigs: Map<string, string>; | |
| }) { | |
| const [open, setOpen] = useState(false); | |
| const tasks = useMemo(() => { | |
| const t = new Set(filteredRows.map((r) => r.task)); | |
| return Array.from(t).sort(); | |
| }, [filteredRows]); | |
| if (tasks.length <= 1) return null; | |
| return ( | |
| <div className="border-t border-gray-700 pt-4"> | |
| <button | |
| onClick={() => setOpen((o) => !o)} | |
| className="flex items-center gap-2 text-sm font-semibold text-gray-300 hover:text-gray-100 transition-colors" | |
| > | |
| <span className={`text-xs transition-transform ${open ? "rotate-90" : ""}`}>▶</span> | |
| Per-task breakdown ({tasks.length} tasks) | |
| </button> | |
| {open && ( | |
| <div className="mt-4 space-y-10"> | |
| {tasks.map((task) => { | |
| const taskRows = filteredRows.filter((r) => r.task === task); | |
| const leftRows = taskRows.filter(leftFilter); | |
| const rightRows = taskRows.filter(rightFilter); | |
| return ( | |
| <div key={task} className="space-y-4"> | |
| <h3 className="text-sm font-medium text-cyan-300 border-b border-gray-800 pb-1">{task}</h3> | |
| <div className="space-y-6"> | |
| {METRICS.map((m) => { | |
| const leftData = aggregateData(leftRows, m.key, getRowLabel).data; | |
| const rightData = aggregateData(rightRows, m.key, getRowLabel).data; | |
| const existingPairs = pairs.filter(([l, r]) => leftData.has(l) && rightData.has(r)); | |
| if (existingPairs.length === 0) return null; | |
| const evs = EVAL_VARIANTS.filter((ev) => { | |
| for (const evMap of [...leftData.values(), ...rightData.values()]) { | |
| if (evMap.has(ev)) return true; | |
| } | |
| return false; | |
| }); | |
| return ( | |
| <SideBySideDelta | |
| key={m.key} | |
| pairs={existingPairs} | |
| leftData={leftData} | |
| rightData={rightData} | |
| colLabels={evs} | |
| metric={m.key} | |
| title={m.label} | |
| leftLabel={leftLabel} | |
| rightLabel={rightLabel} | |
| fullConfigs={fullConfigs} | |
| /> | |
| ); | |
| })} | |
| </div> | |
| </div> | |
| ); | |
| })} | |
| </div> | |
| )} | |
| </div> | |
| ); | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Main component | |
| // --------------------------------------------------------------------------- | |
| export default function HeatmapViewer({ | |
| datasetRepo, | |
| split: _split = "train", | |
| onClose, | |
| }: HeatmapViewerProps) { | |
| const fullRepo = datasetRepo.includes("/") ? datasetRepo : `${HF_ORG}/${datasetRepo}`; | |
| const shortName = datasetRepo.split("/").pop() ?? datasetRepo; | |
| const { rows, loading, error, progress, refetch } = useDatasetRows(fullRepo, _split); | |
| // --- Filter state --- | |
| const [selectedModel, setSelectedModel] = useState<string>("__first__"); | |
| const [selectedFamily, setSelectedFamily] = useState<string>("hypernym"); | |
| const [selectedTask, setSelectedTask] = useState<string>("__all__"); | |
| const [selectedSplit, setSelectedSplit] = useState<string>("test"); | |
| const [selectedTCType, setSelectedTCType] = useState<TCType>("self"); | |
| const [selectedMetric, setSelectedMetric] = useState<Metric>("gen_roc"); | |
| const [selectedDomain, setSelectedDomain] = useState<DomainFilter>("all"); | |
| const [viewMode, setViewMode] = useState<ViewMode>("heatmap"); | |
| const [barVariant, setBarVariant] = useState<EvalVariant>("raw"); | |
| const [comparisonPreset, setComparisonPreset] = useState<ComparisonPreset>("all"); | |
| // Row visibility: null means "show all", otherwise explicit set from preset or manual toggle | |
| const [visibleRows, setVisibleRows] = useState<Set<string> | null>(null); | |
| // --- Derived: available splits, tasks, families --- | |
| // --- Available base models (model column = base model identity, same for all finetuned variants) --- | |
| const availableModels = useMemo(() => { | |
| const s = new Set(rows.map((r) => r.model)); | |
| return Array.from(s).sort(); | |
| }, [rows]); | |
| // Display-friendly name: "v6-google_gemma-2-2b" → "gemma-2-2b" | |
| const modelDisplayName = useCallback((m: string) => { | |
| return m.replace(/^v\d+-[^_]+_/, ""); | |
| }, []); | |
| // Resolve __first__ to actual first model once data loads | |
| const resolvedModel = useMemo(() => { | |
| if (selectedModel === "__first__" && availableModels.length > 0) return availableModels[0]; | |
| if (availableModels.includes(selectedModel)) return selectedModel; | |
| return availableModels[0] ?? ""; | |
| }, [selectedModel, availableModels]); | |
| const availableSplits = useMemo(() => { | |
| const s = new Set(rows.map((r) => r.split)); | |
| return Array.from(s).sort(); | |
| }, [rows]); | |
| const availableTasks = useMemo(() => { | |
| let filtered = rows.filter((r) => isValidEvalTask(r.task)); | |
| const fam = TASK_FAMILIES.find((f) => f.key === selectedFamily); | |
| if (fam) filtered = filtered.filter((r) => r.task.startsWith(fam.prefix)); | |
| const tasks = new Set(filtered.map((r) => r.task)); | |
| return Array.from(tasks).sort(); | |
| }, [rows, selectedFamily]); | |
| // Reset task selection when family changes | |
| useEffect(() => { | |
| setSelectedTask("__all__"); | |
| }, [selectedFamily]); | |
| // --- Filtering pipeline --- | |
| const filteredRows = useMemo(() => { | |
| let result = rows; | |
| // Model filter — critical: never mix different base models | |
| if (resolvedModel) { | |
| result = result.filter((r) => r.model === resolvedModel); | |
| } | |
| // Exclude non-eval tasks (concat, bare family names, etc.) | |
| result = result.filter((r) => isValidEvalTask(r.task)); | |
| // Only show finetuned models trained on valid combined datasets (Setting U) | |
| result = result.filter((r) => isValidFinetunedModel(r)); | |
| // Split filter | |
| result = result.filter((r) => r.split === selectedSplit); | |
| // TC type filter (single-select: each TC type is a different eval condition) | |
| result = result.filter((r) => getEvalTCType(r) === selectedTCType); | |
| // For finetuned models trained with a specific TC, only show rows where eval TC matches training TC | |
| result = result.filter((r) => isMatchedTC(r)); | |
| // Family filter — also exclude finetuned models trained on a different family | |
| const fam = TASK_FAMILIES.find((f) => f.key === selectedFamily); | |
| if (fam) { | |
| result = result.filter((r) => r.task.startsWith(fam.prefix)); | |
| result = result.filter((r) => { | |
| if (!r.finetuned) return true; // base models always shown | |
| const trainFam = getTrainingFamily(r.training_config); | |
| return trainFam === null || trainFam === selectedFamily; | |
| }); | |
| } | |
| // Specific task filter | |
| if (selectedTask !== "__all__") { | |
| result = result.filter((r) => r.task === selectedTask); | |
| } | |
| // Domain filter (OOD vs in-domain) | |
| if (selectedDomain !== "all") { | |
| result = result.filter((r) => { | |
| const family = getTaskFamily(r.task); | |
| if (!family) return true; | |
| const ood = isOodTask(r.task, r.training_config, family); | |
| if (ood === null) return true; | |
| return selectedDomain === "ood" ? ood : !ood; | |
| }); | |
| } | |
| return result; | |
| }, [rows, resolvedModel, selectedSplit, selectedTCType, selectedFamily, selectedTask, selectedDomain]); | |
| // --- Build row label → full config mapping --- | |
| const fullConfigs = useMemo(() => { | |
| const map = new Map<string, string>(); | |
| for (const r of filteredRows) { | |
| const label = getRowLabel(r); | |
| if (!map.has(label)) { | |
| map.set(label, r.finetuned ? r.training_config : "Base (not finetuned)"); | |
| } | |
| } | |
| return map; | |
| }, [filteredRows]); | |
| // --- All available row labels (before visibility filter) --- | |
| const allRowLabels = useMemo(() => { | |
| const labels = new Set<string>(); | |
| for (const r of filteredRows) labels.add(getRowLabel(r)); | |
| return sortRowLabels(Array.from(labels)); | |
| }, [filteredRows]); | |
| // Effective visible rows: use explicit selection, preset, or show all | |
| const effectiveVisibleRows = useMemo(() => { | |
| const available = new Set(allRowLabels); | |
| // If a comparison preset restricts rows, apply that | |
| if (comparisonPreset === "training-effect" && visibleRows === null) { | |
| const effective = new Set<string>(); | |
| for (const label of available) { | |
| if (TRAINING_EFFECT_ROWS.has(label)) effective.add(label); | |
| } | |
| return effective.size > 0 ? effective : available; | |
| } | |
| // Manual selection | |
| if (visibleRows !== null) { | |
| const effective = new Set<string>(); | |
| for (const label of visibleRows) { | |
| if (available.has(label)) effective.add(label); | |
| } | |
| return effective.size > 0 ? effective : available; | |
| } | |
| // Default: show all except hidden-by-default rows | |
| const effective = new Set<string>(); | |
| for (const label of available) { | |
| if (!DEFAULT_HIDDEN_ROWS.has(label)) effective.add(label); | |
| } | |
| return effective.size > 0 ? effective : available; | |
| }, [visibleRows, allRowLabels, comparisonPreset]); | |
| // --- Aggregation --- | |
| const aggResult = useMemo(() => { | |
| return aggregateData(filteredRows, selectedMetric, getRowLabel); | |
| }, [filteredRows, selectedMetric]); | |
| const aggData = aggResult.data; | |
| const validationError = aggResult.validationError; | |
| const rowLabels = useMemo(() => { | |
| return sortRowLabels(Array.from(aggData.keys()).filter((l) => effectiveVisibleRows.has(l))); | |
| }, [aggData, effectiveVisibleRows]); | |
| // Available eval variants (only those with data) | |
| const availableEvalVariants = useMemo(() => { | |
| const evs = new Set<EvalVariant>(); | |
| for (const evMap of aggData.values()) { | |
| for (const ev of evMap.keys()) evs.add(ev); | |
| } | |
| return EVAL_VARIANTS.filter((ev) => evs.has(ev)); | |
| }, [aggData]); | |
| // --- Stats --- | |
| const stats = useMemo(() => { | |
| const taskCount = new Set(filteredRows.map((r) => r.task)).size; | |
| const modelCount = new Set(filteredRows.map((r) => r.finetuned ? r.training_config : "base")).size; | |
| return { tasks: taskCount, models: modelCount, rows: filteredRows.length }; | |
| }, [filteredRows]); | |
| // --- Domain-split rows for aggregated heatmaps --- | |
| const domainSplitRows = useMemo(() => { | |
| const hasDomainSplit = (selectedFamily === "ifeval" || selectedFamily === "hypernym") && selectedTask === "__all__"; | |
| if (!hasDomainSplit || selectedDomain !== "all") return null; | |
| const oodRows = filteredRows.filter((r) => { | |
| const family = getTaskFamily(r.task); | |
| if (!family) return false; | |
| return isOodTask(r.task, r.training_config, family) === true; | |
| }); | |
| const idRows = filteredRows.filter((r) => { | |
| const family = getTaskFamily(r.task); | |
| if (!family) return false; | |
| return isOodTask(r.task, r.training_config, family) === false; | |
| }); | |
| const oodTaskCount = new Set(oodRows.map((r) => r.task)).size; | |
| const idTaskCount = new Set(idRows.map((r) => r.task)).size; | |
| return { oodRows, idRows, oodTaskCount, idTaskCount }; | |
| }, [filteredRows, selectedFamily, selectedTask, selectedDomain]); | |
| // --- Multi-metric heatmaps (one set per domain group when applicable) --- | |
| type HeatmapSection = { | |
| label: string; | |
| taskCount: number; | |
| metrics: { metric: Metric; data: Map<string, Map<EvalVariant, AggCell>> }[]; | |
| }; | |
| const heatmapSections = useMemo((): HeatmapSection[] => { | |
| if (viewMode !== "heatmap") return []; | |
| const buildSection = (sectionRows: SummaryRow[], label: string, taskCount: number): HeatmapSection => { | |
| const metrics: { metric: Metric; data: Map<string, Map<EvalVariant, AggCell>> }[] = []; | |
| for (const m of METRICS) { | |
| metrics.push({ metric: m.key, data: aggregateData(sectionRows, m.key, getRowLabel).data }); | |
| } | |
| return { label, taskCount, metrics }; | |
| }; | |
| if (domainSplitRows) { | |
| const sections: HeatmapSection[] = []; | |
| if (domainSplitRows.oodRows.length > 0) { | |
| sections.push(buildSection(domainSplitRows.oodRows, "Out-of-domain", domainSplitRows.oodTaskCount)); | |
| } | |
| if (domainSplitRows.idRows.length > 0) { | |
| sections.push(buildSection(domainSplitRows.idRows, "In-domain", domainSplitRows.idTaskCount)); | |
| } | |
| return sections; | |
| } | |
| // For families without domain splits but with multiple tasks, still show aggregate | |
| return [buildSection(filteredRows, "All tasks", stats.tasks)]; | |
| }, [filteredRows, viewMode, domainSplitRows, stats.tasks]); | |
| // --- Per-task heatmaps (when viewing "All tasks" in a family) --- | |
| const perTaskSections = useMemo((): HeatmapSection[] => { | |
| if (viewMode !== "heatmap") return []; | |
| if (selectedTask !== "__all__") return []; // single task already shown in main sections | |
| const tasks = new Set(filteredRows.map((r) => r.task)); | |
| if (tasks.size <= 1) return []; // no point showing per-task if only one | |
| const sorted = Array.from(tasks).sort(); | |
| return sorted.map((task) => { | |
| const taskRows = filteredRows.filter((r) => r.task === task); | |
| const metrics: { metric: Metric; data: Map<string, Map<EvalVariant, AggCell>> }[] = []; | |
| for (const m of METRICS) { | |
| metrics.push({ metric: m.key, data: aggregateData(taskRows, m.key, getRowLabel).data }); | |
| } | |
| return { label: task, taskCount: 1, metrics }; | |
| }); | |
| }, [filteredRows, viewMode, selectedTask]); | |
| // --- Comparison preset data --- | |
| // For delta comparisons, we need rows from multiple TC types simultaneously | |
| const comparisonBaseRows = useMemo(() => { | |
| // Same as filteredRows but without TC type filter | |
| let result = rows; | |
| // Model filter — must match filteredRows | |
| if (resolvedModel) { | |
| result = result.filter((r) => r.model === resolvedModel); | |
| } | |
| result = result.filter((r) => isValidEvalTask(r.task)); | |
| result = result.filter((r) => isValidFinetunedModel(r)); | |
| result = result.filter((r) => r.split === selectedSplit); | |
| // For finetuned models trained with a specific TC, only show rows where eval TC matches training TC | |
| result = result.filter((r) => isMatchedTC(r)); | |
| const fam = TASK_FAMILIES.find((f) => f.key === selectedFamily); | |
| if (fam) { | |
| result = result.filter((r) => r.task.startsWith(fam.prefix)); | |
| result = result.filter((r) => { | |
| if (!r.finetuned) return true; | |
| const trainFam = getTrainingFamily(r.training_config); | |
| return trainFam === null || trainFam === selectedFamily; | |
| }); | |
| } | |
| if (selectedTask !== "__all__") { | |
| result = result.filter((r) => r.task === selectedTask); | |
| } | |
| if (selectedDomain !== "all") { | |
| result = result.filter((r) => { | |
| const family = getTaskFamily(r.task); | |
| if (!family) return true; | |
| const ood = isOodTask(r.task, r.training_config, family); | |
| if (ood === null) return true; | |
| return selectedDomain === "ood" ? ood : !ood; | |
| }); | |
| } | |
| return result; | |
| }, [rows, resolvedModel, selectedSplit, selectedFamily, selectedTask, selectedDomain]); | |
| // Build aggregated data for rows with a specific training TC flag (in the row label) | |
| // Note: this is the TRAINING tc flag (in the model name), NOT the eval TC type | |
| const aggByTrainingTC = useCallback((trainingTCFlag: string | null, metric: Metric) => { | |
| const subset = filteredRows.filter((r) => { | |
| const label = getRowLabel(r); | |
| if (trainingTCFlag === null) { | |
| // No training TC flag — exclude rows that have any TC flag | |
| return !TC_FLAGS.some((f) => hasTCFlag(label, f)); | |
| } | |
| return hasTCFlag(label, trainingTCFlag); | |
| }); | |
| return aggregateData(subset, metric, getRowLabel).data; | |
| }, [filteredRows]); | |
| // Dynamic pairing: find rows that differ only by a TC flag | |
| const buildPairs = useCallback((leftFlag: string | null, rightFlag: string): [string, string][] => { | |
| const leftLabels = new Set<string>(); | |
| const rightLabels = new Set<string>(); | |
| for (const r of filteredRows) { | |
| const label = getRowLabel(r); | |
| if (label === "Base") continue; | |
| if (rightFlag && hasTCFlag(label, rightFlag)) { | |
| rightLabels.add(label); | |
| } else if (leftFlag === null && !TC_FLAGS.some((f) => hasTCFlag(label, f))) { | |
| leftLabels.add(label); | |
| } else if (leftFlag && hasTCFlag(label, leftFlag)) { | |
| leftLabels.add(label); | |
| } | |
| } | |
| // Match: strip TC flag from right label, see if it matches a left label | |
| const pairs: [string, string][] = []; | |
| for (const rl of rightLabels) { | |
| const stripped = stripTCFlag(rl); | |
| if (leftLabels.has(stripped)) { | |
| pairs.push([stripped, rl]); | |
| } | |
| } | |
| return sortRowLabels(pairs.map(([l]) => l)).map((l) => { | |
| const r = pairs.find(([left]) => left === l)![1]; | |
| return [l, r] as [string, string]; | |
| }); | |
| }, [filteredRows]); | |
| // + TC-Self pairs: without-tc ↔ tcself | |
| const tcSelfPairs = useMemo((): [string, string][] => { | |
| if (comparisonPreset !== "plus-tcself") return []; | |
| return buildPairs(null, "tcself"); | |
| }, [comparisonPreset, buildPairs]); | |
| // TC-Self vs TC-Neg: for each model, compare train-tcself+eval-self vs train-tcneg+eval-neg | |
| // Base model: same "Base" label, just different eval TC type | |
| // Finetuned: pair by stripping tcself/tcneg from training label | |
| const aggByMatchedTC = useCallback((evalTC: TCType, metric: Metric) => { | |
| // From comparisonBaseRows (no eval TC filter), get rows matching this eval TC type | |
| const subset = comparisonBaseRows.filter((r) => getEvalTCType(r) === evalTC); | |
| return aggregateData(subset, metric, getRowLabel).data; | |
| }, [comparisonBaseRows]); | |
| const tcSelfVsNegPairs = useMemo((): [string, string][] => { | |
| if (comparisonPreset !== "tcself-vs-tcneg") return []; | |
| // Left: models evaluated with self-TC (base + models trained with tcself) | |
| const selfRows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "self"); | |
| const selfLabels = new Set(selfRows.map(getRowLabel)); | |
| // Right: models evaluated with neg-TC (base + models trained with tcneg) | |
| const negRows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "neg"); | |
| const negLabels = new Set(negRows.map(getRowLabel)); | |
| const pairs: [string, string][] = []; | |
| // Base: appears in both with same label | |
| if (selfLabels.has("Base") && negLabels.has("Base")) { | |
| pairs.push(["Base", "Base"]); | |
| } | |
| // Finetuned: match tcself label ↔ tcneg label (strip tc flag to find pair) | |
| for (const sl of selfLabels) { | |
| if (sl === "Base") continue; | |
| if (!hasTCFlag(sl, "tcself")) continue; | |
| const negVersion = sl.replace("tcself", "tcneg"); | |
| if (negLabels.has(negVersion)) { | |
| pairs.push([sl, negVersion]); | |
| } | |
| } | |
| return sortRowLabels(pairs.map(([l]) => l)).map((l) => { | |
| const p = pairs.find(([left]) => left === l)!; | |
| return p; | |
| }); | |
| }, [comparisonPreset, comparisonBaseRows]); | |
| // TC-Self vs TC-GPT2: pair tcself-trained (eval=self) with tco-trained (eval=gpt2) | |
| // e.g. "Comb-tcself-lo" ↔ "Comb-tco-lo" | |
| const tcSelfVsGpt2Pairs = useMemo((): [string, string][] => { | |
| if (comparisonPreset !== "tcself-vs-tcgpt2") return []; | |
| const selfRows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "self"); | |
| const selfLabels = new Set(selfRows.map(getRowLabel)); | |
| const gpt2Rows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "gpt2"); | |
| const gpt2Labels = new Set(gpt2Rows.map(getRowLabel)); | |
| const pairs: [string, string][] = []; | |
| // Base: appears in both with same label | |
| if (selfLabels.has("Base") && gpt2Labels.has("Base")) { | |
| pairs.push(["Base", "Base"]); | |
| } | |
| // Finetuned: match tcself label ↔ tco label (swap tc flag to find pair) | |
| for (const sl of selfLabels) { | |
| if (sl === "Base") continue; | |
| if (!hasTCFlag(sl, "tcself")) continue; | |
| const tcoVersion = sl.replace("tcself", "tco"); | |
| if (gpt2Labels.has(tcoVersion)) { | |
| pairs.push([sl, tcoVersion]); | |
| } | |
| } | |
| return sortRowLabels(pairs.map(([l]) => l)).map((l) => { | |
| const p = pairs.find(([left]) => left === l)!; | |
| return p; | |
| }); | |
| }, [comparisonPreset, comparisonBaseRows]); | |
| // Family options for dropdown | |
| const familyOptions = useMemo((): { key: string; label: string }[] => { | |
| const options: { key: string; label: string }[] = []; | |
| for (const fam of TASK_FAMILIES) { | |
| const count = new Set(rows.filter((r) => r.task.startsWith(fam.prefix)).map((r) => r.task)).size; | |
| if (count > 0) options.push({ key: fam.key, label: `${fam.label} (${count})` }); | |
| } | |
| return options; | |
| }, [rows]); | |
| // Task options for dropdown | |
| const taskOptions = useMemo(() => { | |
| const options = [{ key: "__all__", label: `All tasks (${availableTasks.length})` }]; | |
| for (const t of availableTasks) { | |
| options.push({ key: t, label: t }); | |
| } | |
| return options; | |
| }, [availableTasks]); | |
| // Domain filter visibility | |
| const showDomainFilter = selectedFamily === "ifeval" || selectedFamily === "hypernym"; | |
| return ( | |
| <div className="fixed inset-0 z-50 flex flex-col bg-gray-950"> | |
| {/* Header */} | |
| <div className="flex items-center justify-between px-5 py-3 border-b border-gray-800 flex-shrink-0 bg-gray-900"> | |
| <div className="flex items-center gap-3 min-w-0"> | |
| <span className="text-sm font-semibold text-gray-200 truncate">{shortName}</span> | |
| <span className="text-xs text-gray-600 border border-gray-700 px-1.5 py-0.5 rounded">heatmap</span> | |
| {!loading && ( | |
| <span className="text-xs text-gray-500"> | |
| {rows.length.toLocaleString()} rows loaded | |
| </span> | |
| )} | |
| </div> | |
| <div className="flex items-center gap-2 flex-shrink-0"> | |
| <a | |
| href={`https://huggingface.co/datasets/${fullRepo}`} | |
| target="_blank" | |
| rel="noopener noreferrer" | |
| className="text-xs text-gray-500 hover:text-cyan-400 transition-colors px-2 py-1 border border-gray-700 rounded" | |
| > | |
| HF | |
| </a> | |
| <button | |
| onClick={onClose} | |
| className="text-gray-400 hover:text-gray-200 transition-colors p-1 rounded hover:bg-gray-700" | |
| aria-label="Close viewer" | |
| > | |
| <svg xmlns="http://www.w3.org/2000/svg" className="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth={2}> | |
| <path strokeLinecap="round" strokeLinejoin="round" d="M6 18L18 6M6 6l12 12" /> | |
| </svg> | |
| </button> | |
| </div> | |
| </div> | |
| {/* Loading */} | |
| {loading && ( | |
| <div className="flex-1 flex items-center justify-center"> | |
| <div className="flex flex-col items-center gap-3"> | |
| <div className="w-6 h-6 border-2 border-cyan-500 border-t-transparent rounded-full animate-spin" /> | |
| <p className="text-sm text-gray-400"> | |
| Loading... {progress.loaded.toLocaleString()} / {progress.total.toLocaleString()} rows | |
| </p> | |
| </div> | |
| </div> | |
| )} | |
| {/* Error */} | |
| {!loading && error && ( | |
| <div className="flex-1 flex items-center justify-center"> | |
| <div className="max-w-lg text-center space-y-3"> | |
| <p className="text-sm font-medium text-red-400">Failed to load dataset</p> | |
| <p className="text-xs text-gray-500 font-mono break-words bg-gray-800 rounded p-3">{error}</p> | |
| <button | |
| onClick={refetch} | |
| className="text-xs text-cyan-400 hover:text-cyan-300 border border-cyan-700/50 px-3 py-1 rounded transition-colors" | |
| > | |
| Retry | |
| </button> | |
| </div> | |
| </div> | |
| )} | |
| {/* Main content */} | |
| {!loading && !error && rows.length > 0 && ( | |
| <div className="flex-1 flex overflow-hidden"> | |
| {/* Controls sidebar */} | |
| <div className="w-64 flex-shrink-0 border-r border-gray-800 bg-gray-900/50 overflow-y-auto p-4 space-y-4"> | |
| {/* View mode */} | |
| <div className="flex gap-1"> | |
| <button | |
| onClick={() => setViewMode("heatmap")} | |
| className={`flex-1 text-xs py-1.5 rounded border transition-colors ${ | |
| viewMode === "heatmap" | |
| ? "bg-cyan-800/60 text-cyan-200 border-cyan-600/60" | |
| : "bg-gray-800 text-gray-500 border-gray-700" | |
| }`} | |
| > | |
| Heatmap | |
| </button> | |
| <button | |
| onClick={() => setViewMode("bar")} | |
| className={`flex-1 text-xs py-1.5 rounded border transition-colors ${ | |
| viewMode === "bar" | |
| ? "bg-cyan-800/60 text-cyan-200 border-cyan-600/60" | |
| : "bg-gray-800 text-gray-500 border-gray-700" | |
| }`} | |
| > | |
| Bar Plot | |
| </button> | |
| </div> | |
| {availableModels.length > 1 && ( | |
| <Dropdown | |
| label="Base Model" | |
| value={resolvedModel} | |
| options={availableModels.map((m) => ({ key: m, label: modelDisplayName(m) }))} | |
| onChange={setSelectedModel} | |
| /> | |
| )} | |
| <Dropdown | |
| label="Task Family" | |
| value={selectedFamily} | |
| options={familyOptions} | |
| onChange={setSelectedFamily} | |
| /> | |
| <Dropdown | |
| label="Task" | |
| value={selectedTask} | |
| options={taskOptions} | |
| onChange={setSelectedTask} | |
| /> | |
| <Dropdown | |
| label="Split" | |
| value={selectedSplit} | |
| options={availableSplits.map((s) => ({ key: s, label: s }))} | |
| onChange={setSelectedSplit} | |
| /> | |
| <Dropdown | |
| label="Eval TC Type" | |
| value={selectedTCType} | |
| options={TC_TYPES} | |
| onChange={setSelectedTCType} | |
| /> | |
| {showDomainFilter && ( | |
| <Dropdown | |
| label="Domain" | |
| value={selectedDomain} | |
| options={[ | |
| { key: "all" as DomainFilter, label: "All" }, | |
| { key: "ood" as DomainFilter, label: "Out-of-domain" }, | |
| { key: "in-domain" as DomainFilter, label: "In-domain" }, | |
| ]} | |
| onChange={setSelectedDomain} | |
| /> | |
| )} | |
| <Dropdown | |
| label="Comparison" | |
| value={comparisonPreset} | |
| options={COMPARISON_PRESETS} | |
| onChange={setComparisonPreset} | |
| /> | |
| {/* Row visibility */} | |
| {allRowLabels.length > 0 && ( | |
| <div className="flex flex-col gap-1"> | |
| <div className="flex items-center justify-between"> | |
| <label className="text-[10px] uppercase tracking-wider text-gray-500 font-medium">Visible Models</label> | |
| <div className="flex gap-1"> | |
| <button | |
| onClick={() => setVisibleRows(new Set(allRowLabels))} | |
| className="text-[9px] text-gray-500 hover:text-gray-300 px-1" | |
| > | |
| all | |
| </button> | |
| <button | |
| onClick={() => setVisibleRows(null)} | |
| className="text-[9px] text-gray-500 hover:text-gray-300 px-1" | |
| > | |
| default | |
| </button> | |
| </div> | |
| </div> | |
| <div className="flex flex-col gap-0.5 max-h-48 overflow-y-auto"> | |
| {allRowLabels.map((label) => ( | |
| <label key={label} className="flex items-center gap-1.5 text-xs cursor-pointer hover:bg-gray-800/50 px-1 py-0.5 rounded"> | |
| <input | |
| type="checkbox" | |
| checked={effectiveVisibleRows.has(label)} | |
| onChange={() => { | |
| const next = new Set(effectiveVisibleRows); | |
| if (next.has(label)) next.delete(label); | |
| else next.add(label); | |
| setVisibleRows(next); | |
| }} | |
| className="rounded border-gray-600 bg-gray-800 text-cyan-500 focus:ring-0 focus:ring-offset-0 h-3 w-3" | |
| /> | |
| <span className={effectiveVisibleRows.has(label) ? "text-gray-200" : "text-gray-500"}><DisplayRowLabel label={label} /></span> | |
| </label> | |
| ))} | |
| </div> | |
| </div> | |
| )} | |
| {/* Stats */} | |
| <div className="pt-2 border-t border-gray-800 space-y-1"> | |
| <p className="text-[10px] uppercase tracking-wider text-gray-500 font-medium">Filtered data</p> | |
| <p className="text-xs text-gray-400">{stats.tasks} tasks, {stats.models} configs</p> | |
| <p className="text-xs text-gray-400">{stats.rows.toLocaleString()} rows</p> | |
| {selectedTask === "__all__" && stats.tasks > 1 && ( | |
| <p className="text-[10px] text-gray-600 italic">Averaging across {stats.tasks} tasks</p> | |
| )} | |
| </div> | |
| </div> | |
| {/* Visualization area */} | |
| <div className="flex-1 overflow-auto p-6 space-y-8"> | |
| {validationError ? ( | |
| <div className="flex items-center justify-center h-full"> | |
| <div className="max-w-xl bg-red-950 border-2 border-red-500 rounded-lg p-6 space-y-3"> | |
| <div className="flex items-center gap-2"> | |
| <span className="text-2xl">!!!</span> | |
| <h2 className="text-lg font-bold text-red-300">Data Integrity Violation</h2> | |
| </div> | |
| <pre className="text-sm text-red-200 whitespace-pre-wrap font-mono leading-relaxed">{validationError}</pre> | |
| <p className="text-xs text-red-400 mt-4"> | |
| All visualizations are blocked until this is resolved. If you see this, the data pipeline has a bug | |
| — different models or training configs are collapsing into the same row label. | |
| </p> | |
| </div> | |
| </div> | |
| ) : filteredRows.length === 0 ? ( | |
| <div className="flex items-center justify-center h-full"> | |
| <p className="text-sm text-gray-500 italic">No data matches the current filters.</p> | |
| </div> | |
| ) : viewMode === "heatmap" ? ( | |
| <> | |
| {/* Aggregated sections (domain-split or single aggregate) — hidden when comparison preset is active */} | |
| {(comparisonPreset === "all" || comparisonPreset === "training-effect") && heatmapSections.map((section, sIdx) => ( | |
| <div key={`agg-${sIdx}`} className="space-y-6"> | |
| {section.label && ( | |
| <h2 className="text-sm font-semibold text-gray-200 border-b border-gray-700 pb-2"> | |
| {section.label} | |
| {section.taskCount > 1 && ( | |
| <span className="text-gray-500 font-normal ml-2">(mean over {section.taskCount} tasks)</span> | |
| )} | |
| </h2> | |
| )} | |
| <div className="flex flex-wrap gap-6"> | |
| {section.metrics.map(({ metric: m, data }) => { | |
| const metricLabel = METRICS.find((x) => x.key === m)?.label ?? m; | |
| const metricRowLabels = sortRowLabels(Array.from(data.keys()).filter((l) => effectiveVisibleRows.has(l))); | |
| const metricEvs = EVAL_VARIANTS.filter((ev) => { | |
| for (const evMap of data.values()) { | |
| if (evMap.has(ev)) return true; | |
| } | |
| return false; | |
| }); | |
| if (metricRowLabels.length === 0) return null; | |
| return ( | |
| <HeatmapGrid | |
| key={m} | |
| data={data} | |
| rowLabels={metricRowLabels} | |
| colLabels={metricEvs} | |
| metric={m} | |
| title={metricLabel} | |
| fullConfigs={fullConfigs} | |
| /> | |
| ); | |
| })} | |
| </div> | |
| </div> | |
| ))} | |
| {/* Comparison: + TC-Self (side-by-side + delta) */} | |
| {comparisonPreset === "plus-tcself" && tcSelfPairs.length === 0 && ( | |
| <p className="text-sm text-gray-500 italic">No matched pairs found for + TC-Self comparison. Check that both non-TC and tcself models exist for this family.</p> | |
| )} | |
| {comparisonPreset === "plus-tcself" && tcSelfPairs.length > 0 && METRICS.map((m) => { | |
| const leftData = aggByTrainingTC(null, m.key); | |
| const rightData = aggByTrainingTC("tcself", m.key); | |
| const existingPairs = tcSelfPairs.filter(([l, r]) => leftData.has(l) && rightData.has(r)); | |
| if (existingPairs.length === 0) return null; | |
| const evs = EVAL_VARIANTS.filter((ev) => { | |
| for (const evMap of [...leftData.values(), ...rightData.values()]) { | |
| if (evMap.has(ev)) return true; | |
| } | |
| return false; | |
| }); | |
| return ( | |
| <SideBySideDelta | |
| key={`tcself-${m.key}`} | |
| pairs={existingPairs} | |
| leftData={leftData} | |
| rightData={rightData} | |
| colLabels={evs} | |
| metric={m.key} | |
| title={m.label} | |
| leftLabel="Without TC" | |
| rightLabel="+ TC-Self" | |
| fullConfigs={fullConfigs} | |
| /> | |
| ); | |
| })} | |
| {/* Comparison: TC-Self vs TC-Neg (side-by-side + delta) */} | |
| {comparisonPreset === "tcself-vs-tcneg" && tcSelfVsNegPairs.length === 0 && ( | |
| <p className="text-sm text-gray-500 italic">No matched pairs found for TC-Self vs TC-Neg comparison. Check that both tcself and tcneg models exist for this family.</p> | |
| )} | |
| {comparisonPreset === "tcself-vs-tcneg" && tcSelfVsNegPairs.length > 0 && METRICS.map((m) => { | |
| const leftData = aggByMatchedTC("self", m.key); | |
| const rightData = aggByMatchedTC("neg", m.key); | |
| const existingPairs = tcSelfVsNegPairs.filter(([l, r]) => leftData.has(l) && rightData.has(r)); | |
| if (existingPairs.length === 0) return null; | |
| const evs = EVAL_VARIANTS.filter((ev) => { | |
| for (const evMap of [...leftData.values(), ...rightData.values()]) { | |
| if (evMap.has(ev)) return true; | |
| } | |
| return false; | |
| }); | |
| return ( | |
| <SideBySideDelta | |
| key={`tcneg-${m.key}`} | |
| pairs={existingPairs} | |
| leftData={leftData} | |
| rightData={rightData} | |
| colLabels={evs} | |
| metric={m.key} | |
| title={m.label} | |
| leftLabel="TC-Self" | |
| rightLabel="TC-Neg" | |
| fullConfigs={fullConfigs} | |
| /> | |
| ); | |
| })} | |
| {/* Comparison: TC-Self vs TC-GPT2 (side-by-side + delta) */} | |
| {comparisonPreset === "tcself-vs-tcgpt2" && tcSelfVsGpt2Pairs.length === 0 && ( | |
| <p className="text-sm text-gray-500 italic">No matched pairs found for TC-Self vs TC-GPT2 comparison.</p> | |
| )} | |
| {comparisonPreset === "tcself-vs-tcgpt2" && tcSelfVsGpt2Pairs.length > 0 && METRICS.map((m) => { | |
| const leftData = aggByMatchedTC("self", m.key); | |
| const rightData = aggByMatchedTC("gpt2", m.key); | |
| const existingPairs = tcSelfVsGpt2Pairs.filter(([l, r]) => leftData.has(l) && rightData.has(r)); | |
| if (existingPairs.length === 0) return null; | |
| const evs = EVAL_VARIANTS.filter((ev) => { | |
| for (const evMap of [...leftData.values(), ...rightData.values()]) { | |
| if (evMap.has(ev)) return true; | |
| } | |
| return false; | |
| }); | |
| return ( | |
| <SideBySideDelta | |
| key={`tcgpt2-${m.key}`} | |
| pairs={existingPairs} | |
| leftData={leftData} | |
| rightData={rightData} | |
| colLabels={evs} | |
| metric={m.key} | |
| title={m.label} | |
| leftLabel="TC-Self" | |
| rightLabel="TC-GPT2" | |
| fullConfigs={fullConfigs} | |
| /> | |
| ); | |
| })} | |
| {/* Per-task breakdown — regular heatmap for all/training-effect */} | |
| {(comparisonPreset === "all" || comparisonPreset === "training-effect") && perTaskSections.length > 0 && ( | |
| <PerTaskCollapsible sections={perTaskSections} fullConfigs={fullConfigs} visibleRows={effectiveVisibleRows} /> | |
| )} | |
| {/* Per-task breakdown — side-by-side + delta for + TC-Self */} | |
| {comparisonPreset === "plus-tcself" && selectedTask === "__all__" && ( | |
| <PerTaskComparisonCollapsible | |
| filteredRows={filteredRows} | |
| pairs={tcSelfPairs} | |
| leftLabel="Without TC" | |
| rightLabel="+ TC-Self" | |
| leftFilter={(r) => { | |
| const label = getRowLabel(r); | |
| return !TC_FLAGS.some((f) => hasTCFlag(label, f)); | |
| }} | |
| rightFilter={(r) => hasTCFlag(getRowLabel(r), "tcself")} | |
| fullConfigs={fullConfigs} | |
| /> | |
| )} | |
| {/* Per-task breakdown — side-by-side + delta for TC-Self vs TC-Neg */} | |
| {comparisonPreset === "tcself-vs-tcneg" && selectedTask === "__all__" && ( | |
| <PerTaskComparisonCollapsible | |
| filteredRows={comparisonBaseRows} | |
| pairs={tcSelfVsNegPairs} | |
| leftLabel="TC-Self (eval)" | |
| rightLabel="TC-Neg (eval)" | |
| leftFilter={(r) => getEvalTCType(r) === "self"} | |
| rightFilter={(r) => getEvalTCType(r) === "neg"} | |
| fullConfigs={fullConfigs} | |
| /> | |
| )} | |
| {/* Per-task breakdown — side-by-side + delta for TC-Self vs TC-GPT2 */} | |
| {comparisonPreset === "tcself-vs-tcgpt2" && selectedTask === "__all__" && ( | |
| <PerTaskComparisonCollapsible | |
| filteredRows={comparisonBaseRows} | |
| pairs={tcSelfVsGpt2Pairs} | |
| leftLabel="TC-Self (eval)" | |
| rightLabel="TC-GPT2 (eval)" | |
| leftFilter={(r) => getEvalTCType(r) === "self"} | |
| rightFilter={(r) => getEvalTCType(r) === "gpt2"} | |
| fullConfigs={fullConfigs} | |
| /> | |
| )} | |
| </> | |
| ) : ( | |
| <div className="space-y-8"> | |
| {METRICS.map((m) => { | |
| const metricAgg = aggregateData(filteredRows, m.key, getRowLabel).data; | |
| const metricRowLabels = sortRowLabels(Array.from(metricAgg.keys()).filter((l) => effectiveVisibleRows.has(l))); | |
| if (metricRowLabels.length === 0) return null; | |
| return ( | |
| <BarChart | |
| key={m.key} | |
| data={metricAgg} | |
| rowLabels={metricRowLabels} | |
| evalVariants={availableEvalVariants.filter((ev) => ev === "raw" || ev === "tc")} | |
| metric={m.key} | |
| title={`${m.label}${selectedTask === "__all__" && stats.tasks > 1 ? ` (mean ± SE over ${stats.tasks} tasks)` : ""}`} | |
| /> | |
| ); | |
| })} | |
| </div> | |
| )} | |
| </div> | |
| </div> | |
| )} | |
| </div> | |
| ); | |
| } | |