api / backend /icl_service.py
gary-boon
Add backend support for ICL emergence analysis
920a98d
raw
history blame
13 kB
"""
In-Context Learning Analysis Service
Analyzes how examples influence model behavior during code generation.
"""
import torch
import numpy as np
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
from .icl_attention_extractor import AttentionExtractor
from .induction_head_detector import InductionHeadDetector, ICLEmergenceAnalysis
import logging
logger = logging.getLogger(__name__)
@dataclass
class ICLExample:
"""Represents an in-context learning example"""
input: str
output: str
@dataclass
class ICLAnalysisResult:
"""Results from ICL analysis"""
shot_count: int
generated_code: str
tokens: List[str]
confidence_scores: List[float]
attention_from_examples: Dict[str, List[float]] # example_id -> attention weights per token
perplexity: float
avg_confidence: float
example_influences: Dict[str, float] # example_id -> overall influence score
hidden_state_drift: Optional[List[float]] = None # magnitude of hidden state changes
icl_emergence: Optional[ICLEmergenceAnalysis] = None # When/how ICL kicks in
class ICLAnalyzer:
"""Analyzes in-context learning effects on model behavior"""
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
# Initialize attention extractor for real attention data
self.attention_extractor = AttentionExtractor(model, tokenizer)
# Initialize induction head detector
self.induction_detector = InductionHeadDetector(model, tokenizer)
# Storage for attention patterns
self.attention_maps = []
self.hidden_states = []
def prepare_prompt_with_examples(self, examples: List[ICLExample], test_prompt: str) -> str:
"""Construct prompt with examples in standard format"""
if not examples:
return test_prompt
prompt_parts = []
for example in examples:
prompt_parts.append(f"{example.input}\n{example.output}\n")
prompt_parts.append(test_prompt)
return "\n".join(prompt_parts)
def extract_attention_patterns(self, outputs, input_ids, example_boundaries: List[Tuple[int, int]]) -> Dict[str, List[float]]:
"""Extract attention patterns - real if available, simulated otherwise"""
# Try to use real attention data if available
if hasattr(self, 'last_attention_data') and self.last_attention_data:
logger.info("Using real attention data from model hooks")
prompt_length = len(input_ids[0])
return self.attention_extractor.aggregate_attention_to_examples(
self.last_attention_data,
example_boundaries,
prompt_length
)
# Fall back to simulated patterns
logger.info("Using simulated attention patterns")
attention_from_examples = {}
if not example_boundaries:
return attention_from_examples
generated_ids = outputs.sequences[0][len(input_ids[0]):]
num_generated = len(generated_ids)
if num_generated == 0:
return attention_from_examples
# Create simulated patterns (existing code)
for idx, (start, end) in enumerate(example_boundaries):
example_id = str(idx + 1)
base_weight = 0.3 + (idx * 0.1) / len(example_boundaries)
attention_weights = []
for token_idx in range(num_generated):
weight = base_weight * np.exp(-token_idx * 0.05)
weight += np.random.normal(0, 0.02)
weight = max(0, min(1, weight))
attention_weights.append(weight)
attention_from_examples[example_id] = attention_weights
# Normalize
if len(attention_from_examples) > 1:
for token_idx in range(num_generated):
total = sum(weights[token_idx] for weights in attention_from_examples.values())
if total > 0:
for example_id in attention_from_examples:
attention_from_examples[example_id][token_idx] /= total
return attention_from_examples
def calculate_example_influences(self, attention_from_examples: Dict[str, List[float]]) -> Dict[str, float]:
"""Calculate overall influence score for each example"""
# If we have real attention data, use the extractor's method
if hasattr(self, 'last_attention_data') and self.last_attention_data:
return self.attention_extractor.calculate_example_influences(attention_from_examples)
# Otherwise use existing calculation
influences = {}
for example_id, weights in attention_from_examples.items():
influences[example_id] = float(np.mean(weights)) if weights else 0.0
total = sum(influences.values())
if total > 0 and total != 1.0:
influences = {k: v/total for k, v in influences.items()}
return influences
def track_hidden_state_drift(self, base_hidden_states, example_hidden_states) -> List[float]:
"""Track how hidden states change from base (no examples) to with examples"""
if base_hidden_states is None or example_hidden_states is None:
return []
# Calculate L2 distance between hidden states at each position
drift = []
min_len = min(len(base_hidden_states), len(example_hidden_states))
for i in range(min_len):
base = base_hidden_states[i]
example = example_hidden_states[i]
if isinstance(base, torch.Tensor):
base = base.cpu().numpy()
if isinstance(example, torch.Tensor):
example = example.cpu().numpy()
distance = np.linalg.norm(example - base)
drift.append(float(distance))
return drift
def analyze_generation(
self,
examples: List[ICLExample],
test_prompt: str,
max_length: int = 150,
temperature: float = 0.7,
base_hidden_states: Optional[Any] = None
) -> ICLAnalysisResult:
"""Analyze how examples influence generation"""
# Prepare prompt
full_prompt = self.prepare_prompt_with_examples(examples, test_prompt)
# Tokenize
inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(self.device)
# Find example boundaries in token space
example_boundaries = []
if examples:
current_pos = 0
for example in examples:
example_text = f"{example.input}\n{example.output}\n"
example_tokens = self.tokenizer(example_text, add_special_tokens=False)["input_ids"]
example_boundaries.append((current_pos, current_pos + len(example_tokens)))
current_pos += len(example_tokens)
# First do standard generation to get scores and text
with torch.no_grad():
outputs = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.pad_token_id,
return_dict_in_generate=True,
output_scores=True,
output_hidden_states=False
)
# Then try to extract real attention data
try:
logger.info("Extracting real attention data")
_, attention_data, _ = self.attention_extractor.extract_attention_with_generation(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=min(30, max_length - len(input_ids[0])), # Limit for performance
temperature=temperature
)
self.last_attention_data = attention_data
logger.info(f"Successfully extracted {len(attention_data)} attention records")
except Exception as e:
logger.warning(f"Real attention extraction failed: {e}")
self.last_attention_data = None
# Extract generated tokens - show raw output, no trimming
generated_ids = outputs.sequences[0][len(input_ids[0]):]
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
tokens = [self.tokenizer.decode([token_id]) for token_id in generated_ids]
# Calculate confidence scores
confidence_scores = []
if outputs.scores:
for score in outputs.scores:
probs = F.softmax(score[0], dim=-1)
max_prob = probs.max().item()
confidence_scores.append(max_prob)
# Calculate perplexity
if outputs.scores:
log_probs = []
for i, score in enumerate(outputs.scores):
if i < len(generated_ids):
token_id = generated_ids[i]
log_prob = F.log_softmax(score[0], dim=-1)[token_id].item()
log_probs.append(log_prob)
perplexity = np.exp(-np.mean(log_probs)) if log_probs else 0.0
else:
perplexity = 0.0
# Extract attention patterns
attention_from_examples = self.extract_attention_patterns(outputs, input_ids, example_boundaries)
# Calculate example influences
example_influences = self.calculate_example_influences(attention_from_examples)
# Track hidden state drift if base states provided
hidden_state_drift = None
if base_hidden_states is not None and hasattr(outputs, 'hidden_states'):
current_hidden = outputs.hidden_states[-1] if outputs.hidden_states else None
if current_hidden is not None:
hidden_state_drift = self.track_hidden_state_drift(base_hidden_states, current_hidden)
# Analyze ICL emergence if we have attention data and examples
icl_emergence = None
if self.last_attention_data and len(examples) > 0:
try:
icl_emergence = self.induction_detector.analyze_icl_emergence(
self.last_attention_data,
input_ids,
example_boundaries,
generated_ids.tolist() if generated_ids.numel() > 0 else []
)
logger.info(f"ICL emergence analysis: detected={icl_emergence.emergence_detected}, "
f"token={icl_emergence.emergence_token}, confidence={icl_emergence.confidence:.2f}")
except Exception as e:
logger.warning(f"ICL emergence analysis failed: {e}")
return ICLAnalysisResult(
shot_count=len(examples),
generated_code=generated_text,
tokens=tokens,
confidence_scores=confidence_scores,
attention_from_examples=attention_from_examples,
perplexity=perplexity,
avg_confidence=np.mean(confidence_scores) if confidence_scores else 0.0,
example_influences=example_influences,
hidden_state_drift=hidden_state_drift,
icl_emergence=icl_emergence
)
def compare_shot_settings(
self,
examples: List[ICLExample],
test_prompt: str,
max_length: int = 150,
temperature: float = 0.7
) -> Dict[str, ICLAnalysisResult]:
"""Compare 0-shot, 1-shot, and few-shot generation"""
results = {}
# 0-shot (no examples)
results['zero_shot'] = self.analyze_generation([], test_prompt, max_length, temperature)
base_hidden = results['zero_shot'].hidden_state_drift # Use as baseline
# 1-shot (first example only)
if len(examples) >= 1:
results['one_shot'] = self.analyze_generation(
examples[:1], test_prompt, max_length, temperature, base_hidden
)
# Few-shot (all examples)
if len(examples) >= 2:
results['few_shot'] = self.analyze_generation(
examples, test_prompt, max_length, temperature, base_hidden
)
return results