import torch from transformers import AutoModelForCausalLM, AutoTokenizer from torch import nn def get_lm_head(model): if hasattr(model, 'lm_head'): # LLaMA models return model.lm_head elif hasattr(model, 'embed_out'): # GPTNeoX (Pythia) models return model.embed_out else: raise ValueError(f"Unsupported model architecture: {type(model)}") def get_input_embeddings(model): if hasattr(model, 'get_input_embeddings'): return model.get_input_embeddings() elif hasattr(model, 'gpt_neox'): # GPTNeoX (Pythia) models return model.gpt_neox.embed_in elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'): # LLaMA models return model.model.embed_tokens else: raise ValueError(f"Unsupported model architecture: {type(model)}") def customize_forward_pass(model, residual, presence, input_ids, grad_idx, attention_mask, ): lm_head = get_lm_head(model) embedding_layer = model.get_input_embeddings() vocab_embed = embedding_layer.weight with torch.no_grad(): input_ids_to_dev = input_ids.to(embedding_layer.weight.device) base_embeds = embedding_layer(input_ids_to_dev) # base_embeds = embedding_layer(input_ids.to(embedding_layer.weight.device)) def build_inputs(): embeds = base_embeds.clone() # add residuals at masked positions embeds[0, grad_idx, :] += residual return embeds def compute_logits(hidden, built_input_embeds): """ Modify logit of target token to use updated embedding for prediction """ L = built_input_embeds.size(1) # lm_head = vocab_embed.clone() # logits = hidden[0,].to(lm_head.device) @ lm_head.T logits = hidden[0,].to(lm_head.device) @ vocab_embed.T for t in range(0, L-1): target_logit = torch.dot(hidden[0,t],built_input_embeds[0,t+1].to(hidden.device)) target_id = input_ids[0, t+1].item() # print(f"logits[{t},{target_id}] before: {logits[t,target_id].item()}, after: {target_logit.item()}") logits[t,target_id] = target_logit return logits def forward_pass(loss_position='all', hidden_norm_as_loss=False, unnormalized_logits=False, projection_probe=None,tie_input_output_embed = False, return_input_embeds = False, alpha=1, target_id=None): embeds = build_inputs() input_embeds = embeds # input_embeds[0, grad_idx, :] *= presence input_embeds[0, :, :] *= presence input_embeds[0, :, :] *= alpha # input_normalized = input_embeds[0, grad_idx, :] / input_embeds[0, grad_idx, :].norm(dim=-1, keepdim=True) # print("norms: ", input_embeds.norm(dim=-1, keepdim=True)) # input_embeds[0, grad_idx, :] += presence * input_normalized out = model.model(inputs_embeds=input_embeds, attention_mask=attention_mask, use_cache=False) hidden = out.last_hidden_state # [1, L, d] # print("hidden", hidden.shape) if tie_input_output_embed: readout_embeds = embeds logits = compute_logits(hidden, readout_embeds) else: # lm_head.weight = lm_head.weight.to(hidden.device) # lm_head = lm_head.to(hidden.device) # logits = hidden[0,] @ lm_head.weight.T lm_head_on_device = lm_head.to(hidden.device) logits = hidden[0] @ lm_head_on_device.weight.T targets = input_ids[0, 1:].to(logits.device) # print("lm_head ", lm_head.weight.shape) # print("logits ",logits.shape) # print("targets ",targets.shape) ### Total energy for anomaly detection if loss_position == 'all': if unnormalized_logits: # Extract logits at target positions and sum target_logits = logits[torch.arange(len(targets)), targets] # [L-1] loss_full = -target_logits loss = loss_full.mean() # or .sum() else: loss = nn.CrossEntropyLoss()(logits[:-1], targets) loss_full = nn.CrossEntropyLoss(reduction='none')(logits[:-1], targets).detach() # shape: (seq_len,) # loss = nn.CrossEntropyLoss()(logits[5:-1], targets[5:]) return loss, logits, loss_full if not torch.is_tensor(loss_position): loss_position = torch.tensor(loss_position) ### random projection for Hutchinson if projection_probe is not None: projection_probe = projection_probe / projection_probe.norm(dim=-1, keepdim=True) loss_position = loss_position.to(hidden.device) # loss = (hidden[0, loss_position, :] * projection_probe.to(hidden.device)).sum() loss = (hidden[0, loss_position, :] * projection_probe.to(hidden.device)).sum(dim=-1) if return_input_embeds: return loss, logits, input_embeds.detach() else: return loss, logits ### Loss at chosen location if hidden_norm_as_loss == True: ### Temperature scope hidden_act = hidden[:, loss_position, :].detach() hidden_act = hidden_act / hidden_act.norm(dim=-1, keepdim=True) # print("hidden_act", hidden_act.shape) # print("hidden", hidden.shape) loss_position = loss_position.to(hidden.device) loss = (hidden[0, loss_position, :] * hidden_act).sum(dim=-1) if return_input_embeds: return loss, logits, input_embeds.detach() else: return loss, logits else: loss_position = loss_position.to(logits.device) if target_id is not None: target_chosen = torch.tensor([target_id], device=logits.device, dtype=torch.long) else: target_chosen = targets[loss_position].unsqueeze(0).to(logits.device) if unnormalized_logits: loss = -logits[loss_position,target_chosen] else: assert isinstance(loss_position, int) or ( torch.is_tensor(loss_position) and loss_position.dim() == 0 ), "loss_position must be either an integer, a 0D tensor, or str(all)" logits_chosen = logits[loss_position, :].unsqueeze(0) # [1, V] loss = nn.CrossEntropyLoss()(logits_chosen, target_chosen) if return_input_embeds: return loss, logits, input_embeds.detach() else: return loss, logits return forward_pass