training_sem / libs /data /dataset.py
kai-2054's picture
Initial commit: add code
cb0ad2d
import os
import copy
import json
import pickle
import random
from torch._C import layout
import tqdm
import torch
import numpy as np
from PIL import Image
from .list_record_cache import ListRecordLoader
from libs.utils.format_translate import table_to_html
class LRCRecordLoader:
def __init__(self, lrc_path, data_dir=''):
self.loader = ListRecordLoader(lrc_path)
self.data_root_dir = data_dir
def __len__(self):
return len(self.loader)
def get_info(self, idx):
table = self.loader.get_record(idx)
image = Image.open(table['image_path']).convert('RGB')
w = image.width
h = image.height
n_rows, n_cols = table['layout'].shape
n_cells = n_rows * n_cols
return w, h, n_cells
def get_data(self, idx):
table = self.loader.get_record(idx)
img_path = os.path.join(self.data_root_dir, table['image_path'])
image = Image.open(img_path).convert('RGB')
return image, table
class Dataset:
def __init__(self, loaders, transforms):
self.loaders = loaders
self.transforms = transforms
def _match_loader(self, idx):
offset = 0
for loader in self.loaders:
if len(loader) + offset > idx:
return loader, idx - offset
else:
offset += len(loader)
raise IndexError()
def get_info(self, idx):
loader, rela_idx = self._match_loader(idx)
return loader.get_info(rela_idx)
def __len__(self):
return sum([len(loader) for loader in self.loaders])
def __getitem__(self,idx):
try:
loader, rela_idx = self._match_loader(idx)
image, table = loader.get_data(rela_idx)
image, _, cls_label, \
rows_fg_span, rows_bg_span, \
cols_fg_span, cols_bg_span, \
cells_span, divide = self.transforms(image, table) if 'layout' in table.keys() else self.transforms(image)
return dict(
id=idx,
image_size=(image.shape[2], image.shape[1]),
image=image,
cls_label=cls_label,
rows_fg_span=rows_fg_span,
rows_bg_span=rows_bg_span,
cols_fg_span=cols_fg_span,
cols_bg_span=cols_bg_span,
cells_span=cells_span,
layout=table['layout'] if 'layout' in table.keys() else None,
divide=divide,
table=table
)
except Exception as e:
print('Error occured while load data: %d' % idx)
raise e
def collate_func(batch_data):
batch_size = len(batch_data)
image_dim = batch_data[0]['image'].shape[0]
max_h = max([data['image'].shape[1] for data in batch_data])
max_w = max([data['image'].shape[2] for data in batch_data])
batch_id = list()
batch_image_size = list()
batch_image = torch.zeros([batch_size, image_dim, max_h, max_w], dtype=torch.float)
batch_image_mask = torch.zeros([batch_size, 1, max_h, max_w], dtype=torch.float)
batch_rows_fg_span = list()
batch_rows_bg_span = list()
batch_cols_fg_span = list()
batch_cols_bg_span = list()
batch_cells_span = list()
batch_divide = list()
tables = list()
if all([(data['cls_label'] is None) and (data['layout'] is None) for data in batch_data]):
batch_cls_label = list()
batch_label_mask = list()
batch_layout = list()
else:
assert not any([(data['cls_label'] is None) or (data['layout'] is None) for data in batch_data])
max_label_length = max([data['cls_label'].shape[0] for data in batch_data])
batch_cls_label = torch.zeros([batch_size, max_label_length], dtype=torch.long)
batch_label_mask = torch.zeros([batch_size, max_label_length], dtype=torch.float)
max_nr = max([data['layout'].shape[0] for data in batch_data])
max_nc = max([data['layout'].shape[1] for data in batch_data])
batch_layout = torch.full([batch_size, max_nr, max_nc], -1, dtype=torch.float)
for batch_idx, data in enumerate(batch_data):
batch_id.append(data['id'])
batch_image_size.append(data['image_size'])
_, cur_h, cur_w = data['image'].shape
batch_image[batch_idx, :, :cur_h, :cur_w] = data["image"]
batch_image_mask[batch_idx, :, :cur_h, :cur_w] = 1
if (data['cls_label'] is None) and (data['layout'] is None):
batch_cls_label.append(data["cls_label"])
batch_label_mask.append(None)
batch_layout.append(data["layout"])
else:
label_length = data['cls_label'].shape[0]
batch_cls_label[batch_idx, :label_length] = data['cls_label']
batch_label_mask[batch_idx, :label_length] = 1.0
layout_nr, layout_nc = data["layout" ].shape
batch_layout[batch_idx, :layout_nr, :layout_nc] = torch.from_numpy(data['layout']).float()
batch_rows_fg_span.append(data["rows_fg_span"])
batch_rows_bg_span.append(data['rows_bg_span'])
batch_cols_fg_span.append(data["cols_fg_span"])
batch_cols_bg_span.append(data["cols_bg_span"])
batch_cells_span.append(data["cells_span"])
batch_divide.append(data["divide"])
tables.append(data['table'])
batch_divide = torch.tensor(batch_divide, dtype=torch.long) if batch_divide[0] is not None else batch_divide
return dict(
ids=batch_id,
images_size=batch_image_size,
images=batch_image,
images_mask=batch_image_mask,
cls_labels=batch_cls_label,
labels_mask=batch_label_mask,
rows_fg_spans=batch_rows_fg_span,
rows_bg_spans=batch_rows_bg_span,
cols_fg_spans=batch_cols_fg_span,
cols_bg_spans=batch_cols_bg_span,
cells_spans=batch_cells_span,
divide_labels=batch_divide,
layouts=batch_layout,
tables=tables
)