yyang181 commited on
Commit
95257c4
·
1 Parent(s): 0580cf1
Files changed (4) hide show
  1. .gitignore +1 -1
  2. app.py +182 -405
  3. inference/data/test_datasets.py +104 -24
  4. test.py +235 -174
.gitignore CHANGED
@@ -11,7 +11,7 @@ Pytorch-Correlation-extension/
11
  result
12
  src/
13
  DINOv2FeatureV6_LocalAtten_s2_154000.pth
14
- example/
15
 
16
  # Byte-compiled / optimized / DLL files
17
  __pycache__/
 
11
  result
12
  src/
13
  DINOv2FeatureV6_LocalAtten_s2_154000.pth
14
+ _colormnet_tmp/
15
 
16
  # Byte-compiled / optimized / DLL files
17
  __pycache__/
app.py CHANGED
@@ -1,432 +1,177 @@
1
- # app.py (aligned to main.py logic; keeps debug hooks; Gradio-safe DataLoader)
2
- # Inputs: (1) Black-and-white video (mp4/webm/avi) (2) Reference image (RGB)
3
- # Output: Colored video (mp4)
4
- #
5
- # Model checkpoint is HARD-CODED as required:
6
- # https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth
7
 
8
  import os
9
  import sys
10
  import shutil
11
- import subprocess
12
- import uuid
13
  import urllib.request
14
- import warnings
15
  from os import path
16
- from progressbar import progressbar
17
- import gc
18
-
19
- # # 1) 完全禁止 PyTorch 调用 NVML(ZeroGPU/MIG 下经常拿不到 NVML 句柄)
20
- # os.environ.setdefault("PYTORCH_NO_NVML", "1")
21
- # # 2) 用 cudaMallocAsync 后端,降低碎片/避免旧分配器的 NVML 路径
22
- # os.environ.setdefault(
23
- # "PYTORCH_CUDA_ALLOC_CONF",
24
- # "backend:cudaMallocAsync,expandable_segments:True,garbage_collection_threshold:0.9,max_split_size_mb:64"
25
- # )
26
- # # (可选)定位更准:同步执行
27
- # os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
28
- # warnings.filterwarnings("ignore", message="The detected CUDA version .* minor version mismatch")
29
- # warnings.filterwarnings("ignore", message="There are no g\\+\\+ version bounds defined for CUDA version.*")
30
- # warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.cpp_extension")
31
- # os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
32
- # os.environ.setdefault("MAX_JOBS", "1")
33
 
34
  import gradio as gr
35
- import spaces # ZeroGPU decorator
36
- import numpy as np
37
  from PIL import Image
38
  import cv2
39
- import traceback
40
 
41
- import torch
42
- import torch.nn.functional as F
43
- from torch.utils.data import DataLoader
44
-
45
- # ---- Project imports ----
46
- from inference.data.test_datasets import DAVISTestDataset_221128_TransColorization_batch
47
- from inference.data.mask_mapper import MaskMapper
48
- from model.network import ColorMNet
49
- from inference.inference_core import InferenceCore
50
- from dataset.range_transform import inv_lll2rgb_trans
51
- from skimage import color
52
-
53
- # ----------------- CONFIG -----------------
54
  CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
55
  CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
56
 
57
  TITLE = "ColorMNet — ZeroGPU (CUDA-only) Video Colorization with Reference Image"
58
  DESC = """
59
  上传**黑白视频**与**参考图像**,点击“开始着色”。
60
- Space **仅在 ZeroGPU(CUDA)** 上运行;若未分配到 GPU,会报错提示
61
- 模型权重已固定链接(如需修改,请编辑 `CHECKPOINT_URL`)。
62
- **数据集结构**
63
- - 抽帧 -> `./colormnet_run_<UUID>/input_video/<视频名不含扩展>/00000.png...`
64
- - 参考图 -> `./colormnet_run_<UUID>/input_ref/<视频名不含扩展>/ref.png`
 
65
  """
66
 
67
  # ----------------- TEMP WORKDIR -----------------
68
  TEMP_ROOT = path.join(os.getcwd(), "_colormnet_tmp")
 
 
 
69
 
70
  def reset_temp_root():
71
  """每次运行前清空并重建临时工作目录。"""
72
  if path.isdir(TEMP_ROOT):
73
  shutil.rmtree(TEMP_ROOT, ignore_errors=True)
74
  os.makedirs(TEMP_ROOT, exist_ok=True)
 
 
75
 
76
- torch.set_grad_enabled(False)
77
-
78
- # ----------------- DEBUG (kept) -----------------
79
- def _enable_runtime_debug():
80
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # 同步执行,定位准确
81
- os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" # 显示 C++ 栈
82
- os.environ["PYTORCH_JIT"] = "0" # 关闭 JIT
83
- try:
84
- torch.autograd.set_detect_anomaly(True) # 捕捉无效 op/grad
85
- except Exception:
86
- pass
87
-
88
- # ----------------- PATH/DIR UTILS -----------------
89
- def ensure_clean_dir(d: str):
90
- if path.exists(d):
91
- if path.isdir(d):
92
- return
93
- else:
94
- os.remove(d)
95
  os.makedirs(d, exist_ok=True)
96
 
97
- # ----------------- MISC UTILS -----------------
98
  def ensure_checkpoint():
99
- if not path.exists(CHECKPOINT_LOCAL):
100
- print(f"[INFO] Downloading checkpoint from: {CHECKPOINT_URL}")
101
- urllib.request.urlretrieve(CHECKPOINT_URL, CHECKPOINT_LOCAL)
102
- print("[INFO] Checkpoint downloaded:", CHECKPOINT_LOCAL)
103
-
104
- def detach_to_cpu(x: torch.Tensor) -> torch.Tensor:
105
- return x.detach().cpu()
106
-
107
- def tensor_to_np_float(image: torch.Tensor) -> np.ndarray:
108
- image_np = image.numpy().astype("float32")
109
- return image_np
110
-
111
- def lab2rgb_transform_PIL(mask: torch.Tensor) -> np.ndarray:
112
- mask_d = detach_to_cpu(mask)
113
- mask_d = inv_lll2rgb_trans(mask_d)
114
- im = tensor_to_np_float(mask_d)
115
- if len(im.shape) == 3:
116
- im = im.transpose((1, 2, 0))
117
- else:
118
- im = im[:, :, None]
119
- im = color.lab2rgb(im)
120
- return im.clip(0, 1)
121
 
122
- # ---------- extract frames: dataset-root/<video_stem>/00000.png ----------
123
- def video_to_dataset_root(video_path: str, dataset_root: str):
124
  """
125
- 将单个视频抽帧到 dataset_root/<video_stem>/00000.png...
126
- 返回: (subdir_path, video_stem, width, height, fps, frame_count)
127
  """
128
- ensure_clean_dir(dataset_root)
129
- basename = path.basename(video_path)
130
- stem, _ = path.splitext(basename)
131
- subdir = path.join(dataset_root, stem)
132
- ensure_clean_dir(subdir)
133
-
134
  cap = cv2.VideoCapture(video_path)
135
  assert cap.isOpened(), f"Cannot open video: {video_path}"
136
-
137
- fps = cap.get(cv2.CAP_PROP_FPS)
138
- if not fps or fps <= 0:
139
- fps = 25.0
140
-
141
  idx = 0
142
  w = h = None
143
-
144
  while True:
145
  ret, frame = cap.read()
146
  if not ret:
147
  break
148
  if frame is None:
149
  continue
150
-
151
  h, w = frame.shape[:2]
152
- out_path = path.join(subdir, f"{idx:05d}.png")
153
-
154
- parent = path.dirname(out_path)
155
- if not path.isdir(parent):
156
- if path.exists(parent):
157
- os.remove(parent)
158
- os.makedirs(parent, exist_ok=True)
159
-
160
  ok = cv2.imwrite(out_path, frame)
161
  if not ok:
162
  raise RuntimeError(f"写入抽帧失败: {out_path}")
163
  idx += 1
164
-
165
  cap.release()
166
  if idx == 0:
167
  raise RuntimeError("Input video has no readable frames.")
