File size: 8,155 Bytes
d7aa29e
 
e94adc9
 
 
d7aa29e
 
e94adc9
 
 
d7aa29e
 
e94adc9
 
d7aa29e
e94adc9
 
2ddea2f
e94adc9
 
 
 
f3afecf
d7aa29e
 
 
f3afecf
 
 
e94adc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb9bda6
e94adc9
 
eb9bda6
e94adc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7aa29e
e94adc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7aa29e
 
 
 
 
 
e94adc9
 
 
 
 
 
 
 
 
d7aa29e
e94adc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7aa29e
 
e94adc9
d7aa29e
 
 
 
 
 
 
 
e94adc9
 
 
 
d7aa29e
 
 
e94adc9
d7aa29e
e94adc9
 
 
 
d7aa29e
a62750f
 
 
d7aa29e
e94adc9
 
d7aa29e
 
 
a62750f
 
d7aa29e
 
 
a62750f
e94adc9
 
 
 
d7aa29e
a62750f
 
08a1eb5
 
 
 
 
 
a62750f
d7aa29e
a62750f
08a1eb5
a62750f
 
 
 
 
 
 
 
e94adc9
d7aa29e
 
 
e94adc9
 
 
 
 
d7aa29e
 
 
 
e94adc9
d7aa29e
e94adc9
 
 
 
 
 
 
 
 
 
 
 
d7aa29e
 
e94adc9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
#!/usr/bin/env python3
"""
Yasai (CID) Product Recommendation FastAPI App
FastAPI version of the Yasai CID inference engine
This maintains the exact same functionality as the Gradio version
"""

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import json
import os
import time
from typing import List, Optional, Dict, Any

# Import the existing inference engine
try:
    from inference_yasai_cid import YasaiCIDInferenceEngine
except ImportError:
    YasaiCIDInferenceEngine = None

# Model paths - same as Gradio version
MODEL_PATH = "model/yasai/epoch_028_p50_0.6911.pt"
ENCODERS_DIR = "model/yasai"
PRODUCT_MASTER_PATH = "model/yasai/yasai_pm.csv"

# App name for consistent messaging
app_name = "yasai"

# Pydantic models matching the exact API structure
class PredictionRequest(BaseModel):
    company_data_json: str
    topK: Optional[int] = None

class CategoryRecommendation(BaseModel):
    category_id: int
    category_name: str
    score: float

class PredictionResponse(BaseModel):
    status: str
    model: str
    recommendations: List[CategoryRecommendation]
    metadata: Dict[str, Any]

# Global variables
engine = None
model_files_exist = False

@asynccontextmanager
async def lifespan(app: FastAPI):
    global engine, model_files_exist
    
    print(f"πŸš€ Yasai FastAPI is starting. Loading AI model and data...")
    start_time = time.time()
    
    # Check if model files exist (same logic as Gradio version)
    model_files_exist = all([
        os.path.exists(MODEL_PATH),
        os.path.exists(ENCODERS_DIR),
        os.path.exists(PRODUCT_MASTER_PATH)
    ])
    
    if model_files_exist:
        print(f"πŸ” Checking model files:")
        print(f"   - MODEL_PATH: {MODEL_PATH} (exists: {os.path.exists(MODEL_PATH)})")
        print(f"   - ENCODERS_DIR: {ENCODERS_DIR} (exists: {os.path.exists(ENCODERS_DIR)})")
        print(f"   - PRODUCT_MASTER_PATH: {PRODUCT_MASTER_PATH} (exists: {os.path.exists(PRODUCT_MASTER_PATH)})")
        
        try:
            if YasaiCIDInferenceEngine:
                engine = YasaiCIDInferenceEngine(
                    model_path=MODEL_PATH,
                    encoders_dir=ENCODERS_DIR,
                    product_master_path=PRODUCT_MASTER_PATH
                )
                print(f"βœ… {app_name.title()} CID model loaded successfully!")
            else:
                print(f"❌ {app_name.title()}CIDInferenceEngine not available")
                engine = None
        except Exception as e:
            print(f"❌ Failed to load {app_name.title()} CID model: {e}")
            engine = None
    else:
        print(f"⚠️  Model files not found. This is a template - add your model files to:")
        print(f"   - {MODEL_PATH}")
        print(f"   - {ENCODERS_DIR}/*.json")
        print(f"   - {PRODUCT_MASTER_PATH}")
        engine = None
    
    print(f"βœ… Startup completed in {time.time() - start_time:.2f} seconds.")
    yield
    
    print(f"πŸ”„ {app_name.title()} FastAPI is shutting down.")

