|
|
import sys |
|
|
import os |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
|
|
from collections import OrderedDict |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from copy import deepcopy |
|
|
from functools import partial |
|
|
from typing import Optional, Tuple, List, Any |
|
|
from dataclasses import dataclass |
|
|
from transformers import PretrainedConfig |
|
|
from transformers.file_utils import ModelOutput |
|
|
from dust3r.utils.misc import ( |
|
|
fill_default_args, |
|
|
freeze_all_params, |
|
|
is_symmetrized, |
|
|
interleave, |
|
|
transpose_to_landscape, |
|
|
) |
|
|
from dust3r.heads import head_factory |
|
|
from dust3r.utils.camera import PoseEncoder |
|
|
from dust3r.patch_embed import get_patch_embed |
|
|
import dust3r.utils.path_to_croco |
|
|
from models.croco import CroCoNet, CrocoConfig |
|
|
from dust3r.blocks import ( |
|
|
Block, |
|
|
DecoderBlock, |
|
|
Mlp, |
|
|
Attention, |
|
|
CrossAttention, |
|
|
DropPath, |
|
|
CustomDecoderBlock, |
|
|
) |
|
|
|
|
|
inf = float("inf") |
|
|
from accelerate.logging import get_logger |
|
|
|
|
|
printer = get_logger(__name__, log_level="DEBUG") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ARCroco3DStereoOutput(ModelOutput): |
|
|
""" |
|
|
Custom output class for ARCroco3DStereo. |
|
|
""" |
|
|
|
|
|
ress: Optional[List[Any]] = None |
|
|
views: Optional[List[Any]] = None |
|
|
|
|
|
|
|
|
def strip_module(state_dict): |
|
|
""" |
|
|
Removes the 'module.' prefix from the keys of a state_dict. |
|
|
Args: |
|
|
state_dict (dict): The original state_dict with possible 'module.' prefixes. |
|
|
Returns: |
|
|
OrderedDict: A new state_dict with 'module.' prefixes removed. |
|
|
""" |
|
|
new_state_dict = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
name = k[7:] if k.startswith("module.") else k |
|
|
new_state_dict[name] = v |
|
|
return new_state_dict |
|
|
|
|
|
|
|
|
def load_model(model_path, device, verbose=True): |
|
|
if verbose: |
|
|
print("... loading model from", model_path) |
|
|
ckpt = torch.load(model_path, map_location="cpu", weights_only=False) |
|
|
args = ckpt["args"].model.replace( |
|
|
"ManyAR_PatchEmbed", "PatchEmbedDust3R" |
|
|
) |
|
|
if "landscape_only" not in args: |
|
|
args = args[:-2] + ", landscape_only=False))" |
|
|
else: |
|
|
args = args.replace(" ", "").replace( |
|
|
"landscape_only=True", "landscape_only=False" |
|
|
) |
|
|
assert "landscape_only=False" in args |
|
|
if verbose: |
|
|
print(f"instantiating : {args}") |
|
|
net = eval(args) |
|
|
s = net.load_state_dict(ckpt["model"], strict=False) |
|
|
if verbose: |
|
|
print(s) |
|
|
return net.to(device) |
|
|
|
|
|
|
|
|
class ARCroco3DStereoConfig(PretrainedConfig): |
|
|
model_type = "arcroco_3d_stereo" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
output_mode="pts3d", |
|
|
head_type="linear", |
|
|
depth_mode=("exp", -float("inf"), float("inf")), |
|
|
conf_mode=("exp", 1, float("inf")), |
|
|
pose_mode=("exp", -float("inf"), float("inf")), |
|
|
freeze="none", |
|
|
landscape_only=True, |
|
|
patch_embed_cls="PatchEmbedDust3R", |
|
|
ray_enc_depth=2, |
|
|
state_size=324, |
|
|
local_mem_size=256, |
|
|
state_pe="2d", |
|
|
state_dec_num_heads=16, |
|
|
depth_head=False, |
|
|
rgb_head=False, |
|
|
pose_conf_head=False, |
|
|
pose_head=False, |
|
|
**croco_kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.output_mode = output_mode |
|
|
self.head_type = head_type |
|
|
self.depth_mode = depth_mode |
|
|
self.conf_mode = conf_mode |
|
|
self.pose_mode = pose_mode |
|
|
self.freeze = freeze |
|
|
self.landscape_only = landscape_only |
|
|
self.patch_embed_cls = patch_embed_cls |
|
|
self.ray_enc_depth = ray_enc_depth |
|
|
self.state_size = state_size |
|
|
self.state_pe = state_pe |
|
|
self.state_dec_num_heads = state_dec_num_heads |
|
|
self.local_mem_size = local_mem_size |
|
|
self.depth_head = depth_head |
|
|
self.rgb_head = rgb_head |
|
|
self.pose_conf_head = pose_conf_head |
|
|
self.pose_head = pose_head |
|
|
self.croco_kwargs = croco_kwargs |
|
|
|
|
|
|
|
|
class LocalMemory(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
size, |
|
|
k_dim, |
|
|
v_dim, |
|
|
num_heads, |
|
|
depth=2, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=False, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
drop_path=0.0, |
|
|
act_layer=nn.GELU, |
|
|
norm_layer=nn.LayerNorm, |
|
|
norm_mem=True, |
|
|
rope=None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.v_dim = v_dim |
|
|
self.proj_q = nn.Linear(k_dim, v_dim) |
|
|
self.masked_token = nn.Parameter( |
|
|
torch.randn(1, 1, v_dim) * 0.2, requires_grad=True |
|
|
) |
|
|
self.mem = nn.Parameter( |
|
|
torch.randn(1, size, 2 * v_dim) * 0.2, requires_grad=True |
|
|
) |
|
|
self.write_blocks = nn.ModuleList( |
|
|
[ |
|
|
DecoderBlock( |
|
|
2 * v_dim, |
|
|
num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
norm_layer=norm_layer, |
|
|
attn_drop=attn_drop, |
|
|
drop=drop, |
|
|
drop_path=drop_path, |
|
|
act_layer=act_layer, |
|
|
norm_mem=norm_mem, |
|
|
rope=rope, |
|
|
) |
|
|
for _ in range(depth) |
|
|
] |
|
|
) |
|
|
self.read_blocks = nn.ModuleList( |
|
|
[ |
|
|
DecoderBlock( |
|
|
2 * v_dim, |
|
|
num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
norm_layer=norm_layer, |
|
|
attn_drop=attn_drop, |
|
|
drop=drop, |
|
|
drop_path=drop_path, |
|
|
act_layer=act_layer, |
|
|
norm_mem=norm_mem, |
|
|
rope=rope, |
|
|
) |
|
|
for _ in range(depth) |
|
|
] |
|
|
) |
|
|
|
|
|
def update_mem(self, mem, feat_k, feat_v): |
|
|
""" |
|
|
mem_k: [B, size, C] |
|
|
mem_v: [B, size, C] |
|
|
feat_k: [B, 1, C] |
|
|
feat_v: [B, 1, C] |
|
|
""" |
|
|
feat_k = self.proj_q(feat_k) |
|
|
feat = torch.cat([feat_k, feat_v], dim=-1) |
|
|
for blk in self.write_blocks: |
|
|
mem, _ = blk(mem, feat, None, None) |
|
|
return mem |
|
|
|
|
|
def inquire(self, query, mem): |
|
|
x = self.proj_q(query) |
|
|
x = torch.cat([x, self.masked_token.expand(x.shape[0], -1, -1)], dim=-1) |
|
|
for blk in self.read_blocks: |
|
|
x, _ = blk(x, mem, None, None) |
|
|
return x[..., -self.v_dim :] |
|
|
|
|
|
|
|
|
class ARCroco3DStereo(CroCoNet): |
|
|
config_class = ARCroco3DStereoConfig |
|
|
base_model_prefix = "arcroco3dstereo" |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__(self, config: ARCroco3DStereoConfig): |
|
|
self.gradient_checkpointing = False |
|
|
self.fixed_input_length = True |
|
|
config.croco_kwargs = fill_default_args( |
|
|
config.croco_kwargs, CrocoConfig.__init__ |
|
|
) |
|
|
self.config = config |
|
|
self.patch_embed_cls = config.patch_embed_cls |
|
|
self.croco_args = config.croco_kwargs |
|
|
croco_cfg = CrocoConfig(**self.croco_args) |
|
|
super().__init__(croco_cfg) |
|
|
self.enc_blocks_ray_map = nn.ModuleList( |
|
|
[ |
|
|
Block( |
|
|
self.enc_embed_dim, |
|
|
16, |
|
|
4, |
|
|
qkv_bias=True, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
rope=self.rope, |
|
|
) |
|
|
for _ in range(config.ray_enc_depth) |
|
|
] |
|
|
) |
|
|
self.enc_norm_ray_map = nn.LayerNorm(self.enc_embed_dim, eps=1e-6) |
|
|
self.dec_num_heads = self.croco_args["dec_num_heads"] |
|
|
self.pose_head_flag = config.pose_head |
|
|
if self.pose_head_flag: |
|
|
self.pose_token = nn.Parameter( |
|
|
torch.randn(1, 1, self.dec_embed_dim) * 0.02, requires_grad=True |
|
|
) |
|
|
self.pose_retriever = LocalMemory( |
|
|
size=config.local_mem_size, |
|
|
k_dim=self.enc_embed_dim, |
|
|
v_dim=self.dec_embed_dim, |
|
|
num_heads=self.dec_num_heads, |
|
|
mlp_ratio=4, |
|
|
qkv_bias=True, |
|
|
attn_drop=0.0, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
rope=None, |
|
|
) |
|
|
self.register_tokens = nn.Embedding(config.state_size, self.enc_embed_dim) |
|
|
self.state_size = config.state_size |
|
|
self.state_pe = config.state_pe |
|
|
self.masked_img_token = nn.Parameter( |
|
|
torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True |
|
|
) |
|
|
self.masked_ray_map_token = nn.Parameter( |
|
|
torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True |
|
|
) |
|
|
self._set_state_decoder( |
|
|
self.enc_embed_dim, |
|
|
self.dec_embed_dim, |
|
|
config.state_dec_num_heads, |
|
|
self.dec_depth, |
|
|
self.croco_args.get("mlp_ratio", None), |
|
|
self.croco_args.get("norm_layer", None), |
|
|
self.croco_args.get("norm_im2_in_dec", None), |
|
|
) |
|
|
self.set_downstream_head( |
|
|
config.output_mode, |
|
|
config.head_type, |
|
|
config.landscape_only, |
|
|
config.depth_mode, |
|
|
config.conf_mode, |
|
|
config.pose_mode, |
|
|
config.depth_head, |
|
|
config.rgb_head, |
|
|
config.pose_conf_head, |
|
|
config.pose_head, |
|
|
**self.croco_args, |
|
|
) |
|
|
self.set_freeze(config.freeze) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kw): |
|
|
if os.path.isfile(pretrained_model_name_or_path): |
|
|
return load_model(pretrained_model_name_or_path, device="cpu") |
|
|
else: |
|
|
try: |
|
|
model = super(ARCroco3DStereo, cls).from_pretrained( |
|
|
pretrained_model_name_or_path, **kw |
|
|
) |
|
|
except TypeError as e: |
|
|
raise Exception( |
|
|
f"tried to load {pretrained_model_name_or_path} from huggingface, but failed" |
|
|
) |
|
|
return model |
|
|
|
|
|
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): |
|
|
self.patch_embed = get_patch_embed( |
|
|
self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3 |
|
|
) |
|
|
self.patch_embed_ray_map = get_patch_embed( |
|
|
self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=6 |
|
|
) |
|
|
|
|
|
def _set_decoder( |
|
|
self, |
|
|
enc_embed_dim, |
|
|
dec_embed_dim, |
|
|
dec_num_heads, |
|
|
dec_depth, |
|
|
mlp_ratio, |
|
|
norm_layer, |
|
|
norm_im2_in_dec, |
|
|
): |
|
|
self.dec_depth = dec_depth |
|
|
self.dec_embed_dim = dec_embed_dim |
|
|
self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) |
|
|
self.dec_blocks = nn.ModuleList( |
|
|
[ |
|
|
DecoderBlock( |
|
|
dec_embed_dim, |
|
|
dec_num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=True, |
|
|
norm_layer=norm_layer, |
|
|
norm_mem=norm_im2_in_dec, |
|
|
rope=self.rope, |
|
|
) |
|
|
for i in range(dec_depth) |
|
|
] |
|
|
) |
|
|
self.dec_norm = norm_layer(dec_embed_dim) |
|
|
|
|
|
def _set_state_decoder( |
|
|
self, |
|
|
enc_embed_dim, |
|
|
dec_embed_dim, |
|
|
dec_num_heads, |
|
|
dec_depth, |
|
|
mlp_ratio, |
|
|
norm_layer, |
|
|
norm_im2_in_dec, |
|
|
): |
|
|
self.dec_depth_state = dec_depth |
|
|
self.dec_embed_dim_state = dec_embed_dim |
|
|
self.decoder_embed_state = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) |
|
|
self.dec_blocks_state = nn.ModuleList( |
|
|
[ |
|
|
DecoderBlock( |
|
|
dec_embed_dim, |
|
|
dec_num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=True, |
|
|
norm_layer=norm_layer, |
|
|
norm_mem=norm_im2_in_dec, |
|
|
rope=self.rope, |
|
|
) |
|
|
for i in range(dec_depth) |
|
|
] |
|
|
) |
|
|
self.dec_norm_state = norm_layer(dec_embed_dim) |
|
|
|
|
|
def load_state_dict(self, ckpt, **kw): |
|
|
if all(k.startswith("module") for k in ckpt): |
|
|
ckpt = strip_module(ckpt) |
|
|
new_ckpt = dict(ckpt) |
|
|
if not any(k.startswith("dec_blocks_state") for k in ckpt): |
|
|
for key, value in ckpt.items(): |
|
|
if key.startswith("dec_blocks"): |
|
|
new_ckpt[key.replace("dec_blocks", "dec_blocks_state")] = value |
|
|
try: |
|
|
return super().load_state_dict(new_ckpt, **kw) |
|
|
except: |
|
|
try: |
|
|
new_new_ckpt = { |
|
|
k: v |
|
|
for k, v in new_ckpt.items() |
|
|
if not k.startswith("dec_blocks") |
|
|
and not k.startswith("dec_norm") |
|
|
and not k.startswith("decoder_embed") |
|
|
} |
|
|
return super().load_state_dict(new_new_ckpt, **kw) |
|
|
except: |
|
|
new_new_ckpt = {} |
|
|
for key in new_ckpt: |
|
|
if key in self.state_dict(): |
|
|
if new_ckpt[key].size() == self.state_dict()[key].size(): |
|
|
new_new_ckpt[key] = new_ckpt[key] |
|
|
else: |
|
|
printer.info( |
|
|
f"Skipping '{key}': size mismatch (ckpt: {new_ckpt[key].size()}, model: {self.state_dict()[key].size()})" |
|
|
) |
|
|
else: |
|
|
printer.info(f"Skipping '{key}': not found in model") |
|
|
return super().load_state_dict(new_new_ckpt, **kw) |
|
|
|
|
|
def set_freeze(self, freeze): |
|
|
self.freeze = freeze |
|
|
to_be_frozen = { |
|
|
"none": [], |
|
|
"mask": [self.mask_token] if hasattr(self, "mask_token") else [], |
|
|
"encoder": [ |
|
|
self.patch_embed, |
|
|
self.patch_embed_ray_map, |
|
|
self.masked_img_token, |
|
|
self.masked_ray_map_token, |
|
|
self.enc_blocks, |
|
|
self.enc_blocks_ray_map, |
|
|
self.enc_norm, |
|
|
self.enc_norm_ray_map, |
|
|
], |
|
|
"encoder_and_head": [ |
|
|
self.patch_embed, |
|
|
self.patch_embed_ray_map, |
|
|
self.masked_img_token, |
|
|
self.masked_ray_map_token, |
|
|
self.enc_blocks, |
|
|
self.enc_blocks_ray_map, |
|
|
self.enc_norm, |
|
|
self.enc_norm_ray_map, |
|
|
self.downstream_head, |
|
|
], |
|
|
"encoder_and_decoder": [ |
|
|
self.patch_embed, |
|
|
self.patch_embed_ray_map, |
|
|
self.masked_img_token, |
|
|
self.masked_ray_map_token, |
|
|
self.enc_blocks, |
|
|
self.enc_blocks_ray_map, |
|
|
self.enc_norm, |
|
|
self.enc_norm_ray_map, |
|
|
self.dec_blocks, |
|
|
self.dec_blocks_state, |
|
|
self.pose_retriever, |
|
|
self.pose_token, |
|
|
self.register_tokens, |
|
|
self.decoder_embed_state, |
|
|
self.decoder_embed, |
|
|
self.dec_norm, |
|
|
self.dec_norm_state, |
|
|
], |
|
|
"decoder": [ |
|
|
self.dec_blocks, |
|
|
self.dec_blocks_state, |
|
|
self.pose_retriever, |
|
|
self.pose_token, |
|
|
], |
|
|
} |
|
|
freeze_all_params(to_be_frozen[freeze]) |
|
|
|
|
|
def _set_prediction_head(self, *args, **kwargs): |
|
|
"""No prediction head""" |
|
|
return |
|
|
|
|
|
def set_downstream_head( |
|
|
self, |
|
|
output_mode, |
|
|
head_type, |
|
|
landscape_only, |
|
|
depth_mode, |
|
|
conf_mode, |
|
|
pose_mode, |
|
|
depth_head, |
|
|
rgb_head, |
|
|
pose_conf_head, |
|
|
pose_head, |
|
|
patch_size, |
|
|
img_size, |
|
|
**kw, |
|
|
): |
|
|
assert ( |
|
|
img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0 |
|
|
), f"{img_size=} must be multiple of {patch_size=}" |
|
|
self.output_mode = output_mode |
|
|
self.head_type = head_type |
|
|
self.depth_mode = depth_mode |
|
|
self.conf_mode = conf_mode |
|
|
self.pose_mode = pose_mode |
|
|
self.downstream_head = head_factory( |
|
|
head_type, |
|
|
output_mode, |
|
|
self, |
|
|
has_conf=bool(conf_mode), |
|
|
has_depth=bool(depth_head), |
|
|
has_rgb=bool(rgb_head), |
|
|
has_pose_conf=bool(pose_conf_head), |
|
|
has_pose=bool(pose_head), |
|
|
) |
|
|
self.head = transpose_to_landscape( |
|
|
self.downstream_head, activate=landscape_only |
|
|
) |
|
|
|
|
|
def _encode_image(self, image, true_shape): |
|
|
x, pos = self.patch_embed(image, true_shape=true_shape) |
|
|
assert self.enc_pos_embed is None |
|
|
for blk in self.enc_blocks: |
|
|
if self.gradient_checkpointing and self.training: |
|
|
x = checkpoint(blk, x, pos, use_reentrant=False) |
|
|
else: |
|
|
x = blk(x, pos) |
|
|
x = self.enc_norm(x) |
|
|
return [x], pos, None |
|
|
|
|
|
def _encode_ray_map(self, ray_map, true_shape): |
|
|
x, pos = self.patch_embed_ray_map(ray_map, true_shape=true_shape) |
|
|
assert self.enc_pos_embed is None |
|
|
for blk in self.enc_blocks_ray_map: |
|
|
if self.gradient_checkpointing and self.training: |
|
|
x = checkpoint(blk, x, pos, use_reentrant=False) |
|
|
else: |
|
|
x = blk(x, pos) |
|
|
x = self.enc_norm_ray_map(x) |
|
|
return [x], pos, None |
|
|
|
|
|
def _encode_state(self, image_tokens, image_pos): |
|
|
batch_size = image_tokens.shape[0] |
|
|
state_feat = self.register_tokens( |
|
|
torch.arange(self.state_size, device=image_pos.device) |
|
|
) |
|
|
if self.state_pe == "1d": |
|
|
state_pos = ( |
|
|
torch.tensor( |
|
|
[[i, i] for i in range(self.state_size)], |
|
|
dtype=image_pos.dtype, |
|
|
device=image_pos.device, |
|
|
)[None] |
|
|
.expand(batch_size, -1, -1) |
|
|
.contiguous() |
|
|
) |
|
|
elif self.state_pe == "2d": |
|
|
width = int(self.state_size**0.5) |
|
|
width = width + 1 if width % 2 == 1 else width |
|
|
state_pos = ( |
|
|
torch.tensor( |
|
|
[[i // width, i % width] for i in range(self.state_size)], |
|
|
dtype=image_pos.dtype, |
|
|
device=image_pos.device, |
|
|
)[None] |
|
|
.expand(batch_size, -1, -1) |
|
|
.contiguous() |
|
|
) |
|
|
elif self.state_pe == "none": |
|
|
state_pos = None |
|
|
state_feat = state_feat[None].expand(batch_size, -1, -1) |
|
|
return state_feat, state_pos, None |
|
|
|
|
|
def _encode_views(self, views, img_mask=None, ray_mask=None): |
|
|
device = views[0]["img"].device |
|
|
batch_size = views[0]["img"].shape[0] |
|
|
given = True |
|
|
if img_mask is None and ray_mask is None: |
|
|
given = False |
|
|
if not given: |
|
|
img_mask = torch.stack( |
|
|
[view["img_mask"] for view in views], dim=0 |
|
|
) |
|
|
ray_mask = torch.stack( |
|
|
[view["ray_mask"] for view in views], dim=0 |
|
|
) |
|
|
imgs = torch.stack( |
|
|
[view["img"] for view in views], dim=0 |
|
|
) |
|
|
ray_maps = torch.stack( |
|
|
[view["ray_map"] for view in views], dim=0 |
|
|
) |
|
|
shapes = [] |
|
|
for view in views: |
|
|
if "true_shape" in view: |
|
|
shapes.append(view["true_shape"]) |
|
|
else: |
|
|
shape = torch.tensor(view["img"].shape[-2:], device=device) |
|
|
shapes.append(shape.unsqueeze(0).repeat(batch_size, 1)) |
|
|
shapes = torch.stack(shapes, dim=0).to( |
|
|
imgs.device |
|
|
) |
|
|
imgs = imgs.view( |
|
|
-1, *imgs.shape[2:] |
|
|
) |
|
|
ray_maps = ray_maps.view( |
|
|
-1, *ray_maps.shape[2:] |
|
|
) |
|
|
shapes = shapes.view(-1, 2) |
|
|
img_masks_flat = img_mask.view(-1) |
|
|
ray_masks_flat = ray_mask.view(-1) |
|
|
selected_imgs = imgs[img_masks_flat] |
|
|
selected_shapes = shapes[img_masks_flat] |
|
|
if selected_imgs.size(0) > 0: |
|
|
img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
full_out = [ |
|
|
torch.zeros( |
|
|
len(views) * batch_size, *img_out[0].shape[1:], device=img_out[0].device |
|
|
) |
|
|
for _ in range(len(img_out)) |
|
|
] |
|
|
full_pos = torch.zeros( |
|
|
len(views) * batch_size, |
|
|
*img_pos.shape[1:], |
|
|
device=img_pos.device, |
|
|
dtype=img_pos.dtype, |
|
|
) |
|
|
for i in range(len(img_out)): |
|
|
full_out[i][img_masks_flat] += img_out[i] |
|
|
full_out[i][~img_masks_flat] += self.masked_img_token |
|
|
full_pos[img_masks_flat] += img_pos |
|
|
ray_maps = ray_maps.permute(0, 3, 1, 2) |
|
|
selected_ray_maps = ray_maps[ray_masks_flat] |
|
|
selected_shapes_ray = shapes[ray_masks_flat] |
|
|
if selected_ray_maps.size(0) > 0: |
|
|
ray_out, ray_pos, _ = self._encode_ray_map( |
|
|
selected_ray_maps, selected_shapes_ray |
|
|
) |
|
|
assert len(ray_out) == len(full_out), f"{len(ray_out)}, {len(full_out)}" |
|
|
for i in range(len(ray_out)): |
|
|
full_out[i][ray_masks_flat] += ray_out[i] |
|
|
full_out[i][~ray_masks_flat] += self.masked_ray_map_token |
|
|
full_pos[ray_masks_flat] += ( |
|
|
ray_pos * (~img_masks_flat[ray_masks_flat][:, None, None]).long() |
|
|
) |
|
|
else: |
|
|
raymaps = torch.zeros( |
|
|
1, 6, imgs[0].shape[-2], imgs[0].shape[-1], device=img_out[0].device |
|
|
) |
|
|
ray_mask_flat = torch.zeros_like(img_masks_flat) |
|
|
ray_mask_flat[:1] = True |
|
|
ray_out, ray_pos, _ = self._encode_ray_map(raymaps, shapes[ray_mask_flat]) |
|
|
for i in range(len(ray_out)): |
|
|
full_out[i][ray_mask_flat] += ray_out[i] * 0.0 |
|
|
full_out[i][~ray_mask_flat] += self.masked_ray_map_token * 0.0 |
|
|
return ( |
|
|
shapes.chunk(len(views), dim=0), |
|
|
[out.chunk(len(views), dim=0) for out in full_out], |
|
|
full_pos.chunk(len(views), dim=0), |
|
|
) |
|
|
|
|
|
def _decoder(self, f_state, pos_state, f_img, pos_img, f_pose, pos_pose): |
|
|
final_output = [(f_state, f_img)] |
|
|
assert f_state.shape[-1] == self.dec_embed_dim |
|
|
f_img = self.decoder_embed(f_img) |
|
|
if self.pose_head_flag: |
|
|
assert f_pose is not None and pos_pose is not None |
|
|
f_img = torch.cat([f_pose, f_img], dim=1) |
|
|
pos_img = torch.cat([pos_pose, pos_img], dim=1) |
|
|
final_output.append((f_state, f_img)) |
|
|
for blk_state, blk_img in zip(self.dec_blocks_state, self.dec_blocks): |
|
|
if ( |
|
|
self.gradient_checkpointing |
|
|
and self.training |
|
|
and torch.is_grad_enabled() |
|
|
): |
|
|
f_state, _ = checkpoint( |
|
|
blk_state, |
|
|
*final_output[-1][::+1], |
|
|
pos_state, |
|
|
pos_img, |
|
|
use_reentrant=not self.fixed_input_length, |
|
|
) |
|
|
f_img, _ = checkpoint( |
|
|
blk_img, |
|
|
*final_output[-1][::-1], |
|
|
pos_img, |
|
|
pos_state, |
|
|
use_reentrant=not self.fixed_input_length, |
|
|
) |
|
|
else: |
|
|
f_state, _ = blk_state(*final_output[-1][::+1], pos_state, pos_img) |
|
|
f_img, _ = blk_img(*final_output[-1][::-1], pos_img, pos_state) |
|
|
final_output.append((f_state, f_img)) |
|
|
del final_output[1] |
|
|
final_output[-1] = ( |
|
|
self.dec_norm_state(final_output[-1][0]), |
|
|
self.dec_norm(final_output[-1][1]), |
|
|
) |
|
|
return zip(*final_output) |
|
|
|
|
|
def _downstream_head(self, decout, img_shape, **kwargs): |
|
|
B, S, D = decout[-1].shape |
|
|
head = getattr(self, f"head") |
|
|
return head(decout, img_shape, **kwargs) |
|
|
|
|
|
def _init_state(self, image_tokens, image_pos): |
|
|
""" |
|
|
Current Version: input the first frame img feature and pose to initialize the state feature and pose |
|
|
""" |
|
|
state_feat, state_pos, _ = self._encode_state(image_tokens, image_pos) |
|
|
state_feat = self.decoder_embed_state(state_feat) |
|
|
return state_feat, state_pos |
|
|
|
|
|
def _recurrent_rollout( |
|
|
self, |
|
|
state_feat, |
|
|
state_pos, |
|
|
current_feat, |
|
|
current_pos, |
|
|
pose_feat, |
|
|
pose_pos, |
|
|
init_state_feat, |
|
|
img_mask=None, |
|
|
reset_mask=None, |
|
|
update=None, |
|
|
): |
|
|
new_state_feat, dec = self._decoder( |
|
|
state_feat, state_pos, current_feat, current_pos, pose_feat, pose_pos |
|
|
) |
|
|
new_state_feat = new_state_feat[-1] |
|
|
return new_state_feat, dec |
|
|
|
|
|
def _get_img_level_feat(self, feat): |
|
|
return torch.mean(feat, dim=1, keepdim=True) |
|
|
|
|
|
def _forward_encoder(self, views): |
|
|
shape, feat_ls, pos = self._encode_views(views) |
|
|
feat = feat_ls[-1] |
|
|
state_feat, state_pos = self._init_state(feat[0], pos[0]) |
|
|
mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1) |
|
|
init_state_feat = state_feat.clone() |
|
|
init_mem = mem.clone() |
|
|
return (feat, pos, shape), ( |
|
|
init_state_feat, |
|
|
init_mem, |
|
|
state_feat, |
|
|
state_pos, |
|
|
mem, |
|
|
) |
|
|
|
|
|
def _forward_decoder_step( |
|
|
self, |
|
|
views, |
|
|
i, |
|
|
feat_i, |
|
|
pos_i, |
|
|
shape_i, |
|
|
init_state_feat, |
|
|
init_mem, |
|
|
state_feat, |
|
|
state_pos, |
|
|
mem, |
|
|
): |
|
|
if self.pose_head_flag: |
|
|
global_img_feat_i = self._get_img_level_feat(feat_i) |
|
|
if i == 0: |
|
|
pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1) |
|
|
else: |
|
|
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) |
|
|
pose_pos_i = -torch.ones( |
|
|
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype |
|
|
) |
|
|
else: |
|
|
pose_feat_i = None |
|
|
pose_pos_i = None |
|
|
new_state_feat, dec = self._recurrent_rollout( |
|
|
state_feat, |
|
|
state_pos, |
|
|
feat_i, |
|
|
pos_i, |
|
|
pose_feat_i, |
|
|
pose_pos_i, |
|
|
init_state_feat, |
|
|
img_mask=views[i]["img_mask"], |
|
|
reset_mask=views[i]["reset"], |
|
|
update=views[i].get("update", None), |
|
|
) |
|
|
out_pose_feat_i = dec[-1][:, 0:1] |
|
|
new_mem = self.pose_retriever.update_mem( |
|
|
mem, global_img_feat_i, out_pose_feat_i |
|
|
) |
|
|
head_input = [ |
|
|
dec[0].float(), |
|
|
dec[self.dec_depth * 2 // 4][:, 1:].float(), |
|
|
dec[self.dec_depth * 3 // 4][:, 1:].float(), |
|
|
dec[self.dec_depth].float(), |
|
|
] |
|
|
res = self._downstream_head(head_input, shape_i, pos=pos_i) |
|
|
img_mask = views[i]["img_mask"] |
|
|
update = views[i].get("update", None) |
|
|
if update is not None: |
|
|
update_mask = img_mask & update |
|
|
else: |
|
|
update_mask = img_mask |
|
|
update_mask = update_mask[:, None, None].float() |
|
|
state_feat = new_state_feat * update_mask + state_feat * ( |
|
|
1 - update_mask |
|
|
) |
|
|
mem = new_mem * update_mask + mem * (1 - update_mask) |
|
|
reset_mask = views[i]["reset"] |
|
|
if reset_mask is not None: |
|
|
reset_mask = reset_mask[:, None, None].float() |
|
|
state_feat = init_state_feat * reset_mask + state_feat * (1 - reset_mask) |
|
|
mem = init_mem * reset_mask + mem * (1 - reset_mask) |
|
|
return res, (state_feat, mem) |
|
|
|
|
|
def _forward_impl(self, views, ret_state=False): |
|
|
shape, feat_ls, pos = self._encode_views(views) |
|
|
feat = feat_ls[-1] |
|
|
state_feat, state_pos = self._init_state(feat[0], pos[0]) |
|
|
mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1) |
|
|
init_state_feat = state_feat.clone() |
|
|
init_mem = mem.clone() |
|
|
all_state_args = [(state_feat, state_pos, init_state_feat, mem, init_mem)] |
|
|
ress = [] |
|
|
for i in range(len(views)): |
|
|
feat_i = feat[i] |
|
|
pos_i = pos[i] |
|
|
if self.pose_head_flag: |
|
|
global_img_feat_i = self._get_img_level_feat(feat_i) |
|
|
if i == 0: |
|
|
pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1) |
|
|
else: |
|
|
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) |
|
|
pose_pos_i = -torch.ones( |
|
|
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype |
|
|
) |
|
|
else: |
|
|
pose_feat_i = None |
|
|
pose_pos_i = None |
|
|
new_state_feat, dec = self._recurrent_rollout( |
|
|
state_feat, |
|
|
state_pos, |
|
|
feat_i, |
|
|
pos_i, |
|
|
pose_feat_i, |
|
|
pose_pos_i, |
|
|
init_state_feat, |
|
|
img_mask=views[i]["img_mask"], |
|
|
reset_mask=views[i]["reset"], |
|
|
update=views[i].get("update", None), |
|
|
) |
|
|
out_pose_feat_i = dec[-1][:, 0:1] |
|
|
new_mem = self.pose_retriever.update_mem( |
|
|
mem, global_img_feat_i, out_pose_feat_i |
|
|
) |
|
|
assert len(dec) == self.dec_depth + 1 |
|
|
head_input = [ |
|
|
dec[0].float(), |
|
|
dec[self.dec_depth * 2 // 4][:, 1:].float(), |
|
|
dec[self.dec_depth * 3 // 4][:, 1:].float(), |
|
|
dec[self.dec_depth].float(), |
|
|
] |
|
|
res = self._downstream_head(head_input, shape[i], pos=pos_i) |
|
|
ress.append(res) |
|
|
img_mask = views[i]["img_mask"] |
|
|
update = views[i].get("update", None) |
|
|
if update is not None: |
|
|
update_mask = ( |
|
|
img_mask & update |
|
|
) |
|
|
else: |
|
|
update_mask = img_mask |
|
|
update_mask = update_mask[:, None, None].float() |
|
|
state_feat = new_state_feat * update_mask + state_feat * ( |
|
|
1 - update_mask |
|
|
) |
|
|
mem = new_mem * update_mask + mem * ( |
|
|
1 - update_mask |
|
|
) |
|
|
reset_mask = views[i]["reset"] |
|
|
if reset_mask is not None: |
|
|
reset_mask = reset_mask[:, None, None].float() |
|
|
state_feat = init_state_feat * reset_mask + state_feat * ( |
|
|
1 - reset_mask |
|
|
) |
|
|
mem = init_mem * reset_mask + mem * (1 - reset_mask) |
|
|
all_state_args.append( |
|
|
(state_feat, state_pos, init_state_feat, mem, init_mem) |
|
|
) |
|
|
if ret_state: |
|
|
return ress, views, all_state_args |
|
|
return ress, views |
|
|
|
|
|
def forward(self, views, ret_state=False): |
|
|
if ret_state: |
|
|
ress, views, state_args = self._forward_impl(views, ret_state=ret_state) |
|
|
return ARCroco3DStereoOutput(ress=ress, views=views), state_args |
|
|
else: |
|
|
ress, views = self._forward_impl(views, ret_state=ret_state) |
|
|
return ARCroco3DStereoOutput(ress=ress, views=views) |
|
|
|
|
|
def inference_step( |
|
|
self, view, state_feat, state_pos, init_state_feat, mem, init_mem |
|
|
): |
|
|
batch_size = view["img"].shape[0] |
|
|
raymaps = [] |
|
|
shapes = [] |
|
|
for j in range(batch_size): |
|
|
assert view["ray_mask"][j] |
|
|
raymap = view["ray_map"][[j]].permute(0, 3, 1, 2) |
|
|
raymaps.append(raymap) |
|
|
shapes.append( |
|
|
view.get( |
|
|
"true_shape", |
|
|
torch.tensor(view["ray_map"].shape[-2:])[None].repeat( |
|
|
view["ray_map"].shape[0], 1 |
|
|
), |
|
|
)[[j]] |
|
|
) |
|
|
|
|
|
raymaps = torch.cat(raymaps, dim=0) |
|
|
shape = torch.cat(shapes, dim=0).to(raymaps.device) |
|
|
feat_ls, pos, _ = self._encode_ray_map(raymaps, shapes) |
|
|
|
|
|
feat_i = feat_ls[-1] |
|
|
pos_i = pos |
|
|
if self.pose_head_flag: |
|
|
global_img_feat_i = self._get_img_level_feat(feat_i) |
|
|
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) |
|
|
pose_pos_i = -torch.ones( |
|
|
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype |
|
|
) |
|
|
else: |
|
|
pose_feat_i = None |
|
|
pose_pos_i = None |
|
|
new_state_feat, dec = self._recurrent_rollout( |
|
|
state_feat, |
|
|
state_pos, |
|
|
feat_i, |
|
|
pos_i, |
|
|
pose_feat_i, |
|
|
pose_pos_i, |
|
|
init_state_feat, |
|
|
img_mask=view["img_mask"], |
|
|
reset_mask=view["reset"], |
|
|
update=view.get("update", None), |
|
|
) |
|
|
|
|
|
out_pose_feat_i = dec[-1][:, 0:1] |
|
|
new_mem = self.pose_retriever.update_mem( |
|
|
mem, global_img_feat_i, out_pose_feat_i |
|
|
) |
|
|
assert len(dec) == self.dec_depth + 1 |
|
|
head_input = [ |
|
|
dec[0].float(), |
|
|
dec[self.dec_depth * 2 // 4][:, 1:].float(), |
|
|
dec[self.dec_depth * 3 // 4][:, 1:].float(), |
|
|
dec[self.dec_depth].float(), |
|
|
] |
|
|
res = self._downstream_head(head_input, shape, pos=pos_i) |
|
|
return res, view |
|
|
|
|
|
def forward_recurrent(self, views, device, ret_state=False): |
|
|
ress = [] |
|
|
all_state_args = [] |
|
|
for i, view in enumerate(views): |
|
|
device = view["img"].device |
|
|
batch_size = view["img"].shape[0] |
|
|
img_mask = view["img_mask"].reshape( |
|
|
-1, batch_size |
|
|
) |
|
|
ray_mask = view["ray_mask"].reshape( |
|
|
-1, batch_size |
|
|
) |
|
|
imgs = view["img"].unsqueeze(0) |
|
|
ray_maps = view["ray_map"].unsqueeze( |
|
|
0 |
|
|
) |
|
|
shapes = ( |
|
|
view["true_shape"].unsqueeze(0) |
|
|
if "true_shape" in view |
|
|
else torch.tensor(view["img"].shape[-2:], device=device) |
|
|
.unsqueeze(0) |
|
|
.repeat(batch_size, 1) |
|
|
.unsqueeze(0) |
|
|
) |
|
|
imgs = imgs.view( |
|
|
-1, *imgs.shape[2:] |
|
|
) |
|
|
ray_maps = ray_maps.view( |
|
|
-1, *ray_maps.shape[2:] |
|
|
) |
|
|
shapes = shapes.view(-1, 2).to( |
|
|
imgs.device |
|
|
) |
|
|
img_masks_flat = img_mask.view(-1) |
|
|
ray_masks_flat = ray_mask.view(-1) |
|
|
selected_imgs = imgs[img_masks_flat] |
|
|
selected_shapes = shapes[img_masks_flat] |
|
|
if selected_imgs.size(0) > 0: |
|
|
img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes) |
|
|
else: |
|
|
img_out, img_pos = None, None |
|
|
ray_maps = ray_maps.permute(0, 3, 1, 2) |
|
|
selected_ray_maps = ray_maps[ray_masks_flat] |
|
|
selected_shapes_ray = shapes[ray_masks_flat] |
|
|
if selected_ray_maps.size(0) > 0: |
|
|
ray_out, ray_pos, _ = self._encode_ray_map( |
|
|
selected_ray_maps, selected_shapes_ray |
|
|
) |
|
|
else: |
|
|
ray_out, ray_pos = None, None |
|
|
|
|
|
shape = shapes |
|
|
if img_out is not None and ray_out is None: |
|
|
feat_i = img_out[-1] |
|
|
pos_i = img_pos |
|
|
elif img_out is None and ray_out is not None: |
|
|
feat_i = ray_out[-1] |
|
|
pos_i = ray_pos |
|
|
elif img_out is not None and ray_out is not None: |
|
|
feat_i = img_out[-1] + ray_out[-1] |
|
|
pos_i = img_pos |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if i == 0: |
|
|
state_feat, state_pos = self._init_state(feat_i, pos_i) |
|
|
mem = self.pose_retriever.mem.expand(feat_i.shape[0], -1, -1) |
|
|
init_state_feat = state_feat.clone() |
|
|
init_mem = mem.clone() |
|
|
all_state_args.append( |
|
|
(state_feat, state_pos, init_state_feat, mem, init_mem) |
|
|
) |
|
|
|
|
|
if self.pose_head_flag: |
|
|
global_img_feat_i = self._get_img_level_feat(feat_i) |
|
|
if i == 0: |
|
|
pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1) |
|
|
else: |
|
|
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) |
|
|
pose_pos_i = -torch.ones( |
|
|
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype |
|
|
) |
|
|
else: |
|
|
pose_feat_i = None |
|
|
pose_pos_i = None |
|
|
new_state_feat, dec = self._recurrent_rollout( |
|
|
state_feat, |
|
|
state_pos, |
|
|
feat_i, |
|
|
pos_i, |
|
|
pose_feat_i, |
|
|
pose_pos_i, |
|
|
init_state_feat, |
|
|
img_mask=view["img_mask"], |
|
|
reset_mask=view["reset"], |
|
|
update=view.get("update", None), |
|
|
) |
|
|
out_pose_feat_i = dec[-1][:, 0:1] |
|
|
new_mem = self.pose_retriever.update_mem( |
|
|
mem, global_img_feat_i, out_pose_feat_i |
|
|
) |
|
|
assert len(dec) == self.dec_depth + 1 |
|
|
head_input = [ |
|
|
dec[0].float(), |
|
|
dec[self.dec_depth * 2 // 4][:, 1:].float(), |
|
|
dec[self.dec_depth * 3 // 4][:, 1:].float(), |
|
|
dec[self.dec_depth].float(), |
|
|
] |
|
|
res = self._downstream_head(head_input, shape, pos=pos_i) |
|
|
ress.append(res) |
|
|
img_mask = view["img_mask"] |
|
|
update = view.get("update", None) |
|
|
if update is not None: |
|
|
update_mask = ( |
|
|
img_mask & update |
|
|
) |
|
|
else: |
|
|
update_mask = img_mask |
|
|
update_mask = update_mask[:, None, None].float() |
|
|
state_feat = new_state_feat * update_mask + state_feat * ( |
|
|
1 - update_mask |
|
|
) |
|
|
mem = new_mem * update_mask + mem * ( |
|
|
1 - update_mask |
|
|
) |
|
|
reset_mask = view["reset"] |
|
|
if reset_mask is not None: |
|
|
reset_mask = reset_mask[:, None, None].float() |
|
|
state_feat = init_state_feat * reset_mask + state_feat * ( |
|
|
1 - reset_mask |
|
|
) |
|
|
mem = init_mem * reset_mask + mem * (1 - reset_mask) |
|
|
all_state_args.append( |
|
|
(state_feat, state_pos, init_state_feat, mem, init_mem) |
|
|
) |
|
|
if ret_state: |
|
|
return ress, views, all_state_args |
|
|
return ress, views |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print(ARCroco3DStereo.mro()) |
|
|
cfg = ARCroco3DStereoConfig( |
|
|
state_size=256, |
|
|
pos_embed="RoPE100", |
|
|
rgb_head=True, |
|
|
pose_head=True, |
|
|
img_size=(224, 224), |
|
|
head_type="linear", |
|
|
output_mode="pts3d+pose", |
|
|
depth_mode=("exp", -inf, inf), |
|
|
conf_mode=("exp", 1, inf), |
|
|
pose_mode=("exp", -inf, inf), |
|
|
enc_embed_dim=1024, |
|
|
enc_depth=24, |
|
|
enc_num_heads=16, |
|
|
dec_embed_dim=768, |
|
|
dec_depth=12, |
|
|
dec_num_heads=12, |
|
|
) |
|
|
ARCroco3DStereo(cfg) |
|
|
|