GeorgyVlasov commited on
Commit
ce167cb
·
1 Parent(s): 63c2d15

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +178 -0
  2. my_model.tflite +3 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.keras.models import load_model
6
+ from io import BytesIO
7
+ import xml
8
+ import untangle
9
+ import mvnx
10
+ from mvnx import MVNX
11
+ import matplotlib
12
+ import matplotlib.pyplot as plt
13
+ import scipy.signal
14
+ import copy
15
+ import numpy as np
16
+
17
+
18
+ model = load_model('model/my_model.keras')
19
+
20
+
21
+
22
+
23
+ def classify_movement(file):
24
+
25
+ labels = ['bench', 'squat']
26
+
27
+ mvnx_file = file.name
28
+
29
+ data = mvnx.MVNX(mvnx_file)
30
+
31
+ position_data = data.get_info('position') # Change to the appropriate attribute
32
+
33
+ position_tensor = np.array([list(frame.values()) for frame in position_data])
34
+
35
+ # Get the dimensions of the position tensor
36
+ num_frames, num_joints, num_features = position_tensor.shape
37
+
38
+ # Determine the desired sequence length
39
+ max_sequence_length = 1200 # Choose an appropriate value
40
+
41
+ # Calculate the amount of padding needed
42
+ padding_amount = max_sequence_length - num_frames
43
+
44
+ # Create a tensor with zeros to represent padding
45
+ padding_tensor = tf.zeros((padding_amount, num_joints, num_features))
46
+
47
+ # Concatenate the padding tensor to the position tensor along the first axis
48
+ padded_position_tensor = tf.concat([position_tensor, padding_tensor], axis=0)
49
+
50
+ transformed_tensor = tf.expand_dims(padded_position_tensor, axis=0)
51
+ # Predict
52
+
53
+ interpreter = tf.lite.Interpreter(model_path='my_model.tflite')
54
+ interpreter.allocate_tensors()
55
+
56
+ # Get input and output details
57
+ input_details = interpreter.get_input_details()
58
+ output_details = interpreter.get_output_details()
59
+
60
+
61
+
62
+ # Set input tensor
63
+ interpreter.set_tensor(input_details[0]['index'], transformed_tensor)
64
+
65
+ # Run inference
66
+ interpreter.invoke()
67
+
68
+ # Get the output tensor
69
+ output_data = interpreter.get_tensor(output_details[0]['index'])
70
+
71
+ # Use the output for your task
72
+
73
+
74
+ pn = output_data
75
+
76
+ prob_dict = {'bench':float(pn[0][0]), 'squat':float(pn[0][1])}
77
+
78
+
79
+ return prob_dict
80
+
81
+
82
+
83
+
84
+ import numpy as np
85
+ import plotly.graph_objs as go
86
+
87
+ def plot_3d_joint_trajectories(file):
88
+ mvnx_file = file.name
89
+
90
+ data = mvnx.MVNX(mvnx_file)
91
+
92
+ position_data = data.get_info('position') # Change to the appropriate attribute
93
+
94
+ sample = np.array([list(frame.values()) for frame in position_data])
95
+
96
+ # Get the number of joints and timesteps
97
+ num_joints = sample.shape[1]
98
+ num_timesteps = sample.shape[0]
99
+
100
+ # List of joint names
101
+ joint_names = [
102
+ 'pelvis', 'l5', 'l3', 't12', 't8', 'neck', 'head',
103
+ 'right_shoulder', 'right_upper_arm', 'right_forearm', 'right_hand',
104
+ 'left_shoulder', 'left_upper_arm', 'left_forearm', 'left_hand',
105
+ 'right_upper_leg', 'right_lower_leg', 'right_foot', 'right_toe',
106
+ 'left_upper_leg', 'left_lower_leg', 'left_foot', 'left_toe'
107
+ ]
108
+
109
+ # Create a list to store scatter plot traces
110
+ traces = []
111
+
112
+ # Iterate through all joints
113
+ for joint_index in range(num_joints):
114
+ # Extract the X, Y, and Z coordinates of the joint over all time steps
115
+ x_coords = sample[:, joint_index, 0]
116
+ y_coords = sample[:, joint_index, 1]
117
+ z_coords = sample[:, joint_index, 2]
118
+
119
+ # Create a scatter plot trace for the joint's movement trajectory
120
+ trace = go.Scatter3d(
121
+ x=x_coords,
122
+ y=y_coords,
123
+ z=z_coords,
124
+ mode='lines',
125
+ name=joint_names[joint_index] # Use joint names instead of indices
126
+ )
127
+ traces.append(trace)
128
+
129
+ # Create the figure
130
+ fig = go.Figure(data=traces)
131
+
132
+ # Set layout
133
+ fig.update_layout(
134
+ scene=dict(
135
+ xaxis_title='X',
136
+ yaxis_title='Y',
137
+ zaxis_title='Z'
138
+ ),
139
+ title='Interactive 3D Movement Trajectories of All Joints'
140
+ )
141
+
142
+ # Show the interactive plot in a web browser
143
+ return fig
144
+
145
+
146
+
147
+ title="Movement Binary Classification with TensorFlow"
148
+ description="Demo app for Movement type classification with TensorFlow using data from mvnx file. To use it, simply upload your mvnx file (should be less than 1200 datapoints) or click on one of the examples to load them."
149
+
150
+ examples=['squat-test-30kg.mvnx', 'benchpress-test-35kg.mvnx']
151
+
152
+
153
+
154
+
155
+ classifier = gr.Interface(fn = classify_movement,
156
+ inputs=gr.File(),
157
+ outputs=gr.Label(),
158
+
159
+
160
+
161
+
162
+
163
+ )
164
+
165
+ plotter = gr.Interface(fn=plot_3d_joint_trajectories,
166
+ inputs=gr.File(),
167
+ outputs=gr.Plot(),
168
+
169
+
170
+ )
171
+
172
+
173
+ demo = gr.Parallel(classifier, plotter, examples = examples, title=title, description = description)
174
+
175
+
176
+ demo.launch()
177
+
178
+
my_model.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8c470c4ef1377204f2e78542987308f2c0e28b6b9c6c26510e06c5bf37ac11a
3
+ size 206101256