"""Chat Template Inspector — a Gradio Space for analyzing chat templates.""" from __future__ import annotations import ast import traceback from typing import Any import gradio as gr from transformers import AutoProcessor, AutoTokenizer from inspector.checks import ALL_CHECKS, DESCRIPTIONS, CheckResult, Context from inspector.format import format_template # ─── Tokenizer cache ─────────────────────────────────────────────────────────── # # `from_pretrained` is slow (download + parse). We load each model once per # (model_id, token) pair so visitors with different Hub access don't share # cached objects. We also stash the originals of the four special tokens so an # empty override field truly resets them. _CACHE: dict[tuple[str, str], tuple[Any, bool, dict]] = {} def _load(model_id: str, token: str | None = None) -> tuple[Any, bool, dict]: model_id = model_id.strip() if not model_id: raise gr.Error("Please enter a model id.") key = (model_id, token or "") if key in _CACHE: return _CACHE[key] try: proc = AutoProcessor.from_pretrained(model_id, trust_remote_code=False, token=token) is_processor = hasattr(proc, "image_processor") or hasattr(proc, "feature_extractor") obj = proc if is_processor else AutoTokenizer.from_pretrained(model_id, trust_remote_code=False, token=token) except Exception: obj = AutoTokenizer.from_pretrained(model_id, trust_remote_code=False, token=token) is_processor = False tok = obj.tokenizer if is_processor else obj originals = {attr: getattr(tok, attr, None) for attr in ("bos_token", "eos_token", "pad_token", "unk_token")} _CACHE[key] = (obj, is_processor, originals) return _CACHE[key] def load_from_hub(model_id: str, oauth_token: gr.OAuthToken | None = None): token = oauth_token.token if oauth_token is not None else None try: obj, is_processor, originals = _load(model_id, token) except Exception as e: raise gr.Error(f"Failed to load `{model_id}`: {type(e).__name__}: {e}") template = ( obj.chat_template if not is_processor else getattr(obj, "chat_template", None) or obj.tokenizer.chat_template ) if not template: raise gr.Error(f"`{model_id}` has no chat_template.") return ( template, originals.get("bos_token") or "", originals.get("eos_token") or "", originals.get("pad_token") or "", originals.get("unk_token") or "", model_id, ) # ─── Checks runner ───────────────────────────────────────────────────────────── def _build_context( model_id: str, template_source: str, bos: str, eos: str, pad: str, unk: str, token: str | None = None, ) -> Context: obj, is_processor, originals = _load(model_id, token) tok = obj.tokenizer if is_processor else obj # Reset to original tokens, then apply overrides where the user provided one. for attr, original in originals.items(): setattr(tok, attr, original) for attr, value in (("bos_token", bos), ("eos_token", eos), ("pad_token", pad), ("unk_token", unk)): if value: setattr(tok, attr, value) if is_processor: obj.tokenizer.chat_template = template_source obj.chat_template = template_source return Context(obj=obj, is_processor=is_processor, template_source=template_source, model_id=model_id) _STATUS_EMOJI = {"pass": "✅", "fail": "❌", "warning": "⚠️", "na": "➖"} def _format_report(results) -> str: by_cat: dict[str, list] = {} for r in results: by_cat.setdefault(r.category, []).append(r) lines: list[str] = [] for cat, items in by_cat.items(): lines.append(f"### {cat}") for r in items: emoji = _STATUS_EMOJI.get(r.status, "?") summary = f"{emoji} {r.name}" body_parts: list[str] = [] description = DESCRIPTIONS.get(r.name, "") if description: body_parts.append(f"_{description}_") if r.message: body_parts.append(f"**{r.message}**") if r.details: body_parts.append("```\n" + r.details + "\n```") if r.reproducer: body_parts.append("```python\n" + r.reproducer + "\n```") if body_parts: lines.append(f"
\n{summary}\n") lines.append("\n\n".join(body_parts)) lines.append("\n
") else: lines.append(summary + "
") lines.append("") return "\n".join(lines) _PLAYGROUND_DEFAULT = '''\ [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's the capital of France?"}, ] ''' def render_playground( model_id: str, template_source: str, bos: str, eos: str, pad: str, unk: str, messages_src: str, add_gen_prompt: bool, oauth_token: gr.OAuthToken | None = None, ) -> str: """Parse the user's `messages` literal and render it with apply_chat_template. Uses `ast.literal_eval` so only Python literals are accepted — no function calls or attribute access — which makes it safe to expose on a public Space. """ if not messages_src.strip(): return "" try: messages = ast.literal_eval(messages_src) except (ValueError, SyntaxError) as e: return f"# parse error: {type(e).__name__}: {e}" if not isinstance(messages, list): return "# expected a list of messages at the top level" token = oauth_token.token if oauth_token is not None else None try: ctx = _build_context(model_id, template_source, bos, eos, pad, unk, token=token) return ctx.obj.apply_chat_template( messages, add_generation_prompt=add_gen_prompt, tokenize=False, ) except Exception as e: return f"# render error: {type(e).__name__}: {e}" def run_inspection( model_id: str, template_source: str, bos: str, eos: str, pad: str, unk: str, oauth_token: gr.OAuthToken | None = None, ): if not model_id.strip() or not template_source.strip(): return "_Enter a model id and a template to inspect._" token = oauth_token.token if oauth_token is not None else None try: ctx = _build_context(model_id, template_source, bos, eos, pad, unk, token=token) except Exception as e: return f"**Failed to build context:** `{type(e).__name__}: {e}`\n\n```\n{traceback.format_exc()}\n```" results: list[CheckResult] = [] for check in ALL_CHECKS: try: results.append(check(ctx)) except Exception as e: results.append(CheckResult( check.__name__, "Internal", "fail", f"Check raised an unexpected exception: {type(e).__name__}: {e}", )) return _format_report(results) # ─── UI ──────────────────────────────────────────────────────────────────────── DEFAULT_MODEL = "Qwen/Qwen3-8B" try: _d_template, _d_bos, _d_eos, _d_pad, _d_unk, _d_id = load_from_hub(DEFAULT_MODEL) except Exception: _d_template = _d_bos = _d_eos = _d_pad = _d_unk = "" _d_id = DEFAULT_MODEL with gr.Blocks(title="Chat Template Inspector", fill_width=True) as demo: with gr.Row(): gr.Markdown("## Chat Template Inspector", elem_id="title") login_btn = gr.LoginButton(size="sm") with gr.Row(): with gr.Column(): with gr.Row(): model_id_box = gr.Textbox( value=_d_id, placeholder="Qwen/Qwen3-8B", scale=5, show_label=False, container=False, ) load_btn = gr.Button("📥 Load", scale=1, variant="secondary") format_btn = gr.Button("✨ Format", scale=1, variant="secondary") template_editor = gr.Code( value=_d_template, language="jinja2", show_label=False, lines=20, max_lines=20, ) with gr.Row(): bos_box = gr.Textbox(value=_d_bos, label="bos_token") eos_box = gr.Textbox(value=_d_eos, label="eos_token") pad_box = gr.Textbox(value=_d_pad, label="pad_token") unk_box = gr.Textbox(value=_d_unk, label="unk_token") with gr.Column(): report_md = gr.Markdown("_Loading…_") gr.Markdown("### Playground") gr.Markdown( "Define a `messages` list (Python literal — only `dict`, `list`, " "`str`, etc., no function calls) and see it rendered with the " "current template." ) playground_input = gr.Code( value=_PLAYGROUND_DEFAULT, language="python", show_label=False, lines=8, max_lines=8, ) playground_gen_prompt = gr.Checkbox(label="add_generation_prompt", value=True) playground_output = gr.Code( value="", language="markdown", show_label=False, lines=8, max_lines=8, interactive=False, ) inspection_inputs = [model_id_box, template_editor, bos_box, eos_box, pad_box, unk_box] playground_inputs = inspection_inputs + [playground_input, playground_gen_prompt] # Initial render on page load. demo.load(run_inspection, inputs=inspection_inputs, outputs=[report_md]) demo.load(render_playground, inputs=playground_inputs, outputs=[playground_output]) # Re-run whenever the template or any special token changes. Gradio cancels # any in-flight call when a new event fires, so rapid edits debounce naturally. for component in (template_editor, bos_box, eos_box, pad_box, unk_box): component.change( run_inspection, inputs=inspection_inputs, outputs=[report_md], show_progress="hidden", ) component.change( render_playground, inputs=playground_inputs, outputs=[playground_output], show_progress="hidden", ) # Playground re-renders on its own input changes too. for component in (playground_input, playground_gen_prompt): component.change( render_playground, inputs=playground_inputs, outputs=[playground_output], show_progress="hidden", ) # Loading from the Hub overwrites the editor + token boxes (which fires # their .change events and triggers a fresh inspection automatically). load_btn.click( load_from_hub, inputs=[model_id_box], outputs=[template_editor, bos_box, eos_box, pad_box, unk_box, model_id_box], ) format_btn.click(format_template, inputs=[template_editor], outputs=[template_editor]) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())