File size: 3,507 Bytes
be611b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np

cimport cython
cimport numpy as np

DTYPE = np.int64
ctypedef np.int64_t DTYPE_t


cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long max_sentences):
    if num_sentences == 0:
        return 0
    if max_sentences > 0 and num_sentences == max_sentences:
        return 1
    if max_tokens > 0 and num_tokens > max_tokens:
        return 1
    return 0


@cython.cdivision(True)
cpdef list batch_by_size_fast(
    np.ndarray[DTYPE_t, ndim=1] indices,
    num_tokens_fn,
    long max_tokens,
    long max_sentences,
    int bsz_mult,
):
    cdef long sample_len = 0
    cdef list sample_lens = []
    cdef list batch = []
    cdef list batches = []
    cdef long mod_len
    cdef long i
    cdef long idx
    cdef long num_tokens
    cdef DTYPE_t[:] indices_view = indices

    for i in range(len(indices_view)):
        idx = indices_view[i]
        num_tokens = num_tokens_fn(idx)
        sample_lens.append(num_tokens)
        sample_len = max(sample_len, num_tokens)

        assert max_tokens <= 0 or sample_len <= max_tokens, (
            "sentence at index {} of size {} exceeds max_tokens "
            "limit of {}!".format(idx, sample_len, max_tokens)
        )
        num_tokens = (len(batch) + 1) * sample_len

        if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences):
            mod_len = max(
                bsz_mult * (len(batch) // bsz_mult),
                len(batch) % bsz_mult,
            )
            batches.append(batch[:mod_len])
            batch = batch[mod_len:]
            sample_lens = sample_lens[mod_len:]
            sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
        batch.append(idx)
    if len(batch) > 0:
        batches.append(batch)
    return batches


cdef _find_valid_shape(
    DTYPE_t[:, :] shapes_view,
    long num_sentences,
    long num_tokens,
):
    """Return index of first valid shape of -1 if none is found."""
    for i in range(shapes_view.shape[0]):
        if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
            return i
    return -1


@cython.cdivision(True)
cpdef list batch_fixed_shapes_fast(
    np.ndarray[DTYPE_t, ndim=1] indices,
    num_tokens_fn,
    np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
):
    cdef long sample_len = 0
    cdef list sample_lens = []
    cdef list batch = []
    cdef list batches = []
    cdef long mod_len
    cdef long i
    cdef long idx
    cdef long num_tokens
    cdef DTYPE_t[:] indices_view = indices
    cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted

    for i in range(len(indices_view)):
        idx = indices_view[i]
        num_tokens = num_tokens_fn(idx)
        sample_lens.append(num_tokens)
        sample_len = max(sample_len, num_tokens)

        shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
        if shape_idx == -1:
            batches.append(batch)
            batch = []
            sample_lens = []
            sample_len = 0
            shapes_view = fixed_shapes_sorted
        elif shape_idx > 0:
            # small optimization for the next call to _find_valid_shape
            shapes_view = shapes_view[shape_idx:]

        batch.append(idx)

    if len(batch) > 0:
        batches.append(batch)

    return batches