File size: 4,931 Bytes
1f5470c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from keras.src import backend
from keras.src.utils.module_utils import tensorflow as tf
def get_tensor_spec(t, dynamic_batch=False, name=None):
"""Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
if isinstance(t, tf.TypeSpec):
spec = t
elif isinstance(t, tf.__internal__.CompositeTensor):
# Check for ExtensionTypes
spec = t._type_spec
elif hasattr(t, "shape") and hasattr(t, "dtype"):
spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
else:
return None # Allow non-Tensors to pass through.
if not dynamic_batch:
return spec
shape = spec.shape
if shape.rank is None or shape.rank == 0:
return spec
shape_list = shape.as_list()
shape_list[0] = None
shape = tf.TensorShape(shape_list)
spec._shape = shape
return spec
def ensure_tensor(inputs, dtype=None):
"""Ensures the input is a Tensor, SparseTensor or RaggedTensor."""
if not isinstance(inputs, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)):
if backend.backend() == "torch" and backend.is_tensor(inputs):
# Plain `np.asarray()` conversion fails with PyTorch.
inputs = backend.convert_to_numpy(inputs)
inputs = tf.convert_to_tensor(inputs, dtype)
if dtype is not None and inputs.dtype != dtype:
inputs = tf.cast(inputs, dtype)
return inputs
def is_ragged_tensor(x):
return "ragged_tensor.RaggedTensor" in str(type(x))
def sparse_bincount(inputs, depth, binary_output, dtype, count_weights=None):
"""Apply binary or count encoding to an input and return a sparse tensor."""
result = tf.sparse.bincount(
inputs,
weights=count_weights,
minlength=depth,
maxlength=depth,
axis=-1,
binary_output=binary_output,
)
result = tf.cast(result, dtype)
if inputs.shape.rank == 1:
output_shape = (depth,)
else:
batch_size = tf.shape(result)[0]
output_shape = (batch_size, depth)
result = tf.SparseTensor(
indices=result.indices, values=result.values, dense_shape=output_shape
)
return result
def dense_bincount(inputs, depth, binary_output, dtype, count_weights=None):
"""Apply binary or count encoding to an input."""
result = tf.math.bincount(
inputs,
weights=count_weights,
minlength=depth,
maxlength=depth,
dtype=dtype,
axis=-1,
binary_output=binary_output,
)
if inputs.shape.rank == 1:
result.set_shape(tf.TensorShape((depth,)))
else:
batch_size = inputs.shape.as_list()[0]
result.set_shape(tf.TensorShape((batch_size, depth)))
return result
def expand_dims(inputs, axis):
"""Expand dims on sparse, ragged, or dense tensors."""
if isinstance(inputs, tf.SparseTensor):
return tf.sparse.expand_dims(inputs, axis)
return tf.expand_dims(inputs, axis)
def tf_encode_categorical_inputs(
inputs,
output_mode,
depth,
dtype="float32",
sparse=False,
count_weights=None,
idf_weights=None,
):
"""Encodes categorical inputs according to output_mode.
Faster method that relies on bincount.
"""
if output_mode == "int":
return tf.identity(tf.cast(inputs, dtype))
original_shape = inputs.shape
# In all cases, we should uprank scalar input to a single sample.
if inputs.shape.rank == 0:
inputs = expand_dims(inputs, -1)
# One hot will unprank only if the final output dimension is not already 1.
if output_mode == "one_hot":
if inputs.shape[-1] != 1:
inputs = expand_dims(inputs, -1)
if inputs.shape.rank > 2:
raise ValueError(
"When output_mode is not `'int'`, maximum supported output rank "
f"is 2. Received output_mode {output_mode} and input shape "
f"{original_shape}, "
f"which would result in output rank {inputs.shape.rank}."
)
binary_output = output_mode in ("multi_hot", "one_hot")
if sparse:
bincounts = sparse_bincount(
inputs, depth, binary_output, dtype, count_weights
)
else:
bincounts = dense_bincount(
inputs, depth, binary_output, dtype, count_weights
)
bincounts = tf.cast(bincounts, dtype)
if output_mode != "tf_idf":
return bincounts
if idf_weights is None:
raise ValueError(
"When output mode is `'tf_idf'`, idf_weights must be provided. "
f"Received: output_mode={output_mode} and idf_weights={idf_weights}"
)
if sparse:
value_weights = tf.gather(idf_weights, bincounts.indices[:, -1])
return tf.SparseTensor(
bincounts.indices,
value_weights * bincounts.values,
bincounts.dense_shape,
)
else:
return tf.multiply(bincounts, idf_weights)
|