emotion-vector-replication / extract_vectors.py
rain1955's picture
Add extract_vectors.py
051c56b verified
#!/usr/bin/env python3
"""Extract emotion vectors from Gemma4 model using PyTorch hooks on residual stream."""
import json
import os
import warnings
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from collections import defaultdict
warnings.filterwarnings("ignore")
EXP_DIR = os.path.dirname(os.path.abspath(__file__))
STORIES_FILE = os.path.join(EXP_DIR, "emotion_stories.jsonl")
OUTPUT_DIR = os.path.join(EXP_DIR, "results")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Model config - use Gemma4 E4B from HuggingFace
MODEL_ID = "google/gemma-4-E4B-it" # Gemma4 E4B (not gated, ~8GB float16)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
START_TOKEN = 50 # Start averaging from token 50 (emotion should be apparent)
def load_stories():
stories = defaultdict(list)
with open(STORIES_FILE, "r") as f:
for line in f:
d = json.loads(line)
stories[d["emotion"]].append(d["text"])
return dict(stories)
def get_residual_stream_hooks(model):
"""Attach hooks to capture residual stream after each layer."""
activations = {}
def make_hook(name):
def hook(module, input, output):
# output is typically (hidden_states, ...) or just hidden_states
if isinstance(output, tuple):
hidden = output[0]
else:
hidden = output
activations[name] = hidden.detach().cpu().float()
return hook
hooks = []
# Gemma4 multimodal: text layers at model.model.language_model.layers
# Gemma3/standard: text layers at model.model.layers
if hasattr(model.model, 'language_model'):
layers = model.model.language_model.layers
else:
layers = model.model.layers
for i, layer in enumerate(layers):
h = layer.register_forward_hook(make_hook(f"layer_{i}"))
hooks.append(h)
return activations, hooks
def extract_activations(model, tokenizer, text, activations_dict, target_layer):
"""Run text through model and return mean activation at target layer."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
with torch.no_grad():
model(**inputs)
key = f"layer_{target_layer}"
if key not in activations_dict:
return None
hidden = activations_dict[key] # (1, seq_len, hidden_dim)
seq_len = hidden.shape[1]
if seq_len <= START_TOKEN:
# Short text, use all tokens
mean_act = hidden[0].mean(dim=0)
else:
mean_act = hidden[0, START_TOKEN:].mean(dim=0)
# Clear activations for next run
activations_dict.clear()
return mean_act.numpy()
def compute_emotion_vectors(emotion_activations):
"""
emotion_vector[e] = mean(activations[e]) - mean(activations[all])
"""
# Global mean across all emotions
all_acts = []
for acts in emotion_activations.values():
all_acts.extend(acts)
global_mean = np.mean(all_acts, axis=0)
emotion_vectors = {}
for emotion, acts in emotion_activations.items():
emotion_mean = np.mean(acts, axis=0)
emotion_vectors[emotion] = emotion_mean - global_mean
return emotion_vectors, global_mean
def denoise_vectors(emotion_vectors, neutral_activations, variance_threshold=0.5):
"""Project out top PCA components from neutral text."""
if len(neutral_activations) == 0:
return emotion_vectors
neutral_matrix = np.stack(neutral_activations)
neutral_centered = neutral_matrix - neutral_matrix.mean(axis=0)
# SVD
U, S, Vt = np.linalg.svd(neutral_centered, full_matrices=False)
# Find components explaining 50% variance
total_var = (S ** 2).sum()
cumvar = np.cumsum(S ** 2) / total_var
n_components = np.searchsorted(cumvar, variance_threshold) + 1
print(f"Denoising: projecting out {n_components} components (explain {variance_threshold*100}% variance)")
# Projection matrix
V_noise = Vt[:n_components].T # (hidden_dim, n_components)
denoised = {}
for emotion, vec in emotion_vectors.items():
# Project out noise components
projection = V_noise @ (V_noise.T @ vec)
denoised[emotion] = vec - projection
return denoised
def logit_lens(model, tokenizer, emotion_vectors, top_k=10):
"""Project emotion vectors through unembedding to see associated tokens."""
# Get the lm_head / embed_tokens weight
if hasattr(model, 'lm_head'):
W = model.lm_head.weight.detach().cpu().float().numpy() # (vocab, hidden)
elif hasattr(model.model, 'language_model'):
W = model.model.language_model.embed_tokens.weight.detach().cpu().float().numpy()
else:
W = model.model.embed_tokens.weight.detach().cpu().float().numpy()
results = {}
for emotion, vec in emotion_vectors.items():
# Logits = W @ vec
logits = W @ vec
top_indices = np.argsort(logits)[-top_k:][::-1]
bottom_indices = np.argsort(logits)[:top_k]
top_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in top_indices]
bottom_tokens = [(tokenizer.decode([idx]), float(logits[idx])) for idx in bottom_indices]
results[emotion] = {"top": top_tokens, "bottom": bottom_tokens}
return results
def pca_analysis(emotion_vectors):
"""PCA on emotion vectors to find valence/arousal structure."""
emotions = list(emotion_vectors.keys())
matrix = np.stack([emotion_vectors[e] for e in emotions])
# Center
matrix_centered = matrix - matrix.mean(axis=0)
# SVD
U, S, Vt = np.linalg.svd(matrix_centered, full_matrices=False)
# Project onto first 2 PCs
projections = matrix_centered @ Vt[:2].T # (n_emotions, 2)
explained_variance = (S[:2] ** 2) / (S ** 2).sum()
return {
"emotions": emotions,
"pc1": projections[:, 0].tolist(),
"pc2": projections[:, 1].tolist(),
"explained_variance_pc1": float(explained_variance[0]),
"explained_variance_pc2": float(explained_variance[1]),
}
def main():
print("=== Emotion Vector Extraction Experiment ===\n")
# Load stories
stories = load_stories()
print(f"Loaded {sum(len(v) for v in stories.values())} stories across {len(stories)} emotions\n")
# Load model
print(f"Loading model {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Gemma4 is multimodal — AutoModelForCausalLM maps to ForConditionalGeneration
# which is correct, but we need to handle the nested language_model layers
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16, # Gemma4 native dtype
device_map="auto",
)
model.eval()
if hasattr(model.model, 'language_model'):
num_layers = len(model.model.language_model.layers)
else:
num_layers = len(model.model.layers)
target_layer = int(num_layers * 2 / 3) # 2/3 depth
print(f"Model loaded. {num_layers} layers, target layer: {target_layer}\n")
# Attach hooks
activations_dict, hooks = get_residual_stream_hooks(model)
# Extract activations for each emotion
print("Extracting activations...")
emotion_activations = defaultdict(list)
total = sum(len(v) for v in stories.values())
done = 0
for emotion, story_list in stories.items():
for story in story_list:
act = extract_activations(model, tokenizer, story, activations_dict, target_layer)
if act is not None:
emotion_activations[emotion].append(act)
done += 1
if done % 50 == 0:
print(f" [{done}/{total}]")
print(f" Extracted activations for {len(emotion_activations)} emotions\n")
# Extract neutral activations for denoising
print("Extracting neutral activations for denoising...")
neutral_texts = [
"The meeting is scheduled for 3pm tomorrow.",
"Please find the attached document.",
"The temperature today is 22 degrees Celsius.",
"The project deadline has been moved to next Friday.",
"The store is located on the corner of Main Street.",
"Chapter 3 discusses the economic implications.",
"The software update includes several bug fixes.",
"The report contains data from the past quarter.",
"The committee will review the proposal next week.",
"The library opens at 9am on weekdays.",
] * 5 # 50 neutral texts
neutral_activations = []
for text in neutral_texts:
act = extract_activations(model, tokenizer, text, activations_dict, target_layer)
if act is not None:
neutral_activations.append(act)
print(f" {len(neutral_activations)} neutral activations collected\n")
# Compute emotion vectors
print("Computing emotion vectors...")
raw_vectors, global_mean = compute_emotion_vectors(dict(emotion_activations))
print(f" {len(raw_vectors)} raw emotion vectors computed")
# Denoise
print("Denoising...")
denoised_vectors = denoise_vectors(raw_vectors, neutral_activations)
# Logit Lens
print("\nRunning Logit Lens analysis...")
logit_results = logit_lens(model, tokenizer, denoised_vectors)
print("\n=== Logit Lens Results ===")
for emotion in sorted(logit_results.keys()):
top = logit_results[emotion]["top"][:5]
bottom = logit_results[emotion]["bottom"][:5]
top_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in top])
bottom_str = ", ".join([f"{t[0].strip()}({t[1]:.1f})" for t in bottom])
print(f" {emotion:12s}{top_str}")
print(f" {' ':12s}{bottom_str}")
# PCA analysis
print("\nRunning PCA analysis...")
pca_results = pca_analysis(denoised_vectors)
print(f" PC1 explains {pca_results['explained_variance_pc1']*100:.1f}% variance")
print(f" PC2 explains {pca_results['explained_variance_pc2']*100:.1f}% variance")
print("\n=== Emotion Space (PC1 vs PC2) ===")
for i, emotion in enumerate(pca_results["emotions"]):
pc1 = pca_results["pc1"][i]
pc2 = pca_results["pc2"][i]
print(f" {emotion:12s} PC1={pc1:+.3f} PC2={pc2:+.3f}")
# Save results
results = {
"model": MODEL_ID,
"target_layer": target_layer,
"num_layers": num_layers,
"num_emotions": len(denoised_vectors),
"stories_per_emotion": {e: len(v) for e, v in stories.items()},
"logit_lens": logit_results,
"pca": pca_results,
}
results_file = os.path.join(OUTPUT_DIR, "experiment_results.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
# Save raw vectors as numpy
vectors_file = os.path.join(OUTPUT_DIR, "emotion_vectors.npz")
np.savez(vectors_file, **{e: v for e, v in denoised_vectors.items()})
print(f"\nResults saved to {results_file}")
print(f"Vectors saved to {vectors_file}")
# Cleanup hooks
for h in hooks:
h.remove()
print("\n=== EXPERIMENT COMPLETE ===")
if __name__ == "__main__":
main()