lta / LTA_openwebtext_dualt /scripts /dump_one_sample_top1_trace.py
JinghuiLuAstronaut's picture
Add files using upload-large-folder tool
a96e98b verified
Raw
History Blame Contribute Delete
7.01 kB
from __future__ import annotations
import argparse
import json
import sys
from collections import Counter
from pathlib import Path
import torch
import torch.nn.functional as F
REPO_ROOT = Path(__file__).resolve().parents[1]
SCRIPT_DIR = Path(__file__).resolve().parent
for p in (REPO_ROOT, SCRIPT_DIR):
if str(p) not in sys.path:
sys.path.insert(0, str(p))
from flowtext_lab.decode import model_time_for_step, sample_noise_simplex, state_for_model
from flowtext_lab.tokenization import BpeTextTokenizer
from infer_context_compare_from_c128 import build_model, temperature
from trace_decode_basin import apply_decode_update
def top1_for_sample(
ids: torch.Tensor,
probs: torch.Tensor,
tokenizer: BpeTextTokenizer,
sample_idx: int,
) -> dict[str, object]:
row = ids[sample_idx].detach().cpu().tolist()
total = max(len(row), 1)
tid, count = Counter(row).most_common(1)[0]
mask = ids[sample_idx] == tid
pvals = probs[sample_idx, :, tid]
return {
"id": int(tid),
"text": tokenizer.decode([int(tid)], stop_at_eos=False, skip_special_tokens=False).replace("\n", "\\n"),
"frac": float(count / total),
"mean_p_all_pos": float(pvals.mean().detach().cpu()),
"mean_p_on_argmax_pos": float(pvals[mask].mean().detach().cpu()) if bool(mask.any()) else 0.0,
"mean_max_p": float(probs[sample_idx].max(dim=-1).values.mean().detach().cpu()),
}
def fmt_cell(item: dict[str, object]) -> str:
text = str(item["text"]) or "<empty>"
return f"`{text}` {float(item['frac']) * 100:.1f}% / p={float(item['mean_p_on_argmax_pos']):.3f}"
@torch.inference_mode()
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", required=True)
ap.add_argument("--tokenizer_path", required=True)
ap.add_argument("--out_dir", required=True)
ap.add_argument("--name", required=True)
ap.add_argument("--max_len", type=int, required=True)
ap.add_argument("--n_samples", type=int, required=True)
ap.add_argument("--sample_idx", type=int, default=0)
ap.add_argument("--steps", type=int, default=128)
ap.add_argument("--decode_rule", default="dirichlet_resample")
ap.add_argument("--seed", type=int, default=314159)
ap.add_argument("--pos_extend", default="repeat")
ap.add_argument("--support_power", type=float, default=1.0)
ap.add_argument("--semantic_power", type=float, default=1.5)
ap.add_argument("--early_temp", type=float, default=2.8)
ap.add_argument("--late_temp", type=float, default=1.45)
ap.add_argument("--temp_end", type=float, default=0.55)
ap.add_argument("--temp_power", type=float, default=1.5)
ap.add_argument("--hybrid_switch", type=float, default=0.5)
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False, mmap=True)
model = build_model(ckpt, tokenizer, args.max_len, device, args.pos_extend)
eps = 1e-8
torch.manual_seed(args.seed)
probs = sample_noise_simplex(
(args.n_samples, args.max_len),
tokenizer.vocab_size,
device,
eps,
noise_mode="dirichlet",
target_prob=1.0,
noise_sigma=-1.0,
dirichlet_concentration=1.0,
)
attn = torch.ones((args.n_samples, args.max_len), dtype=torch.bool, device=device)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
rows: list[dict[str, object]] = []
for step in range(args.steps):
prev_probs = probs
prev_ids = prev_probs.argmax(dim=-1)
t = model_time_for_step("flow", step, args.steps, args.n_samples, device, dtype=torch.float32)
temp = temperature(step, args.steps, args.early_temp, args.late_temp, args.temp_end, args.temp_power)
logits = model(state_for_model(model, prev_probs, eps), t, attn).float()
endpoint = F.softmax(logits / temp, dim=-1)
endpoint_ids = endpoint.argmax(dim=-1)
probs = apply_decode_update(
decode_rule=args.decode_rule,
probs=prev_probs,
endpoint=endpoint,
step=step,
steps=args.steps,
support_power=args.support_power,
semantic_power=args.semantic_power,
hybrid_switch=args.hybrid_switch,
c_min=1.0,
c_max=1024.0,
eps=eps,
)
post_ids = probs.argmax(dim=-1)
row = {
"step": step + 1,
"t": float((step + 1) / args.steps),
"input": top1_for_sample(prev_ids, prev_probs, tokenizer, args.sample_idx),
"endpoint": top1_for_sample(endpoint_ids, endpoint, tokenizer, args.sample_idx),
"post": top1_for_sample(post_ids, probs, tokenizer, args.sample_idx),
}
rows.append(row)
if (step + 1) % 16 == 0 or step == 0:
print(
args.name,
"step",
step + 1,
"input",
row["input"]["text"],
row["input"]["frac"],
"endpoint",
row["endpoint"]["text"],
row["endpoint"]["frac"],
"post",
row["post"]["text"],
row["post"]["frac"],
flush=True,
)
stem = f"{args.name}_sample{args.sample_idx}"
(out_dir / f"{stem}_top1_trace.json").write_text(json.dumps(rows, ensure_ascii=False, indent=2))
lines = [
f"# {args.name} sample {args.sample_idx} top1 trace",
"",
"Cell format: token = fraction of sequence positions whose argmax is that token / mean probability on those positions.",
"",
"| step | t | input top1 | endpoint top1 | post-update top1 | endpoint mean max-p |",
"|---:|---:|---|---|---|---:|",
]
for row in rows:
endpoint = row["endpoint"]
lines.append(
f"| {row['step']} | {row['t']:.3f} | {fmt_cell(row['input'])} | "
f"{fmt_cell(endpoint)} | {fmt_cell(row['post'])} | {float(endpoint['mean_max_p']):.3f} |"
)
(out_dir / f"{stem}_top1_trace.md").write_text("\n".join(lines) + "\n")
focus = [row for row in rows if 40 <= int(row["step"]) <= 60]
focus_lines = [
f"# {args.name} sample {args.sample_idx} focus steps 40-60",
"",
"| step | input top1 | endpoint top1 | post-update top1 | endpoint mean max-p |",
"|---:|---|---|---|---:|",
]
for row in focus:
endpoint = row["endpoint"]
focus_lines.append(
f"| {row['step']} | {fmt_cell(row['input'])} | "
f"{fmt_cell(endpoint)} | {fmt_cell(row['post'])} | {float(endpoint['mean_max_p']):.3f} |"
)
(out_dir / f"{stem}_focus_40_60.md").write_text("\n".join(focus_lines) + "\n")
print("WROTE", out_dir)
if __name__ == "__main__":
main()