File size: 9,960 Bytes
c383594 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | """Evaluation: factual probes + sampled factual English scoring.
Extracted from train.py (W1 modularization). Semantics unchanged.
Perf optimizations (eval_perf_fix):
- Probe mode: single forward per prompt instead of autoregressive gen
- Batch decode: all GPU work first, all CPU decode after
- Batched factual probes: single padded forward instead of N sequential
"""
from __future__ import annotations
import os
import re as _re
import torch
from hydra.config import FACTUAL_SAMPLES, FACTUAL_BATCH, FACTUAL_GEN_TOKENS, USE_MDLM, MDLM_MASK_ID
from hydra.mdlm_decode import mdlm_next_token_logits
# Default to probe mode (1 forward per prompt); set HYDRA_FACTUAL_MODE=gen for
# the original autoregressive generation path.
FACTUAL_MODE = os.environ.get("HYDRA_FACTUAL_MODE", "probe")
def _next_token_logits(model, x: torch.Tensor) -> torch.Tensor:
"""Return next-token logits, branching on MDLM training mode.
Audit 2026-05-09 issue #16: when MDLM training is on, the model was
trained to reconstruct masked positions, not to autoregressively predict
the next token. Reading ``model(x)[:, -1, :]`` therefore measures the
wrong distribution. Route through ``mdlm_next_token_logits`` which
appends a single MASK slot and returns the prediction at that slot.
Returns a 2D tensor of shape (B, V) in float precision.
"""
if USE_MDLM:
# mask_id default of -1 is a sentinel for "use vocab_size-1"; the
# mdlm_decode helper resolves the actual mask id via
# validate_mask_token_id once we know the vocab size.
mask_id = MDLM_MASK_ID
if mask_id < 0:
mask_id = int(getattr(model.config, "vocab_size", 0)) - 1
return mdlm_next_token_logits(
model,
x,
mask_id=mask_id,
vocab_size=int(model.config.vocab_size),
)
logits = model(x, targets=None)
if logits.dim() == 3:
return logits[:, -1, :].float()
return logits.float()
FACTUAL_EVAL = [
# Hard factual recall — requires specific knowledge memorization
("The capital of France is", ["Paris", "paris"]),
("Water boils at", ["100", "boiling"]),
("The largest planet in our solar system is", ["Jupiter", "jupiter"]),
# Easier completions — common collocations / patterns the model may pick up
("Once upon a", ["time"]),
("Hello, my name", ["is", "'s"]),
("The cat sat on the", ["mat", "floor", "rug", "table", "couch", "chair", "ground"]),
("She opened the door and", ["walked", "saw", "found", "stepped", "looked", "went", "ran"]),
# Original hard ones kept for completeness
("The speed of light is approximately", ["299", "300", "186,000", "light speed"]),
("Two plus two equals", ["4", "four"]),
]
_FACTUAL_PROBES = [
"The capital of France is",
"Water boils at",
"The largest planet in our solar system is",
"The speed of light is approximately",
"Shakespeare wrote",
]
def run_factual_probes(model, tokenizer, device, autocast_ctx) -> None:
"""Top-5 next-token predictions for canonical factual prompts.
Batched: pads all prompts into a single forward pass instead of N
sequential passes.
"""
print("\n--- Factual Probes ---")
model.eval()
# Process probes one at a time to avoid cooperative launch limit
# (batched forward with B=len(probes) can exceed SM residency cap).
for prompt_text in _FACTUAL_PROBES:
ids = tokenizer.encode(prompt_text)
x = torch.tensor([ids], device=device)
with torch.no_grad(), autocast_ctx:
logits = model(x)
probs = torch.softmax(logits[0, -1].float(), dim=-1)
top5 = torch.topk(probs, 5)
completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
probs_list = [f"{p:.4f}" for p in top5.values[:3].tolist()]
print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})')
print("--- End Factual Probes ---\n")
# ---------------------------------------------------------------------------
# Probe mode: single forward per prompt (Fix D)
# ---------------------------------------------------------------------------
def _run_factual_english_probe(model, tokenizer, max_seq_len: int):
"""Fast probe mode: for each (prompt, answers), encode prompt + each answer
candidate as a single sequence, do ONE forward pass, and check if the model's
argmax at the last prompt token matches the first answer token.
Falls back to checking top-K predictions to be generous (same as gen mode
which samples multiple temperatures).
"""
print("---")
print("factual_english_samples: (probe mode)")
model.eval()
hits = 0
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
for prompt, answers in FACTUAL_EVAL:
prompt_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_ids)
x = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
# Audit 2026-05-09 #16: route through MDLM contract if active.
last_logits = _next_token_logits(model, x)[0]
probs = torch.softmax(last_logits, dim=-1)
# Check top-K predictions (generous: K=20 to match multi-sample gen)
top_k = min(20, probs.shape[-1])
top_ids = torch.topk(probs, top_k).indices.tolist()
top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
answers_lower = [a.lower() for a in answers]
any_hit = any(
any(a in tok for a in answers_lower)
for tok in top_tokens
)
if any_hit:
hits += 1
best_completion = tokenizer.decode([top_ids[0]])
print(f" prompt: {prompt!r}")
print(f" output: {(prompt + best_completion).replace(chr(10), ' ')!r}")
print(f" hit: {any_hit} (probe top-{top_k})")
score = hits / len(FACTUAL_EVAL)
print("---")
print(f"factual_english_score: {score:.4f}")
print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
return score, hits, len(FACTUAL_EVAL)
# ---------------------------------------------------------------------------
# Gen mode: original autoregressive path (Fix F: batch decode)
# ---------------------------------------------------------------------------
def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
"""Original autoregressive generation path with batch decode optimization:
all GPU work runs first, then all CPU decoding happens after."""
print("---")
print("factual_english_samples: (gen mode)")
model.eval()
num_samples = FACTUAL_SAMPLES
batch = FACTUAL_BATCH
gen_tokens = FACTUAL_GEN_TOKENS
temps = [0.7, 0.9, 1.1]
hits = 0
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
for prompt, answers in FACTUAL_EVAL:
ids = tokenizer.encode(prompt)
answers_lower = [a.lower() for a in answers]
# Collect all generated token sequences on GPU first
all_rows: list[list[int]] = []
samples_done = 0
batch_idx = 0
while samples_done < num_samples:
b = min(batch, num_samples - samples_done)
temp = temps[batch_idx % len(temps)]
batch_idx += 1
ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
for _ in range(gen_tokens):
# Audit 2026-05-09 #16: route through MDLM contract if active.
next_logits = _next_token_logits(model, ctx)
probs = torch.softmax(next_logits / temp, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
ctx = torch.cat([ctx, next_id], dim=1)
if ctx.size(1) >= max_seq_len:
break
# Transfer to CPU in one shot, no per-row sync
all_rows.extend(ctx.cpu().tolist())
samples_done += b
# CPU-side batch decode — no GPU sync between decodes
any_hit = False
first_gen = None
hit_gen = None
for row in all_rows:
generated = tokenizer.decode(row)
continuation = generated[len(prompt):].strip()
_words = set(w.lower() for w in _re.findall(r"\b[\w'-]+\b", continuation))
hit = any(a in _words for a in answers_lower)
if first_gen is None:
first_gen = generated
if hit:
any_hit = True
if hit_gen is None:
hit_gen = generated
if any_hit:
hits += 1
print(f" prompt: {prompt!r}")
print(f" output: {(first_gen or '').replace(chr(10), ' ')!r}")
print(f" hit: {any_hit} (any of {num_samples} samples, temps={temps}, gen={gen_tokens}tok)")
if hit_gen is not None and hit_gen != first_gen:
print(f" hit_sample: {hit_gen.replace(chr(10), ' ')!r}")
score = hits / len(FACTUAL_EVAL)
print("---")
print(f"factual_english_score: {score:.4f}")
print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
return score, hits, len(FACTUAL_EVAL)
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
def run_factual_english(model, tokenizer, max_seq_len: int):
"""Dispatch to probe (fast, default) or gen (original) mode.
Set HYDRA_FACTUAL_MODE=gen to use the autoregressive path.
"""
if FACTUAL_MODE == "gen":
return _run_factual_english_gen(model, tokenizer, max_seq_len)
return _run_factual_english_probe(model, tokenizer, max_seq_len)
|