jkim03 commited on
Commit
1e67d7e
·
verified ·
1 Parent(s): 017a9c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -11,6 +11,9 @@ pca = joblib.load("pca.joblib")
11
  with open("cluster_fatigue_map.json") as f:
12
  cluster_fatigue_map = json.load(f)
13
 
 
 
 
14
  feature_cols = [
15
  "AVRR", "SDNN", "RMSSD", "PNN50", "Coefficient_of_Variation",
16
  "Age", "Weight", "Height"
@@ -45,7 +48,7 @@ def debug_predict(input_data: FatigueInput):
45
  "scaler_scale": scaler.scale_.tolist()
46
  }
47
 
48
- @app.post("/predict")
49
  def predict(input_data: FatigueInput):
50
  try:
51
  input_dict = input_data.dict()
@@ -73,3 +76,18 @@ def predict(input_data: FatigueInput):
73
  }
74
  except Exception as e:
75
  return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  with open("cluster_fatigue_map.json") as f:
12
  cluster_fatigue_map = json.load(f)
13
 
14
+ # Load trained regressor
15
+ model = joblib.load("gb_regressor.joblib")
16
+
17
  feature_cols = [
18
  "AVRR", "SDNN", "RMSSD", "PNN50", "Coefficient_of_Variation",
19
  "Age", "Weight", "Height"
 
48
  "scaler_scale": scaler.scale_.tolist()
49
  }
50
 
51
+ @app.post("/predictGMM")
52
  def predict(input_data: FatigueInput):
53
  try:
54
  input_dict = input_data.dict()
 
76
  }
77
  except Exception as e:
78
  return {"error": str(e)}
79
+
80
+ @app.post("/predict")
81
+ def predict(input_data: FatigueInput):
82
+ try:
83
+ input_df = pd.DataFrame([input_data.dict()])[feature_cols]
84
+
85
+ # Predict fatigue level
86
+ predicted_fatigue = model.predict(input_df)[0]
87
+
88
+ return {
89
+ "predicted_fatigue_level": float(predicted_fatigue),
90
+ "rounded_level": int(round(predicted_fatigue))
91
+ }
92
+ except Exception as e:
93
+ return {"error": str(e)}