yyang181 commited on
Commit
0580cf1
·
1 Parent(s): e07d0cd
.gitignore CHANGED
@@ -11,6 +11,7 @@ Pytorch-Correlation-extension/
11
  result
12
  src/
13
  DINOv2FeatureV6_LocalAtten_s2_154000.pth
 
14
 
15
  # Byte-compiled / optimized / DLL files
16
  __pycache__/
 
11
  result
12
  src/
13
  DINOv2FeatureV6_LocalAtten_s2_154000.pth
14
+ example/
15
 
16
  # Byte-compiled / optimized / DLL files
17
  __pycache__/
app.py CHANGED
@@ -13,12 +13,23 @@ import uuid
13
  import urllib.request
14
  import warnings
15
  from os import path
16
-
17
- warnings.filterwarnings("ignore", message="The detected CUDA version .* minor version mismatch")
18
- warnings.filterwarnings("ignore", message="There are no g\\+\\+ version bounds defined for CUDA version.*")
19
- warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.cpp_extension")
20
- os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
21
- os.environ.setdefault("MAX_JOBS", "1")
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  import gradio as gr
24
  import spaces # ZeroGPU decorator
@@ -53,6 +64,15 @@ DESC = """
53
  - 参考图 -> `./colormnet_run_<UUID>/input_ref/<视频名不含扩展>/ref.png`
54
  """
55
 
 
 
 
 
 
 
 
 
 
56
  torch.set_grad_enabled(False)
57
 
58
  # ----------------- DEBUG (kept) -----------------
@@ -146,7 +166,7 @@ def video_to_dataset_root(video_path: str, dataset_root: str):
146
  if idx == 0:
147
  raise RuntimeError("Input video has no readable frames.")
148
 
149
- return subdir, stem, w, h, fps, idx
150
 
151
  # ---------- place ref image into ref_root/<video_stem>/ref.png ----------
152
  def ref_to_dataset_root(ref_image_path: str, ref_root: str, video_stem: str):
