llm-sort / run_pernumber.py
gatmiry's picture
Upload folder using huggingface_hub
beda614 verified
"""
Per-number intervention experiment for 100k checkpoints.
Fixes a target number (value) from the vocabulary, generates sequences
containing it, finds its sorted-output position, and intervenes there.
5 target numbers spread across vocab: 25, 75, 128, 180, 230.
Intensity values: [1.0, 2.0, 4.0, 6.0, 8.0, 10.0], ub=lb=60.
One worker per GPU, plots assembled on the fly.
"""
import json
import os
import subprocess
import sys
import time
import glob
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
CKPT_DIR = os.path.join(SCRIPT_DIR, 'final_models')
OUTPUT_BASE = os.path.join(SCRIPT_DIR, 'outputs')
NUM_GPUS = 8
TARGET_NUMBERS = [25, 75, 128, 180, 230]
INTENSITY_VALUES = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0]
def discover_checkpoints():
pt_files = sorted(glob.glob(os.path.join(CKPT_DIR, '*.pt')))
checkpoints = []
for pt in pt_files:
bn = os.path.basename(pt)
if '.summary.' in bn:
continue
parts = bn.replace('.pt', '').split('__')
config_str = parts[0]
ckpt_type = parts[1] if len(parts) > 1 else 'final'
tokens = config_str.split('_')
dseed = iseed = None
for t in tokens:
if t.startswith('dseed'):
dseed = t.replace('dseed', '')
elif t.startswith('iseed'):
iseed = t.replace('iseed', '')
if ckpt_type.startswith('ckpt'):
itr = int(ckpt_type.replace('ckpt', ''))
ckpt_label = f'ckpt{itr}'
else:
itr = 100000
ckpt_label = 'final'
folder_name = f"plots_N256_B16_ds{dseed}_is{iseed}_{ckpt_label}"
checkpoints.append({
'path': pt, 'dseed': dseed, 'iseed': iseed,
'itr': itr, 'ckpt_label': ckpt_label, 'folder_name': folder_name,
})
return checkpoints
def make_tasks(checkpoints):
tasks = []
for ckpt in checkpoints:
tmp_dir = os.path.join(OUTPUT_BASE, 'tmp_results', ckpt['folder_name'], 'pernumber')
for num in TARGET_NUMBERS:
for layer in [0, 1]:
out = os.path.join(tmp_dir, f'intensity_num{num}_layer{layer}.npz')
tasks.append({
'ckpt_path': ckpt['path'],
'folder_name': ckpt['folder_name'],
'target_num': num,
'layer': layer,
'out': out,
'itr': ckpt['itr'],
'dseed': ckpt['dseed'],
'iseed': ckpt['iseed'],
'name': f"{ckpt['folder_name']}_num{num}_L{layer}",
})
return tasks
def is_ckpt_done(ckpt):
tmp_dir = os.path.join(OUTPUT_BASE, 'tmp_results', ckpt['folder_name'], 'pernumber')
for num in TARGET_NUMBERS:
for layer in [0, 1]:
if not os.path.exists(os.path.join(tmp_dir, f'intensity_num{num}_layer{layer}.npz')):
return False
return True
def assemble_plots_for_ckpt(ckpt):
folder_name = ckpt['folder_name']
tmp_dir = os.path.join(OUTPUT_BASE, 'tmp_results', folder_name, 'pernumber')
plot_dir = os.path.join(OUTPUT_BASE, folder_name, 'pernumber')
os.makedirs(plot_dir, exist_ok=True)
tag = (f"N=256 block=16 lr=0.01 iters={ckpt['itr']} "
f"dseed={ckpt['dseed']} iseed={ckpt['iseed']}")
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
markers = ['o', 's', '^', 'D', 'v']
for layer in [0, 1]:
fig, ax = plt.subplots(figsize=(5.5, 4))
for i, num in enumerate(TARGET_NUMBERS):
f = os.path.join(tmp_dir, f'intensity_num{num}_layer{layer}.npz')
if not os.path.exists(f):
continue
d = np.load(f)
ax.plot(d['intensities'], d['success_rates'],
marker=markers[i], linewidth=1.5, markersize=5,
color=colors[i], label=f'number {num}')
ax.set_xlabel('Intervention Intensity', fontsize=10)
ax.set_ylabel('Success Probability', fontsize=10)
title = f'Per-Number Intervention (Layer {layer}, ub=60)\n{tag}'
ax.set_title(title, fontsize=10, fontweight='bold')
ax.legend(fontsize=8, loc='lower left')
ax.grid(True, alpha=0.3)
ax.set_xticks(INTENSITY_VALUES)
ax.set_ylim(0, 1.05)
fig.tight_layout()
fig.savefig(os.path.join(plot_dir, f'pernumber_layer{layer}.png'),
dpi=300, bbox_inches='tight')
plt.close()
return 2
def main():
t_start = time.time()
checkpoints = discover_checkpoints()
print(f"Found {len(checkpoints)} checkpoints")
all_tasks = make_tasks(checkpoints)
todo = [t for t in all_tasks if not os.path.exists(t['out'])]
cached = len(all_tasks) - len(todo)
print(f"Total tasks: {len(all_tasks)}, cached: {cached}, to run: {len(todo)}")
assembled = set()
for ckpt in checkpoints:
if is_ckpt_done(ckpt):
assemble_plots_for_ckpt(ckpt)
assembled.add(ckpt['folder_name'])
print(f" [PLOTS] {ckpt['folder_name']} (cached)", flush=True)
if not todo:
print("All done!")
return
gpu_tasks = {g: [] for g in range(NUM_GPUS)}
ckpt_to_gpu = {}
for i, ckpt in enumerate(checkpoints):
ckpt_to_gpu[ckpt['folder_name']] = i % NUM_GPUS
for t in todo:
g = ckpt_to_gpu[t['folder_name']]
gpu_tasks[g].append(t)
for g in gpu_tasks:
gpu_tasks[g].sort(key=lambda t: (t['ckpt_path'], t['target_num'], t['layer']))
total_to_run = len(todo)
print(f"\nDistributed {total_to_run} tasks across {NUM_GPUS} GPUs:")
for g in range(NUM_GPUS):
n = len(gpu_tasks[g])
ckpts = len(set(t['ckpt_path'] for t in gpu_tasks[g])) if n else 0
print(f" GPU {g}: {n} tasks across {ckpts} checkpoints")
task_dir = os.path.join(OUTPUT_BASE, 'task_files')
os.makedirs(task_dir, exist_ok=True)
log_dir = os.path.join(OUTPUT_BASE, 'pernumber_logs')
os.makedirs(log_dir, exist_ok=True)
procs = {}
for g in range(NUM_GPUS):
if not gpu_tasks[g]:
continue
tf = os.path.join(task_dir, f'pernum_gpu{g}.json')
with open(tf, 'w') as f:
json.dump(gpu_tasks[g], f)
log_file = open(os.path.join(log_dir, f'gpu{g}.log'), 'w')
proc = subprocess.Popen(
[sys.executable, os.path.join(SCRIPT_DIR, 'pernumber_worker.py'),
'--tasks-file', tf, '--gpu', str(g)],
stdout=log_file, stderr=subprocess.STDOUT, cwd=SCRIPT_DIR)
procs[g] = proc
print(f"\nLaunched {len(procs)} workers. Monitoring...\n", flush=True)
last_print = 0
while any(p.poll() is None for p in procs.values()):
time.sleep(5)
for ckpt in checkpoints:
fn = ckpt['folder_name']
if fn not in assembled and is_ckpt_done(ckpt):
assemble_plots_for_ckpt(ckpt)
assembled.add(fn)
elapsed = time.time() - t_start
print(f" [PLOTS] {fn}: 2 plots ({elapsed:.0f}s)", flush=True)
done_now = sum(1 for t in all_tasks if os.path.exists(t['out']))
elapsed = time.time() - t_start
if done_now >= last_print + 10:
last_print = done_now
rate = done_now / elapsed if elapsed > 0 else 0
eta = (len(all_tasks) - done_now) / rate if rate > 0 else 0
print(f" [PROGRESS] {done_now}/{len(all_tasks)} tasks, "
f"{len(assembled)}/{len(checkpoints)} ckpts plotted "
f"({elapsed:.0f}s, ETA ~{eta:.0f}s)", flush=True)
for ckpt in checkpoints:
fn = ckpt['folder_name']
if fn not in assembled and is_ckpt_done(ckpt):
assemble_plots_for_ckpt(ckpt)
assembled.add(fn)
print(f" [PLOTS] {fn}: 2 plots (final)", flush=True)
for g, proc in procs.items():
if proc.returncode != 0:
print(f" [WARN] GPU {g} exited with code {proc.returncode}", flush=True)
elapsed = time.time() - t_start
print(f"\n{'='*60}")
print(f"ALL DONE — {len(assembled)}/{len(checkpoints)} checkpoints plotted")
print(f"Elapsed: {elapsed:.0f}s ({elapsed/60:.1f}m)")
print(f"{'='*60}")
if __name__ == '__main__':
main()