Spaces:
Runtime error
Runtime error
File size: 7,397 Bytes
9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9af43c4 45c2088 7afe39b 9c48159 7afe39b 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 4478f12 f991fae 9c48159 4478f12 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 9c48159 45c2088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
#!/usr/bin/env python3
"""
Gohan (CID) Product Recommendation FastAPI App
FastAPI version of the Gohan CID inference engine
This maintains the exact same functionality as the Gradio version
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import json
import os
import time
from typing import List, Optional, Dict, Any
# Import the existing inference engine
try:
from inference_gohan_cid import GohanCIDInferenceEngine
except ImportError:
GohanCIDInferenceEngine = None
# Model paths - same as Gradio version
MODEL_PATH = "model/gohan/epoch_009_p50_0.5776.pt"
ENCODERS_DIR = "model/gohan"
PRODUCT_MASTER_PATH = "model/gohan/gohan_pm.csv"
# App name for consistent messaging
app_name = "gohan"
# Pydantic models matching the exact API structure
class PredictionRequest(BaseModel):
company_data_json: str
topK: Optional[int] = None
class CategoryRecommendation(BaseModel):
category_id: int
category_name: str
score: float
class PredictionResponse(BaseModel):
status: str
model: str
recommendations: List[CategoryRecommendation]
metadata: Dict[str, Any]
# Global variables
engine = None
model_files_exist = False
@asynccontextmanager
async def lifespan(app: FastAPI):
global engine, model_files_exist
print(f"π Gohan FastAPI is starting. Loading AI model and data...")
start_time = time.time()
# Check if model files exist (same logic as Gradio version)
model_files_exist = all([
os.path.exists(MODEL_PATH),
os.path.exists(ENCODERS_DIR),
os.path.exists(PRODUCT_MASTER_PATH)
])
if model_files_exist:
print(f"π Checking model files:")
print(f" - MODEL_PATH: {MODEL_PATH} (exists: {os.path.exists(MODEL_PATH)})")
print(f" - ENCODERS_DIR: {ENCODERS_DIR} (exists: {os.path.exists(ENCODERS_DIR)})")
print(f" - PRODUCT_MASTER_PATH: {PRODUCT_MASTER_PATH} (exists: {os.path.exists(PRODUCT_MASTER_PATH)})")
try:
if GohanCIDInferenceEngine:
engine = GohanCIDInferenceEngine(
model_path=MODEL_PATH,
encoders_dir=ENCODERS_DIR,
product_master_path=PRODUCT_MASTER_PATH
)
print(f"β
{app_name.title()} CID model loaded successfully!")
else:
print(f"β {app_name.title()}CIDInferenceEngine not available")
engine = None
except Exception as e:
print(f"β Failed to load {app_name.title()} CID model: {e}")
engine = None
else:
print(f"β οΈ Model files not found. This is a template - add your model files to:")
print(f" - {MODEL_PATH}")
print(f" - {ENCODERS_DIR}/*.json")
print(f" - {PRODUCT_MASTER_PATH}")
engine = None
print(f"β
Startup completed in {time.time() - start_time:.2f} seconds.")
yield
print(f"π {app_name.title()} FastAPI is shutting down.")
# Initialize FastAPI app with lifespan
app = FastAPI(
title=f"{app_name.title()} Product Recommendation API",
description=f"FastAPI version of the {app_name.title()} recommendation system - maintains exact same functionality as Gradio version",
version="2.0.0",
lifespan=lifespan
)
# Target input fields (same as Gradio version)
REQUIRED_FIELDS_EN = [
'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
]
@app.get("/")
def root():
return {
"message": f"π {app_name.title()} Product Recommendation API (FastAPI)",
"status": "running",
"version": "2.0.0",
"endpoints": ["/status", "/predict", "/predict_simple"],
"model_status": "loaded" if engine else "not_loaded",
"model_files_exist": model_files_exist
}
@app.get("/status")
def get_status():
if engine is None:
if model_files_exist:
raise HTTPException(
status_code=503,
detail="Model not loaded - check model files"
)
else:
raise HTTPException(
status_code=503,
detail="Model files not found - this is a template. Add your model files to enable predictions."
)
return {
"status": "ready",
"model_loaded": engine is not None,
"model_files_exist": model_files_exist,
"model_path": MODEL_PATH,
"encoders_dir": ENCODERS_DIR,
"product_master_path": PRODUCT_MASTER_PATH
}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
"""
Predict gohan categories for a company (CID-based)
This is the EXACT same logic as the Gradio version
"""
try:
if engine is None:
if model_files_exist:
error_msg = "Model not loaded - check model files"
else:
error_msg = "Model files not found - this is a template. Add your model files to enable predictions."
raise HTTPException(
status_code=503,
detail=error_msg
)
# Parse input
try:
incoming = json.loads(request.company_data_json)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid JSON format: {str(e)}"
)
# topK handling
if request.topK is not None and request.topK > 0:
incoming["topK"] = int(request.topK)
else:
incoming.setdefault("topK", 30)
# Validate English field presence
missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming]
if missing_en:
raise HTTPException(
status_code=400,
detail=f"Missing required fields: {missing_en}"
)
# Ensure TOTAL_VOLUME is present for the inference engine
if 'TOTAL_VOLUME' not in incoming and 'DELIVERY_NUM' in incoming:
incoming['TOTAL_VOLUME'] = incoming['DELIVERY_NUM']
# Predict
recommendations = engine.predict(incoming)
requested_k = int(incoming.get("topK", 30))
if len(recommendations) > requested_k:
recommendations = recommendations[:requested_k]
return PredictionResponse(
status="success",
model="gohan",
recommendations=recommendations,
metadata={
"model_version": "gohan_cid_v1.0",
"total_categories": len(recommendations),
"requested_k": requested_k
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Prediction error: {str(e)}"
)
@app.post("/predict_simple", response_model=PredictionResponse)
def predict_simple(request: PredictionRequest):
"""Simple endpoint without topK parameter - same as Gradio version"""
return predict(request)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|