feather-runtime / overlay /scripts /sweep_depth_aggregate.py
Jackoatmon's picture
Update Feather h200 training runtime image
e317e25 verified
#!/usr/bin/env python3
"""Aggregator for depth-sweep results.
Reads the sweep manifest at /tmp/sweep_depth_manifest.txt, pulls HF Jobs logs
for each job, extracts the [METRICS_JSON] stdout line, and prints a
comparison table of per-layer diagnostics across n_layer values.
Usage:
export HF_TOKEN=...
python scripts/sweep_depth_aggregate.py [manifest_path]
"""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
MANIFEST = Path(sys.argv[1] if len(sys.argv) > 1 else '/tmp/sweep_depth_manifest.txt')
def fetch_metrics_from_job(job_id: str) -> dict | None:
"""Fetch HF Job stdout and parse the [METRICS_JSON] line."""
try:
from huggingface_hub import HfApi # type: ignore
except Exception as e:
print(f'ERROR: huggingface_hub missing: {e}', file=sys.stderr)
return None
api = HfApi(token=os.environ.get('HF_TOKEN'))
try:
logs_stream = api.fetch_job_logs(job_id=job_id)
except Exception as e:
print(f'[agg] could not fetch logs for job={job_id}: {e}', file=sys.stderr)
return None
last_json = None
for line in logs_stream:
# HfApi returns strings or JobLogEntry-like objects depending on version.
text = getattr(line, 'data', None) or str(line)
if '[METRICS_JSON]' in text:
payload = text.split('[METRICS_JSON]', 1)[1].strip()
try:
last_json = json.loads(payload)
except Exception:
# Might be truncated on a line boundary — keep looking.
pass
return last_json
def compare(results: dict[int, dict]) -> None:
"""Pretty-print comparison across n_layer values."""
if not results:
print('[agg] no results')
return
sorted_n = sorted(results.keys())
# Top-level scalars
print('\n=== Top-level scalars ===')
hdr = ['metric'] + [f'L={n}' for n in sorted_n]
print(' '.join(f'{h:>14}' for h in hdr))
for key in ('val_bpb', 'val_ppl', 'num_params_M', 'total_tokens_M',
'training_seconds', 'peak_vram_mb', 'sdr_target_active',
'htm_anomaly', 'engram_hit_rate', 'sdr_active_bits'):
row = [key] + [f'{results[n].get(key, float("nan")):.4f}' if isinstance(results[n].get(key), (int, float)) else 'n/a' for n in sorted_n]
print(' '.join(f'{c:>14}' for c in row))
# Per-layer panel — one table per metric.
print('\n=== Per-layer: delta_ratio (residual contribution) ===')
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
max_depth = max(results[n].get('n_layer', 0) for n in sorted_n)
for li in range(max_depth):
row = [f'L{li:02d}']
for n in sorted_n:
v = results[n].get(f'layer_{li}_delta_ratio')
row.append(f'{v:.4f}' if isinstance(v, (int, float)) else ' -')
print(' '.join(f'{c:>7}' for c in row))
print('\n=== Per-layer: grad_norm ===')
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
for li in range(max_depth):
row = [f'L{li:02d}']
for n in sorted_n:
v = results[n].get(f'layer_{li}_grad_norm')
row.append(f'{v:.2e}' if isinstance(v, (int, float)) else ' -')
print(' '.join(f'{c:>9}' for c in row))
print('\n=== Per-layer: eff_rank (participation-ratio) ===')
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
for li in range(max_depth):
row = [f'L{li:02d}']
for n in sorted_n:
v = results[n].get(f'layer_{li}_eff_rank')
row.append(f'{v:.1f}' if isinstance(v, (int, float)) else ' -')
print(' '.join(f'{c:>7}' for c in row))
print('\n=== Per-layer: feat_std ===')
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
for li in range(max_depth):
row = [f'L{li:02d}']
for n in sorted_n:
v = results[n].get(f'layer_{li}_feat_std')
row.append(f'{v:.4f}' if isinstance(v, (int, float)) else ' -')
print(' '.join(f'{c:>7}' for c in row))
# Dead-layer detection
print('\n=== Dead-layer detection (delta_ratio < 0.02) ===')
for n in sorted_n:
r = results[n]
n_layer = r.get('n_layer', 0)
dead = []
for li in range(n_layer):
v = r.get(f'layer_{li}_delta_ratio')
if isinstance(v, (int, float)) and v < 0.02:
dead.append(li)
status = 'ALL LIVE' if not dead else f'DEAD LAYERS: {dead}'
print(f' n_layer={n:2d} val_bpb={r.get("val_bpb", float("nan")):.4f} {status}')
def main() -> int:
if not MANIFEST.exists():
print(f'ERROR: manifest not found at {MANIFEST}', file=sys.stderr)
return 2
lines = MANIFEST.read_text().splitlines()[1:] # skip header
jobs = {}
for ln in lines:
parts = ln.strip().split('\t')
if len(parts) < 2:
continue
try:
n_layer = int(parts[0])
job_id = parts[1]
except ValueError:
continue
jobs[n_layer] = job_id
print(f'[agg] reading {len(jobs)} jobs from {MANIFEST}')
results: dict[int, dict] = {}
for n, jid in jobs.items():
print(f'[agg] fetching job={jid} (n_layer={n}) ...')
m = fetch_metrics_from_job(jid)
if m is None:
print(f'[agg] no metrics for n_layer={n} (job likely still running or failed)')
continue
results[n] = m
compare(results)
out_path = Path('/tmp/sweep_depth_aggregated.json')
out_path.write_text(json.dumps(results, indent=2, sort_keys=True))
print(f'\n[agg] wrote aggregated results to {out_path}')
return 0
if __name__ == '__main__':
raise SystemExit(main())