Source / dataloader.py
Richard1231's picture
Upload folder using huggingface_hub
470ac18
import os
import pickle
from os.path import join as pjoin
import collections
import json
import torch
import numpy as np
import re
import cv2
from torch.utils import data
def get_data_path(name):
"""Extract path to data from config file.
Args:
name (str): The name of the dataset.
Returns:
(str): The path to the root directory containing the dataset.
"""
with open('../xgw/segmentation/config.json') as f:
js = f.read()
# js = open('config.json').read()
data = json.loads(js)
return os.path.expanduser(data[name]['data_path'])
def getDatasets(dir):
return os.listdir(dir)
'''
Resize the input image into 1024x960 (zooming in or out along the longest side and keeping the aspect ration, then filling zero for padding. )
'''
def resize_image(origin_img, long_edge=1024, short_edge=960):
# long_edge, short_edge = 2048, 1920
# long_edge, short_edge = 1024, 960
# long_edge, short_edge = 512, 480
im_lr = origin_img.shape[0]
im_ud = origin_img.shape[1]
new_img = np.zeros([long_edge, short_edge, 3], dtype=np.uint8)
new_shape = new_img.shape[:2]
if im_lr > im_ud:
img_shrink, base_img_shrink = long_edge, long_edge
im_ud = int(im_ud / im_lr * base_img_shrink)
im_ud += 32-im_ud%32
im_ud = min(im_ud, short_edge)
im_lr = img_shrink
origin_img = cv2.resize(origin_img, (im_ud, im_lr), interpolation=cv2.INTER_CUBIC)
new_img[:, (new_shape[1]-im_ud)//2:new_shape[1]-(new_shape[1]-im_ud)//2] = origin_img
# mask = np.full(new_shape, 255, dtype='uint8')
# mask[:, (new_shape[1] - im_ud) // 2:new_shape[1] - (new_shape[1] - im_ud) // 2] = 0
else:
img_shrink, base_img_shrink = short_edge, short_edge
im_lr = int(im_lr / im_ud * base_img_shrink)
im_lr += 32-im_lr%32
im_lr = min(im_lr, long_edge)
im_ud = img_shrink
origin_img = cv2.resize(origin_img, (im_ud, im_lr), interpolation=cv2.INTER_CUBIC)
new_img[(new_shape[0] - im_lr) // 2:new_shape[0] - (new_shape[0] - im_lr) // 2, :] = origin_img
return new_img
class PerturbedDatastsForFiducialPoints_pickle_color_v2_v2(data.Dataset):
def __init__(self, root, split='1-1', img_shrink=None, is_return_img_name=False, preproccess=False):
self.root = os.path.expanduser(root)
self.split = split
self.img_shrink = img_shrink
self.is_return_img_name = is_return_img_name
self.preproccess = preproccess
# self.mean = np.array([104.00699, 116.66877, 122.67892])
self.images = collections.defaultdict(list)
self.labels = collections.defaultdict(list)
self.row_gap = 1 # value:0, 1, 2; POINTS NUM: 61, 31, 21
self.col_gap = 1
datasets = ['validate', 'test', 'train']
if self.split == 'test' or self.split == 'eval':
img_file_list = getDatasets(os.path.join(self.root))
self.images[self.split] = img_file_list
# self.images[self.split] = sorted(img_file_list, key=lambda num: (
# int(re.match(r'(\d+)_(\d+)( copy.png)', num, re.IGNORECASE).group(1)), int(re.match(r'(\d+)_(\d+)( copy.png)', num, re.IGNORECASE).group(2))))
elif self.split in datasets:
img_file_list = []
img_file_list_ = getDatasets(os.path.join(self.root, 'color'))
for id_ in img_file_list_:
img_file_list.append(id_.rstrip())
self.images[self.split] = sorted(img_file_list, key=lambda num: (
re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(1), int(re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(2))
, int(re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(3)), re.match(r'(\w+\d*)_(\d+)_(\d+)_(\w+)', num, re.IGNORECASE).group(4)))
else:
raise Exception('load data error')
# self.checkImg()
def checkImg(self):
if self.split == 'validate':
for im_name in self.images[self.split]:
# if 'SinglePage' in im_name:
im_path = pjoin(self.root, self.split, 'color', im_name)
try:
with open(im_path, 'rb') as f:
perturbed_data = pickle.load(f)
im_shape = perturbed_data.shape
except:
print(im_name)
# os.remove(im_path)
def __len__(self):
return len(self.images[self.split])
def __getitem__(self, item):
if self.split == 'test':
im_name = self.images[self.split][item]
im_path = pjoin(self.root, im_name)
im = cv2.imread(im_path, flags=cv2.IMREAD_COLOR)
im = self.resize_im(im)
im = self.transform_im(im)
if self.is_return_img_name:
return im, im_name
return im
elif self.split == 'eval':
im_name = self.images[self.split][item]
im_path = pjoin(self.root, im_name)
img = cv2.imread(im_path, flags=cv2.IMREAD_COLOR)
im = self.resize_im(img)
im = self.transform_im(im)
if self.is_return_img_name:
return im, im_name
return im, img
# return im, img, im_name
else:
im_name = self.images[self.split][item]
im_path = pjoin(self.root, 'color', im_name)
with open(im_path, 'rb') as f:
perturbed_data = pickle.load(f)
im = perturbed_data.get('image')
lbl = perturbed_data.get('fiducial_points')
segment = perturbed_data.get('segment')
im = self.resize_im(im)
im = im.transpose(2, 0, 1)
lbl = self.resize_lbl(lbl)
lbl, segment = self.fiducal_points_lbl(lbl, segment)
lbl = lbl.transpose(2, 0, 1)
im = torch.from_numpy(im)
lbl = torch.from_numpy(lbl).float()
segment = torch.from_numpy(segment).float()
if self.is_return_img_name:
return im, lbl, segment, im_name
return im, lbl, segment
def transform_im(self, im):
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float()
return im
def resize_im(self, im):
im = cv2.resize(im, (992, 992), interpolation=cv2.INTER_LINEAR)
# im = cv2.resize(im, (496, 496), interpolation=cv2.INTER_LINEAR)
return im
def resize_lbl(self, lbl):
lbl = lbl/[960, 1024]*[992, 992]
# lbl = lbl/[960, 1024]*[496, 496]
return lbl
def fiducal_points_lbl(self, fiducial_points, segment):
fiducial_point_gaps = [1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 60] # POINTS NUM: 61, 31, 21, 16, 13, 11, 7, 6, 5, 4, 3, 2
fiducial_points = fiducial_points[::fiducial_point_gaps[self.row_gap], ::fiducial_point_gaps[self.col_gap], :]
segment = segment * [fiducial_point_gaps[self.col_gap], fiducial_point_gaps[self.row_gap]]
return fiducial_points, segment