shaocong commited on
Commit
f75ef88
·
1 Parent(s): f72734a
Files changed (5) hide show
  1. .gitattributes +31 -0
  2. .gitignore +87 -0
  3. README.md +1 -1
  4. app.py +437 -117
  5. requirements.txt +15 -5
.gitattributes CHANGED
@@ -33,4 +33,35 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  examples/IMG_5703.mp4 filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.egg filter=lfs diff=lfs merge=lfs -text
37
+ **/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ examples/39.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ examples/40.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ examples/1.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ examples/10.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ examples/178db6e89ab682bfc612a3290fec58dd.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ examples/1b0daeb776471c7389b36cee53049417.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ examples/30.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ examples/31.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ examples/8.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ examples/9.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ examples/DJI_20250912164311_0007_D.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ examples/2.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ examples/36.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ examples/32.mp4 filter=lfs diff=lfs merge=lfs -text
53
+ examples/33.mp4 filter=lfs diff=lfs merge=lfs -text
54
+ examples/5.mp4 filter=lfs diff=lfs merge=lfs -text
55
+ examples/8a6dfb8cfe80634f4f77ae9aa830d075.mp4 filter=lfs diff=lfs merge=lfs -text
56
+ examples/b68045aa2128ab63d9c7518f8d62eafe.mp4 filter=lfs diff=lfs merge=lfs -text
57
+ examples/3.mp4 filter=lfs diff=lfs merge=lfs -text
58
+ examples/35.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ examples/69230f105ad8740e08d743a8ee11c651.mp4 filter=lfs diff=lfs merge=lfs -text
60
+ examples/7.mp4 filter=lfs diff=lfs merge=lfs -text
61
+ examples/DJI_20250912163642_0003_D.mp4 filter=lfs diff=lfs merge=lfs -text
62
+ examples/b1f1fa44f414d7731cd7d77751093c44.mp4 filter=lfs diff=lfs merge=lfs -text
63
+ *MOV filter=lfs diff=lfs merge=lfs -text
64
+ *mov filter=lfs diff=lfs merge=lfs -text
65
+ tmp.mp4 filter=lfs diff=lfs merge=lfs -text
66
+ *.MOV filter=lfs diff=lfs merge=lfs -text
67
  examples/IMG_5703.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ benchmark
5
+ benchmark/*
6
+
7
+
8
+
9
+ *mp4
10
+ !examples/*.mp4
11
+ data/*
12
+ logs/*
13
+
14
+
15
+
16
+ *pyc
17
+
18
+ checkpoints/*
19
+
20
+
21
+
22
+
23
+
24
+ *egg-info
25
+
26
+ frames
27
+
28
+
29
+
30
+
31
+ *png
32
+
33
+ *gif
34
+
35
+
36
+ *ipynb
37
+ daniel_tools
38
+ daniel_tools/*
39
+
40
+
41
+ *jpg
42
+
43
+
44
+ build
45
+
46
+
47
+ run*sh
48
+
49
+
50
+ .m*
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+ scripts/*
59
+
60
+
61
+
62
+ *.sh
63
+ wandb
64
+ benchmark
65
+ *jsonl
66
+ *json
67
+ *npz
68
+ DKT_models
69
+ trash
70
+ gradio
71
+ tmp*
72
+ *.webp
73
+ *.ico
74
+ *.model
75
+ __pycache__/
76
+ *.pyc
77
+ **/tokenizer_configs/**/vocab.txt
78
+ **/tokenizer_configs/**/spiece.model
79
+ **/tokenizer_configs/**/tokenizer.model
80
+
81
+
82
+
83
+ dist
84
+ build
85
+
86
+ .gradio
87
+ debug*
README.md CHANGED
@@ -6,7 +6,7 @@ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.44.0
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
  short_description: DKT-Depth-14B
12
  ---
 
6
  sdk: gradio
7
  sdk_version: 5.44.0
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
  short_description: DKT-Depth-14B
12
  ---
app.py CHANGED
@@ -1,154 +1,474 @@
 
 
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
  num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
 
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
40
 
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
 
 
 
 
 
 
 
 
 
 
 
 
47
  height=height,
48
- generator=generator,
49
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- return image, seed
52
 
53
 
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
 
60
  css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
 
 
128
  num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
  ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
  import gradio as gr
 
 
