Spaces:
Runtime error
Runtime error
File size: 4,547 Bytes
fe4863c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ----------------------------------------
# 1. Load from Hugging Face Hub
# ----------------------------------------
# Change this to YOUR pushed model repo
HUB_MODEL_ID = "Abelex/Sentence-Chunking-Afri_BERTA_amharic_longtext"
# <--- EDIT IF NEEDED
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512 # model context window in TOKENS
# Load tokenizer and model directly from HF Hub
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_ID)
model.to(DEVICE)
model.eval()
# Label mapping from config
id2label = {int(k): v for k, v in model.config.id2label.items()}
num_labels = len(id2label)
# ----------------------------------------
# Helper: highlight tokens after MAX_LENGTH in red (HTML)
# ----------------------------------------
def highlight_token_overflow(text: str, max_tokens: int = 512) -> str:
"""
Tokenize the input text and generate HTML where tokens beyond
`max_tokens` are wrapped in red. This shows exactly which tokens
are outside the model's context window.
"""
if not text.strip():
return "<i>No text provided.</i>"
# Tokenize without truncation (so we can see ALL tokens)
tokens = tokenizer.tokenize(text)
if len(tokens) == 0:
return "<i>No tokens produced by tokenizer.</i>"
spans = []
for i, tok in enumerate(tokens):
# minimal HTML escape
safe_tok = (
tok.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
)
if i >= max_tokens:
spans.append(f"<span style='color:red;font-weight:bold;'>{safe_tok}</span>")
else:
spans.append(f"<span>{safe_tok}</span>")
html = " ".join(spans)
if len(tokens) > max_tokens:
html += (
f"<br><br>"
f"<small style='color:red;'>"
f"Note: Tokens in <b>red</b> are beyond the model context window "
f"({max_tokens} tokens) and will be truncated."
f"</small>"
)
else:
html += (
f"<br><br>"
f"<small>Token count: {len(tokens)} (β€ {max_tokens}, no truncation).</small>"
)
return html
# ----------------------------------------
# 2. Prediction
# ----------------------------------------
def predict_amharic_news(text):
if not text.strip():
# Also return highlighted version (empty)
return "Please enter text.", None, "<i>No text provided.</i>"
# For actual model inference: truncate to MAX_LENGTH tokens
encoded = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt"
)
encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
with torch.no_grad():
outputs = model(**encoded)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
pred_id = int(np.argmax(probs))
pred_label = id2label.get(pred_id, f"LABEL_{pred_id}")
# Prepare probability table
rows = []
for i in range(num_labels):
rows.append((id2label.get(i, f"LABEL_{i}"), float(probs[i])))
rows = sorted(rows, key=lambda x: x[1], reverse=True)
# Build HTML showing tokens; tokens >512 in red
token_highlight_html = highlight_token_overflow(text, max_tokens=MAX_LENGTH)
# Now we return 3 outputs: prediction, probs table, token visualization
return f"Predicted Label: {pred_label}", rows, token_highlight_html
# ----------------------------------------
# 3. Gradio Interface
# ----------------------------------------
demo = gr.Interface(
fn=predict_amharic_news,
inputs=gr.Textbox(
lines=5,
label="Enter Amharic News Text",
placeholder="α₯α£αα α¨α ααα αα α½αα α«α΅αα‘..."
),
outputs=[
gr.Textbox(label="Prediction"),
gr.Dataframe(
headers=["Label", "Probability"],
label="Class Probabilities"
),
gr.HTML(label="Tokenizer view (tokens > 512 are red)")
],
title="Amharic News Classifier",
description=(
"XLM-RoBERTa model loaded directly from Hugging Face Hub (raw text input, no preprocessing). "
"Below, tokenizer output shows which tokens are beyond the 512-token context window (in red)."
)
)
demo.launch()
|