jkim03 commited on
Commit
65afcb7
·
verified ·
1 Parent(s): c126e79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -1
app.py CHANGED
@@ -32,11 +32,27 @@ app = FastAPI()
32
  @app.post("/predict")
33
  def predict(input_data: FatigueInput):
34
  try:
 
35
  input_dict = input_data.dict()
36
  input_df = pd.DataFrame([input_dict])[feature_cols]
 
 
37
  scaled_input = scaler.transform(input_df)
 
 
38
  cluster = gmm.predict(scaled_input)[0]
 
 
 
39
  fatigue_level = cluster_fatigue_map[str(cluster)]
40
- return {"cluster": int(cluster), "fatigue_level": fatigue_level}
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
  return {"error": str(e)}
 
32
  @app.post("/predict")
33
  def predict(input_data: FatigueInput):
34
  try:
35
+ # Convert input to DataFrame
36
  input_dict = input_data.dict()
37
  input_df = pd.DataFrame([input_dict])[feature_cols]
38
+
39
+ # Scale the input
40
  scaled_input = scaler.transform(input_df)
41
+
42
+ # Get cluster prediction and probabilities
43
  cluster = gmm.predict(scaled_input)[0]
44
+ probs = gmm.predict_proba(scaled_input)[0]
45
+
46
+ # Map cluster to fatigue level
47
  fatigue_level = cluster_fatigue_map[str(cluster)]
48
+
49
+ # Return more detailed prediction info
50
+ return {
51
+ "cluster": int(cluster),
52
+ "fatigue_level": fatigue_level,
53
+ "cluster_probabilities": {str(i): float(prob) for i, prob in enumerate(probs)},
54
+ "input_features": input_dict
55
+ }
56
+
57
  except Exception as e:
58
  return {"error": str(e)}