Spaces:
Sleeping
Sleeping
Quentin GallouΓ©dec
Add "Sign in with Hugging Face" OAuth so visitors access their own gated repos
cb002d0 | """Chat Template Inspector β a Gradio Space for analyzing chat templates.""" | |
| from __future__ import annotations | |
| import ast | |
| import traceback | |
| from typing import Any | |
| import gradio as gr | |
| from transformers import AutoProcessor, AutoTokenizer | |
| from inspector.checks import ALL_CHECKS, DESCRIPTIONS, CheckResult, Context | |
| from inspector.format import format_template | |
| # βββ Tokenizer cache βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # | |
| # `from_pretrained` is slow (download + parse). We load each model once per | |
| # (model_id, token) pair so visitors with different Hub access don't share | |
| # cached objects. We also stash the originals of the four special tokens so an | |
| # empty override field truly resets them. | |
| _CACHE: dict[tuple[str, str], tuple[Any, bool, dict]] = {} | |
| def _load(model_id: str, token: str | None = None) -> tuple[Any, bool, dict]: | |
| model_id = model_id.strip() | |
| if not model_id: | |
| raise gr.Error("Please enter a model id.") | |
| key = (model_id, token or "") | |
| if key in _CACHE: | |
| return _CACHE[key] | |
| try: | |
| proc = AutoProcessor.from_pretrained(model_id, trust_remote_code=False, token=token) | |
| is_processor = hasattr(proc, "image_processor") or hasattr(proc, "feature_extractor") | |
| obj = proc if is_processor else AutoTokenizer.from_pretrained(model_id, trust_remote_code=False, token=token) | |
| except Exception: | |
| obj = AutoTokenizer.from_pretrained(model_id, trust_remote_code=False, token=token) | |
| is_processor = False | |
| tok = obj.tokenizer if is_processor else obj | |
| originals = {attr: getattr(tok, attr, None) for attr in ("bos_token", "eos_token", "pad_token", "unk_token")} | |
| _CACHE[key] = (obj, is_processor, originals) | |
| return _CACHE[key] | |
| def load_from_hub(model_id: str, oauth_token: gr.OAuthToken | None = None): | |
| token = oauth_token.token if oauth_token is not None else None | |
| try: | |
| obj, is_processor, originals = _load(model_id, token) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load `{model_id}`: {type(e).__name__}: {e}") | |
| template = ( | |
| obj.chat_template | |
| if not is_processor | |
| else getattr(obj, "chat_template", None) or obj.tokenizer.chat_template | |
| ) | |
| if not template: | |
| raise gr.Error(f"`{model_id}` has no chat_template.") | |
| return ( | |
| template, | |
| originals.get("bos_token") or "", | |
| originals.get("eos_token") or "", | |
| originals.get("pad_token") or "", | |
| originals.get("unk_token") or "", | |
| model_id, | |
| ) | |
| # βββ Checks runner βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_context( | |
| model_id: str, template_source: str, bos: str, eos: str, pad: str, unk: str, | |
| token: str | None = None, | |
| ) -> Context: | |
| obj, is_processor, originals = _load(model_id, token) | |
| tok = obj.tokenizer if is_processor else obj | |
| # Reset to original tokens, then apply overrides where the user provided one. | |
| for attr, original in originals.items(): | |
| setattr(tok, attr, original) | |
| for attr, value in (("bos_token", bos), ("eos_token", eos), ("pad_token", pad), ("unk_token", unk)): | |
| if value: | |
| setattr(tok, attr, value) | |
| if is_processor: | |
| obj.tokenizer.chat_template = template_source | |
| obj.chat_template = template_source | |
| return Context(obj=obj, is_processor=is_processor, template_source=template_source, model_id=model_id) | |
| _STATUS_EMOJI = {"pass": "β ", "fail": "β", "warning": "β οΈ", "na": "β"} | |
| def _format_report(results) -> str: | |
| by_cat: dict[str, list] = {} | |
| for r in results: | |
| by_cat.setdefault(r.category, []).append(r) | |
| lines: list[str] = [] | |
| for cat, items in by_cat.items(): | |
| lines.append(f"### {cat}") | |
| for r in items: | |
| emoji = _STATUS_EMOJI.get(r.status, "?") | |
| summary = f"{emoji} <b>{r.name}</b>" | |
| body_parts: list[str] = [] | |
| description = DESCRIPTIONS.get(r.name, "") | |
| if description: | |
| body_parts.append(f"_{description}_") | |
| if r.message: | |
| body_parts.append(f"**{r.message}**") | |
| if r.details: | |
| body_parts.append("```\n" + r.details + "\n```") | |
| if r.reproducer: | |
| body_parts.append("```python\n" + r.reproducer + "\n```") | |
| if body_parts: | |
| lines.append(f"<details>\n<summary>{summary}</summary>\n") | |
| lines.append("\n\n".join(body_parts)) | |
| lines.append("\n</details>") | |
| else: | |
| lines.append(summary + "<br>") | |
| lines.append("") | |
| return "\n".join(lines) | |
| _PLAYGROUND_DEFAULT = '''\ | |
| [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "What's the capital of France?"}, | |
| ] | |
| ''' | |
| def render_playground( | |
| model_id: str, template_source: str, bos: str, eos: str, pad: str, unk: str, | |
| messages_src: str, add_gen_prompt: bool, | |
| oauth_token: gr.OAuthToken | None = None, | |
| ) -> str: | |
| """Parse the user's `messages` literal and render it with apply_chat_template. | |
| Uses `ast.literal_eval` so only Python literals are accepted β no function | |
| calls or attribute access β which makes it safe to expose on a public Space. | |
| """ | |
| if not messages_src.strip(): | |
| return "" | |
| try: | |
| messages = ast.literal_eval(messages_src) | |
| except (ValueError, SyntaxError) as e: | |
| return f"# parse error: {type(e).__name__}: {e}" | |
| if not isinstance(messages, list): | |
| return "# expected a list of messages at the top level" | |
| token = oauth_token.token if oauth_token is not None else None | |
| try: | |
| ctx = _build_context(model_id, template_source, bos, eos, pad, unk, token=token) | |
| return ctx.obj.apply_chat_template( | |
| messages, add_generation_prompt=add_gen_prompt, tokenize=False, | |
| ) | |
| except Exception as e: | |
| return f"# render error: {type(e).__name__}: {e}" | |
| def run_inspection( | |
| model_id: str, template_source: str, bos: str, eos: str, pad: str, unk: str, | |
| oauth_token: gr.OAuthToken | None = None, | |
| ): | |
| if not model_id.strip() or not template_source.strip(): | |
| return "_Enter a model id and a template to inspect._" | |
| token = oauth_token.token if oauth_token is not None else None | |
| try: | |
| ctx = _build_context(model_id, template_source, bos, eos, pad, unk, token=token) | |
| except Exception as e: | |
| return f"**Failed to build context:** `{type(e).__name__}: {e}`\n\n```\n{traceback.format_exc()}\n```" | |
| results: list[CheckResult] = [] | |
| for check in ALL_CHECKS: | |
| try: | |
| results.append(check(ctx)) | |
| except Exception as e: | |
| results.append(CheckResult( | |
| check.__name__, "Internal", "fail", | |
| f"Check raised an unexpected exception: {type(e).__name__}: {e}", | |
| )) | |
| return _format_report(results) | |
| # βββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_MODEL = "Qwen/Qwen3-8B" | |
| try: | |
| _d_template, _d_bos, _d_eos, _d_pad, _d_unk, _d_id = load_from_hub(DEFAULT_MODEL) | |
| except Exception: | |
| _d_template = _d_bos = _d_eos = _d_pad = _d_unk = "" | |
| _d_id = DEFAULT_MODEL | |
| with gr.Blocks(title="Chat Template Inspector", fill_width=True) as demo: | |
| with gr.Row(): | |
| gr.Markdown("## Chat Template Inspector", elem_id="title") | |
| login_btn = gr.LoginButton(size="sm") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| model_id_box = gr.Textbox( | |
| value=_d_id, placeholder="Qwen/Qwen3-8B", | |
| scale=5, show_label=False, container=False, | |
| ) | |
| load_btn = gr.Button("π₯ Load", scale=1, variant="secondary") | |
| format_btn = gr.Button("β¨ Format", scale=1, variant="secondary") | |
| template_editor = gr.Code( | |
| value=_d_template, language="jinja2", show_label=False, | |
| lines=20, max_lines=20, | |
| ) | |
| with gr.Row(): | |
| bos_box = gr.Textbox(value=_d_bos, label="bos_token") | |
| eos_box = gr.Textbox(value=_d_eos, label="eos_token") | |
| pad_box = gr.Textbox(value=_d_pad, label="pad_token") | |
| unk_box = gr.Textbox(value=_d_unk, label="unk_token") | |
| with gr.Column(): | |
| report_md = gr.Markdown("_Loadingβ¦_") | |
| gr.Markdown("### Playground") | |
| gr.Markdown( | |
| "Define a `messages` list (Python literal β only `dict`, `list`, " | |
| "`str`, etc., no function calls) and see it rendered with the " | |
| "current template." | |
| ) | |
| playground_input = gr.Code( | |
| value=_PLAYGROUND_DEFAULT, language="python", show_label=False, | |
| lines=8, max_lines=8, | |
| ) | |
| playground_gen_prompt = gr.Checkbox(label="add_generation_prompt", value=True) | |
| playground_output = gr.Code( | |
| value="", language="markdown", show_label=False, | |
| lines=8, max_lines=8, interactive=False, | |
| ) | |
| inspection_inputs = [model_id_box, template_editor, bos_box, eos_box, pad_box, unk_box] | |
| playground_inputs = inspection_inputs + [playground_input, playground_gen_prompt] | |
| # Initial render on page load. | |
| demo.load(run_inspection, inputs=inspection_inputs, outputs=[report_md]) | |
| demo.load(render_playground, inputs=playground_inputs, outputs=[playground_output]) | |
| # Re-run whenever the template or any special token changes. Gradio cancels | |
| # any in-flight call when a new event fires, so rapid edits debounce naturally. | |
| for component in (template_editor, bos_box, eos_box, pad_box, unk_box): | |
| component.change( | |
| run_inspection, inputs=inspection_inputs, outputs=[report_md], | |
| show_progress="hidden", | |
| ) | |
| component.change( | |
| render_playground, inputs=playground_inputs, outputs=[playground_output], | |
| show_progress="hidden", | |
| ) | |
| # Playground re-renders on its own input changes too. | |
| for component in (playground_input, playground_gen_prompt): | |
| component.change( | |
| render_playground, inputs=playground_inputs, outputs=[playground_output], | |
| show_progress="hidden", | |
| ) | |
| # Loading from the Hub overwrites the editor + token boxes (which fires | |
| # their .change events and triggers a fresh inspection automatically). | |
| load_btn.click( | |
| load_from_hub, | |
| inputs=[model_id_box], | |
| outputs=[template_editor, bos_box, eos_box, pad_box, unk_box, model_id_box], | |
| ) | |
| format_btn.click(format_template, inputs=[template_editor], outputs=[template_editor]) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) | |