maotao / fairseq /data /data_utils_fast.pyx
julse's picture
Upload 551 files
be611b4 verified
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
cimport cython
cimport numpy as np
DTYPE = np.int64
ctypedef np.int64_t DTYPE_t
cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long max_sentences):
if num_sentences == 0:
return 0
if max_sentences > 0 and num_sentences == max_sentences:
return 1
if max_tokens > 0 and num_tokens > max_tokens:
return 1
return 0
@cython.cdivision(True)
cpdef list batch_by_size_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
long max_tokens,
long max_sentences,
int bsz_mult,
):
cdef long sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef long mod_len
cdef long i
cdef long idx
cdef long num_tokens
cdef DTYPE_t[:] indices_view = indices
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
assert max_tokens <= 0 or sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens)
)
num_tokens = (len(batch) + 1) * sample_len
if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
batches.append(batch[:mod_len])
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches
cdef _find_valid_shape(
DTYPE_t[:, :] shapes_view,
long num_sentences,
long num_tokens,
):
"""Return index of first valid shape of -1 if none is found."""
for i in range(shapes_view.shape[0]):
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
return i
return -1
@cython.cdivision(True)
cpdef list batch_fixed_shapes_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
):
cdef long sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef long mod_len
cdef long i
cdef long idx
cdef long num_tokens
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
if shape_idx == -1:
batches.append(batch)
batch = []
sample_lens = []
sample_len = 0
shapes_view = fixed_shapes_sorted
elif shape_idx > 0:
# small optimization for the next call to _find_valid_shape
shapes_view = shapes_view[shape_idx:]
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches