Perth0603 commited on
Commit
711ac8e
·
verified ·
1 Parent(s): 9dfcd1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -19
app.py CHANGED
@@ -1,13 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def _normalize_label_name(name: str) -> str:
2
  if not isinstance(name, str):
3
  return ""
4
- return name.strip().lower()
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def _resolve_indices_from_config():
7
  # Returns (phish_idx, legit_idx) using model-config names and sensible fallbacks
8
- cfg = getattr(_model, "config", None)
9
- id2label = getattr(cfg, "id2label", {}) if cfg else {}
10
- # Normalize keys to int
 
 
11
  norm = {}
12
  for k, v in id2label.items():
13
  try:
@@ -16,28 +75,103 @@ def _resolve_indices_from_config():
16
  continue
17
  norm[ik] = _normalize_label_name(v)
18
 
19
- # Try to detect via keywords
20
- phish_keywords = {"phish", "phishing", "spam", "scam", "malicious"}
21
- legit_keywords = {"legit", "ham", "safe", "benign", "not phish", "non-phish"}
22
 
23
  phish_idx = None
24
  legit_idx = None
25
  for i, name in norm.items():
26
- if any(kw in name for kw in phish_keywords):
27
- phish_idx = i if phish_idx is None else phish_idx
28
- if any(kw in name for kw in legit_keywords):
29
- legit_idx = i if legit_idx is None else legit_idx
30
-
31
- # Fallback conventions for binary heads
32
- if phish_idx is None or legit_idx is None:
33
- if len(norm) == 2:
34
- # Common convention: 0 = negative(legit), 1 = positive(phish)
 
 
 
 
 
 
 
35
  phish_idx = 1 if phish_idx is None else phish_idx
36
  legit_idx = 0 if legit_idx is None else legit_idx
37
 
38
  return phish_idx, legit_idx
39
 
