yyang181 commited on
Commit
a5ec286
·
1 Parent(s): ecd9134
Files changed (1) hide show
  1. app.py +77 -30
app.py CHANGED
@@ -17,20 +17,55 @@ import gradio as gr
17
  import spaces
18
  from PIL import Image
19
  import cv2
 
20
 
21
  # ----------------- BASIC INFO -----------------
22
  CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
23
  CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
24
 
25
- TITLE = "ColorMNet — ZeroGPU (CUDA-only) Video Colorization with Reference Image"
26
  DESC = """
27
- 上传**黑白视频**与**参考图像**,点击“开始着色”。
 
28
  此版本在 **app.py 中调度 ZeroGPU**,并**在同一进程**调用 `test.py` 的入口函数。
29
- 临时工作目录结构:
30
- - 抽帧:`_colormnet_tmp/input_video/<视频名>/00000.png ...`
31
- - 参考:`_colormnet_tmp/ref/<视频名>/ref.png`
32
- - 输出:`_colormnet_tmp/output/<视频名>/*.png`
33
  - 合成视频:`_colormnet_tmp/<视频名>.mp4`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  """
35
 
36
  # ----------------- TEMP WORKDIR -----------------
@@ -83,11 +118,11 @@ def video_to_frames_dir(video_path: str, frames_dir: str):
83
  out_path = path.join(frames_dir, f"{idx:05d}.png")
84
  ok = cv2.imwrite(out_path, frame)
85
  if not ok:
86
- raise RuntimeError(f"写入抽帧失败: {out_path}")
87
  idx += 1
88
  cap.release()
89
  if idx == 0:
90
- raise RuntimeError("Input video has no readable frames.")
91
  return w, h, fps, idx
92
 
93
  def encode_frames_to_video(frames_dir: str, out_path: str, fps: float):
@@ -166,9 +201,9 @@ def gradio_infer(
166
  ):
167
  # 1) 基本校验与临时目录
168
  if bw_video is None:
169
- return None, "请上传黑白视频"
170
  if ref_image is None:
171
- return None, "请上传参考图像"
172
  reset_temp_root()
173
 
174
  # 2) 解析视频源路径 & 目标 <video_stem>
