Rasel Santillan
Add application file
7a3576b
"""
FastAPI application for phishing URL detection.
Provides a REST API endpoint to predict if a URL is phishing or legitimate.
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
from typing import Optional
import uvicorn
from model.model import load_model, predict_url
# Initialize FastAPI app
app = FastAPI(
title="Phishing URL Detection API",
description="API for detecting phishing URLs using machine learning",
version="1.0.0"
)
# Add CORS middleware to allow web access
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model on startup
model_components = None
@app.on_event("startup")
async def startup_event():
"""Load the model when the application starts."""
global model_components
try:
model_components = load_model()
print("✅ Model loaded successfully on startup")
except Exception as e:
print(f"❌ Failed to load model on startup: {e}")
raise
# Request and Response Models
class URLRequest(BaseModel):
"""Request model for URL prediction."""
url: str = Field(..., description="The URL to check for phishing", min_length=1)
@validator('url')
def validate_url(cls, v):
"""Validate that URL is not empty after stripping whitespace."""
if not v.strip():
raise ValueError('URL cannot be empty')
return v.strip()
class Config:
schema_extra = {
"example": {
"url": "https://www.google.com"
}
}
class PredictionResponse(BaseModel):
"""Response model for URL prediction."""
url: str = Field(..., description="The URL that was analyzed")
predicted_label: Optional[int] = Field(None, description="0 for legitimate, 1 for phishing, None if error")
prediction: str = Field(..., description="Human-readable prediction: 'legitimate', 'phishing', 'unknown', or 'error'")
phish_probability: Optional[float] = Field(None, description="Probability of being phishing (0.0 to 1.0)")
confidence: Optional[float] = Field(None, description="Confidence percentage of the prediction")
features_extracted: bool = Field(..., description="Whether features were successfully extracted from the URL")
error: Optional[str] = Field(None, description="Error message if prediction failed")
class Config:
schema_extra = {
"example": {
"url": "https://www.google.com",
"predicted_label": 0,
"prediction": "legitimate",
"phish_probability": 0.0234,
"confidence": 97.66,
"features_extracted": True,
"error": None
}
}
# API Endpoints
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"message": "Phishing URL Detection API",
"version": "1.0.0",
"endpoints": {
"/predict": "POST - Predict if a URL is phishing or legitimate",
"/health": "GET - Check API health status",
"/docs": "GET - Interactive API documentation"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": model_components is not None
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: URLRequest):
"""
Predict if a URL is phishing or legitimate.
Args:
request: URLRequest containing the URL to analyze
Returns:
PredictionResponse with prediction results
Raises:
HTTPException: If model is not loaded or prediction fails
"""
if model_components is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Please try again later."
)
try:
# Make prediction
result = predict_url(request.url, model_components)
return PredictionResponse(**result)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Prediction failed: {str(e)}"
)
# Run the application
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
reload=True
)