Spaces:
Runtime error
Runtime error
File size: 6,159 Bytes
470ac18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
import pickle
from os.path import join as pjoin
import collections
import json
import torch
import numpy as np
import re
import cv2
from torch.utils import data
def get_data_path(name):
"""Extract path to data from config file.
Args:
name (str): The name of the dataset.
Returns:
(str): The path to the root directory containing the dataset.
"""
with open('../xgw/segmentation/config.json') as f:
js = f.read()
# js = open('config.json').read()
data = json.loads(js)
return os.path.expanduser(data[name]['data_path'])
def getDatasets(dir):
return os.listdir(dir)
'''
Resize the input image into 1024x960 (zooming in or out along the longest side and keeping the aspect ration, then filling zero for padding. )
'''
def resize_image(origin_img, long_edge=1024, short_edge=960):
# long_edge, short_edge = 2048, 1920
# long_edge, short_edge = 1024, 960
# long_edge, short_edge = 512, 480
im_lr = origin_img.shape[0]
im_ud = origin_img.shape[1]
new_img = np.zeros([long_edge, short_edge, 3], dtype=np.uint8)
new_shape = new_img.shape[:2]
if im_lr > im_ud:
img_shrink, base_img_shrink = long_edge, long_edge
im_ud = int(im_ud / im_lr * base_img_shrink)
im_ud += 32-im_ud%32
im_ud = min(im_ud, short_edge)
im_lr = img_shrink
origin_img = cv2.resize(origin_img, (im_ud, im_lr), interpolation=cv2.INTER_CUBIC)
new_img[:, (new_shape[1]-im_ud)//2:new_shape[1]-(new_shape[1]-im_ud)//2] = origin_img
# mask = np.full(new_shape, 255, dtype='uint8')
# mask[:, (new_shape[1] - im_ud) // 2:new_shape[1] - (new_shape[1] - im_ud) // 2] = 0
else:
img_shrink, base_img_shrink = short_edge, short_edge
im_lr = int(im_lr / im_ud * base_img_shrink)
im_lr += 32-im_lr%32
im_lr = min(im_lr, long_edge)
im_ud = img_shrink
origin_img = cv2.resize(origin_img, (im_ud, im_lr), interpolation=cv2.INTER_CUBIC)
new_img[(new_shape[0] - im_lr) // 2:new_shape[0] - (new_shape[0] - im_lr) // 2, :] = origin_img
return new_img
class PerturbedDatastsForFiducialPoints_pickle_color_v2_v2(data.Dataset):
def __init__(self, root, split='1-1', img_shrink=None, is_return_img_name=False, preproccess=False):
self.root = os.path.expanduser(root)
self.split = split
self.img_shrink = img_shrink
self.is_return_img_name = is_return_img_name
self.preproccess = preproccess
# self.mean = np.array([104.00699, 116.66877, 122.67892])
self.images = collections.defaultdict(list)
self.labels = collections.defaultdict(list)
self.row_gap = 1 # value:0, 1, 2; POINTS NUM: 61, 31, 21
self.col_gap = 1
datasets = ['validate', 'test', 'train']
if self.split == 'test' or self.split == 'eval':
img_file_list = getDatasets(os.path.join(self.root))
self.images[self.split] = img_file_list
# self.images[self.split] = sorted(img_file_list, key=lambda num: (
# int(re.match(r'(\d+)_(\d+)( copy.png)', num, re.IGNORECASE).group(1)), int(re.match(r'(\d+)_(\d+)( copy.png)', num, re.IGNORECASE).group(2))))
elif self.split in datasets:
img_file_list = []
img_file_list_ = getDatasets(os.path.join(self.root, 'color'))
for id_ in img_file_list_:
img_file_list.append(id_.rstrip())
self.images[self.split] = sorted(img_file_list, key=lambda num: (
re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(1), int(re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(2))
, int(re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(3)), re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(4)))
else:
raise Exception('load data error')
# self.checkImg()
def checkImg(self):
if self.split == 'validate':
for im_name in self.images[self.split]:
# if 'SinglePage' in im_name:
im_path = pjoin(self.root, self.split, 'color', im_name)
try:
with open(im_path, 'rb') as f:
perturbed_data = pickle.load(f)
im_shape = perturbed_data.shape
except:
print(im_name)
# os.remove(im_path)
def __len__(self):
return len(self.images[self.split])
def __getitem__(self, item):
if self.split == 'test':
im_name = self.images[self.split][item]
im_path = pjoin(self.root, im_name)
im = cv2.imread(im_path, flags=cv2.IMREAD_COLOR)
im = self.resize_im(im)
im = self.transform_im(im)
if self.is_return_img_name:
return im, im_name
return im
elif self.split == 'eval':
im_name = self.images[self.split][item]
im_path = pjoin(self.root, im_name)
img = cv2.imread(im_path, flags=cv2.IMREAD_COLOR)
im = self.resize_im(img)
im = self.transform_im(im)
if self.is_return_img_name:
return im, im_name
return im, img
# return im, img, im_name
else:
im_name = self.images[self.split][item]
im_path = pjoin(self.root, 'color', im_name)
with open(im_path, 'rb') as f:
perturbed_data = pickle.load(f)
im = perturbed_data.get('image')
lbl = perturbed_data.get('fiducial_points')
segment = perturbed_data.get('segment')
im = self.resize_im(im)
im = im.transpose(2, 0, 1)
lbl = self.resize_lbl(lbl)
lbl, segment = self.fiducal_points_lbl(lbl, segment)
lbl = lbl.transpose(2, 0, 1)
im = torch.from_numpy(im)
lbl = torch.from_numpy(lbl).float()
segment = torch.from_numpy(segment).float()
if self.is_return_img_name:
return im, lbl, segment, im_name
return im, lbl, segment
def transform_im(self, im):
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float()
return im
def resize_im(self, im):
im = cv2.resize(im, (992, 992), interpolation=cv2.INTER_LINEAR)
# im = cv2.resize(im, (496, 496), interpolation=cv2.INTER_LINEAR)
return im
def resize_lbl(self, lbl):
lbl = lbl/[960, 1024]*[992, 992]
# lbl = lbl/[960, 1024]*[496, 496]
return lbl
def fiducal_points_lbl(self, fiducial_points, segment):
fiducial_point_gaps = [1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 60] # POINTS NUM: 61, 31, 21, 16, 13, 11, 7, 6, 5, 4, 3, 2
fiducial_points = fiducial_points[::fiducial_point_gaps[self.row_gap], ::fiducial_point_gaps[self.col_gap], :]
segment = segment * [fiducial_point_gaps[self.col_gap], fiducial_point_gaps[self.row_gap]]
return fiducial_points, segment
|