Text Generation
Transformers
emotion-vectors
interpretability
mechanistic-interpretability
replication
gemma4
google
anthropic
valence-arousal
PCA
logit-lens
linear-probe
probing
emotion
functional-emotions
AI-safety
neuroscience
circumplex-model
activation-extraction
residual-stream
Eval Results (legacy)
| #!/usr/bin/env python3 | |
| """Extract emotion vectors from Gemma4 model using PyTorch hooks on residual stream.""" | |
| import json | |
| import os | |
| import warnings | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel | |
| from collections import defaultdict | |
| warnings.filterwarnings("ignore") | |
| EXP_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| STORIES_FILE = os.path.join(EXP_DIR, "emotion_stories.jsonl") | |
| OUTPUT_DIR = os.path.join(EXP_DIR, "results") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Model config - use Gemma4 E4B from HuggingFace | |
| MODEL_ID = "google/gemma-4-E4B-it" # Gemma4 E4B (not gated, ~8GB float16) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| START_TOKEN = 50 # Start averaging from token 50 (emotion should be apparent) | |
| def load_stories(): | |
| stories = defaultdict(list) | |
| with open(STORIES_FILE, "r") as f: | |
| for line in f: | |
| d = json.loads(line) | |
| stories[d["emotion"]].append(d["text"]) | |
| return dict(stories) | |
| def get_residual_stream_hooks(model): | |
| """Attach hooks to capture residual stream after each layer.""" | |
| activations = {} | |
| def make_hook(name): | |
| def hook(module, input, output): | |
| # output is typically (hidden_states, ...) or just hidden_states | |
| if isinstance(output, tuple): | |
| hidden = output[0] | |
| else: | |
| hidden = output | |
| activations[name] = hidden.detach().cpu().float() | |
| return hook | |
| hooks = [] | |
| # Gemma4 multimodal: text layers at model.model.language_model.layers | |
| # Gemma3/standard: text layers at model.model.layers | |
| if hasattr(model.model, 'language_model'): | |
| layers = model.model.language_model.layers | |
| else: | |
| layers = model.model.layers | |
| for i, layer in enumerate(layers): | |
| h = layer.register_forward_hook(make_hook(f"layer_{i}")) | |
| hooks.append(h) | |
| return activations, hooks | |
| def extract_activations(model, tokenizer, text, activations_dict, target_layer): | |
| """Run text through model and return mean activation at target layer.""" | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(DEVICE) | |
| with torch.no_grad(): | |
| model(**inputs) | |
| key = f"layer_{target_layer}" | |
| if key not in activations_dict: | |
| return None | |
| hidden = activations_dict[key] # (1, seq_len, hidden_dim) | |
| seq_len = hidden.shape[1] | |
| if seq_len <= START_TOKEN: | |
| # Short text, use all tokens | |
| mean_act = hidden[0].mean(dim=0) | |
| else: | |
| mean_act = hidden[0, START_TOKEN:].mean(dim=0) | |
| # Clear activations for next run | |
| activations_dict.clear() | |
| return mean_act.numpy() | |
| def compute_emotion_vectors(emotion_activations): | |
| """ | |
| emotion_vector[e] = mean(activations[e]) - mean(activations[all]) | |
| """ | |
| # Global mean across all emotions | |
| all_acts = [] | |
| for acts in emotion_activations.values(): | |
| all_acts.extend(acts) | |
| global_mean = np.mean(all_acts, axis=0) | |
| emotion_vectors = {} | |
| for emotion, acts in emotion_activations.items(): | |
| emotion_mean = np.mean(acts, axis=0) | |
| emotion_vectors[emotion] = emotion_mean - global_mean | |
| return emotion_vectors, global_mean | |
| def denoise_vectors(emotion_vectors, neutral_activations, variance_threshold=0.5): | |
| """Project out top PCA components from neutral text.""" | |
| if len(neutral_activations) == 0: | |
| return emotion_vectors | |
| neutral_matrix = np.stack(neutral_activations) | |
| neutral_centered = neutral_matrix - neutral_matrix.mean(axis=0) | |
| # SVD | |
| U, S, Vt = np.linalg.svd(neutral_centered, full_matrices=False) | |
| # Find components explaining 50% variance | |
| total_var = (S ** 2).sum() | |
| cumvar = np.cumsum(S ** 2) / total_var | |
| n_components = np.searchsorted(cumvar, variance_threshold) + 1 | |
| print(f"Denoising: projecting out {n_components} components (explain {variance_threshold*100}% variance)") | |
| # Projection matrix | |
| V_noise = Vt[:n_components].T # (hidden_dim, n_components) | |
| denoised = {} | |
| for emotion, vec in emotion_vectors.items(): | |
| # Project out noise components | |
| projection = V_noise @ (V_noise.T @ vec) | |
| denoised[emotion] = vec - projection | |
| return denoised | |
| def logit_lens(model, tokenizer, emotion_vectors, top_k=10): | |
| """Project emotion vectors through unembedding to see associated tokens.""" | |
| # Get the lm_head / embed_tokens weight | |
| if hasattr(model, 'lm_head'): | |
| W = model.lm_head.weight.detach().cpu().float().numpy() # (vocab, hidden) | |
| elif hasattr(model.model, 'language_model'): | |
| W = model.model.language_model.embed_tokens.weight.detach().cpu().float().numpy() | |
| else: | |
| W = model.model.embed_tokens.weight.detach().cpu().float().numpy() | |
| results = {} | |
| for emotion, vec in emotion_vectors.items(): | |
| # Logits = W @ vec | |
| logits = W @ vec | |
| top_indices = np.argsort(logits)[-top_k:][::-1] | |
| bottom_indices = np.argsort(logits)[:top_k] | |
| top_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in top_indices] | |
| bottom_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in bottom_indices] | |
| results[emotion] = {"top": top_tokens, "bottom": bottom_tokens} | |
| return results | |
| def pca_analysis(emotion_vectors): | |
| """PCA on emotion vectors to find valence/arousal structure.""" | |
| emotions = list(emotion_vectors.keys()) | |
| matrix = np.stack([emotion_vectors[e] for e in emotions]) | |
| # Center | |
| matrix_centered = matrix - matrix.mean(axis=0) | |
| # SVD | |
| U, S, Vt = np.linalg.svd(matrix_centered, full_matrices=False) | |
| # Project onto first 2 PCs | |
| projections = matrix_centered @ Vt[:2].T # (n_emotions, 2) | |
| explained_variance = (S[:2] ** 2) / (S ** 2).sum() | |
| return { | |
| "emotions": emotions, | |
| "pc1": projections[:, 0].tolist(), | |
| "pc2": projections[:, 1].tolist(), | |
| "explained_variance_pc1": float(explained_variance[0]), | |
| "explained_variance_pc2": float(explained_variance[1]), | |
| } | |
| def main(): | |
| print("=== Emotion Vector Extraction Experiment ===\n") | |
| # Load stories | |
| stories = load_stories() | |
| print(f"Loaded {sum(len(v) for v in stories.values())} stories across {len(stories)} emotions\n") | |
| # Load model | |
| print(f"Loading model {MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| # Gemma4 is multimodal — AutoModelForCausalLM maps to ForConditionalGeneration | |
| # which is correct, but we need to handle the nested language_model layers | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, # Gemma4 native dtype | |
| device_map="auto", | |
| ) | |
| model.eval() | |
| if hasattr(model.model, 'language_model'): | |
| num_layers = len(model.model.language_model.layers) | |
| else: | |
| num_layers = len(model.model.layers) | |
| target_layer = int(num_layers * 2 / 3) # 2/3 depth | |
| print(f"Model loaded. {num_layers} layers, target layer: {target_layer}\n") | |
| # Attach hooks | |
| activations_dict, hooks = get_residual_stream_hooks(model) | |
| # Extract activations for each emotion | |
| print("Extracting activations...") | |
| emotion_activations = defaultdict(list) | |
| total = sum(len(v) for v in stories.values()) | |
| done = 0 | |
| for emotion, story_list in stories.items(): | |
| for story in story_list: | |
| act = extract_activations(model, tokenizer, story, activations_dict, target_layer) | |
| if act is not None: | |
| emotion_activations[emotion].append(act) | |
| done += 1 | |
| if done % 50 == 0: | |
| print(f" [{done}/{total}]") | |
| print(f" Extracted activations for {len(emotion_activations)} emotions\n") | |
| # Extract neutral activations for denoising | |
| print("Extracting neutral activations for denoising...") | |
| neutral_texts = [ | |
| "The meeting is scheduled for 3pm tomorrow.", | |
| "Please find the attached document.", | |
| "The temperature today is 22 degrees Celsius.", | |
| "The project deadline has been moved to next Friday.", | |
| "The store is located on the corner of Main Street.", | |
| "Chapter 3 discusses the economic implications.", | |
| "The software update includes several bug fixes.", | |
| "The report contains data from the past quarter.", | |
| "The committee will review the proposal next week.", | |
| "The library opens at 9am on weekdays.", | |
| ] * 5 # 50 neutral texts | |
| neutral_activations = [] | |
| for text in neutral_texts: | |
| act = extract_activations(model, tokenizer, text, activations_dict, target_layer) | |
| if act is not None: | |
| neutral_activations.append(act) | |
| print(f" {len(neutral_activations)} neutral activations collected\n") | |
| # Compute emotion vectors | |
| print("Computing emotion vectors...") | |
| raw_vectors, global_mean = compute_emotion_vectors(dict(emotion_activations)) | |
| print(f" {len(raw_vectors)} raw emotion vectors computed") | |
| # Denoise | |
| print("Denoising...") | |
| denoised_vectors = denoise_vectors(raw_vectors, neutral_activations) | |
| # Logit Lens | |
| print("\nRunning Logit Lens analysis...") | |
| logit_results = logit_lens(model, tokenizer, denoised_vectors) | |
| print("\n=== Logit Lens Results ===") | |
| for emotion in sorted(logit_results.keys()): | |
| top = logit_results[emotion]["top"][:5] | |
| bottom = logit_results[emotion]["bottom"][:5] | |
| top_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in top]) | |
| bottom_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in bottom]) | |
| print(f" {emotion:12s} ↑ {top_str}") | |
| print(f" {' ':12s} ↓ {bottom_str}") | |
| # PCA analysis | |
| print("\nRunning PCA analysis...") | |
| pca_results = pca_analysis(denoised_vectors) | |
| print(f" PC1 explains {pca_results['explained_variance_pc1']*100:.1f}% variance") | |
| print(f" PC2 explains {pca_results['explained_variance_pc2']*100:.1f}% variance") | |
| print("\n=== Emotion Space (PC1 vs PC2) ===") | |
| for i, emotion in enumerate(pca_results["emotions"]): | |
| pc1 = pca_results["pc1"][i] | |
| pc2 = pca_results["pc2"][i] | |
| print(f" {emotion:12s} PC1={pc1:+.3f} PC2={pc2:+.3f}") | |
| # Save results | |
| results = { | |
| "model": MODEL_ID, | |
| "target_layer": target_layer, | |
| "num_layers": num_layers, | |
| "num_emotions": len(denoised_vectors), | |
| "stories_per_emotion": {e: len(v) for e, v in stories.items()}, | |
| "logit_lens": logit_results, | |
| "pca": pca_results, | |
| } | |
| results_file = os.path.join(OUTPUT_DIR, "experiment_results.json") | |
| with open(results_file, "w") as f: | |
| json.dump(results, f, indent=2, ensure_ascii=False) | |
| # Save raw vectors as numpy | |
| vectors_file = os.path.join(OUTPUT_DIR, "emotion_vectors.npz") | |
| np.savez(vectors_file, **{e: v for e, v in denoised_vectors.items()}) | |
| print(f"\nResults saved to {results_file}") | |
| print(f"Vectors saved to {vectors_file}") | |
| # Cleanup hooks | |
| for h in hooks: | |
| h.remove() | |
| print("\n=== EXPERIMENT COMPLETE ===") | |
| if __name__ == "__main__": | |
| main() | |