import torch
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ----------------------------------------
# 1. Load from Hugging Face Hub
# ----------------------------------------
# Change this to YOUR pushed model repo
HUB_MODEL_ID = "Abelex/Sentence-Chunking-Afri_BERTA_amharic_longtext"
# <--- EDIT IF NEEDED
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512 # model context window in TOKENS
# Load tokenizer and model directly from HF Hub
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_ID)
model.to(DEVICE)
model.eval()
# Label mapping from config
id2label = {int(k): v for k, v in model.config.id2label.items()}
num_labels = len(id2label)
# ----------------------------------------
# Helper: highlight tokens after MAX_LENGTH in red (HTML)
# ----------------------------------------
def highlight_token_overflow(text: str, max_tokens: int = 512) -> str:
"""
Tokenize the input text and generate HTML where tokens beyond
`max_tokens` are wrapped in red. This shows exactly which tokens
are outside the model's context window.
"""
if not text.strip():
return "No text provided."
# Tokenize without truncation (so we can see ALL tokens)
tokens = tokenizer.tokenize(text)
if len(tokens) == 0:
return "No tokens produced by tokenizer."
spans = []
for i, tok in enumerate(tokens):
# minimal HTML escape
safe_tok = (
tok.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
)
if i >= max_tokens:
spans.append(f"{safe_tok}")
else:
spans.append(f"{safe_tok}")
html = " ".join(spans)
if len(tokens) > max_tokens:
html += (
f"
"
f""
f"Note: Tokens in red are beyond the model context window "
f"({max_tokens} tokens) and will be truncated."
f""
)
else:
html += (
f"
"
f"Token count: {len(tokens)} (≤ {max_tokens}, no truncation)."
)
return html
# ----------------------------------------
# 2. Prediction
# ----------------------------------------
def predict_amharic_news(text):
if not text.strip():
# Also return highlighted version (empty)
return "Please enter text.", None, "No text provided."
# For actual model inference: truncate to MAX_LENGTH tokens
encoded = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt"
)
encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
with torch.no_grad():
outputs = model(**encoded)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
pred_id = int(np.argmax(probs))
pred_label = id2label.get(pred_id, f"LABEL_{pred_id}")
# Prepare probability table
rows = []
for i in range(num_labels):
rows.append((id2label.get(i, f"LABEL_{i}"), float(probs[i])))
rows = sorted(rows, key=lambda x: x[1], reverse=True)
# Build HTML showing tokens; tokens >512 in red
token_highlight_html = highlight_token_overflow(text, max_tokens=MAX_LENGTH)
# Now we return 3 outputs: prediction, probs table, token visualization
return f"Predicted Label: {pred_label}", rows, token_highlight_html
# ----------------------------------------
# 3. Gradio Interface
# ----------------------------------------
demo = gr.Interface(
fn=predict_amharic_news,
inputs=gr.Textbox(
lines=5,
label="Enter Amharic News Text",
placeholder="እባክዎ የአማርኛ ዜና ጽሑፍ ያስገቡ..."
),
outputs=[
gr.Textbox(label="Prediction"),
gr.Dataframe(
headers=["Label", "Probability"],
label="Class Probabilities"
),
gr.HTML(label="Tokenizer view (tokens > 512 are red)")
],
title="Amharic News Classifier",
description=(
"XLM-RoBERTa model loaded directly from Hugging Face Hub (raw text input, no preprocessing). "
"Below, tokenizer output shows which tokens are beyond the 512-token context window (in red)."
)
)
demo.launch()