Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 13,193 Bytes
920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
"""
In-Context Learning Analysis Service
Analyzes how examples influence model behavior during code generation.
"""
import torch
import numpy as np
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
from .icl_attention_extractor import AttentionExtractor
from .induction_head_detector import InductionHeadDetector, ICLEmergenceAnalysis
import logging
logger = logging.getLogger(__name__)
@dataclass
class ICLExample:
"""Represents an in-context learning example"""
input: str
output: str
@dataclass
class ICLAnalysisResult:
"""Results from ICL analysis"""
shot_count: int
generated_code: str
tokens: List[str]
confidence_scores: List[float]
attention_from_examples: Dict[str, List[float]] # example_id -> attention weights per token
perplexity: float
avg_confidence: float
example_influences: Dict[str, float] # example_id -> overall influence score
hidden_state_drift: Optional[List[float]] = None # magnitude of hidden state changes
icl_emergence: Optional[ICLEmergenceAnalysis] = None # When/how ICL kicks in
class ICLAnalyzer:
"""Analyzes in-context learning effects on model behavior"""
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, adapter=None):
self.model = model
self.tokenizer = tokenizer
self.adapter = adapter
self.device = next(model.parameters()).device
# Ensure tokenizer has pad_token (needed for Code-Llama)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Initialize attention extractor for real attention data
self.attention_extractor = AttentionExtractor(model, tokenizer, adapter=adapter)
# Initialize induction head detector
self.induction_detector = InductionHeadDetector(model, tokenizer, adapter=adapter)
# Storage for attention patterns
self.attention_maps = []
self.hidden_states = []
def prepare_prompt_with_examples(self, examples: List[ICLExample], test_prompt: str) -> str:
"""Construct prompt with examples in standard format"""
if not examples:
return test_prompt
prompt_parts = []
for example in examples:
prompt_parts.append(f"{example.input}\n{example.output}\n")
prompt_parts.append(test_prompt)
return "\n".join(prompt_parts)
def extract_attention_patterns(self, outputs, input_ids, example_boundaries: List[Tuple[int, int]]) -> Dict[str, List[float]]:
"""Extract attention patterns - real if available, simulated otherwise"""
# Try to use real attention data if available
if hasattr(self, 'last_attention_data') and self.last_attention_data:
logger.info("Using real attention data from model hooks")
prompt_length = len(input_ids[0])
return self.attention_extractor.aggregate_attention_to_examples(
self.last_attention_data,
example_boundaries,
prompt_length
)
# Fall back to simulated patterns
logger.info("Using simulated attention patterns")
attention_from_examples = {}
if not example_boundaries:
return attention_from_examples
generated_ids = outputs.sequences[0][len(input_ids[0]):]
num_generated = len(generated_ids)
if num_generated == 0:
return attention_from_examples
# Create simulated patterns (existing code)
for idx, (start, end) in enumerate(example_boundaries):
example_id = str(idx + 1)
base_weight = 0.3 + (idx * 0.1) / len(example_boundaries)
attention_weights = []
for token_idx in range(num_generated):
weight = base_weight * np.exp(-token_idx * 0.05)
weight += np.random.normal(0, 0.02)
weight = max(0, min(1, weight))
attention_weights.append(weight)
attention_from_examples[example_id] = attention_weights
# Normalize
if len(attention_from_examples) > 1:
for token_idx in range(num_generated):
total = sum(weights[token_idx] for weights in attention_from_examples.values())
if total > 0:
for example_id in attention_from_examples:
attention_from_examples[example_id][token_idx] /= total
return attention_from_examples
def calculate_example_influences(self, attention_from_examples: Dict[str, List[float]]) -> Dict[str, float]:
"""Calculate overall influence score for each example"""
# If we have real attention data, use the extractor's method
if hasattr(self, 'last_attention_data') and self.last_attention_data:
return self.attention_extractor.calculate_example_influences(attention_from_examples)
# Otherwise use existing calculation
influences = {}
for example_id, weights in attention_from_examples.items():
influences[example_id] = float(np.mean(weights)) if weights else 0.0
total = sum(influences.values())
if total > 0 and total != 1.0:
influences = {k: v/total for k, v in influences.items()}
return influences
def track_hidden_state_drift(self, base_hidden_states, example_hidden_states) -> List[float]:
"""Track how hidden states change from base (no examples) to with examples"""
if base_hidden_states is None or example_hidden_states is None:
return []
# Calculate L2 distance between hidden states at each position
drift = []
min_len = min(len(base_hidden_states), len(example_hidden_states))
for i in range(min_len):
base = base_hidden_states[i]
example = example_hidden_states[i]
if isinstance(base, torch.Tensor):
base = base.cpu().numpy()
if isinstance(example, torch.Tensor):
example = example.cpu().numpy()
distance = np.linalg.norm(example - base)
drift.append(float(distance))
return drift
def analyze_generation(
self,
examples: List[ICLExample],
test_prompt: str,
max_length: int = 150,
temperature: float = 0.7,
base_hidden_states: Optional[Any] = None
) -> ICLAnalysisResult:
"""Analyze how examples influence generation"""
# Prepare prompt
full_prompt = self.prepare_prompt_with_examples(examples, test_prompt)
# Tokenize
inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(self.device)
# Find example boundaries in token space
example_boundaries = []
if examples:
current_pos = 0
for example in examples:
example_text = f"{example.input}\n{example.output}\n"
example_tokens = self.tokenizer(example_text, add_special_tokens=False)["input_ids"]
example_boundaries.append((current_pos, current_pos + len(example_tokens)))
current_pos += len(example_tokens)
# First do standard generation to get scores and text
with torch.no_grad():
outputs = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.pad_token_id,
return_dict_in_generate=True,
output_scores=True,
output_hidden_states=False
)
# Then try to extract real attention data
try:
logger.info("Extracting real attention data")
_, attention_data, _ = self.attention_extractor.extract_attention_with_generation(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=min(30, max_length - len(input_ids[0])), # Limit for performance
temperature=temperature
)
self.last_attention_data = attention_data
logger.info(f"Successfully extracted {len(attention_data)} attention records")
except Exception as e:
logger.warning(f"Real attention extraction failed: {e}")
self.last_attention_data = None
# Extract generated tokens - show raw output, no trimming
generated_ids = outputs.sequences[0][len(input_ids[0]):]
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
tokens = [self.tokenizer.decode([token_id]) for token_id in generated_ids]
# Calculate confidence scores
confidence_scores = []
if outputs.scores:
for score in outputs.scores:
probs = F.softmax(score[0], dim=-1)
max_prob = probs.max().item()
confidence_scores.append(max_prob)
# Calculate perplexity
if outputs.scores:
log_probs = []
for i, score in enumerate(outputs.scores):
if i < len(generated_ids):
token_id = generated_ids[i]
log_prob = F.log_softmax(score[0], dim=-1)[token_id].item()
log_probs.append(log_prob)
perplexity = np.exp(-np.mean(log_probs)) if log_probs else 0.0
else:
perplexity = 0.0
# Extract attention patterns
attention_from_examples = self.extract_attention_patterns(outputs, input_ids, example_boundaries)
# Calculate example influences
example_influences = self.calculate_example_influences(attention_from_examples)
# Track hidden state drift if base states provided
hidden_state_drift = None
if base_hidden_states is not None and hasattr(outputs, 'hidden_states'):
current_hidden = outputs.hidden_states[-1] if outputs.hidden_states else None
if current_hidden is not None:
hidden_state_drift = self.track_hidden_state_drift(base_hidden_states, current_hidden)
# Analyze ICL emergence if we have attention data and examples
icl_emergence = None
if self.last_attention_data and len(examples) > 0:
try:
icl_emergence = self.induction_detector.analyze_icl_emergence(
self.last_attention_data,
input_ids,
example_boundaries,
generated_ids.tolist() if generated_ids.numel() > 0 else []
)
logger.info(f"ICL emergence analysis: detected={icl_emergence.emergence_detected}, "
f"token={icl_emergence.emergence_token}, confidence={icl_emergence.confidence:.2f}")
except Exception as e:
logger.warning(f"ICL emergence analysis failed: {e}")
return ICLAnalysisResult(
shot_count=len(examples),
generated_code=generated_text,
tokens=tokens,
confidence_scores=confidence_scores,
attention_from_examples=attention_from_examples,
perplexity=perplexity,
avg_confidence=np.mean(confidence_scores) if confidence_scores else 0.0,
example_influences=example_influences,
hidden_state_drift=hidden_state_drift,
icl_emergence=icl_emergence
)
def compare_shot_settings(
self,
examples: List[ICLExample],
test_prompt: str,
max_length: int = 150,
temperature: float = 0.7
) -> Dict[str, ICLAnalysisResult]:
"""Compare 0-shot, 1-shot, and few-shot generation"""
results = {}
# 0-shot (no examples)
results['zero_shot'] = self.analyze_generation([], test_prompt, max_length, temperature)
base_hidden = results['zero_shot'].hidden_state_drift # Use as baseline
# 1-shot (first example only)
if len(examples) >= 1:
results['one_shot'] = self.analyze_generation(
examples[:1], test_prompt, max_length, temperature, base_hidden
)
# Few-shot (all examples)
if len(examples) >= 2:
results['few_shot'] = self.analyze_generation(
examples, test_prompt, max_length, temperature, base_hidden
)
return results |