yyang181 commited on
Commit
9e837cb
·
1 Parent(s): 95257c4
Files changed (2) hide show
  1. app.py +1 -1
  2. inference/data/test_datasets.py +24 -104
app.py CHANGED
@@ -155,7 +155,7 @@ def build_args_list_for_test(d16_batch_path: str,
155
  return args
156
 
157
  # ----------------- GRADIO HANDLER -----------------
158
- @spaces.GPU(duration=160) # 确保 CUDA 初始化在此函数体内
159
  def gradio_infer(
160
  debug_shapes,
161
  bw_video, ref_image,
 
155
  return args
156
 
157
  # ----------------- GRADIO HANDLER -----------------
158
+ @spaces.GPU(duration=100) # 确保 CUDA 初始化在此函数体内
159
  def gradio_infer(
160
  debug_shapes,
161
  bw_video, ref_image,
inference/data/test_datasets.py CHANGED
@@ -1,116 +1,36 @@
1
  import os
2
  from os import path
 
3
 
4
- from torch.utils.data.dataset import Dataset
5
- from torchvision import transforms
6
- from torchvision.transforms import InterpolationMode
7
- import torch.nn.functional as Ff
8
- from PIL import Image
9
- import numpy as np
10
 
11
- from dataset.range_transform import im_normalization, im_rgb2lab_normalization, ToTensor, RGB2Lab
12
-
13
- class VideoReader_221128_TransColorization(Dataset):
14
- """
15
- This class is used to read a video, one frame at a time
16
- """
17
- def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None, args=None):
18
- """
19
- image_dir - points to a directory of jpg images
20
- mask_dir - points to a directory of png masks
21
- size - resize min. side to size. Does nothing if <0.
22
- to_save - optionally contains a list of file names without extensions
23
- where the segmentation mask is required
24
- use_all_mask - when true, read all available mask in mask_dir.
25
- Default false. Set to true for YouTubeVOS validation.
26
- """
27
- self.vid_name = vid_name
28
- self.image_dir = image_dir
29
- self.mask_dir = mask_dir
30
- self.to_save = to_save
31
- self.use_all_mask = use_all_mask
32
- # print('use_all_mask', use_all_mask);assert 1==0
33
- if size_dir is None:
34
- self.size_dir = self.image_dir
35
- else:
36
- self.size_dir = size_dir
37
-
38
- # flag_reverse = args.getattr('reverse', False) if args is not None else False
39
- flag_reverse = False
40
- self.frames = [img for img in sorted(os.listdir(self.image_dir), reverse=flag_reverse) if (img.endswith('.jpg') or img.endswith('.png')) and not img.startswith('.')]
41
- self.palette = Image.open(path.join(mask_dir, sorted([msk for msk in os.listdir(mask_dir) if not msk.startswith('.')])[0])).getpalette()
42
- self.first_gt_path = path.join(self.mask_dir, sorted([msk for msk in os.listdir(self.mask_dir) if not msk.startswith('.')])[0])
43
- self.suffix = self.first_gt_path.split('.')[-1]
44
-
45
- if size < 0:
46
- self.im_transform = transforms.Compose([
47
- RGB2Lab(),
48
- ToTensor(),
49
- im_rgb2lab_normalization,
50
- ])
51
- else:
52
- self.im_transform = transforms.Compose([
53
- transforms.ToTensor(),
54
- im_normalization,
55
- transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
56
- ])
57
  self.size = size
58
 
 
 
59
 
60
- def __getitem__(self, idx):
61
- frame = self.frames[idx]
62
- info = {}
63
- data = {}
64
- info['frame'] = frame
65
- info['vid_name'] = self.vid_name
66
- info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
67
-
68
- im_path = path.join(self.image_dir, frame)
69
- img = Image.open(im_path).convert('RGB')
70
-
71
- if self.image_dir == self.size_dir:
72
- shape = np.array(img).shape[:2]
73
- else:
74
- size_path = path.join(self.size_dir, frame)
75
- size_im = Image.open(size_path).convert('RGB')
76
- shape = np.array(size_im).shape[:2]
77
-
78
- gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[idx]) if idx < len(os.listdir(self.mask_dir)) else None
79
-
80
- img = self.im_transform(img)
81
- img_l = img[:1,:,:]
82
- img_lll = img_l.repeat(3,1,1)
83
-
84
- load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
85
- if load_mask and path.exists(gt_path):
86
- mask = Image.open(gt_path).convert('RGB')
87
-
88
- # 用 PIL 先 resize 成和 img 尺寸一致
89
- mask = mask.resize((img.shape[2], img.shape[1]), Image.BILINEAR)
90
-
91
- mask = self.im_transform(mask)
92
-
93
- # keep L channel of reference image in case First frame is not exemplar
94
- # mask_ab = mask[1:3,:,:]
95
- # data['mask'] = mask_ab
96
- data['mask'] = mask
97
-
98
- info['shape'] = shape
99
- info['need_resize'] = not (self.size < 0)
100
- data['rgb'] = img_lll
101
- data['info'] = info
102
 
103
- return data
104
 
105
- def resize_mask(self, mask):
106
- # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
107
- h, w = mask.shape[-2:]
108
- min_hw = min(h, w)
109
- return Ff.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
110
- mode='nearest')
111
 
112
- def get_palette(self):
113
- return self.palette
 
 
 
 
 
 
114
 
115
  def __len__(self):
116
- return len(self.frames)
 
1
  import os
2
  from os import path
3
+ import json
4
 
5
+ from inference.data.video_reader import VideoReader_221128_TransColorization
 
 
 
 
 
6
 
7
+ class DAVISTestDataset_221128_TransColorization_batch:
8
+ def __init__(self, data_root, imset='2017/val.txt', size=-1, args=None):
9
+ self.image_dir = data_root
10
+ self.mask_dir = imset
11
+ self.size_dir = data_root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  self.size = size
13
 
14
+ self.vid_list = [clip_name for clip_name in sorted(os.listdir(data_root)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
15
+ self.ref_img_list = [clip_name for clip_name in sorted(os.listdir(imset)) if clip_name != '.DS_Store' and not clip_name.startswith('.')]
16
 
17
+ self.args = args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # print(lst, len(lst), self.vid_list, self.vid_list_DAVIS2016, path.join(data_root, 'ImageSets', imset));assert 1==0
20
 
21
+ def get_datasets(self):
22
+ for video in self.vid_list:
23
+ if video not in self.ref_img_list:
24
+ continue
 
 
25
 
26
+ # print(self.image_dir, video, path.join(self.image_dir, video));assert 1==0
27
+ yield VideoReader_221128_TransColorization(video,
28
+ path.join(self.image_dir, video),
29
+ path.join(self.mask_dir, video),
30
+ size=self.size,
31
+ size_dir=path.join(self.size_dir, video),
32
+ args=self.args
33
+ )
34
 
35
  def __len__(self):
36
+ return len(self.vid_list)