| import gradio as gr |
| import torch |
| import collections |
| import re |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| |
| from camel_tools.utils.normalize import normalize_alef_maksura_ar |
| from camel_tools.utils.normalize import normalize_alef_ar |
| from camel_tools.utils.normalize import normalize_teh_marbuta_ar |
| from camel_tools.utils.dediac import dediac_ar |
|
|
| HF_USERNAME = "mahmoudmohammad" |
| CONFIDENCE_THRESHOLD = 0.70 |
|
|
| |
| def clean_arabic_news(text): |
| if not isinstance(text, str): return "" |
| |
| text = re.sub(r'http\S+|www.\S+', '', text) |
| text = re.sub(r'<.*?>', '', text) |
| text = re.sub(r'@\w+', '', text) |
| text = re.sub(r'\s+', ' ', text).strip() |
| |
| |
| text = dediac_ar(text) |
| text = normalize_alef_ar(text) |
| text = normalize_alef_maksura_ar(text) |
| text = normalize_teh_marbuta_ar(text) |
| return text |
|
|
| print("Booting Global Taxonomy Engine...") |
|
|
| |
| l1_repo = f"{HF_USERNAME}/SANAD-L1-Root-Classifier" |
| l1_tokenizer = AutoTokenizer.from_pretrained(l1_repo) |
| l1_model = AutoModelForSequenceClassification.from_pretrained(l1_repo) |
| l1_model.eval() |
|
|
| |
| class L2ModelCache: |
| def __init__(self, max_models=3): |
| self.max_models = max_models |
| self.cache = collections.OrderedDict() |
|
|
| def get_model(self, l1_label): |
| if l1_label in self.cache: |
| self.cache.move_to_end(l1_label) |
| return self.cache[l1_label] |
| |
| print(f"Loading {l1_label} L2 model into RAM...") |
| repo_id = f"{HF_USERNAME}/SANAD-L2-{l1_label}-Classifier" |
| |
| try: |
| tok = AutoTokenizer.from_pretrained(repo_id) |
| mod = AutoModelForSequenceClassification.from_pretrained(repo_id) |
| mod.eval() |
| self.cache[l1_label] = (tok, mod) |
| |
| if len(self.cache) > self.max_models: |
| evicted = self.cache.popitem(last=False) |
| print(f"Unloaded {evicted[0]} L2 model from RAM.") |
| return self.cache[l1_label] |
| except Exception: |
| return None, None |
|
|
| l2_manager = L2ModelCache(max_models=3) |
|
|
| |
| def classify_news(text): |
| if not text.strip(): |
| return "Empty text", "N/A" |
|
|
| |
| cleaned_text = clean_arabic_news(text) |
|
|
| |
| inputs = l1_tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=256) |
| with torch.no_grad(): |
| out1 = l1_model(**inputs) |
| |
| probs1 = torch.softmax(out1.logits, dim=-1).squeeze() |
| conf1 = probs1.max().item() |
| pred1 = l1_model.config.id2label[probs1.argmax().item()] |
| |
| if conf1 < CONFIDENCE_THRESHOLD: |
| return "Uncertain", f"L1 Drop: {pred1} (Conf: {conf1:.2f})" |
| |
| l2_tok, l2_mod = l2_manager.get_model(pred1) |
| |
| if not l2_mod: |
| return pred1, f"Status: L1 Flat Structure Approved (Conf: {conf1:.2f})" |
| |
| |
| l2_in = l2_tok(cleaned_text, return_tensors="pt", truncation=True, max_length=256) |
| with torch.no_grad(): |
| out2 = l2_mod(**l2_in) |
| |
| probs2 = torch.softmax(out2.logits, dim=-1).squeeze() |
| conf2 = probs2.max().item() |
| pred2 = l2_mod.config.id2label[probs2.argmax().item()] |
| |
| if conf2 < CONFIDENCE_THRESHOLD: |
| return pred1, f"Status: Sub-Tag Rejected. Dropped to Root (L2 Conf: {conf2:.2f})" |
| |
| return f"{pred1} / {pred2}", f"Success: L1({conf1:.2f}) -> L2({conf2:.2f})" |
|
|
| |
| iface = gr.Interface( |
| fn=classify_news, |
| inputs=gr.Textbox(lines=7, label="Arabic News Text", placeholder="Paste article here..."), |
| outputs=[ |
| gr.Textbox(label="Final Category Assignment"), |
| gr.Textbox(label="Confidence Diagnostics") |
| ], |
| title="Arabic News Hierarchical Categorizer (L1 + L2 Pipeline)", |
| description="This gateway intelligently filters, normalizes, and classifies Arabic text dynamically.", |
| examples=["سجل فريق ريال مدريد فوزاً كاسحاً في دوري أبطال أوروبا"] |
| ) |
|
|
| iface.launch() |