# Define functions and run GaitDynamics. import numpy as np import pandas as pd from torch.utils.data import Dataset import torch.nn.functional as F import torch from torch import nn from torch import Tensor import argparse from typing import Any import nimblephysics as nimble import scipy.interpolate as interpo from scipy.interpolate import interp1d import random from scipy.signal import filtfilt, butter from accelerate import Accelerator, DistributedDataParallelKwargs import os from inspect import isfunction from math import pi from einops import rearrange, repeat import math import copy import matplotlib.pyplot as plt from tqdm import tqdm from functools import partial from accelerate.state import AcceleratorState device = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) random.seed(0) np.random.seed(0) """ ============================ Start scaler.py ============================ """ def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): # if we are fitting on 1D arrays, scale might be a scalar if constant_mask is None: # Detect near constant values to avoid dividing by a very small # value that could lead to surprising results and numerical # stability issues. constant_mask = scale < 10 * torch.finfo(scale.dtype).eps if copy: # New array to avoid side-effects scale = scale.clone() scale[constant_mask] = 1.0 return scale class MinMaxScaler: _parameter_constraints: dict = { "feature_range": [tuple], "copy": ["boolean"], "clip": ["boolean"], } def __init__(self, feature_range=(-1, 1), *, copy=True, clip=True): self.feature_range = feature_range self.copy = copy self.clip = clip def _reset(self): if hasattr(self, "scale_"): del self.scale_ del self.min_ del self.n_samples_seen_ del self.data_min_ del self.data_max_ del self.data_range_ def fit(self, X): # Reset internal state before fitting self._reset() return self.partial_fit(X) def partial_fit(self, X): feature_range = self.feature_range if feature_range[0] >= feature_range[1]: raise ValueError( "Minimum of desired feature range must be smaller than maximum. Got %s." % str(feature_range) ) data_min = torch.min(X, axis=0)[0] data_max = torch.max(X, axis=0)[0] self.n_samples_seen_ = X.shape[0] data_range = data_max - data_min self.scale_ = (feature_range[1] - feature_range[0]) / _handle_zeros_in_scale( data_range, copy=True ) self.min_ = feature_range[0] - data_min * self.scale_ self.data_min_ = data_min self.data_max_ = data_max self.data_range_ = data_range return self def transform(self, X): X *= self.scale_.to(X.device) X += self.min_.to(X.device) return X def inverse_transform(self, X): X -= self.min_.to(X.device) X /= self.scale_.to(X.device) return X class Normalizer: def __init__(self, scaler, cols_to_normalize): self.scaler = scaler self.cols_to_normalize = cols_to_normalize def normalize(self, x): x = x.clone() x[:, self.cols_to_normalize] = self.scaler.transform(x[:, self.cols_to_normalize]) return x def unnormalize(self, x): batch, seq, ch = x.shape x = x.clone() x = x.reshape(-1, ch) x[:, self.cols_to_normalize] = self.scaler.inverse_transform(x[:, self.cols_to_normalize]) return x.reshape((batch, seq, ch)) """ ============================ End scaler.py ============================ """ """ ============================ Start consts.py ============================ """ JOINTS_3D_ALL = { 'pelvis': ['pelvis_tilt', 'pelvis_list', 'pelvis_rotation'], 'hip_r': ['hip_flexion_r', 'hip_adduction_r', 'hip_rotation_r'], 'hip_l': ['hip_flexion_l', 'hip_adduction_l', 'hip_rotation_l'], 'lumbar': ['lumbar_extension', 'lumbar_bending', 'lumbar_rotation'], # 'arm_r': ['arm_flex_r', 'arm_add_r', 'arm_rot_r'], # 'arm_l': ['arm_flex_l', 'arm_add_l', 'arm_rot_l'] } OSIM_DOF_ALL = [ 'pelvis_tilt', 'pelvis_list', 'pelvis_rotation', 'pelvis_tx', 'pelvis_ty', 'pelvis_tz', 'hip_flexion_r', 'hip_adduction_r', 'hip_rotation_r', 'knee_angle_r', 'ankle_angle_r', 'subtalar_angle_r', 'mtp_angle_r', 'hip_flexion_l', 'hip_adduction_l', 'hip_rotation_l', 'knee_angle_l', 'ankle_angle_l', 'subtalar_angle_l', 'mtp_angle_l', 'lumbar_extension', 'lumbar_bending', 'lumbar_rotation', 'arm_flex_r', 'arm_add_r', 'arm_rot_r', 'elbow_flex_r', 'pro_sup_r', 'wrist_flex_r', 'wrist_dev_r', 'arm_flex_l', 'arm_add_l', 'arm_rot_l', 'elbow_flex_l', 'pro_sup_l', 'wrist_flex_l', 'wrist_dev_l'] KINETICS_ALL = [body + modality for body in ['calcn_r', 'calcn_l'] for modality in ['_force_vx', '_force_vy', '_force_vz', '_force_normed_cop_x', '_force_normed_cop_y', '_force_normed_cop_z']] MODEL_STATES_COLUMN_NAMES_WITH_ARM = [ 'pelvis_tx', 'pelvis_ty', 'pelvis_tz', 'knee_angle_r', 'ankle_angle_r', 'subtalar_angle_r', 'knee_angle_l', 'ankle_angle_l', 'subtalar_angle_l', 'elbow_flex_r', 'pro_sup_r', 'elbow_flex_l', 'pro_sup_l' ] + KINETICS_ALL + [ 'pelvis_0', 'pelvis_1', 'pelvis_2', 'pelvis_3', 'pelvis_4', 'pelvis_5', 'hip_r_0', 'hip_r_1', 'hip_r_2', 'hip_r_3', 'hip_r_4', 'hip_r_5', 'hip_l_0', 'hip_l_1', 'hip_l_2', 'hip_l_3', 'hip_l_4', 'hip_l_5', 'lumbar_0', 'lumbar_1', 'lumbar_2', 'lumbar_3', 'lumbar_4', 'lumbar_5', 'arm_r_0', 'arm_r_1', 'arm_r_2', 'arm_r_3', 'arm_r_4', 'arm_r_5', # only for with arm 'arm_l_0', 'arm_l_1', 'arm_l_2', 'arm_l_3', 'arm_l_4', 'arm_l_5' # only for with arm ] MODEL_STATES_COLUMN_NAMES_NO_ARM = copy.deepcopy(MODEL_STATES_COLUMN_NAMES_WITH_ARM) for name_ in ['elbow_flex_r', 'pro_sup_r', 'elbow_flex_l', 'pro_sup_l', 'arm_r_0', 'arm_r_1', 'arm_r_2', 'arm_r_3', 'arm_r_4', 'arm_r_5', 'arm_l_0', 'arm_l_1', 'arm_l_2', 'arm_l_3', 'arm_l_4', 'arm_l_5']: MODEL_STATES_COLUMN_NAMES_NO_ARM.remove(name_) FROZEN_DOFS = ['mtp_angle_r', 'mtp_angle_l', 'wrist_flex_r', 'wrist_dev_r', 'wrist_flex_l', 'wrist_dev_l'] FULL_OSIM_DOF = ['pelvis_tilt', 'pelvis_list', 'pelvis_rotation', 'pelvis_tx', 'pelvis_ty', 'pelvis_tz', 'hip_flexion_r', 'hip_adduction_r', 'hip_rotation_r', 'knee_angle_r', 'ankle_angle_r', 'subtalar_angle_r', 'mtp_angle_r', 'hip_flexion_l', 'hip_adduction_l', 'hip_rotation_l', 'knee_angle_l', 'ankle_angle_l', 'subtalar_angle_l', 'mtp_angle_l', 'lumbar_extension', 'lumbar_bending', 'lumbar_rotation'] JOINTS_1D_ALL = ['pelvis_tx', 'pelvis_ty', 'pelvis_tz', 'knee_angle_r', 'ankle_angle_r', 'subtalar_angle_r', 'mtp_angle_r', 'knee_angle_l', 'ankle_angle_l', 'subtalar_angle_l', 'mtp_angle_l'] """ ============================ End consts.py ============================ """ """ ============================ Start model.py ============================ """ class GaussianDiffusion(nn.Module): def __init__( self, model, horizon, repr_dim, opt, n_timestep=1000, schedule="linear", loss_type="l1", clip_denoised=False, predict_epsilon=True, guidance_weight=1, use_p2=False, cond_drop_prob=0., ): super().__init__() self.horizon = horizon self.transition_dim = repr_dim self.model = model self.ema = EMA(0.99) self.master_model = copy.deepcopy(self.model) self.opt = opt self.cond_drop_prob = cond_drop_prob # make a SMPL instance for FK module betas = torch.Tensor( make_beta_schedule(schedule=schedule, n_timestep=n_timestep) ) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) self.n_timestep = int(n_timestep) self.clip_denoised = clip_denoised self.predict_epsilon = predict_epsilon self.register_buffer("betas", betas) self.register_buffer("alphas_cumprod", alphas_cumprod) self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) self.guidance_weight = guidance_weight # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) self.register_buffer( "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) ) self.register_buffer( "log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod) ) self.register_buffer( "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod) ) self.register_buffer( "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) self.register_buffer("posterior_variance", posterior_variance) ## log calculation clipped because the posterior variance ## is 0 at the beginning of the diffusion chain self.register_buffer( "posterior_log_variance_clipped", torch.log(torch.clamp(posterior_variance, min=1e-20)), ) self.register_buffer( "posterior_mean_coef1", betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), ) self.register_buffer( "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod), ) # p2 weighting self.p2_loss_weight_k = 1 self.p2_loss_weight_gamma = 0.5 if use_p2 else 0 self.register_buffer( "p2_loss_weight", (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -self.p2_loss_weight_gamma, ) ## get loss coefficients and initialize objective self.loss_fn = F.mse_loss if loss_type == "l2" else F.l1_loss def set_normalizer(self, normalizer): self.normalizer = normalizer # ------------------------------------------ sampling ------------------------------------------# def predict_start_from_noise(self, x_t, t, noise): """ if self.predict_epsilon, model output is (scaled) noise; otherwise, model predicts x0 directly """ if self.predict_epsilon: return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) else: return noise def predict_noise_from_start(self, x_t, t, x0): return ( (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) ) def model_predictions(self, x, cond, time_cond, weight=None, clip_x_start=False): weight = weight if weight is not None else self.guidance_weight model_output = self.model.guided_forward(x, cond, time_cond, weight) maybe_clip = partial(torch.clamp, min=-1., max=1.) if clip_x_start else identity x_start = model_output x_start = maybe_clip(x_start) pred_noise = self.predict_noise_from_start(x, time_cond, x_start) return pred_noise, x_start def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract( self.posterior_log_variance_clipped, t, x_t.shape ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, cond, t): # guidance clipping if t[0] > 1.0 * self.n_timestep: weight = min(self.guidance_weight, 0) elif t[0] < 0.1 * self.n_timestep: weight = min(self.guidance_weight, 1) else: weight = self.guidance_weight x_recon = self.predict_start_from_noise( x, t=t, noise=self.model.guided_forward(x, cond, t, weight) ) if self.clip_denoised: x_recon.clamp_(-1.0, 1.0) else: assert RuntimeError() model_mean, posterior_variance, posterior_log_variance = self.q_posterior( x_start=x_recon, x_t=x, t=t ) return model_mean, posterior_variance, posterior_log_variance, x_recon @torch.no_grad() def p_sample(self, x, cond, t): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance, x_start = self.p_mean_variance( x=x, cond=cond, t=t ) noise = torch.randn_like(model_mean) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape( b, *((1,) * (len(noise.shape) - 1)) ) x_out = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return x_out, x_start @torch.no_grad() def p_sample_loop( # Only used during inference self, shape, cond, noise=None, constraint=None, return_diffusion=False, start_point=None, ): device = self.betas.device # default to diffusion over whole timescale start_point = self.n_timestep if start_point is None else start_point batch_size = shape[0] x = torch.randn(shape, device=device) if noise is None else noise.to(device) cond = cond.to(device) if return_diffusion: diffusion = [x] for i in tqdm(reversed(range(0, start_point))): # fill with i timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) x, _ = self.p_sample(x, cond, timesteps) if return_diffusion: diffusion.append(x) if return_diffusion: return x, diffusion else: return x @torch.no_grad() def inpaint_ddim_guided(self, shape, noise=None, constraint=None, return_diffusion=False, start_point=None): batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 0 times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps times = list(reversed(times.int().tolist())) time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] x = torch.randn(shape, device=device) cond = constraint["cond"].to(device) mask = constraint["mask"].to(device) # batch x horizon x channels value = constraint["value"].to(device) # batch x horizon x channels value_diff_thd = constraint["value_diff_thd"].to(device) # channels value_diff_weight = constraint["value_diff_weight"].to(device) # channels for time, time_next in time_pairs: time_cond = torch.full((batch,), time, device=device, dtype=torch.long) if self.opt.guide_x_start_the_end_step <= time <= self.opt.guide_x_start_the_beginning_step: x.requires_grad_() with torch.enable_grad(): for step_ in range(self.opt.n_guided_steps): pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start=self.clip_denoised) value_diff = torch.subtract(x_start, value) loss = torch.relu(value_diff.abs() - value_diff_thd) * value_diff_weight grad = torch.autograd.grad([loss.sum()], [x])[0] x = x - self.opt.guidance_lr * grad pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start=self.clip_denoised) if time_next < 0: x = x_start x = value * mask + (1.0 - mask) * x return x alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c = (1 - alpha_next - sigma ** 2).sqrt() noise = torch.randn_like(x) x = x_start * alpha_next.sqrt() + \ c * pred_noise + \ sigma * noise timesteps = torch.full((batch,), time_next, device=device, dtype=torch.long) value_ = self.q_sample(value, timesteps) if (time > 0) else x x = value_ * mask + (1.0 - mask) * x return x @torch.no_grad() def inpaint_ddim_loop(self, shape, noise=None, constraint=None, return_diffusion=False, start_point=None): batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 0 times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) times = list(reversed(times.int().tolist())) time_pairs = list(zip(times[:-1], times[1:])) x = torch.randn(shape, device=device) cond = constraint["cond"].to(device) mask = constraint["mask"].to(device) # batch x horizon x channels value = constraint["value"].to(device) # batch x horizon x channels for time, time_next in time_pairs: time_cond = torch.full((batch,), time, device=device, dtype=torch.long) pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start=self.clip_denoised) if time_next < 0: x = x_start return x alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c = (1 - alpha_next - sigma ** 2).sqrt() noise = torch.randn_like(x) x = x_start * alpha_next.sqrt() + \ c * pred_noise + \ sigma * noise timesteps = torch.full((batch,), time_next, device=device, dtype=torch.long) value_ = self.q_sample(value, timesteps) if (time > 0) else x x = value_ * mask + (1.0 - mask) * x return x def q_sample(self, x_start, t, noise=None): # blend noise into state variables if noise is None: noise = torch.randn_like(x_start) sample = ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) return sample def forward(self, x, cond, t_override=None): return self.loss(x, cond, t_override) def generate_samples( self, shape, normalizer, opt, mode, noise=None, constraint=None, start_point=None, ): torch.manual_seed(torch.randint(0, 2 ** 32, (1,)).item()) if isinstance(shape, tuple): if mode == "inpaint": func_class = self.inpaint_ddim_loop elif mode == "inpaint_ddim_guided": func_class = self.inpaint_ddim_guided else: assert False, "Unrecognized inference mode" samples = ( func_class( shape, noise=noise, constraint=constraint, start_point=start_point, ) ) else: samples = shape samples = normalizer.unnormalize(samples.detach().cpu()) samples = samples.detach().cpu() return samples class FillingBase: def fill_param(self, windows, diffusion_model_for_filling, progress_callback=None): return self.filling(windows, diffusion_model_for_filling, self.update_kinematics_and_masks_for_masking_column, progress_callback=progress_callback) @staticmethod def update_kinematics_and_masks_for_masking_column(windows, samples, i_win, masks): unmasked_samples_in_temporal_dim = (masks.sum(axis=2)).bool() for j_win in range(len(samples)): windows[j_win+i_win].pose = samples[j_win] updated_mask = windows[j_win+i_win].mask updated_mask[unmasked_samples_in_temporal_dim[j_win], :] = 1 updated_mask[:, opt.kinetic_diffusion_col_loc] = 0 windows[j_win+i_win].mask = updated_mask return windows def update_kinematics_and_masks_for_masking_temporal(self, windows, samples, i_win, masks): for j_win in range(len(samples)): windows[j_win+i_win].pose = samples[j_win] windows[j_win+i_win].mask = self.mask_original[j_win+i_win] return windows @staticmethod def filling(windows, diffusion_model_for_filling, windows_update_func): raise NotImplementedError class DiffusionFilling(FillingBase): @staticmethod def filling(windows, diffusion_model_for_filling, windows_update_func, progress_callback=None): windows = copy.deepcopy(windows) total_iterations = len(range(0, len(windows), opt.batch_size_inference)) for iteration_idx, i_win in enumerate(range(0, len(windows), opt.batch_size_inference)): state_true = torch.stack([win.pose for win in windows[i_win:i_win+opt.batch_size_inference]]) masks = torch.stack([win.mask for win in windows[i_win:i_win+opt.batch_size_inference]]) cond = torch.ones([6]) constraint = {'mask': masks, 'value': state_true.clone(), 'cond': cond} shape = (state_true.shape[0], state_true.shape[1], state_true.shape[2]) samples = (diffusion_model_for_filling.diffusion.inpaint_ddim_loop(shape, constraint=constraint)) samples = state_true * masks + (1.0 - masks) * samples.to(state_true.device) # samples[:, :, opt.kinetic_diffusion_col_loc] = state_true[:, :, opt.kinetic_diffusion_col_loc] windows = windows_update_func(windows, samples, i_win, masks) # Report progress: 10% for first iteration, 80% for last iteration if progress_callback: progress = 0.1 + 0.7 * (iteration_idx + 1) / total_iterations progress_callback(progress) return windows def __str__(self): return 'diffusion_filling' def wrap(x): return {f"module.{key}": value for key, value in x.items()} def maybe_wrap(x, num): return x if num == 1 else wrap(x) def exists(val): return val is not None def broadcat(tensors, dim=-1): num_tensors = len(tensors) shape_lens = set(list(map(lambda t: len(t.shape), tensors))) assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim dims = list(zip(*map(lambda t: list(t.shape), tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all( [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] ), "invalid dimensions for broadcastable concatentation" max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) expanded_dims.insert(dim, (dim, dims[dim])) expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) return torch.cat(tensors, dim=dim) def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") def apply_rotary_emb(freqs, t, start_index=0): freqs = freqs.to(t) rot_dim = freqs.shape[-1] end_index = start_index + rot_dim assert ( rot_dim <= t.shape[-1] ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" t_left, t, t_right = ( t[..., :start_index], t[..., start_index:end_index], t[..., end_index:], ) t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) return torch.cat((t_left, t, t_right), dim=-1) class RotaryEmbedding(nn.Module): def __init__( self, dim, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, learned_freq=False, ): super().__init__() if exists(custom_freqs): freqs = custom_freqs elif freqs_for == "lang": freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) elif freqs_for == "pixel": freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi elif freqs_for == "constant": freqs = torch.ones(num_freqs).float() else: raise ValueError(f"unknown modality {freqs_for}") self.cache = dict() if learned_freq: self.freqs = nn.Parameter(freqs) else: self.register_buffer("freqs", freqs) def rotate_queries_or_keys(self, t, seq_dim=-2): device = t.device seq_len = t.shape[seq_dim] freqs = self.forward( lambda: torch.arange(seq_len, device=device), cache_key=seq_len ) return apply_rotary_emb(freqs, t) def forward(self, t, cache_key=None): if exists(cache_key) and cache_key in self.cache: return self.cache[cache_key] if isfunction(t): t = t() freqs = self.freqs freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs) freqs = repeat(freqs, "... n -> ... (n r)", r=2) if exists(cache_key): self.cache[cache_key] = freqs return freqs class EMA: def __init__(self, beta): super().__init__() self.beta = beta def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip( current_model.parameters(), ma_model.parameters() ): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len, dropout: float = 0.): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(max_len * 2) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) try: pe[:, 0, 1::2] = torch.cos(position * div_term) except RuntimeError: pe[:, 0, 1::2] = torch.cos(position * div_term)[:, :-1] pe = pe.transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [batch_size, seq_len, embedding_dim] """ x = x + self.pe[:, :x.shape[1], :] return self.dropout(x) class PositionWiseFeedForward(nn.Module): def __init__(self, d_model, d_ff): super(PositionWiseFeedForward, self).__init__() self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.relu = nn.ReLU() def forward(self, x): return self.fc2(self.relu(self.fc1(x))) class EncoderLayer(nn.Module): def __init__(self, d_model, num_heads=8, d_ff=512, dropout=0.1): super(EncoderLayer, self).__init__() self.rotary = RotaryEmbedding(dim=d_model) self.use_rotary = self.rotary is not None self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True) self.feed_forward = PositionWiseFeedForward(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x attn_output, _ = self.self_attn(qk, qk, x, need_weights=False) x = self.norm1(x + self.dropout(attn_output)) ff_output = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_output)) return x class TransformerEncoderArchitecture(nn.Module): def __init__(self, repr_dim, opt, nlayers=6): super(TransformerEncoderArchitecture, self).__init__() self.input_dim = len(opt.kinematic_diffusion_col_loc) self.output_dim = repr_dim - self.input_dim embedding_dim = 192 self.input_to_embedding = nn.Linear(self.input_dim, embedding_dim) self.encoder_layers = nn.Sequential(*[EncoderLayer(embedding_dim) for _ in range(nlayers)]) self.embedding_to_output = nn.Linear(embedding_dim, self.output_dim) self.opt = opt self.input_col_loc = opt.kinematic_diffusion_col_loc self.output_col_loc = [i for i in range(repr_dim) if i not in self.input_col_loc] def loss_fun(self, output_pred, output_true): return F.mse_loss(output_pred, output_true, reduction='none') def end_to_end_prediction(self, x): input = x[0][:, :, self.input_col_loc] sequence = self.input_to_embedding(input) sequence = self.encoder_layers(sequence) output_pred = self.embedding_to_output(sequence) return output_pred def __str__(self): return 'tf' class DiffusionShellForAdaptingTheOriginalFramework(nn.Module): def __init__(self, model): super(DiffusionShellForAdaptingTheOriginalFramework, self).__init__() self.device = device self.model = model self.ema = EMA(0.99) self.master_model = copy.deepcopy(self.model) def set_normalizer(self, normalizer): self.normalizer = normalizer def predict_samples(self, x, constraint): x[0] = x[0] * constraint['mask'] output_pred = self.model.end_to_end_prediction(x) x[0][:, :, self.model.output_col_loc] = output_pred return x[0] def forward(self, x, cond, t_override): output_true = x[0][:, :, self.model.output_col_loc] output_pred = self.model.end_to_end_prediction(x) loss_simple = torch.zeros(x[0].shape).to(x[0].device) loss_simple[:, :, self.model.output_col_loc] = self.model.loss_fun(output_pred, output_true) losses = [ 1. * loss_simple.mean(), torch.tensor(0.).to(loss_simple.device), torch.tensor(0.).to(loss_simple.device), torch.tensor(0.).to(loss_simple.device), torch.tensor(0).to(loss_simple.device)] return sum(losses), losses + [loss_simple] def featurewise_affine(x, scale_shift): scale, shift = scale_shift return (scale + 1) * x + shift class DenseFiLM(nn.Module): """Feature-wise linear modulation (FiLM) generator.""" def __init__(self, embed_channels): super().__init__() self.embed_channels = embed_channels self.block = nn.Sequential( nn.Mish(), nn.Linear(embed_channels, embed_channels * 2) ) def forward(self, position): pos_encoding = self.block(position) pos_encoding = rearrange(pos_encoding, "b c -> b 1 c") scale_shift = pos_encoding.chunk(2, dim=-1) return scale_shift class FiLMTransformerDecoderLayer(nn.Module): def __init__( self, d_model: int, nhead: int, dim_feedforward=2048, dropout=0.1, activation=F.relu, layer_norm_eps=1e-5, batch_first=False, norm_first=True, device=None, dtype=None, rotary=None, ): super().__init__() self.self_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first ) self.multihead_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first ) # Feedforward self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm_first = norm_first self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.dropout1 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = activation self.film1 = DenseFiLM(d_model) self.film3 = DenseFiLM(d_model) self.rotary = rotary self.use_rotary = rotary is not None # x, cond, t def forward( self, tgt, memory, t, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, ): x = tgt if self.norm_first: x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) x = x + featurewise_affine(x_1, self.film1(t)) x_3 = self._ff_block(self.norm3(x)) x = x + featurewise_affine(x_3, self.film3(t)) else: x = self.norm1( x + featurewise_affine( self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t) ) ) x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t))) return x # self-attention block # qkv def _sa_block(self, x, attn_mask, key_padding_mask): qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x x = self.self_attn( qk, qk, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout1(x) # multihead attention block # qkv def _mha_block(self, x, mem, attn_mask, key_padding_mask): q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem x = self.multihead_attn( q, k, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout2(x) # feed forward block def _ff_block(self, x): x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout3(x) class DecoderLayerStack(nn.Module): def __init__(self, stack): super().__init__() self.stack = stack def forward(self, x, cond, t): for layer in self.stack: x = layer(x, cond, t) return x class DanceDecoder(nn.Module): def __init__( self, nfeats: int, seq_len: int = 150, # 5 seconds, 30 fps latent_dim: int = 512, ff_size: int = 1024, num_layers: int = 8, num_heads: int = 8, dropout: float = 0.1, # cond_feature_dim: int = 6, activation=F.gelu, use_rotary=True, **kwargs ) -> None: super().__init__() output_feats = nfeats # positional embeddings self.rotary = None self.abs_pos_encoding = nn.Identity() # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity) if use_rotary: self.rotary = RotaryEmbedding(dim=latent_dim) else: self.abs_pos_encoding = PositionalEncoding( latent_dim, dropout, batch_first=True ) # time embedding processing self.time_mlp = nn.Sequential( SinusoidalPosEmb(latent_dim), nn.Linear(latent_dim, latent_dim * 4), nn.Mish(), ) # input projection self.input_projection = nn.Linear(nfeats, latent_dim) self.to_time_cond = nn.Sequential(nn.Linear(latent_dim * 4, latent_dim),) decoderstack = nn.ModuleList([]) for _ in range(num_layers): decoderstack.append( FiLMTransformerDecoderLayer( # decoder layers latent_dim, num_heads, dim_feedforward=ff_size, dropout=dropout, activation=activation, batch_first=True, rotary=self.rotary, ) ) self.seqTransDecoder = DecoderLayerStack(decoderstack) self.final_layer = nn.Linear(latent_dim, output_feats) def guided_forward(self, x, cond_embed, time_cond, guidance_weight): return self.forward(x, cond_embed, time_cond) def __str__(self): return 'diffusion' # No conditioning version def forward(self, x: Tensor, cond_embed: Tensor, time_cond: Tensor, cond_drop_prob: float = 0.0): x = self.input_projection(x) x = self.abs_pos_encoding(x) t_hidden = self.time_mlp(time_cond) t = self.to_time_cond(t_hidden) output = self.seqTransDecoder(x, None, t) output = self.final_layer(output) return output class MotionModel: def __init__( self, opt, normalizer=None, EMA=True, learning_rate=4e-4, weight_decay=0.02, ): self.opt = opt ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) state = AcceleratorState() num_processes = state.num_processes self.repr_dim = len(opt.model_states_column_names) self.horizon = horizon = opt.window_len self.accelerator.wait_for_everyone() checkpoint = None if opt.checkpoint != "": checkpoint = torch.load( opt.checkpoint, map_location=self.accelerator.device, weights_only=False ) self.normalizer = checkpoint["normalizer"] model = DanceDecoder( nfeats=self.repr_dim, seq_len=horizon, latent_dim=512, ff_size=1024, num_layers=8, num_heads=8, dropout=0.1, activation=F.gelu, ) diffusion = GaussianDiffusion( model, horizon, self.repr_dim, opt, # schedule="cosine", n_timestep=1000, predict_epsilon=False, loss_type="l2", use_p2=False, cond_drop_prob=0., guidance_weight=2, ) self.model = self.accelerator.prepare(model) self.diffusion = diffusion.to(self.accelerator.device) if opt.checkpoint != "": self.model.load_state_dict( maybe_wrap( checkpoint["ema_state_dict" if EMA else "model_state_dict"], num_processes, ) ) def eval(self): self.diffusion.eval() def prepare(self, objects): return self.accelerator.prepare(*objects) def eval_loop(self, opt, state_true, masks, value_diff_thd=None, value_diff_weight=None, cond=None, num_of_generation_per_window=1, mode="inpaint"): self.eval() if value_diff_thd is None: value_diff_thd = torch.zeros([state_true.shape[2]]) if value_diff_weight is None: value_diff_weight = torch.ones([state_true.shape[2]]) if cond is None: cond = torch.ones([6]) constraint = {'mask': masks, 'value': state_true.clone(), 'value_diff_thd': value_diff_thd, 'value_diff_weight': value_diff_weight, 'cond': cond} shape = (state_true.shape[0], self.horizon, self.repr_dim) state_pred_list = [self.diffusion.generate_samples( shape, self.normalizer, opt, mode=mode, constraint=constraint) for _ in range(num_of_generation_per_window)] return torch.stack(state_pred_list) class BaselineModel: def __init__( self, opt, model_architecture_class, EMA=True, ): self.opt = opt ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) self.repr_dim = len(opt.model_states_column_names) self.horizon = horizon = opt.window_len self.accelerator.wait_for_everyone() self.model = model_architecture_class(self.repr_dim, opt) self.diffusion = DiffusionShellForAdaptingTheOriginalFramework(self.model) self.diffusion = self.accelerator.prepare(self.diffusion) checkpoint = None if opt.checkpoint_bl != "": checkpoint = torch.load( opt.checkpoint_bl, map_location=self.accelerator.device, weights_only=False ) self.normalizer = checkpoint["normalizer"] if opt.checkpoint_bl != "": self.model.load_state_dict( maybe_wrap( checkpoint["ema_state_dict" if EMA else "model_state_dict"], 1, ) ) def eval_loop(self, opt, state_true, masks, value_diff_thd=None, value_diff_weight=None, cond=None, num_of_generation_per_window=1, mode="inpaint"): self.eval() constraint = {'mask': masks.to(self.accelerator.device), 'value': state_true, 'cond': cond} state_true = state_true.to(self.accelerator.device) state_pred_list = [self.diffusion.predict_samples([state_true], constraint) for _ in range(num_of_generation_per_window)] state_pred_list = [self.normalizer.unnormalize(state_pred.detach().cpu()) for state_pred in state_pred_list] return torch.stack(state_pred_list) def eval(self): self.diffusion.eval() def train(self): self.diffusion.train() def prepare(self, objects): return self.accelerator.prepare(*objects) """ ============================ End model.py ============================ """ """ ============================ Start util.py ============================ """ def load_diffusion_model(opt): opt.checkpoint = opt.subject_data_path + '/GaitDynamicsDiffusion.pt' model = MotionModel(opt) model_key = 'diffusion' return model, model_key class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb def convertDfToGRFMot(df, out_path, dt, time_column): numFrames = df.shape[0] for key in df.keys(): if key == 'TimeStamp': continue out_file = open(out_path, 'w') out_file.write('nColumns=9\n') out_file.write('nRows='+str(numFrames)+'\n') out_file.write('DataType=double\n') out_file.write('version=3\n') out_file.write('OpenSimVersion=4.1\n') out_file.write('endheader\n') out_file.write('time') plate_num = 2 for i, side in zip(range(1, 1 + plate_num), ['r', 'l']): out_file.write('\t' + f'force_{side}_vx') out_file.write('\t' + f'force_{side}_vy') out_file.write('\t' + f'force_{side}_vz') out_file.write('\t' + f'force_{side}_px') out_file.write('\t' + f'force_{side}_py') out_file.write('\t' + f'force_{side}_pz') out_file.write('\t' + f'torque_{side}_x') out_file.write('\t' + f'torque_{side}_y') out_file.write('\t' + f'torque_{side}_z') out_file.write('\n') for i in range(numFrames): out_file.write(str(round(dt * i + time_column[0], 5))) for side in ['r', 'l']: out_file.write('\t' + str(df[f'calcn_{side}_force_vx'][i])) out_file.write('\t' + str(df[f'calcn_{side}_force_vy'][i])) out_file.write('\t' + str(df[f'calcn_{side}_force_vz'][i])) out_file.write('\t' + str(df[f'calcn_{side}_force_normed_cop_x'][i])) out_file.write('\t' + str(df[f'calcn_{side}_force_normed_cop_y'][i])) out_file.write('\t' + str(df[f'calcn_{side}_force_normed_cop_z'][i])) out_file.write('\t' + str(0)) out_file.write('\t' + str(0)) out_file.write('\t' + str(0)) out_file.write('\n') out_file.close() print('Ground reaction forces exported to ' + out_path) def convertDataframeToMot(df, out_path, dt, time_column): numFrames = df.shape[0] out_file = open(out_path, 'w') out_file.write('Coordinates\n') out_file.write('version=1\n') out_file.write(f'nRows={numFrames}\n') out_file.write(f'nColumns={len(df.columns)+1}\n') out_file.write('inDegrees=no\n\n') out_file.write('If the header above contains a line with \'inDegrees\', this indicates whether rotational values are in degrees (yes) or radians (no).\n\n') out_file.write('endheader\n') out_file.write('time') for i in range(len(df.columns)): out_file.write('\t' + df.columns[i]) out_file.write('\n') for i in range(numFrames): out_file.write(str(round(dt * i + time_column[0], 5))) for j in range(len(df.columns)): out_file.write('\t' + str(df.iloc[i, j])) out_file.write('\n') out_file.close() print('Missing kinematics exported to ' + out_path) def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: a1, a2 = d6[..., :3], d6[..., 3:] b1 = F.normalize(a1, dim=-1) b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 b2 = F.normalize(b2, dim=-1) b3 = torch.cross(b1, b2, dim=-1) return torch.stack((b1, b2, b3), dim=-2) def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: batch_dim = matrix.size()[:-2] return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) def euler_from_6v(q, convention="XYZ"): assert q.shape[-1] == 6 mat = rotation_6d_to_matrix(q) eul = matrix_to_euler_angles(mat, convention) return eul def euler_to_6v(q, convention="XYZ"): assert q.shape[-1] == 3 mat = euler_angles_to_matrix(q, convention) mat = matrix_to_rotation_6d(mat) return mat def second_order_poly(coeff, x): y = coeff[...,0] * x**2 + coeff[...,1] * x + coeff[...,2] return y def batch_identity(batch_shape, size): batch_identity = torch.eye(size) output_shape = batch_shape.copy() output_shape.append(size) output_shape.append(size) batch_identity_out = batch_identity.view(*(1,) * (len(output_shape) - batch_identity.ndim),*batch_identity.shape).expand(output_shape) return batch_identity_out.clone() def get_knee_rotation_coefficients(): knee_Z_rotation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], [0, 0.0126809, 0.0226969, 0.0296054, 0.0332049, 0.0335354, 0.0308779, 0.0257548, 0.0189295, 0.011407, 0.00443314, -0.00050475, -0.0016782]]).T polyfit_knee_Z_rotation = np.polyfit(knee_Z_rotation_function[:,0], knee_Z_rotation_function[:,1], deg=2, full = True) coefficients_knee_Z_rotation = polyfit_knee_Z_rotation[0] knee_Y_rotation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], [0, 0.059461, 0.109399, 0.150618, 0.18392, 0.210107, 0.229983, 0.24435, 0.254012, 0.25977, 0.262428, 0.262788, 0.261654]]).T polyfit_knee_Y_rotation = np.polyfit(knee_Y_rotation_function[:, 0], knee_Y_rotation_function[:, 1], deg=2, full=True) coefficients_knee_Y_rotation = polyfit_knee_Y_rotation[0] knee_X_translation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], [0, 5.3e-05, 0.000188, 0.000378, 0.000597, 0.000825, 0.001045, 0.001247, 0.00142, 0.001558, 0.001661, 0.001728, 0.00176]]).T polyfit_knee_X_translation = np.polyfit(knee_X_translation_function[:, 0], knee_X_translation_function[:, 1], deg=2, full=True) coefficients_knee_X_translation = polyfit_knee_X_translation[0] knee_Y_translation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], [0, 0.000301, 0.000143, -0.000401, -0.001233, -0.002243, -0.003316, -0.004346, -0.005239, -0.005924, -0.006361, -0.006539, -0.00648]]).T polyfit_knee_Y_translation = np.polyfit(knee_Y_translation_function[:, 0], knee_Y_translation_function[:, 1], deg=2, full=True) coefficients_knee_Y_translation = polyfit_knee_Y_translation[0] knee_Z_translation_function = np.array([[0, 0.174533, 0.349066, 0.523599, 0.698132, 0.872665, 1.0472, 1.22173, 1.39626, 1.5708, 1.74533, 1.91986, 2.0944], [0, 0.001055, 0.002061, 0.00289, 0.003447, 0.003676, 0.003559, 0.00311, 0.002373, 0.001418, 0.000329, -0.000805, -0.001898]]).T polyfit_knee_Z_translation = np.polyfit(knee_Z_translation_function[:, 0], knee_Z_translation_function[:, 1], deg=2, full=True) coefficients_knee_Z_translation = polyfit_knee_Z_translation[0] walker_knee_coefficients = np.stack((coefficients_knee_Y_rotation, coefficients_knee_Z_rotation, coefficients_knee_X_translation, coefficients_knee_Y_translation, coefficients_knee_Z_translation), axis=1) return walker_knee_coefficients walker_knee_coefficients = get_knee_rotation_coefficients() walker_knee_coefficients = torch.tensor(walker_knee_coefficients).to(device) def forward_kinematics(pose, offsets, with_arm=False): """ Pose indices 0-5: pelvis orientation + translation 6-8: hip_r 9: knee_r 10: ankle_r 11: subtalar_r 12: mtp_r 13-15: hip_l 16: knee_l 17: ankle_l 18: subtalar_l 19: mtp_l 20-22: lumbar 23-25: shoulder_r 26: elbow_r 27: radioulnar 28-30: shoulder_l 31: elbow_l 32: radioulnar_l """ if isinstance(pose, np.ndarray): pose = torch.from_numpy(pose) pose = pose.to(torch.device(device)) if len(pose.shape) == 2: pose = pose[None, ...] if len(offsets.shape) == 3: offsets = offsets[None, ...] offsets = offsets.to(pose.device) offsets = offsets[:, None, ...] batch_shape = pose.shape[:-1] batch_shape_list = [] for i in range(pose.dim()-1): batch_shape_list.append(int(batch_shape[i])) batch_shape = batch_shape_list if batch_shape == (): batch_shape = (1,) coefficients_knee_Y_rotation = walker_knee_coefficients[..., 0] coefficients_knee_Z_rotation = walker_knee_coefficients[..., 1] coefficients_knee_X_translation = walker_knee_coefficients[..., 2] coefficients_knee_Y_translation = walker_knee_coefficients[..., 3] coefficients_knee_Z_translation = walker_knee_coefficients[..., 4] knee_r_Y_rot = second_order_poly(coefficients_knee_Y_rotation, pose[..., 9]) knee_r_Z_rot = second_order_poly(coefficients_knee_Z_rotation, pose[..., 9]) knee_r_X_trans = second_order_poly(coefficients_knee_X_translation, pose[..., 9]) knee_r_Y_trans = second_order_poly(coefficients_knee_Y_translation, pose[..., 9]) knee_r_Z_trans = second_order_poly(coefficients_knee_Z_translation, pose[..., 9]) knee_l_Y_rot = second_order_poly(coefficients_knee_Y_rotation, pose[..., 16]) knee_l_Z_rot = second_order_poly(coefficients_knee_Z_rotation, pose[..., 16]) knee_l_X_trans = second_order_poly(coefficients_knee_X_translation, pose[..., 16]) knee_l_Y_trans = second_order_poly(coefficients_knee_Y_translation, pose[..., 16]) knee_l_Z_trans = second_order_poly(coefficients_knee_Z_translation, pose[..., 16]) # Pelvis pelvis_transform = batch_identity(batch_shape, 4).to(pose.device) pelvis_transform[..., :3, :3] = euler_angles_to_matrix(pose[..., 0:3], 'ZXY') pelvis_transform[..., :3, 3] = pose[..., 3:6].clone().detach() # Get offsets (model and model scaling dependent) offset_hip_pelvis_r = offsets[..., 0] femur_offset_in_hip_r = offsets[..., 1] knee_offset_in_femur_r = offsets[..., 2] tibia_offset_in_knee_r = offsets[..., 3] ankle_offset_in_tibia_r = offsets[..., 4] talus_offset_in_ankle_r = offsets[..., 5] subtalar_offset_in_talus_r = offsets[..., 6] calcaneus_offset_in_subtalar_r = offsets[..., 7] mtp_offset_in_calcaneus_r = offsets[..., 8] offset_hip_pelvis_l = offsets[..., 9] femur_offset_in_hip_l = offsets[..., 10] knee_offset_in_femur_l = offsets[..., 11] tibia_offset_in_knee_l = offsets[..., 12] ankle_offset_in_tibia_l = offsets[..., 13] talus_offset_in_ankle_l = offsets[..., 14] subtalar_offset_in_talus_l = offsets[..., 15] calcaneus_offset_in_subtalar_l = offsets[..., 16] mtp_offset_in_calcaneus_l = offsets[..., 17] lumbar_offset_in_pelvis = offsets[..., 18] torso_offset_in_lumbar = offsets[..., 19] if with_arm: shoulder_offset_in_torso_r = offsets[..., 20] humerus_offset_in_shoulder_r = offsets[..., 21] elbow_offset_in_humerus_r = offsets[..., 22] ulna_offset_in_elbow_r = offsets[..., 23] radioulnar_offset_in_radius_r = offsets[..., 24] radius_offset_in_radioulnar_r = offsets[..., 25] wrist_offset_in_radius_r = offsets[..., 26] hand_offset_in_wrist_r = offsets[..., 27] shoulder_offset_in_torso_l = offsets[..., 28] humerus_offset_in_shoulder_l = offsets[..., 29] elbow_offset_in_humerus_l = offsets[..., 30] ulna_offset_in_elbow_l = offsets[..., 31] radioulnar_offset_in_radius_l = offsets[..., 32] radius_offset_in_radioulnar_l = offsets[..., 33] wrist_offset_in_radius_l = offsets[..., 34] hand_offset_in_wrist_l = offsets[..., 35] # Coordinates to transformation matrix hip_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) knee_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) ankle_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) subtalar_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) mtp_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) hip_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) knee_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) ankle_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) subtalar_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) mtp_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) lumbar_coordinates_transform = batch_identity(batch_shape, 4).to(pose.device) if with_arm: shoulder_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) elbow_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) radioulnar_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) wrist_coordinates_transform_r = batch_identity(batch_shape, 4).to(pose.device) shoulder_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) elbow_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) radioulnar_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) wrist_coordinates_transform_l = batch_identity(batch_shape, 4).to(pose.device) # Knee axis translation knee_coordinates_transform_r[..., :3, -1] = torch.stack((knee_r_X_trans, knee_r_Y_trans, knee_r_Z_trans), dim=-1) knee_coordinates_transform_l[..., :3, -1] = torch.stack((knee_l_X_trans, knee_l_Y_trans, -knee_l_Z_trans), dim=-1) # Joint rotations zero_2_shape = batch_shape.copy() zero_2_shape.append(2) zero_2 = torch.zeros(tuple(zero_2_shape), device=pose.device) zero_3_shape = batch_shape.copy() zero_3_shape.append(3) zero_3 = torch.zeros(tuple(zero_3_shape), device=pose.device) hip_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix((pose[..., 6:9]), 'ZXY') knee_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( torch.stack((pose[..., 9], knee_r_Y_rot, knee_r_Z_rot), dim=-1), 'XYZ') ankle_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 10:11], zero_2), dim=-1), 'ZXY') subtalar_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 11:12], zero_2), dim=-1), 'ZXY') mtp_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 12:13], zero_2), dim=-1), 'ZXY') hip_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( torch.cat(([pose[..., 13:14], -pose[..., 14:15], -pose[..., 15:16]]), dim=-1), 'ZXY') knee_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( torch.stack((-pose[..., 16], -knee_l_Y_rot, knee_l_Z_rot), dim=-1), 'XYZ') ankle_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 17:18], zero_2), dim=-1), 'ZXY') subtalar_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 18:19], zero_2), dim=-1), 'ZXY') mtp_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 19:20], zero_2), dim=-1), 'ZXY') lumbar_coordinates_transform[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 20:21], pose[..., 21:22], pose[..., 22:23]), dim=-1), 'ZXY') if with_arm: shoulder_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 23:24], pose[..., 24:25], pose[..., 25:26]), dim=-1), 'ZXY') elbow_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 26:27], zero_2), dim=-1), 'ZXY') radioulnar_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 27:28], zero_2), dim=-1), 'ZXY') wrist_coordinates_transform_r[..., :3, :3] = euler_angles_to_matrix( zero_3, 'ZXY') shoulder_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 28:29], -pose[..., 29:30], -pose[..., 30:31]), dim=-1), 'ZXY') elbow_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 31:32], zero_2), dim=-1), 'ZXY') radioulnar_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( torch.cat((pose[..., 32:33], zero_2), dim=-1), 'ZXY') wrist_coordinates_transform_l[..., :3, :3] = euler_angles_to_matrix( zero_3, 'ZXY') # Forward kinematics for the lower body hip_transform_r = torch.matmul(torch.matmul(pelvis_transform, offset_hip_pelvis_r), hip_coordinates_transform_r) femur_transform_r = torch.matmul(hip_transform_r, femur_offset_in_hip_r) knee_transform_r = torch.matmul(torch.matmul(femur_transform_r, knee_offset_in_femur_r), knee_coordinates_transform_r) tibia_transform_r = torch.matmul(knee_transform_r, tibia_offset_in_knee_r) ankle_transform_r = torch.matmul(torch.matmul(tibia_transform_r, ankle_offset_in_tibia_r), ankle_coordinates_transform_r) talus_transform_r = torch.matmul(ankle_transform_r, talus_offset_in_ankle_r) subtalar_transform_r = torch.matmul(torch.matmul(talus_transform_r, subtalar_offset_in_talus_r), subtalar_coordinates_transform_r) calcaneus_transform_r = torch.matmul(subtalar_transform_r, calcaneus_offset_in_subtalar_r) mtp_offset_transform_r = torch.matmul(torch.matmul(calcaneus_transform_r, mtp_offset_in_calcaneus_r), mtp_coordinates_transform_r) hip_transform_l = torch.matmul(torch.matmul(pelvis_transform, offset_hip_pelvis_l), hip_coordinates_transform_l) femur_transform_l = torch.matmul(hip_transform_l, femur_offset_in_hip_l) knee_transform_l = torch.matmul(torch.matmul(femur_transform_l, knee_offset_in_femur_l), knee_coordinates_transform_l) tibia_transform_l = torch.matmul(knee_transform_l, tibia_offset_in_knee_l) ankle_transform_l = torch.matmul(torch.matmul(tibia_transform_l, ankle_offset_in_tibia_l), ankle_coordinates_transform_l) talus_transform_l = torch.matmul(ankle_transform_l, talus_offset_in_ankle_l) subtalar_transform_l = torch.matmul(torch.matmul(talus_transform_l, subtalar_offset_in_talus_l), subtalar_coordinates_transform_l) calcaneus_transform_l = torch.matmul(subtalar_transform_l, calcaneus_offset_in_subtalar_l) mtp_offset_transform_l = torch.matmul(torch.matmul(calcaneus_transform_l, mtp_offset_in_calcaneus_l), mtp_coordinates_transform_l) # Forward kinematics for the upper body lumbar_transform = torch.matmul(torch.matmul(pelvis_transform, lumbar_offset_in_pelvis), lumbar_coordinates_transform) torso_transform = torch.matmul(lumbar_transform, torso_offset_in_lumbar) if with_arm: shoulder_transform_r = torch.matmul(torch.matmul(torso_transform, shoulder_offset_in_torso_r), shoulder_coordinates_transform_r) humerus_transform_r = torch.matmul(shoulder_transform_r, humerus_offset_in_shoulder_r) elbow_transform_r = torch.matmul(torch.matmul(humerus_transform_r, elbow_offset_in_humerus_r), elbow_coordinates_transform_r) ulna_transform_r = torch.matmul(elbow_transform_r, ulna_offset_in_elbow_r) radioulnar_transform_r = torch.matmul(torch.matmul(ulna_transform_r, radioulnar_offset_in_radius_r), radioulnar_coordinates_transform_r) radius_transform_r = torch.matmul(radioulnar_transform_r, radius_offset_in_radioulnar_r) wrist_transform_r = torch.matmul(torch.matmul(ulna_transform_r, wrist_offset_in_radius_r), wrist_coordinates_transform_r) hand_transform_r = torch.matmul(wrist_transform_r, hand_offset_in_wrist_r) shoulder_transform_l = torch.matmul(torch.matmul(torso_transform, shoulder_offset_in_torso_l), shoulder_coordinates_transform_l) humerus_transform_l = torch.matmul(shoulder_transform_l, humerus_offset_in_shoulder_l) elbow_transform_l = torch.matmul(torch.matmul(humerus_transform_l, elbow_offset_in_humerus_l), elbow_coordinates_transform_l) ulna_transform_l = torch.matmul(elbow_transform_l, ulna_offset_in_elbow_l) radioulnar_transform_l = torch.matmul(torch.matmul(ulna_transform_l, radioulnar_offset_in_radius_l), radioulnar_coordinates_transform_l) radius_transform_l = torch.matmul(radioulnar_transform_l, radius_offset_in_radioulnar_l) wrist_transform_l = torch.matmul(torch.matmul(ulna_transform_l, wrist_offset_in_radius_l), wrist_coordinates_transform_l) hand_transform_l = torch.matmul(wrist_transform_l, hand_offset_in_wrist_l) joint_locations = torch.stack((pelvis_transform[..., :3, 3], hip_transform_r[..., :3, 3], knee_transform_r[..., :3, 3], ankle_transform_r[..., :3, 3], calcaneus_transform_r[..., :3, 3], mtp_offset_transform_r[..., :3, 3], hip_transform_l[..., :3, 3], knee_transform_l[..., :3, 3], ankle_transform_l[..., :3, 3], calcaneus_transform_l[..., :3, 3], mtp_offset_transform_l[..., :3, 3])) if with_arm: joint_locations = torch.stack((*[joint_locations[i] for i in range(joint_locations.shape[0])], lumbar_transform[..., :3, 3], shoulder_transform_r[..., :3, 3], elbow_transform_r[..., :3, 3], wrist_transform_r[..., :3, 3], shoulder_transform_l[..., :3, 3], elbow_transform_l[..., :3, 3], wrist_transform_l[..., :3, 3])) if torch.isnan(joint_locations).any(): print('NAN in joint locations') foot_locations = torch.stack((calcaneus_transform_r[..., :3, 3], mtp_offset_transform_r[..., :3, 3], calcaneus_transform_l[..., :3, 3], mtp_offset_transform_l[..., :3, 3])) segment_orientations = torch.stack((pelvis_transform[..., :3, :3], femur_transform_r[..., :3, :3], tibia_transform_r[..., :3, :3], talus_transform_r[..., :3, :3], calcaneus_transform_r[..., :3, :3], femur_transform_l[..., :3, :3], tibia_transform_l[..., :3, :3], talus_transform_l[..., :3, :3], calcaneus_transform_l[..., :3, :3], torso_transform[..., :3, :3])) if with_arm: segment_orientations = torch.stack((*[segment_orientations[i] for i in range(segment_orientations.shape[0])], humerus_transform_r[..., :3, :3], ulna_transform_r[..., :3, :3], radius_transform_r[..., :3, :3], humerus_transform_l[..., :3, :3], ulna_transform_l[..., :3, :3], radius_transform_l[..., :3, :3])) joint_names = ['pelvis', 'hip_r', 'knee_r', 'ankle_r', 'calcn_r', 'mtp_r', 'hip_l', 'knee_l', 'ankle_l', 'calcn_l', 'mtp_l'] return foot_locations, joint_locations, joint_names, segment_orientations def get_model_offsets(skeleton, with_arm=False): pelvis = skeleton.getBodyNode(0) hip_r_joint = pelvis.getChildJoint(0) femur_r = hip_r_joint.getChildBodyNode() knee_r_joint = femur_r.getChildJoint(0) tibia_r = knee_r_joint.getChildBodyNode() ankle_r_joint = tibia_r.getChildJoint(0) talus_r = ankle_r_joint.getChildBodyNode() subtalar_r_joint = talus_r.getChildJoint(0) calcn_r = subtalar_r_joint.getChildBodyNode() mtp_r_joint = calcn_r.getChildJoint(0) hip_l_joint = pelvis.getChildJoint(1) femur_l = hip_l_joint.getChildBodyNode() knee_l_joint = femur_l.getChildJoint(0) tibia_l = knee_l_joint.getChildBodyNode() ankle_l_joint = tibia_l.getChildJoint(0) talus_l = ankle_l_joint.getChildBodyNode() subtalar_l_joint = talus_l.getChildJoint(0) calcn_l = subtalar_l_joint.getChildBodyNode() mtp_l_joint = calcn_l.getChildJoint(0) lumbar_joint = pelvis.getChildJoint(2) torso = lumbar_joint.getChildBodyNode() if with_arm: shoulder_r_joint = torso.getChildJoint(0) humerus_r = shoulder_r_joint.getChildBodyNode() elbow_r_joint = humerus_r.getChildJoint(0) ulna_r = elbow_r_joint.getChildBodyNode() radioulnar_r_joint = ulna_r.getChildJoint(0) radius_r = radioulnar_r_joint.getChildBodyNode() wrist_r_joint = radius_r.getChildJoint(0) hand_r = wrist_r_joint.getChildBodyNode() shoulder_l_joint = torso.getChildJoint(1) humerus_l = shoulder_l_joint.getChildBodyNode() elbow_l_joint = humerus_l.getChildJoint(0) ulna_l = elbow_l_joint.getChildBodyNode() radioulnar_l_joint = ulna_l.getChildJoint(0) radius_l = radioulnar_l_joint.getChildBodyNode() wrist_l_joint = radius_l.getChildJoint(0) hand_l = wrist_l_joint.getChildBodyNode() # hip offset hip_offset_r = torch.eye(4) hip_offset_r[:3, :3] = torch.tensor(hip_r_joint.getTransformFromParentBodyNode().rotation()) hip_offset_r[:3, 3] = torch.tensor(hip_r_joint.getTransformFromParentBodyNode().translation()) hip_offset_l = torch.eye(4) hip_offset_l[:3, :3] = torch.tensor(hip_l_joint.getTransformFromParentBodyNode().rotation()) hip_offset_l[:3, 3] = torch.tensor(hip_l_joint.getTransformFromParentBodyNode().translation()) # femur offset femur_offset_to_knee_in_femur_r = -torch.tensor(hip_r_joint.getTransformFromChildBodyNode().translation()) femur_rotation_to_knee_in_femur_r = torch.inverse(torch.tensor(hip_r_joint.getTransformFromChildBodyNode().rotation())) femur_offset_rotation_r = torch.eye(4) femur_offset_rotation_r[:3, :3] = femur_rotation_to_knee_in_femur_r femur_offset_translation_r = torch.eye(4) femur_offset_translation_r[:3, 3] = femur_offset_to_knee_in_femur_r femur_offset_r = torch.matmul(femur_offset_rotation_r, femur_offset_translation_r) femur_offset_to_knee_in_femur_l = -torch.tensor(hip_l_joint.getTransformFromChildBodyNode().translation()) femur_rotation_to_knee_in_femur_l = torch.inverse(torch.tensor(hip_l_joint.getTransformFromChildBodyNode().rotation())) femur_offset_rotation_l = torch.eye(4) femur_offset_rotation_l[:3, :3] = femur_rotation_to_knee_in_femur_l femur_offset_translation_l = torch.eye(4) femur_offset_translation_l[:3, 3] = femur_offset_to_knee_in_femur_l femur_offset_l = torch.matmul(femur_offset_rotation_l, femur_offset_translation_l) # knee offset knee_offset_r = torch.eye(4) knee_offset_r[:3, :3] = torch.tensor(knee_r_joint.getTransformFromParentBodyNode().rotation()) knee_offset_r[:3, 3] = torch.tensor(knee_r_joint.getTransformFromParentBodyNode().translation()) knee_offset_l = torch.eye(4) knee_offset_l[:3, :3] = torch.tensor(knee_l_joint.getTransformFromParentBodyNode().rotation()) knee_offset_l[:3, 3] = torch.tensor(knee_l_joint.getTransformFromParentBodyNode().translation()) # tibia offset tibia_offset_to_knee_in_tibia_r = -torch.tensor(knee_r_joint.getTransformFromChildBodyNode().translation()) tibia_rotation_to_knee_in_tibia_r = torch.inverse(torch.tensor(knee_r_joint.getTransformFromChildBodyNode().rotation())) tibia_offset_rotation_r = torch.eye(4) tibia_offset_rotation_r[:3, :3] = tibia_rotation_to_knee_in_tibia_r tibia_offset_translation_r = torch.eye(4) tibia_offset_translation_r[:3, 3] = tibia_offset_to_knee_in_tibia_r tibia_offset_r = torch.matmul(tibia_offset_rotation_r, tibia_offset_translation_r) tibia_offset_to_knee_in_tibia_l = -torch.tensor(knee_l_joint.getTransformFromChildBodyNode().translation()) tibia_rotation_to_knee_in_tibia_l = torch.inverse(torch.tensor(knee_l_joint.getTransformFromChildBodyNode().rotation())) tibia_offset_rotation_l = torch.eye(4) tibia_offset_rotation_l[:3, :3] = tibia_rotation_to_knee_in_tibia_l tibia_offset_translation_l = torch.eye(4) tibia_offset_translation_l[:3, 3] = tibia_offset_to_knee_in_tibia_l tibia_offset_l = torch.matmul(tibia_offset_rotation_l, tibia_offset_translation_l) # ankle offset ankle_offset_r = torch.eye(4) ankle_offset_r[:3,:3] = torch.tensor(ankle_r_joint.getTransformFromParentBodyNode().rotation()) ankle_offset_r[:3, 3] = torch.tensor(ankle_r_joint.getTransformFromParentBodyNode().translation()) ankle_offset_l = torch.eye(4) ankle_offset_l[:3, :3] = torch.tensor(ankle_l_joint.getTransformFromParentBodyNode().rotation()) ankle_offset_l[:3, 3] = torch.tensor(ankle_l_joint.getTransformFromParentBodyNode().translation()) # talus offset talus_offset_to_ankle_in_talus_r = -torch.tensor(ankle_r_joint.getTransformFromChildBodyNode().translation()) talus_rotation_to_ankle_in_talus_r = torch.inverse(torch.tensor(ankle_r_joint.getTransformFromChildBodyNode().rotation())) talus_offset_rotation_r = torch.eye(4) talus_offset_rotation_r[:3, :3] = talus_rotation_to_ankle_in_talus_r talus_offset_translation_r = torch.eye(4) talus_offset_translation_r[:3, 3] = talus_offset_to_ankle_in_talus_r talus_offset_r = torch.matmul(talus_offset_rotation_r, talus_offset_translation_r) talus_offset_to_ankle_in_talus_l = -torch.tensor(ankle_l_joint.getTransformFromChildBodyNode().translation()) talus_rotation_to_ankle_in_talus_l = torch.inverse(torch.tensor(ankle_l_joint.getTransformFromChildBodyNode().rotation())) talus_offset_rotation_l = torch.eye(4) talus_offset_rotation_l[:3, :3] = talus_rotation_to_ankle_in_talus_l talus_offset_translation_l = torch.eye(4) talus_offset_translation_l[:3, 3] = talus_offset_to_ankle_in_talus_l talus_offset_l = torch.matmul(talus_offset_rotation_l, talus_offset_translation_l) # subtalar offset subtalar_offset_r = torch.eye(4) subtalar_offset_r[:3,:3] = torch.tensor(subtalar_r_joint.getTransformFromParentBodyNode().rotation()) subtalar_offset_r[:3, 3] = torch.tensor(subtalar_r_joint.getTransformFromParentBodyNode().translation()) subtalar_offset_l = torch.eye(4) subtalar_offset_l[:3, :3] = torch.tensor(subtalar_l_joint.getTransformFromParentBodyNode().rotation()) subtalar_offset_l[:3, 3] = torch.tensor(subtalar_l_joint.getTransformFromParentBodyNode().translation()) # calcaneus offset calcaneus_offset_to_subtalar_in_calcaneus_r = -torch.tensor(subtalar_r_joint.getTransformFromChildBodyNode().translation()) calcaneus_rotation_to_subtalar_in_calcaneus_r = torch.inverse(torch.tensor(subtalar_r_joint.getTransformFromChildBodyNode().rotation())) calcaneus_offset_rotation_r = torch.eye(4) calcaneus_offset_rotation_r[:3, :3] = calcaneus_rotation_to_subtalar_in_calcaneus_r calcaneus_offset_translation_r = torch.eye(4) calcaneus_offset_translation_r[:3, 3] = calcaneus_offset_to_subtalar_in_calcaneus_r calcaneus_offset_r = torch.matmul(calcaneus_offset_rotation_r, calcaneus_offset_translation_r) calcaneus_offset_to_subtalar_in_calcaneus_l = -torch.tensor(subtalar_l_joint.getTransformFromChildBodyNode().translation()) calcaneus_rotation_to_subtalar_in_calcaneus_l = torch.inverse(torch.tensor(subtalar_l_joint.getTransformFromChildBodyNode().rotation())) calcaneus_offset_rotation_l = torch.eye(4) calcaneus_offset_rotation_l[:3, :3] = calcaneus_rotation_to_subtalar_in_calcaneus_l calcaneus_offset_translation_l = torch.eye(4) calcaneus_offset_translation_l[:3, 3] = calcaneus_offset_to_subtalar_in_calcaneus_l calcaneus_offset_l = torch.matmul(calcaneus_offset_rotation_l, calcaneus_offset_translation_l) # mtp offset mtp_offset_r = torch.eye(4) mtp_offset_r[:3, :3] = torch.tensor(mtp_r_joint.getTransformFromParentBodyNode().rotation()) mtp_offset_r[:3, 3] = torch.tensor(mtp_r_joint.getTransformFromParentBodyNode().translation()) mtp_offset_l = torch.eye(4) mtp_offset_l[:3, :3] = torch.tensor(mtp_l_joint.getTransformFromParentBodyNode().rotation()) mtp_offset_l[:3, 3] = torch.tensor(mtp_l_joint.getTransformFromParentBodyNode().translation()) # toes offset toes_offset_to_mtp_in_toes_r = -torch.tensor(mtp_r_joint.getTransformFromChildBodyNode().translation()) toes_rotation_to_mtp_in_toes_r = torch.inverse(torch.tensor(mtp_r_joint.getTransformFromChildBodyNode().rotation())) toes_offset_rotation_r = torch.eye(4) toes_offset_rotation_r[:3, :3] = toes_rotation_to_mtp_in_toes_r toes_offset_translation_r = torch.eye(4) toes_offset_translation_r[:3, 3] = toes_offset_to_mtp_in_toes_r toes_offset_r = torch.matmul(toes_offset_rotation_r, toes_offset_translation_r) toes_offset_to_mtp_in_toes_l = -torch.tensor(mtp_l_joint.getTransformFromChildBodyNode().translation()) toes_rotation_to_mtp_in_toes_l = torch.inverse(torch.tensor(mtp_l_joint.getTransformFromChildBodyNode().rotation())) toes_offset_rotation_l = torch.eye(4) toes_offset_rotation_l[:3, :3] = toes_rotation_to_mtp_in_toes_l toes_offset_translation_l = torch.eye(4) toes_offset_translation_l[:3, 3] = toes_offset_to_mtp_in_toes_l toes_offset_l = torch.matmul(toes_offset_rotation_l, toes_offset_translation_l) # lumbar offset lumbar_offset = torch.eye(4) lumbar_offset[:3, :3] = torch.tensor(lumbar_joint.getTransformFromParentBodyNode().rotation()) lumbar_offset[:3, 3] = torch.tensor(lumbar_joint.getTransformFromParentBodyNode().translation()) # torso offset torso_offset_to_lumbar_in_torso = -torch.tensor(lumbar_joint.getTransformFromChildBodyNode().translation()) torso_offset_rotation_to_lumbar_in_torso = torch.inverse(torch.tensor(lumbar_joint.getTransformFromChildBodyNode().rotation())) torso_offset_rotation = torch.eye(4) torso_offset_rotation[:3, :3] = torso_offset_rotation_to_lumbar_in_torso torso_offset_translation = torch.eye(4) torso_offset_translation[:3, 3] = torso_offset_to_lumbar_in_torso torso_offset = torch.matmul(torso_offset_rotation, torso_offset_translation) if with_arm: # shoulder offset shoulder_offset_r = torch.eye(4) shoulder_offset_r[:3, :3] = torch.tensor(shoulder_r_joint.getTransformFromParentBodyNode().rotation()) shoulder_offset_r[:3, 3] = torch.tensor(shoulder_r_joint.getTransformFromParentBodyNode().translation()) shoulder_offset_l = torch.eye(4) shoulder_offset_l[:3, :3] = torch.tensor(shoulder_l_joint.getTransformFromParentBodyNode().rotation()) shoulder_offset_l[:3, 3] = torch.tensor(shoulder_l_joint.getTransformFromParentBodyNode().translation()) # humerus offset humerus_offset_to_shoulder_in_humerus_r = -torch.tensor(shoulder_r_joint.getTransformFromChildBodyNode().translation()) humerus_offset_rotation_to_shoulder_in_humerus_r = torch.inverse(torch.tensor(shoulder_r_joint.getTransformFromChildBodyNode().rotation())) humerus_offset_rotation_r = torch.eye(4) humerus_offset_rotation_r[:3, :3] = humerus_offset_rotation_to_shoulder_in_humerus_r humerus_offset_translation_r = torch.eye(4) humerus_offset_translation_r[:3, 3] = humerus_offset_to_shoulder_in_humerus_r humerus_offset_r = torch.matmul(humerus_offset_rotation_r, humerus_offset_translation_r) humerus_offset_to_shoulder_in_humerus_l = -torch.tensor(shoulder_l_joint.getTransformFromChildBodyNode().translation()) humerus_offset_rotation_to_shoulder_in_humerus_l = torch.inverse(torch.tensor(shoulder_l_joint.getTransformFromChildBodyNode().rotation())) humerus_offset_rotation_l = torch.eye(4) humerus_offset_rotation_l[:3, :3] = humerus_offset_rotation_to_shoulder_in_humerus_l humerus_offset_translation_l = torch.eye(4) humerus_offset_translation_l[:3, 3] = humerus_offset_to_shoulder_in_humerus_l humerus_offset_l = torch.matmul(humerus_offset_rotation_l, humerus_offset_translation_l) # elbow offset elbow_offset_r = torch.eye(4) elbow_offset_r[:3, :3] = torch.tensor(elbow_r_joint.getTransformFromParentBodyNode().rotation()) elbow_offset_r[:3, 3] = torch.tensor(elbow_r_joint.getTransformFromParentBodyNode().translation()) elbow_offset_l = torch.eye(4) elbow_offset_l[:3, :3] = torch.tensor(elbow_l_joint.getTransformFromParentBodyNode().rotation()) elbow_offset_l[:3, 3] = torch.tensor(elbow_l_joint.getTransformFromParentBodyNode().translation()) # ulna offset ulna_offset_to_elbow_in_ulna_r = -torch.tensor(elbow_r_joint.getTransformFromChildBodyNode().translation()) ulna_offset_rotation_to_elbow_in_ulna_r = torch.inverse(torch.tensor(elbow_r_joint.getTransformFromChildBodyNode().rotation())) ulna_offset_rotation_r = torch.eye(4) ulna_offset_rotation_r[:3, :3] = ulna_offset_rotation_to_elbow_in_ulna_r ulna_offset_translation_r = torch.eye(4) ulna_offset_translation_r[:3, 3] = ulna_offset_to_elbow_in_ulna_r ulna_offset_r = torch.matmul(ulna_offset_rotation_r, ulna_offset_translation_r) ulna_offset_to_elbow_in_ulna_l = -torch.tensor(elbow_l_joint.getTransformFromChildBodyNode().translation()) ulna_offset_rotation_to_elbow_in_ulna_l = torch.inverse(torch.tensor(elbow_l_joint.getTransformFromChildBodyNode().rotation())) ulna_offset_rotation_l = torch.eye(4) ulna_offset_rotation_l[:3, :3] = ulna_offset_rotation_to_elbow_in_ulna_l ulna_offset_translation_l = torch.eye(4) ulna_offset_translation_l[:3, 3] = ulna_offset_to_elbow_in_ulna_l ulna_offset_l = torch.matmul(ulna_offset_rotation_l, ulna_offset_translation_l) # radioulnar offset radioulnar_offset_r = torch.eye(4) radioulnar_offset_r[:3, :3] = torch.tensor(radioulnar_r_joint.getTransformFromParentBodyNode().rotation()) radioulnar_offset_r[:3, 3] = torch.tensor(radioulnar_r_joint.getTransformFromParentBodyNode().translation()) radioulnar_offset_l = torch.eye(4) radioulnar_offset_l[:3, :3] = torch.tensor(radioulnar_l_joint.getTransformFromParentBodyNode().rotation()) radioulnar_offset_l[:3, 3] = torch.tensor(radioulnar_l_joint.getTransformFromParentBodyNode().translation()) # radius offset radius_offset_to_radioulnar_in_radius_r = -torch.tensor(radioulnar_r_joint.getTransformFromChildBodyNode().translation()) radius_offset_rotation_to_radioulnar_in_radius_r = torch.inverse(torch.tensor(radioulnar_r_joint.getTransformFromChildBodyNode().rotation())) radius_offset_rotation_r = torch.eye(4) radius_offset_rotation_r[:3, :3] = radius_offset_rotation_to_radioulnar_in_radius_r radius_offset_translation_r = torch.eye(4) radius_offset_translation_r[:3, 3] = radius_offset_to_radioulnar_in_radius_r radius_offset_r = torch.matmul(radius_offset_rotation_r, radius_offset_translation_r) radius_offset_to_radioulnar_in_radius_l = -torch.tensor(radioulnar_l_joint.getTransformFromChildBodyNode().translation()) radius_offset_rotation_to_radioulnar_in_radius_l = torch.inverse(torch.tensor(radioulnar_l_joint.getTransformFromChildBodyNode().rotation())) radius_offset_rotation_l = torch.eye(4) radius_offset_rotation_l[:3, :3] = radius_offset_rotation_to_radioulnar_in_radius_l radius_offset_translation_l = torch.eye(4) radius_offset_translation_l[:3, 3] = radius_offset_to_radioulnar_in_radius_l radius_offset_l = torch.matmul(radius_offset_rotation_l, radius_offset_translation_l) # wrist offset wrist_offset_r = torch.eye(4) wrist_offset_r[:3, :3] = torch.tensor(wrist_r_joint.getTransformFromParentBodyNode().rotation()) wrist_offset_r[:3, 3] = torch.tensor(wrist_r_joint.getTransformFromParentBodyNode().translation()) wrist_offset_l = torch.eye(4) wrist_offset_l[:3, :3] = torch.tensor(wrist_l_joint.getTransformFromParentBodyNode().rotation()) wrist_offset_l[:3, 3] = torch.tensor(wrist_l_joint.getTransformFromParentBodyNode().translation()) # hand offset hand_offset_to_wrist_in_hand_r = -torch.tensor(wrist_r_joint.getTransformFromChildBodyNode().translation()) hand_offset_rotation_to_wrist_in_hand_r = torch.inverse(torch.tensor(wrist_r_joint.getTransformFromChildBodyNode().rotation())) hand_offset_rotation_r = torch.eye(4) hand_offset_rotation_r[:3, :3] = hand_offset_rotation_to_wrist_in_hand_r hand_offset_translation_r = torch.eye(4) hand_offset_translation_r[:3, 3] = hand_offset_to_wrist_in_hand_r hand_offset_r = torch.matmul(hand_offset_rotation_r, hand_offset_translation_r) hand_offset_to_wrist_in_hand_l = -torch.tensor(wrist_l_joint.getTransformFromChildBodyNode().translation()) hand_offset_rotation_to_wrist_in_hand_l = torch.inverse(torch.tensor(wrist_l_joint.getTransformFromChildBodyNode().rotation())) hand_offset_rotation_l = torch.eye(4) hand_offset_rotation_l[:3, :3] = hand_offset_rotation_to_wrist_in_hand_l hand_offset_translation_l = torch.eye(4) hand_offset_translation_l[:3, 3] = hand_offset_to_wrist_in_hand_l hand_offset_l = torch.matmul(hand_offset_rotation_l, hand_offset_translation_l) offsets = torch.stack((hip_offset_r, femur_offset_r, knee_offset_r, tibia_offset_r, ankle_offset_r, talus_offset_r, subtalar_offset_r, calcaneus_offset_r, mtp_offset_r, hip_offset_l, femur_offset_l, knee_offset_l, tibia_offset_l, ankle_offset_l, talus_offset_l, subtalar_offset_l, calcaneus_offset_l, mtp_offset_l, lumbar_offset, torso_offset), dim=2) if with_arm: offsets = torch.stack((*[offsets[i] for i in range(offsets.shape[0])], shoulder_offset_r, humerus_offset_r, elbow_offset_r, ulna_offset_r, radioulnar_offset_r, radius_offset_r, wrist_offset_r, hand_offset_r, shoulder_offset_l, humerus_offset_l, elbow_offset_l, ulna_offset_l, radioulnar_offset_l, radius_offset_l, wrist_offset_l, hand_offset_l,toes_offset_r, toes_offset_l), dim=2) return offsets def euler_to_angular_velocity(q, sampling_fre, convention="XYZ"): assert q.shape[-1] == 3 mat = euler_angles_to_matrix(q, convention) mat_diff = mat.clone() mat_diff[:-1, :, :] = mat_diff[:-1, :, :] - mat_diff[1:, :, :] mat_diff[-1, :, :] = mat_diff[-2, :, :] angular_velocity = torch.stack([mat_diff[:, 2, 1] - mat_diff[:, 1, 2], mat_diff[:, 0, 2] - mat_diff[:, 2, 0], mat_diff[:, 1, 0] - mat_diff[:, 0, 1]], dim=1) * sampling_fre * 0.5 return angular_velocity def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: raise ValueError("Invalid input euler angles.") if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") matrices = [ _axis_angle_rotation(c, e) for c, e in zip(convention, torch.unbind(euler_angles, -1)) ] return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: cos = torch.cos(angle) sin = torch.sin(angle) one = torch.ones_like(angle) zero = torch.zeros_like(angle) if axis == "X": R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) elif axis == "Y": R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) elif axis == "Z": R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) else: raise ValueError("letter must be either X, Y or Z.") return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") i0 = _index_from_letter(convention[0]) i2 = _index_from_letter(convention[2]) tait_bryan = i0 != i2 if tait_bryan: central_angle = torch.asin( matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) ) else: central_angle = torch.acos(matrix[..., i0, i0]) o = ( _angle_from_tan( convention[0], convention[1], matrix[..., i2], False, tait_bryan ), central_angle, _angle_from_tan( convention[2], convention[1], matrix[..., i0, :], True, tait_bryan ), ) return torch.stack(o, -1) def _angle_from_tan( axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool ) -> torch.Tensor: i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] if horizontal: i2, i1 = i1, i2 even = (axis + other_axis) in ["XY", "YZ", "ZX"] if horizontal == even: return torch.atan2(data[..., i1], data[..., i2]) if tait_bryan: return torch.atan2(-data[..., i2], data[..., i1]) return torch.atan2(data[..., i2], -data[..., i1]) def _index_from_letter(letter: str) -> int: if letter == "X": return 0 if letter == "Y": return 1 if letter == "Z": return 2 raise ValueError("letter must be either X, Y or Z.") def get_multi_body_loc_using_nimble_by_body_names(body_names, skel, poses): body_ids = [skel.getBodyNode(name) for name in body_names] return get_multi_body_loc_using_nimble_by_body_nodes(body_ids, skel, poses) def get_multi_body_loc_using_nimble_by_body_nodes(body_nodes, skel, poses): body_loc = [] for i_frame in range(len(poses)): skel.setPositions(poses[i_frame]) body_loc.append(np.concatenate([body_node.getWorldTransform().translation() for body_node in body_nodes])) body_loc = np.array(body_loc) return body_loc def inverse_norm_cops(skel, states, opt, sub_mass, height_m, grf_thd_to_zero_cop=20): poses = states[:, opt.kinematic_osim_col_loc] forces = states[:, opt.grf_osim_col_loc] normed_cops = states[:, opt.cop_osim_col_loc] if len(skel.getDofs()) != poses.shape[1]: # With Arm model print('With Arm model is used. Adding 6 zeros to the end of the poses.') poses = np.concatenate([poses, np.zeros((poses.shape[0], 10))], axis=-1) foot_loc = get_multi_body_loc_using_nimble_by_body_names(('calcn_r', 'calcn_l'), skel, poses) for i_plate in range(2): force_v = forces[:, 3*i_plate:3*(i_plate+1)] force_v[force_v == 0] = 1e-6 vector = normed_cops[:, 3 * i_plate:3 * (i_plate + 1)] / force_v[:, 1:2] * height_m vector = np.nan_to_num(vector, posinf=0, neginf=0) # vector.clip(min=-0.4, max=0.4, out=vector) # CoP should be within 0.4 m from the foot cops = vector + foot_loc[:, 3*i_plate:3*(i_plate+1)] if grf_thd_to_zero_cop: cops[force_v[:, 1] * sub_mass < grf_thd_to_zero_cop] = 0 if isinstance(states, torch.Tensor): cops = torch.from_numpy(cops).to(states.dtype) else: cops = cops.astype(states.dtype) states[:, opt.cop_osim_col_loc[3*i_plate:3*(i_plate+1)]] = cops return states def inverse_align_moving_direction(results_pred, column_names, rot_mat): if isinstance(results_pred, np.ndarray): poses = torch.from_numpy(results_pred) results_pred_clone = poses.clone().float() pelvis_orientation_col_loc = [column_names.index(col) for col in JOINTS_3D_ALL['pelvis']] p_pos_col_loc = [column_names.index(col) for col in [f'pelvis_t{x}' for x in ['x', 'y', 'z']]] r_grf_col_loc = [column_names.index(col) for col in ['calcn_r_force_vx', 'calcn_r_force_vy', 'calcn_r_force_vz']] l_grf_col_loc = [column_names.index(col) for col in ['calcn_l_force_vx', 'calcn_l_force_vy', 'calcn_l_force_vz']] r_cop_col_loc = [column_names.index(col) for col in [f'calcn_r_force_normed_cop_{x}' for x in ['x', 'y', 'z']]] l_cop_col_loc = [column_names.index(col) for col in [f'calcn_l_force_normed_cop_{x}' for x in ['x', 'y', 'z']]] if len(pelvis_orientation_col_loc) != 3 or len(p_pos_col_loc) != 3 or len(r_grf_col_loc) != 3 or len( l_grf_col_loc) != 3: raise ValueError('check column names') pelvis_orientation = results_pred_clone[:, pelvis_orientation_col_loc] pelvis_orientation = euler_angles_to_matrix(pelvis_orientation, "ZXY") p_pos = results_pred_clone[:, p_pos_col_loc] r_grf = results_pred_clone[:, r_grf_col_loc] l_grf = results_pred_clone[:, l_grf_col_loc] r_cop = results_pred_clone[:, r_cop_col_loc] l_cop = results_pred_clone[:, l_cop_col_loc] rot_mat = rot_mat.T pelvis_orientation_rotated = torch.matmul(rot_mat, pelvis_orientation) p_pos_rotated = torch.matmul(rot_mat, p_pos.unsqueeze(2)).squeeze(2) r_grf_rotated = torch.matmul(rot_mat, r_grf.unsqueeze(2)).squeeze(2) l_grf_rotated = torch.matmul(rot_mat, l_grf.unsqueeze(2)).squeeze(2) r_cop_rotated = torch.matmul(rot_mat, r_cop.unsqueeze(2)).squeeze(2) l_cop_rotated = torch.matmul(rot_mat, l_cop.unsqueeze(2)).squeeze(2) results_pred_clone[:, pelvis_orientation_col_loc] = matrix_to_euler_angles(pelvis_orientation_rotated.float(), "ZXY") results_pred_clone[:, p_pos_col_loc] = p_pos_rotated.float() results_pred_clone[:, r_grf_col_loc] = r_grf_rotated.float() results_pred_clone[:, l_grf_col_loc] = l_grf_rotated.float() results_pred_clone[:, r_cop_col_loc] = r_cop_rotated.float() results_pred_clone[:, l_cop_col_loc] = l_cop_rotated.float() return results_pred_clone def convert_addb_state_to_model_input(pose_df, joints_3d, sampling_fre): # shift root position to start in (x,y) = (0,0) pos_vec = [pose_df['pelvis_tx'][0], pose_df['pelvis_ty'][0], pose_df['pelvis_tz'][0]] pose_df['pelvis_tx'] = pose_df['pelvis_tx'] - pose_df['pelvis_tx'][0] pose_df['pelvis_ty'] = pose_df['pelvis_ty'] - pose_df['pelvis_ty'][0] pose_df['pelvis_tz'] = pose_df['pelvis_tz'] - pose_df['pelvis_tz'][0] # remove frozen dof for frozen_col in FROZEN_DOFS: if frozen_col in pose_df.columns: pose_df = pose_df.drop(frozen_col, axis=1) # convert euler to 6v for joint_name, joints_with_3_dof in joints_3d.items(): joint_6v = euler_to_6v(torch.tensor(pose_df[joints_with_3_dof].values), "ZXY").numpy() for i in range(6): pose_df[joint_name + '_' + str(i)] = joint_6v[:, i] for joint_name, joints_with_3_dof in joints_3d.items(): joint_angular_v = euler_to_angular_velocity(torch.tensor(pose_df[joints_with_3_dof].values), sampling_fre, "ZXY").numpy() joint_angular_v = data_filter(joint_angular_v, 15, sampling_fre, 4) for joints_euler_name in joints_with_3_dof: pose_df = pose_df.drop(joints_euler_name, axis=1) for i, axis in enumerate(['x', 'y', 'z']): pose_df[joint_name + '_' + axis + '_angular' + '_vel'] = joint_angular_v[:, i] vel_col_loc = [i for i, col in enumerate(pose_df.columns) if not np.sum([term in col for term in ['force', 'pelvis_', '_vel', '_0', '_1', '_2', '_3', '_4', '_5']])] vel_col_names = [f'{col}_vel' for i, col in enumerate(pose_df.columns) if not np.sum([term in col for term in ['force', 'pelvis_', '_vel', '_0', '_1', '_2', '_3', '_4', '_5']])] kinematics_np = pose_df.iloc[:, vel_col_loc].to_numpy().copy() kinematics_np_filtered = data_filter(kinematics_np, 15, sampling_fre, 4) kinematics_vel = np.stack([spline_fitting_1d(kinematics_np_filtered[:, i_col], range(kinematics_np_filtered.shape[0]), 1).ravel() for i_col in range(kinematics_np_filtered.shape[1])]).T pose_vel_df = pd.DataFrame(np.concatenate([pose_df.values, kinematics_vel], axis=1), columns=list(pose_df.columns)+vel_col_names) return pose_vel_df, pos_vec def inverse_convert_addb_state_to_model_input( model_states, model_states_column_names, treadmill_speed, joints_3d, osim_dof_columns, pos_vec, height_m, sampling_fre=100): model_states_dict = {col: model_states[..., i] for i, col in enumerate(model_states_column_names) if col in osim_dof_columns} for i_col, col in enumerate(['pelvis_tx', 'pelvis_ty', 'pelvis_tz']): model_states_dict[col] = model_states_dict[col] * height_m.unsqueeze(-1).expand(model_states_dict[col].shape) if col == 'pelvis_tx': model_states_dict[col] -= treadmill_speed model_states_dict[col] = torch.cumsum(model_states_dict[col], dim=-1) / sampling_fre # convert 6v to euler for joint_name, joints_with_3_dof in joints_3d.items(): joint_name_6v = [joint_name + '_' + str(i) for i in range(6)] index_ = [model_states_column_names.index(joint_name_6v[i]) for i in range(6)] joint_euler = euler_from_6v(model_states[..., index_], "ZXY") for i, joints_euler_name in enumerate(joints_with_3_dof): model_states_dict[joints_euler_name] = joint_euler[..., i] # add frozen dof back for frozen_col in FROZEN_DOFS: model_states_dict[frozen_col] = torch.zeros(model_states.shape[:len(model_states.shape)-1]).to(model_states.device) pos_vec = torch.tensor(pos_vec) pos_vec_torch = pos_vec.unsqueeze(-1).repeat(*[1 for _ in range(len(pos_vec.shape))], model_states.shape[-2]).to(model_states.device) for i_col, col in enumerate(['pelvis_tx', 'pelvis_ty', 'pelvis_tz']): model_states_dict[col] += pos_vec_torch[..., i_col, :] osim_states = torch.stack([model_states_dict[col] for col in osim_dof_columns], dim=len(model_states.shape)-1).float() return osim_states def align_moving_direction(poses, column_names): if isinstance(poses, np.ndarray): poses = torch.from_numpy(poses) pose_clone = poses.clone().float() pelvis_orientation_col_loc = [column_names.index(col) for col in JOINTS_3D_ALL['pelvis']] p_pos_col_loc = [column_names.index(col) for col in [f'pelvis_t{x}' for x in ['x', 'y', 'z']]] r_grf_col_loc = [column_names.index(col) for col in ['calcn_r_force_vx', 'calcn_r_force_vy', 'calcn_r_force_vz']] l_grf_col_loc = [column_names.index(col) for col in ['calcn_l_force_vx', 'calcn_l_force_vy', 'calcn_l_force_vz']] r_cop_col_loc = [column_names.index(col) for col in [f'calcn_r_force_normed_cop_{x}' for x in ['x', 'y', 'z']]] l_cop_col_loc = [column_names.index(col) for col in [f'calcn_l_force_normed_cop_{x}' for x in ['x', 'y', 'z']]] if len(pelvis_orientation_col_loc) != 3 or len(p_pos_col_loc) != 3 or len(r_grf_col_loc) != 3 or len( l_grf_col_loc) != 3: raise ValueError('check column names') pelvis_orientation = pose_clone[:, pelvis_orientation_col_loc] pelvis_orientation = euler_angles_to_matrix(pelvis_orientation, "ZXY") p_pos = pose_clone[:, p_pos_col_loc] r_grf = pose_clone[:, r_grf_col_loc] l_grf = pose_clone[:, l_grf_col_loc] r_cop = pose_clone[:, r_cop_col_loc] l_cop = pose_clone[:, l_cop_col_loc] angles = np.arctan2(- pelvis_orientation[:, 0, 2], pelvis_orientation[:, 2, 2]) if np.rad2deg(angles.max() - angles.min()) > 45: return False, None angle = angles.median() rot_mat = torch.tensor([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]]).float() pelvis_orientation_rotated = torch.matmul(rot_mat, pelvis_orientation) p_pos_rotated = torch.matmul(rot_mat, p_pos.unsqueeze(2)).squeeze(2) r_grf_rotated = torch.matmul(rot_mat, r_grf.unsqueeze(2)).squeeze(2) l_grf_rotated = torch.matmul(rot_mat, l_grf.unsqueeze(2)).squeeze(2) r_cop_rotated = torch.matmul(rot_mat, r_cop.unsqueeze(2)).squeeze(2) l_cop_rotated = torch.matmul(rot_mat, l_cop.unsqueeze(2)).squeeze(2) pose_clone[:, pelvis_orientation_col_loc] = matrix_to_euler_angles(pelvis_orientation_rotated.float(), "ZXY") pose_clone[:, p_pos_col_loc] = p_pos_rotated.float() pose_clone[:, r_grf_col_loc] = r_grf_rotated.float() pose_clone[:, l_grf_col_loc] = l_grf_rotated.float() pose_clone[:, r_cop_col_loc] = r_cop_rotated.float() pose_clone[:, l_cop_col_loc] = l_cop_rotated.float() return pose_clone, rot_mat def data_filter(data, cut_off_fre, sampling_fre, filter_order=4): fre = cut_off_fre / (sampling_fre / 2) b, a = butter(filter_order, fre, 'lowpass') if len(data.shape) == 1: data_filtered = filtfilt(b, a, data) else: data_filtered = filtfilt(b, a, data, axis=0) return data_filtered def fix_seed(): torch.manual_seed(0) random.seed(0) np.random.seed(0) def linear_resample_data(trial_data, original_fre, target_fre): x, step = np.linspace(0., 1., trial_data.shape[0], retstep=True) new_x = np.arange(0., 1., step * original_fre / target_fre) f = interp1d(x, trial_data, axis=0) trial_data_resampled = f(new_x) return trial_data_resampled def spline_fitting_1d(data_, step_to_resample, der=0): assert len(data_.shape) == 1 data_ = data_.reshape(1, -1) tck, step = interpo.splprep(data_, u=range(data_.shape[1]), s=0) data_resampled = interpo.splev(step_to_resample, tck, der=der) data_resampled = np.column_stack(data_resampled) return data_resampled def convert_overlapped_list_to_array(trial_len, win_list, s_, e_, fun=np.nanmedian, max_size=150): array_val_expand = np.full((max_size, trial_len, win_list[0].shape[1]), np.nan) for i_win, (win, s, e) in enumerate(zip(win_list, s_, e_)): array_val_expand[i_win%max_size, s:e] = win[:e-s] array_val = fun(array_val_expand, axis=0) std_val = np.nanstd(array_val_expand, axis=0) return array_val, std_val def identity(t, *args, **kwargs): return t def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def make_beta_schedule( schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 ): if schedule == "linear": betas = ( torch.linspace( linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64 ) ** 2 ) elif schedule == "cosine": timesteps = ( torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s ) alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) elif schedule == "sqrt_linear": betas = torch.linspace( linear_start, linear_end, n_timestep, dtype=torch.float64 ) elif schedule == "sqrt": betas = ( torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 ) else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() """ ============================ End util.py ============================ """ """ ============================ Start args.py ============================ """ def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument("--exp_name", default="new_orientation_alignment_diffusion", help="save to project/name") parser.add_argument("--with_arm", type=bool, default=False, help="whether osim model has arm DoFs") parser.add_argument("--with_kinematics_vel", type=bool, default=True, help="whether to include 1st derivative of kinematics") parser.add_argument("--epochs", type=int, default=7680) parser.add_argument("--target_sampling_rate", type=int, default=100) parser.add_argument("--window_len", type=int, default=150) parser.add_argument("--guide_x_start_the_beginning_step", type=int, default=-10) # negative value means no guidance parser.add_argument("--project", default="runs/train", help="project/name") parser.add_argument( "--processed_data_dir", type=str, default="dataset_backups/", help="Dataset backup path", ) parser.add_argument("--feature_type", type=str, default="jukebox") parser.add_argument("--batch_size", type=int, default=256, help="batch size") parser.add_argument("--batch_size_inference", type=int, default=5, help="batch size during inference") parser.add_argument( "--checkpoint", type=str, default="", help="trained checkpoint path (optional)" ) # parser.add_argument( # "--checkpoint_bl", type=str, default="", help="trained checkpoint path (optional)" # ) opt = parser.parse_args(args=[]) set_no_arm_opt(opt) current_folder = os.getcwd() opt.subject_data_path = current_folder opt.geometry_folder = current_folder + '/Geometry/' opt.checkpoint_bl = current_folder + '/GaitDynamicsRefinement.pt' return opt def set_no_arm_opt(opt): opt.with_arm = False opt.osim_dof_columns = copy.deepcopy(OSIM_DOF_ALL[:23] + KINETICS_ALL) opt.joints_3d = {key_: value_ for key_, value_ in JOINTS_3D_ALL.items() if key_ in ['pelvis', 'hip_r', 'hip_l', 'lumbar']} opt.model_states_column_names = copy.deepcopy(MODEL_STATES_COLUMN_NAMES_NO_ARM) for joint_name, joints_with_3_dof in opt.joints_3d.items(): opt.model_states_column_names = opt.model_states_column_names + [ joint_name + '_' + axis + '_angular_vel' for axis in ['x', 'y', 'z']] if opt.with_kinematics_vel: opt.model_states_column_names = opt.model_states_column_names + [ f'{col}_vel' for i_col, col in enumerate(opt.model_states_column_names) if not sum([term in col for term in ['force', 'pelvis_', '_vel', '_0', '_1', '_2', '_3', '_4', '_5']])] opt.knee_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'knee' in col] opt.ankle_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'ankle' in col] opt.hip_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'hip' in col] opt.kinematic_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'force' not in col] opt.kinetic_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if i_col not in opt.kinematic_diffusion_col_loc] opt.grf_osim_col_loc = [i_col for i_col, col in enumerate(opt.osim_dof_columns) if 'force' in col and '_cop_' not in col] opt.cop_osim_col_loc = [i_col for i_col, col in enumerate(opt.osim_dof_columns) if '_cop_' in col] opt.kinematic_osim_col_loc = [i_col for i_col, col in enumerate(opt.osim_dof_columns) if 'force' not in col] """ ============================ Start dataset.py ============================ """ class MotionDataset(Dataset): def __init__( self, opt, normalizer: Any = None, max_trial_num=None, check_cop_to_calcn_distance=True, ): self.data_path = opt.subject_data_path self.subject_osim_model = opt.subject_osim_model if opt.target_sampling_rate != 100: raise ValueError('100 Hz sampling rate is not confirmed. Confirm by setting opt.target_sampling_rate = 100') self.target_sampling_rate = opt.target_sampling_rate self.window_len = opt.window_len self.opt = opt self.check_cop_to_calcn_distance = check_cop_to_calcn_distance self.skel = None self.dset_set = set() print("Loading dataset...") self.load_addb(opt) self.guess_vel_and_replace_txtytz() if not len(self.trials): print("No trials loaded") return self.normalizer = normalizer for i_trial in range(len(self.trials)): self.trials[i_trial].converted_pose = self.normalizer.normalize(self.trials[i_trial].converted_pose).clone().detach().float() def customized_param_manipulation(self, states_df): # for guided diffusion return states_df def guess_vel_and_replace_txtytz(self): pelvis_pos_loc = [self.opt.model_states_column_names.index(col) for col in [f'pelvis_t{x}' for x in ['x', 'y', 'z']]] for i_trial, trial in enumerate(self.trials): body_center = trial.converted_pose[:, 0:3] body_center = data_filter(body_center, 10, self.target_sampling_rate).astype(np.float32) vel_from_t = np.diff(body_center, axis=0) * self.target_sampling_rate vel_from_t = np.concatenate([vel_from_t, vel_from_t[-1][None, :]], axis=0) if self.opt.treadmill_speed is None: raise ValueError('Treadmill speed is not set. For overground walking, set opt.treadmill_speed = 0') vel_from_t[:, 0] = vel_from_t[:, 0] + self.opt.treadmill_speed walking_vel = vel_from_t self.trials[i_trial].converted_pose[:, pelvis_pos_loc] = torch.from_numpy(walking_vel) / self.trials[i_trial].height_m def __len__(self): return self.opt.pseudo_dataset_len def get_overlapping_wins(self, col_loc_to_unmask, step_len, start_trial=0, end_trial=None, including_shorter_than_window_len=False): if end_trial is None: end_trial = len(self.trials) windows, s_list, e_list = [], [], [] for i_trial in range(start_trial, end_trial): trial_ = self.trials[i_trial] col_loc_to_unmask_trial = copy.deepcopy(col_loc_to_unmask) for col in trial_.missing_col: col_index = opt.model_states_column_names.index(col) col_loc_to_unmask_trial.remove(col_index) trial_len = trial_.converted_pose.shape[0] if including_shorter_than_window_len: e_of_trial = trial_len else: e_of_trial = trial_len - self.opt.window_len + step_len for i in range(0, e_of_trial, step_len): s = max(0, i) e = min(trial_len, i + self.opt.window_len) s_list.append(s) e_list.append(e) mask = torch.zeros([self.opt.window_len, len(self.opt.model_states_column_names)]) mask[:, col_loc_to_unmask_trial] = 1 mask[e-s:, :] = 0 data_ = torch.zeros([self.opt.window_len, len(self.opt.model_states_column_names)]) data_[:e-s] = trial_.converted_pose[s:e, ...] windows.append(WindowData(data_, trial_.model_offsets, i_trial, None, mask, trial_.height_m, trial_.weight_kg, trial_.missing_col)) return windows, s_list, e_list def load_addb(self, opt): file_paths = opt.file_paths customOsim: nimble.biomechanics.OpenSimFile = nimble.biomechanics.OpenSimParser.parseOsim(self.subject_osim_model, self.opt.geometry_folder) skel = customOsim.skeleton self.skel = skel self.trials, self.file_names, self.rot_mat_trials, self.time_column = [], [], [], [] for i_file, file_path in enumerate(file_paths): model_offsets = get_model_offsets(skel).float() with open(file_path) as f: line_num = 0 endheader_line, angle_scale = None, None while True: header = f.readline() line_num += 1 if 'endheader' in header: endheader_line = line_num break if 'inDegrees' in header and angle_scale is None: if 'yes' in header and 'no' not in header: angle_scale = np.pi / 180 elif 'no' in header and 'yes' not in header: angle_scale = 1 else: raise ValueError('No inDegrees keyword in the header, cannot determine the unit of angles. ' 'Here is an example header: \nCoordinates\nversion=1\nnRows=1380' '\nnColumns=26\ninDegrees=yes\n') if endheader_line is None: raise ValueError(f'No \'endheader\' line found in the header of {file_path}.') poses_df = pd.read_csv(file_path, sep='\t', skiprows=endheader_line) if 'time' not in poses_df.columns: raise ValueError(f'{file_path} does not have time column. Necessary for compuing sampling rate') sampling_rate = round((poses_df.shape[0] - 1) / (poses_df['time'].iloc[-1] - poses_df['time'].iloc[0])) self.time_column.append(poses_df['time']) missing_col = [] for col in JOINTS_1D_ALL: if col not in poses_df.columns: poses_df[col] = 0. if col not in FROZEN_DOFS: missing_col.append(col) missing_col.append(col + '_vel') for key_, value_ in JOINTS_3D_ALL.items(): if not all([col in poses_df.columns for col in value_]): columns_not_in_poses = [col for col in value_ if col not in poses_df.columns] if len(columns_not_in_poses) < 3: print(f'Warning {key_} is a 3-D joint. Since {columns_not_in_poses.__str__()[1:-1]} is missing, ' f'the other DoFs ({[dof for dof in value_ if dof not in columns_not_in_poses].__str__()[1:-1]}) are filled with zeros too.') missing_col.extend([f'{key_}_{n}' for n in range(6)]) missing_col.extend([f'{key_}_{axis}_angular_vel' for axis in ['x', 'y', 'z']]) poses_df[value_] = 0. poses_df = poses_df.astype(float) col_list = list(poses_df.columns) col_loc = [col_list.index(col) for col in OSIM_DOF_ALL[:23]] angle_col_loc = [col_list.index(col) for col in OSIM_DOF_ALL[:23] if col not in ['pelvis_tx', 'pelvis_ty', 'pelvis_tz']] poses_df.iloc[:, angle_col_loc] = poses_df.iloc[:, angle_col_loc] * angle_scale poses = poses_df.values[:, col_loc] all_zeros = np.zeros([poses_df.shape[0], 12]) # used to fill in the missing GRF in normalization states = np.concatenate([np.array(poses), all_zeros], axis=1) if not self.is_lumbar_rotation_reasonable(np.array(states), opt.osim_dof_columns): Warning(f'Warning: {file_path} has unreasonable lumbar rotation, don\'t trust this trial.') states_aligned, rot_mat = align_moving_direction(states, opt.osim_dof_columns) if states_aligned is False: print(f'Warning: {file_path} Pelvis orientation changed by more than 45 deg, don\'t trust this trial') rot_mat = torch.eye(3).float() else: states = states_aligned self.rot_mat_trials.append(rot_mat) file_name = file_path.split('/')[-1] if states.shape[0] / sampling_rate * self.target_sampling_rate < self.window_len + 2: print(f'Warning: {file_name} is shorter than 1.5s, skipping.') continue if sampling_rate != self.target_sampling_rate: states = linear_resample_data(states, sampling_rate, self.target_sampling_rate) states_df = pd.DataFrame(states, columns=opt.osim_dof_columns) states_df = self.customized_param_manipulation(states_df) states_df, pos_vec = convert_addb_state_to_model_input(states_df, opt.joints_3d, self.target_sampling_rate) assert self.opt.model_states_column_names == list(states_df.columns) converted_states = torch.tensor(states_df.values).float() trial_data = TrialData(converted_states, model_offsets, opt.height_m, opt.weight_kg, rot_mat, pos_vec, self.window_len, missing_col) self.trials.append(trial_data) self.file_names.append(file_name) @staticmethod def is_lumbar_rotation_reasonable(states, column_names): lumbar_rotation_col_loc = column_names.index('lumbar_rotation') if np.abs(np.mean(states[:, lumbar_rotation_col_loc])) > np.deg2rad(45): return False else: return True class WindowData: def __init__(self, pose, model_offsets, trial_id, gait_phase_label, mask, height_m, weight_kg, missing_col): self.pose = pose self.model_offsets = model_offsets self.trial_id = trial_id self.gait_phase_label = gait_phase_label self.mask = mask self.height_m = height_m self.weight_kg = weight_kg self.missing_col = missing_col class TrialData: def __init__(self, converted_states, model_offsets, height_m, weight_kg, rot_mat_for_moving_direction_alignment, pos_vec_for_pos_alignment, window_len, missing_col): self.converted_pose = converted_states self.model_offsets = model_offsets self.height_m = height_m self.weight_kg = weight_kg self.length = converted_states.shape[0] self.rot_mat_for_moving_direction_alignment = rot_mat_for_moving_direction_alignment self.pos_vec_for_pos_alignment = pos_vec_for_pos_alignment self.window_len = window_len self.missing_col = missing_col """ ============================ End dataset.py ============================ """ def predict_grf_and_missing_kinematics(): refinement_model = BaselineModel(opt, TransformerEncoderArchitecture) dataset = MotionDataset(opt, normalizer=refinement_model.normalizer) diffusion_model_for_filling = None filling_method = DiffusionFilling() for i_trial in range(len(dataset.trials)): windows, s_list, e_list = dataset.get_overlapping_wins(opt.kinematic_diffusion_col_loc, 150, i_trial, i_trial+1) if len(windows) == 0: continue if len(windows[0].missing_col) > 0: print(f'File {dataset.file_names[i_trial]} do not have {windows[0].missing_col}. ' f'\nGenerating missing kinematics for {dataset.file_names[i_trial]}') if diffusion_model_for_filling is None: diffusion_model_for_filling, _ = load_diffusion_model(opt) windows_reconstructed = filling_method.fill_param(windows, diffusion_model_for_filling) else: windows_reconstructed = windows state_pred_list = [] print(f'Running GaitDynamics on file {dataset.file_names[i_trial]} for external force prediction.') for i_win in range(0, len(windows), opt.batch_size_inference): state_true = torch.stack([win.pose for win in windows_reconstructed[i_win:i_win+opt.batch_size_inference]]) masks = torch.stack([win.mask for win in windows_reconstructed[i_win:i_win+opt.batch_size_inference]]) state_pred_list_batch = refinement_model.eval_loop(opt, state_true, masks, num_of_generation_per_window=1)[0] state_pred_list.append(state_pred_list_batch) state_pred = torch.cat(state_pred_list, dim=0) trial_len = dataset.trials[i_trial].converted_pose.shape[0] results_pred, _ = convert_overlapped_list_to_array( trial_len, state_pred, s_list, e_list) height_m_tensor = torch.tensor([windows[0].height_m]) results_pred = inverse_convert_addb_state_to_model_input( torch.from_numpy(results_pred).unsqueeze(0), opt.model_states_column_names, opt.treadmill_speed, opt.joints_3d, opt.osim_dof_columns, dataset.trials[i_trial].pos_vec_for_pos_alignment, height_m_tensor)[0].numpy() results_pred = inverse_norm_cops(dataset.skel, results_pred, opt, windows[0].weight_kg, windows[0].height_m) results_pred = inverse_align_moving_direction(results_pred, opt.osim_dof_columns, dataset.rot_mat_trials[i_trial]) results_pred[:, -12:-9] = results_pred[:, -12:-9] * opt.weight_kg # convert to N results_pred[:, -6:-3] = results_pred[:, -6:-3] * opt.weight_kg # convert to N df = pd.DataFrame(results_pred, columns=opt.osim_dof_columns) trial_save_path = f'{dataset.file_names[i_trial][:-4]}_grf_pred___.mot' convertDfToGRFMot(df, trial_save_path, round(1 / opt.target_sampling_rate, 3), dataset.time_column[i_trial]) if len(windows[0].missing_col) > 0: trc_save_path = f'{dataset.file_names[i_trial][:-4]}_missing_kinematics_pred___.mot' convertDataframeToMot(df[OSIM_DOF_ALL[:23]], trc_save_path, round(1 / opt.target_sampling_rate, 3), dataset.time_column[i_trial]) print('You can now download files from the default folder.') import gradio as gr import matplotlib.pyplot as plt import io import shutil opt = parse_opt() def gradio_predict(mot_file, osim_file, height_m, weight_kg, treadmill_speed, progress_callback=None): """Gradio interface function for GRF prediction.""" if mot_file is None: return "Please upload a .mot file" if osim_file is None: return "Please upload a .osim file" # Validate inputs if not (1.0 <= height_m <= 2.5): return "Height must be between 1.0 and 2.5 meters" if not (30 <= weight_kg <= 200): return "Weight must be between 30 and 200 kg" if not (0 <= treadmill_speed <= 10): return "Treadmill speed must be between 0 and 10 m/s" try: # Get opt with default settings # Set user inputs opt.height_m = height_m opt.weight_kg = weight_kg opt.treadmill_speed = treadmill_speed # Copy uploaded files to the working directory mot_path = os.path.join(opt.subject_data_path, os.path.basename(mot_file.name)) osim_path = os.path.join(opt.subject_data_path, os.path.basename(osim_file.name)) shutil.copy2(mot_file.name, mot_path) shutil.copy2(osim_file.name, osim_path) opt.file_paths = [mot_path] opt.subject_osim_model = osim_path # Run prediction with updated opt output_files = predict_grf_and_missing_kinematics_with_opt(progress_callback=progress_callback) if output_files: return "Prediction completed! Download files below.", output_files[0], output_files[1] if len(output_files) > 1 else None else: return "Prediction completed but no output files generated.", None, None except Exception as e: return f"Error during prediction: {str(e)}", None, None def predict_grf_and_missing_kinematics_with_opt(progress_callback=None): """Modified prediction function that accepts opt as parameter.""" refinement_model = BaselineModel(opt, TransformerEncoderArchitecture) dataset = MotionDataset(opt, normalizer=refinement_model.normalizer) diffusion_model_for_filling = None filling_method = DiffusionFilling() output_files = [] for i_trial in range(len(dataset.trials)): windows, s_list, e_list = dataset.get_overlapping_wins(opt.kinematic_diffusion_col_loc, 150, i_trial, i_trial+1) if len(windows) == 0: continue if len(windows[0].missing_col) > 0: print(f'File {dataset.file_names[i_trial]} do not have {windows[0].missing_col}. ' f'\nGenerating missing kinematics for {dataset.file_names[i_trial]}') if diffusion_model_for_filling is None: diffusion_model_for_filling, _ = load_diffusion_model(opt) windows_reconstructed = filling_method.fill_param(windows, diffusion_model_for_filling, progress_callback=progress_callback) else: windows_reconstructed = windows state_pred_list = [] print(f'Running GaitDynamics on file {dataset.file_names[i_trial]} for external force prediction.') for i_win in range(0, len(windows), opt.batch_size_inference): state_true = torch.stack([win.pose for win in windows_reconstructed[i_win:i_win+opt.batch_size_inference]]) masks = torch.stack([win.mask for win in windows_reconstructed[i_win:i_win+opt.batch_size_inference]]) state_pred_list_batch = refinement_model.eval_loop(opt, state_true, masks, num_of_generation_per_window=1)[0] state_pred_list.append(state_pred_list_batch) state_pred = torch.cat(state_pred_list, dim=0) trial_len = dataset.trials[i_trial].converted_pose.shape[0] results_pred, _ = convert_overlapped_list_to_array( trial_len, state_pred, s_list, e_list) height_m_tensor = torch.tensor([windows[0].height_m]) results_pred = inverse_convert_addb_state_to_model_input( torch.from_numpy(results_pred).unsqueeze(0), opt.model_states_column_names, opt.treadmill_speed, opt.joints_3d, opt.osim_dof_columns, dataset.trials[i_trial].pos_vec_for_pos_alignment, height_m_tensor)[0].numpy() results_pred = inverse_norm_cops(dataset.skel, results_pred, opt, windows[0].weight_kg, windows[0].height_m) results_pred = inverse_align_moving_direction(results_pred, opt.osim_dof_columns, dataset.rot_mat_trials[i_trial]) results_pred[:, -12:-9] = results_pred[:, -12:-9] * opt.weight_kg # convert to N results_pred[:, -6:-3] = results_pred[:, -6:-3] * opt.weight_kg # convert to N df = pd.DataFrame(results_pred, columns=opt.osim_dof_columns) trial_save_path = f'{dataset.file_names[i_trial][:-4]}_grf_pred___.mot' convertDfToGRFMot(df, trial_save_path, round(1 / opt.target_sampling_rate, 3), dataset.time_column[i_trial]) output_files.append(trial_save_path) if len(windows[0].missing_col) > 0: trc_save_path = f'{dataset.file_names[i_trial][:-4]}_missing_kinematics_pred___.mot' convertDataframeToMot(df[OSIM_DOF_ALL[:23]], trc_save_path, round(1 / opt.target_sampling_rate, 3), dataset.time_column[i_trial]) output_files.append(trc_save_path) print('You can now download files from the default folder.') return output_files def _extract_path(file_obj): try: if file_obj is None: return None if isinstance(file_obj, str): return file_obj if isinstance(file_obj, dict) and 'name' in file_obj: return file_obj['name'] if hasattr(file_obj, 'name'): return file_obj.name except Exception: pass return None def read_mot_to_dataframe(file_path): if file_path is None: return None with open(file_path, 'r') as f: lines = f.readlines() header_idx = None for i, line in enumerate(lines): if line.strip().startswith('time'): header_idx = i break if header_idx is None: return None data_str = ''.join(lines[header_idx:]) try: df = pd.read_csv(io.StringIO(data_str), sep='\t') except Exception: df = pd.read_csv(io.StringIO(data_str), delim_whitespace=True) return df def list_mot_columns(file_choice, grf_file, kin_file): selected = _extract_path(grf_file) if file_choice == 'GRF Results' else _extract_path(kin_file) df = read_mot_to_dataframe(selected) if df is None or 'time' not in df.columns: return gr.update(choices=[], value=None) cols = [c for c in df.columns if c != 'time' and 'torque' not in c.lower()] return gr.update(choices=cols, value=(cols[0] if len(cols) > 0 else None)) def plot_mot_signal(file_choice, column, grf_file, kin_file): selected = _extract_path(grf_file) if file_choice == 'GRF Results' else _extract_path(kin_file) df = read_mot_to_dataframe(selected) if df is None or 'time' not in df.columns or column is None or column not in df.columns: fig = plt.figure(figsize=(10, 5)) plt.title('No data to display') plt.tight_layout() return fig fig = plt.figure(figsize=(10, 5)) ax = fig.add_subplot(111) ax.plot(df['time'], df[column]) ax.set_xlabel('time') ax.set_ylabel(column) ax.set_title(f'{column} over time') ax.grid(True) fig.tight_layout() return fig # Create Gradio Blocks interface with plotting panel with gr.Blocks(css=""" .download-file .file-preview { background-color: #e8f5e8 !important; border-left: 4px solid #4caf50 !important; padding: 6px 10px !important; font-weight: bold !important; color: #2e7d32 !important; } #time_series_plot {height: 500px !important;} """) as demo: gr.HTML( """

