| """ |
| aggregate_results.py |
| ==================== |
| Offline aggregator for Phase 1 ablation results. |
| |
| Pulls all final_report.json files from HF, assembles the band-assignment |
| matrix, and produces: |
| - band_matrix.csv: rows=variants, cols=bands observed |
| - anomalies.csv: configs where predicted_band != expected_band |
| - group_summaries.csv: per-group success/failure rates |
| - uniformity_diagnostic.csv: observed vs uniform sphere CV per config |
| |
| Run locally after Phase 1 completes (or during, to check progress). |
| """ |
|
|
| import os |
| import json |
| import csv |
| from pathlib import Path |
| from typing import Dict, List, Any |
| from collections import defaultdict |
|
|
| from huggingface_hub import HfApi, hf_hub_download, list_repo_files |
|
|
|
|
| HF_REPO = "AbstractPhil/geolip-svae-ablations" |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| OUTPUT_DIR = Path("./aggregate_output") |
| OUTPUT_DIR.mkdir(exist_ok=True) |
|
|
| hf_api = HfApi(token=HF_TOKEN) |
|
|
|
|
| def fetch_all_reports() -> List[Dict[str, Any]]: |
| """Pull every final_report.json from the HF repo.""" |
| reports = [] |
| |
| files = list_repo_files(repo_id=HF_REPO, token=HF_TOKEN) |
| report_files = [f for f in files if f.endswith('final_report.json')] |
| |
| print(f"Found {len(report_files)} reports on HF") |
| |
| for rpath in report_files: |
| try: |
| local_path = hf_hub_download( |
| repo_id=HF_REPO, |
| filename=rpath, |
| token=HF_TOKEN, |
| ) |
| with open(local_path) as f: |
| reports.append(json.load(f)) |
| except Exception as e: |
| print(f"Could not fetch {rpath}: {e}") |
| |
| return reports |
|
|
|
|
| def write_band_matrix(reports: List[Dict[str, Any]]) -> None: |
| """Write the main band-assignment matrix. |
| |
| Rows: (group, variant, band_expected) |
| Cols: observed_band_counts for each possible band |
| """ |
| matrix = defaultdict(lambda: defaultdict(int)) |
| |
| for r in reports: |
| key = (r['config']['group'], r['config']['variant'], r['expected_band']) |
| matrix[key][r['predicted_band']] += 1 |
| |
| output_path = OUTPUT_DIR / "band_matrix.csv" |
| with open(output_path, 'w', newline='') as f: |
| w = csv.writer(f) |
| w.writerow(['group', 'variant', 'expected_band', |
| 'n_LOW', 'n_MID', 'n_HIGH', 'n_UNCLASSIFIED', |
| 'total', 'match_rate']) |
| for (group, variant, expected), counts in sorted(matrix.items()): |
| total = sum(counts.values()) |
| match = counts.get(expected, 0) / total if total else 0.0 |
| w.writerow([group, variant, expected, |
| counts.get('LOW', 0), counts.get('MID', 0), |
| counts.get('HIGH', 0), counts.get('UNCLASSIFIED', 0), |
| total, f"{match:.2f}"]) |
| |
| print(f"Wrote {output_path}") |
|
|
|
|
| def write_anomalies(reports: List[Dict[str, Any]]) -> None: |
| """Write configs where predicted_band != expected_band. |
| |
| These are the interesting ones — either a bug, or a real finding |
| (an ablation broke the band structure). |
| """ |
| anomalies = [r for r in reports if not r.get('band_match', True)] |
| |
| output_path = OUTPUT_DIR / "anomalies.csv" |
| with open(output_path, 'w', newline='') as f: |
| w = csv.writer(f) |
| w.writerow(['description', 'expected_band', 'predicted_band', |
| 'cv_ema_final', 'observed_sphere_cv', |
| 'uniform_sphere_cv_prediction', 'band_deviation', |
| 'test_mse']) |
| for r in anomalies: |
| w.writerow([ |
| r['config']['description'], |
| r['expected_band'], |
| r['predicted_band'], |
| f"{r['cv_ema_final']:.4f}", |
| f"{r['observed_sphere_cv']:.4f}", |
| f"{r['uniform_sphere_cv_prediction']:.4f}", |
| f"{r['band_deviation']:.4f}", |
| f"{r['test_mse']:.6f}", |
| ]) |
| |
| print(f"Wrote {output_path} ({len(anomalies)} anomalies)") |
|
|
|
|
| def write_group_summaries(reports: List[Dict[str, Any]]) -> None: |
| """Per-group success rates.""" |
| groups = defaultdict(lambda: {'match': 0, 'total': 0, 'failed': 0}) |
| |
| for r in reports: |
| g = r['config']['group'] |
| groups[g]['total'] += 1 |
| if r.get('band_match', True): |
| groups[g]['match'] += 1 |
| else: |
| groups[g]['failed'] += 1 |
| |
| output_path = OUTPUT_DIR / "group_summaries.csv" |
| with open(output_path, 'w', newline='') as f: |
| w = csv.writer(f) |
| w.writerow(['group', 'total', 'match', 'failed', 'match_rate']) |
| for g in sorted(groups.keys()): |
| s = groups[g] |
| rate = s['match'] / s['total'] if s['total'] else 0.0 |
| w.writerow([g, s['total'], s['match'], s['failed'], f"{rate:.2f}"]) |
| |
| print(f"Wrote {output_path}") |
|
|
|
|
| def write_uniformity_diagnostic(reports: List[Dict[str, Any]]) -> None: |
| """Group N — observed vs uniform sphere CV per config. |
| |
| Identifies which configs deviate significantly from uniform-sphere |
| prediction. Large positive deviation = model NOT reaching uniform |
| attractor despite being in-band. |
| """ |
| output_path = OUTPUT_DIR / "uniformity_diagnostic.csv" |
| with open(output_path, 'w', newline='') as f: |
| w = csv.writer(f) |
| w.writerow(['description', 'expected_band', |
| 'observed_sphere_cv', 'uniform_sphere_cv_prediction', |
| 'band_deviation', 'cv_ema_final']) |
| for r in sorted(reports, key=lambda x: abs(x.get('band_deviation', 0)), reverse=True): |
| w.writerow([ |
| r['config']['description'], |
| r['expected_band'], |
| f"{r['observed_sphere_cv']:.4f}", |
| f"{r['uniform_sphere_cv_prediction']:.4f}", |
| f"{r['band_deviation']:+.4f}", |
| f"{r['cv_ema_final']:.4f}", |
| ]) |
| |
| print(f"Wrote {output_path}") |
|
|
|
|
| def print_summary(reports: List[Dict[str, Any]]) -> None: |
| """Quick text summary for stdout.""" |
| print(f"\n{'='*70}") |
| print(f"ABLATION AGGREGATE SUMMARY — {len(reports)} reports") |
| print(f"{'='*70}") |
| |
| |
| expected_total = len(get_phase1_configs()) |
| print(f"Expected Phase 1 configs: {expected_total}") |
| print(f"Completed so far: {len(reports)} ({100*len(reports)/expected_total:.1f}%)") |
| |
| |
| matches = sum(1 for r in reports if r.get('band_match', True)) |
| print(f"\nBand preservation rate: {matches}/{len(reports)} ({100*matches/len(reports):.1f}%)") |
| |
| |
| print(f"\nPer-group preservation:") |
| groups = defaultdict(lambda: [0, 0]) |
| for r in reports: |
| g = r['config']['group'] |
| groups[g][1] += 1 |
| if r.get('band_match', True): |
| groups[g][0] += 1 |
| for g in sorted(groups.keys()): |
| m, t = groups[g] |
| print(f" {g:12s} {m}/{t} ({100*m/t:.0f}%)") |
|
|
|
|
| if __name__ == '__main__': |
| reports = fetch_all_reports() |
| |
| if not reports: |
| print("No reports found on HF — did Phase 1 start yet?") |
| else: |
| print_summary(reports) |
| write_band_matrix(reports) |
| write_anomalies(reports) |
| write_group_summaries(reports) |
| write_uniformity_diagnostic(reports) |
| print(f"\nAll outputs in {OUTPUT_DIR}/") |