gohan-api-light / app.py
tabito12345678910
Fix numpy dependency: specify exact versions and install order in Dockerfile
4478f12
#!/usr/bin/env python3
"""
Gohan (CID) Product Recommendation FastAPI App
FastAPI version of the Gohan CID inference engine
This maintains the exact same functionality as the Gradio version
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import json
import os
import time
from typing import List, Optional, Dict, Any
# Import the existing inference engine
try:
from inference_gohan_cid import GohanCIDInferenceEngine
except ImportError:
GohanCIDInferenceEngine = None
# Model paths - same as Gradio version
MODEL_PATH = "model/gohan/epoch_009_p50_0.5776.pt"
ENCODERS_DIR = "model/gohan"
PRODUCT_MASTER_PATH = "model/gohan/gohan_pm.csv"
# App name for consistent messaging
app_name = "gohan"
# Pydantic models matching the exact API structure
class PredictionRequest(BaseModel):
company_data_json: str
topK: Optional[int] = None
class CategoryRecommendation(BaseModel):
category_id: int
category_name: str
score: float
class PredictionResponse(BaseModel):
status: str
model: str
recommendations: List[CategoryRecommendation]
metadata: Dict[str, Any]
# Global variables
engine = None
model_files_exist = False
@asynccontextmanager
async def lifespan(app: FastAPI):
global engine, model_files_exist
print(f"πŸš€ Gohan FastAPI is starting. Loading AI model and data...")
start_time = time.time()
# Check if model files exist (same logic as Gradio version)
model_files_exist = all([
os.path.exists(MODEL_PATH),
os.path.exists(ENCODERS_DIR),
os.path.exists(PRODUCT_MASTER_PATH)
])
if model_files_exist:
print(f"πŸ” Checking model files:")
print(f" - MODEL_PATH: {MODEL_PATH} (exists: {os.path.exists(MODEL_PATH)})")
print(f" - ENCODERS_DIR: {ENCODERS_DIR} (exists: {os.path.exists(ENCODERS_DIR)})")
print(f" - PRODUCT_MASTER_PATH: {PRODUCT_MASTER_PATH} (exists: {os.path.exists(PRODUCT_MASTER_PATH)})")
try:
if GohanCIDInferenceEngine:
engine = GohanCIDInferenceEngine(
model_path=MODEL_PATH,
encoders_dir=ENCODERS_DIR,
product_master_path=PRODUCT_MASTER_PATH
)
print(f"βœ… {app_name.title()} CID model loaded successfully!")
else:
print(f"❌ {app_name.title()}CIDInferenceEngine not available")
engine = None
except Exception as e:
print(f"❌ Failed to load {app_name.title()} CID model: {e}")
engine = None
else:
print(f"⚠️ Model files not found. This is a template - add your model files to:")
print(f" - {MODEL_PATH}")
print(f" - {ENCODERS_DIR}/*.json")
print(f" - {PRODUCT_MASTER_PATH}")
engine = None
print(f"βœ… Startup completed in {time.time() - start_time:.2f} seconds.")
yield
print(f"πŸ”„ {app_name.title()} FastAPI is shutting down.")
# Initialize FastAPI app with lifespan
app = FastAPI(
title=f"{app_name.title()} Product Recommendation API",
description=f"FastAPI version of the {app_name.title()} recommendation system - maintains exact same functionality as Gradio version",
version="2.0.0",
lifespan=lifespan
)
# Target input fields (same as Gradio version)
REQUIRED_FIELDS_EN = [
'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
]
@app.get("/")
def root():
return {
"message": f"🍚 {app_name.title()} Product Recommendation API (FastAPI)",
"status": "running",
"version": "2.0.0",
"endpoints": ["/status", "/predict", "/predict_simple"],
"model_status": "loaded" if engine else "not_loaded",
"model_files_exist": model_files_exist
}
@app.get("/status")
def get_status():
if engine is None:
if model_files_exist:
raise HTTPException(
status_code=503,
detail="Model not loaded - check model files"
)
else:
raise HTTPException(
status_code=503,
detail="Model files not found - this is a template. Add your model files to enable predictions."
)
return {
"status": "ready",
"model_loaded": engine is not None,
"model_files_exist": model_files_exist,
"model_path": MODEL_PATH,
"encoders_dir": ENCODERS_DIR,
"product_master_path": PRODUCT_MASTER_PATH
}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
"""
Predict gohan categories for a company (CID-based)
This is the EXACT same logic as the Gradio version
"""
try:
if engine is None:
if model_files_exist:
error_msg = "Model not loaded - check model files"
else:
error_msg = "Model files not found - this is a template. Add your model files to enable predictions."
raise HTTPException(
status_code=503,
detail=error_msg
)
# Parse input
try:
incoming = json.loads(request.company_data_json)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid JSON format: {str(e)}"
)
# topK handling
if request.topK is not None and request.topK > 0:
incoming["topK"] = int(request.topK)
else:
incoming.setdefault("topK", 30)
# Validate English field presence
missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming]
if missing_en:
raise HTTPException(
status_code=400,
detail=f"Missing required fields: {missing_en}"
)
# Ensure TOTAL_VOLUME is present for the inference engine
if 'TOTAL_VOLUME' not in incoming and 'DELIVERY_NUM' in incoming:
incoming['TOTAL_VOLUME'] = incoming['DELIVERY_NUM']
# Predict
recommendations = engine.predict(incoming)
requested_k = int(incoming.get("topK", 30))
if len(recommendations) > requested_k:
recommendations = recommendations[:requested_k]
return PredictionResponse(
status="success",
model="gohan",
recommendations=recommendations,
metadata={
"model_version": "gohan_cid_v1.0",
"total_categories": len(recommendations),
"requested_k": requested_k
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Prediction error: {str(e)}"
)
@app.post("/predict_simple", response_model=PredictionResponse)
def predict_simple(request: PredictionRequest):
"""Simple endpoint without topK parameter - same as Gradio version"""
return predict(request)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)