4
 
5
+
6
+ import numpy as np
7
  import torch
8
+ from PIL import Image
9
+ from loguru import logger
10
+ from tqdm import tqdm
11
+ from tools.common_utils import save_video
12
+ from dkt.pipelines.pipeline import DKTPipeline, ModelConfig
13
+
14
+
15
+ import cv2
16
+ import copy
17
+ import trimesh
18
+
19
+ from os.path import join
20
+ from tools.depth2pcd import depth2pcd
21
+ # from moge.model.v2 import MoGeModel
22
+
23
+
24
+ from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map
25
+ import datetime
26
+ import tempfile
27
+ import time
28
+
29
+
30
+ #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors
31
+
32
+ NEGATIVE_PROMPT = ''
33
+ height = 480
34
+ width = 832
35
+ window_size = 21
36
+ # DKT_PIPELINE = DKTPipeline()
37
+ DKT_PIPELINE_14B = DKTPipeline(is14B=True)
38
+ # DKT_PIPELINE_14B_NORMAL = DKTPipeline(is14B=True, is_depth=False)
39
+
40
+ example_inputs = [
41
+ "examples/1.mp4",
42
+ "examples/7.mp4",
43
+ "examples/8.mp4",
44
+ "examples/39.mp4",
45
+ "examples/10.mp4",
46
+ "examples/30.mp4",
47
+
48
+ "examples/35.mp4",
49
+ "examples/40.mp4",
50
+ "examples/2.mp4",
51
+
52
+
53
+ "examples/4.mp4"
54
+ "examples/episode_48-camera_head.mp4",
55
+ "examples/input_20251128_121408.mp4",
56
+ "examples/input_20251128_122722.mp4",
57
+ "examples/5eaeaff52b23787a3dc3c610655a49d2.mp4",
58
+ "examples/9f2909760aff526070f169620ff38290.mp4",
59
+ "examples/18.mp4",
60
+ "examples/27.mp4",
61
+ "examples/28.mp4",
62
+ "examples/73fc0b2a3af3474de27c7da0bfbf5faa.mp4",
63
+ "examples/episode_48-camera_third_view.mp4",
64
+ "examples/extra_5.mp4",
65
+ "examples/extra_9.mp4",
66
+ "examples/IMG_5703.MOV",
67
+ "examples/input_20251202_031811.mp4",
68
+ "examples/input_20251202_032007.mp4",
69
+ "examples/teaser_1.mp4",
70
+ "examples/3.mp4",
71
+ "examples/teaser_3.mp4",
72
+ "examples/teaser_7.mp4",
73
+ "examples/teaser_25.mp4",
74
+
75
+
76
+
77
+
78
+
79
+
80
+ ]
81
+
82
+
83
+
84
+
85
+
86
+ def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene:
87
+ pts_3d = point_map[valid_mask] * np.array([-1, -1, 1])
88
+ pts_rgb = frame[valid_mask]
89
+
90
+ # Initialize a 3D scene
91
+ scene_3d = trimesh.Scene()
92
+
93
+ # Add point cloud data to the scene
94
+ point_cloud_data = trimesh.PointCloud(
95
+ vertices=pts_3d, colors=pts_rgb
96
+ )
97
+
98
+ scene_3d.add_geometry(point_cloud_data)
99
+ return scene_3d
100
 
 
 
101
 
 
 
 
 
102
 
103
+ def create_simple_glb_from_pointcloud(points, colors, glb_filename):
104
+ try:
105
+ if len(points) == 0:
106
+ logger.warning(f"No valid points to create GLB for {glb_filename}")
107
+ return False
108
+
109
+ if colors is not None:
110
+ # logger.info(f"Adding colors to GLB: shape={colors.shape}, range=[{colors.min():.3f}, {colors.max():.3f}]")
111
+ pts_rgb = colors
112
+ else:
113
+ logger.info("No colors provided, adding default white colors")
114
+ pts_rgb = np.ones((len(points), 3))
115
+
116
+ valid_mask = np.ones(len(points), dtype=bool)
117
+
118
+ scene_3d = pmap_to_glb(points, valid_mask, pts_rgb)
119
+
120
+ scene_3d.export(glb_filename)
121
+ # logger.info(f"Saved GLB file using trimesh: {glb_filename}")
122
+
123
+ return True
124
+
125
+ except Exception as e:
126
+ logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}")
127
+ return False
128
 
 
 
