""" Training utilities for supervised Cortex adapter tuning. These helpers keep the base model frozen and optimize only the modules managed by CortexSurgeon. They intentionally mirror benchmark log-likelihood scoring so a small tuning run optimizes the same multiple-choice objective being evaluated. """ from __future__ import annotations from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F from benchmark.scoring import reset_cortex_state def continuation_log_likelihood( model, tokenizer, context: str, continuation: str, device: str, ) -> Optional[torch.Tensor]: """Differentiable average continuation log-likelihood.""" ctx_ids = tokenizer.encode(context, add_special_tokens=False) full_ids = tokenizer.encode(context + continuation, add_special_tokens=False) cont_start = len(ctx_ids) cont_length = len(full_ids) - cont_start if cont_start <= 0 or cont_length <= 0: return None input_ids = torch.tensor([full_ids], device=device) max_len = getattr(model.config, "max_position_embeddings", 2048) if input_ids.shape[1] > max_len: input_ids = input_ids[:, :max_len] cont_length = min(cont_length, max_len - cont_start) if cont_length <= 0: return None reset_cortex_state(model, batch_size=input_ids.shape[0]) outputs = model(input_ids) logits = outputs.logits shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :] shift_labels = input_ids[0, cont_start : cont_start + cont_length] log_probs = F.log_softmax(shift_logits, dim=-1) token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1) return token_log_probs.mean() def multiple_choice_loss( model, tokenizer, example: Dict, device: str, ) -> Tuple[Optional[torch.Tensor], Optional[int]]: """ Cross-entropy over continuation log-likelihoods. Returns: (loss, prediction). If an example cannot be scored, both are None. """ scores: List[torch.Tensor] = [] for continuation in example["continuations"]: score = continuation_log_likelihood( model, tokenizer, example["context"], continuation, device ) if score is None: return None, None scores.append(score) logits = torch.stack(scores).unsqueeze(0) gold = torch.tensor([example["gold_idx"]], device=device) loss = F.cross_entropy(logits, gold) pred = int(logits.argmax(dim=-1).item()) return loss, pred def cortex_auxiliary_loss(model) -> torch.Tensor: """Collect differentiable auxiliary losses exposed by Cortex modules.""" device = next(model.parameters()).device surgeon = getattr(model, "_cortex_surgeon", None) if surgeon is None: return torch.tensor(0.0, device=device) losses = [] for module in surgeon.modules.values(): get_budget_loss = getattr(module, "get_budget_loss", None) if get_budget_loss is not None: losses.append(get_budget_loss()) if not losses: return torch.tensor(0.0, device=device) return torch.stack([loss.to(device) for loss in losses]).sum()