bitnet-1bitllm / _analyze_continue_traj.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Analyze the continuation training jsonl: compute recovery rate, project
when val_bpc will reach the original baseline of 6.16.
Usage:
# Pull latest log from VM
scp -P 36911 root@ssh6.vast.ai:/root/bitnet1/logs/fineweb_edu_v75_75M_continue.jsonl logs/
python3 _analyze_continue_traj.py logs/fineweb_edu_v75_75M_continue.jsonl
"""
import json, sys, math
def main():
path = sys.argv[1] if len(sys.argv) > 1 else 'logs/fineweb_edu_v75_75M_continue.jsonl'
target_bpc = 6.16 # original converged best from previous run
base_lr = 1e-4
total_steps = 400_000
warmup_steps = 2000
evals = []
trains = []
for line in open(path):
try:
d = json.loads(line)
except Exception:
continue
ev = d.get('event')
if ev == 'eval':
evals.append((d['step'], d['val_bpc']))
elif ev == 'train':
trains.append((d['step'], d['lr'], d['tok_per_s']))
if not evals:
print('no eval events yet')
return
print(f'=== val_bpc trajectory (target: {target_bpc}) ===')
print(f'{"step":>8s} {"val_bpc":>10s} {"Δ":>8s} {"gap":>7s}')
prev = None
for step, bpc in evals:
delta = (bpc - prev[1]) if prev else 0.0
gap = bpc - target_bpc
print(f'{step:>8d} {bpc:>10.4f} {delta:>+8.4f} {gap:>+7.3f}')
prev = (step, bpc)
# Recovery rate from last 3 windows
if len(evals) >= 3:
recent = evals[-3:]
total_drop = recent[0][1] - recent[-1][1]
total_steps_window = recent[-1][0] - recent[0][0]
rate = total_drop / max(1, total_steps_window) # bpc per step
print(f'\nRecent recovery rate: {rate*1000:+.4f} bpc/1K steps '
f'over last {total_steps_window} steps')
if rate > 1e-6:
current_bpc = evals[-1][1]
current_step = evals[-1][0]
steps_to_recover = (current_bpc - target_bpc) / rate
target_step = current_step + steps_to_recover
tps = trains[-1][2] if trains else 45000
seconds_remaining = steps_to_recover * 65536 / tps
print(f'At this rate, val_bpc=6.16 reached at step '
f'~{int(target_step):,} (~{seconds_remaining/3600:.1f} h)')
else:
print('Recovery rate near zero — model has plateaued at this LR.')
# LR forecast — detect constant vs cosine schedule from logged data
print('\n=== LR forecast ===')
if trains:
cur_step = trains[-1][0]
cur_lr_logged = trains[-1][1]
# Sniff for constant LR (last 10 events all the same)
recent_lrs = [lr for _, lr, _ in trains[-10:]]
constant_mode = len(set(f'{l:.3e}' for l in recent_lrs)) == 1
if constant_mode:
tps = trains[-1][2]
sec_per_step = 65536 / tps
steps_per_hour = 3600 / sec_per_step
print(f'Current step {cur_step:,}: constant LR={cur_lr_logged:.2e}')
print(f'(constant-LR run; cosine forecast not applicable)')
print(f'Throughput: {tps:.0f} tok/s ({sec_per_step*1000:.0f} ms/step, '
f'{steps_per_hour:.0f} steps/h)')
else:
def lr_at(step):
prog = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return base_lr * 0.5 * (1 + math.cos(math.pi * prog))
cur_lr_calc = lr_at(cur_step)
print(f'Current step {cur_step:,}: logged LR={cur_lr_logged:.2e} '
f'(calc {cur_lr_calc:.2e})')
for target_lr in [1e-5, 1e-6, 1e-8]:
cos_v = 2 * target_lr / base_lr - 1
cos_v = max(-1.0, min(1.0, cos_v))
prog = math.acos(cos_v) / math.pi
step = warmup_steps + prog * (total_steps - warmup_steps)
steps_to = step - cur_step
tps = trains[-1][2]
hours = steps_to * 65536 / tps / 3600
print(f'LR={target_lr:.0e} reached at step {int(step):,} '
f'(+{int(steps_to):,} steps, ~{hours:.1f} h)')
if __name__ == '__main__':
main()