Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import time | |
| import torch | |
| import lightgbm as lgb | |
| import numpy as np | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, field_validator | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from pydantic import BaseModel, field_validator, ConfigDict # Added ConfigDict | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_REPO = os.getenv("MODEL_REPO", "AliMusaRizvi/phishing_model_for_extention") | |
| LGBM_PATH = os.getenv("LGBM_PATH", "/app/lgbm_best.txt") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_SEQ_LEN = 256 | |
| print(f"Device: {DEVICE}") | |
| # ββ Global model holders ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| models = {} | |
| # ββ Lifespan: load models once at startup ββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| """Load all models on startup, release on shutdown.""" | |
| print(f"Loading ModernBERT tokenizer from {MODEL_REPO}...") | |
| try: | |
| models["tokenizer"] = AutoTokenizer.from_pretrained( | |
| MODEL_REPO, | |
| use_fast=True, | |
| trust_remote_code=True | |
| ) | |
| print("Tokenizer loaded.") | |
| except Exception as e: | |
| print(f"ERROR loading tokenizer: {e}") | |
| raise | |
| print(f"Loading ModernBERT model from {MODEL_REPO}...") | |
| try: | |
| models["bert"] = AutoModelForSequenceClassification.from_pretrained( | |
| MODEL_REPO, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32, # float32 for CPU stability | |
| low_cpu_mem_usage=True # reduces peak RAM during load | |
| ).to(DEVICE) | |
| models["bert"].eval() | |
| print("ModernBERT loaded.") | |
| except Exception as e: | |
| print(f"ERROR loading ModernBERT: {e}") | |
| raise | |
| print(f"Loading LightGBM from {LGBM_PATH}...") | |
| try: | |
| if os.path.exists(LGBM_PATH): | |
| models["lgbm"] = lgb.Booster(model_file=LGBM_PATH) | |
| print("LightGBM loaded.") | |
| else: | |
| print(f"WARNING: LightGBM file not found at {LGBM_PATH}, skipping.") | |
| models["lgbm"] = None | |
| except Exception as e: | |
| print(f"WARNING: Could not load LightGBM: {e}") | |
| models["lgbm"] = None | |
| print("All models ready. API is live.") | |
| yield | |
| # Cleanup on shutdown | |
| models.clear() | |
| print("Models released.") | |
| # ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Phishing URL Detector", | |
| description="Real-time phishing detection API using ModernBERT", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS β required for browser extension to call this API | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # tighten this to your extension origin in production | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ URL Preprocessing (matches training exactly) ββββββββββββββββββββββββββββββ | |
| def preprocess_url(url: str) -> str: | |
| url = str(url).strip() | |
| if not url.startswith(("http://", "https://")): | |
| url = "http://" + url | |
| for delim in ["://", "/", "?", "&", "=", ".", "-", "_", "@", "%"]: | |
| url = url.replace(delim, f" {delim} ") | |
| return " ".join(url.split()) | |
| # ββ Schemas βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class URLRequest(BaseModel): | |
| url: str | |
| def url_must_not_be_empty(cls, v): | |
| if not v or not v.strip(): | |
| raise ValueError("URL cannot be empty") | |
| if len(v) > 2048: | |
| raise ValueError("URL exceeds maximum length of 2048 characters") | |
| return v.strip() | |
| class PredictionResponse(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| url: str | |
| label: str | |
| confidence: float | |
| phishing_probability: float | |
| legitimate_probability: float | |
| inference_time_ms: float | |
| model_used: str | |
| class BatchURLRequest(BaseModel): | |
| urls: list[str] | |
| def limit_batch_size(cls, v): | |
| if len(v) > 50: | |
| raise ValueError("Batch size cannot exceed 50 URLs") | |
| return v | |
| # ββ Single Prediction βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(request: URLRequest): | |
| if "bert" not in models: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| t_start = time.time() | |
| processed = preprocess_url(request.url) | |
| inputs = models["tokenizer"]( | |
| processed, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=MAX_SEQ_LEN, | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = models["bert"](**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy() | |
| phishing_prob = float(probs[1]) | |
| legitimate_prob = float(probs[0]) | |
| pred_label = "phishing" if phishing_prob >= 0.5 else "legitimate" | |
| confidence = max(phishing_prob, legitimate_prob) | |
| elapsed_ms = (time.time() - t_start) * 1000 | |
| return PredictionResponse( | |
| url = request.url, | |
| label = pred_label, | |
| confidence = round(confidence, 4), | |
| phishing_probability = round(phishing_prob, 4), | |
| legitimate_probability = round(legitimate_prob, 4), | |
| inference_time_ms = round(elapsed_ms, 2), | |
| model_used = "ModernBERT" | |
| ) | |
| # ββ Batch Prediction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_batch(request: BatchURLRequest): | |
| if "bert" not in models: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| t_start = time.time() | |
| processed = [preprocess_url(u) for u in request.urls] | |
| inputs = models["tokenizer"]( | |
| processed, | |
| truncation=True, | |
| padding=True, | |
| max_length=MAX_SEQ_LEN, | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = models["bert"](**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy() | |
| results = [] | |
| for i, url in enumerate(request.urls): | |
| p_prob = float(probs[i][1]) | |
| l_prob = float(probs[i][0]) | |
| results.append({ | |
| "url" : url, | |
| "label" : "phishing" if p_prob >= 0.5 else "legitimate", | |
| "confidence" : round(max(p_prob, l_prob), 4), | |
| "phishing_probability" : round(p_prob, 4), | |
| "legitimate_probability" : round(l_prob, 4), | |
| }) | |
| elapsed_ms = (time.time() - t_start) * 1000 | |
| return { | |
| "results" : results, | |
| "count" : len(results), | |
| "total_time_ms" : round(elapsed_ms, 2) | |
| } | |
| # ββ Health & Info βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return { | |
| "status" : "ok" if "bert" in models else "loading", | |
| "device" : DEVICE, | |
| "models_loaded": list(models.keys()), | |
| "bert_ready" : "bert" in models, | |
| "lgbm_ready" : models.get("lgbm") is not None | |
| } | |
| def root(): | |
| return { | |
| "message" : "Phishing Detector API", | |
| "docs" : "/docs", | |
| "health" : "/health", | |
| "endpoints": ["/predict", "/predict/batch"] | |
| } |