Spaces:
Configuration error
Configuration error
backup
Browse files- .gitignore +1 -1
- app.py +182 -405
- inference/data/test_datasets.py +104 -24
- test.py +235 -174
.gitignore
CHANGED
|
@@ -11,7 +11,7 @@ Pytorch-Correlation-extension/
|
|
| 11 |
result
|
| 12 |
src/
|
| 13 |
DINOv2FeatureV6_LocalAtten_s2_154000.pth
|
| 14 |
-
|
| 15 |
|
| 16 |
# Byte-compiled / optimized / DLL files
|
| 17 |
__pycache__/
|
|
|
|
| 11 |
result
|
| 12 |
src/
|
| 13 |
DINOv2FeatureV6_LocalAtten_s2_154000.pth
|
| 14 |
+
_colormnet_tmp/
|
| 15 |
|
| 16 |
# Byte-compiled / optimized / DLL files
|
| 17 |
__pycache__/
|
app.py
CHANGED
|
@@ -1,432 +1,177 @@
|
|
| 1 |
-
# app.py
|
| 2 |
-
#
|
| 3 |
-
#
|
| 4 |
-
#
|
| 5 |
-
#
|
| 6 |
-
#
|
| 7 |
|
| 8 |
import os
|
| 9 |
import sys
|
| 10 |
import shutil
|
| 11 |
-
import subprocess
|
| 12 |
-
import uuid
|
| 13 |
import urllib.request
|
| 14 |
-
import warnings
|
| 15 |
from os import path
|
| 16 |
-
|
| 17 |
-
import
|
| 18 |
-
|
| 19 |
-
# # 1) 完全禁止 PyTorch 调用 NVML(ZeroGPU/MIG 下经常拿不到 NVML 句柄)
|
| 20 |
-
# os.environ.setdefault("PYTORCH_NO_NVML", "1")
|
| 21 |
-
# # 2) 用 cudaMallocAsync 后端,降低碎片/避免旧分配器的 NVML 路径
|
| 22 |
-
# os.environ.setdefault(
|
| 23 |
-
# "PYTORCH_CUDA_ALLOC_CONF",
|
| 24 |
-
# "backend:cudaMallocAsync,expandable_segments:True,garbage_collection_threshold:0.9,max_split_size_mb:64"
|
| 25 |
-
# )
|
| 26 |
-
# # (可选)定位更准:同步执行
|
| 27 |
-
# os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
|
| 28 |
-
# warnings.filterwarnings("ignore", message="The detected CUDA version .* minor version mismatch")
|
| 29 |
-
# warnings.filterwarnings("ignore", message="There are no g\\+\\+ version bounds defined for CUDA version.*")
|
| 30 |
-
# warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.cpp_extension")
|
| 31 |
-
# os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
|
| 32 |
-
# os.environ.setdefault("MAX_JOBS", "1")
|
| 33 |
|
| 34 |
import gradio as gr
|
| 35 |
-
import spaces
|
| 36 |
-
import numpy as np
|
| 37 |
from PIL import Image
|
| 38 |
import cv2
|
| 39 |
-
import traceback
|
| 40 |
|
| 41 |
-
|
| 42 |
-
import torch.nn.functional as F
|
| 43 |
-
from torch.utils.data import DataLoader
|
| 44 |
-
|
| 45 |
-
# ---- Project imports ----
|
| 46 |
-
from inference.data.test_datasets import DAVISTestDataset_221128_TransColorization_batch
|
| 47 |
-
from inference.data.mask_mapper import MaskMapper
|
| 48 |
-
from model.network import ColorMNet
|
| 49 |
-
from inference.inference_core import InferenceCore
|
| 50 |
-
from dataset.range_transform import inv_lll2rgb_trans
|
| 51 |
-
from skimage import color
|
| 52 |
-
|
| 53 |
-
# ----------------- CONFIG -----------------
|
| 54 |
CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
|
| 55 |
CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
|
| 56 |
|
| 57 |
TITLE = "ColorMNet — ZeroGPU (CUDA-only) Video Colorization with Reference Image"
|
| 58 |
DESC = """
|
| 59 |
上传**黑白视频**与**参考图像**,点击“开始着色”。
|
| 60 |
-
本
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
-
|
| 64 |
-
-
|
|
|
|
| 65 |
"""
|
| 66 |
|
| 67 |
# ----------------- TEMP WORKDIR -----------------
|
| 68 |
TEMP_ROOT = path.join(os.getcwd(), "_colormnet_tmp")
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
def reset_temp_root():
|
| 71 |
"""每次运行前清空并重建临时工作目录。"""
|
| 72 |
if path.isdir(TEMP_ROOT):
|
| 73 |
shutil.rmtree(TEMP_ROOT, ignore_errors=True)
|
| 74 |
os.makedirs(TEMP_ROOT, exist_ok=True)
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
# ----------------- DEBUG (kept) -----------------
|
| 79 |
-
def _enable_runtime_debug():
|
| 80 |
-
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # 同步执行,定位准确
|
| 81 |
-
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" # 显示 C++ 栈
|
| 82 |
-
os.environ["PYTORCH_JIT"] = "0" # 关闭 JIT
|
| 83 |
-
try:
|
| 84 |
-
torch.autograd.set_detect_anomaly(True) # 捕捉无效 op/grad
|
| 85 |
-
except Exception:
|
| 86 |
-
pass
|
| 87 |
-
|
| 88 |
-
# ----------------- PATH/DIR UTILS -----------------
|
| 89 |
-
def ensure_clean_dir(d: str):
|
| 90 |
-
if path.exists(d):
|
| 91 |
-
if path.isdir(d):
|
| 92 |
-
return
|
| 93 |
-
else:
|
| 94 |
-
os.remove(d)
|
| 95 |
os.makedirs(d, exist_ok=True)
|
| 96 |
|
| 97 |
-
# -----------------
|
| 98 |
def ensure_checkpoint():
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def tensor_to_np_float(image: torch.Tensor) -> np.ndarray:
|
| 108 |
-
image_np = image.numpy().astype("float32")
|
| 109 |
-
return image_np
|
| 110 |
-
|
| 111 |
-
def lab2rgb_transform_PIL(mask: torch.Tensor) -> np.ndarray:
|
| 112 |
-
mask_d = detach_to_cpu(mask)
|
| 113 |
-
mask_d = inv_lll2rgb_trans(mask_d)
|
| 114 |
-
im = tensor_to_np_float(mask_d)
|
| 115 |
-
if len(im.shape) == 3:
|
| 116 |
-
im = im.transpose((1, 2, 0))
|
| 117 |
-
else:
|
| 118 |
-
im = im[:, :, None]
|
| 119 |
-
im = color.lab2rgb(im)
|
| 120 |
-
return im.clip(0, 1)
|
| 121 |
|
| 122 |
-
# ----------
|
| 123 |
-
def
|
| 124 |
"""
|
| 125 |
-
|
| 126 |
-
返回: (
|
| 127 |
"""
|
| 128 |
-
|
| 129 |
-
basename = path.basename(video_path)
|
| 130 |
-
stem, _ = path.splitext(basename)
|
| 131 |
-
subdir = path.join(dataset_root, stem)
|
| 132 |
-
ensure_clean_dir(subdir)
|
| 133 |
-
|
| 134 |
cap = cv2.VideoCapture(video_path)
|
| 135 |
assert cap.isOpened(), f"Cannot open video: {video_path}"
|
| 136 |
-
|
| 137 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 138 |
-
if not fps or fps <= 0:
|
| 139 |
-
fps = 25.0
|
| 140 |
-
|
| 141 |
idx = 0
|
| 142 |
w = h = None
|
| 143 |
-
|
| 144 |
while True:
|
| 145 |
ret, frame = cap.read()
|
| 146 |
if not ret:
|
| 147 |
break
|
| 148 |
if frame is None:
|
| 149 |
continue
|
| 150 |
-
|
| 151 |
h, w = frame.shape[:2]
|
| 152 |
-
out_path = path.join(
|
| 153 |
-
|
| 154 |
-
parent = path.dirname(out_path)
|
| 155 |
-
if not path.isdir(parent):
|
| 156 |
-
if path.exists(parent):
|
| 157 |
-
os.remove(parent)
|
| 158 |
-
os.makedirs(parent, exist_ok=True)
|
| 159 |
-
|
| 160 |
ok = cv2.imwrite(out_path, frame)
|
| 161 |
if not ok:
|
| 162 |
raise RuntimeError(f"写入抽帧失败: {out_path}")
|
| 163 |
idx += 1
|
| 164 |
-
|
| 165 |
cap.release()
|
| 166 |
if idx == 0:
|
| 167 |
raise RuntimeError("Input video has no readable frames.")
|
| 168 |
-
|
| 169 |
-
return subdir, path.splitext(path.basename(video_path))[0], w, h, fps, idx
|
| 170 |
-
|
| 171 |
-
# ---------- place ref image into ref_root/<video_stem>/ref.png ----------
|
| 172 |
-
def ref_to_dataset_root(ref_image_path: str, ref_root: str, video_stem: str):
|
| 173 |
-
ensure_clean_dir(ref_root)
|
| 174 |
-
subdir = path.join(ref_root, video_stem)
|
| 175 |
-
ensure_clean_dir(subdir)
|
| 176 |
-
|
| 177 |
-
img = Image.open(ref_image_path).convert("RGB")
|
| 178 |
-
out_path = path.join(subdir, "ref.png")
|
| 179 |
-
img.save(out_path)
|
| 180 |
-
return subdir
|
| 181 |
|
| 182 |
def encode_frames_to_video(frames_dir: str, out_path: str, fps: float):
|
| 183 |
frames = sorted([f for f in os.listdir(frames_dir) if f.lower().endswith(".png")])
|
| 184 |
-
|
| 185 |
-
|
| 186 |
first = cv2.imread(path.join(frames_dir, frames[0]))
|
|
|
|
|
|
|
| 187 |
h, w = first.shape[:2]
|
| 188 |
-
|
| 189 |
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 190 |
vw = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
|
| 191 |
for f in frames:
|
| 192 |
img = cv2.imread(path.join(frames_dir, f))
|
|
|
|
|
|
|
| 193 |
vw.write(img)
|
| 194 |
vw.release()
|
| 195 |
|
| 196 |
-
# -----------------
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
"
|
| 228 |
-
"
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
"max_long_term_elements": 10000,
|
| 240 |
-
"num_prototypes": 128,
|
| 241 |
-
"top_k": 30,
|
| 242 |
-
"mem_every": 5,
|
| 243 |
-
"deep_update_every": -1,
|
| 244 |
-
"save_scores": False,
|
| 245 |
-
"flip": False,
|
| 246 |
-
"size": -1,
|
| 247 |
-
"reverse": False,
|
| 248 |
-
}
|
| 249 |
-
config = {**default_config, **(user_config or {})}
|
| 250 |
-
config["enable_long_term"] = not config["disable_long_term"]
|
| 251 |
-
|
| 252 |
-
# 4) 构建数据集(只选本视频 reader)
|
| 253 |
-
meta_dataset = DAVISTestDataset_221128_TransColorization_batch(
|
| 254 |
-
input_video_root, imset=input_ref_root, size=config["size"]
|
| 255 |
-
)
|
| 256 |
-
meta_loader = meta_dataset.get_datasets()
|
| 257 |
-
|
| 258 |
-
# 输出路径规则(与 main.py 一致)
|
| 259 |
-
is_youtube = str(config["dataset"]).startswith("Y")
|
| 260 |
-
is_davis = str(config["dataset"]).startswith("D")
|
| 261 |
-
is_lv = str(config["dataset"]).startswith("LV")
|
| 262 |
-
|
| 263 |
-
app_output_root = output_dir
|
| 264 |
-
if is_youtube or config["save_scores"]:
|
| 265 |
-
out_path = path.join(app_output_root, "Annotations")
|
| 266 |
-
else:
|
| 267 |
-
out_path = app_output_root
|
| 268 |
-
|
| 269 |
-
# 5) 模型(保持 app 的 URL 权重加载方式)
|
| 270 |
-
network = ColorMNet(config, CHECKPOINT_LOCAL).to(DEVICE).eval()
|
| 271 |
-
model_weights = torch.load(CHECKPOINT_LOCAL, map_location="cuda")
|
| 272 |
-
network.load_weights(model_weights, init_as_zero_if_needed=True)
|
| 273 |
-
|
| 274 |
-
total_process_time = 0.0
|
| 275 |
-
total_frames = 0
|
| 276 |
-
|
| 277 |
-
for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
|
| 278 |
-
# 6) 推理(逐帧;内部逻辑与 main.py 对齐;保留调试打印)
|
| 279 |
-
# Gradio/Spaces 环境禁止子进程:num_workers=0(否则会触发 daemonic processes 错误)
|
| 280 |
-
loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
|
| 281 |
-
vid_name = vid_reader.vid_name
|
| 282 |
-
vid_length = len(loader)
|
| 283 |
-
|
| 284 |
-
# 长时记忆触发逻辑:按 main.py 原样(无除零保护)
|
| 285 |
-
config['enable_long_term_count_usage'] = (
|
| 286 |
-
config['enable_long_term'] and
|
| 287 |
-
(vid_length
|
| 288 |
-
/ (config['max_mid_term_frames'] - config['min_mid_term_frames'])
|
| 289 |
-
* config['num_prototypes'])
|
| 290 |
-
>= config['max_long_term_elements']
|
| 291 |
-
)
|
| 292 |
-
|
| 293 |
-
mapper = MaskMapper()
|
| 294 |
-
processor = InferenceCore(network, config=config)
|
| 295 |
-
first_mask_loaded = False
|
| 296 |
-
|
| 297 |
-
for ti, data in enumerate(loader):
|
| 298 |
-
try:
|
| 299 |
-
with torch.cuda.amp.autocast(enabled=not config["benchmark"]):
|
| 300 |
-
rgb = data['rgb'].cuda()[0]
|
| 301 |
-
msk = data.get('mask')
|
| 302 |
-
if not config['FirstFrameIsNotExemplar']:
|
| 303 |
-
msk = msk[:, 1:3, :, :] if msk is not None else None
|
| 304 |
-
|
| 305 |
-
info = data['info']
|
| 306 |
-
frame = info['frame'][0]
|
| 307 |
-
shape = info['shape']
|
| 308 |
-
need_resize = info['need_resize'][0]
|
| 309 |
-
|
| 310 |
-
if debug_shapes:
|
| 311 |
-
print(f"[Loop] frame={ti} rgb={tuple(rgb.shape)} "
|
| 312 |
-
f"msk={None if msk is None else tuple(msk.shape)}", flush=True)
|
| 313 |
-
|
| 314 |
-
# timing 与 main.py 一致
|
| 315 |
-
start = torch.cuda.Event(enable_timing=True)
|
| 316 |
-
end = torch.cuda.Event(enable_timing=True)
|
| 317 |
-
start.record()
|
| 318 |
-
|
| 319 |
-
if not first_mask_loaded:
|
| 320 |
-
if msk is not None:
|
| 321 |
-
first_mask_loaded = True
|
| 322 |
-
else:
|
| 323 |
-
continue
|
| 324 |
-
|
| 325 |
-
if config['flip']:
|
| 326 |
-
rgb = torch.flip(rgb, dims=[-1])
|
| 327 |
-
msk = torch.flip(msk, dims=[-1]) if msk is not None else None
|
| 328 |
-
|
| 329 |
-
if msk is not None:
|
| 330 |
-
msk = torch.Tensor(msk[0]).cuda()
|
| 331 |
-
if need_resize:
|
| 332 |
-
msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
|
| 333 |
-
processor.set_all_labels(list(range(1, 3)))
|
| 334 |
-
labels = range(1, 3)
|
| 335 |
-
else:
|
| 336 |
-
labels = None
|
| 337 |
-
|
| 338 |
-
if config['FirstFrameIsNotExemplar']:
|
| 339 |
-
prob = processor.step_AnyExemplar(
|
| 340 |
-
rgb,
|
| 341 |
-
msk[:1, :, :].repeat(3, 1, 1) if msk is not None else None,
|
| 342 |
-
msk[1:3, :, :] if msk is not None else None,
|
| 343 |
-
labels,
|
| 344 |
-
end=(ti == vid_length - 1)
|
| 345 |
-
)
|
| 346 |
-
else:
|
| 347 |
-
prob = processor.step(rgb, msk, labels, end=(ti == vid_length - 1))
|
| 348 |
-
|
| 349 |
-
if need_resize:
|
| 350 |
-
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0]
|
| 351 |
-
|
| 352 |
-
end.record()
|
| 353 |
-
torch.cuda.synchronize()
|
| 354 |
-
total_process_time += (start.elapsed_time(end) / 1000.0)
|
| 355 |
-
total_frames += 1
|
| 356 |
-
|
| 357 |
-
if config['flip']:
|
| 358 |
-
prob = torch.flip(prob, dims=[-1])
|
| 359 |
-
|
| 360 |
-
if debug_shapes:
|
| 361 |
-
try:
|
| 362 |
-
print(f"[Loop] prob={tuple(prob.shape)}", flush=True)
|
| 363 |
-
except Exception:
|
| 364 |
-
pass
|
| 365 |
-
|
| 366 |
-
if config['save_scores']:
|
| 367 |
-
prob = (prob.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 368 |
-
|
| 369 |
-
if config['save_all'] or info['save'][0]:
|
| 370 |
-
this_out_path = path.join(out_path, vid_name)
|
| 371 |
-
os.makedirs(this_out_path, exist_ok=True)
|
| 372 |
-
|
| 373 |
-
out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1, :, :], prob], dim=0))
|
| 374 |
-
out_mask_final = (out_mask_final * 255).astype(np.uint8)
|
| 375 |
-
Image.fromarray(out_mask_final).save(os.path.join(this_out_path, frame[:-4] + '.png'))
|
| 376 |
-
|
| 377 |
-
except Exception as _e:
|
| 378 |
-
# 保留完整 traceback,方便定位
|
| 379 |
-
raise RuntimeError("FRAME_ERROR:\n" + traceback.format_exc())
|
| 380 |
-
|
| 381 |
-
if total_process_time > 0:
|
| 382 |
-
print(f'Total processing time: {total_process_time}')
|
| 383 |
-
print(f'Total processed frames: {total_frames}')
|
| 384 |
-
print(f'FPS: {total_frames / total_process_time}')
|
| 385 |
-
print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
|
| 386 |
-
|
| 387 |
-
# 7) 合成 mp4(按 main.py 的 out_path 规则找帧目录)
|
| 388 |
-
frames_dir = path.join(out_path, vid_stem if path.isdir(path.join(out_path, vid_stem)) else vid_name)
|
| 389 |
-
if not path.isdir(frames_dir):
|
| 390 |
-
subs = [d for d in os.listdir(out_path) if path.isdir(path.join(out_path, d))]
|
| 391 |
-
if len(subs) == 1:
|
| 392 |
-
frames_dir = path.join(out_path, subs[0])
|
| 393 |
else:
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
colored_mp4 = path.join(base_run_dir, "colored_output.mp4")
|
| 397 |
-
encode_frames_to_video(frames_dir, colored_mp4, fps=fps)
|
| 398 |
-
|
| 399 |
-
# 8) 输出视频到 CWD(只保留最终文件)
|
| 400 |
-
final_mp4 = path.join(os.getcwd(), "result.mp4")
|
| 401 |
-
shutil.move(colored_mp4, final_mp4)
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
return final_mp4
|
| 407 |
-
|
| 408 |
-
# ----------------- GRADIO HANDLERS -----------------
|
| 409 |
-
@spaces.GPU(duration=600)
|
| 410 |
def gradio_infer(
|
| 411 |
-
debug_shapes,
|
| 412 |
bw_video, ref_image,
|
| 413 |
first_not_exemplar, dataset, split, save_all, benchmark,
|
| 414 |
disable_long_term, max_mid, min_mid, max_long,
|
| 415 |
num_proto, top_k, mem_every, deep_update,
|
| 416 |
-
save_scores, flip, size, reverse
|
| 417 |
):
|
| 418 |
-
|
| 419 |
-
return None, "ZeroGPU 未分配到 GPU,请重试(或检查 Space 硬件是否为 ZeroGPU)。"
|
| 420 |
-
|
| 421 |
if bw_video is None:
|
| 422 |
return None, "请上传黑白视频。"
|
| 423 |
if ref_image is None:
|
| 424 |
return None, "请上传参考图像。"
|
| 425 |
-
|
| 426 |
-
# —— 每次运行先重置临时目录 —— #
|
| 427 |
reset_temp_root()
|
| 428 |
|
| 429 |
-
#
|
| 430 |
if isinstance(bw_video, dict) and "name" in bw_video:
|
| 431 |
src_video_path = bw_video["name"]
|
| 432 |
elif isinstance(bw_video, str):
|
|
@@ -434,28 +179,40 @@ def gradio_infer(
|
|
| 434 |
else:
|
| 435 |
return None, "无法读取视频输入。"
|
| 436 |
|
| 437 |
-
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
try:
|
| 440 |
-
|
| 441 |
except Exception as e:
|
| 442 |
-
return None, f"
|
| 443 |
|
| 444 |
-
#
|
| 445 |
-
|
| 446 |
if isinstance(ref_image, Image.Image):
|
| 447 |
try:
|
| 448 |
-
ref_image.save(
|
| 449 |
except Exception as e:
|
| 450 |
-
return None, f"保存参考图像
|
| 451 |
elif isinstance(ref_image, str):
|
| 452 |
try:
|
| 453 |
-
shutil.copy2(ref_image,
|
| 454 |
except Exception as e:
|
| 455 |
-
return None, f"复制参考图像
|
| 456 |
else:
|
| 457 |
return None, "无法读取参考图像输入。"
|
| 458 |
|
|
|
|
| 459 |
default_config = {
|
| 460 |
"FirstFrameIsNotExemplar": True,
|
| 461 |
"dataset": "D16_batch",
|
|
@@ -473,8 +230,8 @@ def gradio_infer(
|
|
| 473 |
"save_scores": False,
|
| 474 |
"flip": False,
|
| 475 |
"size": -1,
|
|
|
|
| 476 |
}
|
| 477 |
-
|
| 478 |
user_config = {
|
| 479 |
"FirstFrameIsNotExemplar": bool(first_not_exemplar) if first_not_exemplar is not None else default_config["FirstFrameIsNotExemplar"],
|
| 480 |
"dataset": str(dataset) if dataset else default_config["dataset"],
|
|
@@ -492,71 +249,92 @@ def gradio_infer(
|
|
| 492 |
"save_scores": bool(save_scores) if save_scores is not None else default_config["save_scores"],
|
| 493 |
"flip": bool(flip) if flip is not None else default_config["flip"],
|
| 494 |
"size": int(size) if size is not None else default_config["size"],
|
| 495 |
-
"reverse": bool(reverse) if reverse is not None else
|
| 496 |
}
|
| 497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
try:
|
| 499 |
-
|
| 500 |
-
tmp_video_path, tmp_ref_path, user_config, debug_shapes=bool(debug_shapes)
|
| 501 |
-
)
|
| 502 |
-
return out_mp4, "完成 ✅"
|
| 503 |
-
except subprocess.CalledProcessError as e:
|
| 504 |
-
# 出错也可以顺手清一下临时目录(可选)
|
| 505 |
-
try: shutil.rmtree(TEMP_ROOT, ignore_errors=True)
|
| 506 |
-
except: pass
|
| 507 |
-
return None, f"运行时错误:\n{e}"
|
| 508 |
except Exception as e:
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
# ----------------- UI -----------------
|
| 514 |
with gr.Blocks() as demo:
|
| 515 |
gr.Markdown(f"# {TITLE}")
|
| 516 |
gr.Markdown(DESC)
|
| 517 |
|
| 518 |
-
debug_shapes = gr.Checkbox(label="调试日志(
|
| 519 |
|
| 520 |
with gr.Row():
|
| 521 |
inp_video = gr.Video(label="黑白视频(mp4/webm/avi)", interactive=True)
|
| 522 |
inp_ref = gr.Image(label="参考图像(RGB)", type="pil")
|
| 523 |
-
|
| 524 |
gr.Examples(
|
| 525 |
label="示例输入",
|
| 526 |
-
examples=[
|
| 527 |
-
["./example/4.mp4", "./example/4.png"],
|
| 528 |
-
],
|
| 529 |
inputs=[inp_video, inp_ref],
|
| 530 |
-
# 不缓存,避免把推理结果当静态示例
|
| 531 |
cache_examples=False,
|
| 532 |
)
|
| 533 |
|
| 534 |
-
with gr.Accordion("高级参数设置(
|
| 535 |
with gr.Row():
|
| 536 |
-
first_not_exemplar = gr.Checkbox(label="FirstFrameIsNotExemplar", value=True)
|
| 537 |
-
reverse = gr.Checkbox(label="reverse", value=False)
|
| 538 |
-
dataset = gr.Textbox(label="dataset", value="D16_batch")
|
| 539 |
-
split = gr.Textbox(label="split", value="val")
|
| 540 |
-
save_all = gr.Checkbox(label="save_all", value=True)
|
| 541 |
-
benchmark = gr.Checkbox(label="benchmark", value=False)
|
| 542 |
with gr.Row():
|
| 543 |
-
disable_long_term = gr.Checkbox(label="disable_long_term", value=False)
|
| 544 |
-
max_mid = gr.Number(label="max_mid_term_frames", value=10, precision=0)
|
| 545 |
-
min_mid = gr.Number(label="min_mid_term_frames", value=5, precision=0)
|
| 546 |
-
max_long = gr.Number(label="max_long_term_elements", value=10000, precision=0)
|
| 547 |
-
num_proto = gr.Number(label="num_prototypes", value=128, precision=0)
|
| 548 |
with gr.Row():
|
| 549 |
-
top_k = gr.Number(label="top_k", value=30, precision=0)
|
| 550 |
-
mem_every = gr.Number(label="mem_every", value=5, precision=0)
|
| 551 |
-
deep_update = gr.Number(label="deep_update_every", value=-1, precision=0)
|
| 552 |
-
save_scores = gr.Checkbox(label="save_scores", value=False)
|
| 553 |
-
flip = gr.Checkbox(label="flip", value=False)
|
| 554 |
-
size = gr.Number(label="size", value=-1, precision=0)
|
| 555 |
-
|
| 556 |
-
run_btn = gr.Button("开始着色(
|
| 557 |
with gr.Row():
|
| 558 |
out_video = gr.Video(label="输出视频(着色结果)")
|
| 559 |
-
status = gr.Textbox(label="状态 /
|
| 560 |
|
| 561 |
run_btn.click(
|
| 562 |
fn=gradio_infer,
|
|
@@ -566,7 +344,7 @@ with gr.Blocks() as demo:
|
|
| 566 |
first_not_exemplar, dataset, split, save_all, benchmark,
|
| 567 |
disable_long_term, max_mid, min_mid, max_long,
|
| 568 |
num_proto, top_k, mem_every, deep_update,
|
| 569 |
-
save_scores, flip, size, reverse
|
| 570 |
],
|
| 571 |
outputs=[out_video, status]
|
| 572 |
)
|
|
@@ -576,5 +354,4 @@ if __name__ == "__main__":
|
|
| 576 |
ensure_checkpoint()
|
| 577 |
except Exception as e:
|
| 578 |
print(f"[WARN] 预下载权重失败(首次推理会再试): {e}")
|
| 579 |
-
|
| 580 |
-
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 1 |
+
# app.py — Gradio front-end that calls test.py IN-PROCESS (ZeroGPU-safe)
|
| 2 |
+
# Folder layout per run (under TEMP_ROOT):
|
| 3 |
+
# input_video/<video_stem>/00000.png ...
|
| 4 |
+
# ref/<video_stem>/ref.png
|
| 5 |
+
# output/<video_stem>/*.png
|
| 6 |
+
# Final mp4: TEMP_ROOT/<video_stem>.mp4
|
| 7 |
|
| 8 |
import os
|
| 9 |
import sys
|
| 10 |
import shutil
|
|
|
|
|
|
|
| 11 |
import urllib.request
|
|
|
|
| 12 |
from os import path
|
| 13 |
+
import io
|
| 14 |
+
from contextlib import redirect_stdout, redirect_stderr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
import gradio as gr
|
| 17 |
+
import spaces
|
|
|
|
| 18 |
from PIL import Image
|
| 19 |
import cv2
|
|
|
|
| 20 |
|
| 21 |
+
# ----------------- BASIC INFO -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
|
| 23 |
CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
|
| 24 |
|
| 25 |
TITLE = "ColorMNet — ZeroGPU (CUDA-only) Video Colorization with Reference Image"
|
| 26 |
DESC = """
|
| 27 |
上传**黑白视频**与**参考图像**,点击“开始着色”。
|
| 28 |
+
此版本在 **app.py 中调度 ZeroGPU**,并**在同一进程**调用 `test.py` 的入口函数。
|
| 29 |
+
临时工作目录结构:
|
| 30 |
+
- 抽帧:`_colormnet_tmp/input_video/<视频名>/00000.png ...`
|
| 31 |
+
- 参考:`_colormnet_tmp/ref/<视频名>/ref.png`
|
| 32 |
+
- 输出:`_colormnet_tmp/output/<视频名>/*.png`
|
| 33 |
+
- 合成视频:`_colormnet_tmp/<视频名>.mp4`
|
| 34 |
"""
|
| 35 |
|
| 36 |
# ----------------- TEMP WORKDIR -----------------
|
| 37 |
TEMP_ROOT = path.join(os.getcwd(), "_colormnet_tmp")
|
| 38 |
+
INPUT_DIR = "input_video"
|
| 39 |
+
REF_DIR = "ref"
|
| 40 |
+
OUTPUT_DIR = "output"
|
| 41 |
|
| 42 |
def reset_temp_root():
|
| 43 |
"""每次运行前清空并重建临时工作目录。"""
|
| 44 |
if path.isdir(TEMP_ROOT):
|
| 45 |
shutil.rmtree(TEMP_ROOT, ignore_errors=True)
|
| 46 |
os.makedirs(TEMP_ROOT, exist_ok=True)
|
| 47 |
+
for sub in (INPUT_DIR, REF_DIR, OUTPUT_DIR):
|
| 48 |
+
os.makedirs(path.join(TEMP_ROOT, sub), exist_ok=True)
|
| 49 |
|
| 50 |
+
def ensure_dir(d: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
os.makedirs(d, exist_ok=True)
|
| 52 |
|
| 53 |
+
# ----------------- CHECKPOINT (可选) -----------------
|
| 54 |
def ensure_checkpoint():
|
| 55 |
+
"""若 test.py 会在当前目录加载权重,可提前预下载,避免首次拉取超时。"""
|
| 56 |
+
try:
|
| 57 |
+
if not path.exists(CHECKPOINT_LOCAL):
|
| 58 |
+
print(f"[INFO] Downloading checkpoint from: {CHECKPOINT_URL}")
|
| 59 |
+
urllib.request.urlretrieve(CHECKPOINT_URL, CHECKPOINT_LOCAL)
|
| 60 |
+
print("[INFO] Checkpoint downloaded:", CHECKPOINT_LOCAL)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"[WARN] 预下载权重失败(首次推理会再试): {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
# ----------------- VIDEO UTILS -----------------
|
| 65 |
+
def video_to_frames_dir(video_path: str, frames_dir: str):
|
| 66 |
"""
|
| 67 |
+
抽帧到 frames_dir/00000.png ...
|
| 68 |
+
返回: (w, h, fps, n_frames)
|
| 69 |
"""
|
| 70 |
+
ensure_dir(frames_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
cap = cv2.VideoCapture(video_path)
|
| 72 |
assert cap.isOpened(), f"Cannot open video: {video_path}"
|
| 73 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
idx = 0
|
| 75 |
w = h = None
|
|
|
|
| 76 |
while True:
|
| 77 |
ret, frame = cap.read()
|
| 78 |
if not ret:
|
| 79 |
break
|
| 80 |
if frame is None:
|
| 81 |
continue
|
|
|
|
| 82 |
h, w = frame.shape[:2]
|
| 83 |
+
out_path = path.join(frames_dir, f"{idx:05d}.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
ok = cv2.imwrite(out_path, frame)
|
| 85 |
if not ok:
|
| 86 |
raise RuntimeError(f"写入抽帧失败: {out_path}")
|
| 87 |
idx += 1
|
|
|
|
| 88 |
cap.release()
|
| 89 |
if idx == 0:
|
| 90 |
raise RuntimeError("Input video has no readable frames.")
|
| 91 |
+
return w, h, fps, idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
def encode_frames_to_video(frames_dir: str, out_path: str, fps: float):
|
| 94 |
frames = sorted([f for f in os.listdir(frames_dir) if f.lower().endswith(".png")])
|
| 95 |
+
if not frames:
|
| 96 |
+
raise RuntimeError(f"No frames found in {frames_dir}")
|
| 97 |
first = cv2.imread(path.join(frames_dir, frames[0]))
|
| 98 |
+
if first is None:
|
| 99 |
+
raise RuntimeError(f"Failed to read first frame {frames[0]}")
|
| 100 |
h, w = first.shape[:2]
|
|
|
|
| 101 |
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 102 |
vw = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
|
| 103 |
for f in frames:
|
| 104 |
img = cv2.imread(path.join(frames_dir, f))
|
| 105 |
+
if img is None:
|
| 106 |
+
continue
|
| 107 |
vw.write(img)
|
| 108 |
vw.release()
|
| 109 |
|
| 110 |
+
# ----------------- CLI MAPPING -----------------
|
| 111 |
+
CONFIG_TO_CLI = {
|
| 112 |
+
"FirstFrameIsNotExemplar": "--FirstFrameIsNotExemplar", # bool
|
| 113 |
+
"dataset": "--dataset",
|
| 114 |
+
"split": "--split",
|
| 115 |
+
"save_all": "--save_all", # bool
|
| 116 |
+
"benchmark": "--benchmark", # bool
|
| 117 |
+
"disable_long_term": "--disable_long_term", # bool
|
| 118 |
+
"max_mid_term_frames": "--max_mid_term_frames",
|
| 119 |
+
"min_mid_term_frames": "--min_mid_term_frames",
|
| 120 |
+
"max_long_term_elements": "--max_long_term_elements",
|
| 121 |
+
"num_prototypes": "--num_prototypes",
|
| 122 |
+
"top_k": "--top_k",
|
| 123 |
+
"mem_every": "--mem_every",
|
| 124 |
+
"deep_update_every": "--deep_update_every",
|
| 125 |
+
"save_scores": "--save_scores", # bool
|
| 126 |
+
"flip": "--flip", # bool
|
| 127 |
+
"size": "--size",
|
| 128 |
+
"reverse": "--reverse", # bool
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
def build_args_list_for_test(d16_batch_path: str,
|
| 132 |
+
out_path: str,
|
| 133 |
+
ref_root: str,
|
| 134 |
+
cfg: dict):
|
| 135 |
+
"""
|
| 136 |
+
构造传给 test.run_cli(args_list) 的参数列表。
|
| 137 |
+
- 必传:--d16_batch_path <input_video_root>、--ref_path <ref_root>、--output <output_root>
|
| 138 |
+
"""
|
| 139 |
+
args = [
|
| 140 |
+
"--d16_batch_path", d16_batch_path,
|
| 141 |
+
"--ref_path", ref_root,
|
| 142 |
+
"--output", out_path,
|
| 143 |
+
]
|
| 144 |
+
for k, v in cfg.items():
|
| 145 |
+
if k not in CONFIG_TO_CLI:
|
| 146 |
+
continue
|
| 147 |
+
flag = CONFIG_TO_CLI[k]
|
| 148 |
+
if isinstance(v, bool):
|
| 149 |
+
if v:
|
| 150 |
+
args.append(flag) # store_true
|
| 151 |
+
elif v is None:
|
| 152 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
else:
|
| 154 |
+
args.extend([flag, str(v)])
|
| 155 |
+
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
# ----------------- GRADIO HANDLER -----------------
|
| 158 |
+
@spaces.GPU(duration=160) # 确保 CUDA 初始化在此函数体内
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
def gradio_infer(
|
| 160 |
+
debug_shapes,
|
| 161 |
bw_video, ref_image,
|
| 162 |
first_not_exemplar, dataset, split, save_all, benchmark,
|
| 163 |
disable_long_term, max_mid, min_mid, max_long,
|
| 164 |
num_proto, top_k, mem_every, deep_update,
|
| 165 |
+
save_scores, flip, size, reverse
|
| 166 |
):
|
| 167 |
+
# 1) 基本校验与临时目录
|
|
|
|
|
|
|
| 168 |
if bw_video is None:
|
| 169 |
return None, "请上传黑白视频。"
|
| 170 |
if ref_image is None:
|
| 171 |
return None, "请上传参考图像。"
|
|
|
|
|
|
|
| 172 |
reset_temp_root()
|
| 173 |
|
| 174 |
+
# 2) 解析视频源路径 & 目标 <video_stem>
|
| 175 |
if isinstance(bw_video, dict) and "name" in bw_video:
|
| 176 |
src_video_path = bw_video["name"]
|
| 177 |
elif isinstance(bw_video, str):
|
|
|
|
| 179 |
else:
|
| 180 |
return None, "无法读取视频输入。"
|
| 181 |
|
| 182 |
+
video_stem = path.splitext(path.basename(src_video_path))[0]
|
| 183 |
+
|
| 184 |
+
# 3) 生成临时路径
|
| 185 |
+
input_root = path.join(TEMP_ROOT, INPUT_DIR) # _colormnet_tmp/input_video
|
| 186 |
+
ref_root = path.join(TEMP_ROOT, REF_DIR) # _colormnet_tmp/ref
|
| 187 |
+
output_root= path.join(TEMP_ROOT, OUTPUT_DIR) # _colormnet_tmp/output
|
| 188 |
+
input_frames_dir = path.join(input_root, video_stem)
|
| 189 |
+
ref_dir = path.join(ref_root, video_stem)
|
| 190 |
+
out_frames_dir = path.join(output_root, video_stem)
|
| 191 |
+
for d in (input_root, ref_root, output_root, input_frames_dir, ref_dir, out_frames_dir):
|
| 192 |
+
ensure_dir(d)
|
| 193 |
+
|
| 194 |
+
# 4) 抽帧 -> input_video/<stem>/
|
| 195 |
try:
|
| 196 |
+
_w, _h, fps, _n = video_to_frames_dir(src_video_path, input_frames_dir)
|
| 197 |
except Exception as e:
|
| 198 |
+
return None, f"抽帧失败:\n{e}"
|
| 199 |
|
| 200 |
+
# 5) 参考帧 -> ref/<stem>/ref.png
|
| 201 |
+
ref_png_path = path.join(ref_dir, "ref.png")
|
| 202 |
if isinstance(ref_image, Image.Image):
|
| 203 |
try:
|
| 204 |
+
ref_image.save(ref_png_path)
|
| 205 |
except Exception as e:
|
| 206 |
+
return None, f"保存参考图像失败:\n{e}"
|
| 207 |
elif isinstance(ref_image, str):
|
| 208 |
try:
|
| 209 |
+
shutil.copy2(ref_image, ref_png_path)
|
| 210 |
except Exception as e:
|
| 211 |
+
return None, f"复制参考图像失败:\n{e}"
|
| 212 |
else:
|
| 213 |
return None, "无法读取参考图像输入。"
|
| 214 |
|
| 215 |
+
# 6) 收集 UI 配置
|
| 216 |
default_config = {
|
| 217 |
"FirstFrameIsNotExemplar": True,
|
| 218 |
"dataset": "D16_batch",
|
|
|
|
| 230 |
"save_scores": False,
|
| 231 |
"flip": False,
|
| 232 |
"size": -1,
|
| 233 |
+
"reverse": False,
|
| 234 |
}
|
|
|
|
| 235 |
user_config = {
|
| 236 |
"FirstFrameIsNotExemplar": bool(first_not_exemplar) if first_not_exemplar is not None else default_config["FirstFrameIsNotExemplar"],
|
| 237 |
"dataset": str(dataset) if dataset else default_config["dataset"],
|
|
|
|
| 249 |
"save_scores": bool(save_scores) if save_scores is not None else default_config["save_scores"],
|
| 250 |
"flip": bool(flip) if flip is not None else default_config["flip"],
|
| 251 |
"size": int(size) if size is not None else default_config["size"],
|
| 252 |
+
"reverse": bool(reverse) if reverse is not None else default_config["reverse"],
|
| 253 |
}
|
| 254 |
|
| 255 |
+
# 7) 预下载权重(可选)
|
| 256 |
+
ensure_checkpoint()
|
| 257 |
+
|
| 258 |
+
# 8) 同进程调用 test.py
|
| 259 |
try:
|
| 260 |
+
import test # 确保 test.py 同目录且有 run_cli 函数
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
except Exception as e:
|
| 262 |
+
return None, f"导入 test.py 失败:\n{e}"
|
| 263 |
+
|
| 264 |
+
args_list = build_args_list_for_test(
|
| 265 |
+
d16_batch_path=input_root, # 指向 input_video 根
|
| 266 |
+
out_path=output_root, # 指向 output 根(test.py 写 output/<stem>/*.png)
|
| 267 |
+
ref_root=ref_root, # 指向 ref 根(test.py 读 ref/<stem>/ref.png��
|
| 268 |
+
cfg=user_config
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
buf = io.StringIO()
|
| 272 |
+
try:
|
| 273 |
+
with redirect_stdout(buf), redirect_stderr(buf):
|
| 274 |
+
entry = getattr(test, "run_cli", None)
|
| 275 |
+
if entry is None or not callable(entry):
|
| 276 |
+
raise RuntimeError("test.py 未提供可调用的 run_cli(args_list) 接口。")
|
| 277 |
+
entry(args_list)
|
| 278 |
+
log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}"
|
| 279 |
+
except Exception as e:
|
| 280 |
+
log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}"
|
| 281 |
+
return None, log
|
| 282 |
+
|
| 283 |
+
# 9) 合成 mp4:从 output/<stem>/ 帧合成 -> TEMP_ROOT/<stem>.mp4
|
| 284 |
+
out_frames = path.join(output_root, video_stem)
|
| 285 |
+
if not path.isdir(out_frames):
|
| 286 |
+
return None, f"未找到输出帧目录:{out_frames}\n\n{log}"
|
| 287 |
+
final_mp4 = path.join(TEMP_ROOT, f"{video_stem}.mp4")
|
| 288 |
+
try:
|
| 289 |
+
encode_frames_to_video(out_frames, final_mp4, fps=fps)
|
| 290 |
+
except Exception as e:
|
| 291 |
+
return None, f"合成视频失败:\n{e}\n\n{log}"
|
| 292 |
+
|
| 293 |
+
return final_mp4, f"完成 ✅\n\n{log}"
|
| 294 |
|
| 295 |
# ----------------- UI -----------------
|
| 296 |
with gr.Blocks() as demo:
|
| 297 |
gr.Markdown(f"# {TITLE}")
|
| 298 |
gr.Markdown(DESC)
|
| 299 |
|
| 300 |
+
debug_shapes = gr.Checkbox(label="调试日志(仅用于显示更完整日志)", value=False)
|
| 301 |
|
| 302 |
with gr.Row():
|
| 303 |
inp_video = gr.Video(label="黑白视频(mp4/webm/avi)", interactive=True)
|
| 304 |
inp_ref = gr.Image(label="参考图像(RGB)", type="pil")
|
|
|
|
| 305 |
gr.Examples(
|
| 306 |
label="示例输入",
|
| 307 |
+
examples=[["./example/4.mp4", "./example/4.png"]],
|
|
|
|
|
|
|
| 308 |
inputs=[inp_video, inp_ref],
|
|
|
|
| 309 |
cache_examples=False,
|
| 310 |
)
|
| 311 |
|
| 312 |
+
with gr.Accordion("高级参数设置(传给 test.py)", open=False):
|
| 313 |
with gr.Row():
|
| 314 |
+
first_not_exemplar = gr.Checkbox(label="FirstFrameIsNotExemplar (--FirstFrameIsNotExemplar)", value=True)
|
| 315 |
+
reverse = gr.Checkbox(label="reverse (--reverse)", value=False)
|
| 316 |
+
dataset = gr.Textbox(label="dataset (--dataset)", value="D16_batch")
|
| 317 |
+
split = gr.Textbox(label="split (--split)", value="val")
|
| 318 |
+
save_all = gr.Checkbox(label="save_all (--save_all)", value=True)
|
| 319 |
+
benchmark = gr.Checkbox(label="benchmark (--benchmark)", value=False)
|
| 320 |
with gr.Row():
|
| 321 |
+
disable_long_term = gr.Checkbox(label="disable_long_term (--disable_long_term)", value=False)
|
| 322 |
+
max_mid = gr.Number(label="max_mid_term_frames (--max_mid_term_frames)", value=10, precision=0)
|
| 323 |
+
min_mid = gr.Number(label="min_mid_term_frames (--min_mid_term_frames)", value=5, precision=0)
|
| 324 |
+
max_long = gr.Number(label="max_long_term_elements (--max_long_term_elements)", value=10000, precision=0)
|
| 325 |
+
num_proto = gr.Number(label="num_prototypes (--num_prototypes)", value=128, precision=0)
|
| 326 |
with gr.Row():
|
| 327 |
+
top_k = gr.Number(label="top_k (--top_k)", value=30, precision=0)
|
| 328 |
+
mem_every = gr.Number(label="mem_every (--mem_every)", value=5, precision=0)
|
| 329 |
+
deep_update = gr.Number(label="deep_update_every (--deep_update_every)", value=-1, precision=0)
|
| 330 |
+
save_scores = gr.Checkbox(label="save_scores (--save_scores)", value=False)
|
| 331 |
+
flip = gr.Checkbox(label="flip (--flip)", value=False)
|
| 332 |
+
size = gr.Number(label="size (--size)", value=-1, precision=0)
|
| 333 |
+
|
| 334 |
+
run_btn = gr.Button("开始着色(同进程调用 test.py)")
|
| 335 |
with gr.Row():
|
| 336 |
out_video = gr.Video(label="输出视频(着色结果)")
|
| 337 |
+
status = gr.Textbox(label="状态 / 日志输出(test.py stdout/stderr)", interactive=False, lines=16)
|
| 338 |
|
| 339 |
run_btn.click(
|
| 340 |
fn=gradio_infer,
|
|
|
|
| 344 |
first_not_exemplar, dataset, split, save_all, benchmark,
|
| 345 |
disable_long_term, max_mid, min_mid, max_long,
|
| 346 |
num_proto, top_k, mem_every, deep_update,
|
| 347 |
+
save_scores, flip, size, reverse
|
| 348 |
],
|
| 349 |
outputs=[out_video, status]
|
| 350 |
)
|
|
|
|
| 354 |
ensure_checkpoint()
|
| 355 |
except Exception as e:
|
| 356 |
print(f"[WARN] 预下载权重失败(首次推理会再试): {e}")
|
| 357 |
+
demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860)
|
|
|
inference/data/test_datasets.py
CHANGED
|
@@ -1,36 +1,116 @@
|
|
| 1 |
import os
|
| 2 |
from os import path
|
| 3 |
-
import json
|
| 4 |
|
| 5 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
self.size = size
|
| 13 |
|
| 14 |
-
self.vid_list = [clip_name for clip_name in sorted(os.listdir(data_root)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
|
| 15 |
-
self.ref_img_list = [clip_name for clip_name in sorted(os.listdir(imset)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
|
| 21 |
-
def
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
path.join(self.image_dir, video),
|
| 29 |
-
path.join(self.mask_dir, video),
|
| 30 |
-
size=self.size,
|
| 31 |
-
size_dir=path.join(self.size_dir, video),
|
| 32 |
-
args=self.args
|
| 33 |
-
)
|
| 34 |
|
| 35 |
def __len__(self):
|
| 36 |
-
return len(self.
|
|
|
|
| 1 |
import os
|
| 2 |
from os import path
|
|
|
|
| 3 |
|
| 4 |
+
from torch.utils.data.dataset import Dataset
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from torchvision.transforms import InterpolationMode
|
| 7 |
+
import torch.nn.functional as Ff
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
|
| 11 |
+
from dataset.range_transform import im_normalization, im_rgb2lab_normalization, ToTensor, RGB2Lab
|
| 12 |
+
|
| 13 |
+
class VideoReader_221128_TransColorization(Dataset):
|
| 14 |
+
"""
|
| 15 |
+
This class is used to read a video, one frame at a time
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None, args=None):
|
| 18 |
+
"""
|
| 19 |
+
image_dir - points to a directory of jpg images
|
| 20 |
+
mask_dir - points to a directory of png masks
|
| 21 |
+
size - resize min. side to size. Does nothing if <0.
|
| 22 |
+
to_save - optionally contains a list of file names without extensions
|
| 23 |
+
where the segmentation mask is required
|
| 24 |
+
use_all_mask - when true, read all available mask in mask_dir.
|
| 25 |
+
Default false. Set to true for YouTubeVOS validation.
|
| 26 |
+
"""
|
| 27 |
+
self.vid_name = vid_name
|
| 28 |
+
self.image_dir = image_dir
|
| 29 |
+
self.mask_dir = mask_dir
|
| 30 |
+
self.to_save = to_save
|
| 31 |
+
self.use_all_mask = use_all_mask
|
| 32 |
+
# print('use_all_mask', use_all_mask);assert 1==0
|
| 33 |
+
if size_dir is None:
|
| 34 |
+
self.size_dir = self.image_dir
|
| 35 |
+
else:
|
| 36 |
+
self.size_dir = size_dir
|
| 37 |
+
|
| 38 |
+
# flag_reverse = args.getattr('reverse', False) if args is not None else False
|
| 39 |
+
flag_reverse = False
|
| 40 |
+
self.frames = [img for img in sorted(os.listdir(self.image_dir), reverse=flag_reverse) if (img.endswith('.jpg') or img.endswith('.png')) and not img.startswith('.')]
|
| 41 |
+
self.palette = Image.open(path.join(mask_dir, sorted([msk for msk in os.listdir(mask_dir) if not msk.startswith('.')])[0])).getpalette()
|
| 42 |
+
self.first_gt_path = path.join(self.mask_dir, sorted([msk for msk in os.listdir(self.mask_dir) if not msk.startswith('.')])[0])
|
| 43 |
+
self.suffix = self.first_gt_path.split('.')[-1]
|
| 44 |
+
|
| 45 |
+
if size < 0:
|
| 46 |
+
self.im_transform = transforms.Compose([
|
| 47 |
+
RGB2Lab(),
|
| 48 |
+
ToTensor(),
|
| 49 |
+
im_rgb2lab_normalization,
|
| 50 |
+
])
|
| 51 |
+
else:
|
| 52 |
+
self.im_transform = transforms.Compose([
|
| 53 |
+
transforms.ToTensor(),
|
| 54 |
+
im_normalization,
|
| 55 |
+
transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
|
| 56 |
+
])
|
| 57 |
self.size = size
|
| 58 |
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
def __getitem__(self, idx):
|
| 61 |
+
frame = self.frames[idx]
|
| 62 |
+
info = {}
|
| 63 |
+
data = {}
|
| 64 |
+
info['frame'] = frame
|
| 65 |
+
info['vid_name'] = self.vid_name
|
| 66 |
+
info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
|
| 67 |
+
|
| 68 |
+
im_path = path.join(self.image_dir, frame)
|
| 69 |
+
img = Image.open(im_path).convert('RGB')
|
| 70 |
+
|
| 71 |
+
if self.image_dir == self.size_dir:
|
| 72 |
+
shape = np.array(img).shape[:2]
|
| 73 |
+
else:
|
| 74 |
+
size_path = path.join(self.size_dir, frame)
|
| 75 |
+
size_im = Image.open(size_path).convert('RGB')
|
| 76 |
+
shape = np.array(size_im).shape[:2]
|
| 77 |
+
|
| 78 |
+
gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[idx]) if idx < len(os.listdir(self.mask_dir)) else None
|
| 79 |
+
|
| 80 |
+
img = self.im_transform(img)
|
| 81 |
+
img_l = img[:1,:,:]
|
| 82 |
+
img_lll = img_l.repeat(3,1,1)
|
| 83 |
+
|
| 84 |
+
load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
|
| 85 |
+
if load_mask and path.exists(gt_path):
|
| 86 |
+
mask = Image.open(gt_path).convert('RGB')
|
| 87 |
+
|
| 88 |
+
# 用 PIL 先 resize 成和 img 尺寸一致
|
| 89 |
+
mask = mask.resize((img.shape[2], img.shape[1]), Image.BILINEAR)
|
| 90 |
+
|
| 91 |
+
mask = self.im_transform(mask)
|
| 92 |
+
|
| 93 |
+
# keep L channel of reference image in case First frame is not exemplar
|
| 94 |
+
# mask_ab = mask[1:3,:,:]
|
| 95 |
+
# data['mask'] = mask_ab
|
| 96 |
+
data['mask'] = mask
|
| 97 |
+
|
| 98 |
+
info['shape'] = shape
|
| 99 |
+
info['need_resize'] = not (self.size < 0)
|
| 100 |
+
data['rgb'] = img_lll
|
| 101 |
+
data['info'] = info
|
| 102 |
|
| 103 |
+
return data
|
| 104 |
|
| 105 |
+
def resize_mask(self, mask):
|
| 106 |
+
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
| 107 |
+
h, w = mask.shape[-2:]
|
| 108 |
+
min_hw = min(h, w)
|
| 109 |
+
return Ff.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
|
| 110 |
+
mode='nearest')
|
| 111 |
|
| 112 |
+
def get_palette(self):
|
| 113 |
+
return self.palette
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def __len__(self):
|
| 116 |
+
return len(self.frames)
|
test.py
CHANGED
|
@@ -1,8 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from os import path
|
| 3 |
from argparse import ArgumentParser
|
| 4 |
import shutil
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from torch.utils.data import DataLoader
|
|
@@ -27,47 +34,7 @@ except ImportError:
|
|
| 27 |
print('Failed to import hickle. Fine if not using multi-scale testing.')
|
| 28 |
|
| 29 |
|
| 30 |
-
|
| 31 |
-
Arguments loading
|
| 32 |
-
"""
|
| 33 |
-
parser = ArgumentParser()
|
| 34 |
-
parser.add_argument('--model', default='saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth')
|
| 35 |
-
|
| 36 |
-
# dataset setting
|
| 37 |
-
parser.add_argument('--d16_batch_path', default='input')
|
| 38 |
-
parser.add_argument('--deoldify_path', default='ref')
|
| 39 |
-
parser.add_argument('--output', default='result')
|
| 40 |
-
|
| 41 |
-
# For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
|
| 42 |
-
parser.add_argument('--generic_path')
|
| 43 |
-
parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D16_batch')
|
| 44 |
-
parser.add_argument('--split', help='val/test', default='val')
|
| 45 |
-
parser.add_argument('--save_all', action='store_true',
|
| 46 |
-
help='Save all frames. Useful only in YouTubeVOS/long-time video', )
|
| 47 |
-
parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
|
| 48 |
-
|
| 49 |
-
# Long-term memory options
|
| 50 |
-
parser.add_argument('--disable_long_term', action='store_true')
|
| 51 |
-
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
|
| 52 |
-
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
|
| 53 |
-
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
|
| 54 |
-
type=int, default=10000)
|
| 55 |
-
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
|
| 56 |
-
|
| 57 |
-
parser.add_argument('--top_k', type=int, default=30)
|
| 58 |
-
parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
|
| 59 |
-
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
|
| 60 |
-
|
| 61 |
-
# Multi-scale options
|
| 62 |
-
parser.add_argument('--save_scores', action='store_true')
|
| 63 |
-
parser.add_argument('--flip', action='store_true')
|
| 64 |
-
parser.add_argument('--size', default=-1, type=int,
|
| 65 |
-
help='Resize the shorter side to this size. -1 to use original resolution. ')
|
| 66 |
-
|
| 67 |
-
args = parser.parse_args()
|
| 68 |
-
config = vars(args)
|
| 69 |
-
config['enable_long_term'] = not config['disable_long_term']
|
| 70 |
-
|
| 71 |
def detach_to_cpu(x):
|
| 72 |
return x.detach().cpu()
|
| 73 |
|
|
@@ -89,142 +56,236 @@ def lab2rgb_transform_PIL(mask):
|
|
| 89 |
|
| 90 |
return im.clip(0, 1)
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
)
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
if msk is not None:
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
else:
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
-
|
| 180 |
-
if msk is not None:
|
| 181 |
-
msk = torch.Tensor(msk[0]).cuda()
|
| 182 |
if need_resize:
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test.py — In-process callable version (for ZeroGPU stateless)
|
| 2 |
+
# Keep original logic; add build_parser(), run_cli(args_list), and run_inference(args)
|
| 3 |
+
# Do NOT initialize CUDA at import-time.
|
| 4 |
+
|
| 5 |
import os
|
| 6 |
from os import path
|
| 7 |
from argparse import ArgumentParser
|
| 8 |
import shutil
|
| 9 |
|
| 10 |
+
# 不在这里做 @spaces.GPU 装饰,避免与 app.py 的 @spaces.GPU 双重调度
|
| 11 |
+
# import spaces
|
| 12 |
+
|
| 13 |
import torch
|
| 14 |
import torch.nn.functional as F
|
| 15 |
from torch.utils.data import DataLoader
|
|
|
|
| 34 |
print('Failed to import hickle. Fine if not using multi-scale testing.')
|
| 35 |
|
| 36 |
|
| 37 |
+
# ----------------- small utils -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def detach_to_cpu(x):
|
| 39 |
return x.detach().cpu()
|
| 40 |
|
|
|
|
| 56 |
|
| 57 |
return im.clip(0, 1)
|
| 58 |
|
| 59 |
+
|
| 60 |
+
# ----------------- argparse -----------------
|
| 61 |
+
def build_parser() -> ArgumentParser:
|
| 62 |
+
parser = ArgumentParser()
|
| 63 |
+
parser.add_argument('--model', default='saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth')
|
| 64 |
+
parser.add_argument('--FirstFrameIsNotExemplar', help='Whether the provided reference frame is exactly the first input frame', action='store_true')
|
| 65 |
+
|
| 66 |
+
# dataset setting
|
| 67 |
+
parser.add_argument('--d16_batch_path', default='input', help='Point to folder A/ which contains <video_name>/00000.png etc.')
|
| 68 |
+
parser.add_argument('--ref_path', default='ref', help='Kept for parity; dataset will also read ref.png under each video folder when args provided')
|
| 69 |
+
parser.add_argument('--output', default='result', help='Directory to save results')
|
| 70 |
+
|
| 71 |
+
parser.add_argument('--reverse', default=False, action='store_true', help='whether to reverse the frame order')
|
| 72 |
+
parser.add_argument('--allow_resume', action='store_true',
|
| 73 |
+
help='skip existing videos that have been colorized')
|
| 74 |
+
|
| 75 |
+
# For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
|
| 76 |
+
parser.add_argument('--generic_path')
|
| 77 |
+
parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D16_batch')
|
| 78 |
+
parser.add_argument('--split', help='val/test', default='val')
|
| 79 |
+
parser.add_argument('--save_all', action='store_true',
|
| 80 |
+
help='Save all frames. Useful only in YouTubeVOS/long-time video')
|
| 81 |
+
parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
|
| 82 |
+
|
| 83 |
+
# Long-term memory options
|
| 84 |
+
parser.add_argument('--disable_long_term', action='store_true')
|
| 85 |
+
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
|
| 86 |
+
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
|
| 87 |
+
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
|
| 88 |
+
type=int, default=10000)
|
| 89 |
+
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
|
| 90 |
+
|
| 91 |
+
parser.add_argument('--top_k', type=int, default=30)
|
| 92 |
+
parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
|
| 93 |
+
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
|
| 94 |
+
|
| 95 |
+
# Multi-scale options
|
| 96 |
+
parser.add_argument('--save_scores', action='store_true')
|
| 97 |
+
parser.add_argument('--flip', action='store_true')
|
| 98 |
+
parser.add_argument('--size', default=-1, type=int,
|
| 99 |
+
help='Resize the shorter side to this size. -1 to use original resolution. ')
|
| 100 |
+
return parser
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ----------------- core inference -----------------
|
| 104 |
+
def run_inference(args):
|
| 105 |
+
"""
|
| 106 |
+
真正的推理流程。必须在 ZeroGPU 的调度上下文里被调用(由 app.py 的 @spaces.GPU 包裹)。
|
| 107 |
+
不要在导入模块时做任何 CUDA 初始化。
|
| 108 |
+
"""
|
| 109 |
+
config = vars(args)
|
| 110 |
+
config['enable_long_term'] = not config['disable_long_term']
|
| 111 |
+
|
| 112 |
+
if args.output is None:
|
| 113 |
+
args.output = f'.output/{args.dataset}_{args.split}'
|
| 114 |
+
print(f'Output path not provided. Defaulting to {args.output}')
|
| 115 |
+
|
| 116 |
+
# ----- Data preparation -----
|
| 117 |
+
is_youtube = args.dataset.startswith('Y')
|
| 118 |
+
is_davis = args.dataset.startswith('D')
|
| 119 |
+
is_lv = args.dataset.startswith('LV')
|
| 120 |
+
|
| 121 |
+
if is_youtube or args.save_scores:
|
| 122 |
+
out_path = path.join(args.output, 'Annotations')
|
| 123 |
+
else:
|
| 124 |
+
out_path = args.output
|
| 125 |
+
|
| 126 |
+
if args.split != 'val':
|
| 127 |
+
raise NotImplementedError('Only split=val is supported in this script.')
|
| 128 |
+
|
| 129 |
+
# 数据集:支持 A/<video>/00000.png ... 且读取 A/<video>/ref.png
|
| 130 |
+
meta_dataset = DAVISTestDataset_221128_TransColorization_batch(
|
| 131 |
+
args.d16_batch_path, imset=args.ref_path, size=args.size, args=args
|
| 132 |
)
|
| 133 |
+
palette = None # 兼容保留
|
| 134 |
+
|
| 135 |
+
torch.autograd.set_grad_enabled(False)
|
| 136 |
|
| 137 |
+
# Set up loader list (video readers)
|
| 138 |
+
meta_loader = meta_dataset.get_datasets()
|
| 139 |
+
|
| 140 |
+
# Load checkpoint/model
|
| 141 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 142 |
+
network = ColorMNet(config, args.model).to(device).eval()
|
| 143 |
+
if args.model is not None:
|
| 144 |
+
# map_location 不指定,按默认走(ZeroGPU 下会在被调度的设备上加载)
|
| 145 |
+
model_weights = torch.load(args.model, map_location=device)
|
| 146 |
+
network.load_weights(model_weights, init_as_zero_if_needed=True)
|
| 147 |
+
else:
|
| 148 |
+
print('No model loaded.')
|
| 149 |
+
|
| 150 |
+
total_process_time = 0.0
|
| 151 |
+
total_frames = 0
|
| 152 |
+
|
| 153 |
+
# ----- Start eval over videos -----
|
| 154 |
+
for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
|
| 155 |
+
# 注意:ZeroGPU/Spaces 环境不允许子进程多线程加载,保持 num_workers=0
|
| 156 |
+
loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
|
| 157 |
+
vid_name = vid_reader.vid_name
|
| 158 |
+
vid_length = len(loader)
|
| 159 |
+
|
| 160 |
+
# LT usage check per original logic
|
| 161 |
+
config['enable_long_term_count_usage'] = (
|
| 162 |
+
config['enable_long_term'] and
|
| 163 |
+
(vid_length
|
| 164 |
+
/ (config['max_mid_term_frames'] - config['min_mid_term_frames'])
|
| 165 |
+
* config['num_prototypes'])
|
| 166 |
+
>= config['max_long_term_elements']
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
mapper = MaskMapper()
|
| 170 |
+
processor = InferenceCore(network, config=config)
|
| 171 |
+
first_mask_loaded = False
|
| 172 |
+
|
| 173 |
+
# skip existing videos
|
| 174 |
+
if args.allow_resume:
|
| 175 |
+
this_out_path = path.join(out_path, vid_name)
|
| 176 |
+
if path.exists(this_out_path):
|
| 177 |
+
print(f'Skipping {this_out_path} because output already exists.')
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
for ti, data in enumerate(loader):
|
| 181 |
+
with torch.cuda.amp.autocast(enabled=not args.benchmark):
|
| 182 |
+
rgb = data['rgb'].to(device)[0]
|
| 183 |
+
|
| 184 |
+
msk = data.get('mask')
|
| 185 |
+
if not config['FirstFrameIsNotExemplar']:
|
| 186 |
+
msk = msk[:, 1:3, :, :] if msk is not None else None
|
| 187 |
+
|
| 188 |
+
info = data['info']
|
| 189 |
+
frame = info['frame'][0]
|
| 190 |
+
shape = info['shape']
|
| 191 |
+
need_resize = info['need_resize'][0]
|
| 192 |
+
|
| 193 |
+
start = torch.cuda.Event(enable_timing=True)
|
| 194 |
+
end = torch.cuda.Event(enable_timing=True)
|
| 195 |
+
start.record()
|
| 196 |
+
|
| 197 |
+
# 第一次必须有 mask
|
| 198 |
+
if not first_mask_loaded:
|
| 199 |
+
if msk is not None:
|
| 200 |
+
first_mask_loaded = True
|
| 201 |
+
else:
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
if args.flip:
|
| 205 |
+
rgb = torch.flip(rgb, dims=[-1])
|
| 206 |
+
msk = torch.flip(msk, dims=[-1]) if msk is not None else None
|
| 207 |
+
|
| 208 |
+
# Map possibly non-continuous labels to continuous ones
|
| 209 |
if msk is not None:
|
| 210 |
+
msk = torch.Tensor(msk[0]).to(device)
|
| 211 |
+
if need_resize:
|
| 212 |
+
msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
|
| 213 |
+
processor.set_all_labels(list(range(1, 3)))
|
| 214 |
+
labels = range(1, 3)
|
| 215 |
else:
|
| 216 |
+
labels = None
|
| 217 |
+
|
| 218 |
+
# Run the model on this frame
|
| 219 |
+
if config['FirstFrameIsNotExemplar']:
|
| 220 |
+
prob = processor.step_AnyExemplar(
|
| 221 |
+
rgb,
|
| 222 |
+
msk[:1, :, :].repeat(3, 1, 1) if msk is not None else None,
|
| 223 |
+
msk[1:3, :, :] if msk is not None else None,
|
| 224 |
+
labels,
|
| 225 |
+
end=(ti == vid_length - 1)
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
prob = processor.step(rgb, msk, labels, end=(ti == vid_length - 1))
|
| 229 |
|
| 230 |
+
# Upsample to original size if needed
|
|
|
|
|
|
|
| 231 |
if need_resize:
|
| 232 |
+
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0]
|
| 233 |
+
|
| 234 |
+
end.record()
|
| 235 |
+
torch.cuda.synchronize()
|
| 236 |
+
total_process_time += (start.elapsed_time(end)/1000)
|
| 237 |
+
total_frames += 1
|
| 238 |
+
|
| 239 |
+
if args.flip:
|
| 240 |
+
prob = torch.flip(prob, dims=[-1])
|
| 241 |
+
|
| 242 |
+
if args.save_scores:
|
| 243 |
+
prob = (prob.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 244 |
+
|
| 245 |
+
# Save the mask
|
| 246 |
+
if args.save_all or info['save'][0]:
|
| 247 |
+
this_out_path = path.join(out_path, vid_name)
|
| 248 |
+
os.makedirs(this_out_path, exist_ok=True)
|
| 249 |
+
|
| 250 |
+
out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1, :, :], prob], dim=0))
|
| 251 |
+
out_mask_final = (out_mask_final * 255).astype(np.uint8)
|
| 252 |
+
|
| 253 |
+
out_img = Image.fromarray(out_mask_final)
|
| 254 |
+
out_img.save(os.path.join(this_out_path, frame[:-4] + '.png'))
|
| 255 |
+
|
| 256 |
+
print(f'Total processing time: {total_process_time}')
|
| 257 |
+
print(f'Total processed frames: {total_frames}')
|
| 258 |
+
print(f'FPS: {total_frames / total_process_time}')
|
| 259 |
+
print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
|
| 260 |
+
|
| 261 |
+
# 与原版一致:只在 save_scores=False 且特定数据集/子集时打 zip
|
| 262 |
+
if not args.save_scores:
|
| 263 |
+
if is_youtube:
|
| 264 |
+
print('Making zip for YouTubeVOS...')
|
| 265 |
+
shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
|
| 266 |
+
elif is_davis and args.split == 'test':
|
| 267 |
+
print('Making zip for DAVIS test-dev...')
|
| 268 |
+
shutil.make_archive(args.output, 'zip', args.output)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# ----------------- public entrypoints -----------------
|
| 272 |
+
def run_cli(args_list=None):
|
| 273 |
+
"""
|
| 274 |
+
供 app.py 同进程调用:test.run_cli(args_list)
|
| 275 |
+
"""
|
| 276 |
+
parser = build_parser()
|
| 277 |
+
args = parser.parse_args(args_list)
|
| 278 |
+
return run_inference(args)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def main():
|
| 282 |
+
"""
|
| 283 |
+
保留命令行可运行:python test.py --d16_batch_path A --output result ...
|
| 284 |
+
注意:若在 Hugging Face Spaces/ZeroGPU 无状态环境下直接 run main(),
|
| 285 |
+
需要由上层(如 app.py 的 @spaces.GPU)提供调度上下文。
|
| 286 |
+
"""
|
| 287 |
+
run_cli()
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
if __name__ == '__main__':
|
| 291 |
+
main()
|