from __future__ import annotations import argparse import csv import html import json import sys from pathlib import Path import torch import torch.nn.functional as F REPO_ROOT = Path(__file__).resolve().parents[1] SCRIPT_DIR = Path(__file__).resolve().parent for p in (REPO_ROOT, SCRIPT_DIR): if str(p) not in sys.path: sys.path.insert(0, str(p)) from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model from flowtext_lab.tokenization import BpeTextTokenizer from infer_context_compare_from_c128 import build_model, clamp_first_position, temperature from trace_decode_basin import apply_decode_update def decode_token(tokenizer: BpeTextTokenizer, tid: int) -> str: text = tokenizer.decode([int(tid)], stop_at_eos=False, skip_special_tokens=False) return text.replace("\n", "\\n").replace("\t", "\\t") def cell(token: str, prob: float) -> str: # Blue for confident cells, faint background for low confidence. Keep text compact. alpha = min(max(prob, 0.0), 1.0) bg = f"rgba(43, 113, 220, {0.08 + 0.52 * alpha:.3f})" color = "#111" if alpha < 0.55 else "#fff" return ( f'' f'{html.escape(token)}
{prob:.3f}' ) def write_html( path: Path, *, title: str, focus_steps: list[int], rows_by_step: dict[int, list[dict[str, object]]], ) -> None: lines = [ "", f"{html.escape(title)}", "", f"

{html.escape(title)}

", "

Each cell is this position's argmax token and probability. Color intensity tracks probability.

