lewiswatson's picture
Disable Gradio SSR for private Space
a95ea14 verified
Raw
History Blame Contribute Delete
13.4 kB
from __future__ import annotations
from functools import lru_cache
from html import escape
from typing import Any
import gradio as gr
from transformers import AutoTokenizer
DEFAULT_MODEL_ID = "Qwen/Qwen3.5-2B"
DEFAULT_ADDED_TOKENS = [
"<|graph|>",
"<|nodes|>",
"<|node|>",
"<|id|>",
"<|label|>",
"<|bbox|>",
"<|conf|>",
"<|attrs|>",
"<|edges|>",
"<|edge|>",
"<|src|>",
"<|pred|>",
"<|tgt|>",
"<|end_graph|>",
]
DEFAULT_ADDED_TOKEN_TEXT = "\n".join(DEFAULT_ADDED_TOKENS)
SAMPLE_TEXT = "<|graph|><|nodes|><|node|><|id|>person1<|label|>man<|bbox|>0.0 0.0 0.5 1.0<|conf|>0.9<|attrs|>appearance=white t-shirt<|attrs|>size=medium<|node|><|id|>person2<|label|>woman<|bbox|>0.45 0.0 1.0 1.0<|conf|>0.9<|attrs|>appearance=patterned sleeveless top<|attrs|>size=medium<|node|><|id|>table1<|label|>table<|bbox|>0.0 0.7 1.0 1.0<|conf|>0.8<|attrs|>appearance=brown wood<|attrs|>size=large<|edges|><|edge|><|src|>person1<|pred|>next_to<|tgt|>person2<|edge|><|src|>person1<|pred|>sitting_on<|tgt|>table1<|edge|><|src|>person2<|pred|>sitting_on<|tgt|>table1<|end_graph|>"
MAX_DETAIL_ROWS = 12000
@lru_cache(maxsize=4)
def load_tokeniser(model_id: str) -> Any:
model_id = model_id.strip()
if not model_id:
raise ValueError("Tokeniser model id is required.")
return AutoTokenizer.from_pretrained(model_id, use_fast=True)
def parse_added_tokens(raw_tokens: str) -> list[str]:
seen: set[str] = set()
parsed: list[str] = []
for line in raw_tokens.splitlines():
token = line.strip()
if not token or token in seen:
continue
seen.add(token)
parsed.append(token)
return parsed
def split_by_added_tokens(text: str, added_tokens: list[str]) -> list[tuple[str, str]]:
if not text:
return []
if not added_tokens:
return [("base", text)]
sorted_added = sorted(added_tokens, key=len, reverse=True)
pieces: list[tuple[str, str]] = []
buffer: list[str] = []
index = 0
while index < len(text):
match = next(
(token for token in sorted_added if text.startswith(token, index)),
None,
)
if match is not None:
if buffer:
pieces.append(("base", "".join(buffer)))
buffer = []
pieces.append(("added", match))
index += len(match)
continue
buffer.append(text[index])
index += 1
if buffer:
pieces.append(("base", "".join(buffer)))
return pieces
def tokenise_base_piece(tokeniser: Any, text: str) -> list[dict[str, str]]:
if not text:
return []
offsets: list[tuple[int, int]] = []
try:
encoded = tokeniser(
text,
add_special_tokens=False,
return_offsets_mapping=True,
)
input_ids = encoded["input_ids"]
offsets = encoded.get("offset_mapping") or []
except Exception:
input_ids = tokeniser.encode(text, add_special_tokens=False)
if input_ids and isinstance(input_ids[0], list):
input_ids = input_ids[0]
if offsets and isinstance(offsets[0], list) and offsets and offsets[0] and isinstance(offsets[0][0], (list, tuple)):
offsets = offsets[0]
token_strings = tokeniser.convert_ids_to_tokens(input_ids)
tokens: list[dict[str, str]] = []
for index, token_id in enumerate(input_ids):
display = ""
if index < len(offsets):
start, end = offsets[index]
if end > start:
display = text[start:end]
if not display:
try:
display = tokeniser.decode(
[token_id],
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
)
except Exception:
display = token_strings[index] if index < len(token_strings) else ""
tokens.append(
{
"kind": "base",
"id": str(token_id),
"raw": token_strings[index] if index < len(token_strings) else display,
"display": display,
}
)
return tokens
def tokenise_with_added_tokens(
text: str,
added_tokens: list[str],
tokeniser: Any,
) -> list[dict[str, str]]:
output: list[dict[str, str]] = []
for kind, value in split_by_added_tokens(text, added_tokens):
if kind == "added":
output.append(
{
"kind": "added",
"id": "added",
"raw": value,
"display": value,
}
)
else:
output.extend(tokenise_base_piece(tokeniser, value))
return output
def printable(value: str) -> str:
return (
value.replace("\\", "\\\\")
.replace("\n", "\\n")
.replace("\r", "\\r")
.replace("\t", "\\t")
.replace(" ", "·")
)
def render_metrics(total: int, base: int, added: int, chars: int) -> str:
return f"""
<div class="metrics-grid">
<div class="metric-card primary"><span>Total tokens</span><strong>{total:,}</strong></div>
<div class="metric-card"><span>Base tokeniser</span><strong>{base:,}</strong></div>
<div class="metric-card"><span>Added tokens</span><strong>{added:,}</strong></div>
<div class="metric-card"><span>Characters</span><strong>{chars:,}</strong></div>
</div>
"""
def render_added_token_chips(added_tokens: list[str]) -> str:
if not added_tokens:
return '<div class="empty-note">No added tokens configured.</div>'
chips = "".join(
f'<span class="added-chip">{escape(token)}</span>' for token in added_tokens
)
return f'<div class="chip-wrap">{chips}</div>'
def render_highlights(tokens: list[dict[str, str]], max_rendered: int) -> str:
if not tokens:
return '<div class="empty-note">No tokens to show.</div>'
limited = tokens[:max_rendered]
spans: list[str] = []
for index, token in enumerate(limited):
if token["kind"] == "added":
class_name = "token added"
token_type = "added"
else:
class_name = f"token base base-{index % 3}"
token_type = "base"
display = token["display"]
title = escape(
f"#{index + 1} | {token_type} | id: {token['id']} | {printable(token['raw'])}"
)
body = escape(display) if display else "&nbsp;"
spans.append(f'<span class="{class_name}" title="{title}">{body}</span>')
note = (
f"Showing first {len(limited):,} of {len(tokens):,} tokens"
if len(tokens) > len(limited)
else f"{len(tokens):,} tokens shown"
)
return f"""
<div class="token-note">{note}</div>
<div class="token-output">{''.join(spans)}</div>
"""
def build_rows(tokens: list[dict[str, str]], max_rendered: int) -> list[list[str]]:
rows: list[list[str]] = []
for index, token in enumerate(tokens[:max_rendered]):
rows.append(
[
str(index + 1),
token["kind"],
token["id"],
printable(token["display"] or token["raw"]),
]
)
return rows
def analyze(
text: str,
added_token_text: str,
model_id: str,
max_rendered: int,
) -> tuple[str, str, str, list[list[str]], str]:
added_tokens = parse_added_tokens(added_token_text or "")
max_rendered = max(1, min(int(max_rendered or 2000), MAX_DETAIL_ROWS))
try:
tokeniser = load_tokeniser(model_id)
tokens = tokenise_with_added_tokens(text or "", added_tokens, tokeniser)
except Exception as exc:
status = f'<div class="status error">Tokeniser error: {escape(str(exc))}</div>'
return (
render_metrics(0, 0, 0, len(text or "")),
render_added_token_chips(added_tokens),
status,
[],
status,
)
added_count = sum(1 for token in tokens if token["kind"] == "added")
base_count = len(tokens) - added_count
status = f'<div class="status ready">Loaded {escape(model_id.strip())}</div>'
return (
render_metrics(len(tokens), base_count, added_count, len(text or "")),
render_added_token_chips(added_tokens),
render_highlights(tokens, max_rendered),
build_rows(tokens, max_rendered),
status,
)
CSS = """
.gradio-container {
max-width: 1440px !important;
}
.app-title h1 {
font-size: 1.7rem;
letter-spacing: 0;
margin-bottom: 0.35rem;
}
.app-title p {
color: #53636d;
margin-top: 0;
}
.metrics-grid {
display: grid;
grid-template-columns: repeat(4, minmax(130px, 1fr));
gap: 10px;
}
.metric-card {
border: 1px solid #d6dee2;
border-radius: 8px;
background: #fff;
padding: 12px 14px;
}
.metric-card.primary {
border-color: #8ccbd2;
}
.metric-card span {
display: block;
color: #53636d;
font-size: 0.84rem;
font-weight: 700;
}
.metric-card strong {
display: block;
color: #172026;
font-size: 2rem;
line-height: 1;
margin-top: 6px;
}
.metric-card.primary strong {
color: #006873;
}
.chip-wrap {
display: flex;
flex-wrap: wrap;
gap: 7px;
}
.added-chip,
.token {
border-radius: 5px;
font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, "Liberation Mono", monospace;
}
.added-chip {
border: 1px solid #d78c00;
background: #ffe8b3;
color: #4f3200;
padding: 4px 7px;
font-size: 0.86rem;
}
.token-note {
color: #53636d;
font-size: 0.86rem;
margin-bottom: 8px;
text-align: right;
}
.token-output {
min-height: 300px;
max-height: 620px;
overflow: auto;
border: 1px solid #d6dee2;
border-radius: 8px;
background: #fff;
padding: 14px;
line-height: 1.9;
white-space: pre-wrap;
overflow-wrap: anywhere;
}
.token {
display: inline;
border: 1px solid #cbd7db;
margin: 0 1px 2px 0;
padding: 2px 3px;
cursor: default;
white-space: pre-wrap;
}
.token.added {
border-color: #d78c00;
background: #ffe8b3;
color: #4f3200;
font-weight: 750;
}
.token.base-0 {
background: #e4f4ef;
color: #14231e;
}
.token.base-1 {
background: #e9eefb;
color: #182139;
}
.token.base-2 {
background: #f2e8f7;
color: #2e2135;
}
.status {
border-radius: 999px;
display: inline-block;
padding: 6px 11px;
font-size: 0.88rem;
border: 1px solid #d6dee2;
}
.status.ready {
border-color: #9bd2aa;
color: #137333;
}
.status.error {
border-color: #f1a6a0;
color: #b42318;
border-radius: 8px;
max-width: 100%;
overflow-wrap: anywhere;
}
.empty-note {
color: #53636d;
}
@media (max-width: 820px) {
.metrics-grid {
grid-template-columns: 1fr 1fr;
}
}
@media (max-width: 520px) {
.metrics-grid {
grid-template-columns: 1fr;
}
}
"""
with gr.Blocks(title="Custom Added Token Counter") as demo:
gr.Markdown(
"""
# Custom added token counter
Count raw tokeniser pieces with editable added tokens.
""",
elem_classes=["app-title"],
)
with gr.Row():
model_id = gr.Textbox(
value=DEFAULT_MODEL_ID,
label="Tokeniser model",
scale=4,
container=True,
)
max_rendered = gr.Slider(
minimum=100,
maximum=MAX_DETAIL_ROWS,
value=2000,
step=100,
label="Rendered token limit",
scale=2,
)
with gr.Row():
with gr.Column(scale=3):
input_text = gr.Textbox(
value=SAMPLE_TEXT,
label="Input text",
lines=14,
max_lines=28,
autoscroll=False,
)
with gr.Column(scale=2):
added_tokens = gr.Textbox(
value=DEFAULT_ADDED_TOKEN_TEXT,
label="Added tokens",
lines=14,
max_lines=28,
autoscroll=False,
)
with gr.Row():
count_button = gr.Button("Count tokens", variant="primary")
clear_button = gr.Button("Clear text")
status_html = gr.HTML()
metrics_html = gr.HTML()
added_token_html = gr.HTML(label="Active added tokens")
highlighted_html = gr.HTML(label="Highlighted tokens")
token_table = gr.Dataframe(
headers=["#", "type", "id", "token"],
datatype=["str", "str", "str", "str"],
interactive=False,
wrap=True,
label="Token details",
)
analyze_inputs = [input_text, added_tokens, model_id, max_rendered]
analyze_outputs = [
metrics_html,
added_token_html,
highlighted_html,
token_table,
status_html,
]
demo.load(analyze, inputs=analyze_inputs, outputs=analyze_outputs)
count_button.click(analyze, inputs=analyze_inputs, outputs=analyze_outputs)
input_text.submit(analyze, inputs=analyze_inputs, outputs=analyze_outputs)
model_id.submit(analyze, inputs=analyze_inputs, outputs=analyze_outputs)
added_tokens.submit(analyze, inputs=analyze_inputs, outputs=analyze_outputs)
clear_button.click(lambda: "", outputs=input_text).then(
analyze,
inputs=analyze_inputs,
outputs=analyze_outputs,
)
if __name__ == "__main__":
demo.launch(css=CSS, ssr_mode=False)