File size: 4,716 Bytes
1f5470c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import numpy as np
from keras.src.api_export import keras_export
@keras_export(
[
"keras.utils.pad_sequences",
"keras.preprocessing.sequence.pad_sequences",
]
)
def pad_sequences(
sequences,
maxlen=None,
dtype="int32",
padding="pre",
truncating="pre",
value=0.0,
):
"""Pads sequences to the same length.
This function transforms a list (of length `num_samples`)
of sequences (lists of integers)
into a 2D NumPy array of shape `(num_samples, num_timesteps)`.
`num_timesteps` is either the `maxlen` argument if provided,
or the length of the longest sequence in the list.
Sequences that are shorter than `num_timesteps`
are padded with `value` until they are `num_timesteps` long.
Sequences longer than `num_timesteps` are truncated
so that they fit the desired length.
The position where padding or truncation happens is determined by
the arguments `padding` and `truncating`, respectively.
Pre-padding or removing values from the beginning of the sequence is the
default.
>>> sequence = [[1], [2, 3], [4, 5, 6]]
>>> keras.utils.pad_sequences(sequence)
array([[0, 0, 1],
[0, 2, 3],
[4, 5, 6]], dtype=int32)
>>> keras.utils.pad_sequences(sequence, value=-1)
array([[-1, -1, 1],
[-1, 2, 3],
[ 4, 5, 6]], dtype=int32)
>>> keras.utils.pad_sequences(sequence, padding='post')
array([[1, 0, 0],
[2, 3, 0],
[4, 5, 6]], dtype=int32)
>>> keras.utils.pad_sequences(sequence, maxlen=2)
array([[0, 1],
[2, 3],
[5, 6]], dtype=int32)
Args:
sequences: List of sequences (each sequence is a list of integers).
maxlen: Optional Int, maximum length of all sequences. If not provided,
sequences will be padded to the length of the longest individual
sequence.
dtype: (Optional, defaults to `"int32"`). Type of the output sequences.
To pad sequences with variable length strings, you can use `object`.
padding: String, "pre" or "post" (optional, defaults to `"pre"`):
pad either before or after each sequence.
truncating: String, "pre" or "post" (optional, defaults to `"pre"`):
remove values from sequences larger than
`maxlen`, either at the beginning or at the end of the sequences.
value: Float or String, padding value. (Optional, defaults to `0.`)
Returns:
NumPy array with shape `(len(sequences), maxlen)`
"""
if not hasattr(sequences, "__len__"):
raise ValueError("`sequences` must be iterable.")
num_samples = len(sequences)
lengths = []
sample_shape = ()
flag = True
# take the sample shape from the first non empty sequence
# checking for consistency in the main loop below.
for x in sequences:
try:
lengths.append(len(x))
if flag and len(x):
sample_shape = np.asarray(x).shape[1:]
flag = False
except TypeError as e:
raise ValueError(
"`sequences` must be a list of iterables. "
f"Found non-iterable: {str(x)}"
) from e
if maxlen is None:
maxlen = np.max(lengths)
is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(
dtype, np.str_
)
if isinstance(value, str) and dtype is not object and not is_dtype_str:
raise ValueError(
f"`dtype` {dtype} is not compatible with `value`'s type: "
f"{type(value)}\nYou should set `dtype=object` for variable length "
"strings."
)
x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
for idx, s in enumerate(sequences):
if not len(s):
continue # empty list/array was found
if truncating == "pre":
trunc = s[-maxlen:]
elif truncating == "post":
trunc = s[:maxlen]
else:
raise ValueError(f'Truncating type "{truncating}" not understood')
# check `trunc` has expected shape
trunc = np.asarray(trunc, dtype=dtype)
if trunc.shape[1:] != sample_shape:
raise ValueError(
f"Shape of sample {trunc.shape[1:]} of sequence at "
f"position {idx} is different from expected shape "
f"{sample_shape}"
)
if padding == "post":
x[idx, : len(trunc)] = trunc
elif padding == "pre":
x[idx, -len(trunc) :] = trunc
else:
raise ValueError(f'Padding type "{padding}" not understood')
return x
|