lta / LTA_openwebtext_dualt /scripts /flowtext_decode_lab.py
JinghuiLuAstronaut's picture
Add files using upload-large-folder tool
0241b9f verified
Raw
History Blame Contribute Delete
22.6 kB
#!/usr/bin/env python3
"""Decode-sweep lab for FlowText OpenWebText checkpoints.
The goal is to debug inference without touching training. We try several
simplex-valid update rules, generate many candidates, and rank them with
anti-collapse diagnostics instead of pure self-likelihood.
Run from the flowtext_standard_bench repository root.
"""
from __future__ import annotations
import argparse
import json
import math
import re
import sys
from collections import Counter
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Iterable, List, Sequence
import torch
import torch.nn.functional as F
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from eval import build_model_from_ckpt
from flowtext_lab.bridges import smooth_onehot
from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model
from flowtext_lab.tokenization import BpeTextTokenizer
WORD_RE = re.compile(r"[A-Za-z]+|\d+|[^\sA-Za-z\d]")
@dataclass
class DecodeConfig:
label: str
rule: str
steps: int = 64
model_t_mode: str = "flow"
eta: float = 0.5
damping: float = 1.0
max_gamma: float = 1.0
endpoint_temp: float = 1.0
state_floor: float = 1e-8
final_from: str = "state"
noise_mix: float = 0.0
noise_decay: str = "linear"
eos_logit_bias: float = 0.0
def tokenize_for_metrics(text: str) -> list[str]:
return WORD_RE.findall(text)
def repeated_ngram_frac(tokens: Sequence[str], n: int) -> float:
if len(tokens) < n:
return 0.0
grams = list(zip(*[tokens[i:] for i in range(n)]))
counts = Counter(grams)
return sum(v - 1 for v in counts.values() if v > 1) / max(len(grams), 1)
def text_metrics(text: str) -> dict:
toks = tokenize_for_metrics(text)
words = [t.lower() for t in toks if re.fullmatch(r"[A-Za-z]+", t)]
n_tok = max(len(toks), 1)
n_words = max(len(words), 1)
word_counts = Counter(words)
max_word_frac = word_counts.most_common(1)[0][1] / n_words if word_counts else 1.0
distinct1 = len(set(words)) / n_words if words else 0.0
bigrams = list(zip(words, words[1:]))
distinct2 = len(set(bigrams)) / max(len(bigrams), 1) if bigrams else 0.0
digit_frac = sum(t.isdigit() for t in toks) / n_tok
punct_frac = sum(bool(re.fullmatch(r"[,.;:!?]+", t)) for t in toks) / n_tok
eos_count = text.count("<|endoftext|>")
bad_char_count = text.count("�")
rep3 = repeated_ngram_frac([t.lower() for t in toks], 3)
rep4 = repeated_ngram_frac([t.lower() for t in toks], 4)
# This score is deliberately simple and non-oracle. It rewards length and
# lexical variety while heavily penalizing classic collapse artifacts.
quality = (
min(len(text) / 700.0, 1.0)
+ 0.35 * distinct2
+ 0.15 * distinct1
- 0.30 * eos_count
- 2.60 * rep3
- 1.60 * rep4
- 1.30 * digit_frac
- 0.65 * punct_frac
- 1.35 * max_word_frac
- 0.35 * bad_char_count
)
return {
"quality": float(quality),
"chars": len(text),
"tokens": len(toks),
"words": len(words),
"eos_count": eos_count,
"bad_char_count": bad_char_count,
"rep3": float(rep3),
"rep4": float(rep4),
"distinct1": float(distinct1),
"distinct2": float(distinct2),
"digit_frac": float(digit_frac),
"punct_frac": float(punct_frac),
"max_word_frac": float(max_word_frac),
}
def decode_text(tokenizer: BpeTextTokenizer, ids: Sequence[int]) -> str:
return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False)
def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]:
return list(tokenizer.tokenizer.encode(prompt).ids)[:max_len]
@torch.no_grad()
def build_initial_state(
tokenizer: BpeTextTokenizer,
prompts: list[str],
restarts: int,
max_len: int,
target_prob: float,
eps: float,
noise_init: str,
noise_sigma: float,
dirichlet_init_concentration: float,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]:
expanded: list[str] = []
prompt_ids: list[list[int]] = []
for prompt in prompts:
ids = encode_prompt(tokenizer, prompt, max_len=max_len)
for _ in range(restarts):
expanded.append(prompt)
prompt_ids.append(ids)
batch = len(prompt_ids)
attn = torch.ones((batch, max_len), dtype=torch.bool, device=device)
probs = sample_noise_simplex(
(batch, max_len),
tokenizer.vocab_size,
device,
eps,
noise_mode=noise_init,
target_prob=target_prob,
noise_sigma=noise_sigma,
dirichlet_concentration=dirichlet_init_concentration,
)
lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device)
lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device)
for row, ids in enumerate(prompt_ids):
if not ids:
continue
ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
sp = smooth_onehot(ids_t, tokenizer.vocab_size, target_prob, eps)[0]
probs[row, : len(ids)] = sp
lock_probs[row, : len(ids)] = sp
lock[row, : len(ids)] = True
return probs, attn, lock, lock_probs, expanded
def flowmap_gamma(step: int, steps: int, damping: float, max_gamma: float, eps: float) -> float:
s = step / max(steps, 1)
t_next = (step + 1) / max(steps, 1)
base_gamma = (t_next - s) / max(1.0 - s, eps)
gamma = float(damping) * base_gamma
return min(gamma, float(max_gamma)) if max_gamma > 0 else gamma
@torch.no_grad()
def decode_batch(
model,
init_probs: torch.Tensor,
attn: torch.Tensor,
lock: torch.Tensor,
lock_probs: torch.Tensor,
cfg: DecodeConfig,
eps: float,
eos_id: int | None = None,
) -> torch.Tensor:
probs = init_probs.float().clone()
device = probs.device
last_endpoint = probs
for step in range(cfg.steps):
t = model_time_for_step(cfg.model_t_mode, step, cfg.steps, probs.size(0), device, dtype=torch.float32)
logits = model(state_for_model(model, probs, eps), t, attn).float()
if cfg.endpoint_temp != 1.0:
logits = logits / float(cfg.endpoint_temp)
if cfg.eos_logit_bias != 0.0 and eos_id is not None and 0 <= eos_id < logits.size(-1):
logits[..., eos_id] = logits[..., eos_id] + float(cfg.eos_logit_bias)
endpoint = F.softmax(logits, dim=-1)
last_endpoint = endpoint
if cfg.rule == "flowmap":
gamma = flowmap_gamma(step, cfg.steps, cfg.damping, cfg.max_gamma, eps)
new_probs = probs + gamma * (endpoint - probs)
elif cfg.rule == "replace":
new_probs = (1.0 - cfg.eta) * probs + cfg.eta * endpoint
elif cfg.rule == "geometric":
log_mix = (1.0 - cfg.eta) * torch.log(probs.clamp_min(eps)) + cfg.eta * torch.log(endpoint.clamp_min(eps))
new_probs = F.softmax(log_mix, dim=-1)
elif cfg.rule == "centered_residual":
# Add a zero-sum probability residual, then project back to simplex.
residual = endpoint - probs
residual = residual - residual.mean(dim=-1, keepdim=True)
new_probs = probs + cfg.eta * residual
else:
raise ValueError(f"Unknown decode rule: {cfg.rule}")
if cfg.noise_mix > 0:
if cfg.noise_decay == "linear":
lam = cfg.noise_mix * (1.0 - (step + 1) / max(cfg.steps, 1))
elif cfg.noise_decay == "sqrt":
lam = cfg.noise_mix * math.sqrt(max(0.0, 1.0 - (step + 1) / max(cfg.steps, 1)))
else:
lam = cfg.noise_mix
if lam > 0:
uniform = torch.full_like(new_probs, 1.0 / new_probs.size(-1))
new_probs = (1.0 - lam) * new_probs + lam * uniform
new_probs = new_probs.clamp_min(max(float(cfg.state_floor), eps))
new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(eps)
new_probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs)
probs = new_probs
if cfg.final_from == "endpoint":
out = last_endpoint
out = torch.where(lock.unsqueeze(-1), lock_probs, out)
return out / out.sum(dim=-1, keepdim=True).clamp_min(eps)
if cfg.final_from == "blend":
out = 0.5 * probs + 0.5 * last_endpoint
out = torch.where(lock.unsqueeze(-1), lock_probs, out)
return out / out.sum(dim=-1, keepdim=True).clamp_min(eps)
return probs
@torch.no_grad()
def pseudo_likelihood_scores(
model,
tokenizer: BpeTextTokenizer,
probs: torch.Tensor,
attn: torch.Tensor,
lock: torch.Tensor,
target_prob: float,
eps: float,
repeats: int,
mask_frac: float,
rerank_t: float,
) -> torch.Tensor:
ids = probs.argmax(dim=-1)
endpoint = smooth_onehot(ids, tokenizer.vocab_size, target_prob, eps)
eligible = attn & (~lock)
scores = torch.zeros(ids.size(0), dtype=torch.float32, device=ids.device)
counts = torch.zeros_like(scores)
for _ in range(max(1, repeats)):
score_mask = (torch.rand_like(ids.float()) < mask_frac) & eligible
for row in range(ids.size(0)):
if eligible[row].any() and not score_mask[row].any():
choices = torch.nonzero(eligible[row], as_tuple=False).flatten()
score_mask[row, choices[torch.randint(0, choices.numel(), (1,), device=ids.device)]] = True
noise = sample_noise_simplex(
(ids.size(0), ids.size(1)),
tokenizer.vocab_size,
ids.device,
eps,
noise_mode="logistic_normal",
target_prob=target_prob,
noise_sigma=-1.0,
)
inp = torch.where(score_mask.unsqueeze(-1), noise, endpoint)
inp = torch.where(lock.unsqueeze(-1), probs, inp)
t = torch.full((ids.size(0),), float(rerank_t), dtype=torch.float32, device=ids.device)
logits = model(state_for_model(model, inp, eps), t, attn).float()
logp = F.log_softmax(logits, dim=-1).gather(-1, ids.unsqueeze(-1)).squeeze(-1)
scores += (logp * score_mask.float()).sum(dim=-1)
counts += score_mask.float().sum(dim=-1)
return scores / counts.clamp_min(1.0)
def default_configs(steps: int, config_set: str) -> list[DecodeConfig]:
if config_set == "focused_flowmap":
return [
DecodeConfig("flowmap_t1p00_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0),
DecodeConfig("flowmap_t1p10_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.10),
DecodeConfig("flowmap_t1p25_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.25),
DecodeConfig("flowmap_t1p40_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40),
DecodeConfig("flowmap_t1p60_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.60),
DecodeConfig("flowmap_t1p25_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25),
DecodeConfig("flowmap_t1p40_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.40),
DecodeConfig("flowmap_t1p60_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.60),
DecodeConfig("flowmap_t1p25_g0p5", "flowmap", steps=steps, damping=1.0, max_gamma=0.5, endpoint_temp=1.25),
DecodeConfig("flowmap_t1p40_g0p5", "flowmap", steps=steps, damping=1.0, max_gamma=0.5, endpoint_temp=1.40),
]
if config_set == "best_flowmap":
return [
DecodeConfig("flowmap_t1p25_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25),
DecodeConfig("flowmap_t1p25_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.25),
DecodeConfig("flowmap_t1p35_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35),
DecodeConfig("flowmap_t1p40_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40),
]
if config_set == "final_projection":
return [
DecodeConfig("flowmap_t1p35_state", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, final_from="state"),
DecodeConfig("flowmap_t1p35_endpoint", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, final_from="endpoint"),
DecodeConfig("flowmap_t1p35_blend", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, final_from="blend"),
DecodeConfig("flowmap_t1p40_state", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, final_from="state"),
DecodeConfig("flowmap_t1p40_endpoint", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, final_from="endpoint"),
DecodeConfig("flowmap_t1p40_blend", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, final_from="blend"),
DecodeConfig("flowmap_t1p25_d0p7_state", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, final_from="state"),
DecodeConfig("flowmap_t1p25_d0p7_endpoint", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, final_from="endpoint"),
DecodeConfig("flowmap_t1p25_d0p7_blend", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, final_from="blend"),
]
if config_set == "eos_sweep":
return [
DecodeConfig("flowmap_t1p35_eos0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=0.0),
DecodeConfig("flowmap_t1p35_eos-1", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=-1.0),
DecodeConfig("flowmap_t1p35_eos-2", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=-2.0),
DecodeConfig("flowmap_t1p35_eos-3", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=-3.0),
DecodeConfig("flowmap_t1p40_eos-2", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, eos_logit_bias=-2.0),
DecodeConfig("flowmap_t1p25_d0p7_eos-2", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, eos_logit_bias=-2.0),
]
if config_set != "broad":
raise ValueError(f"Unknown config_set: {config_set}")
return [
DecodeConfig("flowmap64", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, final_from="state"),
DecodeConfig("flowmap_temp1p25", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.25),
DecodeConfig("flowmap_temp0p85", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=0.85),
DecodeConfig("replace_eta0p35", "replace", steps=steps, eta=0.35),
DecodeConfig("replace_eta0p50", "replace", steps=steps, eta=0.50),
DecodeConfig("replace_eta0p65", "replace", steps=steps, eta=0.65),
DecodeConfig("replace_eta0p50_temp1p25", "replace", steps=steps, eta=0.50, endpoint_temp=1.25),
DecodeConfig("geometric_eta0p25", "geometric", steps=steps, eta=0.25),
DecodeConfig("geometric_eta0p50", "geometric", steps=steps, eta=0.50),
DecodeConfig("centered_residual_eta0p20", "centered_residual", steps=steps, eta=0.20),
DecodeConfig("replace_eta0p50_floor1e6", "replace", steps=steps, eta=0.50, state_floor=1e-6),
DecodeConfig("replace_eta0p50_leak", "replace", steps=steps, eta=0.50, noise_mix=0.03, noise_decay="sqrt"),
]
def aggregate(rows: list[dict]) -> dict:
keys = ["quality", "eos_count", "rep3", "rep4", "distinct1", "distinct2", "digit_frac", "max_word_frac"]
return {f"mean_{k}": sum(float(r[k]) for r in rows) / max(len(rows), 1) for k in keys}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--tokenizer_path", required=True)
parser.add_argument("--max_len", type=int, default=128)
parser.add_argument("--steps", type=int, default=64)
parser.add_argument("--restarts", type=int, default=64)
parser.add_argument("--target_prob", type=float, default=0.99)
parser.add_argument("--eps", type=float, default=1e-8)
parser.add_argument("--model_t_mode", choices=["linear", "flow", "const0", "const05", "const1", "random"], default="flow")
parser.add_argument("--noise_init", choices=["uniform", "logistic_normal", "dirichlet"], default="dirichlet")
parser.add_argument("--noise_sigma", type=float, default=-1.0)
parser.add_argument("--dirichlet_init_concentration", type=float, default=1.0)
parser.add_argument("--prompts", default="|The|In the early morning|Scientists have|The company said|A young woman")
parser.add_argument("--score_repeats", type=int, default=0)
parser.add_argument("--score_mask_frac", type=float, default=0.5)
parser.add_argument("--rerank_t", type=float, default=0.5)
parser.add_argument("--pl_weight", type=float, default=0.0)
parser.add_argument("--output", default="runs/decode_lab/latest_decode_lab.jsonl")
parser.add_argument("--config_set", default="broad", choices=["broad", "focused_flowmap", "best_flowmap", "final_projection", "eos_sweep"])
parser.add_argument("--decode_batch_size", type=int, default=0)
parser.add_argument("--topk", type=int, default=5)
parser.add_argument("--seed", type=int, default=20260428)
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
ckpt = torch.load(args.checkpoint, map_location="cpu")
model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device)
model.eval()
prompts = args.prompts.split("|")
# Keep the first empty prompt: it is unconditional generation.
print(f"[info] device={device} prompts={prompts} restarts={args.restarts} steps={args.steps}")
print(f"[info] checkpoint={args.checkpoint}")
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
configs = default_configs(args.steps, args.config_set)
for cfg in configs:
cfg.model_t_mode = args.model_t_mode
with out_path.open("w") as f:
for cfg in configs:
init, attn, lock, lock_probs, expanded = build_initial_state(
tokenizer=tokenizer,
prompts=prompts,
restarts=args.restarts,
max_len=args.max_len,
target_prob=args.target_prob,
eps=args.eps,
noise_init=args.noise_init,
noise_sigma=args.noise_sigma,
dirichlet_init_concentration=args.dirichlet_init_concentration,
device=device,
)
if args.decode_batch_size > 0 and init.size(0) > args.decode_batch_size:
decoded_parts = []
for start in range(0, init.size(0), args.decode_batch_size):
end = min(start + args.decode_batch_size, init.size(0))
part = decode_batch(
model,
init[start:end],
attn[start:end],
lock[start:end],
lock_probs[start:end],
cfg,
args.eps,
tokenizer.eos_id,
)
decoded_parts.append(part.detach().cpu())
print(f"[chunk] {cfg.label} decoded {end}/{init.size(0)}", flush=True)
decoded = torch.cat(decoded_parts, dim=0)
else:
decoded = decode_batch(model, init, attn, lock, lock_probs, cfg, args.eps, tokenizer.eos_id)
ids = decoded.argmax(dim=-1).detach().cpu().tolist()
texts = [decode_text(tokenizer, row) for row in ids]
rows = []
for i, text in enumerate(texts):
m = text_metrics(text)
m.update({"candidate": i, "prompt": expanded[i], "text": text})
rows.append(m)
if args.score_repeats > 0:
decoded_for_score = decoded.to(device) if decoded.device != device else decoded
pl = pseudo_likelihood_scores(
model,
tokenizer,
decoded_for_score,
attn,
lock,
args.target_prob,
args.eps,
repeats=args.score_repeats,
mask_frac=args.score_mask_frac,
rerank_t=args.rerank_t,
).detach().cpu().tolist()
for row, score in zip(rows, pl):
row["pseudo_logp"] = float(score)
row["rank_score"] = float(row["quality"] + args.pl_weight * score)
else:
for row in rows:
row["pseudo_logp"] = None
row["rank_score"] = float(row["quality"])
summary = {"type": "summary", "config": asdict(cfg), "agg": aggregate(rows)}
f.write(json.dumps(summary, ensure_ascii=False) + "\n")
print("\n" + "=" * 96)
print("[config]", cfg.label, asdict(cfg))
print("[metrics]", json.dumps(summary["agg"], ensure_ascii=False))
for prompt in prompts:
subset = [r for r in rows if r["prompt"] == prompt]
subset.sort(key=lambda r: r["rank_score"], reverse=True)
for rank, row in enumerate(subset[: args.topk], 1):
rec = {"type": "sample", "config": asdict(cfg), "rank": rank, **row}
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
if rank <= 1:
print(f"\n--- best prompt={prompt!r} rank_score={row['rank_score']:.4f} quality={row['quality']:.4f} ---")
print(row["text"])
del init, attn, lock, lock_probs, decoded
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"[done] wrote {out_path}")
if __name__ == "__main__":
main()