40
- def _label_for_index(idx: int) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  cfg = getattr(_model, "config", None)
42
- id2label = getattr(cfg, "id2label", {}) if cfg else {}
43
- return id2label.get(idx, id2label.get(str(idx), str(idx)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ.setdefault("HOME", "/data")
3
+ os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
4
+ os.environ.setdefault("HF_HOME", "/data/.cache")
5
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache")
6
+ os.environ.setdefault("TORCH_HOME", "/data/.cache")
7
+
8
+ from fastapi import FastAPI
9
+ from fastapi.responses import JSONResponse
10
+ from pydantic import BaseModel
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ import torch
13
+
14
+
15
+ MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
16
+
17
+ # Ensure writable cache directory for HF/torch inside Spaces Docker
18
+ CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
19
+ os.makedirs(CACHE_DIR, exist_ok=True)
20
+
21
+ # Decision threshold for PHISH probability
22
+ PHISH_THRESHOLD = float(os.environ.get("PHISH_THRESHOLD", "0.5"))
23
+
24
+ app = FastAPI(title="Phishing Text Classifier", version="1.0.0")
25
+
26
+
27
+ class PredictPayload(BaseModel):
28
+ inputs: str
29
+
30
+
31
+ # Lazy singletons for model/tokenizer
32
+ _tokenizer = None
33
+ _model = None
34
+
35
+
36
+ def _load_model():
37
+ global _tokenizer, _model
38
+ if _tokenizer is None or _model is None:
39
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
40
+ _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
41
+ _model.eval() # inference mode
42
+ # Warm-up
43
+ with torch.no_grad():
44
+ _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
45
+
46
+
47
  def _normalize_label_name(name: str) -> str:
48
  if not isinstance(name, str):
49
  return ""
50
+ return name.strip().lower().replace("_", " ")
51
+
52
+
53
+ def _id2label_map():
54
+ cfg = getattr(_model, "config", None)
55
+ return getattr(cfg, "id2label", {}) if cfg else {}
56
+
57
+
58
+ def _label_for_index(idx: int) -> str:
59
+ id2label = _id2label_map()
60
+ return id2label.get(idx, id2label.get(str(idx), str(idx)))
61
+
62
 
63
  def _resolve_indices_from_config():
64
  # Returns (phish_idx, legit_idx) using model-config names and sensible fallbacks
65
+ id2label = _id2label_map()
66
+ if not isinstance(id2label, dict):
67
+ id2label = {}
68
+
69
+ # Normalize to int keys when possible
70
  norm = {}
71
  for k, v in id2label.items():
72
  try:
 
75
  continue
76
  norm[ik] = _normalize_label_name(v)
77
 
78
+ phish_keywords = {"phish", "phishing", "spam", "scam", "malicious", "fraud"}
79
+ legit_keywords = {"legit", "ham", "safe", "benign", "not phish", "non phish", "clean"}
 
80
 
81
  phish_idx = None
82
  legit_idx = None
83
  for i, name in norm.items():
84
+ if any(kw in name for kw in phish_keywords) and phish_idx is None:
85
+ phish_idx = i
86
+ if any(kw in name for kw in legit_keywords) and legit_idx is None:
87
+ legit_idx = i
88
+
89
+ # Fallback for common binary convention when labels aren't informative
90
+ if (phish_idx is None or legit_idx is None) and len(norm) == 2:
91
+ # Many binary heads: 0 = negative(legit), 1 = positive(phish)
92
+ phish_idx = 1 if phish_idx is None else phish_idx
93
+ legit_idx = 0 if legit_idx is None else legit_idx
94
+
95
+ # If id2label was empty but model is binary, still fallback to (1,0)
96
+ if not norm:
97
+ cfg = getattr(_model, "config", None)
98
+ num_labels = int(getattr(cfg, "num_labels", 2)) if cfg else 2
99
+ if num_labels == 2:
100
  phish_idx = 1 if phish_idx is None else phish_idx
101
  legit_idx = 0 if legit_idx is None else legit_idx
102
 
103
  return phish_idx, legit_idx
104
 
105
+
106
+ def _probs_dict(probs_list):
107
+ out = {}
108
+ for i, p in enumerate(probs_list):
109
+ out[_label_for_index(i)] = float(p)
110
+ return out
111
+
112
+
113
+ @app.get("/")
114
+ def root():
115
+ _load_model()
116
+ cfg = getattr(_model, "config", None)
117
+ return {
118
+ "status": "ok",
119
+ "model": MODEL_ID,
120
+ "num_labels": int(getattr(cfg, "num_labels", 2)) if cfg else 2,
121
+ }
122
+
123
+
124
+ @app.get("/labels")
125
+ def labels():
126
+ _load_model()
127
  cfg = getattr(_model, "config", None)
128
+ return {
129
+ "id2label": getattr(cfg, "id2label", {}) if cfg else {},
130
+ "label2id": getattr(cfg, "label2id", {}) if cfg else {},
131
+ }
132
+
133
+
134
+ @app.post("/predict")
135
+ def predict(payload: PredictPayload):
136
+ try:
137
+ _load_model()
138
+ with torch.no_grad():
139
+ inputs = _tokenizer(
140
+ [payload.inputs],
141
+ return_tensors="pt",
142
+ truncation=True,
143
+ max_length=512
144
+ )
145
+ outputs = _model(**inputs)
146
+ logits = outputs.logits # [1, num_labels]
147
+ probs_t = torch.softmax(logits, dim=-1)[0] # [num_labels]
148
+ probs_list = probs_t.tolist()
149
+ argmax_idx = int(torch.argmax(probs_t).item())
150
+
151
+ phish_idx, legit_idx = _resolve_indices_from_config()
152
+
153
+ # Compute PHISH probability robustly
154
+ if phish_idx is not None and 0 <= phish_idx < len(probs_list):
155
+ phish_score = float(probs_list[phish_idx])
156
+ else:
157
+ # If we cannot resolve PHISH index, use argmax class prob
158
+ phish_score = float(probs_list[argmax_idx])
159
+
160
+ label = "PHISH" if phish_score >= PHISH_THRESHOLD else "LEGIT"
161
+
162
+ resp = {
163
+ "label": label, # client-compatible
164
+ "score": phish_score, # probability of PHISH class
165
+ "predicted_index": argmax_idx, # argmax over probs
166
+ "logits": logits[0].tolist(), # raw logits
167
+ "probs": _probs_dict(probs_list), # per-label probs
168
+ "id2label": _id2label_map(),
169
+ "phish_idx": phish_idx,
170
+ "legit_idx": legit_idx,
171
+ "threshold": PHISH_THRESHOLD,
172
+ }
173
+
174
+ return resp
175
+
176
+ except Exception as e:
177
+ return JSONResponse(status_code=500, content={"error": str(e)})