# app.py import os import time import torch import lightgbm as lgb import numpy as np from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, field_validator from transformers import AutoTokenizer, AutoModelForSequenceClassification from pydantic import BaseModel, field_validator, ConfigDict # Added ConfigDict # ── Config ──────────────────────────────────────────────────────────────────── MODEL_REPO = os.getenv("MODEL_REPO", "AliMusaRizvi/phishing_model_for_extention") LGBM_PATH = os.getenv("LGBM_PATH", "/app/lgbm_best.txt") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_SEQ_LEN = 256 print(f"Device: {DEVICE}") # ── Global model holders ────────────────────────────────────────────────────── models = {} # ── Lifespan: load models once at startup ──────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): """Load all models on startup, release on shutdown.""" print(f"Loading ModernBERT tokenizer from {MODEL_REPO}...") try: models["tokenizer"] = AutoTokenizer.from_pretrained( MODEL_REPO, use_fast=True, trust_remote_code=True ) print("Tokenizer loaded.") except Exception as e: print(f"ERROR loading tokenizer: {e}") raise print(f"Loading ModernBERT model from {MODEL_REPO}...") try: models["bert"] = AutoModelForSequenceClassification.from_pretrained( MODEL_REPO, trust_remote_code=True, torch_dtype=torch.float32, # float32 for CPU stability low_cpu_mem_usage=True # reduces peak RAM during load ).to(DEVICE) models["bert"].eval() print("ModernBERT loaded.") except Exception as e: print(f"ERROR loading ModernBERT: {e}") raise print(f"Loading LightGBM from {LGBM_PATH}...") try: if os.path.exists(LGBM_PATH): models["lgbm"] = lgb.Booster(model_file=LGBM_PATH) print("LightGBM loaded.") else: print(f"WARNING: LightGBM file not found at {LGBM_PATH}, skipping.") models["lgbm"] = None except Exception as e: print(f"WARNING: Could not load LightGBM: {e}") models["lgbm"] = None print("All models ready. API is live.") yield # Cleanup on shutdown models.clear() print("Models released.") # ── App ─────────────────────────────────────────────────────────────────────── app = FastAPI( title="Phishing URL Detector", description="Real-time phishing detection API using ModernBERT", version="1.0.0", lifespan=lifespan ) # CORS — required for browser extension to call this API app.add_middleware( CORSMiddleware, allow_origins=["*"], # tighten this to your extension origin in production allow_methods=["GET", "POST"], allow_headers=["*"], ) # ── URL Preprocessing (matches training exactly) ────────────────────────────── def preprocess_url(url: str) -> str: url = str(url).strip() if not url.startswith(("http://", "https://")): url = "http://" + url for delim in ["://", "/", "?", "&", "=", ".", "-", "_", "@", "%"]: url = url.replace(delim, f" {delim} ") return " ".join(url.split()) # ── Schemas ─────────────────────────────────────────────────────────────────── class URLRequest(BaseModel): url: str @field_validator("url") @classmethod def url_must_not_be_empty(cls, v): if not v or not v.strip(): raise ValueError("URL cannot be empty") if len(v) > 2048: raise ValueError("URL exceeds maximum length of 2048 characters") return v.strip() class PredictionResponse(BaseModel): model_config = ConfigDict(protected_namespaces=()) url: str label: str confidence: float phishing_probability: float legitimate_probability: float inference_time_ms: float model_used: str class BatchURLRequest(BaseModel): urls: list[str] @field_validator("urls") @classmethod def limit_batch_size(cls, v): if len(v) > 50: raise ValueError("Batch size cannot exceed 50 URLs") return v # ── Single Prediction ───────────────────────────────────────────────────────── @app.post("/predict", response_model=PredictionResponse) def predict(request: URLRequest): if "bert" not in models: raise HTTPException(status_code=503, detail="Model not loaded yet") t_start = time.time() processed = preprocess_url(request.url) inputs = models["tokenizer"]( processed, truncation=True, padding="max_length", max_length=MAX_SEQ_LEN, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): outputs = models["bert"](**inputs) probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy() phishing_prob = float(probs[1]) legitimate_prob = float(probs[0]) pred_label = "phishing" if phishing_prob >= 0.5 else "legitimate" confidence = max(phishing_prob, legitimate_prob) elapsed_ms = (time.time() - t_start) * 1000 return PredictionResponse( url = request.url, label = pred_label, confidence = round(confidence, 4), phishing_probability = round(phishing_prob, 4), legitimate_probability = round(legitimate_prob, 4), inference_time_ms = round(elapsed_ms, 2), model_used = "ModernBERT" ) # ── Batch Prediction ────────────────────────────────────────────────────────── @app.post("/predict/batch") def predict_batch(request: BatchURLRequest): if "bert" not in models: raise HTTPException(status_code=503, detail="Model not loaded yet") t_start = time.time() processed = [preprocess_url(u) for u in request.urls] inputs = models["tokenizer"]( processed, truncation=True, padding=True, max_length=MAX_SEQ_LEN, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): outputs = models["bert"](**inputs) probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy() results = [] for i, url in enumerate(request.urls): p_prob = float(probs[i][1]) l_prob = float(probs[i][0]) results.append({ "url" : url, "label" : "phishing" if p_prob >= 0.5 else "legitimate", "confidence" : round(max(p_prob, l_prob), 4), "phishing_probability" : round(p_prob, 4), "legitimate_probability" : round(l_prob, 4), }) elapsed_ms = (time.time() - t_start) * 1000 return { "results" : results, "count" : len(results), "total_time_ms" : round(elapsed_ms, 2) } # ── Health & Info ───────────────────────────────────────────────────────────── @app.get("/health") def health(): return { "status" : "ok" if "bert" in models else "loading", "device" : DEVICE, "models_loaded": list(models.keys()), "bert_ready" : "bert" in models, "lgbm_ready" : models.get("lgbm") is not None } @app.get("/") def root(): return { "message" : "Phishing Detector API", "docs" : "/docs", "health" : "/health", "endpoints": ["/predict", "/predict/batch"] }