Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| from os.path import basename, splitext, join | |
| import tempfile | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import cv2 | |
| from torchvision.transforms.functional import to_tensor, to_pil_image | |
| from torch import Tensor | |
| from genstereo import GenStereo, AdaptiveFusionLayer | |
| import ssl | |
| from huggingface_hub import hf_hub_download | |
| from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2 | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| IMAGE_SIZE = 768 | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
| CHECKPOINT_NAME = 'genstereo-v2.1' | |
| def download_models(): | |
| models = [ | |
| { | |
| 'repo': 'stabilityai/sd-vae-ft-mse', | |
| 'sub': None, | |
| 'dst': 'checkpoints/sd-vae-ft-mse', | |
| 'files': ['config.json', 'diffusion_pytorch_model.safetensors'], | |
| 'token': None | |
| }, | |
| { | |
| 'repo': 'lambdalabs/sd-image-variations-diffusers', | |
| 'sub': 'image_encoder', | |
| 'dst': 'checkpoints', | |
| 'files': ['config.json', 'pytorch_model.bin'], | |
| 'token': None | |
| }, | |
| { | |
| 'repo': 'FQiao/GenStereo', | |
| 'sub': None, | |
| 'dst': 'checkpoints/genstereo-v1.5', | |
| 'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'], | |
| 'token': None | |
| }, | |
| { | |
| 'repo': 'FQiao/GenStereo-sd2.1', | |
| 'sub': None, | |
| 'dst': 'checkpoints/genstereo-v2.1', | |
| 'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'], | |
| 'token': None | |
| }, | |
| { | |
| 'repo': 'depth-anything/Depth-Anything-V2-Large', | |
| 'sub': None, | |
| 'dst': 'checkpoints', | |
| 'files': [f'depth_anything_v2_vitl.pth'], | |
| 'token': None | |
| } | |
| ] | |
| for model in models: | |
| for file in model['files']: | |
| hf_hub_download( | |
| repo_id=model['repo'], | |
| subfolder=model['sub'], | |
| filename=file, | |
| local_dir=model['dst'], | |
| token=model['token'] | |
| ) | |
| # Setup. | |
| download_models() | |
| # DepthAnythingV2 | |
| def get_dam2_model(): | |
| model_configs = { | |
| 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
| 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
| 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
| } | |
| encoder = 'vitl' | |
| encoder_size_map = {'vits': 'Small', 'vitb': 'Base', 'vitl': 'Large'} | |
| if encoder not in encoder_size_map: | |
| raise ValueError(f"Unsupported encoder: {encoder}. Supported: {list(encoder_size_map.keys())}") | |
| dam2 = DepthAnythingV2(**model_configs[encoder]) | |
| dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth' | |
| dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu')) | |
| dam2 = dam2.to(DEVICE).eval() | |
| return dam2 | |
| # GenStereo | |
| def get_genstereo_model(sd_version): | |
| genstereo_cfg = dict( | |
| pretrained_model_path='checkpoints', | |
| checkpoint_name=CHECKPOINT_NAME, | |
| half_precision_weights=True | |
| ) | |
| genstereo = GenStereo(cfg=genstereo_cfg, device=DEVICE, sd_version=sd_version) | |
| return genstereo | |
| # Adaptive Fusion | |
| def get_fusion_model(): | |
| fusion_model = AdaptiveFusionLayer() | |
| fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth') | |
| fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu')) | |
| fusion_model = fusion_model.to(DEVICE).eval() | |
| return fusion_model | |
| # Crop the image to the shorter side. | |
| def crop(img: Image) -> Image: | |
| W, H = img.size | |
| if W < H: | |
| left, right = 0, W | |
| top, bottom = np.ceil((H - W) / 2.), np.floor((H - W) / 2.) + W | |
| else: | |
| left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H | |
| top, bottom = 0, H | |
| return img.crop((left, top, right, bottom)) | |
| def normalize_disp(disp): | |
| return (disp - disp.min()) / (disp.max() - disp.min()) | |
| # Gradio app | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| with gr.Blocks( | |
| title='StereoGen Demo', | |
| css='img {display: inline;}' | |
| ) as demo: | |
| # Internal states. | |
| src_image = gr.State() | |
| src_depth = gr.State() | |
| # Callbacks | |
| def cb_update_sd_version(sd_version_choice): | |
| global IMAGE_SIZE, CHECKPOINT_NAME | |
| if sd_version_choice == "v1.5": | |
| IMAGE_SIZE = 512 | |
| CHECKPOINT_NAME = 'genstereo-v1.5' | |
| print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}") | |
| elif sd_version_choice == "v2.1": | |
| IMAGE_SIZE = 768 | |
| CHECKPOINT_NAME = 'genstereo-v2.1' | |
| print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}") | |
| return None, None, None, None, None, None, None | |
| def cb_mde(image_file: str, sd_version): | |
| if not image_file: | |
| return None, None, None, None | |
| image = crop(Image.open(image_file).convert('RGB')) | |
| if sd_version == "v1.5": | |
| image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) | |
| elif sd_version == "v2.1": | |
| image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) | |
| else: | |
| gr.Warning(f"Unknown SD version: {sd_version}. Defaulting to {IMAGE_SIZE}.") | |
| image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) | |
| gr.Info(f"Generating with GenStereo {sd_version} at {IMAGE_SIZE}px resolution.") | |
| image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| dam2 = get_dam2_model() | |
| depth_dam2 = dam2.infer_image(image_bgr) | |
| depth_image = cv2.applyColorMap((normalize_disp(depth_dam2) * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| return image, depth_image, image, depth_dam2 | |
| def cb_generate(image, depth, scale_factor, sd_version): | |
| depth_tensor = torch.tensor(depth).unsqueeze(0).unsqueeze(0).float() | |
| norm_disp = normalize_disp(depth_tensor.cuda()) | |
| disp = norm_disp * scale_factor / 100 * IMAGE_SIZE | |
| genstereo = get_genstereo_model(sd_version) | |
| fusion_model = get_fusion_model() | |
| renders = genstereo( | |
| src_image=image, | |
| src_disparity=disp, | |
| ratio=None, | |
| ) | |
| warped = (renders['warped'] + 1) / 2 | |
| synthesized = renders['synthesized'] | |
| mask = renders['mask'] | |
| fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float()) | |
| warped_pil = to_pil_image(warped[0]) | |
| fusion_pil = to_pil_image(fusion_image[0]) | |
| # Create full SBS for Quest 2 | |
| left_resized = image.resize((1832, 1920)) | |
| right_resized = fusion_pil.resize((1832, 1920)) | |
| sbs = Image.new('RGB', (3664, 1920)) | |
| sbs.paste(left_resized, (0, 0)) | |
| sbs.paste(right_resized, (1832, 0)) | |
| return warped_pil, fusion_pil, sbs | |
| # Blocks. | |
| gr.Markdown( | |
| """ | |
| # [ICCV 2025] Towards Open-World Generation of Stereo Images and Unsupervised Matching | |
| [Project Site](https://qjizhi.github.io/genstereo) | [Spaces](https://huggingface.co/spaces/FQiao/GenStereo) | [Github](https://github.com/Qjizhi/GenStereo) | [Models](https://huggingface.co/FQiao/GenStereo-sd2.1/tree/main) | [arXiv](https://arxiv.org/abs/2503.12720) | |
| ## Introduction | |
| This is an official demo for the paper "[Towards Open-World Generation of Stereo Images and Unsupervised Matching](https://qjizhi.github.io/genstereo)". Given an arbitrary reference image, GenStereo can generate the corresponding right-view image. | |
| ## How to Use | |
| 1. Select the GenStereo version | |
| - v1.5: 512px, faster. | |
| - v2.1: 768px, better performance, high resolution, takes more time. | |
| 2. Upload a reference image to "Left Image" | |
| - You can also select an image from "Examples" | |
| 3. Hit "Generate a right image" button and check the result. | |
| """ | |
| ) | |
| sd_version_radio = gr.Radio( | |
| label="GenStereo Version", | |
| choices=["v1.5", "v2.1"], | |
| value="v2.1", | |
| ) | |
| with gr.Row(): | |
| file = gr.File(label='Left', file_types=['image']) | |
| examples = gr.Examples( | |
| examples=['./assets/COCO_val2017_000000070229.jpg', | |
| './assets/COCO_val2017_000000092839.jpg', | |
| './assets/KITTI2015_000003_10.png', | |
| './assets/KITTI2015_000147_10.png'], | |
| inputs=file | |
| ) |