| import os | |
| import copy | |
| import json | |
| import pickle | |
| import random | |
| from torch._C import layout | |
| import tqdm | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from .list_record_cache import ListRecordLoader | |
| from libs.utils.format_translate import table_to_html | |
| class LRCRecordLoader: | |
| def __init__(self, lrc_path, data_dir=''): | |
| self.loader = ListRecordLoader(lrc_path) | |
| self.data_root_dir = data_dir | |
| def __len__(self): | |
| return len(self.loader) | |
| def get_info(self, idx): | |
| table = self.loader.get_record(idx) | |
| image = Image.open(table['image_path']).convert('RGB') | |
| w = image.width | |
| h = image.height | |
| n_rows, n_cols = table['layout'].shape | |
| n_cells = n_rows * n_cols | |
| return w, h, n_cells | |
| def get_data(self, idx): | |
| table = self.loader.get_record(idx) | |
| img_path = os.path.join(self.data_root_dir, table['image_path']) | |
| image = Image.open(img_path).convert('RGB') | |
| return image, table | |
| class Dataset: | |
| def __init__(self, loaders, transforms): | |
| self.loaders = loaders | |
| self.transforms = transforms | |
| def _match_loader(self, idx): | |
| offset = 0 | |
| for loader in self.loaders: | |
| if len(loader) + offset > idx: | |
| return loader, idx - offset | |
| else: | |
| offset += len(loader) | |
| raise IndexError() | |
| def get_info(self, idx): | |
| loader, rela_idx = self._match_loader(idx) | |
| return loader.get_info(rela_idx) | |
| def __len__(self): | |
| return sum([len(loader) for loader in self.loaders]) | |
| def __getitem__(self,idx): | |
| try: | |
| loader, rela_idx = self._match_loader(idx) | |
| image, table = loader.get_data(rela_idx) | |
| image, _, cls_label, \ | |
| rows_fg_span, rows_bg_span, \ | |
| cols_fg_span, cols_bg_span, \ | |
| cells_span, divide = self.transforms(image, table) if 'layout' in table.keys() else self.transforms(image) | |
| return dict( | |
| id=idx, | |
| image_size=(image.shape[2], image.shape[1]), | |
| image=image, | |
| cls_label=cls_label, | |
| rows_fg_span=rows_fg_span, | |
| rows_bg_span=rows_bg_span, | |
| cols_fg_span=cols_fg_span, | |
| cols_bg_span=cols_bg_span, | |
| cells_span=cells_span, | |
| layout=table['layout'] if 'layout' in table.keys() else None, | |
| divide=divide, | |
| table=table | |
| ) | |
| except Exception as e: | |
| print('Error occured while load data: %d' % idx) | |
| raise e | |
| def collate_func(batch_data): | |
| batch_size = len(batch_data) | |
| image_dim = batch_data[0]['image'].shape[0] | |
| max_h = max([data['image'].shape[1] for data in batch_data]) | |
| max_w = max([data['image'].shape[2] for data in batch_data]) | |
| batch_id = list() | |
| batch_image_size = list() | |
| batch_image = torch.zeros([batch_size, image_dim, max_h, max_w], dtype=torch.float) | |
| batch_image_mask = torch.zeros([batch_size, 1, max_h, max_w], dtype=torch.float) | |
| batch_rows_fg_span = list() | |
| batch_rows_bg_span = list() | |
| batch_cols_fg_span = list() | |
| batch_cols_bg_span = list() | |
| batch_cells_span = list() | |
| batch_divide = list() | |
| tables = list() | |
| if all([(data['cls_label'] is None) and (data['layout'] is None) for data in batch_data]): | |
| batch_cls_label = list() | |
| batch_label_mask = list() | |
| batch_layout = list() | |
| else: | |
| assert not any([(data['cls_label'] is None) or (data['layout'] is None) for data in batch_data]) | |
| max_label_length = max([data['cls_label'].shape[0] for data in batch_data]) | |
| batch_cls_label = torch.zeros([batch_size, max_label_length], dtype=torch.long) | |
| batch_label_mask = torch.zeros([batch_size, max_label_length], dtype=torch.float) | |
| max_nr = max([data['layout'].shape[0] for data in batch_data]) | |
| max_nc = max([data['layout'].shape[1] for data in batch_data]) | |
| batch_layout = torch.full([batch_size, max_nr, max_nc], -1, dtype=torch.float) | |
| for batch_idx, data in enumerate(batch_data): | |
| batch_id.append(data['id']) | |
| batch_image_size.append(data['image_size']) | |
| _, cur_h, cur_w = data['image'].shape | |
| batch_image[batch_idx, :, :cur_h, :cur_w] = data["image"] | |
| batch_image_mask[batch_idx, :, :cur_h, :cur_w] = 1 | |
| if (data['cls_label'] is None) and (data['layout'] is None): | |
| batch_cls_label.append(data["cls_label"]) | |
| batch_label_mask.append(None) | |
| batch_layout.append(data["layout"]) | |
| else: | |
| label_length = data['cls_label'].shape[0] | |
| batch_cls_label[batch_idx, :label_length] = data['cls_label'] | |
| batch_label_mask[batch_idx, :label_length] = 1.0 | |
| layout_nr, layout_nc = data["layout" ].shape | |
| batch_layout[batch_idx, :layout_nr, :layout_nc] = torch.from_numpy(data['layout']).float() | |
| batch_rows_fg_span.append(data["rows_fg_span"]) | |
| batch_rows_bg_span.append(data['rows_bg_span']) | |
| batch_cols_fg_span.append(data["cols_fg_span"]) | |
| batch_cols_bg_span.append(data["cols_bg_span"]) | |
| batch_cells_span.append(data["cells_span"]) | |
| batch_divide.append(data["divide"]) | |
| tables.append(data['table']) | |
| batch_divide = torch.tensor(batch_divide, dtype=torch.long) if batch_divide[0] is not None else batch_divide | |
| return dict( | |
| ids=batch_id, | |
| images_size=batch_image_size, | |
| images=batch_image, | |
| images_mask=batch_image_mask, | |
| cls_labels=batch_cls_label, | |
| labels_mask=batch_label_mask, | |
| rows_fg_spans=batch_rows_fg_span, | |
| rows_bg_spans=batch_rows_bg_span, | |
| cols_fg_spans=batch_cols_fg_span, | |
| cols_bg_spans=batch_cols_bg_span, | |
| cells_spans=batch_cells_span, | |
| divide_labels=batch_divide, | |
| layouts=batch_layout, | |
| tables=tables | |
| ) |