#!/usr/bin/env python3 from __future__ import annotations import argparse import json import math import re import sys from collections import Counter from pathlib import Path from typing import Sequence import torch import torch.nn.functional as F REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from eval import build_model_from_ckpt from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model from flowtext_lab.tokenization import BpeTextTokenizer WORD_RE = re.compile(r"[A-Za-z]+|\d+|[^\sA-Za-z\d]") def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]: return list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids)[:max_len] def decode_text(tokenizer: BpeTextTokenizer, ids: Sequence[int]) -> str: return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False) def text_metrics(text: str) -> dict[str, float]: toks = WORD_RE.findall(text) words = [t.lower() for t in toks if re.fullmatch(r"[A-Za-z]+", t)] n_tok = max(len(toks), 1) n_words = max(len(words), 1) wc = Counter(words) max_word_frac = wc.most_common(1)[0][1] / n_words if wc else 1.0 grams3 = list(zip(toks, toks[1:], toks[2:])) rep3 = sum(v - 1 for v in Counter(grams3).values() if v > 1) / max(len(grams3), 1) bigrams = list(zip(words, words[1:])) distinct2 = len(set(bigrams)) / max(len(bigrams), 1) if bigrams else 0.0 punct_frac = sum(bool(re.fullmatch(r"[,.;:!?]+", t)) for t in toks) / n_tok digit_frac = sum(t.isdigit() for t in toks) / n_tok quality = ( min(len(text) / 700.0, 1.0) + 0.35 * distinct2 - 2.6 * rep3 - 1.2 * max_word_frac - 0.8 * punct_frac - 1.0 * digit_frac - 0.2 * text.count("<|endoftext|>") - 0.5 * text.count("�") ) return { "quality": float(quality), "chars": float(len(text)), "words": float(len(words)), "rep3": float(rep3), "distinct2": float(distinct2), "punct_frac": float(punct_frac), "max_word_frac": float(max_word_frac), "eot_count": float(text.count("<|endoftext|>")), } def dirichlet_mean(endpoint: torch.Tensor, support_t: float, eps: float) -> torch.Tensor: vocab = endpoint.size(-1) mean = (1.0 - support_t) / float(vocab) + support_t * endpoint mean = mean.clamp_min(eps) return mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps) def total_concentration(support_t: float, c_min: float, c_max: float) -> float: log_min = math.log(max(c_min, 1e-8)) log_max = math.log(max(c_max, c_min)) return math.exp(log_min + support_t * (log_max - log_min)) def dirichlet_resample(mean: torch.Tensor, support_t: float, c_min: float, c_max: float, eps: float) -> torch.Tensor: conc = total_concentration(support_t, c_min, c_max) alpha = (mean * conc).clamp_min(eps) sample = torch._standard_gamma(alpha).clamp_min(eps) return sample / sample.sum(dim=-1, keepdim=True).clamp_min(eps) def schedule_power(step: int, steps: int, power: float) -> float: base = (step + 1) / max(steps, 1) return float(max(0.0, min(1.0, base ** float(power)))) def current_anchor(probs: torch.Tensor, mode: str, eps: float) -> torch.Tensor: if mode == "state": return probs if mode == "onehot": ids = probs.argmax(dim=-1) return F.one_hot(ids, probs.size(-1)).to(dtype=probs.dtype, device=probs.device) if mode == "sqrt_state": x = probs.clamp_min(eps).sqrt() return x / x.sum(dim=-1, keepdim=True).clamp_min(eps) raise ValueError(f"unknown anchor mode: {mode}") @torch.no_grad() def build_initial( tokenizer: BpeTextTokenizer, prompts: list[str], restarts: int, max_len: int, eps: float, noise_init: str, target_prob: float, noise_sigma: float, dirichlet_concentration: float, lock_bos: bool, lock_final_eos: bool, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]: expanded = [] prompt_ids = [] for prompt in prompts: ids = encode_prompt(tokenizer, prompt, max_len) if lock_bos: ids = [tokenizer.bos_id] + ids ids = ids[:max_len] for _ in range(restarts): expanded.append(prompt) prompt_ids.append(ids) batch = len(prompt_ids) probs = sample_noise_simplex( (batch, max_len), tokenizer.vocab_size, device, eps, noise_mode=noise_init, target_prob=target_prob, noise_sigma=noise_sigma, dirichlet_concentration=dirichlet_concentration, ) lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device) lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device) for row, ids in enumerate(prompt_ids): if not ids: continue ids_t = torch.tensor(ids, dtype=torch.long, device=device) onehot = F.one_hot(ids_t, tokenizer.vocab_size).float() probs[row, : len(ids)] = onehot lock_probs[row, : len(ids)] = onehot lock[row, : len(ids)] = True if lock_final_eos: eos = torch.tensor([tokenizer.eos_id], dtype=torch.long, device=device) eos_prob = F.one_hot(eos, tokenizer.vocab_size).float()[0] probs[:, -1] = eos_prob lock_probs[:, -1] = eos_prob lock[:, -1] = True attn = torch.ones((batch, max_len), dtype=torch.bool, device=device) return probs, lock, lock_probs, attn, expanded @torch.no_grad() def decode_one_config( model, tokenizer, init, lock, lock_probs, attn, args, update: str, final_from: str, temp: float, model_t_mode: str, support_power: float, semantic_power: float, anchor_mode: str, ): probs = init.clone() last_endpoint = probs device = probs.device for step in range(args.steps): model_t = model_time_for_step(model_t_mode, step, args.steps, probs.size(0), device, dtype=torch.float32) logits = model(state_for_model(model, probs, args.eps), model_t, attn).float() / temp endpoint = F.softmax(logits, dim=-1) last_endpoint = endpoint support_t = schedule_power(step, args.steps, support_power) semantic_t = schedule_power(step, args.steps, semantic_power) if update.startswith("dual_line"): anchor = current_anchor(probs, anchor_mode, args.eps) forward_endpoint = (1.0 - semantic_t) * anchor + semantic_t * endpoint forward_endpoint = forward_endpoint / forward_endpoint.sum(dim=-1, keepdim=True).clamp_min(args.eps) else: forward_endpoint = endpoint mean = dirichlet_mean(forward_endpoint, support_t, args.eps) if update == "mean": new_probs = mean elif update == "resample": new_probs = dirichlet_resample(mean, support_t, args.concentration_min, args.concentration_max, args.eps) elif update == "dual_line_mean": new_probs = mean elif update == "dual_line_resample": new_probs = dirichlet_resample(mean, support_t, args.concentration_min, args.concentration_max, args.eps) elif update == "ema_mean": gamma = 1.0 / max(args.steps - step, 1) new_probs = (1.0 - gamma) * probs + gamma * mean new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) else: raise ValueError(update) probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs) if final_from == "endpoint": out = last_endpoint elif final_from == "blend": out = 0.5 * probs + 0.5 * last_endpoint else: out = probs out = torch.where(lock.unsqueeze(-1), lock_probs, out) return out / out.sum(dim=-1, keepdim=True).clamp_min(args.eps) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--checkpoint", required=True) ap.add_argument("--tokenizer_path", required=True) ap.add_argument("--output", required=True) ap.add_argument("--max_len", type=int, default=256) ap.add_argument("--steps", type=int, default=256) ap.add_argument("--restarts", type=int, default=4) ap.add_argument("--prompts", nargs="+", default=[""]) ap.add_argument("--noise_init", default="dirichlet") ap.add_argument("--target_prob", type=float, default=0.99) ap.add_argument("--noise_sigma", type=float, default=-1.0) ap.add_argument("--dirichlet_init_concentration", type=float, default=1.0) ap.add_argument("--concentration_min", type=float, default=1.0) ap.add_argument("--concentration_max", type=float, default=1024.0) ap.add_argument("--updates", nargs="+", default=["mean", "ema_mean", "resample"]) ap.add_argument("--finals", nargs="+", default=["state", "endpoint", "blend"]) ap.add_argument("--temps", nargs="+", type=float, default=[1.0, 1.2, 1.35]) ap.add_argument("--model_t_modes", nargs="+", default=["flow", "const05"]) ap.add_argument("--support_powers", nargs="+", type=float, default=[1.0]) ap.add_argument("--semantic_powers", nargs="+", type=float, default=[1.0]) ap.add_argument("--anchor_modes", nargs="+", default=["onehot"]) ap.add_argument("--lock_bos", action="store_true") ap.add_argument("--lock_final_eos", action="store_true") ap.add_argument("--eps", type=float, default=1e-8) ap.add_argument("--seed", type=int, default=1234) args = ap.parse_args() torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) ckpt = torch.load(args.checkpoint, map_location=device) model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device) model.eval() init, lock, lock_probs, attn, expanded = build_initial( tokenizer, args.prompts, args.restarts, args.max_len, args.eps, args.noise_init, args.target_prob, args.noise_sigma, args.dirichlet_init_concentration, args.lock_bos, args.lock_final_eos, device, ) configs = [] for update in args.updates: for final_from in args.finals: for temp in args.temps: for model_t_mode in args.model_t_modes: for support_power in args.support_powers: for semantic_power in args.semantic_powers: for anchor_mode in args.anchor_modes: configs.append((update, final_from, temp, model_t_mode, support_power, semantic_power, anchor_mode)) out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) rows = [] with out_path.open("w") as f: for update, final_from, temp, model_t_mode, support_power, semantic_power, anchor_mode in configs: probs = decode_one_config( model, tokenizer, init, lock, lock_probs, attn, args, update, final_from, temp, model_t_mode, support_power, semantic_power, anchor_mode, ) ids = probs.argmax(dim=-1).detach().cpu().tolist() texts = [decode_text(tokenizer, row) for row in ids] mets = [text_metrics(t) for t in texts] mean_q = sum(m["quality"] for m in mets) / len(mets) best_i = max(range(len(texts)), key=lambda i: mets[i]["quality"]) row = { "update": update, "final_from": final_from, "temp": temp, "model_t_mode": model_t_mode, "support_power": support_power, "semantic_power": semantic_power, "anchor_mode": anchor_mode, "mean_quality": mean_q, "best_prompt": expanded[best_i], "best_metrics": mets[best_i], "best_text": texts[best_i], } rows.append(row) print( "\n====", update, final_from, temp, model_t_mode, "support_p", support_power, "semantic_p", semantic_power, "anchor", anchor_mode, "mean_q", round(mean_q, 4), flush=True, ) print(texts[best_i][:1600], flush=True) f.write(json.dumps(row, ensure_ascii=False) + "\n") f.flush() best = max(rows, key=lambda r: r["mean_quality"]) print("\nBEST", json.dumps({k: best[k] for k in best if k != "best_text"}, ensure_ascii=False, indent=2), flush=True) print(best["best_text"], flush=True) if __name__ == "__main__": main()