lta / LTA_openwebtext_dualt /scripts /_tmp_trace_lta_stepcompare_candidate.py
JinghuiLuAstronaut's picture
Add files using upload-large-folder tool
0241b9f verified
Raw
History Blame Contribute Delete
9.72 kB
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sys
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn.functional as F
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from eval import build_model_from_ckpt
from flowtext_lab.bridges import smooth_onehot
from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model
from flowtext_lab.tokenization import BpeTextTokenizer
@dataclass
class DecodeConfig:
label: str
steps: int
damping: float = 1.0
max_gamma: float = 1.0
endpoint_temp: float = 1.0
final_from: str = "state"
def focused_configs(steps: int) -> list[DecodeConfig]:
return [
DecodeConfig("flowmap_t1p00_d1p0", steps, endpoint_temp=1.00, damping=1.0),
DecodeConfig("flowmap_t1p10_d1p0", steps, endpoint_temp=1.10, damping=1.0),
DecodeConfig("flowmap_t1p25_d1p0", steps, endpoint_temp=1.25, damping=1.0),
DecodeConfig("flowmap_t1p40_d1p0", steps, endpoint_temp=1.40, damping=1.0),
DecodeConfig("flowmap_t1p60_d1p0", steps, endpoint_temp=1.60, damping=1.0),
DecodeConfig("flowmap_t1p25_d0p7", steps, endpoint_temp=1.25, damping=0.7),
DecodeConfig("flowmap_t1p40_d0p7", steps, endpoint_temp=1.40, damping=0.7),
DecodeConfig("flowmap_t1p60_d0p7", steps, endpoint_temp=1.60, damping=0.7),
DecodeConfig("flowmap_t1p25_g0p5", steps, endpoint_temp=1.25, damping=1.0, max_gamma=0.5),
DecodeConfig("flowmap_t1p40_g0p5", steps, endpoint_temp=1.40, damping=1.0, max_gamma=0.5),
]
def flowmap_gamma(step: int, steps: int, damping: float, max_gamma: float, eps: float) -> float:
s = step / max(steps, 1)
t_next = (step + 1) / max(steps, 1)
base = (t_next - s) / max(1.0 - s, eps)
gamma = float(damping) * base
return min(gamma, float(max_gamma)) if max_gamma > 0 else gamma
def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]:
core = list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids)
bos = tokenizer.bos_id
ids = ([bos] if bos is not None and bos >= 0 else []) + core
return ids[:max_len]
def decode_text(tokenizer: BpeTextTokenizer, ids: list[int]) -> str:
return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False)
def build_initial_state(
tokenizer: BpeTextTokenizer,
prompts: list[str],
restarts: int,
max_len: int,
target_prob: float,
eps: float,
noise_init: str,
dirichlet_concentration: float,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]:
expanded: list[str] = []
prompt_ids: list[list[int]] = []
for prompt in prompts:
ids = encode_prompt(tokenizer, prompt, max_len)
for _ in range(restarts):
expanded.append(prompt)
prompt_ids.append(ids)
batch = len(prompt_ids)
probs = sample_noise_simplex(
(batch, max_len),
tokenizer.vocab_size,
device,
eps,
noise_mode=noise_init,
target_prob=target_prob,
noise_sigma=-1.0,
dirichlet_concentration=dirichlet_concentration,
).float()
attn = torch.ones((batch, max_len), dtype=torch.bool, device=device)
lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device)
lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device)
for row, ids in enumerate(prompt_ids):
if not ids:
continue
ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
sp = smooth_onehot(ids_t, tokenizer.vocab_size, target_prob, eps)[0]
probs[row, : len(ids)] = sp
lock_probs[row, : len(ids)] = sp
lock[row, : len(ids)] = True
return probs, attn, lock, lock_probs, expanded
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", required=True)
p.add_argument("--tokenizer_path", required=True)
p.add_argument("--output", required=True)
p.add_argument("--prompts", required=True)
p.add_argument("--prompt", required=True)
p.add_argument("--restarts", type=int, default=20)
p.add_argument("--candidate_index", type=int, required=True)
p.add_argument("--steps", type=int, required=True)
p.add_argument("--config_label", required=True)
p.add_argument("--max_len", type=int, default=128)
p.add_argument("--seed", type=int, default=20260502)
p.add_argument("--target_prob", type=float, default=1.0)
p.add_argument("--noise_init", default="dirichlet")
p.add_argument("--dirichlet_init_concentration", type=float, default=1.0)
p.add_argument("--eps", type=float, default=1e-8)
return p.parse_args()
@torch.no_grad()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
ckpt = torch.load(args.checkpoint, map_location="cpu")
model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device)
model.eval()
prompts = args.prompts.split("|")
configs = focused_configs(args.steps)
selected_cfg = None
init = attn = lock = lock_probs = None
expanded: list[str] = []
# Reproduce the decode-sweep RNG stream: every config samples a fresh initial
# batch. We consume the same initial batches until the requested config.
for cfg in configs:
init, attn, lock, lock_probs, expanded = build_initial_state(
tokenizer,
prompts,
args.restarts,
args.max_len,
args.target_prob,
args.eps,
args.noise_init,
args.dirichlet_init_concentration,
device,
)
if cfg.label == args.config_label:
selected_cfg = cfg
break
del init, attn, lock, lock_probs
if selected_cfg is None or init is None or attn is None or lock is None or lock_probs is None:
raise ValueError(f"unknown config_label {args.config_label}")
if expanded[args.candidate_index] != args.prompt:
raise ValueError(
f"candidate prompt mismatch: candidate={args.candidate_index} has {expanded[args.candidate_index]!r}, expected {args.prompt!r}"
)
sl = slice(args.candidate_index, args.candidate_index + 1)
probs = init[sl].clone()
attn = attn[sl]
lock = lock[sl]
lock_probs = lock_probs[sl]
last_endpoint = probs
records = []
for step in range(selected_cfg.steps):
t = model_time_for_step("flow", step, selected_cfg.steps, 1, device, dtype=torch.float32)
logits = model(state_for_model(model, probs, args.eps), t, attn).float()
logits = logits / float(selected_cfg.endpoint_temp)
endpoint = F.softmax(logits, dim=-1)
last_endpoint = endpoint
gamma = flowmap_gamma(step, selected_cfg.steps, selected_cfg.damping, selected_cfg.max_gamma, args.eps)
new_probs = probs + gamma * (endpoint - probs)
new_probs = new_probs.clamp_min(args.eps)
new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps)
probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs)
state_top_prob, state_ids = probs[0].max(dim=-1)
endpoint_top_prob, endpoint_ids = endpoint[0].max(dim=-1)
records.append(
{
"step": step,
"gamma": gamma,
"model_t": float(t.item()),
"state_text": decode_text(tokenizer, state_ids.detach().cpu().tolist()),
"endpoint_text": decode_text(tokenizer, endpoint_ids.detach().cpu().tolist()),
"positions": [
{
"pos": pos,
"state_token": tokenizer.decode([int(state_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False),
"state_id": int(state_ids[pos].item()),
"state_top_p": float(state_top_prob[pos].item()),
"endpoint_token": tokenizer.decode([int(endpoint_ids[pos].item())], stop_at_eos=False, skip_special_tokens=False),
"endpoint_id": int(endpoint_ids[pos].item()),
"endpoint_top_p": float(endpoint_top_prob[pos].item()),
}
for pos in range(args.max_len)
],
}
)
if selected_cfg.final_from == "endpoint":
final_dist = torch.where(lock.unsqueeze(-1), lock_probs, last_endpoint)
else:
final_dist = probs
final_dist = final_dist / final_dist.sum(dim=-1, keepdim=True).clamp_min(args.eps)
final_ids = final_dist[0].argmax(dim=-1).detach().cpu().tolist()
payload = {
"checkpoint": args.checkpoint,
"seed": args.seed,
"prompts": prompts,
"prompt": args.prompt,
"restarts": args.restarts,
"candidate_index": args.candidate_index,
"steps": args.steps,
"config": selected_cfg.__dict__,
"final_ids": final_ids,
"final_text": decode_text(tokenizer, final_ids),
"records": records,
}
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8")
print(json.dumps({"output": str(out), "final": payload["final_text"]}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()