168
-
169
- return subdir, path.splitext(path.basename(video_path))[0], w, h, fps, idx
170
-
171
- # ---------- place ref image into ref_root/<video_stem>/ref.png ----------
172
- def ref_to_dataset_root(ref_image_path: str, ref_root: str, video_stem: str):
173
- ensure_clean_dir(ref_root)
174
- subdir = path.join(ref_root, video_stem)
175
- ensure_clean_dir(subdir)
176
-
177
- img = Image.open(ref_image_path).convert("RGB")
178
- out_path = path.join(subdir, "ref.png")
179
- img.save(out_path)
180
- return subdir
181
 
182
  def encode_frames_to_video(frames_dir: str, out_path: str, fps: float):
183
  frames = sorted([f for f in os.listdir(frames_dir) if f.lower().endswith(".png")])
184
- assert len(frames) > 0, "No frames to encode."
185
-
186
  first = cv2.imread(path.join(frames_dir, frames[0]))
 
 
187
  h, w = first.shape[:2]
188
-
189
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
190
  vw = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
191
  for f in frames:
192
  img = cv2.imread(path.join(frames_dir, f))
 
 
193
  vw.write(img)
194
  vw.release()
195
 
196
- # ----------------- MAIN PIPELINE (CUDA-only) -----------------
197
- def run_pipeline_cuda(bw_video_path: str, ref_image_path: str, user_config: dict, debug_shapes: bool) -> str:
198
- print(bw_video_path, ref_image_path)
199
- if not torch.cuda.is_available():
200
- raise RuntimeError("未检测到 GPU。此 Space 仅支持 ZeroGPU (CUDA)。")
201
-
202
- if debug_shapes:
203
- _enable_runtime_debug()
204
-
205
- ensure_checkpoint()
206
-
207
- DEVICE = torch.device("cuda")
208
-
209
- # Workspace in TEMP_ROOT
210
- base_run_dir = path.join(TEMP_ROOT, f"colormnet_run_{uuid.uuid4().hex}")
211
- input_video_root = path.join(base_run_dir, "input_video")
212
- input_ref_root = path.join(base_run_dir, "input_ref")
213
- output_dir = path.join(base_run_dir, "result")
214
-
215
- for p in (base_run_dir, input_video_root, input_ref_root, output_dir):
216
- ensure_clean_dir(p)
217
-
218
- # 1) 抽帧(把抽帧输出到临时目录中)
219
- vid_subdir, vid_stem, w, h, fps, n_frames = video_to_dataset_root(bw_video_path, input_video_root)
220
- assert n_frames > 0, "Input video has no frames."
221
-
222
- # 2) 参考图(存到临时目录)
223
- _ = ref_to_dataset_root(ref_image_path, input_ref_root, vid_stem)
224
-
225
- # 3) 配置(字段与 main.py 一致;值从 UI 合并)
226
- default_config = {
227
- "FirstFrameIsNotExemplar": True,
228
- "d16_batch_path": "input", # parity only
229
- "ref_path": "ref", # parity only
230
- "output": "result", # parity only
231
- "generic_path": None,
232
- "dataset": "D16_batch",
233
- "split": "val",
234
- "save_all": True,
235
- "benchmark": False,
236
- "disable_long_term": False,
237
- "max_mid_term_frames": 10,
238
- "min_mid_term_frames": 5,
239
- "max_long_term_elements": 10000,
240
- "num_prototypes": 128,
241
- "top_k": 30,
242
- "mem_every": 5,
243
- "deep_update_every": -1,
244
- "save_scores": False,
245
- "flip": False,
246
- "size": -1,
247
- "reverse": False,
248
- }
249
- config = {**default_config, **(user_config or {})}
250
- config["enable_long_term"] = not config["disable_long_term"]
251
-
252
- # 4) 构建数据集(只选本视频 reader)
253
- meta_dataset = DAVISTestDataset_221128_TransColorization_batch(
254
- input_video_root, imset=input_ref_root, size=config["size"]
255
- )
256
- meta_loader = meta_dataset.get_datasets()
257
-
258
- # 输出路径规则(与 main.py 一致)
259
- is_youtube = str(config["dataset"]).startswith("Y")
260
- is_davis = str(config["dataset"]).startswith("D")
261
- is_lv = str(config["dataset"]).startswith("LV")
262
-
263
- app_output_root = output_dir
264
- if is_youtube or config["save_scores"]:
265
- out_path = path.join(app_output_root, "Annotations")
266
- else:
267
- out_path = app_output_root
268
-
269
- # 5) 模型(保持 app 的 URL 权重加载方式)
270
- network = ColorMNet(config, CHECKPOINT_LOCAL).to(DEVICE).eval()
271
- model_weights = torch.load(CHECKPOINT_LOCAL, map_location="cuda")
272
- network.load_weights(model_weights, init_as_zero_if_needed=True)
273
-
274
- total_process_time = 0.0
275
- total_frames = 0
276
-
277
- for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
278
- # 6) 推理(逐帧;内部逻辑与 main.py 对齐;保留调试打印)
279
- # Gradio/Spaces 环境禁止子进程:num_workers=0(否则会触发 daemonic processes 错误)
280
- loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
281
- vid_name = vid_reader.vid_name
282
- vid_length = len(loader)
283
-
284
- # 长时记忆触发逻辑:按 main.py 原样(无除零保护)
285
- config['enable_long_term_count_usage'] = (
286
- config['enable_long_term'] and
287
- (vid_length
288
- / (config['max_mid_term_frames'] - config['min_mid_term_frames'])
289
- * config['num_prototypes'])
290
- >= config['max_long_term_elements']
291
- )
292
-
293
- mapper = MaskMapper()
294
- processor = InferenceCore(network, config=config)
295
- first_mask_loaded = False
296
-
297
- for ti, data in enumerate(loader):
298
- try:
299
- with torch.cuda.amp.autocast(enabled=not config["benchmark"]):
300
- rgb = data['rgb'].cuda()[0]
301
- msk = data.get('mask')
302
- if not config['FirstFrameIsNotExemplar']:
303
- msk = msk[:, 1:3, :, :] if msk is not None else None
304
-
305
- info = data['info']
306
- frame = info['frame'][0]
307
- shape = info['shape']
308
- need_resize = info['need_resize'][0]
309
-
310
- if debug_shapes:
311
- print(f"[Loop] frame={ti} rgb={tuple(rgb.shape)} "
312
- f"msk={None if msk is None else tuple(msk.shape)}", flush=True)
313
-
314
- # timing 与 main.py 一致
315
- start = torch.cuda.Event(enable_timing=True)
316
- end = torch.cuda.Event(enable_timing=True)
317
- start.record()
318
-
319
- if not first_mask_loaded:
320
- if msk is not None:
321
- first_mask_loaded = True
322
- else:
323
- continue
324
-
325
- if config['flip']:
326
- rgb = torch.flip(rgb, dims=[-1])
327
- msk = torch.flip(msk, dims=[-1]) if msk is not None else None
328
-
329
- if msk is not None:
330
- msk = torch.Tensor(msk[0]).cuda()
331
- if need_resize:
332
- msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
333
- processor.set_all_labels(list(range(1, 3)))
334
- labels = range(1, 3)
335
- else:
336
- labels = None
337
-
338
- if config['FirstFrameIsNotExemplar']:
339
- prob = processor.step_AnyExemplar(
340
- rgb,
341
- msk[:1, :, :].repeat(3, 1, 1) if msk is not None else None,
342
- msk[1:3, :, :] if msk is not None else None,
343
- labels,
344
- end=(ti == vid_length - 1)
345
- )
346
- else:
347
- prob = processor.step(rgb, msk, labels, end=(ti == vid_length - 1))
348
-
349
- if need_resize:
350
- prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0]
351
-
352
- end.record()
353
- torch.cuda.synchronize()
354
- total_process_time += (start.elapsed_time(end) / 1000.0)
355
- total_frames += 1
356
-
357
- if config['flip']:
358
- prob = torch.flip(prob, dims=[-1])
359
-
360
- if debug_shapes:
361
- try:
362
- print(f"[Loop] prob={tuple(prob.shape)}", flush=True)
363
- except Exception:
364
- pass
365
-
366
- if config['save_scores']:
367
- prob = (prob.detach().cpu().numpy() * 255).astype(np.uint8)
368
-
369
- if config['save_all'] or info['save'][0]:
370
- this_out_path = path.join(out_path, vid_name)
371
- os.makedirs(this_out_path, exist_ok=True)
372
-
373
- out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1, :, :], prob], dim=0))
374
- out_mask_final = (out_mask_final * 255).astype(np.uint8)
375
- Image.fromarray(out_mask_final).save(os.path.join(this_out_path, frame[:-4] + '.png'))
376
-
377
- except Exception as _e:
378
- # 保留完整 traceback,方便定位
379
- raise RuntimeError("FRAME_ERROR:\n" + traceback.format_exc())
380
-
381
- if total_process_time > 0:
382
- print(f'Total processing time: {total_process_time}')
383
- print(f'Total processed frames: {total_frames}')
384
- print(f'FPS: {total_frames / total_process_time}')
385
- print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
386
-
387
- # 7) 合成 mp4(按 main.py 的 out_path 规则找帧目录)
388
- frames_dir = path.join(out_path, vid_stem if path.isdir(path.join(out_path, vid_stem)) else vid_name)
389
- if not path.isdir(frames_dir):
390
- subs = [d for d in os.listdir(out_path) if path.isdir(path.join(out_path, d))]
391
- if len(subs) == 1:
392
- frames_dir = path.join(out_path, subs[0])
393
  else:
