training_sem / libs /data /transform.py
kai-2054's picture
Initial commit: add code
cb0ad2d
raw
history blame
2.95 kB
import math
import torch
import random
import numpy as np
from torchvision.transforms import functional as F
from libs.utils.format_translate import table_to_latex
from .utils import extract_fg_bg_spans, cal_cell_spans
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *data):
for transform in self.transforms:
data = transform(*data)
return data
class TableToLabel:
def __init__(self, vocab):
self.vocab = vocab
def __call__(self, image, table=None):
if table is None:
return image, None, None
latex = table_to_latex(table) # image.size = (w, h)
cls_label = self.vocab.words_to_ids(latex)
return image, table, cls_label
class CalRowColSpans:
def __call__(self, image, table=None, cls_label=None):
if table is None:
return image, table, None, None, None, None, None
image_size = (image.width, image.height)
rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span = extract_fg_bg_spans(table, image_size)
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span
class CalCellSpans:
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None):
if table is not None:
cells_span = cal_cell_spans(table)
else:
cells_span = None
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span
class CalHeadBodyDivide:
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None):
if table is None:
divide = None
else:
head_rows = table['head_rows']
divide = len(head_rows)
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide
class ToTensor:
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None):
image = F.to_tensor(image)
if cls_label is not None:
cls_label = torch.tensor(cls_label, dtype=torch.long)
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide
class Normalize:
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None):
image = F.normalize(image, self.mean, self.std, self.inplace)
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide