| import torch | |
| import copy | |
| import numpy as np | |
| from typing import OrderedDict | |
| from scipy.ndimage import gaussian_filter1d | |
| from transformers import PreTrainedModel | |
| from in2in.utils.configs import get_config | |
| from in2in.models.in2in import in2IN | |
| from in2in.utils.preprocess import MotionNormalizer | |
| from .config import in2INConfig | |
| class in2INModel(PreTrainedModel): | |
| config_class = in2INConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.mode = config.MODE | |
| self.model = in2IN(config, mode=config.MODE) | |
| self.normalizer = MotionNormalizer() | |
| def forward(self, prompt_interaction, prompt_individual1, prompt_individual2): | |
| self.model.eval() | |
| batch = OrderedDict({}) | |
| batch["motion_lens"] = torch.zeros(1,1).long() | |
| batch["prompt_interaction"] = prompt_interaction | |
| if self.mode != "individual": | |
| batch["prompt_individual1"] = prompt_individual1 | |
| batch["prompt_individual2"] = prompt_individual2 | |
| window_size = 210 | |
| motion_output = self.generate_loop(batch, window_size) | |
| return motion_output | |
| def generate_loop(self, batch, window_size): | |
| prompt_interaction = batch["prompt_interaction"] | |
| if self.mode != "individual": | |
| prompt_individual1 = batch["prompt_individual1"] | |
| prompt_individual2 = batch["prompt_individual2"] | |
| batch = copy.deepcopy(batch) | |
| batch["motion_lens"][:] = window_size | |
| batch["text"] = [prompt_interaction] | |
| if self.mode != "individual": | |
| batch["text_individual1"] = [prompt_individual1] | |
| batch["text_individual2"] = [prompt_individual2] | |
| batch = self.model.forward_test(batch) | |
| motion_output_both = batch["output"][0].reshape(batch["output"][0].shape[0], 2, -1) | |
| motion_output_both = self.normalizer.backward(motion_output_both.cpu().detach().numpy()) | |
| sequences = [[], []] | |
| for j in range(2): | |
| motion_output = motion_output_both[:,j] | |
| joints3d = motion_output[:,:22*3].reshape(-1,22,3) | |
| joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest') | |
| sequences[j].append(joints3d) | |
| sequences[0] = np.concatenate(sequences[0], axis=0) | |
| sequences[1] = np.concatenate(sequences[1], axis=0) | |
| return sequences | |