Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """Aggregator for depth-sweep results. | |
| Reads the sweep manifest at /tmp/sweep_depth_manifest.txt, pulls HF Jobs logs | |
| for each job, extracts the [METRICS_JSON] stdout line, and prints a | |
| comparison table of per-layer diagnostics across n_layer values. | |
| Usage: | |
| export HF_TOKEN=... | |
| python scripts/sweep_depth_aggregate.py [manifest_path] | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| MANIFEST = Path(sys.argv[1] if len(sys.argv) > 1 else '/tmp/sweep_depth_manifest.txt') | |
| def fetch_metrics_from_job(job_id: str) -> dict | None: | |
| """Fetch HF Job stdout and parse the [METRICS_JSON] line.""" | |
| try: | |
| from huggingface_hub import HfApi # type: ignore | |
| except Exception as e: | |
| print(f'ERROR: huggingface_hub missing: {e}', file=sys.stderr) | |
| return None | |
| api = HfApi(token=os.environ.get('HF_TOKEN')) | |
| try: | |
| logs_stream = api.fetch_job_logs(job_id=job_id) | |
| except Exception as e: | |
| print(f'[agg] could not fetch logs for job={job_id}: {e}', file=sys.stderr) | |
| return None | |
| last_json = None | |
| for line in logs_stream: | |
| # HfApi returns strings or JobLogEntry-like objects depending on version. | |
| text = getattr(line, 'data', None) or str(line) | |
| if '[METRICS_JSON]' in text: | |
| payload = text.split('[METRICS_JSON]', 1)[1].strip() | |
| try: | |
| last_json = json.loads(payload) | |
| except Exception: | |
| # Might be truncated on a line boundary — keep looking. | |
| pass | |
| return last_json | |
| def compare(results: dict[int, dict]) -> None: | |
| """Pretty-print comparison across n_layer values.""" | |
| if not results: | |
| print('[agg] no results') | |
| return | |
| sorted_n = sorted(results.keys()) | |
| # Top-level scalars | |
| print('\n=== Top-level scalars ===') | |
| hdr = ['metric'] + [f'L={n}' for n in sorted_n] | |
| print(' '.join(f'{h:>14}' for h in hdr)) | |
| for key in ('val_bpb', 'val_ppl', 'num_params_M', 'total_tokens_M', | |
| 'training_seconds', 'peak_vram_mb', 'sdr_target_active', | |
| 'htm_anomaly', 'engram_hit_rate', 'sdr_active_bits'): | |
| row = [key] + [f'{results[n].get(key, float("nan")):.4f}' if isinstance(results[n].get(key), (int, float)) else 'n/a' for n in sorted_n] | |
| print(' '.join(f'{c:>14}' for c in row)) | |
| # Per-layer panel — one table per metric. | |
| print('\n=== Per-layer: delta_ratio (residual contribution) ===') | |
| print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n])) | |
| max_depth = max(results[n].get('n_layer', 0) for n in sorted_n) | |
| for li in range(max_depth): | |
| row = [f'L{li:02d}'] | |
| for n in sorted_n: | |
| v = results[n].get(f'layer_{li}_delta_ratio') | |
| row.append(f'{v:.4f}' if isinstance(v, (int, float)) else ' -') | |
| print(' '.join(f'{c:>7}' for c in row)) | |
| print('\n=== Per-layer: grad_norm ===') | |
| print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n])) | |
| for li in range(max_depth): | |
| row = [f'L{li:02d}'] | |
| for n in sorted_n: | |
| v = results[n].get(f'layer_{li}_grad_norm') | |
| row.append(f'{v:.2e}' if isinstance(v, (int, float)) else ' -') | |
| print(' '.join(f'{c:>9}' for c in row)) | |
| print('\n=== Per-layer: eff_rank (participation-ratio) ===') | |
| print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n])) | |
| for li in range(max_depth): | |
| row = [f'L{li:02d}'] | |
| for n in sorted_n: | |
| v = results[n].get(f'layer_{li}_eff_rank') | |
| row.append(f'{v:.1f}' if isinstance(v, (int, float)) else ' -') | |
| print(' '.join(f'{c:>7}' for c in row)) | |
| print('\n=== Per-layer: feat_std ===') | |
| print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n])) | |
| for li in range(max_depth): | |
| row = [f'L{li:02d}'] | |
| for n in sorted_n: | |
| v = results[n].get(f'layer_{li}_feat_std') | |
| row.append(f'{v:.4f}' if isinstance(v, (int, float)) else ' -') | |
| print(' '.join(f'{c:>7}' for c in row)) | |
| # Dead-layer detection | |
| print('\n=== Dead-layer detection (delta_ratio < 0.02) ===') | |
| for n in sorted_n: | |
| r = results[n] | |
| n_layer = r.get('n_layer', 0) | |
| dead = [] | |
| for li in range(n_layer): | |
| v = r.get(f'layer_{li}_delta_ratio') | |
| if isinstance(v, (int, float)) and v < 0.02: | |
| dead.append(li) | |
| status = 'ALL LIVE' if not dead else f'DEAD LAYERS: {dead}' | |
| print(f' n_layer={n:2d} val_bpb={r.get("val_bpb", float("nan")):.4f} {status}') | |
| def main() -> int: | |
| if not MANIFEST.exists(): | |
| print(f'ERROR: manifest not found at {MANIFEST}', file=sys.stderr) | |
| return 2 | |
| lines = MANIFEST.read_text().splitlines()[1:] # skip header | |
| jobs = {} | |
| for ln in lines: | |
| parts = ln.strip().split('\t') | |
| if len(parts) < 2: | |
| continue | |
| try: | |
| n_layer = int(parts[0]) | |
| job_id = parts[1] | |
| except ValueError: | |
| continue | |
| jobs[n_layer] = job_id | |
| print(f'[agg] reading {len(jobs)} jobs from {MANIFEST}') | |
| results: dict[int, dict] = {} | |
| for n, jid in jobs.items(): | |
| print(f'[agg] fetching job={jid} (n_layer={n}) ...') | |
| m = fetch_metrics_from_job(jid) | |
| if m is None: | |
| print(f'[agg] no metrics for n_layer={n} (job likely still running or failed)') | |
| continue | |
| results[n] = m | |
| compare(results) | |
| out_path = Path('/tmp/sweep_depth_aggregated.json') | |
| out_path.write_text(json.dumps(results, indent=2, sort_keys=True)) | |
| print(f'\n[agg] wrote aggregated results to {out_path}') | |
| return 0 | |
| if __name__ == '__main__': | |
| raise SystemExit(main()) | |