394
- frames_dir = path.join(output_dir, vid_stem)
395
-
396
- colored_mp4 = path.join(base_run_dir, "colored_output.mp4")
397
- encode_frames_to_video(frames_dir, colored_mp4, fps=fps)
398
-
399
- # 8) 输出视频到 CWD(只保留最终文件)
400
- final_mp4 = path.join(os.getcwd(), "result.mp4")
401
- shutil.move(colored_mp4, final_mp4)
402
 
403
- # 清理本次 run 的中间目录;(注:上传的原视频/参考帧位于 TEMP_ROOT,将在下次运行开头被 reset_temp_root 清掉)
404
- shutil.rmtree(base_run_dir, ignore_errors=True)
405
-
406
- return final_mp4
407
-
408
- # ----------------- GRADIO HANDLERS -----------------
409
- @spaces.GPU(duration=600)
410
  def gradio_infer(
411
- debug_shapes, # 调���开关(保留)
412
  bw_video, ref_image,
413
  first_not_exemplar, dataset, split, save_all, benchmark,
414
  disable_long_term, max_mid, min_mid, max_long,
415
  num_proto, top_k, mem_every, deep_update,
416
- save_scores, flip, size, reverse # 新增
417
  ):
418
- if not torch.cuda.is_available():
419
- return None, "ZeroGPU 未分配到 GPU,请重试(或检查 Space 硬件是否为 ZeroGPU)。"
420
-
421
  if bw_video is None:
422
  return None, "请上传黑白视频。"
423
  if ref_image is None:
424
  return None, "请上传参考图像。"
425
-
426
- # —— 每次运行先重置临时目录 —— #
427
  reset_temp_root()
428
 
429
- # Video path -> 拷贝到临时
430
  if isinstance(bw_video, dict) and "name" in bw_video:
431
  src_video_path = bw_video["name"]
432
  elif isinstance(bw_video, str):
