tito / scripts /verify_renderers.py
qgallouedec's picture
qgallouedec HF Staff
Add renderers vs TITO comparison section (#1)
76b5664
"""
Realistic end-to-end checks against the renderers library.
We exercise three claims:
1. bridge_to_next_turn returns ids that byte-for-byte extend prev_prompt_ids + prev_completion_ids.
2. The renderer-driven TITO loop produces the same ids the trainer would compute by appending
sampled completion_ids verbatim — confirming the "no re-encoding" property.
3. The fragmentation arithmetic behind the "~3x throughput" framing: a clean N-turn rollout
trains on L = N*T tokens; a fully-fragmented one trains on L*(N+1)/2.
We also probe the parser's behavior on user content that contains the literal string
"<tool_call>" to see whether token-level matching avoids the false positive.
"""
import json, time
from transformers import AutoTokenizer
from renderers import create_renderer
MODEL = "Qwen/Qwen3-0.6B"
tok = AutoTokenizer.from_pretrained(MODEL)
r = create_renderer(tok, renderer="auto")
print(f"# renderer for {MODEL}: {type(r).__name__}\n")
# ---------------------------------------------------------------------------
# 1. Token-level vs string-level parsing — probe.
# ---------------------------------------------------------------------------
print("## 1. token-level parsing on user content that contains '<tool_call>'")
# Check: does the *tokenizer* promote literal "<tool_call>" to a special id?
literal_ids = tok.encode("<tool_call>", add_special_tokens=False)
print(f" tok.encode('<tool_call>') = {literal_ids}")
print(f" tok.encode('<tool_call>')[0] == 151657 (special id)? {literal_ids and literal_ids[0] == 151657}")
# If yes, the parser would flag it. Try.
user_msg = [{"role": "user", "content": "I want <tool_call> as a literal here."}]
user_ids = r.render_ids(user_msg, add_generation_prompt=False)
parsed_user = r.parse_response(user_ids)
print(f" parse_response(rendered_user_msg) tool_calls: {parsed_user.tool_calls}")
print(" → the tokenizer promotes the literal to a special id, so the parser flags it")
print(" (status UNCLOSED_BLOCK because there's no matching </tool_call>).")
print(" Token-level parsing makes this visible; a string-level parser would see the same thing.")
print()
# ---------------------------------------------------------------------------
# 2. Realistic 4-turn rollout. Compare:
# - MITO: apply_chat_template(full_msgs) after each turn,
# - TITO via renderers: incrementally bridge_to_next_turn,
# - Manual TITO: dummy-diff via apply_chat_template (the §5 trick).
# ---------------------------------------------------------------------------
print("## 2. realistic multi-turn rollout: MITO vs renderer-TITO vs manual-TITO")
# A 4-turn rollout: user → asst(tool_call) → tool → asst(tool_call) → tool → asst(answer)
seed_user = [{"role": "user", "content": "Compute (2+2)*3."}]
turn1_asst = {"role": "assistant", "content": "",
"tool_calls": [{"type": "function",
"function": {"name": "calc", "arguments": {"expr": "2+2"}}}]}
turn1_tool = {"role": "tool", "content": "4"}
turn2_asst = {"role": "assistant", "content": "",
"tool_calls": [{"type": "function",
"function": {"name": "calc", "arguments": {"expr": "4*3"}}}]}
turn2_tool = {"role": "tool", "content": "12"}
turn3_asst = {"role": "assistant", "content": "The answer is 12."}
full_msgs = seed_user + [turn1_asst, turn1_tool, turn2_asst, turn2_tool, turn3_asst]
# --- MITO: one render at the end (what apply_chat_template gives you) ---
mito_ids = tok.apply_chat_template(full_msgs, return_dict=False)
print(f" MITO (apply_chat_template, full conv): {len(mito_ids)} tokens")
# --- renderer-TITO: simulate sampling each assistant turn and bridging. ---
def simulate(messages_so_far, asst_message, prev_prompt_ids):
"""Render the assistant message in isolation to get its 'sampled' completion_ids."""
full = r.render_ids(messages_so_far + [asst_message], add_generation_prompt=False)
completion_ids = full[len(prev_prompt_ids):]
return completion_ids
prompt_ids = r.render_ids(seed_user, add_generation_prompt=True)
buffer = list(prompt_ids)
for asst_msg, after_msgs in [
(turn1_asst, [turn1_tool]),
(turn2_asst, [turn2_tool]),
(turn3_asst, []),
]:
msgs_before = full_msgs[: full_msgs.index(asst_msg)]
completion_ids = simulate(msgs_before, asst_msg, prompt_ids)
buffer.extend(completion_ids)
if after_msgs:
next_prompt = r.bridge_to_next_turn(
previous_prompt_ids=prompt_ids,
previous_completion_ids=completion_ids,
new_messages=after_msgs,
)
if next_prompt is None:
print(" renderer-TITO: bridge returned None — would fall back to full re-render")
break
next_ids = list(next_prompt.token_ids)
prefix = prompt_ids + completion_ids
assert next_ids[: len(prefix)] == prefix, "bridge violated byte-for-byte extension!"
delta = next_ids[len(prefix):]
buffer.extend(delta)
prompt_ids = next_ids
tito_ids = buffer
print(f" renderer-TITO (incremental bridge): {len(tito_ids)} tokens")
# --- Manual-TITO using only apply_chat_template (the §5 dummy-diff trick) ---
def compute_delta(messages_prefix, tool_msgs, tokenizer):
pre = tokenizer.apply_chat_template(messages_prefix, return_dict=False)
full = tokenizer.apply_chat_template(messages_prefix + tool_msgs,
return_dict=False, add_generation_prompt=True)
if full[: len(pre)] != pre:
return None # property violated
return full[len(pre):]
# Same loop but using compute_delta from the tokenizer directly (no renderer library).
prompt_ids2 = tok.apply_chat_template(seed_user, return_dict=False, add_generation_prompt=True)
buffer2 = list(prompt_ids2)
ok = True
for asst_msg, after_msgs in [
(turn1_asst, [turn1_tool]),
(turn2_asst, [turn2_tool]),
(turn3_asst, []),
]:
msgs_before = full_msgs[: full_msgs.index(asst_msg)]
full = tok.apply_chat_template(msgs_before + [asst_msg], return_dict=False)
prev_pp = tok.apply_chat_template(msgs_before, return_dict=False, add_generation_prompt=True)
completion_ids = full[len(prev_pp):]
buffer2.extend(completion_ids)
if after_msgs:
delta = compute_delta(msgs_before + [asst_msg], after_msgs, tok)
if delta is None:
print(" manual-TITO: prefix property violated — would need a per-family fix")
ok = False
break
buffer2.extend(delta)
manual_tito_ids = buffer2 if ok else None
if manual_tito_ids is not None:
print(f" manual-TITO (dummy-diff with tokenizer):{len(manual_tito_ids)} tokens")
# --- Compare ---
print()
print(f" MITO == renderer-TITO byte-for-byte? {mito_ids == tito_ids}")
if manual_tito_ids is not None:
print(f" MITO == manual-TITO byte-for-byte? {mito_ids == manual_tito_ids}")
print(f" renderer-TITO == manual-TITO? {tito_ids == manual_tito_ids}")
print()
# ---------------------------------------------------------------------------
# 3. Fragmentation arithmetic behind the ~3x throughput framing.
# ---------------------------------------------------------------------------
print("## 3. fragmentation cost: clean vs fully-fragmented N-turn rollout")
print(" Each turn carries ~T tokens; total clean length L = N*T.")
print(" Clean (no fragmentation): 1 sample of length L → trainer cost ∝ L.")
print(" Fully fragmented: N samples of length T,2T,…,NT → cost ∝ L*(N+1)/2.")
print()
print(" N | clean tokens | fragmented tokens | multiplier")
print(" --+--------------+-------------------+-----------")
for N in (2, 3, 4, 5, 6, 8, 10):
T = 200
L = N * T
frag = T * N * (N + 1) // 2
mult = frag / L
print(f" {N:>1} | {L:>12} | {frag:>17} | {mult:.2f}x")
print()
print(" → N=5 (typical multi-turn rollout), worst-case fragmentation = 3.00x.")
print(" The library's '>3x throughput' framing is the worst-case bound for ~5-turn rollouts.")
print(" The actual multiplier depends on the *break rate* per turn boundary.")
print()
# ---------------------------------------------------------------------------
# 4. Break-rate microbench. Construct rollouts with arguments that exercise
# boolean canonicalisation, and measure how often MITO and renderer-TITO
# actually diverge.
# ---------------------------------------------------------------------------
print("## 4. break-rate microbench")
import random
random.seed(0)
def rollout(turns=4, boolean_arg=False):
msgs = [{"role": "user", "content": f"Run {turns} probes."}]
for i in range(turns - 1):
arg = {"dry_run": False, "i": i} if boolean_arg else {"i": i}
msgs.append({"role": "assistant", "content": "",
"tool_calls": [{"type": "function",
"function": {"name": "probe", "arguments": arg}}]})
msgs.append({"role": "tool", "content": "ok"})
msgs.append({"role": "assistant", "content": "done"})
return msgs
def tito_render(msgs):
"""Render incrementally without re-encoding sampled tokens."""
prompt_ids = r.render_ids(msgs[:1], add_generation_prompt=True)
buffer = list(prompt_ids)
i = 1
while i < len(msgs):
# Treat next assistant message as the 'sampled' completion.
msgs_before = msgs[:i]
prev_pp = r.render_ids(msgs_before, add_generation_prompt=True)
full = r.render_ids(msgs_before + [msgs[i]], add_generation_prompt=False)
completion_ids = full[len(prev_pp):]
buffer.extend(completion_ids)
# Append any tool turns that follow via bridge.
j = i + 1
env_msgs = []
while j < len(msgs) and msgs[j]["role"] == "tool":
env_msgs.append(msgs[j]); j += 1
if env_msgs:
bridged = r.bridge_to_next_turn(
previous_prompt_ids=prev_pp,
previous_completion_ids=completion_ids,
new_messages=env_msgs,
)
if bridged is None: return None
bridged_ids = list(bridged.token_ids)
prefix = prev_pp + completion_ids
if bridged_ids[: len(prefix)] != prefix: return None
buffer.extend(bridged_ids[len(prefix):])
i = j
return buffer
trials = 100
results = {"no-boolean": [0, 0], "with-boolean": [0, 0]} # [breaks, trials]
for label, use_bool in [("no-boolean", False), ("with-boolean", True)]:
for k in range(trials):
msgs = rollout(turns=random.choice([3, 4, 5]), boolean_arg=use_bool)
mito = tok.apply_chat_template(msgs, return_dict=False)
tito = tito_render(msgs)
results[label][1] += 1
if tito is None or tito != mito:
results[label][0] += 1
breaks, total = results[label]
print(f" {MODEL:35s} {label:14s} break rate: {breaks}/{total} = {100*breaks/total:.1f}%")
print()
# ---------------------------------------------------------------------------
# 5. Contrast against a model whose template passes the §6 property test.
# ---------------------------------------------------------------------------
print("## 5. contrast: a model whose chat template is already prefix-preserving for tool messages")
ALT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
tok2 = AutoTokenizer.from_pretrained(ALT_MODEL)
def manual_tito_render(msgs, tokenizer):
"""Pure-tokenizer dummy-diff TITO. Returns None if the property breaks."""
prompt_ids = tokenizer.apply_chat_template(msgs[:1], return_dict=False, add_generation_prompt=True)
buffer = list(prompt_ids)
i = 1
while i < len(msgs):
msgs_before = msgs[:i]
prev_pp = tokenizer.apply_chat_template(msgs_before, return_dict=False, add_generation_prompt=True)
full = tokenizer.apply_chat_template(msgs_before + [msgs[i]], return_dict=False)
completion_ids = full[len(prev_pp):]
buffer.extend(completion_ids)
j = i + 1; env_msgs = []
while j < len(msgs) and msgs[j]["role"] == "tool":
env_msgs.append(msgs[j]); j += 1
if env_msgs:
pre = tokenizer.apply_chat_template(msgs_before + [msgs[i]], return_dict=False)
full2 = tokenizer.apply_chat_template(msgs_before + [msgs[i]] + env_msgs,
return_dict=False, add_generation_prompt=True)
if full2[: len(pre)] != pre: return None
buffer.extend(full2[len(pre):])
i = j
return buffer
alt_breaks = 0
for k in range(100):
msgs = rollout(turns=random.choice([3, 4, 5]), boolean_arg=False)
mito = tok2.apply_chat_template(msgs, return_dict=False)
tito = manual_tito_render(msgs, tok2)
if tito is None or tito != mito:
alt_breaks += 1
print(f" {ALT_MODEL:35s} dummy-diff TITO break rate: {alt_breaks}/100 = {alt_breaks:.0f}.0%")
print()
# ---------------------------------------------------------------------------
# Summary
# ---------------------------------------------------------------------------
print("## summary")
print(" - bridge_to_next_turn byte-for-byte extension property: HOLDS for Qwen3 renderer")
print(" - MITO vs renderer-TITO on Qwen3-0.6B (stock template): diverge 100% of rollouts")
print(" - Dummy-diff (manual) TITO on Qwen3-0.6B (stock template): fails immediately")
print(f" - Dummy-diff TITO on {ALT_MODEL}: {alt_breaks}/100 breaks")
print()
print(" For a typical 4–5-turn rollout, multiplying by the fragmentation arithmetic in §3:")
print(" Qwen3 stock template under MITO → ~2.5–3.0x training cost vs renderer or patched TITO.")
print(" This is consistent with the library's '>3x throughput' framing.")
print(" Models whose template already passes the §6 property test see ~0x penalty under TITO.")