File size: 2,945 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import math
import torch
import random
import numpy as np
from torchvision.transforms import functional as F
from libs.utils.format_translate import table_to_latex
from .utils import extract_fg_bg_spans, cal_cell_spans


class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, *data):
        for transform in self.transforms:
            data = transform(*data)
        return data

class TableToLabel:
    def __init__(self, vocab):
        self.vocab = vocab

    def __call__(self, image, table=None):
        if table is None:
            return image, None, None
        latex = table_to_latex(table) # image.size = (w, h)
        cls_label = self.vocab.words_to_ids(latex)
        return image, table, cls_label

class CalRowColSpans:
    def __call__(self, image, table=None, cls_label=None):
        if table is None:
            return image, table, None, None, None, None, None
        image_size = (image.width, image.height)
        rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span = extract_fg_bg_spans(table, image_size)
        return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span

class CalCellSpans:
    def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None):
        if table is not None:
            cells_span = cal_cell_spans(table)
        else:
            cells_span = None
        return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span

class CalHeadBodyDivide:
    def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None):
        if table is None:
            divide = None
        else:
            head_rows = table['head_rows']
            divide = len(head_rows)
        return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide

class ToTensor:
    def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None):
        image = F.to_tensor(image)
        if cls_label is not None:
            cls_label = torch.tensor(cls_label, dtype=torch.long)
        return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide

class Normalize:
    def __init__(self, mean, std, inplace=False):
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None):
        image = F.normalize(image, self.mean, self.std, self.inplace)
        return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide