#!/usr/bin/env python3 from __future__ import annotations import argparse import json import sys from pathlib import Path import torch import torch.nn.functional as F REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from eval import build_model_from_ckpt from flowtext_lab.bridges import smooth_onehot from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model from flowtext_lab.tokenization import BpeTextTokenizer from scripts.flowtext_decode_lab import DecodeConfig, decode_text, flowmap_gamma def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--checkpoint", required=True) p.add_argument("--tokenizer_path", required=True) p.add_argument("--output", required=True) p.add_argument("--prompt", required=True) p.add_argument("--candidate_index", type=int, required=True) p.add_argument("--max_len", type=int, default=128) p.add_argument("--steps", type=int, default=128) p.add_argument("--seed", type=int, default=20260502) p.add_argument("--target_prob", type=float, default=1.0) p.add_argument("--endpoint_temp", type=float, default=1.4) p.add_argument("--damping", type=float, default=1.0) p.add_argument("--max_gamma", type=float, default=1.0) p.add_argument("--final_from", choices=["state", "endpoint", "blend"], default="state") p.add_argument("--eps", type=float, default=1e-8) return p.parse_args() def encode_prefix(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]: core = list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids) bos = tokenizer.bos_id ids = ([bos] if bos is not None and bos >= 0 else []) + core return ids[:max_len] @torch.no_grad() def main() -> None: args = parse_args() torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) ckpt = torch.load(args.checkpoint, map_location="cpu") model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device) model.eval() init = sample_noise_simplex( (args.candidate_index + 1, args.max_len), tokenizer.vocab_size, device, args.eps, noise_mode="dirichlet", target_prob=args.target_prob, noise_sigma=-1.0, dirichlet_concentration=1.0, )[-1:].float() attn = torch.ones((1, args.max_len), dtype=torch.bool, device=device) prompt_ids = encode_prefix(tokenizer, args.prompt, args.max_len) lock = torch.zeros((1, args.max_len), dtype=torch.bool, device=device) lock_probs = torch.zeros((1, args.max_len, tokenizer.vocab_size), dtype=torch.float32, device=device) if prompt_ids: ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0) sp = smooth_onehot(ids_t, tokenizer.vocab_size, args.target_prob, args.eps)[0] init[0, : len(prompt_ids)] = sp lock_probs[0, : len(prompt_ids)] = sp lock[0, : len(prompt_ids)] = True probs = init.clone() last_endpoint = probs records = [] cfg = DecodeConfig( label="trace", rule="flowmap", steps=args.steps, model_t_mode="flow", damping=args.damping, max_gamma=args.max_gamma, endpoint_temp=args.endpoint_temp, final_from=args.final_from, ) for step in range(args.steps): t = model_time_for_step(cfg.model_t_mode, step, cfg.steps, 1, device, dtype=torch.float32) logits = model(state_for_model(model, probs, args.eps), t, attn).float() logits = logits / float(cfg.endpoint_temp) endpoint = F.softmax(logits, dim=-1) last_endpoint = endpoint gamma = flowmap_gamma(step, cfg.steps, cfg.damping, cfg.max_gamma, args.eps) new_probs = probs + gamma * (endpoint - probs) new_probs = new_probs.clamp_min(args.eps) new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs) state_top_prob, state_ids = probs[0].max(dim=-1) state_entropy = -(probs[0].clamp_min(args.eps) * probs[0].clamp_min(args.eps).log()).sum(dim=-1) endpoint_top_prob, endpoint_ids = endpoint[0].max(dim=-1) records.append( { "step": step, "gamma": gamma, "model_t": float(t.item()), "text_prefix": decode_text(tokenizer, state_ids[:64].detach().cpu().tolist()), "positions": [ { "pos": pos, "state_token": tokenizer.decode([int(state_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False), "state_id": int(state_ids[pos].item()), "state_top_p": float(state_top_prob[pos].item()), "state_entropy": float(state_entropy[pos].item()), "endpoint_token": tokenizer.decode([int(endpoint_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False), "endpoint_id": int(endpoint_ids[pos].item()), "endpoint_top_p": float(endpoint_top_prob[pos].item()), } for pos in range(args.max_len) ], } ) if args.final_from == "endpoint": final_dist = torch.where(lock.unsqueeze(-1), lock_probs, last_endpoint) elif args.final_from == "blend": final_dist = torch.where(lock.unsqueeze(-1), lock_probs, 0.5 * probs + 0.5 * last_endpoint) else: final_dist = probs final_dist = final_dist / final_dist.sum(dim=-1, keepdim=True).clamp_min(args.eps) final_ids = final_dist[0].argmax(dim=-1).detach().cpu().tolist() final_text = decode_text(tokenizer, final_ids) payload = { "checkpoint": args.checkpoint, "seed": args.seed, "prompt": args.prompt, "candidate_index": args.candidate_index, "steps": args.steps, "endpoint_temp": args.endpoint_temp, "damping": args.damping, "max_gamma": args.max_gamma, "final_from": args.final_from, "prompt_ids": prompt_ids, "final_ids": final_ids, "final_text": final_text, "records": records, } out = Path(args.output) out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") print(json.dumps({"output": str(out), "final_text": final_text}, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()