Spaces:
Sleeping
Sleeping
| """ | |
| In-Context Learning Analysis Service | |
| Analyzes how examples influence model behavior during code generation. | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Optional, Any, Tuple | |
| from dataclasses import dataclass | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch.nn.functional as F | |
| from .icl_attention_extractor import AttentionExtractor | |
| from .induction_head_detector import InductionHeadDetector, ICLEmergenceAnalysis | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ICLExample: | |
| """Represents an in-context learning example""" | |
| input: str | |
| output: str | |
| class ICLAnalysisResult: | |
| """Results from ICL analysis""" | |
| shot_count: int | |
| generated_code: str | |
| tokens: List[str] | |
| confidence_scores: List[float] | |
| attention_from_examples: Dict[str, List[float]] # example_id -> attention weights per token | |
| perplexity: float | |
| avg_confidence: float | |
| example_influences: Dict[str, float] # example_id -> overall influence score | |
| hidden_state_drift: Optional[List[float]] = None # magnitude of hidden state changes | |
| icl_emergence: Optional[ICLEmergenceAnalysis] = None # When/how ICL kicks in | |
| class ICLAnalyzer: | |
| """Analyzes in-context learning effects on model behavior""" | |
| def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, adapter=None): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.adapter = adapter | |
| self.device = next(model.parameters()).device | |
| # Ensure tokenizer has pad_token (needed for Code-Llama) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Initialize attention extractor for real attention data | |
| self.attention_extractor = AttentionExtractor(model, tokenizer, adapter=adapter) | |
| # Initialize induction head detector | |
| self.induction_detector = InductionHeadDetector(model, tokenizer, adapter=adapter) | |
| # Storage for attention patterns | |
| self.attention_maps = [] | |
| self.hidden_states = [] | |
| def prepare_prompt_with_examples(self, examples: List[ICLExample], test_prompt: str) -> str: | |
| """Construct prompt with examples in standard format""" | |
| if not examples: | |
| return test_prompt | |
| prompt_parts = [] | |
| for example in examples: | |
| prompt_parts.append(f"{example.input}\n{example.output}\n") | |
| prompt_parts.append(test_prompt) | |
| return "\n".join(prompt_parts) | |
| def extract_attention_patterns(self, outputs, input_ids, example_boundaries: List[Tuple[int, int]]) -> Dict[str, List[float]]: | |
| """Extract attention patterns - real if available, simulated otherwise""" | |
| # Try to use real attention data if available | |
| if hasattr(self, 'last_attention_data') and self.last_attention_data: | |
| logger.info("Using real attention data from model hooks") | |
| prompt_length = len(input_ids[0]) | |
| return self.attention_extractor.aggregate_attention_to_examples( | |
| self.last_attention_data, | |
| example_boundaries, | |
| prompt_length | |
| ) | |
| # Fall back to simulated patterns | |
| logger.info("Using simulated attention patterns") | |
| attention_from_examples = {} | |
| if not example_boundaries: | |
| return attention_from_examples | |
| generated_ids = outputs.sequences[0][len(input_ids[0]):] | |
| num_generated = len(generated_ids) | |
| if num_generated == 0: | |
| return attention_from_examples | |
| # Create simulated patterns (existing code) | |
| for idx, (start, end) in enumerate(example_boundaries): | |
| example_id = str(idx + 1) | |
| base_weight = 0.3 + (idx * 0.1) / len(example_boundaries) | |
| attention_weights = [] | |
| for token_idx in range(num_generated): | |
| weight = base_weight * np.exp(-token_idx * 0.05) | |
| weight += np.random.normal(0, 0.02) | |
| weight = max(0, min(1, weight)) | |
| attention_weights.append(weight) | |
| attention_from_examples[example_id] = attention_weights | |
| # Normalize | |
| if len(attention_from_examples) > 1: | |
| for token_idx in range(num_generated): | |
| total = sum(weights[token_idx] for weights in attention_from_examples.values()) | |
| if total > 0: | |
| for example_id in attention_from_examples: | |
| attention_from_examples[example_id][token_idx] /= total | |
| return attention_from_examples | |
| def calculate_example_influences(self, attention_from_examples: Dict[str, List[float]]) -> Dict[str, float]: | |
| """Calculate overall influence score for each example""" | |
| # If we have real attention data, use the extractor's method | |
| if hasattr(self, 'last_attention_data') and self.last_attention_data: | |
| return self.attention_extractor.calculate_example_influences(attention_from_examples) | |
| # Otherwise use existing calculation | |
| influences = {} | |
| for example_id, weights in attention_from_examples.items(): | |
| influences[example_id] = float(np.mean(weights)) if weights else 0.0 | |
| total = sum(influences.values()) | |
| if total > 0 and total != 1.0: | |
| influences = {k: v/total for k, v in influences.items()} | |
| return influences | |
| def track_hidden_state_drift(self, base_hidden_states, example_hidden_states) -> List[float]: | |
| """Track how hidden states change from base (no examples) to with examples""" | |
| if base_hidden_states is None or example_hidden_states is None: | |
| return [] | |
| # Calculate L2 distance between hidden states at each position | |
| drift = [] | |
| min_len = min(len(base_hidden_states), len(example_hidden_states)) | |
| for i in range(min_len): | |
| base = base_hidden_states[i] | |
| example = example_hidden_states[i] | |
| if isinstance(base, torch.Tensor): | |
| base = base.cpu().numpy() | |
| if isinstance(example, torch.Tensor): | |
| example = example.cpu().numpy() | |
| distance = np.linalg.norm(example - base) | |
| drift.append(float(distance)) | |
| return drift | |
| def analyze_generation( | |
| self, | |
| examples: List[ICLExample], | |
| test_prompt: str, | |
| max_length: int = 150, | |
| temperature: float = 0.7, | |
| base_hidden_states: Optional[Any] = None | |
| ) -> ICLAnalysisResult: | |
| """Analyze how examples influence generation""" | |
| # Prepare prompt | |
| full_prompt = self.prepare_prompt_with_examples(examples, test_prompt) | |
| # Tokenize | |
| inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) | |
| input_ids = inputs["input_ids"].to(self.device) | |
| attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(self.device) | |
| # Find example boundaries in token space | |
| example_boundaries = [] | |
| if examples: | |
| current_pos = 0 | |
| for example in examples: | |
| example_text = f"{example.input}\n{example.output}\n" | |
| example_tokens = self.tokenizer(example_text, add_special_tokens=False)["input_ids"] | |
| example_boundaries.append((current_pos, current_pos + len(example_tokens))) | |
| current_pos += len(example_tokens) | |
| # First do standard generation to get scores and text | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_length=max_length, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| output_hidden_states=False | |
| ) | |
| # Then try to extract real attention data | |
| try: | |
| logger.info("Extracting real attention data") | |
| _, attention_data, _ = self.attention_extractor.extract_attention_with_generation( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=min(30, max_length - len(input_ids[0])), # Limit for performance | |
| temperature=temperature | |
| ) | |
| self.last_attention_data = attention_data | |
| logger.info(f"Successfully extracted {len(attention_data)} attention records") | |
| except Exception as e: | |
| logger.warning(f"Real attention extraction failed: {e}") | |
| self.last_attention_data = None | |
| # Extract generated tokens - show raw output, no trimming | |
| generated_ids = outputs.sequences[0][len(input_ids[0]):] | |
| generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| tokens = [self.tokenizer.decode([token_id]) for token_id in generated_ids] | |
| # Calculate confidence scores | |
| confidence_scores = [] | |
| if outputs.scores: | |
| for score in outputs.scores: | |
| probs = F.softmax(score[0], dim=-1) | |
| max_prob = probs.max().item() | |
| confidence_scores.append(max_prob) | |
| # Calculate perplexity | |
| if outputs.scores: | |
| log_probs = [] | |
| for i, score in enumerate(outputs.scores): | |
| if i < len(generated_ids): | |
| token_id = generated_ids[i] | |
| log_prob = F.log_softmax(score[0], dim=-1)[token_id].item() | |
| log_probs.append(log_prob) | |
| perplexity = np.exp(-np.mean(log_probs)) if log_probs else 0.0 | |
| else: | |
| perplexity = 0.0 | |
| # Extract attention patterns | |
| attention_from_examples = self.extract_attention_patterns(outputs, input_ids, example_boundaries) | |
| # Calculate example influences | |
| example_influences = self.calculate_example_influences(attention_from_examples) | |
| # Track hidden state drift if base states provided | |
| hidden_state_drift = None | |
| if base_hidden_states is not None and hasattr(outputs, 'hidden_states'): | |
| current_hidden = outputs.hidden_states[-1] if outputs.hidden_states else None | |
| if current_hidden is not None: | |
| hidden_state_drift = self.track_hidden_state_drift(base_hidden_states, current_hidden) | |
| # Analyze ICL emergence if we have attention data and examples | |
| icl_emergence = None | |
| if self.last_attention_data and len(examples) > 0: | |
| try: | |
| icl_emergence = self.induction_detector.analyze_icl_emergence( | |
| self.last_attention_data, | |
| input_ids, | |
| example_boundaries, | |
| generated_ids.tolist() if generated_ids.numel() > 0 else [] | |
| ) | |
| logger.info(f"ICL emergence analysis: detected={icl_emergence.emergence_detected}, " | |
| f"token={icl_emergence.emergence_token}, confidence={icl_emergence.confidence:.2f}") | |
| except Exception as e: | |
| logger.warning(f"ICL emergence analysis failed: {e}") | |
| return ICLAnalysisResult( | |
| shot_count=len(examples), | |
| generated_code=generated_text, | |
| tokens=tokens, | |
| confidence_scores=confidence_scores, | |
| attention_from_examples=attention_from_examples, | |
| perplexity=perplexity, | |
| avg_confidence=np.mean(confidence_scores) if confidence_scores else 0.0, | |
| example_influences=example_influences, | |
| hidden_state_drift=hidden_state_drift, | |
| icl_emergence=icl_emergence | |
| ) | |
| def compare_shot_settings( | |
| self, | |
| examples: List[ICLExample], | |
| test_prompt: str, | |
| max_length: int = 150, | |
| temperature: float = 0.7 | |
| ) -> Dict[str, ICLAnalysisResult]: | |
| """Compare 0-shot, 1-shot, and few-shot generation""" | |
| results = {} | |
| # 0-shot (no examples) | |
| results['zero_shot'] = self.analyze_generation([], test_prompt, max_length, temperature) | |
| base_hidden = results['zero_shot'].hidden_state_drift # Use as baseline | |
| # 1-shot (first example only) | |
| if len(examples) >= 1: | |
| results['one_shot'] = self.analyze_generation( | |
| examples[:1], test_prompt, max_length, temperature, base_hidden | |
| ) | |
| # Few-shot (all examples) | |
| if len(examples) >= 2: | |
| results['few_shot'] = self.analyze_generation( | |
| examples, test_prompt, max_length, temperature, base_hidden | |
| ) | |
| return results |