import os import gradio as gr import numpy as np import torch from PIL import Image from loguru import logger from tqdm import tqdm from tools.common_utils import save_video from dkt.pipelines.pipeline import DKTPipeline, ModelConfig import cv2 import copy import trimesh from os.path import join from tools.depth2pcd import depth2pcd # from moge.model.v2 import MoGeModel from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map import datetime import tempfile import time #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors NEGATIVE_PROMPT = '' height = 480 width = 832 window_size = 21 # DKT_PIPELINE = DKTPipeline() DKT_PIPELINE_14B = DKTPipeline(is14B=True, is_depth=False) # DKT_PIPELINE_14B_NORMAL = DKTPipeline(is14B=True, is_depth=False) example_inputs = [ # "examples/1.mp4", "examples/7.mp4", # "examples/8.mp4", # "examples/39.mp4", "examples/10.mp4", # "examples/30.mp4", # "examples/35.mp4", # "examples/40.mp4", # "examples/2.mp4", # "examples/4.mp4", #* video not available # "examples/episode_48-camera_head.mp4", #* video not available # "examples/input_20251128_121408.mp4", "examples/input_20251128_122722.mp4", # "examples/5eaeaff52b23787a3dc3c610655a49d2.mp4", "examples/9f2909760aff526070f169620ff38290.mp4", "examples/16.mp4", "examples/17.mp4", "examples/18.mp4", "examples/27.mp4", # "examples/28.mp4", # "examples/73fc0b2a3af3474de27c7da0bfbf5faa.mp4", "examples/episode_48-camera_third_view.mp4", "examples/extra_5.mp4", "examples/extra_9.mp4", # "examples/IMG_5703.MOV", "examples/input_20251202_031811.mp4", # "examples/input_20251202_032007.mp4", # "examples/teaser_1.mp4", "examples/3.mp4", # "examples/teaser_3.mp4", "examples/teaser_7.mp4", "examples/teaser_25.mp4", ] def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene: pts_3d = point_map[valid_mask] * np.array([-1, -1, 1]) pts_rgb = frame[valid_mask] # Initialize a 3D scene scene_3d = trimesh.Scene() # Add point cloud data to the scene point_cloud_data = trimesh.PointCloud( vertices=pts_3d, colors=pts_rgb ) scene_3d.add_geometry(point_cloud_data) return scene_3d def create_simple_glb_from_pointcloud(points, colors, glb_filename): try: if len(points) == 0: logger.warning(f"No valid points to create GLB for {glb_filename}") return False if colors is not None: # logger.info(f"Adding colors to GLB: shape={colors.shape}, range=[{colors.min():.3f}, {colors.max():.3f}]") pts_rgb = colors else: logger.info("No colors provided, adding default white colors") pts_rgb = np.ones((len(points), 3)) valid_mask = np.ones(len(points), dtype=bool) scene_3d = pmap_to_glb(points, valid_mask, pts_rgb) scene_3d.export(glb_filename) # logger.info(f"Saved GLB file using trimesh: {glb_filename}") return True except Exception as e: logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}") return False def process_video( video_file, model_size, num_inference_steps, overlap ): global height global width global window_size global DKT_PIPELINE_14B global DKT_PIPELINE if model_size == "14B": logger.info(f'14B model is chosen') pipeline = DKT_PIPELINE_14B elif model_size == "1.3B": logger.info(f'1.3B model is chosen') pipeline = DKT_PIPELINE else: raise ValueError(f"Invalid model size: {model_size}") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_') start_time = time.time() print(f"[1] Starting pipeline call...") try: prediction_result = pipeline( video_file, negative_prompt=NEGATIVE_PROMPT, height=height, width=width, num_inference_steps=num_inference_steps, overlap=overlap, return_rgb=True, get_moge_intrinsics=False ) print(f"[2] Pipeline returned, keys: {prediction_result.keys()}") except Exception as e: print(f"[ERROR] Pipeline failed: {type(e).__name__}: {e}") import traceback traceback.print_exc() raise end_time = time.time() spend_time = end_time - start_time print(f"[3] Pipeline spend time: {spend_time:.2f}s") logger.info(f"pipeline spend time: {spend_time:.2f} seconds for depth prediction") #* save depth predictions video output_filename = f"output_{timestamp}.mp4" output_path = os.path.join(cur_save_dir, output_filename) cap = cv2.VideoCapture(video_file) input_fps = cap.get(cv2.CAP_PROP_FPS) cap.release() print(f"[4] Saving video to {output_path}, fps={input_fps}") save_video(prediction_result['colored_depth_map'], output_path, fps=input_fps, quality=8) print(f"[5] Video saved successfully") return output_path # # 点云可视化相关代码已注释 # #* vis pc # # frame_length = len(prediction_result['rgb_frames']) # vis_pc_num = 4 # indices = np.linspace(0, frame_length-1, vis_pc_num) # indices = np.round(indices).astype(np.int32) # # # try: # glb_files = [] # print(f"selective indices: {indices}") # # if prediction_result['moge_mask'].sum() == 0 : # raise Exception("No valid points to create GLB for") # # # pc_start_time = time.time() # pcds = DKT_PIPELINE.prediction2pc_v3(prediction_result['depth_map'], # prediction_result['rgb_frames'], indices, # prediction_result['scale'], prediction_result['shift'], prediction_result['moge_intrinsics'], # prediction_result['moge_mask'], return_pcd=True) # # pc_end_time = time.time() # pc_spend_time = pc_end_time - pc_start_time # print(f"prediction2pc_v2 spend time: {pc_spend_time:.2f} seconds for point cloud extraction, len(pcds): {len(pcds)}") # # # for idx, pcd in enumerate(pcds): # # # points = np.asarray(pcd.points) # # colors = np.asarray(pcd.colors) if pcd.has_colors() else None # # points = pcd['point'] # colors = pcd['color'] # # logger.info(f'points:{points.shape} ') # print(f'point:{points.shape}') # if points.shape[0] == 0: # continue # # # points[:, 2] = -points[:, 2] # points[:, 0] = -points[:, 0] # # # glb_filename = os.path.join(cur_save_dir, f'{timestamp}_{idx:02d}.glb') # success = create_simple_glb_from_pointcloud(points, colors, glb_filename) # if not success: # logger.warning(f"Failed to save GLB file: {glb_filename}") # print(f"Failed to save GLB file: {glb_filename}") # # glb_files.append(glb_filename) # except Exception as e : # # logger.info(f" len(pcd):{len(pcds)},idx:{idx}, points.shape:{points.shape} e: {e}") # # print(f"len(pcd):{len(pcds)}, idx:{idx}, points.shape:{points.shape}, e: {e}, ") # print(e) # # return output_path, glb_files #* gradio creation and initialization css = """ #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """ head_html = """ """ with gr.Blocks(css=css, title="DKT", head=head_html) as demo: # gr.Markdown(title, elem_classes=["title"]) gr.Markdown( """ # Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation
"""
)
# gr.Markdown(description, elem_classes=["description"])
# gr.Markdown("### Video Processing Demo", elem_classes=["description"])
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video", elem_id='video-display-input')
model_size = gr.Radio(
# choices=["1.3B", "14B"],
choices=["14B"],
value="14B",
label="Model Size"
)
with gr.Accordion("Advanced Parameters", open=False):
num_inference_steps = gr.Slider(
minimum=1, maximum=50, value=5, step=1,
label="Number of Inference Steps"
)
overlap = gr.Slider(
minimum=1, maximum=20, value=3, step=1,
label="Overlap"
)
submit = gr.Button(value="Compute Depth", variant="primary")
with gr.Column():
output_video = gr.Video(
label="Depth Outputs",
elem_id='video-display-output',
autoplay=True
)
vis_video = gr.Video(
label="Visualization Video",
visible=False,
autoplay=True
)
# # 点云可视化相关 UI 已注释
# with gr.Row():
# gr.Markdown("### 3D Point Cloud Visualization", elem_classes=["title"])
#
# with gr.Row(equal_height=True):
# with gr.Column(scale=1):
# output_point_map0 = gr.Model3D(
# label="Point Cloud Key Frame 1",
# clear_color=[1.0, 1.0, 1.0, 1.0],
# interactive=False,
# )
# with gr.Column(scale=1):
# output_point_map1 = gr.Model3D(
# label="Point Cloud Key Frame 2",
# clear_color=[1.0, 1.0, 1.0, 1.0],
# interactive=False
# )
#
#
# with gr.Row(equal_height=True):
#
# with gr.Column(scale=1):
# output_point_map2 = gr.Model3D(
# label="Point Cloud Key Frame 3",
# clear_color=[1.0, 1.0, 1.0, 1.0],
# interactive=False
# )
# with gr.Column(scale=1):
# output_point_map3 = gr.Model3D(
# label="Point Cloud Key Frame 4",
# clear_color=[1.0, 1.0, 1.0, 1.0],
# interactive=False
# )
def on_submit(video_file, model_size, num_inference_steps, overlap):
logger.info('on_submit is calling')
if video_file is None:
return None, None
try:
start_time = time.time()
output_path = process_video(
video_file, model_size, num_inference_steps, overlap
)
spend_time = time.time() - start_time
logger.info(f"Total spend time in on_submit: {spend_time:.2f} seconds")
print(f"Total spend time in on_submit: {spend_time:.2f} seconds")
if output_path is None:
return None, None
# # 点云可视化相关代码已注释
# model3d_outputs = [None] * 4
# if glb_files and len(glb_files) !=0 :
# for i, glb_file in enumerate(glb_files[:4]):
# if os.path.exists(glb_file):
# model3d_outputs[i] = glb_file
return output_path, None
except Exception as e:
logger.error(e)
return None, None
submit.click(
on_submit,
inputs=[
input_video, model_size, num_inference_steps, overlap
],
outputs=[
output_video, vis_video
# output_point_map0, output_point_map1, output_point_map2, output_point_map3 # 点云可视化已注释
]
)
def on_example_submit(video_file):
"""Wrapper function for examples with default parameters"""
return on_submit(video_file, "14B", 5, 3)
examples = gr.Examples(
examples=example_inputs,
inputs=[input_video],
outputs=[
output_video, vis_video
# output_point_map0, output_point_map1, output_point_map2, output_point_map3 # 点云可视化已注释
],
fn=on_example_submit,
examples_per_page=36,
cache_examples=False
)
if __name__ == '__main__':
#* main code, model and moge model initialization
#* ....!!
demo.queue().launch()