| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from eval import build_model_from_ckpt |
| from flowtext_lab.bridges import smooth_onehot |
| from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model |
| from flowtext_lab.tokenization import BpeTextTokenizer |
| from scripts.flowtext_decode_lab import DecodeConfig, decode_text, flowmap_gamma |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--checkpoint", required=True) |
| p.add_argument("--tokenizer_path", required=True) |
| p.add_argument("--output", required=True) |
| p.add_argument("--prompt", required=True) |
| p.add_argument("--candidate_index", type=int, required=True) |
| p.add_argument("--max_len", type=int, default=128) |
| p.add_argument("--steps", type=int, default=128) |
| p.add_argument("--seed", type=int, default=20260502) |
| p.add_argument("--target_prob", type=float, default=1.0) |
| p.add_argument("--endpoint_temp", type=float, default=1.4) |
| p.add_argument("--damping", type=float, default=1.0) |
| p.add_argument("--max_gamma", type=float, default=1.0) |
| p.add_argument("--final_from", choices=["state", "endpoint", "blend"], default="state") |
| p.add_argument("--eps", type=float, default=1e-8) |
| return p.parse_args() |
|
|
|
|
| def encode_prefix(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]: |
| core = list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids) |
| bos = tokenizer.bos_id |
| ids = ([bos] if bos is not None and bos >= 0 else []) + core |
| return ids[:max_len] |
|
|
|
|
| @torch.no_grad() |
| def main() -> None: |
| args = parse_args() |
| torch.manual_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) |
| ckpt = torch.load(args.checkpoint, map_location="cpu") |
| model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device) |
| model.eval() |
|
|
| init = sample_noise_simplex( |
| (args.candidate_index + 1, args.max_len), |
| tokenizer.vocab_size, |
| device, |
| args.eps, |
| noise_mode="dirichlet", |
| target_prob=args.target_prob, |
| noise_sigma=-1.0, |
| dirichlet_concentration=1.0, |
| )[-1:].float() |
| attn = torch.ones((1, args.max_len), dtype=torch.bool, device=device) |
|
|
| prompt_ids = encode_prefix(tokenizer, args.prompt, args.max_len) |
| lock = torch.zeros((1, args.max_len), dtype=torch.bool, device=device) |
| lock_probs = torch.zeros((1, args.max_len, tokenizer.vocab_size), dtype=torch.float32, device=device) |
| if prompt_ids: |
| ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0) |
| sp = smooth_onehot(ids_t, tokenizer.vocab_size, args.target_prob, args.eps)[0] |
| init[0, : len(prompt_ids)] = sp |
| lock_probs[0, : len(prompt_ids)] = sp |
| lock[0, : len(prompt_ids)] = True |
|
|
| probs = init.clone() |
| last_endpoint = probs |
| records = [] |
| cfg = DecodeConfig( |
| label="trace", |
| rule="flowmap", |
| steps=args.steps, |
| model_t_mode="flow", |
| damping=args.damping, |
| max_gamma=args.max_gamma, |
| endpoint_temp=args.endpoint_temp, |
| final_from=args.final_from, |
| ) |
|
|
| for step in range(args.steps): |
| t = model_time_for_step(cfg.model_t_mode, step, cfg.steps, 1, device, dtype=torch.float32) |
| logits = model(state_for_model(model, probs, args.eps), t, attn).float() |
| logits = logits / float(cfg.endpoint_temp) |
| endpoint = F.softmax(logits, dim=-1) |
| last_endpoint = endpoint |
| gamma = flowmap_gamma(step, cfg.steps, cfg.damping, cfg.max_gamma, args.eps) |
| new_probs = probs + gamma * (endpoint - probs) |
| new_probs = new_probs.clamp_min(args.eps) |
| new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
| probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs) |
|
|
| state_top_prob, state_ids = probs[0].max(dim=-1) |
| state_entropy = -(probs[0].clamp_min(args.eps) * probs[0].clamp_min(args.eps).log()).sum(dim=-1) |
| endpoint_top_prob, endpoint_ids = endpoint[0].max(dim=-1) |
| records.append( |
| { |
| "step": step, |
| "gamma": gamma, |
| "model_t": float(t.item()), |
| "text_prefix": decode_text(tokenizer, state_ids[:64].detach().cpu().tolist()), |
| "positions": [ |
| { |
| "pos": pos, |
| "state_token": tokenizer.decode([int(state_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False), |
| "state_id": int(state_ids[pos].item()), |
| "state_top_p": float(state_top_prob[pos].item()), |
| "state_entropy": float(state_entropy[pos].item()), |
| "endpoint_token": tokenizer.decode([int(endpoint_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False), |
| "endpoint_id": int(endpoint_ids[pos].item()), |
| "endpoint_top_p": float(endpoint_top_prob[pos].item()), |
| } |
| for pos in range(args.max_len) |
| ], |
| } |
| ) |
|
|
| if args.final_from == "endpoint": |
| final_dist = torch.where(lock.unsqueeze(-1), lock_probs, last_endpoint) |
| elif args.final_from == "blend": |
| final_dist = torch.where(lock.unsqueeze(-1), lock_probs, 0.5 * probs + 0.5 * last_endpoint) |
| else: |
| final_dist = probs |
| final_dist = final_dist / final_dist.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
| final_ids = final_dist[0].argmax(dim=-1).detach().cpu().tolist() |
| final_text = decode_text(tokenizer, final_ids) |
|
|
| payload = { |
| "checkpoint": args.checkpoint, |
| "seed": args.seed, |
| "prompt": args.prompt, |
| "candidate_index": args.candidate_index, |
| "steps": args.steps, |
| "endpoint_temp": args.endpoint_temp, |
| "damping": args.damping, |
| "max_gamma": args.max_gamma, |
| "final_from": args.final_from, |
| "prompt_ids": prompt_ids, |
| "final_ids": final_ids, |
| "final_text": final_text, |
| "records": records, |
| } |
| out = Path(args.output) |
| out.parent.mkdir(parents=True, exist_ok=True) |
| out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") |
| print(json.dumps({"output": str(out), "final_text": final_text}, ensure_ascii=False, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|