|
|
from keras.src import backend |
|
|
from keras.src.utils.module_utils import tensorflow as tf |
|
|
|
|
|
|
|
|
def get_tensor_spec(t, dynamic_batch=False, name=None): |
|
|
"""Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.""" |
|
|
if isinstance(t, tf.TypeSpec): |
|
|
spec = t |
|
|
elif isinstance(t, tf.__internal__.CompositeTensor): |
|
|
|
|
|
spec = t._type_spec |
|
|
elif hasattr(t, "shape") and hasattr(t, "dtype"): |
|
|
spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) |
|
|
else: |
|
|
return None |
|
|
|
|
|
if not dynamic_batch: |
|
|
return spec |
|
|
|
|
|
shape = spec.shape |
|
|
if shape.rank is None or shape.rank == 0: |
|
|
return spec |
|
|
|
|
|
shape_list = shape.as_list() |
|
|
shape_list[0] = None |
|
|
shape = tf.TensorShape(shape_list) |
|
|
spec._shape = shape |
|
|
return spec |
|
|
|
|
|
|
|
|
def ensure_tensor(inputs, dtype=None): |
|
|
"""Ensures the input is a Tensor, SparseTensor or RaggedTensor.""" |
|
|
if not isinstance(inputs, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)): |
|
|
if backend.backend() == "torch" and backend.is_tensor(inputs): |
|
|
|
|
|
inputs = backend.convert_to_numpy(inputs) |
|
|
inputs = tf.convert_to_tensor(inputs, dtype) |
|
|
if dtype is not None and inputs.dtype != dtype: |
|
|
inputs = tf.cast(inputs, dtype) |
|
|
return inputs |
|
|
|
|
|
|
|
|
def is_ragged_tensor(x): |
|
|
return "ragged_tensor.RaggedTensor" in str(type(x)) |
|
|
|
|
|
|
|
|
def sparse_bincount(inputs, depth, binary_output, dtype, count_weights=None): |
|
|
"""Apply binary or count encoding to an input and return a sparse tensor.""" |
|
|
result = tf.sparse.bincount( |
|
|
inputs, |
|
|
weights=count_weights, |
|
|
minlength=depth, |
|
|
maxlength=depth, |
|
|
axis=-1, |
|
|
binary_output=binary_output, |
|
|
) |
|
|
result = tf.cast(result, dtype) |
|
|
if inputs.shape.rank == 1: |
|
|
output_shape = (depth,) |
|
|
else: |
|
|
batch_size = tf.shape(result)[0] |
|
|
output_shape = (batch_size, depth) |
|
|
result = tf.SparseTensor( |
|
|
indices=result.indices, values=result.values, dense_shape=output_shape |
|
|
) |
|
|
return result |
|
|
|
|
|
|
|
|
def dense_bincount(inputs, depth, binary_output, dtype, count_weights=None): |
|
|
"""Apply binary or count encoding to an input.""" |
|
|
result = tf.math.bincount( |
|
|
inputs, |
|
|
weights=count_weights, |
|
|
minlength=depth, |
|
|
maxlength=depth, |
|
|
dtype=dtype, |
|
|
axis=-1, |
|
|
binary_output=binary_output, |
|
|
) |
|
|
if inputs.shape.rank == 1: |
|
|
result.set_shape(tf.TensorShape((depth,))) |
|
|
else: |
|
|
batch_size = inputs.shape.as_list()[0] |
|
|
result.set_shape(tf.TensorShape((batch_size, depth))) |
|
|
return result |
|
|
|
|
|
|
|
|
def expand_dims(inputs, axis): |
|
|
"""Expand dims on sparse, ragged, or dense tensors.""" |
|
|
if isinstance(inputs, tf.SparseTensor): |
|
|
return tf.sparse.expand_dims(inputs, axis) |
|
|
return tf.expand_dims(inputs, axis) |
|
|
|
|
|
|
|
|
def tf_encode_categorical_inputs( |
|
|
inputs, |
|
|
output_mode, |
|
|
depth, |
|
|
dtype="float32", |
|
|
sparse=False, |
|
|
count_weights=None, |
|
|
idf_weights=None, |
|
|
): |
|
|
"""Encodes categorical inputs according to output_mode. |
|
|
|
|
|
Faster method that relies on bincount. |
|
|
""" |
|
|
|
|
|
if output_mode == "int": |
|
|
return tf.identity(tf.cast(inputs, dtype)) |
|
|
|
|
|
original_shape = inputs.shape |
|
|
|
|
|
if inputs.shape.rank == 0: |
|
|
inputs = expand_dims(inputs, -1) |
|
|
|
|
|
if output_mode == "one_hot": |
|
|
if inputs.shape[-1] != 1: |
|
|
inputs = expand_dims(inputs, -1) |
|
|
|
|
|
if inputs.shape.rank > 2: |
|
|
raise ValueError( |
|
|
"When output_mode is not `'int'`, maximum supported output rank " |
|
|
f"is 2. Received output_mode {output_mode} and input shape " |
|
|
f"{original_shape}, " |
|
|
f"which would result in output rank {inputs.shape.rank}." |
|
|
) |
|
|
|
|
|
binary_output = output_mode in ("multi_hot", "one_hot") |
|
|
if sparse: |
|
|
bincounts = sparse_bincount( |
|
|
inputs, depth, binary_output, dtype, count_weights |
|
|
) |
|
|
else: |
|
|
bincounts = dense_bincount( |
|
|
inputs, depth, binary_output, dtype, count_weights |
|
|
) |
|
|
|
|
|
bincounts = tf.cast(bincounts, dtype) |
|
|
if output_mode != "tf_idf": |
|
|
return bincounts |
|
|
|
|
|
if idf_weights is None: |
|
|
raise ValueError( |
|
|
"When output mode is `'tf_idf'`, idf_weights must be provided. " |
|
|
f"Received: output_mode={output_mode} and idf_weights={idf_weights}" |
|
|
) |
|
|
|
|
|
if sparse: |
|
|
value_weights = tf.gather(idf_weights, bincounts.indices[:, -1]) |
|
|
return tf.SparseTensor( |
|
|
bincounts.indices, |
|
|
value_weights * bincounts.values, |
|
|
bincounts.dense_shape, |
|
|
) |
|
|
else: |
|
|
return tf.multiply(bincounts, idf_weights) |
|
|
|