Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import logging | |
| import os | |
| import random | |
| import tempfile | |
| import time | |
| import spaces | |
| from easydict import EasyDict | |
| import numpy as np | |
| import torch | |
| from dav.pipelines import DAVPipeline | |
| from dav.models import UNetSpatioTemporalRopeConditionModel | |
| from diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler | |
| from dav.utils import img_utils | |
| def seed_all(seed: int = 0): | |
| """ | |
| Set random seeds for reproducibility. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Load models once to avoid reloading on every inference | |
| def load_models(model_base, device): | |
| vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder="vae") | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| model_base, subfolder="scheduler" | |
| ) | |
| unet = UNetSpatioTemporalRopeConditionModel.from_pretrained( | |
| model_base, subfolder="unet" | |
| ) | |
| unet_interp = UNetSpatioTemporalRopeConditionModel.from_pretrained( | |
| model_base, subfolder="unet_interp" | |
| ) | |
| pipe = DAVPipeline( | |
| vae=vae, | |
| unet=unet, | |
| unet_interp=unet_interp, | |
| scheduler=scheduler, | |
| ) | |
| pipe = pipe.to(device) | |
| return pipe | |
| # Load models at startup | |
| MODEL_BASE = "hhyangcs/depth-any-video" | |
| DEVICE_TYPE = "cuda" | |
| DEVICE = torch.device(DEVICE_TYPE) | |
| pipe = load_models(MODEL_BASE, DEVICE) | |
| def depth_any_video( | |
| file, | |
| denoise_steps=3, | |
| num_frames=32, | |
| decode_chunk_size=16, | |
| num_interp_frames=16, | |
| num_overlap_frames=6, | |
| max_resolution=1024, | |
| ): | |
| """ | |
| Perform depth estimation on the uploaded video/image. | |
| """ | |
| with open(file, "rb") as _file: | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Save the uploaded file | |
| input_path = os.path.join(tmp_dir, file.name) | |
| with open(input_path, "wb") as f: | |
| f.write(_file.read()) | |
| # Set up output directory | |
| output_dir = os.path.join(tmp_dir, "output") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Prepare configuration | |
| cfg = EasyDict( | |
| { | |
| "model_base": MODEL_BASE, | |
| "data_path": input_path, | |
| "output_dir": output_dir, | |
| "denoise_steps": denoise_steps, | |
| "num_frames": num_frames, | |
| "decode_chunk_size": decode_chunk_size, | |
| "num_interp_frames": num_interp_frames, | |
| "num_overlap_frames": num_overlap_frames, | |
| "max_resolution": max_resolution, | |
| "seed": 666, | |
| } | |
| ) | |
| seed_all(cfg.seed) | |
| file_name = os.path.splitext(os.path.basename(cfg.data_path))[0] | |
| is_video = cfg.data_path.lower().endswith((".mp4", ".avi", ".mov", ".mkv")) | |
| if is_video: | |
| num_interp_frames = cfg.num_interp_frames | |
| num_overlap_frames = cfg.num_overlap_frames | |
| num_frames = cfg.num_frames | |
| assert num_frames % 2 == 0, "num_frames should be even." | |
| assert ( | |
| 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2 | |
| ), "Invalid frame overlap." | |
| max_frames = (num_interp_frames + 2 - num_overlap_frames) * ( | |
| num_frames // 2 | |
| ) | |
| image, fps = img_utils.read_video(cfg.data_path, max_frames=max_frames) | |
| else: | |
| image = img_utils.read_image(cfg.data_path) | |
| image = img_utils.imresize_max(image, cfg.max_resolution) | |
| image = img_utils.imcrop_multi(image) | |
| image_tensor = np.ascontiguousarray( | |
| [_img.transpose(2, 0, 1) / 255.0 for _img in image] | |
| ) | |
| image_tensor = torch.from_numpy(image_tensor).to(DEVICE) | |
| with torch.no_grad(), torch.autocast( | |
| device_type=DEVICE_TYPE, dtype=torch.float16 | |
| ): | |
| pipe_out = pipe( | |
| image_tensor, | |
| num_frames=cfg.num_frames, | |
| num_overlap_frames=cfg.num_overlap_frames, | |
| num_interp_frames=cfg.num_interp_frames, | |
| decode_chunk_size=cfg.decode_chunk_size, | |
| num_inference_steps=cfg.denoise_steps, | |
| ) | |
| disparity = pipe_out.disparity | |
| disparity_colored = pipe_out.disparity_colored | |
| image = pipe_out.image | |
| # (N, H, 2 * W, 3) | |
| merged = np.concatenate( | |
| [ | |
| image, | |
| disparity_colored, | |
| ], | |
| axis=2, | |
| ) | |
| if is_video: | |
| output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.mp4") | |
| img_utils.write_video( | |
| output_path, | |
| merged, | |
| fps, | |
| ) | |
| return output_path | |
| else: | |
| output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.png") | |
| img_utils.write_image( | |
| output_path, | |
| merged[0], | |
| ) | |
| return output_path | |
| # Define Gradio interface | |
| title = "Depth Any Video with Scalable Synthetic Data" | |
| description = """ | |
| Upload a video or image to perform depth estimation using the Depth Any Video model. | |
| Adjust the parameters as needed to control the inference process. | |
| """ | |
| iface = gr.Interface( | |
| fn=depth_any_video, | |
| inputs=[ | |
| gr.File(label="Upload Video/Image"), | |
| gr.Slider(1, 10, step=1, value=3, label="Denoise Steps"), | |
| gr.Slider(16, 64, step=1, value=32, label="Number of Frames"), | |
| gr.Slider(8, 32, step=1, value=16, label="Decode Chunk Size"), | |
| gr.Slider(8, 32, step=1, value=16, label="Number of Interpolation Frames"), | |
| gr.Slider(2, 10, step=1, value=6, label="Number of Overlap Frames"), | |
| gr.Slider(512, 2048, step=32, value=1024, label="Maximum Resolution"), | |
| ], | |
| outputs=gr.Video(label="Depth Enhanced Video/Image"), | |
| title=title, | |
| description=description, | |
| examples=[["demos/arch_2.jpg"], ["demos/wooly_mammoth.mp4"]], | |
| allow_flagging="never", | |
| analytics_enabled=False, | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(share=True) | |