Spaces:
Sleeping
Sleeping
| # """ | |
| # FastAPI REST API for Product Classification | |
| # """ | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from starlette.requests import Request | |
| from fastapi import FastAPI, HTTPException, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional | |
| import logging | |
| import time | |
| # from classifier import ProductClassifier | |
| # from config import API_TITLE, API_VERSION, API_DESCRIPTION, validate_files | |
| from .classifier import ProductClassifier | |
| from .config import API_TITLE, API_VERSION, API_DESCRIPTION, validate_files | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Validate files exist before starting | |
| try: | |
| validate_files() | |
| logger.info("✅ All required model files found") | |
| except FileNotFoundError as e: | |
| logger.error(f"❌ Missing files: {e}") | |
| raise | |
| # Create FastAPI app | |
| app = FastAPI(title=API_TITLE, version=API_VERSION, description=API_DESCRIPTION) | |
| templates = Jinja2Templates(directory="templates") | |
| # Add CORS middleware (allows frontend to access API) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify actual origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize classifier (loaded once at startup) | |
| classifier = None | |
| # Pydantic models for request/response validation | |
| class ProductInput(BaseModel): | |
| """Input model for single product classification""" | |
| id: Optional[str] = Field(default="unknown", description="Product ID") | |
| title: str = Field(..., description="Product title", min_length=1) | |
| product_type: Optional[str] = Field(default="", description="Product type/category") | |
| vendor: Optional[str] = Field(default="", description="Brand or vendor name") | |
| tags: Optional[List[str]] = Field(default=[], description="Product tags") | |
| description: Optional[str] = Field(default="", description="Product description") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "id": "prod_123", | |
| "title": "Apple iPhone 15 Pro", | |
| "product_type": "Smartphone", | |
| "vendor": "Apple Inc", | |
| "tags": ["electronics", "phone", "mobile"], | |
| "description": "Latest flagship smartphone", | |
| } | |
| } | |
| class CategoryResult(BaseModel): | |
| """Result for a single category match""" | |
| rank: int | |
| category_id: str | |
| category_path: str | |
| confidence_percentage: float | |
| semantic_score: Optional[float] = None | |
| boost_applied: Optional[float] = None | |
| class ClassificationResponse(BaseModel): | |
| """Response model for classification""" | |
| product_id: str | |
| action: str | |
| reason: str | |
| top_category: str | |
| top_confidence: float | |
| product_text: str | |
| alternatives: List[CategoryResult] | |
| processing_time_ms: Optional[float] = None | |
| class BatchProductInput(BaseModel): | |
| """Input model for batch classification""" | |
| products: List[ProductInput] = Field( | |
| ..., description="List of products to classify" | |
| ) | |
| top_k: int = Field( | |
| default=5, ge=1, le=20, description="Number of top matches to return" | |
| ) | |
| class HealthResponse(BaseModel): | |
| """Health check response""" | |
| status: str | |
| model: str | |
| categories_loaded: int | |
| embedding_dimension: int | |
| # Startup event - load classifier | |
| async def startup_event(): | |
| """Load the classifier when API starts""" | |
| global classifier | |
| logger.info("🚀 Starting API server...") | |
| logger.info("Loading Product Classifier...") | |
| try: | |
| classifier = ProductClassifier() | |
| logger.info("✅ Classifier loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load classifier: {e}") | |
| raise | |
| # Root endpoint | |
| # @app.get("/", tags=["General"]) | |
| # async def root(): | |
| # """Root endpoint - API information""" | |
| # return { | |
| # "message": "Insurance Product Classification API", | |
| # "version": API_VERSION, | |
| # "status": "running", | |
| # "docs": "/docs", | |
| # "health": "/health", | |
| # } | |
| async def root(request: Request): | |
| """Serve the web UI""" | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| # Health check endpoint | |
| async def health_check(): | |
| """ | |
| Health check endpoint | |
| Returns system status and model information | |
| """ | |
| if classifier is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Classifier not initialized", | |
| ) | |
| return { | |
| "status": "healthy", | |
| "model": "all-mpnet-base-v2", | |
| "categories_loaded": len(classifier.embeddings), | |
| "embedding_dimension": classifier.embeddings.shape[1], | |
| } | |
| # Single product classification | |
| async def classify_product(product: ProductInput): | |
| """ | |
| Classify a single product into insurance categories | |
| Returns: | |
| - action: AUTO_APPROVE, QUICK_REVIEW, or MANUAL_CATEGORIZATION | |
| - top_category: Best matching category | |
| - confidence: Confidence score (0-100%) | |
| - alternatives: Top alternative categories | |
| """ | |
| if classifier is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Classifier not initialized", | |
| ) | |
| try: | |
| # Start timer | |
| start_time = time.time() | |
| # Classify | |
| result = classifier.classify(product.dict()) | |
| # Calculate processing time | |
| processing_time = (time.time() - start_time) * 1000 # Convert to ms | |
| result["processing_time_ms"] = round(processing_time, 2) | |
| logger.info( | |
| f"Classified product '{product.title}' → " | |
| f"{result['action']} ({result['top_confidence']}%)" | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Classification error: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Classification failed: {str(e)}", | |
| ) | |
| # Batch product classification | |
| async def classify_batch(batch: BatchProductInput): | |
| """ | |
| Classify multiple products at once | |
| Useful for bulk processing of product catalogs | |
| """ | |
| if classifier is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Classifier not initialized", | |
| ) | |
| try: | |
| start_time = time.time() | |
| # Convert to list of dicts | |
| products_data = [p.dict() for p in batch.products] | |
| # Classify batch | |
| results = classifier.classify_batch(products_data, top_k=batch.top_k) | |
| # Calculate stats | |
| processing_time = (time.time() - start_time) * 1000 | |
| # Count actions | |
| action_counts = {} | |
| for result in results: | |
| action = result.get("action", "UNKNOWN") | |
| action_counts[action] = action_counts.get(action, 0) + 1 | |
| logger.info( | |
| f"Batch classified {len(products_data)} products in {processing_time:.0f}ms" | |
| ) | |
| return { | |
| "total_products": len(products_data), | |
| "processing_time_ms": round(processing_time, 2), | |
| "action_counts": action_counts, | |
| "results": results, | |
| } | |
| except Exception as e: | |
| logger.error(f"Batch classification error: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Batch classification failed: {str(e)}", | |
| ) | |
| # Get statistics | |
| async def get_statistics(): | |
| """ | |
| Get system statistics | |
| """ | |
| if classifier is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Classifier not initialized", | |
| ) | |
| return { | |
| "total_categories": len(classifier.embeddings), | |
| "embedding_dimension": classifier.embeddings.shape[1], | |
| "model_name": "all-mpnet-base-v2", | |
| "thresholds": { | |
| "auto_approve": "≥75%", | |
| "quick_review": "60-75%", | |
| "manual": "<60%", | |
| }, | |
| } | |
| # Error handlers | |
| from fastapi.responses import JSONResponse | |
| async def not_found_handler(request, exc): | |
| """Handle 404 errors""" | |
| return JSONResponse( | |
| status_code=404, | |
| content={ | |
| "error": "Endpoint not found", | |
| "message": "Check /docs for available endpoints", | |
| }, | |
| ) | |
| async def internal_error_handler(request, exc): | |
| """Handle 500 errors""" | |
| logger.error(f"Internal server error: {exc}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "Internal server error", | |
| "message": "Something went wrong. Check logs for details.", | |
| }, | |
| ) | |
| # Run with: uvicorn api:app --reload | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True, log_level="info") | |