@@ -177,7 +212,7 @@ def gradio_infer(
177
  elif isinstance(bw_video, str):
178
  src_video_path = bw_video
179
  else:
180
- return None, "无法读取视频输入"
181
 
182
  video_stem = path.splitext(path.basename(src_video_path))[0]
183
 
@@ -195,7 +230,7 @@ def gradio_infer(
195
  try:
196
  _w, _h, fps, _n = video_to_frames_dir(src_video_path, input_frames_dir)
197
  except Exception as e:
198
- return None, f"抽帧失败\n{e}"
199
 
200
  # 5) 参考帧 -> ref/<stem>/ref.png
201
  ref_png_path = path.join(ref_dir, "ref.png")
@@ -203,14 +238,14 @@ def gradio_infer(
203
  try:
204
  ref_image.save(ref_png_path)
205
  except Exception as e:
206
- return None, f"保存参考图像失败\n{e}"
207
  elif isinstance(ref_image, str):
208
  try:
209
  shutil.copy2(ref_image, ref_png_path)
210
  except Exception as e:
211
- return None, f"复制参考图像失败\n{e}"
212
  else:
213
- return None, "无法读取参考图像输入"
214
 
215
  # 6) 收集 UI 配置
216
  default_config = {
@@ -259,7 +294,7 @@ def gradio_infer(
259
  try:
260
  import test # 确保 test.py 同目录且有 run_cli 函数
261
  except Exception as e:
262
- return None, f"导入 test.py 失败:\n{e}"
263
 
264
  args_list = build_args_list_for_test(
265
  d16_batch_path=input_root, # 指向 input_video 根
@@ -280,44 +315,52 @@ def gradio_infer(
280
  log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}"
281
  return None, log
282
 
283
- # 在合成 mp4 之前
284
- torch.cuda.synchronize()
 
 
 
285
  try:
286
  del network, processor, loader, vid_reader, data, rgb, msk, prob
287
  except Exception:
288
  pass
289
- torch.cuda.empty_cache()
 
 
 
290
 
291
  # 9) 合成 mp4:从 output/<stem>/ 帧合成 -> TEMP_ROOT/<stem>.mp4
292
  out_frames = path.join(output_root, video_stem)
293
  if not path.isdir(out_frames):
294
- return None, f"未找到输出帧目录:{out_frames}\n\n{log}"
295
  final_mp4 = path.join(TEMP_ROOT, f"{video_stem}.mp4")
296
  try:
297
  encode_frames_to_video(out_frames, final_mp4, fps=fps)
298
  except Exception as e:
299
- return None, f"合成视频失败:\n{e}\n\n{log}"
300
 
301
- return final_mp4, f"完成 ✅\n\n{log}"
302
 
303
  # ----------------- UI -----------------
304
  with gr.Blocks() as demo:
305
  gr.Markdown(f"# {TITLE}")
 
 
306
  gr.Markdown(DESC)
307
 
308
- debug_shapes = gr.Checkbox(label="调试日志(仅用于显示更完整日志)", value=False)
309
 
310
  with gr.Row():
311
- inp_video = gr.Video(label="黑白视频(mp4/webm/avi)", interactive=True)
312
- inp_ref = gr.Image(label="参考图像(RGB)", type="pil")
313
  gr.Examples(
314
- label="示例输入",
315
  examples=[["./example/4.mp4", "./example/4.png"]],
316
  inputs=[inp_video, inp_ref],
317
  cache_examples=False,
318
  )
319
 
320
- with gr.Accordion("高级参数设置(传给 test.py)", open=False):
321
  with gr.Row():
322
  first_not_exemplar = gr.Checkbox(label="FirstFrameIsNotExemplar (--FirstFrameIsNotExemplar)", value=True)
323
  reverse = gr.Checkbox(label="reverse (--reverse)", value=False)
@@ -339,10 +382,10 @@ with gr.Blocks() as demo:
339
  flip = gr.Checkbox(label="flip (--flip)", value=False)
340
  size = gr.Number(label="size (--size)", value=-1, precision=0)
341
 
342
- run_btn = gr.Button("开始着色(同进程调用 test.py)")
343
  with gr.Row():
344
- out_video = gr.Video(label="输出视频(着色结果)")
345
- status = gr.Textbox(label="状态 / 日志输出(test.py stdout/stderr)", interactive=False, lines=16)
346
 
347
  run_btn.click(
348
  fn=gradio_infer,
@@ -357,6 +400,10 @@ with gr.Blocks() as demo:
357
  outputs=[out_video, status]
358
  )
359
 
 
 
 
 
360
  if __name__ == "__main__":
361
  try:
362
  ensure_checkpoint()
 
17
  import spaces
18
  from PIL import Image
19
  import cv2
20
+ import torch # used for cuda sync & empty_cache
21
 
22
  # ----------------- BASIC INFO -----------------
23
  CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
24
  CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
25
 
26
+ TITLE = "ColorMNet — 视频着色 / Video Colorization (ZeroGPU, CUDA-only)"
27
  DESC = """
28
+ **中文**
29
+ 上传**黑白视频**与**参考图像**,点击「开始着色 / Start Coloring」。
30
  此版本在 **app.py 中调度 ZeroGPU**,并**在同一进程**调用 `test.py` 的入口函数。
31
+ 临时工作目录结构:
32
+ - 抽帧:`_colormnet_tmp/input_video/<视频名>/00000.png ...`
33
+ - 参考:`_colormnet_tmp/ref/<视频名>/ref.png`
34
+ - 输出:`_colormnet_tmp/output/<视频名>/*.png`
35
  - 合成视频:`_colormnet_tmp/<视频名>.mp4`
36
+
37
+ **English**
38
+ Upload a **B&W video** and a **reference image**, then click “Start Coloring”.
39
+ This app runs **ZeroGPU scheduling in `app.py`** and calls `test.py` **in-process**.
40
+ Temp workspace layout:
41
+ - Frames: `_colormnet_tmp/input_video/<stem>/00000.png ...`
42
+ - Reference: `_colormnet_tmp/ref/<stem>/ref.png`
43
+ - Output frames: `_colormnet_tmp/output/<stem>/*.png`
44
+ - Final video: `_colormnet_tmp/<stem>.mp4`
45
+ """
46
+
47
+ PAPER = """
48
+ ### 论文 / Paper
49
+ **ECCV 2024 — ColorMNet: A Memory-based Deep Spatial-Temporal Feature Propagation Network for Video Colorization**
50
+
51
+ 如果你喜欢这个项目,欢迎到 GitHub 点个 ⭐ Star:
52
+ **GitHub**: https://github.com/yyang181/colormnet
53
+ """
54
+
55
+ BADGES_HTML = """
56
+ <div style="display:flex;gap:12px;align-items:center;flex-wrap:wrap;">
57
+ <a href="https://github.com/yyang181/colormnet" target="_blank" title="Open GitHub Repo">
58
+ <img alt="GitHub Repo"
59
+ src="https://img.shields.io/badge/GitHub-colormnet-181717?logo=github" />
60
+ </a>
61
+ <a href="https://github.com/yyang181/colormnet/stargazers" target="_blank" title="Star on GitHub">
62
+ <img alt="GitHub Repo stars"
63
+ src="https://img.shields.io/github/stars/yyang181/colormnet?style=social" />
64
+ </a>
65
+ <span style="opacity:0.85">
66
+ 喜欢这个项目就点个 ⭐ Star / If you like it, please ⭐ Star!
67
+ </span>
68
+ </div>
69
  """
70
 
71
  # ----------------- TEMP WORKDIR -----------------
 
118
  out_path = path.join(frames_dir, f"{idx:05d}.png")
119
  ok = cv2.imwrite(out_path, frame)
120
  if not ok:
121
+ raise RuntimeError(f"写入抽帧失败 / Failed to write: {out_path}")
122
  idx += 1
123
  cap.release()
124
  if idx == 0:
125
+ raise RuntimeError("视频无可读帧 / Input video has no readable frames.")
126
  return w, h, fps, idx
127
 
128
  def encode_frames_to_video(frames_dir: str, out_path: str, fps: float):
 
201
  ):
202
  # 1) 基本校验与临时目录
203
  if bw_video is None:
204
+ return None, "请上传黑白视频 / Please upload a B&W video."
205
  if ref_image is None:
206
+ return None, "请上传参考图像 / Please upload a reference image."
207
  reset_temp_root()
208
 
209
  # 2) 解析视频源路径 & 目标 <video_stem>
 
212
  elif isinstance(bw_video, str):
213
  src_video_path = bw_video
214
  else:
215
+ return None, "无法读取视频输入 / Failed to read video input."
216
 
217
  video_stem = path.splitext(path.basename(src_video_path))[0]
218
 
 
230
  try:
231
  _w, _h, fps, _n = video_to_frames_dir(src_video_path, input_frames_dir)
232
  except Exception as e:
233
+ return None, f"抽帧失败 / Frame extraction failed:\n{e}"
234
 
235
  # 5) 参考帧 -> ref/<stem>/ref.png
236
  ref_png_path = path.join(ref_dir, "ref.png")
 
238
  try:
239
  ref_image.save(ref_png_path)
240
  except Exception as e:
241
+ return None, f"保存参考图像失败 / Failed to save reference image:\n{e}"
242
  elif isinstance(ref_image, str):
243
  try:
244
  shutil.copy2(ref_image, ref_png_path)
245
  except Exception as e:
246
+ return None, f"复制参考图像失败 / Failed to copy reference image:\n{e}"
247
  else:
248
+ return None, "无法读取参考图像输入 / Failed to read reference image."
249
 
250
  # 6) 收集 UI 配置
251
  default_config = {
 
294
  try:
295
  import test # 确保 test.py 同目录且有 run_cli 函数
296
  except Exception as e:
297
+ return None, f"导入 test.py 失败 / Failed to import test.py:\n{e}"
298
 
299
  args_list = build_args_list_for_test(
300
  d16_batch_path=input_root, # 指向 input_video 根
 
315
  log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}"
316
  return None, log
317
 
318
+ # 在合成 mp4 之前:清空 CUDA
319
+ try:
320
+ torch.cuda.synchronize()
321
+ except Exception:
322
+ pass
323
  try:
324
  del network, processor, loader, vid_reader, data, rgb, msk, prob
325
  except Exception:
326
  pass
327
+ try:
328
+ torch.cuda.empty_cache()
329
+ except Exception:
330
+ pass
331
 
332
  # 9) 合成 mp4:从 output/<stem>/ 帧合成 -> TEMP_ROOT/<stem>.mp4
333
  out_frames = path.join(output_root, video_stem)
334
  if not path.isdir(out_frames):
335
+ return None, f"未找到输出帧目录 / Output frame dir not found:{out_frames}\n\n{log}"
336
  final_mp4 = path.join(TEMP_ROOT, f"{video_stem}.mp4")
337
  try:
338
  encode_frames_to_video(out_frames, final_mp4, fps=fps)
339
  except Exception as e:
340
+ return None, f"合成视频失败 / Video mux failed:\n{e}\n\n{log}"
341
 
342
+ return final_mp4, f"完成 ✅ / Done ✅\n\n{log}"
343
 
344
  # ----------------- UI -----------------
345
  with gr.Blocks() as demo:
346
  gr.Markdown(f"# {TITLE}")
347
+ gr.HTML(BADGES_HTML) # 页头徽章
348
+ gr.Markdown(PAPER)
349
  gr.Markdown(DESC)
350
 
351
+ debug_shapes = gr.Checkbox(label="调试日志 / Debug Logs(仅用于显示更完整日志 / show verbose logs)", value=False)
352
 
353
  with gr.Row():
354
+ inp_video = gr.Video(label="黑白视频(mp4/webm/avi) / B&W Video", interactive=True)
355
+ inp_ref = gr.Image(label="参考图像(RGB) / Reference Image (RGB)", type="pil")
356
  gr.Examples(
357
+ label="示例 / Examples",
358
  examples=[["./example/4.mp4", "./example/4.png"]],
359
  inputs=[inp_video, inp_ref],
360
  cache_examples=False,
361
  )
362
 
363
+ with gr.Accordion("高级参数设置 / Advanced Settings(传给 test.py / passed to test.py)", open=False):
364
  with gr.Row():
365
  first_not_exemplar = gr.Checkbox(label="FirstFrameIsNotExemplar (--FirstFrameIsNotExemplar)", value=True)
366
  reverse = gr.Checkbox(label="reverse (--reverse)", value=False)
 
382
  flip = gr.Checkbox(label="flip (--flip)", value=False)
383
  size = gr.Number(label="size (--size)", value=-1, precision=0)
384
 
385
+ run_btn = gr.Button("开始着色 / Start Coloring(同进程调用 test.py / in-process)")
386
  with gr.Row():
387
+ out_video = gr.Video(label="输出视频(着色结果) / Output (Colorized)")
388
+ status = gr.Textbox(label="状态 / 日志输出(test.py stdout/stderr) / Status & Logs", interactive=False, lines=16)
389
 
390
  run_btn.click(
391
  fn=gradio_infer,
 
400
  outputs=[out_video, status]
401
  )
402
 
403
+ # 页脚徽章
404
+ gr.HTML("<hr/>")
405
+ gr.HTML(BADGES_HTML)
406
+
407
  if __name__ == "__main__":
408
  try:
409
  ensure_checkpoint()