""" GPU worker for 200k-checkpoint analysis. Processes analysis tasks on a single GPU, reusing the model for all tasks. Based on 100k-checkpoints/gpu_worker.py with grid-run intensity values. """ import argparse import json import os import sys import time import numpy as np import torch sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run')) from model_analysis import GPT, GPTConfig, GPTIntervention def remap_state_dict(sd): new_sd = {} for key, val in sd.items(): new_key = key for i in range(10): new_key = new_key.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.') new_key = new_key.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.') new_sd[new_key] = val return new_sd def load_model(ckpt_path, device): ckpt = torch.load(ckpt_path, map_location='cpu') mc = ckpt['model_config'] vocab_size = mc['vocab_size'] - 1 block_size = mc['block_size'] with_layer_norm = mc.get('use_final_LN', True) config = GPTConfig(block_size=block_size, vocab_size=vocab_size, with_layer_norm=with_layer_norm) model = GPT(config) sd = remap_state_dict(ckpt['model_state_dict']) grid_wpe_size = block_size * 4 + 1 if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > grid_wpe_size: sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:grid_wpe_size] keys_to_skip = [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k] for k in keys_to_skip: del sd[k] if 'lm_head.weight' in sd: del sd['lm_head.weight'] model.load_state_dict(sd, strict=False) model.to(device) model.eval() return model, config def get_batch(vocab_size, block_size, device='cpu'): x = torch.randperm(vocab_size)[:block_size] vals, _ = torch.sort(x) return torch.cat((x, torch.tensor([vocab_size]), vals), dim=0).unsqueeze(0).to(device) def compute_cinclogits(model, config, device, attn_layer, num_tries=100): bs = config.block_size vs = config.vocab_size acc_cl = np.zeros(bs) acc_icl = np.zeros(bs) for _ in range(num_tries): idx = get_batch(vs, bs, device) with torch.no_grad(): logits, _ = model(idx) is_correct = (torch.argmax(logits[0, bs:2*bs, :], dim=1) == idx[0, bs+1:]) attn_w = model.transformer.h[attn_layer].c_attn.attn for j in range(bs, 2*bs): max_s, max_n = float('-inf'), -1 for k in range(2*bs+1): s = attn_w[j, k].item() if s > max_s: max_s = s max_n = idx[0, k].item() sc = (max_n == idx[0, j+1].item()) pos = j - bs lc = is_correct[pos].item() if lc and not sc: acc_cl[pos] += 1.0 elif not lc and not sc: acc_icl[pos] += 1.0 return acc_cl / num_tries, acc_icl / num_tries def compute_intensity(model, config, device, attn_layer, ub=5, lb=None, ub_num=1, lb_num=0, min_valid=200): if lb is None: lb = ub bs = config.block_size vs = config.vocab_size location = bs + 5 intensities = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0] rates, counts = [], [] for intens in intensities: attempts, rounds = [], 0 while len(attempts) < min_valid and rounds < 2000: rounds += 1 idx = get_batch(vs, bs, device) try: im = GPTIntervention(model, idx) im.intervent_attention( attention_layer_num=attn_layer, location=location, unsorted_lb=lb, unsorted_ub=ub, unsorted_lb_num=lb_num, unsorted_ub_num=ub_num, unsorted_intensity_inc=intens, sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) g, n = im.check_if_still_works() attempts.append(g == n) im.revert_attention(attn_layer) except: continue counts.append(len(attempts)) rates.append(sum(attempts) / len(attempts) if attempts else 0.0) return np.array(intensities), np.array(rates), np.array(counts) def compute_ablation(model, config, device, skip_layer, num_trials=500): bs = config.block_size block = model.transformer.h[skip_layer] orig_fwd = block.forward def skip_attn(x, layer_n=-1): return x + block.c_fc(block.ln_2(x)) block.forward = skip_attn pp = np.zeros(bs) fs = 0 cc = np.zeros(bs) ce = np.zeros(bs) try: for _ in range(num_trials): idx = get_batch(config.vocab_size, bs, device) with torch.no_grad(): logits, _ = model(idx) preds = torch.argmax(logits[0, bs:2*bs, :], dim=1) targets = idx[0, bs+1:] correct = (preds == targets).cpu().numpy() pp += correct if correct.all(): fs += 1 ok = True for i in range(bs): if ok: ce[i] += 1 if correct[i]: cc[i] += 1 else: ok = False else: break finally: block.forward = orig_fwd return pp / num_trials, fs / num_trials, np.where(ce > 0, cc / ce, 0.0), ce def compute_baseline(model, config, device, num_trials=500): bs = config.block_size vs = config.vocab_size pp = np.zeros(bs) fs = 0 cc = np.zeros(bs) ce = np.zeros(bs) for _ in range(num_trials): idx = get_batch(vs, bs, device) with torch.no_grad(): logits, _ = model(idx) preds = torch.argmax(logits[0, bs:2*bs, :], dim=1) targets = idx[0, bs+1:] correct = (preds == targets).cpu().numpy() pp += correct if correct.all(): fs += 1 ok = True for i in range(bs): if ok: ce[i] += 1 if correct[i]: cc[i] += 1 else: ok = False else: break return pp / num_trials, fs / num_trials, np.where(ce > 0, cc / ce, 0.0), ce def process_task(task, model, config, device, itr): task_type = task['type'] out_path = task['out'] if os.path.exists(out_path): return True os.makedirs(os.path.dirname(out_path), exist_ok=True) if task_type == 'baseline': pp, fs, ca, ce = compute_baseline(model, config, device) np.savez(out_path, per_pos_acc=pp, full_seq_acc=fs, cond_acc=ca, cond_eligible=ce, itr=itr) elif task_type == 'ablation': pp, fs, ca, ce = compute_ablation(model, config, device, task['layer']) np.savez(out_path, per_pos_acc=pp, full_seq_acc=fs, cond_acc=ca, cond_eligible=ce, skip_layer=task['layer'], itr=itr) elif task_type == 'cinclogits': cl, icl = compute_cinclogits(model, config, device, task['layer']) np.savez(out_path, clogit_icscore=cl, iclogit_icscore=icl, itr=itr) elif task_type == 'intensity': intensities, rates, counts = compute_intensity( model, config, device, task['layer'], ub=task['ub']) np.savez(out_path, intensities=intensities, success_rates=rates, counts=counts, itr=itr) elif task_type == 'intensity_asym': intensities, rates, counts = compute_intensity( model, config, device, task['layer'], ub=task['unsorted_ub'], lb=task['unsorted_lb'], ub_num=task['unsorted_ub_num'], lb_num=task['unsorted_lb_num']) np.savez(out_path, intensities=intensities, success_rates=rates, counts=counts, itr=itr) return True def main(): parser = argparse.ArgumentParser() parser.add_argument('--tasks-file', required=True) parser.add_argument('--gpu', type=int, required=True) args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) device = 'cuda' with open(args.tasks_file) as f: task_list = json.load(f) print(f"GPU {args.gpu}: {len(task_list)} tasks", flush=True) current_model = None current_ckpt = None done = 0 for task in task_list: ckpt_path = task['ckpt_path'] if ckpt_path != current_ckpt: t0 = time.time() model, config = load_model(ckpt_path, device) current_model = model current_ckpt = ckpt_path itr = task.get('itr', 200000) print(f" Loaded {os.path.basename(ckpt_path)} ({time.time()-t0:.1f}s)", flush=True) t0 = time.time() try: process_task(task, current_model, config, device, itr) dt = time.time() - t0 done += 1 print(json.dumps({ 'status': 'done', 'task': task['name'], 'gpu': args.gpu, 'elapsed': round(dt, 1), 'progress': f'{done}/{len(task_list)}' }), flush=True) except Exception as e: done += 1 print(json.dumps({ 'status': 'fail', 'task': task['name'], 'gpu': args.gpu, 'error': str(e) }), flush=True) if __name__ == '__main__': main()