Spaces:
Build error
Build error
| from numpy import ma | |
| import torch | |
| import torch.nn as nn | |
| import torchvision as tv | |
| from torchvision.models import ResNet18_Weights | |
| from model.mdm import MDM | |
| def load_conditioning_images(images, image_size=224): | |
| """ | |
| Load and preprocess images for the MDMControlNet model. | |
| Args: | |
| images: sequence of 2 PIL Image objects (first and last frames) | |
| image_size: Target size for images (ResNet18 expects 224x224) | |
| Returns: | |
| Tensor of shape [1, 3, H, W] | |
| """ | |
| # ResNet18 normalization values | |
| # normalize = tv.transforms.Normalize( | |
| # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| # ) | |
| # transform = tv.transforms.Compose( | |
| # [ | |
| # tv.transforms.Resize(image_size), | |
| # tv.transforms.CenterCrop(image_size), | |
| # tv.transforms.ToTensor(), | |
| # normalize, | |
| # ] | |
| # ) | |
| # image_tensors = [] | |
| # for img in images: | |
| # img_tensor = transform(img) | |
| # image_tensors.append(img_tensor) | |
| return torch.stack(images) # [1, 3, H, W] | |
| class MDMControlNet(nn.Module): | |
| def __init__(self, base: MDM, args): | |
| super().__init__() | |
| self.base = base | |
| # TODO this is ugly | |
| self.dataset = base.dataset | |
| self.cond_mask_prob = base.cond_mask_prob | |
| self.rot2xyz = base.rot2xyz | |
| self.translation = base.translation | |
| self.njoints = base.njoints | |
| self.nfeats = base.nfeats | |
| self.data_rep = base.data_rep | |
| self.cond_mode = base.cond_mode | |
| self.encode_text = base.encode_text | |
| # Ensure the correct parameters are frozen | |
| self.base.requires_grad_(False) | |
| if args.use_fine_tuned_img_encoder: | |
| if self.base.latent_dim != 512: | |
| print("ERROR: image encoder is pretrained only for latent_dim = 512") | |
| exit(1) | |
| self.img_encoder = tv.models.resnet18() | |
| img_encoder_weights = torch.load( | |
| "save/image_encoder.pt", weights_only=True, map_location="cpu" | |
| ) | |
| self.img_encoder.load_state_dict(img_encoder_weights) | |
| else: | |
| self.img_encoder = tv.models.resnet18(weights=ResNet18_Weights.DEFAULT) | |
| self.img_encoder_head = nn.Linear( | |
| self.img_encoder.fc.out_features, self.base.latent_dim | |
| ) | |
| self.freeze_img_encoder = args.freeze_img_encoder | |
| if args.freeze_img_encoder: | |
| self.img_encoder.requires_grad_(False) | |
| self.img_encoder.eval() | |
| self.num_cond_frames = args.num_cond_frames | |
| self.residual_injectors_count = args.residual_injectors_count | |
| if self.residual_injectors_count > 8: | |
| print( | |
| "ERROR: the MDM transformer only has 8 layers -> max residual injectors count is 8" | |
| ) | |
| self.residual_injectors = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.Linear(2 * self.base.latent_dim, 4 * self.base.latent_dim), | |
| nn.LayerNorm(4 * self.base.latent_dim), | |
| nn.GELU(), | |
| nn.Linear(4 * self.base.latent_dim, 4 * self.base.latent_dim), | |
| nn.LayerNorm(4 * self.base.latent_dim), | |
| nn.GELU(), | |
| nn.Linear(4 * self.base.latent_dim, self.base.latent_dim), | |
| ) | |
| if args.use_extra_mlp_layer | |
| else nn.Sequential( | |
| nn.Linear(2 * self.base.latent_dim, 4 * self.base.latent_dim), | |
| nn.LayerNorm(4 * self.base.latent_dim), | |
| nn.GELU(), | |
| nn.Linear(4 * self.base.latent_dim, self.base.latent_dim), | |
| ) | |
| for _ in range(self.residual_injectors_count) | |
| ] | |
| ) | |
| def parameters_wo_clip(self): | |
| return [ | |
| p | |
| for name, p in self.named_parameters() | |
| if not name.startswith("clip_model.") | |
| ] | |
| def copied_block_parameters(self): | |
| return [p for p in self.copied_blocks.parameters() if p.requires_grad] | |
| def to(self, *args, **kwargs): | |
| # First call the parent class implementation (nn.Module.to) | |
| device_or_dtype = args[0] if args else kwargs.get("device") | |
| self = super().to(*args, **kwargs) | |
| # Ensure the image encoder is moved to the same device | |
| if hasattr(self, "img_encoder"): | |
| self.img_encoder = self.img_encoder.to(*args, **kwargs) | |
| # You might want to log this for debugging | |
| if isinstance(device_or_dtype, torch.device) or isinstance( | |
| device_or_dtype, str | |
| ): | |
| print(f"MDMControlNet moved to {device_or_dtype}") | |
| return self | |
| def process_images(self, cond_images, device): | |
| """ | |
| Args: | |
| cond_images: List of PIL images or tensor of shape [bs, *, 3, H, W] | |
| device: Device to run processing on | |
| Returns: | |
| Image embeddings tensor of shape [bs, latent_dim] | |
| """ | |
| # Make sure img_encoder is on the same device as we're processing on | |
| if next(self.img_encoder.parameters()).device != device: | |
| print(f"Moving image encoder to {device} for processing images") | |
| self.img_encoder = self.img_encoder.to(device) | |
| self.img_encoder_head = self.img_encoder_head.to(device) | |
| if not isinstance(cond_images, list) or not isinstance( | |
| cond_images[0], torch.Tensor | |
| ): | |
| print("WTF") | |
| exit() | |
| cond_images = [it.to(device) for it in cond_images] | |
| if cond_images[0].dim() == 4: # [num_frames, 3, H, W] | |
| if self.freeze_img_encoder: | |
| with torch.no_grad(): | |
| embeddings = [self.img_encoder(imgs) for imgs in cond_images] | |
| else: | |
| embeddings = [self.img_encoder(imgs) for imgs in cond_images] | |
| embeddings = [self.img_encoder_head(emb) for emb in embeddings] | |
| return embeddings | |
| else: | |
| # Assume it's already an embedding | |
| return cond_images | |
| def forward(self, motion, timesteps, y=None, cond_images=None, frame_indices=None): | |
| """ | |
| Args: | |
| motion: [bs, num_joints, num_features, num_frames] | |
| timesteps: [bs] | |
| y: | |
| cond_images: | |
| frame_indices: | |
| """ | |
| # First check if we have precomputed embeddings | |
| if y and "cond_embedding" in y: | |
| image_embedding = y["cond_embedding"] | |
| elif y and "cond_images" in y: | |
| cond_images = y["cond_images"] | |
| image_embedding = self.process_images(cond_images, motion.device) | |
| elif cond_images is not None: | |
| image_embedding = self.process_images(cond_images, motion.device) | |
| else: | |
| return self.base(motion, timesteps, y) | |
| if frame_indices is None: | |
| frame_indices = y["frame_indices"] | |
| base_sequence = self.base.motion_to_sequence( | |
| motion, timesteps, y | |
| ) # [num_frames + 1, bs, latent_dim] | |
| frames_mask = self.base.prepare_mask( | |
| base_sequence, motion.device, y, motion.shape[0] | |
| ) | |
| base_sequence = self.base.sequence_pos_encoder(base_sequence) | |
| for idx, base_layer in enumerate(self.base.seqTransEncoder.layers): | |
| if idx < self.residual_injectors_count: | |
| for i in range(len(image_embedding)): | |
| f_indices = frame_indices[i] | |
| sequence = base_sequence[f_indices + 1, i] | |
| base_sequence[f_indices + 1, i] += self.residual_injectors[idx]( | |
| torch.cat([image_embedding[i], sequence], dim=1) | |
| ) | |
| base_sequence = base_layer(base_sequence, src_key_padding_mask=frames_mask) | |
| motion = self.base.sequence_to_motion(base_sequence) | |
| return motion | |