Quentin GallouΓ©dec
Add "Sign in with Hugging Face" OAuth so visitors access their own gated repos
cb002d0
"""Chat Template Inspector β€” a Gradio Space for analyzing chat templates."""
from __future__ import annotations
import ast
import traceback
from typing import Any
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer
from inspector.checks import ALL_CHECKS, DESCRIPTIONS, CheckResult, Context
from inspector.format import format_template
# ─── Tokenizer cache ───────────────────────────────────────────────────────────
#
# `from_pretrained` is slow (download + parse). We load each model once per
# (model_id, token) pair so visitors with different Hub access don't share
# cached objects. We also stash the originals of the four special tokens so an
# empty override field truly resets them.
_CACHE: dict[tuple[str, str], tuple[Any, bool, dict]] = {}
def _load(model_id: str, token: str | None = None) -> tuple[Any, bool, dict]:
model_id = model_id.strip()
if not model_id:
raise gr.Error("Please enter a model id.")
key = (model_id, token or "")
if key in _CACHE:
return _CACHE[key]
try:
proc = AutoProcessor.from_pretrained(model_id, trust_remote_code=False, token=token)
is_processor = hasattr(proc, "image_processor") or hasattr(proc, "feature_extractor")
obj = proc if is_processor else AutoTokenizer.from_pretrained(model_id, trust_remote_code=False, token=token)
except Exception:
obj = AutoTokenizer.from_pretrained(model_id, trust_remote_code=False, token=token)
is_processor = False
tok = obj.tokenizer if is_processor else obj
originals = {attr: getattr(tok, attr, None) for attr in ("bos_token", "eos_token", "pad_token", "unk_token")}
_CACHE[key] = (obj, is_processor, originals)
return _CACHE[key]
def load_from_hub(model_id: str, oauth_token: gr.OAuthToken | None = None):
token = oauth_token.token if oauth_token is not None else None
try:
obj, is_processor, originals = _load(model_id, token)
except Exception as e:
raise gr.Error(f"Failed to load `{model_id}`: {type(e).__name__}: {e}")
template = (
obj.chat_template
if not is_processor
else getattr(obj, "chat_template", None) or obj.tokenizer.chat_template
)
if not template:
raise gr.Error(f"`{model_id}` has no chat_template.")
return (
template,
originals.get("bos_token") or "",
originals.get("eos_token") or "",
originals.get("pad_token") or "",
originals.get("unk_token") or "",
model_id,
)
# ─── Checks runner ─────────────────────────────────────────────────────────────
def _build_context(
model_id: str, template_source: str, bos: str, eos: str, pad: str, unk: str,
token: str | None = None,
) -> Context:
obj, is_processor, originals = _load(model_id, token)
tok = obj.tokenizer if is_processor else obj
# Reset to original tokens, then apply overrides where the user provided one.
for attr, original in originals.items():
setattr(tok, attr, original)
for attr, value in (("bos_token", bos), ("eos_token", eos), ("pad_token", pad), ("unk_token", unk)):
if value:
setattr(tok, attr, value)
if is_processor:
obj.tokenizer.chat_template = template_source
obj.chat_template = template_source
return Context(obj=obj, is_processor=is_processor, template_source=template_source, model_id=model_id)
_STATUS_EMOJI = {"pass": "βœ…", "fail": "❌", "warning": "⚠️", "na": "βž–"}
def _format_report(results) -> str:
by_cat: dict[str, list] = {}
for r in results:
by_cat.setdefault(r.category, []).append(r)
lines: list[str] = []
for cat, items in by_cat.items():
lines.append(f"### {cat}")
for r in items:
emoji = _STATUS_EMOJI.get(r.status, "?")
summary = f"{emoji} <b>{r.name}</b>"
body_parts: list[str] = []
description = DESCRIPTIONS.get(r.name, "")
if description:
body_parts.append(f"_{description}_")
if r.message:
body_parts.append(f"**{r.message}**")
if r.details:
body_parts.append("```\n" + r.details + "\n```")
if r.reproducer:
body_parts.append("```python\n" + r.reproducer + "\n```")
if body_parts:
lines.append(f"<details>\n<summary>{summary}</summary>\n")
lines.append("\n\n".join(body_parts))
lines.append("\n</details>")
else:
lines.append(summary + "<br>")
lines.append("")
return "\n".join(lines)
_PLAYGROUND_DEFAULT = '''\
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the capital of France?"},
]
'''
def render_playground(
model_id: str, template_source: str, bos: str, eos: str, pad: str, unk: str,
messages_src: str, add_gen_prompt: bool,
oauth_token: gr.OAuthToken | None = None,
) -> str:
"""Parse the user's `messages` literal and render it with apply_chat_template.
Uses `ast.literal_eval` so only Python literals are accepted β€” no function
calls or attribute access β€” which makes it safe to expose on a public Space.
"""
if not messages_src.strip():
return ""
try:
messages = ast.literal_eval(messages_src)
except (ValueError, SyntaxError) as e:
return f"# parse error: {type(e).__name__}: {e}"
if not isinstance(messages, list):
return "# expected a list of messages at the top level"
token = oauth_token.token if oauth_token is not None else None
try:
ctx = _build_context(model_id, template_source, bos, eos, pad, unk, token=token)
return ctx.obj.apply_chat_template(
messages, add_generation_prompt=add_gen_prompt, tokenize=False,
)
except Exception as e:
return f"# render error: {type(e).__name__}: {e}"
def run_inspection(
model_id: str, template_source: str, bos: str, eos: str, pad: str, unk: str,
oauth_token: gr.OAuthToken | None = None,
):
if not model_id.strip() or not template_source.strip():
return "_Enter a model id and a template to inspect._"
token = oauth_token.token if oauth_token is not None else None
try:
ctx = _build_context(model_id, template_source, bos, eos, pad, unk, token=token)
except Exception as e:
return f"**Failed to build context:** `{type(e).__name__}: {e}`\n\n```\n{traceback.format_exc()}\n```"
results: list[CheckResult] = []
for check in ALL_CHECKS:
try:
results.append(check(ctx))
except Exception as e:
results.append(CheckResult(
check.__name__, "Internal", "fail",
f"Check raised an unexpected exception: {type(e).__name__}: {e}",
))
return _format_report(results)
# ─── UI ────────────────────────────────────────────────────────────────────────
DEFAULT_MODEL = "Qwen/Qwen3-8B"
try:
_d_template, _d_bos, _d_eos, _d_pad, _d_unk, _d_id = load_from_hub(DEFAULT_MODEL)
except Exception:
_d_template = _d_bos = _d_eos = _d_pad = _d_unk = ""
_d_id = DEFAULT_MODEL
with gr.Blocks(title="Chat Template Inspector", fill_width=True) as demo:
with gr.Row():
gr.Markdown("## Chat Template Inspector", elem_id="title")
login_btn = gr.LoginButton(size="sm")
with gr.Row():
with gr.Column():
with gr.Row():
model_id_box = gr.Textbox(
value=_d_id, placeholder="Qwen/Qwen3-8B",
scale=5, show_label=False, container=False,
)
load_btn = gr.Button("πŸ“₯ Load", scale=1, variant="secondary")
format_btn = gr.Button("✨ Format", scale=1, variant="secondary")
template_editor = gr.Code(
value=_d_template, language="jinja2", show_label=False,
lines=20, max_lines=20,
)
with gr.Row():
bos_box = gr.Textbox(value=_d_bos, label="bos_token")
eos_box = gr.Textbox(value=_d_eos, label="eos_token")
pad_box = gr.Textbox(value=_d_pad, label="pad_token")
unk_box = gr.Textbox(value=_d_unk, label="unk_token")
with gr.Column():
report_md = gr.Markdown("_Loading…_")
gr.Markdown("### Playground")
gr.Markdown(
"Define a `messages` list (Python literal β€” only `dict`, `list`, "
"`str`, etc., no function calls) and see it rendered with the "
"current template."
)
playground_input = gr.Code(
value=_PLAYGROUND_DEFAULT, language="python", show_label=False,
lines=8, max_lines=8,
)
playground_gen_prompt = gr.Checkbox(label="add_generation_prompt", value=True)
playground_output = gr.Code(
value="", language="markdown", show_label=False,
lines=8, max_lines=8, interactive=False,
)
inspection_inputs = [model_id_box, template_editor, bos_box, eos_box, pad_box, unk_box]
playground_inputs = inspection_inputs + [playground_input, playground_gen_prompt]
# Initial render on page load.
demo.load(run_inspection, inputs=inspection_inputs, outputs=[report_md])
demo.load(render_playground, inputs=playground_inputs, outputs=[playground_output])
# Re-run whenever the template or any special token changes. Gradio cancels
# any in-flight call when a new event fires, so rapid edits debounce naturally.
for component in (template_editor, bos_box, eos_box, pad_box, unk_box):
component.change(
run_inspection, inputs=inspection_inputs, outputs=[report_md],
show_progress="hidden",
)
component.change(
render_playground, inputs=playground_inputs, outputs=[playground_output],
show_progress="hidden",
)
# Playground re-renders on its own input changes too.
for component in (playground_input, playground_gen_prompt):
component.change(
render_playground, inputs=playground_inputs, outputs=[playground_output],
show_progress="hidden",
)
# Loading from the Hub overwrites the editor + token boxes (which fires
# their .change events and triggers a fresh inspection automatically).
load_btn.click(
load_from_hub,
inputs=[model_id_box],
outputs=[template_editor, bos_box, eos_box, pad_box, unk_box, model_id_box],
)
format_btn.click(format_template, inputs=[template_editor], outputs=[template_editor])
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())