| import numpy as np |
| import torch |
| from torch import nn |
|
|
| from typing import Optional, Dict, List |
|
|
| from mogen.models.utils.misc import zero_module |
|
|
| from ..builder import SUBMODULES, build_attention |
| from ..utils.stylization_block import StylizationBlock |
| from .motion_transformer import MotionTransformer |
|
|
|
|
| def get_kit_slice(idx: int) -> List[int]: |
| """ |
| Get the slice indices for the KIT skeleton. |
| |
| Args: |
| idx (int): The index of the skeleton part. |
| |
| Returns: |
| List[int]: Slice indices for the specified skeleton part. |
| """ |
| if idx == 0: |
| return [0, 1, 2, 3, 184, 185, 186, 247, 248, 249, 250] |
| return [ |
| 4 + (idx - 1) * 3, |
| 4 + (idx - 1) * 3 + 1, |
| 4 + (idx - 1) * 3 + 2, |
| 64 + (idx - 1) * 6, |
| 64 + (idx - 1) * 6 + 1, |
| 64 + (idx - 1) * 6 + 2, |
| 64 + (idx - 1) * 6 + 3, |
| 64 + (idx - 1) * 6 + 4, |
| 64 + (idx - 1) * 6 + 5, |
| 184 + idx * 3, |
| 184 + idx * 3 + 1, |
| 184 + idx * 3 + 2, |
| ] |
|
|
|
|
| def get_t2m_slice(idx: int) -> List[int]: |
| """ |
| Get the slice indices for the T2M skeleton. |
| |
| Args: |
| idx (int): The index of the skeleton part. |
| |
| Returns: |
| List[int]: Slice indices for the specified skeleton part. |
| """ |
| if idx == 0: |
| return [0, 1, 2, 3, 193, 194, 195, 259, 260, 261, 262] |
| return [ |
| 4 + (idx - 1) * 3, |
| 4 + (idx - 1) * 3 + 1, |
| 4 + (idx - 1) * 3 + 2, |
| 67 + (idx - 1) * 6, |
| 67 + (idx - 1) * 6 + 1, |
| 67 + (idx - 1) * 6 + 2, |
| 67 + (idx - 1) * 6 + 3, |
| 67 + (idx - 1) * 6 + 4, |
| 67 + (idx - 1) * 6 + 5, |
| 193 + idx * 3, |
| 193 + idx * 3 + 1, |
| 193 + idx * 3 + 2, |
| ] |
|
|
|
|
| def get_part_slice(idx_list: List[int], func) -> List[int]: |
| """ |
| Get the slice indices for a list of indices. |
| |
| Args: |
| idx_list (List[int]): List of part indices. |
| func (Callable): Function to get slice indices for each part. |
| |
| Returns: |
| List[int]: Concatenated list of slice indices for the parts. |
| """ |
| result = [] |
| for idx in idx_list: |
| result.extend(func(idx)) |
| return result |
|
|
|
|
| class PoseEncoder(nn.Module): |
| """ |
| Pose Encoder to process motion data and encode body parts into latent representations. |
| """ |
|
|
| def __init__(self, |
| dataset_name: str = "human_ml3d", |
| latent_dim: int = 64, |
| input_dim: int = 263): |
| super().__init__() |
| self.dataset_name = dataset_name |
| if dataset_name == "human_ml3d": |
| func = get_t2m_slice |
| self.head_slice = get_part_slice([12, 15], func) |
| self.stem_slice = get_part_slice([3, 6, 9], func) |
| self.larm_slice = get_part_slice([14, 17, 19, 21], func) |
| self.rarm_slice = get_part_slice([13, 16, 18, 20], func) |
| self.lleg_slice = get_part_slice([2, 5, 8, 11], func) |
| self.rleg_slice = get_part_slice([1, 4, 7, 10], func) |
| self.root_slice = get_part_slice([0], func) |
| self.body_slice = get_part_slice([_ for _ in range(22)], func) |
| elif dataset_name == "kit_ml": |
| func = get_kit_slice |
| self.head_slice = get_part_slice([4], func) |
| self.stem_slice = get_part_slice([1, 2, 3], func) |
| self.larm_slice = get_part_slice([8, 9, 10], func) |
| self.rarm_slice = get_part_slice([5, 6, 7], func) |
| self.lleg_slice = get_part_slice([16, 17, 18, 19, 20], func) |
| self.rleg_slice = get_part_slice([11, 12, 13, 14, 15], func) |
| self.root_slice = get_part_slice([0], func) |
| self.body_slice = get_part_slice([_ for _ in range(21)], func) |
| else: |
| raise ValueError() |
|
|
| self.head_embed = nn.Linear(len(self.head_slice), latent_dim) |
| self.stem_embed = nn.Linear(len(self.stem_slice), latent_dim) |
| self.larm_embed = nn.Linear(len(self.larm_slice), latent_dim) |
| self.rarm_embed = nn.Linear(len(self.rarm_slice), latent_dim) |
| self.lleg_embed = nn.Linear(len(self.lleg_slice), latent_dim) |
| self.rleg_embed = nn.Linear(len(self.rleg_slice), latent_dim) |
| self.root_embed = nn.Linear(len(self.root_slice), latent_dim) |
| self.body_embed = nn.Linear(len(self.body_slice), latent_dim) |
|
|
| assert len(set(self.body_slice)) == input_dim |
|
|
| def forward(self, motion: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass for encoding the motion into body part embeddings. |
| |
| Args: |
| motion (torch.Tensor): Input motion tensor of shape (B, T, D). |
| |
| Returns: |
| torch.Tensor: Concatenated latent representations of body parts. |
| """ |
| head_feat = self.head_embed(motion[:, :, self.head_slice].contiguous()) |
| stem_feat = self.stem_embed(motion[:, :, self.stem_slice].contiguous()) |
| larm_feat = self.larm_embed(motion[:, :, self.larm_slice].contiguous()) |
| rarm_feat = self.rarm_embed(motion[:, :, self.rarm_slice].contiguous()) |
| lleg_feat = self.lleg_embed(motion[:, :, self.lleg_slice].contiguous()) |
| rleg_feat = self.rleg_embed(motion[:, :, self.rleg_slice].contiguous()) |
| root_feat = self.root_embed(motion[:, :, self.root_slice].contiguous()) |
| body_feat = self.body_embed(motion[:, :, self.body_slice].contiguous()) |
| feat = torch.cat((head_feat, stem_feat, larm_feat, rarm_feat, |
| lleg_feat, rleg_feat, root_feat, body_feat), |
| dim=-1) |
| return feat |
|
|
|
|
| class PoseDecoder(nn.Module): |
| """ |
| Pose Decoder to decode the latent representations of body parts back into motion. |
| """ |
|
|
| def __init__(self, |
| dataset_name: str = "human_ml3d", |
| latent_dim: int = 64, |
| output_dim: int = 263): |
| super().__init__() |
| self.dataset_name = dataset_name |
| self.latent_dim = latent_dim |
| self.output_dim = output_dim |
| if dataset_name == "human_ml3d": |
| func = get_t2m_slice |
| self.head_slice = get_part_slice([12, 15], func) |
| self.stem_slice = get_part_slice([3, 6, 9], func) |
| self.larm_slice = get_part_slice([14, 17, 19, 21], func) |
| self.rarm_slice = get_part_slice([13, 16, 18, 20], func) |
| self.lleg_slice = get_part_slice([2, 5, 8, 11], func) |
| self.rleg_slice = get_part_slice([1, 4, 7, 10], func) |
| self.root_slice = get_part_slice([0], func) |
| self.body_slice = get_part_slice([_ for _ in range(22)], func) |
| elif dataset_name == "kit_ml": |
| func = get_kit_slice |
| self.head_slice = get_part_slice([4], func) |
| self.stem_slice = get_part_slice([1, 2, 3], func) |
| self.larm_slice = get_part_slice([8, 9, 10], func) |
| self.rarm_slice = get_part_slice([5, 6, 7], func) |
| self.lleg_slice = get_part_slice([16, 17, 18, 19, 20], func) |
| self.rleg_slice = get_part_slice([11, 12, 13, 14, 15], func) |
| self.root_slice = get_part_slice([0], func) |
| self.body_slice = get_part_slice([_ for _ in range(21)], func) |
| else: |
| raise ValueError() |
|
|
| self.head_out = nn.Linear(latent_dim, len(self.head_slice)) |
| self.stem_out = nn.Linear(latent_dim, len(self.stem_slice)) |
| self.larm_out = nn.Linear(latent_dim, len(self.larm_slice)) |
| self.rarm_out = nn.Linear(latent_dim, len(self.rarm_slice)) |
| self.lleg_out = nn.Linear(latent_dim, len(self.lleg_slice)) |
| self.rleg_out = nn.Linear(latent_dim, len(self.rleg_slice)) |
| self.root_out = nn.Linear(latent_dim, len(self.root_slice)) |
| self.body_out = nn.Linear(latent_dim, len(self.body_slice)) |
|
|
| def forward(self, motion: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass to decode the latent body part features back to motion. |
| |
| Args: |
| motion (torch.Tensor): Input tensor of shape (B, T, D). |
| |
| Returns: |
| torch.Tensor: Output motion tensor of shape (B, T, output_dim). |
| """ |
| B, T = motion.shape[:2] |
| D = self.latent_dim |
| head_feat = self.head_out(motion[:, :, :D].contiguous()) |
| stem_feat = self.stem_out(motion[:, :, D:2 * D].contiguous()) |
| larm_feat = self.larm_out(motion[:, :, 2 * D:3 * D].contiguous()) |
| rarm_feat = self.rarm_out(motion[:, :, 3 * D:4 * D].contiguous()) |
| lleg_feat = self.lleg_out(motion[:, :, 4 * D:5 * D].contiguous()) |
| rleg_feat = self.rleg_out(motion[:, :, 5 * D:6 * D].contiguous()) |
| root_feat = self.root_out(motion[:, :, 6 * D:7 * D].contiguous()) |
| body_feat = self.body_out(motion[:, :, 7 * D:].contiguous()) |
| output = torch.zeros(B, T, self.output_dim).type_as(motion) |
| output[:, :, self.head_slice] = head_feat |
| output[:, :, self.stem_slice] = stem_feat |
| output[:, :, self.larm_slice] = larm_feat |
| output[:, :, self.rarm_slice] = rarm_feat |
| output[:, :, self.lleg_slice] = lleg_feat |
| output[:, :, self.rleg_slice] = rleg_feat |
| output[:, :, self.root_slice] = root_feat |
| output = (output + body_feat) / 2.0 |
| return output |
|
|
|
|
| class SFFN(nn.Module): |
| """ |
| A Stylized Feed-Forward Network (SFFN) module for transformer layers. |
| |
| Args: |
| latent_dim (int): Dimensionality of the input. |
| ffn_dim (int): Dimensionality of the feed-forward layer. |
| dropout (float): Dropout probability. |
| time_embed_dim (int): Dimensionality of the time embedding. |
| norm (str): Normalization type ('None'). |
| activation (str): Activation function ('GELU'). |
| """ |
|
|
| def __init__(self, |
| latent_dim: int, |
| ffn_dim: int, |
| dropout: float, |
| time_embed_dim: int, |
| norm: str = "None", |
| activation: str = "GELU", |
| **kwargs): |
| super().__init__() |
| self.linear1_list = nn.ModuleList() |
| self.linear2_list = nn.ModuleList() |
|
|
| channel_mul = 1 |
| if activation == "GELU": |
| self.activation = nn.GELU() |
|
|
| for i in range(8): |
| self.linear1_list.append(nn.Linear(latent_dim, ffn_dim * channel_mul)) |
| self.linear2_list.append(nn.Linear(ffn_dim, latent_dim)) |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.proj_out = StylizationBlock(latent_dim * 8, time_embed_dim, dropout) |
|
|
| if norm == "None": |
| self.norm = nn.Identity() |
|
|
| def forward(self, x: torch.Tensor, emb: torch.Tensor, **kwargs) -> torch.Tensor: |
| """ |
| Forward pass of the SFFN layer. |
| |
| Args: |
| x (torch.Tensor): Input tensor of shape (B, T, D). |
| emb (torch.Tensor): Embedding tensor for time step modulation. |
| |
| Returns: |
| torch.Tensor: Output tensor of shape (B, T, D). |
| """ |
| B, T, D = x.shape |
| x = self.norm(x) |
| x = x.reshape(B, T, 8, -1) |
| output = [] |
| for i in range(8): |
| feat = x[:, :, i].contiguous() |
| feat = self.dropout(self.activation(self.linear1_list[i](feat))) |
| feat = self.linear2_list[i](feat) |
| output.append(feat) |
| y = torch.cat(output, dim=-1) |
| y = x.reshape(B, T, D) + self.proj_out(y, emb) |
| return y |
|
|
|
|
| class DecoderLayer(nn.Module): |
| """ |
| A transformer decoder layer with cross-attention and feed-forward network (SFFN). |
| |
| Args: |
| ca_block_cfg (Optional[Dict]): Configuration for the cross-attention block. |
| ffn_cfg (Optional[Dict]): Configuration for the feed-forward network (SFFN). |
| """ |
|
|
| def __init__(self, ca_block_cfg: Optional[Dict] = None, ffn_cfg: Optional[Dict] = None): |
| super().__init__() |
| self.ca_block = build_attention(ca_block_cfg) |
| self.ffn = SFFN(**ffn_cfg) |
|
|
| def forward(self, **kwargs) -> torch.Tensor: |
| """ |
| Forward pass of the decoder layer. |
| |
| Args: |
| kwargs: Keyword arguments for attention and feed-forward layers. |
| |
| Returns: |
| torch.Tensor: Output of the decoder layer. |
| """ |
| if self.ca_block is not None: |
| x = self.ca_block(**kwargs) |
| kwargs.update({'x': x}) |
| if self.ffn is not None: |
| x = self.ffn(**kwargs) |
| return x |
|
|
|
|
| @SUBMODULES.register_module() |
| class FineMoGenTransformer(MotionTransformer): |
| """ |
| A transformer model for motion generation using fine-grained control with Diffusion. |
| |
| Args: |
| scale_func_cfg (Optional[Dict]): Configuration for scaling function. |
| pose_encoder_cfg (Optional[Dict]): Configuration for the PoseEncoder. |
| pose_decoder_cfg (Optional[Dict]): Configuration for the PoseDecoder. |
| moe_route_loss_weight (float): Weight for the Mixture of Experts (MoE) routing loss. |
| template_kl_loss_weight (float): Weight for the KL loss in template generation. |
| fine_mode (bool): Whether to enable fine mode for control over body parts. |
| """ |
|
|
| def __init__(self, |
| scale_func_cfg: Optional[Dict] = None, |
| pose_encoder_cfg: Optional[Dict] = None, |
| pose_decoder_cfg: Optional[Dict] = None, |
| moe_route_loss_weight: float = 1.0, |
| template_kl_loss_weight: float = 0.0001, |
| fine_mode: bool = False, |
| **kwargs): |
| super().__init__(**kwargs) |
| self.scale_func_cfg = scale_func_cfg |
| self.joint_embed = PoseEncoder(**pose_encoder_cfg) |
| self.out = zero_module(PoseDecoder(**pose_decoder_cfg)) |
| self.moe_route_loss_weight = moe_route_loss_weight |
| self.template_kl_loss_weight = template_kl_loss_weight |
| self.mean = np.load("data/datasets/kit_ml/mean.npy") |
| self.std = np.load("data/datasets/kit_ml/std.npy") |
| self.fine_mode = fine_mode |
|
|
| def build_temporal_blocks(self, sa_block_cfg: Optional[Dict], ca_block_cfg: Optional[Dict], ffn_cfg: Optional[Dict]): |
| """ |
| Build temporal decoder blocks for the model. |
| |
| Args: |
| sa_block_cfg (Optional[Dict]): Configuration for self-attention blocks. |
| ca_block_cfg (Optional[Dict]): Configuration for cross-attention blocks. |
| ffn_cfg (Optional[Dict]): Configuration for feed-forward networks. |
| """ |
| self.temporal_decoder_blocks = nn.ModuleList() |
| for i in range(self.num_layers): |
| if isinstance(ffn_cfg, list): |
| ffn_cfg_block = ffn_cfg[i] |
| else: |
| ffn_cfg_block = ffn_cfg |
| self.temporal_decoder_blocks.append(DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg_block)) |
|
|
| def scale_func(self, timestep: int) -> Dict[str, float]: |
| """ |
| Scaling function for text and none coefficient based on timestep. |
| |
| Args: |
| timestep (int): Current diffusion timestep. |
| |
| Returns: |
| Dict[str, float]: Scaling factors for text and non-text conditioning. |
| """ |
| scale = self.scale_func_cfg['scale'] |
| w = (1 - (1000 - timestep) / 1000) * scale + 1 |
| return {'text_coef': w, 'none_coef': 1 - w} |
|
|
| def aux_loss(self) -> Dict[str, torch.Tensor]: |
| """ |
| Auxiliary loss computation for MoE routing and KL loss. |
| |
| Returns: |
| Dict[str, torch.Tensor]: Computed auxiliary losses. |
| """ |
| aux_loss = 0 |
| kl_loss = 0 |
| for module in self.temporal_decoder_blocks: |
| if hasattr(module.ca_block, 'aux_loss'): |
| aux_loss = aux_loss + module.ca_block.aux_loss |
| if hasattr(module.ca_block, 'kl_loss'): |
| kl_loss = kl_loss + module.ca_block.kl_loss |
| losses = {} |
| if aux_loss > 0: |
| losses['moe_route_loss'] = aux_loss * self.moe_route_loss_weight |
| if kl_loss > 0: |
| losses['template_kl_loss'] = kl_loss * self.template_kl_loss_weight |
| return losses |
|
|
| def get_precompute_condition(self, |
| text: Optional[str] = None, |
| motion_length: Optional[torch.Tensor] = None, |
| xf_out: Optional[torch.Tensor] = None, |
| re_dict: Optional[Dict] = None, |
| device: Optional[torch.device] = None, |
| sample_idx: Optional[int] = None, |
| clip_feat: Optional[torch.Tensor] = None, |
| **kwargs) -> Dict[str, torch.Tensor]: |
| """ |
| Precompute conditioning features for text or other modalities. |
| |
| Args: |
| text (Optional[str]): Text input for conditioning. |
| motion_length (Optional[torch.Tensor]): Length of the motion sequence. |
| xf_out (Optional[torch.Tensor]): Precomputed text features. |
| re_dict (Optional[Dict]): Additional features dictionary. |
| device (Optional[torch.device]): Target device for the model. |
| sample_idx (Optional[int]): Sample index for specific conditioning. |
| clip_feat (Optional[torch.Tensor]): Precomputed CLIP features. |
| |
| Returns: |
| Dict[str, torch.Tensor]: Precomputed conditioning features. |
| """ |
| if xf_out is None: |
| xf_out = self.encode_text(text, clip_feat, device) |
| output = {'xf_out': xf_out} |
| return output |
|
|
| def post_process(self, motion: torch.Tensor) -> torch.Tensor: |
| """ |
| Post-process motion data by unnormalizing if necessary. |
| |
| Args: |
| motion (torch.Tensor): Input motion data. |
| |
| Returns: |
| torch.Tensor: Processed motion data. |
| """ |
| if self.post_process_cfg is not None: |
| if self.post_process_cfg.get("unnormalized_infer", False): |
| mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path'])).type_as(motion) |
| std = torch.from_numpy(np.load(self.post_process_cfg['std_path'])).type_as(motion) |
| motion = motion * std + mean |
| return motion |
|
|
| def forward_train(self, |
| h: torch.Tensor, |
| src_mask: Optional[torch.Tensor] = None, |
| emb: Optional[torch.Tensor] = None, |
| xf_out: Optional[torch.Tensor] = None, |
| motion_length: Optional[torch.Tensor] = None, |
| num_intervals: int = 1, |
| **kwargs) -> torch.Tensor: |
| """ |
| Forward pass during training. |
| |
| Args: |
| h (torch.Tensor): Input tensor of shape (B, T, D). |
| src_mask (Optional[torch.Tensor]): Source mask tensor. |
| emb (Optional[torch.Tensor]): Time embedding tensor. |
| xf_out (Optional[torch.Tensor]): Precomputed text features. |
| motion_length (Optional[torch.Tensor]): Lengths of motion sequences. |
| num_intervals (int): Number of intervals for processing. |
| |
| Returns: |
| torch.Tensor: Output tensor of shape (B, T, D). |
| """ |
| B, T = h.shape[0], h.shape[1] |
| cond_type = torch.randint(0, 100, size=(B, 1, 1)).repeat(1, 8, 1).to(h.device) if self.fine_mode else torch.randint(0, 100, size=(B, 1, 1)).to(h.device) |
| for module in self.temporal_decoder_blocks: |
| h = module(x=h, |
| xf=xf_out, |
| emb=emb, |
| src_mask=src_mask, |
| cond_type=cond_type, |
| motion_length=motion_length, |
| num_intervals=num_intervals) |
|
|
| output = self.out(h).view(B, T, -1).contiguous() |
| return output |
|
|
| def forward_test(self, |
| h: torch.Tensor, |
| src_mask: Optional[torch.Tensor] = None, |
| emb: Optional[torch.Tensor] = None, |
| xf_out: Optional[torch.Tensor] = None, |
| timesteps: Optional[torch.Tensor] = None, |
| motion_length: Optional[torch.Tensor] = None, |
| num_intervals: int = 1, |
| **kwargs) -> torch.Tensor: |
| """ |
| Forward pass during inference. |
| |
| Args: |
| h (torch.Tensor): Input tensor of shape (B, T, D). |
| src_mask (Optional[torch.Tensor]): Source mask tensor. |
| emb (Optional[torch.Tensor]): Time embedding tensor. |
| xf_out (Optional[torch.Tensor]): Precomputed text features. |
| timesteps (Optional[torch.Tensor]): Diffusion timesteps. |
| motion_length (Optional[torch.Tensor]): Lengths of motion sequences. |
| num_intervals (int): Number of intervals for processing. |
| |
| Returns: |
| torch.Tensor: Output tensor of shape (B, T, D). |
| """ |
| B, T = h.shape[0], h.shape[1] |
| text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 |
| none_cond_type = torch.zeros(B, 1, 1).to(h.device) |
|
|
| all_cond_type = torch.cat((text_cond_type, none_cond_type), dim=0) |
| h = h.repeat(2, 1, 1) |
| xf_out = xf_out.repeat(2, 1, 1) |
| emb = emb.repeat(2, 1) |
| src_mask = src_mask.repeat(2, 1, 1) |
| motion_length = motion_length.repeat(2, 1) |
| for module in self.temporal_decoder_blocks: |
| h = module(x=h, |
| xf=xf_out, |
| emb=emb, |
| src_mask=src_mask, |
| cond_type=all_cond_type, |
| motion_length=motion_length, |
| num_intervals=num_intervals) |
| out = self.out(h).view(2 * B, T, -1).contiguous() |
| out_text = out[:B].contiguous() |
| out_none = out[B:].contiguous() |
|
|
| coef_cfg = self.scale_func(int(timesteps[0])) |
| text_coef = coef_cfg['text_coef'] |
| none_coef = coef_cfg['none_coef'] |
| output = out_text * text_coef + out_none * none_coef |
| return output |
|
|
|
|
|
|