Spaces:
Runtime error
Runtime error
File size: 5,776 Bytes
e317e25 6707b9b e317e25 4a0d594 e317e25 6707b9b e317e25 6707b9b e317e25 6707b9b e317e25 4a0d594 e317e25 6707b9b e317e25 6707b9b e317e25 6707b9b e317e25 6707b9b e317e25 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | #!/usr/bin/env python3
"""Aggregator for depth-sweep results.
Reads the sweep manifest at /tmp/sweep_depth_manifest.txt, pulls HF Jobs logs
for each job, extracts the [METRICS_JSON] stdout line, and prints a
comparison table of per-layer diagnostics across n_layer values.
Usage:
export HF_TOKEN=...
python scripts/sweep_depth_aggregate.py [manifest_path]
"""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
MANIFEST = Path(sys.argv[1] if len(sys.argv) > 1 else '/tmp/sweep_depth_manifest.txt')
def fetch_metrics_from_job(job_id: str) -> dict | None:
"""Fetch HF Job stdout and parse the [METRICS_JSON] line."""
try:
from huggingface_hub import HfApi # type: ignore
except Exception as e:
print(f'ERROR: huggingface_hub missing: {e}', file=sys.stderr)
return None
api = HfApi(token=os.environ.get('HF_TOKEN'))
try:
logs_stream = api.fetch_job_logs(job_id=job_id)
except Exception as e:
print(f'[agg] could not fetch logs for job={job_id}: {e}', file=sys.stderr)
return None
last_json = None
for line in logs_stream:
# HfApi returns strings or JobLogEntry-like objects depending on version.
text = getattr(line, 'data', None) or str(line)
if '[METRICS_JSON]' in text:
payload = text.split('[METRICS_JSON]', 1)[1].strip()
try:
last_json = json.loads(payload)
except Exception:
# Might be truncated on a line boundary — keep looking.
pass
return last_json
def compare(results: dict[int, dict]) -> None:
"""Pretty-print comparison across n_layer values."""
if not results:
print('[agg] no results')
return
sorted_n = sorted(results.keys())
# Top-level scalars
print('\n=== Top-level scalars ===')
hdr = ['metric'] + [f'L={n}' for n in sorted_n]
print(' '.join(f'{h:>14}' for h in hdr))
for key in ('val_bpb', 'val_ppl', 'num_params_M', 'total_tokens_M',
'training_seconds', 'peak_vram_mb', 'sdr_target_active',
'htm_anomaly', 'engram_hit_rate', 'sdr_active_bits'):
row = [key] + [f'{results[n].get(key, float("nan")):.4f}' if isinstance(results[n].get(key), (int, float)) else 'n/a' for n in sorted_n]
print(' '.join(f'{c:>14}' for c in row))
# Per-layer panel — one table per metric.
print('\n=== Per-layer: delta_ratio (residual contribution) ===')
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
max_depth = max(results[n].get('n_layer', 0) for n in sorted_n)
for li in range(max_depth):
row = [f'L{li:02d}']
for n in sorted_n:
v = results[n].get(f'layer_{li}_delta_ratio')
row.append(f'{v:.4f}' if isinstance(v, (int, float)) else ' -')
print(' '.join(f'{c:>7}' for c in row))
print('\n=== Per-layer: grad_norm ===')
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
for li in range(max_depth):
row = [f'L{li:02d}']
for n in sorted_n:
v = results[n].get(f'layer_{li}_grad_norm')
row.append(f'{v:.2e}' if isinstance(v, (int, float)) else ' -')
print(' '.join(f'{c:>9}' for c in row))
print('\n=== Per-layer: eff_rank (participation-ratio) ===')
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
for li in range(max_depth):
row = [f'L{li:02d}']
for n in sorted_n:
v = results[n].get(f'layer_{li}_eff_rank')
row.append(f'{v:.1f}' if isinstance(v, (int, float)) else ' -')
print(' '.join(f'{c:>7}' for c in row))
print('\n=== Per-layer: feat_std ===')
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
for li in range(max_depth):
row = [f'L{li:02d}']
for n in sorted_n:
v = results[n].get(f'layer_{li}_feat_std')
row.append(f'{v:.4f}' if isinstance(v, (int, float)) else ' -')
print(' '.join(f'{c:>7}' for c in row))
# Dead-layer detection
print('\n=== Dead-layer detection (delta_ratio < 0.02) ===')
for n in sorted_n:
r = results[n]
n_layer = r.get('n_layer', 0)
dead = []
for li in range(n_layer):
v = r.get(f'layer_{li}_delta_ratio')
if isinstance(v, (int, float)) and v < 0.02:
dead.append(li)
status = 'ALL LIVE' if not dead else f'DEAD LAYERS: {dead}'
print(f' n_layer={n:2d} val_bpb={r.get("val_bpb", float("nan")):.4f} {status}')
def main() -> int:
if not MANIFEST.exists():
print(f'ERROR: manifest not found at {MANIFEST}', file=sys.stderr)
return 2
lines = MANIFEST.read_text().splitlines()[1:] # skip header
jobs = {}
for ln in lines:
parts = ln.strip().split('\t')
if len(parts) < 2:
continue
try:
n_layer = int(parts[0])
job_id = parts[1]
except ValueError:
continue
jobs[n_layer] = job_id
print(f'[agg] reading {len(jobs)} jobs from {MANIFEST}')
results: dict[int, dict] = {}
for n, jid in jobs.items():
print(f'[agg] fetching job={jid} (n_layer={n}) ...')
m = fetch_metrics_from_job(jid)
if m is None:
print(f'[agg] no metrics for n_layer={n} (job likely still running or failed)')
continue
results[n] = m
compare(results)
out_path = Path('/tmp/sweep_depth_aggregated.json')
out_path.write_text(json.dumps(results, indent=2, sort_keys=True))
print(f'\n[agg] wrote aggregated results to {out_path}')
return 0
if __name__ == '__main__':
raise SystemExit(main())
|