Spaces:
Runtime error
Runtime error
Commit ·
f01632b
1
Parent(s): 279199b
add
Browse files
app.py
CHANGED
|
@@ -242,383 +242,6 @@ class BaseTrainer(object):
|
|
| 242 |
original_shape_t[i, selected_indices] = filtered_t[i]
|
| 243 |
return original_shape_t
|
| 244 |
|
| 245 |
-
def _load_data(self, dict_data):
|
| 246 |
-
tar_pose_raw = dict_data["pose"]
|
| 247 |
-
tar_pose = tar_pose_raw[:, :, :165].to(self.rank)
|
| 248 |
-
tar_contact = tar_pose_raw[:, :, 165:169].to(self.rank)
|
| 249 |
-
tar_trans = dict_data["trans"].to(self.rank)
|
| 250 |
-
tar_trans_v = dict_data["trans_v"].to(self.rank)
|
| 251 |
-
tar_exps = dict_data["facial"].to(self.rank)
|
| 252 |
-
in_audio = dict_data["audio"].to(self.rank)
|
| 253 |
-
in_word = dict_data["word"].to(self.rank)
|
| 254 |
-
tar_beta = dict_data["beta"].to(self.rank)
|
| 255 |
-
tar_id = dict_data["id"].to(self.rank).long()
|
| 256 |
-
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
|
| 257 |
-
|
| 258 |
-
tar_pose_jaw = tar_pose[:, :, 66:69]
|
| 259 |
-
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
|
| 260 |
-
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
|
| 261 |
-
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
|
| 262 |
-
|
| 263 |
-
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
|
| 264 |
-
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
|
| 265 |
-
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
|
| 266 |
-
|
| 267 |
-
tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
|
| 268 |
-
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
|
| 269 |
-
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
|
| 270 |
-
|
| 271 |
-
tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
|
| 272 |
-
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
|
| 273 |
-
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
|
| 274 |
-
|
| 275 |
-
tar_pose_lower = tar_pose_leg
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2)
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
if self.args.pose_norm:
|
| 282 |
-
tar_pose_upper = (tar_pose_upper - self.mean_upper) / self.std_upper
|
| 283 |
-
tar_pose_hands = (tar_pose_hands - self.mean_hands) / self.std_hands
|
| 284 |
-
tar_pose_lower = (tar_pose_lower - self.mean_lower) / self.std_lower
|
| 285 |
-
|
| 286 |
-
if self.use_trans:
|
| 287 |
-
tar_trans_v = (tar_trans_v - self.trans_mean)/self.trans_std
|
| 288 |
-
tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
|
| 289 |
-
|
| 290 |
-
latent_face_top = None#self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
|
| 291 |
-
latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper)
|
| 292 |
-
latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands)
|
| 293 |
-
latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower)
|
| 294 |
-
|
| 295 |
-
latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/self.args.vqvae_latent_scale
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
|
| 299 |
-
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
|
| 300 |
-
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
|
| 301 |
-
style_feature = None
|
| 302 |
-
if self.args.use_motionclip:
|
| 303 |
-
motionclip_feat = tar_pose_6d[...,:22*6]
|
| 304 |
-
batch = {}
|
| 305 |
-
bs,seq,feat = motionclip_feat.shape
|
| 306 |
-
batch['x']=motionclip_feat.permute(0,2,1).contiguous()
|
| 307 |
-
batch['y']=torch.zeros(bs).int().cuda()
|
| 308 |
-
batch['mask']=torch.ones([bs,seq]).bool().cuda()
|
| 309 |
-
style_feature = self.motionclip.encoder(batch)['mu'].detach().float()
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
# print(tar_index_value_upper_top.shape, index_in.shape)
|
| 314 |
-
return {
|
| 315 |
-
"tar_pose_jaw": tar_pose_jaw,
|
| 316 |
-
"tar_pose_face": tar_pose_face,
|
| 317 |
-
"tar_pose_upper": tar_pose_upper,
|
| 318 |
-
"tar_pose_lower": tar_pose_lower,
|
| 319 |
-
"tar_pose_hands": tar_pose_hands,
|
| 320 |
-
'tar_pose_leg': tar_pose_leg,
|
| 321 |
-
"in_audio": in_audio,
|
| 322 |
-
"in_word": in_word,
|
| 323 |
-
"tar_trans": tar_trans,
|
| 324 |
-
"tar_exps": tar_exps,
|
| 325 |
-
"tar_beta": tar_beta,
|
| 326 |
-
"tar_pose": tar_pose,
|
| 327 |
-
"tar4dis": tar4dis,
|
| 328 |
-
"latent_face_top": latent_face_top,
|
| 329 |
-
"latent_upper_top": latent_upper_top,
|
| 330 |
-
"latent_hands_top": latent_hands_top,
|
| 331 |
-
"latent_lower_top": latent_lower_top,
|
| 332 |
-
"latent_in": latent_in,
|
| 333 |
-
"tar_id": tar_id,
|
| 334 |
-
"latent_all": latent_all,
|
| 335 |
-
"tar_pose_6d": tar_pose_6d,
|
| 336 |
-
"tar_contact": tar_contact,
|
| 337 |
-
"style_feature":style_feature,
|
| 338 |
-
}
|
| 339 |
-
|
| 340 |
-
def _g_test(self, loaded_data):
|
| 341 |
-
sample_fn = self.diffusion.p_sample_loop
|
| 342 |
-
if self.args.use_ddim:
|
| 343 |
-
sample_fn = self.diffusion.ddim_sample_loop
|
| 344 |
-
mode = 'test'
|
| 345 |
-
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints
|
| 346 |
-
tar_pose = loaded_data["tar_pose"]
|
| 347 |
-
tar_beta = loaded_data["tar_beta"]
|
| 348 |
-
tar_exps = loaded_data["tar_exps"]
|
| 349 |
-
tar_contact = loaded_data["tar_contact"]
|
| 350 |
-
tar_trans = loaded_data["tar_trans"]
|
| 351 |
-
in_word = loaded_data["in_word"]
|
| 352 |
-
in_audio = loaded_data["in_audio"]
|
| 353 |
-
in_x0 = loaded_data['latent_in']
|
| 354 |
-
in_seed = loaded_data['latent_in']
|
| 355 |
-
|
| 356 |
-
remain = n%8
|
| 357 |
-
if remain != 0:
|
| 358 |
-
tar_pose = tar_pose[:, :-remain, :]
|
| 359 |
-
tar_beta = tar_beta[:, :-remain, :]
|
| 360 |
-
tar_trans = tar_trans[:, :-remain, :]
|
| 361 |
-
in_word = in_word[:, :-remain]
|
| 362 |
-
tar_exps = tar_exps[:, :-remain, :]
|
| 363 |
-
tar_contact = tar_contact[:, :-remain, :]
|
| 364 |
-
in_x0 = in_x0[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :]
|
| 365 |
-
in_seed = in_seed[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :]
|
| 366 |
-
n = n - remain
|
| 367 |
-
|
| 368 |
-
tar_pose_jaw = tar_pose[:, :, 66:69]
|
| 369 |
-
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
|
| 370 |
-
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
|
| 371 |
-
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
|
| 372 |
-
|
| 373 |
-
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
|
| 374 |
-
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
|
| 375 |
-
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
|
| 376 |
-
|
| 377 |
-
tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
|
| 378 |
-
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
|
| 379 |
-
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
|
| 380 |
-
|
| 381 |
-
tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
|
| 382 |
-
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
|
| 383 |
-
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
|
| 384 |
-
tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)
|
| 385 |
-
|
| 386 |
-
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
|
| 387 |
-
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
|
| 388 |
-
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
|
| 389 |
-
|
| 390 |
-
rec_all_face = []
|
| 391 |
-
rec_all_upper = []
|
| 392 |
-
rec_all_lower = []
|
| 393 |
-
rec_all_hands = []
|
| 394 |
-
vqvae_squeeze_scale = self.args.vqvae_squeeze_scale
|
| 395 |
-
roundt = (n - self.args.pre_frames * vqvae_squeeze_scale) // (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale)
|
| 396 |
-
remain = (n - self.args.pre_frames * vqvae_squeeze_scale) % (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale)
|
| 397 |
-
round_l = self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
for i in range(0, roundt):
|
| 401 |
-
in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames * vqvae_squeeze_scale]
|
| 402 |
-
|
| 403 |
-
in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames * vqvae_squeeze_scale]
|
| 404 |
-
in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames]
|
| 405 |
-
in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames]
|
| 406 |
-
in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames]
|
| 407 |
-
mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims+3+4).float().cuda()
|
| 408 |
-
mask_val[:, :self.args.pre_frames, :] = 0.0
|
| 409 |
-
if i == 0:
|
| 410 |
-
in_seed_tmp = in_seed_tmp[:, :self.args.pre_frames, :]
|
| 411 |
-
else:
|
| 412 |
-
in_seed_tmp = last_sample[:, -self.args.pre_frames:, :]
|
| 413 |
-
|
| 414 |
-
cond_ = {'y':{}}
|
| 415 |
-
cond_['y']['audio'] = in_audio_tmp
|
| 416 |
-
cond_['y']['word'] = in_word_tmp
|
| 417 |
-
cond_['y']['id'] = in_id_tmp
|
| 418 |
-
cond_['y']['seed'] =in_seed_tmp
|
| 419 |
-
cond_['y']['mask'] = (torch.zeros([self.args.batch_size, 1, 1, self.args.pose_length]) < 1).cuda()
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
cond_['y']['style_feature'] = torch.zeros([bs, 512]).cuda()
|
| 424 |
-
|
| 425 |
-
shape_ = (bs, 1536, 1, 32)
|
| 426 |
-
sample = sample_fn(
|
| 427 |
-
self.model,
|
| 428 |
-
shape_,
|
| 429 |
-
clip_denoised=False,
|
| 430 |
-
model_kwargs=cond_,
|
| 431 |
-
skip_timesteps=0,
|
| 432 |
-
init_image=None,
|
| 433 |
-
progress=True,
|
| 434 |
-
dump_steps=None,
|
| 435 |
-
noise=None,
|
| 436 |
-
const_noise=False,
|
| 437 |
-
)
|
| 438 |
-
sample = sample.squeeze().permute(1,0).unsqueeze(0)
|
| 439 |
-
|
| 440 |
-
last_sample = sample.clone()
|
| 441 |
-
|
| 442 |
-
rec_latent_upper = sample[...,:512]
|
| 443 |
-
rec_latent_hands = sample[...,512:1024]
|
| 444 |
-
rec_latent_lower = sample[...,1024:1536]
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
if i == 0:
|
| 449 |
-
rec_all_upper.append(rec_latent_upper)
|
| 450 |
-
rec_all_hands.append(rec_latent_hands)
|
| 451 |
-
rec_all_lower.append(rec_latent_lower)
|
| 452 |
-
else:
|
| 453 |
-
rec_all_upper.append(rec_latent_upper[:, self.args.pre_frames:])
|
| 454 |
-
rec_all_hands.append(rec_latent_hands[:, self.args.pre_frames:])
|
| 455 |
-
rec_all_lower.append(rec_latent_lower[:, self.args.pre_frames:])
|
| 456 |
-
|
| 457 |
-
rec_all_upper = torch.cat(rec_all_upper, dim=1) * self.vqvae_latent_scale
|
| 458 |
-
rec_all_hands = torch.cat(rec_all_hands, dim=1) * self.vqvae_latent_scale
|
| 459 |
-
rec_all_lower = torch.cat(rec_all_lower, dim=1) * self.vqvae_latent_scale
|
| 460 |
-
|
| 461 |
-
rec_upper = self.vq_model_upper.latent2origin(rec_all_upper)[0]
|
| 462 |
-
rec_hands = self.vq_model_hands.latent2origin(rec_all_hands)[0]
|
| 463 |
-
rec_lower = self.vq_model_lower.latent2origin(rec_all_lower)[0]
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
if self.use_trans:
|
| 467 |
-
rec_trans_v = rec_lower[...,-3:]
|
| 468 |
-
rec_trans_v = rec_trans_v * self.trans_std + self.trans_mean
|
| 469 |
-
rec_trans = torch.zeros_like(rec_trans_v)
|
| 470 |
-
rec_trans = torch.cumsum(rec_trans_v, dim=-2)
|
| 471 |
-
rec_trans[...,1]=rec_trans_v[...,1]
|
| 472 |
-
rec_lower = rec_lower[...,:-3]
|
| 473 |
-
|
| 474 |
-
if self.args.pose_norm:
|
| 475 |
-
rec_upper = rec_upper * self.std_upper + self.mean_upper
|
| 476 |
-
rec_hands = rec_hands * self.std_hands + self.mean_hands
|
| 477 |
-
rec_lower = rec_lower * self.std_lower + self.mean_lower
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
n = n - remain
|
| 483 |
-
tar_pose = tar_pose[:, :n, :]
|
| 484 |
-
tar_exps = tar_exps[:, :n, :]
|
| 485 |
-
tar_trans = tar_trans[:, :n, :]
|
| 486 |
-
tar_beta = tar_beta[:, :n, :]
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
rec_exps = tar_exps
|
| 490 |
-
#rec_pose_jaw = rec_face[:, :, :6]
|
| 491 |
-
rec_pose_legs = rec_lower[:, :, :54]
|
| 492 |
-
bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
|
| 493 |
-
rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
|
| 494 |
-
rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)#
|
| 495 |
-
rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3)
|
| 496 |
-
rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n)
|
| 497 |
-
rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
|
| 498 |
-
rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
|
| 499 |
-
rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6)
|
| 500 |
-
rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3)
|
| 501 |
-
rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n)
|
| 502 |
-
rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
|
| 503 |
-
rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
|
| 504 |
-
rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3)
|
| 505 |
-
rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n)
|
| 506 |
-
rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
|
| 507 |
-
rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69]
|
| 508 |
-
|
| 509 |
-
rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3))
|
| 510 |
-
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
|
| 511 |
-
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3))
|
| 512 |
-
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
|
| 513 |
-
|
| 514 |
-
return {
|
| 515 |
-
'rec_pose': rec_pose,
|
| 516 |
-
'rec_trans': rec_trans,
|
| 517 |
-
'tar_pose': tar_pose,
|
| 518 |
-
'tar_exps': tar_exps,
|
| 519 |
-
'tar_beta': tar_beta,
|
| 520 |
-
'tar_trans': tar_trans,
|
| 521 |
-
'rec_exps': rec_exps,
|
| 522 |
-
}
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
def _create_cuda_model(self):
|
| 526 |
-
args = self.args
|
| 527 |
-
other_tools.load_checkpoints(self.model, args.test_ckpt, args.g_name)
|
| 528 |
-
args.num_quantizers = 6
|
| 529 |
-
args.shared_codebook = False
|
| 530 |
-
args.quantize_dropout_prob = 0.2
|
| 531 |
-
args.mu = 0.99
|
| 532 |
-
|
| 533 |
-
args.nb_code = 512
|
| 534 |
-
args.code_dim = 512
|
| 535 |
-
args.code_dim = 512
|
| 536 |
-
args.down_t = 2
|
| 537 |
-
args.stride_t = 2
|
| 538 |
-
args.width = 512
|
| 539 |
-
args.depth = 3
|
| 540 |
-
args.dilation_growth_rate = 3
|
| 541 |
-
args.vq_act = "relu"
|
| 542 |
-
args.vq_norm = None
|
| 543 |
-
|
| 544 |
-
dim_pose = 78
|
| 545 |
-
args.body_part = "upper"
|
| 546 |
-
self.vq_model_upper = RVQVAE(args,
|
| 547 |
-
dim_pose,
|
| 548 |
-
args.nb_code,
|
| 549 |
-
args.code_dim,
|
| 550 |
-
args.code_dim,
|
| 551 |
-
args.down_t,
|
| 552 |
-
args.stride_t,
|
| 553 |
-
args.width,
|
| 554 |
-
args.depth,
|
| 555 |
-
args.dilation_growth_rate,
|
| 556 |
-
args.vq_act,
|
| 557 |
-
args.vq_norm)
|
| 558 |
-
|
| 559 |
-
dim_pose = 180
|
| 560 |
-
args.body_part = "hands"
|
| 561 |
-
self.vq_model_hands = RVQVAE(args,
|
| 562 |
-
dim_pose,
|
| 563 |
-
args.nb_code,
|
| 564 |
-
args.code_dim,
|
| 565 |
-
args.code_dim,
|
| 566 |
-
args.down_t,
|
| 567 |
-
args.stride_t,
|
| 568 |
-
args.width,
|
| 569 |
-
args.depth,
|
| 570 |
-
args.dilation_growth_rate,
|
| 571 |
-
args.vq_act,
|
| 572 |
-
args.vq_norm)
|
| 573 |
-
|
| 574 |
-
dim_pose = 54
|
| 575 |
-
if args.use_trans:
|
| 576 |
-
dim_pose = 57
|
| 577 |
-
self.args.vqvae_lower_path = self.args.vqvae_lower_trans_path
|
| 578 |
-
args.body_part = "lower"
|
| 579 |
-
self.vq_model_lower = RVQVAE(args,
|
| 580 |
-
dim_pose,
|
| 581 |
-
args.nb_code,
|
| 582 |
-
args.code_dim,
|
| 583 |
-
args.code_dim,
|
| 584 |
-
args.down_t,
|
| 585 |
-
args.stride_t,
|
| 586 |
-
args.width,
|
| 587 |
-
args.depth,
|
| 588 |
-
args.dilation_growth_rate,
|
| 589 |
-
args.vq_act,
|
| 590 |
-
args.vq_norm)
|
| 591 |
-
|
| 592 |
-
self.vq_model_upper.load_state_dict(torch.load(self.args.vqvae_upper_path)['net'])
|
| 593 |
-
self.vq_model_hands.load_state_dict(torch.load(self.args.vqvae_hands_path)['net'])
|
| 594 |
-
self.vq_model_lower.load_state_dict(torch.load(self.args.vqvae_lower_path)['net'])
|
| 595 |
-
|
| 596 |
-
self.vqvae_latent_scale = self.args.vqvae_latent_scale
|
| 597 |
-
|
| 598 |
-
self.vq_model_upper.eval().to(self.rank)
|
| 599 |
-
self.vq_model_hands.eval().to(self.rank)
|
| 600 |
-
self.vq_model_lower.eval().to(self.rank)
|
| 601 |
-
|
| 602 |
-
self.model = self.model.cuda()
|
| 603 |
-
self.model.eval()
|
| 604 |
-
|
| 605 |
-
self.mean_upper = torch.from_numpy(self.mean_upper).cuda()
|
| 606 |
-
self.mean_hands = torch.from_numpy(self.mean_hands).cuda()
|
| 607 |
-
self.mean_lower = torch.from_numpy(self.mean_lower).cuda()
|
| 608 |
-
self.std_upper = torch.from_numpy(self.std_upper).cuda()
|
| 609 |
-
self.std_hands = torch.from_numpy(self.std_hands).cuda()
|
| 610 |
-
self.std_lower = torch.from_numpy(self.std_lower).cuda()
|
| 611 |
-
self.trans_mean = torch.from_numpy(self.trans_mean).cuda()
|
| 612 |
-
self.trans_std = torch.from_numpy(self.trans_std).cuda()
|
| 613 |
-
|
| 614 |
-
@spaces.GPU(duration=149)
|
| 615 |
-
def _warp(self, batch_data):
|
| 616 |
-
self._create_cuda_model()
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
loaded_data = self._load_data(batch_data)
|
| 620 |
-
net_out = self._g_test(loaded_data)
|
| 621 |
-
return net_out
|
| 622 |
|
| 623 |
def test_demo(self, epoch):
|
| 624 |
'''
|
|
@@ -644,7 +267,7 @@ class BaseTrainer(object):
|
|
| 644 |
for its, batch_data in enumerate(self.test_loader):
|
| 645 |
# loaded_data = self._load_data(batch_data)
|
| 646 |
# net_out = self._g_test(loaded_data)
|
| 647 |
-
net_out =
|
| 648 |
tar_pose = net_out['tar_pose']
|
| 649 |
rec_pose = net_out['rec_pose']
|
| 650 |
tar_exps = net_out['tar_exps']
|
|
@@ -708,7 +331,402 @@ class BaseTrainer(object):
|
|
| 708 |
end_time = time.time() - start_time
|
| 709 |
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")
|
| 710 |
return result
|
| 711 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 712 |
@logger.catch
|
| 713 |
def syntalker(audio_path,sample_stratege):
|
| 714 |
args = config.parse_args()
|
|
|
|
| 242 |
original_shape_t[i, selected_indices] = filtered_t[i]
|
| 243 |
return original_shape_t
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
def test_demo(self, epoch):
|
| 247 |
'''
|
|
|
|
| 267 |
for its, batch_data in enumerate(self.test_loader):
|
| 268 |
# loaded_data = self._load_data(batch_data)
|
| 269 |
# net_out = self._g_test(loaded_data)
|
| 270 |
+
net_out = _warp(self.args,self.model, batch_data,self.joints,self.joint_mask_upper,self.joint_mask_hands,self.joint_mask_lower,self.use_trans,self.diffusion)
|
| 271 |
tar_pose = net_out['tar_pose']
|
| 272 |
rec_pose = net_out['rec_pose']
|
| 273 |
tar_exps = net_out['tar_exps']
|
|
|
|
| 331 |
end_time = time.time() - start_time
|
| 332 |
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")
|
| 333 |
return result
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@spaces.GPU(duration=149)
|
| 337 |
+
def _warp(args,model, batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,use_trans,diffusion):
|
| 338 |
+
args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale=_warp_create_cuda_model(args,model)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
loaded_data = _warp_load_data(
|
| 342 |
+
batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower
|
| 343 |
+
)
|
| 344 |
+
net_out = _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower)
|
| 345 |
+
return net_out
|
| 346 |
+
|
| 347 |
+
def _warp_inverse_selection_tensor(filtered_t, selection_array, n):
|
| 348 |
+
selection_array = torch.from_numpy(selection_array).cuda()
|
| 349 |
+
original_shape_t = torch.zeros((n, 165)).cuda()
|
| 350 |
+
selected_indices = torch.where(selection_array == 1)[0]
|
| 351 |
+
for i in range(n):
|
| 352 |
+
original_shape_t[i, selected_indices] = filtered_t[i]
|
| 353 |
+
return original_shape_t
|
| 354 |
+
|
| 355 |
+
def _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower):
|
| 356 |
+
sample_fn = diffusion.p_sample_loop
|
| 357 |
+
if args.use_ddim:
|
| 358 |
+
sample_fn = diffusion.ddim_sample_loop
|
| 359 |
+
mode = 'test'
|
| 360 |
+
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], joints
|
| 361 |
+
tar_pose = loaded_data["tar_pose"]
|
| 362 |
+
tar_beta = loaded_data["tar_beta"]
|
| 363 |
+
tar_exps = loaded_data["tar_exps"]
|
| 364 |
+
tar_contact = loaded_data["tar_contact"]
|
| 365 |
+
tar_trans = loaded_data["tar_trans"]
|
| 366 |
+
in_word = loaded_data["in_word"]
|
| 367 |
+
in_audio = loaded_data["in_audio"]
|
| 368 |
+
in_x0 = loaded_data['latent_in']
|
| 369 |
+
in_seed = loaded_data['latent_in']
|
| 370 |
+
|
| 371 |
+
remain = n%8
|
| 372 |
+
if remain != 0:
|
| 373 |
+
tar_pose = tar_pose[:, :-remain, :]
|
| 374 |
+
tar_beta = tar_beta[:, :-remain, :]
|
| 375 |
+
tar_trans = tar_trans[:, :-remain, :]
|
| 376 |
+
in_word = in_word[:, :-remain]
|
| 377 |
+
tar_exps = tar_exps[:, :-remain, :]
|
| 378 |
+
tar_contact = tar_contact[:, :-remain, :]
|
| 379 |
+
in_x0 = in_x0[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :]
|
| 380 |
+
in_seed = in_seed[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :]
|
| 381 |
+
n = n - remain
|
| 382 |
+
|
| 383 |
+
tar_pose_jaw = tar_pose[:, :, 66:69]
|
| 384 |
+
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
|
| 385 |
+
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
|
| 386 |
+
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
|
| 387 |
+
|
| 388 |
+
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
|
| 389 |
+
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
|
| 390 |
+
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
|
| 391 |
+
|
| 392 |
+
tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)]
|
| 393 |
+
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
|
| 394 |
+
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
|
| 395 |
+
|
| 396 |
+
tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)]
|
| 397 |
+
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
|
| 398 |
+
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
|
| 399 |
+
tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)
|
| 400 |
+
|
| 401 |
+
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
|
| 402 |
+
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
|
| 403 |
+
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
|
| 404 |
+
|
| 405 |
+
rec_all_face = []
|
| 406 |
+
rec_all_upper = []
|
| 407 |
+
rec_all_lower = []
|
| 408 |
+
rec_all_hands = []
|
| 409 |
+
vqvae_squeeze_scale = args.vqvae_squeeze_scale
|
| 410 |
+
roundt = (n - args.pre_frames * vqvae_squeeze_scale) // (args.pose_length - args.pre_frames * vqvae_squeeze_scale)
|
| 411 |
+
remain = (n - args.pre_frames * vqvae_squeeze_scale) % (args.pose_length - args.pre_frames * vqvae_squeeze_scale)
|
| 412 |
+
round_l = args.pose_length - args.pre_frames * vqvae_squeeze_scale
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
for i in range(0, roundt):
|
| 416 |
+
in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+args.pre_frames * vqvae_squeeze_scale]
|
| 417 |
+
|
| 418 |
+
in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*args.pre_frames * vqvae_squeeze_scale]
|
| 419 |
+
in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+args.pre_frames]
|
| 420 |
+
in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames]
|
| 421 |
+
in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames]
|
| 422 |
+
mask_val = torch.ones(bs, args.pose_length, args.pose_dims+3+4).float().cuda()
|
| 423 |
+
mask_val[:, :args.pre_frames, :] = 0.0
|
| 424 |
+
if i == 0:
|
| 425 |
+
in_seed_tmp = in_seed_tmp[:, :args.pre_frames, :]
|
| 426 |
+
else:
|
| 427 |
+
in_seed_tmp = last_sample[:, -args.pre_frames:, :]
|
| 428 |
+
|
| 429 |
+
cond_ = {'y':{}}
|
| 430 |
+
cond_['y']['audio'] = in_audio_tmp
|
| 431 |
+
cond_['y']['word'] = in_word_tmp
|
| 432 |
+
cond_['y']['id'] = in_id_tmp
|
| 433 |
+
cond_['y']['seed'] =in_seed_tmp
|
| 434 |
+
cond_['y']['mask'] = (torch.zeros([args.batch_size, 1, 1, args.pose_length]) < 1).cuda()
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
cond_['y']['style_feature'] = torch.zeros([bs, 512]).cuda()
|
| 439 |
+
|
| 440 |
+
shape_ = (bs, 1536, 1, 32)
|
| 441 |
+
sample = sample_fn(
|
| 442 |
+
model,
|
| 443 |
+
shape_,
|
| 444 |
+
clip_denoised=False,
|
| 445 |
+
model_kwargs=cond_,
|
| 446 |
+
skip_timesteps=0,
|
| 447 |
+
init_image=None,
|
| 448 |
+
progress=True,
|
| 449 |
+
dump_steps=None,
|
| 450 |
+
noise=None,
|
| 451 |
+
const_noise=False,
|
| 452 |
+
)
|
| 453 |
+
sample = sample.squeeze().permute(1,0).unsqueeze(0)
|
| 454 |
+
|
| 455 |
+
last_sample = sample.clone()
|
| 456 |
+
|
| 457 |
+
rec_latent_upper = sample[...,:512]
|
| 458 |
+
rec_latent_hands = sample[...,512:1024]
|
| 459 |
+
rec_latent_lower = sample[...,1024:1536]
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
if i == 0:
|
| 464 |
+
rec_all_upper.append(rec_latent_upper)
|
| 465 |
+
rec_all_hands.append(rec_latent_hands)
|
| 466 |
+
rec_all_lower.append(rec_latent_lower)
|
| 467 |
+
else:
|
| 468 |
+
rec_all_upper.append(rec_latent_upper[:, args.pre_frames:])
|
| 469 |
+
rec_all_hands.append(rec_latent_hands[:, args.pre_frames:])
|
| 470 |
+
rec_all_lower.append(rec_latent_lower[:, args.pre_frames:])
|
| 471 |
+
|
| 472 |
+
rec_all_upper = torch.cat(rec_all_upper, dim=1) * vqvae_latent_scale
|
| 473 |
+
rec_all_hands = torch.cat(rec_all_hands, dim=1) * vqvae_latent_scale
|
| 474 |
+
rec_all_lower = torch.cat(rec_all_lower, dim=1) * vqvae_latent_scale
|
| 475 |
+
|
| 476 |
+
rec_upper = vq_model_upper.latent2origin(rec_all_upper)[0]
|
| 477 |
+
rec_hands = vq_model_hands.latent2origin(rec_all_hands)[0]
|
| 478 |
+
rec_lower = vq_model_lower.latent2origin(rec_all_lower)[0]
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
if use_trans:
|
| 482 |
+
rec_trans_v = rec_lower[...,-3:]
|
| 483 |
+
rec_trans_v = rec_trans_v * trans_std + trans_mean
|
| 484 |
+
rec_trans = torch.zeros_like(rec_trans_v)
|
| 485 |
+
rec_trans = torch.cumsum(rec_trans_v, dim=-2)
|
| 486 |
+
rec_trans[...,1]=rec_trans_v[...,1]
|
| 487 |
+
rec_lower = rec_lower[...,:-3]
|
| 488 |
+
|
| 489 |
+
if args.pose_norm:
|
| 490 |
+
rec_upper = rec_upper * std_upper + mean_upper
|
| 491 |
+
rec_hands = rec_hands * std_hands + mean_hands
|
| 492 |
+
rec_lower = rec_lower * std_lower + mean_lower
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
n = n - remain
|
| 498 |
+
tar_pose = tar_pose[:, :n, :]
|
| 499 |
+
tar_exps = tar_exps[:, :n, :]
|
| 500 |
+
tar_trans = tar_trans[:, :n, :]
|
| 501 |
+
tar_beta = tar_beta[:, :n, :]
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
rec_exps = tar_exps
|
| 505 |
+
#rec_pose_jaw = rec_face[:, :, :6]
|
| 506 |
+
rec_pose_legs = rec_lower[:, :, :54]
|
| 507 |
+
bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
|
| 508 |
+
rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
|
| 509 |
+
rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)#
|
| 510 |
+
rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3)
|
| 511 |
+
rec_pose_upper_recover = _warp_inverse_selection_tensor(rec_pose_upper, joint_mask_upper, bs*n)
|
| 512 |
+
rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
|
| 513 |
+
rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
|
| 514 |
+
rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6)
|
| 515 |
+
rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3)
|
| 516 |
+
rec_pose_lower_recover = _warp_inverse_selection_tensor(rec_pose_lower, joint_mask_lower, bs*n)
|
| 517 |
+
rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
|
| 518 |
+
rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
|
| 519 |
+
rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3)
|
| 520 |
+
rec_pose_hands_recover = _warp_inverse_selection_tensor(rec_pose_hands, joint_mask_hands, bs*n)
|
| 521 |
+
rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
|
| 522 |
+
rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69]
|
| 523 |
+
|
| 524 |
+
rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3))
|
| 525 |
+
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
|
| 526 |
+
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3))
|
| 527 |
+
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
|
| 528 |
+
|
| 529 |
+
return {
|
| 530 |
+
'rec_pose': rec_pose,
|
| 531 |
+
'rec_trans': rec_trans,
|
| 532 |
+
'tar_pose': tar_pose,
|
| 533 |
+
'tar_exps': tar_exps,
|
| 534 |
+
'tar_beta': tar_beta,
|
| 535 |
+
'tar_trans': tar_trans,
|
| 536 |
+
'rec_exps': rec_exps,
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def _warp_load_data(dict_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower):
|
| 542 |
+
tar_pose_raw = dict_data["pose"]
|
| 543 |
+
tar_pose = tar_pose_raw[:, :, :165].cuda()
|
| 544 |
+
tar_contact = tar_pose_raw[:, :, 165:169].cuda()
|
| 545 |
+
tar_trans = dict_data["trans"].cuda()
|
| 546 |
+
tar_trans_v = dict_data["trans_v"].cuda()
|
| 547 |
+
tar_exps = dict_data["facial"].cuda()
|
| 548 |
+
in_audio = dict_data["audio"].cuda()
|
| 549 |
+
in_word = dict_data["word"].cuda()
|
| 550 |
+
tar_beta = dict_data["beta"].cuda()
|
| 551 |
+
tar_id = dict_data["id"].cuda().long()
|
| 552 |
+
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], joints
|
| 553 |
+
|
| 554 |
+
tar_pose_jaw = tar_pose[:, :, 66:69]
|
| 555 |
+
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
|
| 556 |
+
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
|
| 557 |
+
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
|
| 558 |
+
|
| 559 |
+
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
|
| 560 |
+
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
|
| 561 |
+
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
|
| 562 |
+
|
| 563 |
+
tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)]
|
| 564 |
+
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
|
| 565 |
+
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
|
| 566 |
+
|
| 567 |
+
tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)]
|
| 568 |
+
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
|
| 569 |
+
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
|
| 570 |
+
|
| 571 |
+
tar_pose_lower = tar_pose_leg
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
if args.pose_norm:
|
| 578 |
+
tar_pose_upper = (tar_pose_upper - mean_upper) / std_upper
|
| 579 |
+
tar_pose_hands = (tar_pose_hands - mean_hands) / std_hands
|
| 580 |
+
tar_pose_lower = (tar_pose_lower - mean_lower) / std_lower
|
| 581 |
+
|
| 582 |
+
if use_trans:
|
| 583 |
+
tar_trans_v = (tar_trans_v - trans_mean)/trans_std
|
| 584 |
+
tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
|
| 585 |
+
|
| 586 |
+
latent_face_top = None#self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
|
| 587 |
+
latent_upper_top = vq_model_upper.map2latent(tar_pose_upper)
|
| 588 |
+
latent_hands_top = vq_model_hands.map2latent(tar_pose_hands)
|
| 589 |
+
latent_lower_top = vq_model_lower.map2latent(tar_pose_lower)
|
| 590 |
+
|
| 591 |
+
latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/args.vqvae_latent_scale
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
|
| 595 |
+
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
|
| 596 |
+
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
|
| 597 |
+
style_feature = None
|
| 598 |
+
if args.use_motionclip:
|
| 599 |
+
motionclip_feat = tar_pose_6d[...,:22*6]
|
| 600 |
+
batch = {}
|
| 601 |
+
bs,seq,feat = motionclip_feat.shape
|
| 602 |
+
batch['x']=motionclip_feat.permute(0,2,1).contiguous()
|
| 603 |
+
batch['y']=torch.zeros(bs).int().cuda()
|
| 604 |
+
batch['mask']=torch.ones([bs,seq]).bool().cuda()
|
| 605 |
+
style_feature = motionclip.encoder(batch)['mu'].detach().float()
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
# print(tar_index_value_upper_top.shape, index_in.shape)
|
| 610 |
+
return {
|
| 611 |
+
"tar_pose_jaw": tar_pose_jaw,
|
| 612 |
+
"tar_pose_face": tar_pose_face,
|
| 613 |
+
"tar_pose_upper": tar_pose_upper,
|
| 614 |
+
"tar_pose_lower": tar_pose_lower,
|
| 615 |
+
"tar_pose_hands": tar_pose_hands,
|
| 616 |
+
'tar_pose_leg': tar_pose_leg,
|
| 617 |
+
"in_audio": in_audio,
|
| 618 |
+
"in_word": in_word,
|
| 619 |
+
"tar_trans": tar_trans,
|
| 620 |
+
"tar_exps": tar_exps,
|
| 621 |
+
"tar_beta": tar_beta,
|
| 622 |
+
"tar_pose": tar_pose,
|
| 623 |
+
"tar4dis": tar4dis,
|
| 624 |
+
"latent_face_top": latent_face_top,
|
| 625 |
+
"latent_upper_top": latent_upper_top,
|
| 626 |
+
"latent_hands_top": latent_hands_top,
|
| 627 |
+
"latent_lower_top": latent_lower_top,
|
| 628 |
+
"latent_in": latent_in,
|
| 629 |
+
"tar_id": tar_id,
|
| 630 |
+
"latent_all": latent_all,
|
| 631 |
+
"tar_pose_6d": tar_pose_6d,
|
| 632 |
+
"tar_contact": tar_contact,
|
| 633 |
+
"style_feature":style_feature,
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def _warp_create_cuda_model(args,model):
|
| 638 |
+
args = args
|
| 639 |
+
other_tools.load_checkpoints(model, args.test_ckpt, args.g_name)
|
| 640 |
+
args.num_quantizers = 6
|
| 641 |
+
args.shared_codebook = False
|
| 642 |
+
args.quantize_dropout_prob = 0.2
|
| 643 |
+
args.mu = 0.99
|
| 644 |
+
|
| 645 |
+
args.nb_code = 512
|
| 646 |
+
args.code_dim = 512
|
| 647 |
+
args.code_dim = 512
|
| 648 |
+
args.down_t = 2
|
| 649 |
+
args.stride_t = 2
|
| 650 |
+
args.width = 512
|
| 651 |
+
args.depth = 3
|
| 652 |
+
args.dilation_growth_rate = 3
|
| 653 |
+
args.vq_act = "relu"
|
| 654 |
+
args.vq_norm = None
|
| 655 |
+
|
| 656 |
+
dim_pose = 78
|
| 657 |
+
args.body_part = "upper"
|
| 658 |
+
vq_model_upper = RVQVAE(args,
|
| 659 |
+
dim_pose,
|
| 660 |
+
args.nb_code,
|
| 661 |
+
args.code_dim,
|
| 662 |
+
args.code_dim,
|
| 663 |
+
args.down_t,
|
| 664 |
+
args.stride_t,
|
| 665 |
+
args.width,
|
| 666 |
+
args.depth,
|
| 667 |
+
args.dilation_growth_rate,
|
| 668 |
+
args.vq_act,
|
| 669 |
+
args.vq_norm)
|
| 670 |
+
|
| 671 |
+
dim_pose = 180
|
| 672 |
+
args.body_part = "hands"
|
| 673 |
+
vq_model_hands = RVQVAE(args,
|
| 674 |
+
dim_pose,
|
| 675 |
+
args.nb_code,
|
| 676 |
+
args.code_dim,
|
| 677 |
+
args.code_dim,
|
| 678 |
+
args.down_t,
|
| 679 |
+
args.stride_t,
|
| 680 |
+
args.width,
|
| 681 |
+
args.depth,
|
| 682 |
+
args.dilation_growth_rate,
|
| 683 |
+
args.vq_act,
|
| 684 |
+
args.vq_norm)
|
| 685 |
+
|
| 686 |
+
dim_pose = 54
|
| 687 |
+
if args.use_trans:
|
| 688 |
+
dim_pose = 57
|
| 689 |
+
args.vqvae_lower_path = args.vqvae_lower_trans_path
|
| 690 |
+
args.body_part = "lower"
|
| 691 |
+
vq_model_lower = RVQVAE(args,
|
| 692 |
+
dim_pose,
|
| 693 |
+
args.nb_code,
|
| 694 |
+
args.code_dim,
|
| 695 |
+
args.code_dim,
|
| 696 |
+
args.down_t,
|
| 697 |
+
args.stride_t,
|
| 698 |
+
args.width,
|
| 699 |
+
args.depth,
|
| 700 |
+
args.dilation_growth_rate,
|
| 701 |
+
args.vq_act,
|
| 702 |
+
args.vq_norm)
|
| 703 |
+
|
| 704 |
+
vq_model_upper.load_state_dict(torch.load(args.vqvae_upper_path)['net'])
|
| 705 |
+
vq_model_hands.load_state_dict(torch.load(args.vqvae_hands_path)['net'])
|
| 706 |
+
vq_model_lower.load_state_dict(torch.load(args.vqvae_lower_path)['net'])
|
| 707 |
+
|
| 708 |
+
vqvae_latent_scale = args.vqvae_latent_scale
|
| 709 |
+
|
| 710 |
+
vq_model_upper.eval().cuda()
|
| 711 |
+
vq_model_hands.eval().cuda()
|
| 712 |
+
vq_model_lower.eval().cuda()
|
| 713 |
+
|
| 714 |
+
model = model.cuda()
|
| 715 |
+
model.eval()
|
| 716 |
+
|
| 717 |
+
mean_upper = torch.from_numpy(mean_upper).cuda()
|
| 718 |
+
mean_hands = torch.from_numpy(mean_hands).cuda()
|
| 719 |
+
mean_lower = torch.from_numpy(mean_lower).cuda()
|
| 720 |
+
std_upper = torch.from_numpy(std_upper).cuda()
|
| 721 |
+
std_hands = torch.from_numpy(std_hands).cuda()
|
| 722 |
+
std_lower = torch.from_numpy(std_lower).cuda()
|
| 723 |
+
trans_mean = torch.from_numpy(trans_mean).cuda()
|
| 724 |
+
trans_std = torch.from_numpy(trans_std).cuda()
|
| 725 |
+
|
| 726 |
+
return args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
|
| 730 |
@logger.catch
|
| 731 |
def syntalker(audio_path,sample_stratege):
|
| 732 |
args = config.parse_args()
|