shaocong commited on
Commit
f4470c4
·
1 Parent(s): 3668549

first commit

Browse files
Files changed (5) hide show
  1. .gitattributes +51 -0
  2. .gitignore +88 -0
  3. README.md +1 -1
  4. app.py +430 -117
  5. requirements.txt +15 -5
.gitattributes CHANGED
@@ -33,4 +33,55 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  examples/IMG_5703.mp4 filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.egg filter=lfs diff=lfs merge=lfs -text
37
+ **/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ examples/39.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ examples/40.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ examples/1.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ examples/10.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ examples/178db6e89ab682bfc612a3290fec58dd.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ examples/1b0daeb776471c7389b36cee53049417.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ examples/30.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ examples/31.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ examples/8.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ examples/9.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ examples/DJI_20250912164311_0007_D.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ examples/2.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ examples/36.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ examples/32.mp4 filter=lfs diff=lfs merge=lfs -text
53
+ examples/33.mp4 filter=lfs diff=lfs merge=lfs -text
54
+ examples/5.mp4 filter=lfs diff=lfs merge=lfs -text
55
+ examples/8a6dfb8cfe80634f4f77ae9aa830d075.mp4 filter=lfs diff=lfs merge=lfs -text
56
+ examples/b68045aa2128ab63d9c7518f8d62eafe.mp4 filter=lfs diff=lfs merge=lfs -text
57
+ examples/3.mp4 filter=lfs diff=lfs merge=lfs -text
58
+ examples/35.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ examples/69230f105ad8740e08d743a8ee11c651.mp4 filter=lfs diff=lfs merge=lfs -text
60
+ examples/7.mp4 filter=lfs diff=lfs merge=lfs -text
61
+ examples/DJI_20250912163642_0003_D.mp4 filter=lfs diff=lfs merge=lfs -text
62
+ examples/b1f1fa44f414d7731cd7d77751093c44.mp4 filter=lfs diff=lfs merge=lfs -text
63
+ *MOV filter=lfs diff=lfs merge=lfs -text
64
+ *mov filter=lfs diff=lfs merge=lfs -text
65
+ tmp.mp4 filter=lfs diff=lfs merge=lfs -text
66
+ *.MOV filter=lfs diff=lfs merge=lfs -text
67
+ examples/ filter=lfs diff=lfs merge=lfs -text
68
+ examples/18.mp4 filter=lfs diff=lfs merge=lfs -text
69
+ examples/73fc0b2a3af3474de27c7da0bfbf5faa.mp4 filter=lfs diff=lfs merge=lfs -text
70
+ examples/episode_48-camera_head.mp4 filter=lfs diff=lfs merge=lfs -text
71
+ examples/input_20251128_121408.mp4 filter=lfs diff=lfs merge=lfs -text
72
+ examples/input_20251202_031811.mp4 filter=lfs diff=lfs merge=lfs -text
73
+ examples/input_20251202_032007.mp4 filter=lfs diff=lfs merge=lfs -text
74
+ examples/teaser_7.mp4 filter=lfs diff=lfs merge=lfs -text
75
+ examples/27.mp4 filter=lfs diff=lfs merge=lfs -text
76
+ examples/9f2909760aff526070f169620ff38290.mp4 filter=lfs diff=lfs merge=lfs -text
77
+ examples/episode_48-camera_third_view.mp4 filter=lfs diff=lfs merge=lfs -text
78
+ examples/input_20251128_122722.mp4 filter=lfs diff=lfs merge=lfs -text
79
+ examples/teaser_1.mp4 filter=lfs diff=lfs merge=lfs -text
80
+ examples/teaser_25.mp4 filter=lfs diff=lfs merge=lfs -text
81
+ examples/5eaeaff52b23787a3dc3c610655a49d2.mp4 filter=lfs diff=lfs merge=lfs -text
82
+ examples/teaser_3.mp4 filter=lfs diff=lfs merge=lfs -text
83
+ examples/28.mp4 filter=lfs diff=lfs merge=lfs -text
84
+ examples/4.mp4 filter=lfs diff=lfs merge=lfs -text
85
+ examples/extra_5.mp4 filter=lfs diff=lfs merge=lfs -text
86
+ examples/extra_9.mp4 filter=lfs diff=lfs merge=lfs -text
87
  examples/IMG_5703.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ benchmark
