|
|
import math |
|
|
import torch |
|
|
import random |
|
|
import numpy as np |
|
|
from torchvision.transforms import functional as F |
|
|
from libs.utils.format_translate import table_to_latex |
|
|
from .utils import extract_fg_bg_spans, cal_cell_spans |
|
|
|
|
|
|
|
|
class Compose: |
|
|
def __init__(self, transforms): |
|
|
self.transforms = transforms |
|
|
|
|
|
def __call__(self, *data): |
|
|
for transform in self.transforms: |
|
|
data = transform(*data) |
|
|
return data |
|
|
|
|
|
class TableToLabel: |
|
|
def __init__(self, vocab): |
|
|
self.vocab = vocab |
|
|
|
|
|
def __call__(self, image, table=None): |
|
|
if table is None: |
|
|
return image, None, None |
|
|
latex = table_to_latex(table) |
|
|
cls_label = self.vocab.words_to_ids(latex) |
|
|
return image, table, cls_label |
|
|
|
|
|
class CalRowColSpans: |
|
|
def __call__(self, image, table=None, cls_label=None): |
|
|
if table is None: |
|
|
return image, table, None, None, None, None, None |
|
|
image_size = (image.width, image.height) |
|
|
rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span = extract_fg_bg_spans(table, image_size) |
|
|
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span |
|
|
|
|
|
class CalCellSpans: |
|
|
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None): |
|
|
if table is not None: |
|
|
cells_span = cal_cell_spans(table) |
|
|
else: |
|
|
cells_span = None |
|
|
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span |
|
|
|
|
|
class CalHeadBodyDivide: |
|
|
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None): |
|
|
if table is None: |
|
|
divide = None |
|
|
else: |
|
|
head_rows = table['head_rows'] |
|
|
divide = len(head_rows) |
|
|
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide |
|
|
|
|
|
class ToTensor: |
|
|
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None): |
|
|
image = F.to_tensor(image) |
|
|
if cls_label is not None: |
|
|
cls_label = torch.tensor(cls_label, dtype=torch.long) |
|
|
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide |
|
|
|
|
|
class Normalize: |
|
|
def __init__(self, mean, std, inplace=False): |
|
|
self.mean = mean |
|
|
self.std = std |
|
|
self.inplace = inplace |
|
|
|
|
|
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None): |
|
|
image = F.normalize(image, self.mean, self.std, self.inplace) |
|
|
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide |
|
|
|