Spaces:
Running
Running
| # Define functions and run GaitDynamics. | |
| import numpy as np | |
| import pandas as pd | |
| from torch.utils.data import Dataset | |
| import torch.nn.functional as F | |
| import torch | |
| from torch import nn | |
| from torch import Tensor | |
| import argparse | |
| from typing import Any | |
| import nimblephysics as nimble | |
| import scipy.interpolate as interpo | |
| from scipy.interpolate import interp1d | |
| import random | |
| from scipy.signal import filtfilt, butter | |
| from accelerate import Accelerator, DistributedDataParallelKwargs | |
| import os | |
| from inspect import isfunction | |
| from math import pi | |
| from einops import rearrange, repeat | |
| import math | |
| import copy | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| from functools import partial | |
| from accelerate.state import AcceleratorState | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| torch.manual_seed(0) | |
| random.seed(0) | |
| np.random.seed(0) | |
| """ ============================ Start scaler.py ============================ """ | |
| def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): | |
| # if we are fitting on 1D arrays, scale might be a scalar | |
| if constant_mask is None: | |
| # Detect near constant values to avoid dividing by a very small | |
| # value that could lead to surprising results and numerical | |
| # stability issues. | |
| constant_mask = scale < 10 * torch.finfo(scale.dtype).eps | |
| if copy: | |
| # New array to avoid side-effects | |
| scale = scale.clone() | |
| scale[constant_mask] = 1.0 | |
| return scale | |
| class MinMaxScaler: | |
| _parameter_constraints: dict = { | |
| "feature_range": [tuple], | |
| "copy": ["boolean"], | |
| "clip": ["boolean"], | |
| } | |
| def __init__(self, feature_range=(-1, 1), *, copy=True, clip=True): | |
| self.feature_range = feature_range | |
| self.copy = copy | |
| self.clip = clip | |
| def _reset(self): | |
| if hasattr(self, "scale_"): | |
| del self.scale_ | |
| del self.min_ | |
| del self.n_samples_seen_ | |
| del self.data_min_ | |
| del self.data_max_ | |
| del self.data_range_ | |
| def fit(self, X): | |
| # Reset internal state before fitting | |
| self._reset() | |
| return self.partial_fit(X) | |
| def partial_fit(self, X): | |
| feature_range = self.feature_range | |
| if feature_range[0] >= feature_range[1]: | |
| raise ValueError( | |
| "Minimum of desired feature range must be smaller than maximum. Got %s." | |
| % str(feature_range) | |
| ) | |
| data_min = torch.min(X, axis=0)[0] | |
| data_max = torch.max(X, axis=0)[0] | |
| self.n_samples_seen_ = X.shape[0] | |
| data_range = data_max - data_min | |
| self.scale_ = (feature_range[1] - feature_range[0]) / _handle_zeros_in_scale( | |
| data_range, copy=True | |
| ) | |
| self.min_ = feature_range[0] - data_min * self.scale_ | |
| self.data_min_ = data_min | |
| self.data_max_ = data_max | |
| self.data_range_ = data_range | |
| return self | |
| def transform(self, X): | |
| X *= self.scale_.to(X.device) | |
| X += self.min_.to(X.device) | |
| return X | |
| def inverse_transform(self, X): | |
| X -= self.min_.to(X.device) | |
| X /= self.scale_.to(X.device) | |
| return X | |
| class Normalizer: | |
| def __init__(self, scaler, cols_to_normalize): | |
| self.scaler = scaler | |
| self.cols_to_normalize = cols_to_normalize | |
| def normalize(self, x): | |
| x = x.clone() | |
| x[:, self.cols_to_normalize] = self.scaler.transform(x[:, self.cols_to_normalize]) | |
| return x | |
| def unnormalize(self, x): | |
| batch, seq, ch = x.shape | |
| x = x.clone() | |
| x = x.reshape(-1, ch) | |
| x[:, self.cols_to_normalize] = self.scaler.inverse_transform(x[:, self.cols_to_normalize]) | |
| return x.reshape((batch, seq, ch)) | |
| """ ============================ End scaler.py ============================ """ | |
| """ ============================ Start consts.py ============================ """ | |
| JOINTS_3D_ALL = { | |
| 'pelvis': ['pelvis_tilt', 'pelvis_list', 'pelvis_rotation'], | |
| 'hip_r': ['hip_flexion_r', 'hip_adduction_r', 'hip_rotation_r'], | |
| 'hip_l': ['hip_flexion_l', 'hip_adduction_l', 'hip_rotation_l'], | |
| 'lumbar': ['lumbar_extension', 'lumbar_bending', 'lumbar_rotation'], | |
| # 'arm_r': ['arm_flex_r', 'arm_add_r', 'arm_rot_r'], | |
| # 'arm_l': ['arm_flex_l', 'arm_add_l', 'arm_rot_l'] | |
| } | |
| OSIM_DOF_ALL = [ | |
| 'pelvis_tilt', 'pelvis_list', 'pelvis_rotation', 'pelvis_tx', 'pelvis_ty', 'pelvis_tz', 'hip_flexion_r', | |
| 'hip_adduction_r', 'hip_rotation_r', 'knee_angle_r', 'ankle_angle_r', 'subtalar_angle_r', 'mtp_angle_r', | |
| 'hip_flexion_l', 'hip_adduction_l', 'hip_rotation_l', 'knee_angle_l', 'ankle_angle_l', 'subtalar_angle_l', | |
| 'mtp_angle_l', 'lumbar_extension', 'lumbar_bending', 'lumbar_rotation', 'arm_flex_r', 'arm_add_r', 'arm_rot_r', | |
| 'elbow_flex_r', 'pro_sup_r', 'wrist_flex_r', 'wrist_dev_r', 'arm_flex_l', 'arm_add_l', 'arm_rot_l', | |
| 'elbow_flex_l', 'pro_sup_l', 'wrist_flex_l', 'wrist_dev_l'] | |
| KINETICS_ALL = [body + modality for body in ['calcn_r', 'calcn_l'] for modality in | |
| ['_force_vx', '_force_vy', '_force_vz', '_force_normed_cop_x', '_force_normed_cop_y', '_force_normed_cop_z']] | |
| MODEL_STATES_COLUMN_NAMES_WITH_ARM = [ | |
| 'pelvis_tx', 'pelvis_ty', 'pelvis_tz', 'knee_angle_r', 'ankle_angle_r', 'subtalar_angle_r', | |
| 'knee_angle_l', 'ankle_angle_l', 'subtalar_angle_l', 'elbow_flex_r', 'pro_sup_r', 'elbow_flex_l', 'pro_sup_l' | |
| ] + KINETICS_ALL + [ | |
| 'pelvis_0', 'pelvis_1', 'pelvis_2', 'pelvis_3', 'pelvis_4', 'pelvis_5', | |
| 'hip_r_0', 'hip_r_1', 'hip_r_2', 'hip_r_3', 'hip_r_4', 'hip_r_5', | |
| 'hip_l_0', 'hip_l_1', 'hip_l_2', 'hip_l_3', 'hip_l_4', 'hip_l_5', | |
| 'lumbar_0', 'lumbar_1', 'lumbar_2', 'lumbar_3', 'lumbar_4', 'lumbar_5', | |
| 'arm_r_0', 'arm_r_1', 'arm_r_2', 'arm_r_3', 'arm_r_4', 'arm_r_5', # only for with arm | |
| 'arm_l_0', 'arm_l_1', 'arm_l_2', 'arm_l_3', 'arm_l_4', 'arm_l_5' # only for with arm | |
| ] | |
| MODEL_STATES_COLUMN_NAMES_NO_ARM = copy.deepcopy(MODEL_STATES_COLUMN_NAMES_WITH_ARM) | |
| for name_ in ['elbow_flex_r', 'pro_sup_r', 'elbow_flex_l', 'pro_sup_l', 'arm_r_0', 'arm_r_1', 'arm_r_2', 'arm_r_3', | |
| 'arm_r_4', 'arm_r_5', 'arm_l_0', 'arm_l_1', 'arm_l_2', 'arm_l_3', 'arm_l_4', 'arm_l_5']: | |
| MODEL_STATES_COLUMN_NAMES_NO_ARM.remove(name_) | |
| FROZEN_DOFS = ['mtp_angle_r', 'mtp_angle_l', | |
| 'wrist_flex_r', 'wrist_dev_r', 'wrist_flex_l', 'wrist_dev_l'] | |
| FULL_OSIM_DOF = ['pelvis_tilt', 'pelvis_list', 'pelvis_rotation', 'pelvis_tx', 'pelvis_ty', 'pelvis_tz', | |
| 'hip_flexion_r', 'hip_adduction_r', 'hip_rotation_r', 'knee_angle_r', 'ankle_angle_r', | |
| 'subtalar_angle_r', 'mtp_angle_r', 'hip_flexion_l', 'hip_adduction_l', 'hip_rotation_l', | |
| 'knee_angle_l', 'ankle_angle_l', 'subtalar_angle_l', 'mtp_angle_l', 'lumbar_extension', | |
| 'lumbar_bending', 'lumbar_rotation'] | |
| JOINTS_1D_ALL = ['pelvis_tx', 'pelvis_ty', 'pelvis_tz', | |
| 'knee_angle_r', 'ankle_angle_r', 'subtalar_angle_r', 'mtp_angle_r', | |
| 'knee_angle_l', 'ankle_angle_l', 'subtalar_angle_l', 'mtp_angle_l'] | |
| """ ============================ End consts.py ============================ """ | |
| """ ============================ Start model.py ============================ """ | |
| class GaussianDiffusion(nn.Module): | |
| def __init__( | |
| self, | |
| model, | |
| horizon, | |
| repr_dim, | |
| opt, | |
| n_timestep=1000, | |
| schedule="linear", | |
| loss_type="l1", | |
| clip_denoised=False, | |
| predict_epsilon=True, | |
| guidance_weight=1, | |
| use_p2=False, | |
| cond_drop_prob=0., | |
| ): | |
| super().__init__() | |
| self.horizon = horizon | |
| self.transition_dim = repr_dim | |
| self.model = model | |
| self.ema = EMA(0.99) | |
| self.master_model = copy.deepcopy(self.model) | |
| self.opt = opt | |
| self.cond_drop_prob = cond_drop_prob | |
| # make a SMPL instance for FK module | |
| betas = torch.Tensor( | |
| make_beta_schedule(schedule=schedule, n_timestep=n_timestep) | |
| ) | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) | |
| self.n_timestep = int(n_timestep) | |
| self.clip_denoised = clip_denoised | |
| self.predict_epsilon = predict_epsilon | |
| self.register_buffer("betas", betas) | |
| self.register_buffer("alphas_cumprod", alphas_cumprod) | |
| self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) | |
| self.guidance_weight = guidance_weight | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) | |
| self.register_buffer( | |
| "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) | |
| ) | |
| self.register_buffer( | |
| "log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod) | |
| ) | |
| self.register_buffer( | |
| "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod) | |
| ) | |
| self.register_buffer( | |
| "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) | |
| ) | |
| # calculations for posterior q(x_{t-1} | x_t, x_0) | |
| posterior_variance = ( | |
| betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) | |
| ) | |
| self.register_buffer("posterior_variance", posterior_variance) | |
| ## log calculation clipped because the posterior variance | |
| ## is 0 at the beginning of the diffusion chain | |
| self.register_buffer( | |
| "posterior_log_variance_clipped", | |
| torch.log(torch.clamp(posterior_variance, min=1e-20)), | |
| ) | |
| self.register_buffer( | |
| "posterior_mean_coef1", | |
| betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), | |
| ) | |
| self.register_buffer( | |
| "posterior_mean_coef2", | |
| (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod), | |
| ) | |
| # p2 weighting | |
| self.p2_loss_weight_k = 1 | |
| self.p2_loss_weight_gamma = 0.5 if use_p2 else 0 | |
| self.register_buffer( | |
| "p2_loss_weight", | |
| (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) | |
| ** -self.p2_loss_weight_gamma, | |
| ) | |
| ## get loss coefficients and initialize objective | |
| self.loss_fn = F.mse_loss if loss_type == "l2" else F.l1_loss | |
| def set_normalizer(self, normalizer): | |
| self.normalizer = normalizer | |
| # ------------------------------------------ sampling ------------------------------------------# | |
| def predict_start_from_noise(self, x_t, t, noise): | |
| """ | |
| if self.predict_epsilon, model output is (scaled) noise; | |
| otherwise, model predicts x0 directly | |
| """ | |
| if self.predict_epsilon: | |
| return ( | |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
| - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise | |
| ) | |
| else: | |
| return noise | |
| def predict_noise_from_start(self, x_t, t, x0): | |
| return ( | |
| (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ | |
| extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
| ) | |
| def model_predictions(self, x, cond, time_cond, weight=None, clip_x_start=False): | |
| weight = weight if weight is not None else self.guidance_weight | |
| model_output = self.model.guided_forward(x, cond, time_cond, weight) | |
| maybe_clip = partial(torch.clamp, min=-1., max=1.) if clip_x_start else identity | |
| x_start = model_output | |
| x_start = maybe_clip(x_start) | |
| pred_noise = self.predict_noise_from_start(x, time_cond, x_start) | |
| return pred_noise, x_start | |
| def q_posterior(self, x_start, x_t, t): | |
| posterior_mean = ( | |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start | |
| + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
| ) | |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) | |
| posterior_log_variance_clipped = extract( | |
| self.posterior_log_variance_clipped, t, x_t.shape | |
| ) | |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
| def p_mean_variance(self, x, cond, t): | |
| # guidance clipping | |
| if t[0] > 1.0 * self.n_timestep: | |
| weight = min(self.guidance_weight, 0) | |
| elif t[0] < 0.1 * self.n_timestep: | |
| weight = min(self.guidance_weight, 1) | |
| else: | |
| weight = self.guidance_weight | |
| x_recon = self.predict_start_from_noise( | |
| x, t=t, noise=self.model.guided_forward(x, cond, t, weight) | |
| ) | |
| if self.clip_denoised: | |
| x_recon.clamp_(-1.0, 1.0) | |
| else: | |
| assert RuntimeError() | |
| model_mean, posterior_variance, posterior_log_variance = self.q_posterior( | |
| x_start=x_recon, x_t=x, t=t | |
| ) | |
| return model_mean, posterior_variance, posterior_log_variance, x_recon | |
| def p_sample(self, x, cond, t): | |
| b, *_, device = *x.shape, x.device | |
| model_mean, _, model_log_variance, x_start = self.p_mean_variance( | |
| x=x, cond=cond, t=t | |
| ) | |
| noise = torch.randn_like(model_mean) | |
| # no noise when t == 0 | |
| nonzero_mask = (1 - (t == 0).float()).reshape( | |
| b, *((1,) * (len(noise.shape) - 1)) | |
| ) | |
| x_out = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise | |
| return x_out, x_start | |
| def p_sample_loop( # Only used during inference | |
| self, | |
| shape, | |
| cond, | |
| noise=None, | |
| constraint=None, | |
| return_diffusion=False, | |
| start_point=None, | |
| ): | |
| device = self.betas.device | |
| # default to diffusion over whole timescale | |
| start_point = self.n_timestep if start_point is None else start_point | |
| batch_size = shape[0] | |
| x = torch.randn(shape, device=device) if noise is None else noise.to(device) | |
| cond = cond.to(device) | |
| if return_diffusion: | |
| diffusion = [x] | |
| for i in tqdm(reversed(range(0, start_point))): | |
| # fill with i | |
| timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) | |
| x, _ = self.p_sample(x, cond, timesteps) | |
| if return_diffusion: | |
| diffusion.append(x) | |
| if return_diffusion: | |
| return x, diffusion | |
| else: | |
| return x | |
| def inpaint_ddim_guided(self, shape, noise=None, constraint=None, return_diffusion=False, start_point=None): | |
| batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 0 | |
| times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps | |
| times = list(reversed(times.int().tolist())) | |
| time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] | |
| x = torch.randn(shape, device=device) | |
| cond = constraint["cond"].to(device) | |
| mask = constraint["mask"].to(device) # batch x horizon x channels | |
| value = constraint["value"].to(device) # batch x horizon x channels | |
| value_diff_thd = constraint["value_diff_thd"].to(device) # channels | |
| value_diff_weight = constraint["value_diff_weight"].to(device) # channels | |
| for time, time_next in time_pairs: | |
| time_cond = torch.full((batch,), time, device=device, dtype=torch.long) | |
| if self.opt.guide_x_start_the_end_step <= time <= self.opt.guide_x_start_the_beginning_step: | |
| x.requires_grad_() | |
| with torch.enable_grad(): | |
| for step_ in range(self.opt.n_guided_steps): | |
| pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start=self.clip_denoised) | |
| value_diff = torch.subtract(x_start, value) | |
| loss = torch.relu(value_diff.abs() - value_diff_thd) * value_diff_weight | |
| grad = torch.autograd.grad([loss.sum()], [x])[0] | |
| x = x - self.opt.guidance_lr * grad | |
| pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start=self.clip_denoised) | |
| if time_next < 0: | |
| x = x_start | |
| x = value * mask + (1.0 - mask) * x | |
| return x | |
| alpha = self.alphas_cumprod[time] | |
| alpha_next = self.alphas_cumprod[time_next] | |
| sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() | |
| c = (1 - alpha_next - sigma ** 2).sqrt() | |
| noise = torch.randn_like(x) | |
| x = x_start * alpha_next.sqrt() + \ | |
| c * pred_noise + \ | |
| sigma * noise | |
| timesteps = torch.full((batch,), time_next, device=device, dtype=torch.long) | |
| value_ = self.q_sample(value, timesteps) if (time > 0) else x | |
| x = value_ * mask + (1.0 - mask) * x | |
| return x | |
| def inpaint_ddim_loop(self, shape, noise=None, constraint=None, return_diffusion=False, start_point=None): | |
| batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 0 | |
| times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) | |
| times = list(reversed(times.int().tolist())) | |
| time_pairs = list(zip(times[:-1], times[1:])) | |
| x = torch.randn(shape, device=device) | |
| cond = constraint["cond"].to(device) | |
| mask = constraint["mask"].to(device) # batch x horizon x channels | |
| value = constraint["value"].to(device) # batch x horizon x channels | |
| for time, time_next in time_pairs: | |
| time_cond = torch.full((batch,), time, device=device, dtype=torch.long) | |
| pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start=self.clip_denoised) | |
| if time_next < 0: | |
| x = x_start | |
| return x | |
| alpha = self.alphas_cumprod[time] | |
| alpha_next = self.alphas_cumprod[time_next] | |
| sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() | |
| c = (1 - alpha_next - sigma ** 2).sqrt() | |
| noise = torch.randn_like(x) | |
| x = x_start * alpha_next.sqrt() + \ | |
| c * pred_noise + \ | |
| sigma * noise | |
| timesteps = torch.full((batch,), time_next, device=device, dtype=torch.long) | |
| value_ = self.q_sample(value, timesteps) if (time > 0) else x | |
| x = value_ * mask + (1.0 - mask) * x | |
| return x | |
| def q_sample(self, x_start, t, noise=None): # blend noise into state variables | |
| if noise is None: | |
| noise = torch.randn_like(x_start) | |
| sample = ( | |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise | |
| ) | |
| return sample | |
| def forward(self, x, cond, t_override=None): | |
| return self.loss(x, cond, t_override) | |
| def generate_samples( | |
| self, | |
| shape, | |
| normalizer, | |
| opt, | |
| mode, | |
| noise=None, | |
| constraint=None, | |
| start_point=None, | |
| ): | |
| torch.manual_seed(torch.randint(0, 2 ** 32, (1,)).item()) | |
| if isinstance(shape, tuple): | |
| if mode == "inpaint": | |
| func_class = self.inpaint_ddim_loop | |
| elif mode == "inpaint_ddim_guided": | |
| func_class = self.inpaint_ddim_guided | |
| else: | |
| assert False, "Unrecognized inference mode" | |
| samples = ( | |
| func_class( | |
| shape, | |
| noise=noise, | |
| constraint=constraint, | |
| start_point=start_point, | |
| ) | |
| ) | |
| else: | |
| samples = shape | |
| samples = normalizer.unnormalize(samples.detach().cpu()) | |
| samples = samples.detach().cpu() | |
| return samples | |
| class FillingBase: | |
| def fill_param(self, windows, diffusion_model_for_filling, progress_callback=None): | |
| return self.filling(windows, diffusion_model_for_filling, self.update_kinematics_and_masks_for_masking_column, progress_callback=progress_callback) | |
| def update_kinematics_and_masks_for_masking_column(windows, samples, i_win, masks): | |
| unmasked_samples_in_temporal_dim = (masks.sum(axis=2)).bool() | |
| for j_win in range(len(samples)): | |
| windows[j_win+i_win].pose = samples[j_win] | |
| updated_mask = windows[j_win+i_win].mask | |
| updated_mask[unmasked_samples_in_temporal_dim[j_win], :] = 1 | |
| updated_mask[:, opt.kinetic_diffusion_col_loc] = 0 | |
| windows[j_win+i_win].mask = updated_mask | |
| return windows | |
| def update_kinematics_and_masks_for_masking_temporal(self, windows, samples, i_win, masks): | |
| for j_win in range(len(samples)): | |
| windows[j_win+i_win].pose = samples[j_win] | |
| windows[j_win+i_win].mask = self.mask_original[j_win+i_win] | |
| return windows | |
| def filling(windows, diffusion_model_for_filling, windows_update_func): | |
| raise NotImplementedError | |
| class DiffusionFilling(FillingBase): | |
| def filling(windows, diffusion_model_for_filling, windows_update_func, progress_callback=None): | |
| windows = copy.deepcopy(windows) | |
| total_iterations = len(range(0, len(windows), opt.batch_size_inference)) | |
| for iteration_idx, i_win in enumerate(range(0, len(windows), opt.batch_size_inference)): | |
| state_true = torch.stack([win.pose for win in windows[i_win:i_win+opt.batch_size_inference]]) | |
| masks = torch.stack([win.mask for win in windows[i_win:i_win+opt.batch_size_inference]]) | |
| cond = torch.ones([6]) | |
| constraint = {'mask': masks, 'value': state_true.clone(), 'cond': cond} | |
| shape = (state_true.shape[0], state_true.shape[1], state_true.shape[2]) | |
| samples = (diffusion_model_for_filling.diffusion.inpaint_ddim_loop(shape, constraint=constraint)) | |
| samples = state_true * masks + (1.0 - masks) * samples.to(state_true.device) | |
| # samples[:, :, opt.kinetic_diffusion_col_loc] = state_true[:, :, opt.kinetic_diffusion_col_loc] | |
| windows = windows_update_func(windows, samples, i_win, masks) | |
| # Report progress: 10% for first iteration, 80% for last iteration | |
| if progress_callback: | |
| progress = 0.1 + 0.7 * (iteration_idx + 1) / total_iterations | |
| progress_callback(progress) | |
| return windows | |
| def __str__(self): | |
| return 'diffusion_filling' | |
| def wrap(x): | |
| return {f"module.{key}": value for key, value in x.items()} | |
| def maybe_wrap(x, num): | |
| return x if num == 1 else wrap(x) | |
| def exists(val): | |
| return val is not None | |
| def broadcat(tensors, dim=-1): | |
| num_tensors = len(tensors) | |
| shape_lens = set(list(map(lambda t: len(t.shape), tensors))) | |
| assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" | |
| shape_len = list(shape_lens)[0] | |
| dim = (dim + shape_len) if dim < 0 else dim | |
| dims = list(zip(*map(lambda t: list(t.shape), tensors))) | |
| expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] | |
| assert all( | |
| [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] | |
| ), "invalid dimensions for broadcastable concatentation" | |
| max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) | |
| expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) | |
| expanded_dims.insert(dim, (dim, dims[dim])) | |
| expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) | |
| tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) | |
| return torch.cat(tensors, dim=dim) | |
| def rotate_half(x): | |
| x = rearrange(x, "... (d r) -> ... d r", r=2) | |
| x1, x2 = x.unbind(dim=-1) | |
| x = torch.stack((-x2, x1), dim=-1) | |
| return rearrange(x, "... d r -> ... (d r)") | |
| def apply_rotary_emb(freqs, t, start_index=0): | |
| freqs = freqs.to(t) | |
| rot_dim = freqs.shape[-1] | |
| end_index = start_index + rot_dim | |
| assert ( | |
| rot_dim <= t.shape[-1] | |
| ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" | |
| t_left, t, t_right = ( | |
| t[..., :start_index], | |
| t[..., start_index:end_index], | |
| t[..., end_index:], | |
| ) | |
| t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) | |
| return torch.cat((t_left, t, t_right), dim=-1) | |
| class RotaryEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| custom_freqs=None, | |
| freqs_for="lang", | |
| theta=10000, | |
| max_freq=10, | |
| num_freqs=1, | |
| learned_freq=False, | |
| ): | |
| super().__init__() | |
| if exists(custom_freqs): | |
| freqs = custom_freqs | |
| elif freqs_for == "lang": | |
| freqs = 1.0 / ( | |
| theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) | |
| ) | |
| elif freqs_for == "pixel": | |
| freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi | |
| elif freqs_for == "constant": | |
| freqs = torch.ones(num_freqs).float() | |
| else: | |
| raise ValueError(f"unknown modality {freqs_for}") | |
| self.cache = dict() | |
| if learned_freq: | |
| self.freqs = nn.Parameter(freqs) | |
| else: | |
| self.register_buffer("freqs", freqs) | |
| def rotate_queries_or_keys(self, t, seq_dim=-2): | |
| device = t.device | |
| seq_len = t.shape[seq_dim] | |
| freqs = self.forward( | |
| lambda: torch.arange(seq_len, device=device), cache_key=seq_len | |
| ) | |
| return apply_rotary_emb(freqs, t) | |
| def forward(self, t, cache_key=None): | |
| if exists(cache_key) and cache_key in self.cache: | |
| return self.cache[cache_key] | |
| if isfunction(t): | |
| t = t() | |
| freqs = self.freqs | |
| freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs) | |
| freqs = repeat(freqs, "... n -> ... (n r)", r=2) | |
| if exists(cache_key): | |
| self.cache[cache_key] = freqs | |
| return freqs | |
| class EMA: | |
| def __init__(self, beta): | |
| super().__init__() | |
| self.beta = beta | |
| def update_model_average(self, ma_model, current_model): | |
| for current_params, ma_params in zip( | |
| current_model.parameters(), ma_model.parameters() | |
| ): | |
| old_weight, up_weight = ma_params.data, current_params.data | |
| ma_params.data = self.update_average(old_weight, up_weight) | |
| def update_average(self, old, new): | |
| if old is None: | |
| return new | |
| return old * self.beta + (1 - self.beta) * new | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len, dropout: float = 0.): | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| position = torch.arange(max_len).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(max_len * 2) / d_model)) | |
| pe = torch.zeros(max_len, 1, d_model) | |
| pe[:, 0, 0::2] = torch.sin(position * div_term) | |
| try: | |
| pe[:, 0, 1::2] = torch.cos(position * div_term) | |
| except RuntimeError: | |
| pe[:, 0, 1::2] = torch.cos(position * div_term)[:, :-1] | |
| pe = pe.transpose(0, 1) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x: Tensor) -> Tensor: | |
| """ | |
| Args: | |
| x: Tensor, shape [batch_size, seq_len, embedding_dim] | |
| """ | |
| x = x + self.pe[:, :x.shape[1], :] | |
| return self.dropout(x) | |
| class PositionWiseFeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff): | |
| super(PositionWiseFeedForward, self).__init__() | |
| self.fc1 = nn.Linear(d_model, d_ff) | |
| self.fc2 = nn.Linear(d_ff, d_model) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| return self.fc2(self.relu(self.fc1(x))) | |
| class EncoderLayer(nn.Module): | |
| def __init__(self, d_model, num_heads=8, d_ff=512, dropout=0.1): | |
| super(EncoderLayer, self).__init__() | |
| self.rotary = RotaryEmbedding(dim=d_model) | |
| self.use_rotary = self.rotary is not None | |
| self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True) | |
| self.feed_forward = PositionWiseFeedForward(d_model, d_ff) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x | |
| attn_output, _ = self.self_attn(qk, qk, x, need_weights=False) | |
| x = self.norm1(x + self.dropout(attn_output)) | |
| ff_output = self.feed_forward(x) | |
| x = self.norm2(x + self.dropout(ff_output)) | |
| return x | |
| class TransformerEncoderArchitecture(nn.Module): | |
| def __init__(self, repr_dim, opt, nlayers=6): | |
| super(TransformerEncoderArchitecture, self).__init__() | |
| self.input_dim = len(opt.kinematic_diffusion_col_loc) | |
| self.output_dim = repr_dim - self.input_dim | |
| embedding_dim = 192 | |
| self.input_to_embedding = nn.Linear(self.input_dim, embedding_dim) | |
| self.encoder_layers = nn.Sequential(*[EncoderLayer(embedding_dim) for _ in range(nlayers)]) | |
| self.embedding_to_output = nn.Linear(embedding_dim, self.output_dim) | |
| self.opt = opt | |
| self.input_col_loc = opt.kinematic_diffusion_col_loc | |
| self.output_col_loc = [i for i in range(repr_dim) if i not in self.input_col_loc] | |
| def loss_fun(self, output_pred, output_true): | |
| return F.mse_loss(output_pred, output_true, reduction='none') | |
| def end_to_end_prediction(self, x): | |
| input = x[0][:, :, self.input_col_loc] | |
| sequence = self.input_to_embedding(input) | |
| sequence = self.encoder_layers(sequence) | |
| output_pred = self.embedding_to_output(sequence) | |
| return output_pred | |
| def __str__(self): | |
| return 'tf' | |
| class DiffusionShellForAdaptingTheOriginalFramework(nn.Module): | |
| def __init__(self, model): | |
| super(DiffusionShellForAdaptingTheOriginalFramework, self).__init__() | |
| self.device = device | |
| self.model = model | |
| self.ema = EMA(0.99) | |
| self.master_model = copy.deepcopy(self.model) | |
| def set_normalizer(self, normalizer): | |
| self.normalizer = normalizer | |
| def predict_samples(self, x, constraint): | |
| x[0] = x[0] * constraint['mask'] | |
| output_pred = self.model.end_to_end_prediction(x) | |
| x[0][:, :, self.model.output_col_loc] = output_pred | |
| return x[0] | |
| def forward(self, x, cond, t_override): | |
| output_true = x[0][:, :, self.model.output_col_loc] | |
| output_pred = self.model.end_to_end_prediction(x) | |
| loss_simple = torch.zeros(x[0].shape).to(x[0].device) | |
| loss_simple[:, :, self.model.output_col_loc] = self.model.loss_fun(output_pred, output_true) | |
| losses = [ | |
| 1. * loss_simple.mean(), | |
| torch.tensor(0.).to(loss_simple.device), | |
| torch.tensor(0.).to(loss_simple.device), | |
| torch.tensor(0.).to(loss_simple.device), | |
| torch.tensor(0).to(loss_simple.device)] | |
| return sum(losses), losses + [loss_simple] | |
| def featurewise_affine(x, scale_shift): | |
| scale, shift = scale_shift | |
| return (scale + 1) * x + shift | |
| class DenseFiLM(nn.Module): | |
| """Feature-wise linear modulation (FiLM) generator.""" | |
| def __init__(self, embed_channels): | |
| super().__init__() | |
| self.embed_channels = embed_channels | |
| self.block = nn.Sequential( | |
| nn.Mish(), nn.Linear(embed_channels, embed_channels * 2) | |
| ) | |
| def forward(self, position): | |
| pos_encoding = self.block(position) | |
| pos_encoding = rearrange(pos_encoding, "b c -> b 1 c") | |
| scale_shift = pos_encoding.chunk(2, dim=-1) | |
| return scale_shift | |
| class FiLMTransformerDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| activation=F.relu, | |
| layer_norm_eps=1e-5, | |
| batch_first=False, | |
| norm_first=True, | |
| device=None, | |
| dtype=None, | |
| rotary=None, | |
| ): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention( | |
| d_model, nhead, dropout=dropout, batch_first=batch_first | |
| ) | |
| self.multihead_attn = nn.MultiheadAttention( | |
| d_model, nhead, dropout=dropout, batch_first=batch_first | |
| ) | |
| # Feedforward | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm_first = norm_first | |
| self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) | |
| self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.activation = activation | |
| self.film1 = DenseFiLM(d_model) | |
| self.film3 = DenseFiLM(d_model) | |
| self.rotary = rotary | |
| self.use_rotary = rotary is not None | |
| # x, cond, t | |
| def forward( | |
| self, | |
| tgt, | |
| memory, | |
| t, | |
| tgt_mask=None, | |
| memory_mask=None, | |
| tgt_key_padding_mask=None, | |
| memory_key_padding_mask=None, | |
| ): | |
| x = tgt | |
| if self.norm_first: | |
| x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) | |
| x = x + featurewise_affine(x_1, self.film1(t)) | |
| x_3 = self._ff_block(self.norm3(x)) | |
| x = x + featurewise_affine(x_3, self.film3(t)) | |
| else: | |
| x = self.norm1( | |
| x | |
| + featurewise_affine( | |
| self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t) | |
| ) | |
| ) | |
| x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t))) | |
| return x | |
| # self-attention block | |
| # qkv | |
| def _sa_block(self, x, attn_mask, key_padding_mask): | |
| qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x | |
| x = self.self_attn( | |
| qk, | |
| qk, | |
| x, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| )[0] | |
| return self.dropout1(x) | |
| # multihead attention block | |
| # qkv | |
| def _mha_block(self, x, mem, attn_mask, key_padding_mask): | |
| q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x | |
| k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem | |
| x = self.multihead_attn( | |
| q, | |
| k, | |
| mem, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| )[0] | |
| return self.dropout2(x) | |
| # feed forward block | |
| def _ff_block(self, x): | |
| x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
| return self.dropout3(x) | |
| class DecoderLayerStack(nn.Module): | |
| def __init__(self, stack): | |
| super().__init__() | |
| self.stack = stack | |
| def forward(self, x, cond, t): | |
| for layer in self.stack: | |
| x = layer(x, cond, t) | |
| return x | |
| class DanceDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| nfeats: int, | |
| seq_len: int = 150, # 5 seconds, 30 fps | |
| latent_dim: int = 512, | |
| ff_size: int = 1024, | |
| num_layers: int = 8, | |
| num_heads: int = 8, | |
| dropout: float = 0.1, | |
| # cond_feature_dim: int = 6, | |
| activation=F.gelu, | |
| use_rotary=True, | |
| **kwargs | |
| ) -> None: | |
| super().__init__() | |
| output_feats = nfeats | |
| # positional embeddings | |
| self.rotary = None | |
| self.abs_pos_encoding = nn.Identity() | |
| # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity) | |
| if use_rotary: | |
| self.rotary = RotaryEmbedding(dim=latent_dim) | |
| else: | |
| self.abs_pos_encoding = PositionalEncoding( | |
| latent_dim, dropout, batch_first=True | |
| ) | |
| # time embedding processing | |
| self.time_mlp = nn.Sequential( | |
| SinusoidalPosEmb(latent_dim), | |
| nn.Linear(latent_dim, latent_dim * 4), | |
| nn.Mish(), | |
| ) | |
| # input projection | |
| self.input_projection = nn.Linear(nfeats, latent_dim) | |
| self.to_time_cond = nn.Sequential(nn.Linear(latent_dim * 4, latent_dim),) | |
| decoderstack = nn.ModuleList([]) | |
| for _ in range(num_layers): | |
| decoderstack.append( | |
| FiLMTransformerDecoderLayer( # decoder layers | |
| latent_dim, | |
| num_heads, | |
| dim_feedforward=ff_size, | |
| dropout=dropout, | |
| activation=activation, | |
| batch_first=True, | |
| rotary=self.rotary, | |
| ) | |
| ) | |
| self.seqTransDecoder = DecoderLayerStack(decoderstack) | |
| self.final_layer = nn.Linear(latent_dim, output_feats) | |
| def guided_forward(self, x, cond_embed, time_cond, guidance_weight): | |
| return self.forward(x, cond_embed, time_cond) | |
| def __str__(self): | |
| return 'diffusion' | |
| # No conditioning version | |
| def forward(self, x: Tensor, cond_embed: Tensor, time_cond: Tensor, cond_drop_prob: float = 0.0): | |
| x = self.input_projection(x) | |
| x = self.abs_pos_encoding(x) | |
| t_hidden = self.time_mlp(time_cond) | |
| t = self.to_time_cond(t_hidden) | |
| output = self.seqTransDecoder(x, None, t) | |
| output = self.final_layer(output) | |
| return output | |
| class MotionModel: | |
| def __init__( | |
| self, | |
| opt, | |
| normalizer=None, | |
| EMA=True, | |
| learning_rate=4e-4, | |
| weight_decay=0.02, | |
| ): | |
| self.opt = opt | |
| ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
| self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) | |
| state = AcceleratorState() | |
| num_processes = state.num_processes | |
| self.repr_dim = len(opt.model_states_column_names) | |
| self.horizon = horizon = opt.window_len | |
| self.accelerator.wait_for_everyone() | |
| checkpoint = None | |
| if opt.checkpoint != "": | |
| checkpoint = torch.load( | |
| opt.checkpoint, map_location=self.accelerator.device, | |
| weights_only=False | |
| ) | |
| self.normalizer = checkpoint["normalizer"] | |
| model = DanceDecoder( | |
| nfeats=self.repr_dim, | |
| seq_len=horizon, | |
| latent_dim=512, | |
| ff_size=1024, | |
| num_layers=8, | |
| num_heads=8, | |
| dropout=0.1, | |
| activation=F.gelu, | |
| ) | |
| diffusion = GaussianDiffusion( | |
| model, | |
| horizon, | |
| self.repr_dim, | |
| opt, | |
| # schedule="cosine", | |
| n_timestep=1000, | |
| predict_epsilon=False, | |
| loss_type="l2", | |
| use_p2=False, | |
| cond_drop_prob=0., | |
| guidance_weight=2, | |
| ) | |
| self.model = self.accelerator.prepare(model) | |
| self.diffusion = diffusion.to(self.accelerator.device) | |
| if opt.checkpoint != "": | |
| self.model.load_state_dict( | |
| maybe_wrap( | |
| checkpoint["ema_state_dict" if EMA else "model_state_dict"], | |
| num_processes, | |
| ) | |
| ) | |
| def eval(self): | |
| self.diffusion.eval() | |
| def prepare(self, objects): | |
| return self.accelerator.prepare(*objects) | |
| def eval_loop(self, opt, state_true, masks, value_diff_thd=None, value_diff_weight=None, cond=None, | |
| num_of_generation_per_window=1, mode="inpaint"): | |
| self.eval() | |
| if value_diff_thd is None: | |
| value_diff_thd = torch.zeros([state_true.shape[2]]) | |
| if value_diff_weight is None: | |
| value_diff_weight = torch.ones([state_true.shape[2]]) | |
| if cond is None: | |
| cond = torch.ones([6]) | |
| constraint = {'mask': masks, 'value': state_true.clone(), 'value_diff_thd': value_diff_thd, | |
| 'value_diff_weight': value_diff_weight, 'cond': cond} | |
| shape = (state_true.shape[0], self.horizon, self.repr_dim) | |
| state_pred_list = [self.diffusion.generate_samples( | |
| shape, | |
| self.normalizer, | |
| opt, | |
| mode=mode, | |
| constraint=constraint) | |
| for _ in range(num_of_generation_per_window)] | |
| return torch.stack(state_pred_list) | |
| class BaselineModel: | |
| def __init__( | |
| self, | |
| opt, | |
| model_architecture_class, | |
| EMA=True, | |
| ): | |
| self.opt = opt | |
| ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
| self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) | |
| self.repr_dim = len(opt.model_states_column_names) | |
| self.horizon = horizon = opt.window_len | |
| self.accelerator.wait_for_everyone() | |
| self.model = model_architecture_class(self.repr_dim, opt) | |
| self.diffusion = DiffusionShellForAdaptingTheOriginalFramework(self.model) | |
| self.diffusion = self.accelerator.prepare(self.diffusion) | |
| checkpoint = None | |
| if opt.checkpoint_bl != "": | |
| checkpoint = torch.load( | |
| opt.checkpoint_bl, map_location=self.accelerator.device, | |
| weights_only=False | |
| ) | |
| self.normalizer = checkpoint["normalizer"] | |
| if opt.checkpoint_bl != "": | |
| self.model.load_state_dict( | |
| maybe_wrap( | |
| checkpoint["ema_state_dict" if EMA else "model_state_dict"], | |
| 1, | |
| ) | |
| ) | |
| def eval_loop(self, opt, state_true, masks, value_diff_thd=None, value_diff_weight=None, cond=None, | |
| num_of_generation_per_window=1, mode="inpaint"): | |
| self.eval() | |
| constraint = {'mask': masks.to(self.accelerator.device), 'value': state_true, 'cond': cond} | |
| state_true = state_true.to(self.accelerator.device) | |
| state_pred_list = [self.diffusion.predict_samples([state_true], constraint) | |
| for _ in range(num_of_generation_per_window)] | |
| state_pred_list = [self.normalizer.unnormalize(state_pred.detach().cpu()) for state_pred in state_pred_list] | |
| return torch.stack(state_pred_list) | |
| def eval(self): | |
| self.diffusion.eval() | |
| def train(self): | |
| self.diffusion.train() | |
| def prepare(self, objects): | |
| return self.accelerator.prepare(*objects) | |
| """ ============================ End model.py ============================ """ | |
| """ ============================ Start util.py ============================ """ | |
| def load_diffusion_model(opt): | |
| opt.checkpoint = opt.subject_data_path + '/GaitDynamicsDiffusion.pt' | |
| model = MotionModel(opt) | |
| model_key = 'diffusion' | |
| return model, model_key | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = x[:, None] * emb[None, :] | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| def convertDfToGRFMot(df, out_path, dt, time_column): | |
| numFrames = df.shape[0] | |
| for key in df.keys(): | |
| if key == 'TimeStamp': | |
| continue | |
| out_file = open(out_path, 'w') | |
| out_file.write('nColumns=9\n') | |
| out_file.write('nRows='+str(numFrames)+'\n') | |
| out_file.write('DataType=double\n') | |
| out_file.write('version=3\n') | |
| out_file.write('OpenSimVersion=4.1\n') | |
| out_file.write('endheader\n') | |
| out_file.write('time') | |
| plate_num = 2 | |
| for i, side in zip(range(1, 1 + plate_num), ['r', 'l']): | |
| out_file.write('\t' + f'force_{side}_vx') | |
| out_file.write('\t' + f'force_{side}_vy') | |
| out_file.write('\t' + f'force_{side}_vz') | |
| out_file.write('\t' + f'force_{side}_px') | |
| out_file.write('\t' + f'force_{side}_py') | |
| out_file.write('\t' + f'force_{side}_pz') | |
| out_file.write('\t' + f'torque_{side}_x') | |
| out_file.write('\t' + f'torque_{side}_y') | |
| out_file.write('\t' + f'torque_{side}_z') | |
| out_file.write('\n') | |
| for i in range(numFrames): | |
| out_file.write(str(round(dt * i + time_column[0], 5))) | |
| for side in ['r', 'l']: | |
| out_file.write('\t' + str(df[f'calcn_{side}_force_vx'][i])) | |
| out_file.write('\t' + str(df[f'calcn_{side}_force_vy'][i])) | |
| out_file.write('\t' + str(df[f'calcn_{side}_force_vz'][i])) | |
| out_file.write('\t' + str(df[f'calcn_{side}_force_normed_cop_x'][i])) | |
| out_file.write('\t' + str(df[f'calcn_{side}_force_normed_cop_y'][i])) | |
| out_file.write('\t' + str(df[f'calcn_{side}_force_normed_cop_z'][i])) | |
| out_file.write('\t' + str(0)) | |
| out_file.write('\t' + str(0)) | |
| out_file.write('\t' + str(0)) | |
| out_file.write('\n') | |
| out_file.close() | |
| print('Ground reaction forces exported to ' + out_path) | |
| def convertDataframeToMot(df, out_path, dt, time_column): | |
| numFrames = df.shape[0] | |
| out_file = open(out_path, 'w') | |
| out_file.write('Coordinates\n') | |
| out_file.write('version=1\n') | |
| out_file.write(f'nRows={numFrames}\n') | |
| out_file.write(f'nColumns={len(df.columns)+1}\n') | |
| out_file.write('inDegrees=no\n\n') | |
| out_file.write('If the header above contains a line with \'inDegrees\', this indicates whether rotational values are in degrees (yes) or radians (no).\n\n') | |
| out_file.write('endheader\n') | |
| out_file.write('time') | |
| for i in range(len(df.columns)): | |
| out_file.write('\t' + df.columns[i]) | |
| out_file.write('\n') | |
| for i in range(numFrames): | |
| out_file.write(str(round(dt * i + time_column[0], 5))) | |
| for j in range(len(df.columns)): | |
| out_file.write('\t' + str(df.iloc[i, j])) | |
| out_file.write('\n') | |
| out_file.close() | |
| print('Missing kinematics exported to ' + out_path) | |
| def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: | |
| a1, a2 = d6[..., :3], d6[..., 3:] | |
| b1 = F.normalize(a1, dim=-1) | |
| b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 | |
| b2 = F.normalize(b2, dim=-1) | |
| b3 = torch.cross(b1, b2, dim=-1) | |
| return torch.stack((b1, b2, b3), dim=-2) | |
| def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: | |
| batch_dim = matrix.size()[:-2] | |
| return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) | |
| def euler_from_6v(q, convention="XYZ"): | |
| assert q.shape[-1] == 6 | |
| mat = rotation_6d_to_matrix(q) | |
| eul = matrix_to_euler_angles(mat, convention) | |
| return eul | |
| def euler_to_6v(q, convention="XYZ"): | |
| assert q.shape[-1] == 3 | |
| mat = euler_angles_to_matrix(q, convention) | |
| mat = matrix_to_rotation_6d(mat) | |
| return mat | |
| def second_order_poly(coeff, x): | |
| y = coeff[...,0] * x**2 + coeff[...,1] * x + coeff[...,2] | |
| return y | |
| def batch_identity(batch_shape, size): | |
| batch_identity = torch.eye(size) | |
| output_shape = batch_shape.copy() | |
| output_shape.append(size) | |
| output_shape.append(size) | |
| batch_identity_out = batch_identity.view(*(1,) * (len(output_shape) - batch_identity.ndim),*batch_identity.shape).expand(output_shape) | |
| return batch_identity_out.clone() | |
| def get_knee_rotation_coefficients(): | |
| knee_Z_rotation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, | |
| 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], | |
| [0, 0.0126809, 0.0226969, 0.0296054, 0.0332049, 0.0335354, 0.0308779, | |
| 0.0257548, 0.0189295, 0.011407, 0.00443314, -0.00050475, -0.0016782]]).T | |
| polyfit_knee_Z_rotation = np.polyfit(knee_Z_rotation_function[:,0], knee_Z_rotation_function[:,1], deg=2, full = True) | |
| coefficients_knee_Z_rotation = polyfit_knee_Z_rotation[0] | |
| knee_Y_rotation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, | |
| 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], | |
| [0, 0.059461, 0.109399, 0.150618, 0.18392, 0.210107, 0.229983, 0.24435, 0.254012, 0.25977, 0.262428, 0.262788, 0.261654]]).T | |
| polyfit_knee_Y_rotation = np.polyfit(knee_Y_rotation_function[:, 0], knee_Y_rotation_function[:, 1], deg=2, full=True) | |
| coefficients_knee_Y_rotation = polyfit_knee_Y_rotation[0] | |
| knee_X_translation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, | |
| 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], | |
| [0, 5.3e-05, 0.000188, 0.000378, 0.000597, 0.000825, 0.001045, 0.001247, 0.00142, 0.001558, 0.001661, 0.001728, 0.00176]]).T | |
| polyfit_knee_X_translation = np.polyfit(knee_X_translation_function[:, 0], knee_X_translation_function[:, 1], deg=2, | |
| full=True) | |
| coefficients_knee_X_translation = polyfit_knee_X_translation[0] | |
| knee_Y_translation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, | |
| 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], | |
| [0, 0.000301, 0.000143, -0.000401, -0.001233, -0.002243, -0.003316, -0.004346, -0.005239, -0.005924, -0.006361, -0.006539, -0.00648]]).T | |
| polyfit_knee_Y_translation = np.polyfit(knee_Y_translation_function[:, 0], knee_Y_translation_function[:, 1], deg=2, | |
| full=True) | |
| coefficients_knee_Y_translation = polyfit_knee_Y_translation[0] | |
| knee_Z_translation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, | |
| 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], | |
| [0, 0.001055, 0.002061, 0.00289, 0.003447, 0.003676, 0.003559, 0.00311, 0.002373, 0.001418, 0.000329, -0.000805, -0.001898]]).T | |
| polyfit_knee_Z_translation = np.polyfit(knee_Z_translation_function[:, 0], knee_Z_translation_function[:, 1], deg=2, | |
| full=True) | |
| coefficients_knee_Z_translation = polyfit_knee_Z_translation[0] | |
| walker_knee_coefficients = np.stack((coefficients_knee_Y_rotation, coefficients_knee_Z_rotation, coefficients_knee_X_translation, coefficients_knee_Y_translation, coefficients_knee_Z_translation), axis=1) | |
| return walker_knee_coefficients | |
| walker_knee_coefficients = get_knee_rotation_coefficients() | |
| walker_knee_coefficients = torch.tensor(walker_knee_coefficients).to(device) | |
| def forward_kinematics(pose, offsets, with_arm=False): | |
| """ | |
| Pose indices | |
| 0-5: pelvis orientation + translation | |
| 6-8: hip_r | |
| 9: knee_r | |
| 10: ankle_r | |
| 11: subtalar_r | |
| 12: mtp_r | |
| 13-15: hip_l | |
| 16: knee_l | |
| 17: ankle_l | |
| 18: subtalar_l | |
| 19: mtp_l | |
| 20-22: lumbar | |
| 23-25: shoulder_r | |
| 26: elbow_r | |
| 27: radioulnar | |
| 28-30: shoulder_l | |
| 31: elbow_l | |
| 32: radioulnar_l | |
| """ | |
| if isinstance(pose, np.ndarray): | |
| pose = torch.from_numpy(pose) | |
| pose = pose.to(torch.device(device)) | |
| if len(pose.shape) == 2: | |
| pose = pose[None, ...] | |
| if len(offsets.shape) == 3: | |
| offsets = offsets[None, ...] | |
| offsets = offsets.to(pose.device) | |
| offsets = offsets[:, None, ...] | |
| batch_shape = pose.shape[:-1] | |
| batch_shape_list = [] | |
| for i in range(pose.dim()-1): | |
| batch_shape_list.append(int(batch_shape[i])) | |
| batch_shape = batch_shape_list | |
| if batch_shape == (): | |
| batch_shape = (1,) | |
| coefficients_knee_Y_rotation = walker_knee_coefficients[..., 0] | |
| coefficients_knee_Z_rotation = walker_knee_coefficients[..., 1] | |
| coefficients_knee_X_translation = walker_knee_coefficients[..., 2] | |
| coefficients_knee_Y_translation = walker_knee_coefficients[..., 3] | |
| coefficients_knee_Z_translation = walker_knee_coefficients[..., 4] | |
| knee_r_Y_rot = second_order_poly(coefficients_knee_Y_rotation, pose[..., 9]) | |
| knee_r_Z_rot = second_order_poly(coefficients_knee_Z_rotation, pose[..., 9]) | |
| knee_r_X_trans = second_order_poly(coefficients_knee_X_translation, pose[..., 9]) | |
| knee_r_Y_trans = second_order_poly(coefficients_knee_Y_translation, pose[..., 9]) | |
| knee_r_Z_trans = second_order_poly(coefficients_knee_Z_translation, pose[..., 9]) | |
| knee_l_Y_rot = second_order_poly(coefficients_knee_Y_rotation, pose[..., 16]) | |
| knee_l_Z_rot = second_order_poly(coefficients_knee_Z_rotation, pose[..., 16]) | |
| knee_l_X_trans = second_order_poly(coefficients_knee_X_translation, pose[..., 16]) | |
| knee_l_Y_trans = second_order_poly(coefficients_knee_Y_translation, pose[..., 16]) | |
| knee_l_Z_trans = second_order_poly(coefficients_knee_Z_translation, pose[..., 16]) | |
| # Pelvis | |
| pelvis_transform = batch_identity(batch_shape, 4).to(pose.device) | |
| pelvis_transform[..., :3, :3] = euler_angles_to_matrix(pose[..., 0:3], 'ZXY') | |
| pelvis_transform[..., :3, 3] = pose[..., 3:6].clone().detach() | |
| # Get offsets (model and model scaling dependent) | |
| offset_hip_pelvis_r = offsets[..., 0] | |
| femur_offset_in_hip_r = offsets[..., 1] | |
| knee_offset_in_femur_r = offsets[..., 2] | |
| tibia_offset_in_knee_r = offsets[..., 3] | |
| ankle_offset_in_tibia_r = offsets[..., 4] | |
| talus_offset_in_ankle_r = offsets[..., 5] | |
| subtalar_offset_in_talus_r = offsets[..., 6] | |
| calcaneus_offset_in_subtalar_r = offsets[..., 7] | |
| mtp_offset_in_calcaneus_r = offsets[..., 8] | |
| offset_hip_pelvis_l = offsets[..., 9] | |
| femur_offset_in_hip_l = offsets[..., 10] | |
| knee_offset_in_femur_l = offsets[..., 11] | |
| tibia_offset_in_knee_l = offsets[..., 12] | |
| ankle_offset_in_tibia_l = offsets[..., 13] | |
| talus_offset_in_ankle_l = offsets[..., 14] | |
| subtalar_offset_in_talus_l = offsets[..., 15] | |
| calcaneus_offset_in_subtalar_l = offsets[..., 16] | |
| mtp_offset_in_calcaneus_l = offsets[..., 17] | |
| lumbar_offset_in_pelvis = offsets[..., 18] | |
| torso_offset_in_lumbar = offsets[..., 19] | |
| if with_arm: | |
| shoulder_offset_in_torso_r = offsets[..., 20] | |
| humerus_offset_in_shoulder_r = offsets[..., 21] | |
| elbow_offset_in_humerus_r = offsets[..., 22] | |
| ulna_offset_in_elbow_r = offsets[..., 23] | |
| radioulnar_offset_in_radius_r = offsets[..., 24] | |
| radius_offset_in_radioulnar_r = offsets[..., 25] | |
| wrist_offset_in_radius_r = offsets[..., 26] | |
| hand_offset_in_wrist_r = offsets[..., 27] | |
| shoulder_offset_in_torso_l = offsets[..., 28] | |
| humerus_offset_in_shoulder_l = offsets[..., 29] | |
| elbow_offset_in_humerus_l = offsets[..., 30] | |
| ulna_offset_in_elbow_l = offsets[..., 31] | |
| radioulnar_offset_in_radius_l = offsets[..., 32] | |
| radius_offset_in_radioulnar_l = offsets[..., 33] | |
| wrist_offset_in_radius_l = offsets[..., 34] | |
| hand_offset_in_wrist_l = offsets[..., 35] | |
| # Coordinates to transformation matrix | |
| hip_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) | |
| knee_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) | |
| ankle_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) | |
| subtalar_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) | |
| mtp_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) | |
| hip_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) | |
| knee_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) | |
| ankle_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) | |
| subtalar_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) | |
| mtp_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) | |
| lumbar_coordinates_transform = batch_identity(batch_shape, 4).to(pose.device) | |
| if with_arm: | |
| shoulder_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) | |
| elbow_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) | |
| radioulnar_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) | |
| wrist_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) | |
| shoulder_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) | |
| elbow_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) | |
| radioulnar_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) | |
| wrist_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) | |
| # Knee axis translation | |
| knee_coordinates_transform_r[..., :3, -1] = torch.stack((knee_r_X_trans, knee_r_Y_trans, knee_r_Z_trans), dim=-1) | |
| knee_coordinates_transform_l[..., :3, -1] = torch.stack((knee_l_X_trans, knee_l_Y_trans, -knee_l_Z_trans), dim=-1) | |
| # Joint rotations | |
| zero_2_shape = batch_shape.copy() | |
| zero_2_shape.append(2) | |
| zero_2 = torch.zeros(tuple(zero_2_shape), device=pose.device) | |
| zero_3_shape = batch_shape.copy() | |
| zero_3_shape.append(3) | |
| zero_3 = torch.zeros(tuple(zero_3_shape), device=pose.device) | |
| hip_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix((pose[..., 6:9]), 'ZXY') | |
| knee_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( | |
| torch.stack((pose[..., 9], knee_r_Y_rot, knee_r_Z_rot), dim=-1), 'XYZ') | |
| ankle_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 10:11], zero_2), dim=-1), 'ZXY') | |
| subtalar_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 11:12], zero_2), dim=-1), 'ZXY') | |
| mtp_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 12:13], zero_2), dim=-1), 'ZXY') | |
| hip_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat(([pose[..., 13:14], -pose[..., 14:15], -pose[..., 15:16]]), dim=-1), 'ZXY') | |
| knee_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( | |
| torch.stack((-pose[..., 16], -knee_l_Y_rot, knee_l_Z_rot), dim=-1), 'XYZ') | |
| ankle_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 17:18], zero_2), dim=-1), 'ZXY') | |
| subtalar_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 18:19], zero_2), dim=-1), 'ZXY') | |
| mtp_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 19:20], zero_2), dim=-1), 'ZXY') | |
| lumbar_coordinates_transform[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 20:21], pose[..., 21:22], pose[..., 22:23]), dim=-1), 'ZXY') | |
| if with_arm: | |
| shoulder_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 23:24], pose[..., 24:25], pose[..., 25:26]), dim=-1), 'ZXY') | |
| elbow_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 26:27], zero_2), dim=-1), 'ZXY') | |
| radioulnar_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 27:28], zero_2), dim=-1), 'ZXY') | |
| wrist_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( | |
| zero_3, 'ZXY') | |
| shoulder_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 28:29], -pose[..., 29:30], -pose[..., 30:31]), dim=-1), 'ZXY') | |
| elbow_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 31:32], zero_2), dim=-1), 'ZXY') | |
| radioulnar_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( | |
| torch.cat((pose[..., 32:33], zero_2), dim=-1), 'ZXY') | |
| wrist_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( | |
| zero_3, 'ZXY') | |
| # Forward kinematics for the lower body | |
| hip_transform_r = torch.matmul(torch.matmul(pelvis_transform, offset_hip_pelvis_r), hip_coordinates_transform_r) | |
| femur_transform_r = torch.matmul(hip_transform_r, femur_offset_in_hip_r) | |
| knee_transform_r = torch.matmul(torch.matmul(femur_transform_r, knee_offset_in_femur_r), | |
| knee_coordinates_transform_r) | |
| tibia_transform_r = torch.matmul(knee_transform_r, tibia_offset_in_knee_r) | |
| ankle_transform_r = torch.matmul(torch.matmul(tibia_transform_r, ankle_offset_in_tibia_r), | |
| ankle_coordinates_transform_r) | |
| talus_transform_r = torch.matmul(ankle_transform_r, talus_offset_in_ankle_r) | |
| subtalar_transform_r = torch.matmul(torch.matmul(talus_transform_r, subtalar_offset_in_talus_r), | |
| subtalar_coordinates_transform_r) | |
| calcaneus_transform_r = torch.matmul(subtalar_transform_r, calcaneus_offset_in_subtalar_r) | |
| mtp_offset_transform_r = torch.matmul(torch.matmul(calcaneus_transform_r, mtp_offset_in_calcaneus_r), | |
| mtp_coordinates_transform_r) | |
| hip_transform_l = torch.matmul(torch.matmul(pelvis_transform, offset_hip_pelvis_l), hip_coordinates_transform_l) | |
| femur_transform_l = torch.matmul(hip_transform_l, femur_offset_in_hip_l) | |
| knee_transform_l = torch.matmul(torch.matmul(femur_transform_l, knee_offset_in_femur_l), | |
| knee_coordinates_transform_l) | |
| tibia_transform_l = torch.matmul(knee_transform_l, tibia_offset_in_knee_l) | |
| ankle_transform_l = torch.matmul(torch.matmul(tibia_transform_l, ankle_offset_in_tibia_l), | |
| ankle_coordinates_transform_l) | |
| talus_transform_l = torch.matmul(ankle_transform_l, talus_offset_in_ankle_l) | |
| subtalar_transform_l = torch.matmul(torch.matmul(talus_transform_l, subtalar_offset_in_talus_l), | |
| subtalar_coordinates_transform_l) | |
| calcaneus_transform_l = torch.matmul(subtalar_transform_l, calcaneus_offset_in_subtalar_l) | |
| mtp_offset_transform_l = torch.matmul(torch.matmul(calcaneus_transform_l, mtp_offset_in_calcaneus_l), | |
| mtp_coordinates_transform_l) | |
| # Forward kinematics for the upper body | |
| lumbar_transform = torch.matmul(torch.matmul(pelvis_transform, lumbar_offset_in_pelvis), lumbar_coordinates_transform) | |
| torso_transform = torch.matmul(lumbar_transform, torso_offset_in_lumbar) | |
| if with_arm: | |
| shoulder_transform_r = torch.matmul(torch.matmul(torso_transform, shoulder_offset_in_torso_r), shoulder_coordinates_transform_r) | |
| humerus_transform_r = torch.matmul(shoulder_transform_r, humerus_offset_in_shoulder_r) | |
| elbow_transform_r = torch.matmul(torch.matmul(humerus_transform_r, elbow_offset_in_humerus_r), elbow_coordinates_transform_r) | |
| ulna_transform_r = torch.matmul(elbow_transform_r, ulna_offset_in_elbow_r) | |
| radioulnar_transform_r = torch.matmul(torch.matmul(ulna_transform_r, radioulnar_offset_in_radius_r), radioulnar_coordinates_transform_r) | |
| radius_transform_r = torch.matmul(radioulnar_transform_r, radius_offset_in_radioulnar_r) | |
| wrist_transform_r = torch.matmul(torch.matmul(ulna_transform_r, wrist_offset_in_radius_r), wrist_coordinates_transform_r) | |
| hand_transform_r = torch.matmul(wrist_transform_r, hand_offset_in_wrist_r) | |
| shoulder_transform_l = torch.matmul(torch.matmul(torso_transform, shoulder_offset_in_torso_l), shoulder_coordinates_transform_l) | |
| humerus_transform_l = torch.matmul(shoulder_transform_l, humerus_offset_in_shoulder_l) | |
| elbow_transform_l = torch.matmul(torch.matmul(humerus_transform_l, elbow_offset_in_humerus_l), elbow_coordinates_transform_l) | |
| ulna_transform_l = torch.matmul(elbow_transform_l, ulna_offset_in_elbow_l) | |
| radioulnar_transform_l = torch.matmul(torch.matmul(ulna_transform_l, radioulnar_offset_in_radius_l), radioulnar_coordinates_transform_l) | |
| radius_transform_l = torch.matmul(radioulnar_transform_l, radius_offset_in_radioulnar_l) | |
| wrist_transform_l = torch.matmul(torch.matmul(ulna_transform_l, wrist_offset_in_radius_l), wrist_coordinates_transform_l) | |
| hand_transform_l = torch.matmul(wrist_transform_l, hand_offset_in_wrist_l) | |
| joint_locations = torch.stack((pelvis_transform[..., :3, 3], | |
| hip_transform_r[..., :3, 3], knee_transform_r[..., :3, 3], | |
| ankle_transform_r[..., :3, 3], calcaneus_transform_r[..., :3, 3], | |
| mtp_offset_transform_r[..., :3, 3], | |
| hip_transform_l[..., :3, 3], knee_transform_l[..., :3, 3], | |
| ankle_transform_l[..., :3, 3], calcaneus_transform_l[..., :3, 3], | |
| mtp_offset_transform_l[..., :3, 3])) | |
| if with_arm: | |
| joint_locations = torch.stack((*[joint_locations[i] for i in range(joint_locations.shape[0])], | |
| lumbar_transform[..., :3, 3], | |
| shoulder_transform_r[..., :3, 3], | |
| elbow_transform_r[..., :3, 3], wrist_transform_r[..., :3, 3], | |
| shoulder_transform_l[..., :3, 3], | |
| elbow_transform_l[..., :3, 3], wrist_transform_l[..., :3, 3])) | |
| if torch.isnan(joint_locations).any(): | |
| print('NAN in joint locations') | |
| foot_locations = torch.stack((calcaneus_transform_r[..., :3, 3], mtp_offset_transform_r[..., :3, 3], | |
| calcaneus_transform_l[..., :3, 3], mtp_offset_transform_l[..., :3, 3])) | |
| segment_orientations = torch.stack((pelvis_transform[..., :3, :3], | |
| femur_transform_r[..., :3, :3], tibia_transform_r[..., :3, :3], | |
| talus_transform_r[..., :3, :3], calcaneus_transform_r[..., :3, :3], | |
| femur_transform_l[..., :3, :3], tibia_transform_l[..., :3, :3], | |
| talus_transform_l[..., :3, :3], calcaneus_transform_l[..., :3, :3], | |
| torso_transform[..., :3, :3])) | |
| if with_arm: | |
| segment_orientations = torch.stack((*[segment_orientations[i] for i in range(segment_orientations.shape[0])], | |
| humerus_transform_r[..., :3, :3], ulna_transform_r[..., :3, :3], | |
| radius_transform_r[..., :3, :3], | |
| humerus_transform_l[..., :3, :3], ulna_transform_l[..., :3, :3], | |
| radius_transform_l[..., :3, :3])) | |
| joint_names = ['pelvis', 'hip_r', 'knee_r', 'ankle_r', 'calcn_r', 'mtp_r', | |
| 'hip_l', 'knee_l', 'ankle_l', 'calcn_l', 'mtp_l'] | |
| return foot_locations, joint_locations, joint_names, segment_orientations | |
| def get_model_offsets(skeleton, with_arm=False): | |
| pelvis = skeleton.getBodyNode(0) | |
| hip_r_joint = pelvis.getChildJoint(0) | |
| femur_r = hip_r_joint.getChildBodyNode() | |
| knee_r_joint = femur_r.getChildJoint(0) | |
| tibia_r = knee_r_joint.getChildBodyNode() | |
| ankle_r_joint = tibia_r.getChildJoint(0) | |
| talus_r = ankle_r_joint.getChildBodyNode() | |
| subtalar_r_joint = talus_r.getChildJoint(0) | |
| calcn_r = subtalar_r_joint.getChildBodyNode() | |
| mtp_r_joint = calcn_r.getChildJoint(0) | |
| hip_l_joint = pelvis.getChildJoint(1) | |
| femur_l = hip_l_joint.getChildBodyNode() | |
| knee_l_joint = femur_l.getChildJoint(0) | |
| tibia_l = knee_l_joint.getChildBodyNode() | |
| ankle_l_joint = tibia_l.getChildJoint(0) | |
| talus_l = ankle_l_joint.getChildBodyNode() | |
| subtalar_l_joint = talus_l.getChildJoint(0) | |
| calcn_l = subtalar_l_joint.getChildBodyNode() | |
| mtp_l_joint = calcn_l.getChildJoint(0) | |
| lumbar_joint = pelvis.getChildJoint(2) | |
| torso = lumbar_joint.getChildBodyNode() | |
| if with_arm: | |
| shoulder_r_joint = torso.getChildJoint(0) | |
| humerus_r = shoulder_r_joint.getChildBodyNode() | |
| elbow_r_joint = humerus_r.getChildJoint(0) | |
| ulna_r = elbow_r_joint.getChildBodyNode() | |
| radioulnar_r_joint = ulna_r.getChildJoint(0) | |
| radius_r = radioulnar_r_joint.getChildBodyNode() | |
| wrist_r_joint = radius_r.getChildJoint(0) | |
| hand_r = wrist_r_joint.getChildBodyNode() | |
| shoulder_l_joint = torso.getChildJoint(1) | |
| humerus_l = shoulder_l_joint.getChildBodyNode() | |
| elbow_l_joint = humerus_l.getChildJoint(0) | |
| ulna_l = elbow_l_joint.getChildBodyNode() | |
| radioulnar_l_joint = ulna_l.getChildJoint(0) | |
| radius_l = radioulnar_l_joint.getChildBodyNode() | |
| wrist_l_joint = radius_l.getChildJoint(0) | |
| hand_l = wrist_l_joint.getChildBodyNode() | |
| # hip offset | |
| hip_offset_r = torch.eye(4) | |
| hip_offset_r[:3, :3] = torch.tensor(hip_r_joint.getTransformFromParentBodyNode().rotation()) | |
| hip_offset_r[:3, 3] = torch.tensor(hip_r_joint.getTransformFromParentBodyNode().translation()) | |
| hip_offset_l = torch.eye(4) | |
| hip_offset_l[:3, :3] = torch.tensor(hip_l_joint.getTransformFromParentBodyNode().rotation()) | |
| hip_offset_l[:3, 3] = torch.tensor(hip_l_joint.getTransformFromParentBodyNode().translation()) | |
| # femur offset | |
| femur_offset_to_knee_in_femur_r = -torch.tensor(hip_r_joint.getTransformFromChildBodyNode().translation()) | |
| femur_rotation_to_knee_in_femur_r = torch.inverse(torch.tensor(hip_r_joint.getTransformFromChildBodyNode().rotation())) | |
| femur_offset_rotation_r = torch.eye(4) | |
| femur_offset_rotation_r[:3, :3] = femur_rotation_to_knee_in_femur_r | |
| femur_offset_translation_r = torch.eye(4) | |
| femur_offset_translation_r[:3, 3] = femur_offset_to_knee_in_femur_r | |
| femur_offset_r = torch.matmul(femur_offset_rotation_r, femur_offset_translation_r) | |
| femur_offset_to_knee_in_femur_l = -torch.tensor(hip_l_joint.getTransformFromChildBodyNode().translation()) | |
| femur_rotation_to_knee_in_femur_l = torch.inverse(torch.tensor(hip_l_joint.getTransformFromChildBodyNode().rotation())) | |
| femur_offset_rotation_l = torch.eye(4) | |
| femur_offset_rotation_l[:3, :3] = femur_rotation_to_knee_in_femur_l | |
| femur_offset_translation_l = torch.eye(4) | |
| femur_offset_translation_l[:3, 3] = femur_offset_to_knee_in_femur_l | |
| femur_offset_l = torch.matmul(femur_offset_rotation_l, femur_offset_translation_l) | |
| # knee offset | |
| knee_offset_r = torch.eye(4) | |
| knee_offset_r[:3, :3] = torch.tensor(knee_r_joint.getTransformFromParentBodyNode().rotation()) | |
| knee_offset_r[:3, 3] = torch.tensor(knee_r_joint.getTransformFromParentBodyNode().translation()) | |
| knee_offset_l = torch.eye(4) | |
| knee_offset_l[:3, :3] = torch.tensor(knee_l_joint.getTransformFromParentBodyNode().rotation()) | |
| knee_offset_l[:3, 3] = torch.tensor(knee_l_joint.getTransformFromParentBodyNode().translation()) | |
| # tibia offset | |
| tibia_offset_to_knee_in_tibia_r = -torch.tensor(knee_r_joint.getTransformFromChildBodyNode().translation()) | |
| tibia_rotation_to_knee_in_tibia_r = torch.inverse(torch.tensor(knee_r_joint.getTransformFromChildBodyNode().rotation())) | |
| tibia_offset_rotation_r = torch.eye(4) | |
| tibia_offset_rotation_r[:3, :3] = tibia_rotation_to_knee_in_tibia_r | |
| tibia_offset_translation_r = torch.eye(4) | |
| tibia_offset_translation_r[:3, 3] = tibia_offset_to_knee_in_tibia_r | |
| tibia_offset_r = torch.matmul(tibia_offset_rotation_r, tibia_offset_translation_r) | |
| tibia_offset_to_knee_in_tibia_l = -torch.tensor(knee_l_joint.getTransformFromChildBodyNode().translation()) | |
| tibia_rotation_to_knee_in_tibia_l = torch.inverse(torch.tensor(knee_l_joint.getTransformFromChildBodyNode().rotation())) | |
| tibia_offset_rotation_l = torch.eye(4) | |
| tibia_offset_rotation_l[:3, :3] = tibia_rotation_to_knee_in_tibia_l | |
| tibia_offset_translation_l = torch.eye(4) | |
| tibia_offset_translation_l[:3, 3] = tibia_offset_to_knee_in_tibia_l | |
| tibia_offset_l = torch.matmul(tibia_offset_rotation_l, tibia_offset_translation_l) | |
| # ankle offset | |
| ankle_offset_r = torch.eye(4) | |
| ankle_offset_r[:3,:3] = torch.tensor(ankle_r_joint.getTransformFromParentBodyNode().rotation()) | |
| ankle_offset_r[:3, 3] = torch.tensor(ankle_r_joint.getTransformFromParentBodyNode().translation()) | |
| ankle_offset_l = torch.eye(4) | |
| ankle_offset_l[:3, :3] = torch.tensor(ankle_l_joint.getTransformFromParentBodyNode().rotation()) | |
| ankle_offset_l[:3, 3] = torch.tensor(ankle_l_joint.getTransformFromParentBodyNode().translation()) | |
| # talus offset | |
| talus_offset_to_ankle_in_talus_r = -torch.tensor(ankle_r_joint.getTransformFromChildBodyNode().translation()) | |
| talus_rotation_to_ankle_in_talus_r = torch.inverse(torch.tensor(ankle_r_joint.getTransformFromChildBodyNode().rotation())) | |
| talus_offset_rotation_r = torch.eye(4) | |
| talus_offset_rotation_r[:3, :3] = talus_rotation_to_ankle_in_talus_r | |
| talus_offset_translation_r = torch.eye(4) | |
| talus_offset_translation_r[:3, 3] = talus_offset_to_ankle_in_talus_r | |
| talus_offset_r = torch.matmul(talus_offset_rotation_r, talus_offset_translation_r) | |
| talus_offset_to_ankle_in_talus_l = -torch.tensor(ankle_l_joint.getTransformFromChildBodyNode().translation()) | |
| talus_rotation_to_ankle_in_talus_l = torch.inverse(torch.tensor(ankle_l_joint.getTransformFromChildBodyNode().rotation())) | |
| talus_offset_rotation_l = torch.eye(4) | |
| talus_offset_rotation_l[:3, :3] = talus_rotation_to_ankle_in_talus_l | |
| talus_offset_translation_l = torch.eye(4) | |
| talus_offset_translation_l[:3, 3] = talus_offset_to_ankle_in_talus_l | |
| talus_offset_l = torch.matmul(talus_offset_rotation_l, talus_offset_translation_l) | |
| # subtalar offset | |
| subtalar_offset_r = torch.eye(4) | |
| subtalar_offset_r[:3,:3] = torch.tensor(subtalar_r_joint.getTransformFromParentBodyNode().rotation()) | |
| subtalar_offset_r[:3, 3] = torch.tensor(subtalar_r_joint.getTransformFromParentBodyNode().translation()) | |
| subtalar_offset_l = torch.eye(4) | |
| subtalar_offset_l[:3, :3] = torch.tensor(subtalar_l_joint.getTransformFromParentBodyNode().rotation()) | |
| subtalar_offset_l[:3, 3] = torch.tensor(subtalar_l_joint.getTransformFromParentBodyNode().translation()) | |
| # calcaneus offset | |
| calcaneus_offset_to_subtalar_in_calcaneus_r = -torch.tensor(subtalar_r_joint.getTransformFromChildBodyNode().translation()) | |
| calcaneus_rotation_to_subtalar_in_calcaneus_r = torch.inverse(torch.tensor(subtalar_r_joint.getTransformFromChildBodyNode().rotation())) | |
| calcaneus_offset_rotation_r = torch.eye(4) | |
| calcaneus_offset_rotation_r[:3, :3] = calcaneus_rotation_to_subtalar_in_calcaneus_r | |
| calcaneus_offset_translation_r = torch.eye(4) | |
| calcaneus_offset_translation_r[:3, 3] = calcaneus_offset_to_subtalar_in_calcaneus_r | |
| calcaneus_offset_r = torch.matmul(calcaneus_offset_rotation_r, calcaneus_offset_translation_r) | |
| calcaneus_offset_to_subtalar_in_calcaneus_l = -torch.tensor(subtalar_l_joint.getTransformFromChildBodyNode().translation()) | |
| calcaneus_rotation_to_subtalar_in_calcaneus_l = torch.inverse(torch.tensor(subtalar_l_joint.getTransformFromChildBodyNode().rotation())) | |
| calcaneus_offset_rotation_l = torch.eye(4) | |
| calcaneus_offset_rotation_l[:3, :3] = calcaneus_rotation_to_subtalar_in_calcaneus_l | |
| calcaneus_offset_translation_l = torch.eye(4) | |
| calcaneus_offset_translation_l[:3, 3] = calcaneus_offset_to_subtalar_in_calcaneus_l | |
| calcaneus_offset_l = torch.matmul(calcaneus_offset_rotation_l, calcaneus_offset_translation_l) | |
| # mtp offset | |
| mtp_offset_r = torch.eye(4) | |
| mtp_offset_r[:3, :3] = torch.tensor(mtp_r_joint.getTransformFromParentBodyNode().rotation()) | |
| mtp_offset_r[:3, 3] = torch.tensor(mtp_r_joint.getTransformFromParentBodyNode().translation()) | |
| mtp_offset_l = torch.eye(4) | |
| mtp_offset_l[:3, :3] = torch.tensor(mtp_l_joint.getTransformFromParentBodyNode().rotation()) | |
| mtp_offset_l[:3, 3] = torch.tensor(mtp_l_joint.getTransformFromParentBodyNode().translation()) | |
| # toes offset | |
| toes_offset_to_mtp_in_toes_r = -torch.tensor(mtp_r_joint.getTransformFromChildBodyNode().translation()) | |
| toes_rotation_to_mtp_in_toes_r = torch.inverse(torch.tensor(mtp_r_joint.getTransformFromChildBodyNode().rotation())) | |
| toes_offset_rotation_r = torch.eye(4) | |
| toes_offset_rotation_r[:3, :3] = toes_rotation_to_mtp_in_toes_r | |
| toes_offset_translation_r = torch.eye(4) | |
| toes_offset_translation_r[:3, 3] = toes_offset_to_mtp_in_toes_r | |
| toes_offset_r = torch.matmul(toes_offset_rotation_r, toes_offset_translation_r) | |
| toes_offset_to_mtp_in_toes_l = -torch.tensor(mtp_l_joint.getTransformFromChildBodyNode().translation()) | |
| toes_rotation_to_mtp_in_toes_l = torch.inverse(torch.tensor(mtp_l_joint.getTransformFromChildBodyNode().rotation())) | |
| toes_offset_rotation_l = torch.eye(4) | |
| toes_offset_rotation_l[:3, :3] = toes_rotation_to_mtp_in_toes_l | |
| toes_offset_translation_l = torch.eye(4) | |
| toes_offset_translation_l[:3, 3] = toes_offset_to_mtp_in_toes_l | |
| toes_offset_l = torch.matmul(toes_offset_rotation_l, toes_offset_translation_l) | |
| # lumbar offset | |
| lumbar_offset = torch.eye(4) | |
| lumbar_offset[:3, :3] = torch.tensor(lumbar_joint.getTransformFromParentBodyNode().rotation()) | |
| lumbar_offset[:3, 3] = torch.tensor(lumbar_joint.getTransformFromParentBodyNode().translation()) | |
| # torso offset | |
| torso_offset_to_lumbar_in_torso = -torch.tensor(lumbar_joint.getTransformFromChildBodyNode().translation()) | |
| torso_offset_rotation_to_lumbar_in_torso = torch.inverse(torch.tensor(lumbar_joint.getTransformFromChildBodyNode().rotation())) | |
| torso_offset_rotation = torch.eye(4) | |
| torso_offset_rotation[:3, :3] = torso_offset_rotation_to_lumbar_in_torso | |
| torso_offset_translation = torch.eye(4) | |
| torso_offset_translation[:3, 3] = torso_offset_to_lumbar_in_torso | |
| torso_offset = torch.matmul(torso_offset_rotation, torso_offset_translation) | |
| if with_arm: | |
| # shoulder offset | |
| shoulder_offset_r = torch.eye(4) | |
| shoulder_offset_r[:3, :3] = torch.tensor(shoulder_r_joint.getTransformFromParentBodyNode().rotation()) | |
| shoulder_offset_r[:3, 3] = torch.tensor(shoulder_r_joint.getTransformFromParentBodyNode().translation()) | |
| shoulder_offset_l = torch.eye(4) | |
| shoulder_offset_l[:3, :3] = torch.tensor(shoulder_l_joint.getTransformFromParentBodyNode().rotation()) | |
| shoulder_offset_l[:3, 3] = torch.tensor(shoulder_l_joint.getTransformFromParentBodyNode().translation()) | |
| # humerus offset | |
| humerus_offset_to_shoulder_in_humerus_r = -torch.tensor(shoulder_r_joint.getTransformFromChildBodyNode().translation()) | |
| humerus_offset_rotation_to_shoulder_in_humerus_r = torch.inverse(torch.tensor(shoulder_r_joint.getTransformFromChildBodyNode().rotation())) | |
| humerus_offset_rotation_r = torch.eye(4) | |
| humerus_offset_rotation_r[:3, :3] = humerus_offset_rotation_to_shoulder_in_humerus_r | |
| humerus_offset_translation_r = torch.eye(4) | |
| humerus_offset_translation_r[:3, 3] = humerus_offset_to_shoulder_in_humerus_r | |
| humerus_offset_r = torch.matmul(humerus_offset_rotation_r, humerus_offset_translation_r) | |
| humerus_offset_to_shoulder_in_humerus_l = -torch.tensor(shoulder_l_joint.getTransformFromChildBodyNode().translation()) | |
| humerus_offset_rotation_to_shoulder_in_humerus_l = torch.inverse(torch.tensor(shoulder_l_joint.getTransformFromChildBodyNode().rotation())) | |
| humerus_offset_rotation_l = torch.eye(4) | |
| humerus_offset_rotation_l[:3, :3] = humerus_offset_rotation_to_shoulder_in_humerus_l | |
| humerus_offset_translation_l = torch.eye(4) | |
| humerus_offset_translation_l[:3, 3] = humerus_offset_to_shoulder_in_humerus_l | |
| humerus_offset_l = torch.matmul(humerus_offset_rotation_l, humerus_offset_translation_l) | |
| # elbow offset | |
| elbow_offset_r = torch.eye(4) | |
| elbow_offset_r[:3, :3] = torch.tensor(elbow_r_joint.getTransformFromParentBodyNode().rotation()) | |
| elbow_offset_r[:3, 3] = torch.tensor(elbow_r_joint.getTransformFromParentBodyNode().translation()) | |
| elbow_offset_l = torch.eye(4) | |
| elbow_offset_l[:3, :3] = torch.tensor(elbow_l_joint.getTransformFromParentBodyNode().rotation()) | |
| elbow_offset_l[:3, 3] = torch.tensor(elbow_l_joint.getTransformFromParentBodyNode().translation()) | |
| # ulna offset | |
| ulna_offset_to_elbow_in_ulna_r = -torch.tensor(elbow_r_joint.getTransformFromChildBodyNode().translation()) | |
| ulna_offset_rotation_to_elbow_in_ulna_r = torch.inverse(torch.tensor(elbow_r_joint.getTransformFromChildBodyNode().rotation())) | |
| ulna_offset_rotation_r = torch.eye(4) | |
| ulna_offset_rotation_r[:3, :3] = ulna_offset_rotation_to_elbow_in_ulna_r | |
| ulna_offset_translation_r = torch.eye(4) | |
| ulna_offset_translation_r[:3, 3] = ulna_offset_to_elbow_in_ulna_r | |
| ulna_offset_r = torch.matmul(ulna_offset_rotation_r, ulna_offset_translation_r) | |
| ulna_offset_to_elbow_in_ulna_l = -torch.tensor(elbow_l_joint.getTransformFromChildBodyNode().translation()) | |
| ulna_offset_rotation_to_elbow_in_ulna_l = torch.inverse(torch.tensor(elbow_l_joint.getTransformFromChildBodyNode().rotation())) | |
| ulna_offset_rotation_l = torch.eye(4) | |
| ulna_offset_rotation_l[:3, :3] = ulna_offset_rotation_to_elbow_in_ulna_l | |
| ulna_offset_translation_l = torch.eye(4) | |
| ulna_offset_translation_l[:3, 3] = ulna_offset_to_elbow_in_ulna_l | |
| ulna_offset_l = torch.matmul(ulna_offset_rotation_l, ulna_offset_translation_l) | |
| # radioulnar offset | |
| radioulnar_offset_r = torch.eye(4) | |
| radioulnar_offset_r[:3, :3] = torch.tensor(radioulnar_r_joint.getTransformFromParentBodyNode().rotation()) | |
| radioulnar_offset_r[:3, 3] = torch.tensor(radioulnar_r_joint.getTransformFromParentBodyNode().translation()) | |
| radioulnar_offset_l = torch.eye(4) | |
| radioulnar_offset_l[:3, :3] = torch.tensor(radioulnar_l_joint.getTransformFromParentBodyNode().rotation()) | |
| radioulnar_offset_l[:3, 3] = torch.tensor(radioulnar_l_joint.getTransformFromParentBodyNode().translation()) | |
| # radius offset | |
| radius_offset_to_radioulnar_in_radius_r = -torch.tensor(radioulnar_r_joint.getTransformFromChildBodyNode().translation()) | |
| radius_offset_rotation_to_radioulnar_in_radius_r = torch.inverse(torch.tensor(radioulnar_r_joint.getTransformFromChildBodyNode().rotation())) | |
| radius_offset_rotation_r = torch.eye(4) | |
| radius_offset_rotation_r[:3, :3] = radius_offset_rotation_to_radioulnar_in_radius_r | |
| radius_offset_translation_r = torch.eye(4) | |
| radius_offset_translation_r[:3, 3] = radius_offset_to_radioulnar_in_radius_r | |
| radius_offset_r = torch.matmul(radius_offset_rotation_r, radius_offset_translation_r) | |
| radius_offset_to_radioulnar_in_radius_l = -torch.tensor(radioulnar_l_joint.getTransformFromChildBodyNode().translation()) | |
| radius_offset_rotation_to_radioulnar_in_radius_l = torch.inverse(torch.tensor(radioulnar_l_joint.getTransformFromChildBodyNode().rotation())) | |
| radius_offset_rotation_l = torch.eye(4) | |
| radius_offset_rotation_l[:3, :3] = radius_offset_rotation_to_radioulnar_in_radius_l | |
| radius_offset_translation_l = torch.eye(4) | |
| radius_offset_translation_l[:3, 3] = radius_offset_to_radioulnar_in_radius_l | |
| radius_offset_l = torch.matmul(radius_offset_rotation_l, radius_offset_translation_l) | |
| # wrist offset | |
| wrist_offset_r = torch.eye(4) | |
| wrist_offset_r[:3, :3] = torch.tensor(wrist_r_joint.getTransformFromParentBodyNode().rotation()) | |
| wrist_offset_r[:3, 3] = torch.tensor(wrist_r_joint.getTransformFromParentBodyNode().translation()) | |
| wrist_offset_l = torch.eye(4) | |
| wrist_offset_l[:3, :3] = torch.tensor(wrist_l_joint.getTransformFromParentBodyNode().rotation()) | |
| wrist_offset_l[:3, 3] = torch.tensor(wrist_l_joint.getTransformFromParentBodyNode().translation()) | |
| # hand offset | |
| hand_offset_to_wrist_in_hand_r = -torch.tensor(wrist_r_joint.getTransformFromChildBodyNode().translation()) | |
| hand_offset_rotation_to_wrist_in_hand_r = torch.inverse(torch.tensor(wrist_r_joint.getTransformFromChildBodyNode().rotation())) | |
| hand_offset_rotation_r = torch.eye(4) | |
| hand_offset_rotation_r[:3, :3] = hand_offset_rotation_to_wrist_in_hand_r | |
| hand_offset_translation_r = torch.eye(4) | |
| hand_offset_translation_r[:3, 3] = hand_offset_to_wrist_in_hand_r | |
| hand_offset_r = torch.matmul(hand_offset_rotation_r, hand_offset_translation_r) | |
| hand_offset_to_wrist_in_hand_l = -torch.tensor(wrist_l_joint.getTransformFromChildBodyNode().translation()) | |
| hand_offset_rotation_to_wrist_in_hand_l = torch.inverse(torch.tensor(wrist_l_joint.getTransformFromChildBodyNode().rotation())) | |
| hand_offset_rotation_l = torch.eye(4) | |
| hand_offset_rotation_l[:3, :3] = hand_offset_rotation_to_wrist_in_hand_l | |
| hand_offset_translation_l = torch.eye(4) | |
| hand_offset_translation_l[:3, 3] = hand_offset_to_wrist_in_hand_l | |
| hand_offset_l = torch.matmul(hand_offset_rotation_l, hand_offset_translation_l) | |
| offsets = torch.stack((hip_offset_r, femur_offset_r, knee_offset_r, tibia_offset_r, | |
| ankle_offset_r, talus_offset_r, subtalar_offset_r, | |
| calcaneus_offset_r, mtp_offset_r, | |
| hip_offset_l, femur_offset_l, knee_offset_l, tibia_offset_l, | |
| ankle_offset_l, talus_offset_l, subtalar_offset_l, | |
| calcaneus_offset_l, mtp_offset_l, | |
| lumbar_offset, torso_offset), dim=2) | |
| if with_arm: | |
| offsets = torch.stack((*[offsets[i] for i in range(offsets.shape[0])], | |
| shoulder_offset_r, humerus_offset_r, elbow_offset_r, ulna_offset_r, | |
| radioulnar_offset_r, radius_offset_r, wrist_offset_r, hand_offset_r, | |
| shoulder_offset_l, humerus_offset_l, elbow_offset_l, ulna_offset_l, | |
| radioulnar_offset_l, radius_offset_l, wrist_offset_l, hand_offset_l,toes_offset_r, toes_offset_l), dim=2) | |
| return offsets | |
| def euler_to_angular_velocity(q, sampling_fre, convention="XYZ"): | |
| assert q.shape[-1] == 3 | |
| mat = euler_angles_to_matrix(q, convention) | |
| mat_diff = mat.clone() | |
| mat_diff[:-1, :, :] = mat_diff[:-1, :, :] - mat_diff[1:, :, :] | |
| mat_diff[-1, :, :] = mat_diff[-2, :, :] | |
| angular_velocity = torch.stack([mat_diff[:, 2, 1] - mat_diff[:, 1, 2], | |
| mat_diff[:, 0, 2] - mat_diff[:, 2, 0], | |
| mat_diff[:, 1, 0] - mat_diff[:, 0, 1]], dim=1) * sampling_fre * 0.5 | |
| return angular_velocity | |
| def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: | |
| if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: | |
| raise ValueError("Invalid input euler angles.") | |
| if len(convention) != 3: | |
| raise ValueError("Convention must have 3 letters.") | |
| if convention[1] in (convention[0], convention[2]): | |
| raise ValueError(f"Invalid convention {convention}.") | |
| for letter in convention: | |
| if letter not in ("X", "Y", "Z"): | |
| raise ValueError(f"Invalid letter {letter} in convention string.") | |
| matrices = [ | |
| _axis_angle_rotation(c, e) | |
| for c, e in zip(convention, torch.unbind(euler_angles, -1)) | |
| ] | |
| return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) | |
| def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: | |
| cos = torch.cos(angle) | |
| sin = torch.sin(angle) | |
| one = torch.ones_like(angle) | |
| zero = torch.zeros_like(angle) | |
| if axis == "X": | |
| R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) | |
| elif axis == "Y": | |
| R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) | |
| elif axis == "Z": | |
| R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) | |
| else: | |
| raise ValueError("letter must be either X, Y or Z.") | |
| return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) | |
| def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: | |
| if len(convention) != 3: | |
| raise ValueError("Convention must have 3 letters.") | |
| if convention[1] in (convention[0], convention[2]): | |
| raise ValueError(f"Invalid convention {convention}.") | |
| for letter in convention: | |
| if letter not in ("X", "Y", "Z"): | |
| raise ValueError(f"Invalid letter {letter} in convention string.") | |
| if matrix.size(-1) != 3 or matrix.size(-2) != 3: | |
| raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") | |
| i0 = _index_from_letter(convention[0]) | |
| i2 = _index_from_letter(convention[2]) | |
| tait_bryan = i0 != i2 | |
| if tait_bryan: | |
| central_angle = torch.asin( | |
| matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) | |
| ) | |
| else: | |
| central_angle = torch.acos(matrix[..., i0, i0]) | |
| o = ( | |
| _angle_from_tan( | |
| convention[0], convention[1], matrix[..., i2], False, tait_bryan | |
| ), | |
| central_angle, | |
| _angle_from_tan( | |
| convention[2], convention[1], matrix[..., i0, :], True, tait_bryan | |
| ), | |
| ) | |
| return torch.stack(o, -1) | |
| def _angle_from_tan( | |
| axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool | |
| ) -> torch.Tensor: | |
| i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] | |
| if horizontal: | |
| i2, i1 = i1, i2 | |
| even = (axis + other_axis) in ["XY", "YZ", "ZX"] | |
| if horizontal == even: | |
| return torch.atan2(data[..., i1], data[..., i2]) | |
| if tait_bryan: | |
| return torch.atan2(-data[..., i2], data[..., i1]) | |
| return torch.atan2(data[..., i2], -data[..., i1]) | |
| def _index_from_letter(letter: str) -> int: | |
| if letter == "X": | |
| return 0 | |
| if letter == "Y": | |
| return 1 | |
| if letter == "Z": | |
| return 2 | |
| raise ValueError("letter must be either X, Y or Z.") | |
| def get_multi_body_loc_using_nimble_by_body_names(body_names, skel, poses): | |
| body_ids = [skel.getBodyNode(name) for name in body_names] | |
| return get_multi_body_loc_using_nimble_by_body_nodes(body_ids, skel, poses) | |
| def get_multi_body_loc_using_nimble_by_body_nodes(body_nodes, skel, poses): | |
| body_loc = [] | |
| for i_frame in range(len(poses)): | |
| skel.setPositions(poses[i_frame]) | |
| body_loc.append(np.concatenate([body_node.getWorldTransform().translation() for body_node in body_nodes])) | |
| body_loc = np.array(body_loc) | |
| return body_loc | |
| def inverse_norm_cops(skel, states, opt, sub_mass, height_m, grf_thd_to_zero_cop=20): | |
| poses = states[:, opt.kinematic_osim_col_loc] | |
| forces = states[:, opt.grf_osim_col_loc] | |
| normed_cops = states[:, opt.cop_osim_col_loc] | |
| if len(skel.getDofs()) != poses.shape[1]: # With Arm model | |
| print('With Arm model is used. Adding 6 zeros to the end of the poses.') | |
| poses = np.concatenate([poses, np.zeros((poses.shape[0], 10))], axis=-1) | |
| foot_loc = get_multi_body_loc_using_nimble_by_body_names(('calcn_r', 'calcn_l'), skel, poses) | |
| for i_plate in range(2): | |
| force_v = forces[:, 3*i_plate:3*(i_plate+1)] | |
| force_v[force_v == 0] = 1e-6 | |
| vector = normed_cops[:, 3 * i_plate:3 * (i_plate + 1)] / force_v[:, 1:2] * height_m | |
| vector = np.nan_to_num(vector, posinf=0, neginf=0) | |
| # vector.clip(min=-0.4, max=0.4, out=vector) # CoP should be within 0.4 m from the foot | |
| cops = vector + foot_loc[:, 3*i_plate:3*(i_plate+1)] | |
| if grf_thd_to_zero_cop: | |
| cops[force_v[:, 1] * sub_mass < grf_thd_to_zero_cop] = 0 | |
| if isinstance(states, torch.Tensor): | |
| cops = torch.from_numpy(cops).to(states.dtype) | |
| else: | |
| cops = cops.astype(states.dtype) | |
| states[:, opt.cop_osim_col_loc[3*i_plate:3*(i_plate+1)]] = cops | |
| return states | |
| def inverse_align_moving_direction(results_pred, column_names, rot_mat): | |
| if isinstance(results_pred, np.ndarray): | |
| poses = torch.from_numpy(results_pred) | |
| results_pred_clone = poses.clone().float() | |
| pelvis_orientation_col_loc = [column_names.index(col) for col in JOINTS_3D_ALL['pelvis']] | |
| p_pos_col_loc = [column_names.index(col) for col in [f'pelvis_t{x}' for x in ['x', 'y', 'z']]] | |
| r_grf_col_loc = [column_names.index(col) for col in ['calcn_r_force_vx', 'calcn_r_force_vy', 'calcn_r_force_vz']] | |
| l_grf_col_loc = [column_names.index(col) for col in ['calcn_l_force_vx', 'calcn_l_force_vy', 'calcn_l_force_vz']] | |
| r_cop_col_loc = [column_names.index(col) for col in [f'calcn_r_force_normed_cop_{x}' for x in ['x', 'y', 'z']]] | |
| l_cop_col_loc = [column_names.index(col) for col in [f'calcn_l_force_normed_cop_{x}' for x in ['x', 'y', 'z']]] | |
| if len(pelvis_orientation_col_loc) != 3 or len(p_pos_col_loc) != 3 or len(r_grf_col_loc) != 3 or len( | |
| l_grf_col_loc) != 3: | |
| raise ValueError('check column names') | |
| pelvis_orientation = results_pred_clone[:, pelvis_orientation_col_loc] | |
| pelvis_orientation = euler_angles_to_matrix(pelvis_orientation, "ZXY") | |
| p_pos = results_pred_clone[:, p_pos_col_loc] | |
| r_grf = results_pred_clone[:, r_grf_col_loc] | |
| l_grf = results_pred_clone[:, l_grf_col_loc] | |
| r_cop = results_pred_clone[:, r_cop_col_loc] | |
| l_cop = results_pred_clone[:, l_cop_col_loc] | |
| rot_mat = rot_mat.T | |
| pelvis_orientation_rotated = torch.matmul(rot_mat, pelvis_orientation) | |
| p_pos_rotated = torch.matmul(rot_mat, p_pos.unsqueeze(2)).squeeze(2) | |
| r_grf_rotated = torch.matmul(rot_mat, r_grf.unsqueeze(2)).squeeze(2) | |
| l_grf_rotated = torch.matmul(rot_mat, l_grf.unsqueeze(2)).squeeze(2) | |
| r_cop_rotated = torch.matmul(rot_mat, r_cop.unsqueeze(2)).squeeze(2) | |
| l_cop_rotated = torch.matmul(rot_mat, l_cop.unsqueeze(2)).squeeze(2) | |
| results_pred_clone[:, pelvis_orientation_col_loc] = matrix_to_euler_angles(pelvis_orientation_rotated.float(), "ZXY") | |
| results_pred_clone[:, p_pos_col_loc] = p_pos_rotated.float() | |
| results_pred_clone[:, r_grf_col_loc] = r_grf_rotated.float() | |
| results_pred_clone[:, l_grf_col_loc] = l_grf_rotated.float() | |
| results_pred_clone[:, r_cop_col_loc] = r_cop_rotated.float() | |
| results_pred_clone[:, l_cop_col_loc] = l_cop_rotated.float() | |
| return results_pred_clone | |
| def convert_addb_state_to_model_input(pose_df, joints_3d, sampling_fre): | |
| # shift root position to start in (x,y) = (0,0) | |
| pos_vec = [pose_df['pelvis_tx'][0], pose_df['pelvis_ty'][0], pose_df['pelvis_tz'][0]] | |
| pose_df['pelvis_tx'] = pose_df['pelvis_tx'] - pose_df['pelvis_tx'][0] | |
| pose_df['pelvis_ty'] = pose_df['pelvis_ty'] - pose_df['pelvis_ty'][0] | |
| pose_df['pelvis_tz'] = pose_df['pelvis_tz'] - pose_df['pelvis_tz'][0] | |
| # remove frozen dof | |
| for frozen_col in FROZEN_DOFS: | |
| if frozen_col in pose_df.columns: | |
| pose_df = pose_df.drop(frozen_col, axis=1) | |
| # convert euler to 6v | |
| for joint_name, joints_with_3_dof in joints_3d.items(): | |
| joint_6v = euler_to_6v(torch.tensor(pose_df[joints_with_3_dof].values), "ZXY").numpy() | |
| for i in range(6): | |
| pose_df[joint_name + '_' + str(i)] = joint_6v[:, i] | |
| for joint_name, joints_with_3_dof in joints_3d.items(): | |
| joint_angular_v = euler_to_angular_velocity(torch.tensor(pose_df[joints_with_3_dof].values), sampling_fre, "ZXY").numpy() | |
| joint_angular_v = data_filter(joint_angular_v, 15, sampling_fre, 4) | |
| for joints_euler_name in joints_with_3_dof: | |
| pose_df = pose_df.drop(joints_euler_name, axis=1) | |
| for i, axis in enumerate(['x', 'y', 'z']): | |
| pose_df[joint_name + '_' + axis + '_angular' + '_vel'] = joint_angular_v[:, i] | |
| vel_col_loc = [i for i, col in enumerate(pose_df.columns) if not np.sum([term in col for term in ['force', 'pelvis_', '_vel', '_0', '_1', '_2', '_3', '_4', '_5']])] | |
| vel_col_names = [f'{col}_vel' for i, col in enumerate(pose_df.columns) if not np.sum([term in col for term in ['force', 'pelvis_', '_vel', '_0', '_1', '_2', '_3', '_4', '_5']])] | |
| kinematics_np = pose_df.iloc[:, vel_col_loc].to_numpy().copy() | |
| kinematics_np_filtered = data_filter(kinematics_np, 15, sampling_fre, 4) | |
| kinematics_vel = np.stack([spline_fitting_1d(kinematics_np_filtered[:, i_col], range(kinematics_np_filtered.shape[0]), 1).ravel() | |
| for i_col in range(kinematics_np_filtered.shape[1])]).T | |
| pose_vel_df = pd.DataFrame(np.concatenate([pose_df.values, kinematics_vel], axis=1), columns=list(pose_df.columns)+vel_col_names) | |
| return pose_vel_df, pos_vec | |
| def inverse_convert_addb_state_to_model_input( | |
| model_states, model_states_column_names, treadmill_speed, joints_3d, osim_dof_columns, pos_vec, height_m, sampling_fre=100): | |
| model_states_dict = {col: model_states[..., i] for i, col in enumerate(model_states_column_names) if | |
| col in osim_dof_columns} | |
| for i_col, col in enumerate(['pelvis_tx', 'pelvis_ty', 'pelvis_tz']): | |
| model_states_dict[col] = model_states_dict[col] * height_m.unsqueeze(-1).expand(model_states_dict[col].shape) | |
| if col == 'pelvis_tx': | |
| model_states_dict[col] -= treadmill_speed | |
| model_states_dict[col] = torch.cumsum(model_states_dict[col], dim=-1) / sampling_fre | |
| # convert 6v to euler | |
| for joint_name, joints_with_3_dof in joints_3d.items(): | |
| joint_name_6v = [joint_name + '_' + str(i) for i in range(6)] | |
| index_ = [model_states_column_names.index(joint_name_6v[i]) for i in range(6)] | |
| joint_euler = euler_from_6v(model_states[..., index_], "ZXY") | |
| for i, joints_euler_name in enumerate(joints_with_3_dof): | |
| model_states_dict[joints_euler_name] = joint_euler[..., i] | |
| # add frozen dof back | |
| for frozen_col in FROZEN_DOFS: | |
| model_states_dict[frozen_col] = torch.zeros(model_states.shape[:len(model_states.shape)-1]).to(model_states.device) | |
| pos_vec = torch.tensor(pos_vec) | |
| pos_vec_torch = pos_vec.unsqueeze(-1).repeat(*[1 for _ in range(len(pos_vec.shape))], model_states.shape[-2]).to(model_states.device) | |
| for i_col, col in enumerate(['pelvis_tx', 'pelvis_ty', 'pelvis_tz']): | |
| model_states_dict[col] += pos_vec_torch[..., i_col, :] | |
| osim_states = torch.stack([model_states_dict[col] for col in osim_dof_columns], dim=len(model_states.shape)-1).float() | |
| return osim_states | |
| def align_moving_direction(poses, column_names): | |
| if isinstance(poses, np.ndarray): | |
| poses = torch.from_numpy(poses) | |
| pose_clone = poses.clone().float() | |
| pelvis_orientation_col_loc = [column_names.index(col) for col in JOINTS_3D_ALL['pelvis']] | |
| p_pos_col_loc = [column_names.index(col) for col in [f'pelvis_t{x}' for x in ['x', 'y', 'z']]] | |
| r_grf_col_loc = [column_names.index(col) for col in ['calcn_r_force_vx', 'calcn_r_force_vy', 'calcn_r_force_vz']] | |
| l_grf_col_loc = [column_names.index(col) for col in ['calcn_l_force_vx', 'calcn_l_force_vy', 'calcn_l_force_vz']] | |
| r_cop_col_loc = [column_names.index(col) for col in [f'calcn_r_force_normed_cop_{x}' for x in ['x', 'y', 'z']]] | |
| l_cop_col_loc = [column_names.index(col) for col in [f'calcn_l_force_normed_cop_{x}' for x in ['x', 'y', 'z']]] | |
| if len(pelvis_orientation_col_loc) != 3 or len(p_pos_col_loc) != 3 or len(r_grf_col_loc) != 3 or len( | |
| l_grf_col_loc) != 3: | |
| raise ValueError('check column names') | |
| pelvis_orientation = pose_clone[:, pelvis_orientation_col_loc] | |
| pelvis_orientation = euler_angles_to_matrix(pelvis_orientation, "ZXY") | |
| p_pos = pose_clone[:, p_pos_col_loc] | |
| r_grf = pose_clone[:, r_grf_col_loc] | |
| l_grf = pose_clone[:, l_grf_col_loc] | |
| r_cop = pose_clone[:, r_cop_col_loc] | |
| l_cop = pose_clone[:, l_cop_col_loc] | |
| angles = np.arctan2(- pelvis_orientation[:, 0, 2], pelvis_orientation[:, 2, 2]) | |
| if np.rad2deg(angles.max() - angles.min()) > 45: | |
| return False, None | |
| angle = angles.median() | |
| rot_mat = torch.tensor([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]]).float() | |
| pelvis_orientation_rotated = torch.matmul(rot_mat, pelvis_orientation) | |
| p_pos_rotated = torch.matmul(rot_mat, p_pos.unsqueeze(2)).squeeze(2) | |
| r_grf_rotated = torch.matmul(rot_mat, r_grf.unsqueeze(2)).squeeze(2) | |
| l_grf_rotated = torch.matmul(rot_mat, l_grf.unsqueeze(2)).squeeze(2) | |
| r_cop_rotated = torch.matmul(rot_mat, r_cop.unsqueeze(2)).squeeze(2) | |
| l_cop_rotated = torch.matmul(rot_mat, l_cop.unsqueeze(2)).squeeze(2) | |
| pose_clone[:, pelvis_orientation_col_loc] = matrix_to_euler_angles(pelvis_orientation_rotated.float(), "ZXY") | |
| pose_clone[:, p_pos_col_loc] = p_pos_rotated.float() | |
| pose_clone[:, r_grf_col_loc] = r_grf_rotated.float() | |
| pose_clone[:, l_grf_col_loc] = l_grf_rotated.float() | |
| pose_clone[:, r_cop_col_loc] = r_cop_rotated.float() | |
| pose_clone[:, l_cop_col_loc] = l_cop_rotated.float() | |
| return pose_clone, rot_mat | |
| def data_filter(data, cut_off_fre, sampling_fre, filter_order=4): | |
| fre = cut_off_fre / (sampling_fre / 2) | |
| b, a = butter(filter_order, fre, 'lowpass') | |
| if len(data.shape) == 1: | |
| data_filtered = filtfilt(b, a, data) | |
| else: | |
| data_filtered = filtfilt(b, a, data, axis=0) | |
| return data_filtered | |
| def fix_seed(): | |
| torch.manual_seed(0) | |
| random.seed(0) | |
| np.random.seed(0) | |
| def linear_resample_data(trial_data, original_fre, target_fre): | |
| x, step = np.linspace(0., 1., trial_data.shape[0], retstep=True) | |
| new_x = np.arange(0., 1., step * original_fre / target_fre) | |
| f = interp1d(x, trial_data, axis=0) | |
| trial_data_resampled = f(new_x) | |
| return trial_data_resampled | |
| def spline_fitting_1d(data_, step_to_resample, der=0): | |
| assert len(data_.shape) == 1 | |
| data_ = data_.reshape(1, -1) | |
| tck, step = interpo.splprep(data_, u=range(data_.shape[1]), s=0) | |
| data_resampled = interpo.splev(step_to_resample, tck, der=der) | |
| data_resampled = np.column_stack(data_resampled) | |
| return data_resampled | |
| def convert_overlapped_list_to_array(trial_len, win_list, s_, e_, fun=np.nanmedian, max_size=150): | |
| array_val_expand = np.full((max_size, trial_len, win_list[0].shape[1]), np.nan) | |
| for i_win, (win, s, e) in enumerate(zip(win_list, s_, e_)): | |
| array_val_expand[i_win%max_size, s:e] = win[:e-s] | |
| array_val = fun(array_val_expand, axis=0) | |
| std_val = np.nanstd(array_val_expand, axis=0) | |
| return array_val, std_val | |
| def identity(t, *args, **kwargs): | |
| return t | |
| def extract(a, t, x_shape): | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| def make_beta_schedule( | |
| schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 | |
| ): | |
| if schedule == "linear": | |
| betas = ( | |
| torch.linspace( | |
| linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64 | |
| ) | |
| ** 2 | |
| ) | |
| elif schedule == "cosine": | |
| timesteps = ( | |
| torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s | |
| ) | |
| alphas = timesteps / (1 + cosine_s) * np.pi / 2 | |
| alphas = torch.cos(alphas).pow(2) | |
| alphas = alphas / alphas[0] | |
| betas = 1 - alphas[1:] / alphas[:-1] | |
| betas = np.clip(betas, a_min=0, a_max=0.999) | |
| elif schedule == "sqrt_linear": | |
| betas = torch.linspace( | |
| linear_start, linear_end, n_timestep, dtype=torch.float64 | |
| ) | |
| elif schedule == "sqrt": | |
| betas = ( | |
| torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) | |
| ** 0.5 | |
| ) | |
| else: | |
| raise ValueError(f"schedule '{schedule}' unknown.") | |
| return betas.numpy() | |
| """ ============================ End util.py ============================ """ | |
| """ ============================ Start args.py ============================ """ | |
| def parse_opt(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--exp_name", default="new_orientation_alignment_diffusion", help="save to project/name") | |
| parser.add_argument("--with_arm", type=bool, default=False, help="whether osim model has arm DoFs") | |
| parser.add_argument("--with_kinematics_vel", type=bool, default=True, help="whether to include 1st derivative of kinematics") | |
| parser.add_argument("--epochs", type=int, default=7680) | |
| parser.add_argument("--target_sampling_rate", type=int, default=100) | |
| parser.add_argument("--window_len", type=int, default=150) | |
| parser.add_argument("--guide_x_start_the_beginning_step", type=int, default=-10) # negative value means no guidance | |
| parser.add_argument("--project", default="runs/train", help="project/name") | |
| parser.add_argument( | |
| "--processed_data_dir", | |
| type=str, | |
| default="dataset_backups/", | |
| help="Dataset backup path", | |
| ) | |
| parser.add_argument("--feature_type", type=str, default="jukebox") | |
| parser.add_argument("--batch_size", type=int, default=256, help="batch size") | |
| parser.add_argument("--batch_size_inference", type=int, default=5, help="batch size during inference") | |
| parser.add_argument( | |
| "--checkpoint", type=str, default="", help="trained checkpoint path (optional)" | |
| ) | |
| # parser.add_argument( | |
| # "--checkpoint_bl", type=str, default="", help="trained checkpoint path (optional)" | |
| # ) | |
| opt = parser.parse_args(args=[]) | |
| set_no_arm_opt(opt) | |
| current_folder = os.getcwd() | |
| opt.subject_data_path = current_folder | |
| opt.geometry_folder = current_folder + '/Geometry/' | |
| opt.checkpoint_bl = current_folder + '/GaitDynamicsRefinement.pt' | |
| return opt | |
| def set_no_arm_opt(opt): | |
| opt.with_arm = False | |
| opt.osim_dof_columns = copy.deepcopy(OSIM_DOF_ALL[:23] + KINETICS_ALL) | |
| opt.joints_3d = {key_: value_ for key_, value_ in JOINTS_3D_ALL.items() if key_ in ['pelvis', 'hip_r', 'hip_l', 'lumbar']} | |
| opt.model_states_column_names = copy.deepcopy(MODEL_STATES_COLUMN_NAMES_NO_ARM) | |
| for joint_name, joints_with_3_dof in opt.joints_3d.items(): | |
| opt.model_states_column_names = opt.model_states_column_names + [ | |
| joint_name + '_' + axis + '_angular_vel' for axis in ['x', 'y', 'z']] | |
| if opt.with_kinematics_vel: | |
| opt.model_states_column_names = opt.model_states_column_names + [ | |
| f'{col}_vel' for i_col, col in enumerate(opt.model_states_column_names) | |
| if not sum([term in col for term in ['force', 'pelvis_', '_vel', '_0', '_1', '_2', '_3', '_4', '_5']])] | |
| opt.knee_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'knee' in col] | |
| opt.ankle_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'ankle' in col] | |
| opt.hip_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'hip' in col] | |
| opt.kinematic_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'force' not in col] | |
| opt.kinetic_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if i_col not in opt.kinematic_diffusion_col_loc] | |
| opt.grf_osim_col_loc = [i_col for i_col, col in enumerate(opt.osim_dof_columns) if 'force' in col and '_cop_' not in col] | |
| opt.cop_osim_col_loc = [i_col for i_col, col in enumerate(opt.osim_dof_columns) if '_cop_' in col] | |
| opt.kinematic_osim_col_loc = [i_col for i_col, col in enumerate(opt.osim_dof_columns) if 'force' not in col] | |
| """ ============================ Start dataset.py ============================ """ | |
| class MotionDataset(Dataset): | |
| def __init__( | |
| self, | |
| opt, | |
| normalizer: Any = None, | |
| max_trial_num=None, | |
| check_cop_to_calcn_distance=True, | |
| ): | |
| self.data_path = opt.subject_data_path | |
| self.subject_osim_model = opt.subject_osim_model | |
| if opt.target_sampling_rate != 100: | |
| raise ValueError('100 Hz sampling rate is not confirmed. Confirm by setting opt.target_sampling_rate = 100') | |
| self.target_sampling_rate = opt.target_sampling_rate | |
| self.window_len = opt.window_len | |
| self.opt = opt | |
| self.check_cop_to_calcn_distance = check_cop_to_calcn_distance | |
| self.skel = None | |
| self.dset_set = set() | |
| print("Loading dataset...") | |
| self.load_addb(opt) | |
| self.guess_vel_and_replace_txtytz() | |
| if not len(self.trials): | |
| print("No trials loaded") | |
| return | |
| self.normalizer = normalizer | |
| for i_trial in range(len(self.trials)): | |
| self.trials[i_trial].converted_pose = self.normalizer.normalize(self.trials[i_trial].converted_pose).clone().detach().float() | |
| def customized_param_manipulation(self, states_df): | |
| # for guided diffusion | |
| return states_df | |
| def guess_vel_and_replace_txtytz(self): | |
| pelvis_pos_loc = [self.opt.model_states_column_names.index(col) for col in [f'pelvis_t{x}' for x in ['x', 'y', 'z']]] | |
| for i_trial, trial in enumerate(self.trials): | |
| body_center = trial.converted_pose[:, 0:3] | |
| body_center = data_filter(body_center, 10, self.target_sampling_rate).astype(np.float32) | |
| vel_from_t = np.diff(body_center, axis=0) * self.target_sampling_rate | |
| vel_from_t = np.concatenate([vel_from_t, vel_from_t[-1][None, :]], axis=0) | |
| if self.opt.treadmill_speed is None: | |
| raise ValueError('Treadmill speed is not set. For overground walking, set opt.treadmill_speed = 0') | |
| vel_from_t[:, 0] = vel_from_t[:, 0] + self.opt.treadmill_speed | |
| walking_vel = vel_from_t | |
| self.trials[i_trial].converted_pose[:, pelvis_pos_loc] = torch.from_numpy(walking_vel) / self.trials[i_trial].height_m | |
| def __len__(self): | |
| return self.opt.pseudo_dataset_len | |
| def get_overlapping_wins(self, col_loc_to_unmask, step_len, start_trial=0, end_trial=None, including_shorter_than_window_len=False): | |
| if end_trial is None: | |
| end_trial = len(self.trials) | |
| windows, s_list, e_list = [], [], [] | |
| for i_trial in range(start_trial, end_trial): | |
| trial_ = self.trials[i_trial] | |
| col_loc_to_unmask_trial = copy.deepcopy(col_loc_to_unmask) | |
| for col in trial_.missing_col: | |
| col_index = opt.model_states_column_names.index(col) | |
| col_loc_to_unmask_trial.remove(col_index) | |
| trial_len = trial_.converted_pose.shape[0] | |
| if including_shorter_than_window_len: | |
| e_of_trial = trial_len | |
| else: | |
| e_of_trial = trial_len - self.opt.window_len + step_len | |
| for i in range(0, e_of_trial, step_len): | |
| s = max(0, i) | |
| e = min(trial_len, i + self.opt.window_len) | |
| s_list.append(s) | |
| e_list.append(e) | |
| mask = torch.zeros([self.opt.window_len, len(self.opt.model_states_column_names)]) | |
| mask[:, col_loc_to_unmask_trial] = 1 | |
| mask[e-s:, :] = 0 | |
| data_ = torch.zeros([self.opt.window_len, len(self.opt.model_states_column_names)]) | |
| data_[:e-s] = trial_.converted_pose[s:e, ...] | |
| windows.append(WindowData(data_, trial_.model_offsets, i_trial, None, mask, | |
| trial_.height_m, trial_.weight_kg, trial_.missing_col)) | |
| return windows, s_list, e_list | |
| def load_addb(self, opt): | |
| file_paths = opt.file_paths | |
| customOsim: nimble.biomechanics.OpenSimFile = nimble.biomechanics.OpenSimParser.parseOsim(self.subject_osim_model, self.opt.geometry_folder) | |
| skel = customOsim.skeleton | |
| self.skel = skel | |
| self.trials, self.file_names, self.rot_mat_trials, self.time_column = [], [], [], [] | |
| for i_file, file_path in enumerate(file_paths): | |
| model_offsets = get_model_offsets(skel).float() | |
| with open(file_path) as f: | |
| line_num = 0 | |
| endheader_line, angle_scale = None, None | |
| while True: | |
| header = f.readline() | |
| line_num += 1 | |
| if 'endheader' in header: | |
| endheader_line = line_num | |
| break | |
| if 'inDegrees' in header and angle_scale is None: | |
| if 'yes' in header and 'no' not in header: | |
| angle_scale = np.pi / 180 | |
| elif 'no' in header and 'yes' not in header: | |
| angle_scale = 1 | |
| else: | |
| raise ValueError('No inDegrees keyword in the header, cannot determine the unit of angles. ' | |
| 'Here is an example header: \nCoordinates\nversion=1\nnRows=1380' | |
| '\nnColumns=26\ninDegrees=yes\n') | |
| if endheader_line is None: | |
| raise ValueError(f'No \'endheader\' line found in the header of {file_path}.') | |
| poses_df = pd.read_csv(file_path, sep='\t', skiprows=endheader_line) | |
| if 'time' not in poses_df.columns: | |
| raise ValueError(f'{file_path} does not have time column. Necessary for compuing sampling rate') | |
| sampling_rate = round((poses_df.shape[0] - 1) / (poses_df['time'].iloc[-1] - poses_df['time'].iloc[0])) | |
| self.time_column.append(poses_df['time']) | |
| missing_col = [] | |
| for col in JOINTS_1D_ALL: | |
| if col not in poses_df.columns: | |
| poses_df[col] = 0. | |
| if col not in FROZEN_DOFS: | |
| missing_col.append(col) | |
| missing_col.append(col + '_vel') | |
| for key_, value_ in JOINTS_3D_ALL.items(): | |
| if not all([col in poses_df.columns for col in value_]): | |
| columns_not_in_poses = [col for col in value_ if col not in poses_df.columns] | |
| if len(columns_not_in_poses) < 3: | |
| print(f'Warning {key_} is a 3-D joint. Since {columns_not_in_poses.__str__()[1:-1]} is missing, ' | |
| f'the other DoFs ({[dof for dof in value_ if dof not in columns_not_in_poses].__str__()[1:-1]}) are filled with zeros too.') | |
| missing_col.extend([f'{key_}_{n}' for n in range(6)]) | |
| missing_col.extend([f'{key_}_{axis}_angular_vel' for axis in ['x', 'y', 'z']]) | |
| poses_df[value_] = 0. | |
| poses_df = poses_df.astype(float) | |
| col_list = list(poses_df.columns) | |
| col_loc = [col_list.index(col) for col in OSIM_DOF_ALL[:23]] | |
| angle_col_loc = [col_list.index(col) for col in OSIM_DOF_ALL[:23] if col not in ['pelvis_tx', 'pelvis_ty', 'pelvis_tz']] | |
| poses_df.iloc[:, angle_col_loc] = poses_df.iloc[:, angle_col_loc] * angle_scale | |
| poses = poses_df.values[:, col_loc] | |
| all_zeros = np.zeros([poses_df.shape[0], 12]) # used to fill in the missing GRF in normalization | |
| states = np.concatenate([np.array(poses), all_zeros], axis=1) | |
| if not self.is_lumbar_rotation_reasonable(np.array(states), opt.osim_dof_columns): | |
| Warning(f'Warning: {file_path} has unreasonable lumbar rotation, don\'t trust this trial.') | |
| states_aligned, rot_mat = align_moving_direction(states, opt.osim_dof_columns) | |
| if states_aligned is False: | |
| print(f'Warning: {file_path} Pelvis orientation changed by more than 45 deg, don\'t trust this trial') | |
| rot_mat = torch.eye(3).float() | |
| else: | |
| states = states_aligned | |
| self.rot_mat_trials.append(rot_mat) | |
| file_name = file_path.split('/')[-1] | |
| if states.shape[0] / sampling_rate * self.target_sampling_rate < self.window_len + 2: | |
| print(f'Warning: {file_name} is shorter than 1.5s, skipping.') | |
| continue | |
| if sampling_rate != self.target_sampling_rate: | |
| states = linear_resample_data(states, sampling_rate, self.target_sampling_rate) | |
| states_df = pd.DataFrame(states, columns=opt.osim_dof_columns) | |
| states_df = self.customized_param_manipulation(states_df) | |
| states_df, pos_vec = convert_addb_state_to_model_input(states_df, opt.joints_3d, self.target_sampling_rate) | |
| assert self.opt.model_states_column_names == list(states_df.columns) | |
| converted_states = torch.tensor(states_df.values).float() | |
| trial_data = TrialData(converted_states, model_offsets, opt.height_m, opt.weight_kg, rot_mat, pos_vec, | |
| self.window_len, missing_col) | |
| self.trials.append(trial_data) | |
| self.file_names.append(file_name) | |
| def is_lumbar_rotation_reasonable(states, column_names): | |
| lumbar_rotation_col_loc = column_names.index('lumbar_rotation') | |
| if np.abs(np.mean(states[:, lumbar_rotation_col_loc])) > np.deg2rad(45): | |
| return False | |
| else: | |
| return True | |
| class WindowData: | |
| def __init__(self, pose, model_offsets, trial_id, gait_phase_label, mask, height_m, weight_kg, missing_col): | |
| self.pose = pose | |
| self.model_offsets = model_offsets | |
| self.trial_id = trial_id | |
| self.gait_phase_label = gait_phase_label | |
| self.mask = mask | |
| self.height_m = height_m | |
| self.weight_kg = weight_kg | |
| self.missing_col = missing_col | |
| class TrialData: | |
| def __init__(self, converted_states, model_offsets, height_m, weight_kg, rot_mat_for_moving_direction_alignment, | |
| pos_vec_for_pos_alignment, window_len, missing_col): | |
| self.converted_pose = converted_states | |
| self.model_offsets = model_offsets | |
| self.height_m = height_m | |
| self.weight_kg = weight_kg | |
| self.length = converted_states.shape[0] | |
| self.rot_mat_for_moving_direction_alignment = rot_mat_for_moving_direction_alignment | |
| self.pos_vec_for_pos_alignment = pos_vec_for_pos_alignment | |
| self.window_len = window_len | |
| self.missing_col = missing_col | |
| """ ============================ End dataset.py ============================ """ | |
| def predict_grf_and_missing_kinematics(): | |
| refinement_model = BaselineModel(opt, TransformerEncoderArchitecture) | |
| dataset = MotionDataset(opt, normalizer=refinement_model.normalizer) | |
| diffusion_model_for_filling = None | |
| filling_method = DiffusionFilling() | |
| for i_trial in range(len(dataset.trials)): | |
| windows, s_list, e_list = dataset.get_overlapping_wins(opt.kinematic_diffusion_col_loc, 150, i_trial, i_trial+1) | |
| if len(windows) == 0: | |
| continue | |
| if len(windows[0].missing_col) > 0: | |
| print(f'File {dataset.file_names[i_trial]} do not have {windows[0].missing_col}. ' | |
| f'\nGenerating missing kinematics for {dataset.file_names[i_trial]}') | |
| if diffusion_model_for_filling is None: | |
| diffusion_model_for_filling, _ = load_diffusion_model(opt) | |
| windows_reconstructed = filling_method.fill_param(windows, diffusion_model_for_filling) | |
| else: | |
| windows_reconstructed = windows | |
| state_pred_list = [] | |
| print(f'Running GaitDynamics on file {dataset.file_names[i_trial]} for external force prediction.') | |
| for i_win in range(0, len(windows), opt.batch_size_inference): | |
| state_true = torch.stack([win.pose for win in windows_reconstructed[i_win:i_win+opt.batch_size_inference]]) | |
| masks = torch.stack([win.mask for win in windows_reconstructed[i_win:i_win+opt.batch_size_inference]]) | |
| state_pred_list_batch = refinement_model.eval_loop(opt, state_true, masks, num_of_generation_per_window=1)[0] | |
| state_pred_list.append(state_pred_list_batch) | |
| state_pred = torch.cat(state_pred_list, dim=0) | |
| trial_len = dataset.trials[i_trial].converted_pose.shape[0] | |
| results_pred, _ = convert_overlapped_list_to_array( | |
| trial_len, state_pred, s_list, e_list) | |
| height_m_tensor = torch.tensor([windows[0].height_m]) | |
| results_pred = inverse_convert_addb_state_to_model_input( | |
| torch.from_numpy(results_pred).unsqueeze(0), opt.model_states_column_names, opt.treadmill_speed, | |
| opt.joints_3d, opt.osim_dof_columns, dataset.trials[i_trial].pos_vec_for_pos_alignment, height_m_tensor)[0].numpy() | |
| results_pred = inverse_norm_cops(dataset.skel, results_pred, opt, windows[0].weight_kg, windows[0].height_m) | |
| results_pred = inverse_align_moving_direction(results_pred, opt.osim_dof_columns, dataset.rot_mat_trials[i_trial]) | |
| results_pred[:, -12:-9] = results_pred[:, -12:-9] * opt.weight_kg # convert to N | |
| results_pred[:, -6:-3] = results_pred[:, -6:-3] * opt.weight_kg # convert to N | |
| df = pd.DataFrame(results_pred, columns=opt.osim_dof_columns) | |
| trial_save_path = f'{dataset.file_names[i_trial][:-4]}_grf_pred___.mot' | |
| convertDfToGRFMot(df, trial_save_path, round(1 / opt.target_sampling_rate, 3), dataset.time_column[i_trial]) | |
| if len(windows[0].missing_col) > 0: | |
| trc_save_path = f'{dataset.file_names[i_trial][:-4]}_missing_kinematics_pred___.mot' | |
| convertDataframeToMot(df[OSIM_DOF_ALL[:23]], trc_save_path, round(1 / opt.target_sampling_rate, 3), dataset.time_column[i_trial]) | |
| print('You can now download files from the default folder.') | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import io | |
| import shutil | |
| opt = parse_opt() | |
| def gradio_predict(mot_file, osim_file, height_m, weight_kg, treadmill_speed, progress_callback=None): | |
| """Gradio interface function for GRF prediction.""" | |
| if mot_file is None: | |
| return "Please upload a .mot file" | |
| if osim_file is None: | |
| return "Please upload a .osim file" | |
| # Validate inputs | |
| if not (1.0 <= height_m <= 2.5): | |
| return "Height must be between 1.0 and 2.5 meters" | |
| if not (30 <= weight_kg <= 200): | |
| return "Weight must be between 30 and 200 kg" | |
| if not (0 <= treadmill_speed <= 10): | |
| return "Treadmill speed must be between 0 and 10 m/s" | |
| try: | |
| # Get opt with default settings | |
| # Set user inputs | |
| opt.height_m = height_m | |
| opt.weight_kg = weight_kg | |
| opt.treadmill_speed = treadmill_speed | |
| # Copy uploaded files to the working directory | |
| mot_path = os.path.join(opt.subject_data_path, os.path.basename(mot_file.name)) | |
| osim_path = os.path.join(opt.subject_data_path, os.path.basename(osim_file.name)) | |
| shutil.copy2(mot_file.name, mot_path) | |
| shutil.copy2(osim_file.name, osim_path) | |
| opt.file_paths = [mot_path] | |
| opt.subject_osim_model = osim_path | |
| # Run prediction with updated opt | |
| output_files = predict_grf_and_missing_kinematics_with_opt(progress_callback=progress_callback) | |
| if output_files: | |
| return "Prediction completed! Download files below.", output_files[0], output_files[1] if len(output_files) > 1 else None | |
| else: | |
| return "Prediction completed but no output files generated.", None, None | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}", None, None | |
| def predict_grf_and_missing_kinematics_with_opt(progress_callback=None): | |
| """Modified prediction function that accepts opt as parameter.""" | |
| refinement_model = BaselineModel(opt, TransformerEncoderArchitecture) | |
| dataset = MotionDataset(opt, normalizer=refinement_model.normalizer) | |
| diffusion_model_for_filling = None | |
| filling_method = DiffusionFilling() | |
| output_files = [] | |
| for i_trial in range(len(dataset.trials)): | |
| windows, s_list, e_list = dataset.get_overlapping_wins(opt.kinematic_diffusion_col_loc, 150, i_trial, i_trial+1) | |
| if len(windows) == 0: | |
| continue | |
| if len(windows[0].missing_col) > 0: | |
| print(f'File {dataset.file_names[i_trial]} do not have {windows[0].missing_col}. ' | |
| f'\nGenerating missing kinematics for {dataset.file_names[i_trial]}') | |
| if diffusion_model_for_filling is None: | |
| diffusion_model_for_filling, _ = load_diffusion_model(opt) | |
| windows_reconstructed = filling_method.fill_param(windows, diffusion_model_for_filling, progress_callback=progress_callback) | |
| else: | |
| windows_reconstructed = windows | |
| state_pred_list = [] | |
| print(f'Running GaitDynamics on file {dataset.file_names[i_trial]} for external force prediction.') | |
| for i_win in range(0, len(windows), opt.batch_size_inference): | |
| state_true = torch.stack([win.pose for win in windows_reconstructed[i_win:i_win+opt.batch_size_inference]]) | |
| masks = torch.stack([win.mask for win in windows_reconstructed[i_win:i_win+opt.batch_size_inference]]) | |
| state_pred_list_batch = refinement_model.eval_loop(opt, state_true, masks, num_of_generation_per_window=1)[0] | |
| state_pred_list.append(state_pred_list_batch) | |
| state_pred = torch.cat(state_pred_list, dim=0) | |
| trial_len = dataset.trials[i_trial].converted_pose.shape[0] | |
| results_pred, _ = convert_overlapped_list_to_array( | |
| trial_len, state_pred, s_list, e_list) | |
| height_m_tensor = torch.tensor([windows[0].height_m]) | |
| results_pred = inverse_convert_addb_state_to_model_input( | |
| torch.from_numpy(results_pred).unsqueeze(0), opt.model_states_column_names, opt.treadmill_speed, | |
| opt.joints_3d, opt.osim_dof_columns, dataset.trials[i_trial].pos_vec_for_pos_alignment, height_m_tensor)[0].numpy() | |
| results_pred = inverse_norm_cops(dataset.skel, results_pred, opt, windows[0].weight_kg, windows[0].height_m) | |
| results_pred = inverse_align_moving_direction(results_pred, opt.osim_dof_columns, dataset.rot_mat_trials[i_trial]) | |
| results_pred[:, -12:-9] = results_pred[:, -12:-9] * opt.weight_kg # convert to N | |
| results_pred[:, -6:-3] = results_pred[:, -6:-3] * opt.weight_kg # convert to N | |
| df = pd.DataFrame(results_pred, columns=opt.osim_dof_columns) | |
| trial_save_path = f'{dataset.file_names[i_trial][:-4]}_grf_pred___.mot' | |
| convertDfToGRFMot(df, trial_save_path, round(1 / opt.target_sampling_rate, 3), dataset.time_column[i_trial]) | |
| output_files.append(trial_save_path) | |
| if len(windows[0].missing_col) > 0: | |
| trc_save_path = f'{dataset.file_names[i_trial][:-4]}_missing_kinematics_pred___.mot' | |
| convertDataframeToMot(df[OSIM_DOF_ALL[:23]], trc_save_path, round(1 / opt.target_sampling_rate, 3), dataset.time_column[i_trial]) | |
| output_files.append(trc_save_path) | |
| print('You can now download files from the default folder.') | |
| return output_files | |
| def _extract_path(file_obj): | |
| try: | |
| if file_obj is None: | |
| return None | |
| if isinstance(file_obj, str): | |
| return file_obj | |
| if isinstance(file_obj, dict) and 'name' in file_obj: | |
| return file_obj['name'] | |
| if hasattr(file_obj, 'name'): | |
| return file_obj.name | |
| except Exception: | |
| pass | |
| return None | |
| def read_mot_to_dataframe(file_path): | |
| if file_path is None: | |
| return None | |
| with open(file_path, 'r') as f: | |
| lines = f.readlines() | |
| header_idx = None | |
| for i, line in enumerate(lines): | |
| if line.strip().startswith('time'): | |
| header_idx = i | |
| break | |
| if header_idx is None: | |
| return None | |
| data_str = ''.join(lines[header_idx:]) | |
| try: | |
| df = pd.read_csv(io.StringIO(data_str), sep='\t') | |
| except Exception: | |
| df = pd.read_csv(io.StringIO(data_str), delim_whitespace=True) | |
| return df | |
| def list_mot_columns(file_choice, grf_file, kin_file): | |
| selected = _extract_path(grf_file) if file_choice == 'GRF Results' else _extract_path(kin_file) | |
| df = read_mot_to_dataframe(selected) | |
| if df is None or 'time' not in df.columns: | |
| return gr.update(choices=[], value=None) | |
| cols = [c for c in df.columns if c != 'time' and 'torque' not in c.lower()] | |
| return gr.update(choices=cols, value=(cols[0] if len(cols) > 0 else None)) | |
| def plot_mot_signal(file_choice, column, grf_file, kin_file): | |
| selected = _extract_path(grf_file) if file_choice == 'GRF Results' else _extract_path(kin_file) | |
| df = read_mot_to_dataframe(selected) | |
| if df is None or 'time' not in df.columns or column is None or column not in df.columns: | |
| fig = plt.figure(figsize=(10, 5)) | |
| plt.title('No data to display') | |
| plt.tight_layout() | |
| return fig | |
| fig = plt.figure(figsize=(10, 5)) | |
| ax = fig.add_subplot(111) | |
| ax.plot(df['time'], df[column]) | |
| ax.set_xlabel('time') | |
| ax.set_ylabel(column) | |
| ax.set_title(f'{column} over time') | |
| ax.grid(True) | |
| fig.tight_layout() | |
| return fig | |
| # Create Gradio Blocks interface with plotting panel | |
| with gr.Blocks(css=""" | |
| .download-file .file-preview { | |
| background-color: #e8f5e8 !important; | |
| border-left: 4px solid #4caf50 !important; | |
| padding: 6px 10px !important; | |
| font-weight: bold !important; | |
| color: #2e7d32 !important; | |
| } | |
| #time_series_plot {height: 500px !important;} | |
| """) as demo: | |
| gr.HTML( | |
| """ | |
| <h1>GaitDynamics - Ground Reaction Force and Kinematics Prediction</h1> | |
| <h3>General instructions</h3> | |
| <p>This code is for ground reaction force and missing kinematics prediction using flexible combinations of OpenSim joint angles. The joint angles should meet the following criteria:</p> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">1. Use OpenSim <a href="https://simtk.org/projects/full_body">Rajagopal Model without Arms</a>.</div> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">2. Resample to 100 Hz if you have a different sampling rate.</div> | |
| <p>To use this code:</p> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">1. Upload OpenSim model (.osim) and kinematics (.mot) file. If kinematics are provided for all coordinates, the model will only predict ground reaction forces. If kinematics are missing, the model will also predict missing kinematics.</div> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">2. Enter the height and weight of the participant, and the treadmill speed (if applicable).</div> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">3. Click the "Submit " button.</div> | |
| <h3>Processing example data</h3> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">1. Download the example files:</div> | |
| """ | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.DownloadButton("a. OpenSim model file (.osim)", value="example_opensim_model.osim", size="sm", scale=1) | |
| gr.DownloadButton("b. complete kinematics file with all coordinates (.mot)", value="example_mot_complete_kinematics.mot", size="sm", scale=1) | |
| gr.DownloadButton("c. incomplete kinematics file with missing knee coordinate kinematics (.mot)", value="example_mot_missing_knee_kinematics.mot", size="sm", scale=1) | |
| gr.HTML( | |
| """ | |
| <div style="padding-left: 2em; text-indent: -1.0em;">2. Upload data for either of the two example .mot files.</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">a. Predict ground reaction forces with a complete kinematics file by uploading the model file (.osim) and complete kinematics file (.mot).</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">b. Predict ground reaction forces with an incomplete kinematics file by uploading the model file (.osim) and the incomplete kinematics file (.mot).</div> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">3. Update the following input parameters.</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">a. Height = 1.83 m</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">b. Weight = 71.4 kg</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">c. Treadmill speed = 1.15 m/s</div> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">4. Click "Submit".</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">a. The example with complete kinematics should take a few seconds to complete.</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">b. The example with incomplete kinematics should take a few minutes to complete.</div> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">5. Plot the data, starting with GRF Results and choosing columns (e.g., force_r_vy).</div> | |
| """ | |
| ) | |
| gr.HTML("<h3>Input Data and Parameters</h3>") | |
| with gr.Row(): | |
| mot_in = gr.File(label="Upload .mot file", file_types=[".mot"]) | |
| osim_in = gr.File(label="Upload .osim file", file_types=[".osim"]) | |
| with gr.Row(): | |
| height_in = gr.Number(label="Height (meters)", value=1.7, minimum=1.0, maximum=2.5, step=0.01) | |
| weight_in = gr.Number(label="Weight (kg)", value=70, minimum=30, maximum=200, step=0.1) | |
| speed_in = gr.Number(label="Treadmill speed (m/s, 0 for overground)", value=0, minimum=0, maximum=10, step=0.01) | |
| submit_btn = gr.Button("Submit") | |
| gr.HTML("<h3>Results</h3>") | |
| with gr.Row(): | |
| status_out = gr.Textbox(label="Status") | |
| grf_path = gr.State(None) | |
| kin_path = gr.State(None) | |
| with gr.Row(): | |
| grf_file_out = gr.DownloadButton("Download Ground Reaction Force Results", visible=False) | |
| kin_file_out = gr.DownloadButton("Download Missing Kinematics (0 for overground)", visible=False) | |
| # Plotting panel | |
| with gr.Column(elem_id="plotting_panel"): | |
| gr.HTML(""" | |
| <h3>Visualization</h3> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">Ground reaction force (GRF) results are available for all inputs. Missing kinematics results are only available if incomplete kinematics were used as inputs.</div> | |
| <div style="padding-left: 2em; text-indent: -1.0em;">GRF column names follow these conventions:</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">• Starts with force_l or force_r: indicates the left or right foot, respectively</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">• Ends in _vx, _vy, _vz: indicates the magnitude value of the force in the x, y, and z directions, respectively.</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">• Ends in _px, _py, _pz: indicates the point location of the center of pressure in the x, y, and z directions, respectively.</div> | |
| <div style="padding-left: 3.5em; text-indent: -1.0em;">• x is the anterior direction, y is the vertical direction, and z is the direction pointing to the right (medial or lateral direction depending on the leg).</div> | |
| """) | |
| with gr.Row(): | |
| file_select = gr.Radio(choices=["GRF Results", "Missing Kinematics"], value="GRF Results", label="Select output file") | |
| column_select = gr.Dropdown(choices=[], label="Column", interactive=True) | |
| plot_btn = gr.Button("Plot") | |
| plot_out = gr.Plot(label="Time series plot", elem_id="time_series_plot") | |
| gr.HTML(""" | |
| <h3>Citation</h3> | |
| <div style="padding-left: 2em; text-indent: -1.0em;"> | |
| <p>Tan, T., Van Wouwe, T., Werling, K.F. et al. GaitDynamics: a generative foundation model for analyzing human walking and running. Nat. Biomed. Eng (2026).</p> | |
| <p>https://doi.org/10.1038/s41551-025-01565-8</p> | |
| </div> | |
| """) | |
| # Wire interactions | |
| def enhanced_predict(mot_in, osim_in, height_in, weight_in, speed_in, current_select, current_grf, current_kin, progress=gr.Progress()): | |
| progress(0.0, desc="Starting...") | |
| def progress_callback(pct): | |
| progress(pct, desc=f"Processing diffusion model... {int(pct*100)}%") | |
| status, new_grf, new_kin = gradio_predict(mot_in, osim_in, height_in, weight_in, speed_in, progress_callback=progress_callback) | |
| progress(1.0, desc="Complete!") | |
| grf_visible = new_grf is not None | |
| kin_visible = new_kin is not None | |
| return status, new_grf, new_kin, gr.update(value=new_grf, visible=grf_visible), gr.update(value=new_kin, visible=kin_visible) | |
| submit_btn.click( | |
| fn=enhanced_predict, | |
| inputs=[mot_in, osim_in, height_in, weight_in, speed_in, file_select, grf_path, kin_path], | |
| outputs=[status_out, grf_path, kin_path, grf_file_out, kin_file_out] | |
| ).then( | |
| fn=list_mot_columns, | |
| inputs=[file_select, grf_path, kin_path], | |
| outputs=[column_select] | |
| ) | |
| file_select.change( | |
| fn=list_mot_columns, | |
| inputs=[file_select, grf_path, kin_path], | |
| outputs=[column_select] | |
| ) | |
| plot_btn.click( | |
| fn=plot_mot_signal, | |
| inputs=[file_select, column_select, grf_path, kin_path], | |
| outputs=[plot_out] | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch() |