Perth0603 commited on
Commit
a88bffc
·
verified ·
1 Parent(s): 89dd160

Upload 4 files

Browse files
Files changed (3) hide show
  1. Dockerfile +28 -28
  2. app.py +66 -71
  3. requirements.txt +8 -8
Dockerfile CHANGED
@@ -1,28 +1,28 @@
1
- FROM python:3.10-slim
2
-
3
- ENV PYTHONDONTWRITEBYTECODE=1 \
4
- PYTHONUNBUFFERED=1 \
5
- PIP_NO_CACHE_DIR=1
6
-
7
- WORKDIR /app
8
-
9
- # Writable cache directory for HF/torch
10
- RUN mkdir -p /data/.cache && chmod -R 777 /data
11
- ENV HF_HOME=/data/.cache \
12
- TRANSFORMERS_CACHE=/data/.cache \
13
- TORCH_HOME=/data/.cache
14
-
15
- # System deps (optional but helps with torch wheels)
16
- RUN apt-get update && apt-get install -y --no-install-recommends \
17
- build-essential git && \
18
- rm -rf /var/lib/apt/lists/*
19
-
20
- COPY requirements.txt /app/requirements.txt
21
- RUN pip install -r /app/requirements.txt
22
-
23
- COPY app.py /app/app.py
24
-
25
- EXPOSE 7860
26
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
27
-
28
-
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1
6
+
7
+ WORKDIR /app
8
+
9
+ # Writable cache directory for HF/torch
10
+ RUN mkdir -p /data/.cache && chmod -R 777 /data
11
+ ENV HF_HOME=/data/.cache \
12
+ TRANSFORMERS_CACHE=/data/.cache \
13
+ TORCH_HOME=/data/.cache
14
+
15
+ # System deps (optional but helps with torch wheels)
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ build-essential git && \
18
+ rm -rf /var/lib/apt/lists/*
19
+
20
+ COPY requirements.txt /app/requirements.txt
21
+ RUN pip install -r /app/requirements.txt
22
+
23
+ COPY app.py /app/app.py
24
+
25
+ EXPOSE 7860
26
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
27
+
28
+
app.py CHANGED
@@ -1,71 +1,66 @@
1
- import os
2
- os.environ.setdefault("HOME", "/data")
3
- os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
4
- os.environ.setdefault("HF_HOME", "/data/.cache")
5
- os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache")
6
- os.environ.setdefault("TORCH_HOME", "/data/.cache")
7
-
8
- from fastapi import FastAPI
9
- from fastapi.responses import JSONResponse
10
- from pydantic import BaseModel
11
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
- import torch
13
-
14
-
15
- MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
16
- # Optional temperature to sharpen probabilities (lower <1 sharper, >1 smoother)
17
- TEMP = float(os.environ.get("TEMP", "0.7"))
18
-
19
- # Ensure writable cache directory for HF/torch inside Spaces Docker
20
- CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
21
- os.makedirs(CACHE_DIR, exist_ok=True)
22
-
23
- app = FastAPI(title="Phishing Text Classifier", version="1.0.0")
24
-
25
-
26
- class PredictPayload(BaseModel):
27
- inputs: str
28
-
29
-
30
- # Lazy singletons for model/tokenizer
31
- _tokenizer = None
32
- _model = None
33
-
34
-
35
- def _load_model():
36
- global _tokenizer, _model
37
- if _tokenizer is None or _model is None:
38
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
39
- _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
40
- _model.eval()
41
- # Warm-up
42
- with torch.no_grad():
43
- _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
44
-
45
-
46
- @app.get("/")
47
- def root():
48
- return {"status": "ok", "model": MODEL_ID}
49
-
50
-
51
- @app.post("/predict")
52
- def predict(payload: PredictPayload):
53
- try:
54
- _load_model()
55
- with torch.no_grad():
56
- inputs = _tokenizer([payload.inputs], return_tensors="pt", truncation=True, max_length=512)
57
- logits = _model(**inputs).logits
58
- # Raw probs (for reference)
59
- raw_probs = torch.softmax(logits, dim=-1)[0]
60
- # Temperature-scaled probs to make confidence less around 0.5
61
- scaled_probs = torch.softmax(logits / TEMP, dim=-1)[0]
62
- score, idx = torch.max(scaled_probs, dim=0)
63
- except Exception as e:
64
- return JSONResponse(status_code=500, content={"error": str(e)})
65
-
66
- # Map common ids to labels (kept generic; your config also has these)
67
- id2label = {0: "LEGIT", 1: "PHISH"}
68
- label = id2label.get(int(idx), str(int(idx)))
69
- return {"label": label, "score": float(score), "raw_score": float(raw_probs[int(idx)])}
70
-
71
-
 
1
+ import os
2
+ os.environ.setdefault("HOME", "/data")
3
+ os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
4
+ os.environ.setdefault("HF_HOME", "/data/.cache")
5
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache")
6
+ os.environ.setdefault("TORCH_HOME", "/data/.cache")
7
+
8
+ from fastapi import FastAPI
9
+ from fastapi.responses import JSONResponse
10
+ from pydantic import BaseModel
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ import torch
13
+
14
+
15
+ MODEL_ID = os.environ.get("MODEL_ID", "Perth0603/phishing-email-mobilebert")
16
+
17
+ # Ensure writable cache directory for HF/torch inside Spaces Docker
18
+ CACHE_DIR = os.environ.get("HF_CACHE_DIR", "/data/.cache")
19
+ os.makedirs(CACHE_DIR, exist_ok=True)
20
+
21
+ app = FastAPI(title="Phishing Text Classifier", version="1.0.0")
22
+
23
+
24
+ class PredictPayload(BaseModel):
25
+ inputs: str
26
+
27
+
28
+ # Lazy singletons for model/tokenizer
29
+ _tokenizer = None
30
+ _model = None
31
+
32
+
33
+ def _load_model():
34
+ global _tokenizer, _model
35
+ if _tokenizer is None or _model is None:
36
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
37
+ _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
38
+ _model.eval()
39
+ # Warm-up
40
+ with torch.no_grad():
41
+ _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
42
+
43
+
44
+ @app.get("/")
45
+ def root():
46
+ return {"status": "ok", "model": MODEL_ID}
47
+
48
+
49
+ @app.post("/predict")
50
+ def predict(payload: PredictPayload):
51
+ try:
52
+ _load_model()
53
+ with torch.no_grad():
54
+ inputs = _tokenizer([payload.inputs], return_tensors="pt", truncation=True, max_length=512)
55
+ logits = _model(**inputs).logits
56
+ probs = torch.softmax(logits, dim=-1)[0]
57
+ score, idx = torch.max(probs, dim=0)
58
+ except Exception as e:
59
+ return JSONResponse(status_code=500, content={"error": str(e)})
60
+
61
+ # Map common ids to labels (kept generic; your config also has these)
62
+ id2label = {0: "LEGIT", 1: "PHISH"}
63
+ label = id2label.get(int(idx), str(int(idx)))
64
+ return {"label": label, "score": float(score)}
65
+
66
+
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- --extra-index-url https://download.pytorch.org/whl/cpu
2
- fastapi==0.115.0
3
- uvicorn==0.30.6
4
- transformers==4.46.3
5
- torch==2.3.1+cpu
6
- accelerate>=0.33.0
7
- safetensors>=0.4.3
8
-
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ fastapi==0.115.0
3
+ uvicorn==0.30.6
4
+ transformers==4.46.3
5
+ torch==2.3.1+cpu
6
+ accelerate>=0.33.0
7
+ safetensors>=0.4.3
8
+