Perth0603 commited on
Commit
9dfcd1b
·
verified ·
1 Parent(s): 820a438

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -108
app.py CHANGED
@@ -1,111 +1,43 @@
1
- import os
2
- os.environ.setdefault("HOME", "/data")
3
- os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
4
- os.environ.setdefault("HF_HOME", "/data/.cache")
5
- os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache")
6
- os.environ.setdefault("TORCH_HOME", "/data/.cache")
7
 
8
- from fastapi import FastAPI
9
- from fastapi.responses import JSONResponse
10
- from pydantic import BaseModel
11
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
- import torch
13
-
14
-
15
- MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
16
-
17
- # Ensure writable cache directory for HF/torch inside Spaces Docker
18
- CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
19
- os.makedirs(CACHE_DIR, exist_ok=True)
20
-
21
- app = FastAPI(title="Phishing Text Classifier", version="1.0.0")
22
-
23
-
24
- class PredictPayload(BaseModel):
25
- inputs: str
26
-
27
-
28
- # Lazy singletons for model/tokenizer
29
- _tokenizer = None
30
- _model = None
31
-
32
-
33
- def _load_model():
34
- global _tokenizer, _model
35
- if _tokenizer is None or _model is None:
36
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
37
- _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
38
- _model.eval() # inference mode
39
- # Warm-up
40
- with torch.no_grad():
41
- _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
42
-
43
-
44
- def _id2label():
45
  cfg = getattr(_model, "config", None)
46
- return getattr(cfg, "id2label", {}) if cfg else {}
47
-
48
-
49
- @app.get("/")
50
- def root():
51
- _load_model()
52
- cfg = getattr(_model, "config", None)
53
- return {
54
- "status": "ok",
55
- "model": MODEL_ID,
56
- "num_labels": int(getattr(cfg, "num_labels", 2)) if cfg else 2,
57
- }
58
-
59
-
60
- @app.get("/labels")
61
- def labels():
62
- _load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  cfg = getattr(_model, "config", None)
64
- return {
65
- "id2label": getattr(cfg, "id2label", {}) if cfg else {},
66
- "label2id": getattr(cfg, "label2id", {}) if cfg else {},
67
- }
68
-
69
-
70
- @app.post("/predict")
71
- def predict(payload: PredictPayload):
72
- try:
73
- _load_model()
74
- with torch.no_grad():
75
- inputs = _tokenizer(
76
- [payload.inputs],
77
- return_tensors="pt",
78
- truncation=True,
79
- max_length=512
80
- )
81
- outputs = _model(**inputs)
82
- logits = outputs.logits # [1, num_labels]
83
- logits_list = logits[0].tolist()
84
- pred_idx = int(torch.argmax(logits, dim=-1).item())
85
-
86
- # Keep client-compatible fields but also provide raw outputs
87
- probs_t = torch.softmax(logits, dim=-1)[0]
88
- score = float(probs_t[pred_idx])
89
-
90
- except Exception as e:
91
- return JSONResponse(status_code=500, content={"error": str(e)})
92
-
93
- id2label = _id2label()
94
- # Resolve label from model config (support int or str keys)
95
- pred_label = id2label.get(pred_idx, id2label.get(str(pred_idx), str(pred_idx)))
96
-
97
- # Build per-label probabilities for debugging/verification
98
- probs = {}
99
- for i, p in enumerate(probs_t.tolist()):
100
- probs[id2label.get(i, id2label.get(str(i), str(i)))] = float(p)
101
-
102
- # Backward-compatible keys: "label" and "score"
103
- return {
104
- "label": pred_label, # expected by your client
105
- "score": score, # probability of predicted class (softmax)
106
- "predicted_index": pred_idx, # raw argmax index from logits
107
- "logits": logits_list, # raw model output
108
- "probs": probs, # per-label probabilities
109
- "id2label": id2label,
110
- "label2id": getattr(getattr(_model, "config", None), "label2id", {}),
111
- }
 
1
+ def _normalize_label_name(name: str) -> str:
2
+ if not isinstance(name, str):
3
+ return ""
4
+ return name.strip().lower()
 
 
5
 
6
+ def _resolve_indices_from_config():
7
+ # Returns (phish_idx, legit_idx) using model-config names and sensible fallbacks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  cfg = getattr(_model, "config", None)
9
+ id2label = getattr(cfg, "id2label", {}) if cfg else {}
10
+ # Normalize keys to int
11
+ norm = {}
12
+ for k, v in id2label.items():
13
+ try:
14
+ ik = int(k)
15
+ except Exception:
16
+ continue
17
+ norm[ik] = _normalize_label_name(v)
18
+
19
+ # Try to detect via keywords
20
+ phish_keywords = {"phish", "phishing", "spam", "scam", "malicious"}
21
+ legit_keywords = {"legit", "ham", "safe", "benign", "not phish", "non-phish"}
22
+
23
+ phish_idx = None
24
+ legit_idx = None
25
+ for i, name in norm.items():
26
+ if any(kw in name for kw in phish_keywords):
27
+ phish_idx = i if phish_idx is None else phish_idx
28
+ if any(kw in name for kw in legit_keywords):
29
+ legit_idx = i if legit_idx is None else legit_idx
30
+
31
+ # Fallback conventions for binary heads
32
+ if phish_idx is None or legit_idx is None:
33
+ if len(norm) == 2:
34
+ # Common convention: 0 = negative(legit), 1 = positive(phish)
35
+ phish_idx = 1 if phish_idx is None else phish_idx
36
+ legit_idx = 0 if legit_idx is None else legit_idx
37
+
38
+ return phish_idx, legit_idx
39
+
40
+ def _label_for_index(idx: int) -> str:
41
  cfg = getattr(_model, "config", None)
42
+ id2label = getattr(cfg, "id2label", {}) if cfg else {}
43
+ return id2label.get(idx, id2label.get(str(idx), str(idx)))