IIIT-5K / utils.py
Tejveer12's picture
Upload 6 files
dca9ee4 verified
import os
import pickle
import torch
import scipy.io as sio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from crnn import CRNN
class FixHeightResize(object):
"""
Scale images to fixed height
"""
def __init__(self, height=32, minwidth=100):
self.height = height
self.minwidth = minwidth
# img is an instance of PIL.Image
def __call__(self, img):
w, h = img.size
width = max(int(w * self.height / h), self.minwidth)
return img.resize((width, self.height), Image.ANTIALIAS)
class IIIT5k(Dataset):
"""
IIIT-5K dataset,(torch.utils.data.Dataset)
Args:
root (string): Root directory of dataset
training (bool, optional): If True, train the model, otherwise test it (default: True)
fix_width (bool, optional): Scale images to fixed size (default: True)
"""
def __init__(self, root, training=True, fix_width=True):
super(IIIT5k, self).__init__()
data_str = 'traindata' if training else 'testdata'
data = sio.loadmat(os.path.join(root, data_str+'.mat'))[data_str][0]
self.img, self.label = zip(*[(x[0][0], x[1][0]) for x in data])
# image resize + grayscale + transform to tensor
transform = [transforms.Resize((32, 100), Image.BILINEAR)
if fix_width else FixHeightResize(32)]
transform.extend([transforms.Grayscale(), transforms.ToTensor()])
transform = transforms.Compose(transform)
# load images
self.img = [transform(Image.open(root+'/'+img)) for img in self.img]
def __len__(self, ):
return len(self.img)
def __getitem__(self, idx):
return self.img[idx], self.label[idx]
def load_data(root, training=True, fix_width=True):
"""
load IIIT-5K dataset
Args:
root (string): Root directory of dataset
training (bool, optional): If True, train the model, otherwise test it (default: True)
fix_width (bool, optional): Scale images to fixed size (default: True)
Return:
Training set or test set
"""
if training:
batch_size = 128 if fix_width else 1
filename = os.path.join(
root, 'train'+('_fix_width' if fix_width else '')+'.pkl')
if os.path.exists(filename):
dataset = pickle.load(open(filename, 'rb'))
else:
print('==== Loading data.. ====')
dataset = IIIT5k(root, training=True, fix_width=fix_width)
pickle.dump(dataset, open(filename, 'wb'))
dataloader = DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=4)
else:
batch_size = 128 if fix_width else 1
filename = os.path.join(
root, 'test'+('_fix_width' if fix_width else '')+'.pkl')
if os.path.exists(filename):
dataset = pickle.load(open(filename, 'rb'))
else:
print('==== Loading data.. ====')
dataset = IIIT5k(root, training=False, fix_width=fix_width)
pickle.dump(dataset, open(filename, 'wb'))
dataloader = DataLoader(dataset, batch_size=batch_size,
shuffle=False, num_workers=4)
return dataloader
class LabelTransformer(object):
"""
encoder and decoder
Args:
letters (str): Letters contained in the data
"""
def __init__(self, letters):
self.encode_map = {letter: idx+1 for idx, letter in enumerate(letters)}
self.decode_map = ' ' + letters
def encode(self, text):
if isinstance(text, str):
length = [len(text)]
result = [self.encode_map[letter] for letter in text]
else:
length = []
result = []
for word in text:
length.append(len(word))
result.extend([self.encode_map[letter] for letter in word])
return torch.IntTensor(result), torch.IntTensor(length)
def decode(self, text_code):
result = []
for code in text_code:
word = []
for i in range(len(code)):
if code[i] != 0 and (i == 0 or code[i] != code[i-1]):
word.append(self.decode_map[code[i]])
result.append(''.join(word))
return result