@@ -186,8 +206,8 @@ def run_pipeline_cuda(bw_video_path: str, ref_image_path: str, user_config: dict
186
 
187
  DEVICE = torch.device("cuda")
188
 
189
- # Workspace in CWD
190
- base_run_dir = path.join(os.getcwd(), f"colormnet_run_{uuid.uuid4().hex}")
191
  input_video_root = path.join(base_run_dir, "input_video")
192
  input_ref_root = path.join(base_run_dir, "input_ref")
193
  output_dir = path.join(base_run_dir, "result")
@@ -195,11 +215,11 @@ def run_pipeline_cuda(bw_video_path: str, ref_image_path: str, user_config: dict
195
  for p in (base_run_dir, input_video_root, input_ref_root, output_dir):
196
  ensure_clean_dir(p)
197
 
198
- # 1) 抽帧
199
  vid_subdir, vid_stem, w, h, fps, n_frames = video_to_dataset_root(bw_video_path, input_video_root)
200
  assert n_frames > 0, "Input video has no frames."
201
 
202
- # 2) 参考图
203
  _ = ref_to_dataset_root(ref_image_path, input_ref_root, vid_stem)
204
 
205
  # 3) 配置(字段与 main.py 一致;值从 UI 合并)
@@ -224,6 +244,7 @@ def run_pipeline_cuda(bw_video_path: str, ref_image_path: str, user_config: dict
224
  "save_scores": False,
225
  "flip": False,
226
  "size": -1,
 
227
  }
228
  config = {**default_config, **(user_config or {})}
229
  config["enable_long_term"] = not config["disable_long_term"]
@@ -232,18 +253,7 @@ def run_pipeline_cuda(bw_video_path: str, ref_image_path: str, user_config: dict
232
  meta_dataset = DAVISTestDataset_221128_TransColorization_batch(
233
  input_video_root, imset=input_ref_root, size=config["size"]
234
  )
235
- meta_list = meta_dataset.get_datasets()
236
-
237
- target_reader = None
238
- for vr in meta_list:
239
- if getattr(vr, "vid_name", None) == vid_stem:
240
- target_reader = vr
241
- break
242
- if target_reader is None:
243
- if len(meta_list) == 1:
244
- target_reader = meta_list[0]
245
- else:
246
- raise RuntimeError(f"未在数据集中找到目标视频子目录:{vid_stem};可用={ [getattr(v, 'vid_name', '?') for v in meta_list] }")
247
 
248
  # 输出路径规则(与 main.py 一致)
249
  is_youtube = str(config["dataset"]).startswith("Y")
@@ -264,111 +274,109 @@ def run_pipeline_cuda(bw_video_path: str, ref_image_path: str, user_config: dict
264
  total_process_time = 0.0
265
  total_frames = 0
266
 
267
- # 6) 推理(逐帧;内部逻辑与 main.py 对齐;保留调试打印)
268
- vid_reader = target_reader
269
- # Gradio/Spaces 环境禁止子进程:num_workers=0(否则会触发 daemonic processes 错误)
270
- loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
271
- vid_name = vid_reader.vid_name
272
- vid_length = len(loader)
273
-
274
- # 长时记忆触发逻辑:按 main.py 原样(无除零保护)
275
- config['enable_long_term_count_usage'] = (
276
- config['enable_long_term'] and
277
- (vid_length
278
- / (config['max_mid_term_frames'] - config['min_mid_term_frames'])
279
- * config['num_prototypes'])
280
- >= config['max_long_term_elements']
281
- )
282
 
283
- mapper = MaskMapper()
284
- processor = InferenceCore(network, config=config)
285
- first_mask_loaded = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- for ti, data in enumerate(loader):
288
- try:
289
- with torch.cuda.amp.autocast(enabled=not config["benchmark"]):
290
- rgb = data['rgb'].cuda()[0]
291
- msk = data.get('mask')
292
- if not config['FirstFrameIsNotExemplar']:
293
- msk = msk[:, 1:3, :, :] if msk is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- print(rgb.shape, msk.shape)
 
 
 
296
 
297
- info = data['info']
298
- frame = info['frame'][0]
299
- shape = info['shape']
300
- need_resize = info['need_resize'][0]
301
 
302
- if debug_shapes:
303
- print(f"[Loop] frame={ti} rgb={tuple(rgb.shape)} "
304
- f"msk={None if msk is None else tuple(msk.shape)}", flush=True)
 
 
305
 
306
- # timing 与 main.py 一致
307
- start = torch.cuda.Event(enable_timing=True)
308
- end = torch.cuda.Event(enable_timing=True)
309
- start.record()
310
 
311
- if not first_mask_loaded:
312
- if msk is not None:
313
- first_mask_loaded = True
314
- else:
315
- continue
316
 
317
- if config['flip']:
318
- rgb = torch.flip(rgb, dims=[-1])
319
- msk = torch.flip(msk, dims=[-1]) if msk is not None else None
320
 
321
- if msk is not None:
322
- msk = torch.Tensor(msk[0]).cuda()
323
- if need_resize:
324
- msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
325
- processor.set_all_labels(list(range(1, 3)))
326
- labels = range(1, 3)
327
- else:
328
- labels = None
329
-
330
- if config['FirstFrameIsNotExemplar']:
331
- prob = processor.step_AnyExemplar(
332
- rgb,
333
- msk[:1, :, :].repeat(3, 1, 1) if msk is not None else None,
334
- msk[1:3, :, :] if msk is not None else None,
335
- labels,
336
- end=(ti == vid_length - 1)
337
- )
338
- else:
339
- prob = processor.step(rgb, msk, labels, end=(ti == vid_length - 1))
340
-
341
- if need_resize:
342
- prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0]
343
-
344
- end.record()
345
- torch.cuda.synchronize()
346
- total_process_time += (start.elapsed_time(end) / 1000.0)
347
- total_frames += 1
348
-
349
- if config['flip']:
350
- prob = torch.flip(prob, dims=[-1])
351
-
352
- if debug_shapes:
353
- try:
354
- print(f"[Loop] prob={tuple(prob.shape)}", flush=True)
355
- except Exception:
356
- pass
357
-
358
- if config['save_scores']:
359
- prob = (prob.detach().cpu().numpy() * 255).astype(np.uint8)
360
-
361
- if config['save_all'] or info['save'][0]:
362
- this_out_path = path.join(out_path, vid_name)
363
- os.makedirs(this_out_path, exist_ok=True)
364
-
365
- out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1, :, :], prob], dim=0))
366
- out_mask_final = (out_mask_final * 255).astype(np.uint8)
367
- Image.fromarray(out_mask_final).save(os.path.join(this_out_path, frame[:-4] + '.png'))
368
-
369
- except Exception as _e:
370
- # 保留完整 traceback,方便定位
371
- raise RuntimeError("FRAME_ERROR:\n" + traceback.format_exc())
372
 
373
  if total_process_time > 0:
374
  print(f'Total processing time: {total_process_time}')
