File size: 3,878 Bytes
5ce8761 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | import einops
from torch import nn
from torchvision.ops import Conv2dNormActivation
from ..vision.fpn import EfficientFeaturePyramidNetwork
from .base_encoder import Encoder as BaseEncoder
class Encoder(BaseEncoder):
def __init__(self,
backbone="clip",
embedding_dim=60,
nhist=1,
num_attn_heads=9,
num_vis_instr_attn_layers=2,
fps_subsampling_factor=5,
finetune_backbone=False,
finetune_text_encoder=False,
rot_dim=3):
super().__init__(
backbone=backbone,
embedding_dim=embedding_dim,
nhist=nhist,
num_attn_heads=num_attn_heads,
num_vis_instr_attn_layers=num_vis_instr_attn_layers,
fps_subsampling_factor=fps_subsampling_factor,
finetune_backbone=finetune_backbone,
finetune_text_encoder=finetune_text_encoder
)
# Postprocess scene features
if self._backbone_name == 'clip':
self.feature_pyramid = EfficientFeaturePyramidNetwork(
[64, 256, 512, 1024, 2048],
embedding_dim, output_level="res4"
)
self.rgb2d_proj = nn.Conv2d(2048, embedding_dim, 1)
# Camera ids
self.camera_ids = nn.Embedding(5, embedding_dim)
# Proprioception learnable projection if no 3D is used
self.rot_dim = rot_dim
self.proprio_feat = nn.Linear(3 + rot_dim, embedding_dim)
def encode_proprio(self, proprio, context_feats, context_pos):
"""
Compute proprioception features.
Args:
- proprio: (B, nhist, 3+)
- context_feats: (B, npt, C)
- context_pos: (B, npt, 3)
Returns:
- gripper_feats: (B, nhist, F)
"""
return self.proprio_feat(proprio[..., :3 + self.rot_dim])
def encode_clip(self, rgb3d, rgb2d, pcd, text):
"""
Compute visual features/pos embeddings.
Args:
- rgb3d: (B, ncam3d, 3, H, W), rgb obs of 3D cameras
- rgb2d: (B, ncam2d, 3, H, W), rgb obs of 2D cameras
- pcd: (B, ncam3d, 3, H, W) or None
- text: [str] of len=B, text instruction
Returns:
- rgb3d_feats: (B, Np, F)
- rgb2d_feats: (B, ncam2d, F)
- pcd: (B, Np, 3)
- instr_feats: (B, L, F)
"""
# Encode language
instruction = self.text_encoder(text)
instr_feats = self.instruction_encoder(instruction)
# 3D camera features (not 3D, we just keep the naming convention)
rgb3d_feats = None
if rgb3d is not None:
num_cameras = rgb3d.shape[1]
# Pass each view independently through backbone
rgb3d = einops.rearrange(rgb3d, "bt ncam c h w -> (bt ncam) c h w")
rgb3d = self.normalize(rgb3d)
rgb3d_feats = self.backbone(rgb3d)
# Pass visual features through feature pyramid network
rgb3d_feats = self.feature_pyramid(rgb3d_feats)["res4"]
# Add camera id embeddings
rgb3d_feats = einops.rearrange(
rgb3d_feats,
"(bt ncam) c h w -> bt ncam c h w", ncam=num_cameras
)
rgb3d_feats = rgb3d_feats + self.camera_ids.weight[:num_cameras][
None, :, :, None, None
]
# Merge different cameras
rgb3d_feats = einops.rearrange(
rgb3d_feats, "bt ncam c h w -> bt (ncam h w) c"
)
# Attention from vision to language
rgb3d_feats = self.vl_attention(seq1=rgb3d_feats, seq2=instr_feats)[-1]
# 2D camera features
rgb2d_feats = None
return rgb3d_feats, rgb2d_feats, None, instr_feats
|