129
 
130
 
131
+
132
+
133
+
134
+ def process_video(
135
+ video_file,
136
+ model_size,
 
 
 
137
  num_inference_steps,
138
+ overlap
139
  ):
140
+ global height
141
+ global width
142
+ global window_size
143
 
144
+ global DKT_PIPELINE_14B
145
+ global DKT_PIPELINE
146
 
147
+ if model_size == "14B":
148
+ pipeline = DKT_PIPELINE_14B
149
+ elif model_size == "1.3B":
150
+ pipeline = DKT_PIPELINE
151
+ else:
152
+ raise ValueError(f"Invalid model size: {model_size}")
153
+
154
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
155
+ cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_')
156
+
157
+
158
+
159
+
160
+ start_time = time.time()
161
+
162
+ prediction_result = pipeline(
163
+ video_file,
164
+ negative_prompt=NEGATIVE_PROMPT,
165
  height=height,
166
+ width=width,
167
+ num_inference_steps=num_inference_steps,
168
+ overlap=overlap,
169
+ return_rgb=True,
170
+ get_moge_intrinsics=False
171
+ )
172
+
173
+ end_time = time.time()
174
+ spend_time = end_time - start_time
175
+ logger.info(f"pipeline spend time: {spend_time:.2f} seconds for depth prediction")
176
+ print(f"pipeline spend time: {spend_time:.2f} seconds for depth prediction")
177
+
178
+
179
+ #* debug
180
+ print(f' keys: {prediction_result.keys()}')
181
+
182
+
183
+ #* save depth predictions video
184
+ output_filename = f"output_{timestamp}.mp4"
185
+ output_path = os.path.join(cur_save_dir, output_filename)
186
+
187
+
188
+ cap = cv2.VideoCapture(video_file)
189
+ input_fps = cap.get(cv2.CAP_PROP_FPS)
190
+ cap.release()
191
+
192
+ save_video(prediction_result['colored_depth_map'], output_path, fps=input_fps, quality=8)
193
+ return output_path
194
+
195
+ # # 点云可视化相关代码已注释
196
+ # #* vis pc
197
+ #
198
+ # frame_length = len(prediction_result['rgb_frames'])
199
+ # vis_pc_num = 4
200
+ # indices = np.linspace(0, frame_length-1, vis_pc_num)
201
+ # indices = np.round(indices).astype(np.int32)
202
+ #
203
+ #
204
+ # try:
205
+ # glb_files = []
206
+ # print(f"selective indices: {indices}")
207
+ #
208
+ # if prediction_result['moge_mask'].sum() == 0 :
209
+ # raise Exception("No valid points to create GLB for")
210
+ #
211
+ #
212
+ # pc_start_time = time.time()
213
+ # pcds = DKT_PIPELINE.prediction2pc_v3(prediction_result['depth_map'],
214
+ # prediction_result['rgb_frames'], indices,
215
+ # prediction_result['scale'], prediction_result['shift'], prediction_result['moge_intrinsics'],
216
+ # prediction_result['moge_mask'], return_pcd=True)
217
+ #
218
+ # pc_end_time = time.time()
219
+ # pc_spend_time = pc_end_time - pc_start_time
220
+ # print(f"prediction2pc_v2 spend time: {pc_spend_time:.2f} seconds for point cloud extraction, len(pcds): {len(pcds)}")
221
+ #
222
+ #
223
+ # for idx, pcd in enumerate(pcds):
224
+ #
225
+ # # points = np.asarray(pcd.points)
226
+ # # colors = np.asarray(pcd.colors) if pcd.has_colors() else None
227
+ #
228
+ # points = pcd['point']
229
+ # colors = pcd['color']
230
+ #
231
+ # logger.info(f'points:{points.shape} ')
232
+ # print(f'point:{points.shape}')
233
+ # if points.shape[0] == 0:
234
+ # continue
235
+ #
236
+ #
237
+ # points[:, 2] = -points[:, 2]
238
+ # points[:, 0] = -points[:, 0]
239
+ #
240
+ #
241
+ # glb_filename = os.path.join(cur_save_dir, f'{timestamp}_{idx:02d}.glb')
242
+ # success = create_simple_glb_from_pointcloud(points, colors, glb_filename)
243
+ # if not success:
244
+ # logger.warning(f"Failed to save GLB file: {glb_filename}")
245
+ # print(f"Failed to save GLB file: {glb_filename}")
246
+ #
247
+ # glb_files.append(glb_filename)
248
+ # except Exception as e :
249
+ # # logger.info(f" len(pcd):{len(pcds)},idx:{idx}, points.shape:{points.shape} e: {e}")
250
+ # # print(f"len(pcd):{len(pcds)}, idx:{idx}, points.shape:{points.shape}, e: {e}, ")
251
+ # print(e)
252
+ #
253
+ # return output_path, glb_files
254
+
255
 
 
256
 
