mahmoudmohammad commited on
Commit
83d21fb
·
verified ·
1 Parent(s): 47b6d46

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +105 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import collections
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ HF_USERNAME = "mahmoudmohammad"
7
+ CONFIDENCE_THRESHOLD = 0.70
8
+
9
+ print("Booting Global Taxonomy Engine...")
10
+ # --- 1. Permanently Load L1 Model ---
11
+ l1_repo = f"{HF_USERNAME}/SANAD-L1-Root-Classifier"
12
+ l1_tokenizer = AutoTokenizer.from_pretrained(l1_repo)
13
+ l1_model = AutoModelForSequenceClassification.from_pretrained(l1_repo)
14
+ l1_model.eval()
15
+
16
+ # Dynamically extract which L2 classes exist directly from the L1 id mappings
17
+ # Matches format deployed to HF Hub
18
+ available_branches = [label for label in l1_model.config.id2label.values()]
19
+
20
+ # --- 2. Smart Memory Manager (LRU Cache) ---
21
+ # Limits how many L2 models sit in RAM at once to avoid Out-Of-Memory errors
22
+ class L2ModelCache:
23
+ def __init__(self, max_models=3):
24
+ self.max_models = max_models
25
+ self.cache = collections.OrderedDict()
26
+
27
+ def get_model(self, l1_label):
28
+ if l1_label in self.cache:
29
+ self.cache.move_to_end(l1_label)
30
+ return self.cache[l1_label]
31
+
32
+ print(f"Loading {l1_label} L2 model into RAM...")
33
+ repo_id = f"{HF_USERNAME}/SANAD-L2-{l1_label}-Classifier"
34
+
35
+ try:
36
+ tok = AutoTokenizer.from_pretrained(repo_id)
37
+ mod = AutoModelForSequenceClassification.from_pretrained(repo_id)
38
+ mod.eval()
39
+ self.cache[l1_label] = (tok, mod)
40
+
41
+ if len(self.cache) > self.max_models:
42
+ evicted = self.cache.popitem(last=False)
43
+ print(f"Unloaded {evicted[0]} L2 model from RAM to free space.")
44
+
45
+ return self.cache[l1_label]
46
+ except Exception:
47
+ return None, None # Branch model doesn't exist on hub (Flattened L1)
48
+
49
+ l2_manager = L2ModelCache(max_models=3)
50
+
51
+ # --- 3. The 2-Stage Routing Logic ---
52
+ def classify_news(text):
53
+ if not text.strip():
54
+ return "Empty text", "N/A"
55
+
56
+ # Stage 1: L1 Routing
57
+ inputs = l1_tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
58
+ with torch.no_grad():
59
+ out1 = l1_model(**inputs)
60
+
61
+ probs1 = torch.softmax(out1.logits, dim=-1).squeeze()
62
+ conf1 = probs1.max().item()
63
+ pred1 = l1_model.config.id2label[probs1.argmax().item()]
64
+
65
+ # Cascade: If root is unsure, drop instantly
66
+ if conf1 < CONFIDENCE_THRESHOLD:
67
+ return "Uncertain", f"L1 Drop: {pred1} (Conf: {conf1:.2f})"
68
+
69
+ # Attempt Stage 2 (Drilldown)
70
+ l2_tok, l2_mod = l2_manager.get_model(pred1)
71
+
72
+ # Branch doesn't exist? (Phase 1D Flattening executed correctly)
73
+ if not l2_mod:
74
+ return pred1, f"Status: L1 Flat Structure Approved (Conf: {conf1:.2f})"
75
+
76
+ # Route through existing L2
77
+ l2_in = l2_tok(text, return_tensors="pt", truncation=True, max_length=256)
78
+ with torch.no_grad():
79
+ out2 = l2_mod(**l2_in)
80
+
81
+ probs2 = torch.softmax(out2.logits, dim=-1).squeeze()
82
+ conf2 = probs2.max().item()
83
+ pred2 = l2_mod.config.id2label[probs2.argmax().item()]
84
+
85
+ # Confidence test Stage 2 - Drop safely to L1 generalization if fail
86
+ if conf2 < CONFIDENCE_THRESHOLD:
87
+ return pred1, f"Status: Sub-Tag Rejected. Dropped to Base Root (L2 Conf: {conf2:.2f})"
88
+
89
+ # Pure hierarchy completion
90
+ return f"{pred1} / {pred2}", f"Success: L1({conf1:.2f}) -> L2({conf2:.2f})"
91
+
92
+ # --- 4. The Front-End UI ---
93
+ iface = gr.Interface(
94
+ fn=classify_news,
95
+ inputs=gr.Textbox(lines=7, label="Arabic News Text", placeholder="Paste article here..."),
96
+ outputs=[
97
+ gr.Textbox(label="Final Category Assignment"),
98
+ gr.Textbox(label="Confidence Diagnostics Routing Debugger")
99
+ ],
100
+ title="Arabic News Hierarchical Categorizer (L1 + L2 Pipeline)",
101
+ description="This gateway automates intelligent semantic tracking against 8 Deep Learning architecture branches globally.",
102
+ examples=["سجل فريق ريال مدريد فوزاً كاسحاً في دوري أبطال أوروبا"]
103
+ )
104
+
105
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ scikit-learn
4
+ pandas