Spaces:
Runtime error
Runtime error
| data_path = "./data/" | |
| import pandas as pd | |
| import datasets | |
| # load the csv into motion_capture_data | |
| import streamlit as st | |
| dataset_names = ['Fold_towels', 'Pipette', 'Take_the_item', 'Twist_the_tube'] | |
| def load_data(): | |
| print("Loading data") | |
| # load the motion capture data | |
| all_datasets = {} | |
| for name in dataset_names: | |
| print("Loading dataset: ", name) | |
| all_datasets[name] = pd.DataFrame(datasets.load_dataset("cyberorigin/"+name)['train']) | |
| total_period = 0 | |
| for dataset in all_datasets.values(): | |
| # dataset["timestamp"] = dataset["timestamp"].astype(float) | |
| traj_period = dataset["timestamp"].iloc[-1] - dataset["timestamp"].iloc[0] | |
| total_period += traj_period | |
| return all_datasets, total_period | |
| def visualize(data): | |
| dataset_option = st.selectbox( | |
| 'Select a dataset:', | |
| dataset_names | |
| ) | |
| # create a streamlit app that displays the motion capture data | |
| # and the video data | |
| st.video("https://huggingface.co/datasets/cyberorigin/"+dataset_option+"/resolve/main/Video/video.mp4") | |
| motion_capture_data = data[dataset_option] | |
| body_part_names = ['Left Shoulder', | |
| 'Right Upper Arm', | |
| 'Left Lower Leg', | |
| 'Spine1', | |
| 'Right Upper Leg', | |
| 'Spine3', | |
| 'Right Lower Arm', | |
| 'Left Foot', | |
| 'Right Lower Leg', | |
| 'Right Shoulder', | |
| 'Left Hand', | |
| 'Left Upper Leg', | |
| 'Right Foot', | |
| 'Spine', | |
| 'Spine2', | |
| 'Left Lower Arm', | |
| 'Left Toe', | |
| 'Neck', | |
| 'Right Hand', | |
| 'Right Toe', | |
| 'Head', | |
| 'Left Upper Arm', | |
| 'Hips',] | |
| motion_capture_x = motion_capture_data[[body_part_name+"_x" for body_part_name in body_part_names]] | |
| motion_capture_y = motion_capture_data[[body_part_name+"_y" for body_part_name in body_part_names]] | |
| motion_capture_z = motion_capture_data[[body_part_name+"_z" for body_part_name in body_part_names]] | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| # Sample Data Preparation | |
| data = [] | |
| times = motion_capture_data["timestamp"] | |
| frames = [go.Frame( | |
| data=[ | |
| go.Scatter3d( | |
| x=motion_capture_x.iloc[k], | |
| y=motion_capture_y.iloc[k], | |
| z=motion_capture_z.iloc[k], | |
| mode='markers', | |
| marker=dict(size=5, color='blue') | |
| ) | |
| ], | |
| name=str(k) | |
| ) for k in range(len(times))] | |
| # Create the initial scatter plot | |
| initial_scatter = go.Scatter3d( | |
| x=motion_capture_x.iloc[0], | |
| y=motion_capture_y.iloc[0], | |
| z=motion_capture_z.iloc[0], | |
| mode='markers', | |
| marker=dict(size=5, color='blue') | |
| ) | |
| # Create the layout with slider | |
| layout = go.Layout( | |
| title='Motion Capture Visualization', | |
| updatemenus=[{ | |
| 'buttons': [ | |
| { | |
| 'args': [None, {'frame': {'duration': 1, 'redraw': True}, 'fromcurrent': True}], | |
| 'label': 'Play', | |
| 'method': 'animate' | |
| }, | |
| { | |
| 'args': [[None], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate', 'transition': {'duration': 0}}], | |
| 'label': 'Pause', | |
| 'method': 'animate' | |
| } | |
| ], | |
| 'direction': 'left', | |
| 'pad': {'r': 10, 't': 87}, | |
| 'showactive': True, | |
| 'type': 'buttons', | |
| 'x': 0.1, | |
| 'xanchor': 'right', | |
| 'y': 0, | |
| 'yanchor': 'top' | |
| }], | |
| sliders=[{ | |
| 'active': 0, | |
| 'steps': [{ | |
| 'label': str(k), | |
| 'method': 'animate', | |
| 'args': [ | |
| [str(k)], | |
| {'mode': 'immediate', 'frame': {'duration': 300, 'redraw': True}, 'transition': {'duration': len(times)/30}} | |
| ] | |
| } for k in range(len(times))], | |
| 'currentvalue': { | |
| 'prefix': 'Time: ', | |
| 'visible': True, | |
| 'xanchor': 'right' | |
| }, | |
| 'pad': {'b': 10}, | |
| 'len': 0.9, | |
| 'x': 0.1, | |
| 'y': 0, | |
| }] | |
| ) | |
| # Create the figure | |
| fig = go.Figure(data=[initial_scatter], frames=frames, layout=layout) | |
| # Display the figure in the streamlit app | |
| st.plotly_chart(fig) | |
| st.title("CyberOrigin Data Visualization") | |
| data, period = load_data() | |
| # display the total period of the data up to 2 decimal places | |
| st.write("Total period of data: ", round(period, 2), " seconds") | |
| visualize(data) |