File size: 4,355 Bytes
204c4de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83d21fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
import collections
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Camel-Tools Preprocessing Libraries
from camel_tools.utils.normalize import normalize_alef_maksura_ar
from camel_tools.utils.normalize import normalize_alef_ar
from camel_tools.utils.normalize import normalize_teh_marbuta_ar
from camel_tools.utils.dediac import dediac_ar

HF_USERNAME = "mahmoudmohammad"  
CONFIDENCE_THRESHOLD = 0.70

# --- 0. Exact Same Preprocessing used in Training Phase ---
def clean_arabic_news(text):
    if not isinstance(text, str): return ""
    # Strip garbage characters
    text = re.sub(r'http\S+|www.\S+', '', text)
    text = re.sub(r'<.*?>', '', text) 
    text = re.sub(r'@\w+', '', text)  
    text = re.sub(r'\s+', ' ', text).strip()
    
    # NLP Morphology standardization
    text = dediac_ar(text)
    text = normalize_alef_ar(text)
    text = normalize_alef_maksura_ar(text)
    text = normalize_teh_marbuta_ar(text)
    return text

print("Booting Global Taxonomy Engine...")

# --- 1. Permanently Load L1 Model ---
l1_repo = f"{HF_USERNAME}/SANAD-L1-Root-Classifier"
l1_tokenizer = AutoTokenizer.from_pretrained(l1_repo)
l1_model = AutoModelForSequenceClassification.from_pretrained(l1_repo)
l1_model.eval()

# --- 2. Smart Memory Manager (LRU Cache) ---
class L2ModelCache:
    def __init__(self, max_models=3):
        self.max_models = max_models
        self.cache = collections.OrderedDict()

    def get_model(self, l1_label):
        if l1_label in self.cache:
            self.cache.move_to_end(l1_label)
            return self.cache[l1_label]
            
        print(f"Loading {l1_label} L2 model into RAM...")
        repo_id = f"{HF_USERNAME}/SANAD-L2-{l1_label}-Classifier"
        
        try:
            tok = AutoTokenizer.from_pretrained(repo_id)
            mod = AutoModelForSequenceClassification.from_pretrained(repo_id)
            mod.eval()
            self.cache[l1_label] = (tok, mod)
            
            if len(self.cache) > self.max_models:
                evicted = self.cache.popitem(last=False)
                print(f"Unloaded {evicted[0]} L2 model from RAM.")
            return self.cache[l1_label]
        except Exception:
            return None, None 

l2_manager = L2ModelCache(max_models=3)

# --- 3. The 2-Stage Routing Logic ---
def classify_news(text):
    if not text.strip():
        return "Empty text", "N/A"

    # CRITICAL: Clean the incoming API request!
    cleaned_text = clean_arabic_news(text)

    # Stage 1: L1 Routing
    inputs = l1_tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=256)
    with torch.no_grad():
        out1 = l1_model(**inputs)
        
    probs1 = torch.softmax(out1.logits, dim=-1).squeeze()
    conf1 = probs1.max().item()
    pred1 = l1_model.config.id2label[probs1.argmax().item()]
    
    if conf1 < CONFIDENCE_THRESHOLD:
        return "Uncertain", f"L1 Drop: {pred1} (Conf: {conf1:.2f})"
        
    l2_tok, l2_mod = l2_manager.get_model(pred1)
    
    if not l2_mod:
        return pred1, f"Status: L1 Flat Structure Approved (Conf: {conf1:.2f})"
        
    # Stage 2: Ensure we feed the CLEAN text here as well
    l2_in = l2_tok(cleaned_text, return_tensors="pt", truncation=True, max_length=256)
    with torch.no_grad():
        out2 = l2_mod(**l2_in)
        
    probs2 = torch.softmax(out2.logits, dim=-1).squeeze()
    conf2 = probs2.max().item()
    pred2 = l2_mod.config.id2label[probs2.argmax().item()]
    
    if conf2 < CONFIDENCE_THRESHOLD:
         return pred1, f"Status: Sub-Tag Rejected. Dropped to Root (L2 Conf: {conf2:.2f})"
         
    return f"{pred1} / {pred2}", f"Success: L1({conf1:.2f}) -> L2({conf2:.2f})"

# --- 4. The Front-End UI ---
iface = gr.Interface(
    fn=classify_news,
    inputs=gr.Textbox(lines=7, label="Arabic News Text", placeholder="Paste article here..."),
    outputs=[
        gr.Textbox(label="Final Category Assignment"), 
        gr.Textbox(label="Confidence Diagnostics")
    ],
    title="Arabic News Hierarchical Categorizer (L1 + L2 Pipeline)",
    description="This gateway intelligently filters, normalizes, and classifies Arabic text dynamically.",
    examples=["سجل فريق ريال مدريد فوزاً كاسحاً في دوري أبطال أوروبا"]
)

iface.launch()