| """ |
| GPU worker for 1000k-checkpoint analysis. |
| Processes all task types on a single GPU: baseline, ablation, cinclogits, |
| intensity (various ub), asymmetric intensity, hijack, separator/random. |
| """ |
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| import types |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run')) |
| from model_analysis import GPT, GPTConfig, GPTIntervention |
|
|
|
|
| def remap_state_dict(sd): |
| new_sd = {} |
| for key, val in sd.items(): |
| new_key = key |
| for i in range(10): |
| new_key = new_key.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.') |
| new_key = new_key.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.') |
| new_sd[new_key] = val |
| return new_sd |
|
|
|
|
| def load_model(ckpt_path, device): |
| ckpt = torch.load(ckpt_path, map_location='cpu') |
| mc = ckpt['model_config'] |
| vocab_size = mc['vocab_size'] - 1 |
| block_size = mc['block_size'] |
| with_layer_norm = mc.get('use_final_LN', True) |
|
|
| config = GPTConfig(block_size=block_size, vocab_size=vocab_size, |
| with_layer_norm=with_layer_norm) |
| model = GPT(config) |
|
|
| sd = remap_state_dict(ckpt['model_state_dict']) |
| grid_wpe_size = block_size * 4 + 1 |
| if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > grid_wpe_size: |
| sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:grid_wpe_size] |
| keys_to_skip = [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k] |
| for k in keys_to_skip: |
| del sd[k] |
| if 'lm_head.weight' in sd: |
| del sd['lm_head.weight'] |
|
|
| model.load_state_dict(sd, strict=False) |
| model.to(device) |
| model.eval() |
| return model, config |
|
|
|
|
| def get_batch(vocab_size, block_size, device='cpu'): |
| x = torch.randperm(vocab_size)[:block_size] |
| vals, _ = torch.sort(x) |
| return torch.cat((x, torch.tensor([vocab_size]), vals), dim=0).unsqueeze(0).to(device) |
|
|
|
|
| def compute_baseline(model, config, device, num_trials=500): |
| bs = config.block_size |
| vs = config.vocab_size |
| pp = np.zeros(bs) |
| fs = 0 |
| cc = np.zeros(bs) |
| ce = np.zeros(bs) |
| for _ in range(num_trials): |
| idx = get_batch(vs, bs, device) |
| with torch.no_grad(): |
| logits, _ = model(idx) |
| preds = torch.argmax(logits[0, bs:2*bs, :], dim=1) |
| targets = idx[0, bs+1:] |
| correct = (preds == targets).cpu().numpy() |
| pp += correct |
| if correct.all(): |
| fs += 1 |
| ok = True |
| for i in range(bs): |
| if ok: |
| ce[i] += 1 |
| if correct[i]: |
| cc[i] += 1 |
| else: |
| ok = False |
| else: |
| break |
| return pp / num_trials, fs / num_trials, np.where(ce > 0, cc / ce, 0.0), ce |
|
|
|
|
| def compute_ablation(model, config, device, skip_layer, num_trials=500): |
| bs = config.block_size |
| block = model.transformer.h[skip_layer] |
| orig_fwd = block.forward |
|
|
| def skip_attn(x, layer_n=-1): |
| return x + block.c_fc(block.ln_2(x)) |
| block.forward = skip_attn |
|
|
| pp = np.zeros(bs) |
| fs = 0 |
| cc = np.zeros(bs) |
| ce = np.zeros(bs) |
| try: |
| for _ in range(num_trials): |
| idx = get_batch(config.vocab_size, bs, device) |
| with torch.no_grad(): |
| logits, _ = model(idx) |
| preds = torch.argmax(logits[0, bs:2*bs, :], dim=1) |
| targets = idx[0, bs+1:] |
| correct = (preds == targets).cpu().numpy() |
| pp += correct |
| if correct.all(): |
| fs += 1 |
| ok = True |
| for i in range(bs): |
| if ok: |
| ce[i] += 1 |
| if correct[i]: |
| cc[i] += 1 |
| else: |
| ok = False |
| else: |
| break |
| finally: |
| block.forward = orig_fwd |
| return pp / num_trials, fs / num_trials, np.where(ce > 0, cc / ce, 0.0), ce |
|
|
|
|
| def compute_cinclogits(model, config, device, attn_layer, num_tries=100): |
| bs = config.block_size |
| vs = config.vocab_size |
| acc_cl = np.zeros(bs) |
| acc_icl = np.zeros(bs) |
| for _ in range(num_tries): |
| idx = get_batch(vs, bs, device) |
| with torch.no_grad(): |
| logits, _ = model(idx) |
| is_correct = (torch.argmax(logits[0, bs:2*bs, :], dim=1) == idx[0, bs+1:]) |
| attn_w = model.transformer.h[attn_layer].c_attn.attn |
| for j in range(bs, 2*bs): |
| max_s, max_n = float('-inf'), -1 |
| for k in range(2*bs+1): |
| s = attn_w[j, k].item() |
| if s > max_s: |
| max_s = s |
| max_n = idx[0, k].item() |
| sc = (max_n == idx[0, j+1].item()) |
| pos = j - bs |
| lc = is_correct[pos].item() |
| if lc and not sc: |
| acc_cl[pos] += 1.0 |
| elif not lc and not sc: |
| acc_icl[pos] += 1.0 |
| return acc_cl / num_tries, acc_icl / num_tries |
|
|
|
|
| def compute_intensity(model, config, device, attn_layer, ub=5, lb=None, |
| ub_num=1, lb_num=0, min_valid=200): |
| if lb is None: |
| lb = ub |
| bs = config.block_size |
| vs = config.vocab_size |
| location = bs + 5 |
| intensities = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0] |
| rates, counts = [], [] |
| for intens in intensities: |
| attempts, rounds = [], 0 |
| while len(attempts) < min_valid and rounds < 2000: |
| rounds += 1 |
| idx = get_batch(vs, bs, device) |
| try: |
| im = GPTIntervention(model, idx) |
| im.intervent_attention( |
| attention_layer_num=attn_layer, location=location, |
| unsorted_lb=lb, unsorted_ub=ub, |
| unsorted_lb_num=lb_num, unsorted_ub_num=ub_num, |
| unsorted_intensity_inc=intens, |
| sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) |
| g, n = im.check_if_still_works() |
| attempts.append(g == n) |
| im.revert_attention(attn_layer) |
| except: |
| continue |
| counts.append(len(attempts)) |
| rates.append(sum(attempts) / len(attempts) if attempts else 0.0) |
| return np.array(intensities), np.array(rates), np.array(counts) |
|
|
|
|
| def compute_hijack(model, config, device, n_trials=2000): |
| """Hijack intervention on layer 0. Returns array of (current, boosted, predicted, correct).""" |
| INTENSITY = 10.0 |
| bs = config.block_size |
| vs = config.vocab_size |
| attn_module = model.transformer.h[0].c_attn |
| records = [] |
|
|
| for trial in range(n_trials): |
| idx = get_batch(vs, bs, device) |
| unsorted = idx[0, :bs] |
| sorted_part = idx[0, bs + 1: 2 * bs + 1] |
|
|
| with torch.no_grad(): |
| _, _ = model(idx) |
| raw_attn = attn_module.raw_attn.clone() |
|
|
| for p in range(bs - 1): |
| location = bs + 1 + p |
| current_num = sorted_part[p].item() |
| correct_next = idx[0, location + 1].item() |
|
|
| next_loc_in_unsorted = (unsorted == correct_next).nonzero(as_tuple=True)[0] |
| if len(next_loc_in_unsorted) == 0: |
| continue |
| next_loc = next_loc_in_unsorted[0].item() |
| main_attn_val = raw_attn[location, next_loc].item() |
|
|
| candidates = [i for i in range(bs) if unsorted[i].item() != correct_next] |
| if not candidates: |
| continue |
|
|
| boost_idx = candidates[torch.randint(len(candidates), (1,)).item()] |
| boosted_number = unsorted[boost_idx].item() |
|
|
| def make_new_forward(loc, bidx, mav): |
| def new_forward(self_attn, x, layer_n=-1): |
| B, T, C = x.size() |
| qkv = self_attn.c_attn(x) |
| q, k, v = qkv.split(self_attn.n_embd, dim=2) |
| q = q.view(B, T, self_attn.n_heads, C // self_attn.n_heads).transpose(1, 2) |
| k = k.view(B, T, self_attn.n_heads, C // self_attn.n_heads).transpose(1, 2) |
| v = v.view(B, T, self_attn.n_heads, C // self_attn.n_heads).transpose(1, 2) |
| attn = q @ k.transpose(-1, -2) * 0.1 / (k.size(-1)) ** 0.5 |
| attn[:, :, loc, bidx] = mav + INTENSITY |
| attn = attn.masked_fill(self_attn.bias[:, :, :T, :T] == 0, float('-inf')) |
| attn = F.softmax(attn, dim=-1) |
| y = attn @ v |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| y = self_attn.c_proj(y) |
| return y |
| return new_forward |
|
|
| old_forward = attn_module.forward |
| attn_module.forward = types.MethodType( |
| make_new_forward(location, boost_idx, main_attn_val), attn_module) |
|
|
| with torch.no_grad(): |
| logits, _ = model(idx) |
| predicted = torch.argmax(logits, dim=-1)[0, location].item() |
|
|
| attn_module.forward = old_forward |
| records.append((current_num, boosted_number, predicted, correct_next)) |
|
|
| return np.array(records, dtype=np.int32) if records else np.empty((0, 4), dtype=np.int32) |
|
|
|
|
| def compute_separator_random(model, config, device, n_trials=1000): |
| """Separator-attention and random-target intervention on layer 0.""" |
| INTENSITIES = [2.0, 6.0, 10.0] |
| UB_STANDARD = 60 |
| bs = config.block_size |
| vs = config.vocab_size |
| sep_pos = bs |
|
|
| sep_records = [] |
| rand_records = [] |
|
|
| for trial in range(n_trials): |
| idx = get_batch(vs, bs, device) |
| with torch.no_grad(): |
| logits, _ = model(idx) |
| attn_layer0 = model.transformer.h[0].c_attn.attn |
|
|
| for p in range(bs - 1): |
| sorted_loc = bs + 1 + p |
| number_val = idx[0, sorted_loc].item() |
|
|
| attn_row = attn_layer0[sorted_loc, :sorted_loc + 1] |
| max_attn_pos = attn_row.argmax().item() |
| attends_to_sep = (max_attn_pos == sep_pos) |
|
|
| for intensity in INTENSITIES: |
| if attends_to_sep: |
| try: |
| im = GPTIntervention(model, idx) |
| im.intervent_attention( |
| attention_layer_num=0, location=sorted_loc, |
| unsorted_lb=UB_STANDARD, unsorted_ub=UB_STANDARD, |
| unsorted_lb_num=0, unsorted_ub_num=1, |
| unsorted_intensity_inc=intensity, |
| sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) |
| g, n = im.check_if_still_works() |
| im.revert_attention(0) |
| sep_records.append((number_val, intensity, int(g == n))) |
| except: |
| pass |
|
|
| try: |
| im = GPTIntervention(model, idx) |
| im.intervent_attention( |
| attention_layer_num=0, location=sorted_loc, |
| unsorted_lb=0, unsorted_ub=vs, |
| unsorted_lb_num=0, unsorted_ub_num=1, |
| unsorted_intensity_inc=intensity, |
| sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) |
| g, n = im.check_if_still_works() |
| im.revert_attention(0) |
| rand_records.append((number_val, intensity, int(g == n))) |
| except: |
| try: |
| im = GPTIntervention(model, idx) |
| im.intervent_attention( |
| attention_layer_num=0, location=sorted_loc, |
| unsorted_lb=vs, unsorted_ub=0, |
| unsorted_lb_num=1, unsorted_ub_num=0, |
| unsorted_intensity_inc=intensity, |
| sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) |
| g, n = im.check_if_still_works() |
| im.revert_attention(0) |
| rand_records.append((number_val, intensity, int(g == n))) |
| except: |
| pass |
|
|
| sep = np.array(sep_records, dtype=np.int32) if sep_records else np.empty((0, 3), dtype=np.int32) |
| rand = np.array(rand_records, dtype=np.int32) if rand_records else np.empty((0, 3), dtype=np.int32) |
| return sep, rand |
|
|
|
|
| def process_task(task, model, config, device): |
| task_type = task['type'] |
| out_path = task['out'] |
| if os.path.exists(out_path): |
| return True |
|
|
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| itr = task.get('itr', 0) |
|
|
| if task_type == 'baseline': |
| pp, fs, ca, ce = compute_baseline(model, config, device) |
| np.savez(out_path, per_pos_acc=pp, full_seq_acc=fs, |
| cond_acc=ca, cond_eligible=ce, itr=itr) |
|
|
| elif task_type == 'ablation': |
| pp, fs, ca, ce = compute_ablation(model, config, device, task['layer']) |
| np.savez(out_path, per_pos_acc=pp, full_seq_acc=fs, |
| cond_acc=ca, cond_eligible=ce, skip_layer=task['layer'], itr=itr) |
|
|
| elif task_type == 'cinclogits': |
| cl, icl = compute_cinclogits(model, config, device, task['layer']) |
| np.savez(out_path, clogit_icscore=cl, iclogit_icscore=icl, itr=itr) |
|
|
| elif task_type == 'intensity': |
| intensities, rates, counts = compute_intensity( |
| model, config, device, task['layer'], ub=task['ub']) |
| np.savez(out_path, intensities=intensities, success_rates=rates, |
| counts=counts, itr=itr) |
|
|
| elif task_type == 'intensity_asym': |
| intensities, rates, counts = compute_intensity( |
| model, config, device, task['layer'], |
| ub=task['unsorted_ub'], lb=task['unsorted_lb'], |
| ub_num=task['unsorted_ub_num'], lb_num=task['unsorted_lb_num']) |
| np.savez(out_path, intensities=intensities, success_rates=rates, |
| counts=counts, itr=itr) |
|
|
| elif task_type == 'hijack': |
| data = compute_hijack(model, config, device, n_trials=task.get('trials', 2000)) |
| np.savez(out_path, data=data) |
|
|
| elif task_type == 'separator_random': |
| sep, rand = compute_separator_random(model, config, device, |
| n_trials=task.get('trials', 1000)) |
| np.savez(out_path, sep_data=sep, rand_data=rand) |
|
|
| return True |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--tasks-file', required=True) |
| parser.add_argument('--gpu', type=int, required=True) |
| args = parser.parse_args() |
|
|
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) |
| device = 'cuda' |
|
|
| with open(args.tasks_file) as f: |
| task_list = json.load(f) |
|
|
| print(f"GPU {args.gpu}: {len(task_list)} tasks", flush=True) |
|
|
| current_model = None |
| current_ckpt = None |
| done = 0 |
|
|
| for task in task_list: |
| ckpt_path = task['ckpt_path'] |
| if ckpt_path != current_ckpt: |
| t0 = time.time() |
| model, config = load_model(ckpt_path, device) |
| current_model = model |
| current_ckpt = ckpt_path |
| print(f" Loaded {os.path.basename(ckpt_path)} ({time.time()-t0:.1f}s)", flush=True) |
|
|
| t0 = time.time() |
| try: |
| process_task(task, current_model, config, device) |
| dt = time.time() - t0 |
| done += 1 |
| print(json.dumps({ |
| 'status': 'done', 'task': task['name'], |
| 'gpu': args.gpu, 'elapsed': round(dt, 1), |
| 'progress': f'{done}/{len(task_list)}' |
| }), flush=True) |
| except Exception as e: |
| done += 1 |
| print(json.dumps({ |
| 'status': 'fail', 'task': task['name'], |
| 'gpu': args.gpu, 'error': str(e) |
| }), flush=True) |
|
|
| print(f"GPU {args.gpu}: all done ({done}/{len(task_list)})", flush=True) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|