geolip-svae-ablations / aggregate_reports.py
AbstractPhil's picture
Create aggregate_reports.py
69f5135 verified
"""
aggregate_results.py
====================
Offline aggregator for Phase 1 ablation results.
Pulls all final_report.json files from HF, assembles the band-assignment
matrix, and produces:
- band_matrix.csv: rows=variants, cols=bands observed
- anomalies.csv: configs where predicted_band != expected_band
- group_summaries.csv: per-group success/failure rates
- uniformity_diagnostic.csv: observed vs uniform sphere CV per config
Run locally after Phase 1 completes (or during, to check progress).
"""
import os
import json
import csv
from pathlib import Path
from typing import Dict, List, Any
from collections import defaultdict
from huggingface_hub import HfApi, hf_hub_download, list_repo_files
HF_REPO = "AbstractPhil/geolip-svae-ablations"
HF_TOKEN = os.environ.get("HF_TOKEN")
OUTPUT_DIR = Path("./aggregate_output")
OUTPUT_DIR.mkdir(exist_ok=True)
hf_api = HfApi(token=HF_TOKEN)
def fetch_all_reports() -> List[Dict[str, Any]]:
"""Pull every final_report.json from the HF repo."""
reports = []
files = list_repo_files(repo_id=HF_REPO, token=HF_TOKEN)
report_files = [f for f in files if f.endswith('final_report.json')]
print(f"Found {len(report_files)} reports on HF")
for rpath in report_files:
try:
local_path = hf_hub_download(
repo_id=HF_REPO,
filename=rpath,
token=HF_TOKEN,
)
with open(local_path) as f:
reports.append(json.load(f))
except Exception as e:
print(f"Could not fetch {rpath}: {e}")
return reports
def write_band_matrix(reports: List[Dict[str, Any]]) -> None:
"""Write the main band-assignment matrix.
Rows: (group, variant, band_expected)
Cols: observed_band_counts for each possible band
"""
matrix = defaultdict(lambda: defaultdict(int))
for r in reports:
key = (r['config']['group'], r['config']['variant'], r['expected_band'])
matrix[key][r['predicted_band']] += 1
output_path = OUTPUT_DIR / "band_matrix.csv"
with open(output_path, 'w', newline='') as f:
w = csv.writer(f)
w.writerow(['group', 'variant', 'expected_band',
'n_LOW', 'n_MID', 'n_HIGH', 'n_UNCLASSIFIED',
'total', 'match_rate'])
for (group, variant, expected), counts in sorted(matrix.items()):
total = sum(counts.values())
match = counts.get(expected, 0) / total if total else 0.0
w.writerow([group, variant, expected,
counts.get('LOW', 0), counts.get('MID', 0),
counts.get('HIGH', 0), counts.get('UNCLASSIFIED', 0),
total, f"{match:.2f}"])
print(f"Wrote {output_path}")
def write_anomalies(reports: List[Dict[str, Any]]) -> None:
"""Write configs where predicted_band != expected_band.
These are the interesting ones — either a bug, or a real finding
(an ablation broke the band structure).
"""
anomalies = [r for r in reports if not r.get('band_match', True)]
output_path = OUTPUT_DIR / "anomalies.csv"
with open(output_path, 'w', newline='') as f:
w = csv.writer(f)
w.writerow(['description', 'expected_band', 'predicted_band',
'cv_ema_final', 'observed_sphere_cv',
'uniform_sphere_cv_prediction', 'band_deviation',
'test_mse'])
for r in anomalies:
w.writerow([
r['config']['description'],
r['expected_band'],
r['predicted_band'],
f"{r['cv_ema_final']:.4f}",
f"{r['observed_sphere_cv']:.4f}",
f"{r['uniform_sphere_cv_prediction']:.4f}",
f"{r['band_deviation']:.4f}",
f"{r['test_mse']:.6f}",
])
print(f"Wrote {output_path} ({len(anomalies)} anomalies)")
def write_group_summaries(reports: List[Dict[str, Any]]) -> None:
"""Per-group success rates."""
groups = defaultdict(lambda: {'match': 0, 'total': 0, 'failed': 0})
for r in reports:
g = r['config']['group']
groups[g]['total'] += 1
if r.get('band_match', True):
groups[g]['match'] += 1
else:
groups[g]['failed'] += 1
output_path = OUTPUT_DIR / "group_summaries.csv"
with open(output_path, 'w', newline='') as f:
w = csv.writer(f)
w.writerow(['group', 'total', 'match', 'failed', 'match_rate'])
for g in sorted(groups.keys()):
s = groups[g]
rate = s['match'] / s['total'] if s['total'] else 0.0
w.writerow([g, s['total'], s['match'], s['failed'], f"{rate:.2f}"])
print(f"Wrote {output_path}")
def write_uniformity_diagnostic(reports: List[Dict[str, Any]]) -> None:
"""Group N — observed vs uniform sphere CV per config.
Identifies which configs deviate significantly from uniform-sphere
prediction. Large positive deviation = model NOT reaching uniform
attractor despite being in-band.
"""
output_path = OUTPUT_DIR / "uniformity_diagnostic.csv"
with open(output_path, 'w', newline='') as f:
w = csv.writer(f)
w.writerow(['description', 'expected_band',
'observed_sphere_cv', 'uniform_sphere_cv_prediction',
'band_deviation', 'cv_ema_final'])
for r in sorted(reports, key=lambda x: abs(x.get('band_deviation', 0)), reverse=True):
w.writerow([
r['config']['description'],
r['expected_band'],
f"{r['observed_sphere_cv']:.4f}",
f"{r['uniform_sphere_cv_prediction']:.4f}",
f"{r['band_deviation']:+.4f}",
f"{r['cv_ema_final']:.4f}",
])
print(f"Wrote {output_path}")
def print_summary(reports: List[Dict[str, Any]]) -> None:
"""Quick text summary for stdout."""
print(f"\n{'='*70}")
print(f"ABLATION AGGREGATE SUMMARY — {len(reports)} reports")
print(f"{'='*70}")
# Expected total from matrix
expected_total = len(get_phase1_configs())
print(f"Expected Phase 1 configs: {expected_total}")
print(f"Completed so far: {len(reports)} ({100*len(reports)/expected_total:.1f}%)")
# Band preservation rate
matches = sum(1 for r in reports if r.get('band_match', True))
print(f"\nBand preservation rate: {matches}/{len(reports)} ({100*matches/len(reports):.1f}%)")
# Per-group breakdown
print(f"\nPer-group preservation:")
groups = defaultdict(lambda: [0, 0])
for r in reports:
g = r['config']['group']
groups[g][1] += 1
if r.get('band_match', True):
groups[g][0] += 1
for g in sorted(groups.keys()):
m, t = groups[g]
print(f" {g:12s} {m}/{t} ({100*m/t:.0f}%)")
if __name__ == '__main__':
reports = fetch_all_reports()
if not reports:
print("No reports found on HF — did Phase 1 start yet?")
else:
print_summary(reports)
write_band_matrix(reports)
write_anomalies(reports)
write_group_summaries(reports)
write_uniformity_diagnostic(reports)
print(f"\nAll outputs in {OUTPUT_DIR}/")