#!/usr/bin/env python3 """Extract emotion vectors from Gemma4 model using PyTorch hooks on residual stream.""" import json import os import warnings import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel from collections import defaultdict warnings.filterwarnings("ignore") EXP_DIR = os.path.dirname(os.path.abspath(__file__)) STORIES_FILE = os.path.join(EXP_DIR, "emotion_stories.jsonl") OUTPUT_DIR = os.path.join(EXP_DIR, "results") os.makedirs(OUTPUT_DIR, exist_ok=True) # Model config - use Gemma4 E4B from HuggingFace MODEL_ID = "google/gemma-4-E4B-it" # Gemma4 E4B (not gated, ~8GB float16) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" START_TOKEN = 50 # Start averaging from token 50 (emotion should be apparent) def load_stories(): stories = defaultdict(list) with open(STORIES_FILE, "r") as f: for line in f: d = json.loads(line) stories[d["emotion"]].append(d["text"]) return dict(stories) def get_residual_stream_hooks(model): """Attach hooks to capture residual stream after each layer.""" activations = {} def make_hook(name): def hook(module, input, output): # output is typically (hidden_states, ...) or just hidden_states if isinstance(output, tuple): hidden = output[0] else: hidden = output activations[name] = hidden.detach().cpu().float() return hook hooks = [] # Gemma4 multimodal: text layers at model.model.language_model.layers # Gemma3/standard: text layers at model.model.layers if hasattr(model.model, 'language_model'): layers = model.model.language_model.layers else: layers = model.model.layers for i, layer in enumerate(layers): h = layer.register_forward_hook(make_hook(f"layer_{i}")) hooks.append(h) return activations, hooks def extract_activations(model, tokenizer, text, activations_dict, target_layer): """Run text through model and return mean activation at target layer.""" inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(DEVICE) with torch.no_grad(): model(**inputs) key = f"layer_{target_layer}" if key not in activations_dict: return None hidden = activations_dict[key] # (1, seq_len, hidden_dim) seq_len = hidden.shape[1] if seq_len <= START_TOKEN: # Short text, use all tokens mean_act = hidden[0].mean(dim=0) else: mean_act = hidden[0, START_TOKEN:].mean(dim=0) # Clear activations for next run activations_dict.clear() return mean_act.numpy() def compute_emotion_vectors(emotion_activations): """ emotion_vector[e] = mean(activations[e]) - mean(activations[all]) """ # Global mean across all emotions all_acts = [] for acts in emotion_activations.values(): all_acts.extend(acts) global_mean = np.mean(all_acts, axis=0) emotion_vectors = {} for emotion, acts in emotion_activations.items(): emotion_mean = np.mean(acts, axis=0) emotion_vectors[emotion] = emotion_mean - global_mean return emotion_vectors, global_mean def denoise_vectors(emotion_vectors, neutral_activations, variance_threshold=0.5): """Project out top PCA components from neutral text.""" if len(neutral_activations) == 0: return emotion_vectors neutral_matrix = np.stack(neutral_activations) neutral_centered = neutral_matrix - neutral_matrix.mean(axis=0) # SVD U, S, Vt = np.linalg.svd(neutral_centered, full_matrices=False) # Find components explaining 50% variance total_var = (S ** 2).sum() cumvar = np.cumsum(S ** 2) / total_var n_components = np.searchsorted(cumvar, variance_threshold) + 1 print(f"Denoising: projecting out {n_components} components (explain {variance_threshold*100}% variance)") # Projection matrix V_noise = Vt[:n_components].T # (hidden_dim, n_components) denoised = {} for emotion, vec in emotion_vectors.items(): # Project out noise components projection = V_noise @ (V_noise.T @ vec) denoised[emotion] = vec - projection return denoised def logit_lens(model, tokenizer, emotion_vectors, top_k=10): """Project emotion vectors through unembedding to see associated tokens.""" # Get the lm_head / embed_tokens weight if hasattr(model, 'lm_head'): W = model.lm_head.weight.detach().cpu().float().numpy() # (vocab, hidden) elif hasattr(model.model, 'language_model'): W = model.model.language_model.embed_tokens.weight.detach().cpu().float().numpy() else: W = model.model.embed_tokens.weight.detach().cpu().float().numpy() results = {} for emotion, vec in emotion_vectors.items(): # Logits = W @ vec logits = W @ vec top_indices = np.argsort(logits)[-top_k:][::-1] bottom_indices = np.argsort(logits)[:top_k] top_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in top_indices] bottom_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in bottom_indices] results[emotion] = {"top": top_tokens, "bottom": bottom_tokens} return results def pca_analysis(emotion_vectors): """PCA on emotion vectors to find valence/arousal structure.""" emotions = list(emotion_vectors.keys()) matrix = np.stack([emotion_vectors[e] for e in emotions]) # Center matrix_centered = matrix - matrix.mean(axis=0) # SVD U, S, Vt = np.linalg.svd(matrix_centered, full_matrices=False) # Project onto first 2 PCs projections = matrix_centered @ Vt[:2].T # (n_emotions, 2) explained_variance = (S[:2] ** 2) / (S ** 2).sum() return { "emotions": emotions, "pc1": projections[:, 0].tolist(), "pc2": projections[:, 1].tolist(), "explained_variance_pc1": float(explained_variance[0]), "explained_variance_pc2": float(explained_variance[1]), } def main(): print("=== Emotion Vector Extraction Experiment ===\n") # Load stories stories = load_stories() print(f"Loaded {sum(len(v) for v in stories.values())} stories across {len(stories)} emotions\n") # Load model print(f"Loading model {MODEL_ID}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Gemma4 is multimodal — AutoModelForCausalLM maps to ForConditionalGeneration # which is correct, but we need to handle the nested language_model layers model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, # Gemma4 native dtype device_map="auto", ) model.eval() if hasattr(model.model, 'language_model'): num_layers = len(model.model.language_model.layers) else: num_layers = len(model.model.layers) target_layer = int(num_layers * 2 / 3) # 2/3 depth print(f"Model loaded. {num_layers} layers, target layer: {target_layer}\n") # Attach hooks activations_dict, hooks = get_residual_stream_hooks(model) # Extract activations for each emotion print("Extracting activations...") emotion_activations = defaultdict(list) total = sum(len(v) for v in stories.values()) done = 0 for emotion, story_list in stories.items(): for story in story_list: act = extract_activations(model, tokenizer, story, activations_dict, target_layer) if act is not None: emotion_activations[emotion].append(act) done += 1 if done % 50 == 0: print(f" [{done}/{total}]") print(f" Extracted activations for {len(emotion_activations)} emotions\n") # Extract neutral activations for denoising print("Extracting neutral activations for denoising...") neutral_texts = [ "The meeting is scheduled for 3pm tomorrow.", "Please find the attached document.", "The temperature today is 22 degrees Celsius.", "The project deadline has been moved to next Friday.", "The store is located on the corner of Main Street.", "Chapter 3 discusses the economic implications.", "The software update includes several bug fixes.", "The report contains data from the past quarter.", "The committee will review the proposal next week.", "The library opens at 9am on weekdays.", ] * 5 # 50 neutral texts neutral_activations = [] for text in neutral_texts: act = extract_activations(model, tokenizer, text, activations_dict, target_layer) if act is not None: neutral_activations.append(act) print(f" {len(neutral_activations)} neutral activations collected\n") # Compute emotion vectors print("Computing emotion vectors...") raw_vectors, global_mean = compute_emotion_vectors(dict(emotion_activations)) print(f" {len(raw_vectors)} raw emotion vectors computed") # Denoise print("Denoising...") denoised_vectors = denoise_vectors(raw_vectors, neutral_activations) # Logit Lens print("\nRunning Logit Lens analysis...") logit_results = logit_lens(model, tokenizer, denoised_vectors) print("\n=== Logit Lens Results ===") for emotion in sorted(logit_results.keys()): top = logit_results[emotion]["top"][:5] bottom = logit_results[emotion]["bottom"][:5] top_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in top]) bottom_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in bottom]) print(f" {emotion:12s} ↑ {top_str}") print(f" {' ':12s} ↓ {bottom_str}") # PCA analysis print("\nRunning PCA analysis...") pca_results = pca_analysis(denoised_vectors) print(f" PC1 explains {pca_results['explained_variance_pc1']*100:.1f}% variance") print(f" PC2 explains {pca_results['explained_variance_pc2']*100:.1f}% variance") print("\n=== Emotion Space (PC1 vs PC2) ===") for i, emotion in enumerate(pca_results["emotions"]): pc1 = pca_results["pc1"][i] pc2 = pca_results["pc2"][i] print(f" {emotion:12s} PC1={pc1:+.3f} PC2={pc2:+.3f}") # Save results results = { "model": MODEL_ID, "target_layer": target_layer, "num_layers": num_layers, "num_emotions": len(denoised_vectors), "stories_per_emotion": {e: len(v) for e, v in stories.items()}, "logit_lens": logit_results, "pca": pca_results, } results_file = os.path.join(OUTPUT_DIR, "experiment_results.json") with open(results_file, "w") as f: json.dump(results, f, indent=2, ensure_ascii=False) # Save raw vectors as numpy vectors_file = os.path.join(OUTPUT_DIR, "emotion_vectors.npz") np.savez(vectors_file, **{e: v for e, v in denoised_vectors.items()}) print(f"\nResults saved to {results_file}") print(f"Vectors saved to {vectors_file}") # Cleanup hooks for h in hooks: h.remove() print("\n=== EXPERIMENT COMPLETE ===") if __name__ == "__main__": main()