| """Wrapper for DUSt3R model to estimate focal length. |
| |
| DUSt3R: Geometric 3D Vision Made Easy, https://arxiv.org/abs/2312.14132 |
| """ |
|
|
| import sys |
|
|
| sys.path.append("third_party/dust3r") |
|
|
| import torch |
| from dust3r.cloud_opt import GlobalAlignerMode, global_aligner |
| from dust3r.image_pairs import make_pairs |
| from dust3r.inference import inference, load_model |
| from dust3r.utils.image import load_images |
|
|
| from siclib.geometry.base_camera import BaseCamera |
| from siclib.geometry.gravity import Gravity |
| from siclib.models import BaseModel |
|
|
| |
|
|
|
|
| class Dust3R(BaseModel): |
| """DUSt3R model for focal length estimation.""" |
|
|
| default_conf = { |
| "model_path": "weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", |
| "device": "cuda", |
| "batch_size": 1, |
| "schedule": "cosine", |
| "lr": 0.01, |
| "niter": 300, |
| "show_scene": False, |
| } |
|
|
| required_data_keys = ["path"] |
|
|
| def _init(self, conf): |
| """Initialize the DUSt3R model.""" |
| self.model = load_model(conf["model_path"], conf["device"]) |
|
|
| def _forward(self, data): |
| """Forward pass of the DUSt.""" |
| assert len(data["path"]) == 1, f"Only batch size of 1 is supported (bs={len(data['path'])}" |
|
|
| path = data["path"][0] |
| images = [path] * 2 |
|
|
| with torch.enable_grad(): |
| images = load_images(images, size=512) |
| pairs = make_pairs(images, scene_graph="complete", prefilter=None, symmetrize=True) |
| output = inference( |
| pairs, self.model, self.conf["device"], batch_size=self.conf["batch_size"] |
| ) |
| scene = global_aligner( |
| output, device=self.conf["device"], mode=GlobalAlignerMode.PointCloudOptimizer |
| ) |
| _ = scene.compute_global_alignment( |
| init="mst", |
| niter=self.conf["niter"], |
| schedule=self.conf["schedule"], |
| lr=self.conf["lr"], |
| ) |
|
|
| |
| focals = scene.get_focals().mean(dim=0) |
|
|
| h, w = images[0]["true_shape"][:, 0], images[0]["true_shape"][:, 1] |
| h, w = focals.new_tensor(h), focals.new_tensor(w) |
|
|
| camera = BaseCamera.from_dict({"height": h, "width": w, "f": focals}) |
| gravity = Gravity.from_rp([0.0], [0.0]) |
|
|
| if self.conf["show_scene"]: |
| scene.show() |
|
|
| return {"camera": camera, "gravity": gravity} |
|
|
| def loss(self, pred, data): |
| """Loss function for DUSt3R model.""" |
| return {}, {} |
|
|