"""Analyze the continuation training jsonl: compute recovery rate, project when val_bpc will reach the original baseline of 6.16. Usage: # Pull latest log from VM scp -P 36911 root@ssh6.vast.ai:/root/bitnet1/logs/fineweb_edu_v75_75M_continue.jsonl logs/ python3 _analyze_continue_traj.py logs/fineweb_edu_v75_75M_continue.jsonl """ import json, sys, math def main(): path = sys.argv[1] if len(sys.argv) > 1 else 'logs/fineweb_edu_v75_75M_continue.jsonl' target_bpc = 6.16 # original converged best from previous run base_lr = 1e-4 total_steps = 400_000 warmup_steps = 2000 evals = [] trains = [] for line in open(path): try: d = json.loads(line) except Exception: continue ev = d.get('event') if ev == 'eval': evals.append((d['step'], d['val_bpc'])) elif ev == 'train': trains.append((d['step'], d['lr'], d['tok_per_s'])) if not evals: print('no eval events yet') return print(f'=== val_bpc trajectory (target: {target_bpc}) ===') print(f'{"step":>8s} {"val_bpc":>10s} {"Δ":>8s} {"gap":>7s}') prev = None for step, bpc in evals: delta = (bpc - prev[1]) if prev else 0.0 gap = bpc - target_bpc print(f'{step:>8d} {bpc:>10.4f} {delta:>+8.4f} {gap:>+7.3f}') prev = (step, bpc) # Recovery rate from last 3 windows if len(evals) >= 3: recent = evals[-3:] total_drop = recent[0][1] - recent[-1][1] total_steps_window = recent[-1][0] - recent[0][0] rate = total_drop / max(1, total_steps_window) # bpc per step print(f'\nRecent recovery rate: {rate*1000:+.4f} bpc/1K steps ' f'over last {total_steps_window} steps') if rate > 1e-6: current_bpc = evals[-1][1] current_step = evals[-1][0] steps_to_recover = (current_bpc - target_bpc) / rate target_step = current_step + steps_to_recover tps = trains[-1][2] if trains else 45000 seconds_remaining = steps_to_recover * 65536 / tps print(f'At this rate, val_bpc=6.16 reached at step ' f'~{int(target_step):,} (~{seconds_remaining/3600:.1f} h)') else: print('Recovery rate near zero — model has plateaued at this LR.') # LR forecast — detect constant vs cosine schedule from logged data print('\n=== LR forecast ===') if trains: cur_step = trains[-1][0] cur_lr_logged = trains[-1][1] # Sniff for constant LR (last 10 events all the same) recent_lrs = [lr for _, lr, _ in trains[-10:]] constant_mode = len(set(f'{l:.3e}' for l in recent_lrs)) == 1 if constant_mode: tps = trains[-1][2] sec_per_step = 65536 / tps steps_per_hour = 3600 / sec_per_step print(f'Current step {cur_step:,}: constant LR={cur_lr_logged:.2e}') print(f'(constant-LR run; cosine forecast not applicable)') print(f'Throughput: {tps:.0f} tok/s ({sec_per_step*1000:.0f} ms/step, ' f'{steps_per_hour:.0f} steps/h)') else: def lr_at(step): prog = (step - warmup_steps) / max(1, total_steps - warmup_steps) return base_lr * 0.5 * (1 + math.cos(math.pi * prog)) cur_lr_calc = lr_at(cur_step) print(f'Current step {cur_step:,}: logged LR={cur_lr_logged:.2e} ' f'(calc {cur_lr_calc:.2e})') for target_lr in [1e-5, 1e-6, 1e-8]: cos_v = 2 * target_lr / base_lr - 1 cos_v = max(-1.0, min(1.0, cos_v)) prog = math.acos(cos_v) / math.pi step = warmup_steps + prog * (total_steps - warmup_steps) steps_to = step - cur_step tps = trains[-1][2] hours = steps_to * 65536 / tps / 3600 print(f'LR={target_lr:.0e} reached at step {int(step):,} ' f'(+{int(steps_to):,} steps, ~{hours:.1f} h)') if __name__ == '__main__': main()