# Initialize FastAPI app with lifespan
app = FastAPI(
    title=f"{app_name.title()} Product Recommendation API",
    description=f"FastAPI version of the {app_name.title()} recommendation system - maintains exact same functionality as Gradio version",
    version="2.0.0",
    lifespan=lifespan
)

# Target input fields (same as Gradio version)
REQUIRED_FIELDS_EN = [
    'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
    'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
    'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
]

@app.get("/")
def root():
    return {
        "message": f"🍚 {app_name.title()} Product Recommendation API (FastAPI)",
        "status": "running",
        "version": "2.0.0",
        "endpoints": ["/status", "/predict", "/predict_simple"],
        "model_status": "loaded" if engine else "not_loaded",
        "model_files_exist": model_files_exist
    }

@app.get("/status")
def get_status():
    if engine is None:
        if model_files_exist:
            raise HTTPException(
                status_code=503,
                detail="Model not loaded - check model files"
            )
        else:
            raise HTTPException(
                status_code=503,
                detail="Model files not found - this is a template. Add your model files to enable predictions."
            )
    
    return {
        "status": "ready",
        "model_loaded": engine is not None,
        "model_files_exist": model_files_exist,
        "model_path": MODEL_PATH,
        "encoders_dir": ENCODERS_DIR,
        "product_master_path": PRODUCT_MASTER_PATH
    }

@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
    """
    Predict yasai categories for a company (CID-based)
    This is the EXACT same logic as the Gradio version
    """
    try:
        if engine is None:
            if model_files_exist:
                error_msg = "Model not loaded - check model files"
            else:
                error_msg = "Model files not found - this is a template. Add your model files to enable predictions."
            
            raise HTTPException(
                status_code=503,
                detail=error_msg
            )

        # Parse input
        try:
            incoming = json.loads(request.company_data_json)
        except json.JSONDecodeError as e:
            raise HTTPException(
                status_code=400,
                detail=f"Invalid JSON format: {str(e)}"
            )

        print(f"πŸ” Received data: {incoming}")
        print(f"🎯 topK from request: {request.topK}")

        # topK handling
        if request.topK is not None and request.topK > 0:
            incoming["topK"] = int(request.topK)
        else:
            incoming.setdefault("topK", 30)

        print(f"🎯 Final topK: {incoming.get('topK')}")

        # Validate English field presence
        missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming]
        if missing_en:
            print(f"❌ Missing required fields: {missing_en}")
            raise HTTPException(
                status_code=400,
                detail=f"Missing required fields: {missing_en}"
            )

        print(f"βœ… All required fields present")

        # Ensure TOTAL_VOLUME is present for the inference engine
        if 'TOTAL_VOLUME' not in incoming and 'DELIVERY_NUM' in incoming:
            incoming['TOTAL_VOLUME'] = incoming['DELIVERY_NUM']
            print(f"πŸ”§ Mapped DELIVERY_NUM to TOTAL_VOLUME: {incoming['TOTAL_VOLUME']}")

        print(f"πŸ”§ Data for inference: {incoming}")

        # Predict
        try:
            recommendations = engine.predict(incoming)
            print(f"βœ… Prediction successful, got {len(recommendations)} recommendations")
        except Exception as e:
            print(f"❌ Prediction failed: {e}")
            raise HTTPException(
                status_code=500,
                detail=f"Prediction error: {str(e)}"
            )
        
        requested_k = int(incoming.get("topK", 30))
        if len(recommendations) > requested_k:
            recommendations = recommendations[:requested_k]

        return PredictionResponse(
            status="success",
            model="yasai",
            recommendations=recommendations,
            metadata={
                "model_version": "yasai_cid_v1.0",
                "total_categories": len(recommendations),
                "requested_k": requested_k
            }
        )
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Prediction error: {str(e)}"
        )

@app.post("/predict_simple", response_model=PredictionResponse)
def predict_simple(request: PredictionRequest):
    """Simple endpoint without topK parameter - same as Gradio version"""
    return predict(request)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)