| """Analyze the continuation training jsonl: compute recovery rate, project |
| when val_bpc will reach the original baseline of 6.16. |
| |
| Usage: |
| # Pull latest log from VM |
| scp -P 36911 root@ssh6.vast.ai:/root/bitnet1/logs/fineweb_edu_v75_75M_continue.jsonl logs/ |
| python3 _analyze_continue_traj.py logs/fineweb_edu_v75_75M_continue.jsonl |
| """ |
| import json, sys, math |
|
|
|
|
| def main(): |
| path = sys.argv[1] if len(sys.argv) > 1 else 'logs/fineweb_edu_v75_75M_continue.jsonl' |
| target_bpc = 6.16 |
| base_lr = 1e-4 |
| total_steps = 400_000 |
| warmup_steps = 2000 |
|
|
| evals = [] |
| trains = [] |
| for line in open(path): |
| try: |
| d = json.loads(line) |
| except Exception: |
| continue |
| ev = d.get('event') |
| if ev == 'eval': |
| evals.append((d['step'], d['val_bpc'])) |
| elif ev == 'train': |
| trains.append((d['step'], d['lr'], d['tok_per_s'])) |
|
|
| if not evals: |
| print('no eval events yet') |
| return |
|
|
| print(f'=== val_bpc trajectory (target: {target_bpc}) ===') |
| print(f'{"step":>8s} {"val_bpc":>10s} {"Δ":>8s} {"gap":>7s}') |
| prev = None |
| for step, bpc in evals: |
| delta = (bpc - prev[1]) if prev else 0.0 |
| gap = bpc - target_bpc |
| print(f'{step:>8d} {bpc:>10.4f} {delta:>+8.4f} {gap:>+7.3f}') |
| prev = (step, bpc) |
|
|
| |
| if len(evals) >= 3: |
| recent = evals[-3:] |
| total_drop = recent[0][1] - recent[-1][1] |
| total_steps_window = recent[-1][0] - recent[0][0] |
| rate = total_drop / max(1, total_steps_window) |
| print(f'\nRecent recovery rate: {rate*1000:+.4f} bpc/1K steps ' |
| f'over last {total_steps_window} steps') |
|
|
| if rate > 1e-6: |
| current_bpc = evals[-1][1] |
| current_step = evals[-1][0] |
| steps_to_recover = (current_bpc - target_bpc) / rate |
| target_step = current_step + steps_to_recover |
| tps = trains[-1][2] if trains else 45000 |
| seconds_remaining = steps_to_recover * 65536 / tps |
| print(f'At this rate, val_bpc=6.16 reached at step ' |
| f'~{int(target_step):,} (~{seconds_remaining/3600:.1f} h)') |
| else: |
| print('Recovery rate near zero — model has plateaued at this LR.') |
|
|
| |
| print('\n=== LR forecast ===') |
| if trains: |
| cur_step = trains[-1][0] |
| cur_lr_logged = trains[-1][1] |
| |
| recent_lrs = [lr for _, lr, _ in trains[-10:]] |
| constant_mode = len(set(f'{l:.3e}' for l in recent_lrs)) == 1 |
| if constant_mode: |
| tps = trains[-1][2] |
| sec_per_step = 65536 / tps |
| steps_per_hour = 3600 / sec_per_step |
| print(f'Current step {cur_step:,}: constant LR={cur_lr_logged:.2e}') |
| print(f'(constant-LR run; cosine forecast not applicable)') |
| print(f'Throughput: {tps:.0f} tok/s ({sec_per_step*1000:.0f} ms/step, ' |
| f'{steps_per_hour:.0f} steps/h)') |
| else: |
| def lr_at(step): |
| prog = (step - warmup_steps) / max(1, total_steps - warmup_steps) |
| return base_lr * 0.5 * (1 + math.cos(math.pi * prog)) |
| cur_lr_calc = lr_at(cur_step) |
| print(f'Current step {cur_step:,}: logged LR={cur_lr_logged:.2e} ' |
| f'(calc {cur_lr_calc:.2e})') |
| for target_lr in [1e-5, 1e-6, 1e-8]: |
| cos_v = 2 * target_lr / base_lr - 1 |
| cos_v = max(-1.0, min(1.0, cos_v)) |
| prog = math.acos(cos_v) / math.pi |
| step = warmup_steps + prog * (total_steps - warmup_steps) |
| steps_to = step - cur_step |
| tps = trains[-1][2] |
| hours = steps_to * 65536 / tps / 3600 |
| print(f'LR={target_lr:.0e} reached at step {int(step):,} ' |
| f'(+{int(steps_to):,} steps, ~{hours:.1f} h)') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|