| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import html |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| SCRIPT_DIR = Path(__file__).resolve().parent |
| for p in (REPO_ROOT, SCRIPT_DIR): |
| if str(p) not in sys.path: |
| sys.path.insert(0, str(p)) |
|
|
| from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model |
| from flowtext_lab.tokenization import BpeTextTokenizer |
| from infer_context_compare_from_c128 import build_model, clamp_first_position, temperature |
| from trace_decode_basin import apply_decode_update |
|
|
|
|
| def decode_token(tokenizer: BpeTextTokenizer, tid: int) -> str: |
| text = tokenizer.decode([int(tid)], stop_at_eos=False, skip_special_tokens=False) |
| return text.replace("\n", "\\n").replace("\t", "\\t") |
|
|
|
|
| def cell(token: str, prob: float) -> str: |
| |
| alpha = min(max(prob, 0.0), 1.0) |
| bg = f"rgba(43, 113, 220, {0.08 + 0.52 * alpha:.3f})" |
| color = "#111" if alpha < 0.55 else "#fff" |
| return ( |
| f'<td style="background:{bg};color:{color}" title="p={prob:.4f}">' |
| f'<span class="tok">{html.escape(token)}</span><br><span class="prob">{prob:.3f}</span></td>' |
| ) |
|
|
|
|
| def write_html( |
| path: Path, |
| *, |
| title: str, |
| focus_steps: list[int], |
| rows_by_step: dict[int, list[dict[str, object]]], |
| ) -> None: |
| lines = [ |
| "<!doctype html><html><head><meta charset='utf-8'>", |
| f"<title>{html.escape(title)}</title>", |
| "<style>", |
| "body{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',sans-serif;margin:18px;color:#111}", |
| "table{border-collapse:collapse;font-size:11px;line-height:1.15}", |
| "th,td{border:1px solid #ddd;padding:3px 5px;min-width:62px;max-width:110px;vertical-align:top;overflow:hidden}", |
| "th{position:sticky;top:0;background:#f7f7f7;z-index:2}", |
| ".pos{position:sticky;left:0;background:#fff;z-index:1;font-weight:600;min-width:48px}", |
| ".tok{font-family:ui-monospace,SFMono-Regular,Menlo,monospace;white-space:pre-wrap}", |
| ".prob{font-size:10px;opacity:.75}", |
| ".wrap{overflow:auto;max-height:88vh;border:1px solid #ddd}", |
| ".phase{font-size:10px;color:#555}", |
| "</style></head><body>", |
| f"<h1>{html.escape(title)}</h1>", |
| "<p>Each cell is this position's argmax token and probability. Color intensity tracks probability.</p>", |
| "<div class='wrap'><table>", |
| "<thead><tr><th class='pos'>pos</th>", |
| ] |
| for step in focus_steps: |
| lines.append(f"<th colspan='3'>step {step}</th>") |
| lines.append("</tr><tr><th class='pos'></th>") |
| for _ in focus_steps: |
| lines.extend(["<th class='phase'>input</th>", "<th class='phase'>endpoint</th>", "<th class='phase'>post</th>"]) |
| lines.append("</tr></thead><tbody>") |
| max_pos = max(len(rows_by_step[s]) for s in focus_steps) |
| for pos in range(max_pos): |
| lines.append(f"<tr><td class='pos'>{pos}</td>") |
| for step in focus_steps: |
| row = rows_by_step[step][pos] |
| lines.append(cell(str(row["input_token"]), float(row["input_prob"]))) |
| lines.append(cell(str(row["endpoint_token"]), float(row["endpoint_prob"]))) |
| lines.append(cell(str(row["post_token"]), float(row["post_prob"]))) |
| lines.append("</tr>") |
| lines.extend(["</tbody></table></div></body></html>"]) |
| path.write_text("\n".join(lines)) |
|
|
|
|
| @torch.inference_mode() |
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--checkpoint", required=True) |
| ap.add_argument("--tokenizer_path", required=True) |
| ap.add_argument("--out_dir", required=True) |
| ap.add_argument("--name", required=True) |
| ap.add_argument("--max_len", type=int, required=True) |
| ap.add_argument("--n_samples", type=int, required=True) |
| ap.add_argument("--sample_idx", type=int, default=0) |
| ap.add_argument("--steps", type=int, default=128) |
| ap.add_argument("--decode_rule", default="dirichlet_resample") |
| ap.add_argument("--seed", type=int, default=314159) |
| ap.add_argument("--pos_extend", default="repeat") |
| ap.add_argument("--support_power", type=float, default=1.0) |
| ap.add_argument("--semantic_power", type=float, default=1.5) |
| ap.add_argument("--early_temp", type=float, default=2.8) |
| ap.add_argument("--late_temp", type=float, default=1.45) |
| ap.add_argument("--temp_end", type=float, default=0.55) |
| ap.add_argument("--temp_power", type=float, default=1.5) |
| ap.add_argument("--hybrid_switch", type=float, default=0.5) |
| ap.add_argument("--fixed_first_token_id", type=int, default=-1) |
| ap.add_argument("--fixed_first_token_text", default="") |
| ap.add_argument("--fixed_first_initial_argmax", action="store_true") |
| ap.add_argument("--focus_start", type=int, default=40) |
| ap.add_argument("--focus_end", type=int, default=60) |
| args = ap.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) |
| ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False, mmap=True) |
| model = build_model(ckpt, tokenizer, args.max_len, device, args.pos_extend) |
| eps = 1e-8 |
| torch.manual_seed(args.seed) |
| probs = sample_noise_simplex( |
| (args.n_samples, args.max_len), |
| tokenizer.vocab_size, |
| device, |
| eps, |
| noise_mode="dirichlet", |
| target_prob=1.0, |
| noise_sigma=-1.0, |
| dirichlet_concentration=1.0, |
| ) |
| fixed_first_token_id: int | None = None |
| if args.fixed_first_token_text: |
| encoded = tokenizer.encode(args.fixed_first_token_text, add_eos=False, add_special_tokens=False) |
| if not encoded: |
| raise ValueError(f"fixed_first_token_text encoded to no tokens: {args.fixed_first_token_text!r}") |
| fixed_first_token_id = int(encoded[0]) |
| elif args.fixed_first_token_id >= 0: |
| fixed_first_token_id = int(args.fixed_first_token_id) |
| fixed_first_ids: torch.Tensor | None = None |
| if args.fixed_first_initial_argmax: |
| fixed_first_ids = probs[:, 0, :].argmax(dim=-1) |
| elif fixed_first_token_id is not None: |
| fixed_first_ids = torch.full((args.n_samples,), fixed_first_token_id, dtype=torch.long, device=device) |
| probs = clamp_first_position(probs, fixed_first_ids) |
| attn = torch.ones((args.n_samples, args.max_len), dtype=torch.bool, device=device) |
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| stem = f"{args.name}_sample{args.sample_idx}" |
| full_tsv = out_dir / f"{stem}_position_top1_full.tsv" |
| focus_tsv = out_dir / f"{stem}_position_top1_focus_{args.focus_start}_{args.focus_end}.tsv" |
| focus_steps = list(range(args.focus_start, args.focus_end + 1)) |
| rows_by_step: dict[int, list[dict[str, object]]] = {} |
|
|
| with full_tsv.open("w", newline="") as f_full, focus_tsv.open("w", newline="") as f_focus: |
| fieldnames = [ |
| "step", |
| "position", |
| "input_token", |
| "input_prob", |
| "endpoint_token", |
| "endpoint_prob", |
| "post_token", |
| "post_prob", |
| ] |
| full_writer = csv.DictWriter(f_full, fieldnames=fieldnames, delimiter="\t") |
| focus_writer = csv.DictWriter(f_focus, fieldnames=fieldnames, delimiter="\t") |
| full_writer.writeheader() |
| focus_writer.writeheader() |
|
|
| for step in range(args.steps): |
| prev_probs = probs |
| prev_ids = prev_probs.argmax(dim=-1) |
| t = model_time_for_step("flow", step, args.steps, args.n_samples, device, dtype=torch.float32) |
| temp = temperature(step, args.steps, args.early_temp, args.late_temp, args.temp_end, args.temp_power) |
| logits = model(state_for_model(model, prev_probs, eps), t, attn).float() |
| endpoint = F.softmax(logits / temp, dim=-1) |
| endpoint_ids = endpoint.argmax(dim=-1) |
| probs = apply_decode_update( |
| decode_rule=args.decode_rule, |
| probs=prev_probs, |
| endpoint=endpoint, |
| step=step, |
| steps=args.steps, |
| support_power=args.support_power, |
| semantic_power=args.semantic_power, |
| hybrid_switch=args.hybrid_switch, |
| c_min=1.0, |
| c_max=1024.0, |
| eps=eps, |
| ) |
| probs = clamp_first_position(probs, fixed_first_ids) |
| post_ids = probs.argmax(dim=-1) |
| s = args.sample_idx |
| input_ids = prev_ids[s].detach().cpu() |
| endpoint_ids_s = endpoint_ids[s].detach().cpu() |
| post_ids_s = post_ids[s].detach().cpu() |
| input_probs = prev_probs[s].gather(1, prev_ids[s].unsqueeze(-1)).squeeze(-1).detach().cpu() |
| endpoint_probs = endpoint[s].gather(1, endpoint_ids[s].unsqueeze(-1)).squeeze(-1).detach().cpu() |
| post_probs = probs[s].gather(1, post_ids[s].unsqueeze(-1)).squeeze(-1).detach().cpu() |
| step_rows: list[dict[str, object]] = [] |
| for pos in range(args.max_len): |
| row = { |
| "step": step + 1, |
| "position": pos, |
| "input_token": decode_token(tokenizer, int(input_ids[pos])), |
| "input_prob": f"{float(input_probs[pos]):.8f}", |
| "endpoint_token": decode_token(tokenizer, int(endpoint_ids_s[pos])), |
| "endpoint_prob": f"{float(endpoint_probs[pos]):.8f}", |
| "post_token": decode_token(tokenizer, int(post_ids_s[pos])), |
| "post_prob": f"{float(post_probs[pos]):.8f}", |
| } |
| full_writer.writerow(row) |
| if args.focus_start <= step + 1 <= args.focus_end: |
| focus_writer.writerow(row) |
| step_rows.append(row) |
| if args.focus_start <= step + 1 <= args.focus_end: |
| rows_by_step[step + 1] = step_rows |
| if (step + 1) % 16 == 0 or step == 0: |
| print(f"{args.name} wrote step {step + 1}", flush=True) |
|
|
| write_html( |
| out_dir / f"{stem}_position_top1_focus_{args.focus_start}_{args.focus_end}.html", |
| title=f"{args.name} sample {args.sample_idx} position top1 focus {args.focus_start}-{args.focus_end}", |
| focus_steps=focus_steps, |
| rows_by_step=rows_by_step, |
| ) |
| meta = { |
| "checkpoint": args.checkpoint, |
| "tokenizer_path": args.tokenizer_path, |
| "name": args.name, |
| "max_len": args.max_len, |
| "n_samples": args.n_samples, |
| "sample_idx": args.sample_idx, |
| "steps": args.steps, |
| "fixed_first_token_id": fixed_first_token_id, |
| "fixed_first_token_text": args.fixed_first_token_text, |
| "fixed_first_initial_argmax": bool(args.fixed_first_initial_argmax), |
| "full_tsv": str(full_tsv), |
| "focus_tsv": str(focus_tsv), |
| } |
| (out_dir / f"{stem}_position_top1_meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2)) |
| print("WROTE", out_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|