""" Scoring utilities for the Cortex benchmark harness. Two evaluation modes: 1. Log-likelihood scoring: For multiple-choice tasks (HellaSwag, ARC, PIQA, etc.) Computes the average log-probability the model assigns to each continuation. 2. Generation scoring: For free-form generation tasks (passkey retrieval, etc.) Generates text and checks against expected patterns. """ import torch import torch.nn.functional as F from typing import List, Optional, Tuple, Dict import re from cortex.torch_device import resolve_torch_device def reset_cortex_state(model, batch_size: int = 1): """Reset runtime state for injected Cortex modules between independent examples.""" surgeon = getattr(model, "_cortex_surgeon", None) if surgeon is None: return for module in surgeon.modules.values(): module.reset_state(batch_size=batch_size) @torch.no_grad() def log_likelihood_score( model, tokenizer, context: str, continuations: List[str], device: Optional[str] = None, ) -> List[float]: """ Compute normalized log-likelihood for each continuation given a context. For each (context, continuation) pair: 1. Tokenize context + continuation together 2. Run forward pass to get logits 3. Compute average log-prob over the continuation tokens only Args: model: The language model tokenizer: The tokenizer context: The prompt/context string continuations: List of possible continuations to score device: Device to use (default: auto — cuda, then mps, then cpu) Returns: List of normalized log-likelihood scores (higher = model prefers this continuation) """ if device is None: device = resolve_torch_device("auto") scores = [] for cont in continuations: # Tokenize context and full sequence separately to find where continuation starts ctx_ids = tokenizer.encode(context, add_special_tokens=False) full_text = context + cont full_ids = tokenizer.encode(full_text, add_special_tokens=False) # The continuation tokens start after the context tokens cont_start = len(ctx_ids) cont_length = len(full_ids) - cont_start if cont_length <= 0: scores.append(float("-inf")) continue # Forward pass input_ids = torch.tensor([full_ids], device=device) reset_cortex_state(model, batch_size=input_ids.shape[0]) # Truncate if too long for model max_len = getattr(model.config, "max_position_embeddings", 2048) if input_ids.shape[1] > max_len: input_ids = input_ids[:, :max_len] cont_length = min(cont_length, max_len - cont_start) if cont_length <= 0: scores.append(float("-inf")) continue outputs = model(input_ids) logits = outputs.logits # [1, seq_len, vocab_size] # Shift: logits[i] predicts token[i+1] # For continuation tokens at positions [cont_start, cont_start+cont_length), # we need logits at positions [cont_start-1, cont_start+cont_length-1) shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :] shift_labels = input_ids[0, cont_start : cont_start + cont_length] # Log-probabilities log_probs = F.log_softmax(shift_logits, dim=-1) token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1) # Normalize by continuation length (average log-prob per token) avg_log_prob = token_log_probs.mean().item() scores.append(avg_log_prob) return scores @torch.no_grad() def generate_and_check( model, tokenizer, prompt: str, expected: str, max_new_tokens: int = 64, device: Optional[str] = None, exact_match: bool = False, ) -> Tuple[bool, str]: """ Generate text and check if the expected answer appears in the output. Args: model: The language model tokenizer: The tokenizer prompt: The input prompt expected: The expected answer string max_new_tokens: Max tokens to generate device: Device (default: auto — cuda, then mps, then cpu) exact_match: If True, requires exact match; otherwise substring match Returns: (is_correct, generated_text) """ if device is None: device = resolve_torch_device("auto") inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device) reset_cortex_state(model, batch_size=inputs["input_ids"].shape[0]) # Pad token pad_token_id = tokenizer.pad_token_id if pad_token_id is None: pad_token_id = tokenizer.eos_token_id output_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, pad_token_id=pad_token_id, ) # Decode only the new tokens new_tokens = output_ids[0, inputs["input_ids"].shape[1]:] generated = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() if exact_match: is_correct = generated.strip().lower() == expected.strip().lower() else: is_correct = expected.strip().lower() in generated.lower() return is_correct, generated def accuracy_from_loglikelihoods( scores_per_example: List[Tuple[List[float], int]], ) -> Dict[str, float]: """ Compute accuracy from log-likelihood scores. Args: scores_per_example: List of (scores_for_each_choice, correct_index) Returns: Dict with accuracy and count metrics """ correct = 0 total = len(scores_per_example) for scores, gold_idx in scores_per_example: predicted = max(range(len(scores)), key=lambda i: scores[i]) if predicted == gold_idx: correct += 1 return { "accuracy": correct / total if total > 0 else 0.0, "correct": correct, "total": total, }