Spaces:
Runtime error
Runtime error
| import clip | |
| import numpy as np | |
| import torch | |
| from scipy.spatial.transform import Rotation as R | |
| import os | |
| import sys | |
| utils_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'utils')) | |
| sys.path.append(utils_dir) | |
| from utils.transforms import rigid_transform_3D, transform_points_numpy | |
| from utils.constants import rest_pelvis | |
| def test_model(models, diffuser, normalizer, configs, text_embedder, hint_text, prog_ind, joint_orig=None, All_one_model=True, **kwargs): | |
| # set up | |
| if All_one_model: | |
| model= models['model'] | |
| try: | |
| disc_model = models['disc_model'] | |
| except: | |
| print("disc_model is not provided!", flush=True) | |
| disc_model = None | |
| else: | |
| assert len(kwargs['model_type']) == len(hint_text), "model_type should have the same length as hint_text" | |
| device = joint_orig.device | |
| normalize, denormalize = normalizer | |
| text_embedder = text_embedder | |
| batch_size = configs['batch_size'] | |
| seq_len = configs['seq_len'] | |
| channels = configs['channels'] | |
| fixed_frame = configs['fixed_frame'] | |
| use_cfg = configs['use_cfg'] | |
| cfg_alpha = configs['cfg_alpha'] | |
| # for classifier guidance | |
| cg_alpha = configs['cg_alpha'] | |
| cg_diffusion_steps = configs['cg_diffusion_steps'] | |
| # select the prog_ind and hint embedding | |
| def get_prog_hint(i, prog_ind, hint_emb, model_type=None): | |
| get_hint_idx = i | |
| remains = 0 | |
| task_i = None | |
| for j in range(len(prog_ind)+1): | |
| if(get_hint_idx>=0): | |
| get_hint_idx -= len(prog_ind[j]) | |
| else: | |
| remains = get_hint_idx + len(prog_ind[j-1]) | |
| get_hint_idx = j-1 | |
| break | |
| prog_ind_i = torch.tensor(prog_ind[get_hint_idx][remains]).unsqueeze(0).to(device) | |
| if model_type is not None: | |
| task_i = model_type[get_hint_idx][remains] | |
| else: | |
| task_i = None | |
| hint_emb_i = hint_emb[get_hint_idx].unsqueeze(0) | |
| return prog_ind_i, hint_emb_i, task_i | |
| epochs_num = 0 | |
| begining_frame = joint_orig[0,:fixed_frame,...].reshape(-1, fixed_frame, channels) | |
| samples_total = [] | |
| orig_samples_total = [] | |
| if hint_text: | |
| hint_token = clip.tokenize(hint_text).to(device) | |
| hint_emb = text_embedder.encode_text(hint_token).to(device=device, dtype=torch.float32) | |
| for i in range(len(prog_ind)): | |
| epochs_num += len(prog_ind[i]) | |
| ################################################################################ | |
| # autogregresive diffusion | |
| trans_mats = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0) | |
| trans_mats_orig = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0) | |
| for i in range(epochs_num): | |
| if All_one_model: | |
| prog_ind_i, hint_emb_i, _ = get_prog_hint(i, prog_ind, hint_emb) | |
| else: | |
| prog_ind_i, hint_emb_i, task_model = get_prog_hint(i, prog_ind, hint_emb, kwargs['model_type']) | |
| joint_orig_i = joint_orig[i].reshape(-1, seq_len, channels) | |
| if not All_one_model: | |
| model = models[task_model] | |
| disc_model = models[task_model+'_disc'] | |
| samples = diffuser.sample(model, | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| channels=channels, | |
| fixed_points=begining_frame, | |
| text=hint_emb_i, | |
| prog_ind=prog_ind_i, | |
| joints_orig=joint_orig_i, | |
| use_cfg=use_cfg, | |
| cfg_alpha=cfg_alpha, | |
| disc_model=disc_model, | |
| cg_alpha = cg_alpha, | |
| cg_diffusion_steps = cg_diffusion_steps, | |
| ) | |
| samples = samples[-1] # only consider the last timestep | |
| samples = denormalize(samples) | |
| samples = samples.detach().cpu().numpy() | |
| # for original motion | |
| orig_samples = denormalize(joint_orig_i).detach().cpu().numpy() | |
| if i==0: | |
| samples_total.append(samples) | |
| orig_samples_total.append(orig_samples) | |
| else: | |
| samples = samples[:, fixed_frame:, :] | |
| samples = transform_points_numpy(samples, trans_mats) | |
| samples_total.append(samples) | |
| orig_samples = orig_samples[:, fixed_frame:, :] | |
| orig_samples = transform_points_numpy(orig_samples, trans_mats_orig) | |
| orig_samples_total.append(orig_samples) | |
| begining_frame = samples[:, -fixed_frame:, :] | |
| pelvis_new = begining_frame[:, -fixed_frame, :9].reshape(batch_size, 3, 3) | |
| trans_mats = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0) | |
| for ip, pn in enumerate(pelvis_new): | |
| _, ret_R, ret_t = rigid_transform_3D(np.matrix(pn), rest_pelvis, False) | |
| ret_t[1] = 0.0 | |
| rot_euler = R.from_matrix(ret_R).as_euler('zxy') | |
| shift_euler = np.array([0, 0, rot_euler[2]]) | |
| shift_rot_matrix2 = R.from_euler('zxy', shift_euler).as_matrix() | |
| trans_mats[ip, :3, :3] = shift_rot_matrix2 | |
| trans_mats[ip, :3, 3] = ret_t.reshape(-1) | |
| begining_frame = normalize(torch.tensor(transform_points_numpy(begining_frame, np.linalg.inv(trans_mats)), device=device, dtype=torch.float32)) | |
| begining_frame_orig = orig_samples[:, -fixed_frame:, :] | |
| pelvis_new_orig = begining_frame_orig[:, -fixed_frame, :9].reshape(batch_size, 3, 3) | |
| trans_mats_orig = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0) | |
| for ip, pn in enumerate(pelvis_new_orig): | |
| _, ret_R, ret_t = rigid_transform_3D(np.matrix(pn), rest_pelvis, False) | |
| ret_t[1] = 0.0 | |
| rot_euler = R.from_matrix(ret_R).as_euler('zxy') | |
| shift_euler = np.array([0, 0, rot_euler[2]]) | |
| shift_rot_matrix2 = R.from_euler('zxy', shift_euler).as_matrix() | |
| trans_mats_orig[ip, :3, :3] = shift_rot_matrix2 | |
| trans_mats_orig[ip, :3, 3] = ret_t.reshape(-1) | |
| begining_frame_orig = normalize(torch.tensor(transform_points_numpy(begining_frame_orig, np.linalg.inv(trans_mats_orig)), device=device, dtype=torch.float32)) | |
| samples_total = np.concatenate(samples_total, axis=1) | |
| orig_samples_total = np.concatenate(orig_samples_total, axis=1) | |
| return samples_total, orig_samples_total |