lta / LTA_openwebtext_dualt /scripts /eval_lm1b_linear_simplex_genppl.py
JinghuiLuAstronaut's picture
Add files using upload-large-folder tool
0badcf2 verified
Raw
History Blame Contribute Delete
12.2 kB
#!/usr/bin/env python3
"""Algebraic simplex-linear GenPPL eval for endpoint models.
This decoder matches the supervised bridge:
p_t = (1 - t) * p0 + t * x1
Inference keeps the sampled p0 fixed and replaces the unknown x1 with the
model's current endpoint prediction:
p_{t_next} = (1 - t_next) * p0 + t_next * a_theta(p_t, t).
There is no Dirichlet/Gamma resampling in the loop.
"""
from __future__ import annotations
import argparse
import json
import math
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from flowtext_lab.decode import sample_noise_simplex, state_for_model
from flowtext_lab.genppl import filter_generated_texts, summarize_token_diversity
from flowtext_lab.tokenization import BpeTextTokenizer
from eval_lm1b_c1024_fullycoupled_sde_genppl import (
build_model,
collect_special_token_ids,
filter_endpoint_probs,
score_with_gpt2,
)
def lerp(a: float, b: float, t: float) -> float:
return float(a) + float(t) * (float(b) - float(a))
def project_endpoint(
logits: torch.Tensor,
*,
temp: float,
projection: str,
top_k: int,
top_p: float,
banned_ids: list[int],
gumbel_tau: float,
gumbel_noise_scale: float,
eps: float,
) -> torch.Tensor:
endpoint = F.softmax(logits / max(float(temp), eps), dim=-1)
endpoint = filter_endpoint_probs(
endpoint,
top_k=top_k,
top_p=top_p,
banned_ids=banned_ids,
eps=eps,
)
if projection == "soft":
return endpoint
if projection == "argmax":
ids = endpoint.argmax(dim=-1)
return torch.zeros_like(endpoint).scatter_(-1, ids.unsqueeze(-1), 1.0)
if projection == "sample":
ids = torch.multinomial(endpoint.reshape(-1, endpoint.size(-1)), 1).view(*endpoint.shape[:-1])
return torch.zeros_like(endpoint).scatter_(-1, ids.unsqueeze(-1), 1.0)
if projection == "gumbel_softmax":
u = torch.rand_like(endpoint).clamp_(min=eps, max=1.0 - eps)
g = -torch.log(-torch.log(u))
z = (endpoint.clamp_min(eps).log() + float(gumbel_noise_scale) * g) / max(float(gumbel_tau), eps)
y = F.softmax(z, dim=-1).clamp_min(eps)
return y / y.sum(dim=-1, keepdim=True).clamp_min(eps)
raise ValueError(f"unknown endpoint_projection: {projection}")
@torch.inference_mode()
def decode_linear_simplex(
model,
tokenizer: BpeTextTokenizer,
*,
n_samples: int,
batch_size: int,
max_len: int,
steps: int,
seed: int,
device: torch.device,
noise_init: str,
noise_sigma: float,
noise_dirichlet_concentration: float,
endpoint_temp_start: float,
endpoint_temp_end: float,
endpoint_projection: str,
endpoint_top_k: int,
endpoint_top_p: float,
ban_special_tokens: bool,
gumbel_tau_start: float,
gumbel_tau_end: float,
gumbel_noise_scale_start: float,
gumbel_noise_scale_end: float,
final_from: str,
) -> tuple[list[list[int]], list[str], dict]:
torch.manual_seed(seed)
eps = 1e-8
all_ids: list[list[int]] = []
all_texts: list[str] = []
remaining = n_samples
banned_endpoint_ids = collect_special_token_ids(tokenizer) if ban_special_tokens else []
while remaining > 0:
bs = min(batch_size, remaining)
p0 = sample_noise_simplex(
(bs, max_len),
tokenizer.vocab_size,
device,
eps,
noise_mode=noise_init,
target_prob=1.0,
noise_sigma=noise_sigma,
dirichlet_concentration=noise_dirichlet_concentration,
)
probs = p0.clone()
attn = torch.ones((bs, max_len), dtype=torch.bool, device=device)
last_endpoint = probs
for step in range(steps):
cur_t = step / max(steps, 1)
next_t = (step + 1) / max(steps, 1)
t = torch.full((bs,), float(cur_t), dtype=torch.float32, device=device)
logits = model(state_for_model(model, probs, eps), t, attn).float()
endpoint = project_endpoint(
logits,
temp=lerp(endpoint_temp_start, endpoint_temp_end, cur_t),
projection=endpoint_projection,
top_k=endpoint_top_k,
top_p=endpoint_top_p,
banned_ids=banned_endpoint_ids,
gumbel_tau=lerp(gumbel_tau_start, gumbel_tau_end, cur_t),
gumbel_noise_scale=lerp(gumbel_noise_scale_start, gumbel_noise_scale_end, cur_t),
eps=eps,
)
last_endpoint = endpoint
probs = (1.0 - next_t) * p0 + next_t * endpoint
probs = probs.clamp_min(eps)
probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(eps)
if final_from == "blend_0.5":
final_probs = 0.5 * probs + 0.5 * last_endpoint
ids = final_probs.argmax(dim=-1).detach().cpu().tolist()
elif final_from == "model_t1":
t = torch.ones((bs,), dtype=torch.float32, device=device)
final_logits = model(state_for_model(model, probs, eps), t, attn).float()
ids = final_logits.argmax(dim=-1).detach().cpu().tolist()
else:
raise ValueError(f"unknown final_from: {final_from}")
all_ids.extend(ids)
all_texts.extend(tokenizer.decode(row, stop_at_eos=False, skip_special_tokens=False) for row in ids)
remaining -= bs
print(f"[linear] generated {n_samples - remaining}/{n_samples}", flush=True)
cfg = {
"decode_rule": "linear_simplex_algebraic",
"steps": steps,
"noise_init": noise_init,
"noise_sigma": noise_sigma,
"noise_dirichlet_concentration": noise_dirichlet_concentration,
"endpoint_temp_start": endpoint_temp_start,
"endpoint_temp_end": endpoint_temp_end,
"endpoint_projection": endpoint_projection,
"endpoint_top_k": endpoint_top_k,
"endpoint_top_p": endpoint_top_p,
"ban_special_tokens": ban_special_tokens,
"banned_endpoint_ids": banned_endpoint_ids,
"gumbel_tau_start": gumbel_tau_start,
"gumbel_tau_end": gumbel_tau_end,
"gumbel_noise_scale_start": gumbel_noise_scale_start,
"gumbel_noise_scale_end": gumbel_noise_scale_end,
"final_from": final_from,
"n_samples": n_samples,
"seed": seed,
}
return all_ids, all_texts, cfg
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Linear-simplex algebraic GenPPL eval")
p.add_argument("--checkpoint", required=True)
p.add_argument("--tokenizer_path", required=True)
p.add_argument("--scorer", required=True)
p.add_argument("--out_dir", required=True)
p.add_argument("--n_samples", type=int, default=128)
p.add_argument("--max_len", type=int, default=128)
p.add_argument("--steps", type=int, default=128)
p.add_argument("--batch_size", type=int, default=16)
p.add_argument("--score_batch", type=int, default=8)
p.add_argument("--score_max_length", type=int, default=1024)
p.add_argument("--noise_init", choices=["uniform", "logistic_normal", "dirichlet"], default="logistic_normal")
p.add_argument("--noise_sigma", type=float, default=3.0)
p.add_argument("--noise_dirichlet_concentration", type=float, default=1.0)
p.add_argument("--endpoint_temp_start", type=float, default=1.45)
p.add_argument("--endpoint_temp_end", type=float, default=0.8)
p.add_argument("--endpoint_projection", choices=["soft", "sample", "argmax", "gumbel_softmax"], default="soft")
p.add_argument("--endpoint_top_k", type=int, default=0)
p.add_argument("--endpoint_top_p", type=float, default=1.0)
p.add_argument("--ban_special_tokens", action="store_true")
p.add_argument("--gumbel_tau_start", type=float, default=1.0)
p.add_argument("--gumbel_tau_end", type=float, default=0.2)
p.add_argument("--gumbel_noise_scale_start", type=float, default=1.0)
p.add_argument("--gumbel_noise_scale_end", type=float, default=0.0)
p.add_argument("--final_from", choices=["blend_0.5", "model_t1"], default="model_t1")
p.add_argument("--seed", type=int, default=20260524)
return p.parse_args()
@torch.no_grad()
def main() -> None:
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[load] {args.checkpoint}", flush=True)
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
step = ckpt.get("step")
print(f"[ckpt] step={step}", flush=True)
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
model = build_model(ckpt, tokenizer, device)
ids, texts, decode_cfg = decode_linear_simplex(
model,
tokenizer,
n_samples=args.n_samples,
batch_size=args.batch_size,
max_len=args.max_len,
steps=args.steps,
seed=args.seed,
device=device,
noise_init=args.noise_init,
noise_sigma=args.noise_sigma,
noise_dirichlet_concentration=args.noise_dirichlet_concentration,
endpoint_temp_start=args.endpoint_temp_start,
endpoint_temp_end=args.endpoint_temp_end,
endpoint_projection=args.endpoint_projection,
endpoint_top_k=args.endpoint_top_k,
endpoint_top_p=args.endpoint_top_p,
ban_special_tokens=args.ban_special_tokens,
gumbel_tau_start=args.gumbel_tau_start,
gumbel_tau_end=args.gumbel_tau_end,
gumbel_noise_scale_start=args.gumbel_noise_scale_start,
gumbel_noise_scale_end=args.gumbel_noise_scale_end,
final_from=args.final_from,
)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
def strip_special(t: str) -> str:
import re
t = t.replace("[CLS]", " ").replace("[SEP]", " ").replace("[PAD]", " ")
t = t.replace("<|endoftext|>", " ")
return re.sub(r"\s+", " ", t).strip()
stripped = [strip_special(t) for t in texts]
kept_raw, _ = filter_generated_texts(texts, min_chars=1, normalize_whitespace=False, drop_empty=True)
kept_stripped, _ = filter_generated_texts(stripped, min_chars=1, normalize_whitespace=True, drop_empty=True)
diversity = summarize_token_diversity(ids).__dict__
print(f"[score] loading scorer: {args.scorer}", flush=True)
scorer_tok = AutoTokenizer.from_pretrained(args.scorer)
if scorer_tok.pad_token_id is None:
scorer_tok.pad_token = scorer_tok.eos_token
scorer_tok.pad_token_id = scorer_tok.eos_token_id
scorer = AutoModelForCausalLM.from_pretrained(args.scorer).to(device).eval()
if getattr(scorer.config, "pad_token_id", None) is None:
scorer.config.pad_token_id = scorer_tok.pad_token_id
raw_ppl = score_with_gpt2(
kept_raw, scorer, scorer_tok,
batch_size=args.score_batch, max_length=args.score_max_length, device=device,
)
stripped_ppl = score_with_gpt2(
kept_stripped, scorer, scorer_tok,
batch_size=args.score_batch, max_length=args.score_max_length, device=device,
)
summary = {
"type": "summary",
"checkpoint": args.checkpoint,
"step": step,
"decode": decode_cfg,
"raw_genppl": raw_ppl,
"stripped_genppl": stripped_ppl,
"diversity": diversity,
}
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_jsonl = out_dir / f"linear_steps{args.steps}_samples{args.n_samples}_scored.jsonl"
with out_jsonl.open("w", encoding="utf-8") as f:
f.write(json.dumps(summary, ensure_ascii=False) + "\n")
for i, (raw, clean) in enumerate(zip(texts, stripped)):
f.write(json.dumps({"type": "sample", "index": i, "raw_text": raw, "stripped_text": clean}, ensure_ascii=False) + "\n")
print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2), flush=True)
print(f"[done] {out_jsonl}", flush=True)
if __name__ == "__main__":
main()