5
+ benchmark/*
6
+
7
+
8
+
9
+ *mp4
10
+ !examples/*.mp4
11
+ data/*
12
+ logs/*
13
+
14
+
15
+
16
+ *pyc
17
+
18
+ checkpoints/*
19
+
20
+
21
+
22
+
23
+
24
+ *egg-info
25
+
26
+ frames
27
+
28
+
29
+
30
+
31
+ *png
32
+
33
+ *gif
34
+
35
+
36
+ *ipynb
37
+ daniel_tools
38
+ daniel_tools/*
39
+
40
+
41
+ *jpg
42
+
43
+
44
+ build
45
+
46
+
47
+ run*sh
48
+
49
+
50
+ .m*
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+ scripts/*
59
+
60
+
61
+
62
+ *.sh
63
+ wandb
64
+ benchmark
65
+ *jsonl
66
+ *json
67
+ *npz
68
+ DKT_models
69
+ trash
70
+ gradio
71
+ tmp*
72
+ *.webp
73
+ *.ico
74
+ *.model
75
+ __pycache__/
76
+ *.pyc
77
+ **/tokenizer_configs/**/vocab.txt
78
+ **/tokenizer_configs/**/spiece.model
79
+ **/tokenizer_configs/**/tokenizer.model
80
+
81
+
82
+
83
+ dist
84
+ build
85
+
86
+ .gradio
87
+ debug*
88
+ debug*
README.md CHANGED
@@ -6,7 +6,7 @@ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.44.0
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
  short_description: DKT-Normal
12
  ---
 
6
  sdk: gradio
7
  sdk_version: 5.44.0
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
  short_description: DKT-Normal
12
  ---
app.py CHANGED
@@ -1,154 +1,467 @@
 
 
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
  num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
 
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
40
 
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  height=height,
48
- generator=generator,
49
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- return image, seed
52
 
53
 
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
 
60
  css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
 
 
 
 
 
 
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
 
 
127
 
 
 
128
  num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
 
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
  ],
150
- outputs=[result, seed],
 
 
 
151
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
1
+
2
+ import os
3
  import gradio as gr
 
 
4
 
5
+
6
+ import numpy as np
7
  import torch
