mahmoudmohammad's picture
Update app.py
204c4de verified
import gradio as gr
import torch
import collections
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Camel-Tools Preprocessing Libraries
from camel_tools.utils.normalize import normalize_alef_maksura_ar
from camel_tools.utils.normalize import normalize_alef_ar
from camel_tools.utils.normalize import normalize_teh_marbuta_ar
from camel_tools.utils.dediac import dediac_ar
HF_USERNAME = "mahmoudmohammad"
CONFIDENCE_THRESHOLD = 0.70
# --- 0. Exact Same Preprocessing used in Training Phase ---
def clean_arabic_news(text):
if not isinstance(text, str): return ""
# Strip garbage characters
text = re.sub(r'http\S+|www.\S+', '', text)
text = re.sub(r'<.*?>', '', text)
text = re.sub(r'@\w+', '', text)
text = re.sub(r'\s+', ' ', text).strip()
# NLP Morphology standardization
text = dediac_ar(text)
text = normalize_alef_ar(text)
text = normalize_alef_maksura_ar(text)
text = normalize_teh_marbuta_ar(text)
return text
print("Booting Global Taxonomy Engine...")
# --- 1. Permanently Load L1 Model ---
l1_repo = f"{HF_USERNAME}/SANAD-L1-Root-Classifier"
l1_tokenizer = AutoTokenizer.from_pretrained(l1_repo)
l1_model = AutoModelForSequenceClassification.from_pretrained(l1_repo)
l1_model.eval()
# --- 2. Smart Memory Manager (LRU Cache) ---
class L2ModelCache:
def __init__(self, max_models=3):
self.max_models = max_models
self.cache = collections.OrderedDict()
def get_model(self, l1_label):
if l1_label in self.cache:
self.cache.move_to_end(l1_label)
return self.cache[l1_label]
print(f"Loading {l1_label} L2 model into RAM...")
repo_id = f"{HF_USERNAME}/SANAD-L2-{l1_label}-Classifier"
try:
tok = AutoTokenizer.from_pretrained(repo_id)
mod = AutoModelForSequenceClassification.from_pretrained(repo_id)
mod.eval()
self.cache[l1_label] = (tok, mod)
if len(self.cache) > self.max_models:
evicted = self.cache.popitem(last=False)
print(f"Unloaded {evicted[0]} L2 model from RAM.")
return self.cache[l1_label]
except Exception:
return None, None
l2_manager = L2ModelCache(max_models=3)
# --- 3. The 2-Stage Routing Logic ---
def classify_news(text):
if not text.strip():
return "Empty text", "N/A"
# CRITICAL: Clean the incoming API request!
cleaned_text = clean_arabic_news(text)
# Stage 1: L1 Routing
inputs = l1_tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
out1 = l1_model(**inputs)
probs1 = torch.softmax(out1.logits, dim=-1).squeeze()
conf1 = probs1.max().item()
pred1 = l1_model.config.id2label[probs1.argmax().item()]
if conf1 < CONFIDENCE_THRESHOLD:
return "Uncertain", f"L1 Drop: {pred1} (Conf: {conf1:.2f})"
l2_tok, l2_mod = l2_manager.get_model(pred1)
if not l2_mod:
return pred1, f"Status: L1 Flat Structure Approved (Conf: {conf1:.2f})"
# Stage 2: Ensure we feed the CLEAN text here as well
l2_in = l2_tok(cleaned_text, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
out2 = l2_mod(**l2_in)
probs2 = torch.softmax(out2.logits, dim=-1).squeeze()
conf2 = probs2.max().item()
pred2 = l2_mod.config.id2label[probs2.argmax().item()]
if conf2 < CONFIDENCE_THRESHOLD:
return pred1, f"Status: Sub-Tag Rejected. Dropped to Root (L2 Conf: {conf2:.2f})"
return f"{pred1} / {pred2}", f"Success: L1({conf1:.2f}) -> L2({conf2:.2f})"
# --- 4. The Front-End UI ---
iface = gr.Interface(
fn=classify_news,
inputs=gr.Textbox(lines=7, label="Arabic News Text", placeholder="Paste article here..."),
outputs=[
gr.Textbox(label="Final Category Assignment"),
gr.Textbox(label="Confidence Diagnostics")
],
title="Arabic News Hierarchical Categorizer (L1 + L2 Pipeline)",
description="This gateway intelligently filters, normalizes, and classifies Arabic text dynamically.",
examples=["سجل فريق ريال مدريد فوزاً كاسحاً في دوري أبطال أوروبا"]
)
iface.launch()