GaitDynamics - Ground Reaction Force and Kinematics Prediction

General instructions

This code is for ground reaction force and missing kinematics prediction using flexible combinations of OpenSim joint angles. The joint angles should meet the following criteria:

1. Use OpenSim Rajagopal Model without Arms.
2. Resample to 100 Hz if you have a different sampling rate.

To use this code:

1. Upload OpenSim model (.osim) and kinematics (.mot) file. If kinematics are provided for all coordinates, the model will only predict ground reaction forces. If kinematics are missing, the model will also predict missing kinematics.
2. Enter the height and weight of the participant, and the treadmill speed (if applicable).
3. Click the "Submit " button.

Processing example data

1. Download the example files:
""" ) with gr.Column(): with gr.Row(): gr.DownloadButton("a. OpenSim model file (.osim)", value="example_opensim_model.osim", size="sm", scale=1) gr.DownloadButton("b. complete kinematics file with all coordinates (.mot)", value="example_mot_complete_kinematics.mot", size="sm", scale=1) gr.DownloadButton("c. incomplete kinematics file with missing knee coordinate kinematics (.mot)", value="example_mot_missing_knee_kinematics.mot", size="sm", scale=1) gr.HTML( """
2. Upload data for either of the two example .mot files.
a. Predict ground reaction forces with a complete kinematics file by uploading the model file (.osim) and complete kinematics file (.mot).
b. Predict ground reaction forces with an incomplete kinematics file by uploading the model file (.osim) and the incomplete kinematics file (.mot).
3. Update the following input parameters.
a. Height = 1.83 m
b. Weight = 71.4 kg
c. Treadmill speed = 1.15 m/s
4. Click "Submit".
a. The example with complete kinematics should take a few seconds to complete.
b. The example with incomplete kinematics should take a few minutes to complete.
5. Plot the data, starting with GRF Results and choosing columns (e.g., force_r_vy).
""" ) gr.HTML("

Input Data and Parameters

") with gr.Row(): mot_in = gr.File(label="Upload .mot file", file_types=[".mot"]) osim_in = gr.File(label="Upload .osim file", file_types=[".osim"]) with gr.Row(): height_in = gr.Number(label="Height (meters)", value=1.7, minimum=1.0, maximum=2.5, step=0.01) weight_in = gr.Number(label="Weight (kg)", value=70, minimum=30, maximum=200, step=0.1) speed_in = gr.Number(label="Treadmill speed (m/s, 0 for overground)", value=0, minimum=0, maximum=10, step=0.01) submit_btn = gr.Button("Submit") gr.HTML("

Results

") with gr.Row(): status_out = gr.Textbox(label="Status") grf_path = gr.State(None) kin_path = gr.State(None) with gr.Row(): grf_file_out = gr.DownloadButton("Download Ground Reaction Force Results", visible=False) kin_file_out = gr.DownloadButton("Download Missing Kinematics (0 for overground)", visible=False) # Plotting panel with gr.Column(elem_id="plotting_panel"): gr.HTML("""

