mahmoudmohammad commited on
Commit
204c4de
·
verified ·
1 Parent(s): cc2b683

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -104
app.py CHANGED
@@ -1,105 +1,121 @@
1
- import gradio as gr
2
- import torch
3
- import collections
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
-
6
- HF_USERNAME = "mahmoudmohammad"
7
- CONFIDENCE_THRESHOLD = 0.70
8
-
9
- print("Booting Global Taxonomy Engine...")
10
- # --- 1. Permanently Load L1 Model ---
11
- l1_repo = f"{HF_USERNAME}/SANAD-L1-Root-Classifier"
12
- l1_tokenizer = AutoTokenizer.from_pretrained(l1_repo)
13
- l1_model = AutoModelForSequenceClassification.from_pretrained(l1_repo)
14
- l1_model.eval()
15
-
16
- # Dynamically extract which L2 classes exist directly from the L1 id mappings
17
- # Matches format deployed to HF Hub
18
- available_branches = [label for label in l1_model.config.id2label.values()]
19
-
20
- # --- 2. Smart Memory Manager (LRU Cache) ---
21
- # Limits how many L2 models sit in RAM at once to avoid Out-Of-Memory errors
22
- class L2ModelCache:
23
- def __init__(self, max_models=3):
24
- self.max_models = max_models
25
- self.cache = collections.OrderedDict()
26
-
27
- def get_model(self, l1_label):
28
- if l1_label in self.cache:
29
- self.cache.move_to_end(l1_label)
30
- return self.cache[l1_label]
31
-
32
- print(f"Loading {l1_label} L2 model into RAM...")
33
- repo_id = f"{HF_USERNAME}/SANAD-L2-{l1_label}-Classifier"
34
-
35
- try:
36
- tok = AutoTokenizer.from_pretrained(repo_id)
37
- mod = AutoModelForSequenceClassification.from_pretrained(repo_id)
38
- mod.eval()
39
- self.cache[l1_label] = (tok, mod)
40
-
41
- if len(self.cache) > self.max_models:
42
- evicted = self.cache.popitem(last=False)
43
- print(f"Unloaded {evicted[0]} L2 model from RAM to free space.")
44
-
45
- return self.cache[l1_label]
46
- except Exception:
47
- return None, None # Branch model doesn't exist on hub (Flattened L1)
48
-
49
- l2_manager = L2ModelCache(max_models=3)
50
-
51
- # --- 3. The 2-Stage Routing Logic ---
52
- def classify_news(text):
53
- if not text.strip():
54
- return "Empty text", "N/A"
55
-
56
- # Stage 1: L1 Routing
57
- inputs = l1_tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
58
- with torch.no_grad():
59
- out1 = l1_model(**inputs)
60
-
61
- probs1 = torch.softmax(out1.logits, dim=-1).squeeze()
62
- conf1 = probs1.max().item()
63
- pred1 = l1_model.config.id2label[probs1.argmax().item()]
64
-
65
- # Cascade: If root is unsure, drop instantly
66
- if conf1 < CONFIDENCE_THRESHOLD:
67
- return "Uncertain", f"L1 Drop: {pred1} (Conf: {conf1:.2f})"
68
-
69
- # Attempt Stage 2 (Drilldown)
70
- l2_tok, l2_mod = l2_manager.get_model(pred1)
71
-
72
- # Branch doesn't exist? (Phase 1D Flattening executed correctly)
73
- if not l2_mod:
74
- return pred1, f"Status: L1 Flat Structure Approved (Conf: {conf1:.2f})"
75
-
76
- # Route through existing L2
77
- l2_in = l2_tok(text, return_tensors="pt", truncation=True, max_length=256)
78
- with torch.no_grad():
79
- out2 = l2_mod(**l2_in)
80
-
81
- probs2 = torch.softmax(out2.logits, dim=-1).squeeze()
82
- conf2 = probs2.max().item()
83
- pred2 = l2_mod.config.id2label[probs2.argmax().item()]
84
-
85
- # Confidence test Stage 2 - Drop safely to L1 generalization if fail
86
- if conf2 < CONFIDENCE_THRESHOLD:
87
- return pred1, f"Status: Sub-Tag Rejected. Dropped to Base Root (L2 Conf: {conf2:.2f})"
88
-
89
- # Pure hierarchy completion
90
- return f"{pred1} / {pred2}", f"Success: L1({conf1:.2f}) -> L2({conf2:.2f})"
91
-
92
- # --- 4. The Front-End UI ---
93
- iface = gr.Interface(
94
- fn=classify_news,
95
- inputs=gr.Textbox(lines=7, label="Arabic News Text", placeholder="Paste article here..."),
96
- outputs=[
97
- gr.Textbox(label="Final Category Assignment"),
98
- gr.Textbox(label="Confidence Diagnostics Routing Debugger")
99
- ],
100
- title="Arabic News Hierarchical Categorizer (L1 + L2 Pipeline)",
101
- description="This gateway automates intelligent semantic tracking against 8 Deep Learning architecture branches globally.",
102
- examples=["سجل فريق ريال مدريد فوزاً كاسحاً في دوري أبطال أوروبا"]
103
- )
104
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  iface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import collections
4
+ import re
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ # Camel-Tools Preprocessing Libraries
8
+ from camel_tools.utils.normalize import normalize_alef_maksura_ar
9
+ from camel_tools.utils.normalize import normalize_alef_ar
10
+ from camel_tools.utils.normalize import normalize_teh_marbuta_ar
11
+ from camel_tools.utils.dediac import dediac_ar
12
+
13
+ HF_USERNAME = "mahmoudmohammad"
14
+ CONFIDENCE_THRESHOLD = 0.70
15
+
16
+ # --- 0. Exact Same Preprocessing used in Training Phase ---
17
+ def clean_arabic_news(text):
18
+ if not isinstance(text, str): return ""
19
+ # Strip garbage characters
20
+ text = re.sub(r'http\S+|www.\S+', '', text)
21
+ text = re.sub(r'<.*?>', '', text)
22
+ text = re.sub(r'@\w+', '', text)
23
+ text = re.sub(r'\s+', ' ', text).strip()
24
+
25
+ # NLP Morphology standardization
26
+ text = dediac_ar(text)
27
+ text = normalize_alef_ar(text)
28
+ text = normalize_alef_maksura_ar(text)
29
+ text = normalize_teh_marbuta_ar(text)
30
+ return text
31
+
32
+ print("Booting Global Taxonomy Engine...")
33
+
34
+ # --- 1. Permanently Load L1 Model ---
35
+ l1_repo = f"{HF_USERNAME}/SANAD-L1-Root-Classifier"
36
+ l1_tokenizer = AutoTokenizer.from_pretrained(l1_repo)
37
+ l1_model = AutoModelForSequenceClassification.from_pretrained(l1_repo)
38
+ l1_model.eval()
39
+
40
+ # --- 2. Smart Memory Manager (LRU Cache) ---
41
+ class L2ModelCache:
42
+ def __init__(self, max_models=3):
43
+ self.max_models = max_models
44
+ self.cache = collections.OrderedDict()
45
+
46
+ def get_model(self, l1_label):
47
+ if l1_label in self.cache:
48
+ self.cache.move_to_end(l1_label)
49
+ return self.cache[l1_label]
50
+
51
+ print(f"Loading {l1_label} L2 model into RAM...")
52
+ repo_id = f"{HF_USERNAME}/SANAD-L2-{l1_label}-Classifier"
53
+
54
+ try:
55
+ tok = AutoTokenizer.from_pretrained(repo_id)
56
+ mod = AutoModelForSequenceClassification.from_pretrained(repo_id)
57
+ mod.eval()
58
+ self.cache[l1_label] = (tok, mod)
59
+
60
+ if len(self.cache) > self.max_models:
61
+ evicted = self.cache.popitem(last=False)
62
+ print(f"Unloaded {evicted[0]} L2 model from RAM.")
63
+ return self.cache[l1_label]
64
+ except Exception:
65
+ return None, None
66
+
67
+ l2_manager = L2ModelCache(max_models=3)
68
+
69
+ # --- 3. The 2-Stage Routing Logic ---
70
+ def classify_news(text):
71
+ if not text.strip():
72
+ return "Empty text", "N/A"
73
+
74
+ # CRITICAL: Clean the incoming API request!
75
+ cleaned_text = clean_arabic_news(text)
76
+
77
+ # Stage 1: L1 Routing
78
+ inputs = l1_tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=256)
79
+ with torch.no_grad():
80
+ out1 = l1_model(**inputs)
81
+
82
+ probs1 = torch.softmax(out1.logits, dim=-1).squeeze()
83
+ conf1 = probs1.max().item()
84
+ pred1 = l1_model.config.id2label[probs1.argmax().item()]
85
+
86
+ if conf1 < CONFIDENCE_THRESHOLD:
87
+ return "Uncertain", f"L1 Drop: {pred1} (Conf: {conf1:.2f})"
88
+
89
+ l2_tok, l2_mod = l2_manager.get_model(pred1)
90
+
91
+ if not l2_mod:
92
+ return pred1, f"Status: L1 Flat Structure Approved (Conf: {conf1:.2f})"
93
+
94
+ # Stage 2: Ensure we feed the CLEAN text here as well
95
+ l2_in = l2_tok(cleaned_text, return_tensors="pt", truncation=True, max_length=256)
96
+ with torch.no_grad():
97
+ out2 = l2_mod(**l2_in)
98
+
99
+ probs2 = torch.softmax(out2.logits, dim=-1).squeeze()
100
+ conf2 = probs2.max().item()
101
+ pred2 = l2_mod.config.id2label[probs2.argmax().item()]
102
+
103
+ if conf2 < CONFIDENCE_THRESHOLD:
104
+ return pred1, f"Status: Sub-Tag Rejected. Dropped to Root (L2 Conf: {conf2:.2f})"
105
+
106
+ return f"{pred1} / {pred2}", f"Success: L1({conf1:.2f}) -> L2({conf2:.2f})"
107
+
108
+ # --- 4. The Front-End UI ---
109
+ iface = gr.Interface(
110
+ fn=classify_news,
111
+ inputs=gr.Textbox(lines=7, label="Arabic News Text", placeholder="Paste article here..."),
112
+ outputs=[
113
+ gr.Textbox(label="Final Category Assignment"),
114
+ gr.Textbox(label="Confidence Diagnostics")
115
+ ],
116
+ title="Arabic News Hierarchical Categorizer (L1 + L2 Pipeline)",
117
+ description="This gateway intelligently filters, normalizes, and classifies Arabic text dynamically.",
118
+ examples=["سجل فريق ريال مدريد فوزاً كاسحاً في دوري أبطال أوروبا"]
119
+ )
120
+
121
  iface.launch()