File size: 3,603 Bytes
0c84ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# -*- coding: utf-8 -*-
"""Classes for feeding data during training."""
import numpy as np
import pandas as pd
from .helpers import img_extend
from .datahelpers import sequences_to_sparse


class BucketDataIterator():
    """Iterator for feeding CTC model during training."""
    def __init__(self,
                 images,
                 targets,
                 num_buckets=5,
                 slider=(60, 30),
                 augmentation=None,
                 dropout=0.0,
                 train=True):

        self.train = train
        self.slider = slider
        self.augmentation = augmentation
        self.dropout = dropout
        for i in range(len(images)):
            images[i] = img_extend(
                images[i],
                (self.slider[0],
		         max(images[i].shape[1], self.slider[1])))
        in_length = [image.shape[1] for image in images]
        
        # Create pandas dataFrame and sort it by images width (length)
        self.dataFrame = pd.DataFrame({
            'in_length': in_length,
            'images': images,
            'targets': targets}).sort_values('in_length').reset_index(drop=True)

        bsize = int(len(images) / num_buckets)
        self.num_buckets = num_buckets
        self.buckets = []
        for bucket in range(num_buckets-1):
            self.buckets.append(
                self.dataFrame.iloc[bucket * bsize: (bucket+1) * bsize])
        self.buckets.append(self.dataFrame.iloc[(num_buckets-1) * bsize:])

        self.buckets_size = [len(bucket) for bucket in self.buckets]
        self.cursor = np.array([0] * num_buckets)
        self.bucket_order = np.random.permutation(num_buckets)
        self.bucket_cursor = 0
        self.shuffle()
        print("Iterator created.")


    def shuffle(self, idx=None):
        """Shuffle idx bucket or each bucket separately."""
        for i in [idx] if idx is not None else range(self.num_buckets):
            self.buckets[i] = self.buckets[i].sample(frac=1).reset_index(drop=True)
            self.cursor[i] = 0


    def next_batch(self, batch_size):
        """Creates next training batch of size.
        Args:
            batch_size: size of next batch
        Retruns:
            (images, labels, images lengths, labels lengths)
        """
        i_bucket = self.bucket_order[self.bucket_cursor]
        # Increment cursor and shuffle in case of new round
        self.bucket_cursor = (self.bucket_cursor + 1) % self.num_buckets
        if self.bucket_cursor == 0:
            self.bucket_order = np.random.permutation(self.num_buckets)

        if self.cursor[i_bucket] + batch_size > self.buckets_size[i_bucket]:
            self.shuffle(i_bucket)

        # Handle too big batch sizes
        if (batch_size > self.buckets_size[i_bucket]):
            batch_size = self.buckets_size[i_bucket]

        res = self.buckets[i_bucket].iloc[self.cursor[i_bucket]:
                                          self.cursor[i_bucket]+batch_size]
        self.cursor[i_bucket] += batch_size

        # PAD input sequence and output
        input_max = max(res['in_length'])

        input_imgs = np.zeros(
            (batch_size, self.slider[0], input_max, 1), dtype=np.uint8)
        for i, img in enumerate(res['images']):
            input_imgs[i][:, :res['in_length'].values[i], 0] = img
            
        if self.train:
            input_imgs = self.augmentation.augment_images(input_imgs)
        input_imgs = input_imgs.astype(np.float32)

        targets = sequences_to_sparse(res['targets'].values)
        return input_imgs, targets, res['in_length'].values