| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import re |
| import sys |
| from collections import Counter |
| from pathlib import Path |
| from typing import Sequence |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from eval import build_model_from_ckpt |
| from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model |
| from flowtext_lab.tokenization import BpeTextTokenizer |
|
|
|
|
| WORD_RE = re.compile(r"[A-Za-z]+|\d+|[^\sA-Za-z\d]") |
|
|
|
|
| def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]: |
| return list(tokenizer.tokenizer.encode(prompt, add_special_tokens=False).ids)[:max_len] |
|
|
|
|
| def decode_text(tokenizer: BpeTextTokenizer, ids: Sequence[int]) -> str: |
| return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False) |
|
|
|
|
| def text_metrics(text: str) -> dict[str, float]: |
| toks = WORD_RE.findall(text) |
| words = [t.lower() for t in toks if re.fullmatch(r"[A-Za-z]+", t)] |
| n_tok = max(len(toks), 1) |
| n_words = max(len(words), 1) |
| wc = Counter(words) |
| max_word_frac = wc.most_common(1)[0][1] / n_words if wc else 1.0 |
| grams3 = list(zip(toks, toks[1:], toks[2:])) |
| rep3 = sum(v - 1 for v in Counter(grams3).values() if v > 1) / max(len(grams3), 1) |
| bigrams = list(zip(words, words[1:])) |
| distinct2 = len(set(bigrams)) / max(len(bigrams), 1) if bigrams else 0.0 |
| punct_frac = sum(bool(re.fullmatch(r"[,.;:!?]+", t)) for t in toks) / n_tok |
| digit_frac = sum(t.isdigit() for t in toks) / n_tok |
| quality = ( |
| min(len(text) / 700.0, 1.0) |
| + 0.35 * distinct2 |
| - 2.6 * rep3 |
| - 1.2 * max_word_frac |
| - 0.8 * punct_frac |
| - 1.0 * digit_frac |
| - 0.2 * text.count("<|endoftext|>") |
| - 0.5 * text.count("�") |
| ) |
| return { |
| "quality": float(quality), |
| "chars": float(len(text)), |
| "words": float(len(words)), |
| "rep3": float(rep3), |
| "distinct2": float(distinct2), |
| "punct_frac": float(punct_frac), |
| "max_word_frac": float(max_word_frac), |
| "eot_count": float(text.count("<|endoftext|>")), |
| } |
|
|
|
|
| def dirichlet_mean(endpoint: torch.Tensor, support_t: float, eps: float) -> torch.Tensor: |
| vocab = endpoint.size(-1) |
| mean = (1.0 - support_t) / float(vocab) + support_t * endpoint |
| mean = mean.clamp_min(eps) |
| return mean / mean.sum(dim=-1, keepdim=True).clamp_min(eps) |
|
|
|
|
| def total_concentration(support_t: float, c_min: float, c_max: float) -> float: |
| log_min = math.log(max(c_min, 1e-8)) |
| log_max = math.log(max(c_max, c_min)) |
| return math.exp(log_min + support_t * (log_max - log_min)) |
|
|
|
|
| def dirichlet_resample(mean: torch.Tensor, support_t: float, c_min: float, c_max: float, eps: float) -> torch.Tensor: |
| conc = total_concentration(support_t, c_min, c_max) |
| alpha = (mean * conc).clamp_min(eps) |
| sample = torch._standard_gamma(alpha).clamp_min(eps) |
| return sample / sample.sum(dim=-1, keepdim=True).clamp_min(eps) |
|
|
|
|
| def schedule_power(step: int, steps: int, power: float) -> float: |
| base = (step + 1) / max(steps, 1) |
| return float(max(0.0, min(1.0, base ** float(power)))) |
|
|
|
|
| def current_anchor(probs: torch.Tensor, mode: str, eps: float) -> torch.Tensor: |
| if mode == "state": |
| return probs |
| if mode == "onehot": |
| ids = probs.argmax(dim=-1) |
| return F.one_hot(ids, probs.size(-1)).to(dtype=probs.dtype, device=probs.device) |
| if mode == "sqrt_state": |
| x = probs.clamp_min(eps).sqrt() |
| return x / x.sum(dim=-1, keepdim=True).clamp_min(eps) |
| raise ValueError(f"unknown anchor mode: {mode}") |
|
|
|
|
| @torch.no_grad() |
| def build_initial( |
| tokenizer: BpeTextTokenizer, |
| prompts: list[str], |
| restarts: int, |
| max_len: int, |
| eps: float, |
| noise_init: str, |
| target_prob: float, |
| noise_sigma: float, |
| dirichlet_concentration: float, |
| lock_bos: bool, |
| lock_final_eos: bool, |
| device: torch.device, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]: |
| expanded = [] |
| prompt_ids = [] |
| for prompt in prompts: |
| ids = encode_prompt(tokenizer, prompt, max_len) |
| if lock_bos: |
| ids = [tokenizer.bos_id] + ids |
| ids = ids[:max_len] |
| for _ in range(restarts): |
| expanded.append(prompt) |
| prompt_ids.append(ids) |
| batch = len(prompt_ids) |
| probs = sample_noise_simplex( |
| (batch, max_len), |
| tokenizer.vocab_size, |
| device, |
| eps, |
| noise_mode=noise_init, |
| target_prob=target_prob, |
| noise_sigma=noise_sigma, |
| dirichlet_concentration=dirichlet_concentration, |
| ) |
| lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device) |
| lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device) |
| for row, ids in enumerate(prompt_ids): |
| if not ids: |
| continue |
| ids_t = torch.tensor(ids, dtype=torch.long, device=device) |
| onehot = F.one_hot(ids_t, tokenizer.vocab_size).float() |
| probs[row, : len(ids)] = onehot |
| lock_probs[row, : len(ids)] = onehot |
| lock[row, : len(ids)] = True |
| if lock_final_eos: |
| eos = torch.tensor([tokenizer.eos_id], dtype=torch.long, device=device) |
| eos_prob = F.one_hot(eos, tokenizer.vocab_size).float()[0] |
| probs[:, -1] = eos_prob |
| lock_probs[:, -1] = eos_prob |
| lock[:, -1] = True |
| attn = torch.ones((batch, max_len), dtype=torch.bool, device=device) |
| return probs, lock, lock_probs, attn, expanded |
|
|
|
|
| @torch.no_grad() |
| def decode_one_config( |
| model, |
| tokenizer, |
| init, |
| lock, |
| lock_probs, |
| attn, |
| args, |
| update: str, |
| final_from: str, |
| temp: float, |
| model_t_mode: str, |
| support_power: float, |
| semantic_power: float, |
| anchor_mode: str, |
| ): |
| probs = init.clone() |
| last_endpoint = probs |
| device = probs.device |
| for step in range(args.steps): |
| model_t = model_time_for_step(model_t_mode, step, args.steps, probs.size(0), device, dtype=torch.float32) |
| logits = model(state_for_model(model, probs, args.eps), model_t, attn).float() / temp |
| endpoint = F.softmax(logits, dim=-1) |
| last_endpoint = endpoint |
| support_t = schedule_power(step, args.steps, support_power) |
| semantic_t = schedule_power(step, args.steps, semantic_power) |
| if update.startswith("dual_line"): |
| anchor = current_anchor(probs, anchor_mode, args.eps) |
| forward_endpoint = (1.0 - semantic_t) * anchor + semantic_t * endpoint |
| forward_endpoint = forward_endpoint / forward_endpoint.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
| else: |
| forward_endpoint = endpoint |
| mean = dirichlet_mean(forward_endpoint, support_t, args.eps) |
| if update == "mean": |
| new_probs = mean |
| elif update == "resample": |
| new_probs = dirichlet_resample(mean, support_t, args.concentration_min, args.concentration_max, args.eps) |
| elif update == "dual_line_mean": |
| new_probs = mean |
| elif update == "dual_line_resample": |
| new_probs = dirichlet_resample(mean, support_t, args.concentration_min, args.concentration_max, args.eps) |
| elif update == "ema_mean": |
| gamma = 1.0 / max(args.steps - step, 1) |
| new_probs = (1.0 - gamma) * probs + gamma * mean |
| new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
| else: |
| raise ValueError(update) |
| probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs) |
| if final_from == "endpoint": |
| out = last_endpoint |
| elif final_from == "blend": |
| out = 0.5 * probs + 0.5 * last_endpoint |
| else: |
| out = probs |
| out = torch.where(lock.unsqueeze(-1), lock_probs, out) |
| return out / out.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--checkpoint", required=True) |
| ap.add_argument("--tokenizer_path", required=True) |
| ap.add_argument("--output", required=True) |
| ap.add_argument("--max_len", type=int, default=256) |
| ap.add_argument("--steps", type=int, default=256) |
| ap.add_argument("--restarts", type=int, default=4) |
| ap.add_argument("--prompts", nargs="+", default=[""]) |
| ap.add_argument("--noise_init", default="dirichlet") |
| ap.add_argument("--target_prob", type=float, default=0.99) |
| ap.add_argument("--noise_sigma", type=float, default=-1.0) |
| ap.add_argument("--dirichlet_init_concentration", type=float, default=1.0) |
| ap.add_argument("--concentration_min", type=float, default=1.0) |
| ap.add_argument("--concentration_max", type=float, default=1024.0) |
| ap.add_argument("--updates", nargs="+", default=["mean", "ema_mean", "resample"]) |
| ap.add_argument("--finals", nargs="+", default=["state", "endpoint", "blend"]) |
| ap.add_argument("--temps", nargs="+", type=float, default=[1.0, 1.2, 1.35]) |
| ap.add_argument("--model_t_modes", nargs="+", default=["flow", "const05"]) |
| ap.add_argument("--support_powers", nargs="+", type=float, default=[1.0]) |
| ap.add_argument("--semantic_powers", nargs="+", type=float, default=[1.0]) |
| ap.add_argument("--anchor_modes", nargs="+", default=["onehot"]) |
| ap.add_argument("--lock_bos", action="store_true") |
| ap.add_argument("--lock_final_eos", action="store_true") |
| ap.add_argument("--eps", type=float, default=1e-8) |
| ap.add_argument("--seed", type=int, default=1234) |
| args = ap.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) |
| ckpt = torch.load(args.checkpoint, map_location=device) |
| model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device) |
| model.eval() |
| init, lock, lock_probs, attn, expanded = build_initial( |
| tokenizer, |
| args.prompts, |
| args.restarts, |
| args.max_len, |
| args.eps, |
| args.noise_init, |
| args.target_prob, |
| args.noise_sigma, |
| args.dirichlet_init_concentration, |
| args.lock_bos, |
| args.lock_final_eos, |
| device, |
| ) |
|
|
| configs = [] |
| for update in args.updates: |
| for final_from in args.finals: |
| for temp in args.temps: |
| for model_t_mode in args.model_t_modes: |
| for support_power in args.support_powers: |
| for semantic_power in args.semantic_powers: |
| for anchor_mode in args.anchor_modes: |
| configs.append((update, final_from, temp, model_t_mode, support_power, semantic_power, anchor_mode)) |
|
|
| out_path = Path(args.output) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| rows = [] |
| with out_path.open("w") as f: |
| for update, final_from, temp, model_t_mode, support_power, semantic_power, anchor_mode in configs: |
| probs = decode_one_config( |
| model, |
| tokenizer, |
| init, |
| lock, |
| lock_probs, |
| attn, |
| args, |
| update, |
| final_from, |
| temp, |
| model_t_mode, |
| support_power, |
| semantic_power, |
| anchor_mode, |
| ) |
| ids = probs.argmax(dim=-1).detach().cpu().tolist() |
| texts = [decode_text(tokenizer, row) for row in ids] |
| mets = [text_metrics(t) for t in texts] |
| mean_q = sum(m["quality"] for m in mets) / len(mets) |
| best_i = max(range(len(texts)), key=lambda i: mets[i]["quality"]) |
| row = { |
| "update": update, |
| "final_from": final_from, |
| "temp": temp, |
| "model_t_mode": model_t_mode, |
| "support_power": support_power, |
| "semantic_power": semantic_power, |
| "anchor_mode": anchor_mode, |
| "mean_quality": mean_q, |
| "best_prompt": expanded[best_i], |
| "best_metrics": mets[best_i], |
| "best_text": texts[best_i], |
| } |
| rows.append(row) |
| print( |
| "\n====", |
| update, |
| final_from, |
| temp, |
| model_t_mode, |
| "support_p", |
| support_power, |
| "semantic_p", |
| semantic_power, |
| "anchor", |
| anchor_mode, |
| "mean_q", |
| round(mean_q, 4), |
| flush=True, |
| ) |
| print(texts[best_i][:1600], flush=True) |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") |
| f.flush() |
| best = max(rows, key=lambda r: r["mean_quality"]) |
| print("\nBEST", json.dumps({k: best[k] for k in best if k != "best_text"}, ensure_ascii=False, indent=2), flush=True) |
| print(best["best_text"], flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|