| | |
| | import math |
| | import types |
| | from copy import deepcopy |
| | from einops import rearrange |
| | from typing import List |
| | import numpy as np |
| | import torch |
| | import torch.cuda.amp as amp |
| | import torch.nn as nn |
| |
|
| | def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): |
| | pose_latents = self.pose_patch_embedding(pose_latents.to(self.pose_patch_embedding.weight.dtype)) |
| | x[:, :, 1:] += pose_latents |
| | |
| | b,c,T,h,w = face_pixel_values.shape |
| | face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") |
| | encode_bs = 8 |
| | face_pixel_values_tmp = [] |
| | for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): |
| | face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) |
| |
|
| | motion_vec = torch.cat(face_pixel_values_tmp) |
| | |
| | motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) |
| | motion_vec = self.face_encoder(motion_vec.to(self.face_encoder.conv1_local.conv.weight.dtype)) |
| |
|
| | B, L, H, C = motion_vec.shape |
| | pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) |
| | motion_vec = torch.cat([pad_face, motion_vec], dim=1) |
| | return x, motion_vec |
| |
|