#!/usr/bin/env python3 """Algebraic simplex-linear GenPPL eval for endpoint models. This decoder matches the supervised bridge: p_t = (1 - t) * p0 + t * x1 Inference keeps the sampled p0 fixed and replaces the unknown x1 with the model's current endpoint prediction: p_{t_next} = (1 - t_next) * p0 + t_next * a_theta(p_t, t). There is no Dirichlet/Gamma resampling in the loop. """ from __future__ import annotations import argparse import json import math import sys from pathlib import Path import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from flowtext_lab.decode import sample_noise_simplex, state_for_model from flowtext_lab.genppl import filter_generated_texts, summarize_token_diversity from flowtext_lab.tokenization import BpeTextTokenizer from eval_lm1b_c1024_fullycoupled_sde_genppl import ( build_model, collect_special_token_ids, filter_endpoint_probs, score_with_gpt2, ) def lerp(a: float, b: float, t: float) -> float: return float(a) + float(t) * (float(b) - float(a)) def project_endpoint( logits: torch.Tensor, *, temp: float, projection: str, top_k: int, top_p: float, banned_ids: list[int], gumbel_tau: float, gumbel_noise_scale: float, eps: float, ) -> torch.Tensor: endpoint = F.softmax(logits / max(float(temp), eps), dim=-1) endpoint = filter_endpoint_probs( endpoint, top_k=top_k, top_p=top_p, banned_ids=banned_ids, eps=eps, ) if projection == "soft": return endpoint if projection == "argmax": ids = endpoint.argmax(dim=-1) return torch.zeros_like(endpoint).scatter_(-1, ids.unsqueeze(-1), 1.0) if projection == "sample": ids = torch.multinomial(endpoint.reshape(-1, endpoint.size(-1)), 1).view(*endpoint.shape[:-1]) return torch.zeros_like(endpoint).scatter_(-1, ids.unsqueeze(-1), 1.0) if projection == "gumbel_softmax": u = torch.rand_like(endpoint).clamp_(min=eps, max=1.0 - eps) g = -torch.log(-torch.log(u)) z = (endpoint.clamp_min(eps).log() + float(gumbel_noise_scale) * g) / max(float(gumbel_tau), eps) y = F.softmax(z, dim=-1).clamp_min(eps) return y / y.sum(dim=-1, keepdim=True).clamp_min(eps) raise ValueError(f"unknown endpoint_projection: {projection}") @torch.inference_mode() def decode_linear_simplex( model, tokenizer: BpeTextTokenizer, *, n_samples: int, batch_size: int, max_len: int, steps: int, seed: int, device: torch.device, noise_init: str, noise_sigma: float, noise_dirichlet_concentration: float, endpoint_temp_start: float, endpoint_temp_end: float, endpoint_projection: str, endpoint_top_k: int, endpoint_top_p: float, ban_special_tokens: bool, gumbel_tau_start: float, gumbel_tau_end: float, gumbel_noise_scale_start: float, gumbel_noise_scale_end: float, final_from: str, ) -> tuple[list[list[int]], list[str], dict]: torch.manual_seed(seed) eps = 1e-8 all_ids: list[list[int]] = [] all_texts: list[str] = [] remaining = n_samples banned_endpoint_ids = collect_special_token_ids(tokenizer) if ban_special_tokens else [] while remaining > 0: bs = min(batch_size, remaining) p0 = sample_noise_simplex( (bs, max_len), tokenizer.vocab_size, device, eps, noise_mode=noise_init, target_prob=1.0, noise_sigma=noise_sigma, dirichlet_concentration=noise_dirichlet_concentration, ) probs = p0.clone() attn = torch.ones((bs, max_len), dtype=torch.bool, device=device) last_endpoint = probs for step in range(steps): cur_t = step / max(steps, 1) next_t = (step + 1) / max(steps, 1) t = torch.full((bs,), float(cur_t), dtype=torch.float32, device=device) logits = model(state_for_model(model, probs, eps), t, attn).float() endpoint = project_endpoint( logits, temp=lerp(endpoint_temp_start, endpoint_temp_end, cur_t), projection=endpoint_projection, top_k=endpoint_top_k, top_p=endpoint_top_p, banned_ids=banned_endpoint_ids, gumbel_tau=lerp(gumbel_tau_start, gumbel_tau_end, cur_t), gumbel_noise_scale=lerp(gumbel_noise_scale_start, gumbel_noise_scale_end, cur_t), eps=eps, ) last_endpoint = endpoint probs = (1.0 - next_t) * p0 + next_t * endpoint probs = probs.clamp_min(eps) probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps) if final_from == "blend_0.5": final_probs = 0.5 * probs + 0.5 * last_endpoint ids = final_probs.argmax(dim=-1).detach().cpu().tolist() elif final_from == "model_t1": t = torch.ones((bs,), dtype=torch.float32, device=device) final_logits = model(state_for_model(model, probs, eps), t, attn).float() ids = final_logits.argmax(dim=-1).detach().cpu().tolist() else: raise ValueError(f"unknown final_from: {final_from}") all_ids.extend(ids) all_texts.extend(tokenizer.decode(row, stop_at_eos=False, skip_special_tokens=False) for row in ids) remaining -= bs print(f"[linear] generated {n_samples - remaining}/{n_samples}", flush=True) cfg = { "decode_rule": "linear_simplex_algebraic", "steps": steps, "noise_init": noise_init, "noise_sigma": noise_sigma, "noise_dirichlet_concentration": noise_dirichlet_concentration, "endpoint_temp_start": endpoint_temp_start, "endpoint_temp_end": endpoint_temp_end, "endpoint_projection": endpoint_projection, "endpoint_top_k": endpoint_top_k, "endpoint_top_p": endpoint_top_p, "ban_special_tokens": ban_special_tokens, "banned_endpoint_ids": banned_endpoint_ids, "gumbel_tau_start": gumbel_tau_start, "gumbel_tau_end": gumbel_tau_end, "gumbel_noise_scale_start": gumbel_noise_scale_start, "gumbel_noise_scale_end": gumbel_noise_scale_end, "final_from": final_from, "n_samples": n_samples, "seed": seed, } return all_ids, all_texts, cfg def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Linear-simplex algebraic GenPPL eval") p.add_argument("--checkpoint", required=True) p.add_argument("--tokenizer_path", required=True) p.add_argument("--scorer", required=True) p.add_argument("--out_dir", required=True) p.add_argument("--n_samples", type=int, default=128) p.add_argument("--max_len", type=int, default=128) p.add_argument("--steps", type=int, default=128) p.add_argument("--batch_size", type=int, default=16) p.add_argument("--score_batch", type=int, default=8) p.add_argument("--score_max_length", type=int, default=1024) p.add_argument("--noise_init", choices=["uniform", "logistic_normal", "dirichlet"], default="logistic_normal") p.add_argument("--noise_sigma", type=float, default=3.0) p.add_argument("--noise_dirichlet_concentration", type=float, default=1.0) p.add_argument("--endpoint_temp_start", type=float, default=1.45) p.add_argument("--endpoint_temp_end", type=float, default=0.8) p.add_argument("--endpoint_projection", choices=["soft", "sample", "argmax", "gumbel_softmax"], default="soft") p.add_argument("--endpoint_top_k", type=int, default=0) p.add_argument("--endpoint_top_p", type=float, default=1.0) p.add_argument("--ban_special_tokens", action="store_true") p.add_argument("--gumbel_tau_start", type=float, default=1.0) p.add_argument("--gumbel_tau_end", type=float, default=0.2) p.add_argument("--gumbel_noise_scale_start", type=float, default=1.0) p.add_argument("--gumbel_noise_scale_end", type=float, default=0.0) p.add_argument("--final_from", choices=["blend_0.5", "model_t1"], default="model_t1") p.add_argument("--seed", type=int, default=20260524) return p.parse_args() @torch.no_grad() def main() -> None: args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[load] {args.checkpoint}", flush=True) ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) step = ckpt.get("step") print(f"[ckpt] step={step}", flush=True) tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) model = build_model(ckpt, tokenizer, device) ids, texts, decode_cfg = decode_linear_simplex( model, tokenizer, n_samples=args.n_samples, batch_size=args.batch_size, max_len=args.max_len, steps=args.steps, seed=args.seed, device=device, noise_init=args.noise_init, noise_sigma=args.noise_sigma, noise_dirichlet_concentration=args.noise_dirichlet_concentration, endpoint_temp_start=args.endpoint_temp_start, endpoint_temp_end=args.endpoint_temp_end, endpoint_projection=args.endpoint_projection, endpoint_top_k=args.endpoint_top_k, endpoint_top_p=args.endpoint_top_p, ban_special_tokens=args.ban_special_tokens, gumbel_tau_start=args.gumbel_tau_start, gumbel_tau_end=args.gumbel_tau_end, gumbel_noise_scale_start=args.gumbel_noise_scale_start, gumbel_noise_scale_end=args.gumbel_noise_scale_end, final_from=args.final_from, ) del model if torch.cuda.is_available(): torch.cuda.empty_cache() def strip_special(t: str) -> str: import re t = t.replace("[CLS]", " ").replace("[SEP]", " ").replace("[PAD]", " ") t = t.replace("<|endoftext|>", " ") return re.sub(r"\s+", " ", t).strip() stripped = [strip_special(t) for t in texts] kept_raw, _ = filter_generated_texts(texts, min_chars=1, normalize_whitespace=False, drop_empty=True) kept_stripped, _ = filter_generated_texts(stripped, min_chars=1, normalize_whitespace=True, drop_empty=True) diversity = summarize_token_diversity(ids).__dict__ print(f"[score] loading scorer: {args.scorer}", flush=True) scorer_tok = AutoTokenizer.from_pretrained(args.scorer) if scorer_tok.pad_token_id is None: scorer_tok.pad_token = scorer_tok.eos_token scorer_tok.pad_token_id = scorer_tok.eos_token_id scorer = AutoModelForCausalLM.from_pretrained(args.scorer).to(device).eval() if getattr(scorer.config, "pad_token_id", None) is None: scorer.config.pad_token_id = scorer_tok.pad_token_id raw_ppl = score_with_gpt2( kept_raw, scorer, scorer_tok, batch_size=args.score_batch, max_length=args.score_max_length, device=device, ) stripped_ppl = score_with_gpt2( kept_stripped, scorer, scorer_tok, batch_size=args.score_batch, max_length=args.score_max_length, device=device, ) summary = { "type": "summary", "checkpoint": args.checkpoint, "step": step, "decode": decode_cfg, "raw_genppl": raw_ppl, "stripped_genppl": stripped_ppl, "diversity": diversity, } out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) out_jsonl = out_dir / f"linear_steps{args.steps}_samples{args.n_samples}_scored.jsonl" with out_jsonl.open("w", encoding="utf-8") as f: f.write(json.dumps(summary, ensure_ascii=False) + "\n") for i, (raw, clean) in enumerate(zip(texts, stripped)): f.write(json.dumps({"type": "sample", "index": i, "raw_text": raw, "stripped_text": clean}, ensure_ascii=False) + "\n") print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2), flush=True) print(f"[done] {out_jsonl}", flush=True) if __name__ == "__main__": main()