cortex / benchmark /scoring.py
theapemachine's picture
Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.
0de2901
"""
Scoring utilities for the Cortex benchmark harness.
Two evaluation modes:
1. Log-likelihood scoring: For multiple-choice tasks (HellaSwag, ARC, PIQA, etc.)
Computes the average log-probability the model assigns to each continuation.
2. Generation scoring: For free-form generation tasks (passkey retrieval, etc.)
Generates text and checks against expected patterns.
"""
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Dict
import re
from cortex.torch_device import resolve_torch_device
def reset_cortex_state(model, batch_size: int = 1):
"""Reset runtime state for injected Cortex modules between independent examples."""
surgeon = getattr(model, "_cortex_surgeon", None)
if surgeon is None:
return
for module in surgeon.modules.values():
module.reset_state(batch_size=batch_size)
@torch.no_grad()
def log_likelihood_score(
model,
tokenizer,
context: str,
continuations: List[str],
device: Optional[str] = None,
) -> List[float]:
"""
Compute normalized log-likelihood for each continuation given a context.
For each (context, continuation) pair:
1. Tokenize context + continuation together
2. Run forward pass to get logits
3. Compute average log-prob over the continuation tokens only
Args:
model: The language model
tokenizer: The tokenizer
context: The prompt/context string
continuations: List of possible continuations to score
device: Device to use (default: auto — cuda, then mps, then cpu)
Returns:
List of normalized log-likelihood scores (higher = model prefers this continuation)
"""
if device is None:
device = resolve_torch_device("auto")
scores = []
for cont in continuations:
# Tokenize context and full sequence separately to find where continuation starts
ctx_ids = tokenizer.encode(context, add_special_tokens=False)
full_text = context + cont
full_ids = tokenizer.encode(full_text, add_special_tokens=False)
# The continuation tokens start after the context tokens
cont_start = len(ctx_ids)
cont_length = len(full_ids) - cont_start
if cont_length <= 0:
scores.append(float("-inf"))
continue
# Forward pass
input_ids = torch.tensor([full_ids], device=device)
reset_cortex_state(model, batch_size=input_ids.shape[0])
# Truncate if too long for model
max_len = getattr(model.config, "max_position_embeddings", 2048)
if input_ids.shape[1] > max_len:
input_ids = input_ids[:, :max_len]
cont_length = min(cont_length, max_len - cont_start)
if cont_length <= 0:
scores.append(float("-inf"))
continue
outputs = model(input_ids)
logits = outputs.logits # [1, seq_len, vocab_size]
# Shift: logits[i] predicts token[i+1]
# For continuation tokens at positions [cont_start, cont_start+cont_length),
# we need logits at positions [cont_start-1, cont_start+cont_length-1)
shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :]
shift_labels = input_ids[0, cont_start : cont_start + cont_length]
# Log-probabilities
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
# Normalize by continuation length (average log-prob per token)
avg_log_prob = token_log_probs.mean().item()
scores.append(avg_log_prob)
return scores
@torch.no_grad()
def generate_and_check(
model,
tokenizer,
prompt: str,
expected: str,
max_new_tokens: int = 64,
device: Optional[str] = None,
exact_match: bool = False,
) -> Tuple[bool, str]:
"""
Generate text and check if the expected answer appears in the output.
Args:
model: The language model
tokenizer: The tokenizer
prompt: The input prompt
expected: The expected answer string
max_new_tokens: Max tokens to generate
device: Device (default: auto — cuda, then mps, then cpu)
exact_match: If True, requires exact match; otherwise substring match
Returns:
(is_correct, generated_text)
"""
if device is None:
device = resolve_torch_device("auto")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
reset_cortex_state(model, batch_size=inputs["input_ids"].shape[0])
# Pad token
pad_token_id = tokenizer.pad_token_id
if pad_token_id is None:
pad_token_id = tokenizer.eos_token_id
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
pad_token_id=pad_token_id,
)
# Decode only the new tokens
new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
generated = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
if exact_match:
is_correct = generated.strip().lower() == expected.strip().lower()
else:
is_correct = expected.strip().lower() in generated.lower()
return is_correct, generated
def accuracy_from_loglikelihoods(
scores_per_example: List[Tuple[List[float], int]],
) -> Dict[str, float]:
"""
Compute accuracy from log-likelihood scores.
Args:
scores_per_example: List of (scores_for_each_choice, correct_index)
Returns:
Dict with accuracy and count metrics
"""
correct = 0
total = len(scores_per_example)
for scores, gold_idx in scores_per_example:
predicted = max(range(len(scores)), key=lambda i: scores[i])
if predicted == gold_idx:
correct += 1
return {
"accuracy": correct / total if total > 0 else 0.0,
"correct": correct,
"total": total,
}