Spaces:
Runtime error
Runtime error
tabito12345678910
commited on
Commit
ยท
45c2088
1
Parent(s):
8e060d3
Migrate from Gradio to FastAPI - adds /status endpoint and cold start handling while maintaining exact same API functionality
Browse files- README.md +30 -27
- app.py +161 -129
- requirements.txt +4 -1
README.md
CHANGED
|
@@ -1,31 +1,34 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
# ๐ Gohan Product Recommendation API
|
| 13 |
-
|
| 14 |
-
Clean API template for rice product recommendations.
|
| 15 |
-
|
| 16 |
-
## Setup
|
| 17 |
-
1. Add your model files to a 'model/' directory
|
| 18 |
-
2. Add encoder JSON files to 'model/gohan/'
|
| 19 |
-
3. Update paths in app.py and inference script
|
| 20 |
-
4. Deploy to HuggingFace Spaces
|
| 21 |
|
| 22 |
## Usage
|
| 23 |
-
This API provides product recommendations for rice products based on company data.
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
- `model/gohan
|
| 30 |
-
- `model/gohan
|
| 31 |
-
- `model/gohan/gohan_pm.csv` (product master data) โ
Already included
|
|
|
|
| 1 |
+
# Gohan FastAPI
|
| 2 |
+
|
| 3 |
+
This is a FastAPI-based product recommendation API deployed on Hugging Face Spaces.
|
| 4 |
+
|
| 5 |
+
## Endpoints
|
| 6 |
+
|
| 7 |
+
- `GET /` - Root endpoint with API information
|
| 8 |
+
- `GET /status` - Health check and model status
|
| 9 |
+
- `POST /predict` - Main prediction endpoint with topK parameter
|
| 10 |
+
- `POST /predict_simple` - Simple prediction endpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
## Usage
|
|
|
|
| 13 |
|
| 14 |
+
### Check API Status
|
| 15 |
+
```bash
|
| 16 |
+
curl "https://your-space-url.hf.space/status"
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
### Make Predictions
|
| 20 |
+
```bash
|
| 21 |
+
curl -X POST "https://your-space-url.hf.space/predict" \
|
| 22 |
+
-H "Content-Type: application/json" \
|
| 23 |
+
-d '{"company_data_json": "{...}", "topK": 10}'
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Model Loading
|
| 27 |
+
|
| 28 |
+
The API uses FastAPI's lifespan events to load models only once during startup, providing efficient cold start handling.
|
| 29 |
+
|
| 30 |
+
## Required Model Files
|
| 31 |
|
| 32 |
+
- `model/gohan/epoch_*.pt` (PyTorch model)
|
| 33 |
+
- `model/gohan/*.json` (encoder files)
|
| 34 |
+
- `model/gohan/gohan_pm.csv` (product master data)
|
|
|
app.py
CHANGED
|
@@ -1,59 +1,149 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Gohan (CID) Product Recommendation
|
| 4 |
-
|
| 5 |
-
This
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
import
|
|
|
|
|
|
|
| 9 |
import json
|
| 10 |
import os
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
ENCODERS_DIR = "model/gohan"
|
| 15 |
PRODUCT_MASTER_PATH = "model/gohan/gohan_pm.csv"
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
engine = None
|
| 36 |
-
|
| 37 |
-
print("
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
print(f"
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
|
|
|
| 43 |
REQUIRED_FIELDS_EN = [
|
| 44 |
'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
|
| 45 |
'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
|
| 46 |
'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
|
| 47 |
]
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""
|
| 51 |
Predict gohan categories for a company (CID-based)
|
| 52 |
-
|
| 53 |
-
company_data_json: JSON string containing company information
|
| 54 |
-
topK: Optional override for number of recommendations
|
| 55 |
-
Returns:
|
| 56 |
-
JSON string with predictions
|
| 57 |
"""
|
| 58 |
try:
|
| 59 |
if engine is None:
|
|
@@ -62,41 +152,33 @@ def predict(company_data_json: str, topK: int | None = None) -> str:
|
|
| 62 |
else:
|
| 63 |
error_msg = "Model files not found - this is a template. Add your model files to enable predictions."
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
"setup_instructions": {
|
| 70 |
-
"model_file": MODEL_PATH,
|
| 71 |
-
"encoders_dir": ENCODERS_DIR,
|
| 72 |
-
"product_master": PRODUCT_MASTER_PATH
|
| 73 |
-
}
|
| 74 |
-
}, indent=2)
|
| 75 |
|
| 76 |
# Parse input
|
| 77 |
try:
|
| 78 |
-
incoming = json.loads(company_data_json)
|
| 79 |
except json.JSONDecodeError as e:
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
})
|
| 85 |
|
| 86 |
# topK handling
|
| 87 |
-
if topK is not None and topK > 0:
|
| 88 |
-
incoming["topK"] = int(topK)
|
| 89 |
else:
|
| 90 |
incoming.setdefault("topK", 30)
|
| 91 |
|
| 92 |
# Validate English field presence
|
| 93 |
missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming]
|
| 94 |
if missing_en:
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
})
|
| 100 |
|
| 101 |
# Predict
|
| 102 |
recommendations = engine.predict(incoming)
|
|
@@ -104,80 +186,30 @@ def predict(company_data_json: str, topK: int | None = None) -> str:
|
|
| 104 |
if len(recommendations) > requested_k:
|
| 105 |
recommendations = recommendations[:requested_k]
|
| 106 |
|
| 107 |
-
return
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
"model_version": "gohan_cid_v1.0",
|
| 113 |
"total_categories": len(recommendations),
|
| 114 |
"requested_k": requested_k
|
| 115 |
}
|
| 116 |
-
|
| 117 |
-
except Exception as e:
|
| 118 |
-
return json.dumps({
|
| 119 |
-
"status": "error",
|
| 120 |
-
"error": str(e),
|
| 121 |
-
"model": "gohan"
|
| 122 |
-
})
|
| 123 |
-
|
| 124 |
-
def predict_simple(company_data_json: str) -> str:
|
| 125 |
-
return predict(company_data_json, None)
|
| 126 |
-
|
| 127 |
-
# Sample input for testing
|
| 128 |
-
sample_input = json.dumps({
|
| 129 |
-
"INDUSTRY": "finance",
|
| 130 |
-
"EMPLOYEE_RANGE": "200-1000",
|
| 131 |
-
"FRIDGE_RANGE": "100-500",
|
| 132 |
-
"PAYMENT_METHOD": "card",
|
| 133 |
-
"PREFECTURE": "osaka",
|
| 134 |
-
"FIRST_YEAR": 2019,
|
| 135 |
-
"FIRST_MONTH": 6,
|
| 136 |
-
"LAT": 34.6937,
|
| 137 |
-
"LONG": 135.5023,
|
| 138 |
-
"DELIVERY_NUM": 300,
|
| 139 |
-
"MEDIAN_GENDER_RATIO": 0.55,
|
| 140 |
-
"MODE_TOP_AGE_RANGE_1": "40-49",
|
| 141 |
-
"MODE_TOP_AGE_RANGE_2": "30-39",
|
| 142 |
-
"MODE_TOP_AGE_RANGE_3": "50-59"
|
| 143 |
-
}, indent=2)
|
| 144 |
-
|
| 145 |
-
with gr.Blocks(title="Gohan CID Product Recommendation API (Light)") as demo:
|
| 146 |
-
gr.Markdown("# ๐ Gohan Product Recommendation API (Light Template)")
|
| 147 |
-
|
| 148 |
-
if model_files_exist:
|
| 149 |
-
gr.Markdown("โ
**Model Status**: Loaded and ready")
|
| 150 |
-
else:
|
| 151 |
-
gr.Markdown("""
|
| 152 |
-
โ ๏ธ **Model Status**: Template mode - add your model files to enable predictions
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
topk = gr.Number(label="Top K Results (optional)", minimum=1, maximum=200, step=1, value=None)
|
| 167 |
-
btn = gr.Button("Get Recommendations", variant="primary")
|
| 168 |
-
with gr.Column():
|
| 169 |
-
out = gr.Textbox(label="API Response", lines=20, interactive=False)
|
| 170 |
-
|
| 171 |
-
with gr.Tab("Simple API"):
|
| 172 |
-
with gr.Row():
|
| 173 |
-
with gr.Column():
|
| 174 |
-
inp2 = gr.Textbox(label="Company Data (JSON)", lines=15, value=sample_input)
|
| 175 |
-
btn2 = gr.Button("Get Recommendations", variant="primary")
|
| 176 |
-
with gr.Column():
|
| 177 |
-
out2 = gr.Textbox(label="API Response", lines=20, interactive=False)
|
| 178 |
-
|
| 179 |
-
btn.click(fn=predict, inputs=[inp, topk], outputs=out)
|
| 180 |
-
btn2.click(fn=predict_simple, inputs=inp2, outputs=out2)
|
| 181 |
|
| 182 |
if __name__ == "__main__":
|
| 183 |
-
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Gohan (CID) Product Recommendation FastAPI App
|
| 4 |
+
FastAPI version of the Gohan CID inference engine
|
| 5 |
+
This maintains the exact same functionality as the Gradio version
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
from fastapi import FastAPI, HTTPException
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
from contextlib import asynccontextmanager
|
| 11 |
import json
|
| 12 |
import os
|
| 13 |
+
import time
|
| 14 |
+
from typing import List, Optional, Dict, Any
|
| 15 |
|
| 16 |
+
# Import the existing inference engine
|
| 17 |
+
try:
|
| 18 |
+
from gohan_cid import GohanCIDInferenceEngine
|
| 19 |
+
except ImportError:
|
| 20 |
+
GohanCIDInferenceEngine = None
|
| 21 |
+
|
| 22 |
+
# Model paths - same as Gradio version
|
| 23 |
+
MODEL_PATH = "model/gohan/epoch_028_p50_0.6911.pt" if "gohan" == "yasai" else "model/gohan/epoch_009_p50_0.5776.pt"
|
| 24 |
ENCODERS_DIR = "model/gohan"
|
| 25 |
PRODUCT_MASTER_PATH = "model/gohan/gohan_pm.csv"
|
| 26 |
|
| 27 |
+
# Pydantic models matching the exact API structure
|
| 28 |
+
class PredictionRequest(BaseModel):
|
| 29 |
+
company_data_json: str
|
| 30 |
+
topK: Optional[int] = None
|
| 31 |
+
|
| 32 |
+
class CategoryRecommendation(BaseModel):
|
| 33 |
+
category_id: int
|
| 34 |
+
category_name: str
|
| 35 |
+
score: float
|
| 36 |
+
|
| 37 |
+
class PredictionResponse(BaseModel):
|
| 38 |
+
status: str
|
| 39 |
+
model: str
|
| 40 |
+
recommendations: List[CategoryRecommendation]
|
| 41 |
+
metadata: Dict[str, Any]
|
| 42 |
+
|
| 43 |
+
# Global variables
|
| 44 |
+
engine = None
|
| 45 |
+
model_files_exist = False
|
| 46 |
+
|
| 47 |
+
@asynccontextmanager
|
| 48 |
+
async def lifespan(app: FastAPI):
|
| 49 |
+
global engine, model_files_exist
|
| 50 |
+
|
| 51 |
+
print(f"๐ Gohan FastAPI is starting. Loading AI model and data...")
|
| 52 |
+
start_time = time.time()
|
| 53 |
+
|
| 54 |
+
# Check if model files exist (same logic as Gradio version)
|
| 55 |
+
model_files_exist = all([
|
| 56 |
+
os.path.exists(MODEL_PATH),
|
| 57 |
+
os.path.exists(ENCODERS_DIR),
|
| 58 |
+
os.path.exists(PRODUCT_MASTER_PATH)
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
if model_files_exist:
|
| 62 |
+
print(f"๐ Checking model files:")
|
| 63 |
+
print(f" - MODEL_PATH: {MODEL_PATH} (exists: {os.path.exists(MODEL_PATH)})")
|
| 64 |
+
print(f" - ENCODERS_DIR: {ENCODERS_DIR} (exists: {os.path.exists(ENCODERS_DIR)})")
|
| 65 |
+
print(f" - PRODUCT_MASTER_PATH: {PRODUCT_MASTER_PATH} (exists: {os.path.exists(PRODUCT_MASTER_PATH)})")
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
if GohanCIDInferenceEngine:
|
| 69 |
+
engine = GohanCIDInferenceEngine(
|
| 70 |
+
model_path=MODEL_PATH,
|
| 71 |
+
encoders_dir=ENCODERS_DIR,
|
| 72 |
+
product_master_path=PRODUCT_MASTER_PATH
|
| 73 |
+
)
|
| 74 |
+
print(f"โ
{app_name.title()} CID model loaded successfully!")
|
| 75 |
+
else:
|
| 76 |
+
print(f"โ {app_name.title()}CIDInferenceEngine not available")
|
| 77 |
+
engine = None
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"โ Failed to load {app_name.title()} CID model: {e}")
|
| 80 |
+
engine = None
|
| 81 |
+
else:
|
| 82 |
+
print(f"โ ๏ธ Model files not found. This is a template - add your model files to:")
|
| 83 |
+
print(f" - {MODEL_PATH}")
|
| 84 |
+
print(f" - {ENCODERS_DIR}/*.json")
|
| 85 |
+
print(f" - {PRODUCT_MASTER_PATH}")
|
| 86 |
engine = None
|
| 87 |
+
|
| 88 |
+
print(f"โ
Startup completed in {time.time() - start_time:.2f} seconds.")
|
| 89 |
+
yield
|
| 90 |
+
|
| 91 |
+
print(f"๐ {app_name.title()} FastAPI is shutting down.")
|
| 92 |
+
|
| 93 |
+
# Initialize FastAPI app with lifespan
|
| 94 |
+
app = FastAPI(
|
| 95 |
+
title=f"{app_name.title()} Product Recommendation API",
|
| 96 |
+
description=f"FastAPI version of the {app_name.title()} recommendation system - maintains exact same functionality as Gradio version",
|
| 97 |
+
version="2.0.0",
|
| 98 |
+
lifespan=lifespan
|
| 99 |
+
)
|
| 100 |
|
| 101 |
+
# Target input fields (same as Gradio version)
|
| 102 |
REQUIRED_FIELDS_EN = [
|
| 103 |
'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
|
| 104 |
'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
|
| 105 |
'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
|
| 106 |
]
|
| 107 |
|
| 108 |
+
@app.get("/")
|
| 109 |
+
def root():
|
| 110 |
+
return {
|
| 111 |
+
"message": f"๐ {app_name.title()} Product Recommendation API (FastAPI)",
|
| 112 |
+
"status": "running",
|
| 113 |
+
"version": "2.0.0",
|
| 114 |
+
"endpoints": ["/status", "/predict", "/predict_simple"],
|
| 115 |
+
"model_status": "loaded" if engine else "not_loaded",
|
| 116 |
+
"model_files_exist": model_files_exist
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
@app.get("/status")
|
| 120 |
+
def get_status():
|
| 121 |
+
if engine is None:
|
| 122 |
+
if model_files_exist:
|
| 123 |
+
raise HTTPException(
|
| 124 |
+
status_code=503,
|
| 125 |
+
detail="Model not loaded - check model files"
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
raise HTTPException(
|
| 129 |
+
status_code=503,
|
| 130 |
+
detail="Model files not found - this is a template. Add your model files to enable predictions."
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"status": "ready",
|
| 135 |
+
"model_loaded": engine is not None,
|
| 136 |
+
"model_files_exist": model_files_exist,
|
| 137 |
+
"model_path": MODEL_PATH,
|
| 138 |
+
"encoders_dir": ENCODERS_DIR,
|
| 139 |
+
"product_master_path": PRODUCT_MASTER_PATH
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
@app.post("/predict", response_model=PredictionResponse)
|
| 143 |
+
def predict(request: PredictionRequest):
|
| 144 |
"""
|
| 145 |
Predict gohan categories for a company (CID-based)
|
| 146 |
+
This is the EXACT same logic as the Gradio version
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
"""
|
| 148 |
try:
|
| 149 |
if engine is None:
|
|
|
|
| 152 |
else:
|
| 153 |
error_msg = "Model files not found - this is a template. Add your model files to enable predictions."
|
| 154 |
|
| 155 |
+
raise HTTPException(
|
| 156 |
+
status_code=503,
|
| 157 |
+
detail=error_msg
|
| 158 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
# Parse input
|
| 161 |
try:
|
| 162 |
+
incoming = json.loads(request.company_data_json)
|
| 163 |
except json.JSONDecodeError as e:
|
| 164 |
+
raise HTTPException(
|
| 165 |
+
status_code=400,
|
| 166 |
+
detail=f"Invalid JSON format: {str(e)}"
|
| 167 |
+
)
|
|
|
|
| 168 |
|
| 169 |
# topK handling
|
| 170 |
+
if request.topK is not None and request.topK > 0:
|
| 171 |
+
incoming["topK"] = int(request.topK)
|
| 172 |
else:
|
| 173 |
incoming.setdefault("topK", 30)
|
| 174 |
|
| 175 |
# Validate English field presence
|
| 176 |
missing_en = [f for f in REQUIRED_FIELDS_EN if f not in incoming]
|
| 177 |
if missing_en:
|
| 178 |
+
raise HTTPException(
|
| 179 |
+
status_code=400,
|
| 180 |
+
detail=f"Missing required fields: {missing_en}"
|
| 181 |
+
)
|
|
|
|
| 182 |
|
| 183 |
# Predict
|
| 184 |
recommendations = engine.predict(incoming)
|
|
|
|
| 186 |
if len(recommendations) > requested_k:
|
| 187 |
recommendations = recommendations[:requested_k]
|
| 188 |
|
| 189 |
+
return PredictionResponse(
|
| 190 |
+
status="success",
|
| 191 |
+
model="gohan",
|
| 192 |
+
recommendations=recommendations,
|
| 193 |
+
metadata={
|
| 194 |
"model_version": "gohan_cid_v1.0",
|
| 195 |
"total_categories": len(recommendations),
|
| 196 |
"requested_k": requested_k
|
| 197 |
}
|
| 198 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
except HTTPException:
|
| 201 |
+
raise
|
| 202 |
+
except Exception as e:
|
| 203 |
+
raise HTTPException(
|
| 204 |
+
status_code=500,
|
| 205 |
+
detail=f"Prediction error: {str(e)}"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
@app.post("/predict_simple", response_model=PredictionResponse)
|
| 209 |
+
def predict_simple(request: PredictionRequest):
|
| 210 |
+
"""Simple endpoint without topK parameter - same as Gradio version"""
|
| 211 |
+
return predict(request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
if __name__ == "__main__":
|
| 214 |
+
import uvicorn
|
| 215 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
requirements.txt
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
torch==2.8.0
|
| 3 |
git+https://github.com/Yura52/rtdl.git@main
|
| 4 |
pandas
|
| 5 |
numpy
|
| 6 |
scipy
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.1
|
| 2 |
+
uvicorn[standard]==0.24.0
|
| 3 |
+
pydantic==2.5.0
|
| 4 |
torch==2.8.0
|
| 5 |
git+https://github.com/Yura52/rtdl.git@main
|
| 6 |
pandas
|
| 7 |
numpy
|
| 8 |
scipy
|
| 9 |
+
python-multipart==0.0.6
|