Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from einops import rearrange | |
| from polygraphy.backend.trt import Profile | |
| class unet_work(nn.Module): # Ugly Power Strip | |
| def __init__(self, pose_guider, motion_encoder, unet, vae, scheduler, timestep): | |
| super().__init__() | |
| self.pose_guider = pose_guider | |
| self.motion_encoder = motion_encoder | |
| self.unet = unet | |
| self.vae = vae | |
| self.scheduler = scheduler | |
| self.timesteps = timestep | |
| def decode_slice(self, vae, x): | |
| x = x / 0.18215 | |
| x = vae.decode(x).sample | |
| x = rearrange(x, "b c h w -> b h w c") | |
| x = (x / 2 + 0.5).clamp(0, 1) | |
| return x | |
| def forward(self, sample, encoder_hidden_states, motion_hidden_states, motion, pose_cond_fea, pose, new_noise, | |
| d00, d01, d10, d11, d20, d21, m, u10, u11, u12, u20, u21, u22, u30, u31, u32 | |
| ): | |
| new_pose_cond_fea = self.pose_guider(pose) | |
| pose_cond_fea = torch.cat([pose_cond_fea, new_pose_cond_fea], dim=2) | |
| new_motion_hidden_states = self.motion_encoder(motion) | |
| motion_hidden_states = torch.cat([motion_hidden_states, new_motion_hidden_states], dim=1) | |
| encoder_hidden_states = [encoder_hidden_states, motion_hidden_states] | |
| score = self.unet(sample, self.timesteps, encoder_hidden_states, pose_cond_fea, d00, d01, d10, d11, d20, d21, m, u10, u11, u12, u20, u21, u22, u30, u31, u32) | |
| score = rearrange(score, 'b c f h w -> (b f) c h w') | |
| sample = rearrange(sample, 'b c f h w -> (b f) c h w') | |
| latents_model_input, pred_original_sample = self.scheduler.step( | |
| score, self.timesteps, sample, return_dict=False | |
| ) | |
| latents_model_input = latents_model_input.to(sample.dtype) | |
| pred_original_sample = pred_original_sample.to(sample.dtype) | |
| latents_model_input = rearrange(latents_model_input, '(b f) c h w -> b c f h w', f=16) | |
| pred_video = self.decode_slice(self.vae, pred_original_sample[:4]) | |
| latents = torch.cat([latents_model_input[:, :, 4:, :, :], new_noise], dim=2) | |
| pose_cond_fea_out = pose_cond_fea[:, :, 4:, :, :] | |
| motion_hidden_states_out = motion_hidden_states[:, 4:, :, :] | |
| motion_out = motion_hidden_states[:, :1, :, :] | |
| return pred_video, latents, pose_cond_fea_out, motion_hidden_states_out, motion_out, pred_original_sample[:1] | |
| def get_sample_input(self, batchsize, height, width, dtype, device): | |
| tw, ts, tb = 4, 4, 16 # temporal window size| temporal adaptive steps | temporal batch size | |
| ml, mc, mh, mw= 32, 16, 224, 224 # motion latent size | motion channels | |
| b, h, w = batchsize, height, width | |
| lh, lw = height // 8, width // 8 # latent height | width | |
| cd0, cd1, cd2, cm, cu1, cu2, cu3 = 320, 640, 1280, 1280, 1280, 640, 320 # unet channels | |
| emb = 768 # CLIP Embedding Dims | TAESDV Channels | |
| lc, ic = 4, 3 # latent | image channels | |
| profile = { | |
| "sample" : [b, lc, tb, lh, lw], | |
| "encoder_hidden_states" : [b, 1, emb], | |
| "motion_hidden_states" : [b, tw * (ts - 1), ml, mc], | |
| "motion": [b, ic, tw, mh, mw], | |
| "pose_cond_fea" : [b, cd0, tw * (ts - 1), lh, lw], | |
| "pose" : [b, ic, tw, h, w], | |
| "new_noise" : [b, lc, tw, lh, lw], | |
| "d00" : [b, lh * lw, cd0], | |
| "d01" : [b, lh * lw, cd0], | |
| "d10" : [b, lh * lw // 4, cd1], | |
| "d11" : [b, lh * lw // 4, cd1], | |
| "d20" : [b, lh * lw // 16, cd2], | |
| "d21" : [b, lh * lw // 16, cd2], | |
| "m" : [b, lh * lw // 64, cm], | |
| "u10" : [b, lh * lw // 16, cu1], | |
| "u11" : [b, lh * lw // 16, cu1], | |
| "u12" : [b, lh * lw // 16, cu1], | |
| "u20" : [b, lh * lw // 4, cu2], | |
| "u21" : [b, lh * lw // 4, cu2], | |
| "u22" : [b, lh * lw // 4, cu2], | |
| "u30" : [b, lh * lw, cu3], | |
| "u31" : [b, lh * lw, cu3], | |
| "u32" : [b, lh * lw, cu3], | |
| } | |
| return {k: torch.randn(profile[k], dtype=dtype, device=device) for k in profile} | |
| def get_input_names(self): | |
| return ["sample", "encoder_hidden_states", "motion_hidden_states", | |
| "motion", "pose_cond_fea", "pose", "new_noise", | |
| "d00", "d01", "d10", "d11", "d20", "d21", "m", "u10", "u11", "u12", | |
| "u20", "u21", "u22", "u30", "u31", "u32"] | |
| def get_output_names(self): | |
| return ["pred_video", "latents", "pose_cond_fea_out", | |
| "motion_hidden_states_out", "motion_out", "latent_first"] | |
| def get_dynamic_axes(self): | |
| dynamic_axes = { | |
| "sample": {3:"h_64", 4:"w_64"}, | |
| "pose_cond_fea": {3:"h_64", 4:"w_64"}, | |
| "pose": {3:"h_512", 4:"h_512"}, | |
| "new_noise": {3: "h_64", 4: "w_64"}, | |
| "d00" : {1: "len_4096"}, | |
| "d01" : {1: "len_4096"}, | |
| "u30" : {1: "len_4096"}, | |
| "u31" : {1: "len_4096"}, | |
| "u32" : {1: "len_4096"}, | |
| "d10" : {1: "len_1024"}, | |
| "d11" : {1: "len_1024"}, | |
| "u20" : {1: "len_1024"}, | |
| "u21" : {1: "len_1024"}, | |
| "u22" : {1: "len_1024"}, | |
| "d20" : {1: "len_256"}, | |
| "d21" : {1: "len_256"}, | |
| "u10" : {1: "len_256"}, | |
| "u11" : {1: "len_256"}, | |
| "u12" : {1: "len_256"}, | |
| "m" : {1: "len_64"}, | |
| } | |
| return dynamic_axes | |
| def get_dynamic_map(self, batchsize, height, width): | |
| tw, ts, tb = 4, 4, 16 # temporal window size| temporal adaptive steps | temporal batch size | |
| ml, mc, mh, mw= 32, 16, 224, 224 # motion latent size | motion channels | |
| b, h, w = batchsize, height, width | |
| lh, lw = height // 8, width // 8 # latent height | width | |
| cd0, cd1, cd2, cm, cu1, cu2, cu3 = 320, 640, 1280, 1280, 1280, 640, 320 # unet channels | |
| emb = 768 # CLIP Embedding Dims | TAESDV Channels | |
| lc, ic = 4, 3 # latent | image channels | |
| fixed_inputs_map = { | |
| "sample": (b, lc, tb, lh, lw), | |
| "encoder_hidden_states": (b, 1, emb), | |
| "motion_hidden_states": (b, tw * (ts - 1), ml, mc), | |
| "motion": (b, ic, tw, mh, mw), | |
| "pose_cond_fea": (b, cd0, tw * (ts - 1), lh, lw), | |
| "pose": (b, ic, tw, h, w), | |
| "new_noise": (b, lc, tw, lh, lw), | |
| } | |
| dynamic_inputs_map = { | |
| "d00": (b, lh * lw, cd0), | |
| "d01": (b, lh * lw, cd0), | |
| "d10": (b, lh * lw // 4, cd1), | |
| "d11": (b, lh * lw // 4, cd1), | |
| "d20": (b, lh * lw // 16, cd2), | |
| "d21": (b, lh * lw // 16, cd2), | |
| "m": (b, lh * lw // 64, cm), | |
| "u10": (b, lh * lw // 16, cu1), | |
| "u11": (b, lh * lw // 16, cu1), | |
| "u12": (b, lh * lw // 16, cu1), | |
| "u20": (b, lh * lw // 4, cu2), | |
| "u21": (b, lh * lw // 4, cu2), | |
| "u22": (b, lh * lw // 4, cu2), | |
| "u30": (b, lh * lw, cu3), | |
| "u31": (b, lh * lw, cu3), | |
| "u32": (b, lh * lw, cu3), | |
| } | |
| profile = Profile() | |
| for name, shape in fixed_inputs_map.items(): | |
| shape_tuple = tuple(shape) | |
| profile.add(name, min=shape_tuple, opt=shape_tuple, max=shape_tuple) | |
| for name, base_shape in dynamic_inputs_map.items(): | |
| dim0, dim1_base, dim2 = base_shape | |
| val_1x = dim1_base * 1 | |
| val_2x = dim1_base * 2 | |
| val_4x = dim1_base * 4 | |
| min_shape = (dim0, val_1x, dim2) | |
| opt_shape = (dim0, val_2x, dim2) | |
| max_shape = (dim0, val_4x, dim2) | |
| profile.add(name, min=min_shape, opt=opt_shape, max=max_shape) | |
| print(f"Dynamic: {name:<5} | Base(1x): {dim1_base:<5} | Range: {val_1x} ~ {val_4x} | Opt: {val_2x}") | |
| return profile |