Spaces:
Runtime error
Runtime error
| """ | |
| FastAPI application for phishing URL detection. | |
| """ | |
| from fastapi import FastAPI, HTTPException, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, validator | |
| from typing import Optional | |
| import logging | |
| from model.model import predict_url, load_model, get_meta_features_and_update | |
| from categorization import categorize_phishing_result, RiskCategory, BinaryClassification | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Create FastAPI application | |
| app = FastAPI( | |
| title="Phishing URL Detection API", | |
| description="API for detecting phishing URLs using machine learning. Analyzes URL features to classify URLs as legitimate or phishing attempts.", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Configure CORS middleware to allow web browser access | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with specific origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic models for request/response validation | |
| class URLRequest(BaseModel): | |
| """Request model for URL prediction.""" | |
| url: str = Field( | |
| ..., | |
| description="The URL to analyze for phishing detection", | |
| example="http://example.com" | |
| ) | |
| def validate_url(cls, v): | |
| """Validate that URL is not empty.""" | |
| if not v or not v.strip(): | |
| raise ValueError("URL cannot be empty") | |
| return v.strip() | |
| class PredictionResponse(BaseModel): | |
| """Response model for URL prediction.""" | |
| url: str = Field(..., description="The analyzed URL") | |
| prediction: str = Field(..., description="Prediction result: 'phishing', 'legitimate', or 'unknown'") | |
| confidence: float = Field(..., description="Confidence score (0-1)") | |
| predicted_label: int = Field(..., description="Predicted label: 0 (legitimate), 1 (phishing), -1 (unknown)") | |
| phish_probability: float = Field(..., description="Probability of being phishing (0-1)") | |
| phish_probability_percent: float = Field(..., description="Probability of being phishing (0-100 scale)") | |
| risk_category: str = Field(..., description="Risk category: 'Safe', 'Low', 'Moderate', 'Dangerous', or 'Critical'") | |
| binary_classification: str = Field(..., description="Binary classification: 'Legitimate' or 'Phishing'") | |
| error: Optional[str] = Field(None, description="Error message if prediction failed") | |
| class HealthResponse(BaseModel): | |
| """Response model for health check.""" | |
| status: str = Field(..., description="Service status") | |
| message: str = Field(..., description="Status message") | |
| class UpdateRequest(BaseModel): | |
| """Request model for online learning update.""" | |
| url: str = Field(..., description="The URL that was misclassified") | |
| true_label: int = Field(..., description="True label: 0 (legitimate) or 1 (phishing)") | |
| def validate_label(cls, v): | |
| """Validate that true_label is 0 or 1.""" | |
| if v not in [0, 1]: | |
| raise ValueError("true_label must be 0 (legitimate) or 1 (phishing)") | |
| return v | |
| class UpdateResponse(BaseModel): | |
| """Response model for online learning update.""" | |
| status: str = Field(..., description="Update status") | |
| message: str = Field(..., description="Update message") | |
| url: str = Field(..., description="The URL that was updated") | |
| true_label: int = Field(..., description="The true label used for update") | |
| meta_features: Optional[list] = Field(None, description="Meta features used for update") | |
| # API Endpoints | |
| async def root(): | |
| """ | |
| Root endpoint - Health check. | |
| Returns: | |
| HealthResponse: Service status information | |
| """ | |
| return HealthResponse( | |
| status="healthy", | |
| message="Phishing URL Detection API is running" | |
| ) | |
| async def health_check(): | |
| """ | |
| Health check endpoint. | |
| Returns: | |
| HealthResponse: Service status information | |
| """ | |
| return HealthResponse( | |
| status="healthy", | |
| message="Service is operational" | |
| ) | |
| async def predict(request: URLRequest): | |
| """ | |
| Predict whether a URL is phishing or legitimate. | |
| This endpoint: | |
| 1. Validates the input URL | |
| 2. Extracts features from the URL and its webpage | |
| 3. Uses a machine learning model to classify the URL | |
| 4. Returns the prediction with confidence score | |
| Args: | |
| request: URLRequest containing the URL to analyze | |
| Returns: | |
| PredictionResponse: Prediction result with confidence score | |
| Raises: | |
| HTTPException: 400 for invalid input, 500 for server errors | |
| """ | |
| try: | |
| logger.info(f"Received prediction request for URL: {request.url}") | |
| # Validate URL is not empty (already done by Pydantic validator) | |
| if not request.url: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="URL cannot be empty" | |
| ) | |
| # Call prediction function | |
| result = predict_url(request.url) | |
| # Add risk categorization | |
| risk_category, binary_classification, score_100 = categorize_phishing_result( | |
| result['phish_probability'] | |
| ) | |
| result['phish_probability_percent'] = score_100 | |
| result['risk_category'] = risk_category.value | |
| result['binary_classification'] = binary_classification.value | |
| logger.info(f"Prediction successful: {result['prediction']} | Risk: {risk_category.value} | Classification: {binary_classification.value}") | |
| return PredictionResponse(**result) | |
| except ValueError as e: | |
| # Handle validation errors | |
| logger.error(f"Validation error: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Invalid input: {str(e)}" | |
| ) | |
| except FileNotFoundError as e: | |
| # Handle model file not found | |
| logger.error(f"Model file not found: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Model file not found. Please ensure the model is properly deployed." | |
| ) | |
| except Exception as e: | |
| # Handle all other errors | |
| logger.error(f"Prediction error: {str(e)}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"An error occurred during prediction: {str(e)}" | |
| ) | |
| async def update_model(request: UpdateRequest): | |
| """ | |
| Update the meta model using online learning with partial_fit. | |
| This endpoint: | |
| 1. Extracts features from the URL | |
| 2. Generates meta-features using base models | |
| 3. Updates the SGD meta model with partial_fit | |
| 4. Saves the updated model | |
| Args: | |
| request: UpdateRequest containing URL and true label | |
| Returns: | |
| UpdateResponse: Update status and meta features used | |
| Raises: | |
| HTTPException: 400 for invalid input, 500 for server errors | |
| """ | |
| try: | |
| logger.info(f"Received update request for URL: {request.url} with label: {request.true_label}") | |
| # Validate inputs | |
| if not request.url or not request.url.strip(): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="URL cannot be empty" | |
| ) | |
| if request.true_label not in [0, 1]: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="true_label must be 0 (legitimate) or 1 (phishing)" | |
| ) | |
| # Get meta features and update model | |
| meta_features, updated = get_meta_features_and_update(request.url, request.true_label) | |
| if not updated: | |
| logger.warning(f"Failed to update model for URL: {request.url}") | |
| return UpdateResponse( | |
| status="failed", | |
| message="Failed to update model - feature extraction may have failed", | |
| url=request.url, | |
| true_label=request.true_label, | |
| meta_features=None | |
| ) | |
| logger.info(f"✅ Model updated successfully for URL: {request.url}") | |
| return UpdateResponse( | |
| status="success", | |
| message="Meta model updated successfully with partial_fit", | |
| url=request.url, | |
| true_label=request.true_label, | |
| meta_features=meta_features.tolist() if meta_features is not None else None | |
| ) | |
| except ValueError as e: | |
| logger.error(f"Validation error in update: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Invalid input: {str(e)}" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Update error: {str(e)}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"An error occurred during model update: {str(e)}" | |
| ) | |
| # Startup event | |
| async def startup_event(): | |
| """ | |
| Startup event handler. | |
| Loads the model on application startup to ensure it's ready. | |
| """ | |
| try: | |
| logger.info("Starting up Phishing URL Detection API...") | |
| from model.model import load_model | |
| load_model() # Pre-load model on startup | |
| logger.info("✅ Model loaded successfully on startup") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load model on startup: {str(e)}") | |
| # Don't prevent startup, but log the error | |
| # Shutdown event | |
| async def shutdown_event(): | |
| """ | |
| Shutdown event handler. | |
| """ | |
| logger.info("Shutting down Phishing URL Detection API...") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True, | |
| log_level="info" | |
| ) | |