Nick-2x commited on
Commit
b2ed807
·
verified ·
1 Parent(s): 0717aa2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+
6
+ app = FastAPI()
7
+
8
+ # 1. SWAP MODEL ID HERE
9
+ # Option A: dima806/phishing-email-detection (Good for Phishing)
10
+ # Option B: AntiSpamInstitute/spam-detector-bert-MoE-v2.2 (Good for Spam)
11
+ MODEL_ID = "dima806/phishing-email-detection"
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
14
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
15
+
16
+ class EmailInput(BaseModel):
17
+ text: str
18
+
19
+ @app.post("/predict")
20
+ async def predict_email(data: EmailInput):
21
+ # PRE-PROCESS: Handle very short text manually to avoid "Model Hallucinations"
22
+ if len(data.text.strip().split()) < 3:
23
+ return {"prediction": "legitimate", "confidence": 1.0, "is_phishing": False, "note": "Text too short for analysis"}
24
+
25
+ inputs = tokenizer(data.text, return_tensors="pt", truncation=True, max_length=512)
26
+
27
+ with torch.no_grad():
28
+ outputs = model(**inputs)
29
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
30
+
31
+ probs = predictions[0].tolist()
32
+
33
+ # 2. DYNAMIC LABEL MAPPING
34
+ # This automatically gets labels like 'LABEL_0', 'phishing', etc., from the model config
35
+ confidences = {model.config.id2label[i]: prob for i, prob in enumerate(probs)}
36
+
37
+ # Determine the top result
38
+ max_label = max(confidences.items(), key=lambda x: x[1])
39
+
40
+ return {
41
+ "prediction": max_label[0],
42
+ "confidence": round(max_label[1], 4),
43
+ "all_scores": confidences,
44
+ "is_phishing": "phishing" in max_label[0].lower() or "spam" in max_label[0].lower()
45
+ }
46
+
47
+ if __name__ == "__main__":
48
+ import uvicorn
49
+ uvicorn.run(app, host="0.0.0.0", port=7860)