Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # -------------------------------------------------------------------------- | |
| # If you find this code useful, we kindly ask you to cite our paper in your work. | |
| # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
| # More information about the method can be found at https://marigoldmonodepth.github.io | |
| # -------------------------------------------------------------------------- | |
| import functools | |
| import os | |
| import sys | |
| import tempfile | |
| import av | |
| import numpy as np | |
| import spaces | |
| import gradio as gr | |
| import torch as torch | |
| import einops | |
| from huggingface_hub import login | |
| from gradio_patches.examples import Examples | |
| from colorize import colorize_depth_multi_thread | |
| from video_io import get_video_fps, write_video_from_numpy | |
| VERBOSE = False | |
| MAX_FRAMES = 100 | |
| def process(pipe, device, path_input): | |
| print(f"Processing {path_input}") | |
| path_output_dir = tempfile.mkdtemp() | |
| os.makedirs(path_output_dir, exist_ok=True) | |
| name_base = os.path.splitext(os.path.basename(path_input))[0] | |
| path_out_in = os.path.join(path_output_dir, f"{name_base}_depth_input.mp4") | |
| path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4") | |
| output_fps = int(get_video_fps(path_input)) | |
| container = av.open(path_input) | |
| stream = container.streams.video[0] | |
| fps = float(stream.average_rate) | |
| duration_sec = float(stream.duration * stream.time_base) if stream.duration else 0 | |
| total_frames = int(duration_sec * fps) | |
| if total_frames > MAX_FRAMES: | |
| gr.Warning( | |
| f"Only the first {MAX_FRAMES} frames (~{MAX_FRAMES / fps:.1f} sec.) will be processed for demonstration; " | |
| f"use the code from GitHub for full processing" | |
| ) | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(2024) | |
| pipe_out: RollingDepthOutput = pipe( | |
| # input setting | |
| input_video_path=path_input, | |
| start_frame=0, | |
| frame_count=min(MAX_FRAMES, total_frames), # 0 = all | |
| processing_res=768, | |
| # infer setting | |
| dilations=[1, 25], | |
| cap_dilation=True, | |
| snippet_lengths=[3], | |
| init_infer_steps=[1], | |
| strides=[1], | |
| coalign_kwargs=None, | |
| refine_step=0, # 0 = off | |
| max_vae_bs=8, # batch size for encoder/decoder | |
| # other settings | |
| generator=generator, | |
| verbose=VERBOSE, | |
| # output settings | |
| restore_res=False, | |
| unload_snippet=False, | |
| ) | |
| depth_pred = pipe_out.depth_pred # [N 1 H W] | |
| # Colorize results | |
| cmap = "Spectral_r" | |
| colored_np = colorize_depth_multi_thread( | |
| depth=depth_pred.numpy(), | |
| valid_mask=None, | |
| chunk_size=4, | |
| num_threads=4, | |
| color_map=cmap, | |
| verbose=VERBOSE, | |
| ) # [n h w 3], in [0, 255] | |
| write_video_from_numpy( | |
| frames=colored_np, | |
| output_path=path_out_vis, | |
| fps=output_fps, | |
| crf=23, | |
| preset="medium", | |
| verbose=VERBOSE, | |
| ) | |
| # Save rgb | |
| rgb = (pipe_out.input_rgb.numpy() * 255).astype(np.uint8) # [N 3 H W] | |
| rgb = einops.rearrange(rgb, "n c h w -> n h w c") | |
| write_video_from_numpy( | |
| frames=rgb, | |
| output_path=path_out_in, | |
| fps=output_fps, | |
| crf=23, | |
| preset="medium", | |
| verbose=VERBOSE, | |
| ) | |
| return path_out_in, path_out_vis | |
| def run_demo_server(pipe, device): | |
| process_pipe = spaces.GPU(functools.partial(process, pipe, device)) | |
| os.environ["GRADIO_ALLOW_FLAGGING"] = "never" | |
| with gr.Blocks( | |
| analytics_enabled=False, | |
| title="RollingDepth", | |
| css=""" | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h2 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h3 { | |
| text-align: center; | |
| display: block; | |
| } | |
| """, | |
| ) as demo: | |
| gr.HTML( | |
| """ | |
| <h1>🛹 RollingDepth: Video Depth without Video Models</h1> | |
| <div style="text-align: center; margin-top: 20px;"> | |
| <a title="Website" href="https://rollingdepth.github.io" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-website.svg" alt="Website Badge"> | |
| </a> | |
| <a title="arXiv" href="https://arxiv.org/abs/2411.xxxxx" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg" alt="arXiv Badge"> | |
| </a> | |
| <a title="GitHub" href="https://github.com/prs-eth/rollingdepth" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;"> | |
| <img src="https://img.shields.io/github/stars/prs-eth/rollingdepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="GitHub Stars Badge"> | |
| </a> | |
| <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
| </a> | |
| </div> | |
| <p style="margin-top: 20px; text-align: justify;"> | |
| RollingDepth is the state-of-the-art depth estimator for videos in the wild. Upload your video into the | |
| <b>left</b> pane, or click any of the <b>examples</b> below. The result preview will be computed and | |
| appear in the <b>right</b> panes. For full functionality, use the code on GitHub. | |
| <b>TIP:</b> When running out of GPU time, fork the demo. | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| input_video = gr.Video(label="Input Video") | |
| with gr.Column(scale=2): | |
| with gr.Row(equal_height=True): | |
| output_video_1 = gr.Video( | |
| label="Preprocessed video", | |
| interactive=False, | |
| autoplay=True, | |
| loop=True, | |
| show_share_button=True, | |
| scale=5, | |
| ) | |
| output_video_2 = gr.Video( | |
| label="Generated Depth Video", | |
| interactive=False, | |
| autoplay=True, | |
| loop=True, | |
| show_share_button=True, | |
| scale=5, | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| with gr.Row(equal_height=False): | |
| generate_btn = gr.Button("Generate") | |
| with gr.Column(scale=2): | |
| pass | |
| Examples( | |
| examples=[ | |
| ["files/gokart.mp4"], | |
| ["files/horse.mp4"], | |
| ["files/walking.mp4"], | |
| ], | |
| inputs=[input_video], | |
| outputs=[output_video_1, output_video_2], | |
| fn=process_pipe, | |
| cache_examples=True, | |
| directory_name="examples_video", | |
| ) | |
| generate_btn.click( | |
| fn=process_pipe, | |
| inputs=[input_video], | |
| outputs=[output_video_1, output_video_2], | |
| ) | |
| demo.queue( | |
| api_open=False, | |
| ).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) | |
| def main(): | |
| os.system("pip freeze") | |
| os.system("pip uninstall -y diffusers") | |
| os.system("pip install rollingdepth_src/diffusers") | |
| os.system("pip freeze") | |
| if "HF_TOKEN_LOGIN" in os.environ: | |
| login(token=os.environ["HF_TOKEN_LOGIN"]) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "rollingdepth_src")) | |
| from rollingdepth import RollingDepthOutput, RollingDepthPipeline | |
| pipe: RollingDepthPipeline = RollingDepthPipeline.from_pretrained( | |
| "prs-eth/rollingdepth-v1-0", | |
| torch_dtype=torch.float16, | |
| ) | |
| pipe.set_progress_bar_config(disable=True) | |
| try: | |
| import xformers | |
| pipe.enable_xformers_memory_efficient_attention() | |
| except: | |
| pass # run without xformers | |
| pipe = pipe.to(device) | |
| run_demo_server(pipe, device) | |
| if __name__ == "__main__": | |
| main() | |