agentsay commited on
Commit
58ee41f
·
verified ·
1 Parent(s): 9a4a869

Update modelLoanAPI.py

Browse files
Files changed (1) hide show
  1. modelLoanAPI.py +11 -8
modelLoanAPI.py CHANGED
@@ -1,10 +1,10 @@
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import JSONResponse, FileResponse
3
  import pandas as pd
4
  import numpy as np
5
  from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor
6
  from sklearn.preprocessing import RobustScaler, LabelEncoder
7
- from sklearn.metrics import accuracy_score
8
  from sklearn.model_selection import TimeSeriesSplit, GridSearchCV
9
  import matplotlib.pyplot as plt
10
  import json
@@ -36,7 +36,7 @@ async def worker_forecast(input_data: WorkerInput):
36
 
37
  # Load dataset
38
  try:
39
- df = pd.read_csv('extended_worker_dataset_random_reduced.csv')
40
  except FileNotFoundError:
41
  raise HTTPException(status_code=500, detail="Dataset file not found")
42
 
@@ -213,10 +213,13 @@ async def worker_forecast(input_data: WorkerInput):
213
  plot_base64 = base64.b64encode(buf_jpg.getvalue()).decode("utf-8")
214
  results['plot'] = f"data:image/jpeg;base64,{plot_base64}"
215
 
216
- # Save plot to file
217
- plot_filename = f"worker_{worker_id}_forecast.jpg"
218
- with open(plot_filename, "wb") as f:
219
- f.write(base64.b64decode(plot_base64))
 
 
 
220
 
221
  # Worker profile
222
  worker_data = df.copy()
@@ -247,8 +250,8 @@ async def worker_forecast(input_data: WorkerInput):
247
 
248
  @app.get("/worker_forecast/plot/{worker_id}")
249
  async def get_forecast_plot(worker_id: int):
250
- plot_filename = f"worker_{worker_id}_forecast.jpg"
251
  if os.path.exists(plot_filename):
252
  return FileResponse(plot_filename, media_type="image/jpeg", filename=f"worker_{worker_id}_forecast.jpg")
253
  else:
254
- raise HTTPException(status_code=404, detail=f"Plot for worker_id {worker_id} not found")
 
1
+
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.responses import JSONResponse, FileResponse
4
  import pandas as pd
5
  import numpy as np
6
  from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor
7
  from sklearn.preprocessing import RobustScaler, LabelEncoder
 
8
  from sklearn.model_selection import TimeSeriesSplit, GridSearchCV
9
  import matplotlib.pyplot as plt
10
  import json
 
36
 
37
  # Load dataset
38
  try:
39
+ df = pd.read_csv('/app/extended_worker_dataset_random_reduced.csv')
40
  except FileNotFoundError:
41
  raise HTTPException(status_code=500, detail="Dataset file not found")
42
 
 
213
  plot_base64 = base64.b64encode(buf_jpg.getvalue()).decode("utf-8")
214
  results['plot'] = f"data:image/jpeg;base64,{plot_base64}"
215
 
216
+ # Save plot to file in /tmp
217
+ plot_filename = f"/tmp/worker_{worker_id}_forecast.jpg"
218
+ try:
219
+ with open(plot_filename, "wb") as f:
220
+ f.write(base64.b64decode(plot_base64))
221
+ except PermissionError as e:
222
+ raise HTTPException(status_code=500, detail=f"Failed to write plot file: {str(e)}")
223
 
224
  # Worker profile
225
  worker_data = df.copy()
 
250
 
251
  @app.get("/worker_forecast/plot/{worker_id}")
252
  async def get_forecast_plot(worker_id: int):
253
+ plot_filename = f"/tmp/worker_{worker_id}_forecast.jpg"
254
  if os.path.exists(plot_filename):
255
  return FileResponse(plot_filename, media_type="image/jpeg", filename=f"worker_{worker_id}_forecast.jpg")
256
  else:
257
+ raise HTTPException(status_code=404, detail=f"Plot for worker_id {worker_id} not found")