File size: 11,150 Bytes
051c56b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
#!/usr/bin/env python3
"""Extract emotion vectors from Gemma4 model using PyTorch hooks on residual stream."""

import json
import os
import warnings
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from collections import defaultdict

warnings.filterwarnings("ignore")

EXP_DIR = os.path.dirname(os.path.abspath(__file__))
STORIES_FILE = os.path.join(EXP_DIR, "emotion_stories.jsonl")
OUTPUT_DIR = os.path.join(EXP_DIR, "results")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model config - use Gemma4 E4B from HuggingFace
MODEL_ID = "google/gemma-4-E4B-it"  # Gemma4 E4B (not gated, ~8GB float16)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
START_TOKEN = 50  # Start averaging from token 50 (emotion should be apparent)


def load_stories():
    stories = defaultdict(list)
    with open(STORIES_FILE, "r") as f:
        for line in f:
            d = json.loads(line)
            stories[d["emotion"]].append(d["text"])
    return dict(stories)


def get_residual_stream_hooks(model):
    """Attach hooks to capture residual stream after each layer."""
    activations = {}

    def make_hook(name):
        def hook(module, input, output):
            # output is typically (hidden_states, ...) or just hidden_states
            if isinstance(output, tuple):
                hidden = output[0]
            else:
                hidden = output
            activations[name] = hidden.detach().cpu().float()
        return hook

    hooks = []
    # Gemma4 multimodal: text layers at model.model.language_model.layers
    # Gemma3/standard: text layers at model.model.layers
    if hasattr(model.model, 'language_model'):
        layers = model.model.language_model.layers
    else:
        layers = model.model.layers
    for i, layer in enumerate(layers):
        h = layer.register_forward_hook(make_hook(f"layer_{i}"))
        hooks.append(h)

    return activations, hooks


