tabito12345678910 commited on
Commit
45c2088
ยท
1 Parent(s): 8e060d3

Migrate from Gradio to FastAPI - adds /status endpoint and cold start handling while maintaining exact same API functionality

Browse files
Files changed (3) hide show
  1. README.md +30 -27
  2. app.py +161 -129
  3. requirements.txt +4 -1
README.md CHANGED
@@ -1,31 +1,34 @@
1
- ---
2
- title: Gohan Product Recommendation API
3
- emoji: ๐Ÿš
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: "4.44.0"
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- # ๐Ÿš Gohan Product Recommendation API
13
-
14
- Clean API template for rice product recommendations.
15
-
16
- ## Setup
17
- 1. Add your model files to a 'model/' directory
18
- 2. Add encoder JSON files to 'model/gohan/'
19
- 3. Update paths in app.py and inference script
20
- 4. Deploy to HuggingFace Spaces
21
 
22
  ## Usage
23
- This API provides product recommendations for rice products based on company data.
24
 
25
- ## Current Status
26
- โš ๏ธ **Template Mode** - Add your trained model file to enable predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- **Required files to add:**
29
- - `model/gohan/epoch_012_p50_0.5836.pt` (PyTorch model)
30
- - `model/gohan/*.json` (encoder files) โœ… Already included
31
- - `model/gohan/gohan_pm.csv` (product master data) โœ… Already included
 
1
+ # Gohan FastAPI
2
+
3
+ This is a FastAPI-based product recommendation API deployed on Hugging Face Spaces.
4
+
5
+ ## Endpoints
6
+
7
+ - `GET /` - Root endpoint with API information
8
+ - `GET /status` - Health check and model status
9
+ - `POST /predict` - Main prediction endpoint with topK parameter
10
+ - `POST /predict_simple` - Simple prediction endpoint
 
 
 
 
 
 
 
 
 
 
11
 
12
  ## Usage
 
13
 
14
+ ### Check API Status
15
+ ```bash
16
+ curl "https://your-space-url.hf.space/status"
17
+ ```
18
+
19
+ ### Make Predictions
20
+ ```bash
21
+ curl -X POST "https://your-space-url.hf.space/predict" \
22
+ -H "Content-Type: application/json" \
23
+ -d '{"company_data_json": "{...}", "topK": 10}'
24
+ ```
25
+
26
+ ## Model Loading
27
+
28
+ The API uses FastAPI's lifespan events to load models only once during startup, providing efficient cold start handling.
29
+
30
+ ## Required Model Files
31
 
32
+ - `model/gohan/epoch_*.pt` (PyTorch model)
33
+ - `model/gohan/*.json` (encoder files)
34
+ - `model/gohan/gohan_pm.csv` (product master data)
 
app.py CHANGED
@@ -1,59 +1,149 @@
1
  #!/usr/bin/env python3
2
  """
3
- Gohan (CID) Product Recommendation Gradio App - Light Version
4
- Gradio interface for the Gohan CID inference engine
5
- This is a template - add your model files to make it functional
6
  """
7
 
8
- import gradio as gr
 
 
9
  import json
10
  import os
 
 
11
 
12
- # Model paths - UPDATE THESE when you add your model files
13
- MODEL_PATH = "model/gohan/epoch_009_p50_0.5776.pt"
 
 
 
 
 
 
14
  ENCODERS_DIR = "model/gohan"
15
  PRODUCT_MASTER_PATH = "model/gohan/gohan_pm.csv"
16
 
