Rasel Santillan
Squashed clean history
8a9ac80
"""
FastAPI application for phishing URL detection.
"""
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
from typing import Optional
import logging
from model.model import predict_url, load_model, get_meta_features_and_update
from categorization import categorize_phishing_result, RiskCategory, BinaryClassification
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Create FastAPI application
app = FastAPI(
title="Phishing URL Detection API",
description="API for detecting phishing URLs using machine learning. Analyzes URL features to classify URLs as legitimate or phishing attempts.",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Configure CORS middleware to allow web browser access
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic models for request/response validation
class URLRequest(BaseModel):
"""Request model for URL prediction."""
url: str = Field(
...,
description="The URL to analyze for phishing detection",
example="http://example.com"
)
@validator('url')
def validate_url(cls, v):
"""Validate that URL is not empty."""
if not v or not v.strip():
raise ValueError("URL cannot be empty")
return v.strip()
class PredictionResponse(BaseModel):
"""Response model for URL prediction."""
url: str = Field(..., description="The analyzed URL")
prediction: str = Field(..., description="Prediction result: 'phishing', 'legitimate', or 'unknown'")
confidence: float = Field(..., description="Confidence score (0-1)")
predicted_label: int = Field(..., description="Predicted label: 0 (legitimate), 1 (phishing), -1 (unknown)")
phish_probability: float = Field(..., description="Probability of being phishing (0-1)")
phish_probability_percent: float = Field(..., description="Probability of being phishing (0-100 scale)")
risk_category: str = Field(..., description="Risk category: 'Safe', 'Low', 'Moderate', 'Dangerous', or 'Critical'")
binary_classification: str = Field(..., description="Binary classification: 'Legitimate' or 'Phishing'")
error: Optional[str] = Field(None, description="Error message if prediction failed")
class HealthResponse(BaseModel):
"""Response model for health check."""
status: str = Field(..., description="Service status")
message: str = Field(..., description="Status message")
class UpdateRequest(BaseModel):
"""Request model for online learning update."""
url: str = Field(..., description="The URL that was misclassified")
true_label: int = Field(..., description="True label: 0 (legitimate) or 1 (phishing)")
@validator('true_label')
def validate_label(cls, v):
"""Validate that true_label is 0 or 1."""
if v not in [0, 1]:
raise ValueError("true_label must be 0 (legitimate) or 1 (phishing)")
return v
class UpdateResponse(BaseModel):
"""Response model for online learning update."""
status: str = Field(..., description="Update status")
message: str = Field(..., description="Update message")
url: str = Field(..., description="The URL that was updated")
true_label: int = Field(..., description="The true label used for update")
meta_features: Optional[list] = Field(None, description="Meta features used for update")
# API Endpoints
@app.get("/", response_model=HealthResponse, tags=["Health"])
async def root():
"""
Root endpoint - Health check.
Returns:
HealthResponse: Service status information
"""
return HealthResponse(
status="healthy",
message="Phishing URL Detection API is running"
)
@app.get("/health", response_model=HealthResponse, tags=["Health"])
async def health_check():
"""
Health check endpoint.
Returns:
HealthResponse: Service status information
"""
return HealthResponse(
status="healthy",
message="Service is operational"
)
@app.post("/predict", response_model=PredictionResponse, tags=["Prediction"])
async def predict(request: URLRequest):
"""
Predict whether a URL is phishing or legitimate.
This endpoint:
1. Validates the input URL
2. Extracts features from the URL and its webpage
3. Uses a machine learning model to classify the URL
4. Returns the prediction with confidence score
Args:
request: URLRequest containing the URL to analyze
Returns:
PredictionResponse: Prediction result with confidence score
Raises:
HTTPException: 400 for invalid input, 500 for server errors
"""
try:
logger.info(f"Received prediction request for URL: {request.url}")
# Validate URL is not empty (already done by Pydantic validator)
if not request.url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="URL cannot be empty"
)
# Call prediction function
result = predict_url(request.url)
# Add risk categorization
risk_category, binary_classification, score_100 = categorize_phishing_result(
result['phish_probability']
)
result['phish_probability_percent'] = score_100
result['risk_category'] = risk_category.value
result['binary_classification'] = binary_classification.value
logger.info(f"Prediction successful: {result['prediction']} | Risk: {risk_category.value} | Classification: {binary_classification.value}")
return PredictionResponse(**result)
except ValueError as e:
# Handle validation errors
logger.error(f"Validation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid input: {str(e)}"
)
except FileNotFoundError as e:
# Handle model file not found
logger.error(f"Model file not found: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Model file not found. Please ensure the model is properly deployed."
)
except Exception as e:
# Handle all other errors
logger.error(f"Prediction error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An error occurred during prediction: {str(e)}"
)
@app.post("/update", response_model=UpdateResponse, tags=["Update"])
async def update_model(request: UpdateRequest):
"""
Update the meta model using online learning with partial_fit.
This endpoint:
1. Extracts features from the URL
2. Generates meta-features using base models
3. Updates the SGD meta model with partial_fit
4. Saves the updated model
Args:
request: UpdateRequest containing URL and true label
Returns:
UpdateResponse: Update status and meta features used
Raises:
HTTPException: 400 for invalid input, 500 for server errors
"""
try:
logger.info(f"Received update request for URL: {request.url} with label: {request.true_label}")
# Validate inputs
if not request.url or not request.url.strip():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="URL cannot be empty"
)
if request.true_label not in [0, 1]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="true_label must be 0 (legitimate) or 1 (phishing)"
)
# Get meta features and update model
meta_features, updated = get_meta_features_and_update(request.url, request.true_label)
if not updated:
logger.warning(f"Failed to update model for URL: {request.url}")
return UpdateResponse(
status="failed",
message="Failed to update model - feature extraction may have failed",
url=request.url,
true_label=request.true_label,
meta_features=None
)
logger.info(f"✅ Model updated successfully for URL: {request.url}")
return UpdateResponse(
status="success",
message="Meta model updated successfully with partial_fit",
url=request.url,
true_label=request.true_label,
meta_features=meta_features.tolist() if meta_features is not None else None
)
except ValueError as e:
logger.error(f"Validation error in update: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid input: {str(e)}"
)
except Exception as e:
logger.error(f"Update error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An error occurred during model update: {str(e)}"
)
# Startup event
@app.on_event("startup")
async def startup_event():
"""
Startup event handler.
Loads the model on application startup to ensure it's ready.
"""
try:
logger.info("Starting up Phishing URL Detection API...")
from model.model import load_model
load_model() # Pre-load model on startup
logger.info("✅ Model loaded successfully on startup")
except Exception as e:
logger.error(f"❌ Failed to load model on startup: {str(e)}")
# Don't prevent startup, but log the error
# Shutdown event
@app.on_event("shutdown")
async def shutdown_event():
"""
Shutdown event handler.
"""
logger.info("Shutting down Phishing URL Detection API...")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True,
log_level="info"
)