Visualization

Ground reaction force (GRF) results are available for all inputs. Missing kinematics results are only available if incomplete kinematics were used as inputs.
GRF column names follow these conventions:
• Starts with force_l or force_r: indicates the left or right foot, respectively
• Ends in _vx, _vy, _vz: indicates the magnitude value of the force in the x, y, and z directions, respectively.
• Ends in _px, _py, _pz: indicates the point location of the center of pressure in the x, y, and z directions, respectively.
• x is the anterior direction, y is the vertical direction, and z is the direction pointing to the right (medial or lateral direction depending on the leg).
""") with gr.Row(): file_select = gr.Radio(choices=["GRF Results", "Missing Kinematics"], value="GRF Results", label="Select output file") column_select = gr.Dropdown(choices=[], label="Column", interactive=True) plot_btn = gr.Button("Plot") plot_out = gr.Plot(label="Time series plot", elem_id="time_series_plot") gr.HTML("""

Citation

Tan, T., Van Wouwe, T., Werling, K.F. et al. GaitDynamics: a generative foundation model for analyzing human walking and running. Nat. Biomed. Eng (2026).

https://doi.org/10.1038/s41551-025-01565-8

""") # Wire interactions def enhanced_predict(mot_in, osim_in, height_in, weight_in, speed_in, current_select, current_grf, current_kin, progress=gr.Progress()): progress(0.0, desc="Starting...") def progress_callback(pct): progress(pct, desc=f"Processing diffusion model... {int(pct*100)}%") status, new_grf, new_kin = gradio_predict(mot_in, osim_in, height_in, weight_in, speed_in, progress_callback=progress_callback) progress(1.0, desc="Complete!") grf_visible = new_grf is not None kin_visible = new_kin is not None return status, new_grf, new_kin, gr.update(value=new_grf, visible=grf_visible), gr.update(value=new_kin, visible=kin_visible) submit_btn.click( fn=enhanced_predict, inputs=[mot_in, osim_in, height_in, weight_in, speed_in, file_select, grf_path, kin_path], outputs=[status_out, grf_path, kin_path, grf_file_out, kin_file_out] ).then( fn=list_mot_columns, inputs=[file_select, grf_path, kin_path], outputs=[column_select] ) file_select.change( fn=list_mot_columns, inputs=[file_select, grf_path, kin_path], outputs=[column_select] ) plot_btn.click( fn=plot_mot_signal, inputs=[file_select, column_select, grf_path, kin_path], outputs=[plot_out] ) if __name__ == '__main__': demo.launch()