lta / LTA_openwebtext_dualt /scripts /_tmp_trace_lta_prompt_decode.py
JinghuiLuAstronaut's picture
Add files using upload-large-folder tool
edff6fa verified
Raw
History Blame Contribute Delete
6.78 kB
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from eval import build_model_from_ckpt
from flowtext_lab.bridges import smooth_onehot
from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model
from flowtext_lab.tokenization import BpeTextTokenizer
from scripts.flowtext_decode_lab import DecodeConfig, decode_text, flowmap_gamma
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", required=True)
p.add_argument("--tokenizer_path", required=True)
p.add_argument("--output", required=True)
p.add_argument("--prompt", required=True)
p.add_argument("--candidate_index", type=int, required=True)
p.add_argument("--max_len", type=int, default=128)
p.add_argument("--steps", type=int, default=128)
p.add_argument("--seed", type=int, default=20260502)
p.add_argument("--target_prob", type=float, default=1.0)
p.add_argument("--endpoint_temp", type=float, default=1.4)
p.add_argument("--damping", type=float, default=1.0)
p.add_argument("--max_gamma", type=float, default=1.0)
p.add_argument("--final_from", choices=["state", "endpoint", "blend"], default="state")
p.add_argument("--eps", type=float, default=1e-8)
return p.parse_args()
def encode_prefix(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]:
core = list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids)
bos = tokenizer.bos_id
ids = ([bos] if bos is not None and bos >= 0 else []) + core
return ids[:max_len]
@torch.no_grad()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
ckpt = torch.load(args.checkpoint, map_location="cpu")
model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device)
model.eval()
init = sample_noise_simplex(
(args.candidate_index + 1, args.max_len),
tokenizer.vocab_size,
device,
args.eps,
noise_mode="dirichlet",
target_prob=args.target_prob,
noise_sigma=-1.0,
dirichlet_concentration=1.0,
)[-1:].float()
attn = torch.ones((1, args.max_len), dtype=torch.bool, device=device)
prompt_ids = encode_prefix(tokenizer, args.prompt, args.max_len)
lock = torch.zeros((1, args.max_len), dtype=torch.bool, device=device)
lock_probs = torch.zeros((1, args.max_len, tokenizer.vocab_size), dtype=torch.float32, device=device)
if prompt_ids:
ids_t = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)
sp = smooth_onehot(ids_t, tokenizer.vocab_size, args.target_prob, args.eps)[0]
init[0, : len(prompt_ids)] = sp
lock_probs[0, : len(prompt_ids)] = sp
lock[0, : len(prompt_ids)] = True
probs = init.clone()
last_endpoint = probs
records = []
cfg = DecodeConfig(
label="trace",
rule="flowmap",
steps=args.steps,
model_t_mode="flow",
damping=args.damping,
max_gamma=args.max_gamma,
endpoint_temp=args.endpoint_temp,
final_from=args.final_from,
)
for step in range(args.steps):
t = model_time_for_step(cfg.model_t_mode, step, cfg.steps, 1, device, dtype=torch.float32)
logits = model(state_for_model(model, probs, args.eps), t, attn).float()
logits = logits / float(cfg.endpoint_temp)
endpoint = F.softmax(logits, dim=-1)
last_endpoint = endpoint
gamma = flowmap_gamma(step, cfg.steps, cfg.damping, cfg.max_gamma, args.eps)
new_probs = probs + gamma * (endpoint - probs)
new_probs = new_probs.clamp_min(args.eps)
new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps)
probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs)
state_top_prob, state_ids = probs[0].max(dim=-1)
state_entropy = -(probs[0].clamp_min(args.eps) * probs[0].clamp_min(args.eps).log()).sum(dim=-1)
endpoint_top_prob, endpoint_ids = endpoint[0].max(dim=-1)
records.append(
{
"step": step,
"gamma": gamma,
"model_t": float(t.item()),
"text_prefix": decode_text(tokenizer, state_ids[:64].detach().cpu().tolist()),
"positions": [
{
"pos": pos,
"state_token": tokenizer.decode([int(state_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False),
"state_id": int(state_ids[pos].item()),
"state_top_p": float(state_top_prob[pos].item()),
"state_entropy": float(state_entropy[pos].item()),
"endpoint_token": tokenizer.decode([int(endpoint_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False),
"endpoint_id": int(endpoint_ids[pos].item()),
"endpoint_top_p": float(endpoint_top_prob[pos].item()),
}
for pos in range(args.max_len)
],
}
)
if args.final_from == "endpoint":
final_dist = torch.where(lock.unsqueeze(-1), lock_probs, last_endpoint)
elif args.final_from == "blend":
final_dist = torch.where(lock.unsqueeze(-1), lock_probs, 0.5 * probs + 0.5 * last_endpoint)
else:
final_dist = probs
final_dist = final_dist / final_dist.sum(dim=-1, keepdim=True).clamp_min(args.eps)
final_ids = final_dist[0].argmax(dim=-1).detach().cpu().tolist()
final_text = decode_text(tokenizer, final_ids)
payload = {
"checkpoint": args.checkpoint,
"seed": args.seed,
"prompt": args.prompt,
"candidate_index": args.candidate_index,
"steps": args.steps,
"endpoint_temp": args.endpoint_temp,
"damping": args.damping,
"max_gamma": args.max_gamma,
"final_from": args.final_from,
"prompt_ids": prompt_ids,
"final_ids": final_ids,
"final_text": final_text,
"records": records,
}
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps({"output": str(out), "final_text": final_text}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()