Daniel Pedrinho commited on
Commit
13a9adc
·
1 Parent(s): c37c122

Model Commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ token
api.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spam Detection API
3
+ Ensemble of RoBERTa-Large + ELECTRA-Large classifiers.
4
+ Run with: uvicorn api:app --reload
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import email
13
+ from email import policy as email_policy
14
+ import numpy as np
15
+ import torch
16
+ from fastapi import FastAPI, HTTPException, UploadFile, File
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel
19
+ from transformers import (
20
+ AutoTokenizer,
21
+ ElectraForSequenceClassification,
22
+ RobertaForSequenceClassification,
23
+ )
24
+
25
+ # ── Config ────────────────────────────────────────────────────────────────────
26
+
27
+ BASE_DIR = Path(__file__).parent
28
+ MODELS_DIR = BASE_DIR / "models"
29
+
30
+ ROBERTA_DIR = MODELS_DIR / "roberta_large_final"
31
+ ELECTRA_DIR = MODELS_DIR / "electra_large_final"
32
+
33
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ MAYBE_SPAM_UPPER = 0.50 # [threshold, MAYBE_SPAM_UPPER) → "maybe spam"
36
+
37
+
38
+ # ── App ───────────────────────────────────────────────────────────────────────
39
+
40
+ app = FastAPI(
41
+ title="Spam Detection API",
42
+ description="Ensemble of RoBERTa-Large + ELECTRA-Large for spam/ham classification.",
43
+ version="1.0.0",
44
+ )
45
+
46
+ app.add_middleware(
47
+ CORSMiddleware,
48
+ allow_origins=["https://pedrinho-dev01.github.io/gone-phishing/"],
49
+ allow_methods=["*"],
50
+ allow_headers=["*"],
51
+ )
52
+
53
+
54
+ # ── Model loading ─────────────────────────────────────────────────────────────
55
+
56
+ class ModelBundle:
57
+ def __init__(self, model_dir: Path, model_class, tokenizer_class=None):
58
+ self.model_dir = model_dir
59
+ self.tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
60
+ self.model = model_class.from_pretrained(str(model_dir))
61
+ self.model.to(DEVICE)
62
+ self.model.eval()
63
+
64
+ threshold_path = model_dir / "threshold_config.json"
65
+ with open(threshold_path) as f:
66
+ cfg = json.load(f)
67
+ self.threshold: float = cfg["recommended_threshold"]
68
+
69
+ @torch.no_grad()
70
+ def predict_proba(self, text: str) -> float:
71
+ """Return P(spam) as a float in [0, 1]."""
72
+ inputs = self.tokenizer(
73
+ text,
74
+ return_tensors="pt",
75
+ truncation=True,
76
+ max_length=512,
77
+ padding=True,
78
+ )
79
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
80
+ logits = self.model(**inputs).logits # shape (1, 2)
81
+ proba = torch.softmax(logits, dim=-1)[0, 1].item() # P(class=1 / spam)
82
+ return proba
83
+
84
+
85
+ roberta_bundle: Optional[ModelBundle] = None
86
+ electra_bundle: Optional[ModelBundle] = None
87
+
88
+
89
+ @app.on_event("startup")
90
+ def load_models():
91
+ global roberta_bundle, electra_bundle
92
+ print("Loading RoBERTa …")
93
+ roberta_bundle = ModelBundle(ROBERTA_DIR, RobertaForSequenceClassification)
94
+ print("Loading ELECTRA …")
95
+ electra_bundle = ModelBundle(ELECTRA_DIR, ElectraForSequenceClassification)
96
+ print(f"Models loaded on {DEVICE}.")
97
+
98
+
99
+ # ── Schemas ───────────────────────────────────────────────────────────────────
100
+
101
+ class PredictRequest(BaseModel):
102
+ text: str
103
+ model: str = "ensemble" # "ensemble" | "roberta" | "electra"
104
+
105
+ class ModelResult(BaseModel):
106
+ spam_probability: float
107
+ is_spam: bool
108
+ threshold: float
109
+
110
+ class PredictResponse(BaseModel):
111
+ text: str
112
+ model_used: str
113
+ is_spam: bool
114
+ maybe_spam: bool
115
+ spam_probability: float
116
+ ensemble_threshold: float
117
+ maybe_spam_upper_threshold: float
118
+ roberta: Optional[ModelResult] = None
119
+ electra: Optional[ModelResult] = None
120
+
121
+
122
+ # ── Helpers ───────────────────────────────────────────────────────────────────
123
+
124
+ def classify(proba: float, threshold: float) -> dict:
125
+ """Return is_spam and maybe_spam flags for a given probability."""
126
+ maybe_spam = threshold <= proba < MAYBE_SPAM_UPPER
127
+ is_spam = proba >= MAYBE_SPAM_UPPER
128
+ return {"is_spam": is_spam, "maybe_spam": maybe_spam}
129
+
130
+
131
+ # ── Endpoints ─────────────────────────────────────────────────────────────────
132
+
133
+ @app.get("/")
134
+ def root():
135
+ return {"status": "ok", "message": "Spam Detection API is running."}
136
+
137
+
138
+ @app.get("/health")
139
+ def health():
140
+ return {
141
+ "status": "healthy",
142
+ "device": DEVICE,
143
+ "models_loaded": roberta_bundle is not None and electra_bundle is not None,
144
+ }
145
+
146
+
147
+ @app.post("/predict", response_model=PredictResponse)
148
+ def predict(req: PredictRequest):
149
+ if not req.text.strip():
150
+ raise HTTPException(status_code=422, detail="text must not be empty.")
151
+
152
+ model_key = req.model.lower()
153
+ if model_key not in ("ensemble", "roberta", "electra"):
154
+ raise HTTPException(status_code=422, detail="model must be 'ensemble', 'roberta', or 'electra'.")
155
+
156
+ roberta_proba = roberta_bundle.predict_proba(req.text)
157
+ electra_proba = electra_bundle.predict_proba(req.text)
158
+
159
+ roberta_result = ModelResult(
160
+ spam_probability=round(roberta_proba, 4),
161
+ is_spam=roberta_proba >= MAYBE_SPAM_UPPER,
162
+ threshold=roberta_bundle.threshold,
163
+ )
164
+ electra_result = ModelResult(
165
+ spam_probability=round(electra_proba, 4),
166
+ is_spam=electra_proba >= MAYBE_SPAM_UPPER,
167
+ threshold=electra_bundle.threshold,
168
+ )
169
+
170
+ if model_key == "roberta":
171
+ final_proba = roberta_proba
172
+ ensemble_threshold = roberta_bundle.threshold
173
+ elif model_key == "electra":
174
+ final_proba = electra_proba
175
+ ensemble_threshold = electra_bundle.threshold
176
+ else:
177
+ # Ensemble: average the two probabilities, use average threshold
178
+ final_proba = (roberta_proba + electra_proba) / 2
179
+ ensemble_threshold = (roberta_bundle.threshold + electra_bundle.threshold) / 2
180
+
181
+ flags = classify(final_proba, ensemble_threshold)
182
+
183
+ return PredictResponse(
184
+ text=req.text,
185
+ model_used=model_key,
186
+ is_spam=flags["is_spam"],
187
+ maybe_spam=flags["maybe_spam"],
188
+ spam_probability=round(final_proba, 4),
189
+ ensemble_threshold=ensemble_threshold,
190
+ maybe_spam_upper_threshold=MAYBE_SPAM_UPPER,
191
+ roberta=roberta_result,
192
+ electra=electra_result,
193
+ )
194
+
195
+
196
+ @app.post("/predict/batch")
197
+ def predict_batch(texts: list[str], model: str = "ensemble"):
198
+ if len(texts) > 50:
199
+ raise HTTPException(status_code=422, detail="Batch size limit is 50.")
200
+ results = []
201
+ for text in texts:
202
+ req = PredictRequest(text=text, model=model)
203
+ results.append(predict(req))
204
+ return results
205
+
206
+
207
+ # ── EML helper ────────────────────────────────────────────────────────────────
208
+
209
+ def extract_text_from_eml(raw_bytes: bytes) -> str:
210
+ """Parse a .eml file and return a single string with subject + body text."""
211
+ msg = email.message_from_bytes(raw_bytes, policy=email_policy.default)
212
+
213
+ parts = []
214
+
215
+ # Subject line
216
+ subject = msg.get("subject", "")
217
+ if subject:
218
+ parts.append(f"Subject: {subject}")
219
+
220
+ # From / To for extra signal
221
+ from_addr = msg.get("from", "")
222
+ if from_addr:
223
+ parts.append(f"From: {from_addr}")
224
+
225
+ # Walk MIME parts for text content
226
+ if msg.is_multipart():
227
+ for part in msg.walk():
228
+ ct = part.get_content_type()
229
+ cd = str(part.get("Content-Disposition", ""))
230
+ if ct == "text/plain" and "attachment" not in cd:
231
+ parts.append(part.get_content())
232
+ elif ct == "text/html" and "attachment" not in cd and not any(p.startswith("Subject") or "plain" in p for p in parts):
233
+ # Fallback to HTML only if no plain text found
234
+ import html as html_lib
235
+ raw_html = part.get_content()
236
+ # Very light strip — remove tags
237
+ import re
238
+ text = re.sub(r"<[^>]+>", " ", raw_html)
239
+ text = html_lib.unescape(text)
240
+ text = re.sub(r"\s+", " ", text).strip()
241
+ parts.append(text)
242
+ else:
243
+ parts.append(msg.get_content())
244
+
245
+ return "\n".join(parts).strip()
246
+
247
+
248
+ @app.post("/predict/eml", response_model=PredictResponse)
249
+ async def predict_eml(file: UploadFile = File(...)):
250
+ if not file.filename.endswith(".eml"):
251
+ raise HTTPException(status_code=422, detail="Only .eml files are accepted.")
252
+
253
+ raw = await file.read()
254
+ if len(raw) > 5 * 1024 * 1024: # 5 MB guard
255
+ raise HTTPException(status_code=413, detail="File too large (max 5 MB).")
256
+
257
+ try:
258
+ text = extract_text_from_eml(raw)
259
+ except Exception as e:
260
+ raise HTTPException(status_code=422, detail=f"Failed to parse .eml: {e}")
261
+
262
+ if not text.strip():
263
+ raise HTTPException(status_code=422, detail="Could not extract any text from the .eml file.")
264
+
265
+ analyzed_text = text.strip()
266
+ print("\n=== [EMAIL SCAN] Content analyzed ===")
267
+ print(analyzed_text)
268
+ print("=== [END EMAIL CONTENT] ===\n")
269
+
270
+ # Reuse the existing ensemble prediction logic
271
+ return predict(PredictRequest(text=analyzed_text, model="ensemble"))
models/electra_large_final/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
models/electra_large_final/config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_cross_attention": false,
3
+ "architectures": [
4
+ "ElectraForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": null,
8
+ "classifier_dropout": null,
9
+ "dtype": "float32",
10
+ "embedding_size": 1024,
11
+ "eos_token_id": null,
12
+ "hidden_act": "gelu",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 4096,
17
+ "is_decoder": false,
18
+ "layer_norm_eps": 1e-12,
19
+ "max_position_embeddings": 512,
20
+ "model_type": "electra",
21
+ "num_attention_heads": 16,
22
+ "num_hidden_layers": 24,
23
+ "pad_token_id": 0,
24
+ "position_embedding_type": "absolute",
25
+ "summary_activation": "gelu",
26
+ "summary_last_dropout": 0.1,
27
+ "summary_type": "first",
28
+ "summary_use_proj": true,
29
+ "tie_word_embeddings": true,
30
+ "transformers_version": "5.3.0",
31
+ "type_vocab_size": 2,
32
+ "use_cache": false,
33
+ "vocab_size": 30522
34
+ }
models/electra_large_final/threshold_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "recommended_threshold": 0.35,
3
+ "standard_metrics": {
4
+ "accuracy": 0.9256,
5
+ "f1": 0.9051987767584098,
6
+ "precision": 0.9230769230769231,
7
+ "recall": 0.888
8
+ },
9
+ "custom_metrics": {
10
+ "accuracy": 0.9256,
11
+ "f1": 0.9055837563451776,
12
+ "precision": 0.9195876288659793,
13
+ "recall": 0.892
14
+ }
15
+ }
models/electra_large_final/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/electra_large_final/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "is_local": false,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 512,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "strip_accents": null,
11
+ "tokenize_chinese_chars": true,
12
+ "tokenizer_class": "BertTokenizer",
13
+ "unk_token": "[UNK]"
14
+ }
models/electra_large_final/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e251fe80c570139a5ddea6518864f1ccf76ef6536208c2d234507ba2c06c2b9
3
+ size 4856
models/roberta_large_final/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
models/roberta_large_final/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_cross_attention": false,
3
+ "architectures": [
4
+ "RobertaForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "dtype": "float32",
10
+ "eos_token_id": 2,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 1024,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 4096,
16
+ "is_decoder": false,
17
+ "layer_norm_eps": 1e-05,
18
+ "max_position_embeddings": 514,
19
+ "model_type": "roberta",
20
+ "num_attention_heads": 16,
21
+ "num_hidden_layers": 24,
22
+ "pad_token_id": 1,
23
+ "tie_word_embeddings": true,
24
+ "transformers_version": "5.3.0",
25
+ "type_vocab_size": 1,
26
+ "use_cache": false,
27
+ "vocab_size": 50265
28
+ }
models/roberta_large_final/threshold_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "recommended_threshold": 0.35,
3
+ "standard_metrics": {
4
+ "accuracy": 0.9352,
5
+ "f1": 0.916923076923077,
6
+ "precision": 0.9410526315789474,
7
+ "recall": 0.894
8
+ },
9
+ "custom_metrics": {
10
+ "accuracy": 0.9336,
11
+ "f1": 0.9150460593654043,
12
+ "precision": 0.9371069182389937,
13
+ "recall": 0.894
14
+ }
15
+ }
models/roberta_large_final/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/roberta_large_final/tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "cls_token": "<s>",
6
+ "eos_token": "</s>",
7
+ "errors": "replace",
8
+ "is_local": false,
9
+ "mask_token": "<mask>",
10
+ "model_max_length": 512,
11
+ "pad_token": "<pad>",
12
+ "sep_token": "</s>",
13
+ "tokenizer_class": "RobertaTokenizer",
14
+ "trim_offsets": true,
15
+ "unk_token": "<unk>"
16
+ }
models/roberta_large_final/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf7746da523087b4c98b10face3adad900b52a4c3ab325a7207442bec1e9eddb
3
+ size 4856