# """ # FastAPI REST API for Product Classification # """ from fastapi.templating import Jinja2Templates from fastapi.responses import HTMLResponse, JSONResponse from starlette.requests import Request from fastapi import FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Optional import logging import time # from classifier import ProductClassifier # from config import API_TITLE, API_VERSION, API_DESCRIPTION, validate_files from .classifier import ProductClassifier from .config import API_TITLE, API_VERSION, API_DESCRIPTION, validate_files # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Validate files exist before starting try: validate_files() logger.info("✅ All required model files found") except FileNotFoundError as e: logger.error(f"❌ Missing files: {e}") raise # Create FastAPI app app = FastAPI(title=API_TITLE, version=API_VERSION, description=API_DESCRIPTION) templates = Jinja2Templates(directory="templates") # Add CORS middleware (allows frontend to access API) app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify actual origins allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize classifier (loaded once at startup) classifier = None # Pydantic models for request/response validation class ProductInput(BaseModel): """Input model for single product classification""" id: Optional[str] = Field(default="unknown", description="Product ID") title: str = Field(..., description="Product title", min_length=1) product_type: Optional[str] = Field(default="", description="Product type/category") vendor: Optional[str] = Field(default="", description="Brand or vendor name") tags: Optional[List[str]] = Field(default=[], description="Product tags") description: Optional[str] = Field(default="", description="Product description") class Config: json_schema_extra = { "example": { "id": "prod_123", "title": "Apple iPhone 15 Pro", "product_type": "Smartphone", "vendor": "Apple Inc", "tags": ["electronics", "phone", "mobile"], "description": "Latest flagship smartphone", } } class CategoryResult(BaseModel): """Result for a single category match""" rank: int category_id: str category_path: str confidence_percentage: float semantic_score: Optional[float] = None boost_applied: Optional[float] = None class ClassificationResponse(BaseModel): """Response model for classification""" product_id: str action: str reason: str top_category: str top_confidence: float product_text: str alternatives: List[CategoryResult] processing_time_ms: Optional[float] = None class BatchProductInput(BaseModel): """Input model for batch classification""" products: List[ProductInput] = Field( ..., description="List of products to classify" ) top_k: int = Field( default=5, ge=1, le=20, description="Number of top matches to return" ) class HealthResponse(BaseModel): """Health check response""" status: str model: str categories_loaded: int embedding_dimension: int # Startup event - load classifier @app.on_event("startup") async def startup_event(): """Load the classifier when API starts""" global classifier logger.info("🚀 Starting API server...") logger.info("Loading Product Classifier...") try: classifier = ProductClassifier() logger.info("✅ Classifier loaded successfully!") except Exception as e: logger.error(f"❌ Failed to load classifier: {e}") raise # Root endpoint # @app.get("/", tags=["General"]) # async def root(): # """Root endpoint - API information""" # return { # "message": "Insurance Product Classification API", # "version": API_VERSION, # "status": "running", # "docs": "/docs", # "health": "/health", # } @app.get("/", response_class=HTMLResponse, tags=["General"]) async def root(request: Request): """Serve the web UI""" return templates.TemplateResponse("index.html", {"request": request}) # Health check endpoint @app.get("/health", response_model=HealthResponse, tags=["General"]) async def health_check(): """ Health check endpoint Returns system status and model information """ if classifier is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Classifier not initialized", ) return { "status": "healthy", "model": "all-mpnet-base-v2", "categories_loaded": len(classifier.embeddings), "embedding_dimension": classifier.embeddings.shape[1], } # Single product classification @app.post("/classify", response_model=ClassificationResponse, tags=["Classification"]) async def classify_product(product: ProductInput): """ Classify a single product into insurance categories Returns: - action: AUTO_APPROVE, QUICK_REVIEW, or MANUAL_CATEGORIZATION - top_category: Best matching category - confidence: Confidence score (0-100%) - alternatives: Top alternative categories """ if classifier is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Classifier not initialized", ) try: # Start timer start_time = time.time() # Classify result = classifier.classify(product.dict()) # Calculate processing time processing_time = (time.time() - start_time) * 1000 # Convert to ms result["processing_time_ms"] = round(processing_time, 2) logger.info( f"Classified product '{product.title}' → " f"{result['action']} ({result['top_confidence']}%)" ) return result except Exception as e: logger.error(f"Classification error: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Classification failed: {str(e)}", ) # Batch product classification @app.post("/classify-batch", tags=["Classification"]) async def classify_batch(batch: BatchProductInput): """ Classify multiple products at once Useful for bulk processing of product catalogs """ if classifier is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Classifier not initialized", ) try: start_time = time.time() # Convert to list of dicts products_data = [p.dict() for p in batch.products] # Classify batch results = classifier.classify_batch(products_data, top_k=batch.top_k) # Calculate stats processing_time = (time.time() - start_time) * 1000 # Count actions action_counts = {} for result in results: action = result.get("action", "UNKNOWN") action_counts[action] = action_counts.get(action, 0) + 1 logger.info( f"Batch classified {len(products_data)} products in {processing_time:.0f}ms" ) return { "total_products": len(products_data), "processing_time_ms": round(processing_time, 2), "action_counts": action_counts, "results": results, } except Exception as e: logger.error(f"Batch classification error: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Batch classification failed: {str(e)}", ) # Get statistics @app.get("/stats", tags=["General"]) async def get_statistics(): """ Get system statistics """ if classifier is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Classifier not initialized", ) return { "total_categories": len(classifier.embeddings), "embedding_dimension": classifier.embeddings.shape[1], "model_name": "all-mpnet-base-v2", "thresholds": { "auto_approve": "≥75%", "quick_review": "60-75%", "manual": "<60%", }, } # Error handlers from fastapi.responses import JSONResponse @app.exception_handler(404) async def not_found_handler(request, exc): """Handle 404 errors""" return JSONResponse( status_code=404, content={ "error": "Endpoint not found", "message": "Check /docs for available endpoints", }, ) @app.exception_handler(500) async def internal_error_handler(request, exc): """Handle 500 errors""" logger.error(f"Internal server error: {exc}") return JSONResponse( status_code=500, content={ "error": "Internal server error", "message": "Something went wrong. Check logs for details.", }, ) # Run with: uvicorn api:app --reload if __name__ == "__main__": import uvicorn uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True, log_level="info")