lta / LTA_openwebtext_dualt /scripts /dump_position_top1_trace.py
JinghuiLuAstronaut's picture
Add files using upload-large-folder tool
edff6fa verified
Raw
History Blame Contribute Delete
11.3 kB
from __future__ import annotations
import argparse
import csv
import html
import json
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
REPO_ROOT = Path(__file__).resolve().parents[1]
SCRIPT_DIR = Path(__file__).resolve().parent
for p in (REPO_ROOT, SCRIPT_DIR):
if str(p) not in sys.path:
sys.path.insert(0, str(p))
from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model
from flowtext_lab.tokenization import BpeTextTokenizer
from infer_context_compare_from_c128 import build_model, clamp_first_position, temperature
from trace_decode_basin import apply_decode_update
def decode_token(tokenizer: BpeTextTokenizer, tid: int) -> str:
text = tokenizer.decode([int(tid)], stop_at_eos=False, skip_special_tokens=False)
return text.replace("\n", "\\n").replace("\t", "\\t")
def cell(token: str, prob: float) -> str:
# Blue for confident cells, faint background for low confidence. Keep text compact.
alpha = min(max(prob, 0.0), 1.0)
bg = f"rgba(43, 113, 220, {0.08 + 0.52 * alpha:.3f})"
color = "#111" if alpha < 0.55 else "#fff"
return (
f'<td style="background:{bg};color:{color}" title="p={prob:.4f}">'
f'<span class="tok">{html.escape(token)}</span><br><span class="prob">{prob:.3f}</span></td>'
)
def write_html(
path: Path,
*,
title: str,
focus_steps: list[int],
rows_by_step: dict[int, list[dict[str, object]]],
) -> None:
lines = [
"<!doctype html><html><head><meta charset='utf-8'>",
f"<title>{html.escape(title)}</title>",
"<style>",
"body{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',sans-serif;margin:18px;color:#111}",
"table{border-collapse:collapse;font-size:11px;line-height:1.15}",
"th,td{border:1px solid #ddd;padding:3px 5px;min-width:62px;max-width:110px;vertical-align:top;overflow:hidden}",
"th{position:sticky;top:0;background:#f7f7f7;z-index:2}",
".pos{position:sticky;left:0;background:#fff;z-index:1;font-weight:600;min-width:48px}",
".tok{font-family:ui-monospace,SFMono-Regular,Menlo,monospace;white-space:pre-wrap}",
".prob{font-size:10px;opacity:.75}",
".wrap{overflow:auto;max-height:88vh;border:1px solid #ddd}",
".phase{font-size:10px;color:#555}",
"</style></head><body>",
f"<h1>{html.escape(title)}</h1>",
"<p>Each cell is this position's argmax token and probability. Color intensity tracks probability.</p>",
"<div class='wrap'><table>",
"<thead><tr><th class='pos'>pos</th>",
]
for step in focus_steps:
lines.append(f"<th colspan='3'>step {step}</th>")
lines.append("</tr><tr><th class='pos'></th>")
for _ in focus_steps:
lines.extend(["<th class='phase'>input</th>", "<th class='phase'>endpoint</th>", "<th class='phase'>post</th>"])
lines.append("</tr></thead><tbody>")
max_pos = max(len(rows_by_step[s]) for s in focus_steps)
for pos in range(max_pos):
lines.append(f"<tr><td class='pos'>{pos}</td>")
for step in focus_steps:
row = rows_by_step[step][pos]
lines.append(cell(str(row["input_token"]), float(row["input_prob"])))
lines.append(cell(str(row["endpoint_token"]), float(row["endpoint_prob"])))
lines.append(cell(str(row["post_token"]), float(row["post_prob"])))
lines.append("</tr>")
lines.extend(["</tbody></table></div></body></html>"])
path.write_text("\n".join(lines))
@torch.inference_mode()
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", required=True)
ap.add_argument("--tokenizer_path", required=True)
ap.add_argument("--out_dir", required=True)
ap.add_argument("--name", required=True)
ap.add_argument("--max_len", type=int, required=True)
ap.add_argument("--n_samples", type=int, required=True)
ap.add_argument("--sample_idx", type=int, default=0)
ap.add_argument("--steps", type=int, default=128)
ap.add_argument("--decode_rule", default="dirichlet_resample")
ap.add_argument("--seed", type=int, default=314159)
ap.add_argument("--pos_extend", default="repeat")
ap.add_argument("--support_power", type=float, default=1.0)
ap.add_argument("--semantic_power", type=float, default=1.5)
ap.add_argument("--early_temp", type=float, default=2.8)
ap.add_argument("--late_temp", type=float, default=1.45)
ap.add_argument("--temp_end", type=float, default=0.55)
ap.add_argument("--temp_power", type=float, default=1.5)
ap.add_argument("--hybrid_switch", type=float, default=0.5)
ap.add_argument("--fixed_first_token_id", type=int, default=-1)
ap.add_argument("--fixed_first_token_text", default="")
ap.add_argument("--fixed_first_initial_argmax", action="store_true")
ap.add_argument("--focus_start", type=int, default=40)
ap.add_argument("--focus_end", type=int, default=60)
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False, mmap=True)
model = build_model(ckpt, tokenizer, args.max_len, device, args.pos_extend)
eps = 1e-8
torch.manual_seed(args.seed)
probs = sample_noise_simplex(
(args.n_samples, args.max_len),
tokenizer.vocab_size,
device,
eps,
noise_mode="dirichlet",
target_prob=1.0,
noise_sigma=-1.0,
dirichlet_concentration=1.0,
)
fixed_first_token_id: int | None = None
if args.fixed_first_token_text:
encoded = tokenizer.encode(args.fixed_first_token_text, add_eos=False, add_special_tokens=False)
if not encoded:
raise ValueError(f"fixed_first_token_text encoded to no tokens: {args.fixed_first_token_text!r}")
fixed_first_token_id = int(encoded[0])
elif args.fixed_first_token_id >= 0:
fixed_first_token_id = int(args.fixed_first_token_id)
fixed_first_ids: torch.Tensor | None = None
if args.fixed_first_initial_argmax:
fixed_first_ids = probs[:, 0, :].argmax(dim=-1)
elif fixed_first_token_id is not None:
fixed_first_ids = torch.full((args.n_samples,), fixed_first_token_id, dtype=torch.long, device=device)
probs = clamp_first_position(probs, fixed_first_ids)
attn = torch.ones((args.n_samples, args.max_len), dtype=torch.bool, device=device)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
stem = f"{args.name}_sample{args.sample_idx}"
full_tsv = out_dir / f"{stem}_position_top1_full.tsv"
focus_tsv = out_dir / f"{stem}_position_top1_focus_{args.focus_start}_{args.focus_end}.tsv"
focus_steps = list(range(args.focus_start, args.focus_end + 1))
rows_by_step: dict[int, list[dict[str, object]]] = {}
with full_tsv.open("w", newline="") as f_full, focus_tsv.open("w", newline="") as f_focus:
fieldnames = [
"step",
"position",
"input_token",
"input_prob",
"endpoint_token",
"endpoint_prob",
"post_token",
"post_prob",
]
full_writer = csv.DictWriter(f_full, fieldnames=fieldnames, delimiter="\t")
focus_writer = csv.DictWriter(f_focus, fieldnames=fieldnames, delimiter="\t")
full_writer.writeheader()
focus_writer.writeheader()
for step in range(args.steps):
prev_probs = probs
prev_ids = prev_probs.argmax(dim=-1)
t = model_time_for_step("flow", step, args.steps, args.n_samples, device, dtype=torch.float32)
temp = temperature(step, args.steps, args.early_temp, args.late_temp, args.temp_end, args.temp_power)
logits = model(state_for_model(model, prev_probs, eps), t, attn).float()
endpoint = F.softmax(logits / temp, dim=-1)
endpoint_ids = endpoint.argmax(dim=-1)
probs = apply_decode_update(
decode_rule=args.decode_rule,
probs=prev_probs,
endpoint=endpoint,
step=step,
steps=args.steps,
support_power=args.support_power,
semantic_power=args.semantic_power,
hybrid_switch=args.hybrid_switch,
c_min=1.0,
c_max=1024.0,
eps=eps,
)
probs = clamp_first_position(probs, fixed_first_ids)
post_ids = probs.argmax(dim=-1)
s = args.sample_idx
input_ids = prev_ids[s].detach().cpu()
endpoint_ids_s = endpoint_ids[s].detach().cpu()
post_ids_s = post_ids[s].detach().cpu()
input_probs = prev_probs[s].gather(1, prev_ids[s].unsqueeze(-1)).squeeze(-1).detach().cpu()
endpoint_probs = endpoint[s].gather(1, endpoint_ids[s].unsqueeze(-1)).squeeze(-1).detach().cpu()
post_probs = probs[s].gather(1, post_ids[s].unsqueeze(-1)).squeeze(-1).detach().cpu()
step_rows: list[dict[str, object]] = []
for pos in range(args.max_len):
row = {
"step": step + 1,
"position": pos,
"input_token": decode_token(tokenizer, int(input_ids[pos])),
"input_prob": f"{float(input_probs[pos]):.8f}",
"endpoint_token": decode_token(tokenizer, int(endpoint_ids_s[pos])),
"endpoint_prob": f"{float(endpoint_probs[pos]):.8f}",
"post_token": decode_token(tokenizer, int(post_ids_s[pos])),
"post_prob": f"{float(post_probs[pos]):.8f}",
}
full_writer.writerow(row)
if args.focus_start <= step + 1 <= args.focus_end:
focus_writer.writerow(row)
step_rows.append(row)
if args.focus_start <= step + 1 <= args.focus_end:
rows_by_step[step + 1] = step_rows
if (step + 1) % 16 == 0 or step == 0:
print(f"{args.name} wrote step {step + 1}", flush=True)
write_html(
out_dir / f"{stem}_position_top1_focus_{args.focus_start}_{args.focus_end}.html",
title=f"{args.name} sample {args.sample_idx} position top1 focus {args.focus_start}-{args.focus_end}",
focus_steps=focus_steps,
rows_by_step=rows_by_step,
)
meta = {
"checkpoint": args.checkpoint,
"tokenizer_path": args.tokenizer_path,
"name": args.name,
"max_len": args.max_len,
"n_samples": args.n_samples,
"sample_idx": args.sample_idx,
"steps": args.steps,
"fixed_first_token_id": fixed_first_token_id,
"fixed_first_token_text": args.fixed_first_token_text,
"fixed_first_initial_argmax": bool(args.fixed_first_initial_argmax),
"full_tsv": str(full_tsv),
"focus_tsv": str(focus_tsv),
}
(out_dir / f"{stem}_position_top1_meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2))
print("WROTE", out_dir)
if __name__ == "__main__":
main()