@@ -388,22 +396,24 @@ def run_pipeline_cuda(bw_video_path: str, ref_image_path: str, user_config: dict
388
  colored_mp4 = path.join(base_run_dir, "colored_output.mp4")
389
  encode_frames_to_video(frames_dir, colored_mp4, fps=fps)
390
 
391
- # 8) 输出视频到 CWD
392
  final_mp4 = path.join(os.getcwd(), "result.mp4")
393
  shutil.move(colored_mp4, final_mp4)
 
 
394
  shutil.rmtree(base_run_dir, ignore_errors=True)
395
 
396
  return final_mp4
397
 
398
  # ----------------- GRADIO HANDLERS -----------------
399
- @spaces.GPU(duration=1200)
400
  def gradio_infer(
401
  debug_shapes, # 调试开关(保留)
402
  bw_video, ref_image,
403
  first_not_exemplar, dataset, split, save_all, benchmark,
404
  disable_long_term, max_mid, min_mid, max_long,
405
  num_proto, top_k, mem_every, deep_update,
406
- save_scores, flip, size
407
  ):
408
  if not torch.cuda.is_available():
409
  return None, "ZeroGPU 未分配到 GPU,请重试(或检查 Space 硬件是否为 ZeroGPU)。"
@@ -413,21 +423,36 @@ def gradio_infer(
413
  if ref_image is None:
414
  return None, "请上传参考图像。"
415
 
416
- # Video path
 
 
 
417
  if isinstance(bw_video, dict) and "name" in bw_video:
418
- bw_video_path = bw_video["name"]
419
  elif isinstance(bw_video, str):
420
- bw_video_path = bw_video
421
  else:
422
  return None, "无法读取视频输入。"
423
 
424
- # Ref path
 
 
 
 
 
 
 
 
425
  if isinstance(ref_image, Image.Image):
426
- tmp_ref_path = path.join(os.getcwd(), f"ref_{uuid.uuid4().hex}.png")
427
- ref_image.save(tmp_ref_path)
428
- ref_path = tmp_ref_path
 
429
  elif isinstance(ref_image, str):
430
- ref_path = ref_image
 
 
 
431
  else:
432
  return None, "无法读取参考图像输入。"
433
 
@@ -467,16 +492,22 @@ def gradio_infer(
467
  "save_scores": bool(save_scores) if save_scores is not None else default_config["save_scores"],
468
  "flip": bool(flip) if flip is not None else default_config["flip"],
469
  "size": int(size) if size is not None else default_config["size"],
 
470
  }
471
 
472
  try:
473
  out_mp4 = run_pipeline_cuda(
474
- bw_video_path, ref_path, user_config, debug_shapes=bool(debug_shapes)
475
  )
476
  return out_mp4, "完成 ✅"
477
  except subprocess.CalledProcessError as e:
 
 
 
478
  return None, f"运行时错误:\n{e}"
479
  except Exception as e:
 
 
480
  return None, f"{e}"
481
 
482
  # ----------------- UI -----------------
@@ -490,9 +521,20 @@ with gr.Blocks() as demo:
490
  inp_video = gr.Video(label="黑白视频(mp4/webm/avi)", interactive=True)
491
  inp_ref = gr.Image(label="参考图像(RGB)", type="pil")
492
 
 
 
 
 
 
 
 
 
 
 
493
  with gr.Accordion("高级参数设置(与 main.py 对齐)", open=False):
494
  with gr.Row():
495
- first_not_exemplar = gr.Checkbox(label="FirstFrameIsNotExemplar", value=False)
 
496
  dataset = gr.Textbox(label="dataset", value="D16_batch")
497
  split = gr.Textbox(label="split", value="val")
498
  save_all = gr.Checkbox(label="save_all", value=True)
@@ -524,7 +566,7 @@ with gr.Blocks() as demo:
524
  first_not_exemplar, dataset, split, save_all, benchmark,
525
  disable_long_term, max_mid, min_mid, max_long,
526
  num_proto, top_k, mem_every, deep_update,
527
- save_scores, flip, size
528
  ],
529
  outputs=[out_video, status]
530
  )
@@ -535,4 +577,4 @@ if __name__ == "__main__":
535
  except Exception as e:
536
  print(f"[WARN] 预下载权重失败(首次推理会再试): {e}")
537
 
538
- demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860)
 
13
  import urllib.request
14
  import warnings
15
  from os import path
16
+ from progressbar import progressbar
17
+ import gc
18
+
19
+ # # 1) 完全禁止 PyTorch 调用 NVML(ZeroGPU/MIG 下经常拿不到 NVML 句柄)
20
+ # os.environ.setdefault("PYTORCH_NO_NVML", "1")
21
+ # # 2) 用 cudaMallocAsync 后端,降低碎片/避免旧分配器的 NVML 路径
22
+ # os.environ.setdefault(
23
+ # "PYTORCH_CUDA_ALLOC_CONF",
24
+ # "backend:cudaMallocAsync,expandable_segments:True,garbage_collection_threshold:0.9,max_split_size_mb:64"
25
+ # )
26
+ # # (可选)定位更准:同步执行
27
+ # os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
28
+ # warnings.filterwarnings("ignore", message="The detected CUDA version .* minor version mismatch")
29
+ # warnings.filterwarnings("ignore", message="There are no g\\+\\+ version bounds defined for CUDA version.*")
30
+ # warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.cpp_extension")
31
+ # os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
32
+ # os.environ.setdefault("MAX_JOBS", "1")
33
 
