""" aggregate_results.py ==================== Offline aggregator for Phase 1 ablation results. Pulls all final_report.json files from HF, assembles the band-assignment matrix, and produces: - band_matrix.csv: rows=variants, cols=bands observed - anomalies.csv: configs where predicted_band != expected_band - group_summaries.csv: per-group success/failure rates - uniformity_diagnostic.csv: observed vs uniform sphere CV per config Run locally after Phase 1 completes (or during, to check progress). """ import os import json import csv from pathlib import Path from typing import Dict, List, Any from collections import defaultdict from huggingface_hub import HfApi, hf_hub_download, list_repo_files HF_REPO = "AbstractPhil/geolip-svae-ablations" HF_TOKEN = os.environ.get("HF_TOKEN") OUTPUT_DIR = Path("./aggregate_output") OUTPUT_DIR.mkdir(exist_ok=True) hf_api = HfApi(token=HF_TOKEN) def fetch_all_reports() -> List[Dict[str, Any]]: """Pull every final_report.json from the HF repo.""" reports = [] files = list_repo_files(repo_id=HF_REPO, token=HF_TOKEN) report_files = [f for f in files if f.endswith('final_report.json')] print(f"Found {len(report_files)} reports on HF") for rpath in report_files: try: local_path = hf_hub_download( repo_id=HF_REPO, filename=rpath, token=HF_TOKEN, ) with open(local_path) as f: reports.append(json.load(f)) except Exception as e: print(f"Could not fetch {rpath}: {e}") return reports def write_band_matrix(reports: List[Dict[str, Any]]) -> None: """Write the main band-assignment matrix. Rows: (group, variant, band_expected) Cols: observed_band_counts for each possible band """ matrix = defaultdict(lambda: defaultdict(int)) for r in reports: key = (r['config']['group'], r['config']['variant'], r['expected_band']) matrix[key][r['predicted_band']] += 1 output_path = OUTPUT_DIR / "band_matrix.csv" with open(output_path, 'w', newline='') as f: w = csv.writer(f) w.writerow(['group', 'variant', 'expected_band', 'n_LOW', 'n_MID', 'n_HIGH', 'n_UNCLASSIFIED', 'total', 'match_rate']) for (group, variant, expected), counts in sorted(matrix.items()): total = sum(counts.values()) match = counts.get(expected, 0) / total if total else 0.0 w.writerow([group, variant, expected, counts.get('LOW', 0), counts.get('MID', 0), counts.get('HIGH', 0), counts.get('UNCLASSIFIED', 0), total, f"{match:.2f}"]) print(f"Wrote {output_path}") def write_anomalies(reports: List[Dict[str, Any]]) -> None: """Write configs where predicted_band != expected_band. These are the interesting ones — either a bug, or a real finding (an ablation broke the band structure). """ anomalies = [r for r in reports if not r.get('band_match', True)] output_path = OUTPUT_DIR / "anomalies.csv" with open(output_path, 'w', newline='') as f: w = csv.writer(f) w.writerow(['description', 'expected_band', 'predicted_band', 'cv_ema_final', 'observed_sphere_cv', 'uniform_sphere_cv_prediction', 'band_deviation', 'test_mse']) for r in anomalies: w.writerow([ r['config']['description'], r['expected_band'], r['predicted_band'], f"{r['cv_ema_final']:.4f}", f"{r['observed_sphere_cv']:.4f}", f"{r['uniform_sphere_cv_prediction']:.4f}", f"{r['band_deviation']:.4f}", f"{r['test_mse']:.6f}", ]) print(f"Wrote {output_path} ({len(anomalies)} anomalies)") def write_group_summaries(reports: List[Dict[str, Any]]) -> None: """Per-group success rates.""" groups = defaultdict(lambda: {'match': 0, 'total': 0, 'failed': 0}) for r in reports: g = r['config']['group'] groups[g]['total'] += 1 if r.get('band_match', True): groups[g]['match'] += 1 else: groups[g]['failed'] += 1 output_path = OUTPUT_DIR / "group_summaries.csv" with open(output_path, 'w', newline='') as f: w = csv.writer(f) w.writerow(['group', 'total', 'match', 'failed', 'match_rate']) for g in sorted(groups.keys()): s = groups[g] rate = s['match'] / s['total'] if s['total'] else 0.0 w.writerow([g, s['total'], s['match'], s['failed'], f"{rate:.2f}"]) print(f"Wrote {output_path}") def write_uniformity_diagnostic(reports: List[Dict[str, Any]]) -> None: """Group N — observed vs uniform sphere CV per config. Identifies which configs deviate significantly from uniform-sphere prediction. Large positive deviation = model NOT reaching uniform attractor despite being in-band. """ output_path = OUTPUT_DIR / "uniformity_diagnostic.csv" with open(output_path, 'w', newline='') as f: w = csv.writer(f) w.writerow(['description', 'expected_band', 'observed_sphere_cv', 'uniform_sphere_cv_prediction', 'band_deviation', 'cv_ema_final']) for r in sorted(reports, key=lambda x: abs(x.get('band_deviation', 0)), reverse=True): w.writerow([ r['config']['description'], r['expected_band'], f"{r['observed_sphere_cv']:.4f}", f"{r['uniform_sphere_cv_prediction']:.4f}", f"{r['band_deviation']:+.4f}", f"{r['cv_ema_final']:.4f}", ]) print(f"Wrote {output_path}") def print_summary(reports: List[Dict[str, Any]]) -> None: """Quick text summary for stdout.""" print(f"\n{'='*70}") print(f"ABLATION AGGREGATE SUMMARY — {len(reports)} reports") print(f"{'='*70}") # Expected total from matrix expected_total = len(get_phase1_configs()) print(f"Expected Phase 1 configs: {expected_total}") print(f"Completed so far: {len(reports)} ({100*len(reports)/expected_total:.1f}%)") # Band preservation rate matches = sum(1 for r in reports if r.get('band_match', True)) print(f"\nBand preservation rate: {matches}/{len(reports)} ({100*matches/len(reports):.1f}%)") # Per-group breakdown print(f"\nPer-group preservation:") groups = defaultdict(lambda: [0, 0]) for r in reports: g = r['config']['group'] groups[g][1] += 1 if r.get('band_match', True): groups[g][0] += 1 for g in sorted(groups.keys()): m, t = groups[g] print(f" {g:12s} {m}/{t} ({100*m/t:.0f}%)") if __name__ == '__main__': reports = fetch_all_reports() if not reports: print("No reports found on HF — did Phase 1 start yet?") else: print_summary(reports) write_band_matrix(reports) write_anomalies(reports) write_group_summaries(reports) write_uniformity_diagnostic(reports) print(f"\nAll outputs in {OUTPUT_DIR}/")