Spaces:
Runtime error
Runtime error
tabito12345678910
Fix numpy dependency: specify exact versions and install order in Dockerfile
4478f12
| #!/usr/bin/env python3 | |
| """ | |
| Gohan (CID) Product Recommendation FastAPI App | |
| FastAPI version of the Gohan CID inference engine | |
| This maintains the exact same functionality as the Gradio version | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| import json | |
| import os | |
| import time | |
| from typing import List, Optional, Dict, Any | |
| # Import the existing inference engine | |
| try: | |
| from inference_gohan_cid import GohanCIDInferenceEngine | |
| except ImportError: | |
| GohanCIDInferenceEngine = None | |
| # Model paths - same as Gradio version | |
| MODEL_PATH = "model/gohan/epoch_009_p50_0.5776.pt" | |
| ENCODERS_DIR = "model/gohan" | |
| PRODUCT_MASTER_PATH = "model/gohan/gohan_pm.csv" | |
| # App name for consistent messaging | |
| app_name = "gohan" | |
| # Pydantic models matching the exact API structure | |
| class PredictionRequest(BaseModel): | |
| company_data_json: str | |
| topK: Optional[int] = None | |
| class CategoryRecommendation(BaseModel): | |
| category_id: int | |
| category_name: str | |
| score: float | |
| class PredictionResponse(BaseModel): | |
| status: str | |
| model: str | |
| recommendations: List[CategoryRecommendation] | |
| metadata: Dict[str, Any] | |
| # Global variables | |
| engine = None | |
| model_files_exist = False | |
| async def lifespan(app: FastAPI): | |
| global engine, model_files_exist | |
| print(f"π Gohan FastAPI is starting. Loading AI model and data...") | |
| start_time = time.time() | |
| # Check if model files exist (same logic as Gradio version) | |
| model_files_exist = all([ | |
| os.path.exists(MODEL_PATH), | |
| os.path.exists(ENCODERS_DIR), | |
| os.path.exists(PRODUCT_MASTER_PATH) | |
| ]) | |
| if model_files_exist: | |
| print(f"π Checking model files:") | |
| print(f" - MODEL_PATH: {MODEL_PATH} (exists: {os.path.exists(MODEL_PATH)})") | |
| print(f" - ENCODERS_DIR: {ENCODERS_DIR} (exists: {os.path.exists(ENCODERS_DIR)})") | |
| print(f" - PRODUCT_MASTER_PATH: {PRODUCT_MASTER_PATH} (exists: {os.path.exists(PRODUCT_MASTER_PATH)})") | |
| try: | |
| if GohanCIDInferenceEngine: | |
| engine = GohanCIDInferenceEngine( | |
| model_path=MODEL_PATH, | |
| encoders_dir=ENCODERS_DIR, | |
| product_master_path=PRODUCT_MASTER_PATH | |
| ) | |
| print(f"β {app_name.title()} CID model loaded successfully!") | |
| else: | |
| print(f"β {app_name.title()}CIDInferenceEngine not available") | |
| engine = None | |
| except Exception as e: | |
| print(f"β Failed to load {app_name.title()} CID model: {e}") | |
| engine = None | |
| else: | |
| print(f"β οΈ Model files not found. This is a template - add your model files to:") | |
| print(f" - {MODEL_PATH}") | |
| print(f" - {ENCODERS_DIR}/*.json") | |
| print(f" - {PRODUCT_MASTER_PATH}") | |
| engine = None | |
| print(f"β Startup completed in {time.time() - start_time:.2f} seconds.") | |
| yield | |
| print(f"π {app_name.title()} FastAPI is shutting down.") | |
| # Initialize FastAPI app with lifespan | |
| app = FastAPI( | |
| title=f"{app_name.title()} Product Recommendation API", | |
| description=f"FastAPI version of the {app_name.title()} recommendation system - maintains exact same functionality as Gradio version", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Target input fields (same as Gradio version) | |
| REQUIRED_FIELDS_EN = [ | |
| 'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE', | |
| 'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO', | |
| 'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3' | |
| ] | |
| def root(): | |
| return { | |
| "message": f"π {app_name.title()} Product Recommendation API (FastAPI)", | |
| "status": "running", | |
| "version": "2.0.0", | |
| "endpoints": ["/status", "/predict", "/predict_simple"], | |
| "model_status": "loaded" if engine else "not_loaded", | |
| "model_files_exist": model_files_exist | |
| } | |
| def get_status(): | |
| if engine is None: | |
| if model_files_exist: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model not loaded - check model files" | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model files not found - this is a template. Add your model files to enable predictions." | |
| ) | |
| return { | |
| "status": "ready", | |
| "model_loaded": engine is not None, | |
| "model_files_exist": model_files_exist, | |
| "model_path": MODEL_PATH, | |
| "encoders_dir": ENCODERS_DIR, | |
| "product_master_path": PRODUCT_MASTER_PATH | |
| } | |
| def predict(request: PredictionRequest): | |
| """ | |
| Predict gohan categories for a company (CID-based) | |
| This is the EXACT same logic as the Gradio version | |
| """ | |
| try: | |
| if engine is None: | |
| if model_files_exist: | |
| error_msg = "Model not loaded - check model files" | |
| else: | |
| error_msg = "Model files not found - this is a template. Add your model files to enable predictions." | |
| raise HTTPException( | |
| status_code=503, | |
| detail=error_msg | |
| ) | |
| # Parse input | |
| try: | |
| incoming = json.loads(request.company_data_json) | |
| except json.JSONDecodeError as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid JSON format: {str(e)}" | |
| ) | |
| # topK handling | |
| if request.topK is not None and request.topK > 0: | |
| incoming["topK"] = int(request.topK) | |
| else: | |
| incoming.setdefault("topK", 30) | |
| # Validate English field presence | |
| missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming] | |
| if missing_en: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Missing required fields: {missing_en}" | |
| ) | |
| # Ensure TOTAL_VOLUME is present for the inference engine | |
| if 'TOTAL_VOLUME' not in incoming and 'DELIVERY_NUM' in incoming: | |
| incoming['TOTAL_VOLUME'] = incoming['DELIVERY_NUM'] | |
| # Predict | |
| recommendations = engine.predict(incoming) | |
| requested_k = int(incoming.get("topK", 30)) | |
| if len(recommendations) > requested_k: | |
| recommendations = recommendations[:requested_k] | |
| return PredictionResponse( | |
| status="success", | |
| model="gohan", | |
| recommendations=recommendations, | |
| metadata={ | |
| "model_version": "gohan_cid_v1.0", | |
| "total_categories": len(recommendations), | |
| "requested_k": requested_k | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction error: {str(e)}" | |
| ) | |
| def predict_simple(request: PredictionRequest): | |
| """Simple endpoint without topK parameter - same as Gradio version""" | |
| return predict(request) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |