AIDUDE0541 commited on
Commit
d219b44
·
verified ·
1 Parent(s): 26abf77

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__init__.py +16 -0
  2. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__pycache__/operation_utils.cpython-310.pyc +0 -0
  3. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__pycache__/symbolic_arguments.cpython-310.pyc +0 -0
  4. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/core.py +1167 -0
  5. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/function.py +423 -0
  6. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/image.py +1235 -0
  7. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/linalg.py +707 -0
  8. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/math.py +1046 -0
  9. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/nn.py +2653 -0
  10. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/node.py +143 -0
  11. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/numpy.py +0 -0
  12. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/operation.py +316 -0
  13. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/operation_utils.py +421 -0
  14. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/symbolic_arguments.py +46 -0
  15. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__init__.py +121 -0
  16. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/__init__.cpython-310.pyc +0 -0
  17. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adadelta.cpython-310.pyc +0 -0
  18. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adafactor.cpython-310.pyc +0 -0
  19. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adagrad.cpython-310.pyc +0 -0
  20. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adam.cpython-310.pyc +0 -0
  21. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adamax.cpython-310.pyc +0 -0
  22. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adamw.cpython-310.pyc +0 -0
  23. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/base_optimizer.cpython-310.pyc +0 -0
  24. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/ftrl.cpython-310.pyc +0 -0
  25. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/lamb.cpython-310.pyc +0 -0
  26. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/lion.cpython-310.pyc +0 -0
  27. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/loss_scale_optimizer.cpython-310.pyc +0 -0
  28. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/nadam.cpython-310.pyc +0 -0
  29. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/optimizer.cpython-310.pyc +0 -0
  30. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/rmsprop.cpython-310.pyc +0 -0
  31. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/sgd.cpython-310.pyc +0 -0
  32. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adadelta.py +139 -0
  33. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adafactor.py +208 -0
  34. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adagrad.py +115 -0
  35. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adam.py +167 -0
  36. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adamax.py +156 -0
  37. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adamw.py +100 -0
  38. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/base_optimizer.py +1102 -0
  39. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/ftrl.py +249 -0
  40. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/lamb.py +158 -0
  41. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/lion.py +142 -0
  42. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/loss_scale_optimizer.py +298 -0
  43. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/nadam.py +174 -0
  44. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/optimizer.py +27 -0
  45. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/rmsprop.py +180 -0
  46. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__init__.py +16 -0
  47. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__pycache__/__init__.cpython-310.pyc +0 -0
  48. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__pycache__/learning_rate_schedule.cpython-310.pyc +0 -0
  49. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/learning_rate_schedule.py +969 -0
  50. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/sgd.py +143 -0
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from keras.src.ops.numpy import Matmul, matmul
2
+ # from keras.src.ops.numpy import Add, add
3
+ # from keras.src.ops.numpy import Multiply, multiply
4
+
5
+ from keras.src.backend import cast
6
+ from keras.src.backend import cond
7
+ from keras.src.backend import is_tensor
8
+ from keras.src.backend import name_scope
9
+ from keras.src.backend import random
10
+ from keras.src.ops import image
11
+ from keras.src.ops import operation_utils
12
+ from keras.src.ops.core import * # noqa: F403
13
+ from keras.src.ops.linalg import * # noqa: F403
14
+ from keras.src.ops.math import * # noqa: F403
15
+ from keras.src.ops.nn import * # noqa: F403
16
+ from keras.src.ops.numpy import * # noqa: F403
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__pycache__/operation_utils.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__pycache__/symbolic_arguments.cpython-310.pyc ADDED
Binary file (1.86 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/core.py ADDED
@@ -0,0 +1,1167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_dtypes
2
+ import numpy as np
3
+
4
+ from keras.src import backend
5
+ from keras.src import tree
6
+ from keras.src.api_export import keras_export
7
+ from keras.src.backend import KerasTensor
8
+ from keras.src.backend import any_symbolic_tensors
9
+ from keras.src.backend.common.backend_utils import slice_along_axis
10
+ from keras.src.ops.operation import Operation
11
+ from keras.src.utils import traceback_utils
12
+
13
+
14
+ class Map(Operation):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def call(self, f, xs):
19
+ return backend.core.map(f, xs)
20
+
21
+ def compute_output_spec(self, f, xs):
22
+ x = xs[0]
23
+ n = xs.shape[0]
24
+ y = backend.compute_output_spec(f, x)
25
+
26
+ def append_batch_axis(x):
27
+ return KerasTensor(
28
+ shape=(n,) + x.shape, dtype=x.dtype, sparse=x.sparse
29
+ )
30
+
31
+ y = tree.map_structure(append_batch_axis, y)
32
+ return y
33
+
34
+
35
+ @keras_export("keras.ops.map")
36
+ def map(f, xs):
37
+ """Map a function over leading array axes.
38
+
39
+ Like Python’s builtin map, except inputs and outputs are in the form of
40
+ stacked arrays. Consider using the `vectorized_map()` transform instead,
41
+ unless you need to apply a function element by element for reduced memory
42
+ usage or heterogeneous computation with other control flow primitives.
43
+
44
+ When `xs` is an array type, the semantics of `map()` are given by this
45
+ Python implementation:
46
+
47
+ ```python
48
+ def map(f, xs):
49
+ return np.stack([f(x) for x in xs])
50
+ ```
51
+
52
+ Args:
53
+ f: Callable defines the function to apply element-wise over the first
54
+ axis or axes of `xs`.
55
+ xs: Values over which to map along the leading axis.
56
+
57
+ Returns:
58
+ Mapped values.
59
+
60
+ Examples:
61
+
62
+ >>> f = lambda x: x**2
63
+ >>> xs = keras.ops.arange(10)
64
+ >>> ys = keras.ops.map(f, xs)
65
+ >>> ys
66
+ [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
67
+
68
+ >>> f = lambda x: {"y1": x**2, "y2": x * 10} # Can have nested outputs
69
+ >>> ys = keras.ops.map(f, xs)
70
+ >>> ys["y1"]
71
+ [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
72
+ >>> ys["y2"]
73
+ [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
74
+ """
75
+ if any_symbolic_tensors((xs,)):
76
+ return Map().symbolic_call(f, xs)
77
+ return backend.core.map(f, xs)
78
+
79
+
80
+ class Scan(Operation):
81
+ def __init__(self, reverse=False, unroll=1):
82
+ super().__init__()
83
+ self.reverse = reverse
84
+ self.unroll = unroll
85
+
86
+ def call(self, f, init, xs, length):
87
+ return backend.core.scan(
88
+ f, init, xs, length, reverse=self.reverse, unroll=self.unroll
89
+ )
90
+
91
+ def compute_output_spec(self, f, init, xs, length):
92
+ if xs is None:
93
+ n = int(length)
94
+ x = None
95
+ else:
96
+ n = (
97
+ int(length)
98
+ if length is not None
99
+ else tree.flatten(xs)[0].shape[0]
100
+ )
101
+ x = xs[0]
102
+
103
+ carry, y = backend.compute_output_spec(f, init, x)
104
+ y = KerasTensor(shape=(n,) + y.shape, dtype=y.dtype, sparse=y.sparse)
105
+ return carry, y
106
+
107
+
108
+ @keras_export("keras.ops.scan")
109
+ def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
110
+ """Scan a function over leading array axes while carrying along state.
111
+
112
+ When the type of `xs` is an array type or `None`, and the type of `ys` is an
113
+ array type, the semantics of `scan()` are given roughly by this Python
114
+ implementation:
115
+
116
+ ```python
117
+ def scan(f, init, xs, length=None):
118
+ if xs is None:
119
+ xs = [None] * length
120
+ carry = init
121
+ ys = []
122
+ for x in xs:
123
+ carry, y = f(carry, x)
124
+ ys.append(y)
125
+ return carry, np.stack(ys)
126
+ ```
127
+
128
+ The loop-carried value `carry` (`init`) must hold a fixed shape and dtype
129
+ across all iterations.
130
+
131
+ In TensorFlow, `y` must match `carry` in shape and dtype. This is not
132
+ required in other backends.
133
+
134
+ Args:
135
+ f: Callable defines the logic for each loop iteration. This accepts two
136
+ arguments where the first is a value of the loop carry and the
137
+ second is a slice of `xs` along its leading axis.
138
+ This callable returns a pair where the first represents a new value
139
+ for the loop carry and the second represents a slice of the output.
140
+ init: The initial loop carry value. This can be a scalar, tensor, or any
141
+ nested structure. It must match the structure of the first element
142
+ returned by `f`.
143
+ xs: Optional value to scan along its leading axis. This can be a tensor
144
+ or any nested structure. If `xs` is not provided, you must specify
145
+ `length` to define the number of loop iterations.
146
+ Defaults to `None`.
147
+ length: Optional integer specifying the number of loop iterations.
148
+ If `length` is not provided, it defaults to the sizes of leading
149
+ axis of the arrays in `xs`. Defaults to `None`.
150
+ reverse: Optional boolean specifying whether to run the scan iteration
151
+ forward or in reverse, equivalent to reversing the leading axes of
152
+ the arrays in both `xs` and in `ys`.
153
+ unroll: Optional positive integer or boolean specifying how many scan
154
+ iterations to unroll within a single iteration of a loop. If an
155
+ integer is provided, it determines how many unrolled loop iterations
156
+ to run within a single rolled iteration of the loop. If a boolean is
157
+ provided, it will determine if the loop is completely unrolled
158
+ (`unroll=True`) or left completely unrolled (`unroll=False`).
159
+ Note that unrolling is only supported by JAX and TensorFlow
160
+ backends.
161
+
162
+ Returns:
163
+ A pair where the first element represents the final loop carry value and
164
+ the second element represents the stacked outputs of `f` when scanned
165
+ over the leading axis of the inputs.
166
+
167
+ Examples:
168
+
169
+ >>> sum_fn = lambda c, x: (c + x, c + x)
170
+ >>> init = keras.ops.array(0)
171
+ >>> xs = keras.ops.array([1, 2, 3, 4, 5])
172
+ >>> carry, result = keras.ops.scan(sum_fn, init, xs)
173
+ >>> carry
174
+ 15
175
+ >>> result
176
+ [1, 3, 6, 10, 15]
177
+ """
178
+ if any_symbolic_tensors((init, xs)):
179
+ return Scan(reverse=reverse, unroll=unroll).symbolic_call(
180
+ f, init, xs, length
181
+ )
182
+ return backend.core.scan(
183
+ f, init, xs, length, reverse=reverse, unroll=unroll
184
+ )
185
+
186
+
187
+ class AssociativeScan(Operation):
188
+ def __init__(self, reverse=False):
189
+ super().__init__()
190
+ self.reverse = reverse
191
+
192
+ def call(self, f, elems, axis=0):
193
+ return backend.core.associative_scan(
194
+ f, elems, reverse=self.reverse, axis=axis
195
+ )
196
+
197
+ def compute_output_spec(self, f, elems, axis):
198
+ elems_flat = tree.flatten(elems)
199
+ lens = [elem.shape[axis] for elem in elems_flat]
200
+ if len(set(lens)) != 1:
201
+ raise ValueError(
202
+ "Array inputs to associative_scan must have the same "
203
+ "first dimension. (saw: {})".format(
204
+ [elem.shape for elem in elems_flat]
205
+ )
206
+ )
207
+
208
+ x = tree.pack_sequence_as(
209
+ elems, [slice_along_axis(x, 0, 1, axis=axis) for x in elems_flat]
210
+ )
211
+ y_spec = backend.compute_output_spec(f, x, x)
212
+
213
+ def _restore_shape(x):
214
+ return KerasTensor(
215
+ shape=elems_flat[0].shape, dtype=x.dtype, sparse=x.sparse
216
+ )
217
+
218
+ y_spec = tree.map_structure(_restore_shape, y_spec)
219
+ return y_spec
220
+
221
+
222
+ @keras_export("keras.ops.associative_scan")
223
+ def associative_scan(f, elems, reverse=False, axis=0):
224
+ """Performs a scan with an associative binary operation, in parallel.
225
+
226
+ This operation his similar to `scan`, with the key difference that
227
+ `associative_scan` is a parallel implementation with
228
+ potentially significant performance benefits, especially when jit compiled.
229
+ The catch is that it can only be used when `f` is a binary associative
230
+ operation (i.e. it must verify `f(a, f(b, c)) == f(f(a, b), c)`).
231
+
232
+ For an introduction to associative scans, refer to this paper:
233
+ Blelloch, Guy E. 1990.
234
+ [Prefix Sums and Their Applications](
235
+ https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf).
236
+
237
+ Args:
238
+ f: A Python callable implementing an associative binary operation with
239
+ signature `r = f(a, b)`. Function `f` must be associative, i.e.,
240
+ it must satisfy the equation
241
+ `f(a, f(b, c)) == f(f(a, b), c)`.
242
+ The inputs and result are (possibly nested Python tree structures
243
+ of) array(s) matching `elems`. Each array has a dimension in place
244
+ of the `axis` dimension. `f` should be applied elementwise over
245
+ the `axis` dimension.
246
+ The result `r` has the same shape (and structure) as the
247
+ two inputs `a` and `b`.
248
+ elems: A (possibly nested Python tree structure of) array(s), each with
249
+ an `axis` dimension of size `num_elems`.
250
+ reverse: A boolean stating if the scan should be reversed with respect
251
+ to the `axis` dimension.
252
+ axis: an integer identifying the axis over which the scan should occur.
253
+
254
+ Returns:
255
+ A (possibly nested Python tree structure of) array(s) of the same shape
256
+ and structure as `elems`, in which the `k`'th element of `axis` is
257
+ the result of recursively applying `f` to combine the first `k`
258
+ elements of `elems` along `axis`. For example, given
259
+ `elems = [a, b, c, ...]`, the result would be
260
+ `[a, f(a, b), f(f(a, b), c), ...]`.
261
+
262
+ Examples:
263
+
264
+ >>> sum_fn = lambda x, y: x + y
265
+ >>> xs = keras.ops.arange(5)
266
+ >>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
267
+ >>> ys
268
+ [0, 1, 3, 6, 10]
269
+
270
+ >>> sum_fn = lambda x, y: [x[0] + y[0], x[1] + y[1], x[2] + y[2]]
271
+ >>> xs = [keras.ops.array([[1, 2]]) for _ in range(3)]
272
+ >>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
273
+ >>> ys
274
+ [[1, 3], [1, 3], [1, 3]]
275
+ """
276
+ if any_symbolic_tensors((elems,)):
277
+ return AssociativeScan(reverse=reverse).symbolic_call(f, elems, axis)
278
+ return backend.core.associative_scan(f, elems, reverse=reverse, axis=axis)
279
+
280
+
281
+ class Scatter(Operation):
282
+ def call(self, indices, values, shape):
283
+ return backend.core.scatter(indices, values, shape)
284
+
285
+ def compute_output_spec(self, indices, values, shape):
286
+ return KerasTensor(shape, dtype=values.dtype)
287
+
288
+
289
+ @keras_export("keras.ops.scatter")
290
+ def scatter(indices, values, shape):
291
+ """Returns a tensor of shape `shape` where `indices` are set to `values`.
292
+
293
+ At a high level, this operation does `zeros[indices] = updates` and
294
+ returns the output. It is equivalent to:
295
+
296
+ ```python
297
+ zeros = keras.ops.zeros(shape)
298
+ output = keras.ops.scatter_update(zeros, indices, values)
299
+ ```
300
+
301
+ Args:
302
+ indices: A tensor or list/tuple specifying
303
+ indices for the values in `values`.
304
+ values: A tensor, the values to be set at `indices`.
305
+ shape: Shape of the output tensor.
306
+
307
+ Example:
308
+
309
+ >>> indices = [[0, 1], [1, 1]]
310
+ >>> values = np.array([1., 1.])
311
+ >>> keras.ops.scatter(indices, values, shape=(2, 2))
312
+ array([[0., 1.],
313
+ [0., 1.]])
314
+ """
315
+ if any_symbolic_tensors((indices, values, shape)):
316
+ return Scatter().symbolic_call(indices, values, shape)
317
+ return backend.core.scatter(indices, values, shape)
318
+
319
+
320
+ class ScatterUpdate(Operation):
321
+ def call(self, inputs, indices, updates):
322
+ return backend.core.scatter_update(inputs, indices, updates)
323
+
324
+ def compute_output_spec(self, inputs, indices, updates):
325
+ return KerasTensor(inputs.shape, dtype=inputs.dtype)
326
+
327
+
328
+ @keras_export("keras.ops.scatter_update")
329
+ def scatter_update(inputs, indices, updates):
330
+ """Update inputs via updates at scattered (sparse) indices.
331
+
332
+ At a high level, this operation does `inputs[indices] = updates`.
333
+ Assume `inputs` is a tensor of shape `(D0, D1, ..., Dn)`, there are 2 main
334
+ usages of `scatter_update`.
335
+
336
+ 1. `indices` is a 2D tensor of shape `(num_updates, n)`, where `num_updates`
337
+ is the number of updates to perform, and `updates` is a 1D tensor of
338
+ shape `(num_updates,)`. For example, if `inputs` is `zeros((4, 4, 4))`,
339
+ and we want to update `inputs[1, 2, 3]` and `inputs[0, 1, 3]` as 1, then
340
+ we can use:
341
+
342
+ ```python
343
+ inputs = np.zeros((4, 4, 4))
344
+ indices = [[1, 2, 3], [0, 1, 3]]
345
+ updates = np.array([1., 1.])
346
+ inputs = keras.ops.scatter_update(inputs, indices, updates)
347
+ ```
348
+
349
+ 2 `indices` is a 2D tensor of shape `(num_updates, k)`, where `num_updates`
350
+ is the number of updates to perform, and `k` (`k < n`) is the size of
351
+ each index in `indices`. `updates` is a `n - k`-D tensor of shape
352
+ `(num_updates, inputs.shape[k:])`. For example, if
353
+ `inputs = np.zeros((4, 4, 4))`, and we want to update `inputs[1, 2, :]`
354
+ and `inputs[2, 3, :]` as `[1, 1, 1, 1]`, then `indices` would have shape
355
+ `(num_updates, 2)` (`k = 2`), and `updates` would have shape
356
+ `(num_updates, 4)` (`inputs.shape[2:] = 4`). See the code below:
357
+
358
+ ```python
359
+ inputs = np.zeros((4, 4, 4))
360
+ indices = [[1, 2], [2, 3]]
361
+ updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
362
+ inputs = keras.ops.scatter_update(inputs, indices, updates)
363
+ ```
364
+
365
+ Args:
366
+ inputs: A tensor, the tensor to be updated.
367
+ indices: A tensor or list/tuple of shape `(N, inputs.ndim)`, specifying
368
+ indices to update. `N` is the number of indices to update, must be
369
+ equal to the first dimension of `updates`.
370
+ updates: A tensor, the new values to be put to `inputs` at `indices`.
371
+
372
+ Returns:
373
+ A tensor, has the same shape and dtype as `inputs`.
374
+ """
375
+ if any_symbolic_tensors((inputs, indices, updates)):
376
+ return ScatterUpdate().symbolic_call(inputs, indices, updates)
377
+ return backend.core.scatter_update(inputs, indices, updates)
378
+
379
+
380
+ class Slice(Operation):
381
+ def call(self, inputs, start_indices, shape):
382
+ return backend.core.slice(inputs, start_indices, shape)
383
+
384
+ def compute_output_spec(self, inputs, start_indices, shape):
385
+ return KerasTensor(shape, dtype=inputs.dtype)
386
+
387
+
388
+ @keras_export("keras.ops.slice")
389
+ def slice(inputs, start_indices, shape):
390
+ """Return a slice of an input tensor.
391
+
392
+ At a high level, this operation is an explicit replacement for array slicing
393
+ e.g. `inputs[start_indices: start_indices + shape]`.
394
+ Unlike slicing via brackets, this operation will accept tensor start
395
+ indices on all backends, which is useful when indices dynamically computed
396
+ via other tensor operations.
397
+
398
+ ```python
399
+ inputs = np.zeros((5, 5))
400
+ start_indices = np.array([3, 3])
401
+ shape = np.array([2, 2])
402
+ inputs = keras.ops.slice(inputs, start_indices, shape)
403
+ ```
404
+
405
+ Args:
406
+ inputs: A tensor, the tensor to be updated.
407
+ start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying
408
+ the starting indices for updating.
409
+ shape: The full shape of the returned slice.
410
+
411
+ Returns:
412
+ A tensor, has the same shape and dtype as `inputs`.
413
+ """
414
+ if any_symbolic_tensors((inputs, start_indices, shape)):
415
+ return Slice().symbolic_call(inputs, start_indices, shape)
416
+ return backend.core.slice(inputs, start_indices, shape)
417
+
418
+
419
+ class SliceUpdate(Operation):
420
+ def call(self, inputs, start_indices, updates):
421
+ return backend.core.slice_update(inputs, start_indices, updates)
422
+
423
+ def compute_output_spec(self, inputs, start_indices, updates):
424
+ return KerasTensor(inputs.shape, dtype=inputs.dtype)
425
+
426
+
427
+ @keras_export("keras.ops.slice_update")
428
+ def slice_update(inputs, start_indices, updates):
429
+ """Update an input by slicing in a tensor of updated values.
430
+
431
+ At a high level, this operation does
432
+ `inputs[start_indices: start_indices + updates.shape] = updates`.
433
+ Assume inputs is a tensor of shape `(D0, D1, ..., Dn)`,
434
+ `start_indices` must be a list/tuple of n integers, specifying the starting
435
+ indices. `updates` must have the same rank as `inputs`, and the size of each
436
+ dim must not exceed `Di - start_indices[i]`. For example, if we have 2D
437
+ inputs `inputs = np.zeros((5, 5))`, and we want to update the intersection
438
+ of last 2 rows and last 2 columns as 1, i.e.,
439
+ `inputs[3:, 3:] = np.ones((2, 2))`, then we can use the code below:
440
+
441
+ ```python
442
+ inputs = np.zeros((5, 5))
443
+ start_indices = [3, 3]
444
+ updates = np.ones((2, 2))
445
+ inputs = keras.ops.slice_update(inputs, start_indices, updates)
446
+ ```
447
+
448
+ Args:
449
+ inputs: A tensor, the tensor to be updated.
450
+ start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying
451
+ the starting indices for updating.
452
+ updates: A tensor, the new values to be put to `inputs` at `indices`.
453
+ `updates` must have the same rank as `inputs`.
454
+
455
+ Returns:
456
+ A tensor, has the same shape and dtype as `inputs`.
457
+ """
458
+ if any_symbolic_tensors((inputs, start_indices, updates)):
459
+ return SliceUpdate().symbolic_call(inputs, start_indices, updates)
460
+ return backend.core.slice_update(inputs, start_indices, updates)
461
+
462
+
463
+ class Switch(Operation):
464
+ def call(self, index, branches, *operands):
465
+ return backend.core.switch(index, branches, *operands)
466
+
467
+ def compute_output_spec(self, index, branches, *operands):
468
+ # We use first branch for output_spec
469
+ spec = backend.compute_output_spec(branches[0], *operands)
470
+ return spec
471
+
472
+
473
+ @keras_export("keras.ops.switch")
474
+ def switch(index, branches, *operands):
475
+ """Apply exactly one of the `branches` given by `index`.
476
+
477
+ If `index` is out of bounds, it is clamped to within bounds.
478
+
479
+ The semantics of `switch` are given roughly by this Python implementation:
480
+
481
+ ```python
482
+ def switch(index, branches, *operands):
483
+ index = clamp(0, index, len(branches) - 1)
484
+ return branches[index](*operands)
485
+ ```
486
+
487
+ Args:
488
+ index: An integer scalar indicating which branch function to apply.
489
+ branches: A sequence of functions to be applied based on `index`.
490
+ operands: Inputs to whichever branch is applied.
491
+
492
+ Returns:
493
+ The outputs of `branch(*operands)` for the branch that was selected
494
+ based on `index`.
495
+
496
+ Examples:
497
+
498
+ >>> add_fn = lambda x, y: x + y
499
+ >>> subtract_fn = lambda x, y: x - y
500
+ >>> x = keras.ops.array(2.0)
501
+ >>> y = keras.ops.array(0.5)
502
+ >>> branches = [add_fn, subtract_fn]
503
+ >>> keras.ops.switch(0, branches, x, y)
504
+ 2.5
505
+
506
+ >>> keras.ops.switch(1, branches, x, y)
507
+ 1.5
508
+ """
509
+ if any_symbolic_tensors(operands):
510
+ return Switch().symbolic_call(index, branches, *operands)
511
+ return backend.core.switch(index, branches, *operands)
512
+
513
+
514
+ class WhileLoop(Operation):
515
+ def __init__(self, cond, body, maximum_iterations):
516
+ super().__init__()
517
+ self.cond = cond
518
+ self.body = body
519
+ self.maximum_iterations = maximum_iterations
520
+
521
+ def call(self, loop_vars):
522
+ return backend.core.while_loop(
523
+ self.cond,
524
+ self.body,
525
+ loop_vars,
526
+ maximum_iterations=self.maximum_iterations,
527
+ )
528
+
529
+ def compute_output_spec(self, loop_vars):
530
+ return [KerasTensor(v.shape, dtype=v.dtype) for v in loop_vars]
531
+
532
+
533
+ @keras_export("keras.ops.while_loop")
534
+ def while_loop(
535
+ cond,
536
+ body,
537
+ loop_vars,
538
+ maximum_iterations=None,
539
+ ):
540
+ """While loop implementation.
541
+
542
+ Args:
543
+ cond: A callable that represents the termination condition of the loop.
544
+ Must accept a `loop_vars` like structure as an argument. If
545
+ `loop_vars` is a tuple or list, each element of `loop_vars` will be
546
+ passed positionally to the callable.
547
+ body: A callable that represents the loop body. Must accept a
548
+ `loop_vars` like structure as an argument, and return update value
549
+ with the same structure. If `loop_vars` is a tuple or list, each
550
+ element of `loop_vars` will be passed positionally to the callable.
551
+ loop_vars: An arbitrary nested structure of tensor state to persist
552
+ across loop iterations.
553
+ maximum_iterations: Optional maximum number of iterations of the while
554
+ loop to run. If provided, the `cond` output is AND-ed with an
555
+ additional condition ensuring the number of iterations executed is
556
+ no greater than `maximum_iterations`.
557
+
558
+ Returns:
559
+ A list/tuple of tensors, has the same shape and dtype as `inputs`.
560
+
561
+ Examples:
562
+
563
+ >>> i = 0
564
+ >>> cond = lambda i: i < 10
565
+ >>> body = lambda i: i + 1
566
+ >>> keras.ops.while_loop(cond, body, i)
567
+ 10
568
+
569
+ >>> x, y = 0, 1
570
+ >>> cond = lambda x, y: x < 10
571
+ >>> body = lambda x, y: (x + 1, y + 1)
572
+ >>> keras.ops.while_loop(cond, body, (x, y))
573
+ 10, 11
574
+ """
575
+ return backend.core.while_loop(
576
+ cond,
577
+ body,
578
+ loop_vars,
579
+ maximum_iterations=maximum_iterations,
580
+ )
581
+
582
+
583
+ class StopGradient(Operation):
584
+ def __init__(self):
585
+ super().__init__()
586
+
587
+ def call(self, variable):
588
+ return backend.core.stop_gradient(variable)
589
+
590
+ def compute_output_spec(self, variable):
591
+ return KerasTensor(variable.shape, dtype=variable.dtype)
592
+
593
+
594
+ @keras_export("keras.ops.stop_gradient")
595
+ def stop_gradient(variable):
596
+ """Stops gradient computation.
597
+
598
+ Args:
599
+ variable: A tensor variable for which the gradient
600
+ computation is to be disabled.
601
+
602
+ Returns:
603
+ The variable with gradient computation disabled.
604
+
605
+ Examples:
606
+
607
+ >>> var = keras.backend.convert_to_tensor(
608
+ ... [1., 2., 3.],
609
+ ... dtype="float32"
610
+ ... )
611
+ >>> var = keras.ops.stop_gradient(var)
612
+ """
613
+ if any_symbolic_tensors((variable,)):
614
+ return StopGradient().symbolic_call(variable)
615
+ return backend.core.stop_gradient(variable)
616
+
617
+
618
+ class ForiLoop(Operation):
619
+ def __init__(self, lower, upper, body_fun):
620
+ super().__init__()
621
+ self.lower = lower
622
+ self.upper = upper
623
+ self.body_fun = body_fun
624
+
625
+ def call(self, init_val):
626
+ return backend.core.fori_loop(
627
+ self.lower,
628
+ self.upper,
629
+ self.body_fun,
630
+ init_val,
631
+ )
632
+
633
+ def compute_output_spec(self, init_val):
634
+ return KerasTensor(init_val.shape, dtype=init_val.dtype)
635
+
636
+
637
+ @keras_export("keras.ops.fori_loop")
638
+ def fori_loop(lower, upper, body_fun, init_val):
639
+ """For loop implementation.
640
+
641
+ Args:
642
+ lower: The initial value of the loop variable.
643
+ upper: The upper bound of the loop variable.
644
+ body_fun: A callable that represents the loop body. Must take two
645
+ arguments: the loop variable and the loop state. The loop state
646
+ should be updated and returned by this function.
647
+ init_val: The initial value of the loop state.
648
+
649
+ Returns:
650
+ The final state after the loop.
651
+
652
+ Example:
653
+
654
+ >>> lower = 0
655
+ >>> upper = 10
656
+ >>> body_fun = lambda i, s: (i + 1, s + i)
657
+ >>> init_val = 0
658
+ >>> keras.ops.fori_loop(lower, upper, body_fun, init_val)
659
+ 45
660
+ """
661
+ if any_symbolic_tensors((lower, upper, init_val)):
662
+ return ForiLoop(lower, upper, body_fun).symbolic_call(init_val)
663
+ return backend.core.fori_loop(lower, upper, body_fun, init_val)
664
+
665
+
666
+ class Unstack(Operation):
667
+ def __init__(self, num=None, axis=0):
668
+ super().__init__()
669
+ self.num = num
670
+ self.axis = axis
671
+
672
+ def call(self, x):
673
+ return backend.core.unstack(x, self.num, self.axis)
674
+
675
+ def compute_output_spec(self, x):
676
+ axis = self.axis
677
+ if axis < 0:
678
+ axis = len(x.shape) + axis
679
+ output_shapes = x.shape[:axis] + x.shape[axis + 1 :]
680
+ num = self.num
681
+ if num is None:
682
+ num = x.shape[axis]
683
+ if num is None:
684
+ raise ValueError(
685
+ "Cannot infer argument `num` from shape "
686
+ f"{x.shape}. Either provide a tensor with a "
687
+ "concrete shape in the `axis` dimension or "
688
+ "explicitly pass the `num` argument."
689
+ )
690
+ output = [
691
+ KerasTensor(shape=output_shapes, dtype=x.dtype) for _ in range(num)
692
+ ]
693
+ return output
694
+
695
+
696
+ @keras_export("keras.ops.unstack")
697
+ def unstack(x, num=None, axis=0):
698
+ """Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
699
+
700
+ Args:
701
+ x: The input tensor.
702
+ num: The length of the dimension axis. Automatically inferred
703
+ if `None`.
704
+ axis: The axis along which to unpack.
705
+
706
+ Returns:
707
+ A list of tensors unpacked along the given axis.
708
+
709
+ Example:
710
+
711
+ >>> x = keras.ops.array([[1, 2], [3, 4]])
712
+ >>> keras.ops.unstack(x, axis=0)
713
+ [array([1, 2]), array([3, 4])]
714
+ """
715
+ if any_symbolic_tensors((x,)):
716
+ return Unstack(num, axis).symbolic_call(x)
717
+ return backend.core.unstack(x, num=num, axis=axis)
718
+
719
+
720
+ @keras_export("keras.ops.shape")
721
+ def shape(x):
722
+ """Gets the shape of the tensor input.
723
+
724
+ Note: On the TensorFlow backend, when `x` is a `tf.Tensor` with dynamic
725
+ shape, dimensions which are dynamic in the context of a compiled function
726
+ will have a `tf.Tensor` value instead of a static integer value.
727
+
728
+ Args:
729
+ x: A tensor. This function will try to access the `shape` attribute of
730
+ the input tensor.
731
+
732
+ Returns:
733
+ A tuple of integers or None values, indicating the shape of the input
734
+ tensor.
735
+
736
+ Example:
737
+
738
+ >>> x = keras.ops.zeros((8, 12))
739
+ >>> keras.ops.shape(x)
740
+ (8, 12)
741
+ """
742
+ if any_symbolic_tensors((x,)):
743
+ return x.shape
744
+ return backend.core.shape(x)
745
+
746
+
747
+ @keras_export("keras.ops.dtype")
748
+ def dtype(x):
749
+ """Return the dtype of the tensor input as a standardized string.
750
+
751
+ Note that due to the standardization, the dtype will not compare equal
752
+ to the backend-specific version of the dtype.
753
+
754
+ Args:
755
+ x: A tensor. This function will try to access the `dtype` attribute of
756
+ the input tensor.
757
+
758
+ Returns:
759
+ A string indicating the dtype of the input tensor, e.g. `"float32"`.
760
+
761
+ Example:
762
+
763
+ >>> x = keras.ops.zeros((8, 12))
764
+ >>> keras.ops.dtype(x)
765
+ 'float32'
766
+
767
+ """
768
+ return backend.standardize_dtype(x.dtype)
769
+
770
+
771
+ class Cast(Operation):
772
+ def __init__(self, dtype):
773
+ super().__init__()
774
+ self.dtype = backend.standardize_dtype(dtype)
775
+
776
+ def call(self, x):
777
+ return backend.core.cast(x, self.dtype)
778
+
779
+ def compute_output_spec(self, x):
780
+ return backend.KerasTensor(shape=x.shape, dtype=self.dtype)
781
+
782
+
783
+ @keras_export("keras.ops.cast")
784
+ def cast(x, dtype):
785
+ """Cast a tensor to the desired dtype.
786
+
787
+ Args:
788
+ x: A tensor or variable.
789
+ dtype: The target type.
790
+
791
+ Returns:
792
+ A tensor of the specified `dtype`.
793
+
794
+ Example:
795
+
796
+ >>> x = keras.ops.arange(4)
797
+ >>> x = keras.ops.cast(x, dtype="float16")
798
+ """
799
+ dtype = backend.standardize_dtype(dtype)
800
+
801
+ if any_symbolic_tensors((x,)):
802
+ return Cast(dtype=dtype)(x)
803
+ return backend.core.cast(x, dtype)
804
+
805
+
806
+ class SaturateCast(Operation):
807
+ def __init__(self, dtype):
808
+ super().__init__()
809
+ self.dtype = backend.standardize_dtype(dtype)
810
+
811
+ def call(self, x):
812
+ return _saturate_cast(x, self.dtype)
813
+
814
+ def compute_output_spec(self, x):
815
+ return backend.KerasTensor(shape=x.shape, dtype=self.dtype)
816
+
817
+
818
+ @keras_export("keras.ops.saturate_cast")
819
+ def saturate_cast(x, dtype):
820
+ """Performs a safe saturating cast to the desired dtype.
821
+
822
+ Saturating cast prevents data type overflow when casting to `dtype` with
823
+ smaller values range. E.g.
824
+ `ops.cast(ops.cast([-1, 256], "float32"), "uint8")` returns `[255, 0]`,
825
+ but `ops.saturate_cast(ops.cast([-1, 256], "float32"), "uint8")` returns
826
+ `[0, 255]`.
827
+
828
+ Args:
829
+ x: A tensor or variable.
830
+ dtype: The target type.
831
+
832
+ Returns:
833
+ A safely casted tensor of the specified `dtype`.
834
+
835
+ Example:
836
+
837
+ Image resizing with bicubic interpolation may produce values outside
838
+ original range.
839
+ >>> image2x2 = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1)
840
+ >>> image4x4 = tf.image.resize(image2x2, (4, 4), method="bicubic")
841
+ >>> print(image4x4.numpy().squeeze())
842
+ >>> # [[-22.500004 -22.204624 -21.618908 -21.32353 ]
843
+ >>> # [ 52.526054 52.82143 53.407146 53.70253 ]
844
+ >>> # [201.29752 201.59288 202.17859 202.47395 ]
845
+ >>> # [276.32355 276.61893 277.20465 277.50006 ]]
846
+
847
+ Casting this resized image back to `uint8` will cause overflow.
848
+ >>> image4x4_casted = ops.cast(image4x4, "uint8")
849
+ >>> print(image4x4_casted.numpy().squeeze())
850
+ >>> # [[234 234 235 235]
851
+ >>> # [ 52 52 53 53]
852
+ >>> # [201 201 202 202]
853
+ >>> # [ 20 20 21 21]]
854
+
855
+ Saturate casting to `uint8` will clip values to `uint8` range before
856
+ casting and will not cause overflow.
857
+ >>> image4x4_saturate_casted = ops.saturate_cast(image4x4, "uint8")
858
+ >>> print(image4x4_saturate_casted.numpy().squeeze())
859
+ >>> # [[ 0 0 0 0]
860
+ >>> # [ 52 52 53 53]
861
+ >>> # [201 201 202 202]
862
+ >>> # [255 255 255 255]]
863
+
864
+ """
865
+ dtype = backend.standardize_dtype(dtype)
866
+
867
+ if any_symbolic_tensors((x,)):
868
+ return SaturateCast(dtype=dtype)(x)
869
+ return _saturate_cast(x, dtype)
870
+
871
+
872
+ def _saturate_cast(x, dtype, backend_module=None):
873
+ backend_module = backend_module or backend
874
+
875
+ def get_dtype_min_max(dtype):
876
+ if "bool" == dtype:
877
+ dtype_min = 0
878
+ dtype_max = 1
879
+ elif "int" in dtype:
880
+ dtype_min = ml_dtypes.iinfo(dtype).min
881
+ dtype_max = ml_dtypes.iinfo(dtype).max
882
+ else:
883
+ dtype_min = ml_dtypes.finfo(dtype).min
884
+ dtype_max = ml_dtypes.finfo(dtype).max
885
+ return dtype_min, dtype_max
886
+
887
+ dtype = backend.standardize_dtype(dtype)
888
+ in_dtype = backend.standardize_dtype(x.dtype)
889
+ in_min, in_max = get_dtype_min_max(in_dtype)
890
+ out_min, out_max = get_dtype_min_max(dtype)
891
+
892
+ # The output min/max may not actually be representable in the
893
+ # in_dtype (e.g. casting float32 to uint32). This can lead to undefined
894
+ # behavior when trying to cast a value outside the valid range of the
895
+ # target type. We work around this by nudging the min/max to fall within
896
+ # the valid output range. The catch is that we may actually saturate
897
+ # to a value less than the true saturation limit, but this is the best we
898
+ # can do in order to avoid UB without backend op.
899
+ min_limit = np.maximum(in_min, out_min).astype(in_dtype)
900
+ if min_limit < out_min:
901
+ min_limit = np.nextafter(min_limit, 0, dtype=in_dtype)
902
+ max_limit = np.minimum(in_max, out_max).astype(in_dtype)
903
+ if max_limit > out_max:
904
+ max_limit = np.nextafter(max_limit, 0, dtype=in_dtype)
905
+
906
+ # Unconditionally apply `clip` to fix `inf` behavior.
907
+ x = backend_module.numpy.clip(x, min_limit, max_limit)
908
+
909
+ return backend_module.cast(x, dtype)
910
+
911
+
912
+ class ConvertToTensor(Operation):
913
+ def __init__(self, dtype, sparse):
914
+ super().__init__()
915
+ self.dtype = backend.standardize_dtype(dtype)
916
+ self.sparse = sparse
917
+
918
+ def call(self, x):
919
+ return backend.core.convert_to_tensor(
920
+ x, dtype=self.dtype, sparse=self.sparse
921
+ )
922
+
923
+ def compute_output_spec(self, x):
924
+ dtype = x.dtype if self.dtype is None else self.dtype
925
+ sparse = (
926
+ False if self.sparse is not None and not self.sparse else x.sparse
927
+ )
928
+ return backend.KerasTensor(shape=x.shape, dtype=dtype, sparse=sparse)
929
+
930
+
931
+ @keras_export("keras.ops.convert_to_tensor")
932
+ def convert_to_tensor(x, dtype=None, sparse=None):
933
+ """Convert a NumPy array to a tensor.
934
+
935
+ Args:
936
+ x: A NumPy array, Python array (can be nested) or a backend tensor.
937
+ dtype: The target type. If `None`, the type of `x` is used.
938
+ sparse: Whether to keep sparse tensors. `False` will cause sparse
939
+ tensors to be densified. The default value of `None` means that
940
+ sparse tensors are kept only if the backend supports them.
941
+
942
+ Returns:
943
+ A backend tensor of the specified `dtype` and sparseness.
944
+
945
+ Example:
946
+
947
+ >>> x = np.array([1, 2, 3])
948
+ >>> y = keras.ops.convert_to_tensor(x)
949
+ """
950
+ if any_symbolic_tensors((x,)):
951
+ return ConvertToTensor(dtype=dtype, sparse=sparse)(x)
952
+ return backend.core.convert_to_tensor(x, dtype=dtype, sparse=sparse)
953
+
954
+
955
+ @keras_export("keras.ops.convert_to_numpy")
956
+ def convert_to_numpy(x):
957
+ """Convert a tensor to a NumPy array.
958
+
959
+ Args:
960
+ x: A tensor.
961
+
962
+ Returns:
963
+ A NumPy array.
964
+ """
965
+ if any_symbolic_tensors((x,)):
966
+ # This will raise a `ValueError` defined in the `KerasTensor` class.
967
+ # We trigger it rather than duplicate it here.
968
+ return np.array(x)
969
+ return backend.convert_to_numpy(x)
970
+
971
+
972
+ class Cond(Operation):
973
+ @traceback_utils.filter_traceback
974
+ def __call__(self, *args, **kwargs):
975
+ def call_fn(*args, **kwargs):
976
+ if any_symbolic_tensors(args, kwargs):
977
+ return self.symbolic_call(*args, **kwargs)
978
+ else:
979
+ return self.call(*args, **kwargs)
980
+
981
+ if traceback_utils.is_traceback_filtering_enabled():
982
+ # Wrap self.call to provide helpful info in case of exception
983
+ call_fn = traceback_utils.inject_argument_info_in_traceback(
984
+ call_fn,
985
+ object_name=(f"{self.__class__.__name__}.call()"),
986
+ )
987
+ return call_fn(*args, **kwargs)
988
+
989
+ # Plain flow.
990
+ return call_fn(*args, **kwargs)
991
+
992
+ def call(self, pred, true_fn, false_fn):
993
+ return backend.core.cond(pred, true_fn, false_fn)
994
+
995
+ def compute_output_spec(self, pred, true_fn, false_fn):
996
+ true_fn_spec = backend.compute_output_spec(true_fn)
997
+ false_fn_spec = backend.compute_output_spec(false_fn)
998
+ if not self._check_output_spec(true_fn_spec, false_fn_spec):
999
+ raise ValueError(
1000
+ "`true_fn` and `false_fn` should return outputs "
1001
+ "of the same kind (struct, dtype and shape). "
1002
+ f"Got {true_fn_spec} and {false_fn_spec} instead."
1003
+ )
1004
+ return true_fn_spec
1005
+
1006
+ def _check_output_spec(self, true_fn_spec, false_fn_spec):
1007
+ try:
1008
+ tree.assert_same_structure(true_fn_spec, false_fn_spec)
1009
+ except:
1010
+ return False
1011
+
1012
+ def check_leaf(t_spec, f_spec):
1013
+ if t_spec is None or f_spec is None:
1014
+ return t_spec is None and f_spec is None
1015
+ return t_spec.shape == f_spec.shape and t_spec.dtype == f_spec.dtype
1016
+
1017
+ same = tree.map_structure(check_leaf, true_fn_spec, false_fn_spec)
1018
+ return all(tree.flatten(same))
1019
+
1020
+
1021
+ @keras_export("keras.ops.cond")
1022
+ def cond(pred, true_fn, false_fn):
1023
+ """Conditionally applies `true_fn` or `false_fn`.
1024
+
1025
+ Args:
1026
+ pred: Boolean scalar type
1027
+ true_fn: Callable returning the output for the `pred == True` case.
1028
+ false_fn: Callable returning the output for the `pred == False` case.
1029
+
1030
+ Returns:
1031
+ The output of either `true_fn` or `false_fn` depending on pred.
1032
+ """
1033
+ return Cond()(pred, true_fn, false_fn)
1034
+
1035
+
1036
+ # TODO: also create an Op subclass VectorizedMap.
1037
+ @keras_export("keras.ops.vectorized_map")
1038
+ def vectorized_map(function, elements):
1039
+ """Parallel map of `function` on axis 0 of tensor(s) `elements`.
1040
+
1041
+ Schematically, `vectorized_map` implements the following,
1042
+ in the case of a single tensor input `elements`:
1043
+
1044
+ ```python
1045
+ def vectorized_map(function, elements)
1046
+ outputs = []
1047
+ for e in elements:
1048
+ outputs.append(function(e))
1049
+ return stack(outputs)
1050
+ ```
1051
+
1052
+ In the case of an iterable of tensors `elements`,
1053
+ it implements the following:
1054
+
1055
+ ```python
1056
+ def vectorized_map(function, elements)
1057
+ batch_size = elements[0].shape[0]
1058
+ outputs = []
1059
+ for index in range(batch_size):
1060
+ outputs.append(function([e[index] for e in elements]))
1061
+ return np.stack(outputs)
1062
+ ```
1063
+
1064
+ In this case, `function` is expected to take as input
1065
+ a single list of tensor arguments.
1066
+ """
1067
+ return backend.core.vectorized_map(function, elements)
1068
+
1069
+
1070
+ @keras_export("keras.ops.is_tensor")
1071
+ def is_tensor(x):
1072
+ """Check whether the given object is a tensor.
1073
+
1074
+ Note: This checks for backend specific tensors so passing a TensorFlow
1075
+ tensor would return `False` if your backend is PyTorch or JAX.
1076
+
1077
+ Args:
1078
+ x: A variable.
1079
+
1080
+ Returns:
1081
+ `True` if `x` is a tensor, otherwise `False`.
1082
+ """
1083
+ return backend.core.is_tensor(x)
1084
+
1085
+
1086
+ @keras_export("keras.ops.custom_gradient")
1087
+ def custom_gradient(f):
1088
+ """Decorator to define a function with a custom gradient.
1089
+
1090
+ This decorator allows fine grained control over the gradients of a sequence
1091
+ for operations. This may be useful for multiple reasons, including providing
1092
+ a more efficient or numerically stable gradient for a sequence of
1093
+ operations.
1094
+
1095
+ Args:
1096
+ f: Function `f(*args)` that returns a tuple
1097
+ `(output, grad_fn)`, where:
1098
+ - `args` is a sequence of (nested structures of) tensor inputs to
1099
+ the function.
1100
+ - `output` is a (nested structure of) tensor outputs of applying
1101
+ operations in `forward_fn` to `args`.
1102
+ - `grad_fn` is a function with the signature `grad_fn(*args,
1103
+ upstream)` which returns a tuple of tensors the same size as
1104
+ (flattened) `args`: the derivatives of tensors in `output` with
1105
+ respect to the tensors in `args`. `upstream` is a tensor or
1106
+ sequence of tensors holding the initial value gradients for each
1107
+ tensor in `output`.
1108
+
1109
+ Returns:
1110
+ A function `h(*args)` which returns the same value as
1111
+ `f(*args)[0]` and whose gradient is determined by
1112
+ `f(*args)[1]`.
1113
+
1114
+
1115
+ Examples:
1116
+
1117
+ 1. Backend-agnostic example.
1118
+
1119
+ ```python
1120
+ @ops.custom_gradient
1121
+ def log1pexp(x):
1122
+ e = ops.exp(x)
1123
+
1124
+ def grad(*args, upstream=None):
1125
+ if upstream is None:
1126
+ (upstream,) = args
1127
+ return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
1128
+
1129
+ return ops.log(1 + e), grad
1130
+ ```
1131
+
1132
+ Note that the grad function that returns gradient computation
1133
+ requires `args` as well as an `upstream` keyword argument, depending
1134
+ on the backend being set. With the JAX and TensorFlow backends,
1135
+ it requires only one argument, whereas it might use the `upstream`
1136
+ argument in the case of the PyTorch backend.
1137
+
1138
+ When working with TensorFlow/JAX backend, `grad(upstream)`
1139
+ is sufficient. With PyTorch, the `grad` function requires
1140
+ `*args` as well as `upstream`, e.g. `def grad(*args, upstream)`.
1141
+ Follow the previous example to use `@ops.custom_gradient` in
1142
+ a way that is compatible with all backends.
1143
+
1144
+ 2. Here's JAX & TensorFlow-specific example:
1145
+
1146
+ ```python
1147
+ @ops.custom_gradient
1148
+ def log1pexp(x):
1149
+ e = ops.exp(x)
1150
+ def grad(upstream):
1151
+ return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
1152
+ return ops.log(1 + e), grad
1153
+ ```
1154
+
1155
+ 3. Lastly, here's a PyTorch-specific example,
1156
+ using `*args` & `upstream`:
1157
+
1158
+ ```python
1159
+ @ops.custom_gradient
1160
+ def log1pexp(x):
1161
+ e = ops.exp(x)
1162
+ def grad(*args, upstream):
1163
+ return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
1164
+ return ops.log(1 + e), grad
1165
+ ```
1166
+ """
1167
+ return backend.core.custom_gradient(f)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/function.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+
3
+ from keras.src import tree
4
+ from keras.src.api_export import keras_export
5
+ from keras.src.backend import KerasTensor
6
+ from keras.src.backend.config import backend
7
+ from keras.src.ops.operation import Operation
8
+
9
+
10
+ @keras_export("keras.Function")
11
+ class Function(Operation):
12
+ """Class that encapsulates a computation graph of Keras operations.
13
+
14
+ You can use a `Function` to capture the computation graph linking
15
+ some input tensors to some output tensors, and reapply the same
16
+ computation on new inputs.
17
+
18
+ A `Function` is similar to a Functional Model, with the difference
19
+ that it is stateless (it does not track state variables)
20
+ and does not implement the `Layer` API.
21
+
22
+ Example:
23
+
24
+ ```python
25
+ input_1 = keras.KerasTensor(shape=(None, 2, 3))
26
+ input_2 = keras.KerasTensor(shape=(None, 2, 3))
27
+ x = input_1 + input_2
28
+ output = keras.ops.sigmoid(x)
29
+ fn = keras.Function(inputs=[input_1, input_2], outputs=output)
30
+
31
+ input_1_val = np.random.random((4, 2, 3))
32
+ input_2_val = np.random.random((4, 2, 3))
33
+ output_val = fn([input_1_val, input_2_val])
34
+ ```
35
+
36
+ Args:
37
+ inputs: `KerasTensor` instance or nested structured of
38
+ `KerasTensor` instances.
39
+ outputs: `KerasTensor` instance or nested structured of
40
+ `KerasTensor` instances. They should be computable
41
+ given only the values of `inputs`.
42
+ name: String. The name of the function.
43
+ """
44
+
45
+ def __init__(self, inputs, outputs, name=None):
46
+ super().__init__(name=name)
47
+
48
+ if backend() == "tensorflow":
49
+ # Temporary work around for
50
+ # https://github.com/keras-team/keras/issues/931
51
+ # This stop tensorflow from wrapping tf.function output in a
52
+ # _DictWrapper object.
53
+ _self_setattr_tracking = getattr(
54
+ self, "_self_setattr_tracking", True
55
+ )
56
+ self._self_setattr_tracking = False
57
+ self._inputs_struct = tree.map_structure(lambda x: x, inputs)
58
+ self._outputs_struct = tree.map_structure(lambda x: x, outputs)
59
+ self._inputs = tree.flatten(inputs)
60
+ self._outputs = tree.flatten(outputs)
61
+ if not self._inputs:
62
+ raise ValueError(
63
+ "`inputs` argument cannot be empty. Received:\n"
64
+ f"inputs={inputs}\n"
65
+ f"outputs={outputs}"
66
+ )
67
+ if not self._outputs:
68
+ raise ValueError(
69
+ "`outputs` argument cannot be empty. Received:\n"
70
+ f"inputs={inputs}\n"
71
+ f"outputs={outputs}"
72
+ )
73
+
74
+ if backend() == "tensorflow":
75
+ self._self_setattr_tracking = _self_setattr_tracking
76
+
77
+ (nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(
78
+ self._inputs, self._outputs
79
+ )
80
+ self._nodes = nodes
81
+ self._nodes_by_depth = nodes_by_depth
82
+ self._operations = operations
83
+ self._operations_by_depth = operations_by_depth
84
+
85
+ @property
86
+ def operations(self):
87
+ return self._operations[:]
88
+
89
+ @property
90
+ def inputs(self):
91
+ """Flat list of the symbolic inputs of the Function."""
92
+ return self._inputs
93
+
94
+ @property
95
+ def outputs(self):
96
+ """Flat list of the symbolic outputs of the Function."""
97
+ return self._outputs
98
+
99
+ def compute_output_spec(self, inputs):
100
+ self._assert_input_compatibility(inputs)
101
+ # Check if input shapes are identical to ref input shapes,
102
+ # if so take a shortcut.
103
+ shortcut = True
104
+ for x, x_ref in zip(tree.flatten(inputs), self._inputs):
105
+ if x.shape != x_ref.shape:
106
+ shortcut = False
107
+ break
108
+ if shortcut:
109
+ return tree.map_structure(
110
+ lambda x: KerasTensor(shape=x.shape, dtype=x.dtype),
111
+ self._outputs_struct,
112
+ )
113
+ # No luck; take the long road through the graph.
114
+ # Original Keras used a cache to avoid recomputing all this
115
+ # when known input shapes where seen again. Perhaps a good
116
+ # idea to bring that back.
117
+ return self._run_through_graph(
118
+ inputs, operation_fn=lambda op: op.compute_output_spec
119
+ )
120
+
121
+ def compute_output_shape(self, input_shape):
122
+ # Wrap `input_shape` into the structure of KerasTensor to utilize
123
+ # `compute_output_spec`.
124
+ input_shape_struct = tree.map_shape_structure(
125
+ lambda x: KerasTensor(shape=x), input_shape
126
+ )
127
+ # Ensure that dtype and sparse settings are the same as self._inputs,
128
+ # because we only care about the shape in this function.
129
+ for x, x_ref in zip(tree.flatten(input_shape_struct), self._inputs):
130
+ x._dtype = x_ref.dtype
131
+ x._sparse = x_ref.sparse
132
+ output_spec = self.compute_output_spec(input_shape_struct)
133
+ return tree.map_structure(lambda x: x.shape, output_spec)
134
+
135
+ def call(self, inputs):
136
+ """Computes output tensors for new inputs."""
137
+ self._assert_input_compatibility(inputs)
138
+ return self._run_through_graph(inputs, operation_fn=lambda op: op)
139
+
140
+ def _run_through_graph(self, inputs, operation_fn, call_fn=None):
141
+ """Execute the graph.
142
+
143
+ At each node we compute outputs via
144
+ `operation_fn(node.operation)(*args, **kwargs)`.
145
+ """
146
+ inputs = tree.flatten(inputs)
147
+
148
+ # Dictionary mapping reference tensors to computed tensors.
149
+ tensor_dict = {}
150
+ for x, y in zip(self.inputs, inputs):
151
+ tensor_dict[id(x)] = y
152
+
153
+ nodes_by_depth = self._nodes_by_depth
154
+ depth_keys = list(nodes_by_depth.keys())
155
+ depth_keys.sort(reverse=True)
156
+
157
+ for depth in depth_keys:
158
+ nodes = nodes_by_depth[depth]
159
+ for node in nodes:
160
+ if not node.operation or node.is_input:
161
+ continue # Input tensors already exist.
162
+
163
+ if any(id(x) not in tensor_dict for x in node.input_tensors):
164
+ continue # Node is not computable, try skipping.
165
+
166
+ args, kwargs = node.arguments.fill_in(tensor_dict)
167
+ op = operation_fn(node.operation)
168
+ if call_fn is not None:
169
+ outputs = call_fn(op, *args, **kwargs)
170
+ else:
171
+ outputs = op(*args, **kwargs)
172
+
173
+ # Update tensor_dict.
174
+ for x, y in zip(node.outputs, tree.flatten(outputs)):
175
+ tensor_dict[id(x)] = y
176
+
177
+ output_tensors = []
178
+ for x in self.outputs:
179
+ output_tensors.append(tensor_dict[id(x)])
180
+
181
+ return tree.pack_sequence_as(self._outputs_struct, output_tensors)
182
+
183
+ def _assert_input_compatibility(self, inputs):
184
+ try:
185
+ tree.assert_same_structure(inputs, self._inputs_struct)
186
+ except ValueError:
187
+ raise ValueError(
188
+ "Function was called with an invalid input structure. "
189
+ f"Expected input structure: {self._inputs_struct}\n"
190
+ f"Received input structure: {inputs}"
191
+ )
192
+ for x, x_ref in zip(tree.flatten(inputs), self._inputs):
193
+ if len(x.shape) != len(x_ref.shape):
194
+ raise ValueError(
195
+ f"{self.__class__.__name__} was passed "
196
+ f"incompatible inputs. For input '{x_ref.name}', "
197
+ f"expected shape {x_ref.shape}, but received "
198
+ f"instead a tensor with shape {x.shape}."
199
+ )
200
+ for dim, ref_dim in zip(x.shape, x_ref.shape):
201
+ if ref_dim is not None and dim is not None:
202
+ if dim != ref_dim:
203
+ raise ValueError(
204
+ f"{self.__class__.__name__} was passed "
205
+ f"incompatible inputs. For input '{x_ref.name}', "
206
+ f"expected shape {x_ref.shape}, but received "
207
+ f"instead a tensor with shape {x.shape}."
208
+ )
209
+
210
+
211
+ def make_node_key(op, node_index):
212
+ return str(id(op)) + "_ib-" + str(node_index)
213
+
214
+
215
+ def map_graph(inputs, outputs):
216
+ """Validates a graph's topology and gather its operations and nodes.
217
+
218
+ Args:
219
+ inputs: List of input tensors.
220
+ outputs: List of outputs tensors.
221
+
222
+ Returns:
223
+ A tuple `(nodes, nodes_by_depth, operations, operations_by_depth)`.
224
+ - nodes: set of Node instances
225
+ - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
226
+ - operations: list of Operation instances.
227
+ - operations_by_depth: dict mapping ints (depth) to lists of Operation
228
+ instances.
229
+ """
230
+ # "depth" is number of operations between output Node and the Node.
231
+ # Nodes are ordered from inputs -> outputs.
232
+ nodes_in_decreasing_depth, operation_indices = _build_map(inputs, outputs)
233
+ network_nodes = {
234
+ make_node_key(node.operation, node.operation._inbound_nodes.index(node))
235
+ for node in nodes_in_decreasing_depth
236
+ }
237
+
238
+ nodes_depths = {} # dict {node: depth value}
239
+ operations_depths = {} # dict {operation: depth value}
240
+
241
+ for node in reversed(nodes_in_decreasing_depth):
242
+ # If the depth is not set, the node has no outbound nodes (depth 0).
243
+ depth = nodes_depths.setdefault(node, 0)
244
+
245
+ # Update the depth of the corresponding operation
246
+ previous_depth = operations_depths.get(node.operation, 0)
247
+ # If we've seen this operation before at a higher depth,
248
+ # we should use that depth instead of the node depth.
249
+ # This is necessary for shared operations that have inputs at different
250
+ # depth levels in the graph.
251
+ depth = max(depth, previous_depth)
252
+ operations_depths[node.operation] = depth
253
+ nodes_depths[node] = depth
254
+
255
+ # Update the depth of inbound nodes.
256
+ # The "depth" of a node is the max of the depths
257
+ # of all nodes it is connected to + 1.
258
+ for node_dep in node.parent_nodes:
259
+ previous_depth = nodes_depths.get(node_dep, 0)
260
+ nodes_depths[node_dep] = max(depth + 1, previous_depth)
261
+
262
+ # Handle inputs that are not connected to outputs.
263
+ # We do not error out here because the inputs may be used to compute losses
264
+ # and metrics.
265
+ for input_t in inputs:
266
+ input_operation = input_t._keras_history[0]
267
+ if input_operation and input_operation not in operations_depths:
268
+ operations_depths[input_operation] = 0
269
+ operation_indices[input_operation] = -1
270
+ nodes_depths[input_operation._inbound_nodes[0]] = 0
271
+ network_nodes.add(make_node_key(input_operation, 0))
272
+
273
+ # Build a dict {depth: list of nodes with this depth}
274
+ nodes_by_depth = collections.defaultdict(list)
275
+ for node, depth in nodes_depths.items():
276
+ nodes_by_depth[depth].append(node)
277
+
278
+ # Build a dict {depth: list of operations with this depth}
279
+ operations_by_depth = collections.defaultdict(list)
280
+ for operation, depth in operations_depths.items():
281
+ operations_by_depth[depth].append(operation)
282
+
283
+ # Get sorted list of operation depths.
284
+ depth_keys = list(operations_by_depth.keys())
285
+ depth_keys.sort(reverse=True)
286
+
287
+ # Set self.operations ordered by depth.
288
+ operations = []
289
+ for depth in depth_keys:
290
+ operations_for_depth = operations_by_depth[depth]
291
+ # Network.operations needs to have a deterministic order:
292
+ # here we order them by traversal order.
293
+ operations_for_depth.sort(key=lambda x: operation_indices[x])
294
+ operations.extend(operations_for_depth)
295
+
296
+ # Get sorted list of node depths.
297
+ depth_keys = list(nodes_by_depth.keys())
298
+ depth_keys.sort(reverse=True)
299
+
300
+ # Check that all tensors required are computable.
301
+ # computable_tensors: all tensors in the graph
302
+ # that can be computed from the inputs provided.
303
+ computable_tensors = set()
304
+ for x in inputs:
305
+ computable_tensors.add(x)
306
+
307
+ operations_with_complete_input = [] # To provide a better error msg.
308
+ for depth in depth_keys:
309
+ for node in nodes_by_depth[depth]:
310
+ for x in tree.flatten(node.input_tensors):
311
+ if x not in computable_tensors:
312
+ operation = node.operation
313
+ raise ValueError(
314
+ "Graph disconnected: cannot find parent for "
315
+ f"tensor {x} at operation '{operation}'. "
316
+ "The following previous operations were accessed "
317
+ f"without issue: {operations_with_complete_input}"
318
+ )
319
+ operations_with_complete_input.append(node.operation.name)
320
+
321
+ for x in tree.flatten(node.outputs):
322
+ computable_tensors.add(x)
323
+
324
+ # Ensure name unicity, which will be crucial for serialization
325
+ # (since serialized nodes refer to operations by their name).
326
+ all_names = [operation.name for operation in operations]
327
+ for name in all_names:
328
+ if all_names.count(name) != 1:
329
+ raise ValueError(
330
+ f'The name "{name}" is used {all_names.count(name)} '
331
+ "times in the model. All operation names should be unique."
332
+ )
333
+ return network_nodes, nodes_by_depth, operations, operations_by_depth
334
+
335
+
336
+ def _build_map(inputs, outputs):
337
+ """Topologically sort nodes in order from inputs to outputs.
338
+
339
+ It uses a depth-first search to topologically sort nodes that appear in the
340
+ _keras_history connectivity metadata of `outputs`.
341
+
342
+ Args:
343
+ outputs: the output tensors whose _keras_history metadata should be
344
+ walked. This may be an arbitrary nested structure.
345
+
346
+ Returns:
347
+ A tuple like (ordered_nodes, operation_to_first_traversal_index)
348
+ ordered_nodes: list of nodes appearing in the keras history,
349
+ topologically sorted from original inputs to the `outputs`.
350
+ (If outputs have different sets of ancestors, the inputs to one
351
+ output may appear after a different output).
352
+ operation_to_first_traversal_index:
353
+ A dict mapping operation to the traversal index in the DFS where it
354
+ is seen. Note: if a operation is shared by several nodes, the dict
355
+ will onlystore the index corresponding to the *first* time the
356
+ operation seen.
357
+ """
358
+ finished_nodes = set()
359
+ nodes_in_progress = set()
360
+ nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
361
+ operation_indices = {} # operation -> in traversal order.
362
+ for output in tree.flatten(outputs):
363
+ _build_map_helper(
364
+ inputs,
365
+ output,
366
+ finished_nodes,
367
+ nodes_in_progress,
368
+ nodes_in_decreasing_depth,
369
+ operation_indices,
370
+ )
371
+ return nodes_in_decreasing_depth, operation_indices
372
+
373
+
374
+ def _build_map_helper(
375
+ inputs,
376
+ tensor,
377
+ finished_nodes,
378
+ nodes_in_progress,
379
+ nodes_in_decreasing_depth,
380
+ operation_indices,
381
+ ):
382
+ """Recursive helper for `_build_map`."""
383
+ (
384
+ operation,
385
+ node_index,
386
+ _,
387
+ ) = tensor._keras_history
388
+ if not operation:
389
+ return
390
+
391
+ node = operation._inbound_nodes[node_index]
392
+
393
+ # Don't repeat work for shared subgraphs
394
+ if node in finished_nodes:
395
+ return
396
+
397
+ # Prevent cycles.
398
+ if node in nodes_in_progress:
399
+ raise ValueError(
400
+ f"Tensor {tensor} from operation '{operation.name}' is part of a "
401
+ "cycle."
402
+ )
403
+
404
+ # Store the traversal order for operation sorting.
405
+ if operation not in operation_indices:
406
+ operation_indices[operation] = len(operation_indices)
407
+
408
+ # Propagate to all previous tensors connected to this node.
409
+ nodes_in_progress.add(node)
410
+ if not node.is_input and tensor not in tree.flatten(inputs):
411
+ for tensor in node.input_tensors:
412
+ _build_map_helper(
413
+ inputs,
414
+ tensor,
415
+ finished_nodes,
416
+ nodes_in_progress,
417
+ nodes_in_decreasing_depth,
418
+ operation_indices,
419
+ )
420
+
421
+ finished_nodes.add(node)
422
+ nodes_in_progress.remove(node)
423
+ nodes_in_decreasing_depth.append(node)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/image.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src import ops
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.backend import KerasTensor
5
+ from keras.src.backend import any_symbolic_tensors
6
+ from keras.src.ops.operation import Operation
7
+ from keras.src.ops.operation_utils import compute_conv_output_shape
8
+
9
+
10
+ class RGBToGrayscale(Operation):
11
+ def __init__(self, data_format=None):
12
+ super().__init__()
13
+ self.data_format = backend.standardize_data_format(data_format)
14
+
15
+ def call(self, images):
16
+ return backend.image.rgb_to_grayscale(
17
+ images, data_format=self.data_format
18
+ )
19
+
20
+ def compute_output_spec(self, images):
21
+ images_shape = list(images.shape)
22
+ if len(images_shape) not in (3, 4):
23
+ raise ValueError(
24
+ "Invalid images rank: expected rank 3 (single image) "
25
+ "or rank 4 (batch of images). "
26
+ f"Received: images.shape={images_shape}"
27
+ )
28
+ if self.data_format == "channels_last":
29
+ images_shape[-1] = 1
30
+ else:
31
+ images_shape[-3] = 1
32
+ return KerasTensor(shape=images_shape, dtype=images.dtype)
33
+
34
+
35
+ @keras_export("keras.ops.image.rgb_to_grayscale")
36
+ def rgb_to_grayscale(images, data_format=None):
37
+ """Convert RGB images to grayscale.
38
+
39
+ This function converts RGB images to grayscale images. It supports both
40
+ 3D and 4D tensors.
41
+
42
+ Args:
43
+ images: Input image or batch of images. Must be 3D or 4D.
44
+ data_format: A string specifying the data format of the input tensor.
45
+ It can be either `"channels_last"` or `"channels_first"`.
46
+ `"channels_last"` corresponds to inputs with shape
47
+ `(batch, height, width, channels)`, while `"channels_first"`
48
+ corresponds to inputs with shape `(batch, channels, height, width)`.
49
+ If not specified, the value will default to
50
+ `keras.config.image_data_format`.
51
+
52
+ Returns:
53
+ Grayscale image or batch of grayscale images.
54
+
55
+ Examples:
56
+
57
+ >>> import numpy as np
58
+ >>> from keras import ops
59
+ >>> x = np.random.random((2, 4, 4, 3))
60
+ >>> y = ops.image.rgb_to_grayscale(x)
61
+ >>> y.shape
62
+ (2, 4, 4, 1)
63
+
64
+ >>> x = np.random.random((4, 4, 3)) # Single RGB image
65
+ >>> y = ops.image.rgb_to_grayscale(x)
66
+ >>> y.shape
67
+ (4, 4, 1)
68
+
69
+ >>> x = np.random.random((2, 3, 4, 4))
70
+ >>> y = ops.image.rgb_to_grayscale(x, data_format="channels_first")
71
+ >>> y.shape
72
+ (2, 1, 4, 4)
73
+ """
74
+ if any_symbolic_tensors((images,)):
75
+ return RGBToGrayscale(data_format=data_format).symbolic_call(images)
76
+ return backend.image.rgb_to_grayscale(images, data_format=data_format)
77
+
78
+
79
+ class RGBToHSV(Operation):
80
+ def __init__(self, data_format=None):
81
+ super().__init__()
82
+ self.data_format = backend.standardize_data_format(data_format)
83
+
84
+ def call(self, images):
85
+ return backend.image.rgb_to_hsv(images, data_format=self.data_format)
86
+
87
+ def compute_output_spec(self, images):
88
+ images_shape = list(images.shape)
89
+ dtype = images.dtype
90
+ if len(images_shape) not in (3, 4):
91
+ raise ValueError(
92
+ "Invalid images rank: expected rank 3 (single image) "
93
+ "or rank 4 (batch of images). "
94
+ f"Received: images.shape={images_shape}"
95
+ )
96
+ if not backend.is_float_dtype(dtype):
97
+ raise ValueError(
98
+ "Invalid images dtype: expected float dtype. "
99
+ f"Received: images.dtype={dtype}"
100
+ )
101
+ return KerasTensor(shape=images_shape, dtype=images.dtype)
102
+
103
+
104
+ @keras_export("keras.ops.image.rgb_to_hsv")
105
+ def rgb_to_hsv(images, data_format=None):
106
+ """Convert RGB images to HSV.
107
+
108
+ `images` must be of float dtype, and the output is only well defined if the
109
+ values in `images` are in `[0, 1]`.
110
+
111
+ All HSV values are in `[0, 1]`. A hue of `0` corresponds to pure red, `1/3`
112
+ is pure green, and `2/3` is pure blue.
113
+
114
+ Args:
115
+ images: Input image or batch of images. Must be 3D or 4D.
116
+ data_format: A string specifying the data format of the input tensor.
117
+ It can be either `"channels_last"` or `"channels_first"`.
118
+ `"channels_last"` corresponds to inputs with shape
119
+ `(batch, height, width, channels)`, while `"channels_first"`
120
+ corresponds to inputs with shape `(batch, channels, height, width)`.
121
+ If not specified, the value will default to
122
+ `keras.config.image_data_format`.
123
+
124
+ Returns:
125
+ HSV image or batch of HSV images.
126
+
127
+ Examples:
128
+
129
+ >>> import numpy as np
130
+ >>> from keras import ops
131
+ >>> x = np.random.random((2, 4, 4, 3))
132
+ >>> y = ops.image.rgb_to_hsv(x)
133
+ >>> y.shape
134
+ (2, 4, 4, 3)
135
+
136
+ >>> x = np.random.random((4, 4, 3)) # Single RGB image
137
+ >>> y = ops.image.rgb_to_hsv(x)
138
+ >>> y.shape
139
+ (4, 4, 3)
140
+
141
+ >>> x = np.random.random((2, 3, 4, 4))
142
+ >>> y = ops.image.rgb_to_hsv(x, data_format="channels_first")
143
+ >>> y.shape
144
+ (2, 3, 4, 4)
145
+ """
146
+ if any_symbolic_tensors((images,)):
147
+ return RGBToHSV(data_format=data_format).symbolic_call(images)
148
+ return backend.image.rgb_to_hsv(images, data_format=data_format)
149
+
150
+
151
+ class HSVToRGB(Operation):
152
+ def __init__(self, data_format=None):
153
+ super().__init__()
154
+ self.data_format = backend.standardize_data_format(data_format)
155
+
156
+ def call(self, images):
157
+ return backend.image.hsv_to_rgb(images, data_format=self.data_format)
158
+
159
+ def compute_output_spec(self, images):
160
+ images_shape = list(images.shape)
161
+ dtype = images.dtype
162
+ if len(images_shape) not in (3, 4):
163
+ raise ValueError(
164
+ "Invalid images rank: expected rank 3 (single image) "
165
+ "or rank 4 (batch of images). "
166
+ f"Received: images.shape={images_shape}"
167
+ )
168
+ if not backend.is_float_dtype(dtype):
169
+ raise ValueError(
170
+ "Invalid images dtype: expected float dtype. "
171
+ f"Received: images.dtype={dtype}"
172
+ )
173
+ return KerasTensor(shape=images_shape, dtype=images.dtype)
174
+
175
+
176
+ @keras_export("keras.ops.image.hsv_to_rgb")
177
+ def hsv_to_rgb(images, data_format=None):
178
+ """Convert HSV images to RGB.
179
+
180
+ `images` must be of float dtype, and the output is only well defined if the
181
+ values in `images` are in `[0, 1]`.
182
+
183
+ Args:
184
+ images: Input image or batch of images. Must be 3D or 4D.
185
+ data_format: A string specifying the data format of the input tensor.
186
+ It can be either `"channels_last"` or `"channels_first"`.
187
+ `"channels_last"` corresponds to inputs with shape
188
+ `(batch, height, width, channels)`, while `"channels_first"`
189
+ corresponds to inputs with shape `(batch, channels, height, width)`.
190
+ If not specified, the value will default to
191
+ `keras.config.image_data_format`.
192
+
193
+ Returns:
194
+ RGB image or batch of RGB images.
195
+
196
+ Examples:
197
+
198
+ >>> import numpy as np
199
+ >>> from keras import ops
200
+ >>> x = np.random.random((2, 4, 4, 3))
201
+ >>> y = ops.image.hsv_to_rgb(x)
202
+ >>> y.shape
203
+ (2, 4, 4, 3)
204
+
205
+ >>> x = np.random.random((4, 4, 3)) # Single HSV image
206
+ >>> y = ops.image.hsv_to_rgb(x)
207
+ >>> y.shape
208
+ (4, 4, 3)
209
+
210
+ >>> x = np.random.random((2, 3, 4, 4))
211
+ >>> y = ops.image.hsv_to_rgb(x, data_format="channels_first")
212
+ >>> y.shape
213
+ (2, 3, 4, 4)
214
+ """
215
+ if any_symbolic_tensors((images,)):
216
+ return HSVToRGB(data_format=data_format).symbolic_call(images)
217
+ return backend.image.hsv_to_rgb(images, data_format=data_format)
218
+
219
+
220
+ class Resize(Operation):
221
+ def __init__(
222
+ self,
223
+ size,
224
+ interpolation="bilinear",
225
+ antialias=False,
226
+ crop_to_aspect_ratio=False,
227
+ pad_to_aspect_ratio=False,
228
+ fill_mode="constant",
229
+ fill_value=0.0,
230
+ data_format=None,
231
+ ):
232
+ super().__init__()
233
+ self.size = tuple(size)
234
+ self.interpolation = interpolation
235
+ self.antialias = antialias
236
+ self.crop_to_aspect_ratio = crop_to_aspect_ratio
237
+ self.pad_to_aspect_ratio = pad_to_aspect_ratio
238
+ self.fill_mode = fill_mode
239
+ self.fill_value = fill_value
240
+ self.data_format = backend.standardize_data_format(data_format)
241
+
242
+ def call(self, images):
243
+ return _resize(
244
+ images,
245
+ self.size,
246
+ interpolation=self.interpolation,
247
+ antialias=self.antialias,
248
+ data_format=self.data_format,
249
+ crop_to_aspect_ratio=self.crop_to_aspect_ratio,
250
+ pad_to_aspect_ratio=self.pad_to_aspect_ratio,
251
+ fill_mode=self.fill_mode,
252
+ fill_value=self.fill_value,
253
+ )
254
+
255
+ def compute_output_spec(self, images):
256
+ images_shape = list(images.shape)
257
+ if len(images_shape) not in (3, 4):
258
+ raise ValueError(
259
+ "Invalid images rank: expected rank 3 (single image) "
260
+ "or rank 4 (batch of images). Received input with shape: "
261
+ f"images.shape={images.shape}"
262
+ )
263
+ if self.data_format == "channels_last":
264
+ height_axis, width_axis = -3, -2
265
+ else:
266
+ height_axis, width_axis = -2, -1
267
+ images_shape[height_axis] = self.size[0]
268
+ images_shape[width_axis] = self.size[1]
269
+ return KerasTensor(shape=images_shape, dtype=images.dtype)
270
+
271
+
272
+ @keras_export("keras.ops.image.resize")
273
+ def resize(
274
+ images,
275
+ size,
276
+ interpolation="bilinear",
277
+ antialias=False,
278
+ crop_to_aspect_ratio=False,
279
+ pad_to_aspect_ratio=False,
280
+ fill_mode="constant",
281
+ fill_value=0.0,
282
+ data_format=None,
283
+ ):
284
+ """Resize images to size using the specified interpolation method.
285
+
286
+ Args:
287
+ images: Input image or batch of images. Must be 3D or 4D.
288
+ size: Size of output image in `(height, width)` format.
289
+ interpolation: Interpolation method. Available methods are `"nearest"`,
290
+ `"bilinear"`, and `"bicubic"`. Defaults to `"bilinear"`.
291
+ antialias: Whether to use an antialiasing filter when downsampling an
292
+ image. Defaults to `False`.
293
+ crop_to_aspect_ratio: If `True`, resize the images without aspect
294
+ ratio distortion. When the original aspect ratio differs
295
+ from the target aspect ratio, the output image will be
296
+ cropped so as to return the
297
+ largest possible window in the image (of size `(height, width)`)
298
+ that matches the target aspect ratio. By default
299
+ (`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.
300
+ pad_to_aspect_ratio: If `True`, pad the images without aspect
301
+ ratio distortion. When the original aspect ratio differs
302
+ from the target aspect ratio, the output image will be
303
+ evenly padded on the short side.
304
+ fill_mode: When using `pad_to_aspect_ratio=True`, padded areas
305
+ are filled according to the given mode. Only `"constant"` is
306
+ supported at this time
307
+ (fill with constant value, equal to `fill_value`).
308
+ fill_value: Float. Padding value to use when `pad_to_aspect_ratio=True`.
309
+ data_format: A string specifying the data format of the input tensor.
310
+ It can be either `"channels_last"` or `"channels_first"`.
311
+ `"channels_last"` corresponds to inputs with shape
312
+ `(batch, height, width, channels)`, while `"channels_first"`
313
+ corresponds to inputs with shape `(batch, channels, height, width)`.
314
+ If not specified, the value will default to
315
+ `keras.config.image_data_format`.
316
+
317
+ Returns:
318
+ Resized image or batch of images.
319
+
320
+ Examples:
321
+
322
+ >>> x = np.random.random((2, 4, 4, 3)) # batch of 2 RGB images
323
+ >>> y = keras.ops.image.resize(x, (2, 2))
324
+ >>> y.shape
325
+ (2, 2, 2, 3)
326
+
327
+ >>> x = np.random.random((4, 4, 3)) # single RGB image
328
+ >>> y = keras.ops.image.resize(x, (2, 2))
329
+ >>> y.shape
330
+ (2, 2, 3)
331
+
332
+ >>> x = np.random.random((2, 3, 4, 4)) # batch of 2 RGB images
333
+ >>> y = keras.ops.image.resize(x, (2, 2),
334
+ ... data_format="channels_first")
335
+ >>> y.shape
336
+ (2, 3, 2, 2)
337
+ """
338
+ if len(size) != 2:
339
+ raise ValueError(
340
+ "Expected `size` to be a tuple of 2 integers. "
341
+ f"Received: size={size}"
342
+ )
343
+ if len(images.shape) < 3 or len(images.shape) > 4:
344
+ raise ValueError(
345
+ "Invalid images rank: expected rank 3 (single image) "
346
+ "or rank 4 (batch of images). Received input with shape: "
347
+ f"images.shape={images.shape}"
348
+ )
349
+ if pad_to_aspect_ratio and crop_to_aspect_ratio:
350
+ raise ValueError(
351
+ "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` "
352
+ "can be `True`."
353
+ )
354
+ if any_symbolic_tensors((images,)):
355
+ return Resize(
356
+ size,
357
+ interpolation=interpolation,
358
+ antialias=antialias,
359
+ data_format=data_format,
360
+ crop_to_aspect_ratio=crop_to_aspect_ratio,
361
+ pad_to_aspect_ratio=pad_to_aspect_ratio,
362
+ fill_mode=fill_mode,
363
+ fill_value=fill_value,
364
+ ).symbolic_call(images)
365
+ return _resize(
366
+ images,
367
+ size,
368
+ interpolation=interpolation,
369
+ antialias=antialias,
370
+ crop_to_aspect_ratio=crop_to_aspect_ratio,
371
+ data_format=data_format,
372
+ pad_to_aspect_ratio=pad_to_aspect_ratio,
373
+ fill_mode=fill_mode,
374
+ fill_value=fill_value,
375
+ )
376
+
377
+
378
+ def _resize(
379
+ images,
380
+ size,
381
+ interpolation="bilinear",
382
+ antialias=False,
383
+ crop_to_aspect_ratio=False,
384
+ pad_to_aspect_ratio=False,
385
+ fill_mode="constant",
386
+ fill_value=0.0,
387
+ data_format=None,
388
+ ):
389
+ resized = backend.image.resize(
390
+ images,
391
+ size,
392
+ interpolation=interpolation,
393
+ antialias=antialias,
394
+ crop_to_aspect_ratio=crop_to_aspect_ratio,
395
+ data_format=data_format,
396
+ pad_to_aspect_ratio=pad_to_aspect_ratio,
397
+ fill_mode=fill_mode,
398
+ fill_value=fill_value,
399
+ )
400
+ if resized.dtype == images.dtype:
401
+ # Only `torch` backend will cast result to original dtype with
402
+ # correct rounding and without dtype overflow
403
+ return resized
404
+ if backend.is_int_dtype(images.dtype):
405
+ resized = ops.round(resized)
406
+ return ops.saturate_cast(resized, images.dtype)
407
+
408
+
409
+ class AffineTransform(Operation):
410
+ def __init__(
411
+ self,
412
+ interpolation="bilinear",
413
+ fill_mode="constant",
414
+ fill_value=0,
415
+ data_format=None,
416
+ ):
417
+ super().__init__()
418
+ self.interpolation = interpolation
419
+ self.fill_mode = fill_mode
420
+ self.fill_value = fill_value
421
+ self.data_format = backend.standardize_data_format(data_format)
422
+
423
+ def call(self, images, transform):
424
+ return backend.image.affine_transform(
425
+ images,
426
+ transform,
427
+ interpolation=self.interpolation,
428
+ fill_mode=self.fill_mode,
429
+ fill_value=self.fill_value,
430
+ data_format=self.data_format,
431
+ )
432
+
433
+ def compute_output_spec(self, images, transform):
434
+ if len(images.shape) not in (3, 4):
435
+ raise ValueError(
436
+ "Invalid images rank: expected rank 3 (single image) "
437
+ "or rank 4 (batch of images). Received input with shape: "
438
+ f"images.shape={images.shape}"
439
+ )
440
+ if len(transform.shape) not in (1, 2):
441
+ raise ValueError(
442
+ "Invalid transform rank: expected rank 1 (single transform) "
443
+ "or rank 2 (batch of transforms). Received input with shape: "
444
+ f"transform.shape={transform.shape}"
445
+ )
446
+ return KerasTensor(images.shape, dtype=images.dtype)
447
+
448
+
449
+ @keras_export("keras.ops.image.affine_transform")
450
+ def affine_transform(
451
+ images,
452
+ transform,
453
+ interpolation="bilinear",
454
+ fill_mode="constant",
455
+ fill_value=0,
456
+ data_format=None,
457
+ ):
458
+ """Applies the given transform(s) to the image(s).
459
+
460
+ Args:
461
+ images: Input image or batch of images. Must be 3D or 4D.
462
+ transform: Projective transform matrix/matrices. A vector of length 8 or
463
+ tensor of size N x 8. If one row of transform is
464
+ `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the output point
465
+ `(x, y)` to a transformed input point
466
+ `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
467
+ where `k = c0 x + c1 y + 1`. The transform is inverted compared to
468
+ the transform mapping input points to output points. Note that
469
+ gradients are not backpropagated into transformation parameters.
470
+ Note that `c0` and `c1` are only effective when using TensorFlow
471
+ backend and will be considered as `0` when using other backends.
472
+ interpolation: Interpolation method. Available methods are `"nearest"`,
473
+ and `"bilinear"`. Defaults to `"bilinear"`.
474
+ fill_mode: Points outside the boundaries of the input are filled
475
+ according to the given mode. Available methods are `"constant"`,
476
+ `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
477
+ - `"reflect"`: `(d c b a | a b c d | d c b a)`
478
+ The input is extended by reflecting about the edge of the last
479
+ pixel.
480
+ - `"constant"`: `(k k k k | a b c d | k k k k)`
481
+ The input is extended by filling all values beyond
482
+ the edge with the same constant value k specified by
483
+ `fill_value`.
484
+ - `"wrap"`: `(a b c d | a b c d | a b c d)`
485
+ The input is extended by wrapping around to the opposite edge.
486
+ - `"nearest"`: `(a a a a | a b c d | d d d d)`
487
+ The input is extended by the nearest pixel.
488
+ fill_value: Value used for points outside the boundaries of the input if
489
+ `fill_mode="constant"`. Defaults to `0`.
490
+ data_format: A string specifying the data format of the input tensor.
491
+ It can be either `"channels_last"` or `"channels_first"`.
492
+ `"channels_last"` corresponds to inputs with shape
493
+ `(batch, height, width, channels)`, while `"channels_first"`
494
+ corresponds to inputs with shape `(batch, channels, height, width)`.
495
+ If not specified, the value will default to
496
+ `keras.config.image_data_format`.
497
+
498
+ Returns:
499
+ Applied affine transform image or batch of images.
500
+
501
+ Examples:
502
+
503
+ >>> x = np.random.random((2, 64, 80, 3)) # batch of 2 RGB images
504
+ >>> transform = np.array(
505
+ ... [
506
+ ... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom
507
+ ... [1, 0, -20, 0, 1, -16, 0, 0], # translation
508
+ ... ]
509
+ ... )
510
+ >>> y = keras.ops.image.affine_transform(x, transform)
511
+ >>> y.shape
512
+ (2, 64, 80, 3)
513
+
514
+ >>> x = np.random.random((64, 80, 3)) # single RGB image
515
+ >>> transform = np.array([1.0, 0.5, -20, 0.5, 1.0, -16, 0, 0]) # shear
516
+ >>> y = keras.ops.image.affine_transform(x, transform)
517
+ >>> y.shape
518
+ (64, 80, 3)
519
+
520
+ >>> x = np.random.random((2, 3, 64, 80)) # batch of 2 RGB images
521
+ >>> transform = np.array(
522
+ ... [
523
+ ... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom
524
+ ... [1, 0, -20, 0, 1, -16, 0, 0], # translation
525
+ ... ]
526
+ ... )
527
+ >>> y = keras.ops.image.affine_transform(x, transform,
528
+ ... data_format="channels_first")
529
+ >>> y.shape
530
+ (2, 3, 64, 80)
531
+ """
532
+ if any_symbolic_tensors((images, transform)):
533
+ return AffineTransform(
534
+ interpolation=interpolation,
535
+ fill_mode=fill_mode,
536
+ fill_value=fill_value,
537
+ data_format=data_format,
538
+ ).symbolic_call(images, transform)
539
+ return backend.image.affine_transform(
540
+ images,
541
+ transform,
542
+ interpolation=interpolation,
543
+ fill_mode=fill_mode,
544
+ fill_value=fill_value,
545
+ data_format=data_format,
546
+ )
547
+
548
+
549
+ class ExtractPatches(Operation):
550
+ def __init__(
551
+ self,
552
+ size,
553
+ strides=None,
554
+ dilation_rate=1,
555
+ padding="valid",
556
+ data_format=None,
557
+ ):
558
+ super().__init__()
559
+ if isinstance(size, int):
560
+ size = (size, size)
561
+ self.size = size
562
+ self.strides = strides
563
+ self.dilation_rate = dilation_rate
564
+ self.padding = padding
565
+ self.data_format = backend.standardize_data_format(data_format)
566
+
567
+ def call(self, images):
568
+ return _extract_patches(
569
+ images=images,
570
+ size=self.size,
571
+ strides=self.strides,
572
+ dilation_rate=self.dilation_rate,
573
+ padding=self.padding,
574
+ data_format=self.data_format,
575
+ )
576
+
577
+ def compute_output_spec(self, images):
578
+ images_shape = list(images.shape)
579
+ original_ndim = len(images_shape)
580
+ if not self.strides:
581
+ strides = (self.size[0], self.size[1])
582
+ if self.data_format == "channels_last":
583
+ channels_in = images_shape[-1]
584
+ else:
585
+ channels_in = images_shape[-3]
586
+ if original_ndim == 3:
587
+ images_shape = [1] + images_shape
588
+ filters = self.size[0] * self.size[1] * channels_in
589
+ kernel_size = (self.size[0], self.size[1])
590
+ out_shape = compute_conv_output_shape(
591
+ images_shape,
592
+ filters,
593
+ kernel_size,
594
+ strides=strides,
595
+ padding=self.padding,
596
+ data_format=self.data_format,
597
+ dilation_rate=self.dilation_rate,
598
+ )
599
+ if original_ndim == 3:
600
+ out_shape = out_shape[1:]
601
+ return KerasTensor(shape=out_shape, dtype=images.dtype)
602
+
603
+
604
+ @keras_export("keras.ops.image.extract_patches")
605
+ def extract_patches(
606
+ images,
607
+ size,
608
+ strides=None,
609
+ dilation_rate=1,
610
+ padding="valid",
611
+ data_format=None,
612
+ ):
613
+ """Extracts patches from the image(s).
614
+
615
+ Args:
616
+ images: Input image or batch of images. Must be 3D or 4D.
617
+ size: Patch size int or tuple (patch_height, patch_width)
618
+ strides: strides along height and width. If not specified, or
619
+ if `None`, it defaults to the same value as `size`.
620
+ dilation_rate: This is the input stride, specifying how far two
621
+ consecutive patch samples are in the input. For value other than 1,
622
+ strides must be 1. NOTE: `strides > 1` is not supported in
623
+ conjunction with `dilation_rate > 1`
624
+ padding: The type of padding algorithm to use: `"same"` or `"valid"`.
625
+ data_format: A string specifying the data format of the input tensor.
626
+ It can be either `"channels_last"` or `"channels_first"`.
627
+ `"channels_last"` corresponds to inputs with shape
628
+ `(batch, height, width, channels)`, while `"channels_first"`
629
+ corresponds to inputs with shape `(batch, channels, height, width)`.
630
+ If not specified, the value will default to
631
+ `keras.config.image_data_format`.
632
+
633
+ Returns:
634
+ Extracted patches 3D (if not batched) or 4D (if batched)
635
+
636
+ Examples:
637
+
638
+ >>> image = np.random.random(
639
+ ... (2, 20, 20, 3)
640
+ ... ).astype("float32") # batch of 2 RGB images
641
+ >>> patches = keras.ops.image.extract_patches(image, (5, 5))
642
+ >>> patches.shape
643
+ (2, 4, 4, 75)
644
+ >>> image = np.random.random((20, 20, 3)).astype("float32") # 1 RGB image
645
+ >>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))
646
+ >>> patches.shape
647
+ (18, 18, 27)
648
+ """
649
+ if any_symbolic_tensors((images,)):
650
+ return ExtractPatches(
651
+ size=size,
652
+ strides=strides,
653
+ dilation_rate=dilation_rate,
654
+ padding=padding,
655
+ data_format=data_format,
656
+ ).symbolic_call(images)
657
+
658
+ return _extract_patches(
659
+ images, size, strides, dilation_rate, padding, data_format=data_format
660
+ )
661
+
662
+
663
+ def _extract_patches(
664
+ images,
665
+ size,
666
+ strides=None,
667
+ dilation_rate=1,
668
+ padding="valid",
669
+ data_format=None,
670
+ ):
671
+ if isinstance(size, int):
672
+ patch_h = patch_w = size
673
+ elif len(size) == 2:
674
+ patch_h, patch_w = size[0], size[1]
675
+ else:
676
+ raise TypeError(
677
+ "Invalid `size` argument. Expected an "
678
+ f"int or a tuple of length 2. Received: size={size}"
679
+ )
680
+ data_format = backend.standardize_data_format(data_format)
681
+ if data_format == "channels_last":
682
+ channels_in = images.shape[-1]
683
+ elif data_format == "channels_first":
684
+ channels_in = images.shape[-3]
685
+ if not strides:
686
+ strides = size
687
+ out_dim = patch_h * patch_w * channels_in
688
+ kernel = backend.numpy.eye(out_dim, dtype=images.dtype)
689
+ kernel = backend.numpy.reshape(
690
+ kernel, (patch_h, patch_w, channels_in, out_dim)
691
+ )
692
+ _unbatched = False
693
+ if len(images.shape) == 3:
694
+ _unbatched = True
695
+ images = backend.numpy.expand_dims(images, axis=0)
696
+ patches = backend.nn.conv(
697
+ inputs=images,
698
+ kernel=kernel,
699
+ strides=strides,
700
+ padding=padding,
701
+ data_format=data_format,
702
+ dilation_rate=dilation_rate,
703
+ )
704
+ if _unbatched:
705
+ patches = backend.numpy.squeeze(patches, axis=0)
706
+ return patches
707
+
708
+
709
+ class MapCoordinates(Operation):
710
+ def __init__(self, order, fill_mode="constant", fill_value=0):
711
+ super().__init__()
712
+ self.order = order
713
+ self.fill_mode = fill_mode
714
+ self.fill_value = fill_value
715
+
716
+ def call(self, inputs, coordinates):
717
+ return backend.image.map_coordinates(
718
+ inputs,
719
+ coordinates,
720
+ order=self.order,
721
+ fill_mode=self.fill_mode,
722
+ fill_value=self.fill_value,
723
+ )
724
+
725
+ def compute_output_spec(self, inputs, coordinates):
726
+ if coordinates.shape[0] != len(inputs.shape):
727
+ raise ValueError(
728
+ "First dim of `coordinates` must be the same as the rank of "
729
+ "`inputs`. "
730
+ f"Received inputs with shape: {inputs.shape} and coordinate "
731
+ f"leading dim of {coordinates.shape[0]}"
732
+ )
733
+ if len(coordinates.shape) < 2:
734
+ raise ValueError(
735
+ "Invalid coordinates rank: expected at least rank 2."
736
+ f" Received input with shape: {coordinates.shape}"
737
+ )
738
+ return KerasTensor(coordinates.shape[1:], dtype=inputs.dtype)
739
+
740
+
741
+ @keras_export("keras.ops.image.map_coordinates")
742
+ def map_coordinates(
743
+ inputs, coordinates, order, fill_mode="constant", fill_value=0
744
+ ):
745
+ """Map the input array to new coordinates by interpolation.
746
+
747
+ Note that interpolation near boundaries differs from the scipy function,
748
+ because we fixed an outstanding bug
749
+ [scipy/issues/2640](https://github.com/scipy/scipy/issues/2640).
750
+
751
+ Args:
752
+ inputs: The input array.
753
+ coordinates: The coordinates at which inputs is evaluated.
754
+ order: The order of the spline interpolation. The order must be `0` or
755
+ `1`. `0` indicates the nearest neighbor and `1` indicates the linear
756
+ interpolation.
757
+ fill_mode: Points outside the boundaries of the inputs are filled
758
+ according to the given mode. Available methods are `"constant"`,
759
+ `"nearest"`, `"wrap"` and `"mirror"` and `"reflect"`. Defaults to
760
+ `"constant"`.
761
+ - `"constant"`: `(k k k k | a b c d | k k k k)`
762
+ The inputs is extended by filling all values beyond
763
+ the edge with the same constant value k specified by
764
+ `fill_value`.
765
+ - `"nearest"`: `(a a a a | a b c d | d d d d)`
766
+ The inputs is extended by the nearest pixel.
767
+ - `"wrap"`: `(a b c d | a b c d | a b c d)`
768
+ The inputs is extended by wrapping around to the opposite edge.
769
+ - `"mirror"`: `(c d c b | a b c d | c b a b)`
770
+ The inputs is extended by mirroring about the edge.
771
+ - `"reflect"`: `(d c b a | a b c d | d c b a)`
772
+ The inputs is extended by reflecting about the edge of the last
773
+ pixel.
774
+ fill_value: Value used for points outside the boundaries of the inputs
775
+ if `fill_mode="constant"`. Defaults to `0`.
776
+
777
+ Returns:
778
+ Output input or batch of inputs.
779
+
780
+ """
781
+ if any_symbolic_tensors((inputs, coordinates)):
782
+ return MapCoordinates(
783
+ order,
784
+ fill_mode,
785
+ fill_value,
786
+ ).symbolic_call(inputs, coordinates)
787
+ return backend.image.map_coordinates(
788
+ inputs,
789
+ coordinates,
790
+ order,
791
+ fill_mode,
792
+ fill_value,
793
+ )
794
+
795
+
796
+ class PadImages(Operation):
797
+ def __init__(
798
+ self,
799
+ top_padding=None,
800
+ left_padding=None,
801
+ bottom_padding=None,
802
+ right_padding=None,
803
+ target_height=None,
804
+ target_width=None,
805
+ data_format=None,
806
+ ):
807
+ super().__init__()
808
+ self.top_padding = top_padding
809
+ self.left_padding = left_padding
810
+ self.bottom_padding = bottom_padding
811
+ self.right_padding = right_padding
812
+ self.target_height = target_height
813
+ self.target_width = target_width
814
+ self.data_format = backend.standardize_data_format(data_format)
815
+
816
+ def call(self, images):
817
+ return _pad_images(
818
+ images,
819
+ self.top_padding,
820
+ self.left_padding,
821
+ self.bottom_padding,
822
+ self.right_padding,
823
+ self.target_height,
824
+ self.target_width,
825
+ self.data_format,
826
+ )
827
+
828
+ def compute_output_spec(self, images):
829
+ images_shape = list(images.shape)
830
+
831
+ if self.data_format == "channels_last":
832
+ height_axis, width_axis = -3, -2
833
+ height, width = images_shape[height_axis], images_shape[width_axis]
834
+ else:
835
+ height_axis, width_axis = -2, -1
836
+ height, width = images_shape[height_axis], images_shape[width_axis]
837
+
838
+ target_height = self.target_height
839
+ if target_height is None and height is not None:
840
+ target_height = self.top_padding + height + self.bottom_padding
841
+ target_width = self.target_width
842
+ if target_width is None and width is not None:
843
+ target_width = self.left_padding + width + self.right_padding
844
+
845
+ images_shape[height_axis] = target_height
846
+ images_shape[width_axis] = target_width
847
+ return KerasTensor(shape=images_shape, dtype=images.dtype)
848
+
849
+
850
+ @keras_export("keras.ops.image.pad_images")
851
+ def pad_images(
852
+ images,
853
+ top_padding=None,
854
+ left_padding=None,
855
+ bottom_padding=None,
856
+ right_padding=None,
857
+ target_height=None,
858
+ target_width=None,
859
+ data_format=None,
860
+ ):
861
+ """Pad `images` with zeros to the specified `height` and `width`.
862
+
863
+ Args:
864
+ images: Input image or batch of images. Must be 3D or 4D.
865
+ top_padding: Number of rows of zeros to add on top.
866
+ left_padding: Number of columns of zeros to add on the left.
867
+ bottom_padding: Number of rows of zeros to add at the bottom.
868
+ right_padding: Number of columns of zeros to add on the right.
869
+ target_height: Height of output images.
870
+ target_width: Width of output images.
871
+ data_format: A string specifying the data format of the input tensor.
872
+ It can be either `"channels_last"` or `"channels_first"`.
873
+ `"channels_last"` corresponds to inputs with shape
874
+ `(batch, height, width, channels)`, while `"channels_first"`
875
+ corresponds to inputs with shape `(batch, channels, height, width)`.
876
+ If not specified, the value will default to
877
+ `keras.config.image_data_format`.
878
+
879
+ Returns:
880
+ Padded image or batch of images.
881
+
882
+ Example:
883
+
884
+ >>> images = np.random.random((15, 25, 3))
885
+ >>> padded_images = keras.ops.image.pad_images(
886
+ ... images, 2, 3, target_height=20, target_width=30
887
+ ... )
888
+ >>> padded_images.shape
889
+ (20, 30, 3)
890
+
891
+ >>> batch_images = np.random.random((2, 15, 25, 3))
892
+ >>> padded_batch = keras.ops.image.pad_images(
893
+ ... batch_images, 2, 3, target_height=20, target_width=30
894
+ ... )
895
+ >>> padded_batch.shape
896
+ (2, 20, 30, 3)"""
897
+
898
+ if any_symbolic_tensors((images,)):
899
+ return PadImages(
900
+ top_padding,
901
+ left_padding,
902
+ bottom_padding,
903
+ right_padding,
904
+ target_height,
905
+ target_width,
906
+ data_format,
907
+ ).symbolic_call(images)
908
+
909
+ return _pad_images(
910
+ images,
911
+ top_padding,
912
+ left_padding,
913
+ bottom_padding,
914
+ right_padding,
915
+ target_height,
916
+ target_width,
917
+ data_format,
918
+ )
919
+
920
+
921
+ def _pad_images(
922
+ images,
923
+ top_padding,
924
+ left_padding,
925
+ bottom_padding,
926
+ right_padding,
927
+ target_height,
928
+ target_width,
929
+ data_format=None,
930
+ ):
931
+ data_format = backend.standardize_data_format(data_format)
932
+ images = backend.convert_to_tensor(images)
933
+ images_shape = ops.shape(images)
934
+
935
+ # Check
936
+ if len(images_shape) not in (3, 4):
937
+ raise ValueError(
938
+ f"Invalid shape for argument `images`: "
939
+ "it must have rank 3 or 4. "
940
+ f"Received: images.shape={images_shape}"
941
+ )
942
+ if [top_padding, bottom_padding, target_height].count(None) != 1:
943
+ raise ValueError(
944
+ "Must specify exactly two of "
945
+ "top_padding, bottom_padding, target_height. "
946
+ f"Received: top_padding={top_padding}, "
947
+ f"bottom_padding={bottom_padding}, "
948
+ f"target_height={target_height}"
949
+ )
950
+ if [left_padding, right_padding, target_width].count(None) != 1:
951
+ raise ValueError(
952
+ "Must specify exactly two of "
953
+ "left_padding, right_padding, target_width. "
954
+ f"Received: left_padding={left_padding}, "
955
+ f"right_padding={right_padding}, "
956
+ f"target_width={target_width}"
957
+ )
958
+
959
+ is_batch = False if len(images_shape) == 3 else True
960
+ if data_format == "channels_last":
961
+ height, width = images_shape[-3], images_shape[-2]
962
+ else:
963
+ height, width = images_shape[-2], images_shape[-1]
964
+
965
+ # Infer padding
966
+ if top_padding is None:
967
+ top_padding = target_height - bottom_padding - height
968
+ if bottom_padding is None:
969
+ bottom_padding = target_height - top_padding - height
970
+ if left_padding is None:
971
+ left_padding = target_width - right_padding - width
972
+ if right_padding is None:
973
+ right_padding = target_width - left_padding - width
974
+
975
+ if top_padding < 0:
976
+ raise ValueError(
977
+ f"top_padding must be >= 0. Received: top_padding={top_padding}"
978
+ )
979
+ if left_padding < 0:
980
+ raise ValueError(
981
+ "left_padding must be >= 0. "
982
+ f"Received: left_padding={left_padding}"
983
+ )
984
+ if right_padding < 0:
985
+ raise ValueError(
986
+ "right_padding must be >= 0. "
987
+ f"Received: right_padding={right_padding}"
988
+ )
989
+ if bottom_padding < 0:
990
+ raise ValueError(
991
+ "bottom_padding must be >= 0. "
992
+ f"Received: bottom_padding={bottom_padding}"
993
+ )
994
+
995
+ # Compute pad_width
996
+ pad_width = [[top_padding, bottom_padding], [left_padding, right_padding]]
997
+ if data_format == "channels_last":
998
+ pad_width = pad_width + [[0, 0]]
999
+ else:
1000
+ pad_width = [[0, 0]] + pad_width
1001
+ if is_batch:
1002
+ pad_width = [[0, 0]] + pad_width
1003
+
1004
+ padded_images = backend.numpy.pad(images, pad_width)
1005
+ return padded_images
1006
+
1007
+
1008
+ class CropImages(Operation):
1009
+ def __init__(
1010
+ self,
1011
+ top_cropping,
1012
+ left_cropping,
1013
+ bottom_cropping,
1014
+ right_cropping,
1015
+ target_height,
1016
+ target_width,
1017
+ data_format=None,
1018
+ ):
1019
+ super().__init__()
1020
+ self.top_cropping = top_cropping
1021
+ self.bottom_cropping = bottom_cropping
1022
+ self.left_cropping = left_cropping
1023
+ self.right_cropping = right_cropping
1024
+ self.target_height = target_height
1025
+ self.target_width = target_width
1026
+ self.data_format = backend.standardize_data_format(data_format)
1027
+
1028
+ def call(self, images):
1029
+ return _crop_images(
1030
+ images,
1031
+ self.top_cropping,
1032
+ self.left_cropping,
1033
+ self.bottom_cropping,
1034
+ self.right_cropping,
1035
+ self.target_height,
1036
+ self.target_width,
1037
+ self.data_format,
1038
+ )
1039
+
1040
+ def compute_output_spec(self, images):
1041
+ images_shape = list(images.shape)
1042
+
1043
+ if self.data_format == "channels_last":
1044
+ height_axis, width_axis = -3, -2
1045
+ else:
1046
+ height_axis, width_axis = -2, -1
1047
+ height, width = images_shape[height_axis], images_shape[width_axis]
1048
+
1049
+ if height is None and self.target_height is None:
1050
+ raise ValueError(
1051
+ "When the height of the images is unknown, `target_height` "
1052
+ "must be specified."
1053
+ f"Received images.shape={images_shape} and "
1054
+ f"target_height={self.target_height}"
1055
+ )
1056
+ if width is None and self.target_width is None:
1057
+ raise ValueError(
1058
+ "When the width of the images is unknown, `target_width` "
1059
+ "must be specified."
1060
+ f"Received images.shape={images_shape} and "
1061
+ f"target_width={self.target_width}"
1062
+ )
1063
+
1064
+ target_height = self.target_height
1065
+ if target_height is None:
1066
+ target_height = height - self.top_cropping - self.bottom_cropping
1067
+ target_width = self.target_width
1068
+ if target_width is None:
1069
+ target_width = width - self.left_cropping - self.right_cropping
1070
+
1071
+ images_shape[height_axis] = target_height
1072
+ images_shape[width_axis] = target_width
1073
+ return KerasTensor(shape=images_shape, dtype=images.dtype)
1074
+
1075
+
1076
+ @keras_export("keras.ops.image.crop_images")
1077
+ def crop_images(
1078
+ images,
1079
+ top_cropping=None,
1080
+ left_cropping=None,
1081
+ bottom_cropping=None,
1082
+ right_cropping=None,
1083
+ target_height=None,
1084
+ target_width=None,
1085
+ data_format=None,
1086
+ ):
1087
+ """Crop `images` to a specified `height` and `width`.
1088
+
1089
+ Args:
1090
+ images: Input image or batch of images. Must be 3D or 4D.
1091
+ top_cropping: Number of columns to crop from the top.
1092
+ left_cropping: Number of columns to crop from the left.
1093
+ bottom_cropping: Number of columns to crop from the bottom.
1094
+ right_cropping: Number of columns to crop from the right.
1095
+ target_height: Height of the output images.
1096
+ target_width: Width of the output images.
1097
+ data_format: A string specifying the data format of the input tensor.
1098
+ It can be either `"channels_last"` or `"channels_first"`.
1099
+ `"channels_last"` corresponds to inputs with shape
1100
+ `(batch, height, width, channels)`, while `"channels_first"`
1101
+ corresponds to inputs with shape `(batch, channels, height, width)`.
1102
+ If not specified, the value will default to
1103
+ `keras.config.image_data_format`.
1104
+
1105
+ Returns:
1106
+ Cropped image or batch of images.
1107
+
1108
+ Example:
1109
+
1110
+ >>> images = np.reshape(np.arange(1, 28, dtype="float32"), [3, 3, 3])
1111
+ >>> images[:,:,0] # print the first channel of the images
1112
+ array([[ 1., 4., 7.],
1113
+ [10., 13., 16.],
1114
+ [19., 22., 25.]], dtype=float32)
1115
+ >>> cropped_images = keras.image.crop_images(images, 0, 0, 2, 2)
1116
+ >>> cropped_images[:,:,0] # print the first channel of the cropped images
1117
+ array([[ 1., 4.],
1118
+ [10., 13.]], dtype=float32)"""
1119
+
1120
+ if any_symbolic_tensors((images,)):
1121
+ return CropImages(
1122
+ top_cropping,
1123
+ left_cropping,
1124
+ bottom_cropping,
1125
+ right_cropping,
1126
+ target_height,
1127
+ target_width,
1128
+ data_format,
1129
+ ).symbolic_call(images)
1130
+
1131
+ return _crop_images(
1132
+ images,
1133
+ top_cropping,
1134
+ left_cropping,
1135
+ bottom_cropping,
1136
+ right_cropping,
1137
+ target_height,
1138
+ target_width,
1139
+ data_format,
1140
+ )
1141
+
1142
+
1143
+ def _crop_images(
1144
+ images,
1145
+ top_cropping,
1146
+ left_cropping,
1147
+ bottom_cropping,
1148
+ right_cropping,
1149
+ target_height,
1150
+ target_width,
1151
+ data_format=None,
1152
+ ):
1153
+ data_format = backend.standardize_data_format(data_format)
1154
+ images = backend.convert_to_tensor(images)
1155
+ images_shape = ops.shape(images)
1156
+
1157
+ # Check
1158
+ if len(images_shape) not in (3, 4):
1159
+ raise ValueError(
1160
+ f"Invalid shape for argument `images`: "
1161
+ "it must have rank 3 or 4. "
1162
+ f"Received: images.shape={images_shape}"
1163
+ )
1164
+ if [top_cropping, bottom_cropping, target_height].count(None) != 1:
1165
+ raise ValueError(
1166
+ "Must specify exactly two of "
1167
+ "top_cropping, bottom_cropping, target_height. "
1168
+ f"Received: top_cropping={top_cropping}, "
1169
+ f"bottom_cropping={bottom_cropping}, "
1170
+ f"target_height={target_height}"
1171
+ )
1172
+ if [left_cropping, right_cropping, target_width].count(None) != 1:
1173
+ raise ValueError(
1174
+ "Must specify exactly two of "
1175
+ "left_cropping, right_cropping, target_width. "
1176
+ f"Received: left_cropping={left_cropping}, "
1177
+ f"right_cropping={right_cropping}, "
1178
+ f"target_width={target_width}"
1179
+ )
1180
+
1181
+ is_batch = False if len(images_shape) == 3 else True
1182
+ if data_format == "channels_last":
1183
+ height, width = images_shape[-3], images_shape[-2]
1184
+ channels = images_shape[-1]
1185
+ else:
1186
+ height, width = images_shape[-2], images_shape[-1]
1187
+ channels = images_shape[-3]
1188
+
1189
+ # Infer padding
1190
+ if top_cropping is None:
1191
+ top_cropping = height - target_height - bottom_cropping
1192
+ if target_height is None:
1193
+ target_height = height - bottom_cropping - top_cropping
1194
+ if left_cropping is None:
1195
+ left_cropping = width - target_width - right_cropping
1196
+ if target_width is None:
1197
+ target_width = width - right_cropping - left_cropping
1198
+
1199
+ if top_cropping < 0:
1200
+ raise ValueError(
1201
+ "top_cropping must be >= 0. "
1202
+ f"Received: top_cropping={top_cropping}"
1203
+ )
1204
+ if target_height < 0:
1205
+ raise ValueError(
1206
+ "target_height must be >= 0. "
1207
+ f"Received: target_height={target_height}"
1208
+ )
1209
+ if left_cropping < 0:
1210
+ raise ValueError(
1211
+ "left_cropping must be >= 0. "
1212
+ f"Received: left_cropping={left_cropping}"
1213
+ )
1214
+ if target_width < 0:
1215
+ raise ValueError(
1216
+ "target_width must be >= 0. "
1217
+ f"Received: target_width={target_width}"
1218
+ )
1219
+
1220
+ # Compute start_indices and shape
1221
+ start_indices = [top_cropping, left_cropping]
1222
+ shape = [target_height, target_width]
1223
+ if data_format == "channels_last":
1224
+ start_indices = start_indices + [0]
1225
+ shape = shape + [channels]
1226
+ else:
1227
+ start_indices = [0] + start_indices
1228
+ shape = [channels] + shape
1229
+ if is_batch:
1230
+ batch_size = images_shape[0]
1231
+ start_indices = [0] + start_indices
1232
+ shape = [batch_size] + shape
1233
+
1234
+ cropped_images = ops.slice(images, start_indices, shape)
1235
+ return cropped_images
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/linalg.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.backend import KerasTensor
4
+ from keras.src.backend import any_symbolic_tensors
5
+ from keras.src.ops.operation import Operation
6
+ from keras.src.ops.operation_utils import reduce_shape
7
+
8
+
9
+ class Cholesky(Operation):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def call(self, x):
14
+ return _cholesky(x)
15
+
16
+ def compute_output_spec(self, x):
17
+ _assert_2d(x)
18
+ _assert_square(x)
19
+ return KerasTensor(x.shape, x.dtype)
20
+
21
+
22
+ @keras_export(["keras.ops.cholesky", "keras.ops.linalg.cholesky"])
23
+ def cholesky(x):
24
+ """Computes the Cholesky decomposition of a positive semi-definite matrix.
25
+
26
+ Args:
27
+ x: Input tensor of shape `(..., M, M)`.
28
+
29
+ Returns:
30
+ A tensor of shape `(..., M, M)` representing the lower triangular
31
+ Cholesky factor of `x`.
32
+
33
+ """
34
+ if any_symbolic_tensors((x,)):
35
+ return Cholesky().symbolic_call(x)
36
+ return _cholesky(x)
37
+
38
+
39
+ def _cholesky(x):
40
+ x = backend.convert_to_tensor(x)
41
+ _assert_2d(x)
42
+ _assert_square(x)
43
+ try:
44
+ return backend.linalg.cholesky(x)
45
+ except Exception as e:
46
+ raise ValueError(f"Cholesky decomposition failed: {e}")
47
+
48
+
49
+ class Det(Operation):
50
+ def __init__(self):
51
+ super().__init__()
52
+
53
+ def call(self, x):
54
+ return _det(x)
55
+
56
+ def compute_output_spec(self, x):
57
+ _assert_2d(x)
58
+ _assert_square(x)
59
+ return KerasTensor(x.shape[:-2], x.dtype)
60
+
61
+
62
+ @keras_export(["keras.ops.det", "keras.ops.linalg.det"])
63
+ def det(x):
64
+ """Computes the determinant of a square tensor.
65
+
66
+ Args:
67
+ x: Input tensor of shape `(..., M, M)`.
68
+
69
+ Returns:
70
+ A tensor of shape `(...,)` representing the determinant of `x`.
71
+
72
+ """
73
+ if any_symbolic_tensors((x,)):
74
+ return Det().symbolic_call(x)
75
+ return _det(x)
76
+
77
+
78
+ def _det(x):
79
+ x = backend.convert_to_tensor(x)
80
+ _assert_2d(x)
81
+ _assert_square(x)
82
+ return backend.linalg.det(x)
83
+
84
+
85
+ class Eig(Operation):
86
+ def __init__(self):
87
+ super().__init__()
88
+
89
+ def call(self, x):
90
+ return _eig(x)
91
+
92
+ def compute_output_spec(self, x):
93
+ _assert_square(x)
94
+ _assert_2d(x)
95
+ return (
96
+ KerasTensor(x.shape[:-1], x.dtype),
97
+ KerasTensor(x.shape, x.dtype),
98
+ )
99
+
100
+
101
+ @keras_export(["keras.ops.eig", "keras.ops.linalg.eig"])
102
+ def eig(x):
103
+ """Computes the eigenvalues and eigenvectors of a square matrix.
104
+
105
+ Args:
106
+ x: Input tensor of shape `(..., M, M)`.
107
+
108
+ Returns:
109
+ A tuple of two tensors: a tensor of shape `(..., M)` containing
110
+ eigenvalues and a tensor of shape `(..., M, M)` containing eigenvectors.
111
+ """
112
+ if any_symbolic_tensors((x,)):
113
+ return Eig().symbolic_call(x)
114
+ return _eig(x)
115
+
116
+
117
+ def _eig(x):
118
+ x = backend.convert_to_tensor(x)
119
+ _assert_square(x)
120
+ _assert_2d(x)
121
+ return backend.linalg.eig(x)
122
+
123
+
124
+ class Eigh(Operation):
125
+ def __init__(self):
126
+ super().__init__()
127
+
128
+ def call(self, x):
129
+ return _eigh(x)
130
+
131
+ def compute_output_spec(self, x):
132
+ _assert_square(x)
133
+ _assert_2d(x)
134
+ return (
135
+ KerasTensor(x.shape[:-1], x.dtype),
136
+ KerasTensor(x.shape, x.dtype),
137
+ )
138
+
139
+
140
+ @keras_export(["keras.ops.eigh", "keras.ops.linalg.eigh"])
141
+ def eigh(x):
142
+ """Computes the eigenvalues and eigenvectors of a complex Hermitian.
143
+
144
+ Args:
145
+ x: Input tensor of shape `(..., M, M)`.
146
+
147
+ Returns:
148
+ A tuple of two tensors: a tensor of shape `(..., M)` containing
149
+ eigenvalues and a tensor of shape `(..., M, M)` containing eigenvectors.
150
+
151
+ """
152
+ if any_symbolic_tensors((x,)):
153
+ return Eigh().symbolic_call(x)
154
+ return _eigh(x)
155
+
156
+
157
+ def _eigh(x):
158
+ x = backend.convert_to_tensor(x)
159
+ _assert_square(x)
160
+ _assert_2d(x)
161
+ return backend.linalg.eigh(x)
162
+
163
+
164
+ class Inv(Operation):
165
+ def __init__(self):
166
+ super().__init__()
167
+
168
+ def call(self, x):
169
+ return _inv(x)
170
+
171
+ def compute_output_spec(self, x):
172
+ _assert_2d(x)
173
+ _assert_square(x)
174
+ return KerasTensor(x.shape, x.dtype)
175
+
176
+
177
+ @keras_export(["keras.ops.inv", "keras.ops.linalg.inv"])
178
+ def inv(x):
179
+ """Computes the inverse of a square tensor.
180
+
181
+ Args:
182
+ x: Input tensor of shape `(..., M, M)`.
183
+
184
+ Returns:
185
+ A tensor of shape `(..., M, M)` representing the inverse of `x`.
186
+
187
+ """
188
+ if any_symbolic_tensors((x,)):
189
+ return Inv().symbolic_call(x)
190
+ return _inv(x)
191
+
192
+
193
+ def _inv(x):
194
+ x = backend.convert_to_tensor(x)
195
+ _assert_2d(x)
196
+ _assert_square(x)
197
+ return backend.linalg.inv(x)
198
+
199
+
200
+ class LuFactor(Operation):
201
+ def __init__(self):
202
+ super().__init__()
203
+
204
+ def call(self, x):
205
+ return _lu_factor(x)
206
+
207
+ def compute_output_spec(self, x):
208
+ _assert_2d(x)
209
+ batch_shape = x.shape[:-2]
210
+ m, n = x.shape[-2:]
211
+ k = min(m, n)
212
+ return (
213
+ KerasTensor(batch_shape + (m, n), x.dtype),
214
+ KerasTensor(batch_shape + (k,), x.dtype),
215
+ )
216
+
217
+
218
+ @keras_export(["keras.ops.lu_factor", "keras.ops.linalg.lu_factor"])
219
+ def lu_factor(x):
220
+ """Computes the lower-upper decomposition of a square matrix.
221
+
222
+ Args:
223
+ x: A tensor of shape `(..., M, M)`.
224
+
225
+ Returns:
226
+ A tuple of two tensors: a tensor of shape `(..., M, M)` containing the
227
+ lower and upper triangular matrices and a tensor of shape `(..., M)`
228
+ containing the pivots.
229
+
230
+ """
231
+ if any_symbolic_tensors((x,)):
232
+ return LuFactor().symbolic_call(x)
233
+ return _lu_factor(x)
234
+
235
+
236
+ def _lu_factor(x):
237
+ x = backend.convert_to_tensor(x)
238
+ _assert_2d(x)
239
+ if backend.backend() == "tensorflow":
240
+ try:
241
+ _assert_square(x)
242
+ except ValueError as e:
243
+ raise ValueError(
244
+ f"LU decomposition failed: {e}. LU decomposition is only "
245
+ "supported for square matrices in Tensorflow."
246
+ )
247
+ return backend.linalg.lu_factor(x)
248
+
249
+
250
+ class Norm(Operation):
251
+ def __init__(self, ord=None, axis=None, keepdims=False):
252
+ super().__init__()
253
+ if isinstance(ord, str):
254
+ if ord not in ("fro", "nuc"):
255
+ raise ValueError(
256
+ "Invalid `ord` argument. "
257
+ "Expected one of {'fro', 'nuc'} when using string. "
258
+ f"Received: ord={ord}"
259
+ )
260
+ if isinstance(axis, int):
261
+ axis = [axis]
262
+ self.ord = ord
263
+ self.axis = axis
264
+ self.keepdims = keepdims
265
+
266
+ def compute_output_spec(self, x):
267
+ output_dtype = backend.standardize_dtype(x.dtype)
268
+ if "int" in output_dtype or output_dtype == "bool":
269
+ output_dtype = backend.floatx()
270
+ if self.axis is None:
271
+ axis = tuple(range(len(x.shape)))
272
+ else:
273
+ axis = self.axis
274
+ num_axes = len(axis)
275
+ if num_axes == 1 and isinstance(self.ord, str):
276
+ raise ValueError(
277
+ "Invalid `ord` argument for vector norm. "
278
+ f"Received: ord={self.ord}"
279
+ )
280
+ elif num_axes == 2 and self.ord not in (
281
+ None,
282
+ "fro",
283
+ "nuc",
284
+ float("inf"),
285
+ float("-inf"),
286
+ 1,
287
+ -1,
288
+ 2,
289
+ -2,
290
+ ):
291
+ raise ValueError(
292
+ "Invalid `ord` argument for matrix norm. "
293
+ f"Received: ord={self.ord}"
294
+ )
295
+ return KerasTensor(
296
+ reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),
297
+ dtype=output_dtype,
298
+ )
299
+
300
+ def call(self, x):
301
+ x = backend.convert_to_tensor(x)
302
+ return backend.linalg.norm(
303
+ x, ord=self.ord, axis=self.axis, keepdims=self.keepdims
304
+ )
305
+
306
+
307
+ @keras_export(["keras.ops.norm", "keras.ops.linalg.norm"])
308
+ def norm(x, ord=None, axis=None, keepdims=False):
309
+ """Matrix or vector norm.
310
+
311
+ This function is able to return one of eight different matrix norms, or one
312
+ of an infinite number of vector norms (described below), depending on the
313
+ value of the `ord` parameter.
314
+
315
+ Args:
316
+ x: Input tensor.
317
+ ord: Order of the norm (see table under Notes). The default is `None`.
318
+ axis: If `axis` is an integer, it specifies the axis of `x` along which
319
+ to compute the vector norms. If `axis` is a 2-tuple, it specifies
320
+ the axes that hold 2-D matrices, and the matrix norms of these
321
+ matrices are computed.
322
+ keepdims: If this is set to `True`, the axes which are reduced are left
323
+ in the result as dimensions with size one.
324
+
325
+ Note:
326
+ For values of `ord < 1`, the result is, strictly speaking, not a
327
+ mathematical 'norm', but it may still be useful for various numerical
328
+ purposes. The following norms can be calculated:
329
+ - For matrices:
330
+ - `ord=None`: Frobenius norm
331
+ - `ord="fro"`: Frobenius norm
332
+ - `ord="nuc"`: nuclear norm
333
+ - `ord=np.inf`: `max(sum(abs(x), axis=1))`
334
+ - `ord=-np.inf`: `min(sum(abs(x), axis=1))`
335
+ - `ord=0`: not supported
336
+ - `ord=1`: `max(sum(abs(x), axis=0))`
337
+ - `ord=-1`: `min(sum(abs(x), axis=0))`
338
+ - `ord=2`: 2-norm (largest sing. value)
339
+ - `ord=-2`: smallest singular value
340
+ - other: not supported
341
+ - For vectors:
342
+ - `ord=None`: 2-norm
343
+ - `ord="fro"`: not supported
344
+ - `ord="nuc"`: not supported
345
+ - `ord=np.inf`: `max(abs(x))`
346
+ - `ord=-np.inf`: `min(abs(x))`
347
+ - `ord=0`: `sum(x != 0)`
348
+ - `ord=1`: as below
349
+ - `ord=-1`: as below
350
+ - `ord=2`: as below
351
+ - `ord=-2`: as below
352
+ - other: `sum(abs(x)**ord)**(1./ord)`
353
+
354
+ Returns:
355
+ Norm of the matrix or vector(s).
356
+
357
+ Example:
358
+
359
+ >>> x = keras.ops.reshape(keras.ops.arange(9, dtype="float32") - 4, (3, 3))
360
+ >>> keras.ops.linalg.norm(x)
361
+ 7.7459664
362
+ """
363
+ if any_symbolic_tensors((x,)):
364
+ return Norm(ord=ord, axis=axis, keepdims=keepdims).symbolic_call(x)
365
+ x = backend.convert_to_tensor(x)
366
+ return backend.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
367
+
368
+
369
+ class Qr(Operation):
370
+ def __init__(self, mode="reduced"):
371
+ super().__init__()
372
+ if mode not in {"reduced", "complete"}:
373
+ raise ValueError(
374
+ "`mode` argument value not supported. "
375
+ "Expected one of {'reduced', 'complete'}. "
376
+ f"Received: mode={mode}"
377
+ )
378
+ self.mode = mode
379
+
380
+ def compute_output_spec(self, x):
381
+ if len(x.shape) < 2:
382
+ raise ValueError(
383
+ "Input should have rank >= 2. Received: "
384
+ f"input.shape = {x.shape}"
385
+ )
386
+ m = x.shape[-2]
387
+ n = x.shape[-1]
388
+ if m is None or n is None:
389
+ raise ValueError(
390
+ "Input should have its last 2 dimensions "
391
+ "fully-defined. Received: "
392
+ f"input.shape = {x.shape}"
393
+ )
394
+ k = min(m, n)
395
+ base = tuple(x.shape[:-2])
396
+ if self.mode == "reduced":
397
+ return (
398
+ KerasTensor(shape=base + (m, k), dtype=x.dtype),
399
+ KerasTensor(shape=base + (k, n), dtype=x.dtype),
400
+ )
401
+ # 'complete' mode.
402
+ return (
403
+ KerasTensor(shape=base + (m, m), dtype=x.dtype),
404
+ KerasTensor(shape=base + (m, n), dtype=x.dtype),
405
+ )
406
+
407
+ def call(self, x):
408
+ x = backend.convert_to_tensor(x)
409
+ return backend.linalg.qr(x, mode=self.mode)
410
+
411
+
412
+ @keras_export(["keras.ops.qr", "keras.ops.linalg.qr"])
413
+ def qr(x, mode="reduced"):
414
+ """Computes the QR decomposition of a tensor.
415
+
416
+ Args:
417
+ x: Input tensor of shape `(..., M, N)`.
418
+ mode: A string specifying the mode of the QR decomposition.
419
+ - 'reduced': Returns the reduced QR decomposition. (default)
420
+ - 'complete': Returns the complete QR decomposition.
421
+
422
+ Returns:
423
+ A tuple containing two tensors. The first tensor of shape `(..., M, K)`
424
+ is the orthogonal matrix `q` and the second tensor of shape
425
+ `(..., K, N)` is the upper triangular matrix `r`, where `K = min(M, N)`.
426
+
427
+ Example:
428
+
429
+ >>> x = keras.ops.convert_to_tensor([[1., 2.], [3., 4.], [5., 6.]])
430
+ >>> q, r = qr(x)
431
+ >>> print(q)
432
+ array([[-0.16903079 0.897085]
433
+ [-0.5070925 0.2760267 ]
434
+ [-0.8451542 -0.34503305]], shape=(3, 2), dtype=float32)
435
+ """
436
+ if any_symbolic_tensors((x,)):
437
+ return Qr(mode=mode).symbolic_call(x)
438
+ x = backend.convert_to_tensor(x)
439
+ return backend.linalg.qr(x, mode=mode)
440
+
441
+
442
+ class Solve(Operation):
443
+ def __init__(self):
444
+ super().__init__()
445
+
446
+ def call(self, a, b):
447
+ return _solve(a, b)
448
+
449
+ def compute_output_spec(self, a, b):
450
+ _assert_2d(a)
451
+ _assert_square(a)
452
+ _assert_1d(b)
453
+ _assert_a_b_compat(a, b)
454
+ return KerasTensor(b.shape, b.dtype)
455
+
456
+
457
+ @keras_export(["keras.ops.solve", "keras.ops.linalg.solve"])
458
+ def solve(a, b):
459
+ """Solves a linear system of equations given by `a x = b`.
460
+
461
+ Args:
462
+ a: A tensor of shape `(..., M, M)` representing the coefficients matrix.
463
+ b: A tensor of shape `(..., M)` or `(..., M, N)` representing the
464
+ right-hand side or "dependent variable" matrix.
465
+
466
+ Returns:
467
+ A tensor of shape `(..., M)` or `(..., M, N)` representing the solution
468
+ of the linear system. Returned shape is identical to `b`.
469
+
470
+ """
471
+ if any_symbolic_tensors((a, b)):
472
+ return Solve().symbolic_call(a, b)
473
+ return _solve(a, b)
474
+
475
+
476
+ def _solve(a, b):
477
+ a = backend.convert_to_tensor(a)
478
+ b = backend.convert_to_tensor(b)
479
+ _assert_2d(a)
480
+ _assert_square(a)
481
+ _assert_1d(b)
482
+ _assert_a_b_compat(a, b)
483
+ return backend.linalg.solve(a, b)
484
+
485
+
486
+ class SolveTriangular(Operation):
487
+ def __init__(self, lower=False):
488
+ super().__init__()
489
+ self.lower = lower
490
+
491
+ def call(self, a, b):
492
+ return _solve_triangular(a, b, self.lower)
493
+
494
+ def compute_output_spec(self, a, b):
495
+ _assert_2d(a)
496
+ _assert_square(a)
497
+ _assert_1d(b)
498
+ _assert_a_b_compat(a, b)
499
+ return KerasTensor(b.shape, b.dtype)
500
+
501
+
502
+ @keras_export(
503
+ ["keras.ops.solve_triangular", "keras.ops.linalg.solve_triangular"]
504
+ )
505
+ def solve_triangular(a, b, lower=False):
506
+ """Solves a linear system of equations given by `a x = b`.
507
+
508
+ Args:
509
+ a: A tensor of shape `(..., M, M)` representing the coefficients matrix.
510
+ b: A tensor of shape `(..., M)` or `(..., M, N)` representing the
511
+ right-hand side or "dependent variable" matrix.
512
+
513
+ Returns:
514
+ A tensor of shape `(..., M)` or `(..., M, N)` representing the solution
515
+ of the linear system. Returned shape is identical to `b`.
516
+
517
+ """
518
+ if any_symbolic_tensors((a, b)):
519
+ return SolveTriangular(lower).symbolic_call(a, b)
520
+ return _solve_triangular(a, b, lower)
521
+
522
+
523
+ def _solve_triangular(a, b, lower=False):
524
+ a = backend.convert_to_tensor(a)
525
+ b = backend.convert_to_tensor(b)
526
+ _assert_2d(a)
527
+ _assert_square(a)
528
+ _assert_1d(b)
529
+ _assert_a_b_compat(a, b)
530
+ return backend.linalg.solve_triangular(a, b, lower)
531
+
532
+
533
+ class SVD(Operation):
534
+ def __init__(self, full_matrices=True, compute_uv=True):
535
+ super().__init__()
536
+ self.full_matrices = full_matrices
537
+ self.compute_uv = compute_uv
538
+
539
+ def call(self, x):
540
+ return _svd(x, self.full_matrices, self.compute_uv)
541
+
542
+ def compute_output_spec(self, x):
543
+ _assert_2d(x)
544
+ rows, columns = x.shape[-2:]
545
+ batches = x.shape[:-2]
546
+ s_shape = batches + (min(rows, columns),)
547
+ if self.full_matrices:
548
+ u_shape = batches + (rows, rows)
549
+ v_shape = batches + (columns, columns)
550
+ else:
551
+ u_shape = batches + (rows, min(rows, columns))
552
+ v_shape = batches + (min(rows, columns), columns)
553
+
554
+ if self.compute_uv:
555
+ return (
556
+ KerasTensor(u_shape, x.dtype),
557
+ KerasTensor(s_shape, x.dtype),
558
+ KerasTensor(v_shape, x.dtype),
559
+ )
560
+ return KerasTensor(s_shape, x.dtype)
561
+
562
+
563
+ @keras_export(["keras.ops.svd", "keras.ops.linalg.svd"])
564
+ def svd(x, full_matrices=True, compute_uv=True):
565
+ """Computes the singular value decomposition of a matrix.
566
+
567
+ Args:
568
+ x: Input tensor of shape `(..., M, N)`.
569
+
570
+ Returns:
571
+ A tuple of three tensors: a tensor of shape `(..., M, M)` containing the
572
+ left singular vectors, a tensor of shape `(..., M, N)` containing the
573
+ singular values and a tensor of shape `(..., N, N)` containing the
574
+ right singular vectors.
575
+
576
+ """
577
+ if any_symbolic_tensors((x,)):
578
+ return SVD(full_matrices, compute_uv).symbolic_call(x)
579
+ return _svd(x, full_matrices, compute_uv)
580
+
581
+
582
+ def _svd(x, full_matrices=True, compute_uv=True):
583
+ x = backend.convert_to_tensor(x)
584
+ _assert_2d(x)
585
+ return backend.linalg.svd(x, full_matrices, compute_uv)
586
+
587
+
588
+ class Lstsq(Operation):
589
+ def __init__(self, rcond=None):
590
+ super().__init__()
591
+ self.rcond = rcond
592
+
593
+ def call(self, a, b):
594
+ return backend.linalg.lstsq(a, b, rcond=self.rcond)
595
+
596
+ def compute_output_spec(self, a, b):
597
+ if len(a.shape) != 2:
598
+ raise ValueError(
599
+ "Expected a to have rank 2. " f"Received: a.shape={a.shape}"
600
+ )
601
+ if len(b.shape) not in (1, 2):
602
+ raise ValueError(
603
+ "Expected b to have rank 1 or 2. "
604
+ f"Received: b.shape={b.shape}"
605
+ )
606
+ m, n = a.shape
607
+ if b.shape[0] != m:
608
+ raise ValueError(
609
+ "Expected b.shape[0] to be equal to "
610
+ "a.shape[0]. Received: "
611
+ f"a.shape={a.shape}, b.shape={b.shape}"
612
+ )
613
+ if len(b.shape) == 2:
614
+ k = b.shape[1]
615
+ x = KerasTensor((n, k), dtype=a.dtype)
616
+ else:
617
+ x = KerasTensor((n,), dtype=a.dtype)
618
+ return x
619
+
620
+
621
+ @keras_export(["keras.ops.lstsq", "keras.ops.linalg.lstsq"])
622
+ def lstsq(a, b, rcond=None):
623
+ """Return the least-squares solution to a linear matrix equation.
624
+
625
+ Computes the vector x that approximately solves the equation
626
+ `a @ x = b`. The equation may be under-, well-, or over-determined
627
+ (i.e., the number of linearly independent rows of a can be less than,
628
+ equal to, or greater than its number of linearly independent columns).
629
+ If a is square and of full rank, then `x` (but for round-off error)
630
+ is the exact solution of the equation. Else, `x` minimizes the
631
+ L2 norm of `b - a * x`.
632
+
633
+ If there are multiple minimizing solutions,
634
+ the one with the smallest L2 norm is returned.
635
+
636
+ Args:
637
+ a: "Coefficient" matrix of shape `(M, N)`.
638
+ b: Ordinate or "dependent variable" values,
639
+ of shape `(M,)` or `(M, K)`.
640
+ If `b` is two-dimensional, the least-squares solution
641
+ is calculated for each of the K columns of `b`.
642
+ rcond: Cut-off ratio for small singular values of `a`.
643
+ For the purposes of rank determination,
644
+ singular values are treated as zero if they are
645
+ smaller than rcond times the largest
646
+ singular value of `a`.
647
+
648
+ Returns:
649
+ Tensor with shape `(N,)` or `(N, K)` containing
650
+ the least-squares solutions.
651
+
652
+ **NOTE:** The output differs from `numpy.linalg.lstsq`.
653
+ NumPy returns a tuple with four elements, the first of which
654
+ being the least-squares solutions and the others
655
+ being essentially never used.
656
+ Keras only returns the first value. This is done both
657
+ to ensure consistency across backends (which cannot be achieved
658
+ for the other values) and to simplify the API.
659
+ """
660
+ if any_symbolic_tensors((a, b)):
661
+ return Lstsq(rcond=rcond).symbolic_call(a, b)
662
+ return backend.linalg.lstsq(a, b, rcond=rcond)
663
+
664
+
665
+ def _assert_1d(*arrays):
666
+ for a in arrays:
667
+ if a.ndim < 1:
668
+ raise ValueError(
669
+ "Expected input to have rank >= 1. "
670
+ "Received scalar input {a}."
671
+ )
672
+
673
+
674
+ def _assert_2d(*arrays):
675
+ for a in arrays:
676
+ if a.ndim < 2:
677
+ raise ValueError(
678
+ "Expected input to have rank >= 2. "
679
+ "Received input with shape {a.shape}."
680
+ )
681
+
682
+
683
+ def _assert_square(*arrays):
684
+ for a in arrays:
685
+ m, n = a.shape[-2:]
686
+ if m != n:
687
+ raise ValueError(
688
+ "Expected a square matrix. "
689
+ f"Received non-square input with shape {a.shape}"
690
+ )
691
+
692
+
693
+ def _assert_a_b_compat(a, b):
694
+ if a.ndim == b.ndim:
695
+ if a.shape[-2] != b.shape[-2]:
696
+ raise ValueError(
697
+ "Incompatible shapes between `a` and `b`. "
698
+ "Expected `a.shape[-2] == b.shape[-2]`. "
699
+ f"Received: a.shape={a.shape}, b.shape={b.shape}"
700
+ )
701
+ elif a.ndim == b.ndim - 1:
702
+ if a.shape[-1] != b.shape[-1]:
703
+ raise ValueError(
704
+ "Incompatible shapes between `a` and `b`. "
705
+ "Expected `a.shape[-1] == b.shape[-1]`. "
706
+ f"Received: a.shape={a.shape}, b.shape={b.shape}"
707
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/math.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Commonly used math operations not included in NumPy."""
2
+
3
+ from keras.src import backend
4
+ from keras.src.api_export import keras_export
5
+ from keras.src.backend import KerasTensor
6
+ from keras.src.backend import any_symbolic_tensors
7
+ from keras.src.ops.operation import Operation
8
+ from keras.src.ops.operation_utils import reduce_shape
9
+
10
+
11
+ def _segment_reduce_validation(data, segment_ids):
12
+ data_shape = data.shape
13
+ segment_ids_shape = segment_ids.shape
14
+ if len(segment_ids_shape) > 1:
15
+ raise ValueError(
16
+ "Argument `segment_ids` should be an 1-D vector, got shape: "
17
+ f"{len(segment_ids_shape)}. Consider either flatten input with "
18
+ "segment_ids.reshape((-1)) and "
19
+ "data.reshape((-1, ) + data.shape[len(segment_ids.shape):]) or "
20
+ "vectorize with vmap."
21
+ )
22
+ if (
23
+ segment_ids_shape[0] is not None
24
+ and data_shape[0] is not None
25
+ and segment_ids_shape[0] != data_shape[0]
26
+ ):
27
+ raise ValueError(
28
+ "Argument `segment_ids` and `data` should have same leading "
29
+ f"dimension. Got {segment_ids_shape} v.s. "
30
+ f"{data_shape}."
31
+ )
32
+
33
+
34
+ class SegmentReduction(Operation):
35
+ def __init__(self, num_segments=None, sorted=False):
36
+ super().__init__()
37
+ self.num_segments = num_segments
38
+ self.sorted = sorted
39
+
40
+ def compute_output_spec(self, data, _):
41
+ output_shape = (self.num_segments,) + tuple(data.shape[1:])
42
+ return KerasTensor(shape=output_shape, dtype=data.dtype)
43
+
44
+
45
+ class SegmentSum(SegmentReduction):
46
+ def call(self, data, segment_ids):
47
+ _segment_reduce_validation(data, segment_ids)
48
+ return backend.math.segment_sum(
49
+ data,
50
+ segment_ids,
51
+ num_segments=self.num_segments,
52
+ sorted=self.sorted,
53
+ )
54
+
55
+
56
+ @keras_export("keras.ops.segment_sum")
57
+ def segment_sum(data, segment_ids, num_segments=None, sorted=False):
58
+ """Computes the sum of segments in a tensor.
59
+
60
+ Args:
61
+ data: Input tensor.
62
+ segment_ids: A N-D tensor containing segment indices for each
63
+ element in `data`. Num dims for segment ids should be strictly
64
+ smaller or equal to number of dims in data.
65
+ num_segments: An integer representing the total number of
66
+ segments. If not specified, it is inferred from the maximum
67
+ value in `segment_ids`.
68
+ sorted: A boolean indicating whether `segment_ids` is sorted.
69
+ Defaults to `False`.
70
+
71
+ Returns:
72
+ A tensor containing the sum of segments, where each element
73
+ represents the sum of the corresponding segment in `data`.
74
+
75
+ Example:
76
+
77
+ >>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
78
+ >>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
79
+ >>> num_segments = 3
80
+ >>> keras.ops.segment_sum(data, segment_ids,num_segments)
81
+ array([3, 30, 300], dtype=int32)
82
+ """
83
+ _segment_reduce_validation(data, segment_ids)
84
+ if any_symbolic_tensors((data,)):
85
+ return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids)
86
+ return backend.math.segment_sum(
87
+ data, segment_ids, num_segments=num_segments, sorted=sorted
88
+ )
89
+
90
+
91
+ class SegmentMax(SegmentReduction):
92
+ def call(self, data, segment_ids):
93
+ _segment_reduce_validation(data, segment_ids)
94
+ return backend.math.segment_max(
95
+ data,
96
+ segment_ids,
97
+ num_segments=self.num_segments,
98
+ sorted=self.sorted,
99
+ )
100
+
101
+
102
+ @keras_export("keras.ops.segment_max")
103
+ def segment_max(data, segment_ids, num_segments=None, sorted=False):
104
+ """Computes the max of segments in a tensor.
105
+
106
+ Args:
107
+ data: Input tensor.
108
+ segment_ids: A N-D tensor containing segment indices for each
109
+ element in `data`. data.shape[:len(segment_ids.shape)] should match.
110
+ num_segments: An integer representing the total number of
111
+ segments. If not specified, it is inferred from the maximum
112
+ value in `segment_ids`.
113
+ sorted: A boolean indicating whether `segment_ids` is sorted.
114
+ Defaults to `False`.
115
+
116
+ Returns:
117
+ A tensor containing the max of segments, where each element
118
+ represents the max of the corresponding segment in `data`.
119
+
120
+ Example:
121
+
122
+ >>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
123
+ >>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
124
+ >>> num_segments = 3
125
+ >>> keras.ops.segment_max(data, segment_ids, num_segments)
126
+ array([2, 20, 200], dtype=int32)
127
+ """
128
+ _segment_reduce_validation(data, segment_ids)
129
+ if any_symbolic_tensors((data,)):
130
+ return SegmentMax(num_segments, sorted).symbolic_call(data, segment_ids)
131
+ return backend.math.segment_max(
132
+ data, segment_ids, num_segments=num_segments, sorted=sorted
133
+ )
134
+
135
+
136
+ class TopK(Operation):
137
+ def __init__(self, k, sorted=False):
138
+ super().__init__()
139
+ self.k = k
140
+ self.sorted = sorted
141
+
142
+ def compute_output_spec(self, x):
143
+ output_shape = list(x.shape)
144
+ output_shape[-1] = self.k
145
+ # Return a tuple (values, indices).
146
+ return (
147
+ KerasTensor(shape=output_shape, dtype=x.dtype),
148
+ KerasTensor(shape=output_shape, dtype="int32"),
149
+ )
150
+
151
+ def call(self, x):
152
+ return backend.math.top_k(x, self.k, self.sorted)
153
+
154
+
155
+ @keras_export("keras.ops.top_k")
156
+ def top_k(x, k, sorted=True):
157
+ """Finds the top-k values and their indices in a tensor.
158
+
159
+ Args:
160
+ x: Input tensor.
161
+ k: An integer representing the number of top elements to retrieve.
162
+ sorted: A boolean indicating whether to sort the output in
163
+ descending order. Defaults to `True`.
164
+
165
+ Returns:
166
+ A tuple containing two tensors. The first tensor contains the
167
+ top-k values, and the second tensor contains the indices of the
168
+ top-k values in the input tensor.
169
+
170
+ Example:
171
+
172
+ >>> x = keras.ops.convert_to_tensor([5, 2, 7, 1, 9, 3])
173
+ >>> values, indices = top_k(x, k=3)
174
+ >>> print(values)
175
+ array([9 7 5], shape=(3,), dtype=int32)
176
+ >>> print(indices)
177
+ array([4 2 0], shape=(3,), dtype=int32)
178
+
179
+ """
180
+ if any_symbolic_tensors((x,)):
181
+ return TopK(k, sorted).symbolic_call(x)
182
+ return backend.math.top_k(x, k, sorted)
183
+
184
+
185
+ class InTopK(Operation):
186
+ def __init__(self, k):
187
+ super().__init__()
188
+ self.k = k
189
+
190
+ def compute_output_spec(self, targets, predictions):
191
+ return KerasTensor(shape=targets.shape, dtype="bool")
192
+
193
+ def call(self, targets, predictions):
194
+ return backend.math.in_top_k(targets, predictions, self.k)
195
+
196
+
197
+ @keras_export("keras.ops.in_top_k")
198
+ def in_top_k(targets, predictions, k):
199
+ """Checks if the targets are in the top-k predictions.
200
+
201
+ Args:
202
+ targets: A tensor of true labels.
203
+ predictions: A tensor of predicted labels.
204
+ k: An integer representing the number of predictions to consider.
205
+
206
+ Returns:
207
+ A boolean tensor of the same shape as `targets`, where each element
208
+ indicates whether the corresponding target is in the top-k predictions.
209
+
210
+ Example:
211
+
212
+ >>> targets = keras.ops.convert_to_tensor([2, 5, 3])
213
+ >>> predictions = keras.ops.convert_to_tensor(
214
+ ... [[0.1, 0.4, 0.6, 0.9, 0.5],
215
+ ... [0.1, 0.7, 0.9, 0.8, 0.3],
216
+ ... [0.1, 0.6, 0.9, 0.9, 0.5]])
217
+ >>> in_top_k(targets, predictions, k=3)
218
+ array([ True False True], shape=(3,), dtype=bool)
219
+ """
220
+ if any_symbolic_tensors((targets, predictions)):
221
+ return InTopK(k).symbolic_call(targets, predictions)
222
+ return backend.math.in_top_k(targets, predictions, k)
223
+
224
+
225
+ class Logsumexp(Operation):
226
+ def __init__(self, axis=None, keepdims=False):
227
+ super().__init__()
228
+ self.axis = axis
229
+ self.keepdims = keepdims
230
+
231
+ def compute_output_spec(self, x):
232
+ output_shape = reduce_shape(x.shape, self.axis, self.keepdims)
233
+ return KerasTensor(shape=output_shape)
234
+
235
+ def call(self, x):
236
+ return backend.math.logsumexp(x, axis=self.axis, keepdims=self.keepdims)
237
+
238
+
239
+ @keras_export("keras.ops.logsumexp")
240
+ def logsumexp(x, axis=None, keepdims=False):
241
+ """Computes the logarithm of sum of exponentials of elements in a tensor.
242
+
243
+ Args:
244
+ x: Input tensor.
245
+ axis: An integer or a tuple of integers specifying the axis/axes
246
+ along which to compute the sum. If `None`, the sum is computed
247
+ over all elements. Defaults to `None`.
248
+ keepdims: A boolean indicating whether to keep the dimensions of
249
+ the input tensor when computing the sum. Defaults to `False`.
250
+
251
+ Returns:
252
+ A tensor containing the logarithm of the sum of exponentials of
253
+ elements in `x`.
254
+
255
+ Example:
256
+
257
+ >>> x = keras.ops.convert_to_tensor([1., 2., 3.])
258
+ >>> logsumexp(x)
259
+ 3.407606
260
+ """
261
+ if any_symbolic_tensors((x,)):
262
+ return Logsumexp(axis, keepdims).symbolic_call(x)
263
+ return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)
264
+
265
+
266
+ class ExtractSequences(Operation):
267
+ def __init__(self, sequence_length, sequence_stride):
268
+ super().__init__()
269
+ self.sequence_length = sequence_length
270
+ self.sequence_stride = sequence_stride
271
+
272
+ def compute_output_spec(self, x):
273
+ if len(x.shape) < 1:
274
+ raise ValueError(
275
+ f"Input should have rank >= 1. "
276
+ f"Received: input.shape = {x.shape}"
277
+ )
278
+ if x.shape[-1] is not None:
279
+ num_sequences = (
280
+ 1 + (x.shape[-1] - self.sequence_length) // self.sequence_stride
281
+ )
282
+ else:
283
+ num_sequences = None
284
+ new_shape = x.shape[:-1] + (num_sequences, self.sequence_length)
285
+ return KerasTensor(shape=new_shape, dtype=x.dtype)
286
+
287
+ def call(self, x):
288
+ return backend.math.extract_sequences(
289
+ x,
290
+ sequence_length=self.sequence_length,
291
+ sequence_stride=self.sequence_stride,
292
+ )
293
+
294
+
295
+ @keras_export("keras.ops.extract_sequences")
296
+ def extract_sequences(x, sequence_length, sequence_stride):
297
+ """Expands the dimension of last axis into sequences of `sequence_length`.
298
+
299
+ Slides a window of size `sequence_length` over the last axis of the input
300
+ with a stride of `sequence_stride`, replacing the last axis with
301
+ `[num_sequences, sequence_length]` sequences.
302
+
303
+ If the dimension along the last axis is N, the number of sequences can be
304
+ computed by:
305
+
306
+ `num_sequences = 1 + (N - sequence_length) // sequence_stride`
307
+
308
+ Args:
309
+ x: Input tensor.
310
+ sequence_length: An integer representing the sequences length.
311
+ sequence_stride: An integer representing the sequences hop size.
312
+
313
+ Returns:
314
+ A tensor of sequences with shape [..., num_sequences, sequence_length].
315
+
316
+ Example:
317
+
318
+ >>> x = keras.ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
319
+ >>> extract_sequences(x, 3, 2)
320
+ array([[1, 2, 3],
321
+ [3, 4, 5]])
322
+ """
323
+ if any_symbolic_tensors((x,)):
324
+ return ExtractSequences(sequence_length, sequence_stride).symbolic_call(
325
+ x
326
+ )
327
+ return backend.math.extract_sequences(x, sequence_length, sequence_stride)
328
+
329
+
330
+ class FFT(Operation):
331
+ def __init__(self, axis=-1):
332
+ super().__init__()
333
+ self.axis = axis
334
+
335
+ def compute_output_spec(self, x):
336
+ if not isinstance(x, (tuple, list)) or len(x) != 2:
337
+ raise ValueError(
338
+ "Input `x` should be a tuple of two tensors - real and "
339
+ f"imaginary. Received: x={x}"
340
+ )
341
+
342
+ real, imag = x
343
+ # Both real and imaginary parts should have the same shape.
344
+ if real.shape != imag.shape:
345
+ raise ValueError(
346
+ "Input `x` should be a tuple of two tensors - real and "
347
+ "imaginary. Both the real and imaginary parts should have the "
348
+ f"same shape. Received: x[0].shape = {real.shape}, "
349
+ f"x[1].shape = {imag.shape}"
350
+ )
351
+
352
+ # We are calculating 1D FFT. Hence, rank >= 1.
353
+ if len(real.shape) < 1:
354
+ raise ValueError(
355
+ f"Input should have rank >= 1. "
356
+ f"Received: input.shape = {real.shape}"
357
+ )
358
+
359
+ # The axis along which we are calculating FFT should be fully-defined.
360
+ m = real.shape[-1]
361
+ if m is None:
362
+ raise ValueError(
363
+ f"Input should have its {self.axis}th axis fully-defined. "
364
+ f"Received: input.shape = {real.shape}"
365
+ )
366
+
367
+ return (
368
+ KerasTensor(shape=real.shape, dtype=real.dtype),
369
+ KerasTensor(shape=imag.shape, dtype=imag.dtype),
370
+ )
371
+
372
+ def call(self, x):
373
+ return backend.math.fft(x)
374
+
375
+
376
+ @keras_export("keras.ops.fft")
377
+ def fft(x):
378
+ """Computes the Fast Fourier Transform along last axis of input.
379
+
380
+ Args:
381
+ x: Tuple of the real and imaginary parts of the input tensor. Both
382
+ tensors in the tuple should be of floating type.
383
+
384
+ Returns:
385
+ A tuple containing two tensors - the real and imaginary parts of the
386
+ output tensor.
387
+
388
+ Example:
389
+
390
+ >>> x = (
391
+ ... keras.ops.convert_to_tensor([1., 2.]),
392
+ ... keras.ops.convert_to_tensor([0., 1.]),
393
+ ... )
394
+ >>> fft(x)
395
+ (array([ 3., -1.], dtype=float32), array([ 1., -1.], dtype=float32))
396
+ """
397
+ if any_symbolic_tensors(x):
398
+ return FFT().symbolic_call(x)
399
+ return backend.math.fft(x)
400
+
401
+
402
+ class FFT2(Operation):
403
+ def __init__(self):
404
+ super().__init__()
405
+ self.axes = (-2, -1)
406
+
407
+ def compute_output_spec(self, x):
408
+ if not isinstance(x, (tuple, list)) or len(x) != 2:
409
+ raise ValueError(
410
+ "Input `x` should be a tuple of two tensors - real and "
411
+ f"imaginary. Received: x={x}"
412
+ )
413
+
414
+ real, imag = x
415
+ # Both real and imaginary parts should have the same shape.
416
+ if real.shape != imag.shape:
417
+ raise ValueError(
418
+ "Input `x` should be a tuple of two tensors - real and "
419
+ "imaginary. Both the real and imaginary parts should have the "
420
+ f"same shape. Received: x[0].shape = {real.shape}, "
421
+ f"x[1].shape = {imag.shape}"
422
+ )
423
+ # We are calculating 2D FFT. Hence, rank >= 2.
424
+ if len(real.shape) < 2:
425
+ raise ValueError(
426
+ f"Input should have rank >= 2. "
427
+ f"Received: input.shape = {real.shape}"
428
+ )
429
+
430
+ # The axes along which we are calculating FFT should be fully-defined.
431
+ m = real.shape[self.axes[0]]
432
+ n = real.shape[self.axes[1]]
433
+ if m is None or n is None:
434
+ raise ValueError(
435
+ f"Input should have its {self.axes} axes fully-defined. "
436
+ f"Received: input.shape = {real.shape}"
437
+ )
438
+
439
+ return (
440
+ KerasTensor(shape=real.shape, dtype=real.dtype),
441
+ KerasTensor(shape=imag.shape, dtype=imag.dtype),
442
+ )
443
+
444
+ def call(self, x):
445
+ return backend.math.fft2(x)
446
+
447
+
448
+ @keras_export("keras.ops.fft2")
449
+ def fft2(x):
450
+ """Computes the 2D Fast Fourier Transform along the last two axes of input.
451
+
452
+ Args:
453
+ x: Tuple of the real and imaginary parts of the input tensor. Both
454
+ tensors in the tuple should be of floating type.
455
+
456
+ Returns:
457
+ A tuple containing two tensors - the real and imaginary parts of the
458
+ output.
459
+
460
+ Example:
461
+
462
+ >>> x = (
463
+ ... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]),
464
+ ... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]),
465
+ ... )
466
+ >>> fft2(x)
467
+ (array([[ 6., 0.],
468
+ [ 0., -2.]], dtype=float32), array([[ 2., 0.],
469
+ [ 0., -2.]], dtype=float32))
470
+ """
471
+ if any_symbolic_tensors(x):
472
+ return FFT2().symbolic_call(x)
473
+ return backend.math.fft2(x)
474
+
475
+
476
+ class IFFT2(Operation):
477
+ def __init__(self):
478
+ super().__init__()
479
+ self.axes = (-2, -1)
480
+
481
+ def compute_output_spec(self, x):
482
+ if not isinstance(x, (tuple, list)) or len(x) != 2:
483
+ raise ValueError(
484
+ "Input `x` should be a tuple of two tensors - real and "
485
+ f"imaginary. Received: x={x}"
486
+ )
487
+
488
+ real, imag = x
489
+ # Both real and imaginary parts should have the same shape.
490
+ if real.shape != imag.shape:
491
+ raise ValueError(
492
+ "Input `x` should be a tuple of two tensors - real and "
493
+ "imaginary. Both the real and imaginary parts should have the "
494
+ f"same shape. Received: x[0].shape = {real.shape}, "
495
+ f"x[1].shape = {imag.shape}"
496
+ )
497
+ # We are calculating 2D IFFT. Hence, rank >= 2.
498
+ if len(real.shape) < 2:
499
+ raise ValueError(
500
+ f"Input should have rank >= 2. "
501
+ f"Received: input.shape = {real.shape}"
502
+ )
503
+
504
+ # The axes along which we are calculating IFFT should be fully-defined.
505
+ m = real.shape[self.axes[0]]
506
+ n = real.shape[self.axes[1]]
507
+ if m is None or n is None:
508
+ raise ValueError(
509
+ f"Input should have its {self.axes} axes fully-defined. "
510
+ f"Received: input.shape = {real.shape}"
511
+ )
512
+
513
+ return (
514
+ KerasTensor(shape=real.shape, dtype=real.dtype),
515
+ KerasTensor(shape=imag.shape, dtype=imag.dtype),
516
+ )
517
+
518
+ def call(self, x):
519
+ return backend.math.ifft2(x)
520
+
521
+
522
+ @keras_export("keras.ops.ifft2")
523
+ def ifft2(x):
524
+ """Computes the 2D Inverse Fast Fourier Transform along the last two axes of
525
+ input.
526
+
527
+ Args:
528
+ x: Tuple of the real and imaginary parts of the input tensor. Both
529
+ tensors in the tuple should be of floating type.
530
+
531
+ Returns:
532
+ A tuple containing two tensors - the real and imaginary parts of the
533
+ output.
534
+
535
+ Example:
536
+
537
+ >>> x = (
538
+ ... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]),
539
+ ... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]),
540
+ ... )
541
+ >>> ifft2(x)
542
+ (array([[ 6., 0.],
543
+ [ 0., -2.]], dtype=float32), array([[ 2., 0.],
544
+ [ 0., -2.]], dtype=float32))
545
+ """
546
+ if any_symbolic_tensors(x):
547
+ return IFFT2().symbolic_call(x)
548
+ return backend.math.ifft2(x)
549
+
550
+
551
+ class RFFT(Operation):
552
+ def __init__(self, fft_length=None):
553
+ super().__init__()
554
+ self.fft_length = fft_length
555
+
556
+ def compute_output_spec(self, x):
557
+ # We are calculating 1D RFFT. Hence, rank >= 1.
558
+ if len(x.shape) < 1:
559
+ raise ValueError(
560
+ f"Input should have rank >= 1. "
561
+ f"Received: input.shape = {x.shape}"
562
+ )
563
+
564
+ if self.fft_length is not None:
565
+ new_last_dimension = self.fft_length // 2 + 1
566
+ else:
567
+ if x.shape[-1] is not None:
568
+ new_last_dimension = x.shape[-1] // 2 + 1
569
+ else:
570
+ new_last_dimension = None
571
+ new_shape = x.shape[:-1] + (new_last_dimension,)
572
+
573
+ return (
574
+ KerasTensor(shape=new_shape, dtype=x.dtype),
575
+ KerasTensor(shape=new_shape, dtype=x.dtype),
576
+ )
577
+
578
+ def call(self, x):
579
+ return backend.math.rfft(x, fft_length=self.fft_length)
580
+
581
+
582
+ @keras_export("keras.ops.rfft")
583
+ def rfft(x, fft_length=None):
584
+ """Real-valued Fast Fourier Transform along the last axis of the input.
585
+
586
+ Computes the 1D Discrete Fourier Transform of a real-valued signal over the
587
+ inner-most dimension of input.
588
+
589
+ Since the Discrete Fourier Transform of a real-valued signal is
590
+ Hermitian-symmetric, RFFT only returns the `fft_length / 2 + 1` unique
591
+ components of the FFT: the zero-frequency term, followed by the
592
+ `fft_length / 2` positive-frequency terms.
593
+
594
+ Along the axis RFFT is computed on, if `fft_length` is smaller than the
595
+ corresponding dimension of the input, the dimension is cropped. If it is
596
+ larger, the dimension is padded with zeros.
597
+
598
+ Args:
599
+ x: Input tensor.
600
+ fft_length: An integer representing the number of the fft length. If not
601
+ specified, it is inferred from the length of the last axis of `x`.
602
+ Defaults to `None`.
603
+
604
+ Returns:
605
+ A tuple containing two tensors - the real and imaginary parts of the
606
+ output.
607
+
608
+ Examples:
609
+
610
+ >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
611
+ >>> rfft(x)
612
+ (array([10.0, -2.5, -2.5]), array([0.0, 3.4409548, 0.81229924]))
613
+
614
+ >>> rfft(x, 3)
615
+ (array([3.0, -1.5]), array([0.0, 0.8660254]))
616
+ """
617
+ if any_symbolic_tensors((x,)):
618
+ return RFFT(fft_length).symbolic_call(x)
619
+ return backend.math.rfft(x, fft_length)
620
+
621
+
622
+ class IRFFT(Operation):
623
+ def __init__(self, fft_length=None):
624
+ super().__init__()
625
+ self.fft_length = fft_length
626
+
627
+ def compute_output_spec(self, x):
628
+ if not isinstance(x, (tuple, list)) or len(x) != 2:
629
+ raise ValueError(
630
+ "Input `x` should be a tuple of two tensors - real and "
631
+ f"imaginary. Received: x={x}"
632
+ )
633
+ real, imag = x
634
+ # Both real and imaginary parts should have the same shape.
635
+ if real.shape != imag.shape:
636
+ raise ValueError(
637
+ "Input `x` should be a tuple of two tensors - real and "
638
+ "imaginary. Both the real and imaginary parts should have the "
639
+ f"same shape. Received: x[0].shape = {real.shape}, "
640
+ f"x[1].shape = {imag.shape}"
641
+ )
642
+ # We are calculating 1D IRFFT. Hence, rank >= 1.
643
+ if len(real.shape) < 1:
644
+ raise ValueError(
645
+ f"Input should have rank >= 1. "
646
+ f"Received: input.shape = {real.shape}"
647
+ )
648
+
649
+ if self.fft_length is not None:
650
+ new_last_dimension = self.fft_length
651
+ else:
652
+ if real.shape[-1] is not None:
653
+ new_last_dimension = 2 * (real.shape[-1] - 1)
654
+ else:
655
+ new_last_dimension = None
656
+ new_shape = real.shape[:-1] + (new_last_dimension,)
657
+ return KerasTensor(shape=new_shape, dtype=real.dtype)
658
+
659
+ def call(self, x):
660
+ return backend.math.irfft(x, fft_length=self.fft_length)
661
+
662
+
663
+ @keras_export("keras.ops.irfft")
664
+ def irfft(x, fft_length=None):
665
+ """Inverse real-valued Fast Fourier transform along the last axis.
666
+
667
+ Computes the inverse 1D Discrete Fourier Transform of a real-valued signal
668
+ over the inner-most dimension of input.
669
+
670
+ The inner-most dimension of the input is assumed to be the result of RFFT:
671
+ the `fft_length / 2 + 1` unique components of the DFT of a real-valued
672
+ signal. If `fft_length` is not provided, it is computed from the size of the
673
+ inner-most dimension of the input `(fft_length = 2 * (inner - 1))`. If the
674
+ FFT length used to compute is odd, it should be provided since it cannot
675
+ be inferred properly.
676
+
677
+ Along the axis IRFFT is computed on, if `fft_length / 2 + 1` is smaller than
678
+ the corresponding dimension of the input, the dimension is cropped. If it is
679
+ larger, the dimension is padded with zeros.
680
+
681
+ Args:
682
+ x: Tuple of the real and imaginary parts of the input tensor. Both
683
+ tensors in the tuple should be of floating type.
684
+ fft_length: An integer representing the number of the fft length. If not
685
+ specified, it is inferred from the length of the last axis of `x`.
686
+ Defaults to `None`.
687
+
688
+ Returns:
689
+ A tensor containing the inverse real-valued Fast Fourier Transform
690
+ along the last axis of `x`.
691
+
692
+ Examples:
693
+
694
+ >>> real = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
695
+ >>> imag = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
696
+ >>> irfft((real, imag))
697
+ array([0.66666667, -0.9106836, 0.24401694])
698
+
699
+ >>> irfft(rfft(real, 5), 5)
700
+ array([0.0, 1.0, 2.0, 3.0, 4.0])
701
+ """
702
+ if any_symbolic_tensors(x):
703
+ return IRFFT(fft_length).symbolic_call(x)
704
+ return backend.math.irfft(x, fft_length)
705
+
706
+
707
+ class STFT(Operation):
708
+ def __init__(
709
+ self,
710
+ sequence_length,
711
+ sequence_stride,
712
+ fft_length,
713
+ window="hann",
714
+ center=True,
715
+ ):
716
+ super().__init__()
717
+ self.sequence_length = sequence_length
718
+ self.sequence_stride = sequence_stride
719
+ self.fft_length = fft_length
720
+ self.window = window
721
+ self.center = center
722
+
723
+ def compute_output_spec(self, x):
724
+ if x.shape[-1] is not None:
725
+ padded = 0 if self.center is False else (self.fft_length // 2) * 2
726
+ num_sequences = (
727
+ 1
728
+ + (x.shape[-1] + padded - self.fft_length)
729
+ // self.sequence_stride
730
+ )
731
+ else:
732
+ num_sequences = None
733
+ new_shape = x.shape[:-1] + (num_sequences, self.fft_length // 2 + 1)
734
+ return (
735
+ KerasTensor(shape=new_shape, dtype=x.dtype),
736
+ KerasTensor(shape=new_shape, dtype=x.dtype),
737
+ )
738
+
739
+ def call(self, x):
740
+ return backend.math.stft(
741
+ x,
742
+ sequence_length=self.sequence_length,
743
+ sequence_stride=self.sequence_stride,
744
+ fft_length=self.fft_length,
745
+ window=self.window,
746
+ center=self.center,
747
+ )
748
+
749
+
750
+ @keras_export("keras.ops.stft")
751
+ def stft(
752
+ x, sequence_length, sequence_stride, fft_length, window="hann", center=True
753
+ ):
754
+ """Short-Time Fourier Transform along the last axis of the input.
755
+
756
+ The STFT computes the Fourier transform of short overlapping windows of the
757
+ input. This giving frequency components of the signal as they change over
758
+ time.
759
+
760
+ Args:
761
+ x: Input tensor.
762
+ sequence_length: An integer representing the sequence length.
763
+ sequence_stride: An integer representing the sequence hop size.
764
+ fft_length: An integer representing the size of the FFT to apply. If not
765
+ specified, uses the smallest power of 2 enclosing `sequence_length`.
766
+ window: A string, a tensor of the window or `None`. If `window` is a
767
+ string, available values are `"hann"` and `"hamming"`. If `window`
768
+ is a tensor, it will be used directly as the window and its length
769
+ must be `sequence_length`. If `window` is `None`, no windowing is
770
+ used. Defaults to `"hann"`.
771
+ center: Whether to pad `x` on both sides so that the t-th sequence is
772
+ centered at time `t * sequence_stride`. Otherwise, the t-th sequence
773
+ begins at time `t * sequence_stride`. Defaults to `True`.
774
+
775
+ Returns:
776
+ A tuple containing two tensors - the real and imaginary parts of the
777
+ STFT output.
778
+
779
+ Example:
780
+
781
+ >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
782
+ >>> stft(x, 3, 2, 3)
783
+ (array([[0.75, -0.375],
784
+ [3.75, -1.875],
785
+ [5.25, -2.625]]), array([[0.0, 0.64951905],
786
+ [0.0, 0.64951905],
787
+ [0.0, -0.64951905]]))
788
+ """
789
+ if any_symbolic_tensors((x,)):
790
+ return STFT(
791
+ sequence_length=sequence_length,
792
+ sequence_stride=sequence_stride,
793
+ fft_length=fft_length,
794
+ window=window,
795
+ center=center,
796
+ ).symbolic_call(x)
797
+ return backend.math.stft(
798
+ x,
799
+ sequence_length=sequence_length,
800
+ sequence_stride=sequence_stride,
801
+ fft_length=fft_length,
802
+ window=window,
803
+ center=center,
804
+ )
805
+
806
+
807
+ class ISTFT(Operation):
808
+ def __init__(
809
+ self,
810
+ sequence_length,
811
+ sequence_stride,
812
+ fft_length,
813
+ length=None,
814
+ window="hann",
815
+ center=True,
816
+ ):
817
+ super().__init__()
818
+ self.sequence_length = sequence_length
819
+ self.sequence_stride = sequence_stride
820
+ self.fft_length = fft_length
821
+ self.length = length
822
+ self.window = window
823
+ self.center = center
824
+
825
+ def compute_output_spec(self, x):
826
+ if not isinstance(x, (tuple, list)) or len(x) != 2:
827
+ raise ValueError(
828
+ "Input `x` should be a tuple of two tensors - real and "
829
+ f"imaginary. Received: x={x}"
830
+ )
831
+ real, imag = x
832
+ # Both real and imaginary parts should have the same shape.
833
+ if real.shape != imag.shape:
834
+ raise ValueError(
835
+ "Input `x` should be a tuple of two tensors - real and "
836
+ "imaginary. Both the real and imaginary parts should have the "
837
+ f"same shape. Received: x[0].shape = {real.shape}, "
838
+ f"x[1].shape = {imag.shape}"
839
+ )
840
+ if len(real.shape) < 2:
841
+ raise ValueError(
842
+ f"Input should have rank >= 2. "
843
+ f"Received: input.shape = {real.shape}"
844
+ )
845
+ if real.shape[-2] is not None:
846
+ output_size = (
847
+ real.shape[-2] - 1
848
+ ) * self.sequence_stride + self.fft_length
849
+ if self.length is not None:
850
+ output_size = self.length
851
+ elif self.center:
852
+ output_size = output_size - (self.fft_length // 2) * 2
853
+ else:
854
+ output_size = None
855
+ new_shape = real.shape[:-2] + (output_size,)
856
+ return KerasTensor(shape=new_shape, dtype=real.dtype)
857
+
858
+ def call(self, x):
859
+ return backend.math.istft(
860
+ x,
861
+ sequence_length=self.sequence_length,
862
+ sequence_stride=self.sequence_stride,
863
+ fft_length=self.fft_length,
864
+ length=self.length,
865
+ window=self.window,
866
+ center=self.center,
867
+ )
868
+
869
+
870
+ @keras_export("keras.ops.istft")
871
+ def istft(
872
+ x,
873
+ sequence_length,
874
+ sequence_stride,
875
+ fft_length,
876
+ length=None,
877
+ window="hann",
878
+ center=True,
879
+ ):
880
+ """Inverse Short-Time Fourier Transform along the last axis of the input.
881
+
882
+ To reconstruct an original waveform, the parameters should be the same in
883
+ `stft`.
884
+
885
+ Args:
886
+ x: Tuple of the real and imaginary parts of the input tensor. Both
887
+ tensors in the tuple should be of floating type.
888
+ sequence_length: An integer representing the sequence length.
889
+ sequence_stride: An integer representing the sequence hop size.
890
+ fft_length: An integer representing the size of the FFT that produced
891
+ `stft`. Should be of type `int32`.
892
+ length: An integer representing the output is clipped to exactly length.
893
+ If not specified, no padding or clipping take place. Defaults to
894
+ `None`.
895
+ window: A string, a tensor of the window or `None`. If `window` is a
896
+ string, available values are `"hann"` and `"hamming"`. If `window`
897
+ is a tensor, it will be used directly as the window and its length
898
+ must be `sequence_length`. If `window` is `None`, no windowing is
899
+ used. Defaults to `"hann"`.
900
+ center: Whether `x` was padded on both sides so that the t-th sequence
901
+ is centered at time `t * sequence_stride`. Defaults to `True`.
902
+
903
+ Returns:
904
+ A tensor containing the inverse Short-Time Fourier Transform along the
905
+ last axis of `x`.
906
+
907
+ Example:
908
+
909
+ >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
910
+ >>> istft(stft(x, 1, 1, 1), 1, 1, 1)
911
+ array([0.0, 1.0, 2.0, 3.0, 4.0])
912
+ """
913
+ if any_symbolic_tensors(x):
914
+ return ISTFT(
915
+ sequence_length=sequence_length,
916
+ sequence_stride=sequence_stride,
917
+ fft_length=fft_length,
918
+ window=window,
919
+ center=center,
920
+ ).symbolic_call(x)
921
+ return backend.math.istft(
922
+ x,
923
+ sequence_length=sequence_length,
924
+ sequence_stride=sequence_stride,
925
+ fft_length=fft_length,
926
+ length=length,
927
+ window=window,
928
+ center=center,
929
+ )
930
+
931
+
932
+ class Rsqrt(Operation):
933
+ def call(self, x):
934
+ x = backend.convert_to_tensor(x)
935
+ return backend.math.rsqrt(x)
936
+
937
+ def compute_output_spec(self, x):
938
+ return KerasTensor(x.shape, dtype=x.dtype)
939
+
940
+
941
+ @keras_export("keras.ops.rsqrt")
942
+ def rsqrt(x):
943
+ """Computes reciprocal of square root of x element-wise.
944
+
945
+ Args:
946
+ x: input tensor
947
+
948
+ Returns:
949
+ A tensor with the same dtype as `x`.
950
+
951
+ Example:
952
+
953
+ >>> x = keras.ops.convert_to_tensor([1.0, 10.0, 100.0])
954
+ >>> keras.ops.rsqrt(x)
955
+ array([1.0, 0.31622776, 0.1], dtype=float32)
956
+ """
957
+ if any_symbolic_tensors((x,)):
958
+ return Rsqrt().symbolic_call(x)
959
+ x = backend.convert_to_tensor(x)
960
+ return backend.math.rsqrt(x)
961
+
962
+
963
+ class Erf(Operation):
964
+ def compute_output_spec(self, x):
965
+ return KerasTensor(shape=x.shape, dtype=x.dtype)
966
+
967
+ def call(self, x):
968
+ return backend.math.erf(x)
969
+
970
+
971
+ @keras_export("keras.ops.erf")
972
+ def erf(x):
973
+ """Computes the error function of `x`, element-wise.
974
+
975
+ Args:
976
+ x: Input tensor.
977
+
978
+ Returns:
979
+ A tensor with the same dtype as `x`.
980
+
981
+ Example:
982
+
983
+ >>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0])
984
+ >>> keras.ops.erf(x)
985
+ array([-0.99998 , -0.99532, -0.842701, 0., 0.842701], dtype=float32)
986
+ """
987
+ if any_symbolic_tensors((x,)):
988
+ return Erf().symbolic_call(x)
989
+ x = backend.convert_to_tensor(x)
990
+ return backend.math.erf(x)
991
+
992
+
993
+ class Erfinv(Operation):
994
+ def compute_output_spec(self, x):
995
+ return KerasTensor(shape=x.shape, dtype=x.dtype)
996
+
997
+ def call(self, x):
998
+ return backend.math.erfinv(x)
999
+
1000
+
1001
+ @keras_export("keras.ops.erfinv")
1002
+ def erfinv(x):
1003
+ """Computes the inverse error function of `x`, element-wise.
1004
+
1005
+ Args:
1006
+ x: Input tensor.
1007
+
1008
+ Returns:
1009
+ A tensor with the same dtype as `x`.
1010
+
1011
+ Example:
1012
+
1013
+ >>> x = np.array([-0.5, -0.2, -0.1, 0.0, 0.3])
1014
+ >>> keras.ops.erfinv(x)
1015
+ array([-0.47694, -0.17914, -0.08886, 0. , 0.27246], dtype=float32)
1016
+ """
1017
+ if any_symbolic_tensors((x,)):
1018
+ return Erfinv().symbolic_call(x)
1019
+ x = backend.convert_to_tensor(x)
1020
+ return backend.math.erfinv(x)
1021
+
1022
+
1023
+ class Logdet(Operation):
1024
+ def __init__(self):
1025
+ super().__init__()
1026
+
1027
+ def call(self, x):
1028
+ return backend.math.logdet(x)
1029
+
1030
+ def compute_output_spec(self, x):
1031
+ return KerasTensor(x.shape[:-2], dtype=x.dtype)
1032
+
1033
+
1034
+ @keras_export(["keras.ops.logdet"])
1035
+ def logdet(x):
1036
+ """Computes log of the determinant of a hermitian positive definite matrix.
1037
+
1038
+ Args:
1039
+ x: Input matrix. It must 2D and square.
1040
+
1041
+ Returns:
1042
+ The natural log of the determinant of matrix.
1043
+ """
1044
+ if any_symbolic_tensors((x,)):
1045
+ return Logdet().symbolic_call(x)
1046
+ return backend.math.logdet(x)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/nn.py ADDED
@@ -0,0 +1,2653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Commonly-used neural network operations not included in NumPy."""
2
+
3
+ import warnings
4
+
5
+ from keras.src import backend
6
+ from keras.src.api_export import keras_export
7
+ from keras.src.backend import KerasTensor
8
+ from keras.src.backend import any_symbolic_tensors
9
+ from keras.src.backend import standardize_data_format
10
+ from keras.src.backend.common.backend_utils import (
11
+ compute_conv_transpose_output_shape,
12
+ )
13
+ from keras.src.ops import operation_utils
14
+ from keras.src.ops.operation import Operation
15
+ from keras.src.ops.operation_utils import reduce_shape
16
+
17
+
18
+ class Relu(Operation):
19
+ def call(self, x):
20
+ return backend.nn.relu(x)
21
+
22
+ def compute_output_spec(self, x):
23
+ return KerasTensor(x.shape, dtype=x.dtype)
24
+
25
+
26
+ @keras_export(["keras.ops.relu", "keras.ops.nn.relu"])
27
+ def relu(x):
28
+ """Rectified linear unit activation function.
29
+
30
+ It is defined as `f(x) = max(0, x)`.
31
+
32
+ Args:
33
+ x: Input tensor.
34
+
35
+ Returns:
36
+ A tensor with the same shape as `x`.
37
+
38
+ Example:
39
+
40
+ >>> x1 = keras.ops.convert_to_tensor([-1.0, 0.0, 1.0, 0.2])
41
+ >>> keras.ops.relu(x1)
42
+ array([0.0, 0.0, 1.0, 0.2], dtype=float32)
43
+ """
44
+ if any_symbolic_tensors((x,)):
45
+ return Relu().symbolic_call(x)
46
+ return backend.nn.relu(x)
47
+
48
+
49
+ class Relu6(Operation):
50
+ def call(self, x):
51
+ return backend.nn.relu6(x)
52
+
53
+ def compute_output_spec(self, x):
54
+ return KerasTensor(x.shape, dtype=x.dtype)
55
+
56
+
57
+ @keras_export(["keras.ops.relu6", "keras.ops.nn.relu6"])
58
+ def relu6(x):
59
+ """Rectified linear unit activation function with upper bound of 6.
60
+
61
+ It is defined as `f(x) = np.clip(x, 0, 6)`.
62
+
63
+ Args:
64
+ x: Input tensor.
65
+
66
+ Returns:
67
+ A tensor with the same shape as `x`.
68
+
69
+ Example:
70
+
71
+ >>> x = keras.ops.convert_to_tensor([-3.0, -2.0, 0.1, 0.2, 6.0, 8.0])
72
+ >>> keras.ops.relu6(x)
73
+ array([0.0, 0.0, 0.1, 0.2, 6.0, 6.0], dtype=float32)
74
+ """
75
+ if any_symbolic_tensors((x,)):
76
+ return Relu6().symbolic_call(x)
77
+ return backend.nn.relu6(x)
78
+
79
+
80
+ class Sigmoid(Operation):
81
+ def call(self, x):
82
+ return backend.nn.sigmoid(x)
83
+
84
+ def compute_output_spec(self, x):
85
+ return KerasTensor(x.shape, dtype=x.dtype)
86
+
87
+
88
+ @keras_export(["keras.ops.sigmoid", "keras.ops.nn.sigmoid"])
89
+ def sigmoid(x):
90
+ """Sigmoid activation function.
91
+
92
+ It is defined as `f(x) = 1 / (1 + exp(-x))`.
93
+
94
+ Args:
95
+ x: Input tensor.
96
+
97
+ Returns:
98
+ A tensor with the same shape as `x`.
99
+
100
+ Example:
101
+
102
+ >>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0])
103
+ >>> keras.ops.sigmoid(x)
104
+ array([0.00247262, 0.7310586, 0.5, 0.7310586, 0.9975274], dtype=float32)
105
+
106
+ """
107
+ if any_symbolic_tensors((x,)):
108
+ return Sigmoid().symbolic_call(x)
109
+ return backend.nn.sigmoid(x)
110
+
111
+
112
+ class Softplus(Operation):
113
+ def call(self, x):
114
+ return backend.nn.softplus(x)
115
+
116
+ def compute_output_spec(self, x):
117
+ return KerasTensor(x.shape, dtype=x.dtype)
118
+
119
+
120
+ @keras_export(["keras.ops.softplus", "keras.ops.nn.softplus"])
121
+ def softplus(x):
122
+ """Softplus activation function.
123
+
124
+ It is defined as `f(x) = log(exp(x) + 1)`, where `log` is the natural
125
+ logarithm and `exp` is the exponential function.
126
+
127
+ Args:
128
+ x: Input tensor.
129
+
130
+ Returns:
131
+ A tensor with the same shape as `x`.
132
+
133
+ Example:
134
+
135
+ >>> x = keras.ops.convert_to_tensor([-0.555, 0.0, 0.555])
136
+ >>> keras.ops.softplus(x)
137
+ array([0.45366603, 0.6931472, 1.008666], dtype=float32)
138
+
139
+ """
140
+ if any_symbolic_tensors((x,)):
141
+ return Softplus().symbolic_call(x)
142
+ return backend.nn.softplus(x)
143
+
144
+
145
+ class Softsign(Operation):
146
+ def call(self, x):
147
+ return backend.nn.softsign(x)
148
+
149
+ def compute_output_spec(self, x):
150
+ return KerasTensor(x.shape, dtype=x.dtype)
151
+
152
+
153
+ @keras_export(["keras.ops.softsign", "keras.ops.nn.softsign"])
154
+ def softsign(x):
155
+ """Softsign activation function.
156
+
157
+ It is defined as `f(x) = x / (abs(x) + 1)`.
158
+
159
+ Args:
160
+ x: Input tensor.
161
+
162
+ Returns:
163
+ A tensor with the same shape as `x`.
164
+
165
+ Example:
166
+
167
+ >>> x = keras.ops.convert_to_tensor([-0.100, -10.0, 1.0, 0.0, 100.0])
168
+ >>> keras.ops.softsign(x)
169
+ Array([-0.09090909, -0.90909094, 0.5, 0.0, 0.990099], dtype=float32)
170
+
171
+ """
172
+ if any_symbolic_tensors((x,)):
173
+ return Softsign().symbolic_call(x)
174
+ return backend.nn.softsign(x)
175
+
176
+
177
+ class SoftShrink(Operation):
178
+ def __init__(self, threshold=0.5):
179
+ super().__init__()
180
+ self.threshold = threshold
181
+
182
+ def call(self, x):
183
+ return backend.nn.soft_shrink(x, self.threshold)
184
+
185
+ def compute_output_spec(self, x):
186
+ return KerasTensor(x.shape, dtype=x.dtype)
187
+
188
+
189
+ @keras_export(["keras.ops.soft_shrink", "keras.ops.nn.soft_shrink"])
190
+ def soft_shrink(x, threshold=0.5):
191
+ """Soft Shrink activation function.
192
+
193
+ It is defined as
194
+
195
+ `f(x) = x - threshold` if `x > threshold`,
196
+ `f(x) = x + threshold` if `x < -threshold`,
197
+ `f(x) = 0` otherwise.
198
+
199
+ Args:
200
+ x: Input tensor.
201
+ threshold: Threshold value. Defaults to 0.5.
202
+
203
+ Returns:
204
+ A tensor with the same shape as `x`.
205
+
206
+ Example:
207
+
208
+ >>> x = np.array([-1.0, 0.0, 1.0])
209
+ >>> x_soft_shrink = keras.ops.soft_shrink(x)
210
+ >>> print(x_soft_shrink)
211
+ array([-0.5 0. 0.5], shape=(3,), dtype=float64)
212
+
213
+ """
214
+ if any_symbolic_tensors((x,)):
215
+ return SoftShrink(threshold).symbolic_call(x)
216
+ return backend.nn.soft_shrink(x, threshold)
217
+
218
+
219
+ class SparsePlus(Operation):
220
+ def call(self, x):
221
+ return backend.nn.sparse_plus(x)
222
+
223
+ def compute_output_spec(self, x):
224
+ return KerasTensor(x.shape, dtype=x.dtype)
225
+
226
+
227
+ @keras_export(["keras.ops.sparse_plus", "keras.ops.nn.sparse_plus"])
228
+ def sparse_plus(x):
229
+ """SparsePlus activation function.
230
+
231
+ It is defined as
232
+
233
+ `f(x) = 0` for `x <= -1`.
234
+ `f(x) = (1/4) * (x + 1)^2` for `-1 < x < 1`.
235
+ `f(x) = x` for `x >= 1`.
236
+
237
+
238
+ Args:
239
+ x: Input tensor.
240
+
241
+ Returns:
242
+ A tensor with the same shape as `x`.
243
+
244
+ Example:
245
+
246
+ >>> x = np.array([-1.0, 0.0, 1.0])
247
+ >>> x_sparse_plus = keras.ops.sparse_plus(x)
248
+ >>> print(x_sparse_plus)
249
+ Array([0. 0.25 1. ], shape=(3,), dtype=float32)
250
+
251
+ """
252
+ if any_symbolic_tensors((x,)):
253
+ return SparsePlus().symbolic_call(x)
254
+ return backend.nn.sparse_plus(x)
255
+
256
+
257
+ class Silu(Operation):
258
+ def call(self, x):
259
+ return backend.nn.silu(x)
260
+
261
+ def compute_output_spec(self, x):
262
+ return KerasTensor(x.shape, dtype=x.dtype)
263
+
264
+
265
+ @keras_export(
266
+ [
267
+ "keras.ops.silu",
268
+ "keras.ops.nn.silu",
269
+ "keras.ops.swish",
270
+ "keras.ops.nn.swish",
271
+ ]
272
+ )
273
+ def silu(x):
274
+ """Sigmoid Linear Unit (SiLU) activation function, also known as Swish.
275
+
276
+ The SiLU activation function is computed by the sigmoid function multiplied
277
+ by its input. It is defined as `f(x) = x * sigmoid(x)`.
278
+
279
+ Args:
280
+ x: Input tensor.
281
+
282
+ Returns:
283
+ A tensor with the same shape as `x`.
284
+
285
+ Example:
286
+
287
+ >>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0])
288
+ >>> keras.ops.sigmoid(x)
289
+ array([0.00247262, 0.7310586, 0.5, 0.7310586, 0.9975274], dtype=float32)
290
+ >>> keras.ops.silu(x)
291
+ array([-0.0148357, 0.7310586, 0.0, 0.7310586, 5.9851646], dtype=float32)
292
+
293
+ """
294
+ if any_symbolic_tensors((x,)):
295
+ return Silu().symbolic_call(x)
296
+ return backend.nn.silu(x)
297
+
298
+
299
+ class Squareplus(Operation):
300
+ def __init__(self, b=4):
301
+ super().__init__()
302
+ self.b = b
303
+
304
+ def call(self, x):
305
+ return backend.nn.squareplus(x, self.b)
306
+
307
+ def compute_output_spec(self, x):
308
+ return KerasTensor(x.shape, dtype=x.dtype)
309
+
310
+
311
+ @keras_export(["keras.ops.squareplus", "keras.ops.nn.squareplus"])
312
+ def squareplus(x, b=4):
313
+ """Squareplus activation function.
314
+
315
+ The Squareplus activation function is defined as:
316
+
317
+ `f(x) = (x + sqrt(x^2 + b)) / 2`
318
+
319
+ Args:
320
+ x: Input tensor.
321
+ b: Smoothness parameter. Defaults to 4.
322
+
323
+ Returns:
324
+ A tensor with the same shape as `x`.
325
+
326
+ Example:
327
+
328
+ >>> x = np.array([-1.0, 0.0, 1.0])
329
+ >>> x_squareplus = keras.ops.squareplus(x)
330
+ >>> print(x_squareplus)
331
+ array([0.6180, 1.0000, 1.6180], dtype=float32)
332
+
333
+ """
334
+ if any_symbolic_tensors((x,)):
335
+ return Squareplus(b).symbolic_call(x)
336
+ return backend.nn.squareplus(x, b)
337
+
338
+
339
+ class LogSigmoid(Operation):
340
+ def call(self, x):
341
+ return backend.nn.log_sigmoid(x)
342
+
343
+ def compute_output_spec(self, x):
344
+ return KerasTensor(x.shape, dtype=x.dtype)
345
+
346
+
347
+ @keras_export(
348
+ [
349
+ "keras.ops.log_sigmoid",
350
+ "keras.ops.nn.log_sigmoid",
351
+ ]
352
+ )
353
+ def log_sigmoid(x):
354
+ """Logarithm of the sigmoid activation function.
355
+
356
+ It is defined as `f(x) = log(1 / (1 + exp(-x)))`.
357
+
358
+ Args:
359
+ x: Input tensor.
360
+
361
+ Returns:
362
+ A tensor with the same shape as `x`.
363
+
364
+ Example:
365
+
366
+ >>> x = keras.ops.convert_to_tensor([-0.541391, 0.0, 0.50, 5.0])
367
+ >>> keras.ops.log_sigmoid(x)
368
+ array([-1.0000418, -0.6931472, -0.474077, -0.00671535], dtype=float32)
369
+
370
+ """
371
+ if any_symbolic_tensors((x,)):
372
+ return LogSigmoid().symbolic_call(x)
373
+ return backend.nn.log_sigmoid(x)
374
+
375
+
376
+ class LeakyRelu(Operation):
377
+ def __init__(self, negative_slope=0.2):
378
+ super().__init__()
379
+ self.negative_slope = negative_slope
380
+
381
+ def call(self, x):
382
+ return backend.nn.leaky_relu(x, self.negative_slope)
383
+
384
+ def compute_output_spec(self, x):
385
+ return KerasTensor(x.shape, dtype=x.dtype)
386
+
387
+
388
+ @keras_export(["keras.ops.leaky_relu", "keras.ops.nn.leaky_relu"])
389
+ def leaky_relu(x, negative_slope=0.2):
390
+ """Leaky version of a Rectified Linear Unit activation function.
391
+
392
+ It allows a small gradient when the unit is not active, it is defined as:
393
+
394
+ `f(x) = alpha * x for x < 0` or `f(x) = x for x >= 0`.
395
+
396
+ Args:
397
+ x: Input tensor.
398
+ negative_slope: Slope of the activation function at x < 0.
399
+ Defaults to `0.2`.
400
+
401
+ Returns:
402
+ A tensor with the same shape as `x`.
403
+
404
+ Example:
405
+
406
+ >>> x = np.array([-1., 0., 1.])
407
+ >>> x_leaky_relu = keras.ops.leaky_relu(x)
408
+ >>> print(x_leaky_relu)
409
+ array([-0.2, 0. , 1. ], shape=(3,), dtype=float64)
410
+
411
+ """
412
+ if any_symbolic_tensors((x,)):
413
+ return LeakyRelu(negative_slope).symbolic_call(x)
414
+ return backend.nn.leaky_relu(x, negative_slope=negative_slope)
415
+
416
+
417
+ class HardSigmoid(Operation):
418
+ def call(self, x):
419
+ return backend.nn.hard_sigmoid(x)
420
+
421
+ def compute_output_spec(self, x):
422
+ return KerasTensor(x.shape, dtype=x.dtype)
423
+
424
+
425
+ @keras_export(
426
+ [
427
+ "keras.ops.hard_sigmoid",
428
+ "keras.ops.nn.hard_sigmoid",
429
+ ]
430
+ )
431
+ def hard_sigmoid(x):
432
+ """Hard sigmoid activation function.
433
+
434
+ It is defined as:
435
+
436
+ `0 if x < -2.5`, `1 if x > 2.5`, `(0.2 * x) + 0.5 if -2.5 <= x <= 2.5`.
437
+
438
+ Args:
439
+ x: Input tensor.
440
+
441
+ Returns:
442
+ A tensor with the same shape as `x`.
443
+
444
+ Example:
445
+
446
+ >>> x = np.array([-1., 0., 1.])
447
+ >>> x_hard_sigmoid = keras.ops.hard_sigmoid(x)
448
+ >>> print(x_hard_sigmoid)
449
+ array([0.3, 0.5, 0.7], shape=(3,), dtype=float64)
450
+
451
+ """
452
+ if any_symbolic_tensors((x,)):
453
+ return HardSigmoid().symbolic_call(x)
454
+ return backend.nn.hard_sigmoid(x)
455
+
456
+
457
+ class HardSilu(Operation):
458
+ def call(self, x):
459
+ return backend.nn.hard_silu(x)
460
+
461
+ def compute_output_spec(self, x):
462
+ return KerasTensor(x.shape, dtype=x.dtype)
463
+
464
+
465
+ @keras_export(
466
+ [
467
+ "keras.ops.hard_silu",
468
+ "keras.ops.nn.hard_silu",
469
+ "keras.ops.hard_swish",
470
+ "keras.ops.nn.hard_swish",
471
+ ]
472
+ )
473
+ def hard_silu(x):
474
+ """Hard SiLU activation function, also known as Hard Swish.
475
+
476
+ It is defined as:
477
+
478
+ - `0` if `if x < -3`
479
+ - `x` if `x > 3`
480
+ - `x * (x + 3) / 6` if `-3 <= x <= 3`
481
+
482
+ It's a faster, piecewise linear approximation of the silu activation.
483
+
484
+ Args:
485
+ x: Input tensor.
486
+
487
+ Returns:
488
+ A tensor with the same shape as `x`.
489
+
490
+ Example:
491
+
492
+ >>> x = keras.ops.convert_to_tensor([-3.0, -1.0, 0.0, 1.0, 3.0])
493
+ >>> keras.ops.hard_silu(x)
494
+ array([-0.0, -0.3333333, 0.0, 0.6666667, 3.0], shape=(5,), dtype=float32)
495
+
496
+ """
497
+ if any_symbolic_tensors((x,)):
498
+ return HardSilu().symbolic_call(x)
499
+ return backend.nn.hard_silu(x)
500
+
501
+
502
+ class Elu(Operation):
503
+ def __init__(self, alpha=1.0):
504
+ super().__init__()
505
+ self.alpha = alpha
506
+
507
+ def call(self, x):
508
+ return backend.nn.elu(x, alpha=self.alpha)
509
+
510
+ def compute_output_spec(self, x):
511
+ return KerasTensor(x.shape, dtype=x.dtype)
512
+
513
+
514
+ @keras_export(["keras.ops.elu", "keras.ops.nn.elu"])
515
+ def elu(x, alpha=1.0):
516
+ """Exponential Linear Unit activation function.
517
+
518
+ It is defined as:
519
+
520
+ `f(x) = alpha * (exp(x) - 1.) for x < 0`, `f(x) = x for x >= 0`.
521
+
522
+ Args:
523
+ x: Input tensor.
524
+ alpha: A scalar, slope of positive section. Defaults to `1.0`.
525
+
526
+ Returns:
527
+ A tensor with the same shape as `x`.
528
+
529
+ Example:
530
+
531
+ >>> x = np.array([-1., 0., 1.])
532
+ >>> x_elu = keras.ops.elu(x)
533
+ >>> print(x_elu)
534
+ array([-0.63212055, 0., 1.], shape=(3,), dtype=float64)
535
+
536
+ """
537
+ if any_symbolic_tensors((x,)):
538
+ return Elu(alpha).symbolic_call(x)
539
+ return backend.nn.elu(x, alpha=alpha)
540
+
541
+
542
+ class Selu(Operation):
543
+ def call(self, x):
544
+ return backend.nn.selu(x)
545
+
546
+ def compute_output_spec(self, x):
547
+ return KerasTensor(x.shape, dtype=x.dtype)
548
+
549
+
550
+ @keras_export(["keras.ops.selu", "keras.ops.nn.selu"])
551
+ def selu(x):
552
+ """Scaled Exponential Linear Unit (SELU) activation function.
553
+
554
+ It is defined as:
555
+
556
+ `f(x) = scale * alpha * (exp(x) - 1.) for x < 0`,
557
+ `f(x) = scale * x for x >= 0`.
558
+
559
+ Args:
560
+ x: Input tensor.
561
+
562
+ Returns:
563
+ A tensor with the same shape as `x`.
564
+
565
+ Example:
566
+
567
+ >>> x = np.array([-1., 0., 1.])
568
+ >>> x_selu = keras.ops.selu(x)
569
+ >>> print(x_selu)
570
+ array([-1.11133055, 0., 1.05070098], shape=(3,), dtype=float64)
571
+
572
+ """
573
+ if any_symbolic_tensors((x,)):
574
+ return Selu().symbolic_call(x)
575
+ return backend.nn.selu(x)
576
+
577
+
578
+ class Gelu(Operation):
579
+ def __init__(self, approximate=True):
580
+ super().__init__()
581
+ self.approximate = approximate
582
+
583
+ def call(self, x):
584
+ return backend.nn.gelu(x, self.approximate)
585
+
586
+ def compute_output_spec(self, x):
587
+ return KerasTensor(x.shape, dtype=x.dtype)
588
+
589
+
590
+ @keras_export(["keras.ops.gelu", "keras.ops.nn.gelu"])
591
+ def gelu(x, approximate=True):
592
+ """Gaussian Error Linear Unit (GELU) activation function.
593
+
594
+ If `approximate` is `True`, it is defined as:
595
+ `f(x) = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
596
+
597
+ Or if `approximate` is `False`, it is defined as:
598
+ `f(x) = x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`,
599
+ where `P(X) ~ N(0, 1)`.
600
+
601
+ Args:
602
+ x: Input tensor.
603
+ approximate: Approximate version of GELU activation. Defaults to `True`.
604
+
605
+ Returns:
606
+ A tensor with the same shape as `x`.
607
+
608
+ Example:
609
+
610
+ >>> x = np.array([-1., 0., 1.])
611
+ >>> x_gelu = keras.ops.gelu(x)
612
+ >>> print(x_gelu)
613
+ array([-0.15865525, 0., 0.84134475], shape=(3,), dtype=float64)
614
+
615
+ """
616
+ if any_symbolic_tensors((x,)):
617
+ return Gelu(approximate).symbolic_call(x)
618
+ return backend.nn.gelu(x, approximate)
619
+
620
+
621
+ class Celu(Operation):
622
+ def __init__(self, alpha=1.0):
623
+ super().__init__()
624
+ self.alpha = alpha
625
+
626
+ def call(self, x):
627
+ return backend.nn.celu(x, self.alpha)
628
+
629
+ def compute_output_spec(self, x):
630
+ return KerasTensor(x.shape, dtype=x.dtype)
631
+
632
+
633
+ @keras_export(["keras.ops.celu", "keras.ops.nn.celu"])
634
+ def celu(x, alpha=1.0):
635
+ """Continuously-differentiable exponential linear unit.
636
+
637
+ It is defined as:
638
+
639
+ `f(x) = alpha * (exp(x / alpha) - 1) for x < 0`, `f(x) = x for x >= 0`.
640
+
641
+ Args:
642
+ x: Input tensor.
643
+ alpha: the α value for the CELU formulation. Defaults to `1.0`.
644
+
645
+ Returns:
646
+ A tensor with the same shape as `x`.
647
+
648
+ Example:
649
+
650
+ >>> x = np.array([-1., 0., 1.])
651
+ >>> x_celu = keras.ops.celu(x)
652
+ >>> print(x_celu)
653
+ array([-0.63212056, 0. , 1. ], shape=(3,), dtype=float64)
654
+
655
+ """
656
+ if any_symbolic_tensors((x,)):
657
+ return Celu(alpha).symbolic_call(x)
658
+ return backend.nn.celu(x, alpha)
659
+
660
+
661
+ class Glu(Operation):
662
+ def __init__(self, axis=-1):
663
+ super().__init__()
664
+ self.axis = axis
665
+
666
+ def call(self, x):
667
+ return backend.nn.glu(x, axis=self.axis)
668
+
669
+ def compute_output_spec(self, x):
670
+ return KerasTensor(x.shape, dtype=x.dtype)
671
+
672
+
673
+ @keras_export(["keras.ops.glu", "keras.ops.nn.glu"])
674
+ def glu(x, axis=-1):
675
+ """Gated Linear Unit (GLU) activation function.
676
+
677
+ It is defined as:
678
+
679
+ `f(x) = a * sigmoid(b)`
680
+ where `x` is split into `a` and `b` along the given axis.
681
+
682
+ Args:
683
+ x: Input tensor.
684
+ axis: The axis along which to split the input tensor. Defaults to `-1`.
685
+
686
+ Returns:
687
+ A tensor with the same shape as half of the input.
688
+
689
+ Example:
690
+
691
+ >>> x = np.array([-1., 0., 1. , 1.])
692
+ >>> x_glu = keras.ops.glu(x)
693
+ >>> print(x_glu)
694
+ array([-0.73105858, 0. ], shape=(2,), dtype=float64)
695
+
696
+ """
697
+ if any_symbolic_tensors((x,)):
698
+ return Glu(axis).symbolic_call(x)
699
+ return backend.nn.glu(x, axis=axis)
700
+
701
+
702
+ class TanhShrink(Operation):
703
+ def __init__(self):
704
+ super().__init__()
705
+
706
+ def call(self, x):
707
+ return backend.nn.tanh_shrink(x)
708
+
709
+ def compute_output_spec(self, x):
710
+ return KerasTensor(x.shape, dtype=x.dtype)
711
+
712
+
713
+ @keras_export(["keras.ops.tanh_shrink", "keras.ops.nn.tanh_shrink"])
714
+ def tanh_shrink(x):
715
+ """Applies the tanh shrink function element-wise.
716
+
717
+ It is defined as:
718
+
719
+ `f(x) = x - tanh(x)`.
720
+
721
+ Args:
722
+ x: Input tensor.
723
+
724
+ Returns:
725
+ Output tensor of the same shape as `x`, where each element is
726
+ transformed according to the tanh shrink operation.
727
+
728
+ Example:
729
+
730
+ >>> x = np.array([ -1., 0., 1.])
731
+ >>> x_tanh_shrink = keras.ops.tanh_shrink(x)
732
+ >>> print(x_tanh_shrink)
733
+ array([-0.23840584 0. 0.23840584], shape=(3,), dtype=float64)
734
+
735
+ """
736
+ if any_symbolic_tensors((x,)):
737
+ return TanhShrink().symbolic_call(x)
738
+ return backend.nn.tanh_shrink(x)
739
+
740
+
741
+ class HardTanh(Operation):
742
+ def __init__(self):
743
+ super().__init__()
744
+
745
+ def call(self, x):
746
+ return backend.nn.hard_tanh(x)
747
+
748
+ def compute_output_spec(self, x):
749
+ return KerasTensor(x.shape, dtype=x.dtype)
750
+
751
+
752
+ @keras_export(["keras.ops.hard_tanh", "keras.ops.nn.hard_tanh"])
753
+ def hard_tanh(x):
754
+ """Applies the HardTanh function element-wise.
755
+
756
+ It is defined as:
757
+
758
+ `f(x) = -1 for x < -1`, `f(x) = x for -1 <= x <= 1`, `f(x) = 1 for x > 1`.
759
+
760
+ Args:
761
+ x: Input tensor.
762
+
763
+ Returns:
764
+ Output tensor of same shape as `x`
765
+ where values are clamped between -1 and 1.
766
+
767
+ Example:
768
+
769
+ >>> x = np.array([-2., -1., 0., 1., 2.])
770
+ >>> x_hard_tanh = keras.ops.hard_tanh(x)
771
+ >>> print(x_hard_tanh)
772
+ array([-1. -1. 0. 1. 1.], shape=(5,), dtype=float64)
773
+
774
+ """
775
+ if any_symbolic_tensors((x,)):
776
+ return HardTanh().symbolic_call(x)
777
+ return backend.nn.hard_tanh(x)
778
+
779
+
780
+ class HardShrink(Operation):
781
+ def __init__(self, threshold=0.5):
782
+ super().__init__()
783
+ self.threshold = threshold
784
+
785
+ def call(self, x):
786
+ return backend.nn.hard_shrink(x, self.threshold)
787
+
788
+ def compute_output_spec(self, x):
789
+ return KerasTensor(x.shape, dtype=x.dtype)
790
+
791
+
792
+ @keras_export(["keras.ops.hard_shrink", "keras.ops.nn.hard_shrink"])
793
+ def hard_shrink(x, threshold=0.5):
794
+ """Hard Shrink activation function.
795
+
796
+ The Hard Shrink function is a thresholding operation defined as:
797
+
798
+ `f(x) = x` if `|x| > threshold`,
799
+ `f(x) = 0` otherwise.
800
+
801
+ Args:
802
+ x: Input tensor.
803
+ threshold: Threshold value. Defaults to 0.5.
804
+
805
+ Returns:
806
+ A tensor with the same shape as `x`.
807
+
808
+ Example:
809
+
810
+ >>> x = np.array([-0.5, 0., 1.])
811
+ >>> x_hard_shrink = keras.ops.hard_shrink(x)
812
+ >>> print(x_hard_shrink)
813
+ array([0. 0. 1.], shape=(3,), dtype=float64)
814
+
815
+ """
816
+ if any_symbolic_tensors((x,)):
817
+ return HardShrink(threshold).symbolic_call(x)
818
+ return backend.nn.hard_shrink(x, threshold)
819
+
820
+
821
+ class Threshold(Operation):
822
+ def __init__(self, threshold_value, value):
823
+ super().__init__()
824
+ self.threshold_value = threshold_value
825
+ self.value = value
826
+
827
+ def call(self, x):
828
+ return backend.nn.threshold(x, self.threshold_value, self.value)
829
+
830
+ def compute_output_spec(self, x):
831
+ return KerasTensor(x.shape, dtype=x.dtype)
832
+
833
+
834
+ @keras_export(["keras.ops.threshold", "keras.ops.nn.threshold"])
835
+ def threshold(x, threshold, default_value):
836
+ """Threshold activation function.
837
+
838
+ The function thresholds the input `x` as follows:
839
+ `f(x) = x` if `x > threshold`,
840
+ `f(x) = default_value` otherwise.
841
+
842
+ Args:
843
+ x: Input tensor.
844
+ threshold: The value that decides when to retain or replace x.
845
+ default_value: Value to assign when `x <= threshold`.
846
+
847
+ Returns:
848
+ A tensor with the same shape as `x`.
849
+
850
+ Example:
851
+
852
+ >>> x = np.array([-1.0, 0.0, 1.0, 2.0])
853
+ >>> x_threshold = keras.ops.threshold(x, 1, 0)
854
+ >>> print(x_threshold)
855
+ array([0., 0., 0., 2.], shape=(4,), dtype=float64)
856
+
857
+ """
858
+ if any_symbolic_tensors((x,)):
859
+ return Threshold(threshold, default_value).symbolic_call(x)
860
+ return backend.nn.threshold(x, threshold, default_value)
861
+
862
+
863
+ class Softmax(Operation):
864
+ def __init__(self, axis=-1):
865
+ super().__init__()
866
+ self.axis = axis
867
+
868
+ def call(self, x):
869
+ return backend.nn.softmax(x, axis=self.axis)
870
+
871
+ def compute_output_spec(self, x):
872
+ return KerasTensor(x.shape, dtype=x.dtype)
873
+
874
+
875
+ @keras_export(["keras.ops.softmax", "keras.ops.nn.softmax"])
876
+ def softmax(x, axis=-1):
877
+ """Softmax activation function.
878
+
879
+ The elements of the output vector lie within the range `(0, 1)`, and their
880
+ total sum is exactly 1 (excluding the floating point rounding error).
881
+
882
+ Each vector is processed independently. The `axis` argument specifies the
883
+ axis along which the function is applied within the input.
884
+
885
+ It is defined as:
886
+ `f(x) = exp(x) / sum(exp(x))`
887
+
888
+ Args:
889
+ x: Input tensor.
890
+ axis: Integer, axis along which the softmax is applied.
891
+
892
+ Returns:
893
+ A tensor with the same shape as `x`.
894
+
895
+ Example:
896
+
897
+ >>> x = np.array([-1., 0., 1.])
898
+ >>> x_softmax = keras.ops.softmax(x)
899
+ >>> print(x_softmax)
900
+ array([0.09003057, 0.24472847, 0.66524096], shape=(3,), dtype=float64)
901
+
902
+ """
903
+ # Don't use `backend.shape` since TensorFlow returns
904
+ # symbolic tensors for unknown shape which can trigger
905
+ # an error in TensorFlow graph execution.
906
+ if isinstance(axis, int) and x.shape[axis] == 1:
907
+ warnings.warn(
908
+ f"You are using a softmax over axis {axis} "
909
+ f"of a tensor of shape {x.shape}. This axis "
910
+ "has size 1. The softmax operation will always return "
911
+ "the value 1, which is likely not what you intended. "
912
+ "Did you mean to use a sigmoid instead?"
913
+ )
914
+ if any_symbolic_tensors((x,)):
915
+ return Softmax(axis).symbolic_call(x)
916
+ if isinstance(axis, tuple):
917
+ axis_to_keep = [v for v in range(len(x.shape)) if v not in axis]
918
+
919
+ x_transposed = backend.numpy.transpose(x, axes=(*axis_to_keep, *axis))
920
+ x_reshaped = backend.numpy.reshape(
921
+ x_transposed, (*[x.shape[v] for v in axis_to_keep], -1)
922
+ )
923
+
924
+ x = backend.nn.softmax(x_reshaped, axis=-1)
925
+
926
+ x = backend.numpy.reshape(x, x_transposed.shape)
927
+ x = backend.numpy.transpose(
928
+ x, axes=list(backend.numpy.argsort([*axis_to_keep, *axis]))
929
+ )
930
+ return x
931
+ else:
932
+ return backend.nn.softmax(x, axis=axis)
933
+
934
+
935
+ class LogSoftmax(Operation):
936
+ def __init__(self, axis=-1):
937
+ super().__init__()
938
+ self.axis = axis
939
+
940
+ def call(self, x):
941
+ return backend.nn.log_softmax(x, axis=self.axis)
942
+
943
+ def compute_output_spec(self, x):
944
+ return KerasTensor(x.shape, dtype=x.dtype)
945
+
946
+
947
+ @keras_export(
948
+ [
949
+ "keras.ops.log_softmax",
950
+ "keras.ops.nn.log_softmax",
951
+ ]
952
+ )
953
+ def log_softmax(x, axis=-1):
954
+ """Log-softmax activation function.
955
+
956
+ It is defined as:
957
+ `f(x) = x - max(x) - log(sum(exp(x - max(x))))`
958
+
959
+ Args:
960
+ x: Input tensor.
961
+ axis: Integer, axis along which the log-softmax is applied.
962
+ Defaults to `-1`.
963
+
964
+ Returns:
965
+ A tensor with the same shape as `x`.
966
+
967
+ Example:
968
+
969
+ >>> x = np.array([-1., 0., 1.])
970
+ >>> x_log_softmax = keras.ops.log_softmax(x)
971
+ >>> print(x_log_softmax)
972
+ array([-2.40760596, -1.40760596, -0.40760596], shape=(3,), dtype=float64)
973
+
974
+ """
975
+ if any_symbolic_tensors((x,)):
976
+ return LogSoftmax(axis).symbolic_call(x)
977
+ if isinstance(axis, tuple):
978
+ axis_to_keep = [v for v in range(len(x.shape)) if v not in axis]
979
+
980
+ x_transposed = backend.numpy.transpose(x, axes=(*axis_to_keep, *axis))
981
+ x_reshaped = backend.numpy.reshape(
982
+ x_transposed, (*[x.shape[v] for v in axis_to_keep], -1)
983
+ )
984
+
985
+ x = backend.nn.log_softmax(x_reshaped, axis=-1)
986
+
987
+ x = backend.numpy.reshape(x, x_transposed.shape)
988
+ x = backend.numpy.transpose(
989
+ x, axes=list(backend.numpy.argsort([*axis_to_keep, *axis]))
990
+ )
991
+ return x
992
+ else:
993
+ return backend.nn.log_softmax(x, axis=axis)
994
+
995
+
996
+ class Sparsemax(Operation):
997
+ def __init__(self, axis=-1):
998
+ super().__init__()
999
+ self.axis = axis
1000
+
1001
+ def call(self, x):
1002
+ return backend.nn.sparsemax(x, axis=self.axis)
1003
+
1004
+ def compute_output_spec(self, x):
1005
+ return KerasTensor(x.shape, dtype=x.dtype)
1006
+
1007
+
1008
+ @keras_export(["keras.ops.sparsemax", "keras.ops.nn.sparsemax"])
1009
+ def sparsemax(x, axis=-1):
1010
+ """Sparsemax activation function.
1011
+
1012
+ For each batch `i`, and class `j`,
1013
+ sparsemax activation function is defined as:
1014
+
1015
+ `sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).`
1016
+
1017
+ Args:
1018
+ x: Input tensor.
1019
+ axis: `int`, axis along which the sparsemax operation is applied.
1020
+
1021
+ Returns:
1022
+ A tensor, output of sparsemax transformation. Has the same type and
1023
+ shape as `x`.
1024
+
1025
+ Example:
1026
+
1027
+ >>> x = np.array([-1., 0., 1.])
1028
+ >>> x_sparsemax = keras.ops.sparsemax(x)
1029
+ >>> print(x_sparsemax)
1030
+ array([0., 0., 1.], shape=(3,), dtype=float64)
1031
+
1032
+ """
1033
+ if any_symbolic_tensors((x,)):
1034
+ return Sparsemax(axis).symbolic_call(x)
1035
+ return backend.nn.sparsemax(x, axis=axis)
1036
+
1037
+
1038
+ class MaxPool(Operation):
1039
+ def __init__(
1040
+ self,
1041
+ pool_size,
1042
+ strides=None,
1043
+ padding="valid",
1044
+ data_format=None,
1045
+ ):
1046
+ super().__init__()
1047
+ self.pool_size = pool_size
1048
+ self.strides = strides
1049
+ self.padding = padding.lower()
1050
+ self.data_format = data_format
1051
+
1052
+ def call(self, inputs):
1053
+ return backend.nn.max_pool(
1054
+ inputs,
1055
+ self.pool_size,
1056
+ self.strides,
1057
+ self.padding,
1058
+ self.data_format,
1059
+ )
1060
+
1061
+ def compute_output_spec(self, inputs):
1062
+ output_shape = operation_utils.compute_pooling_output_shape(
1063
+ inputs.shape,
1064
+ self.pool_size,
1065
+ self.strides,
1066
+ self.padding,
1067
+ self.data_format,
1068
+ )
1069
+ return KerasTensor(output_shape, dtype=inputs.dtype)
1070
+
1071
+
1072
+ @keras_export(["keras.ops.max_pool", "keras.ops.nn.max_pool"])
1073
+ def max_pool(
1074
+ inputs,
1075
+ pool_size,
1076
+ strides=None,
1077
+ padding="valid",
1078
+ data_format=None,
1079
+ ):
1080
+ """Max pooling operation.
1081
+
1082
+ Args:
1083
+ inputs: Tensor of rank N+2. `inputs` has shape
1084
+ `(batch_size,) + inputs_spatial_shape + (num_channels,)` if
1085
+ `data_format="channels_last"`, or
1086
+ `(batch_size, num_channels) + inputs_spatial_shape` if
1087
+ `data_format="channels_first"`. Pooling happens over the spatial
1088
+ dimensions only.
1089
+ pool_size: int or tuple/list of integers of size
1090
+ `len(inputs_spatial_shape)`, specifying the size of the pooling
1091
+ window for each spatial dimension of the input tensor. If
1092
+ `pool_size` is int, then every spatial dimension shares the same
1093
+ `pool_size`.
1094
+ strides: int or tuple/list of integers of size
1095
+ `len(inputs_spatial_shape)`. The stride of the sliding window for
1096
+ each spatial dimension of the input tensor. If `strides` is int,
1097
+ then every spatial dimension shares the same `strides`.
1098
+ padding: string, either `"valid"` or `"same"`. `"valid"` means no
1099
+ padding is applied, and `"same"` results in padding evenly to the
1100
+ left/right or up/down of the input such that output has the
1101
+ same height/width dimension as the input when `strides=1`.
1102
+ data_format: A string, either `"channels_last"` or `"channels_first"`.
1103
+ `data_format` determines the ordering of the dimensions in the
1104
+ inputs. If `data_format="channels_last"`, `inputs` is of shape
1105
+ `(batch_size, ..., channels)` while if
1106
+ `data_format="channels_first"`, `inputs` is of shape
1107
+ `(batch_size, channels, ...)`.
1108
+
1109
+ Returns:
1110
+ A tensor of rank N+2, the result of the max pooling operation.
1111
+ """
1112
+ data_format = standardize_data_format(data_format)
1113
+ padding = padding.lower()
1114
+ if any_symbolic_tensors((inputs,)):
1115
+ return MaxPool(
1116
+ pool_size,
1117
+ strides,
1118
+ padding,
1119
+ data_format,
1120
+ ).symbolic_call(inputs)
1121
+ return backend.nn.max_pool(inputs, pool_size, strides, padding, data_format)
1122
+
1123
+
1124
+ class AveragePool(Operation):
1125
+ def __init__(
1126
+ self,
1127
+ pool_size,
1128
+ strides=None,
1129
+ padding="valid",
1130
+ data_format=None,
1131
+ ):
1132
+ super().__init__()
1133
+ self.pool_size = pool_size
1134
+ self.strides = strides
1135
+ self.padding = padding.lower()
1136
+ self.data_format = data_format
1137
+
1138
+ def call(self, inputs):
1139
+ return backend.nn.average_pool(
1140
+ inputs,
1141
+ self.pool_size,
1142
+ self.strides,
1143
+ self.padding,
1144
+ self.data_format,
1145
+ )
1146
+
1147
+ def compute_output_spec(self, inputs):
1148
+ output_shape = operation_utils.compute_pooling_output_shape(
1149
+ inputs.shape,
1150
+ self.pool_size,
1151
+ self.strides,
1152
+ self.padding,
1153
+ self.data_format,
1154
+ )
1155
+ return KerasTensor(output_shape, dtype=inputs.dtype)
1156
+
1157
+
1158
+ @keras_export(
1159
+ [
1160
+ "keras.ops.average_pool",
1161
+ "keras.ops.nn.average_pool",
1162
+ ]
1163
+ )
1164
+ def average_pool(
1165
+ inputs,
1166
+ pool_size,
1167
+ strides=None,
1168
+ padding="valid",
1169
+ data_format=None,
1170
+ ):
1171
+ """Average pooling operation.
1172
+
1173
+ Args:
1174
+ inputs: Tensor of rank N+2. `inputs` has shape
1175
+ `(batch_size,) + inputs_spatial_shape + (num_channels,)` if
1176
+ `data_format="channels_last"`, or
1177
+ `(batch_size, num_channels) + inputs_spatial_shape` if
1178
+ `data_format="channels_first"`. Pooling happens over the spatial
1179
+ dimensions only.
1180
+ pool_size: int or tuple/list of integers of size
1181
+ `len(inputs_spatial_shape)`, specifying the size of the pooling
1182
+ window for each spatial dimension of the input tensor. If
1183
+ `pool_size` is int, then every spatial dimension shares the same
1184
+ `pool_size`.
1185
+ strides: int or tuple/list of integers of size
1186
+ `len(inputs_spatial_shape)`. The stride of the sliding window for
1187
+ each spatial dimension of the input tensor. If `strides` is int,
1188
+ then every spatial dimension shares the same `strides`.
1189
+ padding: string, either `"valid"` or `"same"`. `"valid"` means no
1190
+ padding is applied, and `"same"` results in padding evenly to the
1191
+ left/right or up/down of the input such that output has the
1192
+ same height/width dimension as the input when `strides=1`.
1193
+ data_format: A string, either `"channels_last"` or `"channels_first"`.
1194
+ `data_format` determines the ordering of the dimensions in the
1195
+ inputs. If `data_format="channels_last"`, `inputs` is of shape
1196
+ `(batch_size, ..., channels)` while if
1197
+ `data_format="channels_first"`, `inputs` is of shape
1198
+ `(batch_size, channels, ...)`.
1199
+
1200
+ Returns:
1201
+ A tensor of rank N+2, the result of the average pooling operation.
1202
+ """
1203
+ data_format = standardize_data_format(data_format)
1204
+ padding = padding.lower()
1205
+ if any_symbolic_tensors((inputs,)):
1206
+ return AveragePool(
1207
+ pool_size,
1208
+ strides,
1209
+ padding,
1210
+ data_format,
1211
+ ).symbolic_call(inputs)
1212
+ return backend.nn.average_pool(
1213
+ inputs, pool_size, strides, padding, data_format
1214
+ )
1215
+
1216
+
1217
+ class Conv(Operation):
1218
+ def __init__(
1219
+ self,
1220
+ strides=1,
1221
+ padding="valid",
1222
+ data_format=None,
1223
+ dilation_rate=1,
1224
+ ):
1225
+ super().__init__()
1226
+ self.strides = strides
1227
+ self.padding = padding.lower()
1228
+ self.data_format = data_format
1229
+ self.dilation_rate = dilation_rate
1230
+
1231
+ def call(self, inputs, kernel):
1232
+ return backend.nn.conv(
1233
+ inputs,
1234
+ kernel,
1235
+ strides=self.strides,
1236
+ padding=self.padding,
1237
+ data_format=self.data_format,
1238
+ dilation_rate=self.dilation_rate,
1239
+ )
1240
+
1241
+ def compute_output_spec(self, inputs, kernel):
1242
+ output_shape = operation_utils.compute_conv_output_shape(
1243
+ inputs.shape,
1244
+ kernel.shape[-1],
1245
+ kernel.shape[:-2],
1246
+ self.strides,
1247
+ self.padding,
1248
+ self.data_format,
1249
+ self.dilation_rate,
1250
+ )
1251
+ return KerasTensor(output_shape, dtype=inputs.dtype)
1252
+
1253
+
1254
+ @keras_export(["keras.ops.conv", "keras.ops.nn.conv"])
1255
+ def conv(
1256
+ inputs,
1257
+ kernel,
1258
+ strides=1,
1259
+ padding="valid",
1260
+ data_format=None,
1261
+ dilation_rate=1,
1262
+ ):
1263
+ """General N-D convolution.
1264
+
1265
+ This ops supports 1D, 2D and 3D convolution.
1266
+
1267
+ Args:
1268
+ inputs: Tensor of rank N+2. `inputs` has shape
1269
+ `(batch_size,) + inputs_spatial_shape + (num_channels,)` if
1270
+ `data_format="channels_last"`, or
1271
+ `(batch_size, num_channels) + inputs_spatial_shape` if
1272
+ `data_format="channels_first"`.
1273
+ kernel: Tensor of rank N+2. `kernel` has shape
1274
+ `(kernel_spatial_shape, num_input_channels, num_output_channels)`.
1275
+ `num_input_channels` should match the number of channels in
1276
+ `inputs`.
1277
+ strides: int or int tuple/list of `len(inputs_spatial_shape)`,
1278
+ specifying the strides of the convolution along each spatial
1279
+ dimension. If `strides` is int, then every spatial dimension shares
1280
+ the same `strides`.
1281
+ padding: string, either `"valid"` or `"same"`. `"valid"` means no
1282
+ padding is applied, and `"same"` results in padding evenly to the
1283
+ left/right or up/down of the input such that output has the
1284
+ same height/width dimension as the input when `strides=1`.
1285
+ data_format: A string, either `"channels_last"` or `"channels_first"`.
1286
+ `data_format` determines the ordering of the dimensions in the
1287
+ inputs. If `data_format="channels_last"`, `inputs` is of shape
1288
+ `(batch_size, ..., channels)` while if
1289
+ `data_format="channels_first"`, `inputs` is of shape
1290
+ `(batch_size, channels, ...)`.
1291
+ dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
1292
+ specifying the dilation rate to use for dilated convolution. If
1293
+ `dilation_rate` is int, then every spatial dimension shares
1294
+ the same `dilation_rate`.
1295
+
1296
+ Returns:
1297
+ A tensor of rank N+2, the result of the conv operation.
1298
+ """
1299
+ data_format = standardize_data_format(data_format)
1300
+ padding = padding.lower()
1301
+ if any_symbolic_tensors((inputs,)):
1302
+ return Conv(strides, padding, data_format, dilation_rate).symbolic_call(
1303
+ inputs, kernel
1304
+ )
1305
+ return backend.nn.conv(
1306
+ inputs, kernel, strides, padding, data_format, dilation_rate
1307
+ )
1308
+
1309
+
1310
+ class DepthwiseConv(Operation):
1311
+ def __init__(
1312
+ self,
1313
+ strides=1,
1314
+ padding="valid",
1315
+ data_format=None,
1316
+ dilation_rate=1,
1317
+ ):
1318
+ super().__init__()
1319
+ self.strides = strides
1320
+ self.padding = padding.lower()
1321
+ self.data_format = data_format
1322
+ self.dilation_rate = dilation_rate
1323
+
1324
+ def call(self, inputs, kernel):
1325
+ return backend.nn.depthwise_conv(
1326
+ inputs,
1327
+ kernel,
1328
+ self.strides,
1329
+ self.padding,
1330
+ self.data_format,
1331
+ self.dilation_rate,
1332
+ )
1333
+
1334
+ def compute_output_spec(self, inputs, kernel):
1335
+ output_shape = operation_utils.compute_conv_output_shape(
1336
+ inputs.shape,
1337
+ kernel.shape[-1] * kernel.shape[-2],
1338
+ kernel.shape[:-2],
1339
+ self.strides,
1340
+ self.padding,
1341
+ self.data_format,
1342
+ self.dilation_rate,
1343
+ )
1344
+ return KerasTensor(output_shape, dtype=inputs.dtype)
1345
+
1346
+
1347
+ @keras_export(
1348
+ [
1349
+ "keras.ops.depthwise_conv",
1350
+ "keras.ops.nn.depthwise_conv",
1351
+ ]
1352
+ )
1353
+ def depthwise_conv(
1354
+ inputs,
1355
+ kernel,
1356
+ strides=1,
1357
+ padding="valid",
1358
+ data_format=None,
1359
+ dilation_rate=1,
1360
+ ):
1361
+ """General N-D depthwise convolution.
1362
+
1363
+ This ops supports 1D and 2D depthwise convolution.
1364
+
1365
+ Args:
1366
+ inputs: Tensor of rank N+2. `inputs` has shape
1367
+ `(batch_size,) + inputs_spatial_shape + (num_channels,)` if
1368
+ `data_format="channels_last"`, or
1369
+ `(batch_size, num_channels) + inputs_spatial_shape` if
1370
+ `data_format="channels_first"`.
1371
+ kernel: Tensor of rank N+2. `kernel` has shape
1372
+ [kernel_spatial_shape, num_input_channels, num_channels_multiplier],
1373
+ `num_input_channels` should match the number of channels in
1374
+ `inputs`.
1375
+ strides: int or int tuple/list of `len(inputs_spatial_shape)`,
1376
+ specifying the strides of the convolution along each spatial
1377
+ dimension. If `strides` is int, then every spatial dimension shares
1378
+ the same `strides`.
1379
+ padding: string, either `"valid"` or `"same"`. `"valid"` means no
1380
+ padding is applied, and `"same"` results in padding evenly to the
1381
+ left/right or up/down of the input such that output has the
1382
+ same height/width dimension as the input when `strides=1`.
1383
+ data_format: A string, either `"channels_last"` or `"channels_first"`.
1384
+ `data_format` determines the ordering of the dimensions in the
1385
+ inputs. If `data_format="channels_last"`, `inputs` is of shape
1386
+ `(batch_size, ..., channels)` while if
1387
+ `data_format="channels_first"`, `inputs` is of shape
1388
+ `(batch_size, channels, ...)`.
1389
+ dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
1390
+ specifying the dilation rate to use for dilated convolution. If
1391
+ `dilation_rate` is int, then every spatial dimension shares
1392
+ the same `dilation_rate`.
1393
+
1394
+ Returns:
1395
+ A tensor of rank N+2, the result of the depthwise conv operation.
1396
+ """
1397
+ data_format = standardize_data_format(data_format)
1398
+ padding = padding.lower()
1399
+ if any_symbolic_tensors((inputs,)):
1400
+ return DepthwiseConv(
1401
+ strides, padding, data_format, dilation_rate
1402
+ ).symbolic_call(inputs, kernel)
1403
+ return backend.nn.depthwise_conv(
1404
+ inputs,
1405
+ kernel,
1406
+ strides,
1407
+ padding,
1408
+ data_format,
1409
+ dilation_rate,
1410
+ )
1411
+
1412
+
1413
+ class SeparableConv(Operation):
1414
+ def __init__(
1415
+ self,
1416
+ strides=1,
1417
+ padding="valid",
1418
+ data_format=None,
1419
+ dilation_rate=1,
1420
+ ):
1421
+ super().__init__()
1422
+ self.strides = strides
1423
+ self.padding = padding.lower()
1424
+ self.data_format = data_format
1425
+ self.dilation_rate = dilation_rate
1426
+
1427
+ def call(self, inputs, depthwise_kernel, pointwise_kernel):
1428
+ return backend.nn.separable_conv(
1429
+ inputs,
1430
+ depthwise_kernel,
1431
+ pointwise_kernel,
1432
+ self.strides,
1433
+ self.padding,
1434
+ self.data_format,
1435
+ self.dilation_rate,
1436
+ )
1437
+
1438
+ def compute_output_spec(self, inputs, depthwise_kernel, pointwise_kernel):
1439
+ output_shape = list(
1440
+ depthwise_conv(
1441
+ inputs,
1442
+ depthwise_kernel,
1443
+ self.strides,
1444
+ self.padding,
1445
+ self.data_format,
1446
+ self.dilation_rate,
1447
+ ).shape
1448
+ )
1449
+ if self.data_format == "channels_last":
1450
+ output_shape[-1] = pointwise_kernel.shape[-1]
1451
+ else:
1452
+ output_shape[1] = pointwise_kernel.shape[-1]
1453
+ return KerasTensor(output_shape, dtype=inputs.dtype)
1454
+
1455
+
1456
+ @keras_export(
1457
+ [
1458
+ "keras.ops.separable_conv",
1459
+ "keras.ops.nn.separable_conv",
1460
+ ]
1461
+ )
1462
+ def separable_conv(
1463
+ inputs,
1464
+ depthwise_kernel,
1465
+ pointwise_kernel,
1466
+ strides=1,
1467
+ padding="valid",
1468
+ data_format=None,
1469
+ dilation_rate=1,
1470
+ ):
1471
+ """General N-D separable convolution.
1472
+
1473
+ This ops supports 1D and 2D separable convolution. `separable_conv` is
1474
+ a depthwise conv followed by a pointwise conv.
1475
+
1476
+ Args:
1477
+ inputs: Tensor of rank N+2. `inputs` has shape
1478
+ `(batch_size,) + inputs_spatial_shape + (num_channels,)` if
1479
+ `data_format="channels_last"`, or
1480
+ `(batch_size, num_channels) + inputs_spatial_shape` if
1481
+ `data_format="channels_first"`.
1482
+ depthwise_kernel: Tensor of rank N+2. `depthwise_kernel` has shape
1483
+ [kernel_spatial_shape, num_input_channels, num_channels_multiplier],
1484
+ `num_input_channels` should match the number of channels in
1485
+ `inputs`.
1486
+ pointwise_kernel: Tensor of rank N+2. `pointwise_kernel` has shape
1487
+ `(*ones_like(kernel_spatial_shape),
1488
+ num_input_channels * num_channels_multiplier, num_output_channels)`.
1489
+ strides: int or int tuple/list of `len(inputs_spatial_shape)`,
1490
+ specifying the strides of the convolution along each spatial
1491
+ dimension. If `strides` is int, then every spatial dimension shares
1492
+ the same `strides`.
1493
+ padding: string, either `"valid"` or `"same"`. `"valid"` means no
1494
+ padding is applied, and `"same"` results in padding evenly to the
1495
+ left/right or up/down of the input such that output has the
1496
+ same height/width dimension as the input when `strides=1`.
1497
+ data_format: A string, either `"channels_last"` or `"channels_first"`.
1498
+ `data_format` determines the ordering of the dimensions in the
1499
+ inputs. If `data_format="channels_last"`, `inputs` is of shape
1500
+ `(batch_size, ..., channels)` while if
1501
+ `data_format="channels_first"`, `inputs` is of shape
1502
+ `(batch_size, channels, ...)`.
1503
+ dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
1504
+ specifying the dilation rate to use for dilated convolution. If
1505
+ `dilation_rate` is int, then every spatial dimension shares
1506
+ the same `dilation_rate`.
1507
+
1508
+ Returns:
1509
+ A tensor of rank N+2, the result of the depthwise conv operation.
1510
+ """
1511
+ data_format = standardize_data_format(data_format)
1512
+ padding = padding.lower()
1513
+ if any_symbolic_tensors((inputs,)):
1514
+ return SeparableConv(
1515
+ strides,
1516
+ padding,
1517
+ data_format,
1518
+ dilation_rate,
1519
+ ).symbolic_call(inputs, depthwise_kernel, pointwise_kernel)
1520
+ return backend.nn.separable_conv(
1521
+ inputs,
1522
+ depthwise_kernel,
1523
+ pointwise_kernel,
1524
+ strides,
1525
+ padding,
1526
+ data_format,
1527
+ dilation_rate,
1528
+ )
1529
+
1530
+
1531
+ class ConvTranspose(Operation):
1532
+ def __init__(
1533
+ self,
1534
+ strides,
1535
+ padding="valid",
1536
+ output_padding=None,
1537
+ data_format=None,
1538
+ dilation_rate=1,
1539
+ ):
1540
+ super().__init__()
1541
+ self.strides = strides
1542
+ self.output_padding = output_padding
1543
+ self.padding = padding.lower()
1544
+ self.data_format = data_format
1545
+ self.dilation_rate = dilation_rate
1546
+
1547
+ def call(
1548
+ self,
1549
+ inputs,
1550
+ kernel,
1551
+ ):
1552
+ return backend.nn.conv_transpose(
1553
+ inputs,
1554
+ kernel,
1555
+ self.strides,
1556
+ self.output_padding,
1557
+ self.padding,
1558
+ self.data_format,
1559
+ self.dilation_rate,
1560
+ )
1561
+
1562
+ def compute_output_spec(self, inputs, kernel):
1563
+ kernel_size = kernel.shape[:-2]
1564
+ filters = kernel.shape[-2]
1565
+ output_shape = compute_conv_transpose_output_shape(
1566
+ inputs.shape,
1567
+ kernel_size,
1568
+ filters,
1569
+ self.strides,
1570
+ self.padding,
1571
+ self.output_padding,
1572
+ self.data_format,
1573
+ self.dilation_rate,
1574
+ )
1575
+ return KerasTensor(output_shape, dtype=inputs.dtype)
1576
+
1577
+
1578
+ @keras_export(
1579
+ [
1580
+ "keras.ops.conv_transpose",
1581
+ "keras.ops.nn.conv_transpose",
1582
+ ]
1583
+ )
1584
+ def conv_transpose(
1585
+ inputs,
1586
+ kernel,
1587
+ strides,
1588
+ padding="valid",
1589
+ output_padding=None,
1590
+ data_format=None,
1591
+ dilation_rate=1,
1592
+ ):
1593
+ """General N-D convolution transpose.
1594
+
1595
+ Also known as de-convolution. This ops supports 1D, 2D and 3D convolution.
1596
+
1597
+ Args:
1598
+ inputs: Tensor of rank N+2. `inputs` has shape
1599
+ `(batch_size,) + inputs_spatial_shape + (num_channels,)` if
1600
+ `data_format="channels_last"`, or
1601
+ `(batch_size, num_channels) + inputs_spatial_shape` if
1602
+ `data_format="channels_first"`.
1603
+ kernel: Tensor of rank N+2. `kernel` has shape
1604
+ [kernel_spatial_shape, num_output_channels, num_input_channels],
1605
+ `num_input_channels` should match the number of channels in
1606
+ `inputs`.
1607
+ strides: int or int tuple/list of `len(inputs_spatial_shape)`,
1608
+ specifying the strides of the convolution along each spatial
1609
+ dimension. If `strides` is int, then every spatial dimension shares
1610
+ the same `strides`.
1611
+ padding: string, either `"valid"` or `"same"`. `"valid"` means no
1612
+ padding is applied, and `"same"` results in padding evenly to the
1613
+ left/right or up/down of the input such that output has the
1614
+ same height/width dimension as the input when `strides=1`.
1615
+ output_padding: int or int tuple/list of `len(inputs_spatial_shape)`,
1616
+ specifying the amount of padding along the height and width of
1617
+ the output tensor. Can be a single integer to specify the same
1618
+ value for all spatial dimensions. The amount of output padding
1619
+ along a given dimension must be lower than the stride along that
1620
+ same dimension. If set to `None` (default), the output shape is
1621
+ inferred.
1622
+ data_format: A string, either `"channels_last"` or `"channels_first"`.
1623
+ `data_format` determines the ordering of the dimensions in the
1624
+ inputs. If `data_format="channels_last"`, `inputs` is of shape
1625
+ `(batch_size, ..., channels)` while if
1626
+ `data_format="channels_first"`, `inputs` is of shape
1627
+ `(batch_size, channels, ...)`.
1628
+ dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
1629
+ specifying the dilation rate to use for dilated convolution. If
1630
+ `dilation_rate` is int, then every spatial dimension shares
1631
+ the same `dilation_rate`.
1632
+
1633
+ Returns:
1634
+ A tensor of rank N+2, the result of the conv operation.
1635
+ """
1636
+ data_format = standardize_data_format(data_format)
1637
+ padding = padding.lower()
1638
+ if any_symbolic_tensors((inputs,)):
1639
+ return ConvTranspose(
1640
+ strides, padding, output_padding, data_format, dilation_rate
1641
+ ).symbolic_call(inputs, kernel)
1642
+ return backend.nn.conv_transpose(
1643
+ inputs,
1644
+ kernel,
1645
+ strides,
1646
+ padding,
1647
+ output_padding,
1648
+ data_format,
1649
+ dilation_rate,
1650
+ )
1651
+
1652
+
1653
+ class OneHot(Operation):
1654
+ def __init__(self, num_classes, axis=-1, dtype=None, sparse=False):
1655
+ super().__init__()
1656
+ self.num_classes = num_classes
1657
+ self.axis = axis
1658
+ self.dtype = dtype or backend.floatx()
1659
+ self.sparse = sparse
1660
+
1661
+ def call(self, x):
1662
+ return backend.nn.one_hot(
1663
+ x,
1664
+ self.num_classes,
1665
+ axis=self.axis,
1666
+ dtype=self.dtype,
1667
+ sparse=self.sparse,
1668
+ )
1669
+
1670
+ def compute_output_spec(self, x):
1671
+ x_shape = list(getattr(x, "shape", []))
1672
+ if self.axis == -1:
1673
+ x_shape.append(self.num_classes)
1674
+ elif self.axis >= 0 and self.axis < len(x_shape):
1675
+ x_shape.insert(self.axis, self.num_classes)
1676
+ else:
1677
+ raise ValueError(
1678
+ f"axis must be -1 or between [0, {len(x.shape)}), but "
1679
+ f"received {self.axis}."
1680
+ )
1681
+ return KerasTensor(x_shape, dtype=self.dtype, sparse=self.sparse)
1682
+
1683
+
1684
+ @keras_export(["keras.ops.one_hot", "keras.ops.nn.one_hot"])
1685
+ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
1686
+ """Converts integer tensor `x` into a one-hot tensor.
1687
+
1688
+ The one-hot encoding is a representation where each integer value is
1689
+ converted into a binary vector with a length equal to `num_classes`,
1690
+ and the index corresponding to the integer value is marked as 1, while
1691
+ all other indices are marked as 0.
1692
+
1693
+ Args:
1694
+ x: Integer tensor to be encoded. The shape can be
1695
+ arbitrary, but the dtype should be integer.
1696
+ num_classes: Number of classes for the one-hot encoding.
1697
+ axis: Axis along which the encoding is performed.
1698
+ `-1` represents the last axis. Defaults to `-1`.
1699
+ dtype: (Optional) Data type of the output tensor. If not
1700
+ provided, it defaults to the default data type of the backend.
1701
+ sparse: Whether to return a sparse tensor; for backends that support
1702
+ sparse tensors.
1703
+
1704
+ Returns:
1705
+ Integer tensor: One-hot encoded tensor with the same shape as `x`
1706
+ except for the specified `axis` dimension, which will have
1707
+ a length of `num_classes`. The dtype of the output tensor
1708
+ is determined by `dtype` or the default data type of the backend.
1709
+
1710
+ Example:
1711
+
1712
+ >>> x = keras.ops.convert_to_tensor([1, 3, 2, 0])
1713
+ >>> one_hot(x, num_classes=4)
1714
+ array([[0. 1. 0. 0.]
1715
+ [0. 0. 0. 1.]
1716
+ [0. 0. 1. 0.]
1717
+ [1. 0. 0. 0.]], shape=(4, 4), dtype=float32)
1718
+ """
1719
+ if any_symbolic_tensors((x,)):
1720
+ return OneHot(
1721
+ num_classes, axis=axis, dtype=dtype, sparse=sparse
1722
+ ).symbolic_call(x)
1723
+ return backend.nn.one_hot(
1724
+ x,
1725
+ num_classes,
1726
+ axis=axis,
1727
+ dtype=dtype or backend.floatx(),
1728
+ sparse=sparse,
1729
+ )
1730
+
1731
+
1732
+ class BinaryCrossentropy(Operation):
1733
+ def __init__(self, from_logits=False):
1734
+ super().__init__()
1735
+ self.from_logits = from_logits
1736
+
1737
+ def call(self, target, output):
1738
+ return backend.nn.binary_crossentropy(
1739
+ target, output, from_logits=self.from_logits
1740
+ )
1741
+
1742
+ def compute_output_spec(self, target, output):
1743
+ if target.shape != output.shape:
1744
+ raise ValueError(
1745
+ "Arguments `target` and `output` must have the same shape. "
1746
+ "Received: "
1747
+ f"target.shape={target.shape}, output.shape={output.shape}"
1748
+ )
1749
+ return KerasTensor(output.shape, dtype=output.dtype)
1750
+
1751
+
1752
+ @keras_export(
1753
+ [
1754
+ "keras.ops.binary_crossentropy",
1755
+ "keras.ops.nn.binary_crossentropy",
1756
+ ]
1757
+ )
1758
+ def binary_crossentropy(target, output, from_logits=False):
1759
+ """Computes binary cross-entropy loss between target and output tensor.
1760
+
1761
+ The binary cross-entropy loss is commonly used in binary
1762
+ classification tasks where each input sample belongs to one
1763
+ of the two classes. It measures the dissimilarity between the
1764
+ target and output probabilities or logits.
1765
+
1766
+ Args:
1767
+ target: The target tensor representing the true binary labels.
1768
+ Its shape should match the shape of the `output` tensor.
1769
+ output: The output tensor representing the predicted probabilities
1770
+ or logits. Its shape should match the shape of the
1771
+ `target` tensor.
1772
+ from_logits: (optional) Whether `output` is a tensor of logits or
1773
+ probabilities.
1774
+ Set it to `True` if `output` represents logits; otherwise,
1775
+ set it to `False` if `output` represents probabilities.
1776
+ Defaults to `False`.
1777
+
1778
+ Returns:
1779
+ Integer tensor: The computed binary cross-entropy loss between
1780
+ `target` and `output`.
1781
+
1782
+ Example:
1783
+
1784
+ >>> target = keras.ops.convert_to_tensor([0, 1, 1, 0])
1785
+ >>> output = keras.ops.convert_to_tensor([0.1, 0.9, 0.8, 0.2])
1786
+ >>> binary_crossentropy(target, output)
1787
+ array([0.10536054 0.10536054 0.22314355 0.22314355],
1788
+ shape=(4,), dtype=float32)
1789
+ """
1790
+ if any_symbolic_tensors((target, output)):
1791
+ return BinaryCrossentropy(from_logits=from_logits).symbolic_call(
1792
+ target, output
1793
+ )
1794
+ return backend.nn.binary_crossentropy(
1795
+ target, output, from_logits=from_logits
1796
+ )
1797
+
1798
+
1799
+ class CategoricalCrossentropy(Operation):
1800
+ def __init__(self, from_logits=False, axis=-1):
1801
+ super().__init__()
1802
+ self.from_logits = from_logits
1803
+ self.axis = axis
1804
+
1805
+ def call(self, target, output):
1806
+ return backend.nn.categorical_crossentropy(
1807
+ target, output, from_logits=self.from_logits, axis=self.axis
1808
+ )
1809
+
1810
+ def compute_output_spec(self, target, output):
1811
+ if target.shape != output.shape:
1812
+ raise ValueError(
1813
+ "Arguments `target` and `output` must have the same shape. "
1814
+ "Received: "
1815
+ f"target.shape={target.shape}, output.shape={output.shape}"
1816
+ )
1817
+ if len(target.shape) < 1:
1818
+ raise ValueError(
1819
+ "Arguments `target` and `output` must be at least rank 1. "
1820
+ "Received: "
1821
+ f"target.shape={target.shape}, output.shape={output.shape}"
1822
+ )
1823
+ return KerasTensor(output.shape[:-1], dtype=output.dtype)
1824
+
1825
+
1826
+ @keras_export(
1827
+ [
1828
+ "keras.ops.categorical_crossentropy",
1829
+ "keras.ops.nn.categorical_crossentropy",
1830
+ ]
1831
+ )
1832
+ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
1833
+ """Computes categorical cross-entropy loss between target and output tensor.
1834
+
1835
+ The categorical cross-entropy loss is commonly used in multi-class
1836
+ classification tasks where each input sample can belong to one of
1837
+ multiple classes. It measures the dissimilarity
1838
+ between the target and output probabilities or logits.
1839
+
1840
+ Args:
1841
+ target: The target tensor representing the true categorical labels.
1842
+ Its shape should match the shape of the `output` tensor
1843
+ except for the last dimension.
1844
+ output: The output tensor representing the predicted probabilities
1845
+ or logits. Its shape should match the shape of the `target`
1846
+ tensor except for the last dimension.
1847
+ from_logits: (optional) Whether `output` is a tensor of logits or
1848
+ probabilities.
1849
+ Set it to `True` if `output` represents logits; otherwise,
1850
+ set it to `False` if `output` represents probabilities.
1851
+ Defaults to `False`.
1852
+ axis: (optional) The axis along which the categorical cross-entropy
1853
+ is computed.
1854
+ Defaults to `-1`, which corresponds to the last dimension of
1855
+ the tensors.
1856
+
1857
+ Returns:
1858
+ Integer tensor: The computed categorical cross-entropy loss between
1859
+ `target` and `output`.
1860
+
1861
+ Example:
1862
+
1863
+ >>> target = keras.ops.convert_to_tensor(
1864
+ ... [[1, 0, 0],
1865
+ ... [0, 1, 0],
1866
+ ... [0, 0, 1]])
1867
+ >>> output = keras.ops.convert_to_tensor(
1868
+ ... [[0.9, 0.05, 0.05],
1869
+ ... [0.1, 0.8, 0.1],
1870
+ ... [0.2, 0.3, 0.5]])
1871
+ >>> categorical_crossentropy(target, output)
1872
+ array([0.10536054 0.22314355 0.6931472 ], shape=(3,), dtype=float32)
1873
+ """
1874
+ if any_symbolic_tensors((target, output)):
1875
+ return CategoricalCrossentropy(
1876
+ from_logits=from_logits, axis=axis
1877
+ ).symbolic_call(target, output)
1878
+ return backend.nn.categorical_crossentropy(
1879
+ target, output, from_logits=from_logits, axis=axis
1880
+ )
1881
+
1882
+
1883
+ class SparseCategoricalCrossentropy(Operation):
1884
+ def __init__(self, from_logits=False, axis=-1):
1885
+ super().__init__()
1886
+ self.from_logits = from_logits
1887
+ self.axis = axis
1888
+
1889
+ def call(self, target, output):
1890
+ return backend.nn.sparse_categorical_crossentropy(
1891
+ target, output, from_logits=self.from_logits, axis=self.axis
1892
+ )
1893
+
1894
+ def compute_output_spec(self, target, output):
1895
+ if len(output.shape) < 1:
1896
+ raise ValueError(
1897
+ "Argument `output` must be at least rank 1. "
1898
+ "Received: "
1899
+ f"output.shape={output.shape}"
1900
+ )
1901
+ target_shape = target.shape
1902
+ if len(target_shape) == len(output.shape) and target_shape[-1] == 1:
1903
+ target_shape = target_shape[:-1]
1904
+ if target_shape != output.shape[:-1]:
1905
+ raise ValueError(
1906
+ "Arguments `target` and `output` must have the same shape "
1907
+ "up until the last dimension: "
1908
+ f"target.shape={target.shape}, output.shape={output.shape}"
1909
+ )
1910
+ return KerasTensor(output.shape[:-1], dtype=output.dtype)
1911
+
1912
+
1913
+ @keras_export(
1914
+ [
1915
+ "keras.ops.sparse_categorical_crossentropy",
1916
+ "keras.ops.nn.sparse_categorical_crossentropy",
1917
+ ]
1918
+ )
1919
+ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
1920
+ """Computes sparse categorical cross-entropy loss.
1921
+
1922
+ The sparse categorical cross-entropy loss is similar to categorical
1923
+ cross-entropy, but it is used when the target tensor contains integer
1924
+ class labels instead of one-hot encoded vectors. It measures the
1925
+ dissimilarity between the target and output probabilities or logits.
1926
+
1927
+ Args:
1928
+ target: The target tensor representing the true class labels as
1929
+ integers. Its shape should match the shape of the `output`
1930
+ tensor except for the last dimension.
1931
+ output: The output tensor representing the predicted probabilities
1932
+ or logits.
1933
+ Its shape should match the shape of the `target` tensor except
1934
+ for the last dimension.
1935
+ from_logits: (optional) Whether `output` is a tensor of logits
1936
+ or probabilities.
1937
+ Set it to `True` if `output` represents logits; otherwise,
1938
+ set it to `False` if `output` represents probabilities.
1939
+ Defaults to `False`.
1940
+ axis: (optional) The axis along which the sparse categorical
1941
+ cross-entropy is computed.
1942
+ Defaults to `-1`, which corresponds to the last dimension
1943
+ of the tensors.
1944
+
1945
+ Returns:
1946
+ Integer tensor: The computed sparse categorical cross-entropy
1947
+ loss between `target` and `output`.
1948
+
1949
+ Example:
1950
+
1951
+ >>> target = keras.ops.convert_to_tensor([0, 1, 2], dtype=int32)
1952
+ >>> output = keras.ops.convert_to_tensor(
1953
+ ... [[0.9, 0.05, 0.05],
1954
+ ... [0.1, 0.8, 0.1],
1955
+ ... [0.2, 0.3, 0.5]])
1956
+ >>> sparse_categorical_crossentropy(target, output)
1957
+ array([0.10536056 0.22314355 0.6931472 ], shape=(3,), dtype=float32)
1958
+ """
1959
+ if any_symbolic_tensors((target, output)):
1960
+ return SparseCategoricalCrossentropy(
1961
+ from_logits=from_logits, axis=axis
1962
+ ).symbolic_call(target, output)
1963
+ return backend.nn.sparse_categorical_crossentropy(
1964
+ target, output, from_logits=from_logits, axis=axis
1965
+ )
1966
+
1967
+
1968
+ class MultiHot(Operation):
1969
+ def __init__(
1970
+ self, num_classes=None, axis=-1, dtype=None, sparse=False, **kwargs
1971
+ ):
1972
+ if num_classes is None and "num_tokens" in kwargs:
1973
+ num_classes = kwargs.pop("num_tokens")
1974
+ if num_classes is None:
1975
+ raise ValueError("Argument `num_classes` must be specified.")
1976
+ super().__init__(**kwargs)
1977
+ self.num_classes = num_classes
1978
+ self.axis = axis
1979
+ self.dtype = dtype or backend.floatx()
1980
+ self.sparse = sparse
1981
+
1982
+ def call(self, inputs):
1983
+ return backend.nn.multi_hot(
1984
+ inputs,
1985
+ num_classes=self.num_classes,
1986
+ axis=self.axis,
1987
+ dtype=self.dtype,
1988
+ )
1989
+
1990
+ def compute_output_spec(self, inputs):
1991
+ x_shape = list(getattr(inputs, "shape", []))
1992
+ if self.axis == -1:
1993
+ x_shape.append(self.num_classes)
1994
+ elif self.axis >= 0 and self.axis < len(x_shape):
1995
+ x_shape.insert(self.axis, self.num_classes)
1996
+ else:
1997
+ raise ValueError(
1998
+ f"axis must be -1 or between [0, {len(inputs.shape)}), but "
1999
+ f"received {self.axis}."
2000
+ )
2001
+
2002
+ if len(x_shape) == 2:
2003
+ x_shape = [x_shape[-1]]
2004
+ else:
2005
+ x_shape = [x_shape[0]] + x_shape[2:]
2006
+
2007
+ return KerasTensor(x_shape, dtype=inputs.dtype, sparse=self.sparse)
2008
+
2009
+
2010
+ @keras_export(
2011
+ [
2012
+ "keras.ops.multi_hot",
2013
+ "keras.ops.nn.multi_hot",
2014
+ ]
2015
+ )
2016
+ def multi_hot(
2017
+ inputs, num_classes=None, axis=-1, dtype=None, sparse=False, **kwargs
2018
+ ):
2019
+ """Encodes integer labels as multi-hot vectors.
2020
+
2021
+ This function encodes integer labels as multi-hot vectors, where each label
2022
+ is mapped to a binary value in the resulting vector.
2023
+
2024
+ Args:
2025
+ inputs: Tensor of integer labels to be converted to multi-hot vectors.
2026
+ num_classes: Integer, the total number of unique classes.
2027
+ axis: (optional) Axis along which the multi-hot encoding should be
2028
+ added. Defaults to `-1`, which corresponds to the last dimension.
2029
+ dtype: (optional) The data type of the resulting tensor. Default
2030
+ is backend's float type.
2031
+ sparse: Whether to return a sparse tensor; for backends that support
2032
+ sparse tensors.
2033
+
2034
+ Returns:
2035
+ Tensor: The multi-hot encoded tensor.
2036
+
2037
+ Example:
2038
+
2039
+ >>> data = keras.ops.convert_to_tensor([0, 4])
2040
+ >>> keras.ops.multi_hot(data, num_classes=5)
2041
+ array([1.0, 0.0, 0.0, 0.0, 1.0], dtype=float32)
2042
+
2043
+ """
2044
+ if num_classes is None and "num_tokens" in kwargs:
2045
+ num_classes = kwargs.pop("num_tokens")
2046
+ if num_classes is None:
2047
+ raise ValueError("Argument `num_classes` must be specified.")
2048
+
2049
+ if any_symbolic_tensors((inputs,)):
2050
+ return MultiHot(num_classes, axis, dtype, sparse).symbolic_call(inputs)
2051
+
2052
+ return backend.nn.multi_hot(inputs, num_classes, axis, dtype, sparse)
2053
+
2054
+
2055
+ class Moments(Operation):
2056
+ def __init__(self, axes, keepdims=False, synchronized=False):
2057
+ super().__init__()
2058
+ self.axes = axes
2059
+ self.keepdims = keepdims
2060
+ self.synchronized = synchronized
2061
+
2062
+ def call(self, x):
2063
+ return backend.nn.moments(
2064
+ x,
2065
+ axes=self.axes,
2066
+ keepdims=self.keepdims,
2067
+ synchronized=self.synchronized,
2068
+ )
2069
+
2070
+ def compute_output_spec(self, x):
2071
+ return (
2072
+ KerasTensor(
2073
+ reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),
2074
+ dtype=x.dtype,
2075
+ ),
2076
+ KerasTensor(
2077
+ reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),
2078
+ dtype=x.dtype,
2079
+ ),
2080
+ )
2081
+
2082
+
2083
+ @keras_export(
2084
+ [
2085
+ "keras.ops.moments",
2086
+ "keras.ops.nn.moments",
2087
+ ]
2088
+ )
2089
+ def moments(x, axes, keepdims=False, synchronized=False):
2090
+ """Calculates the mean and variance of `x`.
2091
+
2092
+ The mean and variance are calculated by aggregating the contents of `x`
2093
+ across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and
2094
+ variance of a vector.
2095
+
2096
+ Args:
2097
+ x: Input tensor.
2098
+ axes: A list of axes which to compute mean and variance.
2099
+ keepdims: If this is set to `True`, the axes which are reduced are left
2100
+ in the result as dimensions with size one.
2101
+ synchronized: Only applicable with the TensorFlow backend.
2102
+ If `True`, synchronizes the global batch statistics (mean and
2103
+ variance) across all devices at each training step in a
2104
+ distributed training strategy. If `False`, each replica uses its own
2105
+ local batch statistics.
2106
+
2107
+ Returns:
2108
+ A tuple containing two tensors - mean and variance.
2109
+
2110
+ Example:
2111
+
2112
+ >>> x = keras.ops.convert_to_tensor([0, 1, 2, 3, 100], dtype="float32")
2113
+ >>> keras.ops.moments(x, axes=[0])
2114
+ (array(21.2, dtype=float32), array(1553.3601, dtype=float32))
2115
+
2116
+ """
2117
+ if any_symbolic_tensors((x,)):
2118
+ return Moments(axes, keepdims, synchronized=synchronized).symbolic_call(
2119
+ x
2120
+ )
2121
+
2122
+ return backend.nn.moments(x, axes, keepdims, synchronized=synchronized)
2123
+
2124
+
2125
+ class BatchNorm(Operation):
2126
+ def __init__(self, axis, epsilon):
2127
+ super().__init__()
2128
+ self.axis = axis
2129
+ self.epsilon = epsilon
2130
+
2131
+ def _check_shape(self, name, shape, expected_shape):
2132
+ if shape != expected_shape:
2133
+ raise ValueError(
2134
+ f"Arguments `{name}` must be a vector of length "
2135
+ f"`x.shape[axis]`. Expected: `{expected_shape}`. "
2136
+ f"Received: `{shape}."
2137
+ )
2138
+
2139
+ def compute_output_spec(self, x, mean, variance, offset, scale):
2140
+ shape = (x.shape[self.axis],)
2141
+ self._check_shape("mean", tuple(mean.shape), shape)
2142
+ self._check_shape("variance", tuple(variance.shape), shape)
2143
+ if offset is not None:
2144
+ self._check_shape("offset", tuple(offset.shape), shape)
2145
+ if offset is not scale:
2146
+ self._check_shape("scale", tuple(scale.shape), shape)
2147
+ return KerasTensor(x.shape, dtype=x.dtype)
2148
+
2149
+
2150
+ @keras_export(
2151
+ [
2152
+ "keras.ops.batch_normalization",
2153
+ "keras.ops.nn.batch_normalization",
2154
+ ]
2155
+ )
2156
+ def batch_normalization(
2157
+ x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3
2158
+ ):
2159
+ """Normalizes `x` by `mean` and `variance`.
2160
+
2161
+ This op is typically used by the batch normalization step in a neural
2162
+ network. It normalizes the input tensor along the given axis.
2163
+
2164
+ Args:
2165
+ x: Input tensor.
2166
+ mean: A mean vector of the same length as the `axis` dimension of the
2167
+ input thensor.
2168
+ variance: A variance vector of the same length as the `axis` dimension
2169
+ of the input tensor.
2170
+ axis: Integer, the axis that should be normalized.
2171
+ offset: An offset vector of the same length as the `axis` dimension of
2172
+ the input tensor. If not `None`, `offset` is added to the normalized
2173
+ tensor. Defaults to `None`.
2174
+ scale: A scale vector of the same length as the `axis` dimension of the
2175
+ input tensor. If not `None`, the normalized tensor is multiplied by
2176
+ `scale`. Defaults to `None`.
2177
+ epsilon: Small float added to variance to avoid dividing by zero.
2178
+ Defaults to 1e-3.
2179
+
2180
+ Returns:
2181
+ The normalized tensor.
2182
+
2183
+ Example:
2184
+
2185
+ >>> x = keras.ops.convert_to_tensor(
2186
+ ... [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
2187
+ ... )
2188
+ >>> keras.ops.batch_normalization(
2189
+ ... x,
2190
+ ... mean=[0.4, 0.5, 0.6],
2191
+ ... variance=[0.67, 0.67, 0.67],
2192
+ ... axis=-1
2193
+ ... )
2194
+ array([[-3.6624e-01, -3.6624e-01, -3.6624e-01],
2195
+ [-4.6445e-09, 0.0000e+00, -1.8578e-08],
2196
+ [ 3.6624e-01, 3.6624e-01, 3.6624e-01]])
2197
+
2198
+ """
2199
+ if any_symbolic_tensors((x, mean, variance, offset, scale)):
2200
+ return BatchNorm(axis, epsilon).symbolic_call(
2201
+ x, mean, variance, offset, scale
2202
+ )
2203
+
2204
+ return backend.nn.batch_normalization(
2205
+ x, mean, variance, axis, offset, scale, epsilon
2206
+ )
2207
+
2208
+
2209
+ class CTCLoss(Operation):
2210
+ def __init__(self, mask_index=0):
2211
+ super().__init__()
2212
+ self.mask_index = mask_index
2213
+
2214
+ def call(self, target, output, target_length, output_length):
2215
+ return backend.nn.ctc_loss(
2216
+ target, output, target_length, output_length, self.mask_index
2217
+ )
2218
+
2219
+ def _check_shape_first_dim(self, name1, shape1, name2, shape2):
2220
+ if shape1[0] != shape2[0]:
2221
+ raise ValueError(
2222
+ f"Arguments `{name1}` and `{name2}` must have the same "
2223
+ "first dimension. "
2224
+ f"Received shapes: `{shape1}` and `{shape2}`."
2225
+ )
2226
+
2227
+ def compute_output_spec(self, target, output, target_length, output_length):
2228
+ self._check_shape_first_dim(
2229
+ "target", target.shape, "output", output.shape
2230
+ )
2231
+ self._check_shape_first_dim(
2232
+ "target_length", target_length.shape, "target", target.shape
2233
+ )
2234
+ self._check_shape_first_dim(
2235
+ "output_length", output_length.shape, "output", output.shape
2236
+ )
2237
+ dtype = backend.result_type(output.dtype, "float32")
2238
+ return KerasTensor((target.shape[0],), dtype=dtype)
2239
+
2240
+
2241
+ @keras_export(
2242
+ [
2243
+ "keras.ops.ctc_loss",
2244
+ "keras.ops.nn.ctc_loss",
2245
+ ]
2246
+ )
2247
+ def ctc_loss(target, output, target_length, output_length, mask_index=0):
2248
+ """CTC (Connectionist Temporal Classification) loss.
2249
+
2250
+ Args:
2251
+ target: A tensor of shape `(batch_size, max_length)` containing
2252
+ the true labels in integer format.
2253
+ output: A tensor of shape `(batch_size, max_length, num_classes)`
2254
+ containing logits (the output of your model).
2255
+ target_length: A tensor of shape `(batch_size,)` containing the
2256
+ true label lengths.
2257
+ output_length: A tensor of shape `(batch_size,)` containing the
2258
+ output lengths.
2259
+ mask_index: The index of the mask character in the vocabulary.
2260
+ Defaults to `0`.
2261
+ """
2262
+
2263
+ if any_symbolic_tensors((target, output, target_length, output_length)):
2264
+ return CTCLoss(mask_index).symbolic_call(
2265
+ target, output, target_length, output_length
2266
+ )
2267
+ return backend.nn.ctc_loss(
2268
+ target, output, target_length, output_length, mask_index
2269
+ )
2270
+
2271
+
2272
+ class CTCDecode(Operation):
2273
+ def __init__(
2274
+ self,
2275
+ strategy="greedy",
2276
+ beam_width=100,
2277
+ top_paths=1,
2278
+ merge_repeated=True,
2279
+ mask_index=0,
2280
+ ):
2281
+ super().__init__()
2282
+ self.strategy = strategy
2283
+ self.beam_width = beam_width
2284
+ self.top_paths = top_paths
2285
+ self.merge_repeated = merge_repeated
2286
+ self.mask_index = mask_index
2287
+
2288
+ def call(self, inputs, sequence_lengths):
2289
+ return backend.nn.ctc_decode(
2290
+ inputs,
2291
+ sequence_lengths,
2292
+ strategy=self.strategy,
2293
+ beam_width=self.beam_width,
2294
+ top_paths=self.top_paths,
2295
+ merge_repeated=self.merge_repeated,
2296
+ mask_index=self.mask_index,
2297
+ )
2298
+
2299
+ def compute_output_spec(self, inputs, sequence_lengths):
2300
+ inputs_shape = inputs.shape
2301
+ if self.strategy == "greedy":
2302
+ top_paths = 1
2303
+ else:
2304
+ top_paths = self.top_paths
2305
+ dtype = backend.result_type(inputs.dtype, "float32")
2306
+ return (
2307
+ KerasTensor(
2308
+ (top_paths, inputs_shape[0], inputs_shape[1]), dtype="int32"
2309
+ ),
2310
+ KerasTensor((inputs_shape[0], top_paths), dtype=dtype),
2311
+ )
2312
+
2313
+
2314
+ @keras_export(
2315
+ [
2316
+ "keras.ops.ctc_decode",
2317
+ "keras.ops.nn.ctc_decode",
2318
+ ]
2319
+ )
2320
+ def ctc_decode(
2321
+ inputs,
2322
+ sequence_lengths,
2323
+ strategy="greedy",
2324
+ beam_width=100,
2325
+ top_paths=1,
2326
+ merge_repeated=True,
2327
+ mask_index=0,
2328
+ ):
2329
+ """Decodes the output of a CTC model.
2330
+
2331
+ Args:
2332
+ inputs: A tensor of shape `(batch_size, max_length, num_classes)`
2333
+ containing the logits (the output of the model).
2334
+ They should *not* be normalized via softmax.
2335
+ sequence_lengths: A tensor of shape `(batch_size,)` containing the
2336
+ sequence lengths for the batch.
2337
+ strategy: A string for the decoding strategy. Supported values are
2338
+ `"greedy"` and `"beam_search"`.
2339
+ beam_width: An integer scalar beam width used in beam search.
2340
+ Defaults to 100.
2341
+ top_paths: An integer scalar, the number of top paths to return.
2342
+ Defaults to 1.
2343
+ merge_repeated: A boolean scalar, whether to merge repeated
2344
+ labels in the output. Defaults to `True`.
2345
+ mask_index: An integer scalar, the index of the mask character in
2346
+ the vocabulary. Defaults to `0`.
2347
+
2348
+ Returns:
2349
+ A tuple containing:
2350
+ - The tensor representing the list of decoded sequences. If
2351
+ `strategy="greedy"`, the shape is `(1, batch_size, max_length)`. If
2352
+ `strategy="beam_search"`, the shape is
2353
+ `(top_paths, batch_size, max_length)`. Note that: `-1` indicates the
2354
+ blank label.
2355
+ - If `strategy="greedy"`, a tensor of shape `(batch_size, 1)`
2356
+ representing the negative of the sum of the probability logits for
2357
+ each sequence. If `strategy="beam_seatch"`, a tensor of shape
2358
+ `(batch_size, top_paths)` representing the log probability for each
2359
+ sequence.
2360
+ """
2361
+
2362
+ if any_symbolic_tensors((inputs, sequence_lengths)):
2363
+ return CTCDecode(
2364
+ strategy=strategy,
2365
+ beam_width=beam_width,
2366
+ top_paths=top_paths,
2367
+ merge_repeated=merge_repeated,
2368
+ mask_index=mask_index,
2369
+ ).symbolic_call(inputs, sequence_lengths)
2370
+ return backend.nn.ctc_decode(
2371
+ inputs=inputs,
2372
+ sequence_lengths=sequence_lengths,
2373
+ strategy=strategy,
2374
+ beam_width=beam_width,
2375
+ top_paths=top_paths,
2376
+ merge_repeated=merge_repeated,
2377
+ mask_index=mask_index,
2378
+ )
2379
+
2380
+
2381
+ class Normalize(Operation):
2382
+ def __init__(self, axis=-1, order=2, epsilon=None):
2383
+ super().__init__()
2384
+ self.axis = axis
2385
+ self.order = order
2386
+ self.epsilon = epsilon
2387
+
2388
+ def compute_output_spec(self, x):
2389
+ return KerasTensor(shape=x.shape)
2390
+
2391
+ def call(self, x):
2392
+ return _normalize(
2393
+ x, axis=self.axis, order=self.order, epsilon=self.epsilon
2394
+ )
2395
+
2396
+
2397
+ @keras_export(
2398
+ [
2399
+ "keras.ops.normalize",
2400
+ "keras.ops.nn.normalize",
2401
+ ]
2402
+ )
2403
+ def normalize(x, axis=-1, order=2, epsilon=None):
2404
+ """Normalizes `x` over the specified axis.
2405
+
2406
+ It is defined as: `normalize(x) = x / max(norm(x), epsilon)`.
2407
+
2408
+ Args:
2409
+ x: Input tensor.
2410
+ axis: The axis or axes along which to perform normalization.
2411
+ Default to -1.
2412
+ order: The exponent value in the norm formulation.
2413
+ Defaults to 2.
2414
+ epsilon: A lower bound value for the norm.
2415
+ Defaults to `backend.epsilon()`.
2416
+
2417
+ Returns:
2418
+ The normalized array.
2419
+
2420
+ Example:
2421
+
2422
+ >>> x = keras.ops.convert_to_tensor([[1, 2, 3], [4, 5, 6]])
2423
+ >>> x_norm = keras.ops.math.normalize(x)
2424
+ >>> print(x_norm)
2425
+ array([[0.26726124 0.5345225 0.8017837 ]
2426
+ [0.45584232 0.5698029 0.68376344]], shape=(2, 3), dtype=float32)
2427
+
2428
+ """
2429
+ if any_symbolic_tensors((x,)):
2430
+ return Normalize(axis=axis, order=order, epsilon=epsilon).symbolic_call(
2431
+ x
2432
+ )
2433
+ return _normalize(x, axis=axis, order=order, epsilon=epsilon)
2434
+
2435
+
2436
+ def _normalize(x, axis=-1, order=2, epsilon=None):
2437
+ if not isinstance(order, int) or not order >= 1:
2438
+ raise ValueError(
2439
+ f"Argument `order` must be an int >= 1. Received: order={order}"
2440
+ )
2441
+ x = backend.convert_to_tensor(x)
2442
+ if len(x.shape) == 0:
2443
+ x = backend.numpy.expand_dims(x, axis=0)
2444
+ if epsilon is None:
2445
+ epsilon = backend.epsilon()
2446
+ if 2 == order:
2447
+ # A special case: L2 normalization with `x * rsqrt(...)`
2448
+ # instead of `x / sqrt(...)`
2449
+ square_sum = backend.numpy.sum(
2450
+ backend.numpy.square(x), axis=axis, keepdims=True
2451
+ )
2452
+ inv_norm = backend.math.rsqrt(square_sum)
2453
+ inv_norm = backend.numpy.minimum(inv_norm, 1.0 / epsilon)
2454
+ return x * inv_norm
2455
+ norm = backend.linalg.norm(x, ord=order, axis=axis, keepdims=True)
2456
+ denom = backend.numpy.maximum(norm, epsilon)
2457
+ return backend.numpy.divide(x, denom)
2458
+
2459
+
2460
+ class PSNR(Operation):
2461
+ def __init__(
2462
+ self,
2463
+ max_val,
2464
+ ):
2465
+ super().__init__()
2466
+ self.max_val = max_val
2467
+
2468
+ def call(self, x1, x2):
2469
+ return backend.nn.psnr(
2470
+ x1=x1,
2471
+ x2=x2,
2472
+ max_val=self.max_val,
2473
+ )
2474
+
2475
+ def compute_output_spec(self, x1, x2):
2476
+ if len(x1.shape) != len(x2.shape):
2477
+ raise ValueError("Inputs must have the same rank")
2478
+
2479
+ return KerasTensor(shape=())
2480
+
2481
+
2482
+ @keras_export(
2483
+ [
2484
+ "keras.ops.psnr",
2485
+ "keras.ops.nn.psnr",
2486
+ ]
2487
+ )
2488
+ def psnr(
2489
+ x1,
2490
+ x2,
2491
+ max_val,
2492
+ ):
2493
+ """Peak Signal-to-Noise Ratio (PSNR) function.
2494
+
2495
+ This function computes the Peak Signal-to-Noise Ratio between two signals,
2496
+ `x1` and `x2`. PSNR is a measure of the quality of a reconstructed signal.
2497
+ The higher the PSNR, the closer the reconstructed signal is to the original
2498
+ signal. Note that it can become negative when the signal power is
2499
+ smaller that the noise power.
2500
+
2501
+ Args:
2502
+ x1: The first input signal.
2503
+ x2: The second input signal. Must have the same shape as `x1`.
2504
+ max_val: The maximum possible value in the signals.
2505
+
2506
+ Returns:
2507
+ float: The PSNR value between `x1` and `x2`.
2508
+
2509
+ Examples:
2510
+
2511
+ >>> x1 = keras.random.normal((2, 4, 4, 3))
2512
+ >>> x2 = keras.random.normal((2, 4, 4, 3))
2513
+ >>> max_val = 1.0
2514
+ >>> keras.ops.nn.psnr(x1, x2, max_val)
2515
+ -3.1697404
2516
+ """
2517
+ if any_symbolic_tensors(
2518
+ (
2519
+ x1,
2520
+ x2,
2521
+ )
2522
+ ):
2523
+ return PSNR(
2524
+ max_val,
2525
+ ).symbolic_call(x1, x2)
2526
+ return backend.nn.psnr(
2527
+ x1,
2528
+ x2,
2529
+ max_val,
2530
+ )
2531
+
2532
+
2533
+ class DotProductAttention(Operation):
2534
+ def __init__(self, is_causal=False):
2535
+ super().__init__()
2536
+ self.is_causal = is_causal
2537
+
2538
+ def call(
2539
+ self,
2540
+ query,
2541
+ key,
2542
+ value,
2543
+ bias=None,
2544
+ mask=None,
2545
+ scale=None,
2546
+ flash_attention=None,
2547
+ ):
2548
+ return backend.nn.dot_product_attention(
2549
+ query,
2550
+ key,
2551
+ value,
2552
+ bias=bias,
2553
+ mask=mask,
2554
+ scale=scale,
2555
+ is_causal=self.is_causal,
2556
+ flash_attention=flash_attention,
2557
+ )
2558
+
2559
+ def compute_output_spec(
2560
+ self,
2561
+ query,
2562
+ key,
2563
+ value,
2564
+ bias=None,
2565
+ mask=None,
2566
+ scale=None,
2567
+ flash_attention=None,
2568
+ ):
2569
+ return KerasTensor(query.shape, dtype=query.dtype)
2570
+
2571
+
2572
+ @keras_export(
2573
+ ["keras.ops.dot_product_attention", "keras.ops.nn.dot_product_attention"]
2574
+ )
2575
+ def dot_product_attention(
2576
+ query,
2577
+ key,
2578
+ value,
2579
+ bias=None,
2580
+ mask=None,
2581
+ scale=None,
2582
+ is_causal=False,
2583
+ flash_attention=None,
2584
+ ):
2585
+ """Scaled dot product attention function.
2586
+
2587
+ Computes the attention function on Q (`query`), K (`key`), and V(`value`):
2588
+ `attention(Q, K, V) = softmax(Q * K / sqrt(d)) * V`. If we define `logits`
2589
+ as the output of `Q * K` and the `probs` as the output of `softmax`.
2590
+
2591
+ Throughout this function, we utilize the following notation to represent the
2592
+ shape of array:
2593
+ - B: batch size
2594
+ - S: length of the key/value
2595
+ - T: length of the query
2596
+ - N: number of attention heads
2597
+ - H: dimensions of each attention head
2598
+ - K: number of key/value heads
2599
+ - G: number of groups, which equals to `N // K`
2600
+
2601
+ Args:
2602
+ query: The query array with the shape of `(B, T, N, H)`.
2603
+ key: The key array with the shape of `(B, S, K, H)`. When `K` equals
2604
+ `N`, multi-headed attention (MHA) is performed. Otherwise, grouped
2605
+ query attention (GQA) is performed if `N` is a multiple of `K`. and
2606
+ multi-query attention (MQA) is performed if `K==1` (a special case
2607
+ of GQA).
2608
+ value: The value array with the same shape of `key`.
2609
+ bias: Optional bias array to be added to logits. The shape must be
2610
+ broadcastable to `(B, N, T, S)`.
2611
+ mask: Optional mask array used to filter out logits. It is a boolean
2612
+ mask where `True` indicates the element should take part in
2613
+ attention. For an additive mask, users should pass it to bias. The
2614
+ shape must be broadcastable to `(B, N, T, S)`.
2615
+ scale: Optional scale for the logits. If `None`, the scale will be set
2616
+ to `1.0 / sqrt(H)`.
2617
+ is_causal: Whether to apply causal mask.
2618
+ flash_attention: Whether to use flash attention. If `None`, it will
2619
+ attempt to use flash attention if the required conditions are met.
2620
+ Typically, the inputs must be in float16 and bfloat16 dtype and the
2621
+ input layout requirements may vary depending on the backend.
2622
+
2623
+ Returns:
2624
+ An array of the attention output with the same shape of `query`.
2625
+
2626
+ Example:
2627
+
2628
+ >>> query = keras.random.normal((2, 4, 8, 16))
2629
+ >>> key = keras.random.normal((2, 6, 8, 16))
2630
+ >>> value = keras.random.normal((2, 6, 8, 16))
2631
+ >>> keras.ops.nn.dot_product_attention(query, key, value).shape
2632
+ (2, 4, 8, 16)
2633
+ """
2634
+ if any_symbolic_tensors((query, key, value)):
2635
+ return DotProductAttention(is_causal=is_causal).symbolic_call(
2636
+ query,
2637
+ key,
2638
+ value,
2639
+ bias=bias,
2640
+ mask=mask,
2641
+ scale=scale,
2642
+ flash_attention=flash_attention,
2643
+ )
2644
+ return backend.nn.dot_product_attention(
2645
+ query,
2646
+ key,
2647
+ value,
2648
+ bias=bias,
2649
+ mask=mask,
2650
+ scale=scale,
2651
+ is_causal=is_causal,
2652
+ flash_attention=flash_attention,
2653
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/node.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+
3
+ from keras.src import tree
4
+ from keras.src.backend import KerasTensor
5
+ from keras.src.ops.symbolic_arguments import SymbolicArguments
6
+
7
+
8
+ class Node:
9
+ """A `Node` describes an operation `__call__()` event.
10
+
11
+ A Keras Function is a DAG with `Node` instances as nodes, and
12
+ `KerasTensor` instances as edges. Nodes aren't `Operation` instances,
13
+ because a single operation could be called multiple times, which would
14
+ result in graph cycles.
15
+
16
+ A `__call__()` event involves input tensors (and other input arguments),
17
+ the operation that was called, and the resulting output tensors.
18
+ A `Node` will include all this information.
19
+
20
+ Since a single `Operation` could be called multiple times,
21
+ the `Node` instances are stored on operations as a list.
22
+ Each time an operation is called, a node is added to `op._inbound_nodes`.
23
+ Each time the output of an operation is used by another operation,
24
+ a node is added to `op._outbound_nodes`.
25
+
26
+ Every `KerasTensor` instance has a `KerasHistory` object attached,
27
+ which tracks the `Node` that records the `__call__()` event that created
28
+ the tensor. By recursively walking through `Node` instances
29
+ via the `KerasHistory` metadata of `KerasTensor` instances, once can
30
+ retrieve the entire DAG of a Keras Function.
31
+
32
+ Args:
33
+ operation: The Operation that was called in the `op.__call__()`
34
+ event that this node represents.
35
+ call_args: The positional arguments the operation was called with.
36
+ call_kwargs: The keyword arguments the operation was called with.
37
+ outputs: The output tensors of the `op.__call__()` call.
38
+ """
39
+
40
+ def __init__(
41
+ self, operation, call_args=None, call_kwargs=None, outputs=None
42
+ ):
43
+ self.operation = operation
44
+ self.arguments = SymbolicArguments(*call_args, **call_kwargs)
45
+ self.outputs = [] if outputs is None else tree.flatten(outputs)
46
+ for x in self.outputs:
47
+ if not isinstance(x, KerasTensor):
48
+ raise ValueError(
49
+ "All operation outputs must be tensors. "
50
+ f"Operation {operation} returned a non-tensor. "
51
+ f"Non-tensor received: {x}"
52
+ )
53
+
54
+ zero_history = any(
55
+ not x.record_history for x in self.arguments.keras_tensors
56
+ )
57
+
58
+ # If inputs don't have metadata yet, add it.
59
+ if not zero_history:
60
+ for tensor in self.arguments.keras_tensors:
61
+ if not hasattr(tensor, "_keras_history"):
62
+ tensor._keras_history = KerasHistory(
63
+ operation=None, node_index=0, tensor_index=0
64
+ )
65
+
66
+ # Wire up Node to Operations.
67
+ self.operation._inbound_nodes.append(self)
68
+ for kt in self.arguments.keras_tensors:
69
+ inbound_op = kt._keras_history.operation
70
+ if inbound_op is not None: # It's a graph entry point.
71
+ inbound_op._outbound_nodes.append(self)
72
+
73
+ # Set metadata on outputs.
74
+ if not zero_history:
75
+ node_index = len(self.operation._inbound_nodes) - 1
76
+ for i, tensor in enumerate(self.outputs):
77
+ tensor._keras_history = KerasHistory(
78
+ operation=operation, node_index=node_index, tensor_index=i
79
+ )
80
+
81
+ # Whether this is a root node.
82
+ self.is_input = not self.arguments.keras_tensors
83
+
84
+ def __repr__(self):
85
+ return f"<Node operation={self.operation.name}, id={id(self)}>"
86
+
87
+ @property
88
+ def input_tensors(self):
89
+ return self.arguments.keras_tensors
90
+
91
+ @property
92
+ def output_tensors(self):
93
+ return self.outputs
94
+
95
+ @property
96
+ def parent_nodes(self):
97
+ """The parent `Node`s.
98
+
99
+ Returns:
100
+ all the `Node`s whose output this node immediately depends on.
101
+ """
102
+ node_deps = []
103
+ for kt in self.arguments.keras_tensors:
104
+ op = kt._keras_history.operation
105
+ node_index = kt._keras_history.node_index
106
+ if op is not None: # `None` for `Input` tensors.
107
+ node_deps.append(op._inbound_nodes[node_index])
108
+ return node_deps
109
+
110
+
111
+ class KerasHistory(
112
+ collections.namedtuple(
113
+ "KerasHistory", ["operation", "node_index", "tensor_index"]
114
+ )
115
+ ):
116
+ """Tracks the Operation call that created a Tensor.
117
+
118
+ During construction of Keras Functions, this metadata is added to
119
+ each Tensor produced as the output of an Operation.
120
+ This allows Keras to track how each Tensor was produced, and
121
+ this information is later retraced by the `Function` class to
122
+ reconstruct the Operations graph.
123
+
124
+ Attributes:
125
+ operation: The Operation instance that produced the Tensor.
126
+ node_index: The specific call to the Operation that produced this Tensor.
127
+ Operations can be called multiple times in order to share weights. A new
128
+ node is created every time an Operation is called. The corresponding
129
+ node that represents the call event that produced the Tensor can be
130
+ found at `op._inbound_nodes[node_index]`.
131
+ tensor_index: The output index for this Tensor.
132
+ Always zero if the Operation that produced this Tensor
133
+ only has one output. Nested structures of
134
+ Tensors are deterministically assigned an index via `nest.flatten`.
135
+ """
136
+
137
+ # Added to maintain memory and performance characteristics of `namedtuple`
138
+ # while subclassing.
139
+ __slots__ = ()
140
+
141
+
142
+ def is_keras_tensor(obj):
143
+ return hasattr(obj, "_keras_history")
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/numpy.py ADDED
The diff for this file is too large to render. See raw diff
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/operation.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import textwrap
3
+
4
+ from keras.src import backend
5
+ from keras.src import dtype_policies
6
+ from keras.src import tree
7
+ from keras.src.api_export import keras_export
8
+ from keras.src.backend.common.keras_tensor import any_symbolic_tensors
9
+ from keras.src.ops.node import Node
10
+ from keras.src.utils import python_utils
11
+ from keras.src.utils import traceback_utils
12
+ from keras.src.utils.naming import auto_name
13
+
14
+
15
+ @keras_export("keras.Operation")
16
+ class Operation:
17
+ def __init__(self, dtype=None, name=None):
18
+ if name is None:
19
+ name = auto_name(self.__class__.__name__)
20
+ if not isinstance(name, str) or "/" in name:
21
+ raise ValueError(
22
+ "Argument `name` must be a string and "
23
+ "cannot contain character `/`. "
24
+ f"Received: name={name} (of type {type(name)})"
25
+ )
26
+ self._dtype_policy = dtype_policies.get(dtype)
27
+ self.name = name
28
+ self._inbound_nodes = []
29
+ self._outbound_nodes = []
30
+
31
+ @traceback_utils.filter_traceback
32
+ def __call__(self, *args, **kwargs):
33
+ if traceback_utils.is_traceback_filtering_enabled():
34
+ # Wrap self.call to provide helpful info in case of exception
35
+ if any_symbolic_tensors(args, kwargs):
36
+ call_fn = self.symbolic_call
37
+ else:
38
+ if getattr(self, "quantization_mode", None) is not None:
39
+ call_fn = self.quantized_call
40
+ else:
41
+ call_fn = self.call
42
+ call_fn = traceback_utils.inject_argument_info_in_traceback(
43
+ call_fn,
44
+ object_name=(f"{self.__class__.__name__}.call()"),
45
+ )
46
+ return call_fn(*args, **kwargs)
47
+
48
+ # Plain flow.
49
+ if any_symbolic_tensors(args, kwargs):
50
+ return self.symbolic_call(*args, **kwargs)
51
+ if getattr(self, "quantization_mode", None) is not None:
52
+ return self.quantized_call(*args, **kwargs)
53
+ else:
54
+ return self.call(*args, **kwargs)
55
+
56
+ def symbolic_call(self, *args, **kwargs):
57
+ # Perform shape/dtype inference.
58
+ outputs = self.compute_output_spec(*args, **kwargs)
59
+ # Record a new node in the operations graph.
60
+ # The Node wires itself to inbound and outbound ops. The
61
+ # Node constructor updates this op's self._inbound_nodes,
62
+ # sets _keras_history on the outputs, and adds itself to the
63
+ # `_outbound_nodes` of the ops that produced the inputs to this
64
+ # call.
65
+ Node(
66
+ operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs
67
+ )
68
+ return outputs
69
+
70
+ def call(self, *args, **kwargs):
71
+ raise NotImplementedError
72
+
73
+ def quantized_call(self, *args, **kwargs):
74
+ raise NotImplementedError
75
+
76
+ def compute_output_spec(self, *args, **kwargs):
77
+ try:
78
+ return backend.compute_output_spec(self.call, *args, **kwargs)
79
+ except Exception as e:
80
+ new_e = e.__class__(
81
+ "Could not automatically infer the output shape / dtype of "
82
+ f"'{self.name}' (of type {self.__class__.__name__}). "
83
+ f"Either the `{self.__class__.__name__}.call()` method "
84
+ f"is incorrect, or you need to implement the "
85
+ f"`{self.__class__.__name__}.compute_output_spec() / "
86
+ "compute_output_shape()` method. "
87
+ f"Error encountered:\n\n{e}"
88
+ )
89
+ raise new_e.with_traceback(e.__traceback__) from None
90
+
91
+ def __new__(cls, *args, **kwargs):
92
+ """We override __new__ to saving serializable constructor arguments.
93
+
94
+ These arguments are used to auto-generate an object serialization
95
+ config, which enables user-created subclasses to be serializable
96
+ out of the box in most cases without forcing the user
97
+ to manually implement `get_config()`.
98
+ """
99
+ instance = super(Operation, cls).__new__(cls)
100
+
101
+ # Generate a config to be returned by default by `get_config()`.
102
+ arg_names = inspect.getfullargspec(cls.__init__).args
103
+ kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
104
+
105
+ # Explicitly serialize `dtype` to support auto_config
106
+ dtype = kwargs.get("dtype", None)
107
+ if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy):
108
+ # For backward compatibility, we use a str (`name`) for
109
+ # `DTypePolicy`
110
+ if dtype.quantization_mode is None:
111
+ kwargs["dtype"] = dtype.name
112
+ # Otherwise, use `dtype_policies.serialize`
113
+ else:
114
+ kwargs["dtype"] = dtype_policies.serialize(dtype)
115
+
116
+ # For safety, we only rely on auto-configs for a small set of
117
+ # serializable types.
118
+ supported_types = (str, int, float, bool, type(None))
119
+ try:
120
+ flat_arg_values = tree.flatten(kwargs)
121
+ auto_config = True
122
+ for value in flat_arg_values:
123
+ if not isinstance(value, supported_types):
124
+ auto_config = False
125
+ break
126
+ except TypeError:
127
+ auto_config = False
128
+ try:
129
+ instance._lock = False
130
+ if auto_config:
131
+ from keras.src.saving import serialization_lib
132
+
133
+ instance._auto_config = serialization_lib.SerializableDict(
134
+ **kwargs
135
+ )
136
+ else:
137
+ instance._auto_config = None
138
+ instance._lock = True
139
+ except RecursionError:
140
+ # Setting an instance attribute in __new__ has the potential
141
+ # to trigger an infinite recursion if a subclass overrides
142
+ # setattr in an unsafe way.
143
+ pass
144
+ return instance
145
+
146
+ @python_utils.default
147
+ def get_config(self):
148
+ """Returns the config of the object.
149
+
150
+ An object config is a Python dictionary (serializable)
151
+ containing the information needed to re-instantiate it.
152
+ """
153
+ config = {
154
+ "name": self.name,
155
+ }
156
+
157
+ if not python_utils.is_default(self.get_config):
158
+ # In this case the subclass implements get_config()
159
+ return config
160
+
161
+ # In this case the subclass doesn't implement get_config():
162
+ # Let's see if we can autogenerate it.
163
+ if getattr(self, "_auto_config", None) is not None:
164
+ xtra_args = set(config.keys())
165
+ config.update(self._auto_config.config)
166
+ # Remove args non explicitly supported
167
+ argspec = inspect.getfullargspec(self.__init__)
168
+ if argspec.varkw != "kwargs":
169
+ for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
170
+ config.pop(key, None)
171
+ return config
172
+ else:
173
+ raise NotImplementedError(
174
+ textwrap.dedent(
175
+ f"""
176
+ Object {self.__class__.__name__} was created by passing
177
+ non-serializable argument values in `__init__()`,
178
+ and therefore the object must override `get_config()` in
179
+ order to be serializable. Please implement `get_config()`.
180
+
181
+ Example:
182
+
183
+ class CustomLayer(keras.layers.Layer):
184
+ def __init__(self, arg1, arg2, **kwargs):
185
+ super().__init__(**kwargs)
186
+ self.arg1 = arg1
187
+ self.arg2 = arg2
188
+
189
+ def get_config(self):
190
+ config = super().get_config()
191
+ config.update({{
192
+ "arg1": self.arg1,
193
+ "arg2": self.arg2,
194
+ }})
195
+ return config"""
196
+ )
197
+ )
198
+
199
+ @classmethod
200
+ def from_config(cls, config):
201
+ """Creates an operation from its config.
202
+
203
+ This method is the reverse of `get_config`, capable of instantiating the
204
+ same operation from the config dictionary.
205
+
206
+ Note: If you override this method, you might receive a serialized dtype
207
+ config, which is a `dict`. You can deserialize it as follows:
208
+
209
+ ```python
210
+ if "dtype" in config and isinstance(config["dtype"], dict):
211
+ policy = dtype_policies.deserialize(config["dtype"])
212
+ ```
213
+
214
+ Args:
215
+ config: A Python dictionary, typically the output of `get_config`.
216
+
217
+ Returns:
218
+ An operation instance.
219
+ """
220
+ # Explicitly deserialize dtype config if needed. This enables users to
221
+ # directly interact with the instance of `DTypePolicy`.
222
+ if "dtype" in config and isinstance(config["dtype"], dict):
223
+ config = config.copy()
224
+ policy = dtype_policies.deserialize(config["dtype"])
225
+ if (
226
+ not isinstance(policy, dtype_policies.DTypePolicyMap)
227
+ and policy.quantization_mode is None
228
+ ):
229
+ # For backward compatibility, we use a str (`name`) for
230
+ # `DTypePolicy`
231
+ policy = policy.name
232
+ config["dtype"] = policy
233
+ try:
234
+ return cls(**config)
235
+ except Exception as e:
236
+ raise TypeError(
237
+ f"Error when deserializing class '{cls.__name__}' using "
238
+ f"config={config}.\n\nException encountered: {e}"
239
+ )
240
+
241
+ def __repr__(self):
242
+ return f"<Operation name={self.name}>"
243
+
244
+ @property
245
+ def input(self):
246
+ """Retrieves the input tensor(s) of a symbolic operation.
247
+
248
+ Only returns the tensor(s) corresponding to the *first time*
249
+ the operation was called.
250
+
251
+ Returns:
252
+ Input tensor or list of input tensors.
253
+ """
254
+ return self._get_node_attribute_at_index(0, "input_tensors", "input")
255
+
256
+ @property
257
+ def output(self):
258
+ """Retrieves the output tensor(s) of a layer.
259
+
260
+ Only returns the tensor(s) corresponding to the *first time*
261
+ the operation was called.
262
+
263
+ Returns:
264
+ Output tensor or list of output tensors.
265
+ """
266
+ return self._get_node_attribute_at_index(0, "output_tensors", "output")
267
+
268
+ def _get_node_attribute_at_index(self, node_index, attr, attr_name):
269
+ """Private utility to retrieves an attribute (e.g. inputs) from a node.
270
+
271
+ This is used to implement the properties:
272
+ - output
273
+ - input
274
+
275
+ Args:
276
+ node_index: Integer index of the node from which
277
+ to retrieve the attribute.
278
+ attr: Exact node attribute name.
279
+ attr_name: Human-readable attribute name, for error messages.
280
+
281
+ Returns:
282
+ The operation's attribute `attr` at the node of index `node_index`.
283
+ """
284
+ if not self._inbound_nodes:
285
+ raise AttributeError(
286
+ f"The layer {self.name} has never been called "
287
+ f"and thus has no defined {attr_name}."
288
+ )
289
+ if not len(self._inbound_nodes) > node_index:
290
+ raise ValueError(
291
+ f"Asked to get {attr_name} at node "
292
+ f"{node_index}, but the operation has only "
293
+ f"{len(self._inbound_nodes)} inbound nodes."
294
+ )
295
+ values = getattr(self._inbound_nodes[node_index], attr)
296
+ if isinstance(values, list) and len(values) == 1:
297
+ return values[0]
298
+ else:
299
+ return values
300
+
301
+ # Hooks for backend layer classes
302
+ def _post_build(self):
303
+ """Can be overridden for per backend post build actions."""
304
+ pass
305
+
306
+ def _setattr_hook(self, name, value):
307
+ """Can be overridden for per backend post build actions."""
308
+ return name, value
309
+
310
+ def _post_track_variable(self, variable):
311
+ """Can be overridden for per backend post track actions."""
312
+ pass
313
+
314
+ def _post_untrack_variable(self, variable):
315
+ """Can be overridden for per backend post untrack actions."""
316
+ pass
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/operation_utils.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+
5
+ from keras.src import tree
6
+ from keras.src.api_export import keras_export
7
+ from keras.src.backend.common.backend_utils import canonicalize_axis
8
+ from keras.src.backend.common.backend_utils import to_tuple_or_list
9
+
10
+
11
+ def broadcast_shapes(shape1, shape2):
12
+ """Broadcast input shapes to a unified shape.
13
+
14
+ Convert to list for mutability.
15
+
16
+ Args:
17
+ shape1: A tuple or list of integers.
18
+ shape2: A tuple or list of integers.
19
+
20
+ Returns:
21
+ output_shape (list of integers or `None`): The broadcasted shape.
22
+
23
+ Example:
24
+ >>> broadcast_shapes((5, 3), (1, 3))
25
+ [5, 3]
26
+ """
27
+ shape1 = list(shape1)
28
+ shape2 = list(shape2)
29
+ origin_shape1 = shape1
30
+ origin_shape2 = shape2
31
+
32
+ if len(shape1) > len(shape2):
33
+ shape2 = [1] * (len(shape1) - len(shape2)) + shape2
34
+ if len(shape1) < len(shape2):
35
+ shape1 = [1] * (len(shape2) - len(shape1)) + shape1
36
+ output_shape = list(shape1)
37
+ for i in range(len(shape1)):
38
+ if shape1[i] == 1:
39
+ output_shape[i] = shape2[i]
40
+ elif shape1[i] is None:
41
+ output_shape[i] = None if shape2[i] == 1 else shape2[i]
42
+ else:
43
+ if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]:
44
+ output_shape[i] = shape1[i]
45
+ else:
46
+ raise ValueError(
47
+ "Cannot broadcast shape, the failure dim has value "
48
+ f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. "
49
+ f"Input shapes are: {origin_shape1} and {origin_shape2}."
50
+ )
51
+
52
+ return output_shape
53
+
54
+
55
+ def compute_expand_dims_output_shape(input_shape, axis):
56
+ """Compute the output shape for the `expand_dims` operation.
57
+
58
+ Args:
59
+ input_shape: Input shape.
60
+ axis: int or sequence of ints for the axis to expand.
61
+
62
+ Returns:
63
+ Tuple of ints: The output shape after the `expand_dims` operation.
64
+ """
65
+ input_shape = list(input_shape)
66
+ if axis is None:
67
+ axis = len(input_shape)
68
+ axis = to_tuple_or_list(axis)
69
+ out_ndim = len(axis) + len(input_shape)
70
+ axis = [canonicalize_axis(a, out_ndim) for a in axis]
71
+ shape_iter = iter(input_shape)
72
+ new_shape = [
73
+ 1 if ax in axis else next(shape_iter) for ax in range(out_ndim)
74
+ ]
75
+ return tuple(new_shape)
76
+
77
+
78
+ def compute_pooling_output_shape(
79
+ input_shape,
80
+ pool_size,
81
+ strides,
82
+ padding="valid",
83
+ data_format="channels_last",
84
+ ):
85
+ """Computes the output shape of pooling operations.
86
+
87
+ Args:
88
+ input_shape: Input shape. Must be a tuple of integers.
89
+ pool_size: Size of the pooling operation. Must be a tuple of integers.
90
+ strides: Stride of the pooling operation. Must be a tuple of integers.
91
+ Defaults to `pool_size`.
92
+ padding: Padding method. Available methods are `"valid"` or `"same"`.
93
+ Defaults to `"valid"`.
94
+ data_format: String, either `"channels_last"` or `"channels_first"`.
95
+ The ordering of the dimensions in the inputs. `"channels_last"`
96
+ corresponds to inputs with shape `(batch, height, width, channels)`
97
+ while `"channels_first"` corresponds to inputs with shape
98
+ `(batch, channels, height, weight)`. Defaults to `"channels_last"`.
99
+
100
+ Returns:
101
+ Tuple of ints: The output shape of the pooling operation.
102
+
103
+ Examples:
104
+
105
+ # Basic usage with square pooling on a single image
106
+ >>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2))
107
+ (1, 2, 2, 1)
108
+
109
+ # Strided pooling on a single image with strides different from pool_size
110
+ >>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2), strides=(1, 1))
111
+ (1, 3, 3, 1)
112
+
113
+ # Pooling on a batch of images
114
+ >>> compute_pooling_output_shape((32, 4, 4, 3), (2, 2))
115
+ (32, 2, 2, 3)
116
+ """
117
+ strides = pool_size if strides is None else strides
118
+ input_shape_origin = list(input_shape)
119
+ input_shape = np.array(input_shape)
120
+ if data_format == "channels_last":
121
+ spatial_shape = input_shape[1:-1]
122
+ else:
123
+ spatial_shape = input_shape[2:]
124
+ none_dims = []
125
+ for i in range(len(spatial_shape)):
126
+ if spatial_shape[i] is None:
127
+ # Set `None` shape to a manual value so that we can run numpy
128
+ # computation on `spatial_shape`.
129
+ spatial_shape[i] = -1
130
+ none_dims.append(i)
131
+ pool_size = np.array(pool_size)
132
+ if padding == "valid":
133
+ output_spatial_shape = (
134
+ np.floor((spatial_shape - pool_size) / strides) + 1
135
+ )
136
+ for i in range(len(output_spatial_shape)):
137
+ if i not in none_dims and output_spatial_shape[i] < 0:
138
+ raise ValueError(
139
+ "Computed output size would be negative. Received: "
140
+ f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
141
+ )
142
+ elif padding == "same":
143
+ output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
144
+ else:
145
+ raise ValueError(
146
+ "Argument `padding` must be either 'valid' or 'same'. Received: "
147
+ f"padding={padding}"
148
+ )
149
+ output_spatial_shape = [int(i) for i in output_spatial_shape]
150
+ for i in none_dims:
151
+ output_spatial_shape[i] = None
152
+ output_spatial_shape = tuple(output_spatial_shape)
153
+ if data_format == "channels_last":
154
+ output_shape = (
155
+ (input_shape_origin[0],)
156
+ + output_spatial_shape
157
+ + (input_shape_origin[-1],)
158
+ )
159
+ else:
160
+ output_shape = (
161
+ input_shape_origin[0],
162
+ input_shape_origin[1],
163
+ ) + output_spatial_shape
164
+ return output_shape
165
+
166
+
167
+ def compute_conv_output_shape(
168
+ input_shape,
169
+ filters,
170
+ kernel_size,
171
+ strides=1,
172
+ padding="valid",
173
+ data_format="channels_last",
174
+ dilation_rate=1,
175
+ ):
176
+ """Compute the output shape of conv ops."""
177
+ if data_format == "channels_last":
178
+ spatial_shape = input_shape[1:-1]
179
+ kernel_shape = kernel_size + (input_shape[-1], filters)
180
+ else:
181
+ spatial_shape = input_shape[2:]
182
+ kernel_shape = kernel_size + (input_shape[1], filters)
183
+ if len(kernel_shape) != len(input_shape):
184
+ raise ValueError(
185
+ "Kernel shape must have the same length as input, but received "
186
+ f"kernel of shape {kernel_shape} and "
187
+ f"input of shape {input_shape}."
188
+ )
189
+ if isinstance(dilation_rate, int):
190
+ dilation_rate = (dilation_rate,) * len(spatial_shape)
191
+ if isinstance(strides, int):
192
+ strides = (strides,) * len(spatial_shape)
193
+ if len(dilation_rate) != len(spatial_shape):
194
+ raise ValueError(
195
+ "Dilation must be None, scalar or tuple/list of length of "
196
+ "inputs' spatial shape, but received "
197
+ f"`dilation_rate={dilation_rate}` and "
198
+ f"input of shape {input_shape}."
199
+ )
200
+ none_dims = []
201
+ spatial_shape = np.array(spatial_shape)
202
+ for i in range(len(spatial_shape)):
203
+ if spatial_shape[i] is None:
204
+ # Set `None` shape to a manual value so that we can run numpy
205
+ # computation on `spatial_shape`.
206
+ spatial_shape[i] = -1
207
+ none_dims.append(i)
208
+
209
+ kernel_spatial_shape = np.array(kernel_shape[:-2])
210
+ dilation_rate = np.array(dilation_rate)
211
+ if padding == "valid":
212
+ output_spatial_shape = (
213
+ np.floor(
214
+ (spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)
215
+ / strides
216
+ )
217
+ + 1
218
+ )
219
+ for i in range(len(output_spatial_shape)):
220
+ if i not in none_dims and output_spatial_shape[i] < 0:
221
+ raise ValueError(
222
+ "Computed output size would be negative. Received "
223
+ f"`inputs shape={input_shape}`, "
224
+ f"`kernel shape={kernel_shape}`, "
225
+ f"`dilation_rate={dilation_rate}`."
226
+ )
227
+ elif padding == "same" or padding == "causal":
228
+ output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
229
+ else:
230
+ raise ValueError(
231
+ "`padding` must be either `'valid'` or `'same'`. Received "
232
+ f"{padding}."
233
+ )
234
+ output_spatial_shape = [int(i) for i in output_spatial_shape]
235
+ for i in none_dims:
236
+ output_spatial_shape[i] = None
237
+ output_spatial_shape = tuple(output_spatial_shape)
238
+ if data_format == "channels_last":
239
+ output_shape = (
240
+ (input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)
241
+ )
242
+ else:
243
+ output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape
244
+ return output_shape
245
+
246
+
247
+ def compute_matmul_output_shape(shape1, shape2):
248
+ """Compute the output shape of a `matmul` operation.
249
+
250
+ Args:
251
+ shape1: Shape of the left operand.
252
+ shape2: Shape of the right operand.
253
+
254
+ Returns:
255
+ Tuple of ints: The output shape for the `matmul` operation.
256
+ """
257
+ if len(shape1) == 1:
258
+ shape1 = (1, shape1[0])
259
+ if len(shape2) == 1:
260
+ shape2 = (shape2[0], 1)
261
+ if (
262
+ shape1[-1] is not None
263
+ and shape2[-2] is not None
264
+ and shape1[-1] != shape2[-2]
265
+ ):
266
+ raise ValueError(
267
+ "Inner dimensions (`x1.shape[-1]` and `x2.shape[-2]`) must be "
268
+ f"equal, but received `x1.shape={shape1}` and "
269
+ f"`x2.shape={shape2}`."
270
+ )
271
+
272
+ leading_shape = broadcast_shapes(shape1[:-2], shape2[:-2])
273
+ last_2_dims_shape = [shape1[-2], shape2[-1]]
274
+ output_shape = leading_shape + last_2_dims_shape
275
+ if len(shape1) == 1:
276
+ del output_shape[-2]
277
+ if len(shape2) == 1:
278
+ del output_shape[-1]
279
+ return tuple(output_shape)
280
+
281
+
282
+ def compute_reshape_output_shape(input_shape, newshape, newshape_arg_name):
283
+ """Converts `-1` in `newshape` to either an actual dimension or `None`.
284
+
285
+ This utility does not special case the 0th dimension (batch size).
286
+ """
287
+ unknown_dim_count = newshape.count(-1)
288
+ if unknown_dim_count > 1:
289
+ raise ValueError(
290
+ "There must be at most one unknown dimension (-1) in "
291
+ f"{newshape_arg_name}. Received: {newshape_arg_name}={newshape}."
292
+ )
293
+
294
+ # If there is a None in input_shape, we can't infer what the -1 is
295
+ if None in input_shape:
296
+ return tuple(dim if dim != -1 else None for dim in newshape)
297
+
298
+ input_size = math.prod(input_shape)
299
+ # If the `newshape` is fully defined, return it
300
+ if unknown_dim_count == 0:
301
+ if input_size != math.prod(newshape):
302
+ raise ValueError(
303
+ "The total size of the tensor must be unchanged. Received: "
304
+ f"input_shape={input_shape}, {newshape_arg_name}={newshape}"
305
+ )
306
+ return newshape
307
+
308
+ # We have one -1 in `newshape`, compute the actual value
309
+ known_output_size = 1
310
+ unknown_dim_index = None
311
+ for index, dim in enumerate(newshape):
312
+ if dim == -1:
313
+ unknown_dim_index = index
314
+ else:
315
+ known_output_size *= dim
316
+
317
+ if known_output_size == 0 or input_size % known_output_size != 0:
318
+ raise ValueError(
319
+ "The total size of the tensor must be unchanged, however, the "
320
+ "input size cannot by divided by the specified dimensions in "
321
+ f"{newshape_arg_name}. Received: input_shape={input_shape}, "
322
+ f"{newshape_arg_name}={newshape}"
323
+ )
324
+
325
+ output_shape = list(newshape)
326
+ output_shape[unknown_dim_index] = input_size // known_output_size
327
+ return tuple(output_shape)
328
+
329
+
330
+ def compute_transpose_output_shape(input_shape, axes):
331
+ """Compute the output shape for the `transpose` operation.
332
+
333
+ Args:
334
+ input_shape: Input shape.
335
+ axes: Permutation of the dimensions for the `transpose` operation.
336
+
337
+ Returns:
338
+ Tuple of ints: The output shape after the `transpose` operation.
339
+ """
340
+ input_shape = list(input_shape)
341
+ if axes is None:
342
+ return tuple(input_shape[::-1])
343
+
344
+ if len(axes) != len(input_shape):
345
+ raise ValueError(
346
+ "axis must be a list of the same length as the input shape, "
347
+ f"expected {len(input_shape)}, but received {len(axes)}."
348
+ )
349
+ return tuple(input_shape[ax] for ax in axes)
350
+
351
+
352
+ def compute_take_along_axis_output_shape(input_shape, indices_shape, axis):
353
+ input_shape = list(input_shape)
354
+ indices_shape = list(indices_shape)
355
+ if axis is None:
356
+ input_shape = (
357
+ [None] if None in input_shape else [int(np.prod(input_shape))]
358
+ )
359
+
360
+ if len(input_shape) != len(indices_shape):
361
+ raise ValueError(
362
+ "`x` and `indices` must have the same number of dimensions, "
363
+ f"but receive shape {input_shape} and {indices_shape}."
364
+ )
365
+
366
+ input_shape[axis] = indices_shape[axis]
367
+ output_shape = broadcast_shapes(input_shape, indices_shape)
368
+ return output_shape
369
+
370
+
371
+ def reduce_shape(shape, axis=None, keepdims=False):
372
+ shape = list(shape)
373
+ if axis is None:
374
+ if keepdims:
375
+ return tuple([1 for _ in shape])
376
+ else:
377
+ return tuple([])
378
+
379
+ if keepdims:
380
+ for ax in axis:
381
+ shape[ax] = 1
382
+ return tuple(shape)
383
+ else:
384
+ for ax in sorted(axis, reverse=True):
385
+ del shape[ax]
386
+ return tuple(shape)
387
+
388
+
389
+ @keras_export("keras.utils.get_source_inputs")
390
+ def get_source_inputs(tensor):
391
+ """Returns the list of input tensors necessary to compute `tensor`.
392
+
393
+ Output will always be a list of tensors
394
+ (potentially with 1 element).
395
+
396
+ Args:
397
+ tensor: The tensor to start from.
398
+
399
+ Returns:
400
+ List of input tensors.
401
+ """
402
+ if not hasattr(tensor, "_keras_history"):
403
+ return tensor
404
+
405
+ operation, node_index, _ = tensor._keras_history
406
+ if not operation or not operation._inbound_nodes:
407
+ return [tensor]
408
+ else:
409
+ node = operation._inbound_nodes[node_index]
410
+ if node.is_input:
411
+ # Reached input node, stop recursion.
412
+ return tree.flatten(node.output_tensors)
413
+ else:
414
+ source_tensors = []
415
+ for tensor in node.input_tensors:
416
+ previous_sources = get_source_inputs(tensor)
417
+ # Avoid input redundancy.
418
+ for x in previous_sources:
419
+ if all(x is not t for t in source_tensors):
420
+ source_tensors.append(x)
421
+ return source_tensors
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/symbolic_arguments.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import tree
2
+ from keras.src.backend import KerasTensor
3
+
4
+
5
+ class SymbolicArguments:
6
+ def __init__(self, *args, **kwargs):
7
+ self.args = tree.map_structure(lambda x: x, args)
8
+ self.kwargs = tree.map_structure(lambda x: x, kwargs)
9
+ self._flat_arguments = tree.flatten((self.args, self.kwargs))
10
+
11
+ # Used to avoid expensive `tree` operations in the most common case.
12
+ if (
13
+ not self.kwargs
14
+ and len(self.args) == 1
15
+ and isinstance(self.args[0], KerasTensor)
16
+ ):
17
+ self._single_positional_tensor = self.args[0]
18
+ else:
19
+ self._single_positional_tensor = None
20
+
21
+ self.keras_tensors = []
22
+ for arg in self._flat_arguments:
23
+ if isinstance(arg, KerasTensor):
24
+ self.keras_tensors.append(arg)
25
+
26
+ def convert(self, conversion_fn):
27
+ args = tree.map_structure(conversion_fn, self.args)
28
+ kwargs = tree.map_structure(conversion_fn, self.kwargs)
29
+ return args, kwargs
30
+
31
+ def fill_in(self, tensor_dict):
32
+ """Maps KerasTensors to computed values using `tensor_dict`.
33
+
34
+ `tensor_dict` maps `KerasTensor` instances to their current values.
35
+ """
36
+ if self._single_positional_tensor is not None:
37
+ # Performance optimization for most common case.
38
+ # Approx. 70x faster.
39
+ return (tensor_dict[id(self._single_positional_tensor)],), {}
40
+
41
+ def switch_fn(x):
42
+ if isinstance(x, KerasTensor):
43
+ return tensor_dict.get(id(x), None)
44
+ return x
45
+
46
+ return self.convert(switch_fn)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__init__.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src.api_export import keras_export
2
+ from keras.src.optimizers.adadelta import Adadelta
3
+ from keras.src.optimizers.adafactor import Adafactor
4
+ from keras.src.optimizers.adagrad import Adagrad
5
+ from keras.src.optimizers.adam import Adam
6
+ from keras.src.optimizers.adamax import Adamax
7
+ from keras.src.optimizers.adamw import AdamW
8
+ from keras.src.optimizers.ftrl import Ftrl
9
+ from keras.src.optimizers.lion import Lion
10
+ from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer
11
+ from keras.src.optimizers.nadam import Nadam
12
+ from keras.src.optimizers.optimizer import Optimizer
13
+ from keras.src.optimizers.rmsprop import RMSprop
14
+ from keras.src.optimizers.sgd import SGD
15
+ from keras.src.saving import serialization_lib
16
+
17
+ ALL_OBJECTS = {
18
+ Optimizer,
19
+ Adam,
20
+ SGD,
21
+ RMSprop,
22
+ Adadelta,
23
+ AdamW,
24
+ Adagrad,
25
+ Adamax,
26
+ Adafactor,
27
+ Nadam,
28
+ Ftrl,
29
+ Lion,
30
+ LossScaleOptimizer,
31
+ }
32
+ ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS}
33
+
34
+
35
+ @keras_export("keras.optimizers.serialize")
36
+ def serialize(optimizer):
37
+ """Returns the optimizer configuration as a Python dict.
38
+
39
+ Args:
40
+ optimizer: An `Optimizer` instance to serialize.
41
+
42
+ Returns:
43
+ Python dict which contains the configuration of the optimizer.
44
+ """
45
+ return serialization_lib.serialize_keras_object(optimizer)
46
+
47
+
48
+ @keras_export("keras.optimizers.deserialize")
49
+ def deserialize(config, custom_objects=None):
50
+ """Returns a Keras optimizer object via its configuration.
51
+
52
+ Args:
53
+ config: Optimizer configuration dictionary.
54
+ custom_objects: Optional dictionary mapping names (strings) to custom
55
+ objects (classes and functions) to be considered during
56
+ deserialization.
57
+
58
+ Returns:
59
+ A Keras Optimizer instance.
60
+ """
61
+ # Make deserialization case-insensitive for built-in optimizers.
62
+ if config["class_name"].lower() in ALL_OBJECTS_DICT:
63
+ config["class_name"] = config["class_name"].lower()
64
+
65
+ return serialization_lib.deserialize_keras_object(
66
+ config,
67
+ module_objects=ALL_OBJECTS_DICT,
68
+ custom_objects=custom_objects,
69
+ )
70
+
71
+
72
+ @keras_export("keras.optimizers.get")
73
+ def get(identifier):
74
+ """Retrieves a Keras Optimizer instance.
75
+
76
+ Args:
77
+ identifier: Optimizer identifier, one of:
78
+ - String: name of an optimizer
79
+ - Dictionary: configuration dictionary.
80
+ - Keras Optimizer instance (it will be returned unchanged).
81
+
82
+ Returns:
83
+ A Keras Optimizer instance.
84
+ """
85
+ if identifier is None:
86
+ return None
87
+ elif isinstance(identifier, dict):
88
+ obj = deserialize(identifier)
89
+ elif isinstance(identifier, str):
90
+ config = {"class_name": identifier, "config": {}}
91
+ obj = deserialize(config)
92
+ else:
93
+ obj = identifier
94
+
95
+ if isinstance(obj, Optimizer):
96
+ return obj
97
+ raise ValueError(f"Could not interpret optimizer identifier: {identifier}")
98
+
99
+
100
+ # We will add this temporarily so that tensorflow packages that depend on
101
+ # estimators will continue to import (there are a large number). Note that
102
+ # Keras 3 will not work with the estimators API.
103
+ @keras_export(
104
+ [
105
+ "keras.optimizers.legacy.Adagrad",
106
+ "keras.optimizers.legacy.Adam",
107
+ "keras.optimizers.legacy.Ftrl",
108
+ "keras.optimizers.legacy.RMSprop",
109
+ "keras.optimizers.legacy.SGD",
110
+ "keras.optimizers.legacy.Optimizer",
111
+ ]
112
+ )
113
+ class LegacyOptimizerWarning:
114
+ def __init__(self, *args, **kwargs):
115
+ raise ImportError(
116
+ "`keras.optimizers.legacy` is not supported in Keras 3. When using "
117
+ "`tf.keras`, to continue using a `tf.keras.optimizers.legacy` "
118
+ "optimizer, you can install the `tf_keras` package (Keras 2) and "
119
+ "set the environment variable `TF_USE_LEGACY_KERAS=True` to "
120
+ "configure TensorFlow to use `tf_keras` when accessing `tf.keras`."
121
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.92 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adadelta.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adafactor.cpython-310.pyc ADDED
Binary file (5.65 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adagrad.cpython-310.pyc ADDED
Binary file (3.54 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adam.cpython-310.pyc ADDED
Binary file (4.9 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adamax.cpython-310.pyc ADDED
Binary file (4.57 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adamw.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/base_optimizer.cpython-310.pyc ADDED
Binary file (35.5 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/ftrl.cpython-310.pyc ADDED
Binary file (6.72 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/lamb.cpython-310.pyc ADDED
Binary file (4.42 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/lion.cpython-310.pyc ADDED
Binary file (4.46 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/loss_scale_optimizer.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/nadam.cpython-310.pyc ADDED
Binary file (4.94 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/optimizer.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/rmsprop.cpython-310.pyc ADDED
Binary file (4.81 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/sgd.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adadelta.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import ops
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.optimizers import optimizer
4
+
5
+
6
+ @keras_export(["keras.optimizers.Adadelta"])
7
+ class Adadelta(optimizer.Optimizer):
8
+ """Optimizer that implements the Adadelta algorithm.
9
+
10
+ Adadelta optimization is a stochastic gradient descent method that is based
11
+ on adaptive learning rate per dimension to address two drawbacks:
12
+
13
+ - The continual decay of learning rates throughout training.
14
+ - The need for a manually selected global learning rate.
15
+
16
+ Adadelta is a more robust extension of Adagrad that adapts learning rates
17
+ based on a moving window of gradient updates, instead of accumulating all
18
+ past gradients. This way, Adadelta continues learning even when many updates
19
+ have been done. Compared to Adagrad, in the original version of Adadelta you
20
+ don't have to set an initial learning rate. In this version, the initial
21
+ learning rate can be set, as in most other Keras optimizers.
22
+
23
+ Args:
24
+ learning_rate: A float, a
25
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
26
+ a callable that takes no arguments and returns the actual value to
27
+ use. The learning rate. Defaults to `0.001`. Note that `Adadelta`
28
+ tends to benefit from higher initial learning rate values compared
29
+ to other optimizers. To match the exact form in the original paper,
30
+ use 1.0.
31
+ rho: A floating point value. The decay rate. Defaults to `0.95`.
32
+ epsilon: Small floating point value for maintaining numerical stability.
33
+ {{base_optimizer_keyword_args}}
34
+
35
+ Reference:
36
+
37
+ - [Zeiler, 2012](http://arxiv.org/abs/1212.5701)
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ learning_rate=0.001,
43
+ rho=0.95,
44
+ epsilon=1e-7,
45
+ weight_decay=None,
46
+ clipnorm=None,
47
+ clipvalue=None,
48
+ global_clipnorm=None,
49
+ use_ema=False,
50
+ ema_momentum=0.99,
51
+ ema_overwrite_frequency=None,
52
+ loss_scale_factor=None,
53
+ gradient_accumulation_steps=None,
54
+ name="adadelta",
55
+ **kwargs,
56
+ ):
57
+ super().__init__(
58
+ learning_rate=learning_rate,
59
+ weight_decay=weight_decay,
60
+ clipnorm=clipnorm,
61
+ clipvalue=clipvalue,
62
+ global_clipnorm=global_clipnorm,
63
+ use_ema=use_ema,
64
+ ema_momentum=ema_momentum,
65
+ ema_overwrite_frequency=ema_overwrite_frequency,
66
+ loss_scale_factor=loss_scale_factor,
67
+ gradient_accumulation_steps=gradient_accumulation_steps,
68
+ name=name,
69
+ **kwargs,
70
+ )
71
+ self.rho = rho
72
+ self.epsilon = epsilon
73
+
74
+ def build(self, var_list):
75
+ if self.built:
76
+ return
77
+ super().build(var_list)
78
+ self._accumulated_grads = []
79
+ self._accumulated_delta_vars = []
80
+ for var in var_list:
81
+ self._accumulated_grads.append(
82
+ self.add_variable_from_reference(var, "accumulated_grad")
83
+ )
84
+ self._accumulated_delta_vars.append(
85
+ self.add_variable_from_reference(var, "accumulated_delta_var")
86
+ )
87
+
88
+ def update_step(self, grad, variable, learning_rate):
89
+ """Update step given gradient and the associated model variable."""
90
+ lr = ops.cast(learning_rate, variable.dtype)
91
+ grad = ops.cast(grad, variable.dtype)
92
+
93
+ rho = self.rho
94
+ accumulated_grad = self._accumulated_grads[
95
+ self._get_variable_index(variable)
96
+ ]
97
+ accumulated_delta_var = self._accumulated_delta_vars[
98
+ self._get_variable_index(variable)
99
+ ]
100
+
101
+ def rms(x):
102
+ return ops.sqrt(ops.add(x, self.epsilon))
103
+
104
+ self.assign(
105
+ accumulated_grad,
106
+ ops.add(
107
+ rho * accumulated_grad, ops.multiply(1 - rho, ops.square(grad))
108
+ ),
109
+ )
110
+ delta_var = ops.negative(
111
+ ops.divide(
112
+ ops.multiply(rms(accumulated_delta_var), grad),
113
+ rms(accumulated_grad),
114
+ )
115
+ )
116
+ self.assign(
117
+ accumulated_delta_var,
118
+ ops.add(
119
+ ops.multiply(rho, accumulated_delta_var),
120
+ ops.multiply(1 - rho, ops.square(delta_var)),
121
+ ),
122
+ )
123
+ self.assign_add(variable, ops.multiply(lr, delta_var))
124
+
125
+ def get_config(self):
126
+ config = super().get_config()
127
+
128
+ config.update(
129
+ {
130
+ "rho": self.rho,
131
+ "epsilon": self.epsilon,
132
+ }
133
+ )
134
+ return config
135
+
136
+
137
+ Adadelta.__doc__ = Adadelta.__doc__.replace(
138
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
139
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adafactor.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src import ops
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.optimizers import optimizer
5
+
6
+
7
+ @keras_export(["keras.optimizers.Adafactor"])
8
+ class Adafactor(optimizer.Optimizer):
9
+ """Optimizer that implements the Adafactor algorithm.
10
+
11
+ Adafactor is commonly used in NLP tasks, and has the advantage
12
+ of taking less memory because it only saves partial information of previous
13
+ gradients.
14
+
15
+ The default argument setup is based on the original paper (see reference).
16
+ When gradients are of dimension > 2, Adafactor optimizer will delete the
17
+ last 2 dimensions separately in its accumulator variables.
18
+
19
+ Args:
20
+ learning_rate: A float, a
21
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
22
+ a callable that takes no arguments and returns the actual value to
23
+ use. The learning rate. Defaults to `0.001`.
24
+ beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.
25
+ epsilon_1: float, defaults to 1e-30. A small offset to keep denominator
26
+ away from 0.
27
+ epsilon_2: float, defaults to 1e-3. A small offset to avoid learning
28
+ rate becoming too small by time.
29
+ clip_threshold: float, defaults to 1.0. Clipping threshold. This is a
30
+ part of Adafactor algorithm, independent from `clipnorm`,
31
+ `clipvalue`, and `global_clipnorm`.
32
+ relative_step: bool, defaults to `True`. If `learning_rate` is a
33
+ constant and `relative_step=True`, learning rate will be adjusted
34
+ based on current iterations. This is a default learning rate decay
35
+ in Adafactor.
36
+ {{base_optimizer_keyword_args}}
37
+
38
+ Reference:
39
+
40
+ - [Shazeer, Noam et al., 2018](https://arxiv.org/abs/1804.04235).
41
+
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ learning_rate=0.001,
47
+ beta_2_decay=-0.8,
48
+ epsilon_1=1e-30,
49
+ epsilon_2=1e-3,
50
+ clip_threshold=1.0,
51
+ relative_step=True,
52
+ weight_decay=None,
53
+ clipnorm=None,
54
+ clipvalue=None,
55
+ global_clipnorm=None,
56
+ use_ema=False,
57
+ ema_momentum=0.99,
58
+ ema_overwrite_frequency=None,
59
+ loss_scale_factor=None,
60
+ gradient_accumulation_steps=None,
61
+ name="adafactor",
62
+ **kwargs,
63
+ ):
64
+ super().__init__(
65
+ learning_rate=learning_rate,
66
+ name=name,
67
+ weight_decay=weight_decay,
68
+ clipnorm=clipnorm,
69
+ clipvalue=clipvalue,
70
+ global_clipnorm=global_clipnorm,
71
+ use_ema=use_ema,
72
+ ema_momentum=ema_momentum,
73
+ ema_overwrite_frequency=ema_overwrite_frequency,
74
+ loss_scale_factor=loss_scale_factor,
75
+ gradient_accumulation_steps=gradient_accumulation_steps,
76
+ **kwargs,
77
+ )
78
+ self.beta_2_decay = beta_2_decay
79
+ self.epsilon_1 = epsilon_1
80
+ self.epsilon_2 = epsilon_2
81
+ self.clip_threshold = clip_threshold
82
+ self.relative_step = relative_step
83
+
84
+ def build(self, var_list):
85
+ """Initialize optimizer variables.
86
+
87
+ Adam optimizer has 3 types of variables: momentums, velocities and
88
+ velocity_hat (only set when amsgrad is applied),
89
+
90
+ Args:
91
+ var_list: list of model variables to build Adam variables on.
92
+ """
93
+ if self.built:
94
+ return
95
+ super().build(var_list)
96
+ self._r = []
97
+ self._c = []
98
+ self._v = []
99
+ for var in var_list:
100
+ if len(var.shape) < 2:
101
+ # Don't factor if variable is of dimension < 2, but we still
102
+ # need to create dummy variables as placeholder.
103
+ with backend.name_scope(self.name, caller=self):
104
+ self._r.append(
105
+ backend.Variable(0, name=var.name, trainable=False)
106
+ )
107
+ self._c.append(
108
+ backend.Variable(0, name=var.name, trainable=False)
109
+ )
110
+ else:
111
+ # Always factor the last 2 dimensions.
112
+ r_shape = var.shape[:-1]
113
+ c_shape = var.shape[:-2] + (var.shape[-1],)
114
+ self._r.append(
115
+ self.add_variable(
116
+ shape=r_shape,
117
+ dtype=var.dtype,
118
+ name=var.name,
119
+ )
120
+ )
121
+ self._c.append(
122
+ self.add_variable(
123
+ shape=c_shape,
124
+ dtype=var.dtype,
125
+ name=var.name,
126
+ )
127
+ )
128
+ self._v.append(
129
+ self.add_variable_from_reference(
130
+ reference_variable=var, name="velocity"
131
+ )
132
+ )
133
+
134
+ def _rms(self, x):
135
+ return ops.sqrt(ops.mean(ops.square(x)))
136
+
137
+ def update_step(self, gradient, variable, learning_rate):
138
+ """Update step given gradient and the associated model variable."""
139
+
140
+ lr = ops.cast(learning_rate, variable.dtype)
141
+ gradient = ops.cast(gradient, variable.dtype)
142
+ epsilon_2 = ops.cast(self.epsilon_2, variable.dtype)
143
+ one = ops.cast(1.0, variable.dtype)
144
+ local_step = ops.cast(self.iterations + 1, variable.dtype)
145
+ if not callable(self._learning_rate) and self.relative_step:
146
+ lr = ops.minimum(lr, 1 / ops.sqrt(local_step))
147
+
148
+ r = self._r[self._get_variable_index(variable)]
149
+ c = self._c[self._get_variable_index(variable)]
150
+ v = self._v[self._get_variable_index(variable)]
151
+
152
+ rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step))
153
+ alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t
154
+ regulated_grad_square = ops.add(ops.square(gradient), self.epsilon_1)
155
+ beta_2_t = 1 - ops.power(local_step, self.beta_2_decay)
156
+
157
+ if len(variable.shape) >= 2:
158
+ # `r` deletes the last dimension of gradient, so it is of shape
159
+ # `gradient.shape[:-1]`.
160
+ self.assign(
161
+ r,
162
+ beta_2_t * r
163
+ + (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-1),
164
+ )
165
+ # `c` deletes the second last dimension of gradient, so it is of
166
+ # shape `gradient.shape[:-2] + gradient.shape[-1]`.
167
+ self.assign(
168
+ c,
169
+ beta_2_t * c
170
+ + (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-2),
171
+ )
172
+ self.assign(
173
+ v,
174
+ ops.expand_dims(
175
+ r / ops.mean(r, axis=-1, keepdims=True), axis=-1
176
+ )
177
+ * ops.expand_dims(c, -2),
178
+ )
179
+ else:
180
+ self.assign(
181
+ v, beta_2_t * v + (1 - beta_2_t) * regulated_grad_square
182
+ )
183
+
184
+ u_t = ops.divide(gradient, ops.sqrt(v))
185
+ u_t_hat = ops.divide(
186
+ u_t,
187
+ ops.maximum(one, ops.divide(self._rms(u_t), self.clip_threshold)),
188
+ )
189
+ self.assign_sub(variable, ops.multiply(alpha_t, u_t_hat))
190
+
191
+ def get_config(self):
192
+ config = super().get_config()
193
+
194
+ config.update(
195
+ {
196
+ "beta_2_decay": self.beta_2_decay,
197
+ "epsilon_1": self.epsilon_1,
198
+ "epsilon_2": self.epsilon_2,
199
+ "clip_threshold": self.clip_threshold,
200
+ "relative_step": self.relative_step,
201
+ }
202
+ )
203
+ return config
204
+
205
+
206
+ Adafactor.__doc__ = Adafactor.__doc__.replace(
207
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
208
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adagrad.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import initializers
2
+ from keras.src import ops
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.optimizers import optimizer
5
+
6
+
7
+ @keras_export(["keras.optimizers.Adagrad"])
8
+ class Adagrad(optimizer.Optimizer):
9
+ """Optimizer that implements the Adagrad algorithm.
10
+
11
+ Adagrad is an optimizer with parameter-specific learning rates,
12
+ which are adapted relative to how frequently a parameter gets
13
+ updated during training. The more updates a parameter receives,
14
+ the smaller the updates.
15
+
16
+ Args:
17
+ learning_rate: A float, a
18
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
19
+ a callable that takes no arguments and returns the actual value to
20
+ use. The learning rate. Defaults to `0.001`. Note that `Adagrad`
21
+ tends to benefit from higher initial learning rate values compared
22
+ to other optimizers. To match the exact form in the original paper,
23
+ use `1.0`.
24
+ initial_accumulator_value: Floating point value. Starting value for the
25
+ accumulators (per-parameter momentum values). Must be non-negative.
26
+ epsilon: Small floating point value for maintaining numerical stability.
27
+ {{base_optimizer_keyword_args}}
28
+
29
+ Reference:
30
+
31
+ - [Duchi et al., 2011](
32
+ http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ learning_rate=0.001,
38
+ initial_accumulator_value=0.1,
39
+ epsilon=1e-7,
40
+ weight_decay=None,
41
+ clipnorm=None,
42
+ clipvalue=None,
43
+ global_clipnorm=None,
44
+ use_ema=False,
45
+ ema_momentum=0.99,
46
+ ema_overwrite_frequency=None,
47
+ loss_scale_factor=None,
48
+ gradient_accumulation_steps=None,
49
+ name="adagrad",
50
+ **kwargs,
51
+ ):
52
+ super().__init__(
53
+ learning_rate=learning_rate,
54
+ weight_decay=weight_decay,
55
+ clipnorm=clipnorm,
56
+ clipvalue=clipvalue,
57
+ global_clipnorm=global_clipnorm,
58
+ use_ema=use_ema,
59
+ ema_momentum=ema_momentum,
60
+ ema_overwrite_frequency=ema_overwrite_frequency,
61
+ loss_scale_factor=loss_scale_factor,
62
+ gradient_accumulation_steps=gradient_accumulation_steps,
63
+ name=name,
64
+ **kwargs,
65
+ )
66
+ self.initial_accumulator_value = initial_accumulator_value
67
+ self.epsilon = epsilon
68
+
69
+ def build(self, var_list):
70
+ if self.built:
71
+ return
72
+ super().build(var_list)
73
+ self._accumulators = []
74
+ initializer = initializers.Constant(self.initial_accumulator_value)
75
+ for var in var_list:
76
+ self._accumulators.append(
77
+ self.add_variable(
78
+ shape=var.shape,
79
+ initializer=initializer,
80
+ dtype=var.dtype,
81
+ name="accumulator",
82
+ )
83
+ )
84
+
85
+ def update_step(self, gradient, variable, learning_rate):
86
+ """Update step given gradient and the associated model variable."""
87
+ lr = ops.cast(learning_rate, variable.dtype)
88
+ gradient = ops.cast(gradient, variable.dtype)
89
+
90
+ accumulator = self._accumulators[self._get_variable_index(variable)]
91
+
92
+ self.assign_add(accumulator, ops.square(gradient))
93
+ self.assign_sub(
94
+ variable,
95
+ ops.divide(
96
+ ops.multiply(lr, gradient),
97
+ ops.sqrt(ops.add(accumulator, self.epsilon)),
98
+ ),
99
+ )
100
+
101
+ def get_config(self):
102
+ config = super().get_config()
103
+
104
+ config.update(
105
+ {
106
+ "initial_accumulator_value": self.initial_accumulator_value,
107
+ "epsilon": self.epsilon,
108
+ }
109
+ )
110
+ return config
111
+
112
+
113
+ Adagrad.__doc__ = Adagrad.__doc__.replace(
114
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
115
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adam.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import ops
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.optimizers import optimizer
4
+
5
+
6
+ @keras_export(["keras.optimizers.Adam"])
7
+ class Adam(optimizer.Optimizer):
8
+ """Optimizer that implements the Adam algorithm.
9
+
10
+ Adam optimization is a stochastic gradient descent method that is based on
11
+ adaptive estimation of first-order and second-order moments.
12
+
13
+ According to
14
+ [Kingma et al., 2014](http://arxiv.org/abs/1412.6980),
15
+ the method is "*computationally
16
+ efficient, has little memory requirement, invariant to diagonal rescaling of
17
+ gradients, and is well suited for problems that are large in terms of
18
+ data/parameters*".
19
+
20
+ Args:
21
+ learning_rate: A float, a
22
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
23
+ a callable that takes no arguments and returns the actual value to
24
+ use. The learning rate. Defaults to `0.001`.
25
+ beta_1: A float value or a constant float tensor, or a callable
26
+ that takes no arguments and returns the actual value to use. The
27
+ exponential decay rate for the 1st moment estimates. Defaults to
28
+ `0.9`.
29
+ beta_2: A float value or a constant float tensor, or a callable
30
+ that takes no arguments and returns the actual value to use. The
31
+ exponential decay rate for the 2nd moment estimates. Defaults to
32
+ `0.999`.
33
+ epsilon: A small constant for numerical stability. This epsilon is
34
+ "epsilon hat" in the Kingma and Ba paper (in the formula just before
35
+ Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults
36
+ to `1e-7`.
37
+ amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm
38
+ from the paper "On the Convergence of Adam and beyond". Defaults
39
+ to `False`.
40
+ {{base_optimizer_keyword_args}}
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ learning_rate=0.001,
46
+ beta_1=0.9,
47
+ beta_2=0.999,
48
+ epsilon=1e-7,
49
+ amsgrad=False,
50
+ weight_decay=None,
51
+ clipnorm=None,
52
+ clipvalue=None,
53
+ global_clipnorm=None,
54
+ use_ema=False,
55
+ ema_momentum=0.99,
56
+ ema_overwrite_frequency=None,
57
+ loss_scale_factor=None,
58
+ gradient_accumulation_steps=None,
59
+ name="adam",
60
+ **kwargs,
61
+ ):
62
+ super().__init__(
63
+ learning_rate=learning_rate,
64
+ name=name,
65
+ weight_decay=weight_decay,
66
+ clipnorm=clipnorm,
67
+ clipvalue=clipvalue,
68
+ global_clipnorm=global_clipnorm,
69
+ use_ema=use_ema,
70
+ ema_momentum=ema_momentum,
71
+ ema_overwrite_frequency=ema_overwrite_frequency,
72
+ loss_scale_factor=loss_scale_factor,
73
+ gradient_accumulation_steps=gradient_accumulation_steps,
74
+ **kwargs,
75
+ )
76
+ self.beta_1 = beta_1
77
+ self.beta_2 = beta_2
78
+ self.epsilon = epsilon
79
+ self.amsgrad = amsgrad
80
+
81
+ def build(self, var_list):
82
+ """Initialize optimizer variables.
83
+
84
+ Adam optimizer has 3 types of variables: momentums, velocities and
85
+ velocity_hat (only set when amsgrad is applied),
86
+
87
+ Args:
88
+ var_list: list of model variables to build Adam variables on.
89
+ """
90
+ if self.built:
91
+ return
92
+ super().build(var_list)
93
+ self._momentums = []
94
+ self._velocities = []
95
+ for var in var_list:
96
+ self._momentums.append(
97
+ self.add_variable_from_reference(
98
+ reference_variable=var, name="momentum"
99
+ )
100
+ )
101
+ self._velocities.append(
102
+ self.add_variable_from_reference(
103
+ reference_variable=var, name="velocity"
104
+ )
105
+ )
106
+ if self.amsgrad:
107
+ self._velocity_hats = []
108
+ for var in var_list:
109
+ self._velocity_hats.append(
110
+ self.add_variable_from_reference(
111
+ reference_variable=var, name="velocity_hat"
112
+ )
113
+ )
114
+
115
+ def update_step(self, gradient, variable, learning_rate):
116
+ """Update step given gradient and the associated model variable."""
117
+ lr = ops.cast(learning_rate, variable.dtype)
118
+ gradient = ops.cast(gradient, variable.dtype)
119
+ local_step = ops.cast(self.iterations + 1, variable.dtype)
120
+ beta_1_power = ops.power(
121
+ ops.cast(self.beta_1, variable.dtype), local_step
122
+ )
123
+ beta_2_power = ops.power(
124
+ ops.cast(self.beta_2, variable.dtype), local_step
125
+ )
126
+
127
+ m = self._momentums[self._get_variable_index(variable)]
128
+ v = self._velocities[self._get_variable_index(variable)]
129
+
130
+ alpha = lr * ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)
131
+
132
+ self.assign_add(
133
+ m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1)
134
+ )
135
+ self.assign_add(
136
+ v,
137
+ ops.multiply(
138
+ ops.subtract(ops.square(gradient), v), 1 - self.beta_2
139
+ ),
140
+ )
141
+ if self.amsgrad:
142
+ v_hat = self._velocity_hats[self._get_variable_index(variable)]
143
+ self.assign(v_hat, ops.maximum(v_hat, v))
144
+ v = v_hat
145
+ self.assign_sub(
146
+ variable,
147
+ ops.divide(
148
+ ops.multiply(m, alpha), ops.add(ops.sqrt(v), self.epsilon)
149
+ ),
150
+ )
151
+
152
+ def get_config(self):
153
+ config = super().get_config()
154
+ config.update(
155
+ {
156
+ "beta_1": self.beta_1,
157
+ "beta_2": self.beta_2,
158
+ "epsilon": self.epsilon,
159
+ "amsgrad": self.amsgrad,
160
+ }
161
+ )
162
+ return config
163
+
164
+
165
+ Adam.__doc__ = Adam.__doc__.replace(
166
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
167
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adamax.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import ops
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.optimizers import optimizer
4
+
5
+
6
+ @keras_export(["keras.optimizers.Adamax"])
7
+ class Adamax(optimizer.Optimizer):
8
+ """Optimizer that implements the Adamax algorithm.
9
+
10
+ Adamax, a variant of Adam based on the infinity norm, is a first-order
11
+ gradient-based optimization method. Due to its capability of adjusting the
12
+ learning rate based on data characteristics, it is suited to learn
13
+ time-variant process, e.g., speech data with dynamically changed noise
14
+ conditions. Default parameters follow those provided in the paper (see
15
+ references below).
16
+
17
+ Initialization:
18
+
19
+ ```python
20
+ m = 0 # Initialize initial 1st moment vector
21
+ u = 0 # Initialize the exponentially weighted infinity norm
22
+ t = 0 # Initialize timestep
23
+ ```
24
+
25
+ The update rule for parameter `w` with gradient `g` is described at the end
26
+ of section 7.1 of the paper (see the reference section):
27
+
28
+ ```python
29
+ t += 1
30
+ m = beta1 * m + (1 - beta) * g
31
+ u = max(beta2 * u, abs(g))
32
+ current_lr = learning_rate / (1 - beta1 ** t)
33
+ w = w - current_lr * m / (u + epsilon)
34
+ ```
35
+
36
+ Args:
37
+ learning_rate: A float, a
38
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
39
+ a callable that takes no arguments and returns the actual value to
40
+ use. The learning rate. Defaults to `0.001`.
41
+ beta_1: A float value or a constant float tensor. The exponential decay
42
+ rate for the 1st moment estimates.
43
+ beta_2: A float value or a constant float tensor. The exponential decay
44
+ rate for the exponentially weighted infinity norm.
45
+ epsilon: A small constant for numerical stability.
46
+ {{base_optimizer_keyword_args}}
47
+
48
+ Reference:
49
+
50
+ - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ learning_rate=0.001,
56
+ beta_1=0.9,
57
+ beta_2=0.999,
58
+ epsilon=1e-7,
59
+ weight_decay=None,
60
+ clipnorm=None,
61
+ clipvalue=None,
62
+ global_clipnorm=None,
63
+ use_ema=False,
64
+ ema_momentum=0.99,
65
+ ema_overwrite_frequency=None,
66
+ loss_scale_factor=None,
67
+ gradient_accumulation_steps=None,
68
+ name="adamax",
69
+ **kwargs,
70
+ ):
71
+ super().__init__(
72
+ learning_rate=learning_rate,
73
+ name=name,
74
+ weight_decay=weight_decay,
75
+ clipnorm=clipnorm,
76
+ clipvalue=clipvalue,
77
+ global_clipnorm=global_clipnorm,
78
+ use_ema=use_ema,
79
+ ema_momentum=ema_momentum,
80
+ ema_overwrite_frequency=ema_overwrite_frequency,
81
+ loss_scale_factor=loss_scale_factor,
82
+ gradient_accumulation_steps=gradient_accumulation_steps,
83
+ **kwargs,
84
+ )
85
+ self.beta_1 = beta_1
86
+ self.beta_2 = beta_2
87
+ self.epsilon = epsilon
88
+
89
+ def build(self, var_list):
90
+ """Initialize optimizer variables.
91
+
92
+ Adamax optimizer has 2 types of variables: momentums (denoted as m),
93
+ exponentially weighted infinity norm (denoted as u).
94
+
95
+ Args:
96
+ var_list: list of model variables to build Adamax variables on.
97
+ """
98
+ if self.built:
99
+ return
100
+ super().build(var_list)
101
+ self._m = []
102
+ self._u = []
103
+ for var in var_list:
104
+ self._m.append(
105
+ self.add_variable_from_reference(
106
+ reference_variable=var, name="momentum"
107
+ )
108
+ )
109
+ self._u.append(
110
+ self.add_variable_from_reference(
111
+ reference_variable=var, name="norm"
112
+ )
113
+ )
114
+
115
+ def update_step(self, gradient, variable, learning_rate):
116
+ """Update step given gradient and the associated model variable."""
117
+ lr = ops.cast(learning_rate, variable.dtype)
118
+ gradient = ops.cast(gradient, variable.dtype)
119
+ local_step = ops.cast(self.iterations + 1, variable.dtype)
120
+ beta_1_power = ops.power(
121
+ ops.cast(self.beta_1, variable.dtype), local_step
122
+ )
123
+
124
+ m = self._m[self._get_variable_index(variable)]
125
+ u = self._u[self._get_variable_index(variable)]
126
+
127
+ self.assign_add(
128
+ m, ops.multiply(ops.subtract(gradient, m), (1 - self.beta_1))
129
+ )
130
+ self.assign(
131
+ u, ops.maximum(ops.multiply(self.beta_2, u), ops.abs(gradient))
132
+ )
133
+ self.assign_sub(
134
+ variable,
135
+ ops.divide(
136
+ ops.multiply(lr, m),
137
+ ops.multiply((1 - beta_1_power), ops.add(u, self.epsilon)),
138
+ ),
139
+ )
140
+
141
+ def get_config(self):
142
+ config = super().get_config()
143
+
144
+ config.update(
145
+ {
146
+ "beta_1": self.beta_1,
147
+ "beta_2": self.beta_2,
148
+ "epsilon": self.epsilon,
149
+ }
150
+ )
151
+ return config
152
+
153
+
154
+ Adamax.__doc__ = Adamax.__doc__.replace(
155
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
156
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adamw.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src.api_export import keras_export
2
+ from keras.src.optimizers import adam
3
+ from keras.src.optimizers import optimizer
4
+
5
+
6
+ @keras_export(["keras.optimizers.AdamW"])
7
+ class AdamW(adam.Adam):
8
+ """Optimizer that implements the AdamW algorithm.
9
+
10
+ AdamW optimization is a stochastic gradient descent method that is based on
11
+ adaptive estimation of first-order and second-order moments with an added
12
+ method to decay weights per the techniques discussed in the paper,
13
+ 'Decoupled Weight Decay Regularization' by
14
+ [Loshchilov, Hutter et al., 2019](https://arxiv.org/abs/1711.05101).
15
+
16
+ According to
17
+ [Kingma et al., 2014](http://arxiv.org/abs/1412.6980),
18
+ the underlying Adam method is "*computationally
19
+ efficient, has little memory requirement, invariant to diagonal rescaling of
20
+ gradients, and is well suited for problems that are large in terms of
21
+ data/parameters*".
22
+
23
+ Args:
24
+ learning_rate: A float, a
25
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
26
+ a callable that takes no arguments and returns the actual value to
27
+ use. The learning rate. Defaults to `0.001`.
28
+ beta_1: A float value or a constant float tensor, or a callable
29
+ that takes no arguments and returns the actual value to use. The
30
+ exponential decay rate for the 1st moment estimates.
31
+ Defaults to `0.9`.
32
+ beta_2: A float value or a constant float tensor, or a callable
33
+ that takes no arguments and returns the actual value to use. The
34
+ exponential decay rate for the 2nd moment estimates.
35
+ Defaults to `0.999`.
36
+ epsilon: A small constant for numerical stability. This epsilon is
37
+ "epsilon hat" in the Kingma and Ba paper (in the formula just
38
+ before Section 2.1), not the epsilon in Algorithm 1 of the paper.
39
+ Defaults to 1e-7.
40
+ amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm
41
+ from the paper "On the Convergence of Adam and beyond".
42
+ Defaults to `False`.
43
+ {{base_optimizer_keyword_args}}
44
+
45
+ References:
46
+
47
+ - [Loshchilov et al., 2019](https://arxiv.org/abs/1711.05101)
48
+ - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) for `adam`
49
+ - [Reddi et al., 2018](
50
+ https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ learning_rate=0.001,
56
+ weight_decay=0.004,
57
+ beta_1=0.9,
58
+ beta_2=0.999,
59
+ epsilon=1e-7,
60
+ amsgrad=False,
61
+ clipnorm=None,
62
+ clipvalue=None,
63
+ global_clipnorm=None,
64
+ use_ema=False,
65
+ ema_momentum=0.99,
66
+ ema_overwrite_frequency=None,
67
+ loss_scale_factor=None,
68
+ gradient_accumulation_steps=None,
69
+ name="adamw",
70
+ **kwargs,
71
+ ):
72
+ super().__init__(
73
+ learning_rate=learning_rate,
74
+ beta_1=beta_1,
75
+ beta_2=beta_2,
76
+ epsilon=epsilon,
77
+ amsgrad=amsgrad,
78
+ name=name,
79
+ weight_decay=weight_decay,
80
+ clipnorm=clipnorm,
81
+ clipvalue=clipvalue,
82
+ global_clipnorm=global_clipnorm,
83
+ use_ema=use_ema,
84
+ ema_momentum=ema_momentum,
85
+ ema_overwrite_frequency=ema_overwrite_frequency,
86
+ loss_scale_factor=loss_scale_factor,
87
+ gradient_accumulation_steps=gradient_accumulation_steps,
88
+ **kwargs,
89
+ )
90
+
91
+ if self.weight_decay is None:
92
+ raise ValueError(
93
+ "Argument `weight_decay` must be a float. Received: "
94
+ "weight_decay=None"
95
+ )
96
+
97
+
98
+ AdamW.__doc__ = AdamW.__doc__.replace(
99
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
100
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/base_optimizer.py ADDED
@@ -0,0 +1,1102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import warnings
3
+
4
+ from keras.src import backend
5
+ from keras.src import initializers
6
+ from keras.src import ops
7
+ from keras.src.optimizers.schedules import learning_rate_schedule
8
+ from keras.src.saving import serialization_lib
9
+ from keras.src.saving.keras_saveable import KerasSaveable
10
+ from keras.src.utils import tracking
11
+ from keras.src.utils.naming import auto_name
12
+
13
+
14
+ class BaseOptimizer(KerasSaveable):
15
+ """Abstract optimizer base class.
16
+
17
+ If you intend to create your own optimization algorithm, please inherit from
18
+ this class and override the following methods:
19
+
20
+ - `build`: Create your optimizer-related variables, such as momentum
21
+ variables in the SGD optimizer.
22
+ - `update_step`: Implement your optimizer's variable updating logic.
23
+ - `get_config`: serialization of the optimizer.
24
+
25
+ Example:
26
+
27
+ ```python
28
+ class SGD(Optimizer):
29
+ def __init__(self, **kwargs):
30
+ super().__init__(**kwargs)
31
+ self.momentum = 0.9
32
+
33
+ def build(self, variables):
34
+ super().build(variables)
35
+ self.momentums = []
36
+ for variable in variables:
37
+ self.momentums.append(
38
+ self.add_variable_from_reference(
39
+ reference_variable=variable, name="momentum"
40
+ )
41
+ )
42
+
43
+ def update_step(self, gradient, variable, learning_rate):
44
+ learning_rate = ops.cast(learning_rate, variable.dtype)
45
+ gradient = ops.cast(gradient, variable.dtype)
46
+ m = self.momentums[self._get_variable_index(variable)]
47
+ self.assign(
48
+ m,
49
+ ops.subtract(
50
+ ops.multiply(m, ops.cast(self.momentum, variable.dtype)),
51
+ ops.multiply(gradient, learning_rate),
52
+ ),
53
+ )
54
+ self.assign_add(variable, m)
55
+
56
+ def get_config(self):
57
+ config = super().get_config()
58
+ config.update(
59
+ {
60
+ "momentum": self.momentum,
61
+ "nesterov": self.nesterov,
62
+ }
63
+ )
64
+ return config
65
+ ```
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ learning_rate,
71
+ weight_decay=None,
72
+ clipnorm=None,
73
+ clipvalue=None,
74
+ global_clipnorm=None,
75
+ use_ema=False,
76
+ ema_momentum=0.99,
77
+ ema_overwrite_frequency=None,
78
+ loss_scale_factor=None,
79
+ gradient_accumulation_steps=None,
80
+ name=None,
81
+ **kwargs,
82
+ ):
83
+ self._lock = False
84
+
85
+ if kwargs.pop("decay", None) is not None:
86
+ warnings.warn(
87
+ "Argument `decay` is no longer supported and will be ignored."
88
+ )
89
+ if kwargs:
90
+ raise ValueError(f"Argument(s) not recognized: {kwargs}")
91
+
92
+ if name is None:
93
+ name = auto_name(self.__class__.__name__)
94
+ self.name = name
95
+ self.weight_decay = weight_decay
96
+ self.clipnorm = clipnorm
97
+ self.global_clipnorm = global_clipnorm
98
+ self.clipvalue = clipvalue
99
+ self.use_ema = use_ema
100
+ self.loss_scale_factor = loss_scale_factor
101
+ self.gradient_accumulation_steps = gradient_accumulation_steps
102
+
103
+ if gradient_accumulation_steps:
104
+ if not gradient_accumulation_steps >= 2:
105
+ raise ValueError(
106
+ "`gradient_accumulation_steps` must be an integer >= 2. "
107
+ "Received: gradient_accumulation_steps="
108
+ f"{gradient_accumulation_steps}"
109
+ )
110
+
111
+ if use_ema:
112
+ # Verify the arguments related to EMA.
113
+ if ema_momentum > 1 or ema_momentum < 0:
114
+ raise ValueError(
115
+ "`ema_momentum` must be in the range [0, 1]. "
116
+ f"Received: ema_momentum={ema_momentum}"
117
+ )
118
+ if ema_overwrite_frequency and (
119
+ not isinstance(ema_overwrite_frequency, int)
120
+ or ema_overwrite_frequency < 1
121
+ ):
122
+ raise ValueError(
123
+ "`ema_overwrite_frequency` must be an integer >= 1 or "
124
+ "None. Received: ema_overwrite_frequency="
125
+ f"{ema_overwrite_frequency}"
126
+ )
127
+ self.ema_momentum = ema_momentum
128
+ self.ema_overwrite_frequency = ema_overwrite_frequency
129
+
130
+ clip_args_sum = sum(
131
+ a is not None for a in [clipnorm, clipvalue, global_clipnorm]
132
+ )
133
+ if clip_args_sum > 1:
134
+ raise ValueError(
135
+ "Only one of `clipnorm`, `clipvalue` and `global_clipnorm` can "
136
+ f"be set. Received: clipnorm={clipnorm}, "
137
+ f"clipvalue={clipvalue}, global_clipnorm={global_clipnorm}"
138
+ )
139
+ self.built = False
140
+
141
+ # Set up variable tracking.
142
+ self._variables = []
143
+ self._trainable_variables = []
144
+ self._tracker = tracking.Tracker(
145
+ {
146
+ "variables": (
147
+ lambda x: isinstance(x, backend.Variable),
148
+ self._variables,
149
+ ),
150
+ }
151
+ )
152
+ self._trainable_variables_indices = {}
153
+
154
+ # Create iteration variable
155
+ # Note: dtype="int" will resolve to int32 in JAX
156
+ # (since int64 is disallowed in JAX) and to int64 in TF.
157
+ with backend.name_scope(self.name, caller=self):
158
+ iterations = backend.Variable(
159
+ 0,
160
+ name="iteration",
161
+ dtype="int",
162
+ trainable=False,
163
+ aggregation="only_first_replica",
164
+ )
165
+ self._track_variable(iterations)
166
+ self._iterations = iterations
167
+
168
+ # Create learning rate (schedule or variable)
169
+ if isinstance(
170
+ learning_rate, learning_rate_schedule.LearningRateSchedule
171
+ ):
172
+ self._learning_rate = learning_rate
173
+ elif callable(learning_rate):
174
+ self._learning_rate = learning_rate
175
+ else:
176
+ if not isinstance(learning_rate, float):
177
+ raise ValueError(
178
+ "Argument `learning_rate` should be float, or an instance "
179
+ "of LearningRateSchedule, or a callable "
180
+ "(that takes in the current iteration value "
181
+ "and returns the corresponding learning rate value). "
182
+ f"Received instead: learning_rate={learning_rate}"
183
+ )
184
+ with backend.name_scope(self.name, caller=self):
185
+ learning_rate = backend.Variable(
186
+ learning_rate,
187
+ name="learning_rate",
188
+ dtype=backend.floatx(),
189
+ trainable=False,
190
+ aggregation="only_first_replica",
191
+ )
192
+ self._track_variable(learning_rate)
193
+ self._learning_rate = learning_rate
194
+
195
+ @property
196
+ def iterations(self):
197
+ if self.gradient_accumulation_steps:
198
+ return ops.floor_divide(
199
+ self._iterations, self.gradient_accumulation_steps
200
+ )
201
+
202
+ return self._iterations
203
+
204
+ def _track_variable(self, variable):
205
+ self._tracker.add_to_store("variables", variable)
206
+
207
+ @tracking.no_automatic_dependency_tracking
208
+ def build(self, variables):
209
+ if self.use_ema:
210
+ self._model_variables_moving_average = []
211
+ if self.gradient_accumulation_steps:
212
+ self._accumulated_gradients = []
213
+ for i, variable in enumerate(variables):
214
+ self._trainable_variables_indices[self._var_key(variable)] = i
215
+ if self.use_ema:
216
+ self._model_variables_moving_average.append(
217
+ self.add_variable_from_reference(
218
+ variable,
219
+ name="average",
220
+ )
221
+ )
222
+ if self.gradient_accumulation_steps:
223
+ self._accumulated_gradients.append(
224
+ self.add_variable_from_reference(
225
+ variable,
226
+ name="gradient_accumulator",
227
+ )
228
+ )
229
+ self._trainable_variables = variables[:]
230
+ self.built = True
231
+
232
+ def _var_key(self, variable):
233
+ # Helper function to get a stable ID and the variable instance mapping.
234
+ return id(variable)
235
+
236
+ @property
237
+ def variables(self):
238
+ return self._variables[:]
239
+
240
+ def _get_variable_index(self, variable):
241
+ return self._trainable_variables_indices[self._var_key(variable)]
242
+
243
+ def add_variable(
244
+ self,
245
+ shape,
246
+ initializer="zeros",
247
+ dtype=None,
248
+ aggregation="none",
249
+ name=None,
250
+ ):
251
+ """Add a variable to the optimizer.
252
+
253
+ Args:
254
+ shape: Shape tuple for the variable. Must be fully-defined
255
+ (no `None` entries).
256
+ initializer: Initializer object to use to populate the initial
257
+ variable value, or string name of a built-in initializer
258
+ (e.g. `"random_normal"`). Defaults to `"zeros"`.
259
+ dtype: Dtype of the variable to create, e.g. `"float32"`. If
260
+ unspecified, defaults to the `keras.backend.floatx()`.
261
+ aggregation: Optional string, one of `None`, `"none"`, `"mean"`,
262
+ `"sum"` or `"only_first_replica"`. Annotates the variable with
263
+ the type of multi-replica aggregation to be used for this
264
+ variable when writing custom data parallel training loops.
265
+ Defaults to `"none"`.
266
+ name: String name of the variable. Useful for debugging purposes.
267
+
268
+ Returns:
269
+ An optimizer variable, in the format of `keras.Variable`.
270
+ """
271
+ self._check_super_called()
272
+ initializer = initializers.get(initializer)
273
+ with backend.name_scope(self.name, caller=self):
274
+ variable = backend.Variable(
275
+ initializer=initializer,
276
+ shape=shape,
277
+ dtype=dtype,
278
+ trainable=False,
279
+ aggregation=aggregation,
280
+ name=name,
281
+ )
282
+ self._track_variable(variable)
283
+ return variable
284
+
285
+ def add_variable_from_reference(
286
+ self, reference_variable, name=None, initializer="zeros"
287
+ ):
288
+ """Add an optimizer variable from the model variable.
289
+
290
+ Create an optimizer variable based on the information of model variable.
291
+ For example, in SGD optimizer momemtum, for each model variable, a
292
+ corresponding momemtum variable is created of the same shape and dtype.
293
+
294
+ Args:
295
+ reference_variable: `keras.Variable`. The corresponding model
296
+ variable to the optimizer variable to be created.
297
+ name: Optional string. The name prefix of the optimizer variable to
298
+ be created. If not provided, it will be set to `"var"`. The
299
+ variable name will follow the pattern
300
+ `{variable_name}_{reference_variable.name}`,
301
+ e.g., `momemtum/dense_1`. Defaults to `None`.
302
+ initializer: Initializer object to use to populate the initial
303
+ variable value, or string name of a built-in initializer
304
+ (e.g. `"random_normal"`). If unspecified, defaults to
305
+ `"zeros"`.
306
+
307
+ Returns:
308
+ An optimizer variable, in the format of `keras.Variable`.
309
+ """
310
+ name = name or "var"
311
+ if hasattr(reference_variable, "path"):
312
+ name = reference_variable.path.replace("/", "_") + "_" + name
313
+ else:
314
+ name = (
315
+ str(reference_variable.name).replace("/", "_").replace(":", "_")
316
+ + "_"
317
+ + name
318
+ )
319
+ return self.add_variable(
320
+ shape=reference_variable.shape,
321
+ initializer=initializer,
322
+ dtype=reference_variable.dtype,
323
+ name=name,
324
+ )
325
+
326
+ def _check_variables_are_known(self, variables):
327
+ for v in variables:
328
+ if self._var_key(v) not in self._trainable_variables_indices:
329
+ raise ValueError(
330
+ f"Unknown variable: {v}. This optimizer can only "
331
+ "be called for the variables it was originally built with. "
332
+ "When working with a new set of variables, you should "
333
+ "recreate a new optimizer instance."
334
+ )
335
+
336
+ def assign(self, variable, value):
337
+ """Assign a value to a variable.
338
+
339
+ This should be used in optimizers instead of `variable.assign(value)` to
340
+ support backend specific optimizations.
341
+ Note that the variable can be a model variable or an optimizer variable;
342
+ it can be a backend native variable or a Keras variable.
343
+
344
+ Args:
345
+ variable: The variable to update.
346
+ value: The value to add to the variable.
347
+ """
348
+ variable.assign(value)
349
+
350
+ def assign_add(self, variable, value):
351
+ """Add a value to a variable.
352
+
353
+ This should be used in optimizers instead of
354
+ `variable.assign_add(value)` to support backend specific optimizations.
355
+ Note that the variable can be a model variable or an optimizer variable;
356
+ it can be a backend native variable or a Keras variable.
357
+
358
+ Args:
359
+ variable: The variable to update.
360
+ value: The value to add to the variable.
361
+ """
362
+ variable.assign_add(value)
363
+
364
+ def assign_sub(self, variable, value):
365
+ """Subtract a value from a variable.
366
+
367
+ This should be used in optimizers instead of
368
+ `variable.assign_sub(value)` to support backend specific optimizations.
369
+ Note that the variable can be a model variable or an optimizer variable;
370
+ it can be a backend native variable or a Keras variable.
371
+
372
+ Args:
373
+ variable: The variable to update.
374
+ value: The value to add to the variable.
375
+ """
376
+ variable.assign_sub(value)
377
+
378
+ def update_step(self, gradient, variable, learning_rate):
379
+ raise NotImplementedError
380
+
381
+ def apply_gradients(self, grads_and_vars):
382
+ grads, trainable_variables = zip(*grads_and_vars)
383
+ self.apply(grads, trainable_variables)
384
+ # Return iterations for compat with tf.keras.
385
+ return self._iterations
386
+
387
+ def apply(self, grads, trainable_variables=None):
388
+ """Update traininable variables according to provided gradient values.
389
+
390
+ `grads` should be a list of gradient tensors
391
+ with 1:1 mapping to the list of variables the optimizer was built with.
392
+
393
+ `trainable_variables` can be provided
394
+ on the first call to build the optimizer.
395
+ """
396
+ if len(grads) == 0:
397
+ # It is possible that the grad is empty. In this case,
398
+ # `apply_gradients` is a no-op.
399
+ return
400
+
401
+ if trainable_variables is None:
402
+ if not self.built:
403
+ raise ValueError(
404
+ "When passing `grads` without `variables`, the optimizer "
405
+ "must already be built on a list of variables. "
406
+ "Call `optimizer.build(trainable_variables)` first. "
407
+ )
408
+ if len(grads) != len(self._trainable_variables_indices):
409
+ raise ValueError(
410
+ "When passing `grads` as a list of gradient tensors, the "
411
+ f"gradients must match `optimizer.variables` one-to-on. "
412
+ f"Received a list of {len(grads)} gradients, but the "
413
+ f"optimizer is tracking {len(self._trainable_variables)} "
414
+ "trainable variables."
415
+ )
416
+ trainable_variables = self._trainable_variables
417
+ else:
418
+ trainable_variables = list(trainable_variables)
419
+ # Optionally build optimizer.
420
+ if not self.built:
421
+ with backend.name_scope(self.name, caller=self):
422
+ self.build(trainable_variables)
423
+ self.built = True
424
+ self._check_variables_are_known(trainable_variables)
425
+
426
+ with backend.name_scope(self.name, caller=self):
427
+ # Overwrite targeted variables directly with their gradients if
428
+ # their `overwrite_with_gradient` is set.
429
+ grads, trainable_variables = (
430
+ self._overwrite_variables_directly_with_gradients(
431
+ grads, trainable_variables
432
+ )
433
+ )
434
+
435
+ # Filter empty gradients.
436
+ grads, trainable_variables = self._filter_empty_gradients(
437
+ grads, trainable_variables
438
+ )
439
+ if len(list(grads)) == 0:
440
+ return
441
+
442
+ # Unscale gradients.
443
+ scale = self.loss_scale_factor
444
+ if scale is not None:
445
+ grads = [g if g is None else g / scale for g in grads]
446
+
447
+ # Apply gradient updates.
448
+ self._backend_apply_gradients(grads, trainable_variables)
449
+ # Apply variable constraints after applying gradients.
450
+ for variable in trainable_variables:
451
+ if variable.constraint is not None:
452
+ variable.assign(variable.constraint(variable))
453
+
454
+ def _backend_apply_gradients(self, grads, trainable_variables):
455
+ """Apply method that can be overridden by different backends.
456
+
457
+ JAX overrides it in order to deal with statelessness in gradient
458
+ accumulation and EMA handling.
459
+
460
+ The below implementation is intended to be generally backend-agnostic,
461
+ but may not work with all backends.
462
+
463
+ This method does 4 things:
464
+ - Call the optimizer's update_step() to update trainable variables
465
+ and optimizer variables.
466
+ - Update EMA variables, if EMA is configured.
467
+ - Update gradient accumulators, if gradient accumulation is configured.
468
+ - Update the iteration counter.
469
+ """
470
+ if self.gradient_accumulation_steps:
471
+ is_update_step = (
472
+ self._iterations + 1
473
+ ) % self.gradient_accumulation_steps == 0
474
+ # `trainable_variables` might have been filtered in previous
475
+ # processing steps, so we need to ensure the correct mapping between
476
+ # `self._accumulated_gradients` and `trainable_variables`
477
+ acc_grads = [
478
+ self._accumulated_gradients[self._get_variable_index(v)]
479
+ for v in trainable_variables
480
+ ]
481
+
482
+ def _update_step_fn(grads, trainable_variables):
483
+ # Run update step with accumulated grads + reset accumulators
484
+ steps = self.gradient_accumulation_steps
485
+ grads = [
486
+ (g + acc_g) / steps for g, acc_g in zip(grads, acc_grads)
487
+ ]
488
+
489
+ # Apply clipping and weight decay.
490
+ grads = self._clip_gradients(grads)
491
+ self._apply_weight_decay(trainable_variables)
492
+
493
+ self._backend_update_step(
494
+ grads, trainable_variables, self.learning_rate
495
+ )
496
+ self._backend_reset_gradient_accumulators()
497
+
498
+ ops.cond(
499
+ is_update_step,
500
+ lambda: _update_step_fn(grads, trainable_variables),
501
+ lambda: self._backend_increment_gradient_accumulators(
502
+ grads, acc_grads
503
+ ),
504
+ )
505
+ else:
506
+ # Apply clipping and weight decay.
507
+ grads = self._clip_gradients(grads)
508
+ self._apply_weight_decay(trainable_variables)
509
+
510
+ # Run update step.
511
+ self._backend_update_step(
512
+ grads, trainable_variables, self.learning_rate
513
+ )
514
+
515
+ if self.use_ema:
516
+ self._update_model_variables_moving_average(
517
+ self._trainable_variables
518
+ )
519
+ if self.ema_overwrite_frequency:
520
+ # Only when self.ema_overwrite_frequency is not None, we
521
+ # overwrite the model variables.
522
+ should_overwrite_model_vars = (
523
+ self.iterations + 1
524
+ ) % self.ema_overwrite_frequency == 0
525
+ ops.cond(
526
+ should_overwrite_model_vars,
527
+ lambda: self._overwrite_model_variables_with_average_value(
528
+ self._trainable_variables
529
+ ),
530
+ lambda: None,
531
+ )
532
+ # Update iteration counter.
533
+ self._iterations.assign_add(1)
534
+
535
+ def _backend_update_step(self, grads, trainable_variables, learning_rate):
536
+ """Collective update_step that can be overridden by the backend.
537
+
538
+ It is overridden by torch for performance reasons, and
539
+ by TF to support tf.distribute.
540
+ """
541
+ for grad, var in zip(grads, trainable_variables):
542
+ self.update_step(grad, var, learning_rate)
543
+
544
+ def _backend_reset_gradient_accumulators(self):
545
+ for g_acc in self._accumulated_gradients:
546
+ g_acc.assign(ops.zeros(g_acc.shape, dtype=g_acc.dtype))
547
+
548
+ def _backend_increment_gradient_accumulators(self, grads, acc_grads):
549
+ new_g_accs = [(g + acc_g) for g, acc_g in zip(grads, acc_grads)]
550
+ for n_g_acc, g_acc in zip(new_g_accs, acc_grads):
551
+ g_acc.assign(n_g_acc)
552
+
553
+ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
554
+ self._check_super_called()
555
+
556
+ if not self.built:
557
+ raise ValueError(
558
+ f"To call `stateless_apply`, {self.__class__.__name__} "
559
+ "must be built (i.e. its variables must have been created). "
560
+ "You can build it via `optimizer.build(trainable_variables)`."
561
+ )
562
+ if len(optimizer_variables) != len(self.variables):
563
+ raise ValueError(
564
+ "Argument `optimizer_variables` must be a list of tensors "
565
+ f"corresponding 1:1 to {self.__class__.__name__}().variables. "
566
+ f"Received list with length {len(optimizer_variables)}, but "
567
+ f"expected {len(self.variables)} variables."
568
+ )
569
+ if len(trainable_variables) != len(self._trainable_variables):
570
+ raise ValueError(
571
+ "Argument `optimizer_variables` must be a list of tensors "
572
+ "corresponding 1:1 to the trainable variables list that "
573
+ "the optimizer was built with. Received "
574
+ f"len(trainable_variables) == {len(trainable_variables)} "
575
+ "whereas the optimizer was built with "
576
+ f"{len(self._trainable_variables)} variables."
577
+ )
578
+
579
+ # Gather variable mapping
580
+ mapping = list(
581
+ zip(self._trainable_variables, trainable_variables)
582
+ ) + list(zip(self.variables, optimizer_variables))
583
+
584
+ # Call in stateless scope
585
+ with backend.StatelessScope(state_mapping=mapping) as scope:
586
+ self.apply(grads)
587
+
588
+ # Gather updated variables
589
+ trainable_variables = []
590
+ for v in self._trainable_variables:
591
+ new_v = scope.get_current_value(v)
592
+ if new_v is not None:
593
+ trainable_variables.append(new_v)
594
+ else:
595
+ trainable_variables.append(v)
596
+ optimizer_variables = []
597
+ for v in self.variables:
598
+ new_v = scope.get_current_value(v)
599
+ if new_v is not None:
600
+ optimizer_variables.append(new_v)
601
+ else:
602
+ optimizer_variables.append(v)
603
+ return trainable_variables, optimizer_variables
604
+
605
+ def scale_loss(self, loss):
606
+ """Scale the loss before computing gradients.
607
+
608
+ Scales the loss before gradients are computed in a `train_step`. This
609
+ is primarily useful during mixed precision training to prevent numeric
610
+ underflow.
611
+ """
612
+ if self.loss_scale_factor is not None:
613
+ return loss * self.loss_scale_factor
614
+ return loss
615
+
616
+ @property
617
+ def learning_rate(self):
618
+ return self._get_current_learning_rate()
619
+
620
+ @learning_rate.setter
621
+ def learning_rate(self, learning_rate):
622
+ if isinstance(self._learning_rate, backend.Variable):
623
+ prev_lr_var = self._learning_rate
624
+ else:
625
+ prev_lr_var = None
626
+ if isinstance(
627
+ learning_rate, learning_rate_schedule.LearningRateSchedule
628
+ ):
629
+ self._learning_rate = learning_rate
630
+ elif callable(learning_rate):
631
+ self._learning_rate = learning_rate
632
+ else:
633
+ if isinstance(
634
+ self._learning_rate, learning_rate_schedule.LearningRateSchedule
635
+ ):
636
+ raise TypeError(
637
+ "This optimizer was created with a `LearningRateSchedule`"
638
+ " object as its `learning_rate` constructor argument, "
639
+ "hence its learning rate is not settable. If you need the"
640
+ " learning rate to be settable, you should instantiate "
641
+ "the optimizer with a float `learning_rate` argument."
642
+ )
643
+ self._learning_rate.assign(learning_rate)
644
+ if prev_lr_var is not None and not isinstance(
645
+ self._learning_rate, backend.Variable
646
+ ):
647
+ # Untrack learning rate variable
648
+ self._untrack_variable(prev_lr_var)
649
+
650
+ def set_weights(self, weights):
651
+ """Set the weights of the optimizer."""
652
+ if not self.built:
653
+ raise ValueError(
654
+ "You are calling `set_weights()` on an optimizer that has not "
655
+ "yet been built. Please call "
656
+ "`optimizer.build(trainable_variables)` to create the "
657
+ "optimizer weights before calling `set_weights()`."
658
+ )
659
+ for variable, weight in zip(self._variables, weights):
660
+ if variable.shape != weight.shape:
661
+ raise ValueError(
662
+ f"Optimizer variable {self._var_key(variable)} has shape "
663
+ f"{str(variable.shape)} not compatible with provided "
664
+ f"weight shape {str(weight.shape)}."
665
+ )
666
+ variable.assign(weight)
667
+
668
+ def save_own_variables(self, store):
669
+ """Get the state of this optimizer object."""
670
+ for i, variable in enumerate(self.variables):
671
+ store[str(i)] = variable.numpy()
672
+
673
+ def load_own_variables(self, store):
674
+ """Set the state of this optimizer object."""
675
+ if len(store.keys()) != len(self.variables):
676
+ msg = (
677
+ f"Skipping variable loading for optimizer '{self.name}', "
678
+ f"because it has {len(self.variables)} variables whereas "
679
+ f"the saved optimizer has {len(store.keys())} variables. "
680
+ )
681
+ if len(self.variables) == 0:
682
+ msg += (
683
+ "This is likely because the optimizer has not been "
684
+ "called/built yet."
685
+ )
686
+ warnings.warn(msg, stacklevel=2)
687
+ return
688
+ for i, variable in enumerate(self.variables):
689
+ variable.assign(store[str(i)])
690
+
691
+ def _get_current_learning_rate(self):
692
+ if isinstance(
693
+ self._learning_rate, learning_rate_schedule.LearningRateSchedule
694
+ ):
695
+ return self._learning_rate(self._iterations)
696
+ elif callable(self._learning_rate):
697
+ return self._learning_rate()
698
+ return self._learning_rate
699
+
700
+ def _overwrite_variables_directly_with_gradients(self, grads, vars):
701
+ """Overwrite the variables directly by their gradients.
702
+
703
+ This method is designed for a special case where we want to overwrite
704
+ the variable directly with its computed gradient. For example, in float8
705
+ training, new `scale` and `amax_history` are computed as gradients, and
706
+ we want to overwrite them directly instead of following the typical
707
+ procedure such as gradient descent with a learning rate, gradient
708
+ clipping and weight decaying.
709
+
710
+ After the update, the processed pairs will be filtered out.
711
+ """
712
+ # Shortcut for `tf.Variable` because it doesn't have a
713
+ # `overwrite_with_gradient` attr
714
+ if any(not hasattr(v, "overwrite_with_gradient") for v in vars):
715
+ return grads, vars
716
+
717
+ # Shallow copies
718
+ filtered_grads = list(grads)
719
+ filtered_vars = list(vars)
720
+
721
+ # Iterate from right to left for safe popping
722
+ for i in range(len(filtered_grads) - 1, -1, -1):
723
+ g, v = filtered_grads[i], filtered_vars[i]
724
+ if v.overwrite_with_gradient:
725
+ if self.gradient_accumulation_steps:
726
+ # Utilize a stateless manner for JAX compatibility
727
+ steps = self.gradient_accumulation_steps
728
+ is_update_step = (self._iterations + 1) % steps == 0
729
+ acc_g = self._accumulated_gradients[
730
+ self._get_variable_index(v)
731
+ ]
732
+ # `ops.maximum` is utilized for gradient accumulation for
733
+ # `overwrite_with_gradient=True` variables
734
+ new_g_acc = ops.cond(
735
+ is_update_step,
736
+ lambda: ops.zeros(g.shape, dtype=g.dtype),
737
+ lambda: ops.maximum(g, acc_g),
738
+ )
739
+ new_g = ops.cond(
740
+ is_update_step,
741
+ lambda: ops.maximum(g, acc_g),
742
+ lambda: g,
743
+ )
744
+ new_v = ops.cond(
745
+ is_update_step, lambda: new_g, lambda: v.value
746
+ )
747
+ v.assign(new_v)
748
+ acc_g.assign(new_g_acc)
749
+ else:
750
+ v.assign(g)
751
+ filtered_grads.pop(i)
752
+ filtered_vars.pop(i)
753
+ return filtered_grads, filtered_vars
754
+
755
+ def _filter_empty_gradients(self, grads, vars):
756
+ filtered_grads = list(grads)
757
+ filtered_vars = list(vars)
758
+ missing_grad_vars = []
759
+
760
+ # Iterate from right to left for safe popping
761
+ for i in range(len(filtered_grads) - 1, -1, -1):
762
+ if filtered_grads[i] is None:
763
+ filtered_grads.pop(i)
764
+ v = filtered_vars.pop(i)
765
+ try:
766
+ missing_grad_vars.append(v.path)
767
+ except AttributeError:
768
+ # `tf.Variable` doesn't have `path` attr.
769
+ missing_grad_vars.append(v.name)
770
+
771
+ if not filtered_grads:
772
+ raise ValueError("No gradients provided for any variable.")
773
+ if missing_grad_vars:
774
+ warnings.warn(
775
+ "Gradients do not exist for variables "
776
+ f"{list(reversed(missing_grad_vars))} when minimizing the loss."
777
+ " If using `model.compile()`, did you forget to provide a "
778
+ "`loss` argument?"
779
+ )
780
+ return filtered_grads, filtered_vars
781
+
782
+ def _clip_gradients(self, grads):
783
+ if self.clipnorm and self.clipnorm > 0:
784
+ return [
785
+ self._clip_by_norm(g) if g is not None else g for g in grads
786
+ ]
787
+ elif self.global_clipnorm and self.global_clipnorm > 0:
788
+ return clip_by_global_norm(grads, self.global_clipnorm)
789
+ elif self.clipvalue and self.clipvalue > 0:
790
+ v = self.clipvalue
791
+ return [ops.clip(g, -v, v) if g is not None else g for g in grads]
792
+ else:
793
+ return grads
794
+
795
+ def exclude_from_weight_decay(self, var_list=None, var_names=None):
796
+ """Exclude variables from weight decay.
797
+
798
+ This method must be called before the optimizer's `build` method is
799
+ called. You can set specific variables to exclude out, or set a list of
800
+ strings as the anchor words, if any of which appear in a variable's
801
+ name, then the variable is excluded.
802
+
803
+ Args:
804
+ var_list: A list of `Variable`s to exclude from weight decay.
805
+ var_names: A list of strings. If any string in `var_names` appear
806
+ in the model variable's name, then this model variable is
807
+ excluded from weight decay. For example, `var_names=['bias']`
808
+ excludes all bias variables from weight decay.
809
+ """
810
+ if hasattr(self, "_built") and self._built:
811
+ raise ValueError(
812
+ "`exclude_from_weight_decay()` can only be configured before "
813
+ "the optimizer is built."
814
+ )
815
+
816
+ # Use a `set` for the ids of `var_list` to speed up the searching
817
+ if var_list:
818
+ self._exclude_from_weight_decay = set(
819
+ self._var_key(variable) for variable in var_list
820
+ )
821
+ else:
822
+ self._exclude_from_weight_decay = set()
823
+
824
+ # Precompile the pattern for `var_names` to speed up the searching
825
+ if var_names and len(var_names) > 0:
826
+ self._exclude_from_weight_decay_pattern = re.compile(
827
+ "|".join(set(var_names))
828
+ )
829
+ else:
830
+ self._exclude_from_weight_decay_pattern = None
831
+
832
+ # Reset cache
833
+ self._exclude_from_weight_decay_cache = dict()
834
+
835
+ def _use_weight_decay(self, variable):
836
+ variable_id = self._var_key(variable)
837
+
838
+ # Immediately return the value if `variable_id` hits the cache
839
+ if not hasattr(self, "_exclude_from_weight_decay_cache"):
840
+ self._exclude_from_weight_decay_cache = dict()
841
+ if variable_id in self._exclude_from_weight_decay_cache:
842
+ return self._exclude_from_weight_decay_cache[variable_id]
843
+
844
+ # Determine whether the variable should apply weight decay or not
845
+ exclude_from_weight_decay = getattr(
846
+ self, "_exclude_from_weight_decay", set()
847
+ )
848
+ exclude_from_weight_decay_pattern = getattr(
849
+ self, "_exclude_from_weight_decay_pattern", None
850
+ )
851
+ if variable_id in exclude_from_weight_decay:
852
+ self._exclude_from_weight_decay_cache[variable_id] = False
853
+ return False
854
+ if exclude_from_weight_decay_pattern is not None:
855
+ if (
856
+ re.search(exclude_from_weight_decay_pattern, variable.name)
857
+ is not None
858
+ ):
859
+ self._exclude_from_weight_decay_cache[variable_id] = False
860
+ return False
861
+ self._exclude_from_weight_decay_cache[variable_id] = True
862
+ return True
863
+
864
+ def _apply_weight_decay(self, variables):
865
+ if self.weight_decay is None:
866
+ return
867
+ for variable in variables:
868
+ if self._use_weight_decay(variable):
869
+ lr = ops.cast(self.learning_rate, variable.dtype)
870
+ wd = ops.cast(self.weight_decay, variable.dtype)
871
+ variable.assign(variable - variable * wd * lr)
872
+
873
+ def _check_super_called(self):
874
+ if not hasattr(self, "_lock"):
875
+ raise RuntimeError(
876
+ f"In optimizer '{self.__class__.__name__}', you forgot to call "
877
+ "`super().__init__()` as the first statement "
878
+ "in the `__init__()` method. "
879
+ "Go add it!"
880
+ )
881
+
882
+ def _update_model_variables_moving_average(self, trainable_variables):
883
+ """Update the stored moving average using the latest value."""
884
+ if self.use_ema:
885
+ for var, average in zip(
886
+ trainable_variables, self._model_variables_moving_average
887
+ ):
888
+ not_first_step = ops.not_equal(self.iterations, 0)
889
+ momentum = (
890
+ ops.cast(not_first_step, var.dtype) * self.ema_momentum
891
+ )
892
+ average.assign(momentum * average + (1 - momentum) * var)
893
+
894
+ def _overwrite_model_variables_with_average_value(
895
+ self, trainable_variables
896
+ ):
897
+ """Overwrite model variables with its moving average."""
898
+ if len(trainable_variables) != len(
899
+ self._model_variables_moving_average
900
+ ):
901
+ raise ValueError(
902
+ f"The length of model variables ({len(trainable_variables)}) "
903
+ "to override does not match the length of model variables "
904
+ "stored in the optimizer "
905
+ f"({len(self._model_variables_moving_average)}). Please "
906
+ "check if the optimizer was called on your model."
907
+ )
908
+ for var, average_var in zip(
909
+ trainable_variables, self._model_variables_moving_average
910
+ ):
911
+ var.assign(average_var)
912
+
913
+ def finalize_variable_values(self, var_list):
914
+ """Set the final value of model's trainable variables.
915
+
916
+ Sometimes there are some extra steps before ending the variable updates,
917
+ such as overriding the model variables with its average value.
918
+
919
+ Args:
920
+ var_list: list of model variables.
921
+ """
922
+ if self.use_ema:
923
+ # If the optimizer uses EMA, then when finalizing, we replace the
924
+ # model variable value with its moving average stored inside
925
+ # optimizer.
926
+ self._overwrite_model_variables_with_average_value(var_list)
927
+
928
+ def _obj_type(self):
929
+ return "Optimizer"
930
+
931
+ def get_config(self):
932
+ """Returns the config of the optimizer.
933
+
934
+ An optimizer config is a Python dictionary (serializable)
935
+ containing the configuration of an optimizer.
936
+ The same optimizer can be reinstantiated later
937
+ (without any saved state) from this configuration.
938
+
939
+ Subclass optimizer should override this method to include other
940
+ hyperparameters.
941
+
942
+ Returns:
943
+ Python dictionary.
944
+ """
945
+
946
+ if isinstance(
947
+ self._learning_rate, learning_rate_schedule.LearningRateSchedule
948
+ ):
949
+ learning_rate = learning_rate_schedule.serialize(
950
+ self._learning_rate
951
+ )
952
+ elif isinstance(self._learning_rate, backend.Variable):
953
+ learning_rate = float(self._learning_rate.numpy())
954
+ elif ops.is_tensor(self._learning_rate):
955
+ learning_rate = float(self._learning_rate)
956
+ elif callable(self._learning_rate):
957
+ learning_rate = serialization_lib.serialize_keras_object(
958
+ self._learning_rate
959
+ )
960
+ else:
961
+ learning_rate = 0.5
962
+
963
+ config = {
964
+ "name": self.name,
965
+ "learning_rate": learning_rate,
966
+ "weight_decay": self.weight_decay,
967
+ "clipnorm": self.clipnorm,
968
+ "global_clipnorm": self.global_clipnorm,
969
+ "clipvalue": self.clipvalue,
970
+ "use_ema": self.use_ema,
971
+ "ema_momentum": self.ema_momentum,
972
+ "ema_overwrite_frequency": self.ema_overwrite_frequency,
973
+ "loss_scale_factor": self.loss_scale_factor,
974
+ "gradient_accumulation_steps": self.gradient_accumulation_steps,
975
+ }
976
+ return config
977
+
978
+ @classmethod
979
+ def from_config(cls, config, custom_objects=None):
980
+ """Creates an optimizer from its config.
981
+
982
+ This method is the reverse of `get_config`, capable of instantiating the
983
+ same optimizer from the config dictionary.
984
+
985
+ Args:
986
+ config: A Python dictionary, typically the output of get_config.
987
+ custom_objects: A Python dictionary mapping names to additional
988
+ user-defined Python objects needed to recreate this optimizer.
989
+
990
+ Returns:
991
+ An optimizer instance.
992
+ """
993
+ if "learning_rate" in config:
994
+ if isinstance(config["learning_rate"], dict):
995
+ config["learning_rate"] = (
996
+ serialization_lib.deserialize_keras_object(
997
+ config["learning_rate"], custom_objects=custom_objects
998
+ )
999
+ )
1000
+ return cls(**config)
1001
+
1002
+ def __setattr__(self, name, value):
1003
+ # Prevent users from attaching state to the
1004
+ # layer before `super()` is called -- since that
1005
+ # state would silently not be tracked.
1006
+ if name != "_lock":
1007
+ self._check_super_called()
1008
+ # Track Variables.
1009
+ if hasattr(self, "_tracker"):
1010
+ value = self._tracker.track(value)
1011
+ return super().__setattr__(name, value)
1012
+
1013
+ def _clip_by_norm(self, values, axes=None):
1014
+ # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
1015
+ l2sum = ops.sum(ops.square(values), axes, keepdims=True)
1016
+ pred = l2sum > 0
1017
+ # Two-tap tf.where trick to bypass NaN gradients
1018
+ l2sum_safe = ops.where(pred, l2sum, ops.ones_like(l2sum))
1019
+ l2norm = ops.where(pred, ops.sqrt(l2sum_safe), l2sum)
1020
+ intermediate = ops.multiply(values, self.clipnorm)
1021
+ values_clip = ops.convert_to_tensor(intermediate) / ops.maximum(
1022
+ l2norm, self.clipnorm
1023
+ )
1024
+ return values_clip
1025
+
1026
+ def _untrack_variable(self, variable):
1027
+ previous_lock_state = self._tracker.locked
1028
+ self._tracker.unlock()
1029
+ self._tracker.untrack(variable)
1030
+ if previous_lock_state is True:
1031
+ self._tracker.lock()
1032
+
1033
+
1034
+ base_optimizer_keyword_args = """name: String. The name to use
1035
+ for momentum accumulator weights created by
1036
+ the optimizer.
1037
+ weight_decay: Float. If set, weight decay is applied.
1038
+ clipnorm: Float. If set, the gradient of each weight is individually
1039
+ clipped so that its norm is no higher than this value.
1040
+ clipvalue: Float. If set, the gradient of each weight is clipped to be
1041
+ no higher than this value.
1042
+ global_clipnorm: Float. If set, the gradient of all weights is clipped
1043
+ so that their global norm is no higher than this value.
1044
+ use_ema: Boolean, defaults to `False`.
1045
+ If `True`, exponential moving average
1046
+ (EMA) is applied. EMA consists of computing an exponential moving
1047
+ average of the weights of the model (as the weight values change
1048
+ after each training batch), and periodically overwriting the
1049
+ weights with their moving average.
1050
+ ema_momentum: Float, defaults to 0.99. Only used if `use_ema=True`.
1051
+ This is the momentum to use when computing
1052
+ the EMA of the model's weights:
1053
+ `new_average = ema_momentum * old_average + (1 - ema_momentum) *
1054
+ current_variable_value`.
1055
+ ema_overwrite_frequency: Int or None, defaults to None. Only used if
1056
+ `use_ema=True`. Every `ema_overwrite_frequency` steps of iterations,
1057
+ we overwrite the model variable by its moving average.
1058
+ If None, the optimizer
1059
+ does not overwrite model variables in the middle of training,
1060
+ and you need to explicitly overwrite the variables
1061
+ at the end of training by calling
1062
+ `optimizer.finalize_variable_values()` (which updates the model
1063
+ variables in-place). When using the built-in `fit()` training loop,
1064
+ this happens automatically after the last epoch,
1065
+ and you don't need to do anything.
1066
+ loss_scale_factor: Float or `None`. If a float, the scale factor will
1067
+ be multiplied the loss before computing gradients, and the inverse
1068
+ of the scale factor will be multiplied by the gradients before
1069
+ updating variables. Useful for preventing underflow during
1070
+ mixed precision training. Alternately,
1071
+ `keras.optimizers.LossScaleOptimizer` will
1072
+ automatically set a loss scale factor.
1073
+ gradient_accumulation_steps: Int or `None`. If an int, model & optimizer
1074
+ variables will not be updated at every step; instead they will be
1075
+ updated every `gradient_accumulation_steps` steps, using the average
1076
+ value of the gradients since the last update. This is known as
1077
+ "gradient accumulation". This can be useful
1078
+ when your batch size is very small, in order to reduce gradient
1079
+ noise at each update step. EMA frequency will look at "accumulated"
1080
+ iterations value (optimizer steps // gradient_accumulation_steps).
1081
+ Learning rate schedules will look at "real" iterations value
1082
+ (optimizer steps).
1083
+ """
1084
+
1085
+
1086
+ def global_norm(value_list):
1087
+ """Computes the global norm of multiple tensors."""
1088
+ squared_norms = [
1089
+ ops.sum(ops.square(v)) for v in value_list if v is not None
1090
+ ]
1091
+ squared_norm = ops.sum(ops.stack(squared_norms))
1092
+ return ops.sqrt(squared_norm)
1093
+
1094
+
1095
+ def clip_by_global_norm(value_list, clip_norm):
1096
+ use_norm = global_norm(value_list)
1097
+ # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
1098
+ scale_for_finite = clip_norm * ops.minimum(1.0 / use_norm, 1.0 / clip_norm)
1099
+ # If use_norm is any finite number, this is a no-op. For inf/-inf/NaN,
1100
+ # this will make scale NaN.
1101
+ scale = scale_for_finite + (use_norm - use_norm)
1102
+ return [v * scale if v is not None else v for v in value_list]
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/ftrl.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import initializers
2
+ from keras.src import ops
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.optimizers import optimizer
5
+
6
+
7
+ @keras_export(["keras.optimizers.Ftrl"])
8
+ class Ftrl(optimizer.Optimizer):
9
+ r"""Optimizer that implements the FTRL algorithm.
10
+
11
+ "Follow The Regularized Leader" (FTRL) is an optimization algorithm
12
+ developed at Google for click-through rate prediction in the early 2010s. It
13
+ is most suitable for shallow models with large and sparse feature spaces.
14
+ The algorithm is described by
15
+ [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf).
16
+ The Keras version has support for both online L2 regularization
17
+ (the L2 regularization described in the paper
18
+ above) and shrinkage-type L2 regularization
19
+ (which is the addition of an L2 penalty to the loss function).
20
+
21
+ Initialization:
22
+
23
+ ```python
24
+ n = 0
25
+ sigma = 0
26
+ z = 0
27
+ ```
28
+
29
+ Update rule for one variable `w`:
30
+
31
+ ```python
32
+ prev_n = n
33
+ n = n + g ** 2
34
+ sigma = (n ** -lr_power - prev_n ** -lr_power) / lr
35
+ z = z + g - sigma * w
36
+ if abs(z) < lambda_1:
37
+ w = 0
38
+ else:
39
+ w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2)
40
+ ```
41
+
42
+ Notation:
43
+
44
+ - `lr` is the learning rate
45
+ - `g` is the gradient for the variable
46
+ - `lambda_1` is the L1 regularization strength
47
+ - `lambda_2` is the L2 regularization strength
48
+ - `lr_power` is the power to scale n.
49
+
50
+ Check the documentation for the `l2_shrinkage_regularization_strength`
51
+ parameter for more details when shrinkage is enabled, in which case gradient
52
+ is replaced with a gradient with shrinkage.
53
+
54
+ Args:
55
+ learning_rate: A float, a
56
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
57
+ a callable that takes no arguments and returns the actual value to
58
+ use. The learning rate. Defaults to `0.001`.
59
+ learning_rate_power: A float value, must be less or equal to zero.
60
+ Controls how the learning rate decreases during training. Use zero
61
+ for a fixed learning rate.
62
+ initial_accumulator_value: The starting value for accumulators. Only
63
+ zero or positive values are allowed.
64
+ l1_regularization_strength: A float value, must be greater than or equal
65
+ to zero. Defaults to `0.0`.
66
+ l2_regularization_strength: A float value, must be greater than or equal
67
+ to zero. Defaults to `0.0`.
68
+ l2_shrinkage_regularization_strength: A float value, must be greater
69
+ than or equal to zero. This differs from L2 above in that the L2
70
+ above is a stabilization penalty, whereas this L2 shrinkage is a
71
+ magnitude penalty. When input is sparse shrinkage will only happen
72
+ on the active weights.
73
+ beta: A float value, representing the beta value from the paper.
74
+ Defaults to `0.0`.
75
+ {{base_optimizer_keyword_args}}
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ learning_rate=0.001,
81
+ learning_rate_power=-0.5,
82
+ initial_accumulator_value=0.1,
83
+ l1_regularization_strength=0.0,
84
+ l2_regularization_strength=0.0,
85
+ l2_shrinkage_regularization_strength=0.0,
86
+ beta=0.0,
87
+ weight_decay=None,
88
+ clipnorm=None,
89
+ clipvalue=None,
90
+ global_clipnorm=None,
91
+ use_ema=False,
92
+ ema_momentum=0.99,
93
+ ema_overwrite_frequency=None,
94
+ loss_scale_factor=None,
95
+ gradient_accumulation_steps=None,
96
+ name="ftrl",
97
+ **kwargs,
98
+ ):
99
+ super().__init__(
100
+ learning_rate=learning_rate,
101
+ name=name,
102
+ weight_decay=weight_decay,
103
+ clipnorm=clipnorm,
104
+ clipvalue=clipvalue,
105
+ global_clipnorm=global_clipnorm,
106
+ use_ema=use_ema,
107
+ ema_momentum=ema_momentum,
108
+ ema_overwrite_frequency=ema_overwrite_frequency,
109
+ loss_scale_factor=loss_scale_factor,
110
+ gradient_accumulation_steps=gradient_accumulation_steps,
111
+ **kwargs,
112
+ )
113
+
114
+ if initial_accumulator_value < 0.0:
115
+ raise ValueError(
116
+ "`initial_accumulator_value` needs to be positive or zero. "
117
+ "Received: initial_accumulator_value="
118
+ f"{initial_accumulator_value}."
119
+ )
120
+ if learning_rate_power > 0.0:
121
+ raise ValueError(
122
+ "`learning_rate_power` needs to be negative or zero. Received: "
123
+ f"learning_rate_power={learning_rate_power}."
124
+ )
125
+ if l1_regularization_strength < 0.0:
126
+ raise ValueError(
127
+ "`l1_regularization_strength` needs to be positive or zero. "
128
+ "Received: l1_regularization_strength="
129
+ f"{l1_regularization_strength}."
130
+ )
131
+ if l2_regularization_strength < 0.0:
132
+ raise ValueError(
133
+ "`l2_regularization_strength` needs to be positive or zero. "
134
+ "Received: l2_regularization_strength="
135
+ f"{l2_regularization_strength}."
136
+ )
137
+ if l2_shrinkage_regularization_strength < 0.0:
138
+ raise ValueError(
139
+ "`l2_shrinkage_regularization_strength` needs to be positive "
140
+ "or zero. Received: l2_shrinkage_regularization_strength"
141
+ f"={l2_shrinkage_regularization_strength}."
142
+ )
143
+
144
+ self.learning_rate_power = learning_rate_power
145
+ self.initial_accumulator_value = initial_accumulator_value
146
+ self.l1_regularization_strength = l1_regularization_strength
147
+ self.l2_regularization_strength = l2_regularization_strength
148
+ self.l2_shrinkage_regularization_strength = (
149
+ l2_shrinkage_regularization_strength
150
+ )
151
+ self.beta = beta
152
+
153
+ def build(self, var_list):
154
+ """Initialize optimizer variables.
155
+
156
+ Args:
157
+ var_list: list of model variables to build Ftrl variables on.
158
+ """
159
+ if self.built:
160
+ return
161
+ super().build(var_list)
162
+ self._accumulators = []
163
+ self._linears = []
164
+ for var in var_list:
165
+ self._accumulators.append(
166
+ self.add_variable(
167
+ shape=var.shape,
168
+ dtype=var.dtype,
169
+ name="accumulator",
170
+ initializer=initializers.Constant(
171
+ self.initial_accumulator_value,
172
+ ),
173
+ )
174
+ )
175
+ self._linears.append(
176
+ self.add_variable_from_reference(
177
+ reference_variable=var, name="linear"
178
+ )
179
+ )
180
+
181
+ def update_step(self, gradient, variable, learning_rate):
182
+ """Update step given gradient and the associated model variable."""
183
+
184
+ lr = ops.cast(learning_rate, variable.dtype)
185
+ gradient = ops.cast(gradient, variable.dtype)
186
+
187
+ accum = self._accumulators[self._get_variable_index(variable)]
188
+ linear = self._linears[self._get_variable_index(variable)]
189
+
190
+ lr_power = self.learning_rate_power
191
+ l2_reg = self.l2_regularization_strength
192
+ l2_reg = l2_reg + self.beta / (2.0 * lr)
193
+
194
+ grad_to_use = ops.add(
195
+ gradient,
196
+ ops.multiply(
197
+ 2 * self.l2_shrinkage_regularization_strength, variable
198
+ ),
199
+ )
200
+ new_accum = ops.add(accum, ops.square(gradient))
201
+ self.assign_add(
202
+ linear,
203
+ ops.subtract(
204
+ grad_to_use,
205
+ ops.multiply(
206
+ ops.divide(
207
+ ops.subtract(
208
+ ops.power(new_accum, -lr_power),
209
+ ops.power(accum, -lr_power),
210
+ ),
211
+ lr,
212
+ ),
213
+ variable,
214
+ ),
215
+ ),
216
+ )
217
+ quadratic = ops.add(
218
+ ops.divide(ops.power(new_accum, (-lr_power)), lr), 2 * l2_reg
219
+ )
220
+ linear_clipped = ops.clip(
221
+ linear,
222
+ -self.l1_regularization_strength,
223
+ self.l1_regularization_strength,
224
+ )
225
+ self.assign(
226
+ variable,
227
+ ops.divide(ops.subtract(linear_clipped, linear), quadratic),
228
+ )
229
+ self.assign(accum, new_accum)
230
+
231
+ def get_config(self):
232
+ config = super().get_config()
233
+
234
+ config.update(
235
+ {
236
+ "learning_rate_power": self.learning_rate_power,
237
+ "initial_accumulator_value": self.initial_accumulator_value,
238
+ "l1_regularization_strength": self.l1_regularization_strength,
239
+ "l2_regularization_strength": self.l2_regularization_strength,
240
+ "l2_shrinkage_regularization_strength": self.l2_shrinkage_regularization_strength, # noqa: E501
241
+ "beta": self.beta,
242
+ }
243
+ )
244
+ return config
245
+
246
+
247
+ Ftrl.__doc__ = Ftrl.__doc__.replace(
248
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
249
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/lamb.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import ops
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.optimizers import optimizer
4
+
5
+
6
+ @keras_export("keras.optimizers.Lamb")
7
+ class Lamb(optimizer.Optimizer):
8
+ """Optimizer that implements the Lamb algorithm.
9
+
10
+ Lamb is a stochastic gradient descent method that
11
+ uses layer-wise adaptive moments to adjusts the
12
+ learning rate for each parameter based on the ratio of the
13
+ norm of the weight to the norm of the gradient
14
+ This helps to stabilize the training process and improves convergence
15
+ especially for large batch sizes.
16
+
17
+ Args:
18
+ learning_rate: A float, a
19
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
20
+ a callable that takes no arguments and returns the actual value to
21
+ use. The learning rate. Defaults to `0.001`.
22
+ beta_1: A float value or a constant float tensor, or a callable
23
+ that takes no arguments and returns the actual value to use. The
24
+ exponential decay rate for the 1st moment estimates. Defaults to
25
+ `0.9`.
26
+ beta_2: A float value or a constant float tensor, or a callable
27
+ that takes no arguments and returns the actual value to use. The
28
+ exponential decay rate for the 2nd moment estimates. Defaults to
29
+ `0.999`.
30
+ epsilon: A small constant for numerical stability.
31
+ Defaults to `1e-7`.
32
+ {{base_optimizer_keyword_args}}
33
+
34
+ References:
35
+ - [Yang et al.](https://arxiv.org/pdf/1904.00962)
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ learning_rate=0.001,
41
+ beta_1=0.9,
42
+ beta_2=0.999,
43
+ epsilon=1e-7,
44
+ weight_decay=None,
45
+ clipnorm=None,
46
+ clipvalue=None,
47
+ global_clipnorm=None,
48
+ use_ema=False,
49
+ ema_momentum=0.99,
50
+ ema_overwrite_frequency=None,
51
+ loss_scale_factor=None,
52
+ gradient_accumulation_steps=None,
53
+ name="lamb",
54
+ **kwargs,
55
+ ):
56
+ super().__init__(
57
+ learning_rate=learning_rate,
58
+ name=name,
59
+ weight_decay=weight_decay,
60
+ clipnorm=clipnorm,
61
+ clipvalue=clipvalue,
62
+ global_clipnorm=global_clipnorm,
63
+ use_ema=use_ema,
64
+ ema_momentum=ema_momentum,
65
+ ema_overwrite_frequency=ema_overwrite_frequency,
66
+ loss_scale_factor=loss_scale_factor,
67
+ gradient_accumulation_steps=gradient_accumulation_steps,
68
+ **kwargs,
69
+ )
70
+ self.beta_1 = beta_1
71
+ self.beta_2 = beta_2
72
+ self.epsilon = epsilon
73
+
74
+ def build(self, var_list):
75
+ """Initialize optimizer variables.
76
+
77
+ Lamb optimizer has 2 types of variables: momentums and velocities
78
+
79
+ Args:
80
+ var_list: list of model variables to build Lamb variables on.
81
+ """
82
+ if self.built:
83
+ return
84
+ super().build(var_list)
85
+ self._momentums = []
86
+ self._velocities = []
87
+ for var in var_list:
88
+ self._momentums.append(
89
+ self.add_variable_from_reference(
90
+ reference_variable=var, name="momentum"
91
+ )
92
+ )
93
+ self._velocities.append(
94
+ self.add_variable_from_reference(
95
+ reference_variable=var, name="velocity"
96
+ )
97
+ )
98
+
99
+ def update_step(self, gradient, variable, learning_rate):
100
+ """Update step given gradient and the associated model variable."""
101
+ lr = ops.cast(learning_rate, variable.dtype)
102
+ gradient = ops.cast(gradient, variable.dtype)
103
+ local_step = ops.cast(self.iterations + 1, variable.dtype)
104
+
105
+ beta_1_power = ops.power(
106
+ ops.cast(self.beta_1, variable.dtype), local_step
107
+ )
108
+ beta_2_power = ops.power(
109
+ ops.cast(self.beta_2, variable.dtype), local_step
110
+ )
111
+
112
+ m = self._momentums[self._get_variable_index(variable)]
113
+ v = self._velocities[self._get_variable_index(variable)]
114
+
115
+ self.assign_add(
116
+ m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1)
117
+ )
118
+
119
+ self.assign_add(
120
+ v,
121
+ ops.multiply(
122
+ ops.subtract(ops.square(gradient), v), 1 - self.beta_2
123
+ ),
124
+ )
125
+
126
+ m_t_hat = ops.divide(m, (1.0 - beta_1_power))
127
+ v_sqrt = ops.add(
128
+ ops.sqrt(ops.divide(v, (1.0 - beta_2_power))), self.epsilon
129
+ )
130
+
131
+ update = ops.divide(m_t_hat, v_sqrt)
132
+ w_norm = ops.sqrt(ops.sum(ops.power(variable, 2)))
133
+ g_norm = ops.sqrt(ops.sum(ops.power(update, 2)))
134
+
135
+ # ratio = w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1
136
+ ratio = ops.where(
137
+ ops.greater(w_norm, 0),
138
+ ops.where(ops.greater(g_norm, 0), (w_norm / g_norm), 1.0),
139
+ 1.0,
140
+ )
141
+
142
+ self.assign_sub(variable, ratio * lr * update)
143
+
144
+ def get_config(self):
145
+ config = super().get_config()
146
+ config.update(
147
+ {
148
+ "beta_1": self.beta_1,
149
+ "beta_2": self.beta_2,
150
+ "epsilon": self.epsilon,
151
+ }
152
+ )
153
+ return config
154
+
155
+
156
+ Lamb.__doc__ = Lamb.__doc__.replace(
157
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
158
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/lion.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import ops
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.optimizers import optimizer
4
+
5
+
6
+ @keras_export(["keras.optimizers.Lion"])
7
+ class Lion(optimizer.Optimizer):
8
+ """Optimizer that implements the Lion algorithm.
9
+
10
+ The Lion optimizer is a stochastic-gradient-descent method that uses the
11
+ sign operator to control the magnitude of the update, unlike other adaptive
12
+ optimizers such as Adam that rely on second-order moments. This make
13
+ Lion more memory-efficient as it only keeps track of the momentum. According
14
+ to the authors (see reference), its performance gain over Adam grows with
15
+ the batch size. Because the update of Lion is produced through the sign
16
+ operation, resulting in a larger norm, a suitable learning rate for Lion is
17
+ typically 3-10x smaller than that for AdamW. The weight decay for Lion
18
+ should be in turn 3-10x larger than that for AdamW to maintain a
19
+ similar strength (lr * wd).
20
+
21
+ Args:
22
+ learning_rate: A float, a
23
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
24
+ a callable that takes no arguments and returns the actual value to
25
+ use. The learning rate. Defaults to `0.001`.
26
+ beta_1: A float value or a constant float tensor, or a callable
27
+ that takes no arguments and returns the actual value to use. The
28
+ rate to combine the current gradient and the 1st moment estimate.
29
+ Defaults to `0.9`.
30
+ beta_2: A float value or a constant float tensor, or a callable
31
+ that takes no arguments and returns the actual value to use. The
32
+ exponential decay rate for the 1st moment estimate. Defaults to
33
+ `0.99`.
34
+ {{base_optimizer_keyword_args}}
35
+
36
+ References:
37
+
38
+ - [Chen et al., 2023](http://arxiv.org/abs/2302.06675)
39
+ - [Authors' implementation](
40
+ http://github.com/google/automl/tree/master/lion)
41
+
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ learning_rate=0.001,
47
+ beta_1=0.9,
48
+ beta_2=0.99,
49
+ weight_decay=None,
50
+ clipnorm=None,
51
+ clipvalue=None,
52
+ global_clipnorm=None,
53
+ use_ema=False,
54
+ ema_momentum=0.99,
55
+ ema_overwrite_frequency=None,
56
+ loss_scale_factor=None,
57
+ gradient_accumulation_steps=None,
58
+ name="lion",
59
+ **kwargs,
60
+ ):
61
+ super().__init__(
62
+ learning_rate=learning_rate,
63
+ name=name,
64
+ weight_decay=weight_decay,
65
+ clipnorm=clipnorm,
66
+ clipvalue=clipvalue,
67
+ global_clipnorm=global_clipnorm,
68
+ use_ema=use_ema,
69
+ ema_momentum=ema_momentum,
70
+ ema_overwrite_frequency=ema_overwrite_frequency,
71
+ loss_scale_factor=loss_scale_factor,
72
+ gradient_accumulation_steps=gradient_accumulation_steps,
73
+ **kwargs,
74
+ )
75
+ self.beta_1 = beta_1
76
+ self.beta_2 = beta_2
77
+ if beta_1 <= 0 or beta_1 > 1:
78
+ raise ValueError(
79
+ "Argument `beta_1` must be in the [0, 1] range. Otherwise, the "
80
+ f"optimizer degenerates to SignSGD. Received: beta_1={beta_1}."
81
+ )
82
+
83
+ def build(self, var_list):
84
+ """Initialize optimizer variables.
85
+
86
+ Lion optimizer has one variable `momentums`.
87
+
88
+ Args:
89
+ var_list: list of model variables to build Lion variables on.
90
+ """
91
+ if self.built:
92
+ return
93
+ super().build(var_list)
94
+ self._momentums = []
95
+ for var in var_list:
96
+ self._momentums.append(
97
+ self.add_variable_from_reference(
98
+ reference_variable=var, name="momentum"
99
+ )
100
+ )
101
+
102
+ def update_step(self, gradient, variable, learning_rate):
103
+ """Update step given gradient and the associated model variable."""
104
+ lr = ops.cast(learning_rate, variable.dtype)
105
+ gradient = ops.cast(gradient, variable.dtype)
106
+ beta_1 = ops.cast(self.beta_1, variable.dtype)
107
+ beta_2 = ops.cast(self.beta_2, variable.dtype)
108
+ m = self._momentums[self._get_variable_index(variable)]
109
+
110
+ self.assign_sub(
111
+ variable,
112
+ ops.multiply(
113
+ lr,
114
+ ops.sign(
115
+ ops.add(
116
+ ops.multiply(m, beta_1),
117
+ ops.multiply(gradient, (1.0 - beta_1)),
118
+ )
119
+ ),
120
+ ),
121
+ )
122
+ self.assign(
123
+ m,
124
+ ops.add(
125
+ ops.multiply(m, beta_2), ops.multiply(gradient, (1.0 - beta_2))
126
+ ),
127
+ )
128
+
129
+ def get_config(self):
130
+ config = super().get_config()
131
+ config.update(
132
+ {
133
+ "beta_1": self.beta_1,
134
+ "beta_2": self.beta_2,
135
+ }
136
+ )
137
+ return config
138
+
139
+
140
+ Lion.__doc__ = Lion.__doc__.replace(
141
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
142
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/loss_scale_optimizer.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src import initializers
3
+ from keras.src import ops
4
+ from keras.src.api_export import keras_export
5
+ from keras.src.optimizers import optimizer
6
+ from keras.src.saving import serialization_lib
7
+ from keras.src.utils import tracking
8
+
9
+
10
+ @keras_export(
11
+ [
12
+ "keras.optimizers.LossScaleOptimizer",
13
+ "keras.mixed_precision.LossScaleOptimizer",
14
+ ]
15
+ )
16
+ class LossScaleOptimizer(optimizer.Optimizer):
17
+ """An optimizer that dynamically scales the loss to prevent underflow.
18
+
19
+ Loss scaling is a technique to prevent numeric underflow in intermediate
20
+ gradients when float16 is used. To prevent underflow, the loss is multiplied
21
+ (or "scaled") by a certain factor called the "loss scale", which causes
22
+ intermediate gradients to be scaled by the loss scale as well. The final
23
+ gradients are divided (or "unscaled") by the loss scale to bring them back
24
+ to their original value.
25
+
26
+ `LossScaleOptimizer` wraps another optimizer and applies dynamic loss
27
+ scaling to it. This loss scale is dynamically updated over time as follows:
28
+ - On any train step, if a nonfinite gradient is encountered, the loss scale
29
+ is halved, and the train step is skipped.
30
+ - If `dynamic_growth_steps` have occurred since the last time the loss scale
31
+ was updated, and no nonfinite gradients have occurred, the loss scale
32
+ is doubled.
33
+
34
+ Args:
35
+ inner_optimizer: The `keras.optimizers.Optimizer` instance to wrap.
36
+ initial_scale: Float. The initial loss scale. This scale will be updated
37
+ during training. It is recommended for this to be a very high
38
+ number, because a loss scale that is too high gets lowered far more
39
+ quickly than a loss scale that is too low gets raised.
40
+ dynamic_growth_steps: Int. How often to update the scale upwards. After
41
+ every `dynamic_growth_steps` steps with finite gradients, the
42
+ loss scale is doubled.
43
+ {{base_optimizer_keyword_args}}
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ inner_optimizer,
49
+ initial_scale=2.0**15,
50
+ dynamic_growth_steps=2000,
51
+ **kwargs,
52
+ ):
53
+ if not kwargs.pop("dynamic", True):
54
+ raise ValueError(
55
+ "LossScaleOptimizer no longer supports `dynamic=False`. "
56
+ "Instead, simply set `loss_scale_factor` directly on the "
57
+ "`inner_optimizer`."
58
+ )
59
+ super().__init__(learning_rate=0.0, **kwargs)
60
+ self.inner_optimizer = inner_optimizer
61
+ self.initial_scale = initial_scale
62
+ self.dynamic_growth_steps = dynamic_growth_steps
63
+
64
+ @tracking.no_automatic_dependency_tracking
65
+ def build(self, var_list):
66
+ self.step_counter = self.add_variable(
67
+ shape=(),
68
+ dtype="int",
69
+ initializer=initializers.Zeros(),
70
+ aggregation="none",
71
+ name="step_counter",
72
+ )
73
+ self.dynamic_scale = self.add_variable(
74
+ shape=(),
75
+ dtype="float32",
76
+ initializer=initializers.Constant(self.initial_scale),
77
+ aggregation="none",
78
+ name="dynamic_scale",
79
+ )
80
+ self.inner_optimizer.build(var_list)
81
+ self.built = True
82
+
83
+ @property
84
+ def variables(self):
85
+ return self._variables + self.inner_optimizer.variables
86
+
87
+ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
88
+ if not self.built:
89
+ raise ValueError(
90
+ f"To call `stateless_apply`, {self.__class__.__name__} "
91
+ "must be built (i.e. its variables must have been created). "
92
+ "You can build it via `optimizer.build(trainable_variables)`."
93
+ )
94
+ finite = self.check_finite(grads)
95
+ return ops.cond(
96
+ finite,
97
+ lambda: self._stateless_handle_finite_grads(
98
+ optimizer_variables, grads, trainable_variables
99
+ ),
100
+ lambda: self._stateless_handle_non_finite_grads(
101
+ optimizer_variables, trainable_variables
102
+ ),
103
+ )
104
+
105
+ def _stateless_handle_finite_grads(
106
+ self, optimizer_variables, grads, trainable_variables
107
+ ):
108
+ def upscale():
109
+ mapping = list(zip(self.variables, optimizer_variables))
110
+ with backend.StatelessScope(state_mapping=mapping) as scope:
111
+ self.step_counter.assign(0)
112
+ self.dynamic_scale.assign(self.dynamic_scale * 2.0)
113
+ return [scope.get_current_value(v) for v in self._variables]
114
+
115
+ def increment():
116
+ mapping = list(zip(self.variables, optimizer_variables))
117
+ with backend.StatelessScope(state_mapping=mapping) as scope:
118
+ self.step_counter.assign_add(1)
119
+ return [scope.get_current_value(v) for v in self._variables]
120
+
121
+ mapping = list(zip(self.variables, optimizer_variables))
122
+ with backend.StatelessScope(state_mapping=mapping):
123
+ # Potentially upscale loss and reset counter.
124
+ own_variables = ops.cond(
125
+ ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
126
+ upscale,
127
+ increment,
128
+ )
129
+
130
+ # Unscale gradients.
131
+ scale = self.dynamic_scale
132
+ unscaled_grads = [
133
+ g if g is None else ops.divide(g, scale) for g in grads
134
+ ]
135
+ (
136
+ new_trainable_variables,
137
+ new_inner_variables,
138
+ ) = self.inner_optimizer.stateless_apply(
139
+ self.inner_optimizer.variables,
140
+ unscaled_grads,
141
+ trainable_variables,
142
+ )
143
+
144
+ new_optimizer_variables = own_variables + new_inner_variables
145
+ return new_trainable_variables, new_optimizer_variables
146
+
147
+ def _stateless_handle_non_finite_grads(
148
+ self, optimizer_variables, trainable_variables
149
+ ):
150
+ mapping = list(zip(self.variables, optimizer_variables))
151
+ with backend.StatelessScope(state_mapping=mapping) as scope:
152
+ self.step_counter.assign(0)
153
+ self.dynamic_scale.assign(self.dynamic_scale / 2.0)
154
+ new_optimizer_variables = []
155
+ for v in self.variables:
156
+ new_optimizer_variables.append(scope.get_current_value(v))
157
+ return trainable_variables, new_optimizer_variables
158
+
159
+ def apply(self, grads, trainable_variables=None):
160
+ # Optionally build optimizer.
161
+ if not self.built:
162
+ with backend.name_scope(self.name, caller=self):
163
+ self.build(trainable_variables)
164
+ self.built = True
165
+
166
+ if backend.backend() == "tensorflow":
167
+ self._tf_apply(grads, trainable_variables)
168
+ else:
169
+ self._common_apply(grads, trainable_variables)
170
+
171
+ def _stateful_handle_finite_grads(self, grads, trainable_variables):
172
+ scale = self.dynamic_scale
173
+ # Unscale gradients.
174
+ unscaled_grads = [
175
+ g if g is None else ops.divide(g, scale) for g in grads
176
+ ]
177
+ self.inner_optimizer.apply(
178
+ unscaled_grads, trainable_variables=trainable_variables
179
+ )
180
+
181
+ def upscale():
182
+ self.step_counter.assign(0)
183
+ self.dynamic_scale.assign(self.dynamic_scale * 2.0)
184
+
185
+ def increment():
186
+ self.step_counter.assign_add(1)
187
+
188
+ # Potentially upscale loss and reset counter.
189
+ ops.cond(
190
+ ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
191
+ upscale,
192
+ increment,
193
+ )
194
+
195
+ def _stateful_handle_non_finite_grads(self):
196
+ # If any inf or nan in grads, downscale loss and reset counter.
197
+ self.step_counter.assign(0)
198
+ self.dynamic_scale.assign(self.dynamic_scale / 2.0)
199
+
200
+ def _common_apply(self, grads, trainable_variables=None):
201
+ finite = self.check_finite(grads)
202
+ ops.cond(
203
+ finite,
204
+ lambda: self._stateful_handle_finite_grads(
205
+ grads, trainable_variables
206
+ ),
207
+ self._stateful_handle_non_finite_grads,
208
+ )
209
+
210
+ def _tf_apply(self, grads, trainable_variables=None):
211
+ """Tensorflow specific logic for apply, which handles distribution."""
212
+ from keras.src.utils.module_utils import tensorflow as tf
213
+
214
+ if tf.distribute.in_cross_replica_context():
215
+ raise ValueError("apply() must be called in a replica context.")
216
+
217
+ if tf.__internal__.distribute.strategy_supports_no_merge_call():
218
+ self._common_apply(grads, trainable_variables=trainable_variables)
219
+ else:
220
+
221
+ def _handle_cross_replica(distribution, grads, trainable_variables):
222
+ finite_per_replica = (
223
+ distribution.extended.call_for_each_replica(
224
+ self.check_finite, args=(grads,)
225
+ )
226
+ )
227
+ # Each replica computed the same `finite` value, since
228
+ # `grads` is all-reduced across replicas. Arbitrarily take
229
+ # `finite` from the first replica.
230
+ finite = distribution.experimental_local_results(
231
+ finite_per_replica
232
+ )[0]
233
+
234
+ def apply_fn():
235
+ distribution.extended.call_for_each_replica(
236
+ self._stateful_handle_finite_grads,
237
+ args=(grads, trainable_variables),
238
+ )
239
+
240
+ # Note: We must call this cond() in a cross-replica context.
241
+ # DistributionStrategy does not support having a cond in a
242
+ # replica context with a branch that calls `merge_call`, and
243
+ # self._optimizer.apply_gradients calls `merge_call`.
244
+ ops.cond(
245
+ finite, apply_fn, self._stateful_handle_non_finite_grads
246
+ )
247
+
248
+ tf.distribute.get_replica_context().merge_call(
249
+ _handle_cross_replica, args=(grads, trainable_variables)
250
+ )
251
+
252
+ def check_finite(self, grads):
253
+ tensor_grads = [g for g in grads if g is not None]
254
+ finite_grads = [ops.all(ops.isfinite(g)) for g in tensor_grads]
255
+ return ops.all(ops.convert_to_tensor(finite_grads))
256
+
257
+ @property
258
+ def learning_rate(self):
259
+ return self.inner_optimizer.learning_rate
260
+
261
+ @learning_rate.setter
262
+ def learning_rate(self, learning_rate):
263
+ self.inner_optimizer.learning_rate = learning_rate
264
+
265
+ def scale_loss(self, loss):
266
+ scale = self.dynamic_scale if self.built else self.initial_scale
267
+ return loss * scale
268
+
269
+ def finalize_variable_values(self, var_list):
270
+ self.inner_optimizer.finalize_variable_values(var_list)
271
+
272
+ def get_config(self):
273
+ config = super().get_config()
274
+ inner_optimizer_config = serialization_lib.serialize_keras_object(
275
+ self.inner_optimizer
276
+ )
277
+ config.update(
278
+ {
279
+ "inner_optimizer": inner_optimizer_config,
280
+ "initial_scale": self.initial_scale,
281
+ "dynamic_growth_steps": self.dynamic_growth_steps,
282
+ }
283
+ )
284
+ del config["learning_rate"]
285
+ return config
286
+
287
+ @classmethod
288
+ def from_config(cls, config, custom_objects=None):
289
+ inner_optimizer = serialization_lib.deserialize_keras_object(
290
+ config.pop("inner_optimizer"),
291
+ custom_objects=custom_objects,
292
+ )
293
+ return cls(inner_optimizer, **config)
294
+
295
+
296
+ LossScaleOptimizer.__doc__ = LossScaleOptimizer.__doc__.replace(
297
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
298
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/nadam.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src import ops
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.optimizers import optimizer
5
+
6
+
7
+ @keras_export(["keras.optimizers.Nadam"])
8
+ class Nadam(optimizer.Optimizer):
9
+ """Optimizer that implements the Nadam algorithm.
10
+
11
+ Much like Adam is essentially RMSprop with momentum, Nadam is Adam with
12
+ Nesterov momentum.
13
+
14
+ Args:
15
+ learning_rate: A float, a
16
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
17
+ a callable that takes no arguments and returns the actual value to
18
+ use. The learning rate. Defaults to `0.001`.
19
+ beta_1: A float value or a constant float tensor, or a callable
20
+ that takes no arguments and returns the actual value to use. The
21
+ exponential decay rate for the 1st moment estimates.
22
+ Defaults to `0.9`.
23
+ beta_2: A float value or a constant float tensor, or a callable
24
+ that takes no arguments and returns the actual value to use. The
25
+ exponential decay rate for the 2nd moment estimates. Defaults to
26
+ `0.999`.
27
+ epsilon: A small constant for numerical stability. This epsilon is
28
+ "epsilon hat" in the Kingma and Ba paper (in the formula just before
29
+ Section 2.1), not the epsilon in Algorithm 1 of the paper.
30
+ Defaults to `1e-7`.
31
+ {{base_optimizer_keyword_args}}
32
+
33
+ Reference:
34
+
35
+ - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf).
36
+
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ learning_rate=0.001,
42
+ beta_1=0.9,
43
+ beta_2=0.999,
44
+ epsilon=1e-7,
45
+ weight_decay=None,
46
+ clipnorm=None,
47
+ clipvalue=None,
48
+ global_clipnorm=None,
49
+ use_ema=False,
50
+ ema_momentum=0.99,
51
+ ema_overwrite_frequency=None,
52
+ loss_scale_factor=None,
53
+ gradient_accumulation_steps=None,
54
+ name="nadam",
55
+ **kwargs,
56
+ ):
57
+ super().__init__(
58
+ learning_rate=learning_rate,
59
+ name=name,
60
+ weight_decay=weight_decay,
61
+ clipnorm=clipnorm,
62
+ clipvalue=clipvalue,
63
+ global_clipnorm=global_clipnorm,
64
+ use_ema=use_ema,
65
+ ema_momentum=ema_momentum,
66
+ ema_overwrite_frequency=ema_overwrite_frequency,
67
+ loss_scale_factor=loss_scale_factor,
68
+ gradient_accumulation_steps=gradient_accumulation_steps,
69
+ **kwargs,
70
+ )
71
+ self.beta_1 = beta_1
72
+ self.beta_2 = beta_2
73
+ self.epsilon = epsilon
74
+
75
+ def build(self, var_list):
76
+ """Initialize optimizer variables.
77
+
78
+ Nadam optimizer has 2 types of variables: momentums and velocities.
79
+
80
+ Args:
81
+ var_list: list of model variables to build Nadam variables on.
82
+ """
83
+ if self.built:
84
+ return
85
+ if var_list:
86
+ dtype = var_list[0].dtype
87
+ else:
88
+ dtype = backend.floatx()
89
+ super().build(var_list)
90
+ self._momentums = []
91
+ self._velocities = []
92
+ self._u_product = backend.Variable(1.0, dtype=dtype)
93
+
94
+ for var in var_list:
95
+ self._momentums.append(
96
+ self.add_variable_from_reference(
97
+ reference_variable=var, name="momentum"
98
+ )
99
+ )
100
+ self._velocities.append(
101
+ self.add_variable_from_reference(
102
+ reference_variable=var, name="velocity"
103
+ )
104
+ )
105
+
106
+ def _backend_update_step(self, grads, trainable_variables, learning_rate):
107
+ dtype = self._u_product.dtype
108
+ self.assign(
109
+ self._u_product,
110
+ self._u_product
111
+ * self.beta_1
112
+ * (
113
+ 1.0
114
+ - 0.5 * ops.power(0.96, ops.cast(self.iterations + 1, dtype))
115
+ ),
116
+ )
117
+ super()._backend_update_step(grads, trainable_variables, learning_rate)
118
+
119
+ def update_step(self, gradient, variable, learning_rate):
120
+ """Update step given gradient and the associated model variable."""
121
+ var_dtype = variable.dtype
122
+ lr = ops.cast(learning_rate, var_dtype)
123
+ gradient = ops.cast(gradient, var_dtype)
124
+
125
+ local_step = ops.cast(self.iterations + 1, var_dtype)
126
+ next_step = ops.cast(self.iterations + 2, var_dtype)
127
+ decay = ops.cast(0.96, var_dtype)
128
+ beta_1 = ops.cast(self.beta_1, var_dtype)
129
+ beta_2 = ops.cast(self.beta_2, var_dtype)
130
+ u_t = beta_1 * (1.0 - 0.5 * (ops.power(decay, local_step)))
131
+ u_t_1 = beta_1 * (1.0 - 0.5 * (ops.power(decay, next_step)))
132
+ u_product_t = ops.cast(self._u_product, var_dtype)
133
+
134
+ u_product_t_1 = u_product_t * u_t_1
135
+ beta_2_power = ops.power(beta_2, local_step)
136
+
137
+ m = self._momentums[self._get_variable_index(variable)]
138
+ v = self._velocities[self._get_variable_index(variable)]
139
+
140
+ self.assign_add(
141
+ m, ops.multiply(ops.subtract(gradient, m), (1 - beta_1))
142
+ )
143
+ self.assign_add(
144
+ v, ops.multiply(ops.subtract(ops.square(gradient), v), (1 - beta_2))
145
+ )
146
+ m_hat = ops.add(
147
+ ops.divide(ops.multiply(u_t_1, m), 1 - u_product_t_1),
148
+ ops.divide(ops.multiply(1 - u_t, gradient), 1 - u_product_t),
149
+ )
150
+ v_hat = ops.divide(v, (1 - beta_2_power))
151
+
152
+ self.assign_sub(
153
+ variable,
154
+ ops.divide(
155
+ ops.multiply(m_hat, lr), ops.add(ops.sqrt(v_hat), self.epsilon)
156
+ ),
157
+ )
158
+
159
+ def get_config(self):
160
+ config = super().get_config()
161
+
162
+ config.update(
163
+ {
164
+ "beta_1": self.beta_1,
165
+ "beta_2": self.beta_2,
166
+ "epsilon": self.epsilon,
167
+ }
168
+ )
169
+ return config
170
+
171
+
172
+ Nadam.__doc__ = Nadam.__doc__.replace(
173
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
174
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/optimizer.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.optimizers import base_optimizer
4
+
5
+ if backend.backend() == "tensorflow":
6
+ from keras.src.backend.tensorflow.optimizer import (
7
+ TFOptimizer as BackendOptimizer,
8
+ )
9
+ elif backend.backend() == "torch":
10
+ from keras.src.backend.torch.optimizers import (
11
+ TorchOptimizer as BackendOptimizer,
12
+ )
13
+ elif backend.backend() == "jax":
14
+ from keras.src.backend.jax.optimizer import JaxOptimizer as BackendOptimizer
15
+ else:
16
+
17
+ class BackendOptimizer(base_optimizer.BaseOptimizer):
18
+ pass
19
+
20
+
21
+ @keras_export(["keras.Optimizer", "keras.optimizers.Optimizer"])
22
+ class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer):
23
+ pass
24
+
25
+
26
+ Optimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__
27
+ base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/rmsprop.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import ops
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.optimizers import optimizer
4
+
5
+
6
+ @keras_export(["keras.optimizers.RMSprop"])
7
+ class RMSprop(optimizer.Optimizer):
8
+ """Optimizer that implements the RMSprop algorithm.
9
+
10
+ The gist of RMSprop is to:
11
+
12
+ - Maintain a moving (discounted) average of the square of gradients
13
+ - Divide the gradient by the root of this average
14
+
15
+ This implementation of RMSprop uses plain momentum, not Nesterov momentum.
16
+
17
+ The centered version additionally maintains a moving average of the
18
+ gradients, and uses that average to estimate the variance.
19
+
20
+ Args:
21
+ learning_rate: A float, a
22
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
23
+ a callable that takes no arguments and returns the actual value to
24
+ use. The learning rate. Defaults to `0.001`.
25
+ rho: float, defaults to 0.9. Discounting factor for the old gradients.
26
+ momentum: float, defaults to 0.0. If not 0.0., the optimizer tracks the
27
+ momentum value, with a decay rate equals to `1 - momentum`.
28
+ epsilon: A small constant for numerical stability. This epsilon is
29
+ "epsilon hat" in the Kingma and Ba paper (in the formula just before
30
+ Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults
31
+ to 1e-7.
32
+ centered: Boolean. If `True`, gradients are normalized by the estimated
33
+ variance of the gradient; if False, by the uncentered second moment.
34
+ Setting this to `True` may help with training, but is slightly more
35
+ expensive in terms of computation and memory. Defaults to `False`.
36
+ {{base_optimizer_keyword_args}}
37
+
38
+ Example:
39
+
40
+ >>> opt = keras.optimizers.RMSprop(learning_rate=0.1)
41
+ >>> var1 = keras.backend.Variable(10.0)
42
+ >>> loss = lambda: (var1 ** 2) / 2.0 # d(loss) / d(var1) = var1
43
+ >>> opt.minimize(loss, [var1])
44
+ >>> var1
45
+ 9.683772
46
+
47
+ Reference:
48
+
49
+ - [Hinton, 2012](
50
+ http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ learning_rate=0.001,
56
+ rho=0.9,
57
+ momentum=0.0,
58
+ epsilon=1e-7,
59
+ centered=False,
60
+ weight_decay=None,
61
+ clipnorm=None,
62
+ clipvalue=None,
63
+ global_clipnorm=None,
64
+ use_ema=False,
65
+ ema_momentum=0.99,
66
+ ema_overwrite_frequency=None,
67
+ loss_scale_factor=None,
68
+ gradient_accumulation_steps=None,
69
+ name="rmsprop",
70
+ **kwargs,
71
+ ):
72
+ super().__init__(
73
+ learning_rate=learning_rate,
74
+ weight_decay=weight_decay,
75
+ clipnorm=clipnorm,
76
+ clipvalue=clipvalue,
77
+ global_clipnorm=global_clipnorm,
78
+ use_ema=use_ema,
79
+ ema_momentum=ema_momentum,
80
+ ema_overwrite_frequency=ema_overwrite_frequency,
81
+ loss_scale_factor=loss_scale_factor,
82
+ gradient_accumulation_steps=gradient_accumulation_steps,
83
+ name=name,
84
+ **kwargs,
85
+ )
86
+ self.rho = rho
87
+ self.momentum = momentum
88
+ self.epsilon = epsilon
89
+ self.centered = centered
90
+
91
+ def build(self, var_list):
92
+ if self.built:
93
+ return
94
+
95
+ super().build(var_list)
96
+
97
+ self._velocities = []
98
+ for var in var_list:
99
+ self._velocities.append(
100
+ self.add_variable_from_reference(var, "velocity")
101
+ )
102
+
103
+ self._momentums = []
104
+ if self.momentum > 0:
105
+ for var in var_list:
106
+ self._momentums.append(
107
+ self.add_variable_from_reference(var, "momentum")
108
+ )
109
+
110
+ self._average_gradients = []
111
+ if self.centered:
112
+ for var in var_list:
113
+ self._average_gradients.append(
114
+ self.add_variable_from_reference(var, "average_gradient")
115
+ )
116
+
117
+ def update_step(self, gradient, variable, learning_rate):
118
+ """Update step given gradient and the associated model variable."""
119
+ lr = ops.cast(learning_rate, variable.dtype)
120
+ gradient = ops.cast(gradient, variable.dtype)
121
+
122
+ velocity = self._velocities[self._get_variable_index(variable)]
123
+ momentum = None
124
+ if self.momentum > 0:
125
+ momentum = self._momentums[self._get_variable_index(variable)]
126
+ average_grad = None
127
+ if self.centered:
128
+ average_grad = self._average_gradients[
129
+ self._get_variable_index(variable)
130
+ ]
131
+
132
+ rho = self.rho
133
+
134
+ self.assign(
135
+ velocity,
136
+ ops.add(
137
+ ops.multiply(rho, velocity),
138
+ ops.multiply(1 - rho, ops.square(gradient)),
139
+ ),
140
+ )
141
+ if self.centered:
142
+ self.assign(
143
+ average_grad,
144
+ ops.add(
145
+ ops.multiply(rho, average_grad),
146
+ ops.multiply(1 - rho, gradient),
147
+ ),
148
+ )
149
+ denominator = velocity - ops.square(average_grad) + self.epsilon
150
+ else:
151
+ denominator = ops.add(velocity, self.epsilon)
152
+ increment = ops.divide(
153
+ ops.multiply(lr, gradient), ops.sqrt(denominator)
154
+ )
155
+ if self.momentum > 0:
156
+ self.assign(
157
+ momentum,
158
+ ops.add(ops.multiply(self.momentum, momentum), increment),
159
+ )
160
+ self.assign_sub(variable, momentum)
161
+ else:
162
+ self.assign_sub(variable, increment)
163
+
164
+ def get_config(self):
165
+ config = super().get_config()
166
+
167
+ config.update(
168
+ {
169
+ "rho": self.rho,
170
+ "momentum": self.momentum,
171
+ "epsilon": self.epsilon,
172
+ "centered": self.centered,
173
+ }
174
+ )
175
+ return config
176
+
177
+
178
+ RMSprop.__doc__ = RMSprop.__doc__.replace(
179
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
180
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src.optimizers.schedules.learning_rate_schedule import CosineDecay
2
+ from keras.src.optimizers.schedules.learning_rate_schedule import (
3
+ CosineDecayRestarts,
4
+ )
5
+ from keras.src.optimizers.schedules.learning_rate_schedule import (
6
+ ExponentialDecay,
7
+ )
8
+ from keras.src.optimizers.schedules.learning_rate_schedule import (
9
+ InverseTimeDecay,
10
+ )
11
+ from keras.src.optimizers.schedules.learning_rate_schedule import (
12
+ PiecewiseConstantDecay,
13
+ )
14
+ from keras.src.optimizers.schedules.learning_rate_schedule import (
15
+ PolynomialDecay,
16
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (498 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__pycache__/learning_rate_schedule.cpython-310.pyc ADDED
Binary file (30.4 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/learning_rate_schedule.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Various learning rate schedule functions."""
2
+
3
+ import math
4
+
5
+ from keras.src import ops
6
+ from keras.src.api_export import keras_export
7
+ from keras.src.saving import serialization_lib
8
+
9
+
10
+ @keras_export("keras.optimizers.schedules.LearningRateSchedule")
11
+ class LearningRateSchedule:
12
+ """The learning rate schedule base class.
13
+
14
+ You can use a learning rate schedule to modulate how the learning rate
15
+ of your optimizer changes over time.
16
+
17
+ Several built-in learning rate schedules are available, such as
18
+ `keras.optimizers.schedules.ExponentialDecay` or
19
+ `keras.optimizers.schedules.PiecewiseConstantDecay`:
20
+
21
+ ```python
22
+ lr_schedule = keras.optimizers.schedules.ExponentialDecay(
23
+ initial_learning_rate=1e-2,
24
+ decay_steps=10000,
25
+ decay_rate=0.9)
26
+ optimizer = keras.optimizers.SGD(learning_rate=lr_schedule)
27
+ ```
28
+
29
+ A `LearningRateSchedule` instance can be passed in as the `learning_rate`
30
+ argument of any optimizer.
31
+
32
+ To implement your own schedule object, you should implement the `__call__`
33
+ method, which takes a `step` argument (scalar integer tensor, the
34
+ current training step count).
35
+ Like for any other Keras object, you can also optionally
36
+ make your object serializable by implementing the `get_config`
37
+ and `from_config` methods.
38
+
39
+ Example:
40
+
41
+ ```python
42
+ class MyLRSchedule(keras.optimizers.schedules.LearningRateSchedule):
43
+
44
+ def __init__(self, initial_learning_rate):
45
+ self.initial_learning_rate = initial_learning_rate
46
+
47
+ def __call__(self, step):
48
+ return self.initial_learning_rate / (step + 1)
49
+
50
+ optimizer = keras.optimizers.SGD(learning_rate=MyLRSchedule(0.1))
51
+ ```
52
+ """
53
+
54
+ def __call__(self, step):
55
+ raise NotImplementedError(
56
+ f"Learning rate schedule '{self.__class__.__name__}' "
57
+ "must override `__call__(self, step)`."
58
+ )
59
+
60
+ def get_config(self):
61
+ raise NotImplementedError(
62
+ f"Learning rate schedule '{self.__class__.__name__}' "
63
+ "must override `get_config()` in order to be serializable."
64
+ )
65
+
66
+ @classmethod
67
+ def from_config(cls, config):
68
+ """Instantiates a `LearningRateSchedule` from its config.
69
+
70
+ Args:
71
+ config: Output of `get_config()`.
72
+
73
+ Returns:
74
+ A `LearningRateSchedule` instance.
75
+ """
76
+ return cls(**config)
77
+
78
+
79
+ @keras_export("keras.optimizers.schedules.ExponentialDecay")
80
+ class ExponentialDecay(LearningRateSchedule):
81
+ """A `LearningRateSchedule` that uses an exponential decay schedule.
82
+
83
+ When training a model, it is often useful to lower the learning rate as
84
+ the training progresses. This schedule applies an exponential decay function
85
+ to an optimizer step, given a provided initial learning rate.
86
+
87
+ The schedule is a 1-arg callable that produces a decayed learning
88
+ rate when passed the current optimizer step. This can be useful for changing
89
+ the learning rate value across different invocations of optimizer functions.
90
+ It is computed as:
91
+
92
+ ```python
93
+ def decayed_learning_rate(step):
94
+ return initial_learning_rate * decay_rate ^ (step / decay_steps)
95
+ ```
96
+
97
+ If the argument `staircase` is `True`, then `step / decay_steps` is
98
+ an integer division and the decayed learning rate follows a
99
+ staircase function.
100
+
101
+ You can pass this schedule directly into a `keras.optimizers.Optimizer`
102
+ as the learning rate.
103
+ Example: When fitting a Keras model, decay every 100000 steps with a base
104
+ of 0.96:
105
+
106
+ ```python
107
+ initial_learning_rate = 0.1
108
+ lr_schedule = keras.optimizers.schedules.ExponentialDecay(
109
+ initial_learning_rate,
110
+ decay_steps=100000,
111
+ decay_rate=0.96,
112
+ staircase=True)
113
+
114
+ model.compile(optimizer=keras.optimizers.SGD(learning_rate=lr_schedule),
115
+ loss='sparse_categorical_crossentropy',
116
+ metrics=['accuracy'])
117
+
118
+ model.fit(data, labels, epochs=5)
119
+ ```
120
+
121
+ The learning rate schedule is also serializable and deserializable using
122
+ `keras.optimizers.schedules.serialize` and
123
+ `keras.optimizers.schedules.deserialize`.
124
+
125
+ Args:
126
+ initial_learning_rate: A Python float. The initial learning rate.
127
+ decay_steps: A Python integer. Must be positive. See the decay
128
+ computation above.
129
+ decay_rate: A Python float. The decay rate.
130
+ staircase: Boolean. If `True` decay the learning rate at discrete
131
+ intervals.
132
+ name: String. Optional name of the operation. Defaults to
133
+ `"ExponentialDecay`".
134
+
135
+ Returns:
136
+ A 1-arg callable learning rate schedule that takes the current optimizer
137
+ step and outputs the decayed learning rate, a scalar tensor of the
138
+ same type as `initial_learning_rate`.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ initial_learning_rate,
144
+ decay_steps,
145
+ decay_rate,
146
+ staircase=False,
147
+ name="ExponentialDecay",
148
+ ):
149
+ super().__init__()
150
+ self.initial_learning_rate = initial_learning_rate
151
+ self.decay_steps = decay_steps
152
+ self.decay_rate = decay_rate
153
+ self.staircase = staircase
154
+ self.name = name
155
+
156
+ if self.decay_steps <= 0:
157
+ raise ValueError(
158
+ "Argument `decay_steps` must be > 0. "
159
+ f"Received: decay_steps={self.decay_steps}"
160
+ )
161
+
162
+ def __call__(self, step):
163
+ with ops.name_scope(self.name):
164
+ initial_learning_rate = ops.convert_to_tensor(
165
+ self.initial_learning_rate
166
+ )
167
+ dtype = initial_learning_rate.dtype
168
+ decay_steps = ops.cast(self.decay_steps, dtype)
169
+ decay_rate = ops.cast(self.decay_rate, dtype)
170
+
171
+ global_step_recomp = ops.cast(step, dtype)
172
+ p = global_step_recomp / decay_steps
173
+ if self.staircase:
174
+ p = ops.floor(p)
175
+ return ops.multiply(initial_learning_rate, ops.power(decay_rate, p))
176
+
177
+ def get_config(self):
178
+ return {
179
+ "initial_learning_rate": self.initial_learning_rate,
180
+ "decay_steps": self.decay_steps,
181
+ "decay_rate": self.decay_rate,
182
+ "staircase": self.staircase,
183
+ "name": self.name,
184
+ }
185
+
186
+
187
+ @keras_export("keras.optimizers.schedules.PiecewiseConstantDecay")
188
+ class PiecewiseConstantDecay(LearningRateSchedule):
189
+ """A `LearningRateSchedule` that uses a piecewise constant decay schedule.
190
+
191
+ The function returns a 1-arg callable to compute the piecewise constant
192
+ when passed the current optimizer step. This can be useful for changing the
193
+ learning rate value across different invocations of optimizer functions.
194
+
195
+ Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
196
+ for the next 10000 steps, and 0.1 for any additional steps.
197
+
198
+ ```python
199
+ step = ops.array(0)
200
+ boundaries = [100000, 110000]
201
+ values = [1.0, 0.5, 0.1]
202
+ learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay(
203
+ boundaries, values)
204
+
205
+ # Later, whenever we perform an optimization step, we pass in the step.
206
+ learning_rate = learning_rate_fn(step)
207
+ ```
208
+
209
+ You can pass this schedule directly into a `keras.optimizers.Optimizer`
210
+ as the learning rate. The learning rate schedule is also serializable and
211
+ deserializable using `keras.optimizers.schedules.serialize` and
212
+ `keras.optimizers.schedules.deserialize`.
213
+
214
+ Args:
215
+ boundaries: A list of Python numbers with strictly increasing
216
+ entries, and with all elements having the same type as the
217
+ optimizer step.
218
+ values: A list of Python numbers that specifies the values for the
219
+ intervals defined by `boundaries`. It should have one more
220
+ element than `boundaries`, and all elements should have the same
221
+ type.
222
+ name: A string. Optional name of the operation. Defaults to
223
+ `"PiecewiseConstant"`.
224
+
225
+ Returns:
226
+ A 1-arg callable learning rate schedule that takes the current optimizer
227
+ step and outputs the decayed learning rate, a scalar tensor of the
228
+ same type as the boundary tensors.
229
+
230
+ The output of the 1-arg function that takes the `step`
231
+ is `values[0]` when `step <= boundaries[0]`,
232
+ `values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`,
233
+ ..., and `values[-1]` when `step > boundaries[-1]`.
234
+
235
+
236
+ Raises:
237
+ ValueError: if the number of elements in the `boundaries` and `values`
238
+ lists do not match.
239
+ """
240
+
241
+ def __init__(self, boundaries, values, name="PiecewiseConstant"):
242
+ super().__init__()
243
+
244
+ if len(boundaries) != len(values) - 1:
245
+ raise ValueError(
246
+ "The length of boundaries should be 1 less than the length of "
247
+ f"values. Received: boundaries={boundaries} of length "
248
+ f"{len(boundaries)}, and values={values} "
249
+ f"of length {len(values)}."
250
+ )
251
+
252
+ self.boundaries = boundaries
253
+ self.values = values
254
+ self.name = name
255
+
256
+ def __call__(self, step):
257
+ with ops.name_scope(self.name):
258
+ boundaries = [ops.convert_to_tensor(x) for x in self.boundaries]
259
+ values = [ops.convert_to_tensor(x) for x in self.values]
260
+ step = ops.convert_to_tensor(step)
261
+
262
+ for i, b in enumerate(boundaries):
263
+ if b.dtype != step.dtype:
264
+ # We cast the boundaries to have the same type as the step
265
+ b = ops.cast(b, step.dtype)
266
+ boundaries[i] = b
267
+
268
+ result_dtype = values[0].dtype
269
+ result_value = ops.array(0, dtype=result_dtype)
270
+
271
+ # For each range between boundaries, we check whether the step is
272
+ # within that range, cast the resulting boolean to a number,
273
+ # and multiply the result by the corresponding value for the range.
274
+ # Taking the sum of these yields a piecewise constant function.
275
+ step_less_than_first_boundary = ops.cast(
276
+ step <= boundaries[0], result_dtype
277
+ )
278
+ result_value += step_less_than_first_boundary * values[0]
279
+
280
+ step_greater_than_last_boundary = ops.cast(
281
+ step > boundaries[-1], result_dtype
282
+ )
283
+ result_value += step_greater_than_last_boundary * values[-1]
284
+
285
+ for low, high, value in zip(
286
+ boundaries[:-1], boundaries[1:], values[1:-1]
287
+ ):
288
+ # Need to bind v here; can do this with lambda v=v: ...
289
+ step_in_range = ops.cast(
290
+ (step > low) & (step <= high), result_dtype
291
+ )
292
+ result_value += step_in_range * value
293
+
294
+ return result_value
295
+
296
+ def get_config(self):
297
+ return {
298
+ "boundaries": self.boundaries,
299
+ "values": self.values,
300
+ "name": self.name,
301
+ }
302
+
303
+
304
+ @keras_export("keras.optimizers.schedules.PolynomialDecay")
305
+ class PolynomialDecay(LearningRateSchedule):
306
+ """A `LearningRateSchedule` that uses a polynomial decay schedule.
307
+
308
+ It is commonly observed that a monotonically decreasing learning rate, whose
309
+ degree of change is carefully chosen, results in a better performing model.
310
+ This schedule applies a polynomial decay function to an optimizer step,
311
+ given a provided `initial_learning_rate`, to reach an `end_learning_rate`
312
+ in the given `decay_steps`.
313
+
314
+ It requires a `step` value to compute the decayed learning rate. You
315
+ can just pass a backend variable that you increment at each training
316
+ step.
317
+
318
+ The schedule is a 1-arg callable that produces a decayed learning rate
319
+ when passed the current optimizer step. This can be useful for changing the
320
+ learning rate value across different invocations of optimizer functions.
321
+ It is computed as:
322
+
323
+ ```python
324
+ def decayed_learning_rate(step):
325
+ step = min(step, decay_steps)
326
+ return ((initial_learning_rate - end_learning_rate) *
327
+ (1 - step / decay_steps) ^ (power)
328
+ ) + end_learning_rate
329
+ ```
330
+
331
+ If `cycle` is True then a multiple of `decay_steps` is used, the first one
332
+ that is bigger than `step`.
333
+
334
+ ```python
335
+ def decayed_learning_rate(step):
336
+ decay_steps = decay_steps * ceil(step / decay_steps)
337
+ return ((initial_learning_rate - end_learning_rate) *
338
+ (1 - step / decay_steps) ^ (power)
339
+ ) + end_learning_rate
340
+ ```
341
+
342
+ You can pass this schedule directly into a `keras.optimizers.Optimizer`
343
+ as the learning rate.
344
+ Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using
345
+ sqrt (i.e. power=0.5):
346
+
347
+ ```python
348
+ ...
349
+ starter_learning_rate = 0.1
350
+ end_learning_rate = 0.01
351
+ decay_steps = 10000
352
+ learning_rate_fn = keras.optimizers.schedules.PolynomialDecay(
353
+ starter_learning_rate,
354
+ decay_steps,
355
+ end_learning_rate,
356
+ power=0.5)
357
+
358
+ model.compile(optimizer=keras.optimizers.SGD(
359
+ learning_rate=learning_rate_fn),
360
+ loss='sparse_categorical_crossentropy',
361
+ metrics=['accuracy'])
362
+
363
+ model.fit(data, labels, epochs=5)
364
+ ```
365
+
366
+ The learning rate schedule is also serializable and deserializable using
367
+ `keras.optimizers.schedules.serialize` and
368
+ `keras.optimizers.schedules.deserialize`.
369
+
370
+ Args:
371
+ initial_learning_rate: A Python float. The initial learning rate.
372
+ decay_steps: A Python integer. Must be positive. See the decay
373
+ computation above.
374
+ end_learning_rate: A Python float. The minimal end learning rate.
375
+ power: A Python float. The power of the polynomial. Defaults to
376
+ `1.0`.
377
+ cycle: A boolean, whether it should cycle beyond decay_steps.
378
+ name: String. Optional name of the operation. Defaults to
379
+ `"PolynomialDecay"`.
380
+
381
+ Returns:
382
+ A 1-arg callable learning rate schedule that takes the current optimizer
383
+ step and outputs the decayed learning rate, a scalar tensor of the
384
+ same type as `initial_learning_rate`.
385
+ """
386
+
387
+ def __init__(
388
+ self,
389
+ initial_learning_rate,
390
+ decay_steps,
391
+ end_learning_rate=0.0001,
392
+ power=1.0,
393
+ cycle=False,
394
+ name="PolynomialDecay",
395
+ ):
396
+ super().__init__()
397
+
398
+ self.initial_learning_rate = initial_learning_rate
399
+ self.decay_steps = decay_steps
400
+ self.end_learning_rate = end_learning_rate
401
+ self.power = power
402
+ self.cycle = cycle
403
+ self.name = name
404
+
405
+ if self.decay_steps <= 0:
406
+ raise ValueError(
407
+ "Argument `decay_steps` must be > 0. "
408
+ f"Received: decay_steps={self.decay_steps}"
409
+ )
410
+
411
+ def __call__(self, step):
412
+ with ops.name_scope(self.name):
413
+ initial_learning_rate = ops.convert_to_tensor(
414
+ self.initial_learning_rate
415
+ )
416
+ dtype = initial_learning_rate.dtype
417
+ end_learning_rate = ops.cast(self.end_learning_rate, dtype)
418
+ power = ops.cast(self.power, dtype)
419
+
420
+ global_step_recomp = ops.cast(step, dtype)
421
+ decay_steps_recomp = ops.cast(self.decay_steps, dtype)
422
+ if self.cycle:
423
+ # Find the first multiple of decay_steps that is bigger than
424
+ # global_step. If global_step is zero set the multiplier to 1
425
+ multiplier = ops.where(
426
+ ops.equal(global_step_recomp, 0),
427
+ 1.0,
428
+ ops.ceil(global_step_recomp / self.decay_steps),
429
+ )
430
+ decay_steps_recomp = ops.multiply(
431
+ decay_steps_recomp, multiplier
432
+ )
433
+ else:
434
+ # Make sure that the global_step used is not bigger than
435
+ # decay_steps.
436
+ global_step_recomp = ops.minimum(
437
+ global_step_recomp, decay_steps_recomp
438
+ )
439
+
440
+ p = ops.divide(global_step_recomp, decay_steps_recomp)
441
+ return ops.add(
442
+ ops.multiply(
443
+ initial_learning_rate - end_learning_rate,
444
+ ops.power(1 - p, power),
445
+ ),
446
+ end_learning_rate,
447
+ )
448
+
449
+ def get_config(self):
450
+ return {
451
+ "initial_learning_rate": self.initial_learning_rate,
452
+ "decay_steps": self.decay_steps,
453
+ "end_learning_rate": self.end_learning_rate,
454
+ "power": self.power,
455
+ "cycle": self.cycle,
456
+ "name": self.name,
457
+ }
458
+
459
+
460
+ @keras_export("keras.optimizers.schedules.InverseTimeDecay")
461
+ class InverseTimeDecay(LearningRateSchedule):
462
+ """A `LearningRateSchedule` that uses an inverse time decay schedule.
463
+
464
+ When training a model, it is often useful to lower the learning rate as
465
+ the training progresses. This schedule applies the inverse decay function
466
+ to an optimizer step, given a provided initial learning rate.
467
+ It requires a `step` value to compute the decayed learning rate. You can
468
+ just pass a backend variable that you increment at each training step.
469
+
470
+ The schedule is a 1-arg callable that produces a decayed learning
471
+ rate when passed the current optimizer step. This can be useful for changing
472
+ the learning rate value across different invocations of optimizer functions.
473
+ It is computed as:
474
+
475
+ ```python
476
+ def decayed_learning_rate(step):
477
+ return initial_learning_rate / (1 + decay_rate * step / decay_step)
478
+ ```
479
+
480
+ or, if `staircase` is `True`, as:
481
+
482
+ ```python
483
+ def decayed_learning_rate(step):
484
+ return initial_learning_rate /
485
+ (1 + decay_rate * floor(step / decay_step))
486
+ ```
487
+
488
+ You can pass this schedule directly into a `keras.optimizers.Optimizer`
489
+ as the learning rate.
490
+ Example: Fit a Keras model when decaying 1/t with a rate of 0.5:
491
+
492
+ ```python
493
+ ...
494
+ initial_learning_rate = 0.1
495
+ decay_steps = 1.0
496
+ decay_rate = 0.5
497
+ learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay(
498
+ initial_learning_rate, decay_steps, decay_rate)
499
+
500
+ model.compile(optimizer=keras.optimizers.SGD(
501
+ learning_rate=learning_rate_fn),
502
+ loss='sparse_categorical_crossentropy',
503
+ metrics=['accuracy'])
504
+
505
+ model.fit(data, labels, epochs=5)
506
+ ```
507
+
508
+ Args:
509
+ initial_learning_rate: A Python float. The initial learning rate.
510
+ decay_steps: How often to apply decay.
511
+ decay_rate: A Python number. The decay rate.
512
+ staircase: Whether to apply decay in a discrete staircase, as o
513
+ pposed to continuous, fashion.
514
+ name: String. Optional name of the operation. Defaults to
515
+ `"InverseTimeDecay"`.
516
+
517
+ Returns:
518
+ A 1-arg callable learning rate schedule that takes the current optimizer
519
+ step and outputs the decayed learning rate, a scalar tensor of the
520
+ same type as `initial_learning_rate`.
521
+ """
522
+
523
+ def __init__(
524
+ self,
525
+ initial_learning_rate,
526
+ decay_steps,
527
+ decay_rate,
528
+ staircase=False,
529
+ name="InverseTimeDecay",
530
+ ):
531
+ super().__init__()
532
+
533
+ self.initial_learning_rate = initial_learning_rate
534
+ self.decay_steps = decay_steps
535
+ self.decay_rate = decay_rate
536
+ self.staircase = staircase
537
+ self.name = name
538
+
539
+ if self.decay_steps <= 0:
540
+ raise ValueError(
541
+ "Argument `decay_steps` must be > 0. "
542
+ f"Received: decay_steps={self.decay_steps}"
543
+ )
544
+
545
+ def __call__(self, step):
546
+ with ops.name_scope(self.name):
547
+ initial_learning_rate = ops.convert_to_tensor(
548
+ self.initial_learning_rate
549
+ )
550
+ dtype = initial_learning_rate.dtype
551
+ decay_steps = ops.cast(self.decay_steps, dtype)
552
+ decay_rate = ops.cast(self.decay_rate, dtype)
553
+
554
+ global_step_recomp = ops.cast(step, dtype)
555
+ p = global_step_recomp / decay_steps
556
+ if self.staircase:
557
+ p = ops.floor(p)
558
+ const = ops.cast(ops.array(1), dtype)
559
+ denom = ops.add(const, ops.multiply(decay_rate, p))
560
+ return ops.divide(initial_learning_rate, denom)
561
+
562
+ def get_config(self):
563
+ return {
564
+ "initial_learning_rate": self.initial_learning_rate,
565
+ "decay_steps": self.decay_steps,
566
+ "decay_rate": self.decay_rate,
567
+ "staircase": self.staircase,
568
+ "name": self.name,
569
+ }
570
+
571
+
572
+ @keras_export("keras.optimizers.schedules.CosineDecay")
573
+ class CosineDecay(LearningRateSchedule):
574
+ """A `LearningRateSchedule` that uses a cosine decay with optional warmup.
575
+
576
+ See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
577
+ SGDR: Stochastic Gradient Descent with Warm Restarts.
578
+
579
+ For the idea of a linear warmup of our learning rate,
580
+ see [Goyal et al.](https://arxiv.org/pdf/1706.02677.pdf).
581
+
582
+ When we begin training a model, we often want an initial increase in our
583
+ learning rate followed by a decay. If `warmup_target` is an int, this
584
+ schedule applies a linear increase per optimizer step to our learning rate
585
+ from `initial_learning_rate` to `warmup_target` for a duration of
586
+ `warmup_steps`. Afterwards, it applies a cosine decay function taking our
587
+ learning rate from `warmup_target` to `alpha` for a duration of
588
+ `decay_steps`. If `warmup_target` is None we skip warmup and our decay
589
+ will take our learning rate from `initial_learning_rate` to `alpha`.
590
+ It requires a `step` value to compute the learning rate. You can
591
+ just pass a backend variable that you increment at each training step.
592
+
593
+ The schedule is a 1-arg callable that produces a warmup followed by a
594
+ decayed learning rate when passed the current optimizer step. This can be
595
+ useful for changing the learning rate value across different invocations of
596
+ optimizer functions.
597
+
598
+ Our warmup is computed as:
599
+
600
+ ```python
601
+ def warmup_learning_rate(step):
602
+ completed_fraction = step / warmup_steps
603
+ total_delta = target_warmup - initial_learning_rate
604
+ return completed_fraction * total_delta
605
+ ```
606
+
607
+ And our decay is computed as:
608
+
609
+ ```python
610
+ if warmup_target is None:
611
+ initial_decay_lr = initial_learning_rate
612
+ else:
613
+ initial_decay_lr = warmup_target
614
+
615
+ def decayed_learning_rate(step):
616
+ step = min(step, decay_steps)
617
+ cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps))
618
+ decayed = (1 - alpha) * cosine_decay + alpha
619
+ return initial_decay_lr * decayed
620
+ ```
621
+
622
+ Example usage without warmup:
623
+
624
+ ```python
625
+ decay_steps = 1000
626
+ initial_learning_rate = 0.1
627
+ lr_decayed_fn = keras.optimizers.schedules.CosineDecay(
628
+ initial_learning_rate, decay_steps)
629
+ ```
630
+
631
+ Example usage with warmup:
632
+
633
+ ```python
634
+ decay_steps = 1000
635
+ initial_learning_rate = 0
636
+ warmup_steps = 1000
637
+ target_learning_rate = 0.1
638
+ lr_warmup_decayed_fn = keras.optimizers.schedules.CosineDecay(
639
+ initial_learning_rate, decay_steps, warmup_target=target_learning_rate,
640
+ warmup_steps=warmup_steps
641
+ )
642
+ ```
643
+
644
+ You can pass this schedule directly into a `keras.optimizers.Optimizer`
645
+ as the learning rate. The learning rate schedule is also serializable and
646
+ deserializable using `keras.optimizers.schedules.serialize` and
647
+ `keras.optimizers.schedules.deserialize`.
648
+
649
+ Args:
650
+ initial_learning_rate: A Python float. The initial learning rate.
651
+ decay_steps: A Python int. Number of steps to decay over.
652
+ alpha: A Python float. Minimum learning rate value for decay as a
653
+ fraction of `initial_learning_rate`.
654
+ name: String. Optional name of the operation. Defaults to
655
+ `"CosineDecay"`.
656
+ warmup_target: A Python float. The target learning rate for our
657
+ warmup phase. Will cast to the `initial_learning_rate` datatype.
658
+ Setting to `None` will skip warmup and begins decay phase from
659
+ `initial_learning_rate`. Otherwise scheduler will warmup from
660
+ `initial_learning_rate` to `warmup_target`.
661
+ warmup_steps: A Python int. Number of steps to warmup over.
662
+
663
+ Returns:
664
+ A 1-arg callable learning rate schedule that takes the current optimizer
665
+ step and outputs the decayed learning rate, a scalar tensor of the
666
+ same type as `initial_learning_rate`.
667
+ """
668
+
669
+ def __init__(
670
+ self,
671
+ initial_learning_rate,
672
+ decay_steps,
673
+ alpha=0.0,
674
+ name="CosineDecay",
675
+ warmup_target=None,
676
+ warmup_steps=0,
677
+ ):
678
+ super().__init__()
679
+
680
+ self.initial_learning_rate = initial_learning_rate
681
+ self.decay_steps = decay_steps
682
+ self.alpha = alpha
683
+ self.name = name
684
+ self.warmup_steps = warmup_steps
685
+ self.warmup_target = warmup_target
686
+
687
+ if self.decay_steps <= 0:
688
+ raise ValueError(
689
+ "Argument `decay_steps` must be > 0. "
690
+ f"Received: decay_steps={self.decay_steps}"
691
+ )
692
+
693
+ def _decay_function(self, step, decay_steps, decay_from_lr, dtype):
694
+ with ops.name_scope(self.name):
695
+ completed_fraction = step / decay_steps
696
+ pi = ops.array(math.pi, dtype=dtype)
697
+ cosine_decayed = 0.5 * (1.0 + ops.cos(pi * completed_fraction))
698
+ decayed = (1 - self.alpha) * cosine_decayed + self.alpha
699
+ return ops.multiply(decay_from_lr, decayed)
700
+
701
+ def _warmup_function(
702
+ self, step, warmup_steps, warmup_target, initial_learning_rate
703
+ ):
704
+ with ops.name_scope(self.name):
705
+ completed_fraction = step / warmup_steps
706
+ total_step_delta = warmup_target - initial_learning_rate
707
+ return total_step_delta * completed_fraction + initial_learning_rate
708
+
709
+ def __call__(self, step):
710
+ with ops.name_scope(self.name):
711
+ initial_learning_rate = ops.convert_to_tensor(
712
+ self.initial_learning_rate
713
+ )
714
+ dtype = initial_learning_rate.dtype
715
+ decay_steps = ops.cast(self.decay_steps, dtype)
716
+ global_step_recomp = ops.cast(step, dtype)
717
+
718
+ if self.warmup_target is None:
719
+ global_step_recomp = ops.minimum(
720
+ global_step_recomp, decay_steps
721
+ )
722
+ return self._decay_function(
723
+ global_step_recomp,
724
+ decay_steps,
725
+ initial_learning_rate,
726
+ dtype,
727
+ )
728
+
729
+ warmup_target = ops.cast(self.warmup_target, dtype)
730
+ warmup_steps = ops.cast(self.warmup_steps, dtype)
731
+
732
+ global_step_recomp = ops.minimum(
733
+ global_step_recomp, decay_steps + warmup_steps
734
+ )
735
+
736
+ return ops.cond(
737
+ global_step_recomp < warmup_steps,
738
+ lambda: self._warmup_function(
739
+ global_step_recomp,
740
+ warmup_steps,
741
+ warmup_target,
742
+ initial_learning_rate,
743
+ ),
744
+ lambda: self._decay_function(
745
+ global_step_recomp - warmup_steps,
746
+ decay_steps,
747
+ warmup_target,
748
+ dtype,
749
+ ),
750
+ )
751
+
752
+ def get_config(self):
753
+ return {
754
+ "initial_learning_rate": self.initial_learning_rate,
755
+ "decay_steps": self.decay_steps,
756
+ "alpha": self.alpha,
757
+ "name": self.name,
758
+ "warmup_target": self.warmup_target,
759
+ "warmup_steps": self.warmup_steps,
760
+ }
761
+
762
+
763
+ @keras_export("keras.optimizers.schedules.CosineDecayRestarts")
764
+ class CosineDecayRestarts(LearningRateSchedule):
765
+ """A `LearningRateSchedule` that uses a cosine decay schedule with restarts.
766
+
767
+ See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
768
+ SGDR: Stochastic Gradient Descent with Warm Restarts.
769
+
770
+ When training a model, it is often useful to lower the learning rate as
771
+ the training progresses. This schedule applies a cosine decay function with
772
+ restarts to an optimizer step, given a provided initial learning rate.
773
+ It requires a `step` value to compute the decayed learning rate. You can
774
+ just pass a backend variable that you increment at each training step.
775
+
776
+ The schedule is a 1-arg callable that produces a decayed learning
777
+ rate when passed the current optimizer step. This can be useful for changing
778
+ the learning rate value across different invocations of optimizer functions.
779
+
780
+ The learning rate multiplier first decays
781
+ from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
782
+ restart is performed. Each new warm restart runs for `t_mul` times more
783
+ steps and with `m_mul` times initial learning rate as the new learning rate.
784
+
785
+ Example:
786
+ ```python
787
+ first_decay_steps = 1000
788
+ lr_decayed_fn = (
789
+ keras.optimizers.schedules.CosineDecayRestarts(
790
+ initial_learning_rate,
791
+ first_decay_steps))
792
+ ```
793
+
794
+ You can pass this schedule directly into a `keras.optimizers.Optimizer`
795
+ as the learning rate. The learning rate schedule is also serializable and
796
+ deserializable using `keras.optimizers.schedules.serialize` and
797
+ `keras.optimizers.schedules.deserialize`.
798
+
799
+ Args:
800
+ initial_learning_rate: A Python float. The initial learning rate.
801
+ first_decay_steps: A Python integer. Number of steps to decay over.
802
+ t_mul: A Python float. Used to derive the number of iterations in
803
+ the i-th period.
804
+ m_mul: A Python float. Used to derive the initial learning rate of
805
+ the i-th period.
806
+ alpha: A Python float. Minimum learning rate value as a fraction of
807
+ the `initial_learning_rate`.
808
+ name: String. Optional name of the operation. Defaults to
809
+ `"SGDRDecay"`.
810
+
811
+ Returns:
812
+ A 1-arg callable learning rate schedule that takes the current optimizer
813
+ step and outputs the decayed learning rate, a scalar tensor of the
814
+ same type as `initial_learning_rate`.
815
+ """
816
+
817
+ def __init__(
818
+ self,
819
+ initial_learning_rate,
820
+ first_decay_steps,
821
+ t_mul=2.0,
822
+ m_mul=1.0,
823
+ alpha=0.0,
824
+ name="SGDRDecay",
825
+ ):
826
+ super().__init__()
827
+
828
+ self.initial_learning_rate = initial_learning_rate
829
+ self.first_decay_steps = first_decay_steps
830
+ self._t_mul = t_mul
831
+ self._m_mul = m_mul
832
+ self.alpha = alpha
833
+ self.name = name
834
+
835
+ if self.first_decay_steps <= 0:
836
+ raise ValueError(
837
+ "Argument `first_decay_steps` must be > 0. "
838
+ f"Received: first_decay_steps={self.first_decay_steps}"
839
+ )
840
+
841
+ def __call__(self, step):
842
+ with ops.name_scope(self.name):
843
+ initial_learning_rate = ops.convert_to_tensor(
844
+ self.initial_learning_rate
845
+ )
846
+ dtype = initial_learning_rate.dtype
847
+ first_decay_steps = ops.cast(self.first_decay_steps, dtype)
848
+ alpha = ops.cast(self.alpha, dtype)
849
+ t_mul = ops.cast(self._t_mul, dtype)
850
+ m_mul = ops.cast(self._m_mul, dtype)
851
+
852
+ global_step_recomp = ops.cast(step, dtype)
853
+ completed_fraction = global_step_recomp / first_decay_steps
854
+
855
+ def compute_step(completed_fraction, geometric=False):
856
+ """Helper for `cond` operation."""
857
+ if geometric:
858
+ # ops.log is sensitive to the precision of dtype, so we need
859
+ # the additional casting
860
+ i_restart = ops.floor(
861
+ ops.log(
862
+ ops.cast(
863
+ 1.0 - completed_fraction * (1.0 - t_mul), dtype
864
+ )
865
+ )
866
+ / ops.log(t_mul)
867
+ )
868
+
869
+ sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
870
+ completed_fraction = (
871
+ completed_fraction - sum_r
872
+ ) / t_mul**i_restart
873
+
874
+ else:
875
+ i_restart = ops.floor(completed_fraction)
876
+ completed_fraction -= i_restart
877
+
878
+ return i_restart, completed_fraction
879
+
880
+ i_restart, completed_fraction = ops.cond(
881
+ ops.equal(t_mul, 1.0),
882
+ lambda: compute_step(completed_fraction, geometric=False),
883
+ lambda: compute_step(completed_fraction, geometric=True),
884
+ )
885
+
886
+ m_fac = m_mul**i_restart
887
+ cosine_decayed = (
888
+ 0.5
889
+ * m_fac
890
+ * (
891
+ 1.0
892
+ + ops.cos(
893
+ ops.array(math.pi, dtype=dtype) * completed_fraction
894
+ )
895
+ )
896
+ )
897
+ decayed = (1 - alpha) * cosine_decayed + alpha
898
+
899
+ return ops.multiply(initial_learning_rate, decayed)
900
+
901
+ def get_config(self):
902
+ return {
903
+ "initial_learning_rate": self.initial_learning_rate,
904
+ "first_decay_steps": self.first_decay_steps,
905
+ "t_mul": self._t_mul,
906
+ "m_mul": self._m_mul,
907
+ "alpha": self.alpha,
908
+ "name": self.name,
909
+ }
910
+
911
+
912
+ @keras_export("keras.optimizers.schedules.serialize")
913
+ def serialize(learning_rate_schedule):
914
+ """Serializes a `LearningRateSchedule` into a JSON-compatible dict.
915
+
916
+ Args:
917
+ learning_rate_schedule: The `LearningRateSchedule` object to serialize.
918
+
919
+ Returns:
920
+ A JSON-serializable dict representing the object's config.
921
+
922
+ Example:
923
+
924
+ >>> lr_schedule = keras.optimizers.schedules.ExponentialDecay(
925
+ ... 0.1, decay_steps=100000, decay_rate=0.96, staircase=True)
926
+ >>> keras.optimizers.schedules.serialize(lr_schedule)
927
+ {'module': 'keras.optimizers.schedules',
928
+ 'class_name': 'ExponentialDecay', 'config': {...},
929
+ 'registered_name': None}
930
+ """
931
+ return serialization_lib.serialize_keras_object(learning_rate_schedule)
932
+
933
+
934
+ @keras_export("keras.optimizers.schedules.deserialize")
935
+ def deserialize(config, custom_objects=None):
936
+ """Instantiates a `LearningRateSchedule` object from a serialized form.
937
+
938
+ Args:
939
+ config: The serialized form of the `LearningRateSchedule`. Dictionary of
940
+ the form {'class_name': str, 'config': dict}.
941
+ custom_objects: A dictionary mapping class names (or function names) of
942
+ custom (non-Keras) objects to class/functions.
943
+
944
+ Returns:
945
+ A `LearningRateSchedule` object.
946
+
947
+ Example:
948
+
949
+ ```python
950
+ # Configuration for PolynomialDecay
951
+ config = {
952
+ 'class_name': 'PolynomialDecay',
953
+ 'config': {'cycle': False,
954
+ 'decay_steps': 10000,
955
+ 'end_learning_rate': 0.01,
956
+ 'initial_learning_rate': 0.1,
957
+ 'name': None,
958
+ 'power': 0.5
959
+ }
960
+ }
961
+ lr_schedule = keras.optimizers.schedules.deserialize(config)
962
+ ```
963
+ """
964
+ return serialization_lib.deserialize_keras_object(
965
+ config,
966
+ module_objects=globals(),
967
+ custom_objects=custom_objects,
968
+ printable_module_name="decay",
969
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/sgd.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import ops
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.optimizers import optimizer
4
+
5
+
6
+ @keras_export("keras.optimizers.SGD")
7
+ class SGD(optimizer.Optimizer):
8
+ """Gradient descent (with momentum) optimizer.
9
+
10
+ Update rule for parameter `w` with gradient `g` when `momentum` is 0:
11
+
12
+ ```python
13
+ w = w - learning_rate * g
14
+ ```
15
+
16
+ Update rule when `momentum` is larger than 0:
17
+
18
+ ```python
19
+ velocity = momentum * velocity - learning_rate * g
20
+ w = w + velocity
21
+ ```
22
+
23
+ When `nesterov=True`, this rule becomes:
24
+
25
+ ```python
26
+ velocity = momentum * velocity - learning_rate * g
27
+ w = w + momentum * velocity - learning_rate * g
28
+ ```
29
+
30
+ Args:
31
+ learning_rate: A float, a
32
+ `keras.optimizers.schedules.LearningRateSchedule` instance, or
33
+ a callable that takes no arguments and returns the actual value to
34
+ use. The learning rate. Defaults to `0.01`.
35
+ momentum: float hyperparameter >= 0 that accelerates gradient descent in
36
+ the relevant direction and dampens oscillations. 0 is vanilla
37
+ gradient descent. Defaults to `0.0`.
38
+ nesterov: boolean. Whether to apply Nesterov momentum.
39
+ Defaults to `False`.
40
+ {{base_optimizer_keyword_args}}
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ learning_rate=0.01,
46
+ momentum=0.0,
47
+ nesterov=False,
48
+ weight_decay=None,
49
+ clipnorm=None,
50
+ clipvalue=None,
51
+ global_clipnorm=None,
52
+ use_ema=False,
53
+ ema_momentum=0.99,
54
+ ema_overwrite_frequency=None,
55
+ loss_scale_factor=None,
56
+ gradient_accumulation_steps=None,
57
+ name="SGD",
58
+ **kwargs,
59
+ ):
60
+ super().__init__(
61
+ learning_rate=learning_rate,
62
+ name=name,
63
+ weight_decay=weight_decay,
64
+ clipnorm=clipnorm,
65
+ clipvalue=clipvalue,
66
+ global_clipnorm=global_clipnorm,
67
+ use_ema=use_ema,
68
+ ema_momentum=ema_momentum,
69
+ ema_overwrite_frequency=ema_overwrite_frequency,
70
+ loss_scale_factor=loss_scale_factor,
71
+ gradient_accumulation_steps=gradient_accumulation_steps,
72
+ **kwargs,
73
+ )
74
+ if not isinstance(momentum, float) or momentum < 0 or momentum > 1:
75
+ raise ValueError("`momentum` must be a float between [0, 1].")
76
+ self.momentum = momentum
77
+ self.nesterov = nesterov
78
+
79
+ def build(self, variables):
80
+ """Initialize optimizer variables.
81
+
82
+ SGD optimizer has one variable `momentums`, only set if `self.momentum`
83
+ is not 0.
84
+
85
+ Args:
86
+ var_list: list of model variables to build SGD variables on.
87
+ """
88
+ if self.built:
89
+ return
90
+ super().build(variables)
91
+ self.momentums = []
92
+ if self.momentum != 0:
93
+ for variable in variables:
94
+ self.momentums.append(
95
+ self.add_variable_from_reference(
96
+ reference_variable=variable, name="momentum"
97
+ )
98
+ )
99
+
100
+ def update_step(self, gradient, variable, learning_rate):
101
+ """Update step given gradient and the associated model variable."""
102
+ learning_rate = ops.cast(learning_rate, variable.dtype)
103
+ gradient = ops.cast(gradient, variable.dtype)
104
+ m = None
105
+ if self.momentum != 0:
106
+ m = self.momentums[self._get_variable_index(variable)]
107
+
108
+ if m is not None:
109
+ momentum = ops.cast(self.momentum, variable.dtype)
110
+ self.assign(
111
+ m,
112
+ ops.subtract(
113
+ ops.multiply(m, momentum),
114
+ ops.multiply(gradient, learning_rate),
115
+ ),
116
+ )
117
+ if self.nesterov:
118
+ self.assign_add(
119
+ variable,
120
+ ops.subtract(
121
+ ops.multiply(m, momentum),
122
+ ops.multiply(gradient, learning_rate),
123
+ ),
124
+ )
125
+ else:
126
+ self.assign_add(variable, m)
127
+ else:
128
+ self.assign_sub(variable, ops.multiply(gradient, learning_rate))
129
+
130
+ def get_config(self):
131
+ config = super().get_config()
132
+ config.update(
133
+ {
134
+ "momentum": self.momentum,
135
+ "nesterov": self.nesterov,
136
+ }
137
+ )
138
+ return config
139
+
140
+
141
+ SGD.__doc__ = SGD.__doc__.replace(
142
+ "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
143
+ )