sid22669 commited on
Commit
d7a8003
·
verified ·
1 Parent(s): 81dd9a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -19
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  import joblib
4
  import numpy as np
@@ -6,31 +6,43 @@ import numpy as np
6
  app = FastAPI()
7
 
8
  class InputData(BaseModel):
9
- input1 : float
10
- input2 : float
11
- input3 : float
12
- input4 : float
13
- input5 : float
14
- input6 : float
15
- input7 : float
16
 
17
- try:
 
18
  model = joblib.load('random_forest_model.joblib')
19
  status = 'Loaded'
20
- print(status)
21
- except:
22
- status = "not loaded"
23
- print(status)
24
 
25
  @app.get('/')
26
  def health_check():
27
- return {'status' : f'{status}'}
 
28
 
29
  @app.post('/predict')
30
- def predict(input : InputData):
 
 
 
 
 
31
  data = np.array([[input.input1, input.input2,
32
- input.input3, input.input4,
33
- input.input5, input.input6,
34
- input.input7]])
 
 
35
  prediction = model.predict(data).tolist()
36
- return {'prediction' : prediction[0]}
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import joblib
4
  import numpy as np
 
6
  app = FastAPI()
7
 
8
  class InputData(BaseModel):
9
+ input1: float
10
+ input2: float
11
+ input3: float
12
+ input4: float
13
+ input5: float
14
+ input6: float
15
+ input7: float
16
 
17
+ # Load the model and handle potential errors gracefully
18
+ try:
19
  model = joblib.load('random_forest_model.joblib')
20
  status = 'Loaded'
21
+ print(f"Model {status}")
22
+ except Exception as e:
23
+ status = f"not loaded: {e}"
24
+ print(f"Model {status}")
25
 
26
  @app.get('/')
27
  def health_check():
28
+ # Return the current status of the app (whether the model is loaded or not)
29
+ return {'status': f'{status}'}
30
 
31
  @app.post('/predict')
32
+ def predict(input: InputData):
33
+ # Ensure the model is loaded before making predictions
34
+ if status != 'Loaded':
35
+ raise HTTPException(status_code=500, detail="Model not loaded. Please check the server logs.")
36
+
37
+ # Prepare the input data for prediction
38
  data = np.array([[input.input1, input.input2,
39
+ input.input3, input.input4,
40
+ input.input5, input.input6,
41
+ input.input7]])
42
+
43
+ # Make prediction using the loaded model
44
  prediction = model.predict(data).tolist()
45
+
46
+ # Return the prediction in JSON format
47
+ return {'prediction': prediction[0]}
48
+