import spaces import os from os.path import basename, splitext, join import tempfile import gradio as gr import numpy as np from PIL import Image import torch import cv2 from torchvision.transforms.functional import to_tensor, to_pil_image from torch import Tensor from genstereo import GenStereo, AdaptiveFusionLayer import ssl from huggingface_hub import hf_hub_download from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2 ssl._create_default_https_context = ssl._create_unverified_context IMAGE_SIZE = 768 DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' CHECKPOINT_NAME = 'genstereo-v2.1' def download_models(): models = [ { 'repo': 'stabilityai/sd-vae-ft-mse', 'sub': None, 'dst': 'checkpoints/sd-vae-ft-mse', 'files': ['config.json', 'diffusion_pytorch_model.safetensors'], 'token': None }, { 'repo': 'lambdalabs/sd-image-variations-diffusers', 'sub': 'image_encoder', 'dst': 'checkpoints', 'files': ['config.json', 'pytorch_model.bin'], 'token': None }, { 'repo': 'FQiao/GenStereo', 'sub': None, 'dst': 'checkpoints/genstereo-v1.5', 'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'], 'token': None }, { 'repo': 'FQiao/GenStereo-sd2.1', 'sub': None, 'dst': 'checkpoints/genstereo-v2.1', 'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'], 'token': None }, { 'repo': 'depth-anything/Depth-Anything-V2-Large', 'sub': None, 'dst': 'checkpoints', 'files': [f'depth_anything_v2_vitl.pth'], 'token': None } ] for model in models: for file in model['files']: hf_hub_download( repo_id=model['repo'], subfolder=model['sub'], filename=file, local_dir=model['dst'], token=model['token'] ) # Setup. download_models() # DepthAnythingV2 def get_dam2_model(): model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, } encoder = 'vitl' encoder_size_map = {'vits': 'Small', 'vitb': 'Base', 'vitl': 'Large'} if encoder not in encoder_size_map: raise ValueError(f"Unsupported encoder: {encoder}. Supported: {list(encoder_size_map.keys())}") dam2 = DepthAnythingV2(**model_configs[encoder]) dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth' dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu')) dam2 = dam2.to(DEVICE).eval() return dam2 # GenStereo def get_genstereo_model(sd_version): genstereo_cfg = dict( pretrained_model_path='checkpoints', checkpoint_name=CHECKPOINT_NAME, half_precision_weights=True ) genstereo = GenStereo(cfg=genstereo_cfg, device=DEVICE, sd_version=sd_version) return genstereo # Adaptive Fusion def get_fusion_model(): fusion_model = AdaptiveFusionLayer() fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth') fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu')) fusion_model = fusion_model.to(DEVICE).eval() return fusion_model # Crop the image to the shorter side. def crop(img: Image) -> Image: W, H = img.size if W < H: left, right = 0, W top, bottom = np.ceil((H - W) / 2.), np.floor((H - W) / 2.) + W else: left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H top, bottom = 0, H return img.crop((left, top, right, bottom)) def normalize_disp(disp): return (disp - disp.min()) / (disp.max() - disp.min()) # Gradio app with tempfile.TemporaryDirectory() as tmpdir: with gr.Blocks( title='StereoGen Demo', css='img {display: inline;}' ) as demo: # Internal states. src_image = gr.State() src_depth = gr.State() # Callbacks def cb_update_sd_version(sd_version_choice): global IMAGE_SIZE, CHECKPOINT_NAME if sd_version_choice == "v1.5": IMAGE_SIZE = 512 CHECKPOINT_NAME = 'genstereo-v1.5' print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}") elif sd_version_choice == "v2.1": IMAGE_SIZE = 768 CHECKPOINT_NAME = 'genstereo-v2.1' print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}") return None, None, None, None, None, None, None @spaces.GPU() def cb_mde(image_file: str, sd_version): if not image_file: return None, None, None, None image = crop(Image.open(image_file).convert('RGB')) if sd_version == "v1.5": image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) elif sd_version == "v2.1": image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) else: gr.Warning(f"Unknown SD version: {sd_version}. Defaulting to {IMAGE_SIZE}.") image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) gr.Info(f"Generating with GenStereo {sd_version} at {IMAGE_SIZE}px resolution.") image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) dam2 = get_dam2_model() depth_dam2 = dam2.infer_image(image_bgr) depth_image = cv2.applyColorMap((normalize_disp(depth_dam2) * 255).astype(np.uint8), cv2.COLORMAP_JET) return image, depth_image, image, depth_dam2 @spaces.GPU() def cb_generate(image, depth, scale_factor, sd_version): depth_tensor = torch.tensor(depth).unsqueeze(0).unsqueeze(0).float() norm_disp = normalize_disp(depth_tensor.cuda()) disp = norm_disp * scale_factor / 100 * IMAGE_SIZE genstereo = get_genstereo_model(sd_version) fusion_model = get_fusion_model() renders = genstereo( src_image=image, src_disparity=disp, ratio=None, ) warped = (renders['warped'] + 1) / 2 synthesized = renders['synthesized'] mask = renders['mask'] fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float()) warped_pil = to_pil_image(warped[0]) fusion_pil = to_pil_image(fusion_image[0]) # Create full SBS for Quest 2 left_resized = image.resize((1832, 1920)) right_resized = fusion_pil.resize((1832, 1920)) sbs = Image.new('RGB', (3664, 1920)) sbs.paste(left_resized, (0, 0)) sbs.paste(right_resized, (1832, 0)) return warped_pil, fusion_pil, sbs # Blocks. gr.Markdown( """ # [ICCV 2025] Towards Open-World Generation of Stereo Images and Unsupervised Matching [Project Site](https://qjizhi.github.io/genstereo) | [Spaces](https://huggingface.co/spaces/FQiao/GenStereo) | [Github](https://github.com/Qjizhi/GenStereo) | [Models](https://huggingface.co/FQiao/GenStereo-sd2.1/tree/main) | [arXiv](https://arxiv.org/abs/2503.12720) ## Introduction This is an official demo for the paper "[Towards Open-World Generation of Stereo Images and Unsupervised Matching](https://qjizhi.github.io/genstereo)". Given an arbitrary reference image, GenStereo can generate the corresponding right-view image. ## How to Use 1. Select the GenStereo version - v1.5: 512px, faster. - v2.1: 768px, better performance, high resolution, takes more time. 2. Upload a reference image to "Left Image" - You can also select an image from "Examples" 3. Hit "Generate a right image" button and check the result. """ ) sd_version_radio = gr.Radio( label="GenStereo Version", choices=["v1.5", "v2.1"], value="v2.1", ) with gr.Row(): file = gr.File(label='Left', file_types=['image']) examples = gr.Examples( examples=['./assets/COCO_val2017_000000070229.jpg', './assets/COCO_val2017_000000092839.jpg', './assets/KITTI2015_000003_10.png', './assets/KITTI2015_000147_10.png'], inputs=file )