yasai-api-light / app.py
tabito12345678910
Fix numpy dependency: specify exact versions and install order in Dockerfile
08a1eb5
#!/usr/bin/env python3
"""
Yasai (CID) Product Recommendation FastAPI App
FastAPI version of the Yasai CID inference engine
This maintains the exact same functionality as the Gradio version
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import json
import os
import time
from typing import List, Optional, Dict, Any
# Import the existing inference engine
try:
from inference_yasai_cid import YasaiCIDInferenceEngine
except ImportError:
YasaiCIDInferenceEngine = None
# Model paths - same as Gradio version
MODEL_PATH = "model/yasai/epoch_028_p50_0.6911.pt"
ENCODERS_DIR = "model/yasai"
PRODUCT_MASTER_PATH = "model/yasai/yasai_pm.csv"
# App name for consistent messaging
app_name = "yasai"
# Pydantic models matching the exact API structure
class PredictionRequest(BaseModel):
company_data_json: str
topK: Optional[int] = None
class CategoryRecommendation(BaseModel):
category_id: int
category_name: str
score: float
class PredictionResponse(BaseModel):
status: str
model: str
recommendations: List[CategoryRecommendation]
metadata: Dict[str, Any]
# Global variables
engine = None
model_files_exist = False
@asynccontextmanager
async def lifespan(app: FastAPI):
global engine, model_files_exist
print(f"πŸš€ Yasai FastAPI is starting. Loading AI model and data...")
start_time = time.time()
# Check if model files exist (same logic as Gradio version)
model_files_exist = all([
os.path.exists(MODEL_PATH),
os.path.exists(ENCODERS_DIR),
os.path.exists(PRODUCT_MASTER_PATH)
])
if model_files_exist:
print(f"πŸ” Checking model files:")
print(f" - MODEL_PATH: {MODEL_PATH} (exists: {os.path.exists(MODEL_PATH)})")
print(f" - ENCODERS_DIR: {ENCODERS_DIR} (exists: {os.path.exists(ENCODERS_DIR)})")
print(f" - PRODUCT_MASTER_PATH: {PRODUCT_MASTER_PATH} (exists: {os.path.exists(PRODUCT_MASTER_PATH)})")
try:
if YasaiCIDInferenceEngine:
engine = YasaiCIDInferenceEngine(
model_path=MODEL_PATH,
encoders_dir=ENCODERS_DIR,
product_master_path=PRODUCT_MASTER_PATH
)
print(f"βœ… {app_name.title()} CID model loaded successfully!")
else:
print(f"❌ {app_name.title()}CIDInferenceEngine not available")
engine = None
except Exception as e:
print(f"❌ Failed to load {app_name.title()} CID model: {e}")
engine = None
else:
print(f"⚠️ Model files not found. This is a template - add your model files to:")
print(f" - {MODEL_PATH}")
print(f" - {ENCODERS_DIR}/*.json")
print(f" - {PRODUCT_MASTER_PATH}")
engine = None
print(f"βœ… Startup completed in {time.time() - start_time:.2f} seconds.")
yield
print(f"πŸ”„ {app_name.title()} FastAPI is shutting down.")
# Initialize FastAPI app with lifespan
app = FastAPI(
title=f"{app_name.title()} Product Recommendation API",
description=f"FastAPI version of the {app_name.title()} recommendation system - maintains exact same functionality as Gradio version",
version="2.0.0",
lifespan=lifespan
)
# Target input fields (same as Gradio version)
REQUIRED_FIELDS_EN = [
'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
]
@app.get("/")
def root():
return {
"message": f"🍚 {app_name.title()} Product Recommendation API (FastAPI)",
"status": "running",
"version": "2.0.0",
"endpoints": ["/status", "/predict", "/predict_simple"],
"model_status": "loaded" if engine else "not_loaded",
"model_files_exist": model_files_exist
}
@app.get("/status")
def get_status():
if engine is None:
if model_files_exist:
raise HTTPException(
status_code=503,
detail="Model not loaded - check model files"
)
else:
raise HTTPException(
status_code=503,
detail="Model files not found - this is a template. Add your model files to enable predictions."
)
return {
"status": "ready",
"model_loaded": engine is not None,
"model_files_exist": model_files_exist,
"model_path": MODEL_PATH,
"encoders_dir": ENCODERS_DIR,
"product_master_path": PRODUCT_MASTER_PATH
}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
"""
Predict yasai categories for a company (CID-based)
This is the EXACT same logic as the Gradio version
"""
try:
if engine is None:
if model_files_exist:
error_msg = "Model not loaded - check model files"
else:
error_msg = "Model files not found - this is a template. Add your model files to enable predictions."
raise HTTPException(
status_code=503,
detail=error_msg
)
# Parse input
try:
incoming = json.loads(request.company_data_json)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid JSON format: {str(e)}"
)
print(f"πŸ” Received data: {incoming}")
print(f"🎯 topK from request: {request.topK}")
# topK handling
if request.topK is not None and request.topK > 0:
incoming["topK"] = int(request.topK)
else:
incoming.setdefault("topK", 30)
print(f"🎯 Final topK: {incoming.get('topK')}")
# Validate English field presence
missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming]
if missing_en:
print(f"❌ Missing required fields: {missing_en}")
raise HTTPException(
status_code=400,
detail=f"Missing required fields: {missing_en}"
)
print(f"βœ… All required fields present")
# Ensure TOTAL_VOLUME is present for the inference engine
if 'TOTAL_VOLUME' not in incoming and 'DELIVERY_NUM' in incoming:
incoming['TOTAL_VOLUME'] = incoming['DELIVERY_NUM']
print(f"πŸ”§ Mapped DELIVERY_NUM to TOTAL_VOLUME: {incoming['TOTAL_VOLUME']}")
print(f"πŸ”§ Data for inference: {incoming}")
# Predict
try:
recommendations = engine.predict(incoming)
print(f"βœ… Prediction successful, got {len(recommendations)} recommendations")
except Exception as e:
print(f"❌ Prediction failed: {e}")
raise HTTPException(
status_code=500,
detail=f"Prediction error: {str(e)}"
)
requested_k = int(incoming.get("topK", 30))
if len(recommendations) > requested_k:
recommendations = recommendations[:requested_k]
return PredictionResponse(
status="success",
model="yasai",
recommendations=recommendations,
metadata={
"model_version": "yasai_cid_v1.0",
"total_categories": len(recommendations),
"requested_k": requested_k
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Prediction error: {str(e)}"
)
@app.post("/predict_simple", response_model=PredictionResponse)
def predict_simple(request: PredictionRequest):
"""Simple endpoint without topK parameter - same as Gradio version"""
return predict(request)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)