|
|
""" |
|
|
Util functions to run inference with MoGe |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import warnings |
|
|
from pathlib import Path |
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
import numpy as np |
|
|
import rerun as rr |
|
|
import torch |
|
|
import torchvision |
|
|
import torchvision.transforms as tvf |
|
|
from PIL import Image |
|
|
|
|
|
from mapanything.utils.viz import log_data_to_rerun, script_add_rerun_args |
|
|
|
|
|
|
|
|
def load_moge_model( |
|
|
model_code_path: str = "MoGe", |
|
|
ckpt_path: str = "Ruicheng/moge-vitl", |
|
|
device="cuda", |
|
|
): |
|
|
""" |
|
|
Load the MoGe (ViT-L) model from huggingface hub (or load from local). |
|
|
""" |
|
|
if not Path(model_code_path).exists(): |
|
|
raise FileNotFoundError(f"MoGe code not found at {model_code_path}") |
|
|
import sys |
|
|
|
|
|
|
|
|
sys.path.append(str(model_code_path)) |
|
|
|
|
|
|
|
|
from moge.model.v1 import MoGeModel |
|
|
|
|
|
model = MoGeModel.from_pretrained(ckpt_path).to(device).eval() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def run_moge_inference(model: torch.nn.Module, image: torch.tensor, device="cuda"): |
|
|
""" |
|
|
Run MoGe inference on a batch of images or single image. |
|
|
Output is a dictionary with the following keys: |
|
|
- points: (B, H, W, 3) # scale-invariant point map in OpenCV camera coordinate system (x right, y down, z forward) |
|
|
- depth: (B, H, W) # scale-invariant depth map |
|
|
- mask: (B, H, W) # a binary mask for valid pixels |
|
|
- intrinsics: (B, 3, 3) # normalized camera intrinsics |
|
|
|
|
|
Args: |
|
|
model: MoGe model |
|
|
image: (B, 3, H, W) or (3, H, W) # RGB image in range [0, 1] |
|
|
""" |
|
|
image = image.to(device) |
|
|
return model.infer(image) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("-ip", "--image_path", default='/ocean/projects/cis220039p/mdt2/jkarhade/Any4D/benchmarking/monst3r/demo_data/lady-running/00000.jpg', type=str) |
|
|
parser.add_argument("--viz", action="store_true") |
|
|
script_add_rerun_args(parser) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.viz: |
|
|
rr.script_setup(args, f"MoGe_Pred_Viz") |
|
|
rr.set_time_seconds("stable_time", 0) |
|
|
rr.log("moge", rr.ViewCoordinates.RDF, static=True) |
|
|
|
|
|
|
|
|
img = np.array(Image.open(args.image_path)) |
|
|
transform = tvf.Compose([tvf.ToTensor()]) |
|
|
input_img = transform(img).unsqueeze(0) |
|
|
|
|
|
|
|
|
model = load_moge_model() |
|
|
|
|
|
|
|
|
output = run_moge_inference(model, input_img) |
|
|
|
|
|
|
|
|
pts3d = output["points"].cpu().squeeze(0).numpy() |
|
|
depth = output["depth"].cpu().squeeze(0).numpy() |
|
|
mask = output["mask"].cpu().squeeze(0).numpy() |
|
|
intrinsics = output["intrinsics"].cpu().squeeze(0).numpy() |
|
|
intrinsics[0, :] = intrinsics[0, :] * depth.shape[1] |
|
|
intrinsics[1, :] = intrinsics[1, :] * depth.shape[0] |
|
|
|
|
|
|
|
|
if args.viz: |
|
|
base_name = "moge" |
|
|
log_data_to_rerun( |
|
|
image=img, depthmap=depth, pose=np.eye(4), intrinsics=intrinsics, base_name=base_name, mask=np.float32(mask) |
|
|
) |
|
|
|
|
|
filtered_pts = pts3d[mask] |
|
|
filtered_pts_col = img[mask] |
|
|
pts_name = f"{base_name}/points" |
|
|
rr.log( |
|
|
pts_name, |
|
|
rr.Points3D( |
|
|
positions=filtered_pts.reshape(-1, 3), |
|
|
colors=filtered_pts_col.reshape(-1, 3), |
|
|
), |
|
|
) |
|
|
|