File size: 7,397 Bytes
9c48159
 
45c2088
 
 
9c48159
 
45c2088
 
 
9c48159
 
45c2088
 
9c48159
45c2088
 
9af43c4
45c2088
 
 
 
7afe39b
9c48159
 
 
7afe39b
 
 
45c2088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c48159
45c2088
 
 
 
 
 
 
 
 
 
 
 
 
9c48159
45c2088
9c48159
 
 
 
 
 
45c2088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c48159
 
45c2088
9c48159
 
 
 
 
 
 
 
45c2088
 
 
 
9c48159
 
 
45c2088
9c48159
45c2088
 
 
 
9c48159
 
45c2088
 
9c48159
 
 
 
 
 
45c2088
 
 
 
9c48159
4478f12
 
 
f991fae
9c48159
4478f12
9c48159
 
 
 
45c2088
 
 
 
 
9c48159
 
 
 
45c2088
9c48159
45c2088
 
 
 
 
 
 
 
 
 
 
 
9c48159
 
45c2088
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#!/usr/bin/env python3
"""
Gohan (CID) Product Recommendation FastAPI App
FastAPI version of the Gohan CID inference engine
This maintains the exact same functionality as the Gradio version
"""

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import json
import os
import time
from typing import List, Optional, Dict, Any

# Import the existing inference engine
try:
    from inference_gohan_cid import GohanCIDInferenceEngine
except ImportError:
    GohanCIDInferenceEngine = None

# Model paths - same as Gradio version
MODEL_PATH = "model/gohan/epoch_009_p50_0.5776.pt"
ENCODERS_DIR = "model/gohan"
PRODUCT_MASTER_PATH = "model/gohan/gohan_pm.csv"

# App name for consistent messaging
app_name = "gohan"

# Pydantic models matching the exact API structure
class PredictionRequest(BaseModel):
    company_data_json: str
    topK: Optional[int] = None

class CategoryRecommendation(BaseModel):
    category_id: int
    category_name: str
    score: float

class PredictionResponse(BaseModel):
    status: str
    model: str
    recommendations: List[CategoryRecommendation]
    metadata: Dict[str, Any]

# Global variables
engine = None
model_files_exist = False

@asynccontextmanager
async def lifespan(app: FastAPI):
    global engine, model_files_exist
    
    print(f"πŸš€ Gohan FastAPI is starting. Loading AI model and data...")
    start_time = time.time()
    
    # Check if model files exist (same logic as Gradio version)
    model_files_exist = all([
        os.path.exists(MODEL_PATH),
        os.path.exists(ENCODERS_DIR),
        os.path.exists(PRODUCT_MASTER_PATH)
    ])
    
    if model_files_exist:
        print(f"πŸ” Checking model files:")
        print(f"   - MODEL_PATH: {MODEL_PATH} (exists: {os.path.exists(MODEL_PATH)})")
        print(f"   - ENCODERS_DIR: {ENCODERS_DIR} (exists: {os.path.exists(ENCODERS_DIR)})")
        print(f"   - PRODUCT_MASTER_PATH: {PRODUCT_MASTER_PATH} (exists: {os.path.exists(PRODUCT_MASTER_PATH)})")
        
        try:
            if GohanCIDInferenceEngine:
                engine = GohanCIDInferenceEngine(
                    model_path=MODEL_PATH,
                    encoders_dir=ENCODERS_DIR,
                    product_master_path=PRODUCT_MASTER_PATH
                )
                print(f"βœ… {app_name.title()} CID model loaded successfully!")
            else:
                print(f"❌ {app_name.title()}CIDInferenceEngine not available")
                engine = None
        except Exception as e:
            print(f"❌ Failed to load {app_name.title()} CID model: {e}")
            engine = None
    else:
        print(f"⚠️  Model files not found. This is a template - add your model files to:")
        print(f"   - {MODEL_PATH}")
        print(f"   - {ENCODERS_DIR}/*.json")
        print(f"   - {PRODUCT_MASTER_PATH}")
        engine = None
    
    print(f"βœ… Startup completed in {time.time() - start_time:.2f} seconds.")
    yield
    
    print(f"πŸ”„ {app_name.title()} FastAPI is shutting down.")

# Initialize FastAPI app with lifespan
app = FastAPI(
    title=f"{app_name.title()} Product Recommendation API",
    description=f"FastAPI version of the {app_name.title()} recommendation system - maintains exact same functionality as Gradio version",
    version="2.0.0",
    lifespan=lifespan
)

# Target input fields (same as Gradio version)
REQUIRED_FIELDS_EN = [
    'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
    'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
    'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
]

@app.get("/")
def root():
    return {
        "message": f"🍚 {app_name.title()} Product Recommendation API (FastAPI)",
        "status": "running",
        "version": "2.0.0",
        "endpoints": ["/status", "/predict", "/predict_simple"],
        "model_status": "loaded" if engine else "not_loaded",
        "model_files_exist": model_files_exist
    }

@app.get("/status")
def get_status():
    if engine is None:
        if model_files_exist:
            raise HTTPException(
                status_code=503,
                detail="Model not loaded - check model files"
            )
        else:
            raise HTTPException(
                status_code=503,
                detail="Model files not found - this is a template. Add your model files to enable predictions."
            )
    
    return {
        "status": "ready",
        "model_loaded": engine is not None,
        "model_files_exist": model_files_exist,
        "model_path": MODEL_PATH,
        "encoders_dir": ENCODERS_DIR,
        "product_master_path": PRODUCT_MASTER_PATH
    }

@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
    """
    Predict gohan categories for a company (CID-based)
    This is the EXACT same logic as the Gradio version
    """
    try:
        if engine is None:
            if model_files_exist:
                error_msg = "Model not loaded - check model files"
            else:
                error_msg = "Model files not found - this is a template. Add your model files to enable predictions."
            
            raise HTTPException(
                status_code=503,
                detail=error_msg
            )

        # Parse input
        try:
            incoming = json.loads(request.company_data_json)
        except json.JSONDecodeError as e:
            raise HTTPException(
                status_code=400,
                detail=f"Invalid JSON format: {str(e)}"
            )

        # topK handling
        if request.topK is not None and request.topK > 0:
            incoming["topK"] = int(request.topK)
        else:
            incoming.setdefault("topK", 30)

        # Validate English field presence
        missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming]
        if missing_en:
            raise HTTPException(
                status_code=400,
                detail=f"Missing required fields: {missing_en}"
            )

        # Ensure TOTAL_VOLUME is present for the inference engine
        if 'TOTAL_VOLUME' not in incoming and 'DELIVERY_NUM' in incoming:
            incoming['TOTAL_VOLUME'] = incoming['DELIVERY_NUM']

        # Predict
        recommendations = engine.predict(incoming)
        requested_k = int(incoming.get("topK", 30))
        if len(recommendations) > requested_k:
            recommendations = recommendations[:requested_k]

        return PredictionResponse(
            status="success",
            model="gohan",
            recommendations=recommendations,
            metadata={
                "model_version": "gohan_cid_v1.0",
                "total_categories": len(recommendations),
                "requested_k": requested_k
            }
        )
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Prediction error: {str(e)}"
        )

@app.post("/predict_simple", response_model=PredictionResponse)
def predict_simple(request: PredictionRequest):
    """Simple endpoint without topK parameter - same as Gradio version"""
    return predict(request)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)