JEPA-demo / src /bulk_eval.py
ddebree's picture
Upload folder using huggingface_hub
2bc3168 verified
from __future__ import annotations
import argparse
from datetime import datetime
import json
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from .ijepa_localization import IJepaPatchLocalizer
from .obstacle_dataset import DEFAULT_OBSTACLE_DATASET, load_balanced_obstacle_rows, load_obstacle_image, parse_yolo_boxes
from .prototypes import build_class_prototypes, guess_objects_with_prototypes
from .small_head import guess_objects_with_head, train_small_head
def run_bulk_eval(
dataset_name: str,
split: str,
model_name: str,
eval_samples: int,
support_samples: int,
seed: int,
output: str,
) -> Path:
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = resolve_output_path(output, run_id)
localizer = IJepaPatchLocalizer(model_name=model_name)
support_rows = load_balanced_obstacle_rows(
dataset_name,
split,
support_samples,
random_seed=seed + 10_000,
)
support_files = {row["file_name"] for row in support_rows}
eval_rows = load_disjoint_eval_rows(dataset_name, split, eval_samples, seed, support_files)
prototypes = build_class_prototypes(dataset_name, split, support_rows, localizer)
head = train_small_head(dataset_name, split, support_rows, localizer)
records = []
for sample_index, row in enumerate(eval_rows):
image = load_obstacle_image(dataset_name, row, split)
boxes = parse_yolo_boxes(row)
prototype_guesses = guess_objects_with_prototypes(image, boxes, localizer, prototypes)
head_guesses = guess_objects_with_head(image, boxes, localizer, head)
head_by_object = {guess.object_index: guess for guess in head_guesses}
prototype_by_object = {guess.object_index: guess for guess in prototype_guesses}
for object_index, box in enumerate(boxes, start=1):
prototype_guess = prototype_by_object.get(object_index)
head_guess = head_by_object.get(object_index)
records.append(
{
"sample": sample_index,
"run_id": run_id,
"support_samples": support_samples,
"eval_samples": eval_samples,
"seed": seed,
"file_name": row["file_name"],
"object": object_index,
"yolo_label": box.class_name,
"prototype_guess": prototype_guess.ijepa_guess if prototype_guess else None,
"prototype_similarity": prototype_guess.similarity if prototype_guess else None,
"prototype_agreement": prototype_guess.agreement if prototype_guess else None,
"head_guess": head_guess.head_guess if head_guess else None,
"head_confidence": head_guess.confidence if head_guess else None,
"head_agreement": head_guess.agreement if head_guess else None,
"head_train_objects": head.train_objects if head else None,
"head_train_accuracy": head.train_accuracy if head else None,
"head_parameter_count": head.parameter_count if head else None,
"prototype_classes": len(prototypes),
}
)
output_path.parent.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(records)
df.to_csv(output_path, index=False)
write_summaries(df, output_path)
return output_path
def load_disjoint_eval_rows(dataset_name: str, split: str, eval_samples: int, seed: int, support_files: set[str]):
candidates = load_balanced_obstacle_rows(
dataset_name,
split,
eval_samples + len(support_files) + 100,
random_seed=seed,
)
selected = []
for row in candidates:
if row["file_name"] in support_files:
continue
selected.append(row)
if len(selected) >= eval_samples:
break
if len(selected) < eval_samples:
raise ValueError(f"Could only find {len(selected)} disjoint eval rows; requested {eval_samples}.")
return selected
def write_summaries(df: pd.DataFrame, output_path: Path) -> None:
if df.empty:
return
summary = {
"run_id": str(df["run_id"].iloc[0]) if "run_id" in df else None,
"support_samples": first_int(df, "support_samples"),
"eval_samples": first_int(df, "eval_samples"),
"seed": first_int(df, "seed"),
"objects": int(len(df)),
"classes": int(df["yolo_label"].nunique()),
"prototype_accuracy": safe_accuracy(df, "prototype_agreement"),
"head_accuracy": safe_accuracy(df, "head_agreement"),
"prototype_macro_accuracy": macro_accuracy(df, "prototype_agreement"),
"head_macro_accuracy": macro_accuracy(df, "head_agreement"),
"head_train_objects": first_int(df, "head_train_objects"),
"head_parameter_count": first_int(df, "head_parameter_count"),
"prototype_classes": first_int(df, "prototype_classes"),
}
summary_path = sibling_path(output_path, "_summary.csv")
per_class_path = sibling_path(output_path, "_per_class.csv")
prototype_confusion_path = sibling_path(output_path, "_prototype_confusion.csv")
head_confusion_path = sibling_path(output_path, "_head_confusion.csv")
report_path = sibling_path(output_path, "_report.json")
pd.DataFrame([summary]).to_csv(summary_path, index=False)
per_class = (
df.groupby("yolo_label")
.agg(
objects=("yolo_label", "size"),
prototype_accuracy=("prototype_agreement", safe_mean),
head_accuracy=("head_agreement", safe_mean),
avg_prototype_similarity=("prototype_similarity", "mean"),
avg_head_confidence=("head_confidence", "mean"),
)
.reset_index()
.sort_values(["head_accuracy", "objects"], ascending=[True, False])
)
per_class.to_csv(per_class_path, index=False)
prototype_confusion = confusion_frame(df["yolo_label"], df["prototype_guess"])
head_confusion = confusion_frame(df["yolo_label"], df["head_guess"])
prototype_confusion.to_csv(prototype_confusion_path)
head_confusion.to_csv(head_confusion_path)
report = {
"summary": summary,
"per_class": per_class.replace({np.nan: None}).to_dict(orient="records"),
"top_prototype_confusions": top_confusions(df, "prototype_guess"),
"top_head_confusions": top_confusions(df, "head_guess"),
"files": {
"objects": str(output_path),
"summary": str(summary_path),
"per_class": str(per_class_path),
"prototype_confusion": str(prototype_confusion_path),
"head_confusion": str(head_confusion_path),
},
}
report_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
def sibling_path(output_path: Path, suffix: str) -> Path:
names = {
"_summary.csv": "summary.csv",
"_per_class.csv": "per_class.csv",
"_prototype_confusion.csv": "prototype_confusion.csv",
"_head_confusion.csv": "head_confusion.csv",
"_report.json": "report.json",
}
return output_path.with_name(names.get(suffix, f"{output_path.stem}{suffix}"))
def resolve_output_path(output: str, run_id: str) -> Path:
path = Path(output)
if output == "auto":
return Path("outputs") / f"run_{run_id}" / "objects.csv"
if "{timestamp}" in output or "{run_id}" in output:
resolved = Path(output.format(timestamp=run_id, run_id=run_id))
if resolved.suffix:
return resolved
return resolved / "objects.csv"
if path.suffix == "":
return path / "objects.csv"
if path.exists():
return path.with_name(f"{path.stem}_{run_id}{path.suffix}")
return path
def safe_mean(values) -> float:
clean = pd.Series(values).dropna()
if clean.empty:
return np.nan
return float(clean.astype(bool).mean())
def safe_accuracy(df: pd.DataFrame, column: str) -> float | None:
clean = df[column].dropna()
if clean.empty:
return None
return float(clean.astype(bool).mean())
def macro_accuracy(df: pd.DataFrame, column: str) -> float | None:
clean = df.dropna(subset=[column])
if clean.empty:
return None
return float(clean.groupby("yolo_label")[column].apply(safe_mean).mean())
def confusion_frame(y_true, y_pred) -> pd.DataFrame:
clean = pd.DataFrame({"true": y_true, "pred": y_pred}).dropna()
labels = sorted(set(clean["true"]).union(set(clean["pred"])))
matrix = confusion_matrix(clean["true"], clean["pred"], labels=labels)
return pd.DataFrame(matrix, index=labels, columns=labels)
def top_confusions(df: pd.DataFrame, prediction_column: str, limit: int = 10) -> list[dict]:
clean = df.dropna(subset=[prediction_column])
wrong = clean[clean["yolo_label"] != clean[prediction_column]]
if wrong.empty:
return []
counts = (
wrong.groupby(["yolo_label", prediction_column])
.size()
.reset_index(name="count")
.sort_values("count", ascending=False)
.head(limit)
)
return [
{"yolo_label": row["yolo_label"], "predicted": row[prediction_column], "count": int(row["count"])}
for _, row in counts.iterrows()
]
def first_int(df: pd.DataFrame, column: str) -> int | None:
clean = df[column].dropna()
if clean.empty:
return None
return int(clean.iloc[0])
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run bulk I-JEPA prototype/head evaluation.")
parser.add_argument("--dataset-name", default=DEFAULT_OBSTACLE_DATASET)
parser.add_argument("--split", default="train")
parser.add_argument("--model-name", default="facebook/ijepa_vith14_1k")
parser.add_argument("--eval-samples", type=int, default=50)
parser.add_argument(
"--support-samples",
"--reference-samples",
dest="support_samples",
type=int,
default=80,
help="Images used to build class prototypes and train the tiny classifier.",
)
parser.add_argument("--seed", type=int, default=7)
parser.add_argument(
"--output",
default="auto",
help="Output CSV path or run directory. Supports {timestamp} or {run_id}. Default: outputs/run_<timestamp>/objects.csv",
)
return parser
def main() -> None:
args = build_parser().parse_args()
output = run_bulk_eval(**vars(args))
print(f"Saved bulk evaluation: {output}")
for suffix in [
"_summary.csv",
"_per_class.csv",
"_prototype_confusion.csv",
"_head_confusion.csv",
"_report.json",
]:
print(f"Saved: {sibling_path(output, suffix)}")
if __name__ == "__main__":
main()