#!/usr/bin/env python3 """ Gohan (CID) Product Recommendation FastAPI App FastAPI version of the Gohan CID inference engine This maintains the exact same functionality as the Gradio version """ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from contextlib import asynccontextmanager import json import os import time from typing import List, Optional, Dict, Any # Import the existing inference engine try: from inference_gohan_cid import GohanCIDInferenceEngine except ImportError: GohanCIDInferenceEngine = None # Model paths - same as Gradio version MODEL_PATH = "model/gohan/epoch_009_p50_0.5776.pt" ENCODERS_DIR = "model/gohan" PRODUCT_MASTER_PATH = "model/gohan/gohan_pm.csv" # App name for consistent messaging app_name = "gohan" # Pydantic models matching the exact API structure class PredictionRequest(BaseModel): company_data_json: str topK: Optional[int] = None class CategoryRecommendation(BaseModel): category_id: int category_name: str score: float class PredictionResponse(BaseModel): status: str model: str recommendations: List[CategoryRecommendation] metadata: Dict[str, Any] # Global variables engine = None model_files_exist = False @asynccontextmanager async def lifespan(app: FastAPI): global engine, model_files_exist print(f"🚀 Gohan FastAPI is starting. Loading AI model and data...") start_time = time.time() # Check if model files exist (same logic as Gradio version) model_files_exist = all([ os.path.exists(MODEL_PATH), os.path.exists(ENCODERS_DIR), os.path.exists(PRODUCT_MASTER_PATH) ]) if model_files_exist: print(f"🔍 Checking model files:") print(f" - MODEL_PATH: {MODEL_PATH} (exists: {os.path.exists(MODEL_PATH)})") print(f" - ENCODERS_DIR: {ENCODERS_DIR} (exists: {os.path.exists(ENCODERS_DIR)})") print(f" - PRODUCT_MASTER_PATH: {PRODUCT_MASTER_PATH} (exists: {os.path.exists(PRODUCT_MASTER_PATH)})") try: if GohanCIDInferenceEngine: engine = GohanCIDInferenceEngine( model_path=MODEL_PATH, encoders_dir=ENCODERS_DIR, product_master_path=PRODUCT_MASTER_PATH ) print(f"✅ {app_name.title()} CID model loaded successfully!") else: print(f"❌ {app_name.title()}CIDInferenceEngine not available") engine = None except Exception as e: print(f"❌ Failed to load {app_name.title()} CID model: {e}") engine = None else: print(f"⚠️ Model files not found. This is a template - add your model files to:") print(f" - {MODEL_PATH}") print(f" - {ENCODERS_DIR}/*.json") print(f" - {PRODUCT_MASTER_PATH}") engine = None print(f"✅ Startup completed in {time.time() - start_time:.2f} seconds.") yield print(f"🔄 {app_name.title()} FastAPI is shutting down.") # Initialize FastAPI app with lifespan app = FastAPI( title=f"{app_name.title()} Product Recommendation API", description=f"FastAPI version of the {app_name.title()} recommendation system - maintains exact same functionality as Gradio version", version="2.0.0", lifespan=lifespan ) # Target input fields (same as Gradio version) REQUIRED_FIELDS_EN = [ 'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE', 'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO', 'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3' ] @app.get("/") def root(): return { "message": f"🍚 {app_name.title()} Product Recommendation API (FastAPI)", "status": "running", "version": "2.0.0", "endpoints": ["/status", "/predict", "/predict_simple"], "model_status": "loaded" if engine else "not_loaded", "model_files_exist": model_files_exist } @app.get("/status") def get_status(): if engine is None: if model_files_exist: raise HTTPException( status_code=503, detail="Model not loaded - check model files" ) else: raise HTTPException( status_code=503, detail="Model files not found - this is a template. Add your model files to enable predictions." ) return { "status": "ready", "model_loaded": engine is not None, "model_files_exist": model_files_exist, "model_path": MODEL_PATH, "encoders_dir": ENCODERS_DIR, "product_master_path": PRODUCT_MASTER_PATH } @app.post("/predict", response_model=PredictionResponse) def predict(request: PredictionRequest): """ Predict gohan categories for a company (CID-based) This is the EXACT same logic as the Gradio version """ try: if engine is None: if model_files_exist: error_msg = "Model not loaded - check model files" else: error_msg = "Model files not found - this is a template. Add your model files to enable predictions." raise HTTPException( status_code=503, detail=error_msg ) # Parse input try: incoming = json.loads(request.company_data_json) except json.JSONDecodeError as e: raise HTTPException( status_code=400, detail=f"Invalid JSON format: {str(e)}" ) # topK handling if request.topK is not None and request.topK > 0: incoming["topK"] = int(request.topK) else: incoming.setdefault("topK", 30) # Validate English field presence missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming] if missing_en: raise HTTPException( status_code=400, detail=f"Missing required fields: {missing_en}" ) # Ensure TOTAL_VOLUME is present for the inference engine if 'TOTAL_VOLUME' not in incoming and 'DELIVERY_NUM' in incoming: incoming['TOTAL_VOLUME'] = incoming['DELIVERY_NUM'] # Predict recommendations = engine.predict(incoming) requested_k = int(incoming.get("topK", 30)) if len(recommendations) > requested_k: recommendations = recommendations[:requested_k] return PredictionResponse( status="success", model="gohan", recommendations=recommendations, metadata={ "model_version": "gohan_cid_v1.0", "total_categories": len(recommendations), "requested_k": requested_k } ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=500, detail=f"Prediction error: {str(e)}" ) @app.post("/predict_simple", response_model=PredictionResponse) def predict_simple(request: PredictionRequest): """Simple endpoint without topK parameter - same as Gradio version""" return predict(request) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)