jkim03 commited on
Commit
6d565ea
·
verified ·
1 Parent(s): 60d1870

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -27
app.py CHANGED
@@ -3,25 +3,25 @@ import joblib
3
  import pandas as pd
4
  import json
5
 
6
- # Load your model and preprocessing tools
7
  scaler = joblib.load("scaler.joblib")
8
  gmm = joblib.load("gmm_model.joblib")
 
9
  with open("cluster_fatigue_map.json") as f:
10
  cluster_fatigue_map = json.load(f)
11
 
12
- # Define the input features
13
  feature_cols = [
14
  'AVRR', 'SDNN', 'RMSSD', 'PNN50', 'Coefficient_of_Variation',
15
  'Age', 'Weight', 'Height'
16
  ]
17
 
18
- # Main prediction function
19
  def predict_fatigue(
20
  AVRR, SDNN, RMSSD, PNN50, Coefficient_of_Variation,
21
  Age, Weight, Height
22
  ):
23
  try:
24
- # Prepare input as dict
25
  input_dict = {
26
  'AVRR': AVRR,
27
  'SDNN': SDNN,
@@ -33,35 +33,35 @@ def predict_fatigue(
33
  'Height': Height
34
  }
35
 
36
- # Convert to DataFrame and scale
37
  df = pd.DataFrame([input_dict])[feature_cols]
38
  scaled = scaler.transform(df)
39
-
40
- # Predict cluster and map to fatigue level
41
  cluster = gmm.predict(scaled)[0]
42
  fatigue_level = cluster_fatigue_map[str(cluster)]
43
 
44
- return f"Predicted Fatigue Level: {fatigue_level} (Cluster {cluster})"
 
 
 
45
 
46
  except Exception as e:
47
- return f"Error: {str(e)}"
48
 
49
- # Gradio UI definition
50
- iface = gr.Interface(
51
- fn=predict_fatigue,
52
- inputs=[
53
- gr.Number(label='AVRR'),
54
- gr.Number(label='SDNN'),
55
- gr.Number(label='RMSSD'),
56
- gr.Number(label='PNN50'),
57
- gr.Number(label='Coefficient of Variation'),
58
- gr.Number(label='Age'),
59
- gr.Number(label='Weight'),
60
- gr.Number(label='Height'),
61
- ],
62
- outputs="text",
63
- title="Fatigue Level Predictor",
64
- description="Enter HRV metrics and demographic data to estimate fatigue level using a Gaussian Mixture Model."
65
- )
66
 
67
- iface.launch()
 
3
  import pandas as pd
4
  import json
5
 
6
+ # Load the model, scaler, and mapping
7
  scaler = joblib.load("scaler.joblib")
8
  gmm = joblib.load("gmm_model.joblib")
9
+
10
  with open("cluster_fatigue_map.json") as f:
11
  cluster_fatigue_map = json.load(f)
12
 
13
+ # Define expected input features
14
  feature_cols = [
15
  'AVRR', 'SDNN', 'RMSSD', 'PNN50', 'Coefficient_of_Variation',
16
  'Age', 'Weight', 'Height'
17
  ]
18
 
19
+ # Inference function — accepts Python list of values
20
  def predict_fatigue(
21
  AVRR, SDNN, RMSSD, PNN50, Coefficient_of_Variation,
22
  Age, Weight, Height
23
  ):
24
  try:
 
25
  input_dict = {
26
  'AVRR': AVRR,
27
  'SDNN': SDNN,
 
33
  'Height': Height
34
  }
35
 
 
36
  df = pd.DataFrame([input_dict])[feature_cols]
37
  scaled = scaler.transform(df)
 
 
38
  cluster = gmm.predict(scaled)[0]
39
  fatigue_level = cluster_fatigue_map[str(cluster)]
40
 
41
+ return {
42
+ "cluster": int(cluster),
43
+ "fatigue_level": fatigue_level
44
+ }
45
 
46
  except Exception as e:
47
+ return {"error": str(e)}
48
 
49
+ # Gradio app in REST-friendly format
50
+ with gr.Blocks() as demo:
51
+ gr.Interface(
52
+ fn=predict_fatigue,
53
+ inputs=[
54
+ gr.Number(label='AVRR'),
55
+ gr.Number(label='SDNN'),
56
+ gr.Number(label='RMSSD'),
57
+ gr.Number(label='PNN50'),
58
+ gr.Number(label='Coefficient of Variation'),
59
+ gr.Number(label='Age'),
60
+ gr.Number(label='Weight'),
61
+ gr.Number(label='Height'),
62
+ ],
63
+ outputs="json",
64
+ live=False
65
+ )
66
 
67
+ demo.launch()