| |
|
| | import os |
| | import gradio as gr |
| |
|
| |
|
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from loguru import logger |
| | from tqdm import tqdm |
| | from tools.common_utils import save_video |
| | from dkt.pipelines.pipeline import DKTPipeline, ModelConfig |
| |
|
| |
|
| | import cv2 |
| | import copy |
| | import trimesh |
| |
|
| | from os.path import join |
| | from tools.depth2pcd import depth2pcd |
| | |
| |
|
| |
|
| | from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map |
| | import datetime |
| | import tempfile |
| | import time |
| |
|
| |
|
| | |
| |
|
| | NEGATIVE_PROMPT = '' |
| | height = 480 |
| | width = 832 |
| | window_size = 21 |
| | |
| | DKT_PIPELINE_14B = DKTPipeline(is14B=True) |
| | |
| |
|
| | example_inputs = [ |
| | "examples/1.mp4", |
| | "examples/7.mp4", |
| | "examples/8.mp4", |
| | "examples/39.mp4", |
| | "examples/10.mp4", |
| | "examples/30.mp4", |
| | |
| | "examples/35.mp4", |
| | "examples/40.mp4", |
| | "examples/2.mp4", |
| |
|
| |
|
| | "examples/4.mp4", |
| | "examples/episode_48-camera_head.mp4", |
| | "examples/input_20251128_121408.mp4", |
| | "examples/input_20251128_122722.mp4", |
| | "examples/5eaeaff52b23787a3dc3c610655a49d2.mp4", |
| | "examples/9f2909760aff526070f169620ff38290.mp4", |
| | "examples/18.mp4", |
| | |
| | "examples/28.mp4", |
| | "examples/73fc0b2a3af3474de27c7da0bfbf5faa.mp4", |
| | "examples/episode_48-camera_third_view.mp4", |
| | "examples/extra_5.mp4", |
| | "examples/extra_9.mp4", |
| | "examples/IMG_5703.MOV", |
| | "examples/input_20251202_031811.mp4", |
| | "examples/input_20251202_032007.mp4", |
| | "examples/teaser_1.mp4", |
| | "examples/3.mp4", |
| | "examples/teaser_3.mp4", |
| | "examples/teaser_7.mp4", |
| | "examples/teaser_25.mp4", |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | ] |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene: |
| | pts_3d = point_map[valid_mask] * np.array([-1, -1, 1]) |
| | pts_rgb = frame[valid_mask] |
| |
|
| | |
| | scene_3d = trimesh.Scene() |
| |
|
| | |
| | point_cloud_data = trimesh.PointCloud( |
| | vertices=pts_3d, colors=pts_rgb |
| | ) |
| | |
| | scene_3d.add_geometry(point_cloud_data) |
| | return scene_3d |
| |
|
| |
|
| |
|
| | def create_simple_glb_from_pointcloud(points, colors, glb_filename): |
| | try: |
| | if len(points) == 0: |
| | logger.warning(f"No valid points to create GLB for {glb_filename}") |
| | return False |
| | |
| | if colors is not None: |
| | |
| | pts_rgb = colors |
| | else: |
| | logger.info("No colors provided, adding default white colors") |
| | pts_rgb = np.ones((len(points), 3)) |
| | |
| | valid_mask = np.ones(len(points), dtype=bool) |
| | |
| | scene_3d = pmap_to_glb(points, valid_mask, pts_rgb) |
| | |
| | scene_3d.export(glb_filename) |
| | |
| | |
| | return True |
| | |
| | except Exception as e: |
| | logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}") |
| | return False |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def process_video( |
| | video_file, |
| | model_size, |
| | num_inference_steps, |
| | overlap |
| | ): |
| | global height |
| | global width |
| | global window_size |
| |
|
| | global DKT_PIPELINE_14B |
| | global DKT_PIPELINE |
| |
|
| | if model_size == "14B": |
| | logger.info(f'14B model is chosen') |
| | pipeline = DKT_PIPELINE_14B |
| | elif model_size == "1.3B": |
| | logger.info(f'1.3B model is chosen') |
| | pipeline = DKT_PIPELINE |
| | else: |
| | raise ValueError(f"Invalid model size: {model_size}") |
| |
|
| | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| | cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_') |
| | |
| |
|
| |
|
| | |
| | start_time = time.time() |
| |
|
| | print(f"[1] Starting pipeline...") |
| | try: |
| | prediction_result = pipeline( |
| | video_file, |
| | negative_prompt=NEGATIVE_PROMPT, |
| | height=height, |
| | width=width, |
| | num_inference_steps=num_inference_steps, |
| | overlap=overlap, |
| | return_rgb=True, |
| | get_moge_intrinsics=False |
| | ) |
| | print(f"[2] Pipeline done, keys: {prediction_result.keys()}") |
| | except Exception as e: |
| | print(f"[ERROR] Pipeline failed: {type(e).__name__}: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | raise |
| |
|
| | end_time = time.time() |
| | spend_time = end_time - start_time |
| | print(f"[3] Pipeline time: {spend_time:.2f}s") |
| | logger.info(f"pipeline spend time: {spend_time:.2f} seconds for depth prediction") |
| |
|
| | |
| | |
| | output_filename = f"output_{timestamp}.mp4" |
| | output_path = os.path.join(cur_save_dir, output_filename) |
| |
|
| | cap = cv2.VideoCapture(video_file) |
| | input_fps = cap.get(cv2.CAP_PROP_FPS) |
| | cap.release() |
| |
|
| | print(f"[4] Saving video, fps={input_fps}") |
| | save_video(prediction_result['colored_depth_map'], output_path, fps=input_fps, quality=8) |
| | print(f"[5] Video saved: {output_path}") |
| | return output_path |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| |
|
| |
|
| | css = """ |
| | #download { |
| | height: 118px; |
| | } |
| | .slider .inner { |
| | width: 5px; |
| | background: #FFF; |
| | } |
| | .viewport { |
| | aspect-ratio: 4/3; |
| | } |
| | .tabs button.selected { |
| | font-size: 20px !important; |
| | color: crimson !important; |
| | } |
| | h1 { |
| | text-align: center; |
| | display: block; |
| | } |
| | h2 { |
| | text-align: center; |
| | display: block; |
| | } |
| | h3 { |
| | text-align: center; |
| | display: block; |
| | } |
| | .md_feedback li { |
| | margin-bottom: 0px !important; |
| | } |
| | """ |
| |
|
| |
|
| |
|
| | head_html = """ |
| | <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> |
| | <script> |
| | window.dataLayer = window.dataLayer || []; |
| | function gtag() {dataLayer.push(arguments);} |
| | gtag('js', new Date()); |
| | gtag('config', 'G-1FWSVCGZTG'); |
| | </script> |
| | """ |
| |
|
| |
|
| |
|
| |
|
| | with gr.Blocks(css=css, title="DKT", head=head_html) as demo: |
| | |
| | gr.Markdown( |
| | """ |
| | # Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation |
| | <p align="center"> |
| | |
| | <a title="Website" href="https://daniellli.github.io/projects/DKT/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
| | <img src="https://www.obukhov.ai/img/badges/badge-website.svg"> |
| | </a> |
| | <a title="Github" href="https://github.com/Daniellli/DKT" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
| | <img src="https://img.shields.io/github/stars/Daniellli/DKT?style=social" alt="badge-github-stars"> |
| | </a> |
| | <a title="Social" href="https://x.com/xshocng1" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
| | <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> |
| | </a> |
| | """ |
| | ) |
| | |
| | |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | input_video = gr.Video(label="Input Video", elem_id='video-display-input') |
| | |
| | model_size = gr.Radio( |
| | |
| | choices=["14B"], |
| | value="14B", |
| | label="Model Size" |
| | ) |
| |
|
| |
|
| | with gr.Accordion("Advanced Parameters", open=False): |
| | num_inference_steps = gr.Slider( |
| | minimum=1, maximum=50, value=5, step=1, |
| | label="Number of Inference Steps" |
| | ) |
| | overlap = gr.Slider( |
| | minimum=1, maximum=20, value=3, step=1, |
| | label="Overlap" |
| | ) |
| | |
| | submit = gr.Button(value="Compute Depth", variant="primary") |
| | |
| | with gr.Column(): |
| | output_video = gr.Video( |
| | label="Depth Outputs", |
| | elem_id='video-display-output', |
| | autoplay=True |
| | ) |
| | vis_video = gr.Video( |
| | label="Visualization Video", |
| | visible=False, |
| | autoplay=True |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def on_submit(video_file, model_size, num_inference_steps, overlap): |
| | logger.info('on_submit is calling') |
| | if video_file is None: |
| | return None, None |
| | |
| | try: |
| | |
| | start_time = time.time() |
| | output_path = process_video( |
| | video_file, model_size, num_inference_steps, overlap |
| | ) |
| | spend_time = time.time() - start_time |
| | logger.info(f"Total spend time in on_submit: {spend_time:.2f} seconds") |
| | print(f"Total spend time in on_submit: {spend_time:.2f} seconds") |
| |
|
| | |
| | if output_path is None: |
| | return None, None |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | return output_path, None |
| | |
| | except Exception as e: |
| | logger.error(e) |
| | return None, None |
| |
|
| | |
| | submit.click( |
| | on_submit, |
| | inputs=[ |
| | input_video, model_size, num_inference_steps, overlap |
| | ], |
| | outputs=[ |
| | output_video, vis_video |
| | |
| | ] |
| | ) |
| | |
| |
|
| | |
| | def on_example_submit(video_file): |
| | """Wrapper function for examples with default parameters""" |
| | return on_submit(video_file, "14B", 5, 3) |
| |
|
| | examples = gr.Examples( |
| | examples=example_inputs, |
| | inputs=[input_video], |
| | outputs=[ |
| | output_video, vis_video |
| | |
| | ], |
| | fn=on_example_submit, |
| | examples_per_page=36, |
| | cache_examples=False |
| | ) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | |
| | |
| | |
| | demo.queue().launch() |
| | |