|
|
|
|
|
import gradio as gr |
|
|
|
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.models import load_model |
|
|
from io import BytesIO |
|
|
import xml |
|
|
import untangle |
|
|
import mvnx |
|
|
from mvnx import MVNX |
|
|
import matplotlib |
|
|
import matplotlib.pyplot as plt |
|
|
import scipy.signal |
|
|
import copy |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_movement(file): |
|
|
|
|
|
labels = ['bench', 'squat'] |
|
|
|
|
|
mvnx_file = file.name |
|
|
|
|
|
data = mvnx.MVNX(mvnx_file) |
|
|
|
|
|
position_data = data.get_info('position') |
|
|
|
|
|
position_tensor = np.array([list(frame.values()) for frame in position_data]) |
|
|
|
|
|
|
|
|
num_frames, num_joints, num_features = position_tensor.shape |
|
|
|
|
|
|
|
|
max_sequence_length = 1200 |
|
|
|
|
|
|
|
|
padding_amount = max_sequence_length - num_frames |
|
|
|
|
|
|
|
|
padding_tensor = tf.zeros((padding_amount, num_joints, num_features)) |
|
|
|
|
|
|
|
|
padded_position_tensor = tf.concat([position_tensor, padding_tensor], axis=0) |
|
|
|
|
|
transformed_tensor = tf.expand_dims(padded_position_tensor, axis=0) |
|
|
|
|
|
|
|
|
interpreter = tf.lite.Interpreter(model_path='my_model.tflite') |
|
|
interpreter.allocate_tensors() |
|
|
|
|
|
|
|
|
input_details = interpreter.get_input_details() |
|
|
output_details = interpreter.get_output_details() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
interpreter.set_tensor(input_details[0]['index'], transformed_tensor) |
|
|
|
|
|
|
|
|
interpreter.invoke() |
|
|
|
|
|
|
|
|
output_data = interpreter.get_tensor(output_details[0]['index']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pn = output_data |
|
|
|
|
|
prob_dict = {'bench':float(pn[0][0]), 'squat':float(pn[0][1])} |
|
|
|
|
|
|
|
|
return prob_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import plotly.graph_objs as go |
|
|
|
|
|
def plot_3d_joint_trajectories(file): |
|
|
mvnx_file = file.name |
|
|
|
|
|
data = mvnx.MVNX(mvnx_file) |
|
|
|
|
|
position_data = data.get_info('position') |
|
|
|
|
|
sample = np.array([list(frame.values()) for frame in position_data]) |
|
|
|
|
|
|
|
|
num_joints = sample.shape[1] |
|
|
num_timesteps = sample.shape[0] |
|
|
|
|
|
|
|
|
joint_names = [ |
|
|
'pelvis', 'l5', 'l3', 't12', 't8', 'neck', 'head', |
|
|
'right_shoulder', 'right_upper_arm', 'right_forearm', 'right_hand', |
|
|
'left_shoulder', 'left_upper_arm', 'left_forearm', 'left_hand', |
|
|
'right_upper_leg', 'right_lower_leg', 'right_foot', 'right_toe', |
|
|
'left_upper_leg', 'left_lower_leg', 'left_foot', 'left_toe' |
|
|
] |
|
|
|
|
|
|
|
|
traces = [] |
|
|
|
|
|
|
|
|
for joint_index in range(num_joints): |
|
|
|
|
|
x_coords = sample[:, joint_index, 0] |
|
|
y_coords = sample[:, joint_index, 1] |
|
|
z_coords = sample[:, joint_index, 2] |
|
|
|
|
|
|
|
|
trace = go.Scatter3d( |
|
|
x=x_coords, |
|
|
y=y_coords, |
|
|
z=z_coords, |
|
|
mode='lines', |
|
|
name=joint_names[joint_index] |
|
|
) |
|
|
traces.append(trace) |
|
|
|
|
|
|
|
|
fig = go.Figure(data=traces) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
scene=dict( |
|
|
xaxis_title='X', |
|
|
yaxis_title='Y', |
|
|
zaxis_title='Z' |
|
|
), |
|
|
title='Interactive 3D Movement Trajectories of All Joints' |
|
|
) |
|
|
|
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
title="Movement Binary Classification with TensorFlow" |
|
|
description="Demo app for Movement type classification with TensorFlow using data from mvnx file. To use it, simply upload your mvnx file (should be less than 1200 datapoints) or click on one of the examples to load them." |
|
|
|
|
|
examples=['squat-test-30kg.mvnx', 'benchpress-test-35kg.mvnx'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
classifier = gr.Interface(fn = classify_movement, |
|
|
inputs=gr.File(), |
|
|
outputs=gr.Label(), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
plotter = gr.Interface(fn=plot_3d_joint_trajectories, |
|
|
inputs=gr.File(), |
|
|
outputs=gr.Plot(), |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
demo = gr.Parallel(classifier, plotter, examples = examples, title=title, description = description) |
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|