sort-llm-2 / separator_intervention_worker.py
gatmiry's picture
Upload folder using huggingface_hub
21d04bc verified
"""
GPU worker for separator-attention and random-target intervention experiments.
For each random sequence, at each sorted output position:
1. Check if layer 0 max attention is on the separator token.
If yes, intervene with standard method (ub=60) and record result.
2. Intervene by boosting a random unsorted number's attention and record result.
Collects per-number success data across many trials.
"""
import argparse
import json
import os
import sys
import time
import numpy as np
import torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run'))
from model_analysis import GPT, GPTConfig, GPTIntervention
INTENSITIES = [2.0, 6.0, 10.0]
UB_STANDARD = 60
def remap_state_dict(sd):
new = {}
for k, v in sd.items():
nk = k
for i in range(10):
nk = nk.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.')
nk = nk.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.')
new[nk] = v
return new
def load_model(ckpt_path, device):
ckpt = torch.load(ckpt_path, map_location='cpu')
mc = ckpt['model_config']
config = GPTConfig(block_size=mc['block_size'], vocab_size=mc['vocab_size'] - 1,
with_layer_norm=mc.get('use_final_LN', True))
model = GPT(config)
sd = remap_state_dict(ckpt['model_state_dict'])
wpe_max = config.block_size * 4 + 1
if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > wpe_max:
sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:wpe_max]
for k in [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k]:
del sd[k]
if 'lm_head.weight' in sd:
del sd['lm_head.weight']
model.load_state_dict(sd, strict=False)
model.to(device)
model.eval()
return model, config
def get_batch(vocab_size, block_size, device):
x = torch.randperm(vocab_size)[:block_size]
vals, _ = torch.sort(x)
return torch.cat((x, torch.tensor([vocab_size]), vals), dim=0).unsqueeze(0).to(device)
def try_standard_intervention(model, idx, config, location, intensity):
try:
im = GPTIntervention(model, idx)
im.intervent_attention(
attention_layer_num=0, location=location,
unsorted_lb=UB_STANDARD, unsorted_ub=UB_STANDARD,
unsorted_lb_num=0, unsorted_ub_num=1,
unsorted_intensity_inc=intensity,
sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0)
g, n = im.check_if_still_works()
im.revert_attention(0)
return g == n
except:
return None
def try_random_intervention(model, idx, config, location, intensity):
vs = config.vocab_size
try:
im = GPTIntervention(model, idx)
im.intervent_attention(
attention_layer_num=0, location=location,
unsorted_lb=0, unsorted_ub=vs,
unsorted_lb_num=0, unsorted_ub_num=1,
unsorted_intensity_inc=intensity,
sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0)
g, n = im.check_if_still_works()
im.revert_attention(0)
return g == n
except:
pass
try:
im = GPTIntervention(model, idx)
im.intervent_attention(
attention_layer_num=0, location=location,
unsorted_lb=vs, unsorted_ub=0,
unsorted_lb_num=1, unsorted_ub_num=0,
unsorted_intensity_inc=intensity,
sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0)
g, n = im.check_if_still_works()
im.revert_attention(0)
return g == n
except:
return None
def run_trials(model, config, device, n_trials):
bs = config.block_size
vs = config.vocab_size
sep_pos = bs
sep_records = []
rand_records = []
for trial in range(n_trials):
idx = get_batch(vs, bs, device)
with torch.no_grad():
logits, _ = model(idx)
attn_layer0 = model.transformer.h[0].c_attn.attn
for p in range(bs - 1):
sorted_loc = bs + 1 + p
number_val = idx[0, sorted_loc].item()
next_num = idx[0, sorted_loc + 1].item()
attn_row = attn_layer0[sorted_loc, :sorted_loc + 1]
max_attn_pos = attn_row.argmax().item()
attends_to_sep = (max_attn_pos == sep_pos)
for intensity in INTENSITIES:
if attends_to_sep:
result = try_standard_intervention(model, idx, config, sorted_loc, intensity)
if result is not None:
sep_records.append((number_val, intensity, int(result)))
result_rand = try_random_intervention(model, idx, config, sorted_loc, intensity)
if result_rand is not None:
rand_records.append((number_val, intensity, int(result_rand)))
if (trial + 1) % 200 == 0:
print(f" Trial {trial+1}/{n_trials}: sep={len(sep_records)}, rand={len(rand_records)}",
flush=True)
return np.array(sep_records, dtype=np.int32), np.array(rand_records, dtype=np.int32)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', required=True)
parser.add_argument('--gpu', type=int, required=True)
parser.add_argument('--trials', type=int, default=1000)
parser.add_argument('--out', required=True)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
device = 'cuda'
print(f"GPU {args.gpu}: loading model...", flush=True)
t0 = time.time()
model, config = load_model(args.ckpt, device)
print(f" Loaded in {time.time()-t0:.1f}s", flush=True)
print(f"GPU {args.gpu}: running {args.trials} trials...", flush=True)
t0 = time.time()
sep_data, rand_data = run_trials(model, config, device, args.trials)
elapsed = time.time() - t0
os.makedirs(os.path.dirname(args.out), exist_ok=True)
np.savez(args.out, sep_data=sep_data, rand_data=rand_data)
print(f"GPU {args.gpu}: done in {elapsed:.0f}s, "
f"sep={len(sep_data)} rand={len(rand_data)}", flush=True)
if __name__ == '__main__':
main()