| |
| |
| |
| |
| |
| |
|
|
| from typing import Any, Dict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from accelerate.logging import get_logger |
| from diffusers.utils import is_torch_version |
| from einops import rearrange |
|
|
| from core.models.encoders.sonata import transform |
| from core.models.encoders.sonata.model import ( |
| PointEmbed, |
| PointTransformerV3, |
| offset2bincount, |
| ) |
| from core.models.encoders.sonata.structure import Point |
| from core.models.transformer_block.transformer_dit import SD3PMMJointTransformerBlock |
| from core.models.vggt_transformer import VGGTAggregator |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class PITransformerA4OBase(nn.Module): |
| """ |
| Point-Image Multi-Modality Transformer |
| Given image_list encoded by pretrained model [for example Dino-v2], |
| and also PointClound, we design an efficient and faster transformer to embedding each features. |
| """ |
|
|
| def __init__( |
| self, |
| point_backbone: Dict[str, Any] = None, |
| image_backbone: Dict[str, Any] = None, |
| gradient_checkpointing=False, |
| **kwargs, |
| ): |
| """ |
| Initializes the PITransformerA4OBase module. |
| |
| Args: |
| point_backbone (Dict[str, Any], optional): Configuration dictionary for the point backbone model. |
| image_backbone (Dict[str, Any], optional): Configuration dictionary for the image backbone model. |
| gradient_checkpointing (bool, optional): If True, enables gradient checkpointing for memory savings. Default is False. |
| **kwargs: Additional keyword arguments, typically including transformer configuration details |
| such as 'img_dim', 'dim', and 'num_heads'. |
| """ |
|
|
| super().__init__() |
| self.point_processing = transform.default() |
|
|
| |
| self.build_point_transformer(point_backbone) |
|
|
| |
| self.build_image_transformer(image_backbone) |
|
|
| depth = len(self.point_backbone.enc) |
| assert depth - 1 == self.image_backbone.aa_block_num |
|
|
| |
| |
| self.pi_transformer_enc = nn.ModuleList() |
| self.pi_transformer_dec = nn.ModuleList() |
|
|
| enc_channels = point_backbone["enc_channels"] |
| self.enc_patch_size = point_backbone["enc_patch_size"] |
| aa_order = image_backbone["aa_order"] |
|
|
| |
| for i in range(1, depth): |
| self.pi_transformer_enc.append( |
| SD3PMMJointTransformerBlock( |
| pv_dim=enc_channels[i], |
| img_dim=kwargs["img_dim"], |
| dim=kwargs["dim"], |
| num_heads=kwargs["num_heads"], |
| eps=1e-5, |
| context_pre_only=False, |
| qk_norm="rms_norm", |
| mlp_ratio=4.0, |
| require_mapping=True if len(aa_order) == 2 else False, |
| ) |
| ) |
|
|
| |
| decoder_concat_stop_feat = point_backbone["decoder_concat_stop_feat"] |
| decoder_channels = enc_channels[::-1] |
| decoder_channels = np.cumsum(decoder_channels) |
|
|
| for i in range(1, depth): |
| if i - 1 < decoder_concat_stop_feat: |
| decoder_channel = decoder_channels[i] |
| self.pi_transformer_dec.append( |
| SD3PMMJointTransformerBlock( |
| pv_dim=decoder_channel, |
| img_dim=kwargs["img_dim"], |
| dim=kwargs["dim"], |
| num_heads=kwargs["num_heads"], |
| eps=1e-5, |
| context_pre_only=False if i != depth - 1 else True, |
| qk_norm="rms_norm", |
| mlp_ratio=4.0, |
| require_mapping=True if len(aa_order) == 2 else False, |
| ) |
| ) |
|
|
| self.outnorm = nn.LayerNorm(decoder_channel) |
| self.gradient_checkpointing = gradient_checkpointing |
|
|
| def build_point_transformer(self, point_backbone): |
| """ |
| Build and initialize the point transformer backbone. |
| |
| Args: |
| point_backbone (dict): Configuration dictionary for the point transformer. |
| Should include the type (e.g., 'pv3' or 'pe') and all required initialization parameters |
| for the corresponding point transformer class. |
| |
| Raises: |
| NotImplementedError: If a specified type is not supported. |
| """ |
|
|
| pv_type = point_backbone.pop("type", "pv3") |
|
|
| if pv_type == "pv3": |
| self.point_backbone = PointTransformerV3(**point_backbone) |
| elif pv_type == "pe": |
| self.point_backbone = PointEmbed(**point_backbone) |
| else: |
| raise NotImplementedError |
|
|
| def build_image_transformer(self, image_backbone): |
| """ |
| Build and initialize the image transformer backbone. |
| |
| Args: |
| image_backbone (dict): Configuration dictionary for the image transformer. |
| Should include the type (e.g., 'vggt') and all required initialization parameters |
| for the corresponding image transformer class. |
| |
| Raises: |
| NotImplementedError: If a specified type is not supported. |
| """ |
|
|
| img_type = image_backbone.pop("type", "vggt") |
|
|
| if img_type == "vggt": |
| self.image_backbone = VGGTAggregator(**image_backbone) |
| else: |
| raise NotImplementedError |
|
|
| def assert_runtime_integrity( |
| self, x: torch.Tensor, cond: torch.Tensor, mod: torch.Tensor |
| ): |
| assert x is not None, f"Input tensor must be specified" |
|
|
| @torch.no_grad() |
| def get_padding_and_inverse(self, point, patch_size): |
| """ |
| Compute the padding and its inverse for point cloud batches, so that all patches have the same number of points. |
| |
| Args: |
| point: A Point structure or dictionary with an 'offset' tensor attribute (1D tensor), representing the cumulative sum of points in each batch/sample. |
| patch_size (int): Desired patch size to which batches should be padded. |
| |
| Modifies: |
| point (dict): Adds two new keys: |
| - "pad": Indices for the padded points. |
| - "unpad": Indices for mapping from padded points back to the original points. |
| |
| Returns: |
| None: Updates the input `point` in-place by adding "pad" and "unpad". |
| """ |
|
|
| pad_key = "pad" |
| unpad_key = "unpad" |
|
|
| offset = point.offset |
| bincount = offset2bincount(offset) |
| bincount_pad = ( |
| torch.div( |
| bincount + patch_size - 1, |
| patch_size, |
| rounding_mode="trunc", |
| ) |
| * patch_size |
| ) |
| |
| mask_pad = bincount > patch_size |
| bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad |
| _offset = nn.functional.pad(offset, (1, 0)) |
| _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) |
| pad = torch.arange(_offset_pad[-1], device=offset.device) |
| unpad = torch.arange(_offset[-1], device=offset.device) |
| cu_seqlens = [] |
| for i in range(len(offset)): |
| unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] |
| if bincount[i] != bincount_pad[i]: |
| pad[ |
| _offset_pad[i + 1] |
| - patch_size |
| + (bincount[i] % patch_size) : _offset_pad[i + 1] |
| ] = pad[ |
| _offset_pad[i + 1] |
| - 2 * patch_size |
| + (bincount[i] % patch_size) : _offset_pad[i + 1] |
| - patch_size |
| ] |
| pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] |
| cu_seqlens.append( |
| torch.arange( |
| _offset_pad[i], |
| _offset_pad[i + 1], |
| step=patch_size, |
| dtype=torch.int32, |
| device=offset.device, |
| ) |
| ) |
| point[pad_key] = pad |
| point[unpad_key] = unpad |
| cu_seqlens = nn.functional.pad( |
| torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] |
| ) |
|
|
| return pad, unpad, cu_seqlens, patch_size |
|
|
| def forward_multi_modality_layer(self, layer, x, cond, motion_emb=None): |
| """ |
| Applies a multi-modality transformer layer on input tokens. |
| |
| This function forwards the point and image modality tokens (`x`, `cond`) and |
| an optional `motion_emb` through the provided transformer `layer`, |
| potentially using gradient checkpointing if enabled and in training mode. |
| |
| Args: |
| layer (nn.Module): The transformer block to apply. |
| x (Tensor): Point tokens/tensor, shape-specific to architecture (e.g., (B, N, C)). |
| cond (Tensor): Image tokens/tensor, shape-specific to architecture. |
| motion_emb (Tensor, optional): Optional motion embedding for conditioning. |
| |
| Returns: |
| Tuple[Tensor, Tensor]: |
| Updated (x, cond) tensors after passing through the transformer layer. |
| """ |
|
|
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| ckpt_kwargs: Dict[str, Any] = ( |
| {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| ) |
|
|
| x, cond = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer), |
| x, |
| cond, |
| motion_emb, |
| **ckpt_kwargs, |
| ) |
| else: |
| x, cond = layer( |
| x, |
| cond, |
| motion_emb, |
| ) |
|
|
| return x, cond |
|
|
| def _merge_mv(self, mv_tokens): |
| """ |
| Merges multi-view tokens according to the accumulation strategy. |
| |
| This function reduces the dimensionality of the multi-view tokens (`mv_tokens`) |
| depending on the accumulation mode specified by `self.accumulate`. |
| |
| Args: |
| mv_tokens (Tensor): Input tokens of shape (b, s, p, c) where |
| b = batch size, |
| s = number of views, |
| p = number of points, |
| c = feature dimension. |
| |
| Returns: |
| Tensor: Merged tokens. |
| - If accumulation is 'concat': shape (b, s*p, c). |
| - If 'mean': shape (b, p, c). |
| """ |
|
|
| if self.accumulate == "concat": |
| mv_tokens = rearrange(mv_tokens, "b s p c -> b (s p) c") |
| elif self.accumulate == "mean": |
| mv_tokens = mv_tokens.mean(dim=1) |
| else: |
| raise NotImplementedError |
|
|
| return mv_tokens |
|
|
| def _demerge_mv(self, mv_tokens, merge_mv_tokens, mv_p): |
| """ |
| Demerges multi-view tokens according to the accumulation strategy. |
| |
| This function reconstructs the multi-view token structure from the merged form, |
| depending on the accumulation mode specified by `self.accumulate`. |
| |
| Args: |
| mv_tokens (Tensor): The original tokens, potentially to be updated, shape varies. |
| merge_mv_tokens (Tensor): Merged tokens to be demerged. |
| mv_p (int): The number of points per view. |
| |
| Returns: |
| Tensor: Demerged tokens with shape: |
| - If accumulation is 'concat': (b, s, p, c) |
| - If 'mean': shape as `mv_tokens`, updated in-place. |
| """ |
|
|
| if self.accumulate == "concat": |
| mv_tokens = rearrange(merge_mv_tokens, "b (s p) c -> b s p c", p=mv_p) |
| elif self.accumulate == "mean": |
| norm_merge_mv_tokens = merge_mv_tokens / merge_mv_tokens.norm( |
| dim=-1, keepdim=True |
| ) |
| mv_tokens += norm_merge_mv_tokens.unsqueeze( |
| 1 |
| ) |
| else: |
| raise NotImplementedError |
|
|
| return mv_tokens |
|
|
| def multi_modality_layer( |
| self, |
| mv_tokens, |
| query_points, |
| motion_emb, |
| depth_i, |
| mode="enc", |
| use_auto_patch=True, |
| ): |
| """Process multi-modality features with configurable encoder/decoder transformer. |
| |
| Args: |
| mv_tokens: Multi-view tokens [B, S, P, C] |
| query_points: Query points structure with features |
| motion_emb: Motion embeddings |
| depth_i: Current depth index |
| mode: Processing mode: "enc" for encoder or "dec" for decoder |
| |
| Returns: |
| mv_tokens: Processed multi-view tokens [B, S, P, C] |
| query_points: Updated query points with new features |
| """ |
| |
| _, _, mv_p, _ = mv_tokens.shape |
|
|
| merge_mv_tokens = self._merge_mv(mv_tokens) |
|
|
| |
| offset = query_points.offset |
|
|
| |
| |
| auto_patch_size = 10 * mv_p |
|
|
| |
| if use_auto_patch and offset.item() > auto_patch_size: |
| |
| |
| |
|
|
| pad, unpad, _, patch_size = self.get_padding_and_inverse( |
| query_points, auto_patch_size |
| ) |
| order = query_points.mm_order[pad] |
| inverse = unpad[query_points.mm_inverse] |
| order_feat = query_points.feat[order] |
| point_latents = order_feat.view(-1, patch_size, order_feat.shape[-1]) |
| else: |
| order = query_points.mm_order |
| inverse = query_points.mm_inverse |
| order_feat = query_points.feat[order] |
| point_latents = order_feat.unsqueeze(0) |
|
|
| |
| transformer = ( |
| self.pi_transformer_enc[depth_i - 1] |
| if mode == "enc" |
| else self.pi_transformer_dec[depth_i - 1] |
| ) |
|
|
| |
| if motion_emb is None: |
| raise NotImplementedError |
|
|
| point_latents, merge_mv_tokens = self.forward_multi_modality_layer( |
| transformer, point_latents, merge_mv_tokens, motion_emb=motion_emb |
| ) |
|
|
| if use_auto_patch and offset.item() > auto_patch_size: |
| |
| points_latents = point_latents.view(-1, order_feat.shape[-1])[inverse] |
|
|
| if merge_mv_tokens is not None: |
| mv_tokens = self._demerge_mv(mv_tokens, merge_mv_tokens, mv_p) |
| else: |
| mv_tokens = None |
|
|
| |
| query_points.feat = points_latents |
| query_points.sparse_conv_feat = ( |
| query_points.sparse_conv_feat.replace_feature(points_latents) |
| ) |
| else: |
| |
| points_latents = point_latents.view(-1, order_feat.shape[-1])[inverse] |
|
|
| if merge_mv_tokens is not None: |
| mv_tokens = self._demerge_mv(mv_tokens, merge_mv_tokens, mv_p) |
| else: |
| mv_tokens = None |
| |
| query_points.feat = points_latents |
| query_points.sparse_conv_feat = ( |
| query_points.sparse_conv_feat.replace_feature(points_latents) |
| ) |
|
|
| return mv_tokens, query_points |
|
|
| def forward( |
| self, |
| query_points: dict, |
| cond_imgs: torch.Tensor = None, |
| motion_emb: torch.Tensor = None, |
| **kwargs, |
| ) -> dict: |
| """ |
| Forward pass of the transformer model. |
| Args: |
| query_points (torch.Tensor): Input tensor of shape [N, L, D]. |
| cond_imgs (torch.Tensor, optional): [B S P C] from pretrained_models |
| motion_emb (torch.Tensor, optional): Modulation tensor of shape [N, D_mod] or None. Defaults to None. # For SD3_MM_Cond, temb means MotionCLIP |
| Returns: |
| torch.Tensor: Output tensor of shape [N, L, D]. |
| """ |
|
|
| |
| |
| |
|
|
| query_points = self.point_processing(query_points) |
| query_points = Point(query_points) |
| query_points.cuda() |
| query_points = self.point_backbone.serialization_sparisy(query_points) |
|
|
| |
| query_points = self.point_backbone.enc[0](query_points) |
|
|
| |
| B, S, P, C = cond_imgs.shape |
| img_tokens, pos = self.image_backbone.prepare_image_tokens( |
| cond_imgs, latent_size=kwargs["latent_size"] |
| ) |
| _, P, C = img_tokens.shape |
|
|
| frame_idx = 0 |
| global_idx = 0 |
|
|
| depth = len(self.point_backbone.enc) |
|
|
| img_enc_feat_list = [] |
|
|
| |
| for depth_i in range(1, depth): |
| |
| query_points = self.point_backbone.enc[depth_i](query_points) |
|
|
| |
| for attn_type in self.image_backbone.aa_order: |
|
|
| if attn_type == "frame": |
| img_tokens, frame_idx, frame_intermediates = ( |
| self.image_backbone._process_frame_attention( |
| img_tokens, B, S, P, C, frame_idx, pos=pos |
| ) |
| ) |
| elif attn_type == "global": |
| img_tokens, global_idx, global_intermediates = ( |
| self.image_backbone._process_global_attention( |
| img_tokens, B, S, P, C, global_idx, pos=pos |
| ) |
| ) |
| else: |
| raise ValueError(f"Unknown attention type: {attn_type}") |
|
|
| img_enc_feat_list.append(img_tokens) |
|
|
| |
| if self.image_backbone.aa_block_num == 2: |
| concat_inter = torch.cat( |
| [frame_intermediates[-1], global_intermediates[-1]], dim=-1 |
| ) |
| else: |
| concat_inter = frame_intermediates[-1] |
|
|
| patch_start_idx = self.image_backbone.patch_start_idx |
| mv_tokens = concat_inter[:, :, patch_start_idx:] |
| mv_tokens, query_points = self.multi_modality_layer( |
| mv_tokens, query_points, motion_emb, depth_i, mode="enc" |
| ) |
|
|
| img_tokens = img_tokens.view(B, S, P, C) |
|
|
| |
| decoder_layer_i = 0 |
| img_enc_feat_list = img_enc_feat_list[::-1] |
|
|
| while "pooling_parent" in query_points.keys(): |
|
|
| dec_img_tokens = torch.cat( |
| [img_tokens, img_enc_feat_list[decoder_layer_i].view(B, S, P, C)], dim=3 |
| ) |
| query_points, decoder_layer_i = self.point_backbone.decoder( |
| query_points, decoder_layer_i |
| ) |
|
|
| mv_tokens = dec_img_tokens[:, :, patch_start_idx:] |
|
|
| mv_tokens, query_points = self.multi_modality_layer( |
| mv_tokens, query_points, motion_emb, decoder_layer_i, mode="dec" |
| ) |
|
|
| img_tokens = img_tokens.view(B, S, P, C) |
| img_tokens = torch.cat( |
| [img_tokens[:, :, :patch_start_idx], mv_tokens], dim=2 |
| ) |
|
|
| query_feats = query_points.feat[query_points.inverse] |
|
|
| return dict(query_feats=query_feats.unsqueeze(0), img_feats=img_tokens) |
|
|
|
|
| class PITransformerA4OEncoderOnly(PITransformerA4OBase): |
| """ |
| PITransformerA4OEncoderOnly |
| |
| This class implements the encoder-only variant of the Point-Image Multi-Modality Transformer. |
| It processes input point clouds and corresponding images through multi-stage transformer blocks, |
| encoding features from both modalities for downstream tasks. |
| |
| Attributes: |
| point_processing: Preprocessing function for points. |
| point_backbone: Backbone network for point cloud feature extraction. |
| image_backbone: Backbone network for image feature extraction. |
| pi_transformer_enc: nn.ModuleList of point-image encoder transformer blocks. |
| pi_transformer_dec: nn.ModuleList of point-image decoder transformer blocks (unused in encoder-only mode). |
| enc_patch_size: Patch sizes used for point backbone encoding. |
| """ |
|
|
| def __init__( |
| self, |
| point_backbone: Dict[str, Any] = None, |
| image_backbone: Dict[str, Any] = None, |
| gradient_checkpointing=False, |
| **kwargs, |
| ): |
| """ |
| Initializes the PITransformerA4OEncoderOnly module. |
| |
| Args: |
| point_backbone (Dict[str, Any], optional): Configuration dictionary for the point backbone model. |
| image_backbone (Dict[str, Any], optional): Configuration dictionary for the image backbone model. |
| gradient_checkpointing (bool, optional): If True, enables gradient checkpointing for memory savings. Default is False. |
| **kwargs: Additional keyword arguments, typically including transformer configuration details |
| such as 'img_dim', 'dim', and 'num_heads'. |
| """ |
|
|
| super(PITransformerA4OBase, self).__init__() |
|
|
| self.point_processing = transform.default() |
|
|
| |
| self.build_point_transformer(point_backbone) |
|
|
| |
| self.build_image_transformer(image_backbone) |
|
|
| depth = len(self.point_backbone.enc) |
| assert depth - 1 == self.image_backbone.aa_block_num |
|
|
| |
| |
| self.pi_transformer_enc = nn.ModuleList() |
| self.pi_transformer_dec = nn.ModuleList() |
|
|
| enc_channels = point_backbone["enc_channels"] |
| self.enc_patch_size = point_backbone["enc_patch_size"] |
| aa_order = image_backbone["aa_order"] |
|
|
| |
| for i in range(1, depth): |
| self.pi_transformer_enc.append( |
| SD3PMMJointTransformerBlock( |
| pv_dim=enc_channels[i], |
| img_dim=kwargs["img_dim"], |
| dim=kwargs["dim"], |
| num_heads=kwargs["num_heads"], |
| eps=1e-5, |
| context_pre_only=False, |
| qk_norm="rms_norm", |
| mlp_ratio=4.0, |
| require_mapping=True if len(aa_order) == 2 else False, |
| ) |
| ) |
|
|
| |
| self.build_decoder(point_backbone, image_backbone, **kwargs) |
| self.gradient_checkpointing = gradient_checkpointing |
|
|
| |
| freeze_point = kwargs.get("freeze_point", True) |
| freeze_image = kwargs.get("freeze_image", True) |
| self.only_global_attention(freeze_point, freeze_image) |
|
|
| |
| self.accumulate = kwargs.get("accumulate", "concat") |
|
|
| def only_global_attention(self, freeze_point=True, freeze_image=True): |
| """ |
| Freeze all attention except P-I global attention. |
| |
| Args: |
| freeze_point (bool, optional): If True, freeze all parameters of the point backbone (default: True). |
| freeze_image (bool, optional): If True, freeze all parameters of the image backbone (default: True). |
| |
| This function disables gradient computation for the selected backbone networks to |
| ensure only the global point-image (P-I) attention modules remain trainable. |
| """ |
|
|
| |
| if freeze_image: |
| logger.info("freezing image backbone") |
| for param in self.image_backbone.parameters(): |
| param.requires_grad = False |
|
|
| |
| if freeze_point: |
| logger.info("freezing point backbone") |
| for param in self.point_backbone.parameters(): |
| param.requires_grad = False |
|
|
| def build_decoder(self, point_backbone, image_backbone, **kwargs): |
| """ |
| Build and initialize the decoder part of the point-image transformer. |
| |
| Args: |
| point_backbone (dict): Configuration dictionary for the point backbone model. |
| image_backbone (dict): Configuration dictionary for the image backbone model. |
| **kwargs: Additional keyword arguments for transformer configuration such as |
| 'img_dim', 'dim', and 'num_heads'. |
| |
| This function initializes the decoder transformer blocks for the multi-modality model. |
| Depending on 'enc_channels', it configures the appropriate channel sizes for each decoder layer, |
| and sets up PI (Point-Image) joint transformer modules for each stage. |
| """ |
|
|
| |
| enc_channels = point_backbone["enc_channels"] |
| depth = len(self.point_backbone.enc) |
| aa_order = image_backbone["aa_order"] |
|
|
| decoder_concat_stop_feat = point_backbone["decoder_concat_stop_feat"] |
| decoder_channels = enc_channels[::-1] |
| decoder_channels = np.cumsum(decoder_channels) |
|
|
| for i in range(1, depth): |
| if i - 1 < decoder_concat_stop_feat: |
| decoder_channel = decoder_channels[i] |
| self.pi_transformer_dec.append( |
| SD3PMMJointTransformerBlock( |
| pv_dim=decoder_channel, |
| img_dim=kwargs["img_dim"], |
| dim=kwargs["dim"], |
| num_heads=kwargs["num_heads"], |
| eps=1e-5, |
| context_pre_only=False if i != depth - 1 else True, |
| qk_norm="rms_norm", |
| mlp_ratio=4.0, |
| require_mapping=True if len(aa_order) == 2 else False, |
| ) |
| ) |
|
|
| self.outnorm = nn.LayerNorm(decoder_channel) |
|
|
| def _forward_per_batch(self, query_points, cond_imgs, motion_emb=None, **kwargs): |
| """ |
| Performs a forward pass for a single batch through the point-image transformer. |
| |
| Args: |
| query_points (dict): Dictionary representing the query points, expected to include fields for input features and structure. |
| cond_imgs (torch.Tensor): Conditional images as input, typically a tensor of shape [B, S, P, C] (batch, views, patches, channels). |
| motion_emb (torch.Tensor, optional): Optional motion embedding provided for multi-modality conditioning. |
| **kwargs: Arbitrary keyword arguments, may include model/hyperparameter options such as 'latent_size'. |
| |
| Returns: |
| Output tensors after passing through the point backbone, image backbone, |
| and multi-modality layers. The returned structure depends on the calling function |
| but often includes processed point and image tokens/features. |
| """ |
|
|
| |
| |
| |
|
|
| query_points = self.point_processing(query_points) |
| query_points = Point(query_points) |
| query_points.cuda() |
| query_points = self.point_backbone.serialization_sparisy(query_points) |
|
|
| |
| query_points = self.point_backbone.enc[0](query_points) |
|
|
| |
| B, S, P, C = cond_imgs.shape |
|
|
| img_tokens, pos = self.image_backbone.prepare_image_tokens( |
| cond_imgs, latent_size=kwargs["latent_size"] |
| ) |
| _, P, C = img_tokens.shape |
|
|
| frame_idx = 0 |
| global_idx = 0 |
|
|
| depth = len(self.point_backbone.enc) |
|
|
| |
| for depth_i in range(1, depth): |
| |
| query_points = self.point_backbone.enc[depth_i](query_points) |
|
|
| |
| for attn_type in self.image_backbone.aa_order: |
|
|
| if attn_type == "frame": |
| img_tokens, frame_idx, frame_intermediates = ( |
| self.image_backbone._process_frame_attention( |
| img_tokens, B, S, P, C, frame_idx, pos=pos |
| ) |
| ) |
| elif attn_type == "global": |
| img_tokens, global_idx, global_intermediates = ( |
| self.image_backbone._process_global_attention( |
| img_tokens, B, S, P, C, global_idx, pos=pos |
| ) |
| ) |
| else: |
| raise ValueError(f"Unknown attention type: {attn_type}") |
|
|
| |
| if len(self.image_backbone.aa_order) == 2: |
| mv_tokens = torch.cat( |
| [frame_intermediates[-1], global_intermediates[-1]], dim=-1 |
| ) |
| else: |
| mv_tokens = frame_intermediates[-1] |
|
|
| mv_tokens, query_points = self.multi_modality_layer( |
| mv_tokens, query_points, motion_emb, depth_i, mode="enc" |
| ) |
| img_tokens = mv_tokens.view(B, S, P, C) |
|
|
| |
| decoder_layer_i = 0 |
|
|
| while "pooling_parent" in query_points.keys(): |
| query_points, decoder_layer_i = self.point_backbone.decoder( |
| query_points, decoder_layer_i |
| ) |
| mv_tokens, query_points = self.multi_modality_layer( |
| img_tokens, query_points, motion_emb, decoder_layer_i, mode="dec" |
| ) |
|
|
| if mv_tokens is not None: |
| img_tokens = mv_tokens.view(B, S, P, C) |
|
|
| |
| query_points.feat = self.outnorm(query_points.feat[query_points.inverse]) |
|
|
| return query_points.feat.unsqueeze(0), img_tokens |
|
|
| def forward( |
| self, |
| query_points: dict, |
| cond_imgs: torch.Tensor = None, |
| motion_emb: torch.Tensor = None, |
| **kwargs, |
| ) -> dict: |
| """ |
| Forward pass of the transformer model. |
| Args: |
| query_points (torch.Tensor): Input tensor of shape [N, L, D]. |
| cond_imgs (torch.Tensor, optional): [B S P C] from pretrained_models |
| motion_emb (torch.Tensor, optional): Modulation tensor of shape [N, D_mod] or None. Defaults to None. # For SD3_MM_Cond, temb means MotionCLIP |
| Returns: |
| torch.Tensor: Output tensor of shape [N, L, D]. |
| """ |
|
|
| def sample_batch(query_points, batch): |
| batch_sample = dict() |
| for key in query_points.keys(): |
| if key == "mesh_meta": |
| continue |
| batch_sample[key] = query_points[key][batch : batch + 1] |
|
|
| return batch_sample |
|
|
| batch_size = cond_imgs.shape[0] |
| mesh_meta = query_points["mesh_meta"] |
|
|
| query_feats_list = [] |
| img_feats_list = [] |
| motion_emb_list = [] |
| for batch in range(batch_size): |
|
|
| sample = sample_batch(query_points, batch) |
| sample["mesh_meta"] = mesh_meta |
| |
| batch_ref_imgs_bool = kwargs["ref_imgs_bool"][batch] |
| batch_cond_imgs = cond_imgs[batch : batch + 1] |
| batch_motion_emb = motion_emb[batch : batch + 1] |
|
|
| valid_cond_imgs = batch_cond_imgs[:, batch_ref_imgs_bool] |
| valid_cond_motion_emb = batch_motion_emb[:, batch_ref_imgs_bool].mean(dim=1) |
|
|
| query_feats, img_feats, motion_emb = self._forward_per_batch( |
| sample, valid_cond_imgs, valid_cond_motion_emb, **kwargs |
| ) |
|
|
| query_feats_list.append(query_feats) |
| img_feats_list.append(img_feats) |
| motion_emb_list.append(motion_emb) |
|
|
| return dict( |
| query_feats=torch.cat(query_feats_list, dim=0), |
| img_feats=img_feats_list, |
| motion_embs=motion_emb_list, |
| ) |
|
|
|
|
| class PITransformerA4OE2EEncoderDecoder(PITransformerA4OEncoderOnly): |
| """ |
| PITransformerA4OE2EEncoderDecoder |
| |
| This class implements the end-to-end encoder-decoder variant of the Point-Image Multi-Modality Transformer. |
| It processes input point clouds and corresponding images with full encoder-decoder transformer blocks |
| for both modalities, enabling joint feature extraction and cross-modality reasoning. |
| |
| Inherits from: |
| PITransformerA4OEncoderOnly |
| |
| Attributes: |
| pi_transformer_dec (nn.ModuleList): Decoder blocks for point-image multi-modality transformer. |
| outnorm (nn.LayerNorm): Output normalization for decoder's query point features. |
| |
| Methods: |
| build_decoder: Constructs decoder transformer blocks for cross-modality fusion. |
| _forward_per_batch: Processes a single batch of inputs through encoder and decoder pipeline. |
| """ |
|
|
| def build_decoder(self, point_backbone, image_backbone, **kwargs): |
| """ |
| Build and initialize the decoder part of the point-image transformer. |
| |
| Args: |
| point_backbone (dict): Configuration dictionary for the point backbone model. |
| image_backbone (dict): Configuration dictionary for the image backbone model. |
| **kwargs: Additional keyword arguments for transformer configuration such as |
| 'img_dim', 'dim', and 'num_heads'. |
| |
| This function initializes the decoder transformer blocks for the multi-modality model. |
| It sets up SD3PMMJointTransformerBlock modules for each decoder stage using the provided channels |
| and attention settings, supporting point-image cross-modality fusion in the decoding process. |
| """ |
|
|
| |
| dec_channels = list(point_backbone["dec_channels"])[::-1] |
| depth = len(self.point_backbone.dec) |
| aa_order = image_backbone["aa_order"] |
|
|
| for i in range(depth): |
|
|
| self.pi_transformer_dec.append( |
| SD3PMMJointTransformerBlock( |
| pv_dim=dec_channels[i], |
| img_dim=kwargs["img_dim"], |
| dim=kwargs["dim"], |
| num_heads=kwargs["num_heads"], |
| eps=1e-5, |
| context_pre_only=False if i != depth - 1 else True, |
| qk_norm="rms_norm", |
| mlp_ratio=4.0, |
| require_mapping=True if len(aa_order) == 2 else False, |
| ) |
| ) |
|
|
| self.outnorm = nn.LayerNorm(dec_channels[-1]) |
|
|
| def _forward_per_batch(self, query_points, cond_imgs, motion_emb=None, **kwargs): |
| """ |
| Forward pass of the transformer model. |
| Args: |
| query_points (torch.Tensor): Input tensor of shape [N, L, D]. |
| cond_imgs (torch.Tensor, optional): [B S P C] from pretrained_models |
| motion_emb (torch.Tensor, optional): Modulation tensor of shape [N, D_mod] or None. Defaults to None. # For SD3_MM_Cond, temb means MotionCLIP |
| Returns: |
| torch.Tensor: Output tensor of shape [N, L, D]. |
| """ |
|
|
| |
| |
| |
|
|
| query_points = self.point_processing(query_points) |
| query_points = Point(query_points) |
| query_points.cuda() |
| query_points = self.point_backbone.serialization_sparisy(query_points) |
|
|
| |
| query_points = self.point_backbone.enc[0](query_points) |
|
|
| |
| B, S, P, C = cond_imgs.shape |
|
|
| img_tokens, pos = self.image_backbone.prepare_image_tokens( |
| cond_imgs, latent_size=kwargs["latent_size"] |
| ) |
| _, P, C = img_tokens.shape |
|
|
| frame_idx = 0 |
| global_idx = 0 |
|
|
| depth = len(self.point_backbone.enc) |
|
|
| |
| for depth_i in range(1, depth): |
| |
| query_points = self.point_backbone.enc[depth_i](query_points) |
|
|
| |
| for attn_type in self.image_backbone.aa_order: |
|
|
| if attn_type == "frame": |
| img_tokens, frame_idx, frame_intermediates = ( |
| self.image_backbone._process_frame_attention( |
| img_tokens, B, S, P, C, frame_idx, pos=pos |
| ) |
| ) |
| elif attn_type == "global": |
| img_tokens, global_idx, global_intermediates = ( |
| self.image_backbone._process_global_attention( |
| img_tokens, B, S, P, C, global_idx, pos=pos |
| ) |
| ) |
| else: |
| raise ValueError(f"Unknown attention type: {attn_type}") |
|
|
| |
| if self.image_backbone.aa_block_num == 2: |
| mv_tokens = torch.cat( |
| [frame_intermediates[-1], global_intermediates[-1]], dim=-1 |
| ) |
| else: |
| mv_tokens = frame_intermediates[-1] |
|
|
| mv_tokens, query_points = self.multi_modality_layer( |
| mv_tokens, |
| query_points, |
| motion_emb, |
| depth_i, |
| mode="enc", |
| use_auto_patch=False, |
| ) |
| img_tokens = mv_tokens.view(B, S, P, C) |
|
|
| |
| depth = len(self.point_backbone.dec) |
| for depth_i in range(depth): |
| |
| query_points = self.point_backbone.dec[depth_i](query_points) |
| |
| mv_tokens, query_points = self.multi_modality_layer( |
| img_tokens, |
| query_points, |
| motion_emb, |
| depth_i + 1, |
| mode="dec", |
| use_auto_patch=False, |
| ) |
|
|
| if mv_tokens is not None: |
| img_tokens = mv_tokens.view(B, S, P, C) |
|
|
| |
| query_points.feat = self.outnorm(query_points.feat[query_points.inverse]) |
|
|
| return query_points.feat.unsqueeze(0), img_tokens |
|
|
|
|
| class PITransformerA4OEncoder(PITransformerA4OEncoderOnly): |
| """ |
| PITransformerA4OEncoder |
| |
| This class extends PITransformerA4OEncoderOnly with an additional "middle" module |
| which enables further point-image multi-modality transformer mixing after the encoding stage. |
| The middle transformer blocks refine the joint embedding learned from both the point and image backbones, |
| allowing more flexible and deeper integration of appearance and geometric cues. |
| Suitable for use in encoder-only point-image multimodal architectures where |
| mid-depth fusion and processing is required. |
| |
| Attributes: |
| pi_transformer_middle (nn.ModuleList): The list of transformer blocks applied in the middle of the encoder. |
| (Inherits all attributes from PITransformerA4OEncoderOnly.) |
| """ |
|
|
| def __init__( |
| self, |
| point_backbone: Dict[str, Any] = None, |
| image_backbone: Dict[str, Any] = None, |
| gradient_checkpointing=False, |
| **kwargs, |
| ): |
| """ |
| Initializes the PITransformerA4OEncoder module. |
| |
| This constructor extends PITransformerA4OEncoderOnly by incorporating a "middle" section, |
| which consists of additional transformer blocks designed to further process and fuse |
| multimodal (point-image) features after the encoding stage but before decoding. |
| |
| Args: |
| point_backbone (Dict[str, Any], optional): Configuration for the point cloud transformer backbone network. |
| image_backbone (Dict[str, Any], optional): Configuration for the image transformer backbone network. |
| gradient_checkpointing (bool, optional): Enables gradient checkpointing for reduced memory usage. Default: False. |
| **kwargs: Additional transformer-specific keyword arguments, including: |
| - 'img_dim' (int): Image feature channel dimension. |
| - 'dim' (int): Transformer block hidden/channel dimension. |
| - 'num_heads' (int): Number of attention heads. |
| - 'middle_layer' (int): The number of "middle" transformer blocks to construct. |
| And potentially other custom arguments used in block construction. |
| """ |
|
|
| |
| super(PITransformerA4OEncoderOnly, self).__init__( |
| point_backbone, image_backbone, gradient_checkpointing, **kwargs |
| ) |
|
|
| self.pi_transformer_middle = nn.ModuleList() |
|
|
| self.build_middle(point_backbone, image_backbone, **kwargs) |
|
|
| def build_middle(self, point_backbone, image_backbone, **kwargs): |
| """ |
| Build and initialize the "middle" transformer blocks for further multimodal fusion. |
| |
| Args: |
| point_backbone (dict): Configuration dictionary for the point backbone network. |
| image_backbone (dict): Configuration dictionary for the image backbone network. |
| **kwargs: Additional keyword arguments for transformer block configuration, including: |
| - 'img_dim' (int): Image feature channel dimensionality. |
| - 'dim' (int): Hidden/channel dimension for transformer blocks. |
| - 'num_heads' (int): Number of attention heads in each block. |
| - 'middle_layer' (int): Number of transformer blocks to create in the 'middle' stage. |
| |
| This function populates the `pi_transformer_middle` ModuleList with a stack of |
| SD3PMMJointTransformerBlock modules, using the final encoder channel size for point tokens |
| and the specified image dimension, for deep feature fusing between point and image modalities |
| before (optionally) entering a decoder stage. |
| """ |
|
|
| |
| enc_channels = point_backbone["enc_channels"] |
| depth = len(self.point_backbone.enc) |
| aa_order = image_backbone["aa_order"] |
| middle_layer = kwargs["middle_layer"] |
|
|
| channel = enc_channels[-1] |
|
|
| for i in range(0, middle_layer): |
| self.pi_transformer_middle.append( |
| SD3PMMJointTransformerBlock( |
| pv_dim=channel, |
| img_dim=kwargs["img_dim"], |
| dim=kwargs["dim"], |
| num_heads=kwargs["num_heads"], |
| eps=1e-5, |
| context_pre_only=False, |
| qk_norm="rms_norm", |
| mlp_ratio=4.0, |
| require_mapping=True if len(aa_order) == 2 else False, |
| ) |
| ) |
|
|
| def forward( |
| self, |
| query_points: dict, |
| cond_imgs: torch.Tensor = None, |
| motion_emb: torch.Tensor = None, |
| **kwargs, |
| ) -> dict: |
| """ |
| Forward pass of the transformer model. |
| Args: |
| query_points (torch.Tensor): Input tensor of shape [N, L, D]. |
| cond_imgs (torch.Tensor, optional): [B S P C] from pretrained_models |
| motion_emb (torch.Tensor, optional): Modulation tensor of shape [N, D_mod] or None. Defaults to None. # For SD3_MM_Cond, temb means MotionCLIP |
| Returns: |
| torch.Tensor: Output tensor of shape [N, L, D]. |
| """ |
|
|
| |
| |
| |
|
|
| query_points = self.point_processing(query_points) |
| query_points = Point(query_points) |
| query_points.cuda() |
| query_points = self.point_backbone.serialization_sparisy(query_points) |
|
|
| |
| query_points = self.point_backbone.enc[0](query_points) |
|
|
| |
| B, S, P, C = cond_imgs.shape |
|
|
| img_tokens, pos = self.image_backbone.prepare_image_tokens( |
| cond_imgs, latent_size=kwargs["latent_size"] |
| ) |
| _, P, C = img_tokens.shape |
|
|
| frame_idx = 0 |
| global_idx = 0 |
|
|
| depth = len(self.point_backbone.enc) |
|
|
| |
| for depth_i in range(1, depth): |
| |
| query_points = self.point_backbone.enc[depth_i](query_points) |
|
|
| |
| for attn_type in self.image_backbone.aa_order: |
|
|
| if attn_type == "frame": |
| img_tokens, frame_idx, frame_intermediates = ( |
| self.image_backbone._process_frame_attention( |
| img_tokens, B, S, P, C, frame_idx, pos=pos |
| ) |
| ) |
| elif attn_type == "global": |
| img_tokens, global_idx, global_intermediates = ( |
| self.image_backbone._process_global_attention( |
| img_tokens, B, S, P, C, global_idx, pos=pos |
| ) |
| ) |
| else: |
| raise ValueError(f"Unknown attention type: {attn_type}") |
|
|
| |
| if len(self.image_backbone.aa_order) == 2: |
| mv_tokens = torch.cat( |
| [frame_intermediates[-1], global_intermediates[-1]], dim=-1 |
| ) |
| else: |
| mv_tokens = frame_intermediates[-1] |
|
|
| mv_tokens, query_points = self.multi_modality_layer( |
| mv_tokens, query_points, motion_emb, depth_i, mode="enc" |
| ) |
| img_tokens = mv_tokens.view(B, S, P, C) |
|
|
| |
| for middle_depth_i in range(len(self.pi_transformer_middle)): |
| mv_tokens, query_points = self.middle_multi_modality_layer( |
| img_tokens, |
| query_points, |
| motion_emb, |
| middle_depth_i, |
| ) |
| img_tokens = mv_tokens.view(B, S, P, C) |
|
|
| |
| decoder_layer_i = 0 |
| while "pooling_parent" in query_points.keys(): |
| query_points, decoder_layer_i = self.point_backbone.decoder( |
| query_points, decoder_layer_i |
| ) |
| mv_tokens, query_points = self.multi_modality_layer( |
| img_tokens, query_points, motion_emb, decoder_layer_i, mode="dec" |
| ) |
| if mv_tokens is not None: |
| img_tokens = mv_tokens.view(B, S, P, C) |
|
|
| |
| query_points.feat = self.outnorm(query_points.feat[query_points.inverse]) |
|
|
| return dict(query_feats=query_points.feat.unsqueeze(0), img_feats=img_tokens) |
|
|
| def middle_multi_modality_layer(self, mv_tokens, query_points, motion_emb, depth_i): |
| """Process multi-modality features with configurable encoder/decoder transformer. |
| |
| Args: |
| mv_tokens: Multi-view tokens [B, S, P, C] |
| query_points: Query points structure with features |
| motion_emb: Motion embeddings |
| depth_i: Current depth index |
| |
| Returns: |
| mv_tokens: Processed multi-view tokens [B, S, P, C] |
| query_points: Updated query points with new features |
| """ |
| |
| _, _, mv_p, _ = mv_tokens.shape |
| mv_tokens = rearrange(mv_tokens, "b s p c -> b (s p) c") |
|
|
| |
| order = query_points.mm_order |
| inverse = query_points.mm_inverse |
|
|
| |
| order_feat = query_points.feat[order] |
| point_latents = order_feat.unsqueeze(0) |
|
|
| |
| transformer = self.pi_transformer_middle[depth_i] |
| |
| if motion_emb is None: |
| raise NotImplementedError |
|
|
| point_latents, mv_tokens = self.forward_multi_modality_layer( |
| transformer, |
| point_latents, |
| mv_tokens, |
| motion_emb=motion_emb, |
| ) |
|
|
| |
| points_latents = point_latents.view(-1, order_feat.shape[-1])[inverse] |
|
|
| if mv_tokens is not None: |
| mv_tokens = rearrange(mv_tokens, "b (s p) c -> b s p c", p=mv_p) |
|
|
| |
| query_points.feat = points_latents |
| query_points.sparse_conv_feat = query_points.sparse_conv_feat.replace_feature( |
| points_latents |
| ) |
|
|
| return mv_tokens, query_points |
|
|
|
|
| if __name__ == "__main__": |
| point_backbone = dict( |
| type="pv3", |
| in_channels=6, |
| order=("z", "z-trans", "hilbert", "hilbert-trans"), |
| stride=(1, 2), |
| enc_depths=(4, 4, 4), |
| enc_channels=(64, 256, 512), |
| enc_num_head=(4, 8, 16), |
| enc_patch_size=(4096, 4096, 2048), |
| mlp_ratio=4, |
| qkv_bias=True, |
| qk_scale=None, |
| attn_drop=0.0, |
| proj_drop=0.0, |
| drop_path=0.0, |
| shuffle_orders=False, |
| pre_norm=True, |
| enable_rpe=False, |
| enable_flash=True, |
| upcast_attention=False, |
| upcast_softmax=False, |
| traceable=True, |
| mask_token=False, |
| enc_mode=True, |
| freeze_encoder=False, |
| decoder_concat_stop_feat=2, |
| ) |
|
|
| image_backbone = dict( |
| type="vggt", |
| depth=2, |
| aa_order=["frame"], |
| ) |
|
|
| point_backbone = dict( |
| type="pv3", |
| in_channels=6, |
| order=("z", "z-trans", "hilbert", "hilbert-trans"), |
| stride=(1, 2, 2), |
| enc_depths=(4, 4, 4, 4), |
| enc_channels=(64, 128, 256, 512), |
| enc_num_head=(4, 8, 16, 16), |
| enc_patch_size=(4096, 4096, 2048, 1024), |
| mlp_ratio=4, |
| qkv_bias=True, |
| qk_scale=None, |
| attn_drop=0.0, |
| proj_drop=0.0, |
| drop_path=0.0, |
| shuffle_orders=False, |
| pre_norm=True, |
| enable_rpe=False, |
| enable_flash=True, |
| upcast_attention=False, |
| upcast_softmax=False, |
| traceable=True, |
| mask_token=False, |
| enc_mode=True, |
| freeze_encoder=False, |
| decoder_concat_stop_feat=2, |
| ) |
|
|
| image_backbone = dict( |
| type="vggt", |
| depth=3, |
| aa_order=["frame"], |
| ) |
|
|
| pi_model = PITransformerA4OEncoderOnly( |
| point_backbone=point_backbone, |
| image_backbone=image_backbone, |
| img_dim=1024, |
| dim=1024, |
| num_heads=16, |
| gradient_checkpointing=True, |
| freeze_point=True, |
| accumulate="mean", |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| pi_model.cuda() |
|
|
| from core.models.encoders.sonata import transform |
| from core.models.rendering.skinnings.smplx_voxel_skinning import get_query_points |
| from core.models.vggt_transformer import VGGTAggregator |
|
|
| points = get_query_points() |
|
|
| img_latents = torch.randn(1, 1, 73 * 47, 1024).float().cuda() |
|
|
| import time |
|
|
| start = time.time() |
| with torch.no_grad(): |
| out = pi_model( |
| points, |
| img_latents, |
| motion_emb=torch.randn(1, 1, 1024).cuda(), |
| latent_size=(73, 47), |
| ref_imgs_bool=torch.ones(1, 1).cuda() > 0, |
| ) |
|
|
| print(time.time() - start) |
| print(out["query_feats"].shape) |
|
|