| |
| """Gradio demo for the multilingual token-classification language ID model.""" |
|
|
| from __future__ import annotations |
|
|
| from collections import Counter, defaultdict |
| from functools import lru_cache |
| from typing import Any |
|
|
| import pandas as pd |
| import gradio as gr |
| from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline |
|
|
| from language import ALL_LANGS, LANG_ISO2_TO_ISO3 |
|
|
|
|
| MODEL_CHECKPOINT = "DerivedFunction/polyglot-tagger-66L-3M" |
| MAX_TEXT_CHARS = 512 |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_pipeline(): |
| model = AutoModelForTokenClassification.from_pretrained(MODEL_CHECKPOINT) |
| tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") |
| return pipeline( |
| "token-classification", |
| model=model, |
| tokenizer=tokenizer, |
| aggregation_strategy="simple", |
| ) |
|
|
|
|
| def normalize_label(label: str) -> str: |
| if label.startswith(("B-", "I-")): |
| label = label[2:] |
| return label.lower() |
|
|
|
|
| def predict(text: str) -> tuple[str, pd.DataFrame, dict[str, Any]]: |
| text = (text or "").strip() |
| if not text: |
| empty = pd.DataFrame(columns=["token", "language", "score", "start", "end"]) |
| return ( |
| "<div class='empty-state'>Paste some text to see the model's language signal.</div>", |
| empty, |
| {}, |
| ) |
|
|
| nlp = get_pipeline() |
| entities = nlp(text[:MAX_TEXT_CHARS]) |
|
|
| rows: list[dict[str, Any]] = [] |
| token_counts: Counter[str] = Counter() |
| token_scores: defaultdict[str, float] = defaultdict(float) |
|
|
| for entity in entities: |
| label = normalize_label(entity.get("entity_group", entity.get("entity", "O"))) |
| if label == "o": |
| continue |
| token_counts[label] += 1 |
| token_scores[label] += float(entity.get("score", 0.0)) |
| rows.append( |
| { |
| "token": entity.get("word", ""), |
| "language": label, |
| "score": round(float(entity.get("score", 0.0)), 4), |
| "start": entity.get("start", None), |
| "end": entity.get("end", None), |
| } |
| ) |
|
|
| spans = pd.DataFrame(rows, columns=["token", "language", "score", "start", "end"]) |
| spans = spans.sort_values(by=["start", "end"], na_position="last") if not spans.empty else spans |
|
|
| if token_counts: |
| dominant_lang, dominant_count = token_counts.most_common(1)[0] |
| avg_score = token_scores[dominant_lang] / max(dominant_count, 1) |
| iso3 = LANG_ISO2_TO_ISO3.get(dominant_lang, "n/a") |
| chips = "".join( |
| f"<span class='chip'>{lang.upper()} <strong>{count}</strong></span>" |
| for lang, count in token_counts.most_common(5) |
| ) |
| summary = f""" |
| <div class="summary-card"> |
| <div class="summary-kicker">Prediction</div> |
| <div class="summary-main">{dominant_lang.upper()}</div> |
| <div class="summary-subtitle">ISO-3: {iso3} | analyzed tokens: {len(rows)}</div> |
| <div class="summary-score">Avg confidence: {avg_score:.3f}</div> |
| <div class="chip-row">{chips}</div> |
| </div> |
| """ |
| else: |
| summary = """ |
| <div class="summary-card"> |
| <div class="summary-kicker">Prediction</div> |
| <div class="summary-main">No language spans detected</div> |
| <div class="summary-subtitle">Try a longer sample or a cleaner single-language paragraph.</div> |
| </div> |
| """ |
|
|
| raw = { |
| "model": MODEL_CHECKPOINT, |
| "languages_supported": len(ALL_LANGS), |
| "top_predictions": token_counts.most_common(10), |
| "entities": entities, |
| } |
| return summary, spans, raw |
|
|
|
|
| EXAMPLES = [ |
| "This model should recognize English text without much trouble.", |
| "Hola, este ejemplo mezcla palabras en espanol para probar el detector.", |
| "هذا مثال باللغة العربية لاختبار النموذج على فقرة قصيرة.", |
| "Bonjour, ceci est un petit texte en francais pour un test rapide.", |
| "今日は日本語の文章を入力して、モデルの反応を確認します。", |
| "This sentence mixes English and العربية to show mixed-language behavior.", |
| ] |
|
|
|
|
| CSS = """ |
| :root { |
| --bg-1: #06111f; |
| --bg-2: #0b1f33; |
| --card: rgba(10, 20, 33, 0.72); |
| --card-border: rgba(255, 255, 255, 0.12); |
| --text: #f4f7fb; |
| --muted: #b7c3d6; |
| --accent: #7dd3fc; |
| --accent-2: #f59e0b; |
| } |
| body { |
| background: |
| radial-gradient(circle at top left, rgba(125, 211, 252, 0.22), transparent 28%), |
| radial-gradient(circle at top right, rgba(245, 158, 11, 0.16), transparent 24%), |
| linear-gradient(135deg, var(--bg-1), var(--bg-2)); |
| } |
| .wrap { |
| max-width: 1180px; |
| margin: 0 auto; |
| } |
| .hero { |
| padding: 28px 28px 22px; |
| border: 1px solid var(--card-border); |
| border-radius: 24px; |
| background: linear-gradient(180deg, rgba(255,255,255,0.08), rgba(255,255,255,0.03)); |
| box-shadow: 0 24px 80px rgba(0, 0, 0, 0.28); |
| backdrop-filter: blur(14px); |
| } |
| .eyebrow { |
| text-transform: uppercase; |
| letter-spacing: 0.22em; |
| color: var(--accent); |
| font-size: 12px; |
| font-weight: 700; |
| margin-bottom: 8px; |
| } |
| .title { |
| font-size: clamp(32px, 5vw, 56px); |
| line-height: 1.02; |
| margin: 0; |
| color: var(--text); |
| font-weight: 800; |
| } |
| .subtitle { |
| margin-top: 12px; |
| color: var(--muted); |
| font-size: 16px; |
| max-width: 820px; |
| } |
| .summary-card { |
| border: 1px solid var(--card-border); |
| border-radius: 22px; |
| padding: 22px; |
| background: rgba(7, 13, 24, 0.7); |
| color: var(--text); |
| min-height: 220px; |
| } |
| .summary-kicker { |
| color: var(--accent); |
| text-transform: uppercase; |
| letter-spacing: 0.18em; |
| font-size: 11px; |
| font-weight: 700; |
| } |
| .summary-main { |
| font-size: 42px; |
| font-weight: 900; |
| margin-top: 8px; |
| color: white; |
| } |
| .summary-subtitle, .summary-score { |
| color: var(--muted); |
| margin-top: 8px; |
| } |
| .chip-row { |
| display: flex; |
| flex-wrap: wrap; |
| gap: 8px; |
| margin-top: 18px; |
| } |
| .chip { |
| border: 1px solid rgba(125, 211, 252, 0.25); |
| background: rgba(125, 211, 252, 0.08); |
| color: var(--text); |
| padding: 7px 10px; |
| border-radius: 999px; |
| font-size: 13px; |
| } |
| .empty-state { |
| padding: 18px 20px; |
| border-radius: 18px; |
| border: 1px dashed rgba(255,255,255,0.16); |
| color: var(--muted); |
| background: rgba(255,255,255,0.03); |
| } |
| .gradio-container .gr-textbox textarea { |
| font-size: 15px !important; |
| } |
| .footer-note { |
| color: var(--muted); |
| font-size: 13px; |
| margin-top: 8px; |
| } |
| """ |
|
|
|
|
| with gr.Blocks(title="Polyglot Tagger Studio", css=CSS) as demo: |
| gr.HTML( |
| """ |
| <div class="wrap hero"> |
| <div class="eyebrow">Multilingual Language ID</div> |
| <h1 class="title">Polyglot Tagger Studio</h1> |
| <div class="subtitle"> |
| A Gradio demo for the token-classification model behind this repo. Paste a sentence or paragraph, |
| and the app will surface the dominant language signal, token-level spans, and raw predictions. Note that this is experimental and does not replace a text classifier: be prepared for unexpected results. |
| </div> |
| </div> |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=6): |
| input_text = gr.Textbox( |
| label="Text", |
| lines=12, |
| placeholder="Paste a sentence or a short paragraph here...", |
| value=EXAMPLES[0], |
| ) |
| gr.Markdown( |
| "Try a clean sentence for a single-language prediction, or mix languages to see how the model behaves." |
| ) |
| with gr.Row(): |
| analyze_btn = gr.Button("Analyze", variant="primary") |
| clear_btn = gr.Button("Clear") |
| gr.Examples( |
| examples=[[example] for example in EXAMPLES], |
| inputs=input_text, |
| label="Examples", |
| cache_examples=False, |
| ) |
| with gr.Column(scale=6): |
| summary = gr.HTML() |
| spans = gr.Dataframe( |
| headers=["token", "language", "score", "start", "end"], |
| datatype=["str", "str", "number", "number", "number"], |
| label="Token-level spans", |
| interactive=False, |
| wrap=True, |
| ) |
| raw = gr.JSON(label="Raw output") |
|
|
| analyze_btn.click( |
| fn=predict, |
| inputs=input_text, |
| outputs=[summary, spans, raw], |
| api_name="analyze", |
| ) |
| input_text.submit( |
| fn=predict, |
| inputs=input_text, |
| outputs=[summary, spans, raw], |
| api_name="analyze_text", |
| ) |
| clear_btn.click( |
| fn=lambda: ("", pd.DataFrame(columns=["token", "language", "score", "start", "end"]), {}), |
| inputs=None, |
| outputs=[summary, spans, raw], |
| api_name="clear", |
| ) |
|
|
| gr.HTML( |
| """ |
| <div class="footer-note"> |
| Supported model languages: 60. The demo uses the local repo checkpoint and the ISO-2 to ISO-3 mapping in language.py. |
| </div> |
| """ |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue() |
| demo.launch() |
|
|