AmnaHassan commited on
Commit
78eadb7
·
verified ·
1 Parent(s): 85acc7e

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +70 -0
model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ self.model = GPT2LMHeadModel.from_pretrained(model_name, output_hidden_states=True).to(self.device)
3
+ self.model.eval()
4
+
5
+
6
+ def generate_text(self, prompt, max_length=50, top_k=10):
7
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
8
+ with torch.no_grad():
9
+ output = self.model.generate(**inputs, max_length=len(inputs['input_ids'][0]) + max_length, do_sample=True, top_k=top_k, pad_token_id=self.tokenizer.eos_token_id)
10
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)
11
+
12
+
13
+ def _get_hidden_states(self, prompt):
14
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
15
+ with torch.no_grad():
16
+ out = self.model(**inputs)
17
+ # hidden_states: tuple(len = n_layers+1) of (batch, seq_len, hidden)
18
+ return out.hidden_states
19
+
20
+
21
+ def layer_importance(self, prompt, experiment_type="story_continuation"):
22
+ """
23
+ Simple proxy for activation patching: measure how sensitive the model's next-token logits are
24
+ to zeroing the output of each transformer block (layer). For each layer:
25
+ - compute logits with all layers active
26
+ - compute logits with layer `l` zeroed out (set its hidden output to zero)
27
+ - compute L1 difference between the top token logits — larger difference => higher importance
28
+
29
+
30
+ Returns a list of importance scores (one per transformer block).
31
+ """
32
+ # 1) get baseline logits for the prompt
33
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
34
+ input_ids = inputs['input_ids']
35
+ with torch.no_grad():
36
+ out = self.model(**inputs, output_hidden_states=True)
37
+ baseline_logits = out.logits[0, -1, :].cpu().numpy()
38
+
39
+
40
+ # Number of transformer blocks
41
+ n_layers = len(out.hidden_states) - 1
42
+ scores = []
43
+
44
+
45
+ # We'll re-run forward passes while zeroing each layer's output using a forward hook
46
+ for layer_idx in range(n_layers):
47
+ def hook(module, inp, outp):
48
+ # outp has shape (batch, seq_len, hidden)
49
+ return torch.zeros_like(outp)
50
+
51
+
52
+ # register hook on transformer.h.{layer_idx}
53
+ handle = self.model.transformer.h[layer_idx].register_forward_hook(hook)
54
+ with torch.no_grad():
55
+ out2 = self.model(**inputs)
56
+ logits2 = out2.logits[0, -1, :].cpu().numpy()
57
+ diff = np.sum(np.abs(baseline_logits - logits2))
58
+ scores.append(float(diff))
59
+ handle.remove()
60
+
61
+
62
+ # Normalize scores to 0-1
63
+ arr = np.array(scores)
64
+ if arr.max() > 0:
65
+ arr = (arr - arr.min()) / (arr.max() - arr.min())
66
+ else:
67
+ arr = arr * 0.0
68
+
69
+
70
+ return arr.tolist()