Spaces:
Running
Running
File size: 6,990 Bytes
9080536 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import nn
def get_lm_head(model):
if hasattr(model, 'lm_head'): # LLaMA models
return model.lm_head
elif hasattr(model, 'embed_out'): # GPTNeoX (Pythia) models
return model.embed_out
else:
raise ValueError(f"Unsupported model architecture: {type(model)}")
def get_input_embeddings(model):
if hasattr(model, 'get_input_embeddings'):
return model.get_input_embeddings()
elif hasattr(model, 'gpt_neox'): # GPTNeoX (Pythia) models
return model.gpt_neox.embed_in
elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'): # LLaMA models
return model.model.embed_tokens
else:
raise ValueError(f"Unsupported model architecture: {type(model)}")
def customize_forward_pass(model, residual, presence, input_ids, grad_idx, attention_mask, ):
lm_head = get_lm_head(model)
embedding_layer = model.get_input_embeddings()
vocab_embed = embedding_layer.weight
with torch.no_grad():
input_ids_to_dev = input_ids.to(embedding_layer.weight.device)
base_embeds = embedding_layer(input_ids_to_dev)
# base_embeds = embedding_layer(input_ids.to(embedding_layer.weight.device))
def build_inputs():
embeds = base_embeds.clone()
# add residuals at masked positions
embeds[0, grad_idx, :] += residual
return embeds
def compute_logits(hidden, built_input_embeds):
"""
Modify logit of target token to use updated embedding for prediction
"""
L = built_input_embeds.size(1)
# lm_head = vocab_embed.clone()
# logits = hidden[0,].to(lm_head.device) @ lm_head.T
logits = hidden[0,].to(lm_head.device) @ vocab_embed.T
for t in range(0, L-1):
target_logit = torch.dot(hidden[0,t],built_input_embeds[0,t+1].to(hidden.device))
target_id = input_ids[0, t+1].item()
# print(f"logits[{t},{target_id}] before: {logits[t,target_id].item()}, after: {target_logit.item()}")
logits[t,target_id] = target_logit
return logits
def forward_pass(loss_position='all', hidden_norm_as_loss=False, unnormalized_logits=False, projection_probe=None,tie_input_output_embed = False, return_input_embeds = False, alpha=1, target_id=None):
embeds = build_inputs()
input_embeds = embeds
# input_embeds[0, grad_idx, :] *= presence
input_embeds[0, :, :] *= presence
input_embeds[0, :, :] *= alpha
# input_normalized = input_embeds[0, grad_idx, :] / input_embeds[0, grad_idx, :].norm(dim=-1, keepdim=True)
# print("norms: ", input_embeds.norm(dim=-1, keepdim=True))
# input_embeds[0, grad_idx, :] += presence * input_normalized
out = model.model(inputs_embeds=input_embeds,
attention_mask=attention_mask,
use_cache=False)
hidden = out.last_hidden_state # [1, L, d]
# print("hidden", hidden.shape)
if tie_input_output_embed:
readout_embeds = embeds
logits = compute_logits(hidden, readout_embeds)
else:
# lm_head.weight = lm_head.weight.to(hidden.device)
# lm_head = lm_head.to(hidden.device)
# logits = hidden[0,] @ lm_head.weight.T
lm_head_on_device = lm_head.to(hidden.device)
logits = hidden[0] @ lm_head_on_device.weight.T
targets = input_ids[0, 1:].to(logits.device)
# print("lm_head ", lm_head.weight.shape)
# print("logits ",logits.shape)
# print("targets ",targets.shape)
### Total energy for anomaly detection
if loss_position == 'all':
if unnormalized_logits:
# Extract logits at target positions and sum
target_logits = logits[torch.arange(len(targets)), targets] # [L-1]
loss_full = -target_logits
loss = loss_full.mean() # or .sum()
else:
loss = nn.CrossEntropyLoss()(logits[:-1], targets)
loss_full = nn.CrossEntropyLoss(reduction='none')(logits[:-1], targets).detach() # shape: (seq_len,)
# loss = nn.CrossEntropyLoss()(logits[5:-1], targets[5:])
return loss, logits, loss_full
if not torch.is_tensor(loss_position):
loss_position = torch.tensor(loss_position)
### random projection for Hutchinson
if projection_probe is not None:
projection_probe = projection_probe / projection_probe.norm(dim=-1, keepdim=True)
loss_position = loss_position.to(hidden.device)
# loss = (hidden[0, loss_position, :] * projection_probe.to(hidden.device)).sum()
loss = (hidden[0, loss_position, :] * projection_probe.to(hidden.device)).sum(dim=-1)
if return_input_embeds:
return loss, logits, input_embeds.detach()
else:
return loss, logits
### Loss at chosen location
if hidden_norm_as_loss == True:
### Temperature scope
hidden_act = hidden[:, loss_position, :].detach()
hidden_act = hidden_act / hidden_act.norm(dim=-1, keepdim=True)
# print("hidden_act", hidden_act.shape)
# print("hidden", hidden.shape)
loss_position = loss_position.to(hidden.device)
loss = (hidden[0, loss_position, :] * hidden_act).sum(dim=-1)
if return_input_embeds:
return loss, logits, input_embeds.detach()
else:
return loss, logits
else:
loss_position = loss_position.to(logits.device)
if target_id is not None:
target_chosen = torch.tensor([target_id], device=logits.device, dtype=torch.long)
else:
target_chosen = targets[loss_position].unsqueeze(0).to(logits.device)
if unnormalized_logits:
loss = -logits[loss_position,target_chosen]
else:
assert isinstance(loss_position, int) or (
torch.is_tensor(loss_position) and loss_position.dim() == 0
), "loss_position must be either an integer, a 0D tensor, or str(all)"
logits_chosen = logits[loss_position, :].unsqueeze(0) # [1, V]
loss = nn.CrossEntropyLoss()(logits_chosen, target_chosen)
if return_input_embeds:
return loss, logits, input_embeds.detach()
else:
return loss, logits
return forward_pass |