cortex / benchmark /tuning.py
theapemachine's picture
Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.
0de2901
"""
Training utilities for supervised Cortex adapter tuning.
These helpers keep the base model frozen and optimize only the modules managed by
CortexSurgeon. They intentionally mirror benchmark log-likelihood scoring so a
small tuning run optimizes the same multiple-choice objective being evaluated.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from benchmark.scoring import reset_cortex_state
def continuation_log_likelihood(
model,
tokenizer,
context: str,
continuation: str,
device: str,
) -> Optional[torch.Tensor]:
"""Differentiable average continuation log-likelihood."""
ctx_ids = tokenizer.encode(context, add_special_tokens=False)
full_ids = tokenizer.encode(context + continuation, add_special_tokens=False)
cont_start = len(ctx_ids)
cont_length = len(full_ids) - cont_start
if cont_start <= 0 or cont_length <= 0:
return None
input_ids = torch.tensor([full_ids], device=device)
max_len = getattr(model.config, "max_position_embeddings", 2048)
if input_ids.shape[1] > max_len:
input_ids = input_ids[:, :max_len]
cont_length = min(cont_length, max_len - cont_start)
if cont_length <= 0:
return None
reset_cortex_state(model, batch_size=input_ids.shape[0])
outputs = model(input_ids)
logits = outputs.logits
shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :]
shift_labels = input_ids[0, cont_start : cont_start + cont_length]
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
return token_log_probs.mean()
def multiple_choice_loss(
model,
tokenizer,
example: Dict,
device: str,
) -> Tuple[Optional[torch.Tensor], Optional[int]]:
"""
Cross-entropy over continuation log-likelihoods.
Returns:
(loss, prediction). If an example cannot be scored, both are None.
"""
scores: List[torch.Tensor] = []
for continuation in example["continuations"]:
score = continuation_log_likelihood(
model, tokenizer, example["context"], continuation, device
)
if score is None:
return None, None
scores.append(score)
logits = torch.stack(scores).unsqueeze(0)
gold = torch.tensor([example["gold_idx"]], device=device)
loss = F.cross_entropy(logits, gold)
pred = int(logits.argmax(dim=-1).item())
return loss, pred
def cortex_auxiliary_loss(model) -> torch.Tensor:
"""Collect differentiable auxiliary losses exposed by Cortex modules."""
device = next(model.parameters()).device
surgeon = getattr(model, "_cortex_surgeon", None)
if surgeon is None:
return torch.tensor(0.0, device=device)
losses = []
for module in surgeon.modules.values():
get_budget_loss = getattr(module, "get_budget_loss", None)
if get_budget_loss is not None:
losses.append(get_budget_loss())
if not losses:
return torch.tensor(0.0, device=device)
return torch.stack([loss.to(device) for loss in losses]).sum()