Commit
·
ce167cb
1
Parent(s):
63c2d15
Upload 2 files
Browse files- app.py +178 -0
- my_model.tflite +3 -0
app.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
from tensorflow.keras.models import load_model
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
import xml
|
| 8 |
+
import untangle
|
| 9 |
+
import mvnx
|
| 10 |
+
from mvnx import MVNX
|
| 11 |
+
import matplotlib
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import scipy.signal
|
| 14 |
+
import copy
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
model = load_model('model/my_model.keras')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def classify_movement(file):
|
| 24 |
+
|
| 25 |
+
labels = ['bench', 'squat']
|
| 26 |
+
|
| 27 |
+
mvnx_file = file.name
|
| 28 |
+
|
| 29 |
+
data = mvnx.MVNX(mvnx_file)
|
| 30 |
+
|
| 31 |
+
position_data = data.get_info('position') # Change to the appropriate attribute
|
| 32 |
+
|
| 33 |
+
position_tensor = np.array([list(frame.values()) for frame in position_data])
|
| 34 |
+
|
| 35 |
+
# Get the dimensions of the position tensor
|
| 36 |
+
num_frames, num_joints, num_features = position_tensor.shape
|
| 37 |
+
|
| 38 |
+
# Determine the desired sequence length
|
| 39 |
+
max_sequence_length = 1200 # Choose an appropriate value
|
| 40 |
+
|
| 41 |
+
# Calculate the amount of padding needed
|
| 42 |
+
padding_amount = max_sequence_length - num_frames
|
| 43 |
+
|
| 44 |
+
# Create a tensor with zeros to represent padding
|
| 45 |
+
padding_tensor = tf.zeros((padding_amount, num_joints, num_features))
|
| 46 |
+
|
| 47 |
+
# Concatenate the padding tensor to the position tensor along the first axis
|
| 48 |
+
padded_position_tensor = tf.concat([position_tensor, padding_tensor], axis=0)
|
| 49 |
+
|
| 50 |
+
transformed_tensor = tf.expand_dims(padded_position_tensor, axis=0)
|
| 51 |
+
# Predict
|
| 52 |
+
|
| 53 |
+
interpreter = tf.lite.Interpreter(model_path='my_model.tflite')
|
| 54 |
+
interpreter.allocate_tensors()
|
| 55 |
+
|
| 56 |
+
# Get input and output details
|
| 57 |
+
input_details = interpreter.get_input_details()
|
| 58 |
+
output_details = interpreter.get_output_details()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Set input tensor
|
| 63 |
+
interpreter.set_tensor(input_details[0]['index'], transformed_tensor)
|
| 64 |
+
|
| 65 |
+
# Run inference
|
| 66 |
+
interpreter.invoke()
|
| 67 |
+
|
| 68 |
+
# Get the output tensor
|
| 69 |
+
output_data = interpreter.get_tensor(output_details[0]['index'])
|
| 70 |
+
|
| 71 |
+
# Use the output for your task
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
pn = output_data
|
| 75 |
+
|
| 76 |
+
prob_dict = {'bench':float(pn[0][0]), 'squat':float(pn[0][1])}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
return prob_dict
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
import numpy as np
|
| 85 |
+
import plotly.graph_objs as go
|
| 86 |
+
|
| 87 |
+
def plot_3d_joint_trajectories(file):
|
| 88 |
+
mvnx_file = file.name
|
| 89 |
+
|
| 90 |
+
data = mvnx.MVNX(mvnx_file)
|
| 91 |
+
|
| 92 |
+
position_data = data.get_info('position') # Change to the appropriate attribute
|
| 93 |
+
|
| 94 |
+
sample = np.array([list(frame.values()) for frame in position_data])
|
| 95 |
+
|
| 96 |
+
# Get the number of joints and timesteps
|
| 97 |
+
num_joints = sample.shape[1]
|
| 98 |
+
num_timesteps = sample.shape[0]
|
| 99 |
+
|
| 100 |
+
# List of joint names
|
| 101 |
+
joint_names = [
|
| 102 |
+
'pelvis', 'l5', 'l3', 't12', 't8', 'neck', 'head',
|
| 103 |
+
'right_shoulder', 'right_upper_arm', 'right_forearm', 'right_hand',
|
| 104 |
+
'left_shoulder', 'left_upper_arm', 'left_forearm', 'left_hand',
|
| 105 |
+
'right_upper_leg', 'right_lower_leg', 'right_foot', 'right_toe',
|
| 106 |
+
'left_upper_leg', 'left_lower_leg', 'left_foot', 'left_toe'
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
# Create a list to store scatter plot traces
|
| 110 |
+
traces = []
|
| 111 |
+
|
| 112 |
+
# Iterate through all joints
|
| 113 |
+
for joint_index in range(num_joints):
|
| 114 |
+
# Extract the X, Y, and Z coordinates of the joint over all time steps
|
| 115 |
+
x_coords = sample[:, joint_index, 0]
|
| 116 |
+
y_coords = sample[:, joint_index, 1]
|
| 117 |
+
z_coords = sample[:, joint_index, 2]
|
| 118 |
+
|
| 119 |
+
# Create a scatter plot trace for the joint's movement trajectory
|
| 120 |
+
trace = go.Scatter3d(
|
| 121 |
+
x=x_coords,
|
| 122 |
+
y=y_coords,
|
| 123 |
+
z=z_coords,
|
| 124 |
+
mode='lines',
|
| 125 |
+
name=joint_names[joint_index] # Use joint names instead of indices
|
| 126 |
+
)
|
| 127 |
+
traces.append(trace)
|
| 128 |
+
|
| 129 |
+
# Create the figure
|
| 130 |
+
fig = go.Figure(data=traces)
|
| 131 |
+
|
| 132 |
+
# Set layout
|
| 133 |
+
fig.update_layout(
|
| 134 |
+
scene=dict(
|
| 135 |
+
xaxis_title='X',
|
| 136 |
+
yaxis_title='Y',
|
| 137 |
+
zaxis_title='Z'
|
| 138 |
+
),
|
| 139 |
+
title='Interactive 3D Movement Trajectories of All Joints'
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Show the interactive plot in a web browser
|
| 143 |
+
return fig
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
title="Movement Binary Classification with TensorFlow"
|
| 148 |
+
description="Demo app for Movement type classification with TensorFlow using data from mvnx file. To use it, simply upload your mvnx file (should be less than 1200 datapoints) or click on one of the examples to load them."
|
| 149 |
+
|
| 150 |
+
examples=['squat-test-30kg.mvnx', 'benchpress-test-35kg.mvnx']
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
classifier = gr.Interface(fn = classify_movement,
|
| 156 |
+
inputs=gr.File(),
|
| 157 |
+
outputs=gr.Label(),
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
plotter = gr.Interface(fn=plot_3d_joint_trajectories,
|
| 166 |
+
inputs=gr.File(),
|
| 167 |
+
outputs=gr.Plot(),
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
demo = gr.Parallel(classifier, plotter, examples = examples, title=title, description = description)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
demo.launch()
|
| 177 |
+
|
| 178 |
+
|
my_model.tflite
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e8c470c4ef1377204f2e78542987308f2c0e28b6b9c6c26510e06c5bf37ac11a
|
| 3 |
+
size 206101256
|