Amol Kaushik commited on
Commit
fc4219d
·
1 Parent(s): 926f4f7

gradio app

Browse files
Files changed (1) hide show
  1. A2/app.py +233 -1
A2/app.py CHANGED
@@ -1 +1,233 @@
1
- # placeholder for the app version 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import pickle
5
+ import os
6
+
7
+ MODEL_PATH = "models/champion_model_final.pkl"
8
+ DATA_PATH = "A2_dataset.csv"
9
+
10
+ model = None
11
+ FEATURE_NAMES = None
12
+ MODEL_METRICS = None
13
+
14
+
15
+ def load_champion_model():
16
+ global model, FEATURE_NAMES, MODEL_METRICS
17
+
18
+ possible_paths = [
19
+ MODEL_PATH,
20
+ "A2/models/champion_model_final.pkl",
21
+ "../A2/models/champion_model_final.pkl",
22
+ ]
23
+
24
+ for path in possible_paths:
25
+ if os.path.exists(path):
26
+ print(f"Loading champion model from {path}")
27
+ with open(path, "rb") as f:
28
+ artifact = pickle.load(f)
29
+
30
+ model = artifact["model"]
31
+ FEATURE_NAMES = artifact["feature_columns"]
32
+ MODEL_METRICS = artifact.get("test_metrics", {})
33
+
34
+ print(f"model loaded successfully")
35
+ print(f"Features: {len(FEATURE_NAMES)} columns")
36
+ print(f"Test R2: {MODEL_METRICS.get('r2', 'N/A')}")
37
+ return True
38
+
39
+ print("champion model not found")
40
+ return False
41
+
42
+
43
+ load_champion_model()
44
+
45
+
46
+ # prediction function
47
+ def predict_score(*feature_values):
48
+ if model is None:
49
+ return "Error", "Model not loaded"
50
+
51
+ # Convert inputs to dataframe with correct feature names
52
+ features_df = pd.DataFrame([feature_values], columns=FEATURE_NAMES)
53
+
54
+ raw_score = model.predict(features_df)[0]
55
+
56
+ # score to valid range and change to %
57
+ score = max(0, min(1, raw_score)) * 100
58
+
59
+ if score >= 80:
60
+ interpretation = "Excellent, great squat form"
61
+ elif score >= 60:
62
+ interpretation = "Good, minor improvements needed"
63
+ elif score >= 40:
64
+ interpretation = "Average, a lot of areas to work on"
65
+ else:
66
+ interpretation = "Needs work, focus on proper form"
67
+
68
+ # Create output
69
+ r2 = MODEL_METRICS.get('r2', 'N/A')
70
+ correlation = MODEL_METRICS.get('correlation', 'N/A')
71
+
72
+ # Format metrics
73
+ r2_str = f"{r2:.4f}" if isinstance(r2, (int, float)) else str(r2)
74
+ corr_str = f"{correlation:.4f}" if isinstance(correlation, (int, float)) else str(correlation)
75
+
76
+ details = f"""
77
+ ### Prediction Details
78
+ - **Raw Model Output:** {raw_score:.4f}
79
+ - **Normalized Score:** {score:.1f}%
80
+ - **Assessment:** {interpretation}
81
+
82
+ ### Model Performance
83
+ - **Test R-squared:** {r2_str}
84
+ - **Test Correlation:** {corr_str}
85
+
86
+ *Lower deviation values = better form*
87
+ """
88
+
89
+ return f"{score:.1f}%", interpretation, details
90
+
91
+
92
+ # load example for tesitng
93
+ def load_example():
94
+ if FEATURE_NAMES is None:
95
+ return [0.5] * 35
96
+
97
+ try:
98
+ possible_paths = [
99
+ DATA_PATH,
100
+ "A2/A2_dataset.csv",
101
+ "../A2/A2_dataset.csv",
102
+ "../Datasets_all/A2_dataset_80.csv",
103
+ ]
104
+
105
+ df = None
106
+ for path in possible_paths:
107
+ if os.path.exists(path):
108
+ df = pd.read_csv(path)
109
+ break
110
+
111
+ # Get a random row with only the features we need
112
+ available_features = [f for f in FEATURE_NAMES if f in df.columns]
113
+ sample = df[available_features].sample(1).values[0]
114
+ # Convert to float list to ensure proper types for Gradio sliders
115
+ return [float(x) for x in sample]
116
+ except Exception as e:
117
+ print(f"Error loading example: {e}")
118
+ return [0.5] * len(FEATURE_NAMES)
119
+
120
+
121
+ # create gradio interface
122
+ def create_interface():
123
+ if FEATURE_NAMES is None:
124
+ return gr.Interface(
125
+ fn=lambda: "Model not loaded",
126
+ inputs=[],
127
+ outputs="text",
128
+ title="Error: Model not loaded"
129
+ )
130
+
131
+ # Create input sliders for features
132
+ inputs = []
133
+ for name in FEATURE_NAMES:
134
+ slider = gr.Slider(
135
+ minimum=0,
136
+ maximum=1,
137
+ value=0.5,
138
+ step=0.01,
139
+ label=name.replace("_", " "),
140
+ )
141
+ inputs.append(slider)
142
+
143
+ # Build the interface
144
+ description = """
145
+ ## Deep Squat Movement Assessment
146
+
147
+ **How to use:**
148
+ 1. Adjust the sliders to input deviation values (0 = no deviation, 1 = maximum deviation)
149
+ 2. Click "Submit" to get your predicted score
150
+ 3. Or click "Load Random Example" to test with real data
151
+
152
+ **Score Interpretation:**
153
+ - 80-100%: Excellent form
154
+ - 60-79%: Good form
155
+ - 40-59%: Average form
156
+ - 0-39%: Needs improvement
157
+ """
158
+
159
+ # features into categories
160
+ angle_features = [n for n in FEATURE_NAMES if "Angle" in n]
161
+ nasm_features = [n for n in FEATURE_NAMES if "NASM" in n]
162
+ time_features = [n for n in FEATURE_NAMES if "Time" in n]
163
+
164
+ # Get indices for each category
165
+ angle_indices = [FEATURE_NAMES.index(f) for f in angle_features]
166
+ nasm_indices = [FEATURE_NAMES.index(f) for f in nasm_features]
167
+ time_indices = [FEATURE_NAMES.index(f) for f in time_features]
168
+
169
+ # Create the main interface
170
+ with gr.Blocks(title="Deep Squat Assessment") as demo:
171
+ gr.Markdown("# Deep Squat Movement Assessment")
172
+ gr.Markdown(description)
173
+
174
+ with gr.Row():
175
+ with gr.Column(scale=2):
176
+ gr.Markdown("### Input Features")
177
+ gr.Markdown(f"*{len(FEATURE_NAMES)} features loaded from champion model*")
178
+ gr.Markdown("*Deviation values: 0 = perfect, 1 = maximum deviation*")
179
+
180
+ with gr.Tabs():
181
+ with gr.TabItem(f"Angle Deviations ({len(angle_indices)})"):
182
+ for idx in angle_indices:
183
+ inputs[idx].render()
184
+
185
+ with gr.TabItem(f"NASM Deviations ({len(nasm_indices)})"):
186
+ for idx in nasm_indices:
187
+ inputs[idx].render()
188
+
189
+ with gr.TabItem(f"Time Deviations ({len(time_indices)})"):
190
+ for idx in time_indices:
191
+ inputs[idx].render()
192
+
193
+ with gr.Column(scale=1):
194
+ gr.Markdown("### Results")
195
+ score_output = gr.Textbox(label="Predicted Score")
196
+ interp_output = gr.Textbox(label="Assessment")
197
+ details_output = gr.Markdown(label="Details")
198
+
199
+ with gr.Row():
200
+ submit_btn = gr.Button("Submit", variant="primary")
201
+ example_btn = gr.Button("Load Random Example")
202
+ clear_btn = gr.Button("Clear")
203
+
204
+ submit_btn.click(
205
+ fn=predict_score,
206
+ inputs=inputs,
207
+ outputs=[score_output, interp_output, details_output],
208
+ )
209
+
210
+ example_btn.click(
211
+ fn=load_example,
212
+ inputs=[],
213
+ outputs=inputs
214
+ )
215
+
216
+ clear_btn.click(
217
+ fn=lambda: [0.5] * len(FEATURE_NAMES) + ["", "", ""],
218
+ inputs=[],
219
+ outputs=inputs + [score_output, interp_output, details_output],
220
+ )
221
+
222
+ return demo
223
+
224
+
225
+ # Create the interface
226
+ demo = create_interface()
227
+
228
+ if __name__ == "__main__":
229
+ demo.launch(
230
+ share=False,
231
+ server_name="0.0.0.0",
232
+ server_port=7860,
233
+ )