File size: 6,130 Bytes
4c1ba64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ac64e3
 
4c1ba64
0de2901
 
 
 
 
 
 
 
 
4c1ba64
 
 
 
 
 
0ac64e3
4c1ba64
 
 
 
 
 
 
 
 
 
 
 
 
 
0ac64e3
4c1ba64
 
 
 
0ac64e3
 
4c1ba64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0de2901
4c1ba64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ac64e3
4c1ba64
 
 
 
 
 
 
 
 
 
 
0ac64e3
4c1ba64
 
 
 
 
0ac64e3
 
4c1ba64
0de2901
4c1ba64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
Scoring utilities for the Cortex benchmark harness.

Two evaluation modes:
  1. Log-likelihood scoring: For multiple-choice tasks (HellaSwag, ARC, PIQA, etc.)
     Computes the average log-probability the model assigns to each continuation.
     
  2. Generation scoring: For free-form generation tasks (passkey retrieval, etc.)
     Generates text and checks against expected patterns.
"""

import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Dict
import re

from cortex.torch_device import resolve_torch_device


def reset_cortex_state(model, batch_size: int = 1):
    """Reset runtime state for injected Cortex modules between independent examples."""
    surgeon = getattr(model, "_cortex_surgeon", None)
    if surgeon is None:
        return
    for module in surgeon.modules.values():
        module.reset_state(batch_size=batch_size)


@torch.no_grad()
def log_likelihood_score(
    model,
    tokenizer,
    context: str,
    continuations: List[str],
    device: Optional[str] = None,
) -> List[float]:
    """
    Compute normalized log-likelihood for each continuation given a context.
    
    For each (context, continuation) pair:
      1. Tokenize context + continuation together
      2. Run forward pass to get logits
      3. Compute average log-prob over the continuation tokens only
    
    Args:
        model: The language model
        tokenizer: The tokenizer
        context: The prompt/context string
        continuations: List of possible continuations to score
        device: Device to use (default: auto — cuda, then mps, then cpu)
        
    Returns:
        List of normalized log-likelihood scores (higher = model prefers this continuation)
    """
    if device is None:
        device = resolve_torch_device("auto")
    scores = []
    
    for cont in continuations:
        # Tokenize context and full sequence separately to find where continuation starts
        ctx_ids = tokenizer.encode(context, add_special_tokens=False)
        full_text = context + cont
        full_ids = tokenizer.encode(full_text, add_special_tokens=False)
        
        # The continuation tokens start after the context tokens
        cont_start = len(ctx_ids)
        cont_length = len(full_ids) - cont_start
        
        if cont_length <= 0:
            scores.append(float("-inf"))
            continue
        
        # Forward pass
        input_ids = torch.tensor([full_ids], device=device)
        reset_cortex_state(model, batch_size=input_ids.shape[0])
        
        # Truncate if too long for model
        max_len = getattr(model.config, "max_position_embeddings", 2048)
        if input_ids.shape[1] > max_len:
            input_ids = input_ids[:, :max_len]
            cont_length = min(cont_length, max_len - cont_start)
            if cont_length <= 0:
                scores.append(float("-inf"))
                continue
        
        outputs = model(input_ids)
        logits = outputs.logits  # [1, seq_len, vocab_size]
        
        # Shift: logits[i] predicts token[i+1]
        # For continuation tokens at positions [cont_start, cont_start+cont_length),
        # we need logits at positions [cont_start-1, cont_start+cont_length-1)
        shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :]
        shift_labels = input_ids[0, cont_start : cont_start + cont_length]
        
        # Log-probabilities
        log_probs = F.log_softmax(shift_logits, dim=-1)
        token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
        
        # Normalize by continuation length (average log-prob per token)
        avg_log_prob = token_log_probs.mean().item()
        scores.append(avg_log_prob)
    
    return scores


@torch.no_grad()
def generate_and_check(
    model,
    tokenizer,
    prompt: str,
    expected: str,
    max_new_tokens: int = 64,
    device: Optional[str] = None,
    exact_match: bool = False,
) -> Tuple[bool, str]:
    """
    Generate text and check if the expected answer appears in the output.
    
    Args:
        model: The language model
        tokenizer: The tokenizer
        prompt: The input prompt
        expected: The expected answer string
        max_new_tokens: Max tokens to generate
        device: Device (default: auto — cuda, then mps, then cpu)
        exact_match: If True, requires exact match; otherwise substring match
        
    Returns:
        (is_correct, generated_text)
    """
    if device is None:
        device = resolve_torch_device("auto")
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
    reset_cortex_state(model, batch_size=inputs["input_ids"].shape[0])
    
    # Pad token
    pad_token_id = tokenizer.pad_token_id
    if pad_token_id is None:
        pad_token_id = tokenizer.eos_token_id
    
    output_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=1.0,
        pad_token_id=pad_token_id,
    )
    
    # Decode only the new tokens
    new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
    generated = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    
    if exact_match:
        is_correct = generated.strip().lower() == expected.strip().lower()
    else:
        is_correct = expected.strip().lower() in generated.lower()
    
    return is_correct, generated


def accuracy_from_loglikelihoods(
    scores_per_example: List[Tuple[List[float], int]],
) -> Dict[str, float]:
    """
    Compute accuracy from log-likelihood scores.
    
    Args:
        scores_per_example: List of (scores_for_each_choice, correct_index)
        
    Returns:
        Dict with accuracy and count metrics
    """
    correct = 0
    total = len(scores_per_example)
    
    for scores, gold_idx in scores_per_example:
        predicted = max(range(len(scores)), key=lambda i: scores[i])
        if predicted == gold_idx:
            correct += 1
    
    return {
        "accuracy": correct / total if total > 0 else 0.0,
        "correct": correct,
        "total": total,
    }