phishing / app.py
AliMusaRizvi's picture
Update app.py
14f3507 verified
# app.py
import os
import time
import torch
import lightgbm as lgb
import numpy as np
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, field_validator
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pydantic import BaseModel, field_validator, ConfigDict # Added ConfigDict
# ── Config ────────────────────────────────────────────────────────────────────
MODEL_REPO = os.getenv("MODEL_REPO", "AliMusaRizvi/phishing_model_for_extention")
LGBM_PATH = os.getenv("LGBM_PATH", "/app/lgbm_best.txt")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEQ_LEN = 256
print(f"Device: {DEVICE}")
# ── Global model holders ──────────────────────────────────────────────────────
models = {}
# ── Lifespan: load models once at startup ────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load all models on startup, release on shutdown."""
print(f"Loading ModernBERT tokenizer from {MODEL_REPO}...")
try:
models["tokenizer"] = AutoTokenizer.from_pretrained(
MODEL_REPO,
use_fast=True,
trust_remote_code=True
)
print("Tokenizer loaded.")
except Exception as e:
print(f"ERROR loading tokenizer: {e}")
raise
print(f"Loading ModernBERT model from {MODEL_REPO}...")
try:
models["bert"] = AutoModelForSequenceClassification.from_pretrained(
MODEL_REPO,
trust_remote_code=True,
torch_dtype=torch.float32, # float32 for CPU stability
low_cpu_mem_usage=True # reduces peak RAM during load
).to(DEVICE)
models["bert"].eval()
print("ModernBERT loaded.")
except Exception as e:
print(f"ERROR loading ModernBERT: {e}")
raise
print(f"Loading LightGBM from {LGBM_PATH}...")
try:
if os.path.exists(LGBM_PATH):
models["lgbm"] = lgb.Booster(model_file=LGBM_PATH)
print("LightGBM loaded.")
else:
print(f"WARNING: LightGBM file not found at {LGBM_PATH}, skipping.")
models["lgbm"] = None
except Exception as e:
print(f"WARNING: Could not load LightGBM: {e}")
models["lgbm"] = None
print("All models ready. API is live.")
yield
# Cleanup on shutdown
models.clear()
print("Models released.")
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="Phishing URL Detector",
description="Real-time phishing detection API using ModernBERT",
version="1.0.0",
lifespan=lifespan
)
# CORS β€” required for browser extension to call this API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten this to your extension origin in production
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# ── URL Preprocessing (matches training exactly) ──────────────────────────────
def preprocess_url(url: str) -> str:
url = str(url).strip()
if not url.startswith(("http://", "https://")):
url = "http://" + url
for delim in ["://", "/", "?", "&", "=", ".", "-", "_", "@", "%"]:
url = url.replace(delim, f" {delim} ")
return " ".join(url.split())
# ── Schemas ───────────────────────────────────────────────────────────────────
class URLRequest(BaseModel):
url: str
@field_validator("url")
@classmethod
def url_must_not_be_empty(cls, v):
if not v or not v.strip():
raise ValueError("URL cannot be empty")
if len(v) > 2048:
raise ValueError("URL exceeds maximum length of 2048 characters")
return v.strip()
class PredictionResponse(BaseModel):
model_config = ConfigDict(protected_namespaces=())
url: str
label: str
confidence: float
phishing_probability: float
legitimate_probability: float
inference_time_ms: float
model_used: str
class BatchURLRequest(BaseModel):
urls: list[str]
@field_validator("urls")
@classmethod
def limit_batch_size(cls, v):
if len(v) > 50:
raise ValueError("Batch size cannot exceed 50 URLs")
return v
# ── Single Prediction ─────────────────────────────────────────────────────────
@app.post("/predict", response_model=PredictionResponse)
def predict(request: URLRequest):
if "bert" not in models:
raise HTTPException(status_code=503, detail="Model not loaded yet")
t_start = time.time()
processed = preprocess_url(request.url)
inputs = models["tokenizer"](
processed,
truncation=True,
padding="max_length",
max_length=MAX_SEQ_LEN,
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
outputs = models["bert"](**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
phishing_prob = float(probs[1])
legitimate_prob = float(probs[0])
pred_label = "phishing" if phishing_prob >= 0.5 else "legitimate"
confidence = max(phishing_prob, legitimate_prob)
elapsed_ms = (time.time() - t_start) * 1000
return PredictionResponse(
url = request.url,
label = pred_label,
confidence = round(confidence, 4),
phishing_probability = round(phishing_prob, 4),
legitimate_probability = round(legitimate_prob, 4),
inference_time_ms = round(elapsed_ms, 2),
model_used = "ModernBERT"
)
# ── Batch Prediction ──────────────────────────────────────────────────────────
@app.post("/predict/batch")
def predict_batch(request: BatchURLRequest):
if "bert" not in models:
raise HTTPException(status_code=503, detail="Model not loaded yet")
t_start = time.time()
processed = [preprocess_url(u) for u in request.urls]
inputs = models["tokenizer"](
processed,
truncation=True,
padding=True,
max_length=MAX_SEQ_LEN,
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
outputs = models["bert"](**inputs)
probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()
results = []
for i, url in enumerate(request.urls):
p_prob = float(probs[i][1])
l_prob = float(probs[i][0])
results.append({
"url" : url,
"label" : "phishing" if p_prob >= 0.5 else "legitimate",
"confidence" : round(max(p_prob, l_prob), 4),
"phishing_probability" : round(p_prob, 4),
"legitimate_probability" : round(l_prob, 4),
})
elapsed_ms = (time.time() - t_start) * 1000
return {
"results" : results,
"count" : len(results),
"total_time_ms" : round(elapsed_ms, 2)
}
# ── Health & Info ─────────────────────────────────────────────────────────────
@app.get("/health")
def health():
return {
"status" : "ok" if "bert" in models else "loading",
"device" : DEVICE,
"models_loaded": list(models.keys()),
"bert_ready" : "bert" in models,
"lgbm_ready" : models.get("lgbm") is not None
}
@app.get("/")
def root():
return {
"message" : "Phishing Detector API",
"docs" : "/docs",
"health" : "/health",
"endpoints": ["/predict", "/predict/batch"]
}