from __future__ import annotations import argparse from datetime import datetime import json from pathlib import Path import numpy as np import pandas as pd from sklearn.metrics import confusion_matrix from .ijepa_localization import IJepaPatchLocalizer from .obstacle_dataset import DEFAULT_OBSTACLE_DATASET, load_balanced_obstacle_rows, load_obstacle_image, parse_yolo_boxes from .prototypes import build_class_prototypes, guess_objects_with_prototypes from .small_head import guess_objects_with_head, train_small_head def run_bulk_eval( dataset_name: str, split: str, model_name: str, eval_samples: int, support_samples: int, seed: int, output: str, ) -> Path: run_id = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = resolve_output_path(output, run_id) localizer = IJepaPatchLocalizer(model_name=model_name) support_rows = load_balanced_obstacle_rows( dataset_name, split, support_samples, random_seed=seed + 10_000, ) support_files = {row["file_name"] for row in support_rows} eval_rows = load_disjoint_eval_rows(dataset_name, split, eval_samples, seed, support_files) prototypes = build_class_prototypes(dataset_name, split, support_rows, localizer) head = train_small_head(dataset_name, split, support_rows, localizer) records = [] for sample_index, row in enumerate(eval_rows): image = load_obstacle_image(dataset_name, row, split) boxes = parse_yolo_boxes(row) prototype_guesses = guess_objects_with_prototypes(image, boxes, localizer, prototypes) head_guesses = guess_objects_with_head(image, boxes, localizer, head) head_by_object = {guess.object_index: guess for guess in head_guesses} prototype_by_object = {guess.object_index: guess for guess in prototype_guesses} for object_index, box in enumerate(boxes, start=1): prototype_guess = prototype_by_object.get(object_index) head_guess = head_by_object.get(object_index) records.append( { "sample": sample_index, "run_id": run_id, "support_samples": support_samples, "eval_samples": eval_samples, "seed": seed, "file_name": row["file_name"], "object": object_index, "yolo_label": box.class_name, "prototype_guess": prototype_guess.ijepa_guess if prototype_guess else None, "prototype_similarity": prototype_guess.similarity if prototype_guess else None, "prototype_agreement": prototype_guess.agreement if prototype_guess else None, "head_guess": head_guess.head_guess if head_guess else None, "head_confidence": head_guess.confidence if head_guess else None, "head_agreement": head_guess.agreement if head_guess else None, "head_train_objects": head.train_objects if head else None, "head_train_accuracy": head.train_accuracy if head else None, "head_parameter_count": head.parameter_count if head else None, "prototype_classes": len(prototypes), } ) output_path.parent.mkdir(parents=True, exist_ok=True) df = pd.DataFrame(records) df.to_csv(output_path, index=False) write_summaries(df, output_path) return output_path def load_disjoint_eval_rows(dataset_name: str, split: str, eval_samples: int, seed: int, support_files: set[str]): candidates = load_balanced_obstacle_rows( dataset_name, split, eval_samples + len(support_files) + 100, random_seed=seed, ) selected = [] for row in candidates: if row["file_name"] in support_files: continue selected.append(row) if len(selected) >= eval_samples: break if len(selected) < eval_samples: raise ValueError(f"Could only find {len(selected)} disjoint eval rows; requested {eval_samples}.") return selected def write_summaries(df: pd.DataFrame, output_path: Path) -> None: if df.empty: return summary = { "run_id": str(df["run_id"].iloc[0]) if "run_id" in df else None, "support_samples": first_int(df, "support_samples"), "eval_samples": first_int(df, "eval_samples"), "seed": first_int(df, "seed"), "objects": int(len(df)), "classes": int(df["yolo_label"].nunique()), "prototype_accuracy": safe_accuracy(df, "prototype_agreement"), "head_accuracy": safe_accuracy(df, "head_agreement"), "prototype_macro_accuracy": macro_accuracy(df, "prototype_agreement"), "head_macro_accuracy": macro_accuracy(df, "head_agreement"), "head_train_objects": first_int(df, "head_train_objects"), "head_parameter_count": first_int(df, "head_parameter_count"), "prototype_classes": first_int(df, "prototype_classes"), } summary_path = sibling_path(output_path, "_summary.csv") per_class_path = sibling_path(output_path, "_per_class.csv") prototype_confusion_path = sibling_path(output_path, "_prototype_confusion.csv") head_confusion_path = sibling_path(output_path, "_head_confusion.csv") report_path = sibling_path(output_path, "_report.json") pd.DataFrame([summary]).to_csv(summary_path, index=False) per_class = ( df.groupby("yolo_label") .agg( objects=("yolo_label", "size"), prototype_accuracy=("prototype_agreement", safe_mean), head_accuracy=("head_agreement", safe_mean), avg_prototype_similarity=("prototype_similarity", "mean"), avg_head_confidence=("head_confidence", "mean"), ) .reset_index() .sort_values(["head_accuracy", "objects"], ascending=[True, False]) ) per_class.to_csv(per_class_path, index=False) prototype_confusion = confusion_frame(df["yolo_label"], df["prototype_guess"]) head_confusion = confusion_frame(df["yolo_label"], df["head_guess"]) prototype_confusion.to_csv(prototype_confusion_path) head_confusion.to_csv(head_confusion_path) report = { "summary": summary, "per_class": per_class.replace({np.nan: None}).to_dict(orient="records"), "top_prototype_confusions": top_confusions(df, "prototype_guess"), "top_head_confusions": top_confusions(df, "head_guess"), "files": { "objects": str(output_path), "summary": str(summary_path), "per_class": str(per_class_path), "prototype_confusion": str(prototype_confusion_path), "head_confusion": str(head_confusion_path), }, } report_path.write_text(json.dumps(report, indent=2), encoding="utf-8") def sibling_path(output_path: Path, suffix: str) -> Path: names = { "_summary.csv": "summary.csv", "_per_class.csv": "per_class.csv", "_prototype_confusion.csv": "prototype_confusion.csv", "_head_confusion.csv": "head_confusion.csv", "_report.json": "report.json", } return output_path.with_name(names.get(suffix, f"{output_path.stem}{suffix}")) def resolve_output_path(output: str, run_id: str) -> Path: path = Path(output) if output == "auto": return Path("outputs") / f"run_{run_id}" / "objects.csv" if "{timestamp}" in output or "{run_id}" in output: resolved = Path(output.format(timestamp=run_id, run_id=run_id)) if resolved.suffix: return resolved return resolved / "objects.csv" if path.suffix == "": return path / "objects.csv" if path.exists(): return path.with_name(f"{path.stem}_{run_id}{path.suffix}") return path def safe_mean(values) -> float: clean = pd.Series(values).dropna() if clean.empty: return np.nan return float(clean.astype(bool).mean()) def safe_accuracy(df: pd.DataFrame, column: str) -> float | None: clean = df[column].dropna() if clean.empty: return None return float(clean.astype(bool).mean()) def macro_accuracy(df: pd.DataFrame, column: str) -> float | None: clean = df.dropna(subset=[column]) if clean.empty: return None return float(clean.groupby("yolo_label")[column].apply(safe_mean).mean()) def confusion_frame(y_true, y_pred) -> pd.DataFrame: clean = pd.DataFrame({"true": y_true, "pred": y_pred}).dropna() labels = sorted(set(clean["true"]).union(set(clean["pred"]))) matrix = confusion_matrix(clean["true"], clean["pred"], labels=labels) return pd.DataFrame(matrix, index=labels, columns=labels) def top_confusions(df: pd.DataFrame, prediction_column: str, limit: int = 10) -> list[dict]: clean = df.dropna(subset=[prediction_column]) wrong = clean[clean["yolo_label"] != clean[prediction_column]] if wrong.empty: return [] counts = ( wrong.groupby(["yolo_label", prediction_column]) .size() .reset_index(name="count") .sort_values("count", ascending=False) .head(limit) ) return [ {"yolo_label": row["yolo_label"], "predicted": row[prediction_column], "count": int(row["count"])} for _, row in counts.iterrows() ] def first_int(df: pd.DataFrame, column: str) -> int | None: clean = df[column].dropna() if clean.empty: return None return int(clean.iloc[0]) def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run bulk I-JEPA prototype/head evaluation.") parser.add_argument("--dataset-name", default=DEFAULT_OBSTACLE_DATASET) parser.add_argument("--split", default="train") parser.add_argument("--model-name", default="facebook/ijepa_vith14_1k") parser.add_argument("--eval-samples", type=int, default=50) parser.add_argument( "--support-samples", "--reference-samples", dest="support_samples", type=int, default=80, help="Images used to build class prototypes and train the tiny classifier.", ) parser.add_argument("--seed", type=int, default=7) parser.add_argument( "--output", default="auto", help="Output CSV path or run directory. Supports {timestamp} or {run_id}. Default: outputs/run_/objects.csv", ) return parser def main() -> None: args = build_parser().parse_args() output = run_bulk_eval(**vars(args)) print(f"Saved bulk evaluation: {output}") for suffix in [ "_summary.csv", "_per_class.csv", "_prototype_confusion.csv", "_head_confusion.csv", "_report.json", ]: print(f"Saved: {sibling_path(output, suffix)}") if __name__ == "__main__": main()