File size: 9,960 Bytes
c383594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""Evaluation: factual probes + sampled factual English scoring.

Extracted from train.py (W1 modularization). Semantics unchanged.

Perf optimizations (eval_perf_fix):
- Probe mode: single forward per prompt instead of autoregressive gen
- Batch decode: all GPU work first, all CPU decode after
- Batched factual probes: single padded forward instead of N sequential
"""

from __future__ import annotations

import os
import re as _re

import torch

from hydra.config import FACTUAL_SAMPLES, FACTUAL_BATCH, FACTUAL_GEN_TOKENS, USE_MDLM, MDLM_MASK_ID
from hydra.mdlm_decode import mdlm_next_token_logits

# Default to probe mode (1 forward per prompt); set HYDRA_FACTUAL_MODE=gen for
# the original autoregressive generation path.
FACTUAL_MODE = os.environ.get("HYDRA_FACTUAL_MODE", "probe")


def _next_token_logits(model, x: torch.Tensor) -> torch.Tensor:
    """Return next-token logits, branching on MDLM training mode.

    Audit 2026-05-09 issue #16: when MDLM training is on, the model was
    trained to reconstruct masked positions, not to autoregressively predict
    the next token. Reading ``model(x)[:, -1, :]`` therefore measures the
    wrong distribution. Route through ``mdlm_next_token_logits`` which
    appends a single MASK slot and returns the prediction at that slot.

    Returns a 2D tensor of shape (B, V) in float precision.
    """
    if USE_MDLM:
        # mask_id default of -1 is a sentinel for "use vocab_size-1"; the
        # mdlm_decode helper resolves the actual mask id via
        # validate_mask_token_id once we know the vocab size.
        mask_id = MDLM_MASK_ID
        if mask_id < 0:
            mask_id = int(getattr(model.config, "vocab_size", 0)) - 1
        return mdlm_next_token_logits(
            model,
            x,
            mask_id=mask_id,
            vocab_size=int(model.config.vocab_size),
        )
    logits = model(x, targets=None)
    if logits.dim() == 3:
        return logits[:, -1, :].float()
    return logits.float()

FACTUAL_EVAL = [
    # Hard factual recall — requires specific knowledge memorization
    ("The capital of France is", ["Paris", "paris"]),
    ("Water boils at", ["100", "boiling"]),
    ("The largest planet in our solar system is", ["Jupiter", "jupiter"]),
    # Easier completions — common collocations / patterns the model may pick up
    ("Once upon a", ["time"]),
    ("Hello, my name", ["is", "'s"]),
    ("The cat sat on the", ["mat", "floor", "rug", "table", "couch", "chair", "ground"]),
    ("She opened the door and", ["walked", "saw", "found", "stepped", "looked", "went", "ran"]),
    # Original hard ones kept for completeness
    ("The speed of light is approximately", ["299", "300", "186,000", "light speed"]),
    ("Two plus two equals", ["4", "four"]),
]

_FACTUAL_PROBES = [
    "The capital of France is",
    "Water boils at",
    "The largest planet in our solar system is",
    "The speed of light is approximately",
    "Shakespeare wrote",
]


def run_factual_probes(model, tokenizer, device, autocast_ctx) -> None:
    """Top-5 next-token predictions for canonical factual prompts.

    Batched: pads all prompts into a single forward pass instead of N
    sequential passes.
    """
    print("\n--- Factual Probes ---")
    model.eval()

    # Process probes one at a time to avoid cooperative launch limit
    # (batched forward with B=len(probes) can exceed SM residency cap).
    for prompt_text in _FACTUAL_PROBES:
        ids = tokenizer.encode(prompt_text)
        x = torch.tensor([ids], device=device)
        with torch.no_grad(), autocast_ctx:
            logits = model(x)
        probs = torch.softmax(logits[0, -1].float(), dim=-1)
        top5 = torch.topk(probs, 5)
        completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
        probs_list = [f"{p:.4f}" for p in top5.values[:3].tolist()]
        print(f'  "{prompt_text}" -> {completions[:3]} (p={probs_list})')
    print("--- End Factual Probes ---\n")


# ---------------------------------------------------------------------------
# Probe mode: single forward per prompt (Fix D)
# ---------------------------------------------------------------------------

def _run_factual_english_probe(model, tokenizer, max_seq_len: int):
    """Fast probe mode: for each (prompt, answers), encode prompt + each answer
    candidate as a single sequence, do ONE forward pass, and check if the model's
    argmax at the last prompt token matches the first answer token.

    Falls back to checking top-K predictions to be generous (same as gen mode
    which samples multiple temperatures).
    """
    print("---")
    print("factual_english_samples: (probe mode)")
    model.eval()
    hits = 0

    with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        for prompt, answers in FACTUAL_EVAL:
            prompt_ids = tokenizer.encode(prompt)
            prompt_len = len(prompt_ids)
            x = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
            # Audit 2026-05-09 #16: route through MDLM contract if active.
            last_logits = _next_token_logits(model, x)[0]

            probs = torch.softmax(last_logits, dim=-1)
            # Check top-K predictions (generous: K=20 to match multi-sample gen)
            top_k = min(20, probs.shape[-1])
            top_ids = torch.topk(probs, top_k).indices.tolist()
            top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]

            answers_lower = [a.lower() for a in answers]
            any_hit = any(
                any(a in tok for a in answers_lower)
                for tok in top_tokens
            )
            if any_hit:
                hits += 1

            best_completion = tokenizer.decode([top_ids[0]])
            print(f"  prompt: {prompt!r}")
            print(f"  output: {(prompt + best_completion).replace(chr(10), ' ')!r}")
            print(f"  hit:    {any_hit} (probe top-{top_k})")

    score = hits / len(FACTUAL_EVAL)
    print("---")
    print(f"factual_english_score: {score:.4f}")
    print(f"factual_english_hits:  {hits}/{len(FACTUAL_EVAL)}")
    return score, hits, len(FACTUAL_EVAL)


# ---------------------------------------------------------------------------
# Gen mode: original autoregressive path (Fix F: batch decode)
# ---------------------------------------------------------------------------

def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
    """Original autoregressive generation path with batch decode optimization:
    all GPU work runs first, then all CPU decoding happens after."""
    print("---")
    print("factual_english_samples: (gen mode)")
    model.eval()

    num_samples = FACTUAL_SAMPLES
    batch = FACTUAL_BATCH
    gen_tokens = FACTUAL_GEN_TOKENS
    temps = [0.7, 0.9, 1.1]
    hits = 0

    with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        for prompt, answers in FACTUAL_EVAL:
            ids = tokenizer.encode(prompt)
            answers_lower = [a.lower() for a in answers]
            # Collect all generated token sequences on GPU first
            all_rows: list[list[int]] = []
            samples_done = 0
            batch_idx = 0
            while samples_done < num_samples:
                b = min(batch, num_samples - samples_done)
                temp = temps[batch_idx % len(temps)]
                batch_idx += 1
                ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
                for _ in range(gen_tokens):
                    # Audit 2026-05-09 #16: route through MDLM contract if active.
                    next_logits = _next_token_logits(model, ctx)
                    probs = torch.softmax(next_logits / temp, dim=-1)
                    next_id = torch.multinomial(probs, num_samples=1)
                    ctx = torch.cat([ctx, next_id], dim=1)
                    if ctx.size(1) >= max_seq_len:
                        break
                # Transfer to CPU in one shot, no per-row sync
                all_rows.extend(ctx.cpu().tolist())
                samples_done += b

            # CPU-side batch decode — no GPU sync between decodes
            any_hit = False
            first_gen = None
            hit_gen = None
            for row in all_rows:
                generated = tokenizer.decode(row)
                continuation = generated[len(prompt):].strip()
                _words = set(w.lower() for w in _re.findall(r"\b[\w'-]+\b", continuation))
                hit = any(a in _words for a in answers_lower)
                if first_gen is None:
                    first_gen = generated
                if hit:
                    any_hit = True
                    if hit_gen is None:
                        hit_gen = generated
            if any_hit:
                hits += 1
            print(f"  prompt: {prompt!r}")
            print(f"  output: {(first_gen or '').replace(chr(10), ' ')!r}")
            print(f"  hit:    {any_hit} (any of {num_samples} samples, temps={temps}, gen={gen_tokens}tok)")
            if hit_gen is not None and hit_gen != first_gen:
                print(f"  hit_sample: {hit_gen.replace(chr(10), ' ')!r}")

    score = hits / len(FACTUAL_EVAL)
    print("---")
    print(f"factual_english_score: {score:.4f}")
    print(f"factual_english_hits:  {hits}/{len(FACTUAL_EVAL)}")
    return score, hits, len(FACTUAL_EVAL)


# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------

def run_factual_english(model, tokenizer, max_seq_len: int):
    """Dispatch to probe (fast, default) or gen (original) mode.

    Set HYDRA_FACTUAL_MODE=gen to use the autoregressive path.
    """
    if FACTUAL_MODE == "gen":
        return _run_factual_english_gen(model, tokenizer, max_seq_len)
    return _run_factual_english_probe(model, tokenizer, max_seq_len)