rain1955 commited on
Commit
051c56b
·
verified ·
1 Parent(s): b10ae02

Add extract_vectors.py

Browse files
Files changed (1) hide show
  1. extract_vectors.py +317 -0
extract_vectors.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Extract emotion vectors from Gemma4 model using PyTorch hooks on residual stream."""
3
+
4
+ import json
5
+ import os
6
+ import warnings
7
+ import numpy as np
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
10
+ from collections import defaultdict
11
+
12
+ warnings.filterwarnings("ignore")
13
+
14
+ EXP_DIR = os.path.dirname(os.path.abspath(__file__))
15
+ STORIES_FILE = os.path.join(EXP_DIR, "emotion_stories.jsonl")
16
+ OUTPUT_DIR = os.path.join(EXP_DIR, "results")
17
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
18
+
19
+ # Model config - use Gemma4 E4B from HuggingFace
20
+ MODEL_ID = "google/gemma-4-E4B-it" # Gemma4 E4B (not gated, ~8GB float16)
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+ START_TOKEN = 50 # Start averaging from token 50 (emotion should be apparent)
23
+
24
+
25
+ def load_stories():
26
+ stories = defaultdict(list)
27
+ with open(STORIES_FILE, "r") as f:
28
+ for line in f:
29
+ d = json.loads(line)
30
+ stories[d["emotion"]].append(d["text"])
31
+ return dict(stories)
32
+
33
+
34
+ def get_residual_stream_hooks(model):
35
+ """Attach hooks to capture residual stream after each layer."""
36
+ activations = {}
37
+
38
+ def make_hook(name):
39
+ def hook(module, input, output):
40
+ # output is typically (hidden_states, ...) or just hidden_states
41
+ if isinstance(output, tuple):
42
+ hidden = output[0]
43
+ else:
44
+ hidden = output
45
+ activations[name] = hidden.detach().cpu().float()
46
+ return hook
47
+
48
+ hooks = []
49
+ # Gemma4 multimodal: text layers at model.model.language_model.layers
50
+ # Gemma3/standard: text layers at model.model.layers
51
+ if hasattr(model.model, 'language_model'):
52
+ layers = model.model.language_model.layers
53
+ else:
54
+ layers = model.model.layers
55
+ for i, layer in enumerate(layers):
56
+ h = layer.register_forward_hook(make_hook(f"layer_{i}"))
57
+ hooks.append(h)
58
+
59
+ return activations, hooks
60
+
61
+
62
+ def extract_activations(model, tokenizer, text, activations_dict, target_layer):
63
+ """Run text through model and return mean activation at target layer."""
64
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
65
+
66
+ with torch.no_grad():
67
+ model(**inputs)
68
+
69
+ key = f"layer_{target_layer}"
70
+ if key not in activations_dict:
71
+ return None
72
+
73
+ hidden = activations_dict[key] # (1, seq_len, hidden_dim)
74
+ seq_len = hidden.shape[1]
75
+
76
+ if seq_len <= START_TOKEN:
77
+ # Short text, use all tokens
78
+ mean_act = hidden[0].mean(dim=0)
79
+ else:
80
+ mean_act = hidden[0, START_TOKEN:].mean(dim=0)
81
+
82
+ # Clear activations for next run
83
+ activations_dict.clear()
84
+
85
+ return mean_act.numpy()
86
+
87
+
88
+ def compute_emotion_vectors(emotion_activations):
89
+ """
90
+ emotion_vector[e] = mean(activations[e]) - mean(activations[all])
91
+ """
92
+ # Global mean across all emotions
93
+ all_acts = []
94
+ for acts in emotion_activations.values():
95
+ all_acts.extend(acts)
96
+ global_mean = np.mean(all_acts, axis=0)
97
+
98
+ emotion_vectors = {}
99
+ for emotion, acts in emotion_activations.items():
100
+ emotion_mean = np.mean(acts, axis=0)
101
+ emotion_vectors[emotion] = emotion_mean - global_mean
102
+
103
+ return emotion_vectors, global_mean
104
+
105
+
106
+ def denoise_vectors(emotion_vectors, neutral_activations, variance_threshold=0.5):
107
+ """Project out top PCA components from neutral text."""
108
+ if len(neutral_activations) == 0:
109
+ return emotion_vectors
110
+
111
+ neutral_matrix = np.stack(neutral_activations)
112
+ neutral_centered = neutral_matrix - neutral_matrix.mean(axis=0)
113
+
114
+ # SVD
115
+ U, S, Vt = np.linalg.svd(neutral_centered, full_matrices=False)
116
+
117
+ # Find components explaining 50% variance
118
+ total_var = (S ** 2).sum()
119
+ cumvar = np.cumsum(S ** 2) / total_var
120
+ n_components = np.searchsorted(cumvar, variance_threshold) + 1
121
+ print(f"Denoising: projecting out {n_components} components (explain {variance_threshold*100}% variance)")
122
+
123
+ # Projection matrix
124
+ V_noise = Vt[:n_components].T # (hidden_dim, n_components)
125
+
126
+ denoised = {}
127
+ for emotion, vec in emotion_vectors.items():
128
+ # Project out noise components
129
+ projection = V_noise @ (V_noise.T @ vec)
130
+ denoised[emotion] = vec - projection
131
+
132
+ return denoised
133
+
134
+
135
+ def logit_lens(model, tokenizer, emotion_vectors, top_k=10):
136
+ """Project emotion vectors through unembedding to see associated tokens."""
137
+ # Get the lm_head / embed_tokens weight
138
+ if hasattr(model, 'lm_head'):
139
+ W = model.lm_head.weight.detach().cpu().float().numpy() # (vocab, hidden)
140
+ elif hasattr(model.model, 'language_model'):
141
+ W = model.model.language_model.embed_tokens.weight.detach().cpu().float().numpy()
142
+ else:
143
+ W = model.model.embed_tokens.weight.detach().cpu().float().numpy()
144
+
145
+ results = {}
146
+ for emotion, vec in emotion_vectors.items():
147
+ # Logits = W @ vec
148
+ logits = W @ vec
149
+ top_indices = np.argsort(logits)[-top_k:][::-1]
150
+ bottom_indices = np.argsort(logits)[:top_k]
151
+
152
+ top_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in top_indices]
153
+ bottom_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in bottom_indices]
154
+
155
+ results[emotion] = {"top": top_tokens, "bottom": bottom_tokens}
156
+
157
+ return results
158
+
159
+
160
+ def pca_analysis(emotion_vectors):
161
+ """PCA on emotion vectors to find valence/arousal structure."""
162
+ emotions = list(emotion_vectors.keys())
163
+ matrix = np.stack([emotion_vectors[e] for e in emotions])
164
+
165
+ # Center
166
+ matrix_centered = matrix - matrix.mean(axis=0)
167
+
168
+ # SVD
169
+ U, S, Vt = np.linalg.svd(matrix_centered, full_matrices=False)
170
+
171
+ # Project onto first 2 PCs
172
+ projections = matrix_centered @ Vt[:2].T # (n_emotions, 2)
173
+
174
+ explained_variance = (S[:2] ** 2) / (S ** 2).sum()
175
+
176
+ return {
177
+ "emotions": emotions,
178
+ "pc1": projections[:, 0].tolist(),
179
+ "pc2": projections[:, 1].tolist(),
180
+ "explained_variance_pc1": float(explained_variance[0]),
181
+ "explained_variance_pc2": float(explained_variance[1]),
182
+ }
183
+
184
+
185
+ def main():
186
+ print("=== Emotion Vector Extraction Experiment ===\n")
187
+
188
+ # Load stories
189
+ stories = load_stories()
190
+ print(f"Loaded {sum(len(v) for v in stories.values())} stories across {len(stories)} emotions\n")
191
+
192
+ # Load model
193
+ print(f"Loading model {MODEL_ID}...")
194
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
195
+ # Gemma4 is multimodal — AutoModelForCausalLM maps to ForConditionalGeneration
196
+ # which is correct, but we need to handle the nested language_model layers
197
+ model = AutoModelForCausalLM.from_pretrained(
198
+ MODEL_ID,
199
+ torch_dtype=torch.bfloat16, # Gemma4 native dtype
200
+ device_map="auto",
201
+ )
202
+ model.eval()
203
+
204
+ if hasattr(model.model, 'language_model'):
205
+ num_layers = len(model.model.language_model.layers)
206
+ else:
207
+ num_layers = len(model.model.layers)
208
+ target_layer = int(num_layers * 2 / 3) # 2/3 depth
209
+ print(f"Model loaded. {num_layers} layers, target layer: {target_layer}\n")
210
+
211
+ # Attach hooks
212
+ activations_dict, hooks = get_residual_stream_hooks(model)
213
+
214
+ # Extract activations for each emotion
215
+ print("Extracting activations...")
216
+ emotion_activations = defaultdict(list)
217
+ total = sum(len(v) for v in stories.values())
218
+ done = 0
219
+
220
+ for emotion, story_list in stories.items():
221
+ for story in story_list:
222
+ act = extract_activations(model, tokenizer, story, activations_dict, target_layer)
223
+ if act is not None:
224
+ emotion_activations[emotion].append(act)
225
+ done += 1
226
+ if done % 50 == 0:
227
+ print(f" [{done}/{total}]")
228
+
229
+ print(f" Extracted activations for {len(emotion_activations)} emotions\n")
230
+
231
+ # Extract neutral activations for denoising
232
+ print("Extracting neutral activations for denoising...")
233
+ neutral_texts = [
234
+ "The meeting is scheduled for 3pm tomorrow.",
235
+ "Please find the attached document.",
236
+ "The temperature today is 22 degrees Celsius.",
237
+ "The project deadline has been moved to next Friday.",
238
+ "The store is located on the corner of Main Street.",
239
+ "Chapter 3 discusses the economic implications.",
240
+ "The software update includes several bug fixes.",
241
+ "The report contains data from the past quarter.",
242
+ "The committee will review the proposal next week.",
243
+ "The library opens at 9am on weekdays.",
244
+ ] * 5 # 50 neutral texts
245
+
246
+ neutral_activations = []
247
+ for text in neutral_texts:
248
+ act = extract_activations(model, tokenizer, text, activations_dict, target_layer)
249
+ if act is not None:
250
+ neutral_activations.append(act)
251
+ print(f" {len(neutral_activations)} neutral activations collected\n")
252
+
253
+ # Compute emotion vectors
254
+ print("Computing emotion vectors...")
255
+ raw_vectors, global_mean = compute_emotion_vectors(dict(emotion_activations))
256
+ print(f" {len(raw_vectors)} raw emotion vectors computed")
257
+
258
+ # Denoise
259
+ print("Denoising...")
260
+ denoised_vectors = denoise_vectors(raw_vectors, neutral_activations)
261
+
262
+ # Logit Lens
263
+ print("\nRunning Logit Lens analysis...")
264
+ logit_results = logit_lens(model, tokenizer, denoised_vectors)
265
+
266
+ print("\n=== Logit Lens Results ===")
267
+ for emotion in sorted(logit_results.keys()):
268
+ top = logit_results[emotion]["top"][:5]
269
+ bottom = logit_results[emotion]["bottom"][:5]
270
+ top_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in top])
271
+ bottom_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in bottom])
272
+ print(f" {emotion:12s} ↑ {top_str}")
273
+ print(f" {' ':12s} ↓ {bottom_str}")
274
+
275
+ # PCA analysis
276
+ print("\nRunning PCA analysis...")
277
+ pca_results = pca_analysis(denoised_vectors)
278
+ print(f" PC1 explains {pca_results['explained_variance_pc1']*100:.1f}% variance")
279
+ print(f" PC2 explains {pca_results['explained_variance_pc2']*100:.1f}% variance")
280
+
281
+ print("\n=== Emotion Space (PC1 vs PC2) ===")
282
+ for i, emotion in enumerate(pca_results["emotions"]):
283
+ pc1 = pca_results["pc1"][i]
284
+ pc2 = pca_results["pc2"][i]
285
+ print(f" {emotion:12s} PC1={pc1:+.3f} PC2={pc2:+.3f}")
286
+
287
+ # Save results
288
+ results = {
289
+ "model": MODEL_ID,
290
+ "target_layer": target_layer,
291
+ "num_layers": num_layers,
292
+ "num_emotions": len(denoised_vectors),
293
+ "stories_per_emotion": {e: len(v) for e, v in stories.items()},
294
+ "logit_lens": logit_results,
295
+ "pca": pca_results,
296
+ }
297
+
298
+ results_file = os.path.join(OUTPUT_DIR, "experiment_results.json")
299
+ with open(results_file, "w") as f:
300
+ json.dump(results, f, indent=2, ensure_ascii=False)
301
+
302
+ # Save raw vectors as numpy
303
+ vectors_file = os.path.join(OUTPUT_DIR, "emotion_vectors.npz")
304
+ np.savez(vectors_file, **{e: v for e, v in denoised_vectors.items()})
305
+
306
+ print(f"\nResults saved to {results_file}")
307
+ print(f"Vectors saved to {vectors_file}")
308
+
309
+ # Cleanup hooks
310
+ for h in hooks:
311
+ h.remove()
312
+
313
+ print("\n=== EXPERIMENT COMPLETE ===")
314
+
315
+
316
+ if __name__ == "__main__":
317
+ main()