from __future__ import print_function, division import spaces import sys sys.path.insert(0,'core') sys.path.append('core/utils') import os import argparse import gradio as gr import cv2 from core.raft_stereo_depthbeta_refine import RAFTStereoDepthBetaRefine import torch import torch.nn as nn from core.utils.utils import InputPadder import matplotlib.pyplot as plt from huggingface_hub import hf_hub_download parser = argparse.ArgumentParser() parser.add_argument('--root', help="dataset root", default=None) parser.add_argument('--sv_root', help="visualization root", default=None) parser.add_argument('--test_exp_name', default='', help="name your experiment in testing") parser.add_argument('--mast3r_model_path', default='MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth', help="pretrained model path for MaSt3R") parser.add_argument('--depthany_model_dir', default='./dav2_models', help="directory of pretrained model path for DepthAnything") parser.add_argument('--restore_ckpt', help="restore checkpoint", default="./ckpts/diving_stereo.pth") parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass') parser.add_argument('--eval', action='store_true', help='evaluation mode') parser.add_argument('--is_test', action='store_true', help='on testing') # Architecure choices parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions") parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation") parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders") parser.add_argument('--corr_levels', type=int, default=4, help="number of levels in the correlation pyramid") parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid") parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)") parser.add_argument('--context_norm', type=str, default="batch", choices=['group', 'batch', 'instance', 'none'], help="normalization of context encoder") parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") parser.add_argument('--lbp_neighbor_offsets', default='(-5,-5), (5,5), (5,-5), (-5,5), (-3,0), (3,0), (0,-3), (0,3)', help="determine the neighbors used in LBP encoder") parser.add_argument('--modulation_ratio', type=float, default=1., help="hyperparameters for modulation") parser.add_argument('--modulation_alg', choices=["linear", "sigmoid"], default="linear", help="rescale modulation") parser.add_argument('--conf_from_fea', action='store_true', help="confidence in refinement not only from cost volume but also from other features") parser.add_argument('--refine_pool', action='store_true', help="use pooling in refinement") parser.add_argument('--refine_unet', action='store_true', help="use EfficientUnet in refinement") parser.add_argument('--improvement', action='store_true', help="visualize improvement map (error_map[i] - error_map[i-1])") parser.add_argument('--movement', action='store_true', help="visualize movement map (flow_pr[i] - flow_pr[i-1])") parser.add_argument('--acceleration', action='store_true', help="visualize acceleration map (movement_map[i] - movement_map[i-1])") parser.add_argument('--mask', action='store_true', help="visualize mask") parser.add_argument('--binary_thold', type=float, default=0.5, help="visualize binary mask") args = parser.parse_args() args.conf_from_fea = True args.eval = True model = RAFTStereoDepthBetaRefine(args) model = torch.nn.DataParallel(model, device_ids=[0]) checkpoint_path = hf_hub_download( repo_id="BFZD/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching", filename="ckpts/diving_stereo.pth", ) checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) # model.load_state_dict(checkpoint, strict=True) new_state_dict = {} for key, value in checkpoint.items(): if key.find("lbp_encoder.lbp_conv") != -1: continue new_state_dict[key] = value # model.load_state_dict(new_state_dict, strict=True) model.load_state_dict(new_state_dict, strict=False) model.cuda() model.eval() @spaces.GPU def predict(image1, image2): with torch.no_grad(): image1 = torch.from_numpy(image1).permute(2, 0, 1).float() image2 = torch.from_numpy(image2).permute(2, 0, 1).float() image1 = image1[None][:,:3,:,:].cuda() image2 = image2[None][:,:3,:,:].cuda() padder = InputPadder(image1.shape, divis_by=32) image1, image2 = padder.pad(image1, image2) _, disp = model(image1, image2, iters=args.valid_iters, test_mode=True, vis_mode=True) output = disp.abs().cpu().numpy() disp = padder.unpad(output) disp = disp.squeeze() normalized_disp = (disp - disp.min()) / (disp.max() - disp.min()) cmap = plt.get_cmap('jet') colored_disp = cmap(normalized_disp)[:, :, :3] # Get RGB channels return colored_disp with gr.Blocks() as demo: gr.HTML('''

[ICCV25 Oral] Diving into the Fusion of Monocular Priors for Generalized Stereo Matching GitHub Stars

''') with gr.Row(): left_img = gr.Image(label="Left Image") right_img = gr.Image(label="Right Image") output_img = gr.Image(label="Disparity Map") btn = gr.Button("Submit") btn.click( fn=predict, inputs=[left_img, right_img], outputs=output_img ) demo.launch()