| |
| """Algebraic simplex-linear GenPPL eval for endpoint models. |
| |
| This decoder matches the supervised bridge: |
| |
| p_t = (1 - t) * p0 + t * x1 |
| |
| Inference keeps the sampled p0 fixed and replaces the unknown x1 with the |
| model's current endpoint prediction: |
| |
| p_{t_next} = (1 - t_next) * p0 + t_next * a_theta(p_t, t). |
| |
| There is no Dirichlet/Gamma resampling in the loop. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from flowtext_lab.decode import sample_noise_simplex, state_for_model |
| from flowtext_lab.genppl import filter_generated_texts, summarize_token_diversity |
| from flowtext_lab.tokenization import BpeTextTokenizer |
|
|
| from eval_lm1b_c1024_fullycoupled_sde_genppl import ( |
| build_model, |
| collect_special_token_ids, |
| filter_endpoint_probs, |
| score_with_gpt2, |
| ) |
|
|
|
|
| def lerp(a: float, b: float, t: float) -> float: |
| return float(a) + float(t) * (float(b) - float(a)) |
|
|
|
|
| def project_endpoint( |
| logits: torch.Tensor, |
| *, |
| temp: float, |
| projection: str, |
| top_k: int, |
| top_p: float, |
| banned_ids: list[int], |
| gumbel_tau: float, |
| gumbel_noise_scale: float, |
| eps: float, |
| ) -> torch.Tensor: |
| endpoint = F.softmax(logits / max(float(temp), eps), dim=-1) |
| endpoint = filter_endpoint_probs( |
| endpoint, |
| top_k=top_k, |
| top_p=top_p, |
| banned_ids=banned_ids, |
| eps=eps, |
| ) |
| if projection == "soft": |
| return endpoint |
| if projection == "argmax": |
| ids = endpoint.argmax(dim=-1) |
| return torch.zeros_like(endpoint).scatter_(-1, ids.unsqueeze(-1), 1.0) |
| if projection == "sample": |
| ids = torch.multinomial(endpoint.reshape(-1, endpoint.size(-1)), 1).view(*endpoint.shape[:-1]) |
| return torch.zeros_like(endpoint).scatter_(-1, ids.unsqueeze(-1), 1.0) |
| if projection == "gumbel_softmax": |
| u = torch.rand_like(endpoint).clamp_(min=eps, max=1.0 - eps) |
| g = -torch.log(-torch.log(u)) |
| z = (endpoint.clamp_min(eps).log() + float(gumbel_noise_scale) * g) / max(float(gumbel_tau), eps) |
| y = F.softmax(z, dim=-1).clamp_min(eps) |
| return y / y.sum(dim=-1, keepdim=True).clamp_min(eps) |
| raise ValueError(f"unknown endpoint_projection: {projection}") |
|
|
|
|
| @torch.inference_mode() |
| def decode_linear_simplex( |
| model, |
| tokenizer: BpeTextTokenizer, |
| *, |
| n_samples: int, |
| batch_size: int, |
| max_len: int, |
| steps: int, |
| seed: int, |
| device: torch.device, |
| noise_init: str, |
| noise_sigma: float, |
| noise_dirichlet_concentration: float, |
| endpoint_temp_start: float, |
| endpoint_temp_end: float, |
| endpoint_projection: str, |
| endpoint_top_k: int, |
| endpoint_top_p: float, |
| ban_special_tokens: bool, |
| gumbel_tau_start: float, |
| gumbel_tau_end: float, |
| gumbel_noise_scale_start: float, |
| gumbel_noise_scale_end: float, |
| final_from: str, |
| ) -> tuple[list[list[int]], list[str], dict]: |
| torch.manual_seed(seed) |
| eps = 1e-8 |
| all_ids: list[list[int]] = [] |
| all_texts: list[str] = [] |
| remaining = n_samples |
| banned_endpoint_ids = collect_special_token_ids(tokenizer) if ban_special_tokens else [] |
|
|
| while remaining > 0: |
| bs = min(batch_size, remaining) |
| p0 = sample_noise_simplex( |
| (bs, max_len), |
| tokenizer.vocab_size, |
| device, |
| eps, |
| noise_mode=noise_init, |
| target_prob=1.0, |
| noise_sigma=noise_sigma, |
| dirichlet_concentration=noise_dirichlet_concentration, |
| ) |
| probs = p0.clone() |
| attn = torch.ones((bs, max_len), dtype=torch.bool, device=device) |
| last_endpoint = probs |
|
|
| for step in range(steps): |
| cur_t = step / max(steps, 1) |
| next_t = (step + 1) / max(steps, 1) |
| t = torch.full((bs,), float(cur_t), dtype=torch.float32, device=device) |
| logits = model(state_for_model(model, probs, eps), t, attn).float() |
|
|
| endpoint = project_endpoint( |
| logits, |
| temp=lerp(endpoint_temp_start, endpoint_temp_end, cur_t), |
| projection=endpoint_projection, |
| top_k=endpoint_top_k, |
| top_p=endpoint_top_p, |
| banned_ids=banned_endpoint_ids, |
| gumbel_tau=lerp(gumbel_tau_start, gumbel_tau_end, cur_t), |
| gumbel_noise_scale=lerp(gumbel_noise_scale_start, gumbel_noise_scale_end, cur_t), |
| eps=eps, |
| ) |
| last_endpoint = endpoint |
| probs = (1.0 - next_t) * p0 + next_t * endpoint |
| probs = probs.clamp_min(eps) |
| probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps) |
|
|
| if final_from == "blend_0.5": |
| final_probs = 0.5 * probs + 0.5 * last_endpoint |
| ids = final_probs.argmax(dim=-1).detach().cpu().tolist() |
| elif final_from == "model_t1": |
| t = torch.ones((bs,), dtype=torch.float32, device=device) |
| final_logits = model(state_for_model(model, probs, eps), t, attn).float() |
| ids = final_logits.argmax(dim=-1).detach().cpu().tolist() |
| else: |
| raise ValueError(f"unknown final_from: {final_from}") |
|
|
| all_ids.extend(ids) |
| all_texts.extend(tokenizer.decode(row, stop_at_eos=False, skip_special_tokens=False) for row in ids) |
| remaining -= bs |
| print(f"[linear] generated {n_samples - remaining}/{n_samples}", flush=True) |
|
|
| cfg = { |
| "decode_rule": "linear_simplex_algebraic", |
| "steps": steps, |
| "noise_init": noise_init, |
| "noise_sigma": noise_sigma, |
| "noise_dirichlet_concentration": noise_dirichlet_concentration, |
| "endpoint_temp_start": endpoint_temp_start, |
| "endpoint_temp_end": endpoint_temp_end, |
| "endpoint_projection": endpoint_projection, |
| "endpoint_top_k": endpoint_top_k, |
| "endpoint_top_p": endpoint_top_p, |
| "ban_special_tokens": ban_special_tokens, |
| "banned_endpoint_ids": banned_endpoint_ids, |
| "gumbel_tau_start": gumbel_tau_start, |
| "gumbel_tau_end": gumbel_tau_end, |
| "gumbel_noise_scale_start": gumbel_noise_scale_start, |
| "gumbel_noise_scale_end": gumbel_noise_scale_end, |
| "final_from": final_from, |
| "n_samples": n_samples, |
| "seed": seed, |
| } |
| return all_ids, all_texts, cfg |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Linear-simplex algebraic GenPPL eval") |
| p.add_argument("--checkpoint", required=True) |
| p.add_argument("--tokenizer_path", required=True) |
| p.add_argument("--scorer", required=True) |
| p.add_argument("--out_dir", required=True) |
| p.add_argument("--n_samples", type=int, default=128) |
| p.add_argument("--max_len", type=int, default=128) |
| p.add_argument("--steps", type=int, default=128) |
| p.add_argument("--batch_size", type=int, default=16) |
| p.add_argument("--score_batch", type=int, default=8) |
| p.add_argument("--score_max_length", type=int, default=1024) |
| p.add_argument("--noise_init", choices=["uniform", "logistic_normal", "dirichlet"], default="logistic_normal") |
| p.add_argument("--noise_sigma", type=float, default=3.0) |
| p.add_argument("--noise_dirichlet_concentration", type=float, default=1.0) |
| p.add_argument("--endpoint_temp_start", type=float, default=1.45) |
| p.add_argument("--endpoint_temp_end", type=float, default=0.8) |
| p.add_argument("--endpoint_projection", choices=["soft", "sample", "argmax", "gumbel_softmax"], default="soft") |
| p.add_argument("--endpoint_top_k", type=int, default=0) |
| p.add_argument("--endpoint_top_p", type=float, default=1.0) |
| p.add_argument("--ban_special_tokens", action="store_true") |
| p.add_argument("--gumbel_tau_start", type=float, default=1.0) |
| p.add_argument("--gumbel_tau_end", type=float, default=0.2) |
| p.add_argument("--gumbel_noise_scale_start", type=float, default=1.0) |
| p.add_argument("--gumbel_noise_scale_end", type=float, default=0.0) |
| p.add_argument("--final_from", choices=["blend_0.5", "model_t1"], default="model_t1") |
| p.add_argument("--seed", type=int, default=20260524) |
| return p.parse_args() |
|
|
|
|
| @torch.no_grad() |
| def main() -> None: |
| args = parse_args() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"[load] {args.checkpoint}", flush=True) |
| ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) |
| step = ckpt.get("step") |
| print(f"[ckpt] step={step}", flush=True) |
|
|
| tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) |
| model = build_model(ckpt, tokenizer, device) |
| ids, texts, decode_cfg = decode_linear_simplex( |
| model, |
| tokenizer, |
| n_samples=args.n_samples, |
| batch_size=args.batch_size, |
| max_len=args.max_len, |
| steps=args.steps, |
| seed=args.seed, |
| device=device, |
| noise_init=args.noise_init, |
| noise_sigma=args.noise_sigma, |
| noise_dirichlet_concentration=args.noise_dirichlet_concentration, |
| endpoint_temp_start=args.endpoint_temp_start, |
| endpoint_temp_end=args.endpoint_temp_end, |
| endpoint_projection=args.endpoint_projection, |
| endpoint_top_k=args.endpoint_top_k, |
| endpoint_top_p=args.endpoint_top_p, |
| ban_special_tokens=args.ban_special_tokens, |
| gumbel_tau_start=args.gumbel_tau_start, |
| gumbel_tau_end=args.gumbel_tau_end, |
| gumbel_noise_scale_start=args.gumbel_noise_scale_start, |
| gumbel_noise_scale_end=args.gumbel_noise_scale_end, |
| final_from=args.final_from, |
| ) |
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| def strip_special(t: str) -> str: |
| import re |
| t = t.replace("[CLS]", " ").replace("[SEP]", " ").replace("[PAD]", " ") |
| t = t.replace("<|endoftext|>", " ") |
| return re.sub(r"\s+", " ", t).strip() |
|
|
| stripped = [strip_special(t) for t in texts] |
| kept_raw, _ = filter_generated_texts(texts, min_chars=1, normalize_whitespace=False, drop_empty=True) |
| kept_stripped, _ = filter_generated_texts(stripped, min_chars=1, normalize_whitespace=True, drop_empty=True) |
| diversity = summarize_token_diversity(ids).__dict__ |
|
|
| print(f"[score] loading scorer: {args.scorer}", flush=True) |
| scorer_tok = AutoTokenizer.from_pretrained(args.scorer) |
| if scorer_tok.pad_token_id is None: |
| scorer_tok.pad_token = scorer_tok.eos_token |
| scorer_tok.pad_token_id = scorer_tok.eos_token_id |
| scorer = AutoModelForCausalLM.from_pretrained(args.scorer).to(device).eval() |
| if getattr(scorer.config, "pad_token_id", None) is None: |
| scorer.config.pad_token_id = scorer_tok.pad_token_id |
|
|
| raw_ppl = score_with_gpt2( |
| kept_raw, scorer, scorer_tok, |
| batch_size=args.score_batch, max_length=args.score_max_length, device=device, |
| ) |
| stripped_ppl = score_with_gpt2( |
| kept_stripped, scorer, scorer_tok, |
| batch_size=args.score_batch, max_length=args.score_max_length, device=device, |
| ) |
| summary = { |
| "type": "summary", |
| "checkpoint": args.checkpoint, |
| "step": step, |
| "decode": decode_cfg, |
| "raw_genppl": raw_ppl, |
| "stripped_genppl": stripped_ppl, |
| "diversity": diversity, |
| } |
|
|
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| out_jsonl = out_dir / f"linear_steps{args.steps}_samples{args.n_samples}_scored.jsonl" |
| with out_jsonl.open("w", encoding="utf-8") as f: |
| f.write(json.dumps(summary, ensure_ascii=False) + "\n") |
| for i, (raw, clean) in enumerate(zip(texts, stripped)): |
| f.write(json.dumps({"type": "sample", "index": i, "raw_text": raw, "stripped_text": clean}, ensure_ascii=False) + "\n") |
| print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2), flush=True) |
| print(f"[done] {out_jsonl}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|