Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import time | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import numpy as np | |
| from contextlib import asynccontextmanager | |
| # Create logs directory if it doesn't exist | |
| os.makedirs('logs', exist_ok=True) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('logs/app.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| model_loaded = False | |
| model_info = { | |
| "model_name": "songhieng/roberta-phishing-content-detector-5.0", | |
| "loaded_at": None, | |
| "version": "5.0", | |
| "framework": "transformers" | |
| } | |
| class PredictionRequest(BaseModel): | |
| text: str = Field(..., description="Text content to analyze for phishing", min_length=1, max_length=10000) | |
| class PredictionResponse(BaseModel): | |
| text: str | |
| score: float | |
| description: str | |
| processing_time_ms: float | |
| timestamp: str | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| timestamp: str | |
| uptime_seconds: float | |
| class BatchPredictionRequest(BaseModel): | |
| texts: List[str] = Field(..., description="List of texts to analyze", max_items=100) | |
| # Application startup and shutdown events | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| logger.info("Starting up the application...") | |
| await load_model() | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down the application...") | |
| app = FastAPI( | |
| title="RoBERTa Phishing Content Detector API", | |
| description="MLOps deployment of RoBERTa model for phishing content detection", | |
| version="5.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Startup time for uptime calculation | |
| startup_time = time.time() | |
| async def load_model(): | |
| """Load the model and tokenizer""" | |
| global model, tokenizer, model_loaded, model_info | |
| try: | |
| logger.info("Loading model and tokenizer...") | |
| model_path = "models/roberta-phishing-detector" | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model path {model_path} does not exist!") | |
| raise FileNotFoundError(f"Model not found at {model_path}") | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True) | |
| # Set model to evaluation mode | |
| model.eval() | |
| model_loaded = True | |
| model_info["loaded_at"] = datetime.now().isoformat() | |
| logger.info("Model and tokenizer loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| model_loaded = False | |
| raise e | |
| def predict_phishing(text: str) -> float: | |
| """Predict if text is phishing content and return a phishing score | |
| A higher score (closer to 1) indicates more likely to be phishing | |
| A lower score (closer to 0) indicates more likely to be legitimate | |
| """ | |
| if not model_loaded: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Tokenize the input text | |
| inputs = tokenizer( | |
| text, | |
| truncation=True, | |
| padding=True, | |
| max_length=4096, | |
| return_tensors="pt" | |
| ) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Get the phishing score (class 1 probability) | |
| phishing_score = float(predictions[0][1]) | |
| return phishing_score | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def log_requests(request: Request, call_next): | |
| """Log all requests""" | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| logger.info( | |
| f"Request: {request.method} {request.url} - " | |
| f"Status: {response.status_code} - " | |
| f"Time: {process_time:.4f}s" | |
| ) | |
| return response | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "RoBERTa Phishing Content Detector API", | |
| "version": "1.0.0", | |
| "model": model_info["model_name"], | |
| "status": "healthy" if model_loaded else "unhealthy" | |
| } | |
| async def health_check(): | |
| """Health check endpoint for monitoring""" | |
| uptime = time.time() - startup_time | |
| return HealthResponse( | |
| status="healthy" if model_loaded else "unhealthy", | |
| model_loaded=model_loaded, | |
| timestamp=datetime.now().isoformat(), | |
| uptime_seconds=uptime | |
| ) | |
| async def model_info_endpoint(): | |
| """Get model information""" | |
| return { | |
| "model_info": model_info, | |
| "model_loaded": model_loaded, | |
| "torch_version": torch.__version__ | |
| } | |
| async def predict(request: PredictionRequest): | |
| """Predict if text content is phishing""" | |
| start_time = time.time() | |
| try: | |
| phishing_score = predict_phishing(request.text) | |
| # Generate description based on score | |
| if phishing_score < 0.2: | |
| classification = "Legitimate (Very Low Risk)" | |
| elif phishing_score < 0.4: | |
| classification = "Likely Legitimate (Low Risk)" | |
| elif phishing_score < 0.6: | |
| classification = "Uncertain (Medium Risk)" | |
| elif phishing_score < 0.8: | |
| classification = "Likely Phishing (High Risk)" | |
| else: | |
| classification = "Phishing (Very High Risk)" | |
| description = f"{classification}: Score {phishing_score:.4f} - Lower scores (closer to 0) indicate legitimate content, higher scores (closer to 1) indicate phishing/malicious content" | |
| processing_time = (time.time() - start_time) * 1000 # Convert to milliseconds | |
| response = PredictionResponse( | |
| text=request.text[:100] + "..." if len(request.text) > 100 else request.text, | |
| score=phishing_score, | |
| description=description, | |
| processing_time_ms=round(processing_time, 2), | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| logger.info(f"Prediction made: (phishing score: {phishing_score:.4f}, classification: {classification})") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Prediction endpoint error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def predict_batch(request: BatchPredictionRequest): | |
| """Batch prediction endpoint""" | |
| start_time = time.time() | |
| try: | |
| results = [] | |
| for text in request.texts: | |
| phishing_score = predict_phishing(text) | |
| # Generate classification based on score | |
| if phishing_score < 0.2: | |
| classification = "Legitimate (Very Low Risk)" | |
| elif phishing_score < 0.4: | |
| classification = "Likely Legitimate (Low Risk)" | |
| elif phishing_score < 0.6: | |
| classification = "Uncertain (Medium Risk)" | |
| elif phishing_score < 0.8: | |
| classification = "Likely Phishing (High Risk)" | |
| else: | |
| classification = "Phishing (Very High Risk)" | |
| results.append({ | |
| "text": text[:50] + "..." if len(text) > 50 else text, | |
| "score": phishing_score, | |
| "classification": classification | |
| }) | |
| processing_time = (time.time() - start_time) * 1000 | |
| return { | |
| "results": results, | |
| "total_processed": len(results), | |
| "processing_time_ms": round(processing_time, 2), | |
| "timestamp": datetime.now().isoformat(), | |
| "note": "Lower scores (closer to 0) indicate legitimate content, higher scores (closer to 1) indicate phishing/malicious content" | |
| } | |
| except Exception as e: | |
| logger.error(f"Batch prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def metrics(): | |
| """Basic metrics endpoint for monitoring""" | |
| uptime = time.time() - startup_time | |
| return { | |
| "uptime_seconds": uptime, | |
| "model_loaded": model_loaded, | |
| "model_info": model_info, | |
| "memory_usage": "Not implemented", # Could add psutil for real memory usage | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=False, | |
| log_level="info" | |
| ) |