Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -126,59 +126,11 @@ _tokenizer = None
|
|
| 126 |
_model = None
|
| 127 |
_device = "cpu"
|
| 128 |
_preprocessor = None
|
| 129 |
-
_LABEL_MAPPING = None
|
| 130 |
|
| 131 |
|
| 132 |
# ============================================================================
|
| 133 |
# HELPER FUNCTIONS
|
| 134 |
# ============================================================================
|
| 135 |
-
def _get_label_mapping():
|
| 136 |
-
"""
|
| 137 |
-
Get complete label mapping.
|
| 138 |
-
If model config is incomplete, use fallback mapping.
|
| 139 |
-
"""
|
| 140 |
-
global _model, _LABEL_MAPPING
|
| 141 |
-
|
| 142 |
-
if _model is None:
|
| 143 |
-
return None
|
| 144 |
-
|
| 145 |
-
id2label = getattr(_model.config, "id2label", {}) or {}
|
| 146 |
-
|
| 147 |
-
# Check if mapping is incomplete (missing label 0)
|
| 148 |
-
num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
|
| 149 |
-
|
| 150 |
-
print(f"DEBUG: num_labels = {num_labels}")
|
| 151 |
-
print(f"DEBUG: id2label from config = {id2label}")
|
| 152 |
-
|
| 153 |
-
# If incomplete, use fallback
|
| 154 |
-
if len(id2label) < num_labels:
|
| 155 |
-
print(f"WARNING: Incomplete label mapping detected!")
|
| 156 |
-
print(f"Expected {num_labels} labels, got {len(id2label)}")
|
| 157 |
-
|
| 158 |
-
# Try to load from labels.json if available
|
| 159 |
-
try:
|
| 160 |
-
import pkg_resources
|
| 161 |
-
model_path = pkg_resources.resource_filename(__name__, 'models')
|
| 162 |
-
labels_path = os.path.join(model_path, 'labels.json')
|
| 163 |
-
if os.path.exists(labels_path):
|
| 164 |
-
with open(labels_path, 'r') as f:
|
| 165 |
-
labels_data = json.load(f)
|
| 166 |
-
id2label = labels_data.get("id2label", {})
|
| 167 |
-
print(f"Loaded labels from labels.json: {id2label}")
|
| 168 |
-
except:
|
| 169 |
-
pass
|
| 170 |
-
|
| 171 |
-
# Final fallback mapping
|
| 172 |
-
if len(id2label) < 2:
|
| 173 |
-
print("Using fallback label mapping: 0=LEGIT, 1=PHISH")
|
| 174 |
-
id2label = {
|
| 175 |
-
"0": "LEGIT",
|
| 176 |
-
"1": "PHISH"
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
return id2label
|
| 180 |
-
|
| 181 |
-
|
| 182 |
def _normalize_label(txt: str) -> str:
|
| 183 |
"""Normalize label text"""
|
| 184 |
t = (str(txt) if txt is not None else "").strip().upper()
|
|
@@ -191,7 +143,7 @@ def _normalize_label(txt: str) -> str:
|
|
| 191 |
|
| 192 |
def _load_model():
|
| 193 |
"""Load model, tokenizer, and preprocessor"""
|
| 194 |
-
global _tokenizer, _model, _device, _preprocessor
|
| 195 |
|
| 196 |
if _tokenizer is None or _model is None:
|
| 197 |
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -206,9 +158,6 @@ def _load_model():
|
|
| 206 |
_model.eval()
|
| 207 |
_preprocessor = TextPreprocessor()
|
| 208 |
|
| 209 |
-
# Get label mapping
|
| 210 |
-
_LABEL_MAPPING = _get_label_mapping()
|
| 211 |
-
|
| 212 |
# Warm-up
|
| 213 |
with torch.no_grad():
|
| 214 |
_ = _model(
|
|
@@ -217,14 +166,17 @@ def _load_model():
|
|
| 217 |
).logits
|
| 218 |
|
| 219 |
num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
|
|
|
|
|
|
|
| 220 |
print(f"Number of labels: {num_labels}")
|
| 221 |
-
print(f"Label mapping: {
|
| 222 |
print(f"{'='*60}\n")
|
| 223 |
|
| 224 |
|
| 225 |
def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
|
| 226 |
"""
|
| 227 |
-
Predict with
|
|
|
|
| 228 |
"""
|
| 229 |
_load_model()
|
| 230 |
if not texts:
|
|
@@ -250,39 +202,43 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
|
|
| 250 |
logits = _model(**enc).logits
|
| 251 |
probs = torch.softmax(logits, dim=-1)
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
label = _LABEL_MAPPING.get(str(i), f"LABEL_{i}")
|
| 258 |
-
labels_by_idx.append(label)
|
| 259 |
-
|
| 260 |
-
print(f"DEBUG: Using labels: {labels_by_idx}")
|
| 261 |
|
| 262 |
outputs: List[Dict] = []
|
| 263 |
for i in range(probs.shape[0]):
|
| 264 |
p = probs[i]
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
#
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
output = {
|
| 277 |
"text": texts[i][:100] + "..." if len(texts[i]) > 100 else texts[i],
|
| 278 |
-
"label":
|
| 279 |
-
"
|
| 280 |
-
"
|
| 281 |
-
"
|
| 282 |
-
"
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
}
|
| 287 |
|
| 288 |
if include_preprocessing and preprocessing_info:
|
|
@@ -305,13 +261,17 @@ def root():
|
|
| 305 |
"status": "ok",
|
| 306 |
"model": MODEL_ID,
|
| 307 |
"device": _device,
|
| 308 |
-
"label_mapping":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
}
|
| 310 |
|
| 311 |
|
| 312 |
@app.get("/debug/labels")
|
| 313 |
def debug_labels():
|
| 314 |
-
"""View
|
| 315 |
_load_model()
|
| 316 |
|
| 317 |
id2label_raw = getattr(_model.config, "id2label", {}) or {}
|
|
@@ -323,9 +283,11 @@ def debug_labels():
|
|
| 323 |
"config_id2label": id2label_raw,
|
| 324 |
"config_label2id": label2id_raw,
|
| 325 |
"config_num_labels": num_labels,
|
| 326 |
-
"
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
| 329 |
}
|
| 330 |
|
| 331 |
|
|
|
|
| 126 |
_model = None
|
| 127 |
_device = "cpu"
|
| 128 |
_preprocessor = None
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
# ============================================================================
|
| 132 |
# HELPER FUNCTIONS
|
| 133 |
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
def _normalize_label(txt: str) -> str:
|
| 135 |
"""Normalize label text"""
|
| 136 |
t = (str(txt) if txt is not None else "").strip().upper()
|
|
|
|
| 143 |
|
| 144 |
def _load_model():
|
| 145 |
"""Load model, tokenizer, and preprocessor"""
|
| 146 |
+
global _tokenizer, _model, _device, _preprocessor
|
| 147 |
|
| 148 |
if _tokenizer is None or _model is None:
|
| 149 |
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 158 |
_model.eval()
|
| 159 |
_preprocessor = TextPreprocessor()
|
| 160 |
|
|
|
|
|
|
|
|
|
|
| 161 |
# Warm-up
|
| 162 |
with torch.no_grad():
|
| 163 |
_ = _model(
|
|
|
|
| 166 |
).logits
|
| 167 |
|
| 168 |
num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
|
| 169 |
+
id2label = getattr(_model.config, "id2label", {}) or {}
|
| 170 |
+
|
| 171 |
print(f"Number of labels: {num_labels}")
|
| 172 |
+
print(f"Label mapping: {id2label}")
|
| 173 |
print(f"{'='*60}\n")
|
| 174 |
|
| 175 |
|
| 176 |
def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
|
| 177 |
"""
|
| 178 |
+
Predict with CORRECT label indexing.
|
| 179 |
+
Index 0 = LEGIT, Index 1 = PHISH
|
| 180 |
"""
|
| 181 |
_load_model()
|
| 182 |
if not texts:
|
|
|
|
| 202 |
logits = _model(**enc).logits
|
| 203 |
probs = torch.softmax(logits, dim=-1)
|
| 204 |
|
| 205 |
+
# CORRECT LABEL MAPPING
|
| 206 |
+
# Index 0 = LEGIT (probs[i][0])
|
| 207 |
+
# Index 1 = PHISH (probs[i][1])
|
| 208 |
+
labels_by_idx = ["LEGIT", "PHISH"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
outputs: List[Dict] = []
|
| 211 |
for i in range(probs.shape[0]):
|
| 212 |
p = probs[i]
|
| 213 |
+
|
| 214 |
+
# Get probabilities for each class
|
| 215 |
+
prob_legit = float(p[0].item())
|
| 216 |
+
prob_phish = float(p[1].item())
|
| 217 |
+
|
| 218 |
+
# Determine prediction based on which is higher
|
| 219 |
+
if prob_phish > prob_legit:
|
| 220 |
+
predicted_label = "PHISH"
|
| 221 |
+
predicted_idx = 1
|
| 222 |
+
confidence = prob_phish
|
| 223 |
+
else:
|
| 224 |
+
predicted_label = "LEGIT"
|
| 225 |
+
predicted_idx = 0
|
| 226 |
+
confidence = prob_legit
|
| 227 |
|
| 228 |
output = {
|
| 229 |
"text": texts[i][:100] + "..." if len(texts[i]) > 100 else texts[i],
|
| 230 |
+
"label": predicted_label,
|
| 231 |
+
"is_phish": predicted_label == "PHISH",
|
| 232 |
+
"confidence": round(confidence * 100, 2), # Convert to percentage
|
| 233 |
+
"predicted_index": predicted_idx,
|
| 234 |
+
"probs": {
|
| 235 |
+
"LEGIT": round(prob_legit * 100, 2),
|
| 236 |
+
"PHISH": round(prob_phish * 100, 2),
|
| 237 |
+
},
|
| 238 |
+
"raw_probs": {
|
| 239 |
+
"LEGIT (index 0)": round(prob_legit, 4),
|
| 240 |
+
"PHISH (index 1)": round(prob_phish, 4),
|
| 241 |
+
}
|
| 242 |
}
|
| 243 |
|
| 244 |
if include_preprocessing and preprocessing_info:
|
|
|
|
| 261 |
"status": "ok",
|
| 262 |
"model": MODEL_ID,
|
| 263 |
"device": _device,
|
| 264 |
+
"label_mapping": {
|
| 265 |
+
"0": "LEGIT",
|
| 266 |
+
"1": "PHISH"
|
| 267 |
+
},
|
| 268 |
+
"note": "Index 0 = LEGIT (probability%), Index 1 = PHISH (probability%)"
|
| 269 |
}
|
| 270 |
|
| 271 |
|
| 272 |
@app.get("/debug/labels")
|
| 273 |
def debug_labels():
|
| 274 |
+
"""View model configuration"""
|
| 275 |
_load_model()
|
| 276 |
|
| 277 |
id2label_raw = getattr(_model.config, "id2label", {}) or {}
|
|
|
|
| 283 |
"config_id2label": id2label_raw,
|
| 284 |
"config_label2id": label2id_raw,
|
| 285 |
"config_num_labels": num_labels,
|
| 286 |
+
"applied_mapping": {
|
| 287 |
+
"0": "LEGIT",
|
| 288 |
+
"1": "PHISH"
|
| 289 |
+
},
|
| 290 |
+
"device": _device
|
| 291 |
}
|
| 292 |
|
| 293 |
|