| import torch |
| from utils.word_vectorizer import WordVectorizer |
| from torch.utils.data import Dataset, DataLoader |
| from os.path import join as pjoin |
| from tqdm import tqdm |
| import numpy as np |
| from eval.evaluator_modules import * |
|
|
| from torch.utils.data._utils.collate import default_collate |
|
|
|
|
| class GeneratedDataset(Dataset): |
| """ |
| opt.dataset_name |
| opt.max_motion_length |
| opt.unit_length |
| """ |
|
|
| def __init__( |
| self, opt, pipeline, dataset, w_vectorizer, mm_num_samples, mm_num_repeats |
| ): |
| assert mm_num_samples < len(dataset) |
| self.dataset = dataset |
| dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True) |
| generated_motion = [] |
| min_mov_length = 10 if opt.dataset_name == "t2m" else 6 |
|
|
| |
| mm_generated_motions = [] |
| if mm_num_samples > 0: |
| mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False) |
| mm_idxs = np.sort(mm_idxs) |
|
|
| all_caption = [] |
| all_m_lens = [] |
| all_data = [] |
| with torch.no_grad(): |
| for i, data in tqdm(enumerate(dataloader)): |
| word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data |
| all_data.append(data) |
| tokens = tokens[0].split("_") |
| mm_num_now = len(mm_generated_motions) |
| is_mm = ( |
| True |
| if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) |
| else False |
| ) |
| repeat_times = mm_num_repeats if is_mm else 1 |
| m_lens = max( |
| torch.div(m_lens, opt.unit_length, rounding_mode="trunc") |
| * opt.unit_length, |
| min_mov_length * opt.unit_length, |
| ) |
| m_lens = min(m_lens, opt.max_motion_length) |
| if isinstance(m_lens, int): |
| m_lens = torch.LongTensor([m_lens]).to(opt.device) |
| else: |
| m_lens = m_lens.to(opt.device) |
| for t in range(repeat_times): |
| all_m_lens.append(m_lens) |
| all_caption.extend(caption) |
| if is_mm: |
| mm_generated_motions.append(0) |
| all_m_lens = torch.stack(all_m_lens) |
|
|
| |
| with torch.no_grad(): |
| all_pred_motions, t_eval = pipeline.generate(all_caption, all_m_lens) |
| self.eval_generate_time = t_eval |
|
|
| cur_idx = 0 |
| mm_generated_motions = [] |
| with torch.no_grad(): |
| for i, data_dummy in tqdm(enumerate(dataloader)): |
| data = all_data[i] |
| word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data |
| tokens = tokens[0].split("_") |
| mm_num_now = len(mm_generated_motions) |
| is_mm = ( |
| True |
| if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) |
| else False |
| ) |
| repeat_times = mm_num_repeats if is_mm else 1 |
| mm_motions = [] |
| for t in range(repeat_times): |
| pred_motions = all_pred_motions[cur_idx] |
| cur_idx += 1 |
| if t == 0: |
| sub_dict = { |
| "motion": pred_motions.cpu().numpy(), |
| "length": pred_motions.shape[0], |
| "caption": caption[0], |
| "cap_len": cap_lens[0].item(), |
| "tokens": tokens, |
| } |
| generated_motion.append(sub_dict) |
|
|
| if is_mm: |
| mm_motions.append( |
| { |
| "motion": pred_motions.cpu().numpy(), |
| "length": pred_motions.shape[ |
| 0 |
| ], |
| } |
| ) |
| if is_mm: |
| mm_generated_motions.append( |
| { |
| "caption": caption[0], |
| "tokens": tokens, |
| "cap_len": cap_lens[0].item(), |
| "mm_motions": mm_motions, |
| } |
| ) |
| self.generated_motion = generated_motion |
| self.mm_generated_motion = mm_generated_motions |
| self.opt = opt |
| self.w_vectorizer = w_vectorizer |
|
|
| def __len__(self): |
| return len(self.generated_motion) |
|
|
| def __getitem__(self, item): |
| data = self.generated_motion[item] |
| motion, m_length, caption, tokens = ( |
| data["motion"], |
| data["length"], |
| data["caption"], |
| data["tokens"], |
| ) |
| sent_len = data["cap_len"] |
|
|
| |
| normed_motion = motion |
| denormed_motion = self.dataset.inv_transform(normed_motion) |
| renormed_motion = ( |
| denormed_motion - self.dataset.mean_for_eval |
| ) / self.dataset.std_for_eval |
| motion = renormed_motion |
|
|
| pos_one_hots = [] |
| word_embeddings = [] |
| for token in tokens: |
| word_emb, pos_oh = self.w_vectorizer[token] |
| pos_one_hots.append(pos_oh[None, :]) |
| word_embeddings.append(word_emb[None, :]) |
| pos_one_hots = np.concatenate(pos_one_hots, axis=0) |
| word_embeddings = np.concatenate(word_embeddings, axis=0) |
| length = len(motion) |
| if length < self.opt.max_motion_length: |
| motion = np.concatenate( |
| [ |
| motion, |
| np.zeros((self.opt.max_motion_length - length, motion.shape[1])), |
| ], |
| axis=0, |
| ) |
| return ( |
| word_embeddings, |
| pos_one_hots, |
| caption, |
| sent_len, |
| motion, |
| m_length, |
| "_".join(tokens), |
| ) |
|
|
|
|
| def collate_fn(batch): |
| batch.sort(key=lambda x: x[3], reverse=True) |
| return default_collate(batch) |
|
|
|
|
| class MMGeneratedDataset(Dataset): |
| def __init__(self, opt, motion_dataset, w_vectorizer): |
| self.opt = opt |
| self.dataset = motion_dataset.mm_generated_motion |
| self.w_vectorizer = w_vectorizer |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, item): |
| data = self.dataset[item] |
| mm_motions = data["mm_motions"] |
| m_lens = [] |
| motions = [] |
| for mm_motion in mm_motions: |
| m_lens.append(mm_motion["length"]) |
| motion = mm_motion["motion"] |
| if len(motion) < self.opt.max_motion_length: |
| motion = np.concatenate( |
| [ |
| motion, |
| np.zeros( |
| (self.opt.max_motion_length - len(motion), motion.shape[1]) |
| ), |
| ], |
| axis=0, |
| ) |
| motion = motion[None, :] |
| motions.append(motion) |
| m_lens = np.array(m_lens, dtype=np.int32) |
| motions = np.concatenate(motions, axis=0) |
| sort_indx = np.argsort(m_lens)[::-1].copy() |
|
|
| m_lens = m_lens[sort_indx] |
| motions = motions[sort_indx] |
| return motions, m_lens |
|
|
|
|
| def get_motion_loader( |
| opt, batch_size, pipeline, ground_truth_dataset, mm_num_samples, mm_num_repeats |
| ): |
|
|
| |
| if opt.dataset_name == "t2m" or opt.dataset_name == "kit": |
| w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab") |
| else: |
| raise KeyError("Dataset not recognized!!") |
|
|
| dataset = GeneratedDataset( |
| opt, |
| pipeline, |
| ground_truth_dataset, |
| w_vectorizer, |
| mm_num_samples, |
| mm_num_repeats, |
| ) |
| mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer) |
|
|
| motion_loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| collate_fn=collate_fn, |
| drop_last=True, |
| num_workers=4, |
| ) |
| mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) |
|
|
| return motion_loader, mm_motion_loader, dataset.eval_generate_time |
|
|