File size: 5,776 Bytes
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
6707b9b
 
 
 
 
e317e25
4a0d594
 
e317e25
6707b9b
e317e25
 
 
 
 
 
 
 
 
 
 
 
6707b9b
 
 
 
 
 
 
 
 
 
 
e317e25
6707b9b
 
e317e25
4a0d594
 
 
 
 
 
 
 
e317e25
 
6707b9b
 
e317e25
6707b9b
 
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6707b9b
 
e317e25
 
 
 
 
 
6707b9b
 
 
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python3
"""Aggregator for depth-sweep results.

Reads the sweep manifest at /tmp/sweep_depth_manifest.txt, pulls HF Jobs logs
for each job, extracts the [METRICS_JSON] stdout line, and prints a
comparison table of per-layer diagnostics across n_layer values.

Usage:
    export HF_TOKEN=...
    python scripts/sweep_depth_aggregate.py [manifest_path]
"""
from __future__ import annotations

import json
import os
import sys
from pathlib import Path

MANIFEST = Path(sys.argv[1] if len(sys.argv) > 1 else '/tmp/sweep_depth_manifest.txt')


def fetch_metrics_from_job(job_id: str) -> dict | None:
    """Fetch HF Job stdout and parse the [METRICS_JSON] line."""
    try:
        from huggingface_hub import HfApi  # type: ignore
    except Exception as e:
        print(f'ERROR: huggingface_hub missing: {e}', file=sys.stderr)
        return None
    api = HfApi(token=os.environ.get('HF_TOKEN'))
    try:
        logs_stream = api.fetch_job_logs(job_id=job_id)
    except Exception as e:
        print(f'[agg] could not fetch logs for job={job_id}: {e}', file=sys.stderr)
        return None

    last_json = None
    for line in logs_stream:
        # HfApi returns strings or JobLogEntry-like objects depending on version.
        text = getattr(line, 'data', None) or str(line)
        if '[METRICS_JSON]' in text:
            payload = text.split('[METRICS_JSON]', 1)[1].strip()
            try:
                last_json = json.loads(payload)
            except Exception:
                # Might be truncated on a line boundary — keep looking.
                pass
    return last_json


def compare(results: dict[int, dict]) -> None:
    """Pretty-print comparison across n_layer values."""
    if not results:
        print('[agg] no results')
        return
    sorted_n = sorted(results.keys())

    # Top-level scalars
    print('\n=== Top-level scalars ===')
    hdr = ['metric'] + [f'L={n}' for n in sorted_n]
    print('  '.join(f'{h:>14}' for h in hdr))
    for key in ('val_bpb', 'val_ppl', 'num_params_M', 'total_tokens_M',
                'training_seconds', 'peak_vram_mb', 'sdr_target_active',
                'htm_anomaly', 'engram_hit_rate', 'sdr_active_bits'):
        row = [key] + [f'{results[n].get(key, float("nan")):.4f}' if isinstance(results[n].get(key), (int, float)) else 'n/a' for n in sorted_n]
        print('  '.join(f'{c:>14}' for c in row))

    # Per-layer panel — one table per metric.
    print('\n=== Per-layer: delta_ratio (residual contribution) ===')
    print('  '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
    max_depth = max(results[n].get('n_layer', 0) for n in sorted_n)
    for li in range(max_depth):
        row = [f'L{li:02d}']
        for n in sorted_n:
            v = results[n].get(f'layer_{li}_delta_ratio')
            row.append(f'{v:.4f}' if isinstance(v, (int, float)) else '     -')
        print('  '.join(f'{c:>7}' for c in row))

    print('\n=== Per-layer: grad_norm ===')
    print('  '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
    for li in range(max_depth):
        row = [f'L{li:02d}']
        for n in sorted_n:
            v = results[n].get(f'layer_{li}_grad_norm')
            row.append(f'{v:.2e}' if isinstance(v, (int, float)) else '     -')
        print('  '.join(f'{c:>9}' for c in row))

    print('\n=== Per-layer: eff_rank (participation-ratio) ===')
    print('  '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
    for li in range(max_depth):
        row = [f'L{li:02d}']
        for n in sorted_n:
            v = results[n].get(f'layer_{li}_eff_rank')
            row.append(f'{v:.1f}' if isinstance(v, (int, float)) else '    -')
        print('  '.join(f'{c:>7}' for c in row))

    print('\n=== Per-layer: feat_std ===')
    print('  '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
    for li in range(max_depth):
        row = [f'L{li:02d}']
        for n in sorted_n:
            v = results[n].get(f'layer_{li}_feat_std')
            row.append(f'{v:.4f}' if isinstance(v, (int, float)) else '     -')
        print('  '.join(f'{c:>7}' for c in row))

    # Dead-layer detection
    print('\n=== Dead-layer detection (delta_ratio < 0.02) ===')
    for n in sorted_n:
        r = results[n]
        n_layer = r.get('n_layer', 0)
        dead = []
        for li in range(n_layer):
            v = r.get(f'layer_{li}_delta_ratio')
            if isinstance(v, (int, float)) and v < 0.02:
                dead.append(li)
        status = 'ALL LIVE' if not dead else f'DEAD LAYERS: {dead}'
        print(f'  n_layer={n:2d}  val_bpb={r.get("val_bpb", float("nan")):.4f}  {status}')


def main() -> int:
    if not MANIFEST.exists():
        print(f'ERROR: manifest not found at {MANIFEST}', file=sys.stderr)
        return 2
    lines = MANIFEST.read_text().splitlines()[1:]  # skip header
    jobs = {}
    for ln in lines:
        parts = ln.strip().split('\t')
        if len(parts) < 2:
            continue
        try:
            n_layer = int(parts[0])
            job_id = parts[1]
        except ValueError:
            continue
        jobs[n_layer] = job_id

    print(f'[agg] reading {len(jobs)} jobs from {MANIFEST}')
    results: dict[int, dict] = {}
    for n, jid in jobs.items():
        print(f'[agg] fetching job={jid} (n_layer={n}) ...')
        m = fetch_metrics_from_job(jid)
        if m is None:
            print(f'[agg]   no metrics for n_layer={n} (job likely still running or failed)')
            continue
        results[n] = m
    compare(results)

    out_path = Path('/tmp/sweep_depth_aggregated.json')
    out_path.write_text(json.dumps(results, indent=2, sort_keys=True))
    print(f'\n[agg] wrote aggregated results to {out_path}')
    return 0


if __name__ == '__main__':
    raise SystemExit(main())