import gradio as gr import torch import collections import re from transformers import AutoTokenizer, AutoModelForSequenceClassification # Camel-Tools Preprocessing Libraries from camel_tools.utils.normalize import normalize_alef_maksura_ar from camel_tools.utils.normalize import normalize_alef_ar from camel_tools.utils.normalize import normalize_teh_marbuta_ar from camel_tools.utils.dediac import dediac_ar HF_USERNAME = "mahmoudmohammad" CONFIDENCE_THRESHOLD = 0.70 # --- 0. Exact Same Preprocessing used in Training Phase --- def clean_arabic_news(text): if not isinstance(text, str): return "" # Strip garbage characters text = re.sub(r'http\S+|www.\S+', '', text) text = re.sub(r'<.*?>', '', text) text = re.sub(r'@\w+', '', text) text = re.sub(r'\s+', ' ', text).strip() # NLP Morphology standardization text = dediac_ar(text) text = normalize_alef_ar(text) text = normalize_alef_maksura_ar(text) text = normalize_teh_marbuta_ar(text) return text print("Booting Global Taxonomy Engine...") # --- 1. Permanently Load L1 Model --- l1_repo = f"{HF_USERNAME}/SANAD-L1-Root-Classifier" l1_tokenizer = AutoTokenizer.from_pretrained(l1_repo) l1_model = AutoModelForSequenceClassification.from_pretrained(l1_repo) l1_model.eval() # --- 2. Smart Memory Manager (LRU Cache) --- class L2ModelCache: def __init__(self, max_models=3): self.max_models = max_models self.cache = collections.OrderedDict() def get_model(self, l1_label): if l1_label in self.cache: self.cache.move_to_end(l1_label) return self.cache[l1_label] print(f"Loading {l1_label} L2 model into RAM...") repo_id = f"{HF_USERNAME}/SANAD-L2-{l1_label}-Classifier" try: tok = AutoTokenizer.from_pretrained(repo_id) mod = AutoModelForSequenceClassification.from_pretrained(repo_id) mod.eval() self.cache[l1_label] = (tok, mod) if len(self.cache) > self.max_models: evicted = self.cache.popitem(last=False) print(f"Unloaded {evicted[0]} L2 model from RAM.") return self.cache[l1_label] except Exception: return None, None l2_manager = L2ModelCache(max_models=3) # --- 3. The 2-Stage Routing Logic --- def classify_news(text): if not text.strip(): return "Empty text", "N/A" # CRITICAL: Clean the incoming API request! cleaned_text = clean_arabic_news(text) # Stage 1: L1 Routing inputs = l1_tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=256) with torch.no_grad(): out1 = l1_model(**inputs) probs1 = torch.softmax(out1.logits, dim=-1).squeeze() conf1 = probs1.max().item() pred1 = l1_model.config.id2label[probs1.argmax().item()] if conf1 < CONFIDENCE_THRESHOLD: return "Uncertain", f"L1 Drop: {pred1} (Conf: {conf1:.2f})" l2_tok, l2_mod = l2_manager.get_model(pred1) if not l2_mod: return pred1, f"Status: L1 Flat Structure Approved (Conf: {conf1:.2f})" # Stage 2: Ensure we feed the CLEAN text here as well l2_in = l2_tok(cleaned_text, return_tensors="pt", truncation=True, max_length=256) with torch.no_grad(): out2 = l2_mod(**l2_in) probs2 = torch.softmax(out2.logits, dim=-1).squeeze() conf2 = probs2.max().item() pred2 = l2_mod.config.id2label[probs2.argmax().item()] if conf2 < CONFIDENCE_THRESHOLD: return pred1, f"Status: Sub-Tag Rejected. Dropped to Root (L2 Conf: {conf2:.2f})" return f"{pred1} / {pred2}", f"Success: L1({conf1:.2f}) -> L2({conf2:.2f})" # --- 4. The Front-End UI --- iface = gr.Interface( fn=classify_news, inputs=gr.Textbox(lines=7, label="Arabic News Text", placeholder="Paste article here..."), outputs=[ gr.Textbox(label="Final Category Assignment"), gr.Textbox(label="Confidence Diagnostics") ], title="Arabic News Hierarchical Categorizer (L1 + L2 Pipeline)", description="This gateway intelligently filters, normalizes, and classifies Arabic text dynamically.", examples=["سجل فريق ريال مدريد فوزاً كاسحاً في دوري أبطال أوروبا"] ) iface.launch()