import torch import torch.nn as nn from vggt.heads.camera_head import CameraHead from vggt.heads.dpt_head import DPTHead from .aggregator import Aggregator from .decoder import Decoder def freeze_all_params(modules): for module in modules: try: for n, param in module.named_parameters(): param.requires_grad = False except AttributeError: # module is directly a parameter module.requires_grad = False class VDPM(nn.Module): def __init__(self, cfg, img_size=518, patch_size=14, embed_dim=1024): super().__init__() self.cfg = cfg self.aggregator = Aggregator( img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, ) self.decoder = Decoder( cfg, dim_in=2*embed_dim, embed_dim=embed_dim, depth=cfg.model.decoder_depth ) self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") self.camera_head = CameraHead(dim_in=2 * embed_dim) self.set_freeze() def set_freeze(self): to_be_frozen = [self.aggregator.patch_embed] freeze_all_params(to_be_frozen) def forward( self, views, autocast_dpt=None ): images = torch.stack([view["img"] for view in views], dim=1) aggregated_tokens_list, patch_start_idx = self.aggregator(images) res_dynamic = dict() if self.decoder is not None: cond_view_idxs = torch.stack([view["view_idxs"][:, 1] for view in views], dim=1) decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs) if autocast_dpt is None: autocast_dpt = torch.amp.autocast("cuda", enabled=False) with autocast_dpt: pts3d, pts3d_conf = self.point_head( aggregated_tokens_list, images, patch_start_idx ) padded_decoded_tokens = [None] * len(aggregated_tokens_list) for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx): padded_decoded_tokens[layer_idx] = decoded_tokens[idx] pts3d_dyn, pts3d_dyn_conf = self.point_head( padded_decoded_tokens, images, patch_start_idx ) res_dynamic |= { "pts3d": pts3d_dyn, "conf": pts3d_dyn_conf } pose_enc_list = self.camera_head(aggregated_tokens_list) res_dynamic |= {"pose_enc_list": pose_enc_list} res_static = dict( pts3d=pts3d, conf=pts3d_conf ) return res_static, res_dynamic def inference( self, views, images=None ): autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16) if images is None: images = torch.stack([view["img"] for view in views], dim=1) with autocast_amp: aggregated_tokens_list, patch_start_idx = self.aggregator(images) S = images.shape[1] predictions = dict() pointmaps = [] ones = torch.ones(1, S, dtype=torch.int64) for time_ in range(S): cond_view_idxs = ones * time_ with autocast_amp: decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs) padded_decoded_tokens = [None] * len(aggregated_tokens_list) for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx): padded_decoded_tokens[layer_idx] = decoded_tokens[idx] pts3d, pts3d_conf = self.point_head( padded_decoded_tokens, images, patch_start_idx ) pointmaps.append(dict( pts3d=pts3d, conf=pts3d_conf )) pose_enc_list = self.camera_head(aggregated_tokens_list) predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration predictions["pose_enc_list"] = pose_enc_list predictions["pointmaps"] = pointmaps return predictions def load_state_dict(self, ckpt, is_VGGT_static=False, **kw): # don't load these VGGT heads as not needed exclude = ["depth_head", "track_head"] ckpt = {k:v for k, v in ckpt.items() if k.split('.')[0] not in exclude} return super().load_state_dict(ckpt, **kw)