# =============================== # Final Gradio Demo (FIXED + ALIGNED) # =============================== import gradio as gr import torch import torch.nn as nn import numpy as np import os import re import json from transformers import AutoTokenizer, AutoModel from huggingface_hub import hf_hub_download DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------------------------------------- # MODEL CONFIG (MUST MATCH TRAINING) # ------------------------------------------------- PRETRAINED = "Davlan/bert-base-multilingual-cased-finetuned-amharic" HF_MODEL_ID = "Abelex/afro-xlmr-large" CHUNK_SIZE = 512 MAX_CHUNKS = 8 CHUNK_DMODEL = 256 DROPOUT = 0.1 # ------------------------------------------------- # Load config from HF (labels, num_labels) # ------------------------------------------------- try: config_path = hf_hub_download(HF_MODEL_ID, "config.json") with open(config_path) as f: cfg = json.load(f) id2label = {int(k): v for k, v in cfg["id2label"].items()} label2id = cfg["label2id"] num_labels = cfg["num_labels"] print("✓ Loaded label mappings from HF") except Exception as e: print("⚠ Could not load config.json — using fallback") id2label = { 0: "Politics", 1: "Economy", 2: "Sports", 3: "Technology", 4: "Health", 5: "Agriculture", 6: "accident", 7: "education", } label2id = {v: k for k, v in id2label.items()} num_labels = len(id2label) # ------------------------------------------------- # MODEL # ------------------------------------------------- class HybridSentenceChuLo(nn.Module): def __init__(self, pretrained_name, num_labels): super().__init__() self.bert = AutoModel.from_pretrained(pretrained_name) hidden = self.bert.config.hidden_size self.proj = nn.Linear(hidden, CHUNK_DMODEL) if hidden != CHUNK_DMODEL else nn.Identity() self.token_attn_vec = nn.Parameter(torch.randn(CHUNK_DMODEL)) enc_layer = nn.TransformerEncoderLayer( d_model=CHUNK_DMODEL, nhead=8, dim_feedforward=4 * CHUNK_DMODEL, batch_first=True, dropout=DROPOUT ) self.chunk_transformer = nn.TransformerEncoder(enc_layer, num_layers=2) self.classifier = nn.Sequential( nn.LayerNorm(CHUNK_DMODEL), nn.Linear(CHUNK_DMODEL, num_labels) ) def forward(self, input_ids, attention_mask): B, C, T = input_ids.size() flat_ids = input_ids.view(B * C, T) flat_mask = attention_mask.view(B * C, T) out = self.bert(input_ids=flat_ids, attention_mask=flat_mask) h = self.proj(out.last_hidden_state) scores = torch.matmul(h, self.token_attn_vec) scores = scores.masked_fill(flat_mask == 0, torch.finfo(scores.dtype).min) weights = torch.softmax(scores, dim=1).unsqueeze(-1) chunk_vecs = (h * weights).sum(dim=1).view(B, C, CHUNK_DMODEL) chunk_mask = (attention_mask.sum(dim=2) > 0) key_padding_mask = ~chunk_mask enc = self.chunk_transformer(chunk_vecs, src_key_padding_mask=key_padding_mask) valid = (~key_padding_mask).unsqueeze(-1).float() doc_vec = (enc * valid).sum(dim=1) / valid.sum(dim=1).clamp(min=1e-6) return self.classifier(doc_vec) # ------------------------------------------------- # Load tokenizer & model # ------------------------------------------------- tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) model = HybridSentenceChuLo(PRETRAINED, num_labels).to(DEVICE) from transformers import AutoModel model = AutoModel.from_pretrained( "Abelex/afro-xlmr-large", trust_remote_code=True ) model.load_state_dict(state, strict=False) model.eval() print("✓ Model ready") # ------------------------------------------------- # Sentence splitting # ------------------------------------------------- def split_sentences(text): sents = re.split(r"(?<=[።፤!?])\s+", text) return [s.strip() for s in sents if s.strip()] # ------------------------------------------------- # EXACT Beginning–Middle–End selection # ------------------------------------------------- def select_exact_bme(sentences): n = len(sentences) if n == 0: return [] idxs = sorted(set([0, n // 2, n - 1])) return [(sentences[i], 1) for i in idxs] # ------------------------------------------------- # Encode chunks # ------------------------------------------------- def encode_sentence_chunks(sentences): chunks, masks = [], [] for s in sentences: enc = tokenizer( s, max_length=CHUNK_SIZE, padding="max_length", truncation=True, return_tensors="pt" ) chunks.append(enc["input_ids"][0]) masks.append(enc["attention_mask"][0]) while len(chunks) < MAX_CHUNKS: chunks.append(torch.full((CHUNK_SIZE,), tokenizer.pad_token_id)) masks.append(torch.zeros(CHUNK_SIZE, dtype=torch.long)) return torch.stack(chunks[:MAX_CHUNKS]), torch.stack(masks[:MAX_CHUNKS]) # ------------------------------------------------- # HTML Highlighting # ------------------------------------------------- def build_html(all_sents, selected): highlight = {s for s, _ in selected} html = "
{safe}
" else: html += f"{safe}
" return html + "