34
  import gradio as gr
35
  import spaces # ZeroGPU decorator
 
64
  - 参考图 -> `./colormnet_run_<UUID>/input_ref/<视频名不含扩展>/ref.png`
65
  """
66
 
67
+ # ----------------- TEMP WORKDIR -----------------
68
+ TEMP_ROOT = path.join(os.getcwd(), "_colormnet_tmp")
69
+
70
+ def reset_temp_root():
71
+ """每次运行前清空并重建临时工作目录。"""
72
+ if path.isdir(TEMP_ROOT):
73
+ shutil.rmtree(TEMP_ROOT, ignore_errors=True)
74
+ os.makedirs(TEMP_ROOT, exist_ok=True)
75
+
76
  torch.set_grad_enabled(False)
77
 
78
  # ----------------- DEBUG (kept) -----------------
 
166
  if idx == 0:
167
  raise RuntimeError("Input video has no readable frames.")
168
 
169
+ return subdir, path.splitext(path.basename(video_path))[0], w, h, fps, idx
170
 
171
  # ---------- place ref image into ref_root/<video_stem>/ref.png ----------
172
  def ref_to_dataset_root(ref_image_path: str, ref_root: str, video_stem: str):
 
206
 
207
  DEVICE = torch.device("cuda")
208
 
209
+ # Workspace in TEMP_ROOT
210
+ base_run_dir = path.join(TEMP_ROOT, f"colormnet_run_{uuid.uuid4().hex}")
211
  input_video_root = path.join(base_run_dir, "input_video")
212
  input_ref_root = path.join(base_run_dir, "input_ref")
213
  output_dir = path.join(base_run_dir, "result")
 
215
  for p in (base_run_dir, input_video_root, input_ref_root, output_dir):
216
  ensure_clean_dir(p)
217
 
218
+ # 1) 抽帧(把抽帧输出到临时目录中)
219
  vid_subdir, vid_stem, w, h, fps, n_frames = video_to_dataset_root(bw_video_path, input_video_root)
220
  assert n_frames > 0, "Input video has no frames."
221
 
222
+ # 2) 参考图(存到临时目录)
223
  _ = ref_to_dataset_root(ref_image_path, input_ref_root, vid_stem)
224
 
225
  # 3) 配置(字段与 main.py 一致;值从 UI 合并)
 
244
  "save_scores": False,
245
  "flip": False,
246
  "size": -1,
247
+ "reverse": False,
248
  }
249
  config = {**default_config, **(user_config or {})}
250
  config["enable_long_term"] = not config["disable_long_term"]
 
253
  meta_dataset = DAVISTestDataset_221128_TransColorization_batch(
254
  input_video_root, imset=input_ref_root, size=config["size"]
255
  )
256
+ meta_loader = meta_dataset.get_datasets()
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  # 输出路径规则(与 main.py 一致)
259
  is_youtube = str(config["dataset"]).startswith("Y")
 
274
  total_process_time = 0.0
275
  total_frames = 0
276
 
277
+ for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
278
+ # 6) 推理(逐帧;内部逻辑与 main.py 对齐;保留调试打印)
279
+ # Gradio/Spaces 环境禁止子进程:num_workers=0(否则会触发 daemonic processes 错误)
280
+ loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
281
+ vid_name = vid_reader.vid_name
282
+ vid_length = len(loader)
283
+
284
+ # 长时记忆触发逻辑:按 main.py 原样(无除零保护)
285
+ config['enable_long_term_count_usage'] = (
286
+ config['enable_long_term'] and
287
+ (vid_length
288
+ / (config['max_mid_term_frames'] - config['min_mid_term_frames'])
289
+ * config['num_prototypes'])
290
+ >= config['max_long_term_elements']
291
+ )
292
 
293
+ mapper = MaskMapper()
294
+ processor = InferenceCore(network, config=config)
295
+ first_mask_loaded = False
296
+
297
+ for ti, data in enumerate(loader):
298
+ try:
299
+ with torch.cuda.amp.autocast(enabled=not config["benchmark"]):
300
+ rgb = data['rgb'].cuda()[0]
301
+ msk = data.get('mask')
302
+ if not config['FirstFrameIsNotExemplar']:
303
+ msk = msk[:, 1:3, :, :] if msk is not None else None
304
+
305
+ info = data['info']
306
+ frame = info['frame'][0]
307
+ shape = info['shape']
308
+ need_resize = info['need_resize'][0]
309
+
310
+ if debug_shapes:
311
+ print(f"[Loop] frame={ti} rgb={tuple(rgb.shape)} "
312
+ f"msk={None if msk is None else tuple(msk.shape)}", flush=True)
313
+
314
+ # timing 与 main.py 一致
315
+ start = torch.cuda.Event(enable_timing=True)
316
+ end = torch.cuda.Event(enable_timing=True)
317
+ start.record()
318
+
319
+ if not first_mask_loaded:
320
+ if msk is not None:
321
+ first_mask_loaded = True
322
+ else:
323
+ continue
324
+
325
+ if config['flip']:
326
+ rgb = torch.flip(rgb, dims=[-1])
327
+ msk = torch.flip(msk, dims=[-1]) if msk is not None else None
328
 
329
+ if msk is not None:
330
+ msk = torch.Tensor(msk[0]).cuda()
331
+ if need_resize:
332
+ msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
333
+ processor.set_all_labels(list(range(1, 3)))
334
+ labels = range(1, 3)
335
+ else:
336
+ labels = None
337
+
338
+ if config['FirstFrameIsNotExemplar']:
339
+ prob = processor.step_AnyExemplar(
340
+ rgb,
341
+ msk[:1, :, :].repeat(3, 1, 1) if msk is not None else None,
342
+ msk[1:3, :, :] if msk is not None else None,
343
+ labels,
344
+ end=(ti == vid_length - 1)
345
+ )
346
+ else:
347
+ prob = processor.step(rgb, msk, labels, end=(ti == vid_length - 1))
348
+
349
+ if need_resize:
350
+ prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0]
351
 
352
+ end.record()
353
+ torch.cuda.synchronize()
354
+ total_process_time += (start.elapsed_time(end) / 1000.0)
355
+ total_frames += 1
356
 
357
+ if config['flip']:
358
+ prob = torch.flip(prob, dims=[-1])
 
 
359
 
360
+ if debug_shapes:
361
+ try:
362
+ print(f"[Loop] prob={tuple(prob.shape)}", flush=True)
363
+ except Exception:
364
+ pass
365
 
366
+ if config['save_scores']:
367
+ prob = (prob.detach().cpu().numpy() * 255).astype(np.uint8)
 
 
368
 
369
+ if config['save_all'] or info['save'][0]:
370
+ this_out_path = path.join(out_path, vid_name)
371
+ os.makedirs(this_out_path, exist_ok=True)
 
 
372
 
373
+ out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1, :, :], prob], dim=0))
374
+ out_mask_final = (out_mask_final * 255).astype(np.uint8)
375
+ Image.fromarray(out_mask_final).save(os.path.join(this_out_path, frame[:-4] + '.png'))
376
 
377
+ except Exception as _e:
378
+ # 保留完整 traceback,方便定位
379
+ raise RuntimeError("FRAME_ERROR:\n" + traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
  if total_process_time > 0:
382
  print(f'Total processing time: {total_process_time}')
 
396
  colored_mp4 = path.join(base_run_dir, "colored_output.mp4")
397
  encode_frames_to_video(frames_dir, colored_mp4, fps=fps)
398
 
399
+ # 8) 输出视频到 CWD(只保留最终文件)
400
  final_mp4 = path.join(os.getcwd(), "result.mp4")
401
  shutil.move(colored_mp4, final_mp4)
402
+
403
+ # 清理本次 run 的中间目录;(注:上传的原视频/参考帧位于 TEMP_ROOT,将在下次运行开头被 reset_temp_root 清掉)
404
  shutil.rmtree(base_run_dir, ignore_errors=True)
405
 
406
  return final_mp4
407
 
408
  # ----------------- GRADIO HANDLERS -----------------
409
+ @spaces.GPU(duration=600)
410
  def gradio_infer(
411
  debug_shapes, # 调试开关(保留)
412
  bw_video, ref_image,
413
  first_not_exemplar, dataset, split, save_all, benchmark,
414
  disable_long_term, max_mid, min_mid, max_long,
415
  num_proto, top_k, mem_every, deep_update,
416
+ save_scores, flip, size, reverse # 新增
417
  ):
418
  if not torch.cuda.is_available():
419
  return None, "ZeroGPU 未分配到 GPU,请重试(或检查 Space 硬件是否为 ZeroGPU)。"
 
423
  if ref_image is None:
424
  return None, "请上传参考图像。"
425
 
426
+ # —— 每次运行先重置临时目录 —— #
427
+ reset_temp_root()
428
+
429
+ # Video path -> 拷贝到临时目录
430
  if isinstance(bw_video, dict) and "name" in bw_video:
431
+ src_video_path = bw_video["name"]
432
  elif isinstance(bw_video, str):
433
+ src_video_path = bw_video
434
  else:
435
  return None, "无法读取视频输入。"
436
 
437
+ tmp_video_ext = path.splitext(src_video_path)[1] or ".mp4"
438
+ tmp_video_path = path.join(TEMP_ROOT, "input_video" + tmp_video_ext)
439
+ try:
440
+ shutil.copy2(src_video_path, tmp_video_path)
441
+ except Exception as e:
442
+ return None, f"复制视频到临时目录失败:{e}"
443
+
444
+ # Ref path -> 保存/拷贝到临时目录
445
+ tmp_ref_path = path.join(TEMP_ROOT, "ref.png")
446
  if isinstance(ref_image, Image.Image):
447
+ try:
448
+ ref_image.save(tmp_ref_path)
449
+ except Exception as e:
450
+ return None, f"保存参考图像到临时目录失败:{e}"
451
  elif isinstance(ref_image, str):
452
+ try:
453
+ shutil.copy2(ref_image, tmp_ref_path)
454
+ except Exception as e:
455
+ return None, f"复制参考图像到临时目录失败:{e}"
456
  else:
457
  return None, "无法读取参考图像输入。"
458
 
 
492
  "save_scores": bool(save_scores) if save_scores is not None else default_config["save_scores"],
493
  "flip": bool(flip) if flip is not None else default_config["flip"],
494
  "size": int(size) if size is not None else default_config["size"],
495
+ "reverse": bool(reverse) if reverse is not None else False,
496
  }
497
 
498
  try:
499
  out_mp4 = run_pipeline_cuda(
500
+ tmp_video_path, tmp_ref_path, user_config, debug_shapes=bool(debug_shapes)
501
  )
502
  return out_mp4, "完成 ✅"
503
  except subprocess.CalledProcessError as e:
504
+ # 出错也可以顺手清一下临时目录(可选)
505
+ try: shutil.rmtree(TEMP_ROOT, ignore_errors=True)
506
+ except: pass
507
  return None, f"运行时错误:\n{e}"
508
  except Exception as e:
509
+ try: shutil.rmtree(TEMP_ROOT, ignore_errors=True)
510
+ except: pass
511
  return None, f"{e}"
512
 
513
  # ----------------- UI -----------------
 
521
  inp_video = gr.Video(label="黑白视频(mp4/webm/avi)", interactive=True)
522
  inp_ref = gr.Image(label="参考图像(RGB)", type="pil")
523
 
524
+ gr.Examples(
525
+ label="示例输入",
526
+ examples=[
527
+ ["./example/4.mp4", "./example/4.png"],
528
+ ],
529
+ inputs=[inp_video, inp_ref],
530
+ # 不缓存,避免把推理结果当静态示例
531
+ cache_examples=False,
532
+ )
533
+
534
  with gr.Accordion("高级参数设置(与 main.py 对齐)", open=False):
535
  with gr.Row():
536
+ first_not_exemplar = gr.Checkbox(label="FirstFrameIsNotExemplar", value=True)
537
+ reverse = gr.Checkbox(label="reverse", value=False)
538
  dataset = gr.Textbox(label="dataset", value="D16_batch")
539
  split = gr.Textbox(label="split", value="val")
540
  save_all = gr.Checkbox(label="save_all", value=True)
 
566
  first_not_exemplar, dataset, split, save_all, benchmark,
567
  disable_long_term, max_mid, min_mid, max_long,
568
  num_proto, top_k, mem_every, deep_update,
569
+ save_scores, flip, size, reverse # reverse 已接入
570
  ],
571
  outputs=[out_video, status]
572
  )
 
577
  except Exception as e:
578
  print(f"[WARN] 预下载权重失败(首次推理会再试): {e}")
579
 
580
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
inference/data/test_datasets.py CHANGED
@@ -5,24 +5,31 @@ import json
5
  from inference.data.video_reader import VideoReader_221128_TransColorization
6
 
7
  class DAVISTestDataset_221128_TransColorization_batch:
8
- def __init__(self, data_root, imset='2017/val.txt', size=-1):
9
  self.image_dir = data_root
10
  self.mask_dir = imset
11
  self.size_dir = data_root
12
  self.size = size
13
 
14
- self.vid_list = [clip_name for clip_name in sorted(os.listdir(data_root)) if clip_name != '.DS_Store']
 
 
 
15
 
16
  # print(lst, len(lst), self.vid_list, self.vid_list_DAVIS2016, path.join(data_root, 'ImageSets', imset));assert 1==0
17
 
18
  def get_datasets(self):
19
  for video in self.vid_list:
 
 
 
20
  # print(self.image_dir, video, path.join(self.image_dir, video));assert 1==0
21
  yield VideoReader_221128_TransColorization(video,
22
  path.join(self.image_dir, video),
23
  path.join(self.mask_dir, video),
24
  size=self.size,
25
  size_dir=path.join(self.size_dir, video),
 
26
  )
27
 
28
  def __len__(self):
 
5
  from inference.data.video_reader import VideoReader_221128_TransColorization
6
 
7
  class DAVISTestDataset_221128_TransColorization_batch:
8
+ def __init__(self, data_root, imset='2017/val.txt', size=-1, args=None):
9
  self.image_dir = data_root
10
  self.mask_dir = imset
11
  self.size_dir = data_root
12
  self.size = size
13
 
14
+ self.vid_list = [clip_name for clip_name in sorted(os.listdir(data_root)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
15
+ self.ref_img_list = [clip_name for clip_name in sorted(os.listdir(imset)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
16
+
17
+ self.args = args
18
 
19
  # print(lst, len(lst), self.vid_list, self.vid_list_DAVIS2016, path.join(data_root, 'ImageSets', imset));assert 1==0
20
 
21
  def get_datasets(self):
22
  for video in self.vid_list:
23
+ if video not in self.ref_img_list:
24
+ continue
25
+
26
  # print(self.image_dir, video, path.join(self.image_dir, video));assert 1==0
27
  yield VideoReader_221128_TransColorization(video,
28
  path.join(self.image_dir, video),
29
  path.join(self.mask_dir, video),
30
  size=self.size,
31
  size_dir=path.join(self.size_dir, video),
32
+ args=self.args
33
  )
34
 
35
  def __len__(self):
inference/data/video_reader.py CHANGED
@@ -14,7 +14,7 @@ class VideoReader_221128_TransColorization(Dataset):
14
  """