8
+ from PIL import Image
9
+ from loguru import logger
10
+ from tqdm import tqdm
11
+ from tools.common_utils import save_video
12
+ from dkt.pipelines.pipeline import DKTPipeline, ModelConfig
13
+
14
+
15
+ import cv2
16
+ import copy
17
+ import trimesh
18
+
19
+ from os.path import join
20
+ from tools.depth2pcd import depth2pcd
21
+ # from moge.model.v2 import MoGeModel
22
+
23
+
24
+ from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map
25
+ import datetime
26
+ import tempfile
27
+ import time
28
+
29
+
30
+ #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors
31
+
32
+ NEGATIVE_PROMPT = ''
33
+ height = 480
34
+ width = 832
35
+ window_size = 21
36
+ # DKT_PIPELINE = DKTPipeline()
37
+ DKT_PIPELINE_14B = DKTPipeline(is14B=True, is_depth=False)
38
+ # DKT_PIPELINE_14B_NORMAL = DKTPipeline(is14B=True, is_depth=False)
39
+
40
+ example_inputs = [
41
+ "examples/1.mp4",
42
+ "examples/7.mp4",
43
+ "examples/8.mp4",
44
+ "examples/39.mp4",
45
+ "examples/10.mp4",
46
+ "examples/30.mp4",
47
+
48
+ "examples/35.mp4",
49
+ "examples/40.mp4",
50
+ "examples/2.mp4",
51
+
52
+
53
+ "examples/4.mp4"
54
+ "examples/episode_48-camera_head.mp4",
55
+ "examples/input_20251128_121408.mp4",
56
+ "examples/input_20251128_122722.mp4",
57
+ "examples/5eaeaff52b23787a3dc3c610655a49d2.mp4",
58
+ "examples/9f2909760aff526070f169620ff38290.mp4",
59
+ "examples/18.mp4",
60
+ "examples/27.mp4",
61
+ "examples/28.mp4",
62
+ "examples/73fc0b2a3af3474de27c7da0bfbf5faa.mp4",
63
+ "examples/episode_48-camera_third_view.mp4",
64
+ "examples/extra_5.mp4",
65
+ "examples/extra_9.mp4",
66
+ "examples/IMG_5703.MOV",
67
+ "examples/input_20251202_031811.mp4",
68
+ "examples/input_20251202_032007.mp4",
69
+ "examples/teaser_1.mp4",
70
+ "examples/3.mp4",
71
+ "examples/teaser_3.mp4",
72
+ "examples/teaser_7.mp4",
73
+ "examples/teaser_25.mp4",
74
+
75
+
76
+
77
+
78
 
 
 
79
 
80
+ ]
81
+
82
+
83
+
84
+
85
+
86
+ def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene:
87
+ pts_3d = point_map[valid_mask] * np.array([-1, -1, 1])
88
+ pts_rgb = frame[valid_mask]
89
+
90
+ # Initialize a 3D scene
91
+ scene_3d = trimesh.Scene()
92
+
93
+ # Add point cloud data to the scene
94
+ point_cloud_data = trimesh.PointCloud(
95
+ vertices=pts_3d, colors=pts_rgb
96
+ )
97
+
98
+ scene_3d.add_geometry(point_cloud_data)
99
+ return scene_3d
100
+
101
 
 
 
102
 
103
+ def create_simple_glb_from_pointcloud(points, colors, glb_filename):
104
+ try:
105
+ if len(points) == 0:
106
+ logger.warning(f"No valid points to create GLB for {glb_filename}")
107
+ return False
108
+
109
+ if colors is not None:
110
+ # logger.info(f"Adding colors to GLB: shape={colors.shape}, range=[{colors.min():.3f}, {colors.max():.3f}]")
111
+ pts_rgb = colors
112
+ else:
113
+ logger.info("No colors provided, adding default white colors")
114
+ pts_rgb = np.ones((len(points), 3))
115
+
116
+ valid_mask = np.ones(len(points), dtype=bool)
117
+
118
+ scene_3d = pmap_to_glb(points, valid_mask, pts_rgb)
119
+
120
+ scene_3d.export(glb_filename)
121
+ # logger.info(f"Saved GLB file using trimesh: {glb_filename}")
122
+
123
+ return True
124
+
125
+ except Exception as e:
126
+ logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}")
127
+ return False
128
 
129
 
130
+
131
+
132
+
133
+
134
+ def process_video(
135
+ video_file,
136
+ model_size,
 
 
137
  num_inference_steps,
138
+ overlap
139
  ):
140
+ global height
141
+ global width
142
+ global window_size
143
 
144
+ global DKT_PIPELINE_14B
145
+ global DKT_PIPELINE
146
 
