Spaces:
Build error
Build error
File size: 4,497 Bytes
7a3576b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
"""
FastAPI application for phishing URL detection.
Provides a REST API endpoint to predict if a URL is phishing or legitimate.
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
from typing import Optional
import uvicorn
from model.model import load_model, predict_url
# Initialize FastAPI app
app = FastAPI(
title="Phishing URL Detection API",
description="API for detecting phishing URLs using machine learning",
version="1.0.0"
)
# Add CORS middleware to allow web access
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model on startup
model_components = None
@app.on_event("startup")
async def startup_event():
"""Load the model when the application starts."""
global model_components
try:
model_components = load_model()
print("✅ Model loaded successfully on startup")
except Exception as e:
print(f"❌ Failed to load model on startup: {e}")
raise
# Request and Response Models
class URLRequest(BaseModel):
"""Request model for URL prediction."""
url: str = Field(..., description="The URL to check for phishing", min_length=1)
@validator('url')
def validate_url(cls, v):
"""Validate that URL is not empty after stripping whitespace."""
if not v.strip():
raise ValueError('URL cannot be empty')
return v.strip()
class Config:
schema_extra = {
"example": {
"url": "https://www.google.com"
}
}
class PredictionResponse(BaseModel):
"""Response model for URL prediction."""
url: str = Field(..., description="The URL that was analyzed")
predicted_label: Optional[int] = Field(None, description="0 for legitimate, 1 for phishing, None if error")
prediction: str = Field(..., description="Human-readable prediction: 'legitimate', 'phishing', 'unknown', or 'error'")
phish_probability: Optional[float] = Field(None, description="Probability of being phishing (0.0 to 1.0)")
confidence: Optional[float] = Field(None, description="Confidence percentage of the prediction")
features_extracted: bool = Field(..., description="Whether features were successfully extracted from the URL")
error: Optional[str] = Field(None, description="Error message if prediction failed")
class Config:
schema_extra = {
"example": {
"url": "https://www.google.com",
"predicted_label": 0,
"prediction": "legitimate",
"phish_probability": 0.0234,
"confidence": 97.66,
"features_extracted": True,
"error": None
}
}
# API Endpoints
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"message": "Phishing URL Detection API",
"version": "1.0.0",
"endpoints": {
"/predict": "POST - Predict if a URL is phishing or legitimate",
"/health": "GET - Check API health status",
"/docs": "GET - Interactive API documentation"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": model_components is not None
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: URLRequest):
"""
Predict if a URL is phishing or legitimate.
Args:
request: URLRequest containing the URL to analyze
Returns:
PredictionResponse with prediction results
Raises:
HTTPException: If model is not loaded or prediction fails
"""
if model_components is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Please try again later."
)
try:
# Make prediction
result = predict_url(request.url, model_components)
return PredictionResponse(**result)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Prediction failed: {str(e)}"
)
# Run the application
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
reload=True
)
|