Spaces:
Sleeping
Sleeping
Update modelLoanAPI.py
Browse files- modelLoanAPI.py +11 -8
modelLoanAPI.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from fastapi.responses import JSONResponse, FileResponse
|
| 3 |
import pandas as pd
|
| 4 |
import numpy as np
|
| 5 |
from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor
|
| 6 |
from sklearn.preprocessing import RobustScaler, LabelEncoder
|
| 7 |
-
from sklearn.metrics import accuracy_score
|
| 8 |
from sklearn.model_selection import TimeSeriesSplit, GridSearchCV
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import json
|
|
@@ -36,7 +36,7 @@ async def worker_forecast(input_data: WorkerInput):
|
|
| 36 |
|
| 37 |
# Load dataset
|
| 38 |
try:
|
| 39 |
-
df = pd.read_csv('extended_worker_dataset_random_reduced.csv')
|
| 40 |
except FileNotFoundError:
|
| 41 |
raise HTTPException(status_code=500, detail="Dataset file not found")
|
| 42 |
|
|
@@ -213,10 +213,13 @@ async def worker_forecast(input_data: WorkerInput):
|
|
| 213 |
plot_base64 = base64.b64encode(buf_jpg.getvalue()).decode("utf-8")
|
| 214 |
results['plot'] = f"data:image/jpeg;base64,{plot_base64}"
|
| 215 |
|
| 216 |
-
# Save plot to file
|
| 217 |
-
plot_filename = f"worker_{worker_id}_forecast.jpg"
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
# Worker profile
|
| 222 |
worker_data = df.copy()
|
|
@@ -247,8 +250,8 @@ async def worker_forecast(input_data: WorkerInput):
|
|
| 247 |
|
| 248 |
@app.get("/worker_forecast/plot/{worker_id}")
|
| 249 |
async def get_forecast_plot(worker_id: int):
|
| 250 |
-
plot_filename = f"worker_{worker_id}_forecast.jpg"
|
| 251 |
if os.path.exists(plot_filename):
|
| 252 |
return FileResponse(plot_filename, media_type="image/jpeg", filename=f"worker_{worker_id}_forecast.jpg")
|
| 253 |
else:
|
| 254 |
-
raise HTTPException(status_code=404, detail=f"Plot for worker_id {worker_id} not found")
|
|
|
|
| 1 |
+
|
| 2 |
from fastapi import FastAPI, HTTPException
|
| 3 |
from fastapi.responses import JSONResponse, FileResponse
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
| 6 |
from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor
|
| 7 |
from sklearn.preprocessing import RobustScaler, LabelEncoder
|
|
|
|
| 8 |
from sklearn.model_selection import TimeSeriesSplit, GridSearchCV
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import json
|
|
|
|
| 36 |
|
| 37 |
# Load dataset
|
| 38 |
try:
|
| 39 |
+
df = pd.read_csv('/app/extended_worker_dataset_random_reduced.csv')
|
| 40 |
except FileNotFoundError:
|
| 41 |
raise HTTPException(status_code=500, detail="Dataset file not found")
|
| 42 |
|
|
|
|
| 213 |
plot_base64 = base64.b64encode(buf_jpg.getvalue()).decode("utf-8")
|
| 214 |
results['plot'] = f"data:image/jpeg;base64,{plot_base64}"
|
| 215 |
|
| 216 |
+
# Save plot to file in /tmp
|
| 217 |
+
plot_filename = f"/tmp/worker_{worker_id}_forecast.jpg"
|
| 218 |
+
try:
|
| 219 |
+
with open(plot_filename, "wb") as f:
|
| 220 |
+
f.write(base64.b64decode(plot_base64))
|
| 221 |
+
except PermissionError as e:
|
| 222 |
+
raise HTTPException(status_code=500, detail=f"Failed to write plot file: {str(e)}")
|
| 223 |
|
| 224 |
# Worker profile
|
| 225 |
worker_data = df.copy()
|
|
|
|
| 250 |
|
| 251 |
@app.get("/worker_forecast/plot/{worker_id}")
|
| 252 |
async def get_forecast_plot(worker_id: int):
|
| 253 |
+
plot_filename = f"/tmp/worker_{worker_id}_forecast.jpg"
|
| 254 |
if os.path.exists(plot_filename):
|
| 255 |
return FileResponse(plot_filename, media_type="image/jpeg", filename=f"worker_{worker_id}_forecast.jpg")
|
| 256 |
else:
|
| 257 |
+
raise HTTPException(status_code=404, detail=f"Plot for worker_id {worker_id} not found")
|