", "
", "", ] for step in focus_steps: lines.append(f"") lines.append("") for _ in focus_steps: lines.extend(["", "", ""]) lines.append("") max_pos = max(len(rows_by_step[s]) for s in focus_steps) for pos in range(max_pos): lines.append(f"") for step in focus_steps: row = rows_by_step[step][pos] lines.append(cell(str(row["input_token"]), float(row["input_prob"]))) lines.append(cell(str(row["endpoint_token"]), float(row["endpoint_prob"]))) lines.append(cell(str(row["post_token"]), float(row["post_prob"]))) lines.append("") lines.extend(["
posstep {step}
inputendpointpost
{pos}
"]) path.write_text("\n".join(lines)) @torch.inference_mode() def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--checkpoint", required=True) ap.add_argument("--tokenizer_path", required=True) ap.add_argument("--out_dir", required=True) ap.add_argument("--name", required=True) ap.add_argument("--max_len", type=int, required=True) ap.add_argument("--n_samples", type=int, required=True) ap.add_argument("--sample_idx", type=int, default=0) ap.add_argument("--steps", type=int, default=128) ap.add_argument("--decode_rule", default="dirichlet_resample") ap.add_argument("--seed", type=int, default=314159) ap.add_argument("--pos_extend", default="repeat") ap.add_argument("--support_power", type=float, default=1.0) ap.add_argument("--semantic_power", type=float, default=1.5) ap.add_argument("--early_temp", type=float, default=2.8) ap.add_argument("--late_temp", type=float, default=1.45) ap.add_argument("--temp_end", type=float, default=0.55) ap.add_argument("--temp_power", type=float, default=1.5) ap.add_argument("--hybrid_switch", type=float, default=0.5) ap.add_argument("--fixed_first_token_id", type=int, default=-1) ap.add_argument("--fixed_first_token_text", default="") ap.add_argument("--fixed_first_initial_argmax", action="store_true") ap.add_argument("--focus_start", type=int, default=40) ap.add_argument("--focus_end", type=int, default=60) args = ap.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False, mmap=True) model = build_model(ckpt, tokenizer, args.max_len, device, args.pos_extend) eps = 1e-8 torch.manual_seed(args.seed) probs = sample_noise_simplex( (args.n_samples, args.max_len), tokenizer.vocab_size, device, eps, noise_mode="dirichlet", target_prob=1.0, noise_sigma=-1.0, dirichlet_concentration=1.0, ) fixed_first_token_id: int | None = None if args.fixed_first_token_text: encoded = tokenizer.encode(args.fixed_first_token_text, add_eos=False, add_special_tokens=False) if not encoded: raise ValueError(f"fixed_first_token_text encoded to no tokens: {args.fixed_first_token_text!r}") fixed_first_token_id = int(encoded[0]) elif args.fixed_first_token_id >= 0: fixed_first_token_id = int(args.fixed_first_token_id) fixed_first_ids: torch.Tensor | None = None if args.fixed_first_initial_argmax: fixed_first_ids = probs[:, 0, :].argmax(dim=-1) elif fixed_first_token_id is not None: fixed_first_ids = torch.full((args.n_samples,), fixed_first_token_id, dtype=torch.long, device=device) probs = clamp_first_position(probs, fixed_first_ids) attn = torch.ones((args.n_samples, args.max_len), dtype=torch.bool, device=device) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) stem = f"{args.name}_sample{args.sample_idx}" full_tsv = out_dir / f"{stem}_position_top1_full.tsv" focus_tsv = out_dir / f"{stem}_position_top1_focus_{args.focus_start}_{args.focus_end}.tsv" focus_steps = list(range(args.focus_start, args.focus_end + 1)) rows_by_step: dict[int, list[dict[str, object]]] = {} with full_tsv.open("w", newline="") as f_full, focus_tsv.open("w", newline="") as f_focus: fieldnames = [ "step", "position", "input_token", "input_prob", "endpoint_token", "endpoint_prob", "post_token", "post_prob", ] full_writer = csv.DictWriter(f_full, fieldnames=fieldnames, delimiter="\t") focus_writer = csv.DictWriter(f_focus, fieldnames=fieldnames, delimiter="\t") full_writer.writeheader() focus_writer.writeheader() for step in range(args.steps): prev_probs = probs prev_ids = prev_probs.argmax(dim=-1) t = model_time_for_step("flow", step, args.steps, args.n_samples, device, dtype=torch.float32) temp = temperature(step, args.steps, args.early_temp, args.late_temp, args.temp_end, args.temp_power) logits = model(state_for_model(model, prev_probs, eps), t, attn).float() endpoint = F.softmax(logits / temp, dim=-1) endpoint_ids = endpoint.argmax(dim=-1) probs = apply_decode_update( decode_rule=args.decode_rule, probs=prev_probs, endpoint=endpoint, step=step, steps=args.steps, support_power=args.support_power, semantic_power=args.semantic_power, hybrid_switch=args.hybrid_switch, c_min=1.0, c_max=1024.0, eps=eps, ) probs = clamp_first_position(probs, fixed_first_ids) post_ids = probs.argmax(dim=-1) s = args.sample_idx input_ids = prev_ids[s].detach().cpu() endpoint_ids_s = endpoint_ids[s].detach().cpu() post_ids_s = post_ids[s].detach().cpu() input_probs = prev_probs[s].gather(1, prev_ids[s].unsqueeze(-1)).squeeze(-1).detach().cpu() endpoint_probs = endpoint[s].gather(1, endpoint_ids[s].unsqueeze(-1)).squeeze(-1).detach().cpu() post_probs = probs[s].gather(1, post_ids[s].unsqueeze(-1)).squeeze(-1).detach().cpu() step_rows: list[dict[str, object]] = [] for pos in range(args.max_len): row = { "step": step + 1, "position": pos, "input_token": decode_token(tokenizer, int(input_ids[pos])), "input_prob": f"{float(input_probs[pos]):.8f}", "endpoint_token": decode_token(tokenizer, int(endpoint_ids_s[pos])), "endpoint_prob": f"{float(endpoint_probs[pos]):.8f}", "post_token": decode_token(tokenizer, int(post_ids_s[pos])), "post_prob": f"{float(post_probs[pos]):.8f}", } full_writer.writerow(row) if args.focus_start <= step + 1 <= args.focus_end: focus_writer.writerow(row) step_rows.append(row) if args.focus_start <= step + 1 <= args.focus_end: rows_by_step[step + 1] = step_rows if (step + 1) % 16 == 0 or step == 0: print(f"{args.name} wrote step {step + 1}", flush=True) write_html( out_dir / f"{stem}_position_top1_focus_{args.focus_start}_{args.focus_end}.html", title=f"{args.name} sample {args.sample_idx} position top1 focus {args.focus_start}-{args.focus_end}", focus_steps=focus_steps, rows_by_step=rows_by_step, ) meta = { "checkpoint": args.checkpoint, "tokenizer_path": args.tokenizer_path, "name": args.name, "max_len": args.max_len, "n_samples": args.n_samples, "sample_idx": args.sample_idx, "steps": args.steps, "fixed_first_token_id": fixed_first_token_id, "fixed_first_token_text": args.fixed_first_token_text, "fixed_first_initial_argmax": bool(args.fixed_first_initial_argmax), "full_tsv": str(full_tsv), "focus_tsv": str(focus_tsv), } (out_dir / f"{stem}_position_top1_meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2)) print("WROTE", out_dir) if __name__ == "__main__": main()