mdm / model /mdm_controlnet.py
hassanjbara's picture
update model
5007d4b
from numpy import ma
import torch
import torch.nn as nn
import torchvision as tv
from torchvision.models import ResNet18_Weights
from model.mdm import MDM
def load_conditioning_images(images, image_size=224):
"""
Load and preprocess images for the MDMControlNet model.
Args:
images: sequence of 2 PIL Image objects (first and last frames)
image_size: Target size for images (ResNet18 expects 224x224)
Returns:
Tensor of shape [1, 3, H, W]
"""
# ResNet18 normalization values
# normalize = tv.transforms.Normalize(
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# )
# transform = tv.transforms.Compose(
# [
# tv.transforms.Resize(image_size),
# tv.transforms.CenterCrop(image_size),
# tv.transforms.ToTensor(),
# normalize,
# ]
# )
# image_tensors = []
# for img in images:
# img_tensor = transform(img)
# image_tensors.append(img_tensor)
return torch.stack(images) # [1, 3, H, W]
class MDMControlNet(nn.Module):
def __init__(self, base: MDM, args):
super().__init__()
self.base = base
# TODO this is ugly
self.dataset = base.dataset
self.cond_mask_prob = base.cond_mask_prob
self.rot2xyz = base.rot2xyz
self.translation = base.translation
self.njoints = base.njoints
self.nfeats = base.nfeats
self.data_rep = base.data_rep
self.cond_mode = base.cond_mode
self.encode_text = base.encode_text
# Ensure the correct parameters are frozen
self.base.requires_grad_(False)
if args.use_fine_tuned_img_encoder:
if self.base.latent_dim != 512:
print("ERROR: image encoder is pretrained only for latent_dim = 512")
exit(1)
self.img_encoder = tv.models.resnet18()
img_encoder_weights = torch.load(
"save/image_encoder.pt", weights_only=True, map_location="cpu"
)
self.img_encoder.load_state_dict(img_encoder_weights)
else:
self.img_encoder = tv.models.resnet18(weights=ResNet18_Weights.DEFAULT)
self.img_encoder_head = nn.Linear(
self.img_encoder.fc.out_features, self.base.latent_dim
)
self.freeze_img_encoder = args.freeze_img_encoder
if args.freeze_img_encoder:
self.img_encoder.requires_grad_(False)
self.img_encoder.eval()
self.num_cond_frames = args.num_cond_frames
self.residual_injectors_count = args.residual_injectors_count
if self.residual_injectors_count > 8:
print(
"ERROR: the MDM transformer only has 8 layers -> max residual injectors count is 8"
)
self.residual_injectors = nn.ModuleList(
[
nn.Sequential(
nn.Linear(2 * self.base.latent_dim, 4 * self.base.latent_dim),
nn.LayerNorm(4 * self.base.latent_dim),
nn.GELU(),
nn.Linear(4 * self.base.latent_dim, 4 * self.base.latent_dim),
nn.LayerNorm(4 * self.base.latent_dim),
nn.GELU(),
nn.Linear(4 * self.base.latent_dim, self.base.latent_dim),
)
if args.use_extra_mlp_layer
else nn.Sequential(
nn.Linear(2 * self.base.latent_dim, 4 * self.base.latent_dim),
nn.LayerNorm(4 * self.base.latent_dim),
nn.GELU(),
nn.Linear(4 * self.base.latent_dim, self.base.latent_dim),
)
for _ in range(self.residual_injectors_count)
]
)
def parameters_wo_clip(self):
return [
p
for name, p in self.named_parameters()
if not name.startswith("clip_model.")
]
def copied_block_parameters(self):
return [p for p in self.copied_blocks.parameters() if p.requires_grad]
def to(self, *args, **kwargs):
# First call the parent class implementation (nn.Module.to)
device_or_dtype = args[0] if args else kwargs.get("device")
self = super().to(*args, **kwargs)
# Ensure the image encoder is moved to the same device
if hasattr(self, "img_encoder"):
self.img_encoder = self.img_encoder.to(*args, **kwargs)
# You might want to log this for debugging
if isinstance(device_or_dtype, torch.device) or isinstance(
device_or_dtype, str
):
print(f"MDMControlNet moved to {device_or_dtype}")
return self
def process_images(self, cond_images, device):
"""
Args:
cond_images: List of PIL images or tensor of shape [bs, *, 3, H, W]
device: Device to run processing on
Returns:
Image embeddings tensor of shape [bs, latent_dim]
"""
# Make sure img_encoder is on the same device as we're processing on
if next(self.img_encoder.parameters()).device != device:
print(f"Moving image encoder to {device} for processing images")
self.img_encoder = self.img_encoder.to(device)
self.img_encoder_head = self.img_encoder_head.to(device)
if not isinstance(cond_images, list) or not isinstance(
cond_images[0], torch.Tensor
):
print("WTF")
exit()
cond_images = [it.to(device) for it in cond_images]
if cond_images[0].dim() == 4: # [num_frames, 3, H, W]
if self.freeze_img_encoder:
with torch.no_grad():
embeddings = [self.img_encoder(imgs) for imgs in cond_images]
else:
embeddings = [self.img_encoder(imgs) for imgs in cond_images]
embeddings = [self.img_encoder_head(emb) for emb in embeddings]
return embeddings
else:
# Assume it's already an embedding
return cond_images
def forward(self, motion, timesteps, y=None, cond_images=None, frame_indices=None):
"""
Args:
motion: [bs, num_joints, num_features, num_frames]
timesteps: [bs]
y:
cond_images:
frame_indices:
"""
# First check if we have precomputed embeddings
if y and "cond_embedding" in y:
image_embedding = y["cond_embedding"]
elif y and "cond_images" in y:
cond_images = y["cond_images"]
image_embedding = self.process_images(cond_images, motion.device)
elif cond_images is not None:
image_embedding = self.process_images(cond_images, motion.device)
else:
return self.base(motion, timesteps, y)
if frame_indices is None:
frame_indices = y["frame_indices"]
base_sequence = self.base.motion_to_sequence(
motion, timesteps, y
) # [num_frames + 1, bs, latent_dim]
frames_mask = self.base.prepare_mask(
base_sequence, motion.device, y, motion.shape[0]
)
base_sequence = self.base.sequence_pos_encoder(base_sequence)
for idx, base_layer in enumerate(self.base.seqTransEncoder.layers):
if idx < self.residual_injectors_count:
for i in range(len(image_embedding)):
f_indices = frame_indices[i]
sequence = base_sequence[f_indices + 1, i]
base_sequence[f_indices + 1, i] += self.residual_injectors[idx](
torch.cat([image_embedding[i], sequence], dim=1)
)
base_sequence = base_layer(base_sequence, src_key_padding_mask=frames_mask)
motion = self.base.sequence_to_motion(base_sequence)
return motion