File size: 7,358 Bytes
69f5135 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | """
aggregate_results.py
====================
Offline aggregator for Phase 1 ablation results.
Pulls all final_report.json files from HF, assembles the band-assignment
matrix, and produces:
- band_matrix.csv: rows=variants, cols=bands observed
- anomalies.csv: configs where predicted_band != expected_band
- group_summaries.csv: per-group success/failure rates
- uniformity_diagnostic.csv: observed vs uniform sphere CV per config
Run locally after Phase 1 completes (or during, to check progress).
"""
import os
import json
import csv
from pathlib import Path
from typing import Dict, List, Any
from collections import defaultdict
from huggingface_hub import HfApi, hf_hub_download, list_repo_files
HF_REPO = "AbstractPhil/geolip-svae-ablations"
HF_TOKEN = os.environ.get("HF_TOKEN")
OUTPUT_DIR = Path("./aggregate_output")
OUTPUT_DIR.mkdir(exist_ok=True)
hf_api = HfApi(token=HF_TOKEN)
def fetch_all_reports() -> List[Dict[str, Any]]:
"""Pull every final_report.json from the HF repo."""
reports = []
files = list_repo_files(repo_id=HF_REPO, token=HF_TOKEN)
report_files = [f for f in files if f.endswith('final_report.json')]
print(f"Found {len(report_files)} reports on HF")
for rpath in report_files:
try:
local_path = hf_hub_download(
repo_id=HF_REPO,
filename=rpath,
token=HF_TOKEN,
)
with open(local_path) as f:
reports.append(json.load(f))
except Exception as e:
print(f"Could not fetch {rpath}: {e}")
return reports
def write_band_matrix(reports: List[Dict[str, Any]]) -> None:
"""Write the main band-assignment matrix.
Rows: (group, variant, band_expected)
Cols: observed_band_counts for each possible band
"""
matrix = defaultdict(lambda: defaultdict(int))
for r in reports:
key = (r['config']['group'], r['config']['variant'], r['expected_band'])
matrix[key][r['predicted_band']] += 1
output_path = OUTPUT_DIR / "band_matrix.csv"
with open(output_path, 'w', newline='') as f:
w = csv.writer(f)
w.writerow(['group', 'variant', 'expected_band',
'n_LOW', 'n_MID', 'n_HIGH', 'n_UNCLASSIFIED',
'total', 'match_rate'])
for (group, variant, expected), counts in sorted(matrix.items()):
total = sum(counts.values())
match = counts.get(expected, 0) / total if total else 0.0
w.writerow([group, variant, expected,
counts.get('LOW', 0), counts.get('MID', 0),
counts.get('HIGH', 0), counts.get('UNCLASSIFIED', 0),
total, f"{match:.2f}"])
print(f"Wrote {output_path}")
def write_anomalies(reports: List[Dict[str, Any]]) -> None:
"""Write configs where predicted_band != expected_band.
These are the interesting ones — either a bug, or a real finding
(an ablation broke the band structure).
"""
anomalies = [r for r in reports if not r.get('band_match', True)]
output_path = OUTPUT_DIR / "anomalies.csv"
with open(output_path, 'w', newline='') as f:
w = csv.writer(f)
w.writerow(['description', 'expected_band', 'predicted_band',
'cv_ema_final', 'observed_sphere_cv',
'uniform_sphere_cv_prediction', 'band_deviation',
'test_mse'])
for r in anomalies:
w.writerow([
r['config']['description'],
r['expected_band'],
r['predicted_band'],
f"{r['cv_ema_final']:.4f}",
f"{r['observed_sphere_cv']:.4f}",
f"{r['uniform_sphere_cv_prediction']:.4f}",
f"{r['band_deviation']:.4f}",
f"{r['test_mse']:.6f}",
])
print(f"Wrote {output_path} ({len(anomalies)} anomalies)")
def write_group_summaries(reports: List[Dict[str, Any]]) -> None:
"""Per-group success rates."""
groups = defaultdict(lambda: {'match': 0, 'total': 0, 'failed': 0})
for r in reports:
g = r['config']['group']
groups[g]['total'] += 1
if r.get('band_match', True):
groups[g]['match'] += 1
else:
groups[g]['failed'] += 1
output_path = OUTPUT_DIR / "group_summaries.csv"
with open(output_path, 'w', newline='') as f:
w = csv.writer(f)
w.writerow(['group', 'total', 'match', 'failed', 'match_rate'])
for g in sorted(groups.keys()):
s = groups[g]
rate = s['match'] / s['total'] if s['total'] else 0.0
w.writerow([g, s['total'], s['match'], s['failed'], f"{rate:.2f}"])
print(f"Wrote {output_path}")
def write_uniformity_diagnostic(reports: List[Dict[str, Any]]) -> None:
"""Group N — observed vs uniform sphere CV per config.
Identifies which configs deviate significantly from uniform-sphere
prediction. Large positive deviation = model NOT reaching uniform
attractor despite being in-band.
"""
output_path = OUTPUT_DIR / "uniformity_diagnostic.csv"
with open(output_path, 'w', newline='') as f:
w = csv.writer(f)
w.writerow(['description', 'expected_band',
'observed_sphere_cv', 'uniform_sphere_cv_prediction',
'band_deviation', 'cv_ema_final'])
for r in sorted(reports, key=lambda x: abs(x.get('band_deviation', 0)), reverse=True):
w.writerow([
r['config']['description'],
r['expected_band'],
f"{r['observed_sphere_cv']:.4f}",
f"{r['uniform_sphere_cv_prediction']:.4f}",
f"{r['band_deviation']:+.4f}",
f"{r['cv_ema_final']:.4f}",
])
print(f"Wrote {output_path}")
def print_summary(reports: List[Dict[str, Any]]) -> None:
"""Quick text summary for stdout."""
print(f"\n{'='*70}")
print(f"ABLATION AGGREGATE SUMMARY — {len(reports)} reports")
print(f"{'='*70}")
# Expected total from matrix
expected_total = len(get_phase1_configs())
print(f"Expected Phase 1 configs: {expected_total}")
print(f"Completed so far: {len(reports)} ({100*len(reports)/expected_total:.1f}%)")
# Band preservation rate
matches = sum(1 for r in reports if r.get('band_match', True))
print(f"\nBand preservation rate: {matches}/{len(reports)} ({100*matches/len(reports):.1f}%)")
# Per-group breakdown
print(f"\nPer-group preservation:")
groups = defaultdict(lambda: [0, 0])
for r in reports:
g = r['config']['group']
groups[g][1] += 1
if r.get('band_match', True):
groups[g][0] += 1
for g in sorted(groups.keys()):
m, t = groups[g]
print(f" {g:12s} {m}/{t} ({100*m/t:.0f}%)")
if __name__ == '__main__':
reports = fetch_all_reports()
if not reports:
print("No reports found on HF — did Phase 1 start yet?")
else:
print_summary(reports)
write_band_matrix(reports)
write_anomalies(reports)
write_group_summaries(reports)
write_uniformity_diagnostic(reports)
print(f"\nAll outputs in {OUTPUT_DIR}/") |