| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder |
|
|
| from .configuration_embodiedmae import EmbodiedMAEConfig |
| from .modular_embodiedmae import ( |
| EmbodiedMAEDecoder, |
| EmbodiedMAEDepthEmbeddings, |
| EmbodiedMAEPointCloudEmbeddings, |
| EmbodiedMAERGBEmbeddings, |
| EncoderModelOutput, |
| concat_sequence_with_dummy, |
| prepare_shuffle_idx, |
| ) |
|
|
|
|
| class EmbodiedMAEModel(PreTrainedModel): |
| config_class = EmbodiedMAEConfig |
|
|
| def __init__(self, config: EmbodiedMAEConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.dirichlet = torch.distributions.Dirichlet(torch.full((3,), config.dirichlet_alpha)) |
|
|
| self.rgb_embeddings = EmbodiedMAERGBEmbeddings(config) |
| self.depth_embeddings = EmbodiedMAEDepthEmbeddings(config) |
| self.pc_embeddings = EmbodiedMAEPointCloudEmbeddings(config) |
|
|
| self.encoder = Dinov2Encoder(config) |
|
|
| self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| num_patches = (config.image_size // config.patch_size) ** 2 |
| self.embedding_sz = ( |
| num_patches, |
| num_patches, |
| config.num_pc_centers, |
| ) |
| self.unmask_sz = config.unmask_sz |
|
|
| def get_input_embeddings( |
| self, |
| rgb: Optional[torch.Tensor], |
| depth: Optional[torch.Tensor], |
| pc: Optional[torch.Tensor], |
| add_mask: bool = True, |
| unmask_sz: Optional[int] = None, |
| forward_pc: bool = True, |
| shuffle_idx: Optional[torch.Tensor] = None, |
| ): |
| |
| assert any([rgb is not None, depth is not None, pc is not None]) |
|
|
| |
| rgb_emb = self.rgb_embeddings(rgb) |
| depth_emb = self.depth_embeddings(depth) |
| pc_emb, pc_centers, pc_knn = self.pc_embeddings(pc) |
| if not forward_pc: |
| pc = None |
| pc_emb = None |
|
|
| |
| all_emb = concat_sequence_with_dummy([rgb_emb, depth_emb, pc_emb], self.embedding_sz) |
|
|
| |
| shuffle_idx, restore_idx, unmask_sz = prepare_shuffle_idx( |
| has_rgb=rgb is not None, |
| has_depth=depth is not None, |
| has_pc=pc is not None, |
| batch_size=all_emb.shape[0], |
| unmask_sz=self.unmask_sz if unmask_sz is None else unmask_sz, |
| dirichlet=self.dirichlet, |
| embedding_sz=self.embedding_sz, |
| add_mask=add_mask, |
| shuffle_idx=shuffle_idx, |
| device=all_emb.device, |
| ) |
|
|
| |
| unmasked_emb = torch.gather( |
| all_emb, 1, shuffle_idx[:, :unmask_sz, None].repeat(1, 1, all_emb.shape[-1]) |
| ) |
|
|
| return EncoderModelOutput( |
| embedding=unmasked_emb, |
| pc_centers=pc_centers, |
| pc_knn=pc_knn, |
| shuffle_idx=shuffle_idx, |
| restore_idx=restore_idx, |
| add_mask=add_mask, |
| unmask_sz=unmask_sz, |
| ) |
|
|
| def get_last_hidden_states( |
| self, |
| embedding_output: EncoderModelOutput, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| ): |
| embedding = embedding_output.embedding |
|
|
| encoder_outputs = self.encoder( |
| embedding, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
| sequence_output = encoder_outputs[0] |
| sequence_output = self.layernorm(sequence_output) |
|
|
| embedding_output.last_hidden_states = sequence_output |
| embedding_output.hidden_states = encoder_outputs.hidden_states |
| embedding_output.attentions = encoder_outputs.attentions |
|
|
| return embedding_output |
|
|
| def forward( |
| self, |
| rgb: Optional[torch.Tensor], |
| depth: Optional[torch.Tensor], |
| pc: Optional[torch.Tensor], |
| add_mask: bool = True, |
| unmask_sz: Optional[int] = None, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| forward_pc: bool = True, |
| ): |
| embedding_output = self.get_input_embeddings( |
| rgb, depth, pc, add_mask, unmask_sz, forward_pc |
| ) |
| return self.get_last_hidden_states( |
| embedding_output, output_attentions, output_hidden_states |
| ) |
|
|
|
|
| class EmbodiedMAEForMaskedImageModeling(EmbodiedMAEModel): |
| def __init__(self, config: EmbodiedMAEConfig): |
| super().__init__(config) |
| self.decoder = EmbodiedMAEDecoder(config) |
|
|
| def forward( |
| self, |
| rgb: Optional[torch.Tensor], |
| depth: Optional[torch.Tensor], |
| pc: Optional[torch.Tensor], |
| add_mask: bool = True, |
| unmask_sz: Optional[int] = None, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| forward_pc: bool = True, |
| ): |
| encoder_output = super().forward( |
| rgb, depth, pc, add_mask, unmask_sz, output_attentions, output_hidden_states, forward_pc |
| ) |
| decoder_input = self.decoder.get_decoder_input(encoder_output) |
| return self.decoder(decoder_input) |
|
|
| @torch.no_grad() |
| def visualize( |
| self, |
| rgb: Optional[torch.Tensor], |
| depth: Optional[torch.Tensor], |
| pc: Optional[torch.Tensor], |
| mask_rgb: bool = False, |
| mask_depth: bool = False, |
| mask_pc: bool = False, |
| add_mask: bool = True, |
| unmask_sz: Optional[int] = None, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| forward_pc: bool = True, |
| ): |
| _rgb = None if mask_rgb else rgb |
| _depth = None if mask_depth else depth |
| _pc = None if mask_pc else pc |
| encoder_output = super().forward( |
| _rgb, |
| _depth, |
| _pc, |
| add_mask, |
| unmask_sz, |
| output_attentions, |
| output_hidden_states, |
| forward_pc, |
| ) |
| decoder_input = self.decoder.get_decoder_input(encoder_output) |
| return self.decoder.visualize(decoder_input, rgb, depth, pc) |
|
|
|
|
| __all__ = [EmbodiedMAEModel, EmbodiedMAEForMaskedImageModeling] |
|
|