Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from einops import rearrange | |
| from lvdm.models.ddpm3d import LatentDiffusion | |
| from motionctrl.lvdm_modified_modules import ( | |
| TemporalTransformer_forward, selfattn_forward_unet, | |
| spatial_forward_BasicTransformerBlock, | |
| temporal_selfattn_forward_BasicTransformerBlock) | |
| from utils.utils import instantiate_from_config | |
| class MotionCtrl(LatentDiffusion): | |
| def __init__(self, | |
| omcm_config=None, | |
| pose_dim=12, | |
| context_dim=1024, | |
| *args, | |
| **kwargs): | |
| super(MotionCtrl, self).__init__(*args, **kwargs) | |
| # object motion control module | |
| if omcm_config is not None: | |
| self.omcm = instantiate_from_config(omcm_config) | |
| else: | |
| self.omcm = None | |
| # camera motion control module | |
| bound_method = selfattn_forward_unet.__get__( | |
| self.model.diffusion_model, | |
| self.model.diffusion_model.__class__) | |
| setattr(self.model.diffusion_model, 'forward', bound_method) | |
| for _name, _module in self.model.diffusion_model.named_modules(): | |
| if _module.__class__.__name__ == 'TemporalTransformer': | |
| bound_method = TemporalTransformer_forward.__get__( | |
| _module, _module.__class__) | |
| setattr(_module, 'forward', bound_method) | |
| if _module.__class__.__name__ == 'BasicTransformerBlock': | |
| # SpatialTransformer only | |
| if _module.attn2.to_k.in_features != context_dim: # TemporalTransformer without crossattn | |
| bound_method = temporal_selfattn_forward_BasicTransformerBlock.__get__( | |
| _module, _module.__class__) | |
| setattr(_module, '_forward', bound_method) | |
| cc_projection = nn.Linear(_module.attn2.to_k.in_features + pose_dim, _module.attn2.to_k.in_features) | |
| nn.init.eye_(list(cc_projection.parameters())[0][:_module.attn2.to_k.in_features, :_module.attn2.to_k.in_features]) | |
| nn.init.zeros_(list(cc_projection.parameters())[1]) | |
| cc_projection.requires_grad_(True) | |
| _module.add_module('cc_projection', cc_projection) | |
| else: | |
| bound_method = spatial_forward_BasicTransformerBlock.__get__( | |
| _module, _module.__class__) | |
| setattr(_module, '_forward', bound_method) | |
| def get_traj_features(self, extra_cond): | |
| b, c, t, h, w = extra_cond.shape | |
| ## process in 2D manner | |
| extra_cond = rearrange(extra_cond, 'b c t h w -> (b t) c h w') | |
| traj_features = self.omcm(extra_cond) | |
| traj_features = [rearrange(feature, '(b t) c h w -> b c t h w', b=b, t=t) for feature in traj_features] | |
| return traj_features | |