Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # ---------------------------------------- | |
| # 1. Load from Hugging Face Hub | |
| # ---------------------------------------- | |
| # Change this to YOUR pushed model repo | |
| HUB_MODEL_ID = "Abelex/amharic-news-bert-multilingual-cased" | |
| # <--- EDIT IF NEEDED | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_LENGTH = 512 # model context window in TOKENS | |
| # Load tokenizer and model directly from HF Hub | |
| tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID) | |
| model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_ID) | |
| model.to(DEVICE) | |
| model.eval() | |
| # Label mapping from config | |
| id2label = {int(k): v for k, v in model.config.id2label.items()} | |
| num_labels = len(id2label) | |
| # ---------------------------------------- | |
| # Helper: highlight tokens after MAX_LENGTH in red (HTML) | |
| # ---------------------------------------- | |
| def highlight_token_overflow(text: str, max_tokens: int = 512) -> str: | |
| """ | |
| Tokenize the input text and generate HTML where tokens beyond | |
| `max_tokens` are wrapped in red. This shows exactly which tokens | |
| are outside the model's context window. | |
| """ | |
| if not text.strip(): | |
| return "<i>No text provided.</i>" | |
| # Tokenize without truncation (so we can see ALL tokens) | |
| tokens = tokenizer.tokenize(text) | |
| if len(tokens) == 0: | |
| return "<i>No tokens produced by tokenizer.</i>" | |
| spans = [] | |
| for i, tok in enumerate(tokens): | |
| # minimal HTML escape | |
| safe_tok = ( | |
| tok.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| ) | |
| if i >= max_tokens: | |
| spans.append(f"<span style='color:red;font-weight:bold;'>{safe_tok}</span>") | |
| else: | |
| spans.append(f"<span>{safe_tok}</span>") | |
| html = " ".join(spans) | |
| if len(tokens) > max_tokens: | |
| html += ( | |
| f"<br><br>" | |
| f"<small style='color:red;'>" | |
| f"Note: Tokens in <b>red</b> are beyond the model context window " | |
| f"({max_tokens} tokens) and will be truncated." | |
| f"</small>" | |
| ) | |
| else: | |
| html += ( | |
| f"<br><br>" | |
| f"<small>Token count: {len(tokens)} (β€ {max_tokens}, no truncation).</small>" | |
| ) | |
| return html | |
| # ---------------------------------------- | |
| # 2. Prediction | |
| # ---------------------------------------- | |
| def predict_amharic_news(text): | |
| if not text.strip(): | |
| # Also return highlighted version (empty) | |
| return "Please enter text.", None, "<i>No text provided.</i>" | |
| # For actual model inference: truncate to MAX_LENGTH tokens | |
| encoded = tokenizer( | |
| text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=MAX_LENGTH, | |
| return_tensors="pt" | |
| ) | |
| encoded = {k: v.to(DEVICE) for k, v in encoded.items()} | |
| with torch.no_grad(): | |
| outputs = model(**encoded) | |
| logits = outputs.logits | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] | |
| pred_id = int(np.argmax(probs)) | |
| pred_label = id2label.get(pred_id, f"LABEL_{pred_id}") | |
| # Prepare probability table | |
| rows = [] | |
| for i in range(num_labels): | |
| rows.append((id2label.get(i, f"LABEL_{i}"), float(probs[i]))) | |
| rows = sorted(rows, key=lambda x: x[1], reverse=True) | |
| # Build HTML showing tokens; tokens >512 in red | |
| token_highlight_html = highlight_token_overflow(text, max_tokens=MAX_LENGTH) | |
| # Now we return 3 outputs: prediction, probs table, token visualization | |
| return f"Predicted Label: {pred_label}", rows, token_highlight_html | |
| # ---------------------------------------- | |
| # 3. Gradio Interface | |
| # ---------------------------------------- | |
| demo = gr.Interface( | |
| fn=predict_amharic_news, | |
| inputs=gr.Textbox( | |
| lines=5, | |
| label="Enter Amharic News Text", | |
| placeholder="α₯α£αα α¨α ααα αα α½αα α«α΅αα‘..." | |
| ), | |
| outputs=[ | |
| gr.Textbox(label="Prediction"), | |
| gr.Dataframe( | |
| headers=["Label", "Probability"], | |
| label="Class Probabilities" | |
| ), | |
| gr.HTML(label="Tokenizer view (tokens > 512 are red)") | |
| ], | |
| title="Amharic News Classifier", | |
| description=( | |
| "XLM-RoBERTa model loaded directly from Hugging Face Hub (raw text input, no preprocessing). " | |
| "Below, tokenizer output shows which tokens are beyond the 512-token context window (in red)." | |
| ) | |
| ) | |
| demo.launch() | |