@@ -434,28 +179,40 @@ def gradio_infer(
434
  else:
435
  return None, "无法读取视频输入。"
436
 
437
- tmp_video_ext = path.splitext(src_video_path)[1] or ".mp4"
438
- tmp_video_path = path.join(TEMP_ROOT, "input_video" + tmp_video_ext)
 
 
 
 
 
 
 
 
 
 
 
439
  try:
440
- shutil.copy2(src_video_path, tmp_video_path)
441
  except Exception as e:
442
- return None, f"复制视频到临时目录失败:{e}"
443
 
444
- # Ref path -> 保存/拷贝到临时目录
445
- tmp_ref_path = path.join(TEMP_ROOT, "ref.png")
446
  if isinstance(ref_image, Image.Image):
447
  try:
448
- ref_image.save(tmp_ref_path)
449
  except Exception as e:
450
- return None, f"保存参考图像到临时目录失败:{e}"
451
  elif isinstance(ref_image, str):
452
  try:
453
- shutil.copy2(ref_image, tmp_ref_path)
454
  except Exception as e:
455
- return None, f"复制参考图像到临时目录失败:{e}"
456
  else:
457
  return None, "无法读取参考图像输入。"
458
 
 
459
  default_config = {
460
  "FirstFrameIsNotExemplar": True,
461
  "dataset": "D16_batch",
@@ -473,8 +230,8 @@ def gradio_infer(
473
  "save_scores": False,
474
  "flip": False,
475
  "size": -1,
 
476
  }
477
-
478
  user_config = {
479
  "FirstFrameIsNotExemplar": bool(first_not_exemplar) if first_not_exemplar is not None else default_config["FirstFrameIsNotExemplar"],
480
  "dataset": str(dataset) if dataset else default_config["dataset"],
@@ -492,71 +249,92 @@ def gradio_infer(
492
  "save_scores": bool(save_scores) if save_scores is not None else default_config["save_scores"],
493
  "flip": bool(flip) if flip is not None else default_config["flip"],
494
  "size": int(size) if size is not None else default_config["size"],
495
- "reverse": bool(reverse) if reverse is not None else False,
496
  }
497
 
 
 
 
 
498
  try:
499
- out_mp4 = run_pipeline_cuda(
500
- tmp_video_path, tmp_ref_path, user_config, debug_shapes=bool(debug_shapes)
501
- )
502
- return out_mp4, "完成 ✅"
503
- except subprocess.CalledProcessError as e:
504
- # 出错也可以顺手清一下临时目录(可选)
505
- try: shutil.rmtree(TEMP_ROOT, ignore_errors=True)
506
- except: pass
507
- return None, f"运行时错误:\n{e}"
508
  except Exception as e:
509
- try: shutil.rmtree(TEMP_ROOT, ignore_errors=True)
510
- except: pass
511
- return None, f"{e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
  # ----------------- UI -----------------
514
  with gr.Blocks() as demo:
515
  gr.Markdown(f"# {TITLE}")
516
  gr.Markdown(DESC)
517
 
518
- debug_shapes = gr.Checkbox(label="调试日志(打印形状与完整Traceback)", value=False)
519
 
520
  with gr.Row():
521
  inp_video = gr.Video(label="黑白视频(mp4/webm/avi)", interactive=True)
522
  inp_ref = gr.Image(label="参考图像(RGB)", type="pil")
523
-
524
  gr.Examples(
525
  label="示例输入",
526
- examples=[
527
- ["./example/4.mp4", "./example/4.png"],
528
- ],
529
  inputs=[inp_video, inp_ref],
530
- # 不缓存,避免把推理结果当静态示例
531
  cache_examples=False,
532
  )
533
 
534
- with gr.Accordion("高级参数设置( main.py 对齐)", open=False):
535
  with gr.Row():
536
- first_not_exemplar = gr.Checkbox(label="FirstFrameIsNotExemplar", value=True)
537
- reverse = gr.Checkbox(label="reverse", value=False)
538
- dataset = gr.Textbox(label="dataset", value="D16_batch")
539
- split = gr.Textbox(label="split", value="val")
540
- save_all = gr.Checkbox(label="save_all", value=True)
541
- benchmark = gr.Checkbox(label="benchmark", value=False)
542
  with gr.Row():
543
- disable_long_term = gr.Checkbox(label="disable_long_term", value=False)
544
- max_mid = gr.Number(label="max_mid_term_frames", value=10, precision=0)
545
- min_mid = gr.Number(label="min_mid_term_frames", value=5, precision=0)
546
- max_long = gr.Number(label="max_long_term_elements", value=10000, precision=0)
547
- num_proto = gr.Number(label="num_prototypes", value=128, precision=0)
548
  with gr.Row():
549
- top_k = gr.Number(label="top_k", value=30, precision=0)
550
- mem_every = gr.Number(label="mem_every", value=5, precision=0)
551
- deep_update = gr.Number(label="deep_update_every", value=-1, precision=0)
552
- save_scores = gr.Checkbox(label="save_scores", value=False)
553
- flip = gr.Checkbox(label="flip", value=False)
554
- size = gr.Number(label="size", value=-1, precision=0)
555
-
556
- run_btn = gr.Button("开始着色(ZeroGPU 推理)")
557
  with gr.Row():
558
  out_video = gr.Video(label="输出视频(着色结果)")
559
- status = gr.Textbox(label="状态 / 调试输出", interactive=False, lines=12)
560
 
561
  run_btn.click(
562
  fn=gradio_infer,
@@ -566,7 +344,7 @@ with gr.Blocks() as demo:
566
  first_not_exemplar, dataset, split, save_all, benchmark,
567
  disable_long_term, max_mid, min_mid, max_long,
568
  num_proto, top_k, mem_every, deep_update,
569
- save_scores, flip, size, reverse # reverse 已接入
570
  ],
571
  outputs=[out_video, status]
572
  )
@@ -576,5 +354,4 @@ if __name__ == "__main__":
576
  ensure_checkpoint()
577
  except Exception as e:
578
  print(f"[WARN] 预下载权重失败(首次推理会再试): {e}")
579
-
580
- demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # app.py Gradio front-end that calls test.py IN-PROCESS (ZeroGPU-safe)
2
+ # Folder layout per run (under TEMP_ROOT):
3
+ # input_video/<video_stem>/00000.png ...
4
+ # ref/<video_stem>/ref.png
5
+ # output/<video_stem>/*.png
6
+ # Final mp4: TEMP_ROOT/<video_stem>.mp4
7
 
8
  import os
9
  import sys
10
  import shutil
 
 
11
  import urllib.request
 
12
  from os import path
13
+ import io
14
+ from contextlib import redirect_stdout, redirect_stderr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  import gradio as gr
17
+ import spaces
 
18
  from PIL import Image
19
  import cv2
 
20
 
21
+ # ----------------- BASIC INFO -----------------
 
 
 
 
 
 
 
 
 
 
 
 
22
  CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
23
  CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
24
 
25
  TITLE = "ColorMNet — ZeroGPU (CUDA-only) Video Colorization with Reference Image"
26
  DESC = """
27
  上传**黑白视频**与**参考图像**,点击“开始着色”。
28
+ 此版 **app.py 中调度 ZeroGPU**,并**在同一进程**调用 `test.py` 的入口函数
29
+ 临时工作目录结构:
30
+ - 抽帧`_colormnet_tmp/input_video/<视频名>/00000.png ...`
31
+ - 参考:`_colormnet_tmp/ref/<视频名>/ref.png`
32
+ - 输出:`_colormnet_tmp/output/<视频名>/*.png`
33
+ - 合成视频:`_colormnet_tmp/<视频名>.mp4`
34
  """
35
 
36
  # ----------------- TEMP WORKDIR -----------------
37
  TEMP_ROOT = path.join(os.getcwd(), "_colormnet_tmp")
38
+ INPUT_DIR = "input_video"
39
+ REF_DIR = "ref"
40
+ OUTPUT_DIR = "output"
41
 
42
  def reset_temp_root():
43
  """每次运行前清空并重建临时工作目录。"""
44
  if path.isdir(TEMP_ROOT):
45
  shutil.rmtree(TEMP_ROOT, ignore_errors=True)
46
  os.makedirs(TEMP_ROOT, exist_ok=True)
47
+ for sub in (INPUT_DIR, REF_DIR, OUTPUT_DIR):
48
+ os.makedirs(path.join(TEMP_ROOT, sub), exist_ok=True)
49
 
50
+ def ensure_dir(d: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  os.makedirs(d, exist_ok=True)
52
 
53
+ # ----------------- CHECKPOINT (可选) -----------------
54
  def ensure_checkpoint():
55
+ """若 test.py 会在当前目录加载权重,可提前预下载,避免首次拉取超时。"""
56
+ try:
57
+ if not path.exists(CHECKPOINT_LOCAL):
58
+ print(f"[INFO] Downloading checkpoint from: {CHECKPOINT_URL}")
59
+ urllib.request.urlretrieve(CHECKPOINT_URL, CHECKPOINT_LOCAL)
60
+ print("[INFO] Checkpoint downloaded:", CHECKPOINT_LOCAL)
61
+ except Exception as e:
62
+ print(f"[WARN] 预下载权重失败(首次推理会再试): {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # ----------------- VIDEO UTILS -----------------
65
+ def video_to_frames_dir(video_path: str, frames_dir: str):
66
  """
67
+ 抽帧到 frames_dir/00000.png ...
68
+ 返回: (w, h, fps, n_frames)
69
  """
70
+ ensure_dir(frames_dir)
 
 
 
 
 
71
  cap = cv2.VideoCapture(video_path)
72
  assert cap.isOpened(), f"Cannot open video: {video_path}"
73
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
 
 
 
 
74
  idx = 0
75
  w = h = None
 
76
  while True:
77
  ret, frame = cap.read()
78
  if not ret:
79
  break
80
  if frame is None:
81
  continue
 
82
  h, w = frame.shape[:2]
83
+ out_path = path.join(frames_dir, f"{idx:05d}.png")
 
 
 
 
 
 
 
84
  ok = cv2.imwrite(out_path, frame)
85
  if not ok:
86
  raise RuntimeError(f"写入抽帧失败: {out_path}")
87
  idx += 1
 
88
  cap.release()
89
  if idx == 0:
90
  raise RuntimeError("Input video has no readable frames.")
91
+ return w, h, fps, idx
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def encode_frames_to_video(frames_dir: str, out_path: str, fps: float):
94
  frames = sorted([f for f in os.listdir(frames_dir) if f.lower().endswith(".png")])
95
+ if not frames:
96
+ raise RuntimeError(f"No frames found in {frames_dir}")
97
  first = cv2.imread(path.join(frames_dir, frames[0]))
98
+ if first is None:
99
+ raise RuntimeError(f"Failed to read first frame {frames[0]}")
100
  h, w = first.shape[:2]
 
101
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
102
  vw = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
103
  for f in frames:
104
  img = cv2.imread(path.join(frames_dir, f))
105
+ if img is None:
106
+ continue
107
  vw.write(img)
108
  vw.release()
109
 
110
+ # ----------------- CLI MAPPING -----------------
111
+ CONFIG_TO_CLI = {
112
+ "FirstFrameIsNotExemplar": "--FirstFrameIsNotExemplar", # bool
113
+ "dataset": "--dataset",
114
+ "split": "--split",
115
+ "save_all": "--save_all", # bool
116
+ "benchmark": "--benchmark", # bool
117
+ "disable_long_term": "--disable_long_term", # bool
118
+ "max_mid_term_frames": "--max_mid_term_frames",
119
+ "min_mid_term_frames": "--min_mid_term_frames",
120
+ "max_long_term_elements": "--max_long_term_elements",
121
+ "num_prototypes": "--num_prototypes",
122
+ "top_k": "--top_k",
123
+ "mem_every": "--mem_every",
124
+ "deep_update_every": "--deep_update_every",
125
+ "save_scores": "--save_scores", # bool
126
+ "flip": "--flip", # bool
127
+ "size": "--size",
128
+ "reverse": "--reverse", # bool
129
+ }
130
+
131
+ def build_args_list_for_test(d16_batch_path: str,
132
+ out_path: str,
133
+ ref_root: str,
134
+ cfg: dict):
135
+ """
136
+ 构造传给 test.run_cli(args_list) 数列表。
137
+ - 必传:--d16_batch_path <input_video_root>、--ref_path <ref_root>、--output <output_root>
138
+ """
139
+ args = [
140
+ "--d16_batch_path", d16_batch_path,
141
+ "--ref_path", ref_root,
142
+ "--output", out_path,
143
+ ]
144
+ for k, v in cfg.items():
145
+ if k not in CONFIG_TO_CLI:
146
+ continue
147
+ flag = CONFIG_TO_CLI[k]
148
+ if isinstance(v, bool):
149
+ if v:
150
+ args.append(flag) # store_true
151
+ elif v is None:
152
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  else:
154
+ args.extend([flag, str(v)])
155
+ return args
 
 
 
 
 
 
156
 
157
+ # ----------------- GRADIO HANDLER -----------------
158
+ @spaces.GPU(duration=160) # 确保 CUDA 初始化在此函数体内
 
 
 
 
 
159
  def gradio_infer(
160
+ debug_shapes,
161
  bw_video, ref_image,
162
  first_not_exemplar, dataset, split, save_all, benchmark,
163
  disable_long_term, max_mid, min_mid, max_long,
164
  num_proto, top_k, mem_every, deep_update,
165
+ save_scores, flip, size, reverse
166
  ):
167
+ # 1) 基本校验与临时目录
 
 
168
  if bw_video is None:
169
  return None, "请上传黑白视频。"
170
  if ref_image is None:
171
  return None, "请上传参考图像。"
 
 
172
  reset_temp_root()
173
 
174
+ # 2) 解析视频源路径 &标 <video_stem>
175
  if isinstance(bw_video, dict) and "name" in bw_video:
176
  src_video_path = bw_video["name"]
177
  elif isinstance(bw_video, str):
 
179
  else:
180
  return None, "无法读取视频输入。"
181
 
182
+ video_stem = path.splitext(path.basename(src_video_path))[0]
183
+
184
+ # 3) 生成临时路径
185
+ input_root = path.join(TEMP_ROOT, INPUT_DIR) # _colormnet_tmp/input_video
186
+ ref_root = path.join(TEMP_ROOT, REF_DIR) # _colormnet_tmp/ref
187
+ output_root= path.join(TEMP_ROOT, OUTPUT_DIR) # _colormnet_tmp/output
188
+ input_frames_dir = path.join(input_root, video_stem)
189
+ ref_dir = path.join(ref_root, video_stem)
190
+ out_frames_dir = path.join(output_root, video_stem)
191
+ for d in (input_root, ref_root, output_root, input_frames_dir, ref_dir, out_frames_dir):
192
+ ensure_dir(d)
193
+
194
+ # 4) 抽帧 -> input_video/<stem>/
195
  try:
196
+ _w, _h, fps, _n = video_to_frames_dir(src_video_path, input_frames_dir)
197
  except Exception as e:
198
+ return None, f"抽帧失败:\n{e}"
199
 
200
+ # 5) 参考帧 -> ref/<stem>/ref.png
201
+ ref_png_path = path.join(ref_dir, "ref.png")
202
  if isinstance(ref_image, Image.Image):
203
  try:
204
+ ref_image.save(ref_png_path)
205
  except Exception as e:
206
+ return None, f"保存参考图像失败:\n{e}"
207
  elif isinstance(ref_image, str):
208
  try:
209
+ shutil.copy2(ref_image, ref_png_path)
210
  except Exception as e:
211
+ return None, f"复制参考图像失败:\n{e}"
212
  else:
213
  return None, "无法读取参考图像输入。"
214
 
215
+ # 6) 收集 UI 配置
216
  default_config = {
217
  "FirstFrameIsNotExemplar": True,
218
  "dataset": "D16_batch",
 
230
  "save_scores": False,
231
  "flip": False,
232
  "size": -1,
233
+ "reverse": False,
234
  }
 
235
  user_config = {
236
  "FirstFrameIsNotExemplar": bool(first_not_exemplar) if first_not_exemplar is not None else default_config["FirstFrameIsNotExemplar"],
237
  "dataset": str(dataset) if dataset else default_config["dataset"],
 
249
  "save_scores": bool(save_scores) if save_scores is not None else default_config["save_scores"],
250
  "flip": bool(flip) if flip is not None else default_config["flip"],
251
  "size": int(size) if size is not None else default_config["size"],
252
+ "reverse": bool(reverse) if reverse is not None else default_config["reverse"],
253
  }
254
 
255
+ # 7) 预下载权重(可选)
256
+ ensure_checkpoint()
257
+
258
+ # 8) 同进程调用 test.py
259
  try:
260
+ import test # 确保 test.py 同目录且有 run_cli 函数
 
 
 
 
 
 
 
 
261
  except Exception as e:
262
+ return None, f"导入 test.py 失败:\n{e}"
263
+
264
+ args_list = build_args_list_for_test(
265
+ d16_batch_path=input_root, # 指向 input_video 根
266
+ out_path=output_root, # 指向 output 根(test.py 写 output/<stem>/*.png)
267
+ ref_root=ref_root, # 指向 ref 根(test.py 读 ref/<stem>/ref.png��
268
+ cfg=user_config
269
+ )
270
+
271
+ buf = io.StringIO()
272
+ try:
273
+ with redirect_stdout(buf), redirect_stderr(buf):
274
+ entry = getattr(test, "run_cli", None)
275
+ if entry is None or not callable(entry):
276
+ raise RuntimeError("test.py 未提供可调用的 run_cli(args_list) 接口。")
277
+ entry(args_list)
278
+ log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}"
279
+ except Exception as e:
280
+ log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}"
281
+ return None, log
282
+
283
+ # 9) 合成 mp4:从 output/<stem>/ 帧合成 -> TEMP_ROOT/<stem>.mp4
284
+ out_frames = path.join(output_root, video_stem)
285
+ if not path.isdir(out_frames):
286
+ return None, f"未找到输出帧目录:{out_frames}\n\n{log}"
287
+ final_mp4 = path.join(TEMP_ROOT, f"{video_stem}.mp4")
288
+ try:
289
+ encode_frames_to_video(out_frames, final_mp4, fps=fps)
290
+ except Exception as e:
291
+ return None, f"合成视频失败:\n{e}\n\n{log}"
292
+
293
+ return final_mp4, f"完成 ✅\n\n{log}"
294
 
295
  # ----------------- UI -----------------
296
  with gr.Blocks() as demo:
297
  gr.Markdown(f"# {TITLE}")
298
  gr.Markdown(DESC)
299
 
300
+ debug_shapes = gr.Checkbox(label="调试日志(仅用于显示更完整日志)", value=False)
301
 
302
  with gr.Row():
303
  inp_video = gr.Video(label="黑白视频(mp4/webm/avi)", interactive=True)
304
  inp_ref = gr.Image(label="参考图像(RGB)", type="pil")
 
305
  gr.Examples(
306
  label="示例输入",
307
+ examples=[["./example/4.mp4", "./example/4.png"]],
 
 
308
  inputs=[inp_video, inp_ref],
 
309
  cache_examples=False,
310
  )
311
 
312
+ with gr.Accordion("高级参数设置(传给 test.py)", open=False):
313
  with gr.Row():
314
+ first_not_exemplar = gr.Checkbox(label="FirstFrameIsNotExemplar (--FirstFrameIsNotExemplar)", value=True)
315
+ reverse = gr.Checkbox(label="reverse (--reverse)", value=False)
316
+ dataset = gr.Textbox(label="dataset (--dataset)", value="D16_batch")
317
+ split = gr.Textbox(label="split (--split)", value="val")
318
+ save_all = gr.Checkbox(label="save_all (--save_all)", value=True)
319
+ benchmark = gr.Checkbox(label="benchmark (--benchmark)", value=False)
320
  with gr.Row():
321
+ disable_long_term = gr.Checkbox(label="disable_long_term (--disable_long_term)", value=False)
322
+ max_mid = gr.Number(label="max_mid_term_frames (--max_mid_term_frames)", value=10, precision=0)
323
+ min_mid = gr.Number(label="min_mid_term_frames (--min_mid_term_frames)", value=5, precision=0)
324
+ max_long = gr.Number(label="max_long_term_elements (--max_long_term_elements)", value=10000, precision=0)
325
+ num_proto = gr.Number(label="num_prototypes (--num_prototypes)", value=128, precision=0)
326
  with gr.Row():
327
+ top_k = gr.Number(label="top_k (--top_k)", value=30, precision=0)
328
+ mem_every = gr.Number(label="mem_every (--mem_every)", value=5, precision=0)
329
+ deep_update = gr.Number(label="deep_update_every (--deep_update_every)", value=-1, precision=0)
330
+ save_scores = gr.Checkbox(label="save_scores (--save_scores)", value=False)
331
+ flip = gr.Checkbox(label="flip (--flip)", value=False)
332
+ size = gr.Number(label="size (--size)", value=-1, precision=0)
333
+
334
+ run_btn = gr.Button("开始着色(同进程调用 test.py)")
335
  with gr.Row():
336
  out_video = gr.Video(label="输出视频(着色结果)")
337
+ status = gr.Textbox(label="状态 / 日志输出(test.py stdout/stderr)", interactive=False, lines=16)
338
 
339
  run_btn.click(
340
  fn=gradio_infer,
 
344
  first_not_exemplar, dataset, split, save_all, benchmark,
345
  disable_long_term, max_mid, min_mid, max_long,
346
  num_proto, top_k, mem_every, deep_update,
347
+ save_scores, flip, size, reverse
348
  ],
349
  outputs=[out_video, status]
350
  )
 
354
  ensure_checkpoint()
355
  except Exception as e:
356
  print(f"[WARN] 预下载权重失败(首次推理会再试): {e}")
357
+ demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860)
 