17
- # Check if model files exist
18
- model_files_exist = all([
19
- os.path.exists(MODEL_PATH),
20
- os.path.exists(ENCODERS_DIR),
21
- os.path.exists(PRODUCT_MASTER_PATH)
22
- ])
23
-
24
- if model_files_exist:
25
- try:
26
- from inference_gohan_cid import GohanCIDInferenceEngine
27
- engine = GohanCIDInferenceEngine(
28
- model_path=MODEL_PATH,
29
- encoders_dir=ENCODERS_DIR,
30
- product_master_path=PRODUCT_MASTER_PATH
31
- )
32
- print("โœ… Gohan CID model loaded successfully!")
33
- except Exception as e:
34
- print(f"โŒ Failed to load Gohan CID model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  engine = None
36
- else:
37
- print("โš ๏ธ Model files not found. This is a template - add your model files to:")
38
- print(f" - {MODEL_PATH}")
39
- print(f" - {ENCODERS_DIR}/*.json")
40
- print(f" - {PRODUCT_MASTER_PATH}")
41
- engine = None
 
 
 
 
 
 
 
42
 
 
43
  REQUIRED_FIELDS_EN = [
44
  'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
45
  'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
46
  'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
47
  ]
48
 
49
- def predict(company_data_json: str, topK: int | None = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """
51
  Predict gohan categories for a company (CID-based)
52
- Args:
53
- company_data_json: JSON string containing company information
54
- topK: Optional override for number of recommendations
55
- Returns:
56
- JSON string with predictions
57
  """
58
  try:
59
  if engine is None:
@@ -62,41 +152,33 @@ def predict(company_data_json: str, topK: int | None = None) -> str:
62
  else:
63
  error_msg = "Model files not found - this is a template. Add your model files to enable predictions."
64
 
65
- return json.dumps({
66
- "status": "error",
67
- "error": error_msg,
68
- "model": "gohan",
69
- "setup_instructions": {
70
- "model_file": MODEL_PATH,
71
- "encoders_dir": ENCODERS_DIR,
72
- "product_master": PRODUCT_MASTER_PATH
73
- }
74
- }, indent=2)
75
 
76
  # Parse input
77
  try:
78
- incoming = json.loads(company_data_json)
79
  except json.JSONDecodeError as e:
80
- return json.dumps({
81
- "status": "error",
82
- "error": f"Invalid JSON format: {str(e)}",
83
- "model": "gohan"
84
- })
85
 
86
  # topK handling
87
- if topK is not None and topK > 0:
88
- incoming["topK"] = int(topK)
89
  else:
90
  incoming.setdefault("topK", 30)
91
 
92
  # Validate English field presence
93
  missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming]
94
  if missing_en:
95
- return json.dumps({
96
- "status": "error",
97
- "error": f"Missing required fields: {missing_en}",
98
- "model": "gohan"
99
- })
100
 
101
  # Predict
102
  recommendations = engine.predict(incoming)
@@ -104,80 +186,30 @@ def predict(company_data_json: str, topK: int | None = None) -> str:
104
  if len(recommendations) > requested_k:
105
  recommendations = recommendations[:requested_k]
106
 
107
- return json.dumps({
108
- "status": "success",
109
- "model": "gohan",
110
- "recommendations": recommendations,
111
- "metadata": {
112
  "model_version": "gohan_cid_v1.0",
113
  "total_categories": len(recommendations),
114
  "requested_k": requested_k
115
  }
116
- }, ensure_ascii=False, indent=2)
117
- except Exception as e:
118
- return json.dumps({
119
- "status": "error",
120
- "error": str(e),
121
- "model": "gohan"
122
- })
123
-
124
- def predict_simple(company_data_json: str) -> str:
125
- return predict(company_data_json, None)
126
-
127
- # Sample input for testing
128
- sample_input = json.dumps({
129
- "INDUSTRY": "finance",
130
- "EMPLOYEE_RANGE": "200-1000",
131
- "FRIDGE_RANGE": "100-500",
132
- "PAYMENT_METHOD": "card",
133
- "PREFECTURE": "osaka",
134
- "FIRST_YEAR": 2019,
135
- "FIRST_MONTH": 6,
136
- "LAT": 34.6937,
137
- "LONG": 135.5023,
138
- "DELIVERY_NUM": 300,
139
- "MEDIAN_GENDER_RATIO": 0.55,
140
- "MODE_TOP_AGE_RANGE_1": "40-49",
141
- "MODE_TOP_AGE_RANGE_2": "30-39",
142
- "MODE_TOP_AGE_RANGE_3": "50-59"
143
- }, indent=2)
144
-
145
- with gr.Blocks(title="Gohan CID Product Recommendation API (Light)") as demo:
146
- gr.Markdown("# ๐Ÿš Gohan Product Recommendation API (Light Template)")
147
-
148
- if model_files_exist:
149
- gr.Markdown("โœ… **Model Status**: Loaded and ready")
150
- else:
151
- gr.Markdown("""
152
- โš ๏ธ **Model Status**: Template mode - add your model files to enable predictions
153
 
154
- **Required files to add:**
155
- - `model/gohan/epoch_009_p50_0.5776.pt` (PyTorch model)
156
- - `model/gohan/*.json` (encoder files)
157
- - `model/gohan/gohan_pm.csv` (product master data)
158
- """)
159
-
160
- gr.Markdown("Enter company data as JSON to get rice product recommendations.")
161
-
162
- with gr.Tab("Main API"):
163
- with gr.Row():
164
- with gr.Column():
165
- inp = gr.Textbox(label="Company Data (JSON)", lines=15, value=sample_input)
166
- topk = gr.Number(label="Top K Results (optional)", minimum=1, maximum=200, step=1, value=None)
167
- btn = gr.Button("Get Recommendations", variant="primary")
168
- with gr.Column():
169
- out = gr.Textbox(label="API Response", lines=20, interactive=False)
170
-
171
- with gr.Tab("Simple API"):
172
- with gr.Row():
173
- with gr.Column():
174
- inp2 = gr.Textbox(label="Company Data (JSON)", lines=15, value=sample_input)
175
- btn2 = gr.Button("Get Recommendations", variant="primary")
176
- with gr.Column():
177
- out2 = gr.Textbox(label="API Response", lines=20, interactive=False)
178
-
179
- btn.click(fn=predict, inputs=[inp, topk], outputs=out)
180
- btn2.click(fn=predict_simple, inputs=inp2, outputs=out2)
181
 
182
  if __name__ == "__main__":
183
- demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Gohan (CID) Product Recommendation FastAPI App
4
+ FastAPI version of the Gohan CID inference engine
5
+ This maintains the exact same functionality as the Gradio version
6
  """
7
 
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+ from contextlib import asynccontextmanager
11
  import json
12
  import os
13
+ import time
14
+ from typing import List, Optional, Dict, Any
15
 
16
+ # Import the existing inference engine
17
+ try:
18
+ from gohan_cid import GohanCIDInferenceEngine
19
+ except ImportError:
20
+ GohanCIDInferenceEngine = None
21
+
22
+ # Model paths - same as Gradio version
23
+ MODEL_PATH = "model/gohan/epoch_028_p50_0.6911.pt" if "gohan" == "yasai" else "model/gohan/epoch_009_p50_0.5776.pt"
24
  ENCODERS_DIR = "model/gohan"
25
  PRODUCT_MASTER_PATH = "model/gohan/gohan_pm.csv"
26
 
27
+ # Pydantic models matching the exact API structure
28
+ class PredictionRequest(BaseModel):
29
+ company_data_json: str
30
+ topK: Optional[int] = None
31
+
32
+ class CategoryRecommendation(BaseModel):
33
+ category_id: int
34
+ category_name: str
35
+ score: float
36
+
37
+ class PredictionResponse(BaseModel):
38
+ status: str
39
+ model: str
40
+ recommendations: List[CategoryRecommendation]
41
+ metadata: Dict[str, Any]
42
+
43
+ # Global variables
44
+ engine = None
45
+ model_files_exist = False
46
+
47
+ @asynccontextmanager
48
+ async def lifespan(app: FastAPI):
49
+ global engine, model_files_exist
50
+
51
+ print(f"๐Ÿš€ Gohan FastAPI is starting. Loading AI model and data...")
52
+ start_time = time.time()
53
+
54
+ # Check if model files exist (same logic as Gradio version)
55
+ model_files_exist = all([
56
+ os.path.exists(MODEL_PATH),
57
+ os.path.exists(ENCODERS_DIR),
58
+ os.path.exists(PRODUCT_MASTER_PATH)
59
+ ])
60
+
61
+ if model_files_exist:
62
+ print(f"๐Ÿ” Checking model files:")
63
+ print(f" - MODEL_PATH: {MODEL_PATH} (exists: {os.path.exists(MODEL_PATH)})")
64
+ print(f" - ENCODERS_DIR: {ENCODERS_DIR} (exists: {os.path.exists(ENCODERS_DIR)})")
65
+ print(f" - PRODUCT_MASTER_PATH: {PRODUCT_MASTER_PATH} (exists: {os.path.exists(PRODUCT_MASTER_PATH)})")
66
+
67
+ try:
68
+ if GohanCIDInferenceEngine:
69
+ engine = GohanCIDInferenceEngine(
70
+ model_path=MODEL_PATH,
71
+ encoders_dir=ENCODERS_DIR,
72
+ product_master_path=PRODUCT_MASTER_PATH
73
+ )
74
+ print(f"โœ… {app_name.title()} CID model loaded successfully!")
75
+ else:
76
+ print(f"โŒ {app_name.title()}CIDInferenceEngine not available")
77
+ engine = None
78
+ except Exception as e:
79
+ print(f"โŒ Failed to load {app_name.title()} CID model: {e}")
80
+ engine = None
81
+ else:
82
+ print(f"โš ๏ธ Model files not found. This is a template - add your model files to:")
83
+ print(f" - {MODEL_PATH}")
84
+ print(f" - {ENCODERS_DIR}/*.json")
85
+ print(f" - {PRODUCT_MASTER_PATH}")
86
  engine = None
87
+
88
+ print(f"โœ… Startup completed in {time.time() - start_time:.2f} seconds.")
89
+ yield
90
+
91
+ print(f"๐Ÿ”„ {app_name.title()} FastAPI is shutting down.")
92
+
93
+ # Initialize FastAPI app with lifespan
94
+ app = FastAPI(
95
+ title=f"{app_name.title()} Product Recommendation API",
96
+ description=f"FastAPI version of the {app_name.title()} recommendation system - maintains exact same functionality as Gradio version",
97
+ version="2.0.0",
98
+ lifespan=lifespan
99
+ )
100
 
101
+ # Target input fields (same as Gradio version)
102
  REQUIRED_FIELDS_EN = [
103
  'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
104
  'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
105
  'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
106
  ]
107
 
108
+ @app.get("/")
109
+ def root():
110
+ return {
111
+ "message": f"๐Ÿš {app_name.title()} Product Recommendation API (FastAPI)",
112
+ "status": "running",
113
+ "version": "2.0.0",
114
+ "endpoints": ["/status", "/predict", "/predict_simple"],
115
+ "model_status": "loaded" if engine else "not_loaded",
116
+ "model_files_exist": model_files_exist
117
+ }
118
+
119
+ @app.get("/status")
120
+ def get_status():
121
+ if engine is None:
122
+ if model_files_exist:
123
+ raise HTTPException(
124
+ status_code=503,
125
+ detail="Model not loaded - check model files"
126
+ )
127
+ else:
128
+ raise HTTPException(
129
+ status_code=503,
130
+ detail="Model files not found - this is a template. Add your model files to enable predictions."
131
+ )
132
+
133
+ return {
134
+ "status": "ready",
135
+ "model_loaded": engine is not None,
136
+ "model_files_exist": model_files_exist,
137
+ "model_path": MODEL_PATH,
138
+ "encoders_dir": ENCODERS_DIR,
139
+ "product_master_path": PRODUCT_MASTER_PATH
140
+ }
141
+
142
+ @app.post("/predict", response_model=PredictionResponse)
143
+ def predict(request: PredictionRequest):
144
  """
145
  Predict gohan categories for a company (CID-based)
146
+ This is the EXACT same logic as the Gradio version
 
 
 
 
147
  """
148
  try:
149
  if engine is None:
 
152
  else:
153
  error_msg = "Model files not found - this is a template. Add your model files to enable predictions."
154
 
155
+ raise HTTPException(
156
+ status_code=503,
157
+ detail=error_msg
158
+ )
 
 
 
 
 
 
159
 
160
  # Parse input
161
  try:
162
+ incoming = json.loads(request.company_data_json)
163
  except json.JSONDecodeError as e:
164
+ raise HTTPException(
165
+ status_code=400,
166
+ detail=f"Invalid JSON format: {str(e)}"
167
+ )
 
168
 
169
  # topK handling
170
+ if request.topK is not None and request.topK > 0:
171
+ incoming["topK"] = int(request.topK)
172
  else:
173
  incoming.setdefault("topK", 30)
174
 
175
  # Validate English field presence
176
  missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming]
177
  if missing_en:
178
+ raise HTTPException(
179
+ status_code=400,
180
+ detail=f"Missing required fields: {missing_en}"
181
+ )
 
182
 
183
  # Predict
184
  recommendations = engine.predict(incoming)
 
186
  if len(recommendations) > requested_k:
187
  recommendations = recommendations[:requested_k]
188
 
189
+ return PredictionResponse(
190
+ status="success",
191
+ model="gohan",
192
+ recommendations=recommendations,
193
+ metadata={
194
  "model_version": "gohan_cid_v1.0",
195
  "total_categories": len(recommendations),
196
  "requested_k": requested_k
197
  }
198
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ except HTTPException:
201
+ raise
202
+ except Exception as e:
203
+ raise HTTPException(
204
+ status_code=500,
205
+ detail=f"Prediction error: {str(e)}"
206
+ )
207
+
208
+ @app.post("/predict_simple", response_model=PredictionResponse)
209
+ def predict_simple(request: PredictionRequest):
210
+ """Simple endpoint without topK parameter - same as Gradio version"""
211
+ return predict(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  if __name__ == "__main__":
214
+ import uvicorn
215
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
- gradio>=4.0.0
 
 
2
  torch==2.8.0
3
  git+https://github.com/Yura52/rtdl.git@main
4
  pandas
5
  numpy
6
  scipy
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ pydantic==2.5.0
4
  torch==2.8.0
5
  git+https://github.com/Yura52/rtdl.git@main
6
  pandas
7
  numpy
8
  scipy
9
+ python-multipart==0.0.6