File size: 3,203 Bytes
0de2901 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | """
Training utilities for supervised Cortex adapter tuning.
These helpers keep the base model frozen and optimize only the modules managed by
CortexSurgeon. They intentionally mirror benchmark log-likelihood scoring so a
small tuning run optimizes the same multiple-choice objective being evaluated.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from benchmark.scoring import reset_cortex_state
def continuation_log_likelihood(
model,
tokenizer,
context: str,
continuation: str,
device: str,
) -> Optional[torch.Tensor]:
"""Differentiable average continuation log-likelihood."""
ctx_ids = tokenizer.encode(context, add_special_tokens=False)
full_ids = tokenizer.encode(context + continuation, add_special_tokens=False)
cont_start = len(ctx_ids)
cont_length = len(full_ids) - cont_start
if cont_start <= 0 or cont_length <= 0:
return None
input_ids = torch.tensor([full_ids], device=device)
max_len = getattr(model.config, "max_position_embeddings", 2048)
if input_ids.shape[1] > max_len:
input_ids = input_ids[:, :max_len]
cont_length = min(cont_length, max_len - cont_start)
if cont_length <= 0:
return None
reset_cortex_state(model, batch_size=input_ids.shape[0])
outputs = model(input_ids)
logits = outputs.logits
shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :]
shift_labels = input_ids[0, cont_start : cont_start + cont_length]
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
return token_log_probs.mean()
def multiple_choice_loss(
model,
tokenizer,
example: Dict,
device: str,
) -> Tuple[Optional[torch.Tensor], Optional[int]]:
"""
Cross-entropy over continuation log-likelihoods.
Returns:
(loss, prediction). If an example cannot be scored, both are None.
"""
scores: List[torch.Tensor] = []
for continuation in example["continuations"]:
score = continuation_log_likelihood(
model, tokenizer, example["context"], continuation, device
)
if score is None:
return None, None
scores.append(score)
logits = torch.stack(scores).unsqueeze(0)
gold = torch.tensor([example["gold_idx"]], device=device)
loss = F.cross_entropy(logits, gold)
pred = int(logits.argmax(dim=-1).item())
return loss, pred
def cortex_auxiliary_loss(model) -> torch.Tensor:
"""Collect differentiable auxiliary losses exposed by Cortex modules."""
device = next(model.parameters()).device
surgeon = getattr(model, "_cortex_surgeon", None)
if surgeon is None:
return torch.tensor(0.0, device=device)
losses = []
for module in surgeon.modules.values():
get_budget_loss = getattr(module, "get_budget_loss", None)
if get_budget_loss is not None:
losses.append(get_budget_loss())
if not losses:
return torch.tensor(0.0, device=device)
return torch.stack([loss.to(device) for loss in losses]).sum()
|