inference/data/test_datasets.py CHANGED
@@ -1,36 +1,116 @@
1
  import os
2
  from os import path
3
- import json
4
 
5
- from inference.data.video_reader import VideoReader_221128_TransColorization
 
 
 
 
 
6
 
7
- class DAVISTestDataset_221128_TransColorization_batch:
8
- def __init__(self, data_root, imset='2017/val.txt', size=-1, args=None):
9
- self.image_dir = data_root
10
- self.mask_dir = imset
11
- self.size_dir = data_root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  self.size = size
13
 
14
- self.vid_list = [clip_name for clip_name in sorted(os.listdir(data_root)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
15
- self.ref_img_list = [clip_name for clip_name in sorted(os.listdir(imset)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
16
 
17
- self.args = args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # print(lst, len(lst), self.vid_list, self.vid_list_DAVIS2016, path.join(data_root, 'ImageSets', imset));assert 1==0
20
 
21
- def get_datasets(self):
22
- for video in self.vid_list:
23
- if video not in self.ref_img_list:
24
- continue
 
 
25
 
26
- # print(self.image_dir, video, path.join(self.image_dir, video));assert 1==0
27
- yield VideoReader_221128_TransColorization(video,
28
- path.join(self.image_dir, video),
29
- path.join(self.mask_dir, video),
30
- size=self.size,
31
- size_dir=path.join(self.size_dir, video),
32
- args=self.args
33
- )
34
 
35
  def __len__(self):
36
- return len(self.vid_list)
 
1
  import os
2
  from os import path
 
3
 
4
+ from torch.utils.data.dataset import Dataset
5
+ from torchvision import transforms
6
+ from torchvision.transforms import InterpolationMode
7
+ import torch.nn.functional as Ff
8
+ from PIL import Image
9
+ import numpy as np
10
 
11
+ from dataset.range_transform import im_normalization, im_rgb2lab_normalization, ToTensor, RGB2Lab
12
+
13
+ class VideoReader_221128_TransColorization(Dataset):
14
+ """
15
+ This class is used to read a video, one frame at a time
16
+ """
17
+ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None, args=None):
18
+ """
19
+ image_dir - points to a directory of jpg images
20
+ mask_dir - points to a directory of png masks
21
+ size - resize min. side to size. Does nothing if <0.
22
+ to_save - optionally contains a list of file names without extensions
23
+ where the segmentation mask is required
24
+ use_all_mask - when true, read all available mask in mask_dir.
25
+ Default false. Set to true for YouTubeVOS validation.
26
+ """
27
+ self.vid_name = vid_name
28
+ self.image_dir = image_dir
29
+ self.mask_dir = mask_dir
30
+ self.to_save = to_save
31
+ self.use_all_mask = use_all_mask
32
+ # print('use_all_mask', use_all_mask);assert 1==0
33
+ if size_dir is None:
34
+ self.size_dir = self.image_dir
35
+ else:
36
+ self.size_dir = size_dir
37
+
38
+ # flag_reverse = args.getattr('reverse', False) if args is not None else False
39
+ flag_reverse = False
40
+ self.frames = [img for img in sorted(os.listdir(self.image_dir), reverse=flag_reverse) if (img.endswith('.jpg') or img.endswith('.png')) and not img.startswith('.')]
41
+ self.palette = Image.open(path.join(mask_dir, sorted([msk for msk in os.listdir(mask_dir) if not msk.startswith('.')])[0])).getpalette()
42
+ self.first_gt_path = path.join(self.mask_dir, sorted([msk for msk in os.listdir(self.mask_dir) if not msk.startswith('.')])[0])
43
+ self.suffix = self.first_gt_path.split('.')[-1]
44
+
45
+ if size < 0:
46
+ self.im_transform = transforms.Compose([
47
+ RGB2Lab(),
48
+ ToTensor(),
49
+ im_rgb2lab_normalization,
50
+ ])
51
+ else:
52
+ self.im_transform = transforms.Compose([
53
+ transforms.ToTensor(),
54
+ im_normalization,
55
+ transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
56
+ ])
57
  self.size = size
58
 
 
 
59
 
60
+ def __getitem__(self, idx):
61
+ frame = self.frames[idx]
62
+ info = {}
63
+ data = {}
64
+ info['frame'] = frame
65
+ info['vid_name'] = self.vid_name
66
+ info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
67
+
68
+ im_path = path.join(self.image_dir, frame)
69
+ img = Image.open(im_path).convert('RGB')
70
+
71
+ if self.image_dir == self.size_dir:
72
+ shape = np.array(img).shape[:2]
73
+ else:
74
+ size_path = path.join(self.size_dir, frame)
75
+ size_im = Image.open(size_path).convert('RGB')
76
+ shape = np.array(size_im).shape[:2]
77
+
78
+ gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[idx]) if idx < len(os.listdir(self.mask_dir)) else None
79
+
80
+ img = self.im_transform(img)
81
+ img_l = img[:1,:,:]
82
+ img_lll = img_l.repeat(3,1,1)
83
+
84
+ load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
85
+ if load_mask and path.exists(gt_path):
86
+ mask = Image.open(gt_path).convert('RGB')
87
+
88
+ # 用 PIL 先 resize 成和 img 尺寸一致
89
+ mask = mask.resize((img.shape[2], img.shape[1]), Image.BILINEAR)
90
+
91
+ mask = self.im_transform(mask)
92
+
93
+ # keep L channel of reference image in case First frame is not exemplar
94
+ # mask_ab = mask[1:3,:,:]
95
+ # data['mask'] = mask_ab
96
+ data['mask'] = mask
97
+
98
+ info['shape'] = shape
99
+ info['need_resize'] = not (self.size < 0)
100
+ data['rgb'] = img_lll
101
+ data['info'] = info
102
 
103
+ return data
104
 
105
+ def resize_mask(self, mask):
106
+ # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
107
+ h, w = mask.shape[-2:]
108
+ min_hw = min(h, w)
109
+ return Ff.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
110
+ mode='nearest')
111
 
112
+ def get_palette(self):
113
+ return self.palette
 
 
 
 
 
 
114
 
115
  def __len__(self):
116
+ return len(self.frames)
test.py CHANGED
@@ -1,8 +1,15 @@
 
 
 
 
1
  import os
2
  from os import path
3
  from argparse import ArgumentParser
4
  import shutil
5
 
 
 
 
6
  import torch
7
  import torch.nn.functional as F
8
  from torch.utils.data import DataLoader
@@ -27,47 +34,7 @@ except ImportError:
27
  print('Failed to import hickle. Fine if not using multi-scale testing.')
28
 
29
 
30
- """
31
- Arguments loading
32
- """
33
- parser = ArgumentParser()
34
- parser.add_argument('--model', default='saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth')
35
-
36
- # dataset setting
37
- parser.add_argument('--d16_batch_path', default='input')
38
- parser.add_argument('--deoldify_path', default='ref')
39
- parser.add_argument('--output', default='result')
40
-
41
- # For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
42
- parser.add_argument('--generic_path')
43
- parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D16_batch')
44
- parser.add_argument('--split', help='val/test', default='val')
45
- parser.add_argument('--save_all', action='store_true',
46
- help='Save all frames. Useful only in YouTubeVOS/long-time video', )
47
- parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
48
-
49
- # Long-term memory options
50
- parser.add_argument('--disable_long_term', action='store_true')
51
- parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
52
- parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
53
- parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
54
- type=int, default=10000)
55
- parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
56
-
57
- parser.add_argument('--top_k', type=int, default=30)
58
- parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
59
- parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
60
-
61
- # Multi-scale options
62
- parser.add_argument('--save_scores', action='store_true')
63
- parser.add_argument('--flip', action='store_true')
64
- parser.add_argument('--size', default=-1, type=int,
65
- help='Resize the shorter side to this size. -1 to use original resolution. ')
66
-
67
- args = parser.parse_args()
68
- config = vars(args)
69
- config['enable_long_term'] = not config['disable_long_term']
70
-
71
  def detach_to_cpu(x):
