sentence / app.py
Abelex's picture
Update app.py
183f13d verified
# ===============================
# Final Gradio Demo (FIXED + ALIGNED)
# ===============================
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import os
import re
import json
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------------------------------
# MODEL CONFIG (MUST MATCH TRAINING)
# -------------------------------------------------
PRETRAINED = "Davlan/bert-base-multilingual-cased-finetuned-amharic"
HF_MODEL_ID = "Abelex/afro-xlmr-large"
CHUNK_SIZE = 512
MAX_CHUNKS = 8
CHUNK_DMODEL = 256
DROPOUT = 0.1
# -------------------------------------------------
# Load config from HF (labels, num_labels)
# -------------------------------------------------
try:
config_path = hf_hub_download(HF_MODEL_ID, "config.json")
with open(config_path) as f:
cfg = json.load(f)
id2label = {int(k): v for k, v in cfg["id2label"].items()}
label2id = cfg["label2id"]
num_labels = cfg["num_labels"]
print("βœ“ Loaded label mappings from HF")
except Exception as e:
print("⚠ Could not load config.json β€” using fallback")
id2label = {
0: "Politics",
1: "Economy",
2: "Sports",
3: "Technology",
4: "Health",
5: "Agriculture",
6: "accident",
7: "education",
}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)
# -------------------------------------------------
# MODEL
# -------------------------------------------------
class HybridSentenceChuLo(nn.Module):
def __init__(self, pretrained_name, num_labels):
super().__init__()
self.bert = AutoModel.from_pretrained(pretrained_name)
hidden = self.bert.config.hidden_size
self.proj = nn.Linear(hidden, CHUNK_DMODEL) if hidden != CHUNK_DMODEL else nn.Identity()
self.token_attn_vec = nn.Parameter(torch.randn(CHUNK_DMODEL))
enc_layer = nn.TransformerEncoderLayer(
d_model=CHUNK_DMODEL,
nhead=8,
dim_feedforward=4 * CHUNK_DMODEL,
batch_first=True,
dropout=DROPOUT
)
self.chunk_transformer = nn.TransformerEncoder(enc_layer, num_layers=2)
self.classifier = nn.Sequential(
nn.LayerNorm(CHUNK_DMODEL),
nn.Linear(CHUNK_DMODEL, num_labels)
)
def forward(self, input_ids, attention_mask):
B, C, T = input_ids.size()
flat_ids = input_ids.view(B * C, T)
flat_mask = attention_mask.view(B * C, T)
out = self.bert(input_ids=flat_ids, attention_mask=flat_mask)
h = self.proj(out.last_hidden_state)
scores = torch.matmul(h, self.token_attn_vec)
scores = scores.masked_fill(flat_mask == 0, torch.finfo(scores.dtype).min)
weights = torch.softmax(scores, dim=1).unsqueeze(-1)
chunk_vecs = (h * weights).sum(dim=1).view(B, C, CHUNK_DMODEL)
chunk_mask = (attention_mask.sum(dim=2) > 0)
key_padding_mask = ~chunk_mask
enc = self.chunk_transformer(chunk_vecs, src_key_padding_mask=key_padding_mask)
valid = (~key_padding_mask).unsqueeze(-1).float()
doc_vec = (enc * valid).sum(dim=1) / valid.sum(dim=1).clamp(min=1e-6)
return self.classifier(doc_vec)
# -------------------------------------------------
# Load tokenizer & model
# -------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
model = HybridSentenceChuLo(PRETRAINED, num_labels).to(DEVICE)
from transformers import AutoModel
model = AutoModel.from_pretrained(
"Abelex/afro-xlmr-large",
trust_remote_code=True
)
model.load_state_dict(state, strict=False)
model.eval()
print("βœ“ Model ready")
# -------------------------------------------------
# Sentence splitting
# -------------------------------------------------
def split_sentences(text):
sents = re.split(r"(?<=[ፒፀ!?])\s+", text)
return [s.strip() for s in sents if s.strip()]
# -------------------------------------------------
# EXACT Beginning–Middle–End selection
# -------------------------------------------------
def select_exact_bme(sentences):
n = len(sentences)
if n == 0:
return []
idxs = sorted(set([0, n // 2, n - 1]))
return [(sentences[i], 1) for i in idxs]
# -------------------------------------------------
# Encode chunks
# -------------------------------------------------
def encode_sentence_chunks(sentences):
chunks, masks = [], []
for s in sentences:
enc = tokenizer(
s,
max_length=CHUNK_SIZE,
padding="max_length",
truncation=True,
return_tensors="pt"
)
chunks.append(enc["input_ids"][0])
masks.append(enc["attention_mask"][0])
while len(chunks) < MAX_CHUNKS:
chunks.append(torch.full((CHUNK_SIZE,), tokenizer.pad_token_id))
masks.append(torch.zeros(CHUNK_SIZE, dtype=torch.long))
return torch.stack(chunks[:MAX_CHUNKS]), torch.stack(masks[:MAX_CHUNKS])
# -------------------------------------------------
# HTML Highlighting
# -------------------------------------------------
def build_html(all_sents, selected):
highlight = {s for s, _ in selected}
html = "<div style='font-size:16px; line-height:1.6;'>"
for s in all_sents:
safe = s.replace("<", "&lt;").replace(">", "&gt;")
if s in highlight:
html += f"<p style='background:#c7f7c7; padding:4px;'><b>{safe}</b></p>"
else:
html += f"<p>{safe}</p>"
return html + "</div>"
# -------------------------------------------------
# Prediction
# -------------------------------------------------
def chulo_predict(text):
sents = split_sentences(text)
chosen = select_exact_bme(sents)
selected = [s for s, _ in chosen]
chunks, masks = encode_sentence_chunks(selected)
with torch.no_grad():
logits = model(
input_ids=chunks.unsqueeze(0).to(DEVICE),
attention_mask=masks.unsqueeze(0).to(DEVICE)
)
probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
pred_id = int(np.argmax(probs))
pred_label = id2label[pred_id]
topk = sorted(
[(id2label[i], float(probs[i])) for i in range(len(probs))],
key=lambda x: x[1],
reverse=True
)[:5]
return f"Predicted Label: {pred_label}", topk, build_html(sents, chosen)
# -------------------------------------------------
# Gradio UI
# -------------------------------------------------
demo = gr.Interface(
fn=chulo_predict,
inputs=gr.Textbox(lines=10, label="Enter Afanoromo News Text"),
outputs=[
gr.Textbox(label="Prediction"),
gr.Dataframe(headers=["Label", "Probability"], label="Top Probabilities"),
gr.HTML(label="Highlighted Document"),
],
title="Sentence‑ChuLo β€” Amharic News Classifier",
description="Exact Beginning–Middle–End sentence selection with hierarchical chunk attention."
)
demo.launch()