sort-llm-2 / gpu_worker.py
gatmiry's picture
Upload folder using huggingface_hub
21d04bc verified
"""
GPU worker for 200k-checkpoint analysis.
Processes analysis tasks on a single GPU, reusing the model for all tasks.
Based on 100k-checkpoints/gpu_worker.py with grid-run intensity values.
"""
import argparse
import json
import os
import sys
import time
import numpy as np
import torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run'))
from model_analysis import GPT, GPTConfig, GPTIntervention
def remap_state_dict(sd):
new_sd = {}
for key, val in sd.items():
new_key = key
for i in range(10):
new_key = new_key.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.')
new_key = new_key.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.')
new_sd[new_key] = val
return new_sd
def load_model(ckpt_path, device):
ckpt = torch.load(ckpt_path, map_location='cpu')
mc = ckpt['model_config']
vocab_size = mc['vocab_size'] - 1
block_size = mc['block_size']
with_layer_norm = mc.get('use_final_LN', True)
config = GPTConfig(block_size=block_size, vocab_size=vocab_size,
with_layer_norm=with_layer_norm)
model = GPT(config)
sd = remap_state_dict(ckpt['model_state_dict'])
grid_wpe_size = block_size * 4 + 1
if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > grid_wpe_size:
sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:grid_wpe_size]
keys_to_skip = [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k]
for k in keys_to_skip:
del sd[k]
if 'lm_head.weight' in sd:
del sd['lm_head.weight']
model.load_state_dict(sd, strict=False)
model.to(device)
model.eval()
return model, config
def get_batch(vocab_size, block_size, device='cpu'):
x = torch.randperm(vocab_size)[:block_size]
vals, _ = torch.sort(x)
return torch.cat((x, torch.tensor([vocab_size]), vals), dim=0).unsqueeze(0).to(device)
def compute_cinclogits(model, config, device, attn_layer, num_tries=100):
bs = config.block_size
vs = config.vocab_size
acc_cl = np.zeros(bs)
acc_icl = np.zeros(bs)
for _ in range(num_tries):
idx = get_batch(vs, bs, device)
with torch.no_grad():
logits, _ = model(idx)
is_correct = (torch.argmax(logits[0, bs:2*bs, :], dim=1) == idx[0, bs+1:])
attn_w = model.transformer.h[attn_layer].c_attn.attn
for j in range(bs, 2*bs):
max_s, max_n = float('-inf'), -1
for k in range(2*bs+1):
s = attn_w[j, k].item()
if s > max_s:
max_s = s
max_n = idx[0, k].item()
sc = (max_n == idx[0, j+1].item())
pos = j - bs
lc = is_correct[pos].item()
if lc and not sc:
acc_cl[pos] += 1.0
elif not lc and not sc:
acc_icl[pos] += 1.0
return acc_cl / num_tries, acc_icl / num_tries
def compute_intensity(model, config, device, attn_layer, ub=5, lb=None,
ub_num=1, lb_num=0, min_valid=200):
if lb is None:
lb = ub
bs = config.block_size
vs = config.vocab_size
location = bs + 5
intensities = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0]
rates, counts = [], []
for intens in intensities:
attempts, rounds = [], 0
while len(attempts) < min_valid and rounds < 2000:
rounds += 1
idx = get_batch(vs, bs, device)
try:
im = GPTIntervention(model, idx)
im.intervent_attention(
attention_layer_num=attn_layer, location=location,
unsorted_lb=lb, unsorted_ub=ub,
unsorted_lb_num=lb_num, unsorted_ub_num=ub_num,
unsorted_intensity_inc=intens,
sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0)
g, n = im.check_if_still_works()
attempts.append(g == n)
im.revert_attention(attn_layer)
except:
continue
counts.append(len(attempts))
rates.append(sum(attempts) / len(attempts) if attempts else 0.0)
return np.array(intensities), np.array(rates), np.array(counts)
def compute_ablation(model, config, device, skip_layer, num_trials=500):
bs = config.block_size
block = model.transformer.h[skip_layer]
orig_fwd = block.forward
def skip_attn(x, layer_n=-1):
return x + block.c_fc(block.ln_2(x))
block.forward = skip_attn
pp = np.zeros(bs)
fs = 0
cc = np.zeros(bs)
ce = np.zeros(bs)
try:
for _ in range(num_trials):
idx = get_batch(config.vocab_size, bs, device)
with torch.no_grad():
logits, _ = model(idx)
preds = torch.argmax(logits[0, bs:2*bs, :], dim=1)
targets = idx[0, bs+1:]
correct = (preds == targets).cpu().numpy()
pp += correct
if correct.all():
fs += 1
ok = True
for i in range(bs):
if ok:
ce[i] += 1
if correct[i]:
cc[i] += 1
else:
ok = False
else:
break
finally:
block.forward = orig_fwd
return pp / num_trials, fs / num_trials, np.where(ce > 0, cc / ce, 0.0), ce
def compute_baseline(model, config, device, num_trials=500):
bs = config.block_size
vs = config.vocab_size
pp = np.zeros(bs)
fs = 0
cc = np.zeros(bs)
ce = np.zeros(bs)
for _ in range(num_trials):
idx = get_batch(vs, bs, device)
with torch.no_grad():
logits, _ = model(idx)
preds = torch.argmax(logits[0, bs:2*bs, :], dim=1)
targets = idx[0, bs+1:]
correct = (preds == targets).cpu().numpy()
pp += correct
if correct.all():
fs += 1
ok = True
for i in range(bs):
if ok:
ce[i] += 1
if correct[i]:
cc[i] += 1
else:
ok = False
else:
break
return pp / num_trials, fs / num_trials, np.where(ce > 0, cc / ce, 0.0), ce
def process_task(task, model, config, device, itr):
task_type = task['type']
out_path = task['out']
if os.path.exists(out_path):
return True
os.makedirs(os.path.dirname(out_path), exist_ok=True)
if task_type == 'baseline':
pp, fs, ca, ce = compute_baseline(model, config, device)
np.savez(out_path, per_pos_acc=pp, full_seq_acc=fs,
cond_acc=ca, cond_eligible=ce, itr=itr)
elif task_type == 'ablation':
pp, fs, ca, ce = compute_ablation(model, config, device, task['layer'])
np.savez(out_path, per_pos_acc=pp, full_seq_acc=fs,
cond_acc=ca, cond_eligible=ce, skip_layer=task['layer'], itr=itr)
elif task_type == 'cinclogits':
cl, icl = compute_cinclogits(model, config, device, task['layer'])
np.savez(out_path, clogit_icscore=cl, iclogit_icscore=icl, itr=itr)
elif task_type == 'intensity':
intensities, rates, counts = compute_intensity(
model, config, device, task['layer'], ub=task['ub'])
np.savez(out_path, intensities=intensities, success_rates=rates,
counts=counts, itr=itr)
elif task_type == 'intensity_asym':
intensities, rates, counts = compute_intensity(
model, config, device, task['layer'],
ub=task['unsorted_ub'], lb=task['unsorted_lb'],
ub_num=task['unsorted_ub_num'], lb_num=task['unsorted_lb_num'])
np.savez(out_path, intensities=intensities, success_rates=rates,
counts=counts, itr=itr)
return True
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--tasks-file', required=True)
parser.add_argument('--gpu', type=int, required=True)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
device = 'cuda'
with open(args.tasks_file) as f:
task_list = json.load(f)
print(f"GPU {args.gpu}: {len(task_list)} tasks", flush=True)
current_model = None
current_ckpt = None
done = 0
for task in task_list:
ckpt_path = task['ckpt_path']
if ckpt_path != current_ckpt:
t0 = time.time()
model, config = load_model(ckpt_path, device)
current_model = model
current_ckpt = ckpt_path
itr = task.get('itr', 200000)
print(f" Loaded {os.path.basename(ckpt_path)} ({time.time()-t0:.1f}s)", flush=True)
t0 = time.time()
try:
process_task(task, current_model, config, device, itr)
dt = time.time() - t0
done += 1
print(json.dumps({
'status': 'done', 'task': task['name'],
'gpu': args.gpu, 'elapsed': round(dt, 1),
'progress': f'{done}/{len(task_list)}'
}), flush=True)
except Exception as e:
done += 1
print(json.dumps({
'status': 'fail', 'task': task['name'],
'gpu': args.gpu, 'error': str(e)
}), flush=True)
if __name__ == '__main__':
main()