72
  return x.detach().cpu()
73
 
@@ -89,142 +56,236 @@ def lab2rgb_transform_PIL(mask):
89
 
90
  return im.clip(0, 1)
91
 
92
- if args.output is None:
93
- args.output = f'.output/{args.dataset}_{args.split}'
94
- print(f'Output path not provided. Defaulting to {args.output}')
95
-
96
- """
97
- Data preparation
98
- """
99
- is_youtube = args.dataset.startswith('Y')
100
- is_davis = args.dataset.startswith('D')
101
- is_lv = args.dataset.startswith('LV')
102
-
103
- if is_youtube or args.save_scores:
104
- out_path = path.join(args.output, 'Annotations')
105
- else:
106
- out_path = args.output
107
-
108
- if args.split == 'val':
109
- # Set up Dataset, a small hack to use the image set in the 2017 folder because the 2016 one is of a different format
110
- meta_dataset = DAVISTestDataset_221128_TransColorization_batch(args.d16_batch_path, imset=args.deoldify_path, size=args.size)
111
- else:
112
- raise NotImplementedError
113
- palette = None
114
-
115
- torch.autograd.set_grad_enabled(False)
116
-
117
- # Set up loader
118
- meta_loader = meta_dataset.get_datasets()
119
-
120
- # Load our checkpoint
121
- network = ColorMNet(config, args.model).cuda().eval()
122
- if args.model is not None:
123
- model_weights = torch.load(args.model)
124
- network.load_weights(model_weights, init_as_zero_if_needed=True)
125
- else:
126
- print('No model loaded.')
127
-
128
- total_process_time = 0
129
- total_frames = 0
130
-
131
- # Start eval
132
- for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
133
-
134
- loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2)
135
- vid_name = vid_reader.vid_name
136
- vid_length = len(loader)
137
- # no need to count usage for LT if the video is not that long anyway
138
- config['enable_long_term_count_usage'] = (
139
- config['enable_long_term'] and
140
- (vid_length
141
- / (config['max_mid_term_frames']-config['min_mid_term_frames'])
142
- * config['num_prototypes'])
143
- >= config['max_long_term_elements']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
 
 
 
145
 
146
- mapper = MaskMapper()
147
- processor = InferenceCore(network, config=config)
148
- first_mask_loaded = False
149
-
150
- for ti, data in enumerate(loader):
151
- with torch.cuda.amp.autocast(enabled=not args.benchmark):
152
- rgb = data['rgb'].cuda()[0]
153
- msk = data.get('mask')
154
- info = data['info']
155
- frame = info['frame'][0]
156
- shape = info['shape']
157
- need_resize = info['need_resize'][0]
158
-
159
- """
160
- For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964
161
- Seems to be very similar in testing as my previous timing method
162
- with two cuda sync + time.time() in STCN though
163
- """
164
- start = torch.cuda.Event(enable_timing=True)
165
- end = torch.cuda.Event(enable_timing=True)
166
- start.record()
167
-
168
- if not first_mask_loaded:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  if msk is not None:
170
- first_mask_loaded = True
 
 
 
 
171
  else:
172
- # no point to do anything without a mask
173
- continue
174
-
175
- if args.flip:
176
- rgb = torch.flip(rgb, dims=[-1])
177
- msk = torch.flip(msk, dims=[-1]) if msk is not None else None
 
 
 
 
 
 
 
178
 
179
- # Map possibly non-continuous labels to continuous ones
180
- if msk is not None:
181
- msk = torch.Tensor(msk[0]).cuda()
182
  if need_resize:
183
- msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
184
- processor.set_all_labels(list(range(1,3)))
185
- labels = range(1,3)
186
- else:
187
- labels = None
188
-
189
- # Run the model on this frame
190
- prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1))
191
-
192
- # Upsample to original size if needed
193
- if need_resize:
194
- prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
195
-
196
- end.record()
197
- torch.cuda.synchronize()
198
- total_process_time += (start.elapsed_time(end)/1000)
199
- total_frames += 1
200
-
201
- if args.flip:
202
- prob = torch.flip(prob, dims=[-1])
203
-
204
- if args.save_scores:
205
- prob = (prob.detach().cpu().numpy()*255).astype(np.uint8)
206
-
207
- # Save the mask
208
- if args.save_all or info['save'][0]:
209
- this_out_path = path.join(out_path, vid_name)
210
- os.makedirs(this_out_path, exist_ok=True)
211
-
212
- out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1,:,:], prob], dim=0))
213
- out_mask_final = out_mask_final * 255
214
- out_mask_final = out_mask_final.astype(np.uint8)
215
-
216
- out_img = Image.fromarray(out_mask_final)
217
- out_img.save(os.path.join(this_out_path, frame[:-4]+'.png'))
218
-
219
- print(f'Total processing time: {total_process_time}')
220
- print(f'Total processed frames: {total_frames}')
221
- print(f'FPS: {total_frames / total_process_time}')
222
- print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
223
-
224
- if not args.save_scores:
225
- if is_youtube:
226
- print('Making zip for YouTubeVOS...')
227
- shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
228
- elif is_davis and args.split == 'test':
229
- print('Making zip for DAVIS test-dev...')
230
- shutil.make_archive(args.output, 'zip', args.output)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test.py — In-process callable version (for ZeroGPU stateless)
2
+ # Keep original logic; add build_parser(), run_cli(args_list), and run_inference(args)
3
+ # Do NOT initialize CUDA at import-time.
4
+
5
  import os
