File size: 4,580 Bytes
ce167cb 4cfd559 ce167cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from io import BytesIO
import xml
import untangle
import mvnx
from mvnx import MVNX
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
import copy
import numpy as np
def classify_movement(file):
labels = ['bench', 'squat']
mvnx_file = file.name
data = mvnx.MVNX(mvnx_file)
position_data = data.get_info('position') # Change to the appropriate attribute
position_tensor = np.array([list(frame.values()) for frame in position_data])
# Get the dimensions of the position tensor
num_frames, num_joints, num_features = position_tensor.shape
# Determine the desired sequence length
max_sequence_length = 1200 # Choose an appropriate value
# Calculate the amount of padding needed
padding_amount = max_sequence_length - num_frames
# Create a tensor with zeros to represent padding
padding_tensor = tf.zeros((padding_amount, num_joints, num_features))
# Concatenate the padding tensor to the position tensor along the first axis
padded_position_tensor = tf.concat([position_tensor, padding_tensor], axis=0)
transformed_tensor = tf.expand_dims(padded_position_tensor, axis=0)
# Predict
interpreter = tf.lite.Interpreter(model_path='my_model.tflite')
interpreter.allocate_tensors()
# Get input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Set input tensor
interpreter.set_tensor(input_details[0]['index'], transformed_tensor)
# Run inference
interpreter.invoke()
# Get the output tensor
output_data = interpreter.get_tensor(output_details[0]['index'])
# Use the output for your task
pn = output_data
prob_dict = {'bench':float(pn[0][0]), 'squat':float(pn[0][1])}
return prob_dict
import numpy as np
import plotly.graph_objs as go
def plot_3d_joint_trajectories(file):
mvnx_file = file.name
data = mvnx.MVNX(mvnx_file)
position_data = data.get_info('position') # Change to the appropriate attribute
sample = np.array([list(frame.values()) for frame in position_data])
# Get the number of joints and timesteps
num_joints = sample.shape[1]
num_timesteps = sample.shape[0]
# List of joint names
joint_names = [
'pelvis', 'l5', 'l3', 't12', 't8', 'neck', 'head',
'right_shoulder', 'right_upper_arm', 'right_forearm', 'right_hand',
'left_shoulder', 'left_upper_arm', 'left_forearm', 'left_hand',
'right_upper_leg', 'right_lower_leg', 'right_foot', 'right_toe',
'left_upper_leg', 'left_lower_leg', 'left_foot', 'left_toe'
]
# Create a list to store scatter plot traces
traces = []
# Iterate through all joints
for joint_index in range(num_joints):
# Extract the X, Y, and Z coordinates of the joint over all time steps
x_coords = sample[:, joint_index, 0]
y_coords = sample[:, joint_index, 1]
z_coords = sample[:, joint_index, 2]
# Create a scatter plot trace for the joint's movement trajectory
trace = go.Scatter3d(
x=x_coords,
y=y_coords,
z=z_coords,
mode='lines',
name=joint_names[joint_index] # Use joint names instead of indices
)
traces.append(trace)
# Create the figure
fig = go.Figure(data=traces)
# Set layout
fig.update_layout(
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Z'
),
title='Interactive 3D Movement Trajectories of All Joints'
)
# Show the interactive plot in a web browser
return fig
title="Movement Binary Classification with TensorFlow"
description="Demo app for Movement type classification with TensorFlow using data from mvnx file. To use it, simply upload your mvnx file (should be less than 1200 datapoints) or click on one of the examples to load them."
examples=['squat-test-30kg.mvnx', 'benchpress-test-35kg.mvnx']
classifier = gr.Interface(fn = classify_movement,
inputs=gr.File(),
outputs=gr.Label(),
)
plotter = gr.Interface(fn=plot_3d_joint_trajectories,
inputs=gr.File(),
outputs=gr.Plot(),
)
demo = gr.Parallel(classifier, plotter, examples = examples, title=title, description = description)
demo.launch()
|