| """ |
| GPU worker for separator-attention and random-target intervention experiments. |
| For each random sequence, at each sorted output position: |
| 1. Check if layer 0 max attention is on the separator token. |
| If yes, intervene with standard method (ub=60) and record result. |
| 2. Intervene by boosting a random unsorted number's attention and record result. |
| Collects per-number success data across many trials. |
| """ |
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| import numpy as np |
| import torch |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run')) |
| from model_analysis import GPT, GPTConfig, GPTIntervention |
|
|
| INTENSITIES = [2.0, 6.0, 10.0] |
| UB_STANDARD = 60 |
|
|
|
|
| def remap_state_dict(sd): |
| new = {} |
| for k, v in sd.items(): |
| nk = k |
| for i in range(10): |
| nk = nk.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.') |
| nk = nk.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.') |
| new[nk] = v |
| return new |
|
|
|
|
| def load_model(ckpt_path, device): |
| ckpt = torch.load(ckpt_path, map_location='cpu') |
| mc = ckpt['model_config'] |
| config = GPTConfig(block_size=mc['block_size'], vocab_size=mc['vocab_size'] - 1, |
| with_layer_norm=mc.get('use_final_LN', True)) |
| model = GPT(config) |
| sd = remap_state_dict(ckpt['model_state_dict']) |
| wpe_max = config.block_size * 4 + 1 |
| if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > wpe_max: |
| sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:wpe_max] |
| for k in [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k]: |
| del sd[k] |
| if 'lm_head.weight' in sd: |
| del sd['lm_head.weight'] |
| model.load_state_dict(sd, strict=False) |
| model.to(device) |
| model.eval() |
| return model, config |
|
|
|
|
| def get_batch(vocab_size, block_size, device): |
| x = torch.randperm(vocab_size)[:block_size] |
| vals, _ = torch.sort(x) |
| return torch.cat((x, torch.tensor([vocab_size]), vals), dim=0).unsqueeze(0).to(device) |
|
|
|
|
| def try_standard_intervention(model, idx, config, location, intensity): |
| try: |
| im = GPTIntervention(model, idx) |
| im.intervent_attention( |
| attention_layer_num=0, location=location, |
| unsorted_lb=UB_STANDARD, unsorted_ub=UB_STANDARD, |
| unsorted_lb_num=0, unsorted_ub_num=1, |
| unsorted_intensity_inc=intensity, |
| sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) |
| g, n = im.check_if_still_works() |
| im.revert_attention(0) |
| return g == n |
| except: |
| return None |
|
|
|
|
| def try_random_intervention(model, idx, config, location, intensity): |
| vs = config.vocab_size |
| try: |
| im = GPTIntervention(model, idx) |
| im.intervent_attention( |
| attention_layer_num=0, location=location, |
| unsorted_lb=0, unsorted_ub=vs, |
| unsorted_lb_num=0, unsorted_ub_num=1, |
| unsorted_intensity_inc=intensity, |
| sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) |
| g, n = im.check_if_still_works() |
| im.revert_attention(0) |
| return g == n |
| except: |
| pass |
| try: |
| im = GPTIntervention(model, idx) |
| im.intervent_attention( |
| attention_layer_num=0, location=location, |
| unsorted_lb=vs, unsorted_ub=0, |
| unsorted_lb_num=1, unsorted_ub_num=0, |
| unsorted_intensity_inc=intensity, |
| sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) |
| g, n = im.check_if_still_works() |
| im.revert_attention(0) |
| return g == n |
| except: |
| return None |
|
|
|
|
| def run_trials(model, config, device, n_trials): |
| bs = config.block_size |
| vs = config.vocab_size |
| sep_pos = bs |
|
|
| sep_records = [] |
| rand_records = [] |
|
|
| for trial in range(n_trials): |
| idx = get_batch(vs, bs, device) |
|
|
| with torch.no_grad(): |
| logits, _ = model(idx) |
|
|
| attn_layer0 = model.transformer.h[0].c_attn.attn |
|
|
| for p in range(bs - 1): |
| sorted_loc = bs + 1 + p |
| number_val = idx[0, sorted_loc].item() |
| next_num = idx[0, sorted_loc + 1].item() |
|
|
| attn_row = attn_layer0[sorted_loc, :sorted_loc + 1] |
| max_attn_pos = attn_row.argmax().item() |
| attends_to_sep = (max_attn_pos == sep_pos) |
|
|
| for intensity in INTENSITIES: |
| if attends_to_sep: |
| result = try_standard_intervention(model, idx, config, sorted_loc, intensity) |
| if result is not None: |
| sep_records.append((number_val, intensity, int(result))) |
|
|
| result_rand = try_random_intervention(model, idx, config, sorted_loc, intensity) |
| if result_rand is not None: |
| rand_records.append((number_val, intensity, int(result_rand))) |
|
|
| if (trial + 1) % 200 == 0: |
| print(f" Trial {trial+1}/{n_trials}: sep={len(sep_records)}, rand={len(rand_records)}", |
| flush=True) |
|
|
| return np.array(sep_records, dtype=np.int32), np.array(rand_records, dtype=np.int32) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--ckpt', required=True) |
| parser.add_argument('--gpu', type=int, required=True) |
| parser.add_argument('--trials', type=int, default=1000) |
| parser.add_argument('--out', required=True) |
| args = parser.parse_args() |
|
|
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) |
| device = 'cuda' |
|
|
| print(f"GPU {args.gpu}: loading model...", flush=True) |
| t0 = time.time() |
| model, config = load_model(args.ckpt, device) |
| print(f" Loaded in {time.time()-t0:.1f}s", flush=True) |
|
|
| print(f"GPU {args.gpu}: running {args.trials} trials...", flush=True) |
| t0 = time.time() |
| sep_data, rand_data = run_trials(model, config, device, args.trials) |
| elapsed = time.time() - t0 |
|
|
| os.makedirs(os.path.dirname(args.out), exist_ok=True) |
| np.savez(args.out, sep_data=sep_data, rand_data=rand_data) |
| print(f"GPU {args.gpu}: done in {elapsed:.0f}s, " |
| f"sep={len(sep_data)} rand={len(rand_data)}", flush=True) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|