Spaces:
Runtime error
Runtime error
| """Evaluation: factual probes + sampled factual English scoring. | |
| Extracted from train.py (W1 modularization). Semantics unchanged. | |
| Perf optimizations (eval_perf_fix): | |
| - Probe mode: single forward per prompt instead of autoregressive gen | |
| - Batch decode: all GPU work first, all CPU decode after | |
| - Batched factual probes: single padded forward instead of N sequential | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import re as _re | |
| import torch | |
| from hydra.config import FACTUAL_SAMPLES, FACTUAL_BATCH, FACTUAL_GEN_TOKENS | |
| # Default to probe mode (1 forward per prompt); set HYDRA_FACTUAL_MODE=gen for | |
| # the original autoregressive generation path. | |
| FACTUAL_MODE = os.environ.get("HYDRA_FACTUAL_MODE", "probe") | |
| FACTUAL_EVAL = [ | |
| # Hard factual recall — requires specific knowledge memorization | |
| ("The capital of France is", ["Paris", "paris"]), | |
| ("Water boils at", ["100", "boiling"]), | |
| ("The largest planet in our solar system is", ["Jupiter", "jupiter"]), | |
| # Easier completions — common collocations / patterns the model may pick up | |
| ("Once upon a", ["time"]), | |
| ("Hello, my name", ["is", "'s"]), | |
| ("The cat sat on the", ["mat", "floor", "rug", "table", "couch", "chair", "ground"]), | |
| ("She opened the door and", ["walked", "saw", "found", "stepped", "looked", "went", "ran"]), | |
| # Original hard ones kept for completeness | |
| ("The speed of light is approximately", ["299", "300", "186,000", "light speed"]), | |
| ("Two plus two equals", ["4", "four"]), | |
| ] | |
| _FACTUAL_PROBES = [ | |
| "The capital of France is", | |
| "Water boils at", | |
| "The largest planet in our solar system is", | |
| "The speed of light is approximately", | |
| "Shakespeare wrote", | |
| ] | |
| def run_factual_probes(model, tokenizer, device, autocast_ctx) -> None: | |
| """Top-5 next-token predictions for canonical factual prompts. | |
| Batched: pads all prompts into a single forward pass instead of N | |
| sequential passes. | |
| """ | |
| print("\n--- Factual Probes ---") | |
| model.eval() | |
| # Process probes one at a time to avoid cooperative launch limit | |
| # (batched forward with B=len(probes) can exceed SM residency cap). | |
| for prompt_text in _FACTUAL_PROBES: | |
| ids = tokenizer.encode(prompt_text) | |
| x = torch.tensor([ids], device=device) | |
| with torch.no_grad(), autocast_ctx: | |
| logits = model(x) | |
| probs = torch.softmax(logits[0, -1].float(), dim=-1) | |
| top5 = torch.topk(probs, 5) | |
| completions = [tokenizer.decode([idx.item()]) for idx in top5.indices] | |
| probs_list = [f"{p:.4f}" for p in top5.values[:3].tolist()] | |
| print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})') | |
| print("--- End Factual Probes ---\n") | |
| # --------------------------------------------------------------------------- | |
| # Probe mode: single forward per prompt (Fix D) | |
| # --------------------------------------------------------------------------- | |
| def _run_factual_english_probe(model, tokenizer, max_seq_len: int): | |
| """Fast probe mode: for each (prompt, answers), encode prompt + each answer | |
| candidate as a single sequence, do ONE forward pass, and check if the model's | |
| argmax at the last prompt token matches the first answer token. | |
| Falls back to checking top-K predictions to be generous (same as gen mode | |
| which samples multiple temperatures). | |
| """ | |
| print("---") | |
| print("factual_english_samples: (probe mode)") | |
| model.eval() | |
| hits = 0 | |
| with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| for prompt, answers in FACTUAL_EVAL: | |
| prompt_ids = tokenizer.encode(prompt) | |
| prompt_len = len(prompt_ids) | |
| x = torch.tensor([prompt_ids], device="cuda", dtype=torch.long) | |
| logits = model(x, targets=None) | |
| # logits shape: [1, seq_len, vocab] or [1, vocab] | |
| if logits.dim() == 3: | |
| last_logits = logits[0, -1, :] | |
| else: | |
| last_logits = logits[0] | |
| probs = torch.softmax(last_logits.float(), dim=-1) | |
| # Check top-K predictions (generous: K=20 to match multi-sample gen) | |
| top_k = min(20, probs.shape[-1]) | |
| top_ids = torch.topk(probs, top_k).indices.tolist() | |
| top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids] | |
| answers_lower = [a.lower() for a in answers] | |
| any_hit = any( | |
| any(a in tok for a in answers_lower) | |
| for tok in top_tokens | |
| ) | |
| if any_hit: | |
| hits += 1 | |
| best_completion = tokenizer.decode([top_ids[0]]) | |
| print(f" prompt: {prompt!r}") | |
| print(f" output: {(prompt + best_completion).replace(chr(10), ' ')!r}") | |
| print(f" hit: {any_hit} (probe top-{top_k})") | |
| score = hits / len(FACTUAL_EVAL) | |
| print("---") | |
| print(f"factual_english_score: {score:.4f}") | |
| print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}") | |
| return score, hits, len(FACTUAL_EVAL) | |
| # --------------------------------------------------------------------------- | |
| # Gen mode: original autoregressive path (Fix F: batch decode) | |
| # --------------------------------------------------------------------------- | |
| def _run_factual_english_gen(model, tokenizer, max_seq_len: int): | |
| """Original autoregressive generation path with batch decode optimization: | |
| all GPU work runs first, then all CPU decoding happens after.""" | |
| print("---") | |
| print("factual_english_samples: (gen mode)") | |
| model.eval() | |
| num_samples = FACTUAL_SAMPLES | |
| batch = FACTUAL_BATCH | |
| gen_tokens = FACTUAL_GEN_TOKENS | |
| temps = [0.7, 0.9, 1.1] | |
| hits = 0 | |
| with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| for prompt, answers in FACTUAL_EVAL: | |
| ids = tokenizer.encode(prompt) | |
| answers_lower = [a.lower() for a in answers] | |
| # Collect all generated token sequences on GPU first | |
| all_rows: list[list[int]] = [] | |
| samples_done = 0 | |
| batch_idx = 0 | |
| while samples_done < num_samples: | |
| b = min(batch, num_samples - samples_done) | |
| temp = temps[batch_idx % len(temps)] | |
| batch_idx += 1 | |
| ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long) | |
| for _ in range(gen_tokens): | |
| logits = model(ctx, targets=None) | |
| next_logits = logits[:, -1, :] if logits.dim() == 3 else logits | |
| probs = torch.softmax(next_logits.float() / temp, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1) | |
| ctx = torch.cat([ctx, next_id], dim=1) | |
| if ctx.size(1) >= max_seq_len: | |
| break | |
| # Transfer to CPU in one shot, no per-row sync | |
| all_rows.extend(ctx.cpu().tolist()) | |
| samples_done += b | |
| # CPU-side batch decode — no GPU sync between decodes | |
| any_hit = False | |
| first_gen = None | |
| hit_gen = None | |
| for row in all_rows: | |
| generated = tokenizer.decode(row) | |
| continuation = generated[len(prompt):].strip() | |
| _words = set(w.lower() for w in _re.findall(r"\b[\w'-]+\b", continuation)) | |
| hit = any(a in _words for a in answers_lower) | |
| if first_gen is None: | |
| first_gen = generated | |
| if hit: | |
| any_hit = True | |
| if hit_gen is None: | |
| hit_gen = generated | |
| if any_hit: | |
| hits += 1 | |
| print(f" prompt: {prompt!r}") | |
| print(f" output: {(first_gen or '').replace(chr(10), ' ')!r}") | |
| print(f" hit: {any_hit} (any of {num_samples} samples, temps={temps}, gen={gen_tokens}tok)") | |
| if hit_gen is not None and hit_gen != first_gen: | |
| print(f" hit_sample: {hit_gen.replace(chr(10), ' ')!r}") | |
| score = hits / len(FACTUAL_EVAL) | |
| print("---") | |
| print(f"factual_english_score: {score:.4f}") | |
| print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}") | |
| return score, hits, len(FACTUAL_EVAL) | |
| # --------------------------------------------------------------------------- | |
| # Public entry point | |
| # --------------------------------------------------------------------------- | |
| def run_factual_english(model, tokenizer, max_seq_len: int): | |
| """Dispatch to probe (fast, default) or gen (original) mode. | |
| Set HYDRA_FACTUAL_MODE=gen to use the autoregressive path. | |
| """ | |
| if FACTUAL_MODE == "gen": | |
| return _run_factual_english_gen(model, tokenizer, max_seq_len) | |
| return _run_factual_english_probe(model, tokenizer, max_seq_len) | |