""" GPU worker for 1000k-checkpoint analysis. Processes all task types on a single GPU: baseline, ablation, cinclogits, intensity (various ub), asymmetric intensity, hijack, separator/random. """ import argparse import json import os import sys import time import types import numpy as np import torch import torch.nn.functional as F sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run')) from model_analysis import GPT, GPTConfig, GPTIntervention def remap_state_dict(sd): new_sd = {} for key, val in sd.items(): new_key = key for i in range(10): new_key = new_key.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.') new_key = new_key.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.') new_sd[new_key] = val return new_sd def load_model(ckpt_path, device): ckpt = torch.load(ckpt_path, map_location='cpu') mc = ckpt['model_config'] vocab_size = mc['vocab_size'] - 1 block_size = mc['block_size'] with_layer_norm = mc.get('use_final_LN', True) config = GPTConfig(block_size=block_size, vocab_size=vocab_size, with_layer_norm=with_layer_norm) model = GPT(config) sd = remap_state_dict(ckpt['model_state_dict']) grid_wpe_size = block_size * 4 + 1 if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > grid_wpe_size: sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:grid_wpe_size] keys_to_skip = [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k] for k in keys_to_skip: del sd[k] if 'lm_head.weight' in sd: del sd['lm_head.weight'] model.load_state_dict(sd, strict=False) model.to(device) model.eval() return model, config def get_batch(vocab_size, block_size, device='cpu'): x = torch.randperm(vocab_size)[:block_size] vals, _ = torch.sort(x) return torch.cat((x, torch.tensor([vocab_size]), vals), dim=0).unsqueeze(0).to(device) def compute_baseline(model, config, device, num_trials=500): bs = config.block_size vs = config.vocab_size pp = np.zeros(bs) fs = 0 cc = np.zeros(bs) ce = np.zeros(bs) for _ in range(num_trials): idx = get_batch(vs, bs, device) with torch.no_grad(): logits, _ = model(idx) preds = torch.argmax(logits[0, bs:2*bs, :], dim=1) targets = idx[0, bs+1:] correct = (preds == targets).cpu().numpy() pp += correct if correct.all(): fs += 1 ok = True for i in range(bs): if ok: ce[i] += 1 if correct[i]: cc[i] += 1 else: ok = False else: break return pp / num_trials, fs / num_trials, np.where(ce > 0, cc / ce, 0.0), ce def compute_ablation(model, config, device, skip_layer, num_trials=500): bs = config.block_size block = model.transformer.h[skip_layer] orig_fwd = block.forward def skip_attn(x, layer_n=-1): return x + block.c_fc(block.ln_2(x)) block.forward = skip_attn pp = np.zeros(bs) fs = 0 cc = np.zeros(bs) ce = np.zeros(bs) try: for _ in range(num_trials): idx = get_batch(config.vocab_size, bs, device) with torch.no_grad(): logits, _ = model(idx) preds = torch.argmax(logits[0, bs:2*bs, :], dim=1) targets = idx[0, bs+1:] correct = (preds == targets).cpu().numpy() pp += correct if correct.all(): fs += 1 ok = True for i in range(bs): if ok: ce[i] += 1 if correct[i]: cc[i] += 1 else: ok = False else: break finally: block.forward = orig_fwd return pp / num_trials, fs / num_trials, np.where(ce > 0, cc / ce, 0.0), ce def compute_cinclogits(model, config, device, attn_layer, num_tries=100): bs = config.block_size vs = config.vocab_size acc_cl = np.zeros(bs) acc_icl = np.zeros(bs) for _ in range(num_tries): idx = get_batch(vs, bs, device) with torch.no_grad(): logits, _ = model(idx) is_correct = (torch.argmax(logits[0, bs:2*bs, :], dim=1) == idx[0, bs+1:]) attn_w = model.transformer.h[attn_layer].c_attn.attn for j in range(bs, 2*bs): max_s, max_n = float('-inf'), -1 for k in range(2*bs+1): s = attn_w[j, k].item() if s > max_s: max_s = s max_n = idx[0, k].item() sc = (max_n == idx[0, j+1].item()) pos = j - bs lc = is_correct[pos].item() if lc and not sc: acc_cl[pos] += 1.0 elif not lc and not sc: acc_icl[pos] += 1.0 return acc_cl / num_tries, acc_icl / num_tries def compute_intensity(model, config, device, attn_layer, ub=5, lb=None, ub_num=1, lb_num=0, min_valid=200): if lb is None: lb = ub bs = config.block_size vs = config.vocab_size location = bs + 5 intensities = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0] rates, counts = [], [] for intens in intensities: attempts, rounds = [], 0 while len(attempts) < min_valid and rounds < 2000: rounds += 1 idx = get_batch(vs, bs, device) try: im = GPTIntervention(model, idx) im.intervent_attention( attention_layer_num=attn_layer, location=location, unsorted_lb=lb, unsorted_ub=ub, unsorted_lb_num=lb_num, unsorted_ub_num=ub_num, unsorted_intensity_inc=intens, sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) g, n = im.check_if_still_works() attempts.append(g == n) im.revert_attention(attn_layer) except: continue counts.append(len(attempts)) rates.append(sum(attempts) / len(attempts) if attempts else 0.0) return np.array(intensities), np.array(rates), np.array(counts) def compute_hijack(model, config, device, n_trials=2000): """Hijack intervention on layer 0. Returns array of (current, boosted, predicted, correct).""" INTENSITY = 10.0 bs = config.block_size vs = config.vocab_size attn_module = model.transformer.h[0].c_attn records = [] for trial in range(n_trials): idx = get_batch(vs, bs, device) unsorted = idx[0, :bs] sorted_part = idx[0, bs + 1: 2 * bs + 1] with torch.no_grad(): _, _ = model(idx) raw_attn = attn_module.raw_attn.clone() for p in range(bs - 1): location = bs + 1 + p current_num = sorted_part[p].item() correct_next = idx[0, location + 1].item() next_loc_in_unsorted = (unsorted == correct_next).nonzero(as_tuple=True)[0] if len(next_loc_in_unsorted) == 0: continue next_loc = next_loc_in_unsorted[0].item() main_attn_val = raw_attn[location, next_loc].item() candidates = [i for i in range(bs) if unsorted[i].item() != correct_next] if not candidates: continue boost_idx = candidates[torch.randint(len(candidates), (1,)).item()] boosted_number = unsorted[boost_idx].item() def make_new_forward(loc, bidx, mav): def new_forward(self_attn, x, layer_n=-1): B, T, C = x.size() qkv = self_attn.c_attn(x) q, k, v = qkv.split(self_attn.n_embd, dim=2) q = q.view(B, T, self_attn.n_heads, C // self_attn.n_heads).transpose(1, 2) k = k.view(B, T, self_attn.n_heads, C // self_attn.n_heads).transpose(1, 2) v = v.view(B, T, self_attn.n_heads, C // self_attn.n_heads).transpose(1, 2) attn = q @ k.transpose(-1, -2) * 0.1 / (k.size(-1)) ** 0.5 attn[:, :, loc, bidx] = mav + INTENSITY attn = attn.masked_fill(self_attn.bias[:, :, :T, :T] == 0, float('-inf')) attn = F.softmax(attn, dim=-1) y = attn @ v y = y.transpose(1, 2).contiguous().view(B, T, C) y = self_attn.c_proj(y) return y return new_forward old_forward = attn_module.forward attn_module.forward = types.MethodType( make_new_forward(location, boost_idx, main_attn_val), attn_module) with torch.no_grad(): logits, _ = model(idx) predicted = torch.argmax(logits, dim=-1)[0, location].item() attn_module.forward = old_forward records.append((current_num, boosted_number, predicted, correct_next)) return np.array(records, dtype=np.int32) if records else np.empty((0, 4), dtype=np.int32) def compute_separator_random(model, config, device, n_trials=1000): """Separator-attention and random-target intervention on layer 0.""" INTENSITIES = [2.0, 6.0, 10.0] UB_STANDARD = 60 bs = config.block_size vs = config.vocab_size sep_pos = bs sep_records = [] rand_records = [] for trial in range(n_trials): idx = get_batch(vs, bs, device) with torch.no_grad(): logits, _ = model(idx) attn_layer0 = model.transformer.h[0].c_attn.attn for p in range(bs - 1): sorted_loc = bs + 1 + p number_val = idx[0, sorted_loc].item() attn_row = attn_layer0[sorted_loc, :sorted_loc + 1] max_attn_pos = attn_row.argmax().item() attends_to_sep = (max_attn_pos == sep_pos) for intensity in INTENSITIES: if attends_to_sep: try: im = GPTIntervention(model, idx) im.intervent_attention( attention_layer_num=0, location=sorted_loc, unsorted_lb=UB_STANDARD, unsorted_ub=UB_STANDARD, unsorted_lb_num=0, unsorted_ub_num=1, unsorted_intensity_inc=intensity, sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) g, n = im.check_if_still_works() im.revert_attention(0) sep_records.append((number_val, intensity, int(g == n))) except: pass try: im = GPTIntervention(model, idx) im.intervent_attention( attention_layer_num=0, location=sorted_loc, unsorted_lb=0, unsorted_ub=vs, unsorted_lb_num=0, unsorted_ub_num=1, unsorted_intensity_inc=intensity, sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) g, n = im.check_if_still_works() im.revert_attention(0) rand_records.append((number_val, intensity, int(g == n))) except: try: im = GPTIntervention(model, idx) im.intervent_attention( attention_layer_num=0, location=sorted_loc, unsorted_lb=vs, unsorted_ub=0, unsorted_lb_num=1, unsorted_ub_num=0, unsorted_intensity_inc=intensity, sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) g, n = im.check_if_still_works() im.revert_attention(0) rand_records.append((number_val, intensity, int(g == n))) except: pass sep = np.array(sep_records, dtype=np.int32) if sep_records else np.empty((0, 3), dtype=np.int32) rand = np.array(rand_records, dtype=np.int32) if rand_records else np.empty((0, 3), dtype=np.int32) return sep, rand def process_task(task, model, config, device): task_type = task['type'] out_path = task['out'] if os.path.exists(out_path): return True os.makedirs(os.path.dirname(out_path), exist_ok=True) itr = task.get('itr', 0) if task_type == 'baseline': pp, fs, ca, ce = compute_baseline(model, config, device) np.savez(out_path, per_pos_acc=pp, full_seq_acc=fs, cond_acc=ca, cond_eligible=ce, itr=itr) elif task_type == 'ablation': pp, fs, ca, ce = compute_ablation(model, config, device, task['layer']) np.savez(out_path, per_pos_acc=pp, full_seq_acc=fs, cond_acc=ca, cond_eligible=ce, skip_layer=task['layer'], itr=itr) elif task_type == 'cinclogits': cl, icl = compute_cinclogits(model, config, device, task['layer']) np.savez(out_path, clogit_icscore=cl, iclogit_icscore=icl, itr=itr) elif task_type == 'intensity': intensities, rates, counts = compute_intensity( model, config, device, task['layer'], ub=task['ub']) np.savez(out_path, intensities=intensities, success_rates=rates, counts=counts, itr=itr) elif task_type == 'intensity_asym': intensities, rates, counts = compute_intensity( model, config, device, task['layer'], ub=task['unsorted_ub'], lb=task['unsorted_lb'], ub_num=task['unsorted_ub_num'], lb_num=task['unsorted_lb_num']) np.savez(out_path, intensities=intensities, success_rates=rates, counts=counts, itr=itr) elif task_type == 'hijack': data = compute_hijack(model, config, device, n_trials=task.get('trials', 2000)) np.savez(out_path, data=data) elif task_type == 'separator_random': sep, rand = compute_separator_random(model, config, device, n_trials=task.get('trials', 1000)) np.savez(out_path, sep_data=sep, rand_data=rand) return True def main(): parser = argparse.ArgumentParser() parser.add_argument('--tasks-file', required=True) parser.add_argument('--gpu', type=int, required=True) args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) device = 'cuda' with open(args.tasks_file) as f: task_list = json.load(f) print(f"GPU {args.gpu}: {len(task_list)} tasks", flush=True) current_model = None current_ckpt = None done = 0 for task in task_list: ckpt_path = task['ckpt_path'] if ckpt_path != current_ckpt: t0 = time.time() model, config = load_model(ckpt_path, device) current_model = model current_ckpt = ckpt_path print(f" Loaded {os.path.basename(ckpt_path)} ({time.time()-t0:.1f}s)", flush=True) t0 = time.time() try: process_task(task, current_model, config, device) dt = time.time() - t0 done += 1 print(json.dumps({ 'status': 'done', 'task': task['name'], 'gpu': args.gpu, 'elapsed': round(dt, 1), 'progress': f'{done}/{len(task_list)}' }), flush=True) except Exception as e: done += 1 print(json.dumps({ 'status': 'fail', 'task': task['name'], 'gpu': args.gpu, 'error': str(e) }), flush=True) print(f"GPU {args.gpu}: all done ({done}/{len(task_list)})", flush=True) if __name__ == '__main__': main()