Spaces:
Running
Running
| import rerun as rr | |
| import rerun.blueprint as rrb | |
| import depth_pro | |
| import subprocess | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| import os | |
| from pathlib import Path | |
| import gradio as gr | |
| from gradio_rerun import Rerun | |
| import spaces | |
| # Run the script to get pretrained models | |
| if not os.path.exists("./checkpoints/depth_pro.pt"): | |
| print("downloading pretrained model") | |
| subprocess.run(["bash", "get_pretrained_models.sh"]) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Load model and preprocessing transform | |
| print("loading model...") | |
| model, transform = depth_pro.create_model_and_transforms() | |
| model = model.to(device) | |
| model.eval() | |
| def run_rerun(path_to_video): | |
| stream = rr.binary_stream() | |
| blueprint = rrb.Blueprint( | |
| rrb.Vertical( | |
| rrb.Spatial3DView(origin="/"), | |
| rrb.Horizontal( | |
| rrb.Spatial2DView( | |
| origin="/world/camera/depth", | |
| ), | |
| rrb.Spatial2DView(origin="/world/camera/image"), | |
| ), | |
| ), | |
| collapse_panels=True, | |
| ) | |
| rr.send_blueprint(blueprint) | |
| yield stream.read() | |
| print("Loading video from", path_to_video) | |
| video = cv2.VideoCapture(path_to_video) | |
| frame_idx = -1 | |
| while True: | |
| read, frame = video.read() | |
| if not read: | |
| break | |
| frame_idx += 1 | |
| if frame_idx % 3 != 0: | |
| continue | |
| print("processing frame", frame_idx) | |
| # resize to avoid excessive time spent processing | |
| frame = cv2.resize(frame, (640, 480)) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| rr.set_time_sequence("frame", frame_idx) | |
| rr.log("world/camera/image", rr.Image(frame)) | |
| yield stream.read() | |
| depth, focal_length = estimate_depth(frame) | |
| rr.log( | |
| "world/camera", | |
| rr.Pinhole( | |
| width=frame.shape[1], | |
| height=frame.shape[0], | |
| focal_length=focal_length, | |
| principal_point=(frame.shape[1] / 2, frame.shape[0] / 2), | |
| image_plane_distance=depth.max(), | |
| camera_xyz=rr.ViewCoordinates.FLU, | |
| ), | |
| ) | |
| rr.log( | |
| "world/camera/depth", | |
| rr.DepthImage(depth, meter=1), | |
| ) | |
| yield stream.read() | |
| # clean up | |
| if os.exists(path_to_video): | |
| os.remove(path_to_video) | |
| def estimate_depth(frame): | |
| image = transform(frame) | |
| image = image.to(device) | |
| prediction = model.infer(image) | |
| depth = prediction["depth"].squeeze().detach().cpu().numpy() | |
| focal_length = prediction["focallength_px"].item() | |
| return depth, focal_length | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(variant="compact"): | |
| video = gr.Video(interactive=True, include_audio=False, label="Video") | |
| visualize = gr.Button("Visualize ML Depth Pro") | |
| with gr.Column(): | |
| viewer = Rerun( | |
| streaming=True, | |
| ) | |
| visualize.click(run_rerun, inputs=[video], outputs=[viewer]) | |
| if __name__ == "__main__": | |
| demo.launch() | |