257
 
258
+
259
+
260
+
261
+ #* gradio creation and initialization
262
+
263
 
264
  css = """
265
+ #download {
266
+ height: 118px;
267
+ }
268
+ .slider .inner {
269
+ width: 5px;
270
+ background: #FFF;
271
+ }
272
+ .viewport {
273
+ aspect-ratio: 4/3;
274
+ }
275
+ .tabs button.selected {
276
+ font-size: 20px !important;
277
+ color: crimson !important;
278
+ }
279
+ h1 {
280
+ text-align: center;
281
+ display: block;
282
+ }
283
+ h2 {
284
+ text-align: center;
285
+ display: block;
286
+ }
287
+ h3 {
288
+ text-align: center;
289
+ display: block;
290
+ }
291
+ .md_feedback li {
292
+ margin-bottom: 0px !important;
293
+ }
294
  """
295
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
 
297
 
298
+ head_html = """
299
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
300
+ <script>
301
+ window.dataLayer = window.dataLayer || [];
302
+ function gtag() {dataLayer.push(arguments);}
303
+ gtag('js', new Date());
304
+ gtag('config', 'G-1FWSVCGZTG');
305
+ </script>
306
+ """
307
 
 
 
 
 
 
 
 
308
 
 
 
 
 
 
 
 
309
 
 
310
 
311
+ with gr.Blocks(css=css, title="DKT", head=head_html) as demo:
312
+ # gr.Markdown(title, elem_classes=["title"])
313
+ gr.Markdown(
314
+ """
315
+ # Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation
316
+ <p align="center">
 
 
317
 
318
+ <a title="Website" href="https://daniellli.github.io/projects/DKT/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
319
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
320
+ </a>
321
+ <a title="Github" href="https://github.com/Daniellli/DKT" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
322
+ <img src="https://img.shields.io/github/stars/Daniellli/DKT?style=social" alt="badge-github-stars">
323
+ </a>
324
+ <a title="Social" href="https://x.com/xshocng1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
325
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
326
+ </a>
327
+ """
328
+ )
329
+ # gr.Markdown(description, elem_classes=["description"])
330
+ # gr.Markdown("### Video Processing Demo", elem_classes=["description"])
331
+
332
+ with gr.Row():
333
+ with gr.Column():
334
+ input_video = gr.Video(label="Input Video", elem_id='video-display-input')
335
+
336
+ model_size = gr.Radio(
337
+ # choices=["1.3B", "14B"],
338
+ choices=["14B"],
339
+ value="14B",
340
+ label="Model Size"
341
+ )
342
 
 
 
 
 
 
 
 
 
343
 
344
+ with gr.Accordion("Advanced Parameters", open=False):
345
  num_inference_steps = gr.Slider(
346
+ minimum=1, maximum=50, value=5, step=1,
347
+ label="Number of Inference Steps"
 
 
 
348
  )