6
  from os import path
7
  from argparse import ArgumentParser
8
  import shutil
9
 
10
+ # 不在这里做 @spaces.GPU 装饰,避免与 app.py 的 @spaces.GPU 双重调度
11
+ # import spaces
12
+
13
  import torch
14
  import torch.nn.functional as F
15
  from torch.utils.data import DataLoader
 
34
  print('Failed to import hickle. Fine if not using multi-scale testing.')
35
 
36
 
37
+ # ----------------- small utils -----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def detach_to_cpu(x):
39
  return x.detach().cpu()
40
 
 
56
 
57
  return im.clip(0, 1)
58
 
59
+
60
+ # ----------------- argparse -----------------
61
+ def build_parser() -> ArgumentParser:
62
+ parser = ArgumentParser()
63
+ parser.add_argument('--model', default='saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth')
64
+ parser.add_argument('--FirstFrameIsNotExemplar', help='Whether the provided reference frame is exactly the first input frame', action='store_true')
65
+
66
+ # dataset setting
67
+ parser.add_argument('--d16_batch_path', default='input', help='Point to folder A/ which contains <video_name>/00000.png etc.')
68
+ parser.add_argument('--ref_path', default='ref', help='Kept for parity; dataset will also read ref.png under each video folder when args provided')
69
+ parser.add_argument('--output', default='result', help='Directory to save results')
70
+
71
+ parser.add_argument('--reverse', default=False, action='store_true', help='whether to reverse the frame order')
72
+ parser.add_argument('--allow_resume', action='store_true',
73
+ help='skip existing videos that have been colorized')
74
+
75
+ # For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
76
+ parser.add_argument('--generic_path')
77
+ parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D16_batch')
78
+ parser.add_argument('--split', help='val/test', default='val')
79
+ parser.add_argument('--save_all', action='store_true',
80
+ help='Save all frames. Useful only in YouTubeVOS/long-time video')
81
+ parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
82
+
83
+ # Long-term memory options
84
+ parser.add_argument('--disable_long_term', action='store_true')
85
+ parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
86
+ parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
87
+ parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
88
+ type=int, default=10000)
89
+ parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
90
+
91
+ parser.add_argument('--top_k', type=int, default=30)
92
+ parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
93
+ parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
94
+
95
+ # Multi-scale options
96
+ parser.add_argument('--save_scores', action='store_true')
97
+ parser.add_argument('--flip', action='store_true')
98
+ parser.add_argument('--size', default=-1, type=int,
99
+ help='Resize the shorter side to this size. -1 to use original resolution. ')
100
+ return parser
101
+
102
+
103
+ # ----------------- core inference -----------------
104
+ def run_inference(args):
105
+ """
106
+ 真正的推理流程。必须在 ZeroGPU 的调度上下文里被调用(由 app.py 的 @spaces.GPU 包裹)。
107
+ 不要在导入模块时做任何 CUDA 初始化。
108
+ """
109
+ config = vars(args)
110
+ config['enable_long_term'] = not config['disable_long_term']
111
+
112
+ if args.output is None:
113
+ args.output = f'.output/{args.dataset}_{args.split}'
114
+ print(f'Output path not provided. Defaulting to {args.output}')
115
+
116
+ # ----- Data preparation -----
117
+ is_youtube = args.dataset.startswith('Y')
118
+ is_davis = args.dataset.startswith('D')
119
+ is_lv = args.dataset.startswith('LV')
120
+
121
+ if is_youtube or args.save_scores:
122
+ out_path = path.join(args.output, 'Annotations')
123
+ else:
124
+ out_path = args.output
125
+
126
+ if args.split != 'val':
127
+ raise NotImplementedError('Only split=val is supported in this script.')
128
+
129
+ # 数据集:支持 A/<video>/00000.png ... 且读取 A/<video>/ref.png
130
+ meta_dataset = DAVISTestDataset_221128_TransColorization_batch(
131
+ args.d16_batch_path, imset=args.ref_path, size=args.size, args=args
132
  )
