File size: 1,639 Bytes
2c76547 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import ClassVar
import torch
from pytorch3d.implicitron.tools.config import Configurable
from dynamic_stereo.models.core.dynamic_stereo import DynamicStereo
class DynamicStereoModel(Configurable, torch.nn.Module):
MODEL_CONFIG_NAME: ClassVar[str] = "DynamicStereoModel"
# model_weights: str = "./checkpoints/dynamic_stereo_sf.pth"
model_weights: str = "./checkpoints/dynamic_stereo_dr_sf.pth"
kernel_size: int = 20
def __post_init__(self):
super().__init__()
self.mixed_precision = False
model = DynamicStereo(
mixed_precision=self.mixed_precision,
num_frames=5,
attention_type="self_stereo_temporal_update_time_update_space",
use_3d_update_block=True,
different_update_blocks=True,
)
state_dict = torch.load(self.model_weights, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
state_dict = {"module." + k: v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
self.model = model
self.model.to("cuda")
self.model.eval()
def forward(self, batch_dict, iters=20):
return self.model.forward_batch_test(
batch_dict, kernel_size=self.kernel_size, iters=iters
)
|