#!/usr/bin/env python3 """Decode-sweep lab for FlowText OpenWebText checkpoints. The goal is to debug inference without touching training. We try several simplex-valid update rules, generate many candidates, and rank them with anti-collapse diagnostics instead of pure self-likelihood. Run from the flowtext_standard_bench repository root. """ from __future__ import annotations import argparse import json import math import re import sys from collections import Counter from dataclasses import dataclass, asdict from pathlib import Path from typing import Iterable, List, Sequence import torch import torch.nn.functional as F REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from eval import build_model_from_ckpt from flowtext_lab.bridges import smooth_onehot from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model from flowtext_lab.tokenization import BpeTextTokenizer WORD_RE = re.compile(r"[A-Za-z]+|\d+|[^\sA-Za-z\d]") @dataclass class DecodeConfig: label: str rule: str steps: int = 64 model_t_mode: str = "flow" eta: float = 0.5 damping: float = 1.0 max_gamma: float = 1.0 endpoint_temp: float = 1.0 state_floor: float = 1e-8 final_from: str = "state" noise_mix: float = 0.0 noise_decay: str = "linear" eos_logit_bias: float = 0.0 def tokenize_for_metrics(text: str) -> list[str]: return WORD_RE.findall(text) def repeated_ngram_frac(tokens: Sequence[str], n: int) -> float: if len(tokens) < n: return 0.0 grams = list(zip(*[tokens[i:] for i in range(n)])) counts = Counter(grams) return sum(v - 1 for v in counts.values() if v > 1) / max(len(grams), 1) def text_metrics(text: str) -> dict: toks = tokenize_for_metrics(text) words = [t.lower() for t in toks if re.fullmatch(r"[A-Za-z]+", t)] n_tok = max(len(toks), 1) n_words = max(len(words), 1) word_counts = Counter(words) max_word_frac = word_counts.most_common(1)[0][1] / n_words if word_counts else 1.0 distinct1 = len(set(words)) / n_words if words else 0.0 bigrams = list(zip(words, words[1:])) distinct2 = len(set(bigrams)) / max(len(bigrams), 1) if bigrams else 0.0 digit_frac = sum(t.isdigit() for t in toks) / n_tok punct_frac = sum(bool(re.fullmatch(r"[,.;:!?]+", t)) for t in toks) / n_tok eos_count = text.count("<|endoftext|>") bad_char_count = text.count("�") rep3 = repeated_ngram_frac([t.lower() for t in toks], 3) rep4 = repeated_ngram_frac([t.lower() for t in toks], 4) # This score is deliberately simple and non-oracle. It rewards length and # lexical variety while heavily penalizing classic collapse artifacts. quality = ( min(len(text) / 700.0, 1.0) + 0.35 * distinct2 + 0.15 * distinct1 - 0.30 * eos_count - 2.60 * rep3 - 1.60 * rep4 - 1.30 * digit_frac - 0.65 * punct_frac - 1.35 * max_word_frac - 0.35 * bad_char_count ) return { "quality": float(quality), "chars": len(text), "tokens": len(toks), "words": len(words), "eos_count": eos_count, "bad_char_count": bad_char_count, "rep3": float(rep3), "rep4": float(rep4), "distinct1": float(distinct1), "distinct2": float(distinct2), "digit_frac": float(digit_frac), "punct_frac": float(punct_frac), "max_word_frac": float(max_word_frac), } def decode_text(tokenizer: BpeTextTokenizer, ids: Sequence[int]) -> str: return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False) def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]: return list(tokenizer.tokenizer.encode(prompt).ids)[:max_len] @torch.no_grad() def build_initial_state( tokenizer: BpeTextTokenizer, prompts: list[str], restarts: int, max_len: int, target_prob: float, eps: float, noise_init: str, noise_sigma: float, dirichlet_init_concentration: float, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]: expanded: list[str] = [] prompt_ids: list[list[int]] = [] for prompt in prompts: ids = encode_prompt(tokenizer, prompt, max_len=max_len) for _ in range(restarts): expanded.append(prompt) prompt_ids.append(ids) batch = len(prompt_ids) attn = torch.ones((batch, max_len), dtype=torch.bool, device=device) probs = sample_noise_simplex( (batch, max_len), tokenizer.vocab_size, device, eps, noise_mode=noise_init, target_prob=target_prob, noise_sigma=noise_sigma, dirichlet_concentration=dirichlet_init_concentration, ) lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device) lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device) for row, ids in enumerate(prompt_ids): if not ids: continue ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) sp = smooth_onehot(ids_t, tokenizer.vocab_size, target_prob, eps)[0] probs[row, : len(ids)] = sp lock_probs[row, : len(ids)] = sp lock[row, : len(ids)] = True return probs, attn, lock, lock_probs, expanded def flowmap_gamma(step: int, steps: int, damping: float, max_gamma: float, eps: float) -> float: s = step / max(steps, 1) t_next = (step + 1) / max(steps, 1) base_gamma = (t_next - s) / max(1.0 - s, eps) gamma = float(damping) * base_gamma return min(gamma, float(max_gamma)) if max_gamma > 0 else gamma @torch.no_grad() def decode_batch( model, init_probs: torch.Tensor, attn: torch.Tensor, lock: torch.Tensor, lock_probs: torch.Tensor, cfg: DecodeConfig, eps: float, eos_id: int | None = None, ) -> torch.Tensor: probs = init_probs.float().clone() device = probs.device last_endpoint = probs for step in range(cfg.steps): t = model_time_for_step(cfg.model_t_mode, step, cfg.steps, probs.size(0), device, dtype=torch.float32) logits = model(state_for_model(model, probs, eps), t, attn).float() if cfg.endpoint_temp != 1.0: logits = logits / float(cfg.endpoint_temp) if cfg.eos_logit_bias != 0.0 and eos_id is not None and 0 <= eos_id < logits.size(-1): logits[..., eos_id] = logits[..., eos_id] + float(cfg.eos_logit_bias) endpoint = F.softmax(logits, dim=-1) last_endpoint = endpoint if cfg.rule == "flowmap": gamma = flowmap_gamma(step, cfg.steps, cfg.damping, cfg.max_gamma, eps) new_probs = probs + gamma * (endpoint - probs) elif cfg.rule == "replace": new_probs = (1.0 - cfg.eta) * probs + cfg.eta * endpoint elif cfg.rule == "geometric": log_mix = (1.0 - cfg.eta) * torch.log(probs.clamp_min(eps)) + cfg.eta * torch.log(endpoint.clamp_min(eps)) new_probs = F.softmax(log_mix, dim=-1) elif cfg.rule == "centered_residual": # Add a zero-sum probability residual, then project back to simplex. residual = endpoint - probs residual = residual - residual.mean(dim=-1, keepdim=True) new_probs = probs + cfg.eta * residual else: raise ValueError(f"Unknown decode rule: {cfg.rule}") if cfg.noise_mix > 0: if cfg.noise_decay == "linear": lam = cfg.noise_mix * (1.0 - (step + 1) / max(cfg.steps, 1)) elif cfg.noise_decay == "sqrt": lam = cfg.noise_mix * math.sqrt(max(0.0, 1.0 - (step + 1) / max(cfg.steps, 1))) else: lam = cfg.noise_mix if lam > 0: uniform = torch.full_like(new_probs, 1.0 / new_probs.size(-1)) new_probs = (1.0 - lam) * new_probs + lam * uniform new_probs = new_probs.clamp_min(max(float(cfg.state_floor), eps)) new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(eps) new_probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs) probs = new_probs if cfg.final_from == "endpoint": out = last_endpoint out = torch.where(lock.unsqueeze(-1), lock_probs, out) return out / out.sum(dim=-1, keepdim=True).clamp_min(eps) if cfg.final_from == "blend": out = 0.5 * probs + 0.5 * last_endpoint out = torch.where(lock.unsqueeze(-1), lock_probs, out) return out / out.sum(dim=-1, keepdim=True).clamp_min(eps) return probs @torch.no_grad() def pseudo_likelihood_scores( model, tokenizer: BpeTextTokenizer, probs: torch.Tensor, attn: torch.Tensor, lock: torch.Tensor, target_prob: float, eps: float, repeats: int, mask_frac: float, rerank_t: float, ) -> torch.Tensor: ids = probs.argmax(dim=-1) endpoint = smooth_onehot(ids, tokenizer.vocab_size, target_prob, eps) eligible = attn & (~lock) scores = torch.zeros(ids.size(0), dtype=torch.float32, device=ids.device) counts = torch.zeros_like(scores) for _ in range(max(1, repeats)): score_mask = (torch.rand_like(ids.float()) < mask_frac) & eligible for row in range(ids.size(0)): if eligible[row].any() and not score_mask[row].any(): choices = torch.nonzero(eligible[row], as_tuple=False).flatten() score_mask[row, choices[torch.randint(0, choices.numel(), (1,), device=ids.device)]] = True noise = sample_noise_simplex( (ids.size(0), ids.size(1)), tokenizer.vocab_size, ids.device, eps, noise_mode="logistic_normal", target_prob=target_prob, noise_sigma=-1.0, ) inp = torch.where(score_mask.unsqueeze(-1), noise, endpoint) inp = torch.where(lock.unsqueeze(-1), probs, inp) t = torch.full((ids.size(0),), float(rerank_t), dtype=torch.float32, device=ids.device) logits = model(state_for_model(model, inp, eps), t, attn).float() logp = F.log_softmax(logits, dim=-1).gather(-1, ids.unsqueeze(-1)).squeeze(-1) scores += (logp * score_mask.float()).sum(dim=-1) counts += score_mask.float().sum(dim=-1) return scores / counts.clamp_min(1.0) def default_configs(steps: int, config_set: str) -> list[DecodeConfig]: if config_set == "focused_flowmap": return [ DecodeConfig("flowmap_t1p00_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0), DecodeConfig("flowmap_t1p10_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.10), DecodeConfig("flowmap_t1p25_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.25), DecodeConfig("flowmap_t1p40_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40), DecodeConfig("flowmap_t1p60_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.60), DecodeConfig("flowmap_t1p25_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25), DecodeConfig("flowmap_t1p40_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.40), DecodeConfig("flowmap_t1p60_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.60), DecodeConfig("flowmap_t1p25_g0p5", "flowmap", steps=steps, damping=1.0, max_gamma=0.5, endpoint_temp=1.25), DecodeConfig("flowmap_t1p40_g0p5", "flowmap", steps=steps, damping=1.0, max_gamma=0.5, endpoint_temp=1.40), ] if config_set == "best_flowmap": return [ DecodeConfig("flowmap_t1p25_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25), DecodeConfig("flowmap_t1p25_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.25), DecodeConfig("flowmap_t1p35_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35), DecodeConfig("flowmap_t1p40_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40), ] if config_set == "final_projection": return [ DecodeConfig("flowmap_t1p35_state", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, final_from="state"), DecodeConfig("flowmap_t1p35_endpoint", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, final_from="endpoint"), DecodeConfig("flowmap_t1p35_blend", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, final_from="blend"), DecodeConfig("flowmap_t1p40_state", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, final_from="state"), DecodeConfig("flowmap_t1p40_endpoint", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, final_from="endpoint"), DecodeConfig("flowmap_t1p40_blend", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, final_from="blend"), DecodeConfig("flowmap_t1p25_d0p7_state", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, final_from="state"), DecodeConfig("flowmap_t1p25_d0p7_endpoint", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, final_from="endpoint"), DecodeConfig("flowmap_t1p25_d0p7_blend", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, final_from="blend"), ] if config_set == "eos_sweep": return [ DecodeConfig("flowmap_t1p35_eos0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=0.0), DecodeConfig("flowmap_t1p35_eos-1", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=-1.0), DecodeConfig("flowmap_t1p35_eos-2", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=-2.0), DecodeConfig("flowmap_t1p35_eos-3", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=-3.0), DecodeConfig("flowmap_t1p40_eos-2", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, eos_logit_bias=-2.0), DecodeConfig("flowmap_t1p25_d0p7_eos-2", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, eos_logit_bias=-2.0), ] if config_set != "broad": raise ValueError(f"Unknown config_set: {config_set}") return [ DecodeConfig("flowmap64", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, final_from="state"), DecodeConfig("flowmap_temp1p25", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.25), DecodeConfig("flowmap_temp0p85", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=0.85), DecodeConfig("replace_eta0p35", "replace", steps=steps, eta=0.35), DecodeConfig("replace_eta0p50", "replace", steps=steps, eta=0.50), DecodeConfig("replace_eta0p65", "replace", steps=steps, eta=0.65), DecodeConfig("replace_eta0p50_temp1p25", "replace", steps=steps, eta=0.50, endpoint_temp=1.25), DecodeConfig("geometric_eta0p25", "geometric", steps=steps, eta=0.25), DecodeConfig("geometric_eta0p50", "geometric", steps=steps, eta=0.50), DecodeConfig("centered_residual_eta0p20", "centered_residual", steps=steps, eta=0.20), DecodeConfig("replace_eta0p50_floor1e6", "replace", steps=steps, eta=0.50, state_floor=1e-6), DecodeConfig("replace_eta0p50_leak", "replace", steps=steps, eta=0.50, noise_mix=0.03, noise_decay="sqrt"), ] def aggregate(rows: list[dict]) -> dict: keys = ["quality", "eos_count", "rep3", "rep4", "distinct1", "distinct2", "digit_frac", "max_word_frac"] return {f"mean_{k}": sum(float(r[k]) for r in rows) / max(len(rows), 1) for k in keys} def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True) parser.add_argument("--tokenizer_path", required=True) parser.add_argument("--max_len", type=int, default=128) parser.add_argument("--steps", type=int, default=64) parser.add_argument("--restarts", type=int, default=64) parser.add_argument("--target_prob", type=float, default=0.99) parser.add_argument("--eps", type=float, default=1e-8) parser.add_argument("--model_t_mode", choices=["linear", "flow", "const0", "const05", "const1", "random"], default="flow") parser.add_argument("--noise_init", choices=["uniform", "logistic_normal", "dirichlet"], default="dirichlet") parser.add_argument("--noise_sigma", type=float, default=-1.0) parser.add_argument("--dirichlet_init_concentration", type=float, default=1.0) parser.add_argument("--prompts", default="|The|In the early morning|Scientists have|The company said|A young woman") parser.add_argument("--score_repeats", type=int, default=0) parser.add_argument("--score_mask_frac", type=float, default=0.5) parser.add_argument("--rerank_t", type=float, default=0.5) parser.add_argument("--pl_weight", type=float, default=0.0) parser.add_argument("--output", default="runs/decode_lab/latest_decode_lab.jsonl") parser.add_argument("--config_set", default="broad", choices=["broad", "focused_flowmap", "best_flowmap", "final_projection", "eos_sweep"]) parser.add_argument("--decode_batch_size", type=int, default=0) parser.add_argument("--topk", type=int, default=5) parser.add_argument("--seed", type=int, default=20260428) args = parser.parse_args() torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) ckpt = torch.load(args.checkpoint, map_location="cpu") model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device) model.eval() prompts = args.prompts.split("|") # Keep the first empty prompt: it is unconditional generation. print(f"[info] device={device} prompts={prompts} restarts={args.restarts} steps={args.steps}") print(f"[info] checkpoint={args.checkpoint}") out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) configs = default_configs(args.steps, args.config_set) for cfg in configs: cfg.model_t_mode = args.model_t_mode with out_path.open("w") as f: for cfg in configs: init, attn, lock, lock_probs, expanded = build_initial_state( tokenizer=tokenizer, prompts=prompts, restarts=args.restarts, max_len=args.max_len, target_prob=args.target_prob, eps=args.eps, noise_init=args.noise_init, noise_sigma=args.noise_sigma, dirichlet_init_concentration=args.dirichlet_init_concentration, device=device, ) if args.decode_batch_size > 0 and init.size(0) > args.decode_batch_size: decoded_parts = [] for start in range(0, init.size(0), args.decode_batch_size): end = min(start + args.decode_batch_size, init.size(0)) part = decode_batch( model, init[start:end], attn[start:end], lock[start:end], lock_probs[start:end], cfg, args.eps, tokenizer.eos_id, ) decoded_parts.append(part.detach().cpu()) print(f"[chunk] {cfg.label} decoded {end}/{init.size(0)}", flush=True) decoded = torch.cat(decoded_parts, dim=0) else: decoded = decode_batch(model, init, attn, lock, lock_probs, cfg, args.eps, tokenizer.eos_id) ids = decoded.argmax(dim=-1).detach().cpu().tolist() texts = [decode_text(tokenizer, row) for row in ids] rows = [] for i, text in enumerate(texts): m = text_metrics(text) m.update({"candidate": i, "prompt": expanded[i], "text": text}) rows.append(m) if args.score_repeats > 0: decoded_for_score = decoded.to(device) if decoded.device != device else decoded pl = pseudo_likelihood_scores( model, tokenizer, decoded_for_score, attn, lock, args.target_prob, args.eps, repeats=args.score_repeats, mask_frac=args.score_mask_frac, rerank_t=args.rerank_t, ).detach().cpu().tolist() for row, score in zip(rows, pl): row["pseudo_logp"] = float(score) row["rank_score"] = float(row["quality"] + args.pl_weight * score) else: for row in rows: row["pseudo_logp"] = None row["rank_score"] = float(row["quality"]) summary = {"type": "summary", "config": asdict(cfg), "agg": aggregate(rows)} f.write(json.dumps(summary, ensure_ascii=False) + "\n") print("\n" + "=" * 96) print("[config]", cfg.label, asdict(cfg)) print("[metrics]", json.dumps(summary["agg"], ensure_ascii=False)) for prompt in prompts: subset = [r for r in rows if r["prompt"] == prompt] subset.sort(key=lambda r: r["rank_score"], reverse=True) for rank, row in enumerate(subset[: args.topk], 1): rec = {"type": "sample", "config": asdict(cfg), "rank": rank, **row} f.write(json.dumps(rec, ensure_ascii=False) + "\n") if rank <= 1: print(f"\n--- best prompt={prompt!r} rank_score={row['rank_score']:.4f} quality={row['quality']:.4f} ---") print(row["text"]) del init, attn, lock, lock_probs, decoded if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"[done] wrote {out_path}") if __name__ == "__main__": main()