File size: 3,203 Bytes
0de2901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Training utilities for supervised Cortex adapter tuning.

These helpers keep the base model frozen and optimize only the modules managed by
CortexSurgeon. They intentionally mirror benchmark log-likelihood scoring so a
small tuning run optimizes the same multiple-choice objective being evaluated.
"""

from __future__ import annotations

from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F

from benchmark.scoring import reset_cortex_state


def continuation_log_likelihood(
    model,
    tokenizer,
    context: str,
    continuation: str,
    device: str,
) -> Optional[torch.Tensor]:
    """Differentiable average continuation log-likelihood."""
    ctx_ids = tokenizer.encode(context, add_special_tokens=False)
    full_ids = tokenizer.encode(context + continuation, add_special_tokens=False)

    cont_start = len(ctx_ids)
    cont_length = len(full_ids) - cont_start
    if cont_start <= 0 or cont_length <= 0:
        return None

    input_ids = torch.tensor([full_ids], device=device)
    max_len = getattr(model.config, "max_position_embeddings", 2048)
    if input_ids.shape[1] > max_len:
        input_ids = input_ids[:, :max_len]
        cont_length = min(cont_length, max_len - cont_start)
        if cont_length <= 0:
            return None

    reset_cortex_state(model, batch_size=input_ids.shape[0])
    outputs = model(input_ids)
    logits = outputs.logits

    shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :]
    shift_labels = input_ids[0, cont_start : cont_start + cont_length]
    log_probs = F.log_softmax(shift_logits, dim=-1)
    token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
    return token_log_probs.mean()


def multiple_choice_loss(
    model,
    tokenizer,
    example: Dict,
    device: str,
) -> Tuple[Optional[torch.Tensor], Optional[int]]:
    """
    Cross-entropy over continuation log-likelihoods.

    Returns:
        (loss, prediction). If an example cannot be scored, both are None.
    """
    scores: List[torch.Tensor] = []
    for continuation in example["continuations"]:
        score = continuation_log_likelihood(
            model, tokenizer, example["context"], continuation, device
        )
        if score is None:
            return None, None
        scores.append(score)

    logits = torch.stack(scores).unsqueeze(0)
    gold = torch.tensor([example["gold_idx"]], device=device)
    loss = F.cross_entropy(logits, gold)
    pred = int(logits.argmax(dim=-1).item())
    return loss, pred


def cortex_auxiliary_loss(model) -> torch.Tensor:
    """Collect differentiable auxiliary losses exposed by Cortex modules."""
    device = next(model.parameters()).device
    surgeon = getattr(model, "_cortex_surgeon", None)
    if surgeon is None:
        return torch.tensor(0.0, device=device)

    losses = []
    for module in surgeon.modules.values():
        get_budget_loss = getattr(module, "get_budget_loss", None)
        if get_budget_loss is not None:
            losses.append(get_budget_loss())

    if not losses:
        return torch.tensor(0.0, device=device)
    return torch.stack([loss.to(device) for loss in losses]).sum()