File size: 1,270 Bytes
762eb29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch.nn.utils.rnn import pad_sequence

# CTC collate
def custom_collate(data):

    target_lengths = [len(d['label']) for d in data]
    labels = [d['label'] for d in data]
    inputs = [d['img'].tolist() for d in data]
    idx = [d['idx'] for d in data]
    raw_label = [d['raw_label'] for d in data]

    target_lengths = torch.tensor(target_lengths)
    labels = pad_sequence(labels, batch_first=True)
    inputs = torch.tensor(inputs)
    idx = torch.tensor(idx)

    return { #(6)
        'idx': idx,
        'img': inputs,
        'label': labels,
        'target_lengths': target_lengths,
        'raw_label': raw_label,
    }

def create_char_dicts(list_strings):
    text_to_seq = {}
    seq_to_text = {}
    value = 1 # 0 is blank token

    for text in list_strings:
        for character in text:
            if character not in text_to_seq:
                text_to_seq[character] = value
                seq_to_text[value] = character
                value += 1
    return text_to_seq, seq_to_text

def sample_text_to_seq(list_strings, mydict):
    return [mydict.get(character, "") for character in list_strings]

def sample_seq_to_text(list_strings, mydict):
    return ''.join([mydict.get(character, "") for character in list_strings])