147
+ if model_size == "14B":
148
+ logger.info(f'14B model is chosen')
149
+ pipeline = DKT_PIPELINE_14B
150
+ elif model_size == "1.3B":
151
+ logger.info(f'1.3B model is chosen')
152
+ pipeline = DKT_PIPELINE
153
+ else:
154
+ raise ValueError(f"Invalid model size: {model_size}")
155
+
156
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
157
+ cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_')
158
+
159
+
160
+
161
+
162
+ start_time = time.time()
163
+
164
+ prediction_result = pipeline(
165
+ video_file,
166
+ negative_prompt=NEGATIVE_PROMPT,
167
  height=height,
168
+ width=width,
169
+ num_inference_steps=num_inference_steps,
170
+ overlap=overlap,
171
+ return_rgb=True,
172
+ get_moge_intrinsics=False
173
+ )
174
+
175
+ end_time = time.time()
176
+ spend_time = end_time - start_time
177
+ logger.info(f"pipeline spend time: {spend_time:.2f} seconds for depth prediction")
178
+
179
+
180
+ #* save depth predictions video
181
+ output_filename = f"output_{timestamp}.mp4"
182
+ output_path = os.path.join(cur_save_dir, output_filename)
183
+
184
+ cap = cv2.VideoCapture(video_file)
185
+ input_fps = cap.get(cv2.CAP_PROP_FPS)
186
+ cap.release()
187
+
188
+ save_video(prediction_result['colored_depth_map'], output_path, fps=input_fps, quality=8)
189
+ return output_path
190
+
191
+ # # 点云可视化相关代码已注释
192
+ # #* vis pc
193
+ #
194
+ # frame_length = len(prediction_result['rgb_frames'])
195
+ # vis_pc_num = 4
196
+ # indices = np.linspace(0, frame_length-1, vis_pc_num)
197
+ # indices = np.round(indices).astype(np.int32)
198
+ #
199
+ #
200
+ # try:
201
+ # glb_files = []
202
+ # print(f"selective indices: {indices}")
203
+ #
204
+ # if prediction_result['moge_mask'].sum() == 0 :
205
+ # raise Exception("No valid points to create GLB for")
206
+ #
207
+ #
208
+ # pc_start_time = time.time()
209
+ # pcds = DKT_PIPELINE.prediction2pc_v3(prediction_result['depth_map'],
210
+ # prediction_result['rgb_frames'], indices,
211
+ # prediction_result['scale'], prediction_result['shift'], prediction_result['moge_intrinsics'],
212
+ # prediction_result['moge_mask'], return_pcd=True)
213
+ #
214
+ # pc_end_time = time.time()
215
+ # pc_spend_time = pc_end_time - pc_start_time
216
+ # print(f"prediction2pc_v2 spend time: {pc_spend_time:.2f} seconds for point cloud extraction, len(pcds): {len(pcds)}")
217
+ #
218
+ #
219
+ # for idx, pcd in enumerate(pcds):
220
+ #
221
+ # # points = np.asarray(pcd.points)
222
+ # # colors = np.asarray(pcd.colors) if pcd.has_colors() else None
223
+ #
224
+ # points = pcd['point']
225
+ # colors = pcd['color']
226
+ #
227
+ # logger.info(f'points:{points.shape} ')
228
+ # print(f'point:{points.shape}')
229
+ # if points.shape[0] == 0:
230
+ # continue
231
+ #
232
+ #
233
+ # points[:, 2] = -points[:, 2]
234
+ # points[:, 0] = -points[:, 0]
235
+ #
236
+ #
237
+ # glb_filename = os.path.join(cur_save_dir, f'{timestamp}_{idx:02d}.glb')
238
+ # success = create_simple_glb_from_pointcloud(points, colors, glb_filename)
239
+ # if not success:
240
+ # logger.warning(f"Failed to save GLB file: {glb_filename}")
241
+ # print(f"Failed to save GLB file: {glb_filename}")
242
+ #
243
+ # glb_files.append(glb_filename)
244
+ # except Exception as e :
245
+ # # logger.info(f" len(pcd):{len(pcds)},idx:{idx}, points.shape:{points.shape} e: {e}")
246
+ # # print(f"len(pcd):{len(pcds)}, idx:{idx}, points.shape:{points.shape}, e: {e}, ")
247
+ # print(e)
248
+ #
249
+ # return output_path, glb_files
250
+
251
 
 
252
 
