| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from collections import Counter |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| SCRIPT_DIR = Path(__file__).resolve().parent |
| for p in (REPO_ROOT, SCRIPT_DIR): |
| if str(p) not in sys.path: |
| sys.path.insert(0, str(p)) |
|
|
| from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model |
| from flowtext_lab.tokenization import BpeTextTokenizer |
| from infer_context_compare_from_c128 import build_model, temperature |
| from trace_decode_basin import apply_decode_update |
|
|
|
|
| def top1_for_sample( |
| ids: torch.Tensor, |
| probs: torch.Tensor, |
| tokenizer: BpeTextTokenizer, |
| sample_idx: int, |
| ) -> dict[str, object]: |
| row = ids[sample_idx].detach().cpu().tolist() |
| total = max(len(row), 1) |
| tid, count = Counter(row).most_common(1)[0] |
| mask = ids[sample_idx] == tid |
| pvals = probs[sample_idx, :, tid] |
| return { |
| "id": int(tid), |
| "text": tokenizer.decode([int(tid)], stop_at_eos=False, skip_special_tokens=False).replace("\n", "\\n"), |
| "frac": float(count / total), |
| "mean_p_all_pos": float(pvals.mean().detach().cpu()), |
| "mean_p_on_argmax_pos": float(pvals[mask].mean().detach().cpu()) if bool(mask.any()) else 0.0, |
| "mean_max_p": float(probs[sample_idx].max(dim=-1).values.mean().detach().cpu()), |
| } |
|
|
|
|
| def fmt_cell(item: dict[str, object]) -> str: |
| text = str(item["text"]) or "<empty>" |
| return f"`{text}` {float(item['frac']) * 100:.1f}% / p={float(item['mean_p_on_argmax_pos']):.3f}" |
|
|
|
|
| @torch.inference_mode() |
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--checkpoint", required=True) |
| ap.add_argument("--tokenizer_path", required=True) |
| ap.add_argument("--out_dir", required=True) |
| ap.add_argument("--name", required=True) |
| ap.add_argument("--max_len", type=int, required=True) |
| ap.add_argument("--n_samples", type=int, required=True) |
| ap.add_argument("--sample_idx", type=int, default=0) |
| ap.add_argument("--steps", type=int, default=128) |
| ap.add_argument("--decode_rule", default="dirichlet_resample") |
| ap.add_argument("--seed", type=int, default=314159) |
| ap.add_argument("--pos_extend", default="repeat") |
| ap.add_argument("--support_power", type=float, default=1.0) |
| ap.add_argument("--semantic_power", type=float, default=1.5) |
| ap.add_argument("--early_temp", type=float, default=2.8) |
| ap.add_argument("--late_temp", type=float, default=1.45) |
| ap.add_argument("--temp_end", type=float, default=0.55) |
| ap.add_argument("--temp_power", type=float, default=1.5) |
| ap.add_argument("--hybrid_switch", type=float, default=0.5) |
| args = ap.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) |
| ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False, mmap=True) |
| model = build_model(ckpt, tokenizer, args.max_len, device, args.pos_extend) |
| eps = 1e-8 |
| torch.manual_seed(args.seed) |
| probs = sample_noise_simplex( |
| (args.n_samples, args.max_len), |
| tokenizer.vocab_size, |
| device, |
| eps, |
| noise_mode="dirichlet", |
| target_prob=1.0, |
| noise_sigma=-1.0, |
| dirichlet_concentration=1.0, |
| ) |
| attn = torch.ones((args.n_samples, args.max_len), dtype=torch.bool, device=device) |
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| rows: list[dict[str, object]] = [] |
|
|
| for step in range(args.steps): |
| prev_probs = probs |
| prev_ids = prev_probs.argmax(dim=-1) |
| t = model_time_for_step("flow", step, args.steps, args.n_samples, device, dtype=torch.float32) |
| temp = temperature(step, args.steps, args.early_temp, args.late_temp, args.temp_end, args.temp_power) |
| logits = model(state_for_model(model, prev_probs, eps), t, attn).float() |
| endpoint = F.softmax(logits / temp, dim=-1) |
| endpoint_ids = endpoint.argmax(dim=-1) |
| probs = apply_decode_update( |
| decode_rule=args.decode_rule, |
| probs=prev_probs, |
| endpoint=endpoint, |
| step=step, |
| steps=args.steps, |
| support_power=args.support_power, |
| semantic_power=args.semantic_power, |
| hybrid_switch=args.hybrid_switch, |
| c_min=1.0, |
| c_max=1024.0, |
| eps=eps, |
| ) |
| post_ids = probs.argmax(dim=-1) |
| row = { |
| "step": step + 1, |
| "t": float((step + 1) / args.steps), |
| "input": top1_for_sample(prev_ids, prev_probs, tokenizer, args.sample_idx), |
| "endpoint": top1_for_sample(endpoint_ids, endpoint, tokenizer, args.sample_idx), |
| "post": top1_for_sample(post_ids, probs, tokenizer, args.sample_idx), |
| } |
| rows.append(row) |
| if (step + 1) % 16 == 0 or step == 0: |
| print( |
| args.name, |
| "step", |
| step + 1, |
| "input", |
| row["input"]["text"], |
| row["input"]["frac"], |
| "endpoint", |
| row["endpoint"]["text"], |
| row["endpoint"]["frac"], |
| "post", |
| row["post"]["text"], |
| row["post"]["frac"], |
| flush=True, |
| ) |
|
|
| stem = f"{args.name}_sample{args.sample_idx}" |
| (out_dir / f"{stem}_top1_trace.json").write_text(json.dumps(rows, ensure_ascii=False, indent=2)) |
|
|
| lines = [ |
| f"# {args.name} sample {args.sample_idx} top1 trace", |
| "", |
| "Cell format: token = fraction of sequence positions whose argmax is that token / mean probability on those positions.", |
| "", |
| "| step | t | input top1 | endpoint top1 | post-update top1 | endpoint mean max-p |", |
| "|---:|---:|---|---|---|---:|", |
| ] |
| for row in rows: |
| endpoint = row["endpoint"] |
| lines.append( |
| f"| {row['step']} | {row['t']:.3f} | {fmt_cell(row['input'])} | " |
| f"{fmt_cell(endpoint)} | {fmt_cell(row['post'])} | {float(endpoint['mean_max_p']):.3f} |" |
| ) |
| (out_dir / f"{stem}_top1_trace.md").write_text("\n".join(lines) + "\n") |
|
|
| focus = [row for row in rows if 40 <= int(row["step"]) <= 60] |
| focus_lines = [ |
| f"# {args.name} sample {args.sample_idx} focus steps 40-60", |
| "", |
| "| step | input top1 | endpoint top1 | post-update top1 | endpoint mean max-p |", |
| "|---:|---|---|---|---:|", |
| ] |
| for row in focus: |
| endpoint = row["endpoint"] |
| focus_lines.append( |
| f"| {row['step']} | {fmt_cell(row['input'])} | " |
| f"{fmt_cell(endpoint)} | {fmt_cell(row['post'])} | {float(endpoint['mean_max_p']):.3f} |" |
| ) |
| (out_dir / f"{stem}_focus_40_60.md").write_text("\n".join(focus_lines) + "\n") |
| print("WROTE", out_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|