def extract_activations(model, tokenizer, text, activations_dict, target_layer):
    """Run text through model and return mean activation at target layer."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)

    with torch.no_grad():
        model(**inputs)

    key = f"layer_{target_layer}"
    if key not in activations_dict:
        return None

    hidden = activations_dict[key]  # (1, seq_len, hidden_dim)
    seq_len = hidden.shape[1]

    if seq_len <= START_TOKEN:
        # Short text, use all tokens
        mean_act = hidden[0].mean(dim=0)
    else:
        mean_act = hidden[0, START_TOKEN:].mean(dim=0)

    # Clear activations for next run
    activations_dict.clear()

    return mean_act.numpy()


def compute_emotion_vectors(emotion_activations):
    """
    emotion_vector[e] = mean(activations[e]) - mean(activations[all])
    """
    # Global mean across all emotions
    all_acts = []
    for acts in emotion_activations.values():
        all_acts.extend(acts)
    global_mean = np.mean(all_acts, axis=0)

    emotion_vectors = {}
    for emotion, acts in emotion_activations.items():
        emotion_mean = np.mean(acts, axis=0)
        emotion_vectors[emotion] = emotion_mean - global_mean

    return emotion_vectors, global_mean


def denoise_vectors(emotion_vectors, neutral_activations, variance_threshold=0.5):
    """Project out top PCA components from neutral text."""
    if len(neutral_activations) == 0:
        return emotion_vectors

    neutral_matrix = np.stack(neutral_activations)
    neutral_centered = neutral_matrix - neutral_matrix.mean(axis=0)

    # SVD
    U, S, Vt = np.linalg.svd(neutral_centered, full_matrices=False)

    # Find components explaining 50% variance
    total_var = (S ** 2).sum()
    cumvar = np.cumsum(S ** 2) / total_var
    n_components = np.searchsorted(cumvar, variance_threshold) + 1
    print(f"Denoising: projecting out {n_components} components (explain {variance_threshold*100}% variance)")

    # Projection matrix
    V_noise = Vt[:n_components].T  # (hidden_dim, n_components)

    denoised = {}
    for emotion, vec in emotion_vectors.items():
        # Project out noise components
        projection = V_noise @ (V_noise.T @ vec)
        denoised[emotion] = vec - projection

    return denoised


def logit_lens(model, tokenizer, emotion_vectors, top_k=10):
    """Project emotion vectors through unembedding to see associated tokens."""
    # Get the lm_head / embed_tokens weight
    if hasattr(model, 'lm_head'):
        W = model.lm_head.weight.detach().cpu().float().numpy()  # (vocab, hidden)
    elif hasattr(model.model, 'language_model'):
        W = model.model.language_model.embed_tokens.weight.detach().cpu().float().numpy()
    else:
        W = model.model.embed_tokens.weight.detach().cpu().float().numpy()

    results = {}
    for emotion, vec in emotion_vectors.items():
        # Logits = W @ vec
        logits = W @ vec
        top_indices = np.argsort(logits)[-top_k:][::-1]
        bottom_indices = np.argsort(logits)[:top_k]

        top_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in top_indices]
        bottom_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in bottom_indices]

        results[emotion] = {"top": top_tokens, "bottom": bottom_tokens}

    return results


def pca_analysis(emotion_vectors):
    """PCA on emotion vectors to find valence/arousal structure."""
    emotions = list(emotion_vectors.keys())
    matrix = np.stack([emotion_vectors[e] for e in emotions])

    # Center
    matrix_centered = matrix - matrix.mean(axis=0)

    # SVD
    U, S, Vt = np.linalg.svd(matrix_centered, full_matrices=False)

    # Project onto first 2 PCs
    projections = matrix_centered @ Vt[:2].T  # (n_emotions, 2)

    explained_variance = (S[:2] ** 2) / (S ** 2).sum()

    return {
        "emotions": emotions,
        "pc1": projections[:, 0].tolist(),
        "pc2": projections[:, 1].tolist(),
        "explained_variance_pc1": float(explained_variance[0]),
        "explained_variance_pc2": float(explained_variance[1]),
    }


def main():
    print("=== Emotion Vector Extraction Experiment ===\n")

    # Load stories
    stories = load_stories()
    print(f"Loaded {sum(len(v) for v in stories.values())} stories across {len(stories)} emotions\n")

    # Load model
    print(f"Loading model {MODEL_ID}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    # Gemma4 is multimodal — AutoModelForCausalLM maps to ForConditionalGeneration
    # which is correct, but we need to handle the nested language_model layers
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,  # Gemma4 native dtype
        device_map="auto",
    )
    model.eval()

    if hasattr(model.model, 'language_model'):
        num_layers = len(model.model.language_model.layers)
    else:
        num_layers = len(model.model.layers)
    target_layer = int(num_layers * 2 / 3)  # 2/3 depth
    print(f"Model loaded. {num_layers} layers, target layer: {target_layer}\n")

    # Attach hooks
    activations_dict, hooks = get_residual_stream_hooks(model)

    # Extract activations for each emotion
    print("Extracting activations...")
    emotion_activations = defaultdict(list)
    total = sum(len(v) for v in stories.values())
    done = 0

    for emotion, story_list in stories.items():
        for story in story_list:
            act = extract_activations(model, tokenizer, story, activations_dict, target_layer)
            if act is not None:
                emotion_activations[emotion].append(act)
            done += 1
            if done % 50 == 0:
                print(f"  [{done}/{total}]")

    print(f"  Extracted activations for {len(emotion_activations)} emotions\n")

    # Extract neutral activations for denoising
    print("Extracting neutral activations for denoising...")
    neutral_texts = [
        "The meeting is scheduled for 3pm tomorrow.",
        "Please find the attached document.",
        "The temperature today is 22 degrees Celsius.",
        "The project deadline has been moved to next Friday.",
        "The store is located on the corner of Main Street.",
        "Chapter 3 discusses the economic implications.",
        "The software update includes several bug fixes.",
        "The report contains data from the past quarter.",
        "The committee will review the proposal next week.",
        "The library opens at 9am on weekdays.",
    ] * 5  # 50 neutral texts

    neutral_activations = []
    for text in neutral_texts:
        act = extract_activations(model, tokenizer, text, activations_dict, target_layer)
        if act is not None:
            neutral_activations.append(act)
    print(f"  {len(neutral_activations)} neutral activations collected\n")

    # Compute emotion vectors
    print("Computing emotion vectors...")
    raw_vectors, global_mean = compute_emotion_vectors(dict(emotion_activations))
    print(f"  {len(raw_vectors)} raw emotion vectors computed")

    # Denoise
    print("Denoising...")
    denoised_vectors = denoise_vectors(raw_vectors, neutral_activations)

    # Logit Lens
    print("\nRunning Logit Lens analysis...")
    logit_results = logit_lens(model, tokenizer, denoised_vectors)

    print("\n=== Logit Lens Results ===")
    for emotion in sorted(logit_results.keys()):
        top = logit_results[emotion]["top"][:5]
        bottom = logit_results[emotion]["bottom"][:5]
        top_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in top])
        bottom_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in bottom])
        print(f"  {emotion:12s}{top_str}")
        print(f"  {' ':12s}{bottom_str}")

    # PCA analysis
    print("\nRunning PCA analysis...")
    pca_results = pca_analysis(denoised_vectors)
    print(f"  PC1 explains {pca_results['explained_variance_pc1']*100:.1f}% variance")
    print(f"  PC2 explains {pca_results['explained_variance_pc2']*100:.1f}% variance")

    print("\n=== Emotion Space (PC1 vs PC2) ===")
    for i, emotion in enumerate(pca_results["emotions"]):
        pc1 = pca_results["pc1"][i]
        pc2 = pca_results["pc2"][i]
        print(f"  {emotion:12s}  PC1={pc1:+.3f}  PC2={pc2:+.3f}")

    # Save results
    results = {
        "model": MODEL_ID,
        "target_layer": target_layer,
        "num_layers": num_layers,
        "num_emotions": len(denoised_vectors),
        "stories_per_emotion": {e: len(v) for e, v in stories.items()},
        "logit_lens": logit_results,
        "pca": pca_results,
    }

    results_file = os.path.join(OUTPUT_DIR, "experiment_results.json")
    with open(results_file, "w") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    # Save raw vectors as numpy
    vectors_file = os.path.join(OUTPUT_DIR, "emotion_vectors.npz")
    np.savez(vectors_file, **{e: v for e, v in denoised_vectors.items()})

    print(f"\nResults saved to {results_file}")
    print(f"Vectors saved to {vectors_file}")

    # Cleanup hooks
    for h in hooks:
        h.remove()

    print("\n=== EXPERIMENT COMPLETE ===")


if __name__ == "__main__":
    main()