|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import os |
|
|
import re |
|
|
import json |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PRETRAINED = "Davlan/bert-base-multilingual-cased-finetuned-amharic" |
|
|
HF_MODEL_ID = "Abelex/afro-xlmr-large" |
|
|
|
|
|
CHUNK_SIZE = 512 |
|
|
MAX_CHUNKS = 8 |
|
|
CHUNK_DMODEL = 256 |
|
|
DROPOUT = 0.1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
config_path = hf_hub_download(HF_MODEL_ID, "config.json") |
|
|
with open(config_path) as f: |
|
|
cfg = json.load(f) |
|
|
|
|
|
id2label = {int(k): v for k, v in cfg["id2label"].items()} |
|
|
label2id = cfg["label2id"] |
|
|
num_labels = cfg["num_labels"] |
|
|
|
|
|
print("β Loaded label mappings from HF") |
|
|
except Exception as e: |
|
|
print("β Could not load config.json β using fallback") |
|
|
id2label = { |
|
|
0: "Politics", |
|
|
1: "Economy", |
|
|
2: "Sports", |
|
|
3: "Technology", |
|
|
4: "Health", |
|
|
5: "Agriculture", |
|
|
6: "accident", |
|
|
7: "education", |
|
|
} |
|
|
label2id = {v: k for k, v in id2label.items()} |
|
|
num_labels = len(id2label) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridSentenceChuLo(nn.Module): |
|
|
def __init__(self, pretrained_name, num_labels): |
|
|
super().__init__() |
|
|
self.bert = AutoModel.from_pretrained(pretrained_name) |
|
|
hidden = self.bert.config.hidden_size |
|
|
|
|
|
self.proj = nn.Linear(hidden, CHUNK_DMODEL) if hidden != CHUNK_DMODEL else nn.Identity() |
|
|
self.token_attn_vec = nn.Parameter(torch.randn(CHUNK_DMODEL)) |
|
|
|
|
|
enc_layer = nn.TransformerEncoderLayer( |
|
|
d_model=CHUNK_DMODEL, |
|
|
nhead=8, |
|
|
dim_feedforward=4 * CHUNK_DMODEL, |
|
|
batch_first=True, |
|
|
dropout=DROPOUT |
|
|
) |
|
|
self.chunk_transformer = nn.TransformerEncoder(enc_layer, num_layers=2) |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.LayerNorm(CHUNK_DMODEL), |
|
|
nn.Linear(CHUNK_DMODEL, num_labels) |
|
|
) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
B, C, T = input_ids.size() |
|
|
|
|
|
flat_ids = input_ids.view(B * C, T) |
|
|
flat_mask = attention_mask.view(B * C, T) |
|
|
|
|
|
out = self.bert(input_ids=flat_ids, attention_mask=flat_mask) |
|
|
h = self.proj(out.last_hidden_state) |
|
|
|
|
|
scores = torch.matmul(h, self.token_attn_vec) |
|
|
scores = scores.masked_fill(flat_mask == 0, torch.finfo(scores.dtype).min) |
|
|
weights = torch.softmax(scores, dim=1).unsqueeze(-1) |
|
|
|
|
|
chunk_vecs = (h * weights).sum(dim=1).view(B, C, CHUNK_DMODEL) |
|
|
|
|
|
chunk_mask = (attention_mask.sum(dim=2) > 0) |
|
|
key_padding_mask = ~chunk_mask |
|
|
|
|
|
enc = self.chunk_transformer(chunk_vecs, src_key_padding_mask=key_padding_mask) |
|
|
|
|
|
valid = (~key_padding_mask).unsqueeze(-1).float() |
|
|
doc_vec = (enc * valid).sum(dim=1) / valid.sum(dim=1).clamp(min=1e-6) |
|
|
|
|
|
return self.classifier(doc_vec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) |
|
|
|
|
|
model = HybridSentenceChuLo(PRETRAINED, num_labels).to(DEVICE) |
|
|
|
|
|
from transformers import AutoModel |
|
|
model = AutoModel.from_pretrained( |
|
|
"Abelex/afro-xlmr-large", |
|
|
trust_remote_code=True |
|
|
) |
|
|
model.load_state_dict(state, strict=False) |
|
|
model.eval() |
|
|
|
|
|
print("β Model ready") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_sentences(text): |
|
|
sents = re.split(r"(?<=[α’α€!?])\s+", text) |
|
|
return [s.strip() for s in sents if s.strip()] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_exact_bme(sentences): |
|
|
n = len(sentences) |
|
|
if n == 0: |
|
|
return [] |
|
|
|
|
|
idxs = sorted(set([0, n // 2, n - 1])) |
|
|
return [(sentences[i], 1) for i in idxs] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_sentence_chunks(sentences): |
|
|
chunks, masks = [], [] |
|
|
|
|
|
for s in sentences: |
|
|
enc = tokenizer( |
|
|
s, |
|
|
max_length=CHUNK_SIZE, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
chunks.append(enc["input_ids"][0]) |
|
|
masks.append(enc["attention_mask"][0]) |
|
|
|
|
|
while len(chunks) < MAX_CHUNKS: |
|
|
chunks.append(torch.full((CHUNK_SIZE,), tokenizer.pad_token_id)) |
|
|
masks.append(torch.zeros(CHUNK_SIZE, dtype=torch.long)) |
|
|
|
|
|
return torch.stack(chunks[:MAX_CHUNKS]), torch.stack(masks[:MAX_CHUNKS]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_html(all_sents, selected): |
|
|
highlight = {s for s, _ in selected} |
|
|
html = "<div style='font-size:16px; line-height:1.6;'>" |
|
|
|
|
|
for s in all_sents: |
|
|
safe = s.replace("<", "<").replace(">", ">") |
|
|
if s in highlight: |
|
|
html += f"<p style='background:#c7f7c7; padding:4px;'><b>{safe}</b></p>" |
|
|
else: |
|
|
html += f"<p>{safe}</p>" |
|
|
|
|
|
return html + "</div>" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chulo_predict(text): |
|
|
sents = split_sentences(text) |
|
|
chosen = select_exact_bme(sents) |
|
|
selected = [s for s, _ in chosen] |
|
|
|
|
|
chunks, masks = encode_sentence_chunks(selected) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model( |
|
|
input_ids=chunks.unsqueeze(0).to(DEVICE), |
|
|
attention_mask=masks.unsqueeze(0).to(DEVICE) |
|
|
) |
|
|
probs = torch.softmax(logits, dim=-1)[0].cpu().numpy() |
|
|
|
|
|
pred_id = int(np.argmax(probs)) |
|
|
pred_label = id2label[pred_id] |
|
|
|
|
|
topk = sorted( |
|
|
[(id2label[i], float(probs[i])) for i in range(len(probs))], |
|
|
key=lambda x: x[1], |
|
|
reverse=True |
|
|
)[:5] |
|
|
|
|
|
return f"Predicted Label: {pred_label}", topk, build_html(sents, chosen) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=chulo_predict, |
|
|
inputs=gr.Textbox(lines=10, label="Enter Afanoromo News Text"), |
|
|
outputs=[ |
|
|
gr.Textbox(label="Prediction"), |
|
|
gr.Dataframe(headers=["Label", "Probability"], label="Top Probabilities"), |
|
|
gr.HTML(label="Highlighted Document"), |
|
|
], |
|
|
title="SentenceβChuLo β Amharic News Classifier", |
|
|
description="Exact BeginningβMiddleβEnd sentence selection with hierarchical chunk attention." |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|