import gc import torch from tqdm import tqdm import pdb, sys import glog def clean(): gc.collect() torch.cuda.empty_cache() def show_metrics(hatW, W_orig, H, msg): err_frob = (hatW - W_orig).square().sum() / W_orig.square().sum() err_proxy = (((hatW - W_orig) @ H) * (hatW - W_orig)).sum() / ((W_orig @ H) * W_orig).sum() glog.info(f"{msg} frob error: {err_frob}") glog.info(f"{msg} proxy error: {err_proxy}") def compute_activation_deltas(model, devset, args, device): import pdb pdb.set_trace() dev_emb = model.model.embed_tokens(devset).cpu() position_ids = torch.arange(args.ctx_size, dtype=torch.int16)[None, :].to(device) + \ torch.zeros(args.batch_size, args.ctx_size, dtype=torch.int16).to(device) attention_mask = model.model._prepare_decoder_attention_mask( torch.ones(args.batch_size, args.ctx_size, dtype=torch.bool), (args.batch_size, args.ctx_size), dev_emb[0:args.batch_size, :, :], 0) attention_mask = attention_mask.to(device) acts = [dev_emb] + [torch.zeros(*dev_emb.shape) for _ in range(len(model.model.layers))] for i in tqdm(list(range(len(model.model.layers))), desc='computing activations'): layer = model.model.layers[i] layer = layer.to(device) for j in range(args.devset_size // args.batch_size): acts[i + 1][args.batch_size * j : args.batch_size * (j + 1)] = transformer_layer( acts[args.batch_size * j : args.batch_size * (j + 1)].to(device), position_ids=position_ids, attention_mask=attention_mask, use_cache=False, output_attentions=False)[0].cpu() return torch.stack(acts) class ForkedPdb(pdb.Pdb): """A Pdb subclass that may be used from a forked multiprocessing child use as: from lib import utils utils.ForkedPdb().set_trace() """ def interaction(self, *args, **kwargs): _stdin = sys.stdin try: sys.stdin = open('/dev/stdin') pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin ''' for i in tqdm(range(args.devset_size // args.batch_size), desc=f"calculating dev emb for layer {idx}"): dev_emb[(args.batch_size * i):(args.batch_size * (i + 1)), :, :] = transformer_layer( dev_emb[(args.batch_size * i):(args.batch_size * (i + 1)), :, :].to(device), position_ids=position_ids, attention_mask=attention_mask, use_cache=False, output_attentions=False)[0].cpu() act_error = (dev_emb.to(dtype_) - orig_emb.to(dtype_)).square().sum() / (orig_emb.to(dtype_) - orig_emb.to(dtype_).mean( (0, 1))).square().sum() glog.info(f"layer {idx} act error: {act_error}") glog.info(f"saving activations for layer {idx}") torch.save( { 'devset': devset, 'dev_emb': dev_emb, 'orig_emb': orig_emb, 'after_layer': idx, 'timestamp': str(datetime.datetime.now()) }, f"{args.save_path}/dev_activations.new.pt") if os.path.isfile(f"{args.save_path}/dev_activations.pt"): os.remove(f"{args.save_path}/dev_activations.pt") os.rename(f"{args.save_path}/dev_activations.new.pt", f"{args.save_path}/dev_activations.pt") glog.info(f'"done processing layer {idx}, total time {time.time() - st_time} seconds') '''