Spaces:
Running on Zero
Running on Zero
File size: 4,511 Bytes
7f48ee8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import torch
import torch.nn as nn
from vggt.heads.camera_head import CameraHead
from vggt.heads.dpt_head import DPTHead
from .aggregator import Aggregator
from .decoder import Decoder
def freeze_all_params(modules):
for module in modules:
try:
for n, param in module.named_parameters():
param.requires_grad = False
except AttributeError:
# module is directly a parameter
module.requires_grad = False
class VDPM(nn.Module):
def __init__(self, cfg, img_size=518, patch_size=14, embed_dim=1024):
super().__init__()
self.cfg = cfg
self.aggregator = Aggregator(
img_size=img_size,
patch_size=patch_size,
embed_dim=embed_dim,
)
self.decoder = Decoder(
cfg,
dim_in=2*embed_dim,
embed_dim=embed_dim,
depth=cfg.model.decoder_depth
)
self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
self.camera_head = CameraHead(dim_in=2 * embed_dim)
self.set_freeze()
def set_freeze(self):
to_be_frozen = [self.aggregator.patch_embed]
freeze_all_params(to_be_frozen)
def forward(
self,
views, autocast_dpt=None
):
images = torch.stack([view["img"] for view in views], dim=1)
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
res_dynamic = dict()
if self.decoder is not None:
cond_view_idxs = torch.stack([view["view_idxs"][:, 1] for view in views], dim=1)
decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
if autocast_dpt is None:
autocast_dpt = torch.amp.autocast("cuda", enabled=False)
with autocast_dpt:
pts3d, pts3d_conf = self.point_head(
aggregated_tokens_list, images, patch_start_idx
)
padded_decoded_tokens = [None] * len(aggregated_tokens_list)
for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
pts3d_dyn, pts3d_dyn_conf = self.point_head(
padded_decoded_tokens, images, patch_start_idx
)
res_dynamic |= {
"pts3d": pts3d_dyn,
"conf": pts3d_dyn_conf
}
pose_enc_list = self.camera_head(aggregated_tokens_list)
res_dynamic |= {"pose_enc_list": pose_enc_list}
res_static = dict(
pts3d=pts3d,
conf=pts3d_conf
)
return res_static, res_dynamic
def inference(
self,
views,
images=None
):
autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
if images is None:
images = torch.stack([view["img"] for view in views], dim=1)
with autocast_amp:
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
S = images.shape[1]
predictions = dict()
pointmaps = []
ones = torch.ones(1, S, dtype=torch.int64)
for time_ in range(S):
cond_view_idxs = ones * time_
with autocast_amp:
decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
padded_decoded_tokens = [None] * len(aggregated_tokens_list)
for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
pts3d, pts3d_conf = self.point_head(
padded_decoded_tokens, images, patch_start_idx
)
pointmaps.append(dict(
pts3d=pts3d,
conf=pts3d_conf
))
pose_enc_list = self.camera_head(aggregated_tokens_list)
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
predictions["pose_enc_list"] = pose_enc_list
predictions["pointmaps"] = pointmaps
return predictions
def load_state_dict(self, ckpt, is_VGGT_static=False, **kw):
# don't load these VGGT heads as not needed
exclude = ["depth_head", "track_head"]
ckpt = {k:v for k, v in ckpt.items() if k.split('.')[0] not in exclude}
return super().load_state_dict(ckpt, **kw)
|