File size: 6,780 Bytes
edff6fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | #!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from eval import build_model_from_ckpt
from flowtext_lab.bridges import smooth_onehot
from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model
from flowtext_lab.tokenization import BpeTextTokenizer
from scripts.flowtext_decode_lab import DecodeConfig, decode_text, flowmap_gamma
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", required=True)
p.add_argument("--tokenizer_path", required=True)
p.add_argument("--output", required=True)
p.add_argument("--prompt", required=True)
p.add_argument("--candidate_index", type=int, required=True)
p.add_argument("--max_len", type=int, default=128)
p.add_argument("--steps", type=int, default=128)
p.add_argument("--seed", type=int, default=20260502)
p.add_argument("--target_prob", type=float, default=1.0)
p.add_argument("--endpoint_temp", type=float, default=1.4)
p.add_argument("--damping", type=float, default=1.0)
p.add_argument("--max_gamma", type=float, default=1.0)
p.add_argument("--final_from", choices=["state", "endpoint", "blend"], default="state")
p.add_argument("--eps", type=float, default=1e-8)
return p.parse_args()
def encode_prefix(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]:
core = list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids)
bos = tokenizer.bos_id
ids = ([bos] if bos is not None and bos >= 0 else []) + core
return ids[:max_len]
@torch.no_grad()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
ckpt = torch.load(args.checkpoint, map_location="cpu")
model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device)
model.eval()
init = sample_noise_simplex(
(args.candidate_index + 1, args.max_len),
tokenizer.vocab_size,
device,
args.eps,
noise_mode="dirichlet",
target_prob=args.target_prob,
noise_sigma=-1.0,
dirichlet_concentration=1.0,
)[-1:].float()
attn = torch.ones((1, args.max_len), dtype=torch.bool, device=device)
prompt_ids = encode_prefix(tokenizer, args.prompt, args.max_len)
lock = torch.zeros((1, args.max_len), dtype=torch.bool, device=device)
lock_probs = torch.zeros((1, args.max_len, tokenizer.vocab_size), dtype=torch.float32, device=device)
if prompt_ids:
ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)
sp = smooth_onehot(ids_t, tokenizer.vocab_size, args.target_prob, args.eps)[0]
init[0, : len(prompt_ids)] = sp
lock_probs[0, : len(prompt_ids)] = sp
lock[0, : len(prompt_ids)] = True
probs = init.clone()
last_endpoint = probs
records = []
cfg = DecodeConfig(
label="trace",
rule="flowmap",
steps=args.steps,
model_t_mode="flow",
damping=args.damping,
max_gamma=args.max_gamma,
endpoint_temp=args.endpoint_temp,
final_from=args.final_from,
)
for step in range(args.steps):
t = model_time_for_step(cfg.model_t_mode, step, cfg.steps, 1, device, dtype=torch.float32)
logits = model(state_for_model(model, probs, args.eps), t, attn).float()
logits = logits / float(cfg.endpoint_temp)
endpoint = F.softmax(logits, dim=-1)
last_endpoint = endpoint
gamma = flowmap_gamma(step, cfg.steps, cfg.damping, cfg.max_gamma, args.eps)
new_probs = probs + gamma * (endpoint - probs)
new_probs = new_probs.clamp_min(args.eps)
new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps)
probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs)
state_top_prob, state_ids = probs[0].max(dim=-1)
state_entropy = -(probs[0].clamp_min(args.eps) * probs[0].clamp_min(args.eps).log()).sum(dim=-1)
endpoint_top_prob, endpoint_ids = endpoint[0].max(dim=-1)
records.append(
{
"step": step,
"gamma": gamma,
"model_t": float(t.item()),
"text_prefix": decode_text(tokenizer, state_ids[:64].detach().cpu().tolist()),
"positions": [
{
"pos": pos,
"state_token": tokenizer.decode([int(state_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False),
"state_id": int(state_ids[pos].item()),
"state_top_p": float(state_top_prob[pos].item()),
"state_entropy": float(state_entropy[pos].item()),
"endpoint_token": tokenizer.decode([int(endpoint_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False),
"endpoint_id": int(endpoint_ids[pos].item()),
"endpoint_top_p": float(endpoint_top_prob[pos].item()),
}
for pos in range(args.max_len)
],
}
)
if args.final_from == "endpoint":
final_dist = torch.where(lock.unsqueeze(-1), lock_probs, last_endpoint)
elif args.final_from == "blend":
final_dist = torch.where(lock.unsqueeze(-1), lock_probs, 0.5 * probs + 0.5 * last_endpoint)
else:
final_dist = probs
final_dist = final_dist / final_dist.sum(dim=-1, keepdim=True).clamp_min(args.eps)
final_ids = final_dist[0].argmax(dim=-1).detach().cpu().tolist()
final_text = decode_text(tokenizer, final_ids)
payload = {
"checkpoint": args.checkpoint,
"seed": args.seed,
"prompt": args.prompt,
"candidate_index": args.candidate_index,
"steps": args.steps,
"endpoint_temp": args.endpoint_temp,
"damping": args.damping,
"max_gamma": args.max_gamma,
"final_from": args.final_from,
"prompt_ids": prompt_ids,
"final_ids": final_ids,
"final_text": final_text,
"records": records,
}
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps({"output": str(out), "final_text": final_text}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()
|