juand-r's picture
Upload folder using huggingface_hub
0d2fd72 verified
import { useState, useEffect, useCallback, useMemo } from "react";
import { HF_ORG } from "../../../config";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface HeatmapViewerProps {
datasetRepo: string;
split?: string;
onClose: () => void;
}
interface SummaryRow {
model: string;
hf_model_name: string;
local_model_name: string;
task: string;
split: string;
self_tc: boolean;
neg_tc: boolean;
gpt2_tc: boolean;
finetuned: boolean;
training_config: string;
eval_variant: string;
gen_roc: number | null;
val_roc: number | null;
val_acc: number | null;
corr: number | null;
corr_pos: number | null;
corr_neg: number | null;
n_samples: number;
filename: string;
}
type Metric = "gen_roc" | "val_roc" | "val_acc" | "corr" | "corr_pos" | "corr_neg";
type EvalVariant = "raw" | "tc" | "lenorm" | "tc+lenorm";
type TCType = "none" | "self" | "neg" | "gpt2";
type ViewMode = "heatmap" | "bar";
type DomainFilter = "all" | "ood" | "in-domain";
type ComparisonPreset = "all" | "training-effect" | "plus-tcself" | "tcself-vs-tcneg" | "tcself-vs-tcgpt2";
// ---------------------------------------------------------------------------
// Constants
// ---------------------------------------------------------------------------
const HF_DATASETS_API = "https://datasets-server.huggingface.co";
const METRICS: { key: Metric; label: string }[] = [
{ key: "gen_roc", label: "Gen ROC" },
{ key: "val_roc", label: "Val ROC" },
{ key: "val_acc", label: "Val Acc" },
{ key: "corr", label: "Correlation" },
{ key: "corr_pos", label: "Corr (pos)" },
{ key: "corr_neg", label: "Corr (neg)" },
];
const EVAL_VARIANTS: EvalVariant[] = ["raw", "tc", "lenorm", "tc+lenorm"];
const TC_TYPES: { key: TCType; label: string }[] = [
{ key: "self", label: "Self TC" },
{ key: "neg", label: "Neg TC" },
{ key: "gpt2", label: "GPT-2 TC" },
];
// Plotly default color cycle for bar charts
const BAR_COLORS = [
"#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A",
"#19D3F3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52",
];
const TASK_FAMILIES = [
{ key: "hypernym", label: "Hypernym", prefix: "hypernym-" },
{ key: "ifeval", label: "IFEval", prefix: "ifeval-" },
{ key: "plausibleqa", label: "PlausibleQA", prefix: "plausibleqa-" },
{ key: "ambigqa", label: "AmbigQA", prefix: "ambigqa-" },
];
// Finetuned models must use force-same-x and the correct combined training dataset.
// Hypernym and IFEval have single-task variants that must be excluded.
function isValidFinetunedModel(row: SummaryRow): boolean {
if (!row.finetuned) return true;
const tc = row.training_config;
if (!tc.includes("force-same-x")) return false;
// Hypernym models must be trained on the combined "double" dataset
if (tc.includes("hypernym-") && !tc.includes("hypernym-concat-bananas-to-dogs-double")) return false;
// IFEval models must be trained on the concat dataset
if (tc.includes("ifeval-") && !tc.includes("ifeval-concat")) return false;
return true;
}
const COMPARISON_PRESETS: { key: ComparisonPreset; label: string }[] = [
{ key: "all", label: "All models" },
{ key: "training-effect", label: "Training Effect" },
{ key: "plus-tcself", label: "+ TC-Self" },
{ key: "tcself-vs-tcneg", label: "TC-Self vs TC-Neg" },
{ key: "tcself-vs-tcgpt2", label: "TC-Self vs TC-GPT2" },
];
// Preset row sets:
// Training Effect: Base + lo/semi × {Pref, SFT, Comb-v}
const TRAINING_EFFECT_ROWS = new Set([
"Base", "Pref-lo", "SFT-lo", "Comb-v-lo", "Pref-semi", "SFT-semi", "Comb-v-semi",
]);
const DEFAULT_HIDDEN_ROWS = new Set([
"Pref-v-lo",
"Pref-tcself-v-lo",
"Pref-tcself-norm-lo",
"Pref-tcself-norm-v-lo",
"Comb-tcself-norm-v-lo",
"SFT-tcself-norm-semi",
"Comb-tcself-norm-v-semi",
// PlausibleQA models without lo/semi — not comparable
"Comb-tcself",
"Comb-tcself-v",
"SFT-tcself-v",
]);
// Strip a TC flag from a row label to get the "base" label for pairing
const TC_FLAGS = ["tco", "tcself", "tcneg"];
function stripTCFlag(label: string): string {
const parts = label.split("-");
return parts.filter((p) => !TC_FLAGS.includes(p)).join("-");
}
function hasTCFlag(label: string, flag: string): boolean {
return label.split("-").includes(flag);
}
// IFEval prompts 1-21 are OOD (test-only, never trained on)
const IFEVAL_OOD_MAX = 21;
// Hypernym: fixed in-domain / OOD split
const HYPERNYM_IN_DOMAIN = new Set([
"bananas", "bazookas", "cabinets", "cars", "chairs", "crows", "diapers", "dogs",
]);
const HYPERNYM_OOD = new Set([
"ducklings", "elephants", "guns", "hammers", "helmets", "jackets", "kayaks", "kites", "mirrors",
]);
const HYPERNYM_VALID = new Set([...HYPERNYM_IN_DOMAIN, ...HYPERNYM_OOD]);
// Valid eval task patterns — excludes training tasks like concat, bare family names, etc.
function isValidEvalTask(task: string): boolean {
if (task.startsWith("hypernym-")) {
const subtask = task.replace("hypernym-", "");
return HYPERNYM_VALID.has(subtask);
}
if (task.startsWith("ifeval-")) {
return /^ifeval-prompt[_-]\d+$/.test(task);
}
if (task.startsWith("plausibleqa-")) {
return task !== "plausibleqa"; // must have a subtask
}
if (task.startsWith("ambigqa-")) {
return task !== "ambigqa"; // must have a subtask
}
// Bare family names without subtask suffix are training tasks
if (task === "ambigqa" || task === "plausibleqa" || task === "hypernym" || task === "ifeval") {
return false;
}
return true;
}
// ---------------------------------------------------------------------------
// Color scale: RdYlGn for [0, 1] mapped to percentages
// ---------------------------------------------------------------------------
// Plotly RdYlGn colorscale stops (matches Plotly's built-in)
const RDYLGN_STOPS: [number, [number, number, number]][] = [
[0.0, [165, 0, 38]],
[0.1, [215, 48, 39]],
[0.2, [244, 109, 67]],
[0.3, [253, 174, 97]],
[0.4, [254, 224, 139]],
[0.5, [255, 255, 191]],
[0.6, [217, 239, 139]],
[0.7, [166, 217, 106]],
[0.8, [102, 189, 99]],
[0.9, [26, 152, 80]],
[1.0, [0, 104, 55]],
];
function interpolateRdYlGn(t: number): string {
const clamped = Math.max(0, Math.min(1, t));
// Find the two stops to interpolate between
for (let i = 0; i < RDYLGN_STOPS.length - 1; i++) {
const [t0, c0] = RDYLGN_STOPS[i];
const [t1, c1] = RDYLGN_STOPS[i + 1];
if (clamped >= t0 && clamped <= t1) {
const f = (clamped - t0) / (t1 - t0);
const r = Math.round(c0[0] + f * (c1[0] - c0[0]));
const g = Math.round(c0[1] + f * (c1[1] - c0[1]));
const b = Math.round(c0[2] + f * (c1[2] - c0[2]));
return `rgb(${r}, ${g}, ${b})`;
}
}
return `rgb(0, 104, 55)`;
}
function rdYlGn(value: number): string {
if (isNaN(value)) return "#f3f4f6"; // light gray for no data
return interpolateRdYlGn(value);
}
function corrColor(value: number): string {
if (isNaN(value)) return "#f3f4f6";
return rdYlGn((value + 1) / 2);
}
function textColor(bgValue: number, isCorr: boolean): string {
if (isNaN(bgValue)) return "#9ca3af";
const t = isCorr ? (bgValue + 1) / 2 : bgValue;
// Dark text on the light middle, white text on dark extremes
return t > 0.3 && t < 0.7 ? "#1f2937" : "#ffffff";
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
function getEvalTCType(row: SummaryRow): TCType {
if (row.self_tc) return "self";
if (row.neg_tc) return "neg";
if (row.gpt2_tc) return "gpt2";
return "none";
}
function getTrainingTCType(row: SummaryRow): TCType {
const tc = row.training_config;
if (tc.includes("tc-online") || tc.includes("tc_online")) return "gpt2";
if (tc.includes("tc-self") || tc.includes("tc_self")) return "self";
if (tc.includes("tc-neg") || tc.includes("tc_neg")) return "neg";
return "none";
}
/** For finetuned models, check that eval TC matches training TC */
function isMatchedTC(row: SummaryRow): boolean {
if (!row.finetuned) return true; // base models: any eval TC is valid
const trainTC = getTrainingTCType(row);
if (trainTC === "none") return true; // model not trained with TC: any eval TC is valid
const evalTC = getEvalTCType(row);
return evalTC === trainTC;
}
function getTaskFamily(task: string): string | null {
for (const fam of TASK_FAMILIES) {
if (task.startsWith(fam.prefix)) return fam.key;
}
return null;
}
function parseIfevalPromptNum(task: string): number | null {
// ifeval-prompt_42 or ifeval-prompt-42
const m = task.match(/ifeval-prompt[_-](\d+)/);
return m ? parseInt(m[1], 10) : null;
}
function isOodTask(task: string, trainingConfig: string, family: string): boolean | null {
// Returns true=OOD, false=in-domain, null=unknown/not applicable
if (family === "plausibleqa" || family === "ambigqa") return true; // always OOD
if (family === "ifeval") {
const num = parseIfevalPromptNum(task);
if (num === null) return null;
return num <= IFEVAL_OOD_MAX;
}
if (family === "hypernym") {
const subtask = task.replace("hypernym-", "");
if (HYPERNYM_OOD.has(subtask)) return true;
if (HYPERNYM_IN_DOMAIN.has(subtask)) return false;
return null; // unknown subtask
}
return null;
}
function getTrainingFamily(config: string): string | null {
for (const fam of TASK_FAMILIES) {
if (config.includes(fam.prefix)) return fam.key;
}
return null;
}
function extractFloat(pattern: RegExp, text: string): number | null {
const m = text.match(pattern);
return m ? parseFloat(m[1]) : null;
}
function parseTrainingMode(config: string): string {
// Parse pref/nllv/nllg weights from training config.
// Defaults when absent: pref=1.0, nllv=0.0, nllg=0.0
const pref = extractFloat(/(?:^|[_-])pref(\d+(?:\.\d+)?)/, config) ?? 1.0;
const nllv = extractFloat(/nllv(\d+(?:\.\d+)?)/, config) ?? 0.0;
const nllg = extractFloat(/nllg(\d+(?:\.\d+)?)/, config) ?? 0.0;
const isSFT = pref === 0.0 && nllv === 1.0 && nllg === 1.0;
const isPref = nllv === 0.0 && nllg === 0.0;
const isComb = nllv === 1.0 && nllg === 1.0 && pref === 1.0;
if (isSFT) return "SFT";
if (isComb) return "Comb";
if (isPref) return "Pref";
return "Pref"; // fallback
}
function buildRowLabel(config: string): string {
const mode = parseTrainingMode(config);
const flags: string[] = [];
// TC flags are mutually exclusive
const hasTco = config.includes("_tc-online_") || config.includes("-tc-online-");
const hasTcself = config.includes("_tc-self_") || config.includes("-tc-self-");
const hasTcneg = config.includes("_tc-neg_") || config.includes("-tc-neg-");
const tcCount = [hasTco, hasTcself, hasTcneg].filter(Boolean).length;
if (tcCount > 1) {
console.warn(`Multiple TC flags in training config (expected at most 1): ${config}`);
}
if (hasTco) flags.push("tco");
if (hasTcself) flags.push("tcself");
if (hasTcneg) flags.push("tcneg");
// Optional independent flags
if (config.includes("_lenorm_") || config.includes("-lenorm-")) flags.push("norm");
if (config.includes("_vallogodds") || config.includes("-vallogodds")) flags.push("v");
// Data regime flags
if (config.includes("labelonly")) flags.push("lo");
if (config.includes("semi")) flags.push("semi");
const parts = [mode, ...flags];
return parts.join("-");
}
function getRowLabel(row: SummaryRow): string {
if (!row.finetuned) return "Base";
return buildRowLabel(row.training_config);
}
// Display-friendly row label: reorder to [regime, tc, mode, extras], spaces, Pref→RankAlign
function displayRowLabel(label: string): string {
if (label === "Base") return "Base";
const parts = label.split("-");
const mode = parts[0] === "Pref" ? "RankAlign" : parts[0];
const flags = parts.slice(1);
// Extract known flag groups
const regime = flags.filter((f) => f === "lo" || f === "semi");
const tc = flags.filter((f) => f === "tcself" || f === "tcneg" || f === "tco");
const extras = flags.filter((f) => !["lo", "semi", "tcself", "tcneg", "tco"].includes(f));
// Order: regime, tc, mode, extras
return [...regime, ...tc, mode, ...extras].join(" ");
}
// JSX version with colored TC flags for use in HTML contexts
function DisplayRowLabel({ label }: { label: string }) {
if (label === "Base") return <>Base</>;
const parts = label.split("-");
const mode = parts[0] === "Pref" ? "RankAlign" : parts[0];
const flags = parts.slice(1);
const regime = flags.filter((f) => f === "lo" || f === "semi");
const tc = flags.filter((f) => f === "tcself" || f === "tcneg" || f === "tco");
const extras = flags.filter((f) => !["lo", "semi", "tcself", "tcneg", "tco"].includes(f));
const tokens = [...regime, ...tc, mode, ...extras];
const tcColor: Record<string, string> = { tcself: "#f87171", tcneg: "#fb923c", tco: "#a78bfa" };
return (
<>
{tokens.map((t, i) => (
<span key={i}>
{i > 0 && " "}
{tcColor[t] ? <span style={{ color: tcColor[t] }}>{t}</span> : t}
</span>
))}
</>
);
}
// SVG version with colored TC flags (uses tspan)
function SvgRowLabel({ label, maxLen = 22 }: { label: string; maxLen?: number }) {
if (label === "Base") return <>Base</>;
const parts = label.split("-");
const mode = parts[0] === "Pref" ? "RankAlign" : parts[0];
const flags = parts.slice(1);
const regime = flags.filter((f) => f === "lo" || f === "semi");
const tc = flags.filter((f) => f === "tcself" || f === "tcneg" || f === "tco");
const extras = flags.filter((f) => !["lo", "semi", "tcself", "tcneg", "tco"].includes(f));
const tokens = [...regime, ...tc, mode, ...extras];
const full = tokens.join(" ");
const display = full.length > maxLen ? full.slice(0, maxLen - 2) + ".." : full;
const tcColor: Record<string, string> = { tcself: "#f87171", tcneg: "#fb923c", tco: "#a78bfa" };
// Re-tokenize the display string to color tc flags
const displayTokens = display.split(" ");
return (
<>
{displayTokens.map((t, i) => (
<tspan key={i} fill={tcColor[t] || undefined}>
{i > 0 && " "}{t}
</tspan>
))}
</>
);
}
function mean(values: number[]): number {
if (values.length === 0) return NaN;
return values.reduce((a, b) => a + b, 0) / values.length;
}
function stdErr(values: number[]): number {
if (values.length < 2) return 0;
const m = mean(values);
const variance = values.reduce((sum, v) => sum + (v - m) ** 2, 0) / (values.length - 1);
return Math.sqrt(variance) / Math.sqrt(values.length);
}
// ---------------------------------------------------------------------------
// Data fetching — downloads the parquet file directly (1 request, ~1MB)
// ---------------------------------------------------------------------------
function useDatasetRows(repo: string, _split: string) {
const [rows, setRows] = useState<SummaryRow[]>([]);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [progress, setProgress] = useState({ loaded: 0, total: 0 });
const fetchAll = useCallback(async () => {
setLoading(true);
setError(null);
setRows([]);
try {
// Step 1: get the parquet file URL from HF datasets server
const metaUrl = `${HF_DATASETS_API}/parquet?dataset=${encodeURIComponent(repo)}`;
const metaResp = await fetch(metaUrl);
if (!metaResp.ok) {
throw new Error(`Failed to get parquet info: ${metaResp.status}`);
}
const metaData = await metaResp.json();
const parquetFiles = metaData.parquet_files ?? [];
const matchingFile = parquetFiles.find((f: { split: string }) => f.split === _split) ?? parquetFiles[0];
if (!matchingFile) {
throw new Error("No parquet files found for this dataset");
}
setProgress({ loaded: 0, total: matchingFile.size ?? 0 });
// Step 2: download the parquet file directly from HF (static file, no rate limits)
const parquetUrl: string = matchingFile.url;
const { asyncBufferFromUrl, parquetRead } = await import("hyparquet");
const file = await asyncBufferFromUrl({ url: parquetUrl });
// Step 3: parse all rows
const allRows: SummaryRow[] = [];
await parquetRead({
file,
rowFormat: "object",
onComplete: (data: Record<string, unknown>[]) => {
for (const row of data) {
allRows.push(row as unknown as SummaryRow);
}
},
});
setProgress({ loaded: allRows.length, total: allRows.length });
setRows(allRows);
} catch (err) {
setError(err instanceof Error ? err.message : String(err));
} finally {
setLoading(false);
}
}, [repo, _split]);
useEffect(() => {
fetchAll();
}, [fetchAll]);
return { rows, loading, error, progress, refetch: fetchAll };
}
// ---------------------------------------------------------------------------
// Aggregation
// ---------------------------------------------------------------------------
interface AggCell {
mean: number;
se: number;
n: number;
}
/**
* Validate that each row label maps to exactly one model identity.
* Checks BOTH the base model (r.model) AND training_config.
* If two different models or configs produce the same row label,
* we're silently averaging different models — a critical bug.
*
* Returns null if clean, or an error message string if collisions found.
*/
function validateRowLabels(rows: SummaryRow[], groupByRow: (r: SummaryRow) => string): string | null {
// Check 1: all rows must come from the same base model
const baseModels = new Set(rows.map((r) => r.model));
if (baseModels.size > 1) {
return (
`DATA INTEGRITY ERROR: Multiple base models in the same view!\n` +
`Found ${baseModels.size} different models being mixed together:\n` +
Array.from(baseModels).map((m) => ` • ${m}`).join("\n") +
`\n\nThis means results from different models are being averaged. ` +
`Filter by a single model before displaying.`
);
}
// Check 2: each row label maps to exactly one training_config
const labelToConfigs = new Map<string, Set<string>>();
for (const row of rows) {
const label = groupByRow(row);
const config = row.finetuned ? row.training_config : "__base__";
if (!labelToConfigs.has(label)) labelToConfigs.set(label, new Set());
labelToConfigs.get(label)!.add(config);
}
const collisions: string[] = [];
for (const [label, configs] of labelToConfigs) {
if (configs.size > 1) {
collisions.push(
`"${label}" → ${configs.size} configs: ${Array.from(configs).join(", ")}`
);
}
}
if (collisions.length > 0) {
return (
`DATA INTEGRITY ERROR: Row label collisions detected!\n` +
`The following labels map to multiple different model configs ` +
`(their results are being silently averaged):\n\n` +
collisions.join("\n")
);
}
return null;
}
interface AggResult {
data: Map<string, Map<EvalVariant, AggCell>>;
validationError: string | null;
}
function aggregateData(
rows: SummaryRow[],
metric: Metric,
groupByRow: (r: SummaryRow) => string,
): AggResult {
// Validate: each row label must correspond to one model
const validationError = validateRowLabels(rows, groupByRow);
// Group: rowLabel → evalVariant → values[]
const groups = new Map<string, Map<EvalVariant, number[]>>();
for (const row of rows) {
const rl = groupByRow(row);
const ev = row.eval_variant as EvalVariant;
if (!EVAL_VARIANTS.includes(ev)) continue;
const val = row[metric];
if (val === null || val === undefined || isNaN(val as number)) continue;
if (!groups.has(rl)) groups.set(rl, new Map());
const evMap = groups.get(rl)!;
if (!evMap.has(ev)) evMap.set(ev, []);
evMap.get(ev)!.push(val as number);
}
// Compute aggregates
const result = new Map<string, Map<EvalVariant, AggCell>>();
for (const [rl, evMap] of groups) {
const aggMap = new Map<EvalVariant, AggCell>();
for (const [ev, vals] of evMap) {
aggMap.set(ev, { mean: mean(vals), se: stdErr(vals), n: vals.length });
}
result.set(rl, aggMap);
}
return { data: result, validationError };
}
// ---------------------------------------------------------------------------
// Sort row labels: Base first, then alphabetical
// ---------------------------------------------------------------------------
// Sort order: Base first, then by [tc group, data regime, mode]
// tc group: no-tc (0) < tcself (1) < tcneg (2) < tco (3)
// data regime: lo (0) < semi (1) < other (2)
// mode: Pref (0) < SFT (1) < Comb (2) < other (3)
function rowSortKey(label: string): [number, number, number, number, string] {
if (label === "Base") return [-1, 0, 0, 0, ""];
const parts = label.split("-");
// TC group
let tcGroup = 0;
if (parts.includes("tcself")) tcGroup = 1;
else if (parts.includes("tcneg")) tcGroup = 2;
else if (parts.includes("tco")) tcGroup = 3;
// Data regime
let regime = 2;
if (parts.includes("lo")) regime = 0;
else if (parts.includes("semi")) regime = 1;
// Mode (first part)
const mode = parts[0] ?? "";
const modeOrder: Record<string, number> = { Pref: 0, SFT: 1, Comb: 2 };
// Remaining flags for tiebreak
const remaining = parts.filter((p) => !["Pref", "SFT", "Comb"].includes(p)).join("-");
return [0, tcGroup, regime, modeOrder[mode] ?? 3, remaining];
}
function sortRowLabels(labels: string[]): string[] {
return [...labels].sort((a, b) => {
const ka = rowSortKey(a);
const kb = rowSortKey(b);
for (let i = 0; i < 4; i++) {
if (ka[i] !== kb[i]) return (ka[i] as number) - (kb[i] as number);
}
return ka[4].localeCompare(kb[4]);
});
}
// ---------------------------------------------------------------------------
// Sub-components
// ---------------------------------------------------------------------------
function Dropdown<T extends string>({
label,
value,
options,
onChange,
}: {
label: string;
value: T;
options: { key: T; label: string }[];
onChange: (v: T) => void;
}) {
return (
<div className="flex flex-col gap-1">
<label className="text-[10px] uppercase tracking-wider text-gray-500 font-medium">{label}</label>
<select
value={value}
onChange={(e) => onChange(e.target.value as T)}
className="bg-gray-800 text-gray-200 text-xs border border-gray-700 rounded px-2 py-1.5 focus:outline-none focus:border-cyan-600"
>
{options.map((o) => (
<option key={o.key} value={o.key}>
{o.label}
</option>
))}
</select>
</div>
);
}
function MultiSelect<T extends string>({
label,
selected,
options,
onChange,
}: {
label: string;
selected: Set<T>;
options: { key: T; label: string }[];
onChange: (s: Set<T>) => void;
}) {
const toggle = (key: T) => {
const next = new Set(selected);
if (next.has(key)) next.delete(key);
else next.add(key);
onChange(next);
};
return (
<div className="flex flex-col gap-1">
<label className="text-[10px] uppercase tracking-wider text-gray-500 font-medium">{label}</label>
<div className="flex flex-wrap gap-1">
{options.map((o) => (
<button
key={o.key}
onClick={() => toggle(o.key)}
className={`text-xs px-2 py-1 rounded border transition-colors ${
selected.has(o.key)
? "bg-cyan-800/60 text-cyan-200 border-cyan-600/60"
: "bg-gray-800 text-gray-500 border-gray-700 hover:text-gray-300"
}`}
>
{o.label}
</button>
))}
</div>
</div>
);
}
// ---------------------------------------------------------------------------
// Heatmap component (pure HTML/CSS)
// ---------------------------------------------------------------------------
function HeatmapGrid({
data,
rowLabels,
colLabels,
metric,
title,
fullConfigs,
}: {
data: Map<string, Map<EvalVariant, AggCell>>;
rowLabels: string[];
colLabels: EvalVariant[];
metric: Metric;
title: string;
fullConfigs: Map<string, string>;
}) {
const isCorr = metric.startsWith("corr");
const formatVal = (v: number) => {
if (isNaN(v)) return "-";
return isCorr ? v.toFixed(3) : (v * 100).toFixed(1);
};
const colorFn = isCorr ? corrColor : rdYlGn;
return (
<div className="flex flex-col gap-2">
<h3 className="text-xs font-medium text-gray-300">{title}</h3>
<div className="overflow-x-auto">
<table className="border-collapse text-xs">
<thead>
<tr>
<th className="text-left py-1 px-2 text-gray-500 font-normal min-w-[160px] max-w-[240px]">Model</th>
{colLabels.map((col) => (
<th key={col} className="text-center py-1 px-3 text-gray-400 font-medium min-w-[72px]">
{col}
</th>
))}
</tr>
</thead>
<tbody>
{rowLabels.map((rl) => {
const evMap = data.get(rl);
return (
<tr key={rl} className="group">
<td
className="py-1 px-2 text-gray-300 font-mono truncate max-w-[240px]"
title={fullConfigs.get(rl) || rl}
>
<DisplayRowLabel label={rl} />
</td>
{colLabels.map((col) => {
const cell = evMap?.get(col);
const val = cell?.mean ?? NaN;
const bg = colorFn(isCorr ? val : val);
const fg = textColor(val, isCorr);
return (
<td
key={col}
className="text-center py-1.5 px-2 font-mono border border-gray-800/50 cursor-default transition-all hover:ring-1 hover:ring-cyan-500/50"
style={{ backgroundColor: bg, color: fg }}
title={cell ? `${formatVal(val)}${cell.n > 1 ? ` (n=${cell.n}, se=${formatVal(cell.se)})` : ""}` : "no data"}
>
{formatVal(val)}
</td>
);
})}
</tr>
);
})}
</tbody>
</table>
</div>
</div>
);
}
// ---------------------------------------------------------------------------
// Bar chart component (SVG)
// ---------------------------------------------------------------------------
function BarChart({
data,
rowLabels,
evalVariants,
metric,
title,
}: {
data: Map<string, Map<EvalVariant, AggCell>>;
rowLabels: string[];
evalVariants: EvalVariant[];
metric: Metric;
title: string;
}) {
const isCorr = metric.startsWith("corr");
const subBarWidth = 18;
const subGap = 2;
const groupGap = 16;
const numSub = evalVariants.length;
const groupWidth = numSub * subBarWidth + (numSub - 1) * subGap;
const chartHeight = 200;
const marginTop = 20;
const marginBottom = 80;
const marginLeft = 50;
const legendHeight = 24;
const svgWidth = marginLeft + rowLabels.length * (groupWidth + groupGap) + 20;
const svgHeight = chartHeight + marginTop + marginBottom + legendHeight;
// Collect all values for scale
const allCells: { mean: number; se: number }[] = [];
for (const rl of rowLabels) {
for (const ev of evalVariants) {
const cell = data.get(rl)?.get(ev);
if (cell && !isNaN(cell.mean)) allCells.push(cell);
}
}
if (allCells.length === 0) {
return (
<div className="flex flex-col gap-2">
<h3 className="text-xs font-medium text-gray-300">{title}</h3>
<p className="text-xs text-gray-500 italic">No data</p>
</div>
);
}
let minVal: number, maxVal: number;
const low = Math.min(...allCells.map((c) => c.mean - c.se));
const high = Math.max(...allCells.map((c) => c.mean + c.se));
if (isCorr) {
// Data-driven range with padding, clamped to [-1, 1]
minVal = Math.max(-1, low - 0.05);
maxVal = Math.min(1, high + 0.05);
} else {
minVal = Math.max(0, low - 0.05);
maxVal = Math.min(1, high + 0.05);
}
const range = maxVal - minVal || 1;
const yScale = (v: number) => marginTop + chartHeight * (1 - (v - minVal) / range);
// Tick marks
const ticks: number[] = [];
const corrRange = maxVal - minVal;
const step = isCorr ? (corrRange > 0.5 ? 0.2 : 0.1) : 0.1;
for (let t = Math.ceil(minVal / step) * step; t <= maxVal; t += step) {
ticks.push(Math.round(t * 1000) / 1000);
}
// Shade: raw=full, tc=darker, lenorm=lighter, tc+lenorm=darkest
const variantOpacity: Record<string, number> = { raw: 0.85, tc: 0.6, lenorm: 0.7, "tc+lenorm": 0.45 };
return (
<div className="flex flex-col gap-2">
<h3 className="text-xs font-medium text-gray-300">{title}</h3>
<div className="overflow-x-auto">
<svg width={svgWidth} height={svgHeight} className="block">
{/* Hatching patterns for tc variants */}
<defs>
<pattern id="hatch-tc" patternUnits="userSpaceOnUse" width="4" height="4" patternTransform="rotate(45)">
<line x1="0" y1="0" x2="0" y2="4" stroke="rgba(0,0,0,0.3)" strokeWidth="1" />
</pattern>
<pattern id="hatch-tclenorm" patternUnits="userSpaceOnUse" width="4" height="4" patternTransform="rotate(-45)">
<line x1="0" y1="0" x2="0" y2="4" stroke="rgba(0,0,0,0.4)" strokeWidth="1.5" />
</pattern>
</defs>
{/* Grid lines */}
{ticks.map((t) => (
<g key={t}>
<line x1={marginLeft} x2={svgWidth - 10} y1={yScale(t)} y2={yScale(t)} stroke="#374151" strokeWidth={0.5} />
<text x={marginLeft - 4} y={yScale(t) + 3} textAnchor="end" className="fill-gray-500 text-[10px]">
{isCorr ? t.toFixed(1) : (t * 100).toFixed(0)}
</text>
</g>
))}
{/* Grouped bars */}
{rowLabels.map((rl, i) => {
const groupX = marginLeft + i * (groupWidth + groupGap);
const color = BAR_COLORS[i % BAR_COLORS.length];
const labelX = groupX + groupWidth / 2;
return (
<g key={rl}>
{evalVariants.map((ev, j) => {
const cell = data.get(rl)?.get(ev);
const val = cell?.mean ?? NaN;
if (isNaN(val)) return null;
const x = groupX + j * (subBarWidth + subGap);
const barY = yScale(val);
const baseY = yScale(isCorr ? 0 : minVal);
const barH = Math.abs(baseY - barY);
const actualY = Math.min(barY, baseY);
const opacity = variantOpacity[ev] ?? 0.85;
const hasHatch = ev === "tc" || ev === "tc+lenorm";
const cx = x + subBarWidth / 2;
const se = cell?.se ?? 0;
const seTop = yScale(Math.min(maxVal, val + se));
const seBot = yScale(Math.max(minVal, val - se));
return (
<g key={ev}>
<rect x={x} y={actualY} width={subBarWidth} height={barH} fill={color} rx={1} opacity={opacity} />
{hasHatch && (
<rect x={x} y={actualY} width={subBarWidth} height={barH} fill={ev === "tc" ? "url(#hatch-tc)" : "url(#hatch-tclenorm)"} rx={1} />
)}
{se > 0 && (
<>
<line x1={cx} x2={cx} y1={seTop} y2={seBot} stroke="#e5e7eb" strokeWidth={1} />
<line x1={cx - 3} x2={cx + 3} y1={seTop} y2={seTop} stroke="#e5e7eb" strokeWidth={1} />
<line x1={cx - 3} x2={cx + 3} y1={seBot} y2={seBot} stroke="#e5e7eb" strokeWidth={1} />
</>
)}
<title>{`${displayRowLabel(rl)} [${ev}]: ${isCorr ? val.toFixed(3) : (val * 100).toFixed(1)} (n=${cell?.n ?? 0}, se=${se.toFixed(4)})`}</title>
</g>
);
})}
{/* Model label */}
<text
x={labelX}
y={marginTop + chartHeight + 8}
textAnchor="end"
transform={`rotate(-45, ${labelX}, ${marginTop + chartHeight + 8})`}
className="fill-gray-400 text-[9px]"
>
<SvgRowLabel label={rl} maxLen={22} />
</text>
</g>
);
})}
{/* Legend */}
{evalVariants.map((ev, j) => {
const lx = marginLeft + j * 80;
const ly = svgHeight - 10;
const opacity = variantOpacity[ev] ?? 0.85;
const hasHatch = ev === "tc" || ev === "tc+lenorm";
return (
<g key={ev}>
<rect x={lx} y={ly - 8} width={12} height={8} fill="#636EFA" opacity={opacity} rx={1} />
{hasHatch && <rect x={lx} y={ly - 8} width={12} height={8} fill={ev === "tc" ? "url(#hatch-tc)" : "url(#hatch-tclenorm)"} rx={1} />}
<text x={lx + 16} y={ly} className="fill-gray-400 text-[9px]">{ev}</text>
</g>
);
})}
</svg>
</div>
</div>
);
}
// ---------------------------------------------------------------------------
// Delta comparison view: side-by-side heatmaps + delta
// ---------------------------------------------------------------------------
// Diverging color scale for deltas: red (negative) ↔ white (zero) ↔ blue (positive)
// Uses Plotly's RdBu stops (reversed so blue = positive)
const DELTA_NEG: [number, number, number][] = [
[103, 0, 31], // -max: deep red
[178, 24, 43],
[214, 96, 77],
[244, 165, 130],
[253, 219, 199],
[255, 255, 255], // zero: white
];
const DELTA_POS: [number, number, number][] = [
[255, 255, 255], // zero: white
[209, 229, 240],
[146, 197, 222],
[67, 147, 195],
[33, 102, 172],
[5, 48, 97], // +max: deep blue
];
function deltaColor(delta: number): string {
if (isNaN(delta)) return "#f3f4f6";
const maxDelta = 0.10; // saturate at ±10pp
const t = Math.min(Math.abs(delta) / maxDelta, 1); // 0..1
const stops = delta < 0 ? DELTA_NEG : DELTA_POS;
// Interpolate across 6 stops (indices 0..5), t maps to position in stops
const pos = t * (stops.length - 1);
const i = Math.min(Math.floor(pos), stops.length - 2);
const f = pos - i;
// For negative, go from white (index 5) toward deep red (index 0) → reverse index
const idx = delta < 0 ? (stops.length - 1) - i : i;
const idxNext = delta < 0 ? idx - 1 : idx + 1;
const c0 = stops[idx];
const c1 = stops[Math.max(0, Math.min(stops.length - 1, idxNext))];
const r = Math.round(c0[0] + f * (c1[0] - c0[0]));
const g = Math.round(c0[1] + f * (c1[1] - c0[1]));
const b = Math.round(c0[2] + f * (c1[2] - c0[2]));
return `rgb(${r}, ${g}, ${b})`;
}
function deltaTextColor(delta: number): string {
if (isNaN(delta)) return "#9ca3af";
const t = Math.min(Math.abs(delta) / 0.10, 1);
return t > 0.6 ? "#ffffff" : "#1f2937";
}
function SideBySideDelta({
pairs,
leftData,
rightData,
colLabels,
metric,
title,
leftLabel,
rightLabel,
fullConfigs,
}: {
pairs: [string, string][]; // [leftRowLabel, rightRowLabel]
leftData: Map<string, Map<EvalVariant, AggCell>>;
rightData: Map<string, Map<EvalVariant, AggCell>>;
colLabels: EvalVariant[];
metric: Metric;
title: string;
leftLabel: string;
rightLabel: string;
fullConfigs: Map<string, string>;
}) {
const isCorr = metric.startsWith("corr");
const formatVal = (v: number) => {
if (isNaN(v)) return "-";
return isCorr ? v.toFixed(3) : (v * 100).toFixed(1);
};
const formatDelta = (v: number) => {
if (isNaN(v)) return "-";
const pp = isCorr ? v : v * 100;
const sign = pp > 0 ? "+" : "";
return isCorr ? `${sign}${pp.toFixed(3)}` : `${sign}${pp.toFixed(1)}`;
};
const colorFn = isCorr ? corrColor : rdYlGn;
// Build pair labels (strip the tc flag difference for a clean display)
const pairLabels = pairs.map(([l]) => l);
return (
<div className="flex flex-col gap-3">
<h3 className="text-xs font-medium text-gray-300">{title}</h3>
<div className="flex gap-6 overflow-x-auto">
{/* Left heatmap */}
<div>
<p className="text-[10px] uppercase tracking-wider text-gray-500 font-medium mb-1">{leftLabel}</p>
<table className="border-collapse text-xs">
<thead>
<tr>
<th className="text-left py-1 px-2 text-gray-500 font-normal min-w-[120px]">Model</th>
{colLabels.map((col) => (
<th key={col} className="text-center py-1 px-3 text-gray-400 font-medium min-w-[60px]">{col}</th>
))}
</tr>
</thead>
<tbody>
{pairs.map(([leftLabel_]) => {
const evMap = leftData.get(leftLabel_);
return (
<tr key={leftLabel_}>
<td className="py-1 px-2 text-gray-300 font-mono text-[11px]" title={fullConfigs.get(leftLabel_) || leftLabel_}><DisplayRowLabel label={leftLabel_} /></td>
{colLabels.map((col) => {
const cell = evMap?.get(col);
const val = cell?.mean ?? NaN;
const bg = colorFn(val);
const fg = textColor(val, isCorr);
return (
<td key={col} className="text-center py-1.5 px-2 font-mono border border-gray-800/50" style={{ backgroundColor: bg, color: fg }}>
{formatVal(val)}
</td>
);
})}
</tr>
);
})}
</tbody>
</table>
</div>
{/* Right heatmap */}
<div>
<p className="text-[10px] uppercase tracking-wider text-gray-500 font-medium mb-1">{rightLabel}</p>
<table className="border-collapse text-xs">
<thead>
<tr>
<th className="text-left py-1 px-1.5 text-gray-500 font-normal min-w-[80px]">Model</th>
{colLabels.map((col) => (
<th key={col} className="text-center py-1 px-3 text-gray-400 font-medium min-w-[60px]">{col}</th>
))}
</tr>
</thead>
<tbody>
{pairs.map(([, rightLabel_]) => {
const evMap = rightData.get(rightLabel_);
return (
<tr key={rightLabel_}>
<td className="py-1 px-1.5 text-gray-300 font-mono text-[9px] whitespace-nowrap" title={fullConfigs.get(rightLabel_) || rightLabel_}><DisplayRowLabel label={rightLabel_} /></td>
{colLabels.map((col) => {
const cell = evMap?.get(col);
const val = cell?.mean ?? NaN;
const bg = colorFn(val);
const fg = textColor(val, isCorr);
return (
<td key={col} className="text-center py-1.5 px-2 font-mono border border-gray-800/50" style={{ backgroundColor: bg, color: fg }}>
{formatVal(val)}
</td>
);
})}
</tr>
);
})}
</tbody>
</table>
</div>
{/* Delta heatmap */}
<div>
<p className="text-[10px] uppercase tracking-wider text-gray-500 font-medium mb-1">Delta ({rightLabel} − {leftLabel})</p>
<table className="border-collapse text-xs">
<thead>
<tr>
<th className="text-left py-1 px-2 text-gray-500 font-normal min-w-[120px]">Model</th>
{colLabels.map((col) => (
<th key={col} className="text-center py-1 px-3 text-gray-400 font-medium min-w-[60px]">{col}</th>
))}
</tr>
</thead>
<tbody>
{pairs.map(([leftLabel_, rightLabel_]) => {
const leftEvMap = leftData.get(leftLabel_);
const rightEvMap = rightData.get(rightLabel_);
return (
<tr key={leftLabel_}>
<td className="py-1 px-2 text-gray-300 font-mono text-[11px]"><DisplayRowLabel label={pairLabels[pairs.findIndex(([l]) => l === leftLabel_)]} /></td>
{colLabels.map((col) => {
const leftVal = leftEvMap?.get(col)?.mean ?? NaN;
const rightVal = rightEvMap?.get(col)?.mean ?? NaN;
const delta = (!isNaN(leftVal) && !isNaN(rightVal)) ? rightVal - leftVal : NaN;
const bg = deltaColor(delta);
const fg = deltaTextColor(delta);
return (
<td key={col} className="text-center py-1.5 px-2 font-mono border border-gray-800/50" style={{ backgroundColor: bg, color: fg }}>
{formatDelta(delta)}
</td>
);
})}
</tr>
);
})}
</tbody>
</table>
</div>
</div>
</div>
);
}
// ---------------------------------------------------------------------------
// Per-task collapsible section
// ---------------------------------------------------------------------------
function PerTaskCollapsible({
sections,
fullConfigs,
visibleRows,
}: {
sections: { label: string; taskCount: number; metrics: { metric: Metric; data: Map<string, Map<EvalVariant, AggCell>> }[] }[];
fullConfigs: Map<string, string>;
visibleRows: Set<string>;
}) {
const [open, setOpen] = useState(false);
return (
<div className="border-t border-gray-700 pt-4">
<button
onClick={() => setOpen((o) => !o)}
className="flex items-center gap-2 text-sm font-semibold text-gray-300 hover:text-gray-100 transition-colors"
>
<span className={`text-xs transition-transform ${open ? "rotate-90" : ""}`}></span>
Per-task breakdown ({sections.length} tasks)
</button>
{open && (
<div className="mt-4 space-y-10">
{sections.map((section) => (
<div key={section.label} className="space-y-4">
<h3 className="text-sm font-medium text-cyan-300 border-b border-gray-800 pb-1">{section.label}</h3>
<div className="flex flex-wrap gap-6">
{section.metrics.map(({ metric: m, data }) => {
const metricLabel = METRICS.find((x) => x.key === m)?.label ?? m;
const metricRowLabels = sortRowLabels(Array.from(data.keys()).filter((l) => visibleRows.has(l)));
const metricEvs = EVAL_VARIANTS.filter((ev) => {
for (const evMap of data.values()) {
if (evMap.has(ev)) return true;
}
return false;
});
if (metricRowLabels.length === 0) return null;
return (
<HeatmapGrid
key={m}
data={data}
rowLabels={metricRowLabels}
colLabels={metricEvs}
metric={m}
title={metricLabel}
fullConfigs={fullConfigs}
/>
);
})}
</div>
</div>
))}
</div>
)}
</div>
);
}
// ---------------------------------------------------------------------------
// Per-task comparison collapsible (side-by-side + delta per task)
// ---------------------------------------------------------------------------
function PerTaskComparisonCollapsible({
filteredRows,
pairs,
leftLabel,
rightLabel,
leftFilter,
rightFilter,
fullConfigs,
}: {
filteredRows: SummaryRow[];
pairs: [string, string][];
leftLabel: string;
rightLabel: string;
leftFilter: (r: SummaryRow) => boolean;
rightFilter: (r: SummaryRow) => boolean;
fullConfigs: Map<string, string>;
}) {
const [open, setOpen] = useState(false);
const tasks = useMemo(() => {
const t = new Set(filteredRows.map((r) => r.task));
return Array.from(t).sort();
}, [filteredRows]);
if (tasks.length <= 1) return null;
return (
<div className="border-t border-gray-700 pt-4">
<button
onClick={() => setOpen((o) => !o)}
className="flex items-center gap-2 text-sm font-semibold text-gray-300 hover:text-gray-100 transition-colors"
>
<span className={`text-xs transition-transform ${open ? "rotate-90" : ""}`}></span>
Per-task breakdown ({tasks.length} tasks)
</button>
{open && (
<div className="mt-4 space-y-10">
{tasks.map((task) => {
const taskRows = filteredRows.filter((r) => r.task === task);
const leftRows = taskRows.filter(leftFilter);
const rightRows = taskRows.filter(rightFilter);
return (
<div key={task} className="space-y-4">
<h3 className="text-sm font-medium text-cyan-300 border-b border-gray-800 pb-1">{task}</h3>
<div className="space-y-6">
{METRICS.map((m) => {
const leftData = aggregateData(leftRows, m.key, getRowLabel).data;
const rightData = aggregateData(rightRows, m.key, getRowLabel).data;
const existingPairs = pairs.filter(([l, r]) => leftData.has(l) && rightData.has(r));
if (existingPairs.length === 0) return null;
const evs = EVAL_VARIANTS.filter((ev) => {
for (const evMap of [...leftData.values(), ...rightData.values()]) {
if (evMap.has(ev)) return true;
}
return false;
});
return (
<SideBySideDelta
key={m.key}
pairs={existingPairs}
leftData={leftData}
rightData={rightData}
colLabels={evs}
metric={m.key}
title={m.label}
leftLabel={leftLabel}
rightLabel={rightLabel}
fullConfigs={fullConfigs}
/>
);
})}
</div>
</div>
);
})}
</div>
)}
</div>
);
}
// ---------------------------------------------------------------------------
// Main component
// ---------------------------------------------------------------------------
export default function HeatmapViewer({
datasetRepo,
split: _split = "train",
onClose,
}: HeatmapViewerProps) {
const fullRepo = datasetRepo.includes("/") ? datasetRepo : `${HF_ORG}/${datasetRepo}`;
const shortName = datasetRepo.split("/").pop() ?? datasetRepo;
const { rows, loading, error, progress, refetch } = useDatasetRows(fullRepo, _split);
// --- Filter state ---
const [selectedModel, setSelectedModel] = useState<string>("__first__");
const [selectedFamily, setSelectedFamily] = useState<string>("hypernym");
const [selectedTask, setSelectedTask] = useState<string>("__all__");
const [selectedSplit, setSelectedSplit] = useState<string>("test");
const [selectedTCType, setSelectedTCType] = useState<TCType>("self");
const [selectedMetric, setSelectedMetric] = useState<Metric>("gen_roc");
const [selectedDomain, setSelectedDomain] = useState<DomainFilter>("all");
const [viewMode, setViewMode] = useState<ViewMode>("heatmap");
const [barVariant, setBarVariant] = useState<EvalVariant>("raw");
const [comparisonPreset, setComparisonPreset] = useState<ComparisonPreset>("all");
// Row visibility: null means "show all", otherwise explicit set from preset or manual toggle
const [visibleRows, setVisibleRows] = useState<Set<string> | null>(null);
// --- Derived: available splits, tasks, families ---
// --- Available base models (model column = base model identity, same for all finetuned variants) ---
const availableModels = useMemo(() => {
const s = new Set(rows.map((r) => r.model));
return Array.from(s).sort();
}, [rows]);
// Display-friendly name: "v6-google_gemma-2-2b" → "gemma-2-2b"
const modelDisplayName = useCallback((m: string) => {
return m.replace(/^v\d+-[^_]+_/, "");
}, []);
// Resolve __first__ to actual first model once data loads
const resolvedModel = useMemo(() => {
if (selectedModel === "__first__" && availableModels.length > 0) return availableModels[0];
if (availableModels.includes(selectedModel)) return selectedModel;
return availableModels[0] ?? "";
}, [selectedModel, availableModels]);
const availableSplits = useMemo(() => {
const s = new Set(rows.map((r) => r.split));
return Array.from(s).sort();
}, [rows]);
const availableTasks = useMemo(() => {
let filtered = rows.filter((r) => isValidEvalTask(r.task));
const fam = TASK_FAMILIES.find((f) => f.key === selectedFamily);
if (fam) filtered = filtered.filter((r) => r.task.startsWith(fam.prefix));
const tasks = new Set(filtered.map((r) => r.task));
return Array.from(tasks).sort();
}, [rows, selectedFamily]);
// Reset task selection when family changes
useEffect(() => {
setSelectedTask("__all__");
}, [selectedFamily]);
// --- Filtering pipeline ---
const filteredRows = useMemo(() => {
let result = rows;
// Model filter — critical: never mix different base models
if (resolvedModel) {
result = result.filter((r) => r.model === resolvedModel);
}
// Exclude non-eval tasks (concat, bare family names, etc.)
result = result.filter((r) => isValidEvalTask(r.task));
// Only show finetuned models trained on valid combined datasets (Setting U)
result = result.filter((r) => isValidFinetunedModel(r));
// Split filter
result = result.filter((r) => r.split === selectedSplit);
// TC type filter (single-select: each TC type is a different eval condition)
result = result.filter((r) => getEvalTCType(r) === selectedTCType);
// For finetuned models trained with a specific TC, only show rows where eval TC matches training TC
result = result.filter((r) => isMatchedTC(r));
// Family filter — also exclude finetuned models trained on a different family
const fam = TASK_FAMILIES.find((f) => f.key === selectedFamily);
if (fam) {
result = result.filter((r) => r.task.startsWith(fam.prefix));
result = result.filter((r) => {
if (!r.finetuned) return true; // base models always shown
const trainFam = getTrainingFamily(r.training_config);
return trainFam === null || trainFam === selectedFamily;
});
}
// Specific task filter
if (selectedTask !== "__all__") {
result = result.filter((r) => r.task === selectedTask);
}
// Domain filter (OOD vs in-domain)
if (selectedDomain !== "all") {
result = result.filter((r) => {
const family = getTaskFamily(r.task);
if (!family) return true;
const ood = isOodTask(r.task, r.training_config, family);
if (ood === null) return true;
return selectedDomain === "ood" ? ood : !ood;
});
}
return result;
}, [rows, resolvedModel, selectedSplit, selectedTCType, selectedFamily, selectedTask, selectedDomain]);
// --- Build row label → full config mapping ---
const fullConfigs = useMemo(() => {
const map = new Map<string, string>();
for (const r of filteredRows) {
const label = getRowLabel(r);
if (!map.has(label)) {
map.set(label, r.finetuned ? r.training_config : "Base (not finetuned)");
}
}
return map;
}, [filteredRows]);
// --- All available row labels (before visibility filter) ---
const allRowLabels = useMemo(() => {
const labels = new Set<string>();
for (const r of filteredRows) labels.add(getRowLabel(r));
return sortRowLabels(Array.from(labels));
}, [filteredRows]);
// Effective visible rows: use explicit selection, preset, or show all
const effectiveVisibleRows = useMemo(() => {
const available = new Set(allRowLabels);
// If a comparison preset restricts rows, apply that
if (comparisonPreset === "training-effect" && visibleRows === null) {
const effective = new Set<string>();
for (const label of available) {
if (TRAINING_EFFECT_ROWS.has(label)) effective.add(label);
}
return effective.size > 0 ? effective : available;
}
// Manual selection
if (visibleRows !== null) {
const effective = new Set<string>();
for (const label of visibleRows) {
if (available.has(label)) effective.add(label);
}
return effective.size > 0 ? effective : available;
}
// Default: show all except hidden-by-default rows
const effective = new Set<string>();
for (const label of available) {
if (!DEFAULT_HIDDEN_ROWS.has(label)) effective.add(label);
}
return effective.size > 0 ? effective : available;
}, [visibleRows, allRowLabels, comparisonPreset]);
// --- Aggregation ---
const aggResult = useMemo(() => {
return aggregateData(filteredRows, selectedMetric, getRowLabel);
}, [filteredRows, selectedMetric]);
const aggData = aggResult.data;
const validationError = aggResult.validationError;
const rowLabels = useMemo(() => {
return sortRowLabels(Array.from(aggData.keys()).filter((l) => effectiveVisibleRows.has(l)));
}, [aggData, effectiveVisibleRows]);
// Available eval variants (only those with data)
const availableEvalVariants = useMemo(() => {
const evs = new Set<EvalVariant>();
for (const evMap of aggData.values()) {
for (const ev of evMap.keys()) evs.add(ev);
}
return EVAL_VARIANTS.filter((ev) => evs.has(ev));
}, [aggData]);
// --- Stats ---
const stats = useMemo(() => {
const taskCount = new Set(filteredRows.map((r) => r.task)).size;
const modelCount = new Set(filteredRows.map((r) => r.finetuned ? r.training_config : "base")).size;
return { tasks: taskCount, models: modelCount, rows: filteredRows.length };
}, [filteredRows]);
// --- Domain-split rows for aggregated heatmaps ---
const domainSplitRows = useMemo(() => {
const hasDomainSplit = (selectedFamily === "ifeval" || selectedFamily === "hypernym") && selectedTask === "__all__";
if (!hasDomainSplit || selectedDomain !== "all") return null;
const oodRows = filteredRows.filter((r) => {
const family = getTaskFamily(r.task);
if (!family) return false;
return isOodTask(r.task, r.training_config, family) === true;
});
const idRows = filteredRows.filter((r) => {
const family = getTaskFamily(r.task);
if (!family) return false;
return isOodTask(r.task, r.training_config, family) === false;
});
const oodTaskCount = new Set(oodRows.map((r) => r.task)).size;
const idTaskCount = new Set(idRows.map((r) => r.task)).size;
return { oodRows, idRows, oodTaskCount, idTaskCount };
}, [filteredRows, selectedFamily, selectedTask, selectedDomain]);
// --- Multi-metric heatmaps (one set per domain group when applicable) ---
type HeatmapSection = {
label: string;
taskCount: number;
metrics: { metric: Metric; data: Map<string, Map<EvalVariant, AggCell>> }[];
};
const heatmapSections = useMemo((): HeatmapSection[] => {
if (viewMode !== "heatmap") return [];
const buildSection = (sectionRows: SummaryRow[], label: string, taskCount: number): HeatmapSection => {
const metrics: { metric: Metric; data: Map<string, Map<EvalVariant, AggCell>> }[] = [];
for (const m of METRICS) {
metrics.push({ metric: m.key, data: aggregateData(sectionRows, m.key, getRowLabel).data });
}
return { label, taskCount, metrics };
};
if (domainSplitRows) {
const sections: HeatmapSection[] = [];
if (domainSplitRows.oodRows.length > 0) {
sections.push(buildSection(domainSplitRows.oodRows, "Out-of-domain", domainSplitRows.oodTaskCount));
}
if (domainSplitRows.idRows.length > 0) {
sections.push(buildSection(domainSplitRows.idRows, "In-domain", domainSplitRows.idTaskCount));
}
return sections;
}
// For families without domain splits but with multiple tasks, still show aggregate
return [buildSection(filteredRows, "All tasks", stats.tasks)];
}, [filteredRows, viewMode, domainSplitRows, stats.tasks]);
// --- Per-task heatmaps (when viewing "All tasks" in a family) ---
const perTaskSections = useMemo((): HeatmapSection[] => {
if (viewMode !== "heatmap") return [];
if (selectedTask !== "__all__") return []; // single task already shown in main sections
const tasks = new Set(filteredRows.map((r) => r.task));
if (tasks.size <= 1) return []; // no point showing per-task if only one
const sorted = Array.from(tasks).sort();
return sorted.map((task) => {
const taskRows = filteredRows.filter((r) => r.task === task);
const metrics: { metric: Metric; data: Map<string, Map<EvalVariant, AggCell>> }[] = [];
for (const m of METRICS) {
metrics.push({ metric: m.key, data: aggregateData(taskRows, m.key, getRowLabel).data });
}
return { label: task, taskCount: 1, metrics };
});
}, [filteredRows, viewMode, selectedTask]);
// --- Comparison preset data ---
// For delta comparisons, we need rows from multiple TC types simultaneously
const comparisonBaseRows = useMemo(() => {
// Same as filteredRows but without TC type filter
let result = rows;
// Model filter — must match filteredRows
if (resolvedModel) {
result = result.filter((r) => r.model === resolvedModel);
}
result = result.filter((r) => isValidEvalTask(r.task));
result = result.filter((r) => isValidFinetunedModel(r));
result = result.filter((r) => r.split === selectedSplit);
// For finetuned models trained with a specific TC, only show rows where eval TC matches training TC
result = result.filter((r) => isMatchedTC(r));
const fam = TASK_FAMILIES.find((f) => f.key === selectedFamily);
if (fam) {
result = result.filter((r) => r.task.startsWith(fam.prefix));
result = result.filter((r) => {
if (!r.finetuned) return true;
const trainFam = getTrainingFamily(r.training_config);
return trainFam === null || trainFam === selectedFamily;
});
}
if (selectedTask !== "__all__") {
result = result.filter((r) => r.task === selectedTask);
}
if (selectedDomain !== "all") {
result = result.filter((r) => {
const family = getTaskFamily(r.task);
if (!family) return true;
const ood = isOodTask(r.task, r.training_config, family);
if (ood === null) return true;
return selectedDomain === "ood" ? ood : !ood;
});
}
return result;
}, [rows, resolvedModel, selectedSplit, selectedFamily, selectedTask, selectedDomain]);
// Build aggregated data for rows with a specific training TC flag (in the row label)
// Note: this is the TRAINING tc flag (in the model name), NOT the eval TC type
const aggByTrainingTC = useCallback((trainingTCFlag: string | null, metric: Metric) => {
const subset = filteredRows.filter((r) => {
const label = getRowLabel(r);
if (trainingTCFlag === null) {
// No training TC flag — exclude rows that have any TC flag
return !TC_FLAGS.some((f) => hasTCFlag(label, f));
}
return hasTCFlag(label, trainingTCFlag);
});
return aggregateData(subset, metric, getRowLabel).data;
}, [filteredRows]);
// Dynamic pairing: find rows that differ only by a TC flag
const buildPairs = useCallback((leftFlag: string | null, rightFlag: string): [string, string][] => {
const leftLabels = new Set<string>();
const rightLabels = new Set<string>();
for (const r of filteredRows) {
const label = getRowLabel(r);
if (label === "Base") continue;
if (rightFlag && hasTCFlag(label, rightFlag)) {
rightLabels.add(label);
} else if (leftFlag === null && !TC_FLAGS.some((f) => hasTCFlag(label, f))) {
leftLabels.add(label);
} else if (leftFlag && hasTCFlag(label, leftFlag)) {
leftLabels.add(label);
}
}
// Match: strip TC flag from right label, see if it matches a left label
const pairs: [string, string][] = [];
for (const rl of rightLabels) {
const stripped = stripTCFlag(rl);
if (leftLabels.has(stripped)) {
pairs.push([stripped, rl]);
}
}
return sortRowLabels(pairs.map(([l]) => l)).map((l) => {
const r = pairs.find(([left]) => left === l)![1];
return [l, r] as [string, string];
});
}, [filteredRows]);
// + TC-Self pairs: without-tc ↔ tcself
const tcSelfPairs = useMemo((): [string, string][] => {
if (comparisonPreset !== "plus-tcself") return [];
return buildPairs(null, "tcself");
}, [comparisonPreset, buildPairs]);
// TC-Self vs TC-Neg: for each model, compare train-tcself+eval-self vs train-tcneg+eval-neg
// Base model: same "Base" label, just different eval TC type
// Finetuned: pair by stripping tcself/tcneg from training label
const aggByMatchedTC = useCallback((evalTC: TCType, metric: Metric) => {
// From comparisonBaseRows (no eval TC filter), get rows matching this eval TC type
const subset = comparisonBaseRows.filter((r) => getEvalTCType(r) === evalTC);
return aggregateData(subset, metric, getRowLabel).data;
}, [comparisonBaseRows]);
const tcSelfVsNegPairs = useMemo((): [string, string][] => {
if (comparisonPreset !== "tcself-vs-tcneg") return [];
// Left: models evaluated with self-TC (base + models trained with tcself)
const selfRows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "self");
const selfLabels = new Set(selfRows.map(getRowLabel));
// Right: models evaluated with neg-TC (base + models trained with tcneg)
const negRows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "neg");
const negLabels = new Set(negRows.map(getRowLabel));
const pairs: [string, string][] = [];
// Base: appears in both with same label
if (selfLabels.has("Base") && negLabels.has("Base")) {
pairs.push(["Base", "Base"]);
}
// Finetuned: match tcself label ↔ tcneg label (strip tc flag to find pair)
for (const sl of selfLabels) {
if (sl === "Base") continue;
if (!hasTCFlag(sl, "tcself")) continue;
const negVersion = sl.replace("tcself", "tcneg");
if (negLabels.has(negVersion)) {
pairs.push([sl, negVersion]);
}
}
return sortRowLabels(pairs.map(([l]) => l)).map((l) => {
const p = pairs.find(([left]) => left === l)!;
return p;
});
}, [comparisonPreset, comparisonBaseRows]);
// TC-Self vs TC-GPT2: pair tcself-trained (eval=self) with tco-trained (eval=gpt2)
// e.g. "Comb-tcself-lo" ↔ "Comb-tco-lo"
const tcSelfVsGpt2Pairs = useMemo((): [string, string][] => {
if (comparisonPreset !== "tcself-vs-tcgpt2") return [];
const selfRows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "self");
const selfLabels = new Set(selfRows.map(getRowLabel));
const gpt2Rows = comparisonBaseRows.filter((r) => getEvalTCType(r) === "gpt2");
const gpt2Labels = new Set(gpt2Rows.map(getRowLabel));
const pairs: [string, string][] = [];
// Base: appears in both with same label
if (selfLabels.has("Base") && gpt2Labels.has("Base")) {
pairs.push(["Base", "Base"]);
}
// Finetuned: match tcself label ↔ tco label (swap tc flag to find pair)
for (const sl of selfLabels) {
if (sl === "Base") continue;
if (!hasTCFlag(sl, "tcself")) continue;
const tcoVersion = sl.replace("tcself", "tco");
if (gpt2Labels.has(tcoVersion)) {
pairs.push([sl, tcoVersion]);
}
}
return sortRowLabels(pairs.map(([l]) => l)).map((l) => {
const p = pairs.find(([left]) => left === l)!;
return p;
});
}, [comparisonPreset, comparisonBaseRows]);
// Family options for dropdown
const familyOptions = useMemo((): { key: string; label: string }[] => {
const options: { key: string; label: string }[] = [];
for (const fam of TASK_FAMILIES) {
const count = new Set(rows.filter((r) => r.task.startsWith(fam.prefix)).map((r) => r.task)).size;
if (count > 0) options.push({ key: fam.key, label: `${fam.label} (${count})` });
}
return options;
}, [rows]);
// Task options for dropdown
const taskOptions = useMemo(() => {
const options = [{ key: "__all__", label: `All tasks (${availableTasks.length})` }];
for (const t of availableTasks) {
options.push({ key: t, label: t });
}
return options;
}, [availableTasks]);
// Domain filter visibility
const showDomainFilter = selectedFamily === "ifeval" || selectedFamily === "hypernym";
return (
<div className="fixed inset-0 z-50 flex flex-col bg-gray-950">
{/* Header */}
<div className="flex items-center justify-between px-5 py-3 border-b border-gray-800 flex-shrink-0 bg-gray-900">
<div className="flex items-center gap-3 min-w-0">
<span className="text-sm font-semibold text-gray-200 truncate">{shortName}</span>
<span className="text-xs text-gray-600 border border-gray-700 px-1.5 py-0.5 rounded">heatmap</span>
{!loading && (
<span className="text-xs text-gray-500">
{rows.length.toLocaleString()} rows loaded
</span>
)}
</div>
<div className="flex items-center gap-2 flex-shrink-0">
<a
href={`https://huggingface.co/datasets/${fullRepo}`}
target="_blank"
rel="noopener noreferrer"
className="text-xs text-gray-500 hover:text-cyan-400 transition-colors px-2 py-1 border border-gray-700 rounded"
>
HF
</a>
<button
onClick={onClose}
className="text-gray-400 hover:text-gray-200 transition-colors p-1 rounded hover:bg-gray-700"
aria-label="Close viewer"
>
<svg xmlns="http://www.w3.org/2000/svg" className="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth={2}>
<path strokeLinecap="round" strokeLinejoin="round" d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
</div>
</div>
{/* Loading */}
{loading && (
<div className="flex-1 flex items-center justify-center">
<div className="flex flex-col items-center gap-3">
<div className="w-6 h-6 border-2 border-cyan-500 border-t-transparent rounded-full animate-spin" />
<p className="text-sm text-gray-400">
Loading... {progress.loaded.toLocaleString()} / {progress.total.toLocaleString()} rows
</p>
</div>
</div>
)}
{/* Error */}
{!loading && error && (
<div className="flex-1 flex items-center justify-center">
<div className="max-w-lg text-center space-y-3">
<p className="text-sm font-medium text-red-400">Failed to load dataset</p>
<p className="text-xs text-gray-500 font-mono break-words bg-gray-800 rounded p-3">{error}</p>
<button
onClick={refetch}
className="text-xs text-cyan-400 hover:text-cyan-300 border border-cyan-700/50 px-3 py-1 rounded transition-colors"
>
Retry
</button>
</div>
</div>
)}
{/* Main content */}
{!loading && !error && rows.length > 0 && (
<div className="flex-1 flex overflow-hidden">
{/* Controls sidebar */}
<div className="w-64 flex-shrink-0 border-r border-gray-800 bg-gray-900/50 overflow-y-auto p-4 space-y-4">
{/* View mode */}
<div className="flex gap-1">
<button
onClick={() => setViewMode("heatmap")}
className={`flex-1 text-xs py-1.5 rounded border transition-colors ${
viewMode === "heatmap"
? "bg-cyan-800/60 text-cyan-200 border-cyan-600/60"
: "bg-gray-800 text-gray-500 border-gray-700"
}`}
>
Heatmap
</button>
<button
onClick={() => setViewMode("bar")}
className={`flex-1 text-xs py-1.5 rounded border transition-colors ${
viewMode === "bar"
? "bg-cyan-800/60 text-cyan-200 border-cyan-600/60"
: "bg-gray-800 text-gray-500 border-gray-700"
}`}
>
Bar Plot
</button>
</div>
{availableModels.length > 1 && (
<Dropdown
label="Base Model"
value={resolvedModel}
options={availableModels.map((m) => ({ key: m, label: modelDisplayName(m) }))}
onChange={setSelectedModel}
/>
)}
<Dropdown
label="Task Family"
value={selectedFamily}
options={familyOptions}
onChange={setSelectedFamily}
/>
<Dropdown
label="Task"
value={selectedTask}
options={taskOptions}
onChange={setSelectedTask}
/>
<Dropdown
label="Split"
value={selectedSplit}
options={availableSplits.map((s) => ({ key: s, label: s }))}
onChange={setSelectedSplit}
/>
<Dropdown
label="Eval TC Type"
value={selectedTCType}
options={TC_TYPES}
onChange={setSelectedTCType}
/>
{showDomainFilter && (
<Dropdown
label="Domain"
value={selectedDomain}
options={[
{ key: "all" as DomainFilter, label: "All" },
{ key: "ood" as DomainFilter, label: "Out-of-domain" },
{ key: "in-domain" as DomainFilter, label: "In-domain" },
]}
onChange={setSelectedDomain}
/>
)}
<Dropdown
label="Comparison"
value={comparisonPreset}
options={COMPARISON_PRESETS}
onChange={setComparisonPreset}
/>
{/* Row visibility */}
{allRowLabels.length > 0 && (
<div className="flex flex-col gap-1">
<div className="flex items-center justify-between">
<label className="text-[10px] uppercase tracking-wider text-gray-500 font-medium">Visible Models</label>
<div className="flex gap-1">
<button
onClick={() => setVisibleRows(new Set(allRowLabels))}
className="text-[9px] text-gray-500 hover:text-gray-300 px-1"
>
all
</button>
<button
onClick={() => setVisibleRows(null)}
className="text-[9px] text-gray-500 hover:text-gray-300 px-1"
>
default
</button>
</div>
</div>
<div className="flex flex-col gap-0.5 max-h-48 overflow-y-auto">
{allRowLabels.map((label) => (
<label key={label} className="flex items-center gap-1.5 text-xs cursor-pointer hover:bg-gray-800/50 px-1 py-0.5 rounded">
<input
type="checkbox"
checked={effectiveVisibleRows.has(label)}
onChange={() => {
const next = new Set(effectiveVisibleRows);
if (next.has(label)) next.delete(label);
else next.add(label);
setVisibleRows(next);
}}
className="rounded border-gray-600 bg-gray-800 text-cyan-500 focus:ring-0 focus:ring-offset-0 h-3 w-3"
/>
<span className={effectiveVisibleRows.has(label) ? "text-gray-200" : "text-gray-500"}><DisplayRowLabel label={label} /></span>
</label>
))}
</div>
</div>
)}
{/* Stats */}
<div className="pt-2 border-t border-gray-800 space-y-1">
<p className="text-[10px] uppercase tracking-wider text-gray-500 font-medium">Filtered data</p>
<p className="text-xs text-gray-400">{stats.tasks} tasks, {stats.models} configs</p>
<p className="text-xs text-gray-400">{stats.rows.toLocaleString()} rows</p>
{selectedTask === "__all__" && stats.tasks > 1 && (
<p className="text-[10px] text-gray-600 italic">Averaging across {stats.tasks} tasks</p>
)}
</div>
</div>
{/* Visualization area */}
<div className="flex-1 overflow-auto p-6 space-y-8">
{validationError ? (
<div className="flex items-center justify-center h-full">
<div className="max-w-xl bg-red-950 border-2 border-red-500 rounded-lg p-6 space-y-3">
<div className="flex items-center gap-2">
<span className="text-2xl">!!!</span>
<h2 className="text-lg font-bold text-red-300">Data Integrity Violation</h2>
</div>
<pre className="text-sm text-red-200 whitespace-pre-wrap font-mono leading-relaxed">{validationError}</pre>
<p className="text-xs text-red-400 mt-4">
All visualizations are blocked until this is resolved. If you see this, the data pipeline has a bug
— different models or training configs are collapsing into the same row label.
</p>
</div>
</div>
) : filteredRows.length === 0 ? (
<div className="flex items-center justify-center h-full">
<p className="text-sm text-gray-500 italic">No data matches the current filters.</p>
</div>
) : viewMode === "heatmap" ? (
<>
{/* Aggregated sections (domain-split or single aggregate) — hidden when comparison preset is active */}
{(comparisonPreset === "all" || comparisonPreset === "training-effect") && heatmapSections.map((section, sIdx) => (
<div key={`agg-${sIdx}`} className="space-y-6">
{section.label && (
<h2 className="text-sm font-semibold text-gray-200 border-b border-gray-700 pb-2">
{section.label}
{section.taskCount > 1 && (
<span className="text-gray-500 font-normal ml-2">(mean over {section.taskCount} tasks)</span>
)}
</h2>
)}
<div className="flex flex-wrap gap-6">
{section.metrics.map(({ metric: m, data }) => {
const metricLabel = METRICS.find((x) => x.key === m)?.label ?? m;
const metricRowLabels = sortRowLabels(Array.from(data.keys()).filter((l) => effectiveVisibleRows.has(l)));
const metricEvs = EVAL_VARIANTS.filter((ev) => {
for (const evMap of data.values()) {
if (evMap.has(ev)) return true;
}
return false;
});
if (metricRowLabels.length === 0) return null;
return (
<HeatmapGrid
key={m}
data={data}
rowLabels={metricRowLabels}
colLabels={metricEvs}
metric={m}
title={metricLabel}
fullConfigs={fullConfigs}
/>
);
})}
</div>
</div>
))}
{/* Comparison: + TC-Self (side-by-side + delta) */}
{comparisonPreset === "plus-tcself" && tcSelfPairs.length === 0 && (
<p className="text-sm text-gray-500 italic">No matched pairs found for + TC-Self comparison. Check that both non-TC and tcself models exist for this family.</p>
)}
{comparisonPreset === "plus-tcself" && tcSelfPairs.length > 0 && METRICS.map((m) => {
const leftData = aggByTrainingTC(null, m.key);
const rightData = aggByTrainingTC("tcself", m.key);
const existingPairs = tcSelfPairs.filter(([l, r]) => leftData.has(l) && rightData.has(r));
if (existingPairs.length === 0) return null;
const evs = EVAL_VARIANTS.filter((ev) => {
for (const evMap of [...leftData.values(), ...rightData.values()]) {
if (evMap.has(ev)) return true;
}
return false;
});
return (
<SideBySideDelta
key={`tcself-${m.key}`}
pairs={existingPairs}
leftData={leftData}
rightData={rightData}
colLabels={evs}
metric={m.key}
title={m.label}
leftLabel="Without TC"
rightLabel="+ TC-Self"
fullConfigs={fullConfigs}
/>
);
})}
{/* Comparison: TC-Self vs TC-Neg (side-by-side + delta) */}
{comparisonPreset === "tcself-vs-tcneg" && tcSelfVsNegPairs.length === 0 && (
<p className="text-sm text-gray-500 italic">No matched pairs found for TC-Self vs TC-Neg comparison. Check that both tcself and tcneg models exist for this family.</p>
)}
{comparisonPreset === "tcself-vs-tcneg" && tcSelfVsNegPairs.length > 0 && METRICS.map((m) => {
const leftData = aggByMatchedTC("self", m.key);
const rightData = aggByMatchedTC("neg", m.key);
const existingPairs = tcSelfVsNegPairs.filter(([l, r]) => leftData.has(l) && rightData.has(r));
if (existingPairs.length === 0) return null;
const evs = EVAL_VARIANTS.filter((ev) => {
for (const evMap of [...leftData.values(), ...rightData.values()]) {
if (evMap.has(ev)) return true;
}
return false;
});
return (
<SideBySideDelta
key={`tcneg-${m.key}`}
pairs={existingPairs}
leftData={leftData}
rightData={rightData}
colLabels={evs}
metric={m.key}
title={m.label}
leftLabel="TC-Self"
rightLabel="TC-Neg"
fullConfigs={fullConfigs}
/>
);
})}
{/* Comparison: TC-Self vs TC-GPT2 (side-by-side + delta) */}
{comparisonPreset === "tcself-vs-tcgpt2" && tcSelfVsGpt2Pairs.length === 0 && (
<p className="text-sm text-gray-500 italic">No matched pairs found for TC-Self vs TC-GPT2 comparison.</p>
)}
{comparisonPreset === "tcself-vs-tcgpt2" && tcSelfVsGpt2Pairs.length > 0 && METRICS.map((m) => {
const leftData = aggByMatchedTC("self", m.key);
const rightData = aggByMatchedTC("gpt2", m.key);
const existingPairs = tcSelfVsGpt2Pairs.filter(([l, r]) => leftData.has(l) && rightData.has(r));
if (existingPairs.length === 0) return null;
const evs = EVAL_VARIANTS.filter((ev) => {
for (const evMap of [...leftData.values(), ...rightData.values()]) {
if (evMap.has(ev)) return true;
}
return false;
});
return (
<SideBySideDelta
key={`tcgpt2-${m.key}`}
pairs={existingPairs}
leftData={leftData}
rightData={rightData}
colLabels={evs}
metric={m.key}
title={m.label}
leftLabel="TC-Self"
rightLabel="TC-GPT2"
fullConfigs={fullConfigs}
/>
);
})}
{/* Per-task breakdown — regular heatmap for all/training-effect */}
{(comparisonPreset === "all" || comparisonPreset === "training-effect") && perTaskSections.length > 0 && (
<PerTaskCollapsible sections={perTaskSections} fullConfigs={fullConfigs} visibleRows={effectiveVisibleRows} />
)}
{/* Per-task breakdown — side-by-side + delta for + TC-Self */}
{comparisonPreset === "plus-tcself" && selectedTask === "__all__" && (
<PerTaskComparisonCollapsible
filteredRows={filteredRows}
pairs={tcSelfPairs}
leftLabel="Without TC"
rightLabel="+ TC-Self"
leftFilter={(r) => {
const label = getRowLabel(r);
return !TC_FLAGS.some((f) => hasTCFlag(label, f));
}}
rightFilter={(r) => hasTCFlag(getRowLabel(r), "tcself")}
fullConfigs={fullConfigs}
/>
)}
{/* Per-task breakdown — side-by-side + delta for TC-Self vs TC-Neg */}
{comparisonPreset === "tcself-vs-tcneg" && selectedTask === "__all__" && (
<PerTaskComparisonCollapsible
filteredRows={comparisonBaseRows}
pairs={tcSelfVsNegPairs}
leftLabel="TC-Self (eval)"
rightLabel="TC-Neg (eval)"
leftFilter={(r) => getEvalTCType(r) === "self"}
rightFilter={(r) => getEvalTCType(r) === "neg"}
fullConfigs={fullConfigs}
/>
)}
{/* Per-task breakdown — side-by-side + delta for TC-Self vs TC-GPT2 */}
{comparisonPreset === "tcself-vs-tcgpt2" && selectedTask === "__all__" && (
<PerTaskComparisonCollapsible
filteredRows={comparisonBaseRows}
pairs={tcSelfVsGpt2Pairs}
leftLabel="TC-Self (eval)"
rightLabel="TC-GPT2 (eval)"
leftFilter={(r) => getEvalTCType(r) === "self"}
rightFilter={(r) => getEvalTCType(r) === "gpt2"}
fullConfigs={fullConfigs}
/>
)}
</>
) : (
<div className="space-y-8">
{METRICS.map((m) => {
const metricAgg = aggregateData(filteredRows, m.key, getRowLabel).data;
const metricRowLabels = sortRowLabels(Array.from(metricAgg.keys()).filter((l) => effectiveVisibleRows.has(l)));
if (metricRowLabels.length === 0) return null;
return (
<BarChart
key={m.key}
data={metricAgg}
rowLabels={metricRowLabels}
evalVariants={availableEvalVariants.filter((ev) => ev === "raw" || ev === "tc")}
metric={m.key}
title={`${m.label}${selectedTask === "__all__" && stats.tasks > 1 ? ` (mean ± SE over ${stats.tasks} tasks)` : ""}`}
/>
);
})}
</div>
)}
</div>
</div>
)}
</div>
);
}