|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
cimport cython |
|
|
cimport numpy as np |
|
|
|
|
|
DTYPE = np.int64 |
|
|
ctypedef np.int64_t DTYPE_t |
|
|
|
|
|
|
|
|
cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long max_sentences): |
|
|
if num_sentences == 0: |
|
|
return 0 |
|
|
if max_sentences > 0 and num_sentences == max_sentences: |
|
|
return 1 |
|
|
if max_tokens > 0 and num_tokens > max_tokens: |
|
|
return 1 |
|
|
return 0 |
|
|
|
|
|
|
|
|
@cython.cdivision(True) |
|
|
cpdef list batch_by_size_fast( |
|
|
np.ndarray[DTYPE_t, ndim=1] indices, |
|
|
num_tokens_fn, |
|
|
long max_tokens, |
|
|
long max_sentences, |
|
|
int bsz_mult, |
|
|
): |
|
|
cdef long sample_len = 0 |
|
|
cdef list sample_lens = [] |
|
|
cdef list batch = [] |
|
|
cdef list batches = [] |
|
|
cdef long mod_len |
|
|
cdef long i |
|
|
cdef long idx |
|
|
cdef long num_tokens |
|
|
cdef DTYPE_t[:] indices_view = indices |
|
|
|
|
|
for i in range(len(indices_view)): |
|
|
idx = indices_view[i] |
|
|
num_tokens = num_tokens_fn(idx) |
|
|
sample_lens.append(num_tokens) |
|
|
sample_len = max(sample_len, num_tokens) |
|
|
|
|
|
assert max_tokens <= 0 or sample_len <= max_tokens, ( |
|
|
"sentence at index {} of size {} exceeds max_tokens " |
|
|
"limit of {}!".format(idx, sample_len, max_tokens) |
|
|
) |
|
|
num_tokens = (len(batch) + 1) * sample_len |
|
|
|
|
|
if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences): |
|
|
mod_len = max( |
|
|
bsz_mult * (len(batch) |
|
|
len(batch) % bsz_mult, |
|
|
) |
|
|
batches.append(batch[:mod_len]) |
|
|
batch = batch[mod_len:] |
|
|
sample_lens = sample_lens[mod_len:] |
|
|
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 |
|
|
batch.append(idx) |
|
|
if len(batch) > 0: |
|
|
batches.append(batch) |
|
|
return batches |
|
|
|
|
|
|
|
|
cdef _find_valid_shape( |
|
|
DTYPE_t[:, :] shapes_view, |
|
|
long num_sentences, |
|
|
long num_tokens, |
|
|
): |
|
|
"""Return index of first valid shape of -1 if none is found.""" |
|
|
for i in range(shapes_view.shape[0]): |
|
|
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: |
|
|
return i |
|
|
return -1 |
|
|
|
|
|
|
|
|
@cython.cdivision(True) |
|
|
cpdef list batch_fixed_shapes_fast( |
|
|
np.ndarray[DTYPE_t, ndim=1] indices, |
|
|
num_tokens_fn, |
|
|
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, |
|
|
): |
|
|
cdef long sample_len = 0 |
|
|
cdef list sample_lens = [] |
|
|
cdef list batch = [] |
|
|
cdef list batches = [] |
|
|
cdef long mod_len |
|
|
cdef long i |
|
|
cdef long idx |
|
|
cdef long num_tokens |
|
|
cdef DTYPE_t[:] indices_view = indices |
|
|
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted |
|
|
|
|
|
for i in range(len(indices_view)): |
|
|
idx = indices_view[i] |
|
|
num_tokens = num_tokens_fn(idx) |
|
|
sample_lens.append(num_tokens) |
|
|
sample_len = max(sample_len, num_tokens) |
|
|
|
|
|
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) |
|
|
if shape_idx == -1: |
|
|
batches.append(batch) |
|
|
batch = [] |
|
|
sample_lens = [] |
|
|
sample_len = 0 |
|
|
shapes_view = fixed_shapes_sorted |
|
|
elif shape_idx > 0: |
|
|
|
|
|
shapes_view = shapes_view[shape_idx:] |
|
|
|
|
|
batch.append(idx) |
|
|
|
|
|
if len(batch) > 0: |
|
|
batches.append(batch) |
|
|
|
|
|
return batches |
|
|
|