| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from eval import build_model_from_ckpt |
| from flowtext_lab.bridges import smooth_onehot |
| from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model |
| from flowtext_lab.tokenization import BpeTextTokenizer |
|
|
|
|
| @dataclass |
| class DecodeConfig: |
| label: str |
| steps: int |
| damping: float = 1.0 |
| max_gamma: float = 1.0 |
| endpoint_temp: float = 1.0 |
| final_from: str = "state" |
|
|
|
|
| def focused_configs(steps: int) -> list[DecodeConfig]: |
| return [ |
| DecodeConfig("flowmap_t1p00_d1p0", steps, endpoint_temp=1.00, damping=1.0), |
| DecodeConfig("flowmap_t1p10_d1p0", steps, endpoint_temp=1.10, damping=1.0), |
| DecodeConfig("flowmap_t1p25_d1p0", steps, endpoint_temp=1.25, damping=1.0), |
| DecodeConfig("flowmap_t1p40_d1p0", steps, endpoint_temp=1.40, damping=1.0), |
| DecodeConfig("flowmap_t1p60_d1p0", steps, endpoint_temp=1.60, damping=1.0), |
| DecodeConfig("flowmap_t1p25_d0p7", steps, endpoint_temp=1.25, damping=0.7), |
| DecodeConfig("flowmap_t1p40_d0p7", steps, endpoint_temp=1.40, damping=0.7), |
| DecodeConfig("flowmap_t1p60_d0p7", steps, endpoint_temp=1.60, damping=0.7), |
| DecodeConfig("flowmap_t1p25_g0p5", steps, endpoint_temp=1.25, damping=1.0, max_gamma=0.5), |
| DecodeConfig("flowmap_t1p40_g0p5", steps, endpoint_temp=1.40, damping=1.0, max_gamma=0.5), |
| ] |
|
|
|
|
| def flowmap_gamma(step: int, steps: int, damping: float, max_gamma: float, eps: float) -> float: |
| s = step / max(steps, 1) |
| t_next = (step + 1) / max(steps, 1) |
| base = (t_next - s) / max(1.0 - s, eps) |
| gamma = float(damping) * base |
| return min(gamma, float(max_gamma)) if max_gamma > 0 else gamma |
|
|
|
|
| def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]: |
| core = list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids) |
| bos = tokenizer.bos_id |
| ids = ([bos] if bos is not None and bos >= 0 else []) + core |
| return ids[:max_len] |
|
|
|
|
| def decode_text(tokenizer: BpeTextTokenizer, ids: list[int]) -> str: |
| return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False) |
|
|
|
|
| def build_initial_state( |
| tokenizer: BpeTextTokenizer, |
| prompts: list[str], |
| restarts: int, |
| max_len: int, |
| target_prob: float, |
| eps: float, |
| noise_init: str, |
| dirichlet_concentration: float, |
| device: torch.device, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]: |
| expanded: list[str] = [] |
| prompt_ids: list[list[int]] = [] |
| for prompt in prompts: |
| ids = encode_prompt(tokenizer, prompt, max_len) |
| for _ in range(restarts): |
| expanded.append(prompt) |
| prompt_ids.append(ids) |
|
|
| batch = len(prompt_ids) |
| probs = sample_noise_simplex( |
| (batch, max_len), |
| tokenizer.vocab_size, |
| device, |
| eps, |
| noise_mode=noise_init, |
| target_prob=target_prob, |
| noise_sigma=-1.0, |
| dirichlet_concentration=dirichlet_concentration, |
| ).float() |
| attn = torch.ones((batch, max_len), dtype=torch.bool, device=device) |
| lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device) |
| lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device) |
| for row, ids in enumerate(prompt_ids): |
| if not ids: |
| continue |
| ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) |
| sp = smooth_onehot(ids_t, tokenizer.vocab_size, target_prob, eps)[0] |
| probs[row, : len(ids)] = sp |
| lock_probs[row, : len(ids)] = sp |
| lock[row, : len(ids)] = True |
| return probs, attn, lock, lock_probs, expanded |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--checkpoint", required=True) |
| p.add_argument("--tokenizer_path", required=True) |
| p.add_argument("--output", required=True) |
| p.add_argument("--prompts", required=True) |
| p.add_argument("--prompt", required=True) |
| p.add_argument("--restarts", type=int, default=20) |
| p.add_argument("--candidate_index", type=int, required=True) |
| p.add_argument("--steps", type=int, required=True) |
| p.add_argument("--config_label", required=True) |
| p.add_argument("--max_len", type=int, default=128) |
| p.add_argument("--seed", type=int, default=20260502) |
| p.add_argument("--target_prob", type=float, default=1.0) |
| p.add_argument("--noise_init", default="dirichlet") |
| p.add_argument("--dirichlet_init_concentration", type=float, default=1.0) |
| p.add_argument("--eps", type=float, default=1e-8) |
| return p.parse_args() |
|
|
|
|
| @torch.no_grad() |
| def main() -> None: |
| args = parse_args() |
| torch.manual_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) |
| ckpt = torch.load(args.checkpoint, map_location="cpu") |
| model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device) |
| model.eval() |
|
|
| prompts = args.prompts.split("|") |
| configs = focused_configs(args.steps) |
| selected_cfg = None |
| init = attn = lock = lock_probs = None |
| expanded: list[str] = [] |
| |
| |
| for cfg in configs: |
| init, attn, lock, lock_probs, expanded = build_initial_state( |
| tokenizer, |
| prompts, |
| args.restarts, |
| args.max_len, |
| args.target_prob, |
| args.eps, |
| args.noise_init, |
| args.dirichlet_init_concentration, |
| device, |
| ) |
| if cfg.label == args.config_label: |
| selected_cfg = cfg |
| break |
| del init, attn, lock, lock_probs |
| if selected_cfg is None or init is None or attn is None or lock is None or lock_probs is None: |
| raise ValueError(f"unknown config_label {args.config_label}") |
| if expanded[args.candidate_index] != args.prompt: |
| raise ValueError( |
| f"candidate prompt mismatch: candidate={args.candidate_index} has {expanded[args.candidate_index]!r}, expected {args.prompt!r}" |
| ) |
|
|
| sl = slice(args.candidate_index, args.candidate_index + 1) |
| probs = init[sl].clone() |
| attn = attn[sl] |
| lock = lock[sl] |
| lock_probs = lock_probs[sl] |
| last_endpoint = probs |
| records = [] |
| for step in range(selected_cfg.steps): |
| t = model_time_for_step("flow", step, selected_cfg.steps, 1, device, dtype=torch.float32) |
| logits = model(state_for_model(model, probs, args.eps), t, attn).float() |
| logits = logits / float(selected_cfg.endpoint_temp) |
| endpoint = F.softmax(logits, dim=-1) |
| last_endpoint = endpoint |
| gamma = flowmap_gamma(step, selected_cfg.steps, selected_cfg.damping, selected_cfg.max_gamma, args.eps) |
| new_probs = probs + gamma * (endpoint - probs) |
| new_probs = new_probs.clamp_min(args.eps) |
| new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
| probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs) |
|
|
| state_top_prob, state_ids = probs[0].max(dim=-1) |
| endpoint_top_prob, endpoint_ids = endpoint[0].max(dim=-1) |
| records.append( |
| { |
| "step": step, |
| "gamma": gamma, |
| "model_t": float(t.item()), |
| "state_text": decode_text(tokenizer, state_ids.detach().cpu().tolist()), |
| "endpoint_text": decode_text(tokenizer, endpoint_ids.detach().cpu().tolist()), |
| "positions": [ |
| { |
| "pos": pos, |
| "state_token": tokenizer.decode([int(state_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False), |
| "state_id": int(state_ids[pos].item()), |
| "state_top_p": float(state_top_prob[pos].item()), |
| "endpoint_token": tokenizer.decode([int(endpoint_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False), |
| "endpoint_id": int(endpoint_ids[pos].item()), |
| "endpoint_top_p": float(endpoint_top_prob[pos].item()), |
| } |
| for pos in range(args.max_len) |
| ], |
| } |
| ) |
|
|
| if selected_cfg.final_from == "endpoint": |
| final_dist = torch.where(lock.unsqueeze(-1), lock_probs, last_endpoint) |
| else: |
| final_dist = probs |
| final_dist = final_dist / final_dist.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
| final_ids = final_dist[0].argmax(dim=-1).detach().cpu().tolist() |
| payload = { |
| "checkpoint": args.checkpoint, |
| "seed": args.seed, |
| "prompts": prompts, |
| "prompt": args.prompt, |
| "restarts": args.restarts, |
| "candidate_index": args.candidate_index, |
| "steps": args.steps, |
| "config": selected_cfg.__dict__, |
| "final_ids": final_ids, |
| "final_text": decode_text(tokenizer, final_ids), |
| "records": records, |
| } |
| out = Path(args.output) |
| out.parent.mkdir(parents=True, exist_ok=True) |
| out.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") |
| print(json.dumps({"output": str(out), "final": payload["final_text"]}, ensure_ascii=False, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|