| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder |
| |
|
| | from .configuration_embodiedmae import EmbodiedMAEConfig |
| | from .modular_embodiedmae import ( |
| | EmbodiedMAEDecoder, |
| | EmbodiedMAEDepthEmbeddings, |
| | EmbodiedMAEPointCloudEmbeddings, |
| | EmbodiedMAERGBEmbeddings, |
| | EncoderModelOutput, |
| | concat_sequence_with_dummy, |
| | prepare_shuffle_idx, |
| | ) |
| |
|
| |
|
| | class EmbodiedMAEModel(PreTrainedModel): |
| | config_class = EmbodiedMAEConfig |
| |
|
| | def __init__(self, config: EmbodiedMAEConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | self.dirichlet = torch.distributions.Dirichlet(torch.full((3,), config.dirichlet_alpha)) |
| |
|
| | self.rgb_embeddings = EmbodiedMAERGBEmbeddings(config) |
| | self.depth_embeddings = EmbodiedMAEDepthEmbeddings(config) |
| | self.pc_embeddings = EmbodiedMAEPointCloudEmbeddings(config) |
| |
|
| | self.encoder = Dinov2Encoder(config) |
| |
|
| | self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| |
|
| | num_patches = (config.image_size // config.patch_size) ** 2 |
| | self.embedding_sz = ( |
| | num_patches, |
| | num_patches, |
| | config.num_pc_centers, |
| | ) |
| | self.unmask_sz = config.unmask_sz |
| |
|
| | def get_input_embeddings( |
| | self, |
| | rgb: Optional[torch.Tensor], |
| | depth: Optional[torch.Tensor], |
| | pc: Optional[torch.Tensor], |
| | add_mask: bool = True, |
| | unmask_sz: Optional[int] = None, |
| | forward_pc: bool = True, |
| | shuffle_idx: Optional[torch.Tensor] = None, |
| | ): |
| | |
| | assert any([rgb is not None, depth is not None, pc is not None]) |
| |
|
| | |
| | rgb_emb = self.rgb_embeddings(rgb) |
| | depth_emb = self.depth_embeddings(depth) |
| | pc_emb, pc_centers, pc_knn = self.pc_embeddings(pc) |
| | if not forward_pc: |
| | pc = None |
| | pc_emb = None |
| |
|
| | |
| | all_emb = concat_sequence_with_dummy([rgb_emb, depth_emb, pc_emb], self.embedding_sz) |
| |
|
| | |
| | shuffle_idx, restore_idx, unmask_sz = prepare_shuffle_idx( |
| | has_rgb=rgb is not None, |
| | has_depth=depth is not None, |
| | has_pc=pc is not None, |
| | batch_size=all_emb.shape[0], |
| | unmask_sz=self.unmask_sz if unmask_sz is None else unmask_sz, |
| | dirichlet=self.dirichlet, |
| | embedding_sz=self.embedding_sz, |
| | add_mask=add_mask, |
| | shuffle_idx=shuffle_idx, |
| | device=all_emb.device, |
| | ) |
| |
|
| | |
| | unmasked_emb = torch.gather( |
| | all_emb, 1, shuffle_idx[:, :unmask_sz, None].repeat(1, 1, all_emb.shape[-1]) |
| | ) |
| |
|
| | return EncoderModelOutput( |
| | embedding=unmasked_emb, |
| | pc_centers=pc_centers, |
| | pc_knn=pc_knn, |
| | shuffle_idx=shuffle_idx, |
| | restore_idx=restore_idx, |
| | add_mask=add_mask, |
| | unmask_sz=unmask_sz, |
| | ) |
| |
|
| | def get_last_hidden_states( |
| | self, |
| | embedding_output: EncoderModelOutput, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | ): |
| | embedding = embedding_output.embedding |
| |
|
| | encoder_outputs = self.encoder( |
| | embedding, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| | sequence_output = encoder_outputs[0] |
| | sequence_output = self.layernorm(sequence_output) |
| |
|
| | embedding_output.last_hidden_states = sequence_output |
| | embedding_output.hidden_states = encoder_outputs.hidden_states |
| | embedding_output.attentions = encoder_outputs.attentions |
| |
|
| | return embedding_output |
| |
|
| | def forward( |
| | self, |
| | rgb: Optional[torch.Tensor], |
| | depth: Optional[torch.Tensor], |
| | pc: Optional[torch.Tensor], |
| | add_mask: bool = True, |
| | unmask_sz: Optional[int] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | forward_pc: bool = True, |
| | ): |
| | embedding_output = self.get_input_embeddings( |
| | rgb, depth, pc, add_mask, unmask_sz, forward_pc |
| | ) |
| | return self.get_last_hidden_states( |
| | embedding_output, output_attentions, output_hidden_states |
| | ) |
| |
|
| |
|
| | class EmbodiedMAEForMaskedImageModeling(EmbodiedMAEModel): |
| | def __init__(self, config: EmbodiedMAEConfig): |
| | super().__init__(config) |
| | self.decoder = EmbodiedMAEDecoder(config) |
| |
|
| | def forward( |
| | self, |
| | rgb: Optional[torch.Tensor], |
| | depth: Optional[torch.Tensor], |
| | pc: Optional[torch.Tensor], |
| | add_mask: bool = True, |
| | unmask_sz: Optional[int] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | forward_pc: bool = True, |
| | ): |
| | encoder_output = super().forward( |
| | rgb, depth, pc, add_mask, unmask_sz, output_attentions, output_hidden_states, forward_pc |
| | ) |
| | decoder_input = self.decoder.get_decoder_input(encoder_output) |
| | return self.decoder(decoder_input) |
| |
|
| | @torch.no_grad() |
| | def visualize( |
| | self, |
| | rgb: Optional[torch.Tensor], |
| | depth: Optional[torch.Tensor], |
| | pc: Optional[torch.Tensor], |
| | mask_rgb: bool = False, |
| | mask_depth: bool = False, |
| | mask_pc: bool = False, |
| | add_mask: bool = True, |
| | unmask_sz: Optional[int] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | forward_pc: bool = True, |
| | ): |
| | _rgb = None if mask_rgb else rgb |
| | _depth = None if mask_depth else depth |
| | _pc = None if mask_pc else pc |
| | encoder_output = super().forward( |
| | _rgb, |
| | _depth, |
| | _pc, |
| | add_mask, |
| | unmask_sz, |
| | output_attentions, |
| | output_hidden_states, |
| | forward_pc, |
| | ) |
| | decoder_input = self.decoder.get_decoder_input(encoder_output) |
| | return self.decoder.visualize(decoder_input, rgb, depth, pc) |
| |
|
| |
|
| | __all__ = [EmbodiedMAEModel, EmbodiedMAEForMaskedImageModeling] |
| |
|