| |
| """Decode-sweep lab for FlowText OpenWebText checkpoints. |
| |
| The goal is to debug inference without touching training. We try several |
| simplex-valid update rules, generate many candidates, and rank them with |
| anti-collapse diagnostics instead of pure self-likelihood. |
| |
| Run from the flowtext_standard_bench repository root. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import re |
| import sys |
| from collections import Counter |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
| from typing import Iterable, List, Sequence |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from eval import build_model_from_ckpt |
| from flowtext_lab.bridges import smooth_onehot |
| from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model |
| from flowtext_lab.tokenization import BpeTextTokenizer |
|
|
|
|
| WORD_RE = re.compile(r"[A-Za-z]+|\d+|[^\sA-Za-z\d]") |
|
|
|
|
| @dataclass |
| class DecodeConfig: |
| label: str |
| rule: str |
| steps: int = 64 |
| model_t_mode: str = "flow" |
| eta: float = 0.5 |
| damping: float = 1.0 |
| max_gamma: float = 1.0 |
| endpoint_temp: float = 1.0 |
| state_floor: float = 1e-8 |
| final_from: str = "state" |
| noise_mix: float = 0.0 |
| noise_decay: str = "linear" |
| eos_logit_bias: float = 0.0 |
|
|
|
|
| def tokenize_for_metrics(text: str) -> list[str]: |
| return WORD_RE.findall(text) |
|
|
|
|
| def repeated_ngram_frac(tokens: Sequence[str], n: int) -> float: |
| if len(tokens) < n: |
| return 0.0 |
| grams = list(zip(*[tokens[i:] for i in range(n)])) |
| counts = Counter(grams) |
| return sum(v - 1 for v in counts.values() if v > 1) / max(len(grams), 1) |
|
|
|
|
| def text_metrics(text: str) -> dict: |
| toks = tokenize_for_metrics(text) |
| words = [t.lower() for t in toks if re.fullmatch(r"[A-Za-z]+", t)] |
| n_tok = max(len(toks), 1) |
| n_words = max(len(words), 1) |
| word_counts = Counter(words) |
| max_word_frac = word_counts.most_common(1)[0][1] / n_words if word_counts else 1.0 |
| distinct1 = len(set(words)) / n_words if words else 0.0 |
| bigrams = list(zip(words, words[1:])) |
| distinct2 = len(set(bigrams)) / max(len(bigrams), 1) if bigrams else 0.0 |
| digit_frac = sum(t.isdigit() for t in toks) / n_tok |
| punct_frac = sum(bool(re.fullmatch(r"[,.;:!?]+", t)) for t in toks) / n_tok |
| eos_count = text.count("<|endoftext|>") |
| bad_char_count = text.count("�") |
| rep3 = repeated_ngram_frac([t.lower() for t in toks], 3) |
| rep4 = repeated_ngram_frac([t.lower() for t in toks], 4) |
| |
| |
| quality = ( |
| min(len(text) / 700.0, 1.0) |
| + 0.35 * distinct2 |
| + 0.15 * distinct1 |
| - 0.30 * eos_count |
| - 2.60 * rep3 |
| - 1.60 * rep4 |
| - 1.30 * digit_frac |
| - 0.65 * punct_frac |
| - 1.35 * max_word_frac |
| - 0.35 * bad_char_count |
| ) |
| return { |
| "quality": float(quality), |
| "chars": len(text), |
| "tokens": len(toks), |
| "words": len(words), |
| "eos_count": eos_count, |
| "bad_char_count": bad_char_count, |
| "rep3": float(rep3), |
| "rep4": float(rep4), |
| "distinct1": float(distinct1), |
| "distinct2": float(distinct2), |
| "digit_frac": float(digit_frac), |
| "punct_frac": float(punct_frac), |
| "max_word_frac": float(max_word_frac), |
| } |
|
|
|
|
| def decode_text(tokenizer: BpeTextTokenizer, ids: Sequence[int]) -> str: |
| return tokenizer.decode(ids, stop_at_eos=False, skip_special_tokens=False) |
|
|
|
|
| def encode_prompt(tokenizer: BpeTextTokenizer, prompt: str, max_len: int) -> list[int]: |
| return list(tokenizer.tokenizer.encode(prompt).ids)[:max_len] |
|
|
|
|
| @torch.no_grad() |
| def build_initial_state( |
| tokenizer: BpeTextTokenizer, |
| prompts: list[str], |
| restarts: int, |
| max_len: int, |
| target_prob: float, |
| eps: float, |
| noise_init: str, |
| noise_sigma: float, |
| dirichlet_init_concentration: float, |
| device: torch.device, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str]]: |
| expanded: list[str] = [] |
| prompt_ids: list[list[int]] = [] |
| for prompt in prompts: |
| ids = encode_prompt(tokenizer, prompt, max_len=max_len) |
| for _ in range(restarts): |
| expanded.append(prompt) |
| prompt_ids.append(ids) |
|
|
| batch = len(prompt_ids) |
| attn = torch.ones((batch, max_len), dtype=torch.bool, device=device) |
| probs = sample_noise_simplex( |
| (batch, max_len), |
| tokenizer.vocab_size, |
| device, |
| eps, |
| noise_mode=noise_init, |
| target_prob=target_prob, |
| noise_sigma=noise_sigma, |
| dirichlet_concentration=dirichlet_init_concentration, |
| ) |
| lock = torch.zeros((batch, max_len), dtype=torch.bool, device=device) |
| lock_probs = torch.zeros((batch, max_len, tokenizer.vocab_size), dtype=torch.float32, device=device) |
| for row, ids in enumerate(prompt_ids): |
| if not ids: |
| continue |
| ids_t = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) |
| sp = smooth_onehot(ids_t, tokenizer.vocab_size, target_prob, eps)[0] |
| probs[row, : len(ids)] = sp |
| lock_probs[row, : len(ids)] = sp |
| lock[row, : len(ids)] = True |
| return probs, attn, lock, lock_probs, expanded |
|
|
|
|
| def flowmap_gamma(step: int, steps: int, damping: float, max_gamma: float, eps: float) -> float: |
| s = step / max(steps, 1) |
| t_next = (step + 1) / max(steps, 1) |
| base_gamma = (t_next - s) / max(1.0 - s, eps) |
| gamma = float(damping) * base_gamma |
| return min(gamma, float(max_gamma)) if max_gamma > 0 else gamma |
|
|
|
|
| @torch.no_grad() |
| def decode_batch( |
| model, |
| init_probs: torch.Tensor, |
| attn: torch.Tensor, |
| lock: torch.Tensor, |
| lock_probs: torch.Tensor, |
| cfg: DecodeConfig, |
| eps: float, |
| eos_id: int | None = None, |
| ) -> torch.Tensor: |
| probs = init_probs.float().clone() |
| device = probs.device |
| last_endpoint = probs |
| for step in range(cfg.steps): |
| t = model_time_for_step(cfg.model_t_mode, step, cfg.steps, probs.size(0), device, dtype=torch.float32) |
| logits = model(state_for_model(model, probs, eps), t, attn).float() |
| if cfg.endpoint_temp != 1.0: |
| logits = logits / float(cfg.endpoint_temp) |
| if cfg.eos_logit_bias != 0.0 and eos_id is not None and 0 <= eos_id < logits.size(-1): |
| logits[..., eos_id] = logits[..., eos_id] + float(cfg.eos_logit_bias) |
| endpoint = F.softmax(logits, dim=-1) |
| last_endpoint = endpoint |
|
|
| if cfg.rule == "flowmap": |
| gamma = flowmap_gamma(step, cfg.steps, cfg.damping, cfg.max_gamma, eps) |
| new_probs = probs + gamma * (endpoint - probs) |
| elif cfg.rule == "replace": |
| new_probs = (1.0 - cfg.eta) * probs + cfg.eta * endpoint |
| elif cfg.rule == "geometric": |
| log_mix = (1.0 - cfg.eta) * torch.log(probs.clamp_min(eps)) + cfg.eta * torch.log(endpoint.clamp_min(eps)) |
| new_probs = F.softmax(log_mix, dim=-1) |
| elif cfg.rule == "centered_residual": |
| |
| residual = endpoint - probs |
| residual = residual - residual.mean(dim=-1, keepdim=True) |
| new_probs = probs + cfg.eta * residual |
| else: |
| raise ValueError(f"Unknown decode rule: {cfg.rule}") |
|
|
| if cfg.noise_mix > 0: |
| if cfg.noise_decay == "linear": |
| lam = cfg.noise_mix * (1.0 - (step + 1) / max(cfg.steps, 1)) |
| elif cfg.noise_decay == "sqrt": |
| lam = cfg.noise_mix * math.sqrt(max(0.0, 1.0 - (step + 1) / max(cfg.steps, 1))) |
| else: |
| lam = cfg.noise_mix |
| if lam > 0: |
| uniform = torch.full_like(new_probs, 1.0 / new_probs.size(-1)) |
| new_probs = (1.0 - lam) * new_probs + lam * uniform |
|
|
| new_probs = new_probs.clamp_min(max(float(cfg.state_floor), eps)) |
| new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True).clamp_min(eps) |
| new_probs = torch.where(lock.unsqueeze(-1), lock_probs, new_probs) |
| probs = new_probs |
|
|
| if cfg.final_from == "endpoint": |
| out = last_endpoint |
| out = torch.where(lock.unsqueeze(-1), lock_probs, out) |
| return out / out.sum(dim=-1, keepdim=True).clamp_min(eps) |
| if cfg.final_from == "blend": |
| out = 0.5 * probs + 0.5 * last_endpoint |
| out = torch.where(lock.unsqueeze(-1), lock_probs, out) |
| return out / out.sum(dim=-1, keepdim=True).clamp_min(eps) |
| return probs |
|
|
|
|
| @torch.no_grad() |
| def pseudo_likelihood_scores( |
| model, |
| tokenizer: BpeTextTokenizer, |
| probs: torch.Tensor, |
| attn: torch.Tensor, |
| lock: torch.Tensor, |
| target_prob: float, |
| eps: float, |
| repeats: int, |
| mask_frac: float, |
| rerank_t: float, |
| ) -> torch.Tensor: |
| ids = probs.argmax(dim=-1) |
| endpoint = smooth_onehot(ids, tokenizer.vocab_size, target_prob, eps) |
| eligible = attn & (~lock) |
| scores = torch.zeros(ids.size(0), dtype=torch.float32, device=ids.device) |
| counts = torch.zeros_like(scores) |
| for _ in range(max(1, repeats)): |
| score_mask = (torch.rand_like(ids.float()) < mask_frac) & eligible |
| for row in range(ids.size(0)): |
| if eligible[row].any() and not score_mask[row].any(): |
| choices = torch.nonzero(eligible[row], as_tuple=False).flatten() |
| score_mask[row, choices[torch.randint(0, choices.numel(), (1,), device=ids.device)]] = True |
| noise = sample_noise_simplex( |
| (ids.size(0), ids.size(1)), |
| tokenizer.vocab_size, |
| ids.device, |
| eps, |
| noise_mode="logistic_normal", |
| target_prob=target_prob, |
| noise_sigma=-1.0, |
| ) |
| inp = torch.where(score_mask.unsqueeze(-1), noise, endpoint) |
| inp = torch.where(lock.unsqueeze(-1), probs, inp) |
| t = torch.full((ids.size(0),), float(rerank_t), dtype=torch.float32, device=ids.device) |
| logits = model(state_for_model(model, inp, eps), t, attn).float() |
| logp = F.log_softmax(logits, dim=-1).gather(-1, ids.unsqueeze(-1)).squeeze(-1) |
| scores += (logp * score_mask.float()).sum(dim=-1) |
| counts += score_mask.float().sum(dim=-1) |
| return scores / counts.clamp_min(1.0) |
|
|
|
|
| def default_configs(steps: int, config_set: str) -> list[DecodeConfig]: |
| if config_set == "focused_flowmap": |
| return [ |
| DecodeConfig("flowmap_t1p00_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0), |
| DecodeConfig("flowmap_t1p10_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.10), |
| DecodeConfig("flowmap_t1p25_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.25), |
| DecodeConfig("flowmap_t1p40_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40), |
| DecodeConfig("flowmap_t1p60_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.60), |
| DecodeConfig("flowmap_t1p25_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25), |
| DecodeConfig("flowmap_t1p40_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.40), |
| DecodeConfig("flowmap_t1p60_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.60), |
| DecodeConfig("flowmap_t1p25_g0p5", "flowmap", steps=steps, damping=1.0, max_gamma=0.5, endpoint_temp=1.25), |
| DecodeConfig("flowmap_t1p40_g0p5", "flowmap", steps=steps, damping=1.0, max_gamma=0.5, endpoint_temp=1.40), |
| ] |
| if config_set == "best_flowmap": |
| return [ |
| DecodeConfig("flowmap_t1p25_d0p7", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25), |
| DecodeConfig("flowmap_t1p25_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.25), |
| DecodeConfig("flowmap_t1p35_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35), |
| DecodeConfig("flowmap_t1p40_d1p0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40), |
| ] |
| if config_set == "final_projection": |
| return [ |
| DecodeConfig("flowmap_t1p35_state", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, final_from="state"), |
| DecodeConfig("flowmap_t1p35_endpoint", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, final_from="endpoint"), |
| DecodeConfig("flowmap_t1p35_blend", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, final_from="blend"), |
| DecodeConfig("flowmap_t1p40_state", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, final_from="state"), |
| DecodeConfig("flowmap_t1p40_endpoint", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, final_from="endpoint"), |
| DecodeConfig("flowmap_t1p40_blend", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, final_from="blend"), |
| DecodeConfig("flowmap_t1p25_d0p7_state", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, final_from="state"), |
| DecodeConfig("flowmap_t1p25_d0p7_endpoint", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, final_from="endpoint"), |
| DecodeConfig("flowmap_t1p25_d0p7_blend", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, final_from="blend"), |
| ] |
| if config_set == "eos_sweep": |
| return [ |
| DecodeConfig("flowmap_t1p35_eos0", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=0.0), |
| DecodeConfig("flowmap_t1p35_eos-1", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=-1.0), |
| DecodeConfig("flowmap_t1p35_eos-2", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=-2.0), |
| DecodeConfig("flowmap_t1p35_eos-3", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.35, eos_logit_bias=-3.0), |
| DecodeConfig("flowmap_t1p40_eos-2", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.40, eos_logit_bias=-2.0), |
| DecodeConfig("flowmap_t1p25_d0p7_eos-2", "flowmap", steps=steps, damping=0.7, max_gamma=1.0, endpoint_temp=1.25, eos_logit_bias=-2.0), |
| ] |
| if config_set != "broad": |
| raise ValueError(f"Unknown config_set: {config_set}") |
| return [ |
| DecodeConfig("flowmap64", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, final_from="state"), |
| DecodeConfig("flowmap_temp1p25", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=1.25), |
| DecodeConfig("flowmap_temp0p85", "flowmap", steps=steps, damping=1.0, max_gamma=1.0, endpoint_temp=0.85), |
| DecodeConfig("replace_eta0p35", "replace", steps=steps, eta=0.35), |
| DecodeConfig("replace_eta0p50", "replace", steps=steps, eta=0.50), |
| DecodeConfig("replace_eta0p65", "replace", steps=steps, eta=0.65), |
| DecodeConfig("replace_eta0p50_temp1p25", "replace", steps=steps, eta=0.50, endpoint_temp=1.25), |
| DecodeConfig("geometric_eta0p25", "geometric", steps=steps, eta=0.25), |
| DecodeConfig("geometric_eta0p50", "geometric", steps=steps, eta=0.50), |
| DecodeConfig("centered_residual_eta0p20", "centered_residual", steps=steps, eta=0.20), |
| DecodeConfig("replace_eta0p50_floor1e6", "replace", steps=steps, eta=0.50, state_floor=1e-6), |
| DecodeConfig("replace_eta0p50_leak", "replace", steps=steps, eta=0.50, noise_mix=0.03, noise_decay="sqrt"), |
| ] |
|
|
|
|
| def aggregate(rows: list[dict]) -> dict: |
| keys = ["quality", "eos_count", "rep3", "rep4", "distinct1", "distinct2", "digit_frac", "max_word_frac"] |
| return {f"mean_{k}": sum(float(r[k]) for r in rows) / max(len(rows), 1) for k in keys} |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--tokenizer_path", required=True) |
| parser.add_argument("--max_len", type=int, default=128) |
| parser.add_argument("--steps", type=int, default=64) |
| parser.add_argument("--restarts", type=int, default=64) |
| parser.add_argument("--target_prob", type=float, default=0.99) |
| parser.add_argument("--eps", type=float, default=1e-8) |
| parser.add_argument("--model_t_mode", choices=["linear", "flow", "const0", "const05", "const1", "random"], default="flow") |
| parser.add_argument("--noise_init", choices=["uniform", "logistic_normal", "dirichlet"], default="dirichlet") |
| parser.add_argument("--noise_sigma", type=float, default=-1.0) |
| parser.add_argument("--dirichlet_init_concentration", type=float, default=1.0) |
| parser.add_argument("--prompts", default="|The|In the early morning|Scientists have|The company said|A young woman") |
| parser.add_argument("--score_repeats", type=int, default=0) |
| parser.add_argument("--score_mask_frac", type=float, default=0.5) |
| parser.add_argument("--rerank_t", type=float, default=0.5) |
| parser.add_argument("--pl_weight", type=float, default=0.0) |
| parser.add_argument("--output", default="runs/decode_lab/latest_decode_lab.jsonl") |
| parser.add_argument("--config_set", default="broad", choices=["broad", "focused_flowmap", "best_flowmap", "final_projection", "eos_sweep"]) |
| parser.add_argument("--decode_batch_size", type=int, default=0) |
| parser.add_argument("--topk", type=int, default=5) |
| parser.add_argument("--seed", type=int, default=20260428) |
| args = parser.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) |
| ckpt = torch.load(args.checkpoint, map_location="cpu") |
| model = build_model_from_ckpt(ckpt, tokenizer.vocab_size, args.max_len, device) |
| model.eval() |
|
|
| prompts = args.prompts.split("|") |
| |
| print(f"[info] device={device} prompts={prompts} restarts={args.restarts} steps={args.steps}") |
| print(f"[info] checkpoint={args.checkpoint}") |
|
|
| out_path = Path(args.output) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| configs = default_configs(args.steps, args.config_set) |
| for cfg in configs: |
| cfg.model_t_mode = args.model_t_mode |
| with out_path.open("w") as f: |
| for cfg in configs: |
| init, attn, lock, lock_probs, expanded = build_initial_state( |
| tokenizer=tokenizer, |
| prompts=prompts, |
| restarts=args.restarts, |
| max_len=args.max_len, |
| target_prob=args.target_prob, |
| eps=args.eps, |
| noise_init=args.noise_init, |
| noise_sigma=args.noise_sigma, |
| dirichlet_init_concentration=args.dirichlet_init_concentration, |
| device=device, |
| ) |
| if args.decode_batch_size > 0 and init.size(0) > args.decode_batch_size: |
| decoded_parts = [] |
| for start in range(0, init.size(0), args.decode_batch_size): |
| end = min(start + args.decode_batch_size, init.size(0)) |
| part = decode_batch( |
| model, |
| init[start:end], |
| attn[start:end], |
| lock[start:end], |
| lock_probs[start:end], |
| cfg, |
| args.eps, |
| tokenizer.eos_id, |
| ) |
| decoded_parts.append(part.detach().cpu()) |
| print(f"[chunk] {cfg.label} decoded {end}/{init.size(0)}", flush=True) |
| decoded = torch.cat(decoded_parts, dim=0) |
| else: |
| decoded = decode_batch(model, init, attn, lock, lock_probs, cfg, args.eps, tokenizer.eos_id) |
| ids = decoded.argmax(dim=-1).detach().cpu().tolist() |
| texts = [decode_text(tokenizer, row) for row in ids] |
| rows = [] |
| for i, text in enumerate(texts): |
| m = text_metrics(text) |
| m.update({"candidate": i, "prompt": expanded[i], "text": text}) |
| rows.append(m) |
| if args.score_repeats > 0: |
| decoded_for_score = decoded.to(device) if decoded.device != device else decoded |
| pl = pseudo_likelihood_scores( |
| model, |
| tokenizer, |
| decoded_for_score, |
| attn, |
| lock, |
| args.target_prob, |
| args.eps, |
| repeats=args.score_repeats, |
| mask_frac=args.score_mask_frac, |
| rerank_t=args.rerank_t, |
| ).detach().cpu().tolist() |
| for row, score in zip(rows, pl): |
| row["pseudo_logp"] = float(score) |
| row["rank_score"] = float(row["quality"] + args.pl_weight * score) |
| else: |
| for row in rows: |
| row["pseudo_logp"] = None |
| row["rank_score"] = float(row["quality"]) |
|
|
| summary = {"type": "summary", "config": asdict(cfg), "agg": aggregate(rows)} |
| f.write(json.dumps(summary, ensure_ascii=False) + "\n") |
| print("\n" + "=" * 96) |
| print("[config]", cfg.label, asdict(cfg)) |
| print("[metrics]", json.dumps(summary["agg"], ensure_ascii=False)) |
| for prompt in prompts: |
| subset = [r for r in rows if r["prompt"] == prompt] |
| subset.sort(key=lambda r: r["rank_score"], reverse=True) |
| for rank, row in enumerate(subset[: args.topk], 1): |
| rec = {"type": "sample", "config": asdict(cfg), "rank": rank, **row} |
| f.write(json.dumps(rec, ensure_ascii=False) + "\n") |
| if rank <= 1: |
| print(f"\n--- best prompt={prompt!r} rank_score={row['rank_score']:.4f} quality={row['quality']:.4f} ---") |
| print(row["text"]) |
|
|
| del init, attn, lock, lock_probs, decoded |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| print(f"[done] wrote {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|