15
  This class is used to read a video, one frame at a time
16
  """
17
- def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None):
18
  """
19
  image_dir - points to a directory of jpg images
20
  mask_dir - points to a directory of png masks
@@ -35,9 +35,10 @@ class VideoReader_221128_TransColorization(Dataset):
35
  else:
36
  self.size_dir = size_dir
37
 
38
- self.frames = [img for img in sorted(os.listdir(self.image_dir)) if img.endswith('.jpg') or img.endswith('.png')]
39
- self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette()
40
- self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0])
 
41
  self.suffix = self.first_gt_path.split('.')[-1]
42
 
43
  if size < 0:
@@ -87,8 +88,11 @@ class VideoReader_221128_TransColorization(Dataset):
87
  mask = mask.resize((img.shape[2], img.shape[1]), Image.BILINEAR)
88
 
89
  mask = self.im_transform(mask)
90
- mask_ab = mask[1:3,:,:]
91
- data['mask'] = mask_ab
 
 
 
92
 
93
  info['shape'] = shape
94
  info['need_resize'] = not (self.size < 0)
 
14
  """
15
  This class is used to read a video, one frame at a time
16
  """
17
+ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None, args=None):
18
  """
19
  image_dir - points to a directory of jpg images
20
  mask_dir - points to a directory of png masks
 
35
  else:
36
  self.size_dir = size_dir
37
 
38
+ flag_reverse = args.getattr('reverse', False) if args is not None else False
39
+ self.frames = [img for img in sorted(os.listdir(self.image_dir), reverse=flag_reverse) if (img.endswith('.jpg') or img.endswith('.png')) and not img.startswith('.')]
40
+ self.palette = Image.open(path.join(mask_dir, sorted([msk for msk in os.listdir(mask_dir) if not msk.startswith('.')])[0])).getpalette()
41
+ self.first_gt_path = path.join(self.mask_dir, sorted([msk for msk in os.listdir(self.mask_dir) if not msk.startswith('.')])[0])
42
  self.suffix = self.first_gt_path.split('.')[-1]
43
 
44
  if size < 0:
 
88
  mask = mask.resize((img.shape[2], img.shape[1]), Image.BILINEAR)
89
 
90
  mask = self.im_transform(mask)
91
+
92
+ # keep L channel of reference image in case First frame is not exemplar
93
+ # mask_ab = mask[1:3,:,:]
94
+ # data['mask'] = mask_ab
95
+ data['mask'] = mask
96
 
97
  info['shape'] = shape
98
  info['need_resize'] = not (self.size < 0)
inference/inference_core.py CHANGED
@@ -109,3 +109,102 @@ class InferenceCore:
109
  self.last_deep_update_ti = self.curr_ti
110
 
111
  return unpad(pred_prob_with_bg, self.pad)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  self.last_deep_update_ti = self.curr_ti
110
 
111
  return unpad(pred_prob_with_bg, self.pad)
112
+
113
+ def step_AnyExemplar(self, image, msk_lll=None, msk_ab=None, valid_labels=None, end=False, flag_FirstframeIsExemplar=False):
114
+ # image: 3*H*W
115
+ # mask: num_objects*H*W or None
116
+ divide_by = 112 # 16
117
+ self.curr_ti += 1
118
+ image, self.pad = pad_divide_by(image, divide_by)
119
+ image = image.unsqueeze(0) # add the batch dimension
120
+
121
+ is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (msk_ab is not None)) and (not end)
122
+ need_segment = (self.curr_ti >= 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) if not flag_FirstframeIsExemplar else (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels)))
123
+ is_deep_update = (
124
+ (self.deep_update_sync and is_mem_frame) or # synchronized
125
+ (not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync
126
+ ) and (not end)
127
+ is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end)
128
+
129
+ key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image,
130
+ need_ek=(self.enable_long_term or need_segment),
131
+ need_sk=is_mem_frame)
132
+ multi_scale_features = (f16, f8, f4)
133
+
134
+ # save as memory if needed
135
+ if msk_ab is not None and not flag_FirstframeIsExemplar:
136
+ need_segment = True
137
+ is_deep_update = False
138
+
139
+ msk_lll, _ = pad_divide_by(msk_lll, divide_by)
140
+ msk_lll = msk_lll.unsqueeze(0) # add the batch dimension
141
+ key_mask, shrinkage_mask, selection_mask, f16_mask, f8_mask, f4_mask = self.network.encode_key(msk_lll,
142
+ need_ek=(self.enable_long_term or need_segment),
143
+ need_sk=is_mem_frame)
144
+
145
+ msk_ab, _ = pad_divide_by(msk_ab, divide_by)
146
+ pred_prob_with_bg = msk_ab
147
+
148
+
149
+ self.memory.create_hidden_state(2, key)
150
+
151
+
152
+ value_mask, hidden_mask = self.network.encode_value(msk_lll, f16_mask, self.memory.get_hidden(),
153
+ pred_prob_with_bg.unsqueeze(0), is_deep_update=False)
154
+
155
+ # save key-value to memory
156
+ self.memory.add_memory(key_mask, shrinkage_mask, value_mask, self.all_labels,
157
+ selection=selection_mask if self.enable_long_term else None)
158
+ self.last_mem_ti = self.curr_ti
159
+
160
+ self.last_ti_key = key_mask
161
+ self.last_ti_value = value_mask
162
+
163
+ if is_deep_update:
164
+ self.memory.set_hidden(hidden_mask)
165
+ self.last_deep_update_ti = self.curr_ti
166
+
167
+ # segment the current frame is needed
168
+ if need_segment:
169
+ memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
170
+
171
+ # short term memory
172
+ batch, num_objects, value_dim, h, w = self.last_ti_value.shape
173
+ last_ti_value = self.last_ti_value.flatten(start_dim=1, end_dim=2)
174
+
175
+ if not (msk_ab is not None and not flag_FirstframeIsExemplar):
176
+ memory_value_short, _ = self.network.short_term_attn(key, self.last_ti_key, last_ti_value, None, key.shape[-2:])
177
+ memory_value_short = memory_value_short.permute(1, 2, 0).view(batch, num_objects, value_dim, h, w)
178
+ memory_readout += memory_value_short
179
+ hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
180
+ self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
181
+ # remove batch dim
182
+ pred_prob_with_bg = pred_prob_with_bg[0]
183
+ pred_prob_no_bg = pred_prob_with_bg
184
+ if is_normal_update:
185
+ self.memory.set_hidden(hidden)
186
+ else:
187
+ pred_prob_no_bg = pred_prob_with_bg = None
188
+
189
+ # use the input mask if any
190
+ if msk_ab is not None and flag_FirstframeIsExemplar:
191
+ msk_ab, _ = pad_divide_by(msk_ab, divide_by)
192
+ pred_prob_with_bg = msk_ab
193
+
194
+ # save as memory if needed
195
+ if is_mem_frame:
196
+ value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(),
197
+ pred_prob_with_bg.unsqueeze(0), is_deep_update=is_deep_update)
198
+
199
+ self.memory.add_memory(key, shrinkage, value, self.all_labels,
200
+ selection=selection if self.enable_long_term else None)
201
+ self.last_mem_ti = self.curr_ti
202
+
203
+ self.last_ti_key = key
204
+ self.last_ti_value = value
205
+
206
+ if is_deep_update:
207
+ self.memory.set_hidden(hidden)
208
+ self.last_deep_update_ti = self.curr_ti
209
+
210
+ return unpad(pred_prob_with_bg, self.pad)