from __future__ import annotations import argparse import json import sys from collections import Counter from pathlib import Path import torch import torch.nn.functional as F REPO_ROOT = Path(__file__).resolve().parents[1] SCRIPT_DIR = Path(__file__).resolve().parent for p in (REPO_ROOT, SCRIPT_DIR): if str(p) not in sys.path: sys.path.insert(0, str(p)) from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model from flowtext_lab.tokenization import BpeTextTokenizer from infer_context_compare_from_c128 import build_model, temperature from trace_decode_basin import apply_decode_update def top1_for_sample( ids: torch.Tensor, probs: torch.Tensor, tokenizer: BpeTextTokenizer, sample_idx: int, ) -> dict[str, object]: row = ids[sample_idx].detach().cpu().tolist() total = max(len(row), 1) tid, count = Counter(row).most_common(1)[0] mask = ids[sample_idx] == tid pvals = probs[sample_idx, :, tid] return { "id": int(tid), "text": tokenizer.decode([int(tid)], stop_at_eos=False, skip_special_tokens=False).replace("\n", "\\n"), "frac": float(count / total), "mean_p_all_pos": float(pvals.mean().detach().cpu()), "mean_p_on_argmax_pos": float(pvals[mask].mean().detach().cpu()) if bool(mask.any()) else 0.0, "mean_max_p": float(probs[sample_idx].max(dim=-1).values.mean().detach().cpu()), } def fmt_cell(item: dict[str, object]) -> str: text = str(item["text"]) or "" return f"`{text}` {float(item['frac']) * 100:.1f}% / p={float(item['mean_p_on_argmax_pos']):.3f}" @torch.inference_mode() def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--checkpoint", required=True) ap.add_argument("--tokenizer_path", required=True) ap.add_argument("--out_dir", required=True) ap.add_argument("--name", required=True) ap.add_argument("--max_len", type=int, required=True) ap.add_argument("--n_samples", type=int, required=True) ap.add_argument("--sample_idx", type=int, default=0) ap.add_argument("--steps", type=int, default=128) ap.add_argument("--decode_rule", default="dirichlet_resample") ap.add_argument("--seed", type=int, default=314159) ap.add_argument("--pos_extend", default="repeat") ap.add_argument("--support_power", type=float, default=1.0) ap.add_argument("--semantic_power", type=float, default=1.5) ap.add_argument("--early_temp", type=float, default=2.8) ap.add_argument("--late_temp", type=float, default=1.45) ap.add_argument("--temp_end", type=float, default=0.55) ap.add_argument("--temp_power", type=float, default=1.5) ap.add_argument("--hybrid_switch", type=float, default=0.5) args = ap.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False, mmap=True) model = build_model(ckpt, tokenizer, args.max_len, device, args.pos_extend) eps = 1e-8 torch.manual_seed(args.seed) probs = sample_noise_simplex( (args.n_samples, args.max_len), tokenizer.vocab_size, device, eps, noise_mode="dirichlet", target_prob=1.0, noise_sigma=-1.0, dirichlet_concentration=1.0, ) attn = torch.ones((args.n_samples, args.max_len), dtype=torch.bool, device=device) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) rows: list[dict[str, object]] = [] for step in range(args.steps): prev_probs = probs prev_ids = prev_probs.argmax(dim=-1) t = model_time_for_step("flow", step, args.steps, args.n_samples, device, dtype=torch.float32) temp = temperature(step, args.steps, args.early_temp, args.late_temp, args.temp_end, args.temp_power) logits = model(state_for_model(model, prev_probs, eps), t, attn).float() endpoint = F.softmax(logits / temp, dim=-1) endpoint_ids = endpoint.argmax(dim=-1) probs = apply_decode_update( decode_rule=args.decode_rule, probs=prev_probs, endpoint=endpoint, step=step, steps=args.steps, support_power=args.support_power, semantic_power=args.semantic_power, hybrid_switch=args.hybrid_switch, c_min=1.0, c_max=1024.0, eps=eps, ) post_ids = probs.argmax(dim=-1) row = { "step": step + 1, "t": float((step + 1) / args.steps), "input": top1_for_sample(prev_ids, prev_probs, tokenizer, args.sample_idx), "endpoint": top1_for_sample(endpoint_ids, endpoint, tokenizer, args.sample_idx), "post": top1_for_sample(post_ids, probs, tokenizer, args.sample_idx), } rows.append(row) if (step + 1) % 16 == 0 or step == 0: print( args.name, "step", step + 1, "input", row["input"]["text"], row["input"]["frac"], "endpoint", row["endpoint"]["text"], row["endpoint"]["frac"], "post", row["post"]["text"], row["post"]["frac"], flush=True, ) stem = f"{args.name}_sample{args.sample_idx}" (out_dir / f"{stem}_top1_trace.json").write_text(json.dumps(rows, ensure_ascii=False, indent=2)) lines = [ f"# {args.name} sample {args.sample_idx} top1 trace", "", "Cell format: token = fraction of sequence positions whose argmax is that token / mean probability on those positions.", "", "| step | t | input top1 | endpoint top1 | post-update top1 | endpoint mean max-p |", "|---:|---:|---|---|---|---:|", ] for row in rows: endpoint = row["endpoint"] lines.append( f"| {row['step']} | {row['t']:.3f} | {fmt_cell(row['input'])} | " f"{fmt_cell(endpoint)} | {fmt_cell(row['post'])} | {float(endpoint['mean_max_p']):.3f} |" ) (out_dir / f"{stem}_top1_trace.md").write_text("\n".join(lines) + "\n") focus = [row for row in rows if 40 <= int(row["step"]) <= 60] focus_lines = [ f"# {args.name} sample {args.sample_idx} focus steps 40-60", "", "| step | input top1 | endpoint top1 | post-update top1 | endpoint mean max-p |", "|---:|---|---|---|---:|", ] for row in focus: endpoint = row["endpoint"] focus_lines.append( f"| {row['step']} | {fmt_cell(row['input'])} | " f"{fmt_cell(endpoint)} | {fmt_cell(row['post'])} | {float(endpoint['mean_max_p']):.3f} |" ) (out_dir / f"{stem}_focus_40_60.md").write_text("\n".join(focus_lines) + "\n") print("WROTE", out_dir) if __name__ == "__main__": main()