File size: 6,159 Bytes
470ac18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import pickle
from os.path import join as pjoin
import collections
import json
import torch
import numpy as np

import re
import cv2

from torch.utils import data


def get_data_path(name):
	"""Extract path to data from config file.

	Args:
		name (str): The name of the dataset.

	Returns:
		(str): The path to the root directory containing the dataset.
	"""
	with open('../xgw/segmentation/config.json') as f:
		js = f.read()
	# js = open('config.json').read()
	data = json.loads(js)
	return os.path.expanduser(data[name]['data_path'])

def getDatasets(dir):
	return os.listdir(dir)
'''
Resize the input image into 1024x960 (zooming in or out along the longest side and keeping the aspect ration, then filling zero for padding. )
'''
def resize_image(origin_img, long_edge=1024, short_edge=960):
	# long_edge, short_edge = 2048, 1920
	# long_edge, short_edge = 1024, 960
	# long_edge, short_edge = 512, 480

	im_lr = origin_img.shape[0]
	im_ud = origin_img.shape[1]
	new_img = np.zeros([long_edge, short_edge, 3], dtype=np.uint8)
	new_shape = new_img.shape[:2]
	if im_lr > im_ud:
		img_shrink, base_img_shrink = long_edge, long_edge
		im_ud = int(im_ud / im_lr * base_img_shrink)
		im_ud += 32-im_ud%32
		im_ud = min(im_ud, short_edge)
		im_lr = img_shrink
		origin_img = cv2.resize(origin_img, (im_ud, im_lr), interpolation=cv2.INTER_CUBIC)
		new_img[:, (new_shape[1]-im_ud)//2:new_shape[1]-(new_shape[1]-im_ud)//2] = origin_img
		# mask = np.full(new_shape, 255, dtype='uint8')
		# mask[:, (new_shape[1] - im_ud) // 2:new_shape[1] - (new_shape[1] - im_ud) // 2] = 0
	else:
		img_shrink, base_img_shrink = short_edge, short_edge
		im_lr = int(im_lr / im_ud * base_img_shrink)
		im_lr += 32-im_lr%32
		im_lr = min(im_lr, long_edge)
		im_ud = img_shrink
		origin_img = cv2.resize(origin_img, (im_ud, im_lr), interpolation=cv2.INTER_CUBIC)
		new_img[(new_shape[0] - im_lr) // 2:new_shape[0] - (new_shape[0] - im_lr) // 2, :] = origin_img
	return new_img

class PerturbedDatastsForFiducialPoints_pickle_color_v2_v2(data.Dataset):
	def __init__(self, root, split='1-1', img_shrink=None, is_return_img_name=False, preproccess=False):
		self.root = os.path.expanduser(root)
		self.split = split
		self.img_shrink = img_shrink
		self.is_return_img_name = is_return_img_name
		self.preproccess = preproccess
		# self.mean = np.array([104.00699, 116.66877, 122.67892])
		self.images = collections.defaultdict(list)
		self.labels = collections.defaultdict(list)
		self.row_gap = 1  # value:0, 1, 2;  POINTS NUM: 61, 31, 21
		self.col_gap = 1
		datasets = ['validate', 'test', 'train']

		if self.split == 'test' or self.split == 'eval':
			img_file_list = getDatasets(os.path.join(self.root))
			self.images[self.split] = img_file_list
			# self.images[self.split] = sorted(img_file_list, key=lambda num: (
			# int(re.match(r'(\d+)_(\d+)( copy.png)', num, re.IGNORECASE).group(1)), int(re.match(r'(\d+)_(\d+)( copy.png)', num, re.IGNORECASE).group(2))))
		elif self.split in datasets:
			img_file_list = []
			img_file_list_ = getDatasets(os.path.join(self.root, 'color'))
			for id_ in img_file_list_:
				img_file_list.append(id_.rstrip())

			self.images[self.split] = sorted(img_file_list, key=lambda num: (
			re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(1), int(re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(2))
			, int(re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(3)), re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(4)))
		else:
			raise Exception('load data error')
		# self.checkImg()

	def checkImg(self):
		if self.split == 'validate':
			for im_name in self.images[self.split]:
				# if 'SinglePage' in im_name:
				im_path = pjoin(self.root, self.split, 'color', im_name)
				try:
					with open(im_path, 'rb') as f:
						perturbed_data = pickle.load(f)

					im_shape = perturbed_data.shape
				except:
					print(im_name)
					# os.remove(im_path)

	def __len__(self):
		return len(self.images[self.split])

	def __getitem__(self, item):
		if self.split == 'test':
			im_name = self.images[self.split][item]
			im_path = pjoin(self.root, im_name)

			im = cv2.imread(im_path, flags=cv2.IMREAD_COLOR)

			im = self.resize_im(im)
			im = self.transform_im(im)

			if self.is_return_img_name:
				return im, im_name
			return im
		elif self.split == 'eval':
			im_name = self.images[self.split][item]
			im_path = pjoin(self.root, im_name)

			img = cv2.imread(im_path, flags=cv2.IMREAD_COLOR)

			im = self.resize_im(img)
			im = self.transform_im(im)

			if self.is_return_img_name:
				return im, im_name
			return im, img
			# return im, img, im_name

		else:
			im_name = self.images[self.split][item]

			im_path = pjoin(self.root, 'color', im_name)

			with open(im_path, 'rb') as f:
				perturbed_data = pickle.load(f)

			im = perturbed_data.get('image')
			lbl = perturbed_data.get('fiducial_points')
			segment = perturbed_data.get('segment')

			im = self.resize_im(im)
			im = im.transpose(2, 0, 1)

			lbl = self.resize_lbl(lbl)
			lbl, segment = self.fiducal_points_lbl(lbl, segment)
			lbl = lbl.transpose(2, 0, 1)

			im = torch.from_numpy(im)
			lbl = torch.from_numpy(lbl).float()
			segment = torch.from_numpy(segment).float()

			if self.is_return_img_name:
				return im, lbl, segment, im_name

			return im, lbl, segment

	def transform_im(self, im):
		im = im.transpose(2, 0, 1)
		im = torch.from_numpy(im).float()

		return im

	def resize_im(self, im):
		im = cv2.resize(im, (992, 992), interpolation=cv2.INTER_LINEAR)
		# im = cv2.resize(im, (496, 496), interpolation=cv2.INTER_LINEAR)
		return im

	def resize_lbl(self, lbl):
		lbl = lbl/[960, 1024]*[992, 992]
		# lbl = lbl/[960, 1024]*[496, 496]
		return lbl

	def fiducal_points_lbl(self, fiducial_points, segment):

		fiducial_point_gaps = [1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 60]  # POINTS NUM: 61, 31, 21, 16, 13, 11, 7, 6, 5, 4, 3, 2
		fiducial_points = fiducial_points[::fiducial_point_gaps[self.row_gap], ::fiducial_point_gaps[self.col_gap], :]
		segment = segment * [fiducial_point_gaps[self.col_gap], fiducial_point_gaps[self.row_gap]]
		return fiducial_points, segment