lta / LTA_openwebtext_dualt /scripts /dirichlet_support_decode_probe.py
JinghuiLuAstronaut's picture
Add files using upload-large-folder tool
6bc5b2c verified
Raw
History Blame Contribute Delete
13.2 kB
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import math
import re
import sys
from collections import Counter
from pathlib import Path
from typing import Sequence
import torch
import torch.nn.functional as F
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from eval import build_model_from_ckpt
from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model
from flowtext_lab.tokenization import BpeTextTokenizer
WORD_RE = re.compile(r"[A-Za-z]+|\d+|[^\sA-Za-z\d]")
def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]:
return list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids)[:max_len]
def decode_text(tokenizer: BpeTextTokenizer, ids: Sequence[int]) -> str:
return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False)
def text_metrics(text: str) -> dict[str, float]:
toks = WORD_RE.findall(text)
words = [t.lower() for t in toks if re.fullmatch(r"[A-Za-z]+", t)]
n_tok = max(len(toks), 1)
n_words = max(len(words), 1)
wc = Counter(words)
max_word_frac = wc.most_common(1)[0][1] / n_words if wc else 1.0
grams3 = list(zip(toks, toks[1:], toks[2:]))
rep3 = sum(v - 1 for v in Counter(grams3).values() if v > 1) / max(len(grams3), 1)
bigrams = list(zip(words, words[1:]))
distinct2 = len(set(bigrams)) / max(len(bigrams), 1) if bigrams else 0.0
punct_frac = sum(bool(re.fullmatch(r"[,.;:!?]+", t)) for t in toks) / n_tok
digit_frac = sum(t.isdigit() for t in toks) / n_tok
quality = (
min(len(text) / 700.0, 1.0)
+ 0.35 * distinct2
- 2.6 * rep3
- 1.2 * max_word_frac
- 0.8 * punct_frac
- 1.0 * digit_frac
- 0.2 * text.count("<|endoftext|>")
- 0.5 * text.count("�")
)
return {
"quality": float(quality),
"chars": float(len(text)),
"words": float(len(words)),
"rep3": float(rep3),
"distinct2": float(distinct2),
"punct_frac": float(punct_frac),
"max_word_frac": float(max_word_frac),
"eot_count": float(text.count("<|endoftext|>")),
}
def dirichlet_mean(endpoint: torch.Tensor, support_t: float, eps: float) -> torch.Tensor:
vocab = endpoint.size(-1)
mean = (1.0 - support_t) / float(vocab) + support_t * endpoint
mean = mean.clamp_min(eps)
return mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps)
def total_concentration(support_t: float, c_min: float, c_max: float) -> float:
log_min = math.log(max(c_min, 1e-8))
log_max = math.log(max(c_max, c_min))
return math.exp(log_min + support_t * (log_max - log_min))
def dirichlet_resample(mean: torch.Tensor, support_t: float, c_min: float, c_max: float, eps: float) -> torch.Tensor:
conc = total_concentration(support_t, c_min, c_max)
alpha = (mean * conc).clamp_min(eps)
sample = torch._standard_gamma(alpha).clamp_min(eps)
return sample / sample.sum(dim=-1, keepdim=True).clamp_min(eps)
def schedule_power(step: int, steps: int, power: float) -> float:
base = (step + 1) / max(steps, 1)
return float(max(0.0, min(1.0, base ** float(power))))
def current_anchor(probs: torch.Tensor, mode: str, eps: float) -> torch.Tensor:
if mode == "state":
return probs
if mode == "onehot":
ids = probs.argmax(dim=-1)
return F.one_hot(ids, probs.size(-1)).to(dtype=probs.dtype, device=probs.device)
if mode == "sqrt_state":
x = probs.clamp_min(eps).sqrt()
return x / x.sum(dim=-1, keepdim=True).clamp_min(eps)
raise ValueError(f"unknown anchor mode: {mode}")
@torch.no_grad()
def build_initial(
tokenizer: BpeTextTokenizer,
prompts: list[str],
restarts: int,
max_len: int,
eps: float,
noise_init: str,
target_prob: float,
noise_sigma: float,
dirichlet_concentration: float,
lock_bos: bool,
lock_final_eos: bool,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]:
expanded = []
prompt_ids = []
for prompt in prompts:
ids = encode_prompt(tokenizer, prompt, max_len)
if lock_bos:
ids = [tokenizer.bos_id] + ids
ids = ids[:max_len]
for _ in range(restarts):
expanded.append(prompt)
prompt_ids.append(ids)
batch = len(prompt_ids)
probs = sample_noise_simplex(
(batch, max_len),
tokenizer.vocab_size,
device,
eps,
noise_mode=noise_init,
target_prob=target_prob,
noise_sigma=noise_sigma,
dirichlet_concentration=dirichlet_concentration,
)
lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device)
lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device)
for row, ids in enumerate(prompt_ids):
if not ids:
continue
ids_t = torch.tensor(ids, dtype=torch.long, device=device)
onehot = F.one_hot(ids_t, tokenizer.vocab_size).float()
probs[row, : len(ids)] = onehot
lock_probs[row, : len(ids)] = onehot
lock[row, : len(ids)] = True
if lock_final_eos:
eos = torch.tensor([tokenizer.eos_id], dtype=torch.long, device=device)
eos_prob = F.one_hot(eos, tokenizer.vocab_size).float()[0]
probs[:, -1] = eos_prob
lock_probs[:, -1] = eos_prob
lock[:, -1] = True
attn = torch.ones((batch, max_len), dtype=torch.bool, device=device)
return probs, lock, lock_probs, attn, expanded
@torch.no_grad()
def decode_one_config(
model,
tokenizer,
init,
lock,
lock_probs,
attn,
args,
update: str,
final_from: str,
temp: float,
model_t_mode: str,
support_power: float,
semantic_power: float,
anchor_mode: str,
):
probs = init.clone()
last_endpoint = probs
device = probs.device
for step in range(args.steps):
model_t = model_time_for_step(model_t_mode, step, args.steps, probs.size(0), device, dtype=torch.float32)
logits = model(state_for_model(model, probs, args.eps), model_t, attn).float() / temp
endpoint = F.softmax(logits, dim=-1)
last_endpoint = endpoint
support_t = schedule_power(step, args.steps, support_power)
semantic_t = schedule_power(step, args.steps, semantic_power)
if update.startswith("dual_line"):
anchor = current_anchor(probs, anchor_mode, args.eps)
forward_endpoint = (1.0 - semantic_t) * anchor + semantic_t * endpoint
forward_endpoint = forward_endpoint / forward_endpoint.sum(dim=-1, keepdim=True).clamp_min(args.eps)
else:
forward_endpoint = endpoint
mean = dirichlet_mean(forward_endpoint, support_t, args.eps)
if update == "mean":
new_probs = mean
elif update == "resample":
new_probs = dirichlet_resample(mean, support_t, args.concentration_min, args.concentration_max, args.eps)
elif update == "dual_line_mean":
new_probs = mean
elif update == "dual_line_resample":
new_probs = dirichlet_resample(mean, support_t, args.concentration_min, args.concentration_max, args.eps)
elif update == "ema_mean":
gamma = 1.0 / max(args.steps - step, 1)
new_probs = (1.0 - gamma) * probs + gamma * mean
new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps)
else:
raise ValueError(update)
probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs)
if final_from == "endpoint":
out = last_endpoint
elif final_from == "blend":
out = 0.5 * probs + 0.5 * last_endpoint
else:
out = probs
out = torch.where(lock.unsqueeze(-1), lock_probs, out)
return out / out.sum(dim=-1, keepdim=True).clamp_min(args.eps)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", required=True)
ap.add_argument("--tokenizer_path", required=True)
ap.add_argument("--output", required=True)
ap.add_argument("--max_len", type=int, default=256)
ap.add_argument("--steps", type=int, default=256)
ap.add_argument("--restarts", type=int, default=4)
ap.add_argument("--prompts", nargs="+", default=[""])
ap.add_argument("--noise_init", default="dirichlet")
ap.add_argument("--target_prob", type=float, default=0.99)
ap.add_argument("--noise_sigma", type=float, default=-1.0)
ap.add_argument("--dirichlet_init_concentration", type=float, default=1.0)
ap.add_argument("--concentration_min", type=float, default=1.0)
ap.add_argument("--concentration_max", type=float, default=1024.0)
ap.add_argument("--updates", nargs="+", default=["mean", "ema_mean", "resample"])
ap.add_argument("--finals", nargs="+", default=["state", "endpoint", "blend"])
ap.add_argument("--temps", nargs="+", type=float, default=[1.0, 1.2, 1.35])
ap.add_argument("--model_t_modes", nargs="+", default=["flow", "const05"])
ap.add_argument("--support_powers", nargs="+", type=float, default=[1.0])
ap.add_argument("--semantic_powers", nargs="+", type=float, default=[1.0])
ap.add_argument("--anchor_modes", nargs="+", default=["onehot"])
ap.add_argument("--lock_bos", action="store_true")
ap.add_argument("--lock_final_eos", action="store_true")
ap.add_argument("--eps", type=float, default=1e-8)
ap.add_argument("--seed", type=int, default=1234)
args = ap.parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
ckpt = torch.load(args.checkpoint, map_location=device)
model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device)
model.eval()
init, lock, lock_probs, attn, expanded = build_initial(
tokenizer,
args.prompts,
args.restarts,
args.max_len,
args.eps,
args.noise_init,
args.target_prob,
args.noise_sigma,
args.dirichlet_init_concentration,
args.lock_bos,
args.lock_final_eos,
device,
)
configs = []
for update in args.updates:
for final_from in args.finals:
for temp in args.temps:
for model_t_mode in args.model_t_modes:
for support_power in args.support_powers:
for semantic_power in args.semantic_powers:
for anchor_mode in args.anchor_modes:
configs.append((update, final_from, temp, model_t_mode, support_power, semantic_power, anchor_mode))
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
rows = []
with out_path.open("w") as f:
for update, final_from, temp, model_t_mode, support_power, semantic_power, anchor_mode in configs:
probs = decode_one_config(
model,
tokenizer,
init,
lock,
lock_probs,
attn,
args,
update,
final_from,
temp,
model_t_mode,
support_power,
semantic_power,
anchor_mode,
)
ids = probs.argmax(dim=-1).detach().cpu().tolist()
texts = [decode_text(tokenizer, row) for row in ids]
mets = [text_metrics(t) for t in texts]
mean_q = sum(m["quality"] for m in mets) / len(mets)
best_i = max(range(len(texts)), key=lambda i: mets[i]["quality"])
row = {
"update": update,
"final_from": final_from,
"temp": temp,
"model_t_mode": model_t_mode,
"support_power": support_power,
"semantic_power": semantic_power,
"anchor_mode": anchor_mode,
"mean_quality": mean_q,
"best_prompt": expanded[best_i],
"best_metrics": mets[best_i],
"best_text": texts[best_i],
}
rows.append(row)
print(
"\n====",
update,
final_from,
temp,
model_t_mode,
"support_p",
support_power,
"semantic_p",
semantic_power,
"anchor",
anchor_mode,
"mean_q",
round(mean_q, 4),
flush=True,
)
print(texts[best_i][:1600], flush=True)
f.write(json.dumps(row, ensure_ascii=False) + "\n")
f.flush()
best = max(rows, key=lambda r: r["mean_quality"])
print("\nBEST", json.dumps({k: best[k] for k in best if k != "best_text"}, ensure_ascii=False, indent=2), flush=True)
print(best["best_text"], flush=True)
if __name__ == "__main__":
main()