253
 
254
+
255
+
256
+
257
+ #* gradio creation and initialization
258
+
259
 
260
  css = """
261
+ #download {
262
+ height: 118px;
263
+ }
264
+ .slider .inner {
265
+ width: 5px;
266
+ background: #FFF;
267
+ }
268
+ .viewport {
269
+ aspect-ratio: 4/3;
270
+ }
271
+ .tabs button.selected {
272
+ font-size: 20px !important;
273
+ color: crimson !important;
274
+ }
275
+ h1 {
276
+ text-align: center;
277
+ display: block;
278
+ }
279
+ h2 {
280
+ text-align: center;
281
+ display: block;
282
+ }
283
+ h3 {
284
+ text-align: center;
285
+ display: block;
286
+ }
287
+ .md_feedback li {
288
+ margin-bottom: 0px !important;
289
+ }
290
  """
291
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
 
293
 
294
+ head_html = """
295
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
296
+ <script>
297
+ window.dataLayer = window.dataLayer || [];
298
+ function gtag() {dataLayer.push(arguments);}
299
+ gtag('js', new Date());
300
+ gtag('config', 'G-1FWSVCGZTG');
301
+ </script>
302
+ """
303
 
 
 
 
 
 
 
 
304
 
 
 
 
 
 
 
 
305
 
 
306
 
307
+ with gr.Blocks(css=css, title="DKT", head=head_html) as demo:
308
+ # gr.Markdown(title, elem_classes=["title"])
309
+ gr.Markdown(
310
+ """
311
+ # Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation
312
+ <p align="center">
 
 
313
 
314
+ <a title="Website" href="https://daniellli.github.io/projects/DKT/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
315
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
316
+ </a>
317
+ <a title="Github" href="https://github.com/Daniellli/DKT" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
318
+ <img src="https://img.shields.io/github/stars/Daniellli/DKT?style=social" alt="badge-github-stars">
319
+ </a>
320
+ <a title="Social" href="https://x.com/xshocng1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
321
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
322
+ </a>
323
+ """
324
+ )
325
+ # gr.Markdown(description, elem_classes=["description"])
326
+ # gr.Markdown("### Video Processing Demo", elem_classes=["description"])
327
 
328
+ with gr.Row():
329
+ with gr.Column():
330
+ input_video = gr.Video(label="Input Video", elem_id='video-display-input')
331
+
332
+ model_size = gr.Radio(
333
+ # choices=["1.3B", "14B"],
334
+ choices=["14B"],
335
+ value="14B",
336
+ label="Model Size"
337
+ )
338
 
339
+
340
+ with gr.Accordion("Advanced Parameters", open=False):
341
  num_inference_steps = gr.Slider(
342
+ minimum=1, maximum=50, value=5, step=1,
343
+ label="Number of Inference Steps"
344
+ )
345
+ overlap = gr.Slider(
346
+ minimum=1, maximum=20, value=3, step=1,
347
+ label="Overlap"
348
  )
349
+
350
+ submit = gr.Button(value="Compute Depth", variant="primary")
351
+
352
+ with gr.Column():
353
+ output_video = gr.Video(
354
+ label="Depth Outputs",
355
+ elem_id='video-display-output',
356
+ autoplay=True
357
+ )
358
+ vis_video = gr.Video(
359
+ label="Visualization Video",
360
+ visible=False,
361
+ autoplay=True
362
+ )
363
 
