Spaces:
Configuration error
Configuration error
- app.py +1 -1
- inference/data/test_datasets.py +24 -104
app.py
CHANGED
|
@@ -155,7 +155,7 @@ def build_args_list_for_test(d16_batch_path: str,
|
|
| 155 |
return args
|
| 156 |
|
| 157 |
# ----------------- GRADIO HANDLER -----------------
|
| 158 |
-
@spaces.GPU(duration=
|
| 159 |
def gradio_infer(
|
| 160 |
debug_shapes,
|
| 161 |
bw_video, ref_image,
|
|
|
|
| 155 |
return args
|
| 156 |
|
| 157 |
# ----------------- GRADIO HANDLER -----------------
|
| 158 |
+
@spaces.GPU(duration=100) # 确保 CUDA 初始化在此函数体内
|
| 159 |
def gradio_infer(
|
| 160 |
debug_shapes,
|
| 161 |
bw_video, ref_image,
|
inference/data/test_datasets.py
CHANGED
|
@@ -1,116 +1,36 @@
|
|
| 1 |
import os
|
| 2 |
from os import path
|
|
|
|
| 3 |
|
| 4 |
-
from
|
| 5 |
-
from torchvision import transforms
|
| 6 |
-
from torchvision.transforms import InterpolationMode
|
| 7 |
-
import torch.nn.functional as Ff
|
| 8 |
-
from PIL import Image
|
| 9 |
-
import numpy as np
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
"""
|
| 17 |
-
def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None, args=None):
|
| 18 |
-
"""
|
| 19 |
-
image_dir - points to a directory of jpg images
|
| 20 |
-
mask_dir - points to a directory of png masks
|
| 21 |
-
size - resize min. side to size. Does nothing if <0.
|
| 22 |
-
to_save - optionally contains a list of file names without extensions
|
| 23 |
-
where the segmentation mask is required
|
| 24 |
-
use_all_mask - when true, read all available mask in mask_dir.
|
| 25 |
-
Default false. Set to true for YouTubeVOS validation.
|
| 26 |
-
"""
|
| 27 |
-
self.vid_name = vid_name
|
| 28 |
-
self.image_dir = image_dir
|
| 29 |
-
self.mask_dir = mask_dir
|
| 30 |
-
self.to_save = to_save
|
| 31 |
-
self.use_all_mask = use_all_mask
|
| 32 |
-
# print('use_all_mask', use_all_mask);assert 1==0
|
| 33 |
-
if size_dir is None:
|
| 34 |
-
self.size_dir = self.image_dir
|
| 35 |
-
else:
|
| 36 |
-
self.size_dir = size_dir
|
| 37 |
-
|
| 38 |
-
# flag_reverse = args.getattr('reverse', False) if args is not None else False
|
| 39 |
-
flag_reverse = False
|
| 40 |
-
self.frames = [img for img in sorted(os.listdir(self.image_dir), reverse=flag_reverse) if (img.endswith('.jpg') or img.endswith('.png')) and not img.startswith('.')]
|
| 41 |
-
self.palette = Image.open(path.join(mask_dir, sorted([msk for msk in os.listdir(mask_dir) if not msk.startswith('.')])[0])).getpalette()
|
| 42 |
-
self.first_gt_path = path.join(self.mask_dir, sorted([msk for msk in os.listdir(self.mask_dir) if not msk.startswith('.')])[0])
|
| 43 |
-
self.suffix = self.first_gt_path.split('.')[-1]
|
| 44 |
-
|
| 45 |
-
if size < 0:
|
| 46 |
-
self.im_transform = transforms.Compose([
|
| 47 |
-
RGB2Lab(),
|
| 48 |
-
ToTensor(),
|
| 49 |
-
im_rgb2lab_normalization,
|
| 50 |
-
])
|
| 51 |
-
else:
|
| 52 |
-
self.im_transform = transforms.Compose([
|
| 53 |
-
transforms.ToTensor(),
|
| 54 |
-
im_normalization,
|
| 55 |
-
transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
|
| 56 |
-
])
|
| 57 |
self.size = size
|
| 58 |
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
frame = self.frames[idx]
|
| 62 |
-
info = {}
|
| 63 |
-
data = {}
|
| 64 |
-
info['frame'] = frame
|
| 65 |
-
info['vid_name'] = self.vid_name
|
| 66 |
-
info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
|
| 67 |
-
|
| 68 |
-
im_path = path.join(self.image_dir, frame)
|
| 69 |
-
img = Image.open(im_path).convert('RGB')
|
| 70 |
-
|
| 71 |
-
if self.image_dir == self.size_dir:
|
| 72 |
-
shape = np.array(img).shape[:2]
|
| 73 |
-
else:
|
| 74 |
-
size_path = path.join(self.size_dir, frame)
|
| 75 |
-
size_im = Image.open(size_path).convert('RGB')
|
| 76 |
-
shape = np.array(size_im).shape[:2]
|
| 77 |
-
|
| 78 |
-
gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[idx]) if idx < len(os.listdir(self.mask_dir)) else None
|
| 79 |
-
|
| 80 |
-
img = self.im_transform(img)
|
| 81 |
-
img_l = img[:1,:,:]
|
| 82 |
-
img_lll = img_l.repeat(3,1,1)
|
| 83 |
-
|
| 84 |
-
load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
|
| 85 |
-
if load_mask and path.exists(gt_path):
|
| 86 |
-
mask = Image.open(gt_path).convert('RGB')
|
| 87 |
-
|
| 88 |
-
# 用 PIL 先 resize 成和 img 尺寸一致
|
| 89 |
-
mask = mask.resize((img.shape[2], img.shape[1]), Image.BILINEAR)
|
| 90 |
-
|
| 91 |
-
mask = self.im_transform(mask)
|
| 92 |
-
|
| 93 |
-
# keep L channel of reference image in case First frame is not exemplar
|
| 94 |
-
# mask_ab = mask[1:3,:,:]
|
| 95 |
-
# data['mask'] = mask_ab
|
| 96 |
-
data['mask'] = mask
|
| 97 |
-
|
| 98 |
-
info['shape'] = shape
|
| 99 |
-
info['need_resize'] = not (self.size < 0)
|
| 100 |
-
data['rgb'] = img_lll
|
| 101 |
-
data['info'] = info
|
| 102 |
|
| 103 |
-
|
| 104 |
|
| 105 |
-
def
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
return Ff.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
|
| 110 |
-
mode='nearest')
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def __len__(self):
|
| 116 |
-
return len(self.
|
|
|
|
| 1 |
import os
|
| 2 |
from os import path
|
| 3 |
+
import json
|
| 4 |
|
| 5 |
+
from inference.data.video_reader import VideoReader_221128_TransColorization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
class DAVISTestDataset_221128_TransColorization_batch:
|
| 8 |
+
def __init__(self, data_root, imset='2017/val.txt', size=-1, args=None):
|
| 9 |
+
self.image_dir = data_root
|
| 10 |
+
self.mask_dir = imset
|
| 11 |
+
self.size_dir = data_root
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
self.size = size
|
| 13 |
|
| 14 |
+
self.vid_list = [clip_name for clip_name in sorted(os.listdir(data_root)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
|
| 15 |
+
self.ref_img_list = [clip_name for clip_name in sorted(os.listdir(imset)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
|
| 16 |
|
| 17 |
+
self.args = args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
# print(lst, len(lst), self.vid_list, self.vid_list_DAVIS2016, path.join(data_root, 'ImageSets', imset));assert 1==0
|
| 20 |
|
| 21 |
+
def get_datasets(self):
|
| 22 |
+
for video in self.vid_list:
|
| 23 |
+
if video not in self.ref_img_list:
|
| 24 |
+
continue
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
# print(self.image_dir, video, path.join(self.image_dir, video));assert 1==0
|
| 27 |
+
yield VideoReader_221128_TransColorization(video,
|
| 28 |
+
path.join(self.image_dir, video),
|
| 29 |
+
path.join(self.mask_dir, video),
|
| 30 |
+
size=self.size,
|
| 31 |
+
size_dir=path.join(self.size_dir, video),
|
| 32 |
+
args=self.args
|
| 33 |
+
)
|
| 34 |
|
| 35 |
def __len__(self):
|
| 36 |
+
return len(self.vid_list)
|