Perth0603 commited on
Commit
42f689a
·
verified ·
1 Parent(s): f5bfc06

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -50
app.py DELETED
@@ -1,50 +0,0 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
- import torch
5
- import os
6
-
7
-
8
- MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
9
-
10
- app = FastAPI(title="Phishing Text Classifier", version="1.0.0")
11
-
12
-
13
- class PredictPayload(BaseModel):
14
- inputs: str
15
-
16
-
17
- # Lazy singletons for model/tokenizer
18
- _tokenizer = None
19
- _model = None
20
-
21
-
22
- def _load_model():
23
- global _tokenizer, _model
24
- if _tokenizer is None or _model is None:
25
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
26
- _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
27
- # Warm-up
28
- with torch.no_grad():
29
- _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
30
-
31
-
32
- @app.get("/")
33
- def root():
34
- return {"status": "ok", "model": MODEL_ID}
35
-
36
-
37
- @app.post("/predict")
38
- def predict(payload: PredictPayload):
39
- _load_model()
40
- with torch.no_grad():
41
- logits = _model(**_tokenizer([payload.inputs], return_tensors="pt")).logits
42
- probs = torch.softmax(logits, dim=-1)[0]
43
- score, idx = torch.max(probs, dim=0)
44
-
45
- # Map common ids to labels (kept generic; your config also has these)
46
- id2label = {0: "LEGIT", 1: "PHISH"}
47
- label = id2label.get(int(idx), str(int(idx)))
48
- return {"label": label, "score": float(score)}
49
-
50
-