import torch import gradio as gr import numpy as np from transformers import AutoTokenizer, AutoModelForSequenceClassification # ---------------------------------------- # 1. Load from Hugging Face Hub # ---------------------------------------- # Change this to YOUR pushed model repo HUB_MODEL_ID = "Abelex/Sentence-Chunking-Afri_BERTA_amharic_longtext" # <--- EDIT IF NEEDED DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_LENGTH = 512 # model context window in TOKENS # Load tokenizer and model directly from HF Hub tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID) model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_ID) model.to(DEVICE) model.eval() # Label mapping from config id2label = {int(k): v for k, v in model.config.id2label.items()} num_labels = len(id2label) # ---------------------------------------- # Helper: highlight tokens after MAX_LENGTH in red (HTML) # ---------------------------------------- def highlight_token_overflow(text: str, max_tokens: int = 512) -> str: """ Tokenize the input text and generate HTML where tokens beyond `max_tokens` are wrapped in red. This shows exactly which tokens are outside the model's context window. """ if not text.strip(): return "No text provided." # Tokenize without truncation (so we can see ALL tokens) tokens = tokenizer.tokenize(text) if len(tokens) == 0: return "No tokens produced by tokenizer." spans = [] for i, tok in enumerate(tokens): # minimal HTML escape safe_tok = ( tok.replace("&", "&") .replace("<", "<") .replace(">", ">") ) if i >= max_tokens: spans.append(f"{safe_tok}") else: spans.append(f"{safe_tok}") html = " ".join(spans) if len(tokens) > max_tokens: html += ( f"

" f"" f"Note: Tokens in red are beyond the model context window " f"({max_tokens} tokens) and will be truncated." f"" ) else: html += ( f"

" f"Token count: {len(tokens)} (≤ {max_tokens}, no truncation)." ) return html # ---------------------------------------- # 2. Prediction # ---------------------------------------- def predict_amharic_news(text): if not text.strip(): # Also return highlighted version (empty) return "Please enter text.", None, "No text provided." # For actual model inference: truncate to MAX_LENGTH tokens encoded = tokenizer( text, truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt" ) encoded = {k: v.to(DEVICE) for k, v in encoded.items()} with torch.no_grad(): outputs = model(**encoded) logits = outputs.logits probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] pred_id = int(np.argmax(probs)) pred_label = id2label.get(pred_id, f"LABEL_{pred_id}") # Prepare probability table rows = [] for i in range(num_labels): rows.append((id2label.get(i, f"LABEL_{i}"), float(probs[i]))) rows = sorted(rows, key=lambda x: x[1], reverse=True) # Build HTML showing tokens; tokens >512 in red token_highlight_html = highlight_token_overflow(text, max_tokens=MAX_LENGTH) # Now we return 3 outputs: prediction, probs table, token visualization return f"Predicted Label: {pred_label}", rows, token_highlight_html # ---------------------------------------- # 3. Gradio Interface # ---------------------------------------- demo = gr.Interface( fn=predict_amharic_news, inputs=gr.Textbox( lines=5, label="Enter Amharic News Text", placeholder="እባክዎ የአማርኛ ዜና ጽሑፍ ያስገቡ..." ), outputs=[ gr.Textbox(label="Prediction"), gr.Dataframe( headers=["Label", "Probability"], label="Class Probabilities" ), gr.HTML(label="Tokenizer view (tokens > 512 are red)") ], title="Amharic News Classifier", description=( "XLM-RoBERTa model loaded directly from Hugging Face Hub (raw text input, no preprocessing). " "Below, tokenizer output shows which tokens are beyond the 512-token context window (in red)." ) ) demo.launch()