364
+ # # 点云可视化相关 UI 已注释
365
+ # with gr.Row():
366
+ # gr.Markdown("### 3D Point Cloud Visualization", elem_classes=["title"])
367
+ #
368
+ # with gr.Row(equal_height=True):
369
+ # with gr.Column(scale=1):
370
+ # output_point_map0 = gr.Model3D(
371
+ # label="Point Cloud Key Frame 1",
372
+ # clear_color=[1.0, 1.0, 1.0, 1.0],
373
+ # interactive=False,
374
+ # )
375
+ # with gr.Column(scale=1):
376
+ # output_point_map1 = gr.Model3D(
377
+ # label="Point Cloud Key Frame 2",
378
+ # clear_color=[1.0, 1.0, 1.0, 1.0],
379
+ # interactive=False
380
+ # )
381
+ #
382
+ #
383
+ # with gr.Row(equal_height=True):
384
+ #
385
+ # with gr.Column(scale=1):
386
+ # output_point_map2 = gr.Model3D(
387
+ # label="Point Cloud Key Frame 3",
388
+ # clear_color=[1.0, 1.0, 1.0, 1.0],
389
+ # interactive=False
390
+ # )
391
+ # with gr.Column(scale=1):
392
+ # output_point_map3 = gr.Model3D(
393
+ # label="Point Cloud Key Frame 4",
394
+ # clear_color=[1.0, 1.0, 1.0, 1.0],
395
+ # interactive=False
396
+ # )
397
+
398
+ def on_submit(video_file, model_size, num_inference_steps, overlap):
399
+ logger.info('on_submit is calling')
400
+ if video_file is None:
401
+ return None, None
402
+
403
+ try:
404
+
405
+ start_time = time.time()
406
+ output_path = process_video(
407
+ video_file, model_size, num_inference_steps, overlap
408
+ )
409
+ spend_time = time.time() - start_time
410
+ logger.info(f"Total spend time in on_submit: {spend_time:.2f} seconds")
411
+ print(f"Total spend time in on_submit: {spend_time:.2f} seconds")
412
+
413
+
414
+ if output_path is None:
415
+ return None, None
416
+
417
+ # # 点云可视化相关代码已注释
418
+ # model3d_outputs = [None] * 4
419
+ # if glb_files and len(glb_files) !=0 :
420
+ # for i, glb_file in enumerate(glb_files[:4]):
421
+ # if os.path.exists(glb_file):
422
+ # model3d_outputs[i] = glb_file
423
+
424
+
425
+ return output_path, None
426
+
427
+ except Exception as e:
428
+ logger.error(e)
429
+ return None, None
430
+
431
+
432
+ submit.click(
433
+ on_submit,
434
  inputs=[
435
+ input_video, model_size, num_inference_steps, overlap
 
 
 
 
 
 
 
436
  ],
437
+ outputs=[
438
+ output_video, vis_video
439
+ # output_point_map0, output_point_map1, output_point_map2, output_point_map3 # 点云可视化已注释
440
+ ]
441
  )
442
+
443
+
444
+
445
+ def on_example_submit(video_file):
446
+ """Wrapper function for examples with default parameters"""
447
+ return on_submit(video_file, "14B", 5, 3)
448
+
449
+ examples = gr.Examples(
450
+ examples=example_inputs,
451
+ inputs=[input_video],
452
+ outputs=[
453
+ output_video, vis_video
454
+ # output_point_map0, output_point_map1, output_point_map2, output_point_map3 # 点云可视化已注释
455
+ ],
456
+ fn=on_example_submit,
457
+ examples_per_page=36,
458
+ cache_examples=False
459
+ )
460
+
461
 
462
+ if __name__ == '__main__':
463
+
464
+ #* main code, model and moge model initialization
465
+ #* ....
466
+ demo.queue().launch()
467
+
requirements.txt CHANGED
@@ -1,6 +1,16 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
  transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
 
 
3
  transformers
4
+ imageio
5
+ imageio[ffmpeg]
6
+ safetensors
7
+ einops
8
+ modelscope
9
+ ftfy
10
+ accelerate
11
+ loguru
12
+ sentencepiece
13
+ spaces
14
+ open3d
15
+
16
+ git+https://github.com/microsoft/MoGe.git -i https://pypi.org/simple/ --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host files.pythonhosted.org