#!/usr/bin/env python3 from __future__ import annotations import argparse import json import sys from dataclasses import dataclass from pathlib import Path import torch import torch.nn.functional as F REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from eval import build_model_from_ckpt from flowtext_lab.bridges import smooth_onehot from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model from flowtext_lab.tokenization import BpeTextTokenizer @dataclass class DecodeConfig: label: str steps: int damping: float = 1.0 max_gamma: float = 1.0 endpoint_temp: float = 1.0 final_from: str = "state" def focused_configs(steps: int) -> list[DecodeConfig]: return [ DecodeConfig("flowmap_t1p00_d1p0", steps, endpoint_temp=1.00, damping=1.0), DecodeConfig("flowmap_t1p10_d1p0", steps, endpoint_temp=1.10, damping=1.0), DecodeConfig("flowmap_t1p25_d1p0", steps, endpoint_temp=1.25, damping=1.0), DecodeConfig("flowmap_t1p40_d1p0", steps, endpoint_temp=1.40, damping=1.0), DecodeConfig("flowmap_t1p60_d1p0", steps, endpoint_temp=1.60, damping=1.0), DecodeConfig("flowmap_t1p25_d0p7", steps, endpoint_temp=1.25, damping=0.7), DecodeConfig("flowmap_t1p40_d0p7", steps, endpoint_temp=1.40, damping=0.7), DecodeConfig("flowmap_t1p60_d0p7", steps, endpoint_temp=1.60, damping=0.7), DecodeConfig("flowmap_t1p25_g0p5", steps, endpoint_temp=1.25, damping=1.0, max_gamma=0.5), DecodeConfig("flowmap_t1p40_g0p5", steps, endpoint_temp=1.40, damping=1.0, max_gamma=0.5), ] def flowmap_gamma(step: int, steps: int, damping: float, max_gamma: float, eps: float) -> float: s = step / max(steps, 1) t_next = (step + 1) / max(steps, 1) base = (t_next - s) / max(1.0 - s, eps) gamma = float(damping) * base return min(gamma, float(max_gamma)) if max_gamma > 0 else gamma def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]: core = list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids) bos = tokenizer.bos_id ids = ([bos] if bos is not None and bos >= 0 else []) + core return ids[:max_len] def decode_text(tokenizer: BpeTextTokenizer, ids: list[int]) -> str: return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False) def build_initial_state( tokenizer: BpeTextTokenizer, prompts: list[str], restarts: int, max_len: int, target_prob: float, eps: float, noise_init: str, dirichlet_concentration: float, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]: expanded: list[str] = [] prompt_ids: list[list[int]] = [] for prompt in prompts: ids = encode_prompt(tokenizer, prompt, max_len) for _ in range(restarts): expanded.append(prompt) prompt_ids.append(ids) batch = len(prompt_ids) probs = sample_noise_simplex( (batch, max_len), tokenizer.vocab_size, device, eps, noise_mode=noise_init, target_prob=target_prob, noise_sigma=-1.0, dirichlet_concentration=dirichlet_concentration, ).float() attn = torch.ones((batch, max_len), dtype=torch.bool, device=device) lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device) lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device) for row, ids in enumerate(prompt_ids): if not ids: continue ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) sp = smooth_onehot(ids_t, tokenizer.vocab_size, target_prob, eps)[0] probs[row, : len(ids)] = sp lock_probs[row, : len(ids)] = sp lock[row, : len(ids)] = True return probs, attn, lock, lock_probs, expanded def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--checkpoint", required=True) p.add_argument("--tokenizer_path", required=True) p.add_argument("--output", required=True) p.add_argument("--prompts", required=True) p.add_argument("--prompt", required=True) p.add_argument("--restarts", type=int, default=20) p.add_argument("--candidate_index", type=int, required=True) p.add_argument("--steps", type=int, required=True) p.add_argument("--config_label", required=True) p.add_argument("--max_len", type=int, default=128) p.add_argument("--seed", type=int, default=20260502) p.add_argument("--target_prob", type=float, default=1.0) p.add_argument("--noise_init", default="dirichlet") p.add_argument("--dirichlet_init_concentration", type=float, default=1.0) p.add_argument("--eps", type=float, default=1e-8) return p.parse_args() @torch.no_grad() def main() -> None: args = parse_args() torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) ckpt = torch.load(args.checkpoint, map_location="cpu") model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device) model.eval() prompts = args.prompts.split("|") configs = focused_configs(args.steps) selected_cfg = None init = attn = lock = lock_probs = None expanded: list[str] = [] # Reproduce the decode-sweep RNG stream: every config samples a fresh initial # batch. We consume the same initial batches until the requested config. for cfg in configs: init, attn, lock, lock_probs, expanded = build_initial_state( tokenizer, prompts, args.restarts, args.max_len, args.target_prob, args.eps, args.noise_init, args.dirichlet_init_concentration, device, ) if cfg.label == args.config_label: selected_cfg = cfg break del init, attn, lock, lock_probs if selected_cfg is None or init is None or attn is None or lock is None or lock_probs is None: raise ValueError(f"unknown config_label {args.config_label}") if expanded[args.candidate_index] != args.prompt: raise ValueError( f"candidate prompt mismatch: candidate={args.candidate_index} has {expanded[args.candidate_index]!r}, expected {args.prompt!r}" ) sl = slice(args.candidate_index, args.candidate_index + 1) probs = init[sl].clone() attn = attn[sl] lock = lock[sl] lock_probs = lock_probs[sl] last_endpoint = probs records = [] for step in range(selected_cfg.steps): t = model_time_for_step("flow", step, selected_cfg.steps, 1, device, dtype=torch.float32) logits = model(state_for_model(model, probs, args.eps), t, attn).float() logits = logits / float(selected_cfg.endpoint_temp) endpoint = F.softmax(logits, dim=-1) last_endpoint = endpoint gamma = flowmap_gamma(step, selected_cfg.steps, selected_cfg.damping, selected_cfg.max_gamma, args.eps) new_probs = probs + gamma * (endpoint - probs) new_probs = new_probs.clamp_min(args.eps) new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs) state_top_prob, state_ids = probs[0].max(dim=-1) endpoint_top_prob, endpoint_ids = endpoint[0].max(dim=-1) records.append( { "step": step, "gamma": gamma, "model_t": float(t.item()), "state_text": decode_text(tokenizer, state_ids.detach().cpu().tolist()), "endpoint_text": decode_text(tokenizer, endpoint_ids.detach().cpu().tolist()), "positions": [ { "pos": pos, "state_token": tokenizer.decode([int(state_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False), "state_id": int(state_ids[pos].item()), "state_top_p": float(state_top_prob[pos].item()), "endpoint_token": tokenizer.decode([int(endpoint_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False), "endpoint_id": int(endpoint_ids[pos].item()), "endpoint_top_p": float(endpoint_top_prob[pos].item()), } for pos in range(args.max_len) ], } ) if selected_cfg.final_from == "endpoint": final_dist = torch.where(lock.unsqueeze(-1), lock_probs, last_endpoint) else: final_dist = probs final_dist = final_dist / final_dist.sum(dim=-1, keepdim=True).clamp_min(args.eps) final_ids = final_dist[0].argmax(dim=-1).detach().cpu().tolist() payload = { "checkpoint": args.checkpoint, "seed": args.seed, "prompts": prompts, "prompt": args.prompt, "restarts": args.restarts, "candidate_index": args.candidate_index, "steps": args.steps, "config": selected_cfg.__dict__, "final_ids": final_ids, "final_text": decode_text(tokenizer, final_ids), "records": records, } out = Path(args.output) out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") print(json.dumps({"output": str(out), "final": payload["final_text"]}, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()