| import gc |
| import torch |
| from tqdm import tqdm |
| import pdb, sys |
| import glog |
|
|
|
|
| def clean(): |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
|
|
| def show_metrics(hatW, W_orig, H, msg): |
| err_frob = (hatW - W_orig).square().sum() / W_orig.square().sum() |
| err_proxy = (((hatW - W_orig) @ H) * (hatW - W_orig)).sum() / ((W_orig @ H) * W_orig).sum() |
| glog.info(f"{msg} frob error: {err_frob}") |
| glog.info(f"{msg} proxy error: {err_proxy}") |
|
|
|
|
| def compute_activation_deltas(model, devset, args, device): |
| import pdb |
| pdb.set_trace() |
| dev_emb = model.model.embed_tokens(devset).cpu() |
| position_ids = torch.arange(args.ctx_size, dtype=torch.int16)[None, :].to(device) + \ |
| torch.zeros(args.batch_size, args.ctx_size, dtype=torch.int16).to(device) |
| attention_mask = model.model._prepare_decoder_attention_mask( |
| torch.ones(args.batch_size, args.ctx_size, dtype=torch.bool), |
| (args.batch_size, args.ctx_size), |
| dev_emb[0:args.batch_size, :, :], |
| 0) |
| attention_mask = attention_mask.to(device) |
|
|
| acts = [dev_emb] + [torch.zeros(*dev_emb.shape) for _ in range(len(model.model.layers))] |
|
|
| for i in tqdm(list(range(len(model.model.layers))), desc='computing activations'): |
| layer = model.model.layers[i] |
| layer = layer.to(device) |
| for j in range(args.devset_size // args.batch_size): |
| acts[i + 1][args.batch_size * j : args.batch_size * (j + 1)] = transformer_layer( |
| acts[args.batch_size * j : args.batch_size * (j + 1)].to(device), |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| output_attentions=False)[0].cpu() |
|
|
| return torch.stack(acts) |
| |
|
|
| class ForkedPdb(pdb.Pdb): |
| """A Pdb subclass that may be used |
| from a forked multiprocessing child |
| use as: |
| from lib import utils |
| utils.ForkedPdb().set_trace() |
| """ |
| def interaction(self, *args, **kwargs): |
| _stdin = sys.stdin |
| try: |
| sys.stdin = open('/dev/stdin') |
| pdb.Pdb.interaction(self, *args, **kwargs) |
| finally: |
| sys.stdin = _stdin |
|
|
|
|
| ''' |
| |
| |
| for i in tqdm(range(args.devset_size // args.batch_size), desc=f"calculating dev emb for layer {idx}"): |
| dev_emb[(args.batch_size * i):(args.batch_size * (i + 1)), :, :] = transformer_layer( |
| dev_emb[(args.batch_size * i):(args.batch_size * (i + 1)), :, :].to(device), |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| output_attentions=False)[0].cpu() |
| |
| act_error = (dev_emb.to(dtype_) - orig_emb.to(dtype_)).square().sum() / (orig_emb.to(dtype_) - |
| orig_emb.to(dtype_).mean( |
| (0, 1))).square().sum() |
| glog.info(f"layer {idx} act error: {act_error}") |
| |
| glog.info(f"saving activations for layer {idx}") |
| |
| torch.save( |
| { |
| 'devset': devset, |
| 'dev_emb': dev_emb, |
| 'orig_emb': orig_emb, |
| 'after_layer': idx, |
| 'timestamp': str(datetime.datetime.now()) |
| }, f"{args.save_path}/dev_activations.new.pt") |
| if os.path.isfile(f"{args.save_path}/dev_activations.pt"): |
| os.remove(f"{args.save_path}/dev_activations.pt") |
| os.rename(f"{args.save_path}/dev_activations.new.pt", f"{args.save_path}/dev_activations.pt") |
| |
| glog.info(f'"done processing layer {idx}, total time {time.time() - st_time} seconds') |
| |
| |
| ''' |
|
|