Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.
0de2901 | """ | |
| Scoring utilities for the Cortex benchmark harness. | |
| Two evaluation modes: | |
| 1. Log-likelihood scoring: For multiple-choice tasks (HellaSwag, ARC, PIQA, etc.) | |
| Computes the average log-probability the model assigns to each continuation. | |
| 2. Generation scoring: For free-form generation tasks (passkey retrieval, etc.) | |
| Generates text and checks against expected patterns. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import List, Optional, Tuple, Dict | |
| import re | |
| from cortex.torch_device import resolve_torch_device | |
| def reset_cortex_state(model, batch_size: int = 1): | |
| """Reset runtime state for injected Cortex modules between independent examples.""" | |
| surgeon = getattr(model, "_cortex_surgeon", None) | |
| if surgeon is None: | |
| return | |
| for module in surgeon.modules.values(): | |
| module.reset_state(batch_size=batch_size) | |
| def log_likelihood_score( | |
| model, | |
| tokenizer, | |
| context: str, | |
| continuations: List[str], | |
| device: Optional[str] = None, | |
| ) -> List[float]: | |
| """ | |
| Compute normalized log-likelihood for each continuation given a context. | |
| For each (context, continuation) pair: | |
| 1. Tokenize context + continuation together | |
| 2. Run forward pass to get logits | |
| 3. Compute average log-prob over the continuation tokens only | |
| Args: | |
| model: The language model | |
| tokenizer: The tokenizer | |
| context: The prompt/context string | |
| continuations: List of possible continuations to score | |
| device: Device to use (default: auto — cuda, then mps, then cpu) | |
| Returns: | |
| List of normalized log-likelihood scores (higher = model prefers this continuation) | |
| """ | |
| if device is None: | |
| device = resolve_torch_device("auto") | |
| scores = [] | |
| for cont in continuations: | |
| # Tokenize context and full sequence separately to find where continuation starts | |
| ctx_ids = tokenizer.encode(context, add_special_tokens=False) | |
| full_text = context + cont | |
| full_ids = tokenizer.encode(full_text, add_special_tokens=False) | |
| # The continuation tokens start after the context tokens | |
| cont_start = len(ctx_ids) | |
| cont_length = len(full_ids) - cont_start | |
| if cont_length <= 0: | |
| scores.append(float("-inf")) | |
| continue | |
| # Forward pass | |
| input_ids = torch.tensor([full_ids], device=device) | |
| reset_cortex_state(model, batch_size=input_ids.shape[0]) | |
| # Truncate if too long for model | |
| max_len = getattr(model.config, "max_position_embeddings", 2048) | |
| if input_ids.shape[1] > max_len: | |
| input_ids = input_ids[:, :max_len] | |
| cont_length = min(cont_length, max_len - cont_start) | |
| if cont_length <= 0: | |
| scores.append(float("-inf")) | |
| continue | |
| outputs = model(input_ids) | |
| logits = outputs.logits # [1, seq_len, vocab_size] | |
| # Shift: logits[i] predicts token[i+1] | |
| # For continuation tokens at positions [cont_start, cont_start+cont_length), | |
| # we need logits at positions [cont_start-1, cont_start+cont_length-1) | |
| shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :] | |
| shift_labels = input_ids[0, cont_start : cont_start + cont_length] | |
| # Log-probabilities | |
| log_probs = F.log_softmax(shift_logits, dim=-1) | |
| token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1) | |
| # Normalize by continuation length (average log-prob per token) | |
| avg_log_prob = token_log_probs.mean().item() | |
| scores.append(avg_log_prob) | |
| return scores | |
| def generate_and_check( | |
| model, | |
| tokenizer, | |
| prompt: str, | |
| expected: str, | |
| max_new_tokens: int = 64, | |
| device: Optional[str] = None, | |
| exact_match: bool = False, | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Generate text and check if the expected answer appears in the output. | |
| Args: | |
| model: The language model | |
| tokenizer: The tokenizer | |
| prompt: The input prompt | |
| expected: The expected answer string | |
| max_new_tokens: Max tokens to generate | |
| device: Device (default: auto — cuda, then mps, then cpu) | |
| exact_match: If True, requires exact match; otherwise substring match | |
| Returns: | |
| (is_correct, generated_text) | |
| """ | |
| if device is None: | |
| device = resolve_torch_device("auto") | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device) | |
| reset_cortex_state(model, batch_size=inputs["input_ids"].shape[0]) | |
| # Pad token | |
| pad_token_id = tokenizer.pad_token_id | |
| if pad_token_id is None: | |
| pad_token_id = tokenizer.eos_token_id | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| temperature=1.0, | |
| pad_token_id=pad_token_id, | |
| ) | |
| # Decode only the new tokens | |
| new_tokens = output_ids[0, inputs["input_ids"].shape[1]:] | |
| generated = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| if exact_match: | |
| is_correct = generated.strip().lower() == expected.strip().lower() | |
| else: | |
| is_correct = expected.strip().lower() in generated.lower() | |
| return is_correct, generated | |
| def accuracy_from_loglikelihoods( | |
| scores_per_example: List[Tuple[List[float], int]], | |
| ) -> Dict[str, float]: | |
| """ | |
| Compute accuracy from log-likelihood scores. | |
| Args: | |
| scores_per_example: List of (scores_for_each_choice, correct_index) | |
| Returns: | |
| Dict with accuracy and count metrics | |
| """ | |
| correct = 0 | |
| total = len(scores_per_example) | |
| for scores, gold_idx in scores_per_example: | |
| predicted = max(range(len(scores)), key=lambda i: scores[i]) | |
| if predicted == gold_idx: | |
| correct += 1 | |
| return { | |
| "accuracy": correct / total if total > 0 else 0.0, | |
| "correct": correct, | |
| "total": total, | |
| } | |