|
|
from __future__ import print_function, division |
|
|
import spaces |
|
|
import sys |
|
|
sys.path.insert(0,'core') |
|
|
sys.path.append('core/utils') |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import gradio as gr |
|
|
import cv2 |
|
|
from core.raft_stereo_depthbeta_refine import RAFTStereoDepthBetaRefine |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from core.utils.utils import InputPadder |
|
|
import matplotlib.pyplot as plt |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--root', help="dataset root", default=None) |
|
|
parser.add_argument('--sv_root', help="visualization root", default=None) |
|
|
parser.add_argument('--test_exp_name', default='', help="name your experiment in testing") |
|
|
parser.add_argument('--mast3r_model_path', default='MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth', help="pretrained model path for MaSt3R") |
|
|
parser.add_argument('--depthany_model_dir', default='./dav2_models', help="directory of pretrained model path for DepthAnything") |
|
|
parser.add_argument('--restore_ckpt', help="restore checkpoint", default="./ckpts/diving_stereo.pth") |
|
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') |
|
|
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass') |
|
|
parser.add_argument('--eval', action='store_true', help='evaluation mode') |
|
|
parser.add_argument('--is_test', action='store_true', help='on testing') |
|
|
|
|
|
|
|
|
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions") |
|
|
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation") |
|
|
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders") |
|
|
parser.add_argument('--corr_levels', type=int, default=4, help="number of levels in the correlation pyramid") |
|
|
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid") |
|
|
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)") |
|
|
parser.add_argument('--context_norm', type=str, default="batch", choices=['group', 'batch', 'instance', 'none'], help="normalization of context encoder") |
|
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") |
|
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") |
|
|
|
|
|
parser.add_argument('--lbp_neighbor_offsets', default='(-5,-5), (5,5), (5,-5), (-5,5), (-3,0), (3,0), (0,-3), (0,3)', help="determine the neighbors used in LBP encoder") |
|
|
parser.add_argument('--modulation_ratio', type=float, default=1., help="hyperparameters for modulation") |
|
|
parser.add_argument('--modulation_alg', choices=["linear", "sigmoid"], default="linear", help="rescale modulation") |
|
|
parser.add_argument('--conf_from_fea', action='store_true', help="confidence in refinement not only from cost volume but also from other features") |
|
|
parser.add_argument('--refine_pool', action='store_true', help="use pooling in refinement") |
|
|
parser.add_argument('--refine_unet', action='store_true', help="use EfficientUnet in refinement") |
|
|
|
|
|
parser.add_argument('--improvement', action='store_true', help="visualize improvement map (error_map[i] - error_map[i-1])") |
|
|
parser.add_argument('--movement', action='store_true', help="visualize movement map (flow_pr[i] - flow_pr[i-1])") |
|
|
parser.add_argument('--acceleration', action='store_true', help="visualize acceleration map (movement_map[i] - movement_map[i-1])") |
|
|
parser.add_argument('--mask', action='store_true', help="visualize mask") |
|
|
parser.add_argument('--binary_thold', type=float, default=0.5, help="visualize binary mask") |
|
|
|
|
|
args = parser.parse_args() |
|
|
args.conf_from_fea = True |
|
|
args.eval = True |
|
|
|
|
|
model = RAFTStereoDepthBetaRefine(args) |
|
|
model = torch.nn.DataParallel(model, device_ids=[0]) |
|
|
|
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id="BFZD/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching", |
|
|
filename="ckpts/diving_stereo.pth", |
|
|
) |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) |
|
|
|
|
|
new_state_dict = {} |
|
|
for key, value in checkpoint.items(): |
|
|
if key.find("lbp_encoder.lbp_conv") != -1: |
|
|
continue |
|
|
new_state_dict[key] = value |
|
|
|
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
|
|
|
|
model.cuda() |
|
|
model.eval() |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def predict(image1, image2): |
|
|
with torch.no_grad(): |
|
|
image1 = torch.from_numpy(image1).permute(2, 0, 1).float() |
|
|
image2 = torch.from_numpy(image2).permute(2, 0, 1).float() |
|
|
image1 = image1[None][:,:3,:,:].cuda() |
|
|
image2 = image2[None][:,:3,:,:].cuda() |
|
|
padder = InputPadder(image1.shape, divis_by=32) |
|
|
image1, image2 = padder.pad(image1, image2) |
|
|
_, disp = model(image1, image2, iters=args.valid_iters, test_mode=True, vis_mode=True) |
|
|
output = disp.abs().cpu().numpy() |
|
|
disp = padder.unpad(output) |
|
|
disp = disp.squeeze() |
|
|
normalized_disp = (disp - disp.min()) / (disp.max() - disp.min()) |
|
|
cmap = plt.get_cmap('jet') |
|
|
colored_disp = cmap(normalized_disp)[:, :, :3] |
|
|
|
|
|
return colored_disp |
|
|
with gr.Blocks() as demo: |
|
|
gr.HTML(''' |
|
|
<div align="center"> |
|
|
<h1> [ICCV25 Oral] Diving into the Fusion of Monocular Priors for Generalized Stereo Matching |
|
|
<a title="Github" href="https://github.com/YaoChengTang/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> <img src="https://img.shields.io/github/stars/YaoChengTang/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="GitHub Stars"> </a></h1> |
|
|
</div> |
|
|
''') |
|
|
|
|
|
with gr.Row(): |
|
|
left_img = gr.Image(label="Left Image") |
|
|
right_img = gr.Image(label="Right Image") |
|
|
output_img = gr.Image(label="Disparity Map") |
|
|
|
|
|
btn = gr.Button("Submit") |
|
|
btn.click( |
|
|
fn=predict, |
|
|
inputs=[left_img, right_img], |
|
|
outputs=output_img |
|
|
) |
|
|
demo.launch() |
|
|
|