Upload 8 files
Browse files- AESKConv_240_100.bin +3 -0
- __init__.py +0 -0
- decoders.py +56 -0
- mean_vel_smplxflame_30.npy +3 -0
- mertic.py +357 -0
- motion_encoder.py +193 -0
- skeleton.py +298 -0
- skeleton_DME.py +473 -0
AESKConv_240_100.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5cd9566b24264f34d44003b3de62cdfd50aa85b7cdde2d369214599023c40f55
|
| 3 |
+
size 17558653
|
__init__.py
ADDED
|
File without changes
|
decoders.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This script is modified from https://github.com/EricGuo5513/TM2T
|
| 2 |
+
# Licensed under: https://github.com/EricGuo5513/TM2T/blob/main/LICENSE
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VQDecoderV3(nn.Module):
|
| 8 |
+
def __init__(self, args):
|
| 9 |
+
super(VQDecoderV3, self).__init__()
|
| 10 |
+
n_up = args.vae_layer
|
| 11 |
+
channels = []
|
| 12 |
+
for i in range(n_up - 1):
|
| 13 |
+
channels.append(args.vae_length)
|
| 14 |
+
channels.append(args.vae_length)
|
| 15 |
+
channels.append(args.vae_test_dim)
|
| 16 |
+
input_size = args.vae_length
|
| 17 |
+
n_resblk = 2
|
| 18 |
+
assert len(channels) == n_up + 1
|
| 19 |
+
if input_size == channels[0]:
|
| 20 |
+
layers = []
|
| 21 |
+
else:
|
| 22 |
+
layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
|
| 23 |
+
|
| 24 |
+
for i in range(n_resblk):
|
| 25 |
+
layers += [ResBlock(channels[0])]
|
| 26 |
+
# channels = channels
|
| 27 |
+
for i in range(n_up):
|
| 28 |
+
layers += [
|
| 29 |
+
nn.Upsample(scale_factor=2, mode="nearest"),
|
| 30 |
+
nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
|
| 31 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 32 |
+
]
|
| 33 |
+
layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
|
| 34 |
+
self.main = nn.Sequential(*layers)
|
| 35 |
+
# self.main.apply(init_weight)
|
| 36 |
+
|
| 37 |
+
def forward(self, inputs):
|
| 38 |
+
inputs = inputs.permute(0, 2, 1)
|
| 39 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
| 40 |
+
return outputs
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ResBlock(nn.Module):
|
| 44 |
+
def __init__(self, channel):
|
| 45 |
+
super(ResBlock, self).__init__()
|
| 46 |
+
self.model = nn.Sequential(
|
| 47 |
+
nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
|
| 48 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 49 |
+
nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
residual = x
|
| 54 |
+
out = self.model(x)
|
| 55 |
+
out += residual
|
| 56 |
+
return out
|
mean_vel_smplxflame_30.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:53b5e48f2a7bf78c41a6de6395d6bb4f29018465ca5d0ee2820a2be3eebb7137
|
| 3 |
+
size 348
|
mertic.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import wget
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import librosa
|
| 6 |
+
import librosa.display
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from scipy.signal import argrelextrema
|
| 9 |
+
from scipy import linalg
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from .motion_encoder import VAESKConv
|
| 13 |
+
|
| 14 |
+
class LVDFace(object):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.counter = 0
|
| 17 |
+
self.sum = 0
|
| 18 |
+
|
| 19 |
+
def compute(self, pred_vertices, target_vertices):
|
| 20 |
+
t, c = pred_vertices.shape
|
| 21 |
+
diff_pred = pred_vertices[1:, :] - pred_vertices[:-1, :]
|
| 22 |
+
diff_target = target_vertices[1:, :] - target_vertices[:-1, :]
|
| 23 |
+
loss = np.abs(diff_pred - diff_target)
|
| 24 |
+
loss = np.sum(loss)
|
| 25 |
+
self.counter += t * c
|
| 26 |
+
self.sum += loss
|
| 27 |
+
|
| 28 |
+
def avg(self):
|
| 29 |
+
return self.sum / self.counter
|
| 30 |
+
|
| 31 |
+
def reset(self):
|
| 32 |
+
self.counter = 0
|
| 33 |
+
self.sum = 0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MSEFace(object):
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self.counter = 0
|
| 39 |
+
self.sum = 0
|
| 40 |
+
|
| 41 |
+
def compute(self, pred_vertices, target_vertices):
|
| 42 |
+
t, c = pred_vertices.shape
|
| 43 |
+
loss = np.square(pred_vertices - target_vertices)
|
| 44 |
+
self.sum += np.sum(loss)
|
| 45 |
+
self.counter += t * c
|
| 46 |
+
|
| 47 |
+
def avg(self):
|
| 48 |
+
if self.counter == 0:
|
| 49 |
+
return 0
|
| 50 |
+
return self.sum / self.counter
|
| 51 |
+
|
| 52 |
+
def reset(self):
|
| 53 |
+
self.counter = 0
|
| 54 |
+
self.sum = 0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class L1div(object):
|
| 58 |
+
def __init__(self):
|
| 59 |
+
self.counter = 0
|
| 60 |
+
self.sum = 0
|
| 61 |
+
|
| 62 |
+
def compute(self, results):
|
| 63 |
+
self.counter += results.shape[0]
|
| 64 |
+
mean = np.mean(results, axis=0)
|
| 65 |
+
sum_l1 = np.sum(np.abs(results - mean), axis=None)
|
| 66 |
+
self.sum += sum_l1
|
| 67 |
+
|
| 68 |
+
def avg(self):
|
| 69 |
+
if self.counter == 0:
|
| 70 |
+
return 0
|
| 71 |
+
return self.sum / self.counter
|
| 72 |
+
|
| 73 |
+
def reset(self):
|
| 74 |
+
self.counter = 0
|
| 75 |
+
self.sum = 0
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SRGR(object):
|
| 79 |
+
def __init__(self, threshold=0.1, joints=47, joint_dim=3):
|
| 80 |
+
self.threshold = threshold
|
| 81 |
+
self.pose_dimes = joints
|
| 82 |
+
self.joint_dim = joint_dim
|
| 83 |
+
self.counter = 0
|
| 84 |
+
self.sum = 0
|
| 85 |
+
|
| 86 |
+
def run(self, results, targets, semantic=None, verbose=False):
|
| 87 |
+
if semantic is None:
|
| 88 |
+
semantic = np.ones(results.shape[0])
|
| 89 |
+
avg_weight = 1.0
|
| 90 |
+
else:
|
| 91 |
+
# srgr == 0.165 when all success, scale range to [0, 1]
|
| 92 |
+
avg_weight = 0.165
|
| 93 |
+
results = results.reshape(-1, self.pose_dimes, self.joint_dim)
|
| 94 |
+
targets = targets.reshape(-1, self.pose_dimes, self.joint_dim)
|
| 95 |
+
semantic = semantic.reshape(-1)
|
| 96 |
+
diff = np.linalg.norm(results - targets, axis=2) # T, J
|
| 97 |
+
if verbose:
|
| 98 |
+
print(diff)
|
| 99 |
+
success = np.where(diff < self.threshold, 1.0, 0.0)
|
| 100 |
+
for i in range(success.shape[0]):
|
| 101 |
+
success[i, :] *= semantic[i] * (1 / avg_weight)
|
| 102 |
+
rate = np.sum(success) / (success.shape[0] * success.shape[1])
|
| 103 |
+
self.counter += success.shape[0]
|
| 104 |
+
self.sum += rate * success.shape[0]
|
| 105 |
+
return rate
|
| 106 |
+
|
| 107 |
+
def avg(self):
|
| 108 |
+
return self.sum / self.counter
|
| 109 |
+
|
| 110 |
+
def reset(self):
|
| 111 |
+
self.counter = 0
|
| 112 |
+
self.sum = 0
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class BC(object):
|
| 116 |
+
def __init__(self, download_path=None, sigma=0.3, order=7, upper_body=[3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]):
|
| 117 |
+
self.sigma = sigma
|
| 118 |
+
self.order = order
|
| 119 |
+
self.upper_body = upper_body
|
| 120 |
+
self.pose_data = []
|
| 121 |
+
if download_path is not None:
|
| 122 |
+
os.makedirs(download_path, exist_ok=True)
|
| 123 |
+
model_file_path = os.path.join(download_path, "mean_vel_smplxflame_30.npy")
|
| 124 |
+
if not os.path.exists(model_file_path):
|
| 125 |
+
print(f"Downloading {model_file_path}")
|
| 126 |
+
wget.download(
|
| 127 |
+
"https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/test_sequences/weights/mean_vel_smplxflame_30.npy",
|
| 128 |
+
model_file_path,
|
| 129 |
+
)
|
| 130 |
+
self.mmae = np.load(os.path.join(download_path, "mean_vel_smplxflame_30.npy")) if download_path is not None else None
|
| 131 |
+
self.threshold = 0.10
|
| 132 |
+
self.counter = 0
|
| 133 |
+
self.sum = 0
|
| 134 |
+
|
| 135 |
+
def load_audio(self, wave, t_start=None, t_end=None, without_file=False, sr_audio=16000):
|
| 136 |
+
hop_length = 512
|
| 137 |
+
if without_file:
|
| 138 |
+
y = wave
|
| 139 |
+
else:
|
| 140 |
+
y, sr = librosa.load(wave, sr=sr_audio)
|
| 141 |
+
|
| 142 |
+
short_y = y[t_start:t_end] if t_start is not None else y
|
| 143 |
+
short_y = short_y.astype(np.float32)
|
| 144 |
+
onset_t = librosa.onset.onset_detect(y=short_y, sr=sr_audio, hop_length=hop_length, units="time")
|
| 145 |
+
return onset_t
|
| 146 |
+
|
| 147 |
+
def load_motion(self, pose, t_start, t_end, pose_fps, without_file=False):
|
| 148 |
+
data_each_file = []
|
| 149 |
+
if without_file:
|
| 150 |
+
data_each_file = pose
|
| 151 |
+
else:
|
| 152 |
+
with open(pose, "r") as f:
|
| 153 |
+
for i, line_data in enumerate(f.readlines()):
|
| 154 |
+
if i < 432:
|
| 155 |
+
continue
|
| 156 |
+
line_data_np = np.fromstring(line_data, sep=" ")
|
| 157 |
+
if pose_fps == 15 and i % 2 == 0:
|
| 158 |
+
continue
|
| 159 |
+
data_each_file.append(np.concatenate([line_data_np[30:39], line_data_np[112:121]], 0))
|
| 160 |
+
data_each_file = np.array(data_each_file) # T*165
|
| 161 |
+
# print(data_each_file.shape)
|
| 162 |
+
joints = data_each_file.transpose(1, 0)
|
| 163 |
+
dt = 1 / pose_fps
|
| 164 |
+
init_vel = (joints[:, 1:2] - joints[:, :1]) / dt
|
| 165 |
+
middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt)
|
| 166 |
+
final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
|
| 167 |
+
vel = np.concatenate([init_vel, middle_vel, final_vel], 1).transpose(1, 0).reshape(data_each_file.shape[0], -1, 3)
|
| 168 |
+
# print(vel.shape)
|
| 169 |
+
|
| 170 |
+
if self.mmae is not None:
|
| 171 |
+
vel = np.linalg.norm(vel, axis=2) / self.mmae
|
| 172 |
+
else:
|
| 173 |
+
print("Warning: mmae is not provided, using max value of vel as mmae")
|
| 174 |
+
self.mmae = np.linalg.norm(vel, axis=2).max()
|
| 175 |
+
vel = np.linalg.norm(vel, axis=2) / self.mmae
|
| 176 |
+
# print(vel.shape) # T*J
|
| 177 |
+
|
| 178 |
+
beat_vel_all = []
|
| 179 |
+
for i in range(vel.shape[1]):
|
| 180 |
+
vel_mask = np.where(vel[:, i] > self.threshold)
|
| 181 |
+
beat_vel = argrelextrema(vel[t_start:t_end, i], np.less, order=self.order)
|
| 182 |
+
beat_vel_list = [j for j in beat_vel[0] if j in vel_mask[0]]
|
| 183 |
+
beat_vel_all.append(np.array(beat_vel_list))
|
| 184 |
+
return beat_vel_all
|
| 185 |
+
|
| 186 |
+
def eval_random_pose(self, wave, pose, t_start, t_end, pose_fps, num_random=60):
|
| 187 |
+
onset_raw = self.load_audio(wave, t_start, t_end)
|
| 188 |
+
dur = t_end - t_start
|
| 189 |
+
for i in range(num_random):
|
| 190 |
+
beat_vel_all = self.load_motion(pose, i, i + dur, pose_fps)
|
| 191 |
+
dis_all_b2a = self.compute(onset_raw, beat_vel_all)
|
| 192 |
+
print(f"{i}s: ", dis_all_b2a)
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def plot_onsets(audio, sr, onset_times_1, onset_times_2):
|
| 196 |
+
fig, axarr = plt.subplots(2, 1, figsize=(10, 10), sharex=True)
|
| 197 |
+
librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[0])
|
| 198 |
+
librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[1])
|
| 199 |
+
|
| 200 |
+
for onset in onset_times_1:
|
| 201 |
+
axarr[0].axvline(onset, color="r", linestyle="--", alpha=0.9, label="Onset Method 1")
|
| 202 |
+
axarr[0].legend()
|
| 203 |
+
axarr[0].set(title="Onset Method 1", xlabel="", ylabel="Amplitude")
|
| 204 |
+
|
| 205 |
+
for onset in onset_times_2:
|
| 206 |
+
axarr[1].axvline(onset, color="b", linestyle="-", alpha=0.7, label="Onset Method 2")
|
| 207 |
+
axarr[1].legend()
|
| 208 |
+
axarr[1].set(title="Onset Method 2", xlabel="Time (s)", ylabel="Amplitude")
|
| 209 |
+
|
| 210 |
+
handles, labels = plt.gca().get_legend_handles_labels()
|
| 211 |
+
by_label = dict(zip(labels, handles))
|
| 212 |
+
plt.legend(by_label.values(), by_label.keys())
|
| 213 |
+
plt.title("Audio waveform with Onsets")
|
| 214 |
+
plt.savefig("./onset.png", dpi=500)
|
| 215 |
+
|
| 216 |
+
def audio_beat_vis(self, onset_raw, onset_bt, onset_bt_rms):
|
| 217 |
+
fig, ax = plt.subplots(nrows=4, sharex=True)
|
| 218 |
+
librosa.display.specshow(librosa.amplitude_to_db(self.S, ref=np.max), y_axis="log", x_axis="time", ax=ax[0])
|
| 219 |
+
ax[1].plot(self.times, self.oenv, label="Onset strength")
|
| 220 |
+
ax[1].vlines(librosa.frames_to_time(onset_raw), 0, self.oenv.max(), label="Raw onsets", color="r")
|
| 221 |
+
ax[1].legend()
|
| 222 |
+
ax[2].vlines(librosa.frames_to_time(onset_bt), 0, self.oenv.max(), label="Backtracked", color="r")
|
| 223 |
+
ax[2].legend()
|
| 224 |
+
ax[3].vlines(librosa.frames_to_time(onset_bt_rms), 0, self.oenv.max(), label="Backtracked (RMS)", color="r")
|
| 225 |
+
ax[3].legend()
|
| 226 |
+
fig.savefig("./onset.png", dpi=500)
|
| 227 |
+
|
| 228 |
+
@staticmethod
|
| 229 |
+
def motion_frames2time(vel, offset, pose_fps):
|
| 230 |
+
return vel / pose_fps + offset
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def GAHR(a, b, sigma):
|
| 234 |
+
dis_all_b2a = 0
|
| 235 |
+
for b_each in b:
|
| 236 |
+
l2_min = min(abs(a_each - b_each) for a_each in a)
|
| 237 |
+
dis_all_b2a += math.exp(-(l2_min**2) / (2 * sigma**2))
|
| 238 |
+
return dis_all_b2a / len(b)
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def fix_directed_GAHR(a, b, sigma):
|
| 242 |
+
a = BC.motion_frames2time(a, 0, 30)
|
| 243 |
+
b = BC.motion_frames2time(b, 0, 30)
|
| 244 |
+
a = [0] + a + [len(a) / 30]
|
| 245 |
+
b = [0] + b + [len(b) / 30]
|
| 246 |
+
return BC.GAHR(a, b, sigma)
|
| 247 |
+
|
| 248 |
+
def compute(self, onset_bt_rms, beat_vel, length=1, pose_fps=30):
|
| 249 |
+
avg_dis_all_b2a_list = []
|
| 250 |
+
for its, beat_vel_each in enumerate(beat_vel):
|
| 251 |
+
if its not in self.upper_body:
|
| 252 |
+
continue
|
| 253 |
+
if beat_vel_each.size == 0:
|
| 254 |
+
avg_dis_all_b2a_list.append(0)
|
| 255 |
+
continue
|
| 256 |
+
pose_bt = self.motion_frames2time(beat_vel_each, 0, pose_fps)
|
| 257 |
+
avg_dis_all_b2a_list.append(self.GAHR(pose_bt, onset_bt_rms, self.sigma))
|
| 258 |
+
self.sum += (sum(avg_dis_all_b2a_list) / len(self.upper_body)) * length
|
| 259 |
+
self.counter += length
|
| 260 |
+
|
| 261 |
+
def avg(self):
|
| 262 |
+
return self.sum / self.counter
|
| 263 |
+
|
| 264 |
+
def reset(self):
|
| 265 |
+
self.counter = 0
|
| 266 |
+
self.sum = 0
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class Arg(object):
|
| 270 |
+
def __init__(self):
|
| 271 |
+
self.vae_length = 240
|
| 272 |
+
self.vae_test_dim = 330
|
| 273 |
+
self.vae_test_len = 32
|
| 274 |
+
self.vae_layer = 4
|
| 275 |
+
self.vae_test_stride = 20
|
| 276 |
+
self.vae_grow = [1, 1, 2, 1]
|
| 277 |
+
self.variational = False
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class FGD(object):
|
| 281 |
+
def __init__(self, download_path="./emage/", device="cuda"):
|
| 282 |
+
if download_path is not None:
|
| 283 |
+
os.makedirs(download_path, exist_ok=True)
|
| 284 |
+
model_file_path = os.path.join(download_path, "AESKConv_240_100.bin")
|
| 285 |
+
smplx_model_dir = os.path.join(download_path, "smplx_models", "smplx")
|
| 286 |
+
smplx_model_file_path = os.path.join(smplx_model_dir, "SMPLX_NEUTRAL_2020.npz")
|
| 287 |
+
if not os.path.exists(model_file_path):
|
| 288 |
+
print(f"Downloading {model_file_path}")
|
| 289 |
+
wget.download(
|
| 290 |
+
"https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/test_sequences/weights/AESKConv_240_100.bin",
|
| 291 |
+
model_file_path,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
os.makedirs(smplx_model_dir, exist_ok=True)
|
| 295 |
+
if not os.path.exists(smplx_model_file_path):
|
| 296 |
+
print(f"Downloading {smplx_model_file_path}")
|
| 297 |
+
wget.download(
|
| 298 |
+
"https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz",
|
| 299 |
+
smplx_model_file_path,
|
| 300 |
+
)
|
| 301 |
+
args = Arg()
|
| 302 |
+
self.eval_model = VAESKConv(args, model_save_path=download_path) # Assumes LocalEncoder is defined elsewhere
|
| 303 |
+
old_stat = torch.load(download_path + "AESKConv_240_100.bin")["model_state"]
|
| 304 |
+
new_stat = {}
|
| 305 |
+
for k, v in old_stat.items():
|
| 306 |
+
# If 'module.' is in the key, remove it
|
| 307 |
+
new_key = k.replace("module.", "") if "module." in k else k
|
| 308 |
+
new_stat[new_key] = v
|
| 309 |
+
self.eval_model.load_state_dict(new_stat)
|
| 310 |
+
|
| 311 |
+
self.eval_model.eval()
|
| 312 |
+
if torch.cuda.is_available():
|
| 313 |
+
self.eval_model.to(device)
|
| 314 |
+
|
| 315 |
+
self.pred_features = []
|
| 316 |
+
self.target_features = []
|
| 317 |
+
self.device = device
|
| 318 |
+
|
| 319 |
+
def reset(self):
|
| 320 |
+
self.pred_features = []
|
| 321 |
+
self.target_features = []
|
| 322 |
+
|
| 323 |
+
def get_feature(self, data):
|
| 324 |
+
assert len(data.shape) == 3
|
| 325 |
+
if data.shape[1] % 32 != 0:
|
| 326 |
+
drop_len = data.shape[1] % 32
|
| 327 |
+
data = data[:, :-drop_len]
|
| 328 |
+
# print(data.shape)
|
| 329 |
+
with torch.no_grad():
|
| 330 |
+
if torch.cuda.is_available():
|
| 331 |
+
data = data.to(self.device)
|
| 332 |
+
feature = self.eval_model.map2latent(data).cpu().numpy()
|
| 333 |
+
# print(feature.shape)
|
| 334 |
+
return feature
|
| 335 |
+
|
| 336 |
+
def update(self, pred, target):
|
| 337 |
+
self.pred_features.append(self.get_feature(pred))
|
| 338 |
+
self.target_features.append(self.get_feature(target))
|
| 339 |
+
|
| 340 |
+
def compute(self):
|
| 341 |
+
pred_features = np.concatenate([x.reshape(-1, x.shape[-1]) for x in self.pred_features], axis=0)
|
| 342 |
+
target_features = np.concatenate([x.reshape(-1, x.shape[-1]) for x in self.target_features], axis=0)
|
| 343 |
+
# print(pred_features.shape, target_features.shape)
|
| 344 |
+
return self.frechet_distance(pred_features, target_features)
|
| 345 |
+
|
| 346 |
+
@staticmethod
|
| 347 |
+
def frechet_distance(samples_A, samples_B, eps=1e-6):
|
| 348 |
+
mu1 = np.mean(samples_A, axis=0)
|
| 349 |
+
sigma1 = np.cov(samples_A, rowvar=False)
|
| 350 |
+
mu2 = np.mean(samples_B, axis=0)
|
| 351 |
+
sigma2 = np.cov(samples_B, rowvar=False)
|
| 352 |
+
diff = mu1 - mu2
|
| 353 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 354 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 355 |
+
if np.iscomplexobj(covmean):
|
| 356 |
+
covmean = covmean.real
|
| 357 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
|
motion_encoder.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from .skeleton_DME import SkeletonConv, SkeletonPool, find_neighbor, build_edge_topology
|
| 5 |
+
from .skeleton import SkeletonResidual
|
| 6 |
+
from .decoders import VQDecoderV3
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LocalEncoder(nn.Module):
|
| 10 |
+
def __init__(self, args, topology):
|
| 11 |
+
super(LocalEncoder, self).__init__()
|
| 12 |
+
args.channel_base = 6
|
| 13 |
+
args.activation = "tanh"
|
| 14 |
+
args.use_residual_blocks = True
|
| 15 |
+
args.z_dim = 1024
|
| 16 |
+
args.temporal_scale = 8
|
| 17 |
+
args.kernel_size = 4
|
| 18 |
+
args.num_layers = args.vae_layer
|
| 19 |
+
args.skeleton_dist = 2
|
| 20 |
+
args.extra_conv = 0
|
| 21 |
+
# check how to reflect in 1d
|
| 22 |
+
args.padding_mode = "constant"
|
| 23 |
+
args.skeleton_pool = "mean"
|
| 24 |
+
args.upsampling = "linear"
|
| 25 |
+
|
| 26 |
+
self.topologies = [topology]
|
| 27 |
+
self.channel_base = [args.channel_base]
|
| 28 |
+
|
| 29 |
+
self.channel_list = []
|
| 30 |
+
self.edge_num = [len(topology)]
|
| 31 |
+
self.pooling_list = []
|
| 32 |
+
self.layers = nn.ModuleList()
|
| 33 |
+
self.args = args
|
| 34 |
+
# self.convs = []
|
| 35 |
+
|
| 36 |
+
kernel_size = args.kernel_size
|
| 37 |
+
kernel_even = False if kernel_size % 2 else True
|
| 38 |
+
padding = (kernel_size - 1) // 2
|
| 39 |
+
bias = True
|
| 40 |
+
self.grow = args.vae_grow
|
| 41 |
+
for i in range(args.num_layers):
|
| 42 |
+
self.channel_base.append(self.channel_base[-1] * self.grow[i])
|
| 43 |
+
|
| 44 |
+
for i in range(args.num_layers):
|
| 45 |
+
seq = []
|
| 46 |
+
neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist)
|
| 47 |
+
in_channels = self.channel_base[i] * self.edge_num[i]
|
| 48 |
+
out_channels = self.channel_base[i + 1] * self.edge_num[i]
|
| 49 |
+
if i == 0:
|
| 50 |
+
self.channel_list.append(in_channels)
|
| 51 |
+
self.channel_list.append(out_channels)
|
| 52 |
+
last_pool = True if i == args.num_layers - 1 else False
|
| 53 |
+
|
| 54 |
+
# (T, J, D) => (T, J', D)
|
| 55 |
+
pool = SkeletonPool(
|
| 56 |
+
edges=self.topologies[i],
|
| 57 |
+
pooling_mode=args.skeleton_pool,
|
| 58 |
+
channels_per_edge=out_channels // len(neighbour_list),
|
| 59 |
+
last_pool=last_pool,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if args.use_residual_blocks:
|
| 63 |
+
# (T, J, D) => (T/2, J', 2D)
|
| 64 |
+
seq.append(
|
| 65 |
+
SkeletonResidual(
|
| 66 |
+
self.topologies[i],
|
| 67 |
+
neighbour_list,
|
| 68 |
+
joint_num=self.edge_num[i],
|
| 69 |
+
in_channels=in_channels,
|
| 70 |
+
out_channels=out_channels,
|
| 71 |
+
kernel_size=kernel_size,
|
| 72 |
+
stride=2,
|
| 73 |
+
padding=padding,
|
| 74 |
+
padding_mode=args.padding_mode,
|
| 75 |
+
bias=bias,
|
| 76 |
+
extra_conv=args.extra_conv,
|
| 77 |
+
pooling_mode=args.skeleton_pool,
|
| 78 |
+
activation=args.activation,
|
| 79 |
+
last_pool=last_pool,
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
for _ in range(args.extra_conv):
|
| 84 |
+
# (T, J, D) => (T, J, D)
|
| 85 |
+
seq.append(
|
| 86 |
+
SkeletonConv(
|
| 87 |
+
neighbour_list,
|
| 88 |
+
in_channels=in_channels,
|
| 89 |
+
out_channels=in_channels,
|
| 90 |
+
joint_num=self.edge_num[i],
|
| 91 |
+
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
| 92 |
+
stride=1,
|
| 93 |
+
padding=padding,
|
| 94 |
+
padding_mode=args.padding_mode,
|
| 95 |
+
bias=bias,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
seq.append(nn.PReLU() if args.activation == "relu" else nn.Tanh())
|
| 99 |
+
# (T, J, D) => (T/2, J, 2D)
|
| 100 |
+
seq.append(
|
| 101 |
+
SkeletonConv(
|
| 102 |
+
neighbour_list,
|
| 103 |
+
in_channels=in_channels,
|
| 104 |
+
out_channels=out_channels,
|
| 105 |
+
joint_num=self.edge_num[i],
|
| 106 |
+
kernel_size=kernel_size,
|
| 107 |
+
stride=2,
|
| 108 |
+
padding=padding,
|
| 109 |
+
padding_mode=args.padding_mode,
|
| 110 |
+
bias=bias,
|
| 111 |
+
add_offset=False,
|
| 112 |
+
in_offset_channel=3 * self.channel_base[i] // self.channel_base[0],
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
# self.convs.append(seq[-1])
|
| 116 |
+
|
| 117 |
+
seq.append(pool)
|
| 118 |
+
seq.append(nn.PReLU() if args.activation == "relu" else nn.Tanh())
|
| 119 |
+
self.layers.append(nn.Sequential(*seq))
|
| 120 |
+
|
| 121 |
+
self.topologies.append(pool.new_edges)
|
| 122 |
+
self.pooling_list.append(pool.pooling_list)
|
| 123 |
+
self.edge_num.append(len(self.topologies[-1]))
|
| 124 |
+
|
| 125 |
+
# in_features = self.channel_base[-1] * len(self.pooling_list[-1])
|
| 126 |
+
# in_features *= int(args.temporal_scale / 2)
|
| 127 |
+
# self.reduce = nn.Linear(in_features, args.z_dim)
|
| 128 |
+
# self.mu = nn.Linear(in_features, args.z_dim)
|
| 129 |
+
# self.logvar = nn.Linear(in_features, args.z_dim)
|
| 130 |
+
|
| 131 |
+
def forward(self, input):
|
| 132 |
+
# bs, n, c = input.shape[0], input.shape[1], input.shape[2]
|
| 133 |
+
output = input.permute(0, 2, 1) # input.reshape(bs, n, -1, 6)
|
| 134 |
+
for layer in self.layers:
|
| 135 |
+
output = layer(output)
|
| 136 |
+
# output = output.view(output.shape[0], -1)
|
| 137 |
+
output = output.permute(0, 2, 1)
|
| 138 |
+
return output
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def reparameterize(mu, logvar):
|
| 142 |
+
std = torch.exp(0.5 * logvar)
|
| 143 |
+
eps = torch.randn_like(std)
|
| 144 |
+
return mu + eps * std
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class VAEConv(nn.Module):
|
| 148 |
+
def __init__(self, args):
|
| 149 |
+
super(VAEConv, self).__init__()
|
| 150 |
+
# self.encoder = VQEncoderV3(args)
|
| 151 |
+
# self.decoder = VQDecoderV3(args)
|
| 152 |
+
self.fc_mu = nn.Linear(args.vae_length, args.vae_length)
|
| 153 |
+
self.fc_logvar = nn.Linear(args.vae_length, args.vae_length)
|
| 154 |
+
self.variational = args.variational
|
| 155 |
+
|
| 156 |
+
def forward(self, inputs):
|
| 157 |
+
pre_latent = self.encoder(inputs)
|
| 158 |
+
mu, logvar = None, None
|
| 159 |
+
if self.variational:
|
| 160 |
+
mu = self.fc_mu(pre_latent)
|
| 161 |
+
logvar = self.fc_logvar(pre_latent)
|
| 162 |
+
pre_latent = reparameterize(mu, logvar)
|
| 163 |
+
rec_pose = self.decoder(pre_latent)
|
| 164 |
+
return {
|
| 165 |
+
"poses_feat": pre_latent,
|
| 166 |
+
"rec_pose": rec_pose,
|
| 167 |
+
"pose_mu": mu,
|
| 168 |
+
"pose_logvar": logvar,
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
def map2latent(self, inputs):
|
| 172 |
+
pre_latent = self.encoder(inputs)
|
| 173 |
+
if self.variational:
|
| 174 |
+
mu = self.fc_mu(pre_latent)
|
| 175 |
+
logvar = self.fc_logvar(pre_latent)
|
| 176 |
+
pre_latent = reparameterize(mu, logvar)
|
| 177 |
+
return pre_latent
|
| 178 |
+
|
| 179 |
+
def decode(self, pre_latent):
|
| 180 |
+
rec_pose = self.decoder(pre_latent)
|
| 181 |
+
return rec_pose
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class VAESKConv(VAEConv):
|
| 185 |
+
def __init__(self, args, model_save_path="./emage/"):
|
| 186 |
+
# args = args()
|
| 187 |
+
super(VAESKConv, self).__init__(args)
|
| 188 |
+
smpl_fname = model_save_path + "smplx_models/smplx/SMPLX_NEUTRAL_2020.npz"
|
| 189 |
+
smpl_data = np.load(smpl_fname, encoding="latin1")
|
| 190 |
+
parents = smpl_data["kintree_table"][0].astype(np.int32)
|
| 191 |
+
edges = build_edge_topology(parents)
|
| 192 |
+
self.encoder = LocalEncoder(args, edges)
|
| 193 |
+
self.decoder = VQDecoderV3(args)
|
skeleton.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .skeleton_DME import SkeletonConv, SkeletonPool, SkeletonUnpool
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def calc_node_depth(topology):
|
| 8 |
+
def dfs(node, topology):
|
| 9 |
+
if topology[node] < 0:
|
| 10 |
+
return 0
|
| 11 |
+
return 1 + dfs(topology[node], topology)
|
| 12 |
+
|
| 13 |
+
depth = []
|
| 14 |
+
for i in range(len(topology)):
|
| 15 |
+
depth.append(dfs(i, topology))
|
| 16 |
+
|
| 17 |
+
return depth
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def residual_ratio(k):
|
| 21 |
+
return 1 / (k + 1)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Affine(nn.Module):
|
| 25 |
+
def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0):
|
| 26 |
+
super(Affine, self).__init__()
|
| 27 |
+
if scale:
|
| 28 |
+
self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init)
|
| 29 |
+
else:
|
| 30 |
+
self.register_parameter("scale", None)
|
| 31 |
+
|
| 32 |
+
if bias:
|
| 33 |
+
self.bias = nn.Parameter(torch.zeros(num_parameters))
|
| 34 |
+
else:
|
| 35 |
+
self.register_parameter("bias", None)
|
| 36 |
+
|
| 37 |
+
def forward(self, input):
|
| 38 |
+
output = input
|
| 39 |
+
if self.scale is not None:
|
| 40 |
+
scale = self.scale.unsqueeze(0)
|
| 41 |
+
while scale.dim() < input.dim():
|
| 42 |
+
scale = scale.unsqueeze(2)
|
| 43 |
+
output = output.mul(scale)
|
| 44 |
+
|
| 45 |
+
if self.bias is not None:
|
| 46 |
+
bias = self.bias.unsqueeze(0)
|
| 47 |
+
while bias.dim() < input.dim():
|
| 48 |
+
bias = bias.unsqueeze(2)
|
| 49 |
+
output += bias
|
| 50 |
+
|
| 51 |
+
return output
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BatchStatistics(nn.Module):
|
| 55 |
+
def __init__(self, affine=-1):
|
| 56 |
+
super(BatchStatistics, self).__init__()
|
| 57 |
+
self.affine = nn.Sequential() if affine == -1 else Affine(affine)
|
| 58 |
+
self.loss = 0
|
| 59 |
+
|
| 60 |
+
def clear_loss(self):
|
| 61 |
+
self.loss = 0
|
| 62 |
+
|
| 63 |
+
def compute_loss(self, input):
|
| 64 |
+
input_flat = input.view(input.size(1), input.numel() // input.size(1))
|
| 65 |
+
mu = input_flat.mean(1)
|
| 66 |
+
logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log()
|
| 67 |
+
|
| 68 |
+
self.loss = mu.pow(2).mean() + logvar.pow(2).mean()
|
| 69 |
+
|
| 70 |
+
def forward(self, input):
|
| 71 |
+
self.compute_loss(input)
|
| 72 |
+
return self.affine(input)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ResidualBlock(nn.Module):
|
| 76 |
+
def __init__(
|
| 77 |
+
self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False
|
| 78 |
+
):
|
| 79 |
+
super(ResidualBlock, self).__init__()
|
| 80 |
+
|
| 81 |
+
self.residual_ratio = residual_ratio
|
| 82 |
+
self.shortcut_ratio = 1 - residual_ratio
|
| 83 |
+
|
| 84 |
+
residual = []
|
| 85 |
+
residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding))
|
| 86 |
+
if batch_statistics:
|
| 87 |
+
residual.append(BatchStatistics(out_channels))
|
| 88 |
+
if not last_layer:
|
| 89 |
+
residual.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
| 90 |
+
self.residual = nn.Sequential(*residual)
|
| 91 |
+
|
| 92 |
+
self.shortcut = nn.Sequential(
|
| 93 |
+
nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(),
|
| 94 |
+
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
|
| 95 |
+
BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential(),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def forward(self, input):
|
| 99 |
+
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ResidualBlockTranspose(nn.Module):
|
| 103 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation):
|
| 104 |
+
super(ResidualBlockTranspose, self).__init__()
|
| 105 |
+
|
| 106 |
+
self.residual_ratio = residual_ratio
|
| 107 |
+
self.shortcut_ratio = 1 - residual_ratio
|
| 108 |
+
|
| 109 |
+
self.residual = nn.Sequential(
|
| 110 |
+
nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), nn.PReLU() if activation == "relu" else nn.Tanh()
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.shortcut = nn.Sequential(
|
| 114 |
+
nn.Upsample(scale_factor=2, mode="linear", align_corners=False) if stride == 2 else nn.Sequential(),
|
| 115 |
+
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def forward(self, input):
|
| 119 |
+
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class SkeletonResidual(nn.Module):
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
topology,
|
| 126 |
+
neighbour_list,
|
| 127 |
+
joint_num,
|
| 128 |
+
in_channels,
|
| 129 |
+
out_channels,
|
| 130 |
+
kernel_size,
|
| 131 |
+
stride,
|
| 132 |
+
padding,
|
| 133 |
+
padding_mode,
|
| 134 |
+
bias,
|
| 135 |
+
extra_conv,
|
| 136 |
+
pooling_mode,
|
| 137 |
+
activation,
|
| 138 |
+
last_pool,
|
| 139 |
+
):
|
| 140 |
+
super(SkeletonResidual, self).__init__()
|
| 141 |
+
|
| 142 |
+
kernel_even = False if kernel_size % 2 else True
|
| 143 |
+
|
| 144 |
+
seq = []
|
| 145 |
+
for _ in range(extra_conv):
|
| 146 |
+
# (T, J, D) => (T, J, D)
|
| 147 |
+
seq.append(
|
| 148 |
+
SkeletonConv(
|
| 149 |
+
neighbour_list,
|
| 150 |
+
in_channels=in_channels,
|
| 151 |
+
out_channels=in_channels,
|
| 152 |
+
joint_num=joint_num,
|
| 153 |
+
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
| 154 |
+
stride=1,
|
| 155 |
+
padding=padding,
|
| 156 |
+
padding_mode=padding_mode,
|
| 157 |
+
bias=bias,
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
| 161 |
+
# (T, J, D) => (T/2, J, 2D)
|
| 162 |
+
seq.append(
|
| 163 |
+
SkeletonConv(
|
| 164 |
+
neighbour_list,
|
| 165 |
+
in_channels=in_channels,
|
| 166 |
+
out_channels=out_channels,
|
| 167 |
+
joint_num=joint_num,
|
| 168 |
+
kernel_size=kernel_size,
|
| 169 |
+
stride=stride,
|
| 170 |
+
padding=padding,
|
| 171 |
+
padding_mode=padding_mode,
|
| 172 |
+
bias=bias,
|
| 173 |
+
add_offset=False,
|
| 174 |
+
)
|
| 175 |
+
)
|
| 176 |
+
seq.append(nn.GroupNorm(10, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!!
|
| 177 |
+
self.residual = nn.Sequential(*seq)
|
| 178 |
+
|
| 179 |
+
# (T, J, D) => (T/2, J, 2D)
|
| 180 |
+
self.shortcut = SkeletonConv(
|
| 181 |
+
neighbour_list,
|
| 182 |
+
in_channels=in_channels,
|
| 183 |
+
out_channels=out_channels,
|
| 184 |
+
joint_num=joint_num,
|
| 185 |
+
kernel_size=1,
|
| 186 |
+
stride=stride,
|
| 187 |
+
padding=0,
|
| 188 |
+
bias=True,
|
| 189 |
+
add_offset=False,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
seq = []
|
| 193 |
+
# (T/2, J, 2D) => (T/2, J', 2D)
|
| 194 |
+
pool = SkeletonPool(
|
| 195 |
+
edges=topology, pooling_mode=pooling_mode, channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool
|
| 196 |
+
)
|
| 197 |
+
if len(pool.pooling_list) != pool.edge_num:
|
| 198 |
+
seq.append(pool)
|
| 199 |
+
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
| 200 |
+
self.common = nn.Sequential(*seq)
|
| 201 |
+
|
| 202 |
+
def forward(self, input):
|
| 203 |
+
output = self.residual(input) + self.shortcut(input)
|
| 204 |
+
|
| 205 |
+
return self.common(output)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class SkeletonResidualTranspose(nn.Module):
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
neighbour_list,
|
| 212 |
+
joint_num,
|
| 213 |
+
in_channels,
|
| 214 |
+
out_channels,
|
| 215 |
+
kernel_size,
|
| 216 |
+
padding,
|
| 217 |
+
padding_mode,
|
| 218 |
+
bias,
|
| 219 |
+
extra_conv,
|
| 220 |
+
pooling_list,
|
| 221 |
+
upsampling,
|
| 222 |
+
activation,
|
| 223 |
+
last_layer,
|
| 224 |
+
):
|
| 225 |
+
super(SkeletonResidualTranspose, self).__init__()
|
| 226 |
+
|
| 227 |
+
kernel_even = False if kernel_size % 2 else True
|
| 228 |
+
|
| 229 |
+
seq = []
|
| 230 |
+
# (T, J, D) => (2T, J, D)
|
| 231 |
+
if upsampling is not None:
|
| 232 |
+
seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False))
|
| 233 |
+
# (2T, J, D) => (2T, J', D)
|
| 234 |
+
unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list))
|
| 235 |
+
if unpool.input_edge_num != unpool.output_edge_num:
|
| 236 |
+
seq.append(unpool)
|
| 237 |
+
self.common = nn.Sequential(*seq)
|
| 238 |
+
|
| 239 |
+
seq = []
|
| 240 |
+
for _ in range(extra_conv):
|
| 241 |
+
# (2T, J', D) => (2T, J', D)
|
| 242 |
+
seq.append(
|
| 243 |
+
SkeletonConv(
|
| 244 |
+
neighbour_list,
|
| 245 |
+
in_channels=in_channels,
|
| 246 |
+
out_channels=in_channels,
|
| 247 |
+
joint_num=joint_num,
|
| 248 |
+
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
| 249 |
+
stride=1,
|
| 250 |
+
padding=padding,
|
| 251 |
+
padding_mode=padding_mode,
|
| 252 |
+
bias=bias,
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
| 256 |
+
# (2T, J', D) => (2T, J', D/2)
|
| 257 |
+
seq.append(
|
| 258 |
+
SkeletonConv(
|
| 259 |
+
neighbour_list,
|
| 260 |
+
in_channels=in_channels,
|
| 261 |
+
out_channels=out_channels,
|
| 262 |
+
joint_num=joint_num,
|
| 263 |
+
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
| 264 |
+
stride=1,
|
| 265 |
+
padding=padding,
|
| 266 |
+
padding_mode=padding_mode,
|
| 267 |
+
bias=bias,
|
| 268 |
+
add_offset=False,
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
self.residual = nn.Sequential(*seq)
|
| 272 |
+
|
| 273 |
+
# (2T, J', D) => (2T, J', D/2)
|
| 274 |
+
self.shortcut = SkeletonConv(
|
| 275 |
+
neighbour_list,
|
| 276 |
+
in_channels=in_channels,
|
| 277 |
+
out_channels=out_channels,
|
| 278 |
+
joint_num=joint_num,
|
| 279 |
+
kernel_size=1,
|
| 280 |
+
stride=1,
|
| 281 |
+
padding=0,
|
| 282 |
+
bias=True,
|
| 283 |
+
add_offset=False,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if activation == "relu":
|
| 287 |
+
self.activation = nn.PReLU() if not last_layer else None
|
| 288 |
+
else:
|
| 289 |
+
self.activation = nn.Tanh() if not last_layer else None
|
| 290 |
+
|
| 291 |
+
def forward(self, input):
|
| 292 |
+
output = self.common(input)
|
| 293 |
+
output = self.residual(output) + self.shortcut(output)
|
| 294 |
+
|
| 295 |
+
if self.activation is not None:
|
| 296 |
+
return self.activation(output)
|
| 297 |
+
else:
|
| 298 |
+
return output
|
skeleton_DME.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This script is modified from https://github.com/DeepMotionEditing/deep-motion-editing
|
| 2 |
+
# Licensed under:
|
| 3 |
+
"""
|
| 4 |
+
Copyright (c) 2020, Kfir Aberman, Peizhuo Li, Yijia Weng, Dani Lischinski, Olga Sorkine-Hornung, Daniel Cohen-Or and Baoquan Chen.
|
| 5 |
+
All rights reserved.
|
| 6 |
+
|
| 7 |
+
Redistribution and use in source and binary forms, with or without
|
| 8 |
+
modification, are permitted provided that the following conditions are met:
|
| 9 |
+
|
| 10 |
+
* Redistributions of source code must retain the above copyright notice, this
|
| 11 |
+
list of conditions and the following disclaimer.
|
| 12 |
+
|
| 13 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
| 14 |
+
this list of conditions and the following disclaimer in the documentation
|
| 15 |
+
and/or other materials provided with the distribution.
|
| 16 |
+
|
| 17 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 18 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 19 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 20 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 21 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 22 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 23 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 24 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 25 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 26 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import math
|
| 30 |
+
import numpy as np
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SkeletonConv(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
neighbour_list,
|
| 40 |
+
in_channels,
|
| 41 |
+
out_channels,
|
| 42 |
+
kernel_size,
|
| 43 |
+
joint_num,
|
| 44 |
+
stride=1,
|
| 45 |
+
padding=0,
|
| 46 |
+
bias=True,
|
| 47 |
+
padding_mode="zeros",
|
| 48 |
+
add_offset=False,
|
| 49 |
+
in_offset_channel=0,
|
| 50 |
+
):
|
| 51 |
+
self.in_channels_per_joint = in_channels // joint_num
|
| 52 |
+
self.out_channels_per_joint = out_channels // joint_num
|
| 53 |
+
if in_channels % joint_num != 0 or out_channels % joint_num != 0:
|
| 54 |
+
raise Exception("BAD")
|
| 55 |
+
super(SkeletonConv, self).__init__()
|
| 56 |
+
|
| 57 |
+
if padding_mode == "zeros":
|
| 58 |
+
padding_mode = "constant"
|
| 59 |
+
if padding_mode == "reflection":
|
| 60 |
+
padding_mode = "reflect"
|
| 61 |
+
|
| 62 |
+
self.expanded_neighbour_list = []
|
| 63 |
+
self.expanded_neighbour_list_offset = []
|
| 64 |
+
self.neighbour_list = neighbour_list
|
| 65 |
+
self.add_offset = add_offset
|
| 66 |
+
self.joint_num = joint_num
|
| 67 |
+
|
| 68 |
+
self.stride = stride
|
| 69 |
+
self.dilation = 1
|
| 70 |
+
self.groups = 1
|
| 71 |
+
self.padding = padding
|
| 72 |
+
self.padding_mode = padding_mode
|
| 73 |
+
self._padding_repeated_twice = (padding, padding)
|
| 74 |
+
|
| 75 |
+
for neighbour in neighbour_list:
|
| 76 |
+
expanded = []
|
| 77 |
+
for k in neighbour:
|
| 78 |
+
for i in range(self.in_channels_per_joint):
|
| 79 |
+
expanded.append(k * self.in_channels_per_joint + i)
|
| 80 |
+
self.expanded_neighbour_list.append(expanded)
|
| 81 |
+
|
| 82 |
+
if self.add_offset:
|
| 83 |
+
self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)
|
| 84 |
+
|
| 85 |
+
for neighbour in neighbour_list:
|
| 86 |
+
expanded = []
|
| 87 |
+
for k in neighbour:
|
| 88 |
+
for i in range(add_offset):
|
| 89 |
+
expanded.append(k * in_offset_channel + i)
|
| 90 |
+
self.expanded_neighbour_list_offset.append(expanded)
|
| 91 |
+
|
| 92 |
+
self.weight = torch.zeros(out_channels, in_channels, kernel_size)
|
| 93 |
+
if bias:
|
| 94 |
+
self.bias = torch.zeros(out_channels)
|
| 95 |
+
else:
|
| 96 |
+
self.register_parameter("bias", None)
|
| 97 |
+
|
| 98 |
+
self.mask = torch.zeros_like(self.weight)
|
| 99 |
+
for i, neighbour in enumerate(self.expanded_neighbour_list):
|
| 100 |
+
self.mask[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
|
| 101 |
+
self.mask = nn.Parameter(self.mask, requires_grad=False)
|
| 102 |
+
|
| 103 |
+
self.description = (
|
| 104 |
+
"SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, "
|
| 105 |
+
"joint_num={}, stride={}, padding={}, bias={})".format(
|
| 106 |
+
in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.reset_parameters()
|
| 111 |
+
|
| 112 |
+
def reset_parameters(self):
|
| 113 |
+
for i, neighbour in enumerate(self.expanded_neighbour_list):
|
| 114 |
+
""" Use temporary variable to avoid assign to copy of slice, which might lead to unexpected result """
|
| 115 |
+
tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...])
|
| 116 |
+
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
|
| 117 |
+
self.weight[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...] = tmp
|
| 118 |
+
if self.bias is not None:
|
| 119 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
|
| 120 |
+
self.weight[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...]
|
| 121 |
+
)
|
| 122 |
+
bound = 1 / math.sqrt(fan_in)
|
| 123 |
+
tmp = torch.zeros_like(self.bias[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1)])
|
| 124 |
+
nn.init.uniform_(tmp, -bound, bound)
|
| 125 |
+
self.bias[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1)] = tmp
|
| 126 |
+
|
| 127 |
+
self.weight = nn.Parameter(self.weight)
|
| 128 |
+
if self.bias is not None:
|
| 129 |
+
self.bias = nn.Parameter(self.bias)
|
| 130 |
+
|
| 131 |
+
def set_offset(self, offset):
|
| 132 |
+
if not self.add_offset:
|
| 133 |
+
raise Exception("Wrong Combination of Parameters")
|
| 134 |
+
self.offset = offset.reshape(offset.shape[0], -1)
|
| 135 |
+
|
| 136 |
+
def forward(self, input):
|
| 137 |
+
# print('SkeletonConv')
|
| 138 |
+
weight_masked = self.weight * self.mask
|
| 139 |
+
# print(f'input: {input.size()}')
|
| 140 |
+
res = F.conv1d(
|
| 141 |
+
F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
|
| 142 |
+
weight_masked,
|
| 143 |
+
self.bias,
|
| 144 |
+
self.stride,
|
| 145 |
+
0,
|
| 146 |
+
self.dilation,
|
| 147 |
+
self.groups,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if self.add_offset:
|
| 151 |
+
offset_res = self.offset_enc(self.offset)
|
| 152 |
+
offset_res = offset_res.reshape(offset_res.shape + (1,))
|
| 153 |
+
res += offset_res / 100
|
| 154 |
+
# print(f'res: {res.size()}')
|
| 155 |
+
return res
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class SkeletonLinear(nn.Module):
|
| 159 |
+
def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
|
| 160 |
+
super(SkeletonLinear, self).__init__()
|
| 161 |
+
self.neighbour_list = neighbour_list
|
| 162 |
+
self.in_channels = in_channels
|
| 163 |
+
self.out_channels = out_channels
|
| 164 |
+
self.in_channels_per_joint = in_channels // len(neighbour_list)
|
| 165 |
+
self.out_channels_per_joint = out_channels // len(neighbour_list)
|
| 166 |
+
self.extra_dim1 = extra_dim1
|
| 167 |
+
self.expanded_neighbour_list = []
|
| 168 |
+
|
| 169 |
+
for neighbour in neighbour_list:
|
| 170 |
+
expanded = []
|
| 171 |
+
for k in neighbour:
|
| 172 |
+
for i in range(self.in_channels_per_joint):
|
| 173 |
+
expanded.append(k * self.in_channels_per_joint + i)
|
| 174 |
+
self.expanded_neighbour_list.append(expanded)
|
| 175 |
+
|
| 176 |
+
self.weight = torch.zeros(out_channels, in_channels)
|
| 177 |
+
self.mask = torch.zeros(out_channels, in_channels)
|
| 178 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
| 179 |
+
|
| 180 |
+
self.reset_parameters()
|
| 181 |
+
|
| 182 |
+
def reset_parameters(self):
|
| 183 |
+
for i, neighbour in enumerate(self.expanded_neighbour_list):
|
| 184 |
+
tmp = torch.zeros_like(self.weight[i * self.out_channels_per_joint : (i + 1) * self.out_channels_per_joint, neighbour])
|
| 185 |
+
self.mask[i * self.out_channels_per_joint : (i + 1) * self.out_channels_per_joint, neighbour] = 1
|
| 186 |
+
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
|
| 187 |
+
self.weight[i * self.out_channels_per_joint : (i + 1) * self.out_channels_per_joint, neighbour] = tmp
|
| 188 |
+
|
| 189 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
| 190 |
+
bound = 1 / math.sqrt(fan_in)
|
| 191 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
| 192 |
+
|
| 193 |
+
self.weight = nn.Parameter(self.weight)
|
| 194 |
+
self.mask = nn.Parameter(self.mask, requires_grad=False)
|
| 195 |
+
|
| 196 |
+
def forward(self, input):
|
| 197 |
+
input = input.reshape(input.shape[0], -1)
|
| 198 |
+
weight_masked = self.weight * self.mask
|
| 199 |
+
res = F.linear(input, weight_masked, self.bias)
|
| 200 |
+
if self.extra_dim1:
|
| 201 |
+
res = res.reshape(res.shape + (1,))
|
| 202 |
+
return res
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class SkeletonPool(nn.Module):
|
| 206 |
+
def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
|
| 207 |
+
super(SkeletonPool, self).__init__()
|
| 208 |
+
|
| 209 |
+
if pooling_mode != "mean":
|
| 210 |
+
raise Exception("Unimplemented pooling mode in matrix_implementation")
|
| 211 |
+
|
| 212 |
+
self.channels_per_edge = channels_per_edge
|
| 213 |
+
self.pooling_mode = pooling_mode
|
| 214 |
+
self.edge_num = len(edges)
|
| 215 |
+
# self.edge_num = len(edges) + 1
|
| 216 |
+
self.seq_list = []
|
| 217 |
+
self.pooling_list = []
|
| 218 |
+
self.new_edges = []
|
| 219 |
+
degree = [0] * 100 # each element represents the degree of the corresponding joint
|
| 220 |
+
|
| 221 |
+
for edge in edges:
|
| 222 |
+
degree[edge[0]] += 1
|
| 223 |
+
degree[edge[1]] += 1
|
| 224 |
+
|
| 225 |
+
# seq_list contains multiple sub-lists where each sub-list is an edge chain from the joint whose degree > 2 to the end effectors or joints whose degree > 2.
|
| 226 |
+
def find_seq(j, seq):
|
| 227 |
+
nonlocal self, degree, edges
|
| 228 |
+
|
| 229 |
+
if degree[j] > 2 and j != 0:
|
| 230 |
+
self.seq_list.append(seq)
|
| 231 |
+
seq = []
|
| 232 |
+
|
| 233 |
+
if degree[j] == 1:
|
| 234 |
+
self.seq_list.append(seq)
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
for idx, edge in enumerate(edges):
|
| 238 |
+
if edge[0] == j:
|
| 239 |
+
find_seq(edge[1], seq + [idx])
|
| 240 |
+
|
| 241 |
+
find_seq(0, [])
|
| 242 |
+
# print(f'self.seq_list: {self.seq_list}')
|
| 243 |
+
|
| 244 |
+
for seq in self.seq_list:
|
| 245 |
+
if last_pool:
|
| 246 |
+
self.pooling_list.append(seq)
|
| 247 |
+
continue
|
| 248 |
+
if len(seq) % 2 == 1:
|
| 249 |
+
self.pooling_list.append([seq[0]])
|
| 250 |
+
self.new_edges.append(edges[seq[0]])
|
| 251 |
+
seq = seq[1:]
|
| 252 |
+
for i in range(0, len(seq), 2):
|
| 253 |
+
self.pooling_list.append([seq[i], seq[i + 1]])
|
| 254 |
+
self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
|
| 255 |
+
# print(f'self.pooling_list: {self.pooling_list}')
|
| 256 |
+
# print(f'self.new_egdes: {self.new_edges}')
|
| 257 |
+
|
| 258 |
+
# add global position
|
| 259 |
+
# self.pooling_list.append([self.edge_num - 1])
|
| 260 |
+
|
| 261 |
+
self.description = "SkeletonPool(in_edge_num={}, out_edge_num={})".format(len(edges), len(self.pooling_list))
|
| 262 |
+
|
| 263 |
+
self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)
|
| 264 |
+
|
| 265 |
+
for i, pair in enumerate(self.pooling_list):
|
| 266 |
+
for j in pair:
|
| 267 |
+
for c in range(channels_per_edge):
|
| 268 |
+
self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)
|
| 269 |
+
|
| 270 |
+
self.weight = nn.Parameter(self.weight, requires_grad=False)
|
| 271 |
+
|
| 272 |
+
def forward(self, input: torch.Tensor):
|
| 273 |
+
# print('SkeletonPool')
|
| 274 |
+
# print(f'input: {input.size()}')
|
| 275 |
+
# print(f'self.weight: {self.weight.size()}')
|
| 276 |
+
return torch.matmul(self.weight, input)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class SkeletonUnpool(nn.Module):
|
| 280 |
+
def __init__(self, pooling_list, channels_per_edge):
|
| 281 |
+
super(SkeletonUnpool, self).__init__()
|
| 282 |
+
self.pooling_list = pooling_list
|
| 283 |
+
self.input_edge_num = len(pooling_list)
|
| 284 |
+
self.output_edge_num = 0
|
| 285 |
+
self.channels_per_edge = channels_per_edge
|
| 286 |
+
for t in self.pooling_list:
|
| 287 |
+
self.output_edge_num += len(t)
|
| 288 |
+
|
| 289 |
+
self.description = "SkeletonUnpool(in_edge_num={}, out_edge_num={})".format(
|
| 290 |
+
self.input_edge_num,
|
| 291 |
+
self.output_edge_num,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
self.weight = torch.zeros(self.output_edge_num * channels_per_edge, self.input_edge_num * channels_per_edge)
|
| 295 |
+
|
| 296 |
+
for i, pair in enumerate(self.pooling_list):
|
| 297 |
+
for j in pair:
|
| 298 |
+
for c in range(channels_per_edge):
|
| 299 |
+
self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1
|
| 300 |
+
|
| 301 |
+
self.weight = nn.Parameter(self.weight)
|
| 302 |
+
self.weight.requires_grad_(False)
|
| 303 |
+
|
| 304 |
+
def forward(self, input: torch.Tensor):
|
| 305 |
+
# print('SkeletonUnpool')
|
| 306 |
+
# print(f'input: {input.size()}')
|
| 307 |
+
# print(f'self.weight: {self.weight.size()}')
|
| 308 |
+
return torch.matmul(self.weight, input)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
"""
|
| 312 |
+
Helper functions for skeleton operation
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def dfs(x, fa, vis, dist):
|
| 317 |
+
vis[x] = 1
|
| 318 |
+
for y in range(len(fa)):
|
| 319 |
+
if (fa[y] == x or fa[x] == y) and vis[y] == 0:
|
| 320 |
+
dist[y] = dist[x] + 1
|
| 321 |
+
dfs(y, fa, vis, dist)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
"""
|
| 325 |
+
def find_neighbor_joint(fa, threshold):
|
| 326 |
+
neighbor_list = [[]]
|
| 327 |
+
for x in range(1, len(fa)):
|
| 328 |
+
vis = [0 for _ in range(len(fa))]
|
| 329 |
+
dist = [0 for _ in range(len(fa))]
|
| 330 |
+
dist[0] = 10000
|
| 331 |
+
dfs(x, fa, vis, dist)
|
| 332 |
+
neighbor = []
|
| 333 |
+
for j in range(1, len(fa)):
|
| 334 |
+
if dist[j] <= threshold:
|
| 335 |
+
neighbor.append(j)
|
| 336 |
+
neighbor_list.append(neighbor)
|
| 337 |
+
|
| 338 |
+
neighbor = [0]
|
| 339 |
+
for i, x in enumerate(neighbor_list):
|
| 340 |
+
if i == 0: continue
|
| 341 |
+
if 1 in x:
|
| 342 |
+
neighbor.append(i)
|
| 343 |
+
neighbor_list[i] = [0] + neighbor_list[i]
|
| 344 |
+
neighbor_list[0] = neighbor
|
| 345 |
+
return neighbor_list
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def build_edge_topology(topology, offset):
|
| 349 |
+
# get all edges (pa, child, offset)
|
| 350 |
+
edges = []
|
| 351 |
+
joint_num = len(topology)
|
| 352 |
+
for i in range(1, joint_num):
|
| 353 |
+
edges.append((topology[i], i, offset[i]))
|
| 354 |
+
return edges
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def build_edge_topology(topology):
|
| 359 |
+
# get all edges (pa, child)
|
| 360 |
+
edges = []
|
| 361 |
+
joint_num = len(topology)
|
| 362 |
+
edges.append((0, joint_num)) # add an edge between the root joint and a virtual joint
|
| 363 |
+
for i in range(1, joint_num):
|
| 364 |
+
edges.append((topology[i], i))
|
| 365 |
+
return edges
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def build_joint_topology(edges, origin_names):
|
| 369 |
+
parent = []
|
| 370 |
+
offset = []
|
| 371 |
+
names = []
|
| 372 |
+
edge2joint = []
|
| 373 |
+
joint_from_edge = [] # -1 means virtual joint
|
| 374 |
+
joint_cnt = 0
|
| 375 |
+
out_degree = [0] * (len(edges) + 10)
|
| 376 |
+
for edge in edges:
|
| 377 |
+
out_degree[edge[0]] += 1
|
| 378 |
+
|
| 379 |
+
# add root joint
|
| 380 |
+
joint_from_edge.append(-1)
|
| 381 |
+
parent.append(0)
|
| 382 |
+
offset.append(np.array([0, 0, 0]))
|
| 383 |
+
names.append(origin_names[0])
|
| 384 |
+
joint_cnt += 1
|
| 385 |
+
|
| 386 |
+
def make_topology(edge_idx, pa):
|
| 387 |
+
nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt
|
| 388 |
+
edge = edges[edge_idx]
|
| 389 |
+
if out_degree[edge[0]] > 1:
|
| 390 |
+
parent.append(pa)
|
| 391 |
+
offset.append(np.array([0, 0, 0]))
|
| 392 |
+
names.append(origin_names[edge[1]] + "_virtual")
|
| 393 |
+
edge2joint.append(-1)
|
| 394 |
+
pa = joint_cnt
|
| 395 |
+
joint_cnt += 1
|
| 396 |
+
|
| 397 |
+
parent.append(pa)
|
| 398 |
+
offset.append(edge[2])
|
| 399 |
+
names.append(origin_names[edge[1]])
|
| 400 |
+
edge2joint.append(edge_idx)
|
| 401 |
+
pa = joint_cnt
|
| 402 |
+
joint_cnt += 1
|
| 403 |
+
|
| 404 |
+
for idx, e in enumerate(edges):
|
| 405 |
+
if e[0] == edge[1]:
|
| 406 |
+
make_topology(idx, pa)
|
| 407 |
+
|
| 408 |
+
for idx, e in enumerate(edges):
|
| 409 |
+
if e[0] == 0:
|
| 410 |
+
make_topology(idx, 0)
|
| 411 |
+
|
| 412 |
+
return parent, offset, names, edge2joint
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def calc_edge_mat(edges):
|
| 416 |
+
edge_num = len(edges)
|
| 417 |
+
# edge_mat[i][j] = distance between edge(i) and edge(j)
|
| 418 |
+
edge_mat = [[100000] * edge_num for _ in range(edge_num)]
|
| 419 |
+
for i in range(edge_num):
|
| 420 |
+
edge_mat[i][i] = 0
|
| 421 |
+
|
| 422 |
+
# initialize edge_mat with direct neighbor
|
| 423 |
+
for i, a in enumerate(edges):
|
| 424 |
+
for j, b in enumerate(edges):
|
| 425 |
+
link = 0
|
| 426 |
+
for x in range(2):
|
| 427 |
+
for y in range(2):
|
| 428 |
+
if a[x] == b[y]:
|
| 429 |
+
link = 1
|
| 430 |
+
if link:
|
| 431 |
+
edge_mat[i][j] = 1
|
| 432 |
+
|
| 433 |
+
# calculate all the pairs distance
|
| 434 |
+
for k in range(edge_num):
|
| 435 |
+
for i in range(edge_num):
|
| 436 |
+
for j in range(edge_num):
|
| 437 |
+
edge_mat[i][j] = min(edge_mat[i][j], edge_mat[i][k] + edge_mat[k][j])
|
| 438 |
+
return edge_mat
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def find_neighbor(edges, d):
|
| 442 |
+
"""
|
| 443 |
+
Args:
|
| 444 |
+
edges: The list contains N elements, each element represents (parent, child).
|
| 445 |
+
d: Distance between edges (the distance of the same edge is 0 and the distance of adjacent edges is 1).
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
The list contains N elements, each element is a list of edge indices whose distance <= d.
|
| 449 |
+
"""
|
| 450 |
+
edge_mat = calc_edge_mat(edges)
|
| 451 |
+
neighbor_list = []
|
| 452 |
+
edge_num = len(edge_mat)
|
| 453 |
+
for i in range(edge_num):
|
| 454 |
+
neighbor = []
|
| 455 |
+
for j in range(edge_num):
|
| 456 |
+
if edge_mat[i][j] <= d:
|
| 457 |
+
neighbor.append(j)
|
| 458 |
+
neighbor_list.append(neighbor)
|
| 459 |
+
|
| 460 |
+
# # add neighbor for global part
|
| 461 |
+
# global_part_neighbor = neighbor_list[0].copy()
|
| 462 |
+
# """
|
| 463 |
+
# Line #373 is buggy. Thanks @crissallan!!
|
| 464 |
+
# See issue #30 (https://github.com/DeepMotionEditing/deep-motion-editing/issues/30)
|
| 465 |
+
# However, fixing this bug will make it unable to load the pretrained model and
|
| 466 |
+
# affect the reproducibility of quantitative error reported in the paper.
|
| 467 |
+
# It is not a fatal bug so we didn't touch it and we are looking for possible solutions.
|
| 468 |
+
# """
|
| 469 |
+
# for i in global_part_neighbor:
|
| 470 |
+
# neighbor_list[i].append(edge_num)
|
| 471 |
+
# neighbor_list.append(global_part_neighbor)
|
| 472 |
+
|
| 473 |
+
return neighbor_list
|