349
+ overlap = gr.Slider(
350
+ minimum=1, maximum=20, value=3, step=1,
351
+ label="Overlap"
352
+ )
353
+
354
+ submit = gr.Button(value="Compute Depth", variant="primary")
355
+
356
+ with gr.Column():
357
+ output_video = gr.Video(
358
+ label="Depth Outputs",
359
+ elem_id='video-display-output',
360
+ autoplay=True
361
+ )
362
+ vis_video = gr.Video(
363
+ label="Visualization Video",
364
+ visible=False,
365
+ autoplay=True
366
+ )
367
+
368
+ # # 点云可视化相关 UI 已注释
369
+ # with gr.Row():
370
+ # gr.Markdown("### 3D Point Cloud Visualization", elem_classes=["title"])
371
+ #
372
+ # with gr.Row(equal_height=True):
373
+ # with gr.Column(scale=1):
374
+ # output_point_map0 = gr.Model3D(
375
+ # label="Point Cloud Key Frame 1",
376
+ # clear_color=[1.0, 1.0, 1.0, 1.0],
377
+ # interactive=False,
378
+ # )
379
+ # with gr.Column(scale=1):
380
+ # output_point_map1 = gr.Model3D(
381
+ # label="Point Cloud Key Frame 2",
382
+ # clear_color=[1.0, 1.0, 1.0, 1.0],
383
+ # interactive=False
384
+ # )
385
+ #
386
+ #
387
+ # with gr.Row(equal_height=True):
388
+ #
389
+ # with gr.Column(scale=1):
390
+ # output_point_map2 = gr.Model3D(
391
+ # label="Point Cloud Key Frame 3",
392
+ # clear_color=[1.0, 1.0, 1.0, 1.0],
393
+ # interactive=False
394
+ # )
395
+ # with gr.Column(scale=1):
396
+ # output_point_map3 = gr.Model3D(
397
+ # label="Point Cloud Key Frame 4",
398
+ # clear_color=[1.0, 1.0, 1.0, 1.0],
399
+ # interactive=False
400
+ # )
401
 
402
+ def on_submit(video_file, model_size, num_inference_steps, overlap):
403
+ logger.info('on_submit is calling')
404
+ if video_file is None:
405
+ return None, None
406
+
407
+ try:
408
+
409
+ start_time = time.time()
410
+ output_path = process_video(
411
+ video_file, model_size, num_inference_steps, overlap
412
+ )
413
+ spend_time = time.time() - start_time
414
+ logger.info(f"Total spend time in on_submit: {spend_time:.2f} seconds")
415
+ print(f"Total spend time in on_submit: {spend_time:.2f} seconds")
416
+
417
+
418
+ if output_path is None:
419
+ return None, None
420
+
421
+ # # 点云可视化相关代码已注释
422
+ # model3d_outputs = [None] * 4
423
+ # if glb_files and len(glb_files) !=0 :
424
+ # for i, glb_file in enumerate(glb_files[:4]):
425
+ # if os.path.exists(glb_file):
426
+ # model3d_outputs[i] = glb_file
427
+
428
+
429
+ return output_path, None
430
+
431
+ except Exception as e:
432
+ logger.error(e)
433
+ return None, None
434
+
435
+
436
+ submit.click(
437
+ on_submit,
438
  inputs=[
439
+ input_video, model_size, num_inference_steps, overlap
 
 
 
 
 
 
 
440
  ],
441
+ outputs=[
442
+ output_video, vis_video
443
+ # output_point_map0, output_point_map1, output_point_map2, output_point_map3 # 点云可视化已注释
444
+ ]
445
+ )
446
+
447
+
448
+
449
+ def on_example_submit(video_file):
450
+ """Wrapper function for examples with default parameters"""
451
+ return on_submit(video_file, "1.3B", 5, 3)
452
+
453
+ examples = gr.Examples(
454
+ examples=example_inputs,
455
+ inputs=[input_video],
456
+ outputs=[
457
+ output_video, vis_video
458
+ # output_point_map0, output_point_map1, output_point_map2, output_point_map3 # 点云可视化已注释
459
+ ],
460
+ fn=on_example_submit,
461
+ examples_per_page=36,
462
+ cache_examples=False
463
  )
464
 
465
+
466
+ if __name__ == '__main__':
467
+
468
+ #* main code, model and moge model initialization
469
+ #* ........
470
+ demo.queue().launch()
471
+
472
+
473
+
474
+
requirements.txt CHANGED
@@ -1,6 +1,16 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
  transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
 
 
3
  transformers
4
+ imageio
5
+ imageio[ffmpeg]
6
+ safetensors
7
+ einops
8
+ modelscope
9
+ ftfy
10
+ accelerate
11
+ loguru
12
+ sentencepiece
13
+ spaces
14
+ open3d
15
+
16
+ git+https://github.com/microsoft/MoGe.git -i https://pypi.org/simple/ --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host files.pythonhosted.org