133
+ palette = None # 兼容保留
134
+
135
+ torch.autograd.set_grad_enabled(False)
136
 
137
+ # Set up loader list (video readers)
138
+ meta_loader = meta_dataset.get_datasets()
139
+
140
+ # Load checkpoint/model
141
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
142
+ network = ColorMNet(config, args.model).to(device).eval()
143
+ if args.model is not None:
144
+ # map_location 不指定,按默认走(ZeroGPU 下会在被调度的设备上加载)
145
+ model_weights = torch.load(args.model, map_location=device)
146
+ network.load_weights(model_weights, init_as_zero_if_needed=True)
147
+ else:
148
+ print('No model loaded.')
149
+
150
+ total_process_time = 0.0
151
+ total_frames = 0
152
+
153
+ # ----- Start eval over videos -----
154
+ for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
155
+ # 注意:ZeroGPU/Spaces 环境不允许子进程多线程加载,保持 num_workers=0
156
+ loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
157
+ vid_name = vid_reader.vid_name
158
+ vid_length = len(loader)
159
+
160
+ # LT usage check per original logic
161
+ config['enable_long_term_count_usage'] = (
162
+ config['enable_long_term'] and
163
+ (vid_length
164
+ / (config['max_mid_term_frames'] - config['min_mid_term_frames'])
165
+ * config['num_prototypes'])
166
+ >= config['max_long_term_elements']
167
+ )
168
+
169
+ mapper = MaskMapper()
170
+ processor = InferenceCore(network, config=config)
171
+ first_mask_loaded = False
172
+
173
+ # skip existing videos
174
+ if args.allow_resume:
175
+ this_out_path = path.join(out_path, vid_name)
176
+ if path.exists(this_out_path):
177
+ print(f'Skipping {this_out_path} because output already exists.')
178
+ continue
179
+
180
+ for ti, data in enumerate(loader):
181
+ with torch.cuda.amp.autocast(enabled=not args.benchmark):
182
+ rgb = data['rgb'].to(device)[0]
183
+
184
+ msk = data.get('mask')
185
+ if not config['FirstFrameIsNotExemplar']:
186
+ msk = msk[:, 1:3, :, :] if msk is not None else None
187
+
188
+ info = data['info']
189
+ frame = info['frame'][0]
190
+ shape = info['shape']
191
+ need_resize = info['need_resize'][0]
192
+
193
+ start = torch.cuda.Event(enable_timing=True)
194
+ end = torch.cuda.Event(enable_timing=True)
195
+ start.record()
196
+
197
+ # 第一次必须有 mask
198
+ if not first_mask_loaded:
199
+ if msk is not None:
200
+ first_mask_loaded = True
201
+ else:
202
+ continue
203
+
204
+ if args.flip:
205
+ rgb = torch.flip(rgb, dims=[-1])
206
+ msk = torch.flip(msk, dims=[-1]) if msk is not None else None
207
+
208
+ # Map possibly non-continuous labels to continuous ones
209
  if msk is not None:
210
+ msk = torch.Tensor(msk[0]).to(device)
211
+ if need_resize:
212
+ msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
213
+ processor.set_all_labels(list(range(1, 3)))
214
+ labels = range(1, 3)
215
  else:
216
+ labels = None
217
+
218
+ # Run the model on this frame
219
+ if config['FirstFrameIsNotExemplar']:
220
+ prob = processor.step_AnyExemplar(
221
+ rgb,
222
+ msk[:1, :, :].repeat(3, 1, 1) if msk is not None else None,
223
+ msk[1:3, :, :] if msk is not None else None,
224
+ labels,
225
+ end=(ti == vid_length - 1)
226
+ )
227
+ else:
228
+ prob = processor.step(rgb, msk, labels, end=(ti == vid_length - 1))
229
 
230
+ # Upsample to original size if needed
 
 
231
  if need_resize:
232
+ prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0]
233
+
234
+ end.record()
235
+ torch.cuda.synchronize()
236
+ total_process_time += (start.elapsed_time(end)/1000)
237
+ total_frames += 1
238
+
239
+ if args.flip:
240
+ prob = torch.flip(prob, dims=[-1])
241
+
242
+ if args.save_scores:
243
+ prob = (prob.detach().cpu().numpy() * 255).astype(np.uint8)
244
+
245
+ # Save the mask
246
+ if args.save_all or info['save'][0]:
247
+ this_out_path = path.join(out_path, vid_name)
248
+ os.makedirs(this_out_path, exist_ok=True)
249
+
250
+ out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1, :, :], prob], dim=0))
251
+ out_mask_final = (out_mask_final * 255).astype(np.uint8)
252
+
253
+ out_img = Image.fromarray(out_mask_final)
254
+ out_img.save(os.path.join(this_out_path, frame[:-4] + '.png'))
255
+
256
+ print(f'Total processing time: {total_process_time}')
257
+ print(f'Total processed frames: {total_frames}')
258
+ print(f'FPS: {total_frames / total_process_time}')
259
+ print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
260
+
261
+ # 与原版一致:只在 save_scores=False 且特定数据集/子集时打 zip
262
+ if not args.save_scores:
263
+ if is_youtube:
264
+ print('Making zip for YouTubeVOS...')
265
+ shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
266
+ elif is_davis and args.split == 'test':
267
+ print('Making zip for DAVIS test-dev...')
268
+ shutil.make_archive(args.output, 'zip', args.output)
269
+
270
+
271
+ # ----------------- public entrypoints -----------------
272
+ def run_cli(args_list=None):
273
+ """
274
+ app.py 同进程调用:test.run_cli(args_list)
275
+ """
276
+ parser = build_parser()
277
+ args = parser.parse_args(args_list)
278
+ return run_inference(args)
279
+
280
+
281
+ def main():
282
+ """
283
+ 保留命令行可运行:python test.py --d16_batch_path A --output result ...
284
+ 注意:若在 Hugging Face Spaces/ZeroGPU 无状态环境下直接 run main(),
285
+ 需要由上层(如 app.py 的 @spaces.GPU)提供调度上下文。
286
+ """
287
+ run_cli()
288
+
289
+
290
+ if __name__ == '__main__':
291
+ main()