Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__init__.py +16 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__pycache__/operation_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__pycache__/symbolic_arguments.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/core.py +1167 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/function.py +423 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/image.py +1235 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/linalg.py +707 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/math.py +1046 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/nn.py +2653 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/node.py +143 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/numpy.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/operation.py +316 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/operation_utils.py +421 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/symbolic_arguments.py +46 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__init__.py +121 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adadelta.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adafactor.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adagrad.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adam.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adamax.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adamw.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/base_optimizer.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/ftrl.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/lamb.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/lion.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/loss_scale_optimizer.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/nadam.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/optimizer.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/rmsprop.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/sgd.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adadelta.py +139 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adafactor.py +208 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adagrad.py +115 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adam.py +167 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adamax.py +156 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adamw.py +100 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/base_optimizer.py +1102 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/ftrl.py +249 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/lamb.py +158 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/lion.py +142 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/loss_scale_optimizer.py +298 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/nadam.py +174 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/optimizer.py +27 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/rmsprop.py +180 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__init__.py +16 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__pycache__/learning_rate_schedule.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/learning_rate_schedule.py +969 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/sgd.py +143 -0
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from keras.src.ops.numpy import Matmul, matmul
|
| 2 |
+
# from keras.src.ops.numpy import Add, add
|
| 3 |
+
# from keras.src.ops.numpy import Multiply, multiply
|
| 4 |
+
|
| 5 |
+
from keras.src.backend import cast
|
| 6 |
+
from keras.src.backend import cond
|
| 7 |
+
from keras.src.backend import is_tensor
|
| 8 |
+
from keras.src.backend import name_scope
|
| 9 |
+
from keras.src.backend import random
|
| 10 |
+
from keras.src.ops import image
|
| 11 |
+
from keras.src.ops import operation_utils
|
| 12 |
+
from keras.src.ops.core import * # noqa: F403
|
| 13 |
+
from keras.src.ops.linalg import * # noqa: F403
|
| 14 |
+
from keras.src.ops.math import * # noqa: F403
|
| 15 |
+
from keras.src.ops.nn import * # noqa: F403
|
| 16 |
+
from keras.src.ops.numpy import * # noqa: F403
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__pycache__/operation_utils.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/__pycache__/symbolic_arguments.cpython-310.pyc
ADDED
|
Binary file (1.86 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/core.py
ADDED
|
@@ -0,0 +1,1167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ml_dtypes
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from keras.src import backend
|
| 5 |
+
from keras.src import tree
|
| 6 |
+
from keras.src.api_export import keras_export
|
| 7 |
+
from keras.src.backend import KerasTensor
|
| 8 |
+
from keras.src.backend import any_symbolic_tensors
|
| 9 |
+
from keras.src.backend.common.backend_utils import slice_along_axis
|
| 10 |
+
from keras.src.ops.operation import Operation
|
| 11 |
+
from keras.src.utils import traceback_utils
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Map(Operation):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
def call(self, f, xs):
|
| 19 |
+
return backend.core.map(f, xs)
|
| 20 |
+
|
| 21 |
+
def compute_output_spec(self, f, xs):
|
| 22 |
+
x = xs[0]
|
| 23 |
+
n = xs.shape[0]
|
| 24 |
+
y = backend.compute_output_spec(f, x)
|
| 25 |
+
|
| 26 |
+
def append_batch_axis(x):
|
| 27 |
+
return KerasTensor(
|
| 28 |
+
shape=(n,) + x.shape, dtype=x.dtype, sparse=x.sparse
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
y = tree.map_structure(append_batch_axis, y)
|
| 32 |
+
return y
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@keras_export("keras.ops.map")
|
| 36 |
+
def map(f, xs):
|
| 37 |
+
"""Map a function over leading array axes.
|
| 38 |
+
|
| 39 |
+
Like Python’s builtin map, except inputs and outputs are in the form of
|
| 40 |
+
stacked arrays. Consider using the `vectorized_map()` transform instead,
|
| 41 |
+
unless you need to apply a function element by element for reduced memory
|
| 42 |
+
usage or heterogeneous computation with other control flow primitives.
|
| 43 |
+
|
| 44 |
+
When `xs` is an array type, the semantics of `map()` are given by this
|
| 45 |
+
Python implementation:
|
| 46 |
+
|
| 47 |
+
```python
|
| 48 |
+
def map(f, xs):
|
| 49 |
+
return np.stack([f(x) for x in xs])
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
f: Callable defines the function to apply element-wise over the first
|
| 54 |
+
axis or axes of `xs`.
|
| 55 |
+
xs: Values over which to map along the leading axis.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Mapped values.
|
| 59 |
+
|
| 60 |
+
Examples:
|
| 61 |
+
|
| 62 |
+
>>> f = lambda x: x**2
|
| 63 |
+
>>> xs = keras.ops.arange(10)
|
| 64 |
+
>>> ys = keras.ops.map(f, xs)
|
| 65 |
+
>>> ys
|
| 66 |
+
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
|
| 67 |
+
|
| 68 |
+
>>> f = lambda x: {"y1": x**2, "y2": x * 10} # Can have nested outputs
|
| 69 |
+
>>> ys = keras.ops.map(f, xs)
|
| 70 |
+
>>> ys["y1"]
|
| 71 |
+
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
|
| 72 |
+
>>> ys["y2"]
|
| 73 |
+
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
|
| 74 |
+
"""
|
| 75 |
+
if any_symbolic_tensors((xs,)):
|
| 76 |
+
return Map().symbolic_call(f, xs)
|
| 77 |
+
return backend.core.map(f, xs)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Scan(Operation):
|
| 81 |
+
def __init__(self, reverse=False, unroll=1):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.reverse = reverse
|
| 84 |
+
self.unroll = unroll
|
| 85 |
+
|
| 86 |
+
def call(self, f, init, xs, length):
|
| 87 |
+
return backend.core.scan(
|
| 88 |
+
f, init, xs, length, reverse=self.reverse, unroll=self.unroll
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def compute_output_spec(self, f, init, xs, length):
|
| 92 |
+
if xs is None:
|
| 93 |
+
n = int(length)
|
| 94 |
+
x = None
|
| 95 |
+
else:
|
| 96 |
+
n = (
|
| 97 |
+
int(length)
|
| 98 |
+
if length is not None
|
| 99 |
+
else tree.flatten(xs)[0].shape[0]
|
| 100 |
+
)
|
| 101 |
+
x = xs[0]
|
| 102 |
+
|
| 103 |
+
carry, y = backend.compute_output_spec(f, init, x)
|
| 104 |
+
y = KerasTensor(shape=(n,) + y.shape, dtype=y.dtype, sparse=y.sparse)
|
| 105 |
+
return carry, y
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@keras_export("keras.ops.scan")
|
| 109 |
+
def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
|
| 110 |
+
"""Scan a function over leading array axes while carrying along state.
|
| 111 |
+
|
| 112 |
+
When the type of `xs` is an array type or `None`, and the type of `ys` is an
|
| 113 |
+
array type, the semantics of `scan()` are given roughly by this Python
|
| 114 |
+
implementation:
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
def scan(f, init, xs, length=None):
|
| 118 |
+
if xs is None:
|
| 119 |
+
xs = [None] * length
|
| 120 |
+
carry = init
|
| 121 |
+
ys = []
|
| 122 |
+
for x in xs:
|
| 123 |
+
carry, y = f(carry, x)
|
| 124 |
+
ys.append(y)
|
| 125 |
+
return carry, np.stack(ys)
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
The loop-carried value `carry` (`init`) must hold a fixed shape and dtype
|
| 129 |
+
across all iterations.
|
| 130 |
+
|
| 131 |
+
In TensorFlow, `y` must match `carry` in shape and dtype. This is not
|
| 132 |
+
required in other backends.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
f: Callable defines the logic for each loop iteration. This accepts two
|
| 136 |
+
arguments where the first is a value of the loop carry and the
|
| 137 |
+
second is a slice of `xs` along its leading axis.
|
| 138 |
+
This callable returns a pair where the first represents a new value
|
| 139 |
+
for the loop carry and the second represents a slice of the output.
|
| 140 |
+
init: The initial loop carry value. This can be a scalar, tensor, or any
|
| 141 |
+
nested structure. It must match the structure of the first element
|
| 142 |
+
returned by `f`.
|
| 143 |
+
xs: Optional value to scan along its leading axis. This can be a tensor
|
| 144 |
+
or any nested structure. If `xs` is not provided, you must specify
|
| 145 |
+
`length` to define the number of loop iterations.
|
| 146 |
+
Defaults to `None`.
|
| 147 |
+
length: Optional integer specifying the number of loop iterations.
|
| 148 |
+
If `length` is not provided, it defaults to the sizes of leading
|
| 149 |
+
axis of the arrays in `xs`. Defaults to `None`.
|
| 150 |
+
reverse: Optional boolean specifying whether to run the scan iteration
|
| 151 |
+
forward or in reverse, equivalent to reversing the leading axes of
|
| 152 |
+
the arrays in both `xs` and in `ys`.
|
| 153 |
+
unroll: Optional positive integer or boolean specifying how many scan
|
| 154 |
+
iterations to unroll within a single iteration of a loop. If an
|
| 155 |
+
integer is provided, it determines how many unrolled loop iterations
|
| 156 |
+
to run within a single rolled iteration of the loop. If a boolean is
|
| 157 |
+
provided, it will determine if the loop is completely unrolled
|
| 158 |
+
(`unroll=True`) or left completely unrolled (`unroll=False`).
|
| 159 |
+
Note that unrolling is only supported by JAX and TensorFlow
|
| 160 |
+
backends.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
A pair where the first element represents the final loop carry value and
|
| 164 |
+
the second element represents the stacked outputs of `f` when scanned
|
| 165 |
+
over the leading axis of the inputs.
|
| 166 |
+
|
| 167 |
+
Examples:
|
| 168 |
+
|
| 169 |
+
>>> sum_fn = lambda c, x: (c + x, c + x)
|
| 170 |
+
>>> init = keras.ops.array(0)
|
| 171 |
+
>>> xs = keras.ops.array([1, 2, 3, 4, 5])
|
| 172 |
+
>>> carry, result = keras.ops.scan(sum_fn, init, xs)
|
| 173 |
+
>>> carry
|
| 174 |
+
15
|
| 175 |
+
>>> result
|
| 176 |
+
[1, 3, 6, 10, 15]
|
| 177 |
+
"""
|
| 178 |
+
if any_symbolic_tensors((init, xs)):
|
| 179 |
+
return Scan(reverse=reverse, unroll=unroll).symbolic_call(
|
| 180 |
+
f, init, xs, length
|
| 181 |
+
)
|
| 182 |
+
return backend.core.scan(
|
| 183 |
+
f, init, xs, length, reverse=reverse, unroll=unroll
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class AssociativeScan(Operation):
|
| 188 |
+
def __init__(self, reverse=False):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.reverse = reverse
|
| 191 |
+
|
| 192 |
+
def call(self, f, elems, axis=0):
|
| 193 |
+
return backend.core.associative_scan(
|
| 194 |
+
f, elems, reverse=self.reverse, axis=axis
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def compute_output_spec(self, f, elems, axis):
|
| 198 |
+
elems_flat = tree.flatten(elems)
|
| 199 |
+
lens = [elem.shape[axis] for elem in elems_flat]
|
| 200 |
+
if len(set(lens)) != 1:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
"Array inputs to associative_scan must have the same "
|
| 203 |
+
"first dimension. (saw: {})".format(
|
| 204 |
+
[elem.shape for elem in elems_flat]
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
x = tree.pack_sequence_as(
|
| 209 |
+
elems, [slice_along_axis(x, 0, 1, axis=axis) for x in elems_flat]
|
| 210 |
+
)
|
| 211 |
+
y_spec = backend.compute_output_spec(f, x, x)
|
| 212 |
+
|
| 213 |
+
def _restore_shape(x):
|
| 214 |
+
return KerasTensor(
|
| 215 |
+
shape=elems_flat[0].shape, dtype=x.dtype, sparse=x.sparse
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
y_spec = tree.map_structure(_restore_shape, y_spec)
|
| 219 |
+
return y_spec
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@keras_export("keras.ops.associative_scan")
|
| 223 |
+
def associative_scan(f, elems, reverse=False, axis=0):
|
| 224 |
+
"""Performs a scan with an associative binary operation, in parallel.
|
| 225 |
+
|
| 226 |
+
This operation his similar to `scan`, with the key difference that
|
| 227 |
+
`associative_scan` is a parallel implementation with
|
| 228 |
+
potentially significant performance benefits, especially when jit compiled.
|
| 229 |
+
The catch is that it can only be used when `f` is a binary associative
|
| 230 |
+
operation (i.e. it must verify `f(a, f(b, c)) == f(f(a, b), c)`).
|
| 231 |
+
|
| 232 |
+
For an introduction to associative scans, refer to this paper:
|
| 233 |
+
Blelloch, Guy E. 1990.
|
| 234 |
+
[Prefix Sums and Their Applications](
|
| 235 |
+
https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf).
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
f: A Python callable implementing an associative binary operation with
|
| 239 |
+
signature `r = f(a, b)`. Function `f` must be associative, i.e.,
|
| 240 |
+
it must satisfy the equation
|
| 241 |
+
`f(a, f(b, c)) == f(f(a, b), c)`.
|
| 242 |
+
The inputs and result are (possibly nested Python tree structures
|
| 243 |
+
of) array(s) matching `elems`. Each array has a dimension in place
|
| 244 |
+
of the `axis` dimension. `f` should be applied elementwise over
|
| 245 |
+
the `axis` dimension.
|
| 246 |
+
The result `r` has the same shape (and structure) as the
|
| 247 |
+
two inputs `a` and `b`.
|
| 248 |
+
elems: A (possibly nested Python tree structure of) array(s), each with
|
| 249 |
+
an `axis` dimension of size `num_elems`.
|
| 250 |
+
reverse: A boolean stating if the scan should be reversed with respect
|
| 251 |
+
to the `axis` dimension.
|
| 252 |
+
axis: an integer identifying the axis over which the scan should occur.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
A (possibly nested Python tree structure of) array(s) of the same shape
|
| 256 |
+
and structure as `elems`, in which the `k`'th element of `axis` is
|
| 257 |
+
the result of recursively applying `f` to combine the first `k`
|
| 258 |
+
elements of `elems` along `axis`. For example, given
|
| 259 |
+
`elems = [a, b, c, ...]`, the result would be
|
| 260 |
+
`[a, f(a, b), f(f(a, b), c), ...]`.
|
| 261 |
+
|
| 262 |
+
Examples:
|
| 263 |
+
|
| 264 |
+
>>> sum_fn = lambda x, y: x + y
|
| 265 |
+
>>> xs = keras.ops.arange(5)
|
| 266 |
+
>>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
|
| 267 |
+
>>> ys
|
| 268 |
+
[0, 1, 3, 6, 10]
|
| 269 |
+
|
| 270 |
+
>>> sum_fn = lambda x, y: [x[0] + y[0], x[1] + y[1], x[2] + y[2]]
|
| 271 |
+
>>> xs = [keras.ops.array([[1, 2]]) for _ in range(3)]
|
| 272 |
+
>>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)
|
| 273 |
+
>>> ys
|
| 274 |
+
[[1, 3], [1, 3], [1, 3]]
|
| 275 |
+
"""
|
| 276 |
+
if any_symbolic_tensors((elems,)):
|
| 277 |
+
return AssociativeScan(reverse=reverse).symbolic_call(f, elems, axis)
|
| 278 |
+
return backend.core.associative_scan(f, elems, reverse=reverse, axis=axis)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class Scatter(Operation):
|
| 282 |
+
def call(self, indices, values, shape):
|
| 283 |
+
return backend.core.scatter(indices, values, shape)
|
| 284 |
+
|
| 285 |
+
def compute_output_spec(self, indices, values, shape):
|
| 286 |
+
return KerasTensor(shape, dtype=values.dtype)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@keras_export("keras.ops.scatter")
|
| 290 |
+
def scatter(indices, values, shape):
|
| 291 |
+
"""Returns a tensor of shape `shape` where `indices` are set to `values`.
|
| 292 |
+
|
| 293 |
+
At a high level, this operation does `zeros[indices] = updates` and
|
| 294 |
+
returns the output. It is equivalent to:
|
| 295 |
+
|
| 296 |
+
```python
|
| 297 |
+
zeros = keras.ops.zeros(shape)
|
| 298 |
+
output = keras.ops.scatter_update(zeros, indices, values)
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
indices: A tensor or list/tuple specifying
|
| 303 |
+
indices for the values in `values`.
|
| 304 |
+
values: A tensor, the values to be set at `indices`.
|
| 305 |
+
shape: Shape of the output tensor.
|
| 306 |
+
|
| 307 |
+
Example:
|
| 308 |
+
|
| 309 |
+
>>> indices = [[0, 1], [1, 1]]
|
| 310 |
+
>>> values = np.array([1., 1.])
|
| 311 |
+
>>> keras.ops.scatter(indices, values, shape=(2, 2))
|
| 312 |
+
array([[0., 1.],
|
| 313 |
+
[0., 1.]])
|
| 314 |
+
"""
|
| 315 |
+
if any_symbolic_tensors((indices, values, shape)):
|
| 316 |
+
return Scatter().symbolic_call(indices, values, shape)
|
| 317 |
+
return backend.core.scatter(indices, values, shape)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class ScatterUpdate(Operation):
|
| 321 |
+
def call(self, inputs, indices, updates):
|
| 322 |
+
return backend.core.scatter_update(inputs, indices, updates)
|
| 323 |
+
|
| 324 |
+
def compute_output_spec(self, inputs, indices, updates):
|
| 325 |
+
return KerasTensor(inputs.shape, dtype=inputs.dtype)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@keras_export("keras.ops.scatter_update")
|
| 329 |
+
def scatter_update(inputs, indices, updates):
|
| 330 |
+
"""Update inputs via updates at scattered (sparse) indices.
|
| 331 |
+
|
| 332 |
+
At a high level, this operation does `inputs[indices] = updates`.
|
| 333 |
+
Assume `inputs` is a tensor of shape `(D0, D1, ..., Dn)`, there are 2 main
|
| 334 |
+
usages of `scatter_update`.
|
| 335 |
+
|
| 336 |
+
1. `indices` is a 2D tensor of shape `(num_updates, n)`, where `num_updates`
|
| 337 |
+
is the number of updates to perform, and `updates` is a 1D tensor of
|
| 338 |
+
shape `(num_updates,)`. For example, if `inputs` is `zeros((4, 4, 4))`,
|
| 339 |
+
and we want to update `inputs[1, 2, 3]` and `inputs[0, 1, 3]` as 1, then
|
| 340 |
+
we can use:
|
| 341 |
+
|
| 342 |
+
```python
|
| 343 |
+
inputs = np.zeros((4, 4, 4))
|
| 344 |
+
indices = [[1, 2, 3], [0, 1, 3]]
|
| 345 |
+
updates = np.array([1., 1.])
|
| 346 |
+
inputs = keras.ops.scatter_update(inputs, indices, updates)
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
2 `indices` is a 2D tensor of shape `(num_updates, k)`, where `num_updates`
|
| 350 |
+
is the number of updates to perform, and `k` (`k < n`) is the size of
|
| 351 |
+
each index in `indices`. `updates` is a `n - k`-D tensor of shape
|
| 352 |
+
`(num_updates, inputs.shape[k:])`. For example, if
|
| 353 |
+
`inputs = np.zeros((4, 4, 4))`, and we want to update `inputs[1, 2, :]`
|
| 354 |
+
and `inputs[2, 3, :]` as `[1, 1, 1, 1]`, then `indices` would have shape
|
| 355 |
+
`(num_updates, 2)` (`k = 2`), and `updates` would have shape
|
| 356 |
+
`(num_updates, 4)` (`inputs.shape[2:] = 4`). See the code below:
|
| 357 |
+
|
| 358 |
+
```python
|
| 359 |
+
inputs = np.zeros((4, 4, 4))
|
| 360 |
+
indices = [[1, 2], [2, 3]]
|
| 361 |
+
updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
|
| 362 |
+
inputs = keras.ops.scatter_update(inputs, indices, updates)
|
| 363 |
+
```
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
inputs: A tensor, the tensor to be updated.
|
| 367 |
+
indices: A tensor or list/tuple of shape `(N, inputs.ndim)`, specifying
|
| 368 |
+
indices to update. `N` is the number of indices to update, must be
|
| 369 |
+
equal to the first dimension of `updates`.
|
| 370 |
+
updates: A tensor, the new values to be put to `inputs` at `indices`.
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
A tensor, has the same shape and dtype as `inputs`.
|
| 374 |
+
"""
|
| 375 |
+
if any_symbolic_tensors((inputs, indices, updates)):
|
| 376 |
+
return ScatterUpdate().symbolic_call(inputs, indices, updates)
|
| 377 |
+
return backend.core.scatter_update(inputs, indices, updates)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class Slice(Operation):
|
| 381 |
+
def call(self, inputs, start_indices, shape):
|
| 382 |
+
return backend.core.slice(inputs, start_indices, shape)
|
| 383 |
+
|
| 384 |
+
def compute_output_spec(self, inputs, start_indices, shape):
|
| 385 |
+
return KerasTensor(shape, dtype=inputs.dtype)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@keras_export("keras.ops.slice")
|
| 389 |
+
def slice(inputs, start_indices, shape):
|
| 390 |
+
"""Return a slice of an input tensor.
|
| 391 |
+
|
| 392 |
+
At a high level, this operation is an explicit replacement for array slicing
|
| 393 |
+
e.g. `inputs[start_indices: start_indices + shape]`.
|
| 394 |
+
Unlike slicing via brackets, this operation will accept tensor start
|
| 395 |
+
indices on all backends, which is useful when indices dynamically computed
|
| 396 |
+
via other tensor operations.
|
| 397 |
+
|
| 398 |
+
```python
|
| 399 |
+
inputs = np.zeros((5, 5))
|
| 400 |
+
start_indices = np.array([3, 3])
|
| 401 |
+
shape = np.array([2, 2])
|
| 402 |
+
inputs = keras.ops.slice(inputs, start_indices, shape)
|
| 403 |
+
```
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
inputs: A tensor, the tensor to be updated.
|
| 407 |
+
start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying
|
| 408 |
+
the starting indices for updating.
|
| 409 |
+
shape: The full shape of the returned slice.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
A tensor, has the same shape and dtype as `inputs`.
|
| 413 |
+
"""
|
| 414 |
+
if any_symbolic_tensors((inputs, start_indices, shape)):
|
| 415 |
+
return Slice().symbolic_call(inputs, start_indices, shape)
|
| 416 |
+
return backend.core.slice(inputs, start_indices, shape)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class SliceUpdate(Operation):
|
| 420 |
+
def call(self, inputs, start_indices, updates):
|
| 421 |
+
return backend.core.slice_update(inputs, start_indices, updates)
|
| 422 |
+
|
| 423 |
+
def compute_output_spec(self, inputs, start_indices, updates):
|
| 424 |
+
return KerasTensor(inputs.shape, dtype=inputs.dtype)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@keras_export("keras.ops.slice_update")
|
| 428 |
+
def slice_update(inputs, start_indices, updates):
|
| 429 |
+
"""Update an input by slicing in a tensor of updated values.
|
| 430 |
+
|
| 431 |
+
At a high level, this operation does
|
| 432 |
+
`inputs[start_indices: start_indices + updates.shape] = updates`.
|
| 433 |
+
Assume inputs is a tensor of shape `(D0, D1, ..., Dn)`,
|
| 434 |
+
`start_indices` must be a list/tuple of n integers, specifying the starting
|
| 435 |
+
indices. `updates` must have the same rank as `inputs`, and the size of each
|
| 436 |
+
dim must not exceed `Di - start_indices[i]`. For example, if we have 2D
|
| 437 |
+
inputs `inputs = np.zeros((5, 5))`, and we want to update the intersection
|
| 438 |
+
of last 2 rows and last 2 columns as 1, i.e.,
|
| 439 |
+
`inputs[3:, 3:] = np.ones((2, 2))`, then we can use the code below:
|
| 440 |
+
|
| 441 |
+
```python
|
| 442 |
+
inputs = np.zeros((5, 5))
|
| 443 |
+
start_indices = [3, 3]
|
| 444 |
+
updates = np.ones((2, 2))
|
| 445 |
+
inputs = keras.ops.slice_update(inputs, start_indices, updates)
|
| 446 |
+
```
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
inputs: A tensor, the tensor to be updated.
|
| 450 |
+
start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying
|
| 451 |
+
the starting indices for updating.
|
| 452 |
+
updates: A tensor, the new values to be put to `inputs` at `indices`.
|
| 453 |
+
`updates` must have the same rank as `inputs`.
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
A tensor, has the same shape and dtype as `inputs`.
|
| 457 |
+
"""
|
| 458 |
+
if any_symbolic_tensors((inputs, start_indices, updates)):
|
| 459 |
+
return SliceUpdate().symbolic_call(inputs, start_indices, updates)
|
| 460 |
+
return backend.core.slice_update(inputs, start_indices, updates)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class Switch(Operation):
|
| 464 |
+
def call(self, index, branches, *operands):
|
| 465 |
+
return backend.core.switch(index, branches, *operands)
|
| 466 |
+
|
| 467 |
+
def compute_output_spec(self, index, branches, *operands):
|
| 468 |
+
# We use first branch for output_spec
|
| 469 |
+
spec = backend.compute_output_spec(branches[0], *operands)
|
| 470 |
+
return spec
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
@keras_export("keras.ops.switch")
|
| 474 |
+
def switch(index, branches, *operands):
|
| 475 |
+
"""Apply exactly one of the `branches` given by `index`.
|
| 476 |
+
|
| 477 |
+
If `index` is out of bounds, it is clamped to within bounds.
|
| 478 |
+
|
| 479 |
+
The semantics of `switch` are given roughly by this Python implementation:
|
| 480 |
+
|
| 481 |
+
```python
|
| 482 |
+
def switch(index, branches, *operands):
|
| 483 |
+
index = clamp(0, index, len(branches) - 1)
|
| 484 |
+
return branches[index](*operands)
|
| 485 |
+
```
|
| 486 |
+
|
| 487 |
+
Args:
|
| 488 |
+
index: An integer scalar indicating which branch function to apply.
|
| 489 |
+
branches: A sequence of functions to be applied based on `index`.
|
| 490 |
+
operands: Inputs to whichever branch is applied.
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
The outputs of `branch(*operands)` for the branch that was selected
|
| 494 |
+
based on `index`.
|
| 495 |
+
|
| 496 |
+
Examples:
|
| 497 |
+
|
| 498 |
+
>>> add_fn = lambda x, y: x + y
|
| 499 |
+
>>> subtract_fn = lambda x, y: x - y
|
| 500 |
+
>>> x = keras.ops.array(2.0)
|
| 501 |
+
>>> y = keras.ops.array(0.5)
|
| 502 |
+
>>> branches = [add_fn, subtract_fn]
|
| 503 |
+
>>> keras.ops.switch(0, branches, x, y)
|
| 504 |
+
2.5
|
| 505 |
+
|
| 506 |
+
>>> keras.ops.switch(1, branches, x, y)
|
| 507 |
+
1.5
|
| 508 |
+
"""
|
| 509 |
+
if any_symbolic_tensors(operands):
|
| 510 |
+
return Switch().symbolic_call(index, branches, *operands)
|
| 511 |
+
return backend.core.switch(index, branches, *operands)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class WhileLoop(Operation):
|
| 515 |
+
def __init__(self, cond, body, maximum_iterations):
|
| 516 |
+
super().__init__()
|
| 517 |
+
self.cond = cond
|
| 518 |
+
self.body = body
|
| 519 |
+
self.maximum_iterations = maximum_iterations
|
| 520 |
+
|
| 521 |
+
def call(self, loop_vars):
|
| 522 |
+
return backend.core.while_loop(
|
| 523 |
+
self.cond,
|
| 524 |
+
self.body,
|
| 525 |
+
loop_vars,
|
| 526 |
+
maximum_iterations=self.maximum_iterations,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
def compute_output_spec(self, loop_vars):
|
| 530 |
+
return [KerasTensor(v.shape, dtype=v.dtype) for v in loop_vars]
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@keras_export("keras.ops.while_loop")
|
| 534 |
+
def while_loop(
|
| 535 |
+
cond,
|
| 536 |
+
body,
|
| 537 |
+
loop_vars,
|
| 538 |
+
maximum_iterations=None,
|
| 539 |
+
):
|
| 540 |
+
"""While loop implementation.
|
| 541 |
+
|
| 542 |
+
Args:
|
| 543 |
+
cond: A callable that represents the termination condition of the loop.
|
| 544 |
+
Must accept a `loop_vars` like structure as an argument. If
|
| 545 |
+
`loop_vars` is a tuple or list, each element of `loop_vars` will be
|
| 546 |
+
passed positionally to the callable.
|
| 547 |
+
body: A callable that represents the loop body. Must accept a
|
| 548 |
+
`loop_vars` like structure as an argument, and return update value
|
| 549 |
+
with the same structure. If `loop_vars` is a tuple or list, each
|
| 550 |
+
element of `loop_vars` will be passed positionally to the callable.
|
| 551 |
+
loop_vars: An arbitrary nested structure of tensor state to persist
|
| 552 |
+
across loop iterations.
|
| 553 |
+
maximum_iterations: Optional maximum number of iterations of the while
|
| 554 |
+
loop to run. If provided, the `cond` output is AND-ed with an
|
| 555 |
+
additional condition ensuring the number of iterations executed is
|
| 556 |
+
no greater than `maximum_iterations`.
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
A list/tuple of tensors, has the same shape and dtype as `inputs`.
|
| 560 |
+
|
| 561 |
+
Examples:
|
| 562 |
+
|
| 563 |
+
>>> i = 0
|
| 564 |
+
>>> cond = lambda i: i < 10
|
| 565 |
+
>>> body = lambda i: i + 1
|
| 566 |
+
>>> keras.ops.while_loop(cond, body, i)
|
| 567 |
+
10
|
| 568 |
+
|
| 569 |
+
>>> x, y = 0, 1
|
| 570 |
+
>>> cond = lambda x, y: x < 10
|
| 571 |
+
>>> body = lambda x, y: (x + 1, y + 1)
|
| 572 |
+
>>> keras.ops.while_loop(cond, body, (x, y))
|
| 573 |
+
10, 11
|
| 574 |
+
"""
|
| 575 |
+
return backend.core.while_loop(
|
| 576 |
+
cond,
|
| 577 |
+
body,
|
| 578 |
+
loop_vars,
|
| 579 |
+
maximum_iterations=maximum_iterations,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
class StopGradient(Operation):
|
| 584 |
+
def __init__(self):
|
| 585 |
+
super().__init__()
|
| 586 |
+
|
| 587 |
+
def call(self, variable):
|
| 588 |
+
return backend.core.stop_gradient(variable)
|
| 589 |
+
|
| 590 |
+
def compute_output_spec(self, variable):
|
| 591 |
+
return KerasTensor(variable.shape, dtype=variable.dtype)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
@keras_export("keras.ops.stop_gradient")
|
| 595 |
+
def stop_gradient(variable):
|
| 596 |
+
"""Stops gradient computation.
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
variable: A tensor variable for which the gradient
|
| 600 |
+
computation is to be disabled.
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
The variable with gradient computation disabled.
|
| 604 |
+
|
| 605 |
+
Examples:
|
| 606 |
+
|
| 607 |
+
>>> var = keras.backend.convert_to_tensor(
|
| 608 |
+
... [1., 2., 3.],
|
| 609 |
+
... dtype="float32"
|
| 610 |
+
... )
|
| 611 |
+
>>> var = keras.ops.stop_gradient(var)
|
| 612 |
+
"""
|
| 613 |
+
if any_symbolic_tensors((variable,)):
|
| 614 |
+
return StopGradient().symbolic_call(variable)
|
| 615 |
+
return backend.core.stop_gradient(variable)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class ForiLoop(Operation):
|
| 619 |
+
def __init__(self, lower, upper, body_fun):
|
| 620 |
+
super().__init__()
|
| 621 |
+
self.lower = lower
|
| 622 |
+
self.upper = upper
|
| 623 |
+
self.body_fun = body_fun
|
| 624 |
+
|
| 625 |
+
def call(self, init_val):
|
| 626 |
+
return backend.core.fori_loop(
|
| 627 |
+
self.lower,
|
| 628 |
+
self.upper,
|
| 629 |
+
self.body_fun,
|
| 630 |
+
init_val,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
def compute_output_spec(self, init_val):
|
| 634 |
+
return KerasTensor(init_val.shape, dtype=init_val.dtype)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
@keras_export("keras.ops.fori_loop")
|
| 638 |
+
def fori_loop(lower, upper, body_fun, init_val):
|
| 639 |
+
"""For loop implementation.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
lower: The initial value of the loop variable.
|
| 643 |
+
upper: The upper bound of the loop variable.
|
| 644 |
+
body_fun: A callable that represents the loop body. Must take two
|
| 645 |
+
arguments: the loop variable and the loop state. The loop state
|
| 646 |
+
should be updated and returned by this function.
|
| 647 |
+
init_val: The initial value of the loop state.
|
| 648 |
+
|
| 649 |
+
Returns:
|
| 650 |
+
The final state after the loop.
|
| 651 |
+
|
| 652 |
+
Example:
|
| 653 |
+
|
| 654 |
+
>>> lower = 0
|
| 655 |
+
>>> upper = 10
|
| 656 |
+
>>> body_fun = lambda i, s: (i + 1, s + i)
|
| 657 |
+
>>> init_val = 0
|
| 658 |
+
>>> keras.ops.fori_loop(lower, upper, body_fun, init_val)
|
| 659 |
+
45
|
| 660 |
+
"""
|
| 661 |
+
if any_symbolic_tensors((lower, upper, init_val)):
|
| 662 |
+
return ForiLoop(lower, upper, body_fun).symbolic_call(init_val)
|
| 663 |
+
return backend.core.fori_loop(lower, upper, body_fun, init_val)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
class Unstack(Operation):
|
| 667 |
+
def __init__(self, num=None, axis=0):
|
| 668 |
+
super().__init__()
|
| 669 |
+
self.num = num
|
| 670 |
+
self.axis = axis
|
| 671 |
+
|
| 672 |
+
def call(self, x):
|
| 673 |
+
return backend.core.unstack(x, self.num, self.axis)
|
| 674 |
+
|
| 675 |
+
def compute_output_spec(self, x):
|
| 676 |
+
axis = self.axis
|
| 677 |
+
if axis < 0:
|
| 678 |
+
axis = len(x.shape) + axis
|
| 679 |
+
output_shapes = x.shape[:axis] + x.shape[axis + 1 :]
|
| 680 |
+
num = self.num
|
| 681 |
+
if num is None:
|
| 682 |
+
num = x.shape[axis]
|
| 683 |
+
if num is None:
|
| 684 |
+
raise ValueError(
|
| 685 |
+
"Cannot infer argument `num` from shape "
|
| 686 |
+
f"{x.shape}. Either provide a tensor with a "
|
| 687 |
+
"concrete shape in the `axis` dimension or "
|
| 688 |
+
"explicitly pass the `num` argument."
|
| 689 |
+
)
|
| 690 |
+
output = [
|
| 691 |
+
KerasTensor(shape=output_shapes, dtype=x.dtype) for _ in range(num)
|
| 692 |
+
]
|
| 693 |
+
return output
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
@keras_export("keras.ops.unstack")
|
| 697 |
+
def unstack(x, num=None, axis=0):
|
| 698 |
+
"""Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
|
| 699 |
+
|
| 700 |
+
Args:
|
| 701 |
+
x: The input tensor.
|
| 702 |
+
num: The length of the dimension axis. Automatically inferred
|
| 703 |
+
if `None`.
|
| 704 |
+
axis: The axis along which to unpack.
|
| 705 |
+
|
| 706 |
+
Returns:
|
| 707 |
+
A list of tensors unpacked along the given axis.
|
| 708 |
+
|
| 709 |
+
Example:
|
| 710 |
+
|
| 711 |
+
>>> x = keras.ops.array([[1, 2], [3, 4]])
|
| 712 |
+
>>> keras.ops.unstack(x, axis=0)
|
| 713 |
+
[array([1, 2]), array([3, 4])]
|
| 714 |
+
"""
|
| 715 |
+
if any_symbolic_tensors((x,)):
|
| 716 |
+
return Unstack(num, axis).symbolic_call(x)
|
| 717 |
+
return backend.core.unstack(x, num=num, axis=axis)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
@keras_export("keras.ops.shape")
|
| 721 |
+
def shape(x):
|
| 722 |
+
"""Gets the shape of the tensor input.
|
| 723 |
+
|
| 724 |
+
Note: On the TensorFlow backend, when `x` is a `tf.Tensor` with dynamic
|
| 725 |
+
shape, dimensions which are dynamic in the context of a compiled function
|
| 726 |
+
will have a `tf.Tensor` value instead of a static integer value.
|
| 727 |
+
|
| 728 |
+
Args:
|
| 729 |
+
x: A tensor. This function will try to access the `shape` attribute of
|
| 730 |
+
the input tensor.
|
| 731 |
+
|
| 732 |
+
Returns:
|
| 733 |
+
A tuple of integers or None values, indicating the shape of the input
|
| 734 |
+
tensor.
|
| 735 |
+
|
| 736 |
+
Example:
|
| 737 |
+
|
| 738 |
+
>>> x = keras.ops.zeros((8, 12))
|
| 739 |
+
>>> keras.ops.shape(x)
|
| 740 |
+
(8, 12)
|
| 741 |
+
"""
|
| 742 |
+
if any_symbolic_tensors((x,)):
|
| 743 |
+
return x.shape
|
| 744 |
+
return backend.core.shape(x)
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
@keras_export("keras.ops.dtype")
|
| 748 |
+
def dtype(x):
|
| 749 |
+
"""Return the dtype of the tensor input as a standardized string.
|
| 750 |
+
|
| 751 |
+
Note that due to the standardization, the dtype will not compare equal
|
| 752 |
+
to the backend-specific version of the dtype.
|
| 753 |
+
|
| 754 |
+
Args:
|
| 755 |
+
x: A tensor. This function will try to access the `dtype` attribute of
|
| 756 |
+
the input tensor.
|
| 757 |
+
|
| 758 |
+
Returns:
|
| 759 |
+
A string indicating the dtype of the input tensor, e.g. `"float32"`.
|
| 760 |
+
|
| 761 |
+
Example:
|
| 762 |
+
|
| 763 |
+
>>> x = keras.ops.zeros((8, 12))
|
| 764 |
+
>>> keras.ops.dtype(x)
|
| 765 |
+
'float32'
|
| 766 |
+
|
| 767 |
+
"""
|
| 768 |
+
return backend.standardize_dtype(x.dtype)
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
class Cast(Operation):
|
| 772 |
+
def __init__(self, dtype):
|
| 773 |
+
super().__init__()
|
| 774 |
+
self.dtype = backend.standardize_dtype(dtype)
|
| 775 |
+
|
| 776 |
+
def call(self, x):
|
| 777 |
+
return backend.core.cast(x, self.dtype)
|
| 778 |
+
|
| 779 |
+
def compute_output_spec(self, x):
|
| 780 |
+
return backend.KerasTensor(shape=x.shape, dtype=self.dtype)
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
@keras_export("keras.ops.cast")
|
| 784 |
+
def cast(x, dtype):
|
| 785 |
+
"""Cast a tensor to the desired dtype.
|
| 786 |
+
|
| 787 |
+
Args:
|
| 788 |
+
x: A tensor or variable.
|
| 789 |
+
dtype: The target type.
|
| 790 |
+
|
| 791 |
+
Returns:
|
| 792 |
+
A tensor of the specified `dtype`.
|
| 793 |
+
|
| 794 |
+
Example:
|
| 795 |
+
|
| 796 |
+
>>> x = keras.ops.arange(4)
|
| 797 |
+
>>> x = keras.ops.cast(x, dtype="float16")
|
| 798 |
+
"""
|
| 799 |
+
dtype = backend.standardize_dtype(dtype)
|
| 800 |
+
|
| 801 |
+
if any_symbolic_tensors((x,)):
|
| 802 |
+
return Cast(dtype=dtype)(x)
|
| 803 |
+
return backend.core.cast(x, dtype)
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
class SaturateCast(Operation):
|
| 807 |
+
def __init__(self, dtype):
|
| 808 |
+
super().__init__()
|
| 809 |
+
self.dtype = backend.standardize_dtype(dtype)
|
| 810 |
+
|
| 811 |
+
def call(self, x):
|
| 812 |
+
return _saturate_cast(x, self.dtype)
|
| 813 |
+
|
| 814 |
+
def compute_output_spec(self, x):
|
| 815 |
+
return backend.KerasTensor(shape=x.shape, dtype=self.dtype)
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
@keras_export("keras.ops.saturate_cast")
|
| 819 |
+
def saturate_cast(x, dtype):
|
| 820 |
+
"""Performs a safe saturating cast to the desired dtype.
|
| 821 |
+
|
| 822 |
+
Saturating cast prevents data type overflow when casting to `dtype` with
|
| 823 |
+
smaller values range. E.g.
|
| 824 |
+
`ops.cast(ops.cast([-1, 256], "float32"), "uint8")` returns `[255, 0]`,
|
| 825 |
+
but `ops.saturate_cast(ops.cast([-1, 256], "float32"), "uint8")` returns
|
| 826 |
+
`[0, 255]`.
|
| 827 |
+
|
| 828 |
+
Args:
|
| 829 |
+
x: A tensor or variable.
|
| 830 |
+
dtype: The target type.
|
| 831 |
+
|
| 832 |
+
Returns:
|
| 833 |
+
A safely casted tensor of the specified `dtype`.
|
| 834 |
+
|
| 835 |
+
Example:
|
| 836 |
+
|
| 837 |
+
Image resizing with bicubic interpolation may produce values outside
|
| 838 |
+
original range.
|
| 839 |
+
>>> image2x2 = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1)
|
| 840 |
+
>>> image4x4 = tf.image.resize(image2x2, (4, 4), method="bicubic")
|
| 841 |
+
>>> print(image4x4.numpy().squeeze())
|
| 842 |
+
>>> # [[-22.500004 -22.204624 -21.618908 -21.32353 ]
|
| 843 |
+
>>> # [ 52.526054 52.82143 53.407146 53.70253 ]
|
| 844 |
+
>>> # [201.29752 201.59288 202.17859 202.47395 ]
|
| 845 |
+
>>> # [276.32355 276.61893 277.20465 277.50006 ]]
|
| 846 |
+
|
| 847 |
+
Casting this resized image back to `uint8` will cause overflow.
|
| 848 |
+
>>> image4x4_casted = ops.cast(image4x4, "uint8")
|
| 849 |
+
>>> print(image4x4_casted.numpy().squeeze())
|
| 850 |
+
>>> # [[234 234 235 235]
|
| 851 |
+
>>> # [ 52 52 53 53]
|
| 852 |
+
>>> # [201 201 202 202]
|
| 853 |
+
>>> # [ 20 20 21 21]]
|
| 854 |
+
|
| 855 |
+
Saturate casting to `uint8` will clip values to `uint8` range before
|
| 856 |
+
casting and will not cause overflow.
|
| 857 |
+
>>> image4x4_saturate_casted = ops.saturate_cast(image4x4, "uint8")
|
| 858 |
+
>>> print(image4x4_saturate_casted.numpy().squeeze())
|
| 859 |
+
>>> # [[ 0 0 0 0]
|
| 860 |
+
>>> # [ 52 52 53 53]
|
| 861 |
+
>>> # [201 201 202 202]
|
| 862 |
+
>>> # [255 255 255 255]]
|
| 863 |
+
|
| 864 |
+
"""
|
| 865 |
+
dtype = backend.standardize_dtype(dtype)
|
| 866 |
+
|
| 867 |
+
if any_symbolic_tensors((x,)):
|
| 868 |
+
return SaturateCast(dtype=dtype)(x)
|
| 869 |
+
return _saturate_cast(x, dtype)
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
def _saturate_cast(x, dtype, backend_module=None):
|
| 873 |
+
backend_module = backend_module or backend
|
| 874 |
+
|
| 875 |
+
def get_dtype_min_max(dtype):
|
| 876 |
+
if "bool" == dtype:
|
| 877 |
+
dtype_min = 0
|
| 878 |
+
dtype_max = 1
|
| 879 |
+
elif "int" in dtype:
|
| 880 |
+
dtype_min = ml_dtypes.iinfo(dtype).min
|
| 881 |
+
dtype_max = ml_dtypes.iinfo(dtype).max
|
| 882 |
+
else:
|
| 883 |
+
dtype_min = ml_dtypes.finfo(dtype).min
|
| 884 |
+
dtype_max = ml_dtypes.finfo(dtype).max
|
| 885 |
+
return dtype_min, dtype_max
|
| 886 |
+
|
| 887 |
+
dtype = backend.standardize_dtype(dtype)
|
| 888 |
+
in_dtype = backend.standardize_dtype(x.dtype)
|
| 889 |
+
in_min, in_max = get_dtype_min_max(in_dtype)
|
| 890 |
+
out_min, out_max = get_dtype_min_max(dtype)
|
| 891 |
+
|
| 892 |
+
# The output min/max may not actually be representable in the
|
| 893 |
+
# in_dtype (e.g. casting float32 to uint32). This can lead to undefined
|
| 894 |
+
# behavior when trying to cast a value outside the valid range of the
|
| 895 |
+
# target type. We work around this by nudging the min/max to fall within
|
| 896 |
+
# the valid output range. The catch is that we may actually saturate
|
| 897 |
+
# to a value less than the true saturation limit, but this is the best we
|
| 898 |
+
# can do in order to avoid UB without backend op.
|
| 899 |
+
min_limit = np.maximum(in_min, out_min).astype(in_dtype)
|
| 900 |
+
if min_limit < out_min:
|
| 901 |
+
min_limit = np.nextafter(min_limit, 0, dtype=in_dtype)
|
| 902 |
+
max_limit = np.minimum(in_max, out_max).astype(in_dtype)
|
| 903 |
+
if max_limit > out_max:
|
| 904 |
+
max_limit = np.nextafter(max_limit, 0, dtype=in_dtype)
|
| 905 |
+
|
| 906 |
+
# Unconditionally apply `clip` to fix `inf` behavior.
|
| 907 |
+
x = backend_module.numpy.clip(x, min_limit, max_limit)
|
| 908 |
+
|
| 909 |
+
return backend_module.cast(x, dtype)
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
class ConvertToTensor(Operation):
|
| 913 |
+
def __init__(self, dtype, sparse):
|
| 914 |
+
super().__init__()
|
| 915 |
+
self.dtype = backend.standardize_dtype(dtype)
|
| 916 |
+
self.sparse = sparse
|
| 917 |
+
|
| 918 |
+
def call(self, x):
|
| 919 |
+
return backend.core.convert_to_tensor(
|
| 920 |
+
x, dtype=self.dtype, sparse=self.sparse
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
def compute_output_spec(self, x):
|
| 924 |
+
dtype = x.dtype if self.dtype is None else self.dtype
|
| 925 |
+
sparse = (
|
| 926 |
+
False if self.sparse is not None and not self.sparse else x.sparse
|
| 927 |
+
)
|
| 928 |
+
return backend.KerasTensor(shape=x.shape, dtype=dtype, sparse=sparse)
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
@keras_export("keras.ops.convert_to_tensor")
|
| 932 |
+
def convert_to_tensor(x, dtype=None, sparse=None):
|
| 933 |
+
"""Convert a NumPy array to a tensor.
|
| 934 |
+
|
| 935 |
+
Args:
|
| 936 |
+
x: A NumPy array, Python array (can be nested) or a backend tensor.
|
| 937 |
+
dtype: The target type. If `None`, the type of `x` is used.
|
| 938 |
+
sparse: Whether to keep sparse tensors. `False` will cause sparse
|
| 939 |
+
tensors to be densified. The default value of `None` means that
|
| 940 |
+
sparse tensors are kept only if the backend supports them.
|
| 941 |
+
|
| 942 |
+
Returns:
|
| 943 |
+
A backend tensor of the specified `dtype` and sparseness.
|
| 944 |
+
|
| 945 |
+
Example:
|
| 946 |
+
|
| 947 |
+
>>> x = np.array([1, 2, 3])
|
| 948 |
+
>>> y = keras.ops.convert_to_tensor(x)
|
| 949 |
+
"""
|
| 950 |
+
if any_symbolic_tensors((x,)):
|
| 951 |
+
return ConvertToTensor(dtype=dtype, sparse=sparse)(x)
|
| 952 |
+
return backend.core.convert_to_tensor(x, dtype=dtype, sparse=sparse)
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
@keras_export("keras.ops.convert_to_numpy")
|
| 956 |
+
def convert_to_numpy(x):
|
| 957 |
+
"""Convert a tensor to a NumPy array.
|
| 958 |
+
|
| 959 |
+
Args:
|
| 960 |
+
x: A tensor.
|
| 961 |
+
|
| 962 |
+
Returns:
|
| 963 |
+
A NumPy array.
|
| 964 |
+
"""
|
| 965 |
+
if any_symbolic_tensors((x,)):
|
| 966 |
+
# This will raise a `ValueError` defined in the `KerasTensor` class.
|
| 967 |
+
# We trigger it rather than duplicate it here.
|
| 968 |
+
return np.array(x)
|
| 969 |
+
return backend.convert_to_numpy(x)
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
class Cond(Operation):
|
| 973 |
+
@traceback_utils.filter_traceback
|
| 974 |
+
def __call__(self, *args, **kwargs):
|
| 975 |
+
def call_fn(*args, **kwargs):
|
| 976 |
+
if any_symbolic_tensors(args, kwargs):
|
| 977 |
+
return self.symbolic_call(*args, **kwargs)
|
| 978 |
+
else:
|
| 979 |
+
return self.call(*args, **kwargs)
|
| 980 |
+
|
| 981 |
+
if traceback_utils.is_traceback_filtering_enabled():
|
| 982 |
+
# Wrap self.call to provide helpful info in case of exception
|
| 983 |
+
call_fn = traceback_utils.inject_argument_info_in_traceback(
|
| 984 |
+
call_fn,
|
| 985 |
+
object_name=(f"{self.__class__.__name__}.call()"),
|
| 986 |
+
)
|
| 987 |
+
return call_fn(*args, **kwargs)
|
| 988 |
+
|
| 989 |
+
# Plain flow.
|
| 990 |
+
return call_fn(*args, **kwargs)
|
| 991 |
+
|
| 992 |
+
def call(self, pred, true_fn, false_fn):
|
| 993 |
+
return backend.core.cond(pred, true_fn, false_fn)
|
| 994 |
+
|
| 995 |
+
def compute_output_spec(self, pred, true_fn, false_fn):
|
| 996 |
+
true_fn_spec = backend.compute_output_spec(true_fn)
|
| 997 |
+
false_fn_spec = backend.compute_output_spec(false_fn)
|
| 998 |
+
if not self._check_output_spec(true_fn_spec, false_fn_spec):
|
| 999 |
+
raise ValueError(
|
| 1000 |
+
"`true_fn` and `false_fn` should return outputs "
|
| 1001 |
+
"of the same kind (struct, dtype and shape). "
|
| 1002 |
+
f"Got {true_fn_spec} and {false_fn_spec} instead."
|
| 1003 |
+
)
|
| 1004 |
+
return true_fn_spec
|
| 1005 |
+
|
| 1006 |
+
def _check_output_spec(self, true_fn_spec, false_fn_spec):
|
| 1007 |
+
try:
|
| 1008 |
+
tree.assert_same_structure(true_fn_spec, false_fn_spec)
|
| 1009 |
+
except:
|
| 1010 |
+
return False
|
| 1011 |
+
|
| 1012 |
+
def check_leaf(t_spec, f_spec):
|
| 1013 |
+
if t_spec is None or f_spec is None:
|
| 1014 |
+
return t_spec is None and f_spec is None
|
| 1015 |
+
return t_spec.shape == f_spec.shape and t_spec.dtype == f_spec.dtype
|
| 1016 |
+
|
| 1017 |
+
same = tree.map_structure(check_leaf, true_fn_spec, false_fn_spec)
|
| 1018 |
+
return all(tree.flatten(same))
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
@keras_export("keras.ops.cond")
|
| 1022 |
+
def cond(pred, true_fn, false_fn):
|
| 1023 |
+
"""Conditionally applies `true_fn` or `false_fn`.
|
| 1024 |
+
|
| 1025 |
+
Args:
|
| 1026 |
+
pred: Boolean scalar type
|
| 1027 |
+
true_fn: Callable returning the output for the `pred == True` case.
|
| 1028 |
+
false_fn: Callable returning the output for the `pred == False` case.
|
| 1029 |
+
|
| 1030 |
+
Returns:
|
| 1031 |
+
The output of either `true_fn` or `false_fn` depending on pred.
|
| 1032 |
+
"""
|
| 1033 |
+
return Cond()(pred, true_fn, false_fn)
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
# TODO: also create an Op subclass VectorizedMap.
|
| 1037 |
+
@keras_export("keras.ops.vectorized_map")
|
| 1038 |
+
def vectorized_map(function, elements):
|
| 1039 |
+
"""Parallel map of `function` on axis 0 of tensor(s) `elements`.
|
| 1040 |
+
|
| 1041 |
+
Schematically, `vectorized_map` implements the following,
|
| 1042 |
+
in the case of a single tensor input `elements`:
|
| 1043 |
+
|
| 1044 |
+
```python
|
| 1045 |
+
def vectorized_map(function, elements)
|
| 1046 |
+
outputs = []
|
| 1047 |
+
for e in elements:
|
| 1048 |
+
outputs.append(function(e))
|
| 1049 |
+
return stack(outputs)
|
| 1050 |
+
```
|
| 1051 |
+
|
| 1052 |
+
In the case of an iterable of tensors `elements`,
|
| 1053 |
+
it implements the following:
|
| 1054 |
+
|
| 1055 |
+
```python
|
| 1056 |
+
def vectorized_map(function, elements)
|
| 1057 |
+
batch_size = elements[0].shape[0]
|
| 1058 |
+
outputs = []
|
| 1059 |
+
for index in range(batch_size):
|
| 1060 |
+
outputs.append(function([e[index] for e in elements]))
|
| 1061 |
+
return np.stack(outputs)
|
| 1062 |
+
```
|
| 1063 |
+
|
| 1064 |
+
In this case, `function` is expected to take as input
|
| 1065 |
+
a single list of tensor arguments.
|
| 1066 |
+
"""
|
| 1067 |
+
return backend.core.vectorized_map(function, elements)
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
@keras_export("keras.ops.is_tensor")
|
| 1071 |
+
def is_tensor(x):
|
| 1072 |
+
"""Check whether the given object is a tensor.
|
| 1073 |
+
|
| 1074 |
+
Note: This checks for backend specific tensors so passing a TensorFlow
|
| 1075 |
+
tensor would return `False` if your backend is PyTorch or JAX.
|
| 1076 |
+
|
| 1077 |
+
Args:
|
| 1078 |
+
x: A variable.
|
| 1079 |
+
|
| 1080 |
+
Returns:
|
| 1081 |
+
`True` if `x` is a tensor, otherwise `False`.
|
| 1082 |
+
"""
|
| 1083 |
+
return backend.core.is_tensor(x)
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
@keras_export("keras.ops.custom_gradient")
|
| 1087 |
+
def custom_gradient(f):
|
| 1088 |
+
"""Decorator to define a function with a custom gradient.
|
| 1089 |
+
|
| 1090 |
+
This decorator allows fine grained control over the gradients of a sequence
|
| 1091 |
+
for operations. This may be useful for multiple reasons, including providing
|
| 1092 |
+
a more efficient or numerically stable gradient for a sequence of
|
| 1093 |
+
operations.
|
| 1094 |
+
|
| 1095 |
+
Args:
|
| 1096 |
+
f: Function `f(*args)` that returns a tuple
|
| 1097 |
+
`(output, grad_fn)`, where:
|
| 1098 |
+
- `args` is a sequence of (nested structures of) tensor inputs to
|
| 1099 |
+
the function.
|
| 1100 |
+
- `output` is a (nested structure of) tensor outputs of applying
|
| 1101 |
+
operations in `forward_fn` to `args`.
|
| 1102 |
+
- `grad_fn` is a function with the signature `grad_fn(*args,
|
| 1103 |
+
upstream)` which returns a tuple of tensors the same size as
|
| 1104 |
+
(flattened) `args`: the derivatives of tensors in `output` with
|
| 1105 |
+
respect to the tensors in `args`. `upstream` is a tensor or
|
| 1106 |
+
sequence of tensors holding the initial value gradients for each
|
| 1107 |
+
tensor in `output`.
|
| 1108 |
+
|
| 1109 |
+
Returns:
|
| 1110 |
+
A function `h(*args)` which returns the same value as
|
| 1111 |
+
`f(*args)[0]` and whose gradient is determined by
|
| 1112 |
+
`f(*args)[1]`.
|
| 1113 |
+
|
| 1114 |
+
|
| 1115 |
+
Examples:
|
| 1116 |
+
|
| 1117 |
+
1. Backend-agnostic example.
|
| 1118 |
+
|
| 1119 |
+
```python
|
| 1120 |
+
@ops.custom_gradient
|
| 1121 |
+
def log1pexp(x):
|
| 1122 |
+
e = ops.exp(x)
|
| 1123 |
+
|
| 1124 |
+
def grad(*args, upstream=None):
|
| 1125 |
+
if upstream is None:
|
| 1126 |
+
(upstream,) = args
|
| 1127 |
+
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
|
| 1128 |
+
|
| 1129 |
+
return ops.log(1 + e), grad
|
| 1130 |
+
```
|
| 1131 |
+
|
| 1132 |
+
Note that the grad function that returns gradient computation
|
| 1133 |
+
requires `args` as well as an `upstream` keyword argument, depending
|
| 1134 |
+
on the backend being set. With the JAX and TensorFlow backends,
|
| 1135 |
+
it requires only one argument, whereas it might use the `upstream`
|
| 1136 |
+
argument in the case of the PyTorch backend.
|
| 1137 |
+
|
| 1138 |
+
When working with TensorFlow/JAX backend, `grad(upstream)`
|
| 1139 |
+
is sufficient. With PyTorch, the `grad` function requires
|
| 1140 |
+
`*args` as well as `upstream`, e.g. `def grad(*args, upstream)`.
|
| 1141 |
+
Follow the previous example to use `@ops.custom_gradient` in
|
| 1142 |
+
a way that is compatible with all backends.
|
| 1143 |
+
|
| 1144 |
+
2. Here's JAX & TensorFlow-specific example:
|
| 1145 |
+
|
| 1146 |
+
```python
|
| 1147 |
+
@ops.custom_gradient
|
| 1148 |
+
def log1pexp(x):
|
| 1149 |
+
e = ops.exp(x)
|
| 1150 |
+
def grad(upstream):
|
| 1151 |
+
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
|
| 1152 |
+
return ops.log(1 + e), grad
|
| 1153 |
+
```
|
| 1154 |
+
|
| 1155 |
+
3. Lastly, here's a PyTorch-specific example,
|
| 1156 |
+
using `*args` & `upstream`:
|
| 1157 |
+
|
| 1158 |
+
```python
|
| 1159 |
+
@ops.custom_gradient
|
| 1160 |
+
def log1pexp(x):
|
| 1161 |
+
e = ops.exp(x)
|
| 1162 |
+
def grad(*args, upstream):
|
| 1163 |
+
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
|
| 1164 |
+
return ops.log(1 + e), grad
|
| 1165 |
+
```
|
| 1166 |
+
"""
|
| 1167 |
+
return backend.core.custom_gradient(f)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/function.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
|
| 3 |
+
from keras.src import tree
|
| 4 |
+
from keras.src.api_export import keras_export
|
| 5 |
+
from keras.src.backend import KerasTensor
|
| 6 |
+
from keras.src.backend.config import backend
|
| 7 |
+
from keras.src.ops.operation import Operation
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@keras_export("keras.Function")
|
| 11 |
+
class Function(Operation):
|
| 12 |
+
"""Class that encapsulates a computation graph of Keras operations.
|
| 13 |
+
|
| 14 |
+
You can use a `Function` to capture the computation graph linking
|
| 15 |
+
some input tensors to some output tensors, and reapply the same
|
| 16 |
+
computation on new inputs.
|
| 17 |
+
|
| 18 |
+
A `Function` is similar to a Functional Model, with the difference
|
| 19 |
+
that it is stateless (it does not track state variables)
|
| 20 |
+
and does not implement the `Layer` API.
|
| 21 |
+
|
| 22 |
+
Example:
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
input_1 = keras.KerasTensor(shape=(None, 2, 3))
|
| 26 |
+
input_2 = keras.KerasTensor(shape=(None, 2, 3))
|
| 27 |
+
x = input_1 + input_2
|
| 28 |
+
output = keras.ops.sigmoid(x)
|
| 29 |
+
fn = keras.Function(inputs=[input_1, input_2], outputs=output)
|
| 30 |
+
|
| 31 |
+
input_1_val = np.random.random((4, 2, 3))
|
| 32 |
+
input_2_val = np.random.random((4, 2, 3))
|
| 33 |
+
output_val = fn([input_1_val, input_2_val])
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
inputs: `KerasTensor` instance or nested structured of
|
| 38 |
+
`KerasTensor` instances.
|
| 39 |
+
outputs: `KerasTensor` instance or nested structured of
|
| 40 |
+
`KerasTensor` instances. They should be computable
|
| 41 |
+
given only the values of `inputs`.
|
| 42 |
+
name: String. The name of the function.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, inputs, outputs, name=None):
|
| 46 |
+
super().__init__(name=name)
|
| 47 |
+
|
| 48 |
+
if backend() == "tensorflow":
|
| 49 |
+
# Temporary work around for
|
| 50 |
+
# https://github.com/keras-team/keras/issues/931
|
| 51 |
+
# This stop tensorflow from wrapping tf.function output in a
|
| 52 |
+
# _DictWrapper object.
|
| 53 |
+
_self_setattr_tracking = getattr(
|
| 54 |
+
self, "_self_setattr_tracking", True
|
| 55 |
+
)
|
| 56 |
+
self._self_setattr_tracking = False
|
| 57 |
+
self._inputs_struct = tree.map_structure(lambda x: x, inputs)
|
| 58 |
+
self._outputs_struct = tree.map_structure(lambda x: x, outputs)
|
| 59 |
+
self._inputs = tree.flatten(inputs)
|
| 60 |
+
self._outputs = tree.flatten(outputs)
|
| 61 |
+
if not self._inputs:
|
| 62 |
+
raise ValueError(
|
| 63 |
+
"`inputs` argument cannot be empty. Received:\n"
|
| 64 |
+
f"inputs={inputs}\n"
|
| 65 |
+
f"outputs={outputs}"
|
| 66 |
+
)
|
| 67 |
+
if not self._outputs:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"`outputs` argument cannot be empty. Received:\n"
|
| 70 |
+
f"inputs={inputs}\n"
|
| 71 |
+
f"outputs={outputs}"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
if backend() == "tensorflow":
|
| 75 |
+
self._self_setattr_tracking = _self_setattr_tracking
|
| 76 |
+
|
| 77 |
+
(nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(
|
| 78 |
+
self._inputs, self._outputs
|
| 79 |
+
)
|
| 80 |
+
self._nodes = nodes
|
| 81 |
+
self._nodes_by_depth = nodes_by_depth
|
| 82 |
+
self._operations = operations
|
| 83 |
+
self._operations_by_depth = operations_by_depth
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def operations(self):
|
| 87 |
+
return self._operations[:]
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def inputs(self):
|
| 91 |
+
"""Flat list of the symbolic inputs of the Function."""
|
| 92 |
+
return self._inputs
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def outputs(self):
|
| 96 |
+
"""Flat list of the symbolic outputs of the Function."""
|
| 97 |
+
return self._outputs
|
| 98 |
+
|
| 99 |
+
def compute_output_spec(self, inputs):
|
| 100 |
+
self._assert_input_compatibility(inputs)
|
| 101 |
+
# Check if input shapes are identical to ref input shapes,
|
| 102 |
+
# if so take a shortcut.
|
| 103 |
+
shortcut = True
|
| 104 |
+
for x, x_ref in zip(tree.flatten(inputs), self._inputs):
|
| 105 |
+
if x.shape != x_ref.shape:
|
| 106 |
+
shortcut = False
|
| 107 |
+
break
|
| 108 |
+
if shortcut:
|
| 109 |
+
return tree.map_structure(
|
| 110 |
+
lambda x: KerasTensor(shape=x.shape, dtype=x.dtype),
|
| 111 |
+
self._outputs_struct,
|
| 112 |
+
)
|
| 113 |
+
# No luck; take the long road through the graph.
|
| 114 |
+
# Original Keras used a cache to avoid recomputing all this
|
| 115 |
+
# when known input shapes where seen again. Perhaps a good
|
| 116 |
+
# idea to bring that back.
|
| 117 |
+
return self._run_through_graph(
|
| 118 |
+
inputs, operation_fn=lambda op: op.compute_output_spec
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def compute_output_shape(self, input_shape):
|
| 122 |
+
# Wrap `input_shape` into the structure of KerasTensor to utilize
|
| 123 |
+
# `compute_output_spec`.
|
| 124 |
+
input_shape_struct = tree.map_shape_structure(
|
| 125 |
+
lambda x: KerasTensor(shape=x), input_shape
|
| 126 |
+
)
|
| 127 |
+
# Ensure that dtype and sparse settings are the same as self._inputs,
|
| 128 |
+
# because we only care about the shape in this function.
|
| 129 |
+
for x, x_ref in zip(tree.flatten(input_shape_struct), self._inputs):
|
| 130 |
+
x._dtype = x_ref.dtype
|
| 131 |
+
x._sparse = x_ref.sparse
|
| 132 |
+
output_spec = self.compute_output_spec(input_shape_struct)
|
| 133 |
+
return tree.map_structure(lambda x: x.shape, output_spec)
|
| 134 |
+
|
| 135 |
+
def call(self, inputs):
|
| 136 |
+
"""Computes output tensors for new inputs."""
|
| 137 |
+
self._assert_input_compatibility(inputs)
|
| 138 |
+
return self._run_through_graph(inputs, operation_fn=lambda op: op)
|
| 139 |
+
|
| 140 |
+
def _run_through_graph(self, inputs, operation_fn, call_fn=None):
|
| 141 |
+
"""Execute the graph.
|
| 142 |
+
|
| 143 |
+
At each node we compute outputs via
|
| 144 |
+
`operation_fn(node.operation)(*args, **kwargs)`.
|
| 145 |
+
"""
|
| 146 |
+
inputs = tree.flatten(inputs)
|
| 147 |
+
|
| 148 |
+
# Dictionary mapping reference tensors to computed tensors.
|
| 149 |
+
tensor_dict = {}
|
| 150 |
+
for x, y in zip(self.inputs, inputs):
|
| 151 |
+
tensor_dict[id(x)] = y
|
| 152 |
+
|
| 153 |
+
nodes_by_depth = self._nodes_by_depth
|
| 154 |
+
depth_keys = list(nodes_by_depth.keys())
|
| 155 |
+
depth_keys.sort(reverse=True)
|
| 156 |
+
|
| 157 |
+
for depth in depth_keys:
|
| 158 |
+
nodes = nodes_by_depth[depth]
|
| 159 |
+
for node in nodes:
|
| 160 |
+
if not node.operation or node.is_input:
|
| 161 |
+
continue # Input tensors already exist.
|
| 162 |
+
|
| 163 |
+
if any(id(x) not in tensor_dict for x in node.input_tensors):
|
| 164 |
+
continue # Node is not computable, try skipping.
|
| 165 |
+
|
| 166 |
+
args, kwargs = node.arguments.fill_in(tensor_dict)
|
| 167 |
+
op = operation_fn(node.operation)
|
| 168 |
+
if call_fn is not None:
|
| 169 |
+
outputs = call_fn(op, *args, **kwargs)
|
| 170 |
+
else:
|
| 171 |
+
outputs = op(*args, **kwargs)
|
| 172 |
+
|
| 173 |
+
# Update tensor_dict.
|
| 174 |
+
for x, y in zip(node.outputs, tree.flatten(outputs)):
|
| 175 |
+
tensor_dict[id(x)] = y
|
| 176 |
+
|
| 177 |
+
output_tensors = []
|
| 178 |
+
for x in self.outputs:
|
| 179 |
+
output_tensors.append(tensor_dict[id(x)])
|
| 180 |
+
|
| 181 |
+
return tree.pack_sequence_as(self._outputs_struct, output_tensors)
|
| 182 |
+
|
| 183 |
+
def _assert_input_compatibility(self, inputs):
|
| 184 |
+
try:
|
| 185 |
+
tree.assert_same_structure(inputs, self._inputs_struct)
|
| 186 |
+
except ValueError:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
"Function was called with an invalid input structure. "
|
| 189 |
+
f"Expected input structure: {self._inputs_struct}\n"
|
| 190 |
+
f"Received input structure: {inputs}"
|
| 191 |
+
)
|
| 192 |
+
for x, x_ref in zip(tree.flatten(inputs), self._inputs):
|
| 193 |
+
if len(x.shape) != len(x_ref.shape):
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"{self.__class__.__name__} was passed "
|
| 196 |
+
f"incompatible inputs. For input '{x_ref.name}', "
|
| 197 |
+
f"expected shape {x_ref.shape}, but received "
|
| 198 |
+
f"instead a tensor with shape {x.shape}."
|
| 199 |
+
)
|
| 200 |
+
for dim, ref_dim in zip(x.shape, x_ref.shape):
|
| 201 |
+
if ref_dim is not None and dim is not None:
|
| 202 |
+
if dim != ref_dim:
|
| 203 |
+
raise ValueError(
|
| 204 |
+
f"{self.__class__.__name__} was passed "
|
| 205 |
+
f"incompatible inputs. For input '{x_ref.name}', "
|
| 206 |
+
f"expected shape {x_ref.shape}, but received "
|
| 207 |
+
f"instead a tensor with shape {x.shape}."
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def make_node_key(op, node_index):
|
| 212 |
+
return str(id(op)) + "_ib-" + str(node_index)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def map_graph(inputs, outputs):
|
| 216 |
+
"""Validates a graph's topology and gather its operations and nodes.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
inputs: List of input tensors.
|
| 220 |
+
outputs: List of outputs tensors.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
A tuple `(nodes, nodes_by_depth, operations, operations_by_depth)`.
|
| 224 |
+
- nodes: set of Node instances
|
| 225 |
+
- nodes_by_depth: dict mapping ints (depth) to lists of node instances.
|
| 226 |
+
- operations: list of Operation instances.
|
| 227 |
+
- operations_by_depth: dict mapping ints (depth) to lists of Operation
|
| 228 |
+
instances.
|
| 229 |
+
"""
|
| 230 |
+
# "depth" is number of operations between output Node and the Node.
|
| 231 |
+
# Nodes are ordered from inputs -> outputs.
|
| 232 |
+
nodes_in_decreasing_depth, operation_indices = _build_map(inputs, outputs)
|
| 233 |
+
network_nodes = {
|
| 234 |
+
make_node_key(node.operation, node.operation._inbound_nodes.index(node))
|
| 235 |
+
for node in nodes_in_decreasing_depth
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
nodes_depths = {} # dict {node: depth value}
|
| 239 |
+
operations_depths = {} # dict {operation: depth value}
|
| 240 |
+
|
| 241 |
+
for node in reversed(nodes_in_decreasing_depth):
|
| 242 |
+
# If the depth is not set, the node has no outbound nodes (depth 0).
|
| 243 |
+
depth = nodes_depths.setdefault(node, 0)
|
| 244 |
+
|
| 245 |
+
# Update the depth of the corresponding operation
|
| 246 |
+
previous_depth = operations_depths.get(node.operation, 0)
|
| 247 |
+
# If we've seen this operation before at a higher depth,
|
| 248 |
+
# we should use that depth instead of the node depth.
|
| 249 |
+
# This is necessary for shared operations that have inputs at different
|
| 250 |
+
# depth levels in the graph.
|
| 251 |
+
depth = max(depth, previous_depth)
|
| 252 |
+
operations_depths[node.operation] = depth
|
| 253 |
+
nodes_depths[node] = depth
|
| 254 |
+
|
| 255 |
+
# Update the depth of inbound nodes.
|
| 256 |
+
# The "depth" of a node is the max of the depths
|
| 257 |
+
# of all nodes it is connected to + 1.
|
| 258 |
+
for node_dep in node.parent_nodes:
|
| 259 |
+
previous_depth = nodes_depths.get(node_dep, 0)
|
| 260 |
+
nodes_depths[node_dep] = max(depth + 1, previous_depth)
|
| 261 |
+
|
| 262 |
+
# Handle inputs that are not connected to outputs.
|
| 263 |
+
# We do not error out here because the inputs may be used to compute losses
|
| 264 |
+
# and metrics.
|
| 265 |
+
for input_t in inputs:
|
| 266 |
+
input_operation = input_t._keras_history[0]
|
| 267 |
+
if input_operation and input_operation not in operations_depths:
|
| 268 |
+
operations_depths[input_operation] = 0
|
| 269 |
+
operation_indices[input_operation] = -1
|
| 270 |
+
nodes_depths[input_operation._inbound_nodes[0]] = 0
|
| 271 |
+
network_nodes.add(make_node_key(input_operation, 0))
|
| 272 |
+
|
| 273 |
+
# Build a dict {depth: list of nodes with this depth}
|
| 274 |
+
nodes_by_depth = collections.defaultdict(list)
|
| 275 |
+
for node, depth in nodes_depths.items():
|
| 276 |
+
nodes_by_depth[depth].append(node)
|
| 277 |
+
|
| 278 |
+
# Build a dict {depth: list of operations with this depth}
|
| 279 |
+
operations_by_depth = collections.defaultdict(list)
|
| 280 |
+
for operation, depth in operations_depths.items():
|
| 281 |
+
operations_by_depth[depth].append(operation)
|
| 282 |
+
|
| 283 |
+
# Get sorted list of operation depths.
|
| 284 |
+
depth_keys = list(operations_by_depth.keys())
|
| 285 |
+
depth_keys.sort(reverse=True)
|
| 286 |
+
|
| 287 |
+
# Set self.operations ordered by depth.
|
| 288 |
+
operations = []
|
| 289 |
+
for depth in depth_keys:
|
| 290 |
+
operations_for_depth = operations_by_depth[depth]
|
| 291 |
+
# Network.operations needs to have a deterministic order:
|
| 292 |
+
# here we order them by traversal order.
|
| 293 |
+
operations_for_depth.sort(key=lambda x: operation_indices[x])
|
| 294 |
+
operations.extend(operations_for_depth)
|
| 295 |
+
|
| 296 |
+
# Get sorted list of node depths.
|
| 297 |
+
depth_keys = list(nodes_by_depth.keys())
|
| 298 |
+
depth_keys.sort(reverse=True)
|
| 299 |
+
|
| 300 |
+
# Check that all tensors required are computable.
|
| 301 |
+
# computable_tensors: all tensors in the graph
|
| 302 |
+
# that can be computed from the inputs provided.
|
| 303 |
+
computable_tensors = set()
|
| 304 |
+
for x in inputs:
|
| 305 |
+
computable_tensors.add(x)
|
| 306 |
+
|
| 307 |
+
operations_with_complete_input = [] # To provide a better error msg.
|
| 308 |
+
for depth in depth_keys:
|
| 309 |
+
for node in nodes_by_depth[depth]:
|
| 310 |
+
for x in tree.flatten(node.input_tensors):
|
| 311 |
+
if x not in computable_tensors:
|
| 312 |
+
operation = node.operation
|
| 313 |
+
raise ValueError(
|
| 314 |
+
"Graph disconnected: cannot find parent for "
|
| 315 |
+
f"tensor {x} at operation '{operation}'. "
|
| 316 |
+
"The following previous operations were accessed "
|
| 317 |
+
f"without issue: {operations_with_complete_input}"
|
| 318 |
+
)
|
| 319 |
+
operations_with_complete_input.append(node.operation.name)
|
| 320 |
+
|
| 321 |
+
for x in tree.flatten(node.outputs):
|
| 322 |
+
computable_tensors.add(x)
|
| 323 |
+
|
| 324 |
+
# Ensure name unicity, which will be crucial for serialization
|
| 325 |
+
# (since serialized nodes refer to operations by their name).
|
| 326 |
+
all_names = [operation.name for operation in operations]
|
| 327 |
+
for name in all_names:
|
| 328 |
+
if all_names.count(name) != 1:
|
| 329 |
+
raise ValueError(
|
| 330 |
+
f'The name "{name}" is used {all_names.count(name)} '
|
| 331 |
+
"times in the model. All operation names should be unique."
|
| 332 |
+
)
|
| 333 |
+
return network_nodes, nodes_by_depth, operations, operations_by_depth
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def _build_map(inputs, outputs):
|
| 337 |
+
"""Topologically sort nodes in order from inputs to outputs.
|
| 338 |
+
|
| 339 |
+
It uses a depth-first search to topologically sort nodes that appear in the
|
| 340 |
+
_keras_history connectivity metadata of `outputs`.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
outputs: the output tensors whose _keras_history metadata should be
|
| 344 |
+
walked. This may be an arbitrary nested structure.
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
A tuple like (ordered_nodes, operation_to_first_traversal_index)
|
| 348 |
+
ordered_nodes: list of nodes appearing in the keras history,
|
| 349 |
+
topologically sorted from original inputs to the `outputs`.
|
| 350 |
+
(If outputs have different sets of ancestors, the inputs to one
|
| 351 |
+
output may appear after a different output).
|
| 352 |
+
operation_to_first_traversal_index:
|
| 353 |
+
A dict mapping operation to the traversal index in the DFS where it
|
| 354 |
+
is seen. Note: if a operation is shared by several nodes, the dict
|
| 355 |
+
will onlystore the index corresponding to the *first* time the
|
| 356 |
+
operation seen.
|
| 357 |
+
"""
|
| 358 |
+
finished_nodes = set()
|
| 359 |
+
nodes_in_progress = set()
|
| 360 |
+
nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
|
| 361 |
+
operation_indices = {} # operation -> in traversal order.
|
| 362 |
+
for output in tree.flatten(outputs):
|
| 363 |
+
_build_map_helper(
|
| 364 |
+
inputs,
|
| 365 |
+
output,
|
| 366 |
+
finished_nodes,
|
| 367 |
+
nodes_in_progress,
|
| 368 |
+
nodes_in_decreasing_depth,
|
| 369 |
+
operation_indices,
|
| 370 |
+
)
|
| 371 |
+
return nodes_in_decreasing_depth, operation_indices
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _build_map_helper(
|
| 375 |
+
inputs,
|
| 376 |
+
tensor,
|
| 377 |
+
finished_nodes,
|
| 378 |
+
nodes_in_progress,
|
| 379 |
+
nodes_in_decreasing_depth,
|
| 380 |
+
operation_indices,
|
| 381 |
+
):
|
| 382 |
+
"""Recursive helper for `_build_map`."""
|
| 383 |
+
(
|
| 384 |
+
operation,
|
| 385 |
+
node_index,
|
| 386 |
+
_,
|
| 387 |
+
) = tensor._keras_history
|
| 388 |
+
if not operation:
|
| 389 |
+
return
|
| 390 |
+
|
| 391 |
+
node = operation._inbound_nodes[node_index]
|
| 392 |
+
|
| 393 |
+
# Don't repeat work for shared subgraphs
|
| 394 |
+
if node in finished_nodes:
|
| 395 |
+
return
|
| 396 |
+
|
| 397 |
+
# Prevent cycles.
|
| 398 |
+
if node in nodes_in_progress:
|
| 399 |
+
raise ValueError(
|
| 400 |
+
f"Tensor {tensor} from operation '{operation.name}' is part of a "
|
| 401 |
+
"cycle."
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Store the traversal order for operation sorting.
|
| 405 |
+
if operation not in operation_indices:
|
| 406 |
+
operation_indices[operation] = len(operation_indices)
|
| 407 |
+
|
| 408 |
+
# Propagate to all previous tensors connected to this node.
|
| 409 |
+
nodes_in_progress.add(node)
|
| 410 |
+
if not node.is_input and tensor not in tree.flatten(inputs):
|
| 411 |
+
for tensor in node.input_tensors:
|
| 412 |
+
_build_map_helper(
|
| 413 |
+
inputs,
|
| 414 |
+
tensor,
|
| 415 |
+
finished_nodes,
|
| 416 |
+
nodes_in_progress,
|
| 417 |
+
nodes_in_decreasing_depth,
|
| 418 |
+
operation_indices,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
finished_nodes.add(node)
|
| 422 |
+
nodes_in_progress.remove(node)
|
| 423 |
+
nodes_in_decreasing_depth.append(node)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/image.py
ADDED
|
@@ -0,0 +1,1235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src import ops
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.backend import KerasTensor
|
| 5 |
+
from keras.src.backend import any_symbolic_tensors
|
| 6 |
+
from keras.src.ops.operation import Operation
|
| 7 |
+
from keras.src.ops.operation_utils import compute_conv_output_shape
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RGBToGrayscale(Operation):
|
| 11 |
+
def __init__(self, data_format=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.data_format = backend.standardize_data_format(data_format)
|
| 14 |
+
|
| 15 |
+
def call(self, images):
|
| 16 |
+
return backend.image.rgb_to_grayscale(
|
| 17 |
+
images, data_format=self.data_format
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def compute_output_spec(self, images):
|
| 21 |
+
images_shape = list(images.shape)
|
| 22 |
+
if len(images_shape) not in (3, 4):
|
| 23 |
+
raise ValueError(
|
| 24 |
+
"Invalid images rank: expected rank 3 (single image) "
|
| 25 |
+
"or rank 4 (batch of images). "
|
| 26 |
+
f"Received: images.shape={images_shape}"
|
| 27 |
+
)
|
| 28 |
+
if self.data_format == "channels_last":
|
| 29 |
+
images_shape[-1] = 1
|
| 30 |
+
else:
|
| 31 |
+
images_shape[-3] = 1
|
| 32 |
+
return KerasTensor(shape=images_shape, dtype=images.dtype)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@keras_export("keras.ops.image.rgb_to_grayscale")
|
| 36 |
+
def rgb_to_grayscale(images, data_format=None):
|
| 37 |
+
"""Convert RGB images to grayscale.
|
| 38 |
+
|
| 39 |
+
This function converts RGB images to grayscale images. It supports both
|
| 40 |
+
3D and 4D tensors.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
images: Input image or batch of images. Must be 3D or 4D.
|
| 44 |
+
data_format: A string specifying the data format of the input tensor.
|
| 45 |
+
It can be either `"channels_last"` or `"channels_first"`.
|
| 46 |
+
`"channels_last"` corresponds to inputs with shape
|
| 47 |
+
`(batch, height, width, channels)`, while `"channels_first"`
|
| 48 |
+
corresponds to inputs with shape `(batch, channels, height, width)`.
|
| 49 |
+
If not specified, the value will default to
|
| 50 |
+
`keras.config.image_data_format`.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Grayscale image or batch of grayscale images.
|
| 54 |
+
|
| 55 |
+
Examples:
|
| 56 |
+
|
| 57 |
+
>>> import numpy as np
|
| 58 |
+
>>> from keras import ops
|
| 59 |
+
>>> x = np.random.random((2, 4, 4, 3))
|
| 60 |
+
>>> y = ops.image.rgb_to_grayscale(x)
|
| 61 |
+
>>> y.shape
|
| 62 |
+
(2, 4, 4, 1)
|
| 63 |
+
|
| 64 |
+
>>> x = np.random.random((4, 4, 3)) # Single RGB image
|
| 65 |
+
>>> y = ops.image.rgb_to_grayscale(x)
|
| 66 |
+
>>> y.shape
|
| 67 |
+
(4, 4, 1)
|
| 68 |
+
|
| 69 |
+
>>> x = np.random.random((2, 3, 4, 4))
|
| 70 |
+
>>> y = ops.image.rgb_to_grayscale(x, data_format="channels_first")
|
| 71 |
+
>>> y.shape
|
| 72 |
+
(2, 1, 4, 4)
|
| 73 |
+
"""
|
| 74 |
+
if any_symbolic_tensors((images,)):
|
| 75 |
+
return RGBToGrayscale(data_format=data_format).symbolic_call(images)
|
| 76 |
+
return backend.image.rgb_to_grayscale(images, data_format=data_format)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class RGBToHSV(Operation):
|
| 80 |
+
def __init__(self, data_format=None):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.data_format = backend.standardize_data_format(data_format)
|
| 83 |
+
|
| 84 |
+
def call(self, images):
|
| 85 |
+
return backend.image.rgb_to_hsv(images, data_format=self.data_format)
|
| 86 |
+
|
| 87 |
+
def compute_output_spec(self, images):
|
| 88 |
+
images_shape = list(images.shape)
|
| 89 |
+
dtype = images.dtype
|
| 90 |
+
if len(images_shape) not in (3, 4):
|
| 91 |
+
raise ValueError(
|
| 92 |
+
"Invalid images rank: expected rank 3 (single image) "
|
| 93 |
+
"or rank 4 (batch of images). "
|
| 94 |
+
f"Received: images.shape={images_shape}"
|
| 95 |
+
)
|
| 96 |
+
if not backend.is_float_dtype(dtype):
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"Invalid images dtype: expected float dtype. "
|
| 99 |
+
f"Received: images.dtype={dtype}"
|
| 100 |
+
)
|
| 101 |
+
return KerasTensor(shape=images_shape, dtype=images.dtype)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@keras_export("keras.ops.image.rgb_to_hsv")
|
| 105 |
+
def rgb_to_hsv(images, data_format=None):
|
| 106 |
+
"""Convert RGB images to HSV.
|
| 107 |
+
|
| 108 |
+
`images` must be of float dtype, and the output is only well defined if the
|
| 109 |
+
values in `images` are in `[0, 1]`.
|
| 110 |
+
|
| 111 |
+
All HSV values are in `[0, 1]`. A hue of `0` corresponds to pure red, `1/3`
|
| 112 |
+
is pure green, and `2/3` is pure blue.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
images: Input image or batch of images. Must be 3D or 4D.
|
| 116 |
+
data_format: A string specifying the data format of the input tensor.
|
| 117 |
+
It can be either `"channels_last"` or `"channels_first"`.
|
| 118 |
+
`"channels_last"` corresponds to inputs with shape
|
| 119 |
+
`(batch, height, width, channels)`, while `"channels_first"`
|
| 120 |
+
corresponds to inputs with shape `(batch, channels, height, width)`.
|
| 121 |
+
If not specified, the value will default to
|
| 122 |
+
`keras.config.image_data_format`.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
HSV image or batch of HSV images.
|
| 126 |
+
|
| 127 |
+
Examples:
|
| 128 |
+
|
| 129 |
+
>>> import numpy as np
|
| 130 |
+
>>> from keras import ops
|
| 131 |
+
>>> x = np.random.random((2, 4, 4, 3))
|
| 132 |
+
>>> y = ops.image.rgb_to_hsv(x)
|
| 133 |
+
>>> y.shape
|
| 134 |
+
(2, 4, 4, 3)
|
| 135 |
+
|
| 136 |
+
>>> x = np.random.random((4, 4, 3)) # Single RGB image
|
| 137 |
+
>>> y = ops.image.rgb_to_hsv(x)
|
| 138 |
+
>>> y.shape
|
| 139 |
+
(4, 4, 3)
|
| 140 |
+
|
| 141 |
+
>>> x = np.random.random((2, 3, 4, 4))
|
| 142 |
+
>>> y = ops.image.rgb_to_hsv(x, data_format="channels_first")
|
| 143 |
+
>>> y.shape
|
| 144 |
+
(2, 3, 4, 4)
|
| 145 |
+
"""
|
| 146 |
+
if any_symbolic_tensors((images,)):
|
| 147 |
+
return RGBToHSV(data_format=data_format).symbolic_call(images)
|
| 148 |
+
return backend.image.rgb_to_hsv(images, data_format=data_format)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class HSVToRGB(Operation):
|
| 152 |
+
def __init__(self, data_format=None):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.data_format = backend.standardize_data_format(data_format)
|
| 155 |
+
|
| 156 |
+
def call(self, images):
|
| 157 |
+
return backend.image.hsv_to_rgb(images, data_format=self.data_format)
|
| 158 |
+
|
| 159 |
+
def compute_output_spec(self, images):
|
| 160 |
+
images_shape = list(images.shape)
|
| 161 |
+
dtype = images.dtype
|
| 162 |
+
if len(images_shape) not in (3, 4):
|
| 163 |
+
raise ValueError(
|
| 164 |
+
"Invalid images rank: expected rank 3 (single image) "
|
| 165 |
+
"or rank 4 (batch of images). "
|
| 166 |
+
f"Received: images.shape={images_shape}"
|
| 167 |
+
)
|
| 168 |
+
if not backend.is_float_dtype(dtype):
|
| 169 |
+
raise ValueError(
|
| 170 |
+
"Invalid images dtype: expected float dtype. "
|
| 171 |
+
f"Received: images.dtype={dtype}"
|
| 172 |
+
)
|
| 173 |
+
return KerasTensor(shape=images_shape, dtype=images.dtype)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@keras_export("keras.ops.image.hsv_to_rgb")
|
| 177 |
+
def hsv_to_rgb(images, data_format=None):
|
| 178 |
+
"""Convert HSV images to RGB.
|
| 179 |
+
|
| 180 |
+
`images` must be of float dtype, and the output is only well defined if the
|
| 181 |
+
values in `images` are in `[0, 1]`.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
images: Input image or batch of images. Must be 3D or 4D.
|
| 185 |
+
data_format: A string specifying the data format of the input tensor.
|
| 186 |
+
It can be either `"channels_last"` or `"channels_first"`.
|
| 187 |
+
`"channels_last"` corresponds to inputs with shape
|
| 188 |
+
`(batch, height, width, channels)`, while `"channels_first"`
|
| 189 |
+
corresponds to inputs with shape `(batch, channels, height, width)`.
|
| 190 |
+
If not specified, the value will default to
|
| 191 |
+
`keras.config.image_data_format`.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
RGB image or batch of RGB images.
|
| 195 |
+
|
| 196 |
+
Examples:
|
| 197 |
+
|
| 198 |
+
>>> import numpy as np
|
| 199 |
+
>>> from keras import ops
|
| 200 |
+
>>> x = np.random.random((2, 4, 4, 3))
|
| 201 |
+
>>> y = ops.image.hsv_to_rgb(x)
|
| 202 |
+
>>> y.shape
|
| 203 |
+
(2, 4, 4, 3)
|
| 204 |
+
|
| 205 |
+
>>> x = np.random.random((4, 4, 3)) # Single HSV image
|
| 206 |
+
>>> y = ops.image.hsv_to_rgb(x)
|
| 207 |
+
>>> y.shape
|
| 208 |
+
(4, 4, 3)
|
| 209 |
+
|
| 210 |
+
>>> x = np.random.random((2, 3, 4, 4))
|
| 211 |
+
>>> y = ops.image.hsv_to_rgb(x, data_format="channels_first")
|
| 212 |
+
>>> y.shape
|
| 213 |
+
(2, 3, 4, 4)
|
| 214 |
+
"""
|
| 215 |
+
if any_symbolic_tensors((images,)):
|
| 216 |
+
return HSVToRGB(data_format=data_format).symbolic_call(images)
|
| 217 |
+
return backend.image.hsv_to_rgb(images, data_format=data_format)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class Resize(Operation):
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
size,
|
| 224 |
+
interpolation="bilinear",
|
| 225 |
+
antialias=False,
|
| 226 |
+
crop_to_aspect_ratio=False,
|
| 227 |
+
pad_to_aspect_ratio=False,
|
| 228 |
+
fill_mode="constant",
|
| 229 |
+
fill_value=0.0,
|
| 230 |
+
data_format=None,
|
| 231 |
+
):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.size = tuple(size)
|
| 234 |
+
self.interpolation = interpolation
|
| 235 |
+
self.antialias = antialias
|
| 236 |
+
self.crop_to_aspect_ratio = crop_to_aspect_ratio
|
| 237 |
+
self.pad_to_aspect_ratio = pad_to_aspect_ratio
|
| 238 |
+
self.fill_mode = fill_mode
|
| 239 |
+
self.fill_value = fill_value
|
| 240 |
+
self.data_format = backend.standardize_data_format(data_format)
|
| 241 |
+
|
| 242 |
+
def call(self, images):
|
| 243 |
+
return _resize(
|
| 244 |
+
images,
|
| 245 |
+
self.size,
|
| 246 |
+
interpolation=self.interpolation,
|
| 247 |
+
antialias=self.antialias,
|
| 248 |
+
data_format=self.data_format,
|
| 249 |
+
crop_to_aspect_ratio=self.crop_to_aspect_ratio,
|
| 250 |
+
pad_to_aspect_ratio=self.pad_to_aspect_ratio,
|
| 251 |
+
fill_mode=self.fill_mode,
|
| 252 |
+
fill_value=self.fill_value,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def compute_output_spec(self, images):
|
| 256 |
+
images_shape = list(images.shape)
|
| 257 |
+
if len(images_shape) not in (3, 4):
|
| 258 |
+
raise ValueError(
|
| 259 |
+
"Invalid images rank: expected rank 3 (single image) "
|
| 260 |
+
"or rank 4 (batch of images). Received input with shape: "
|
| 261 |
+
f"images.shape={images.shape}"
|
| 262 |
+
)
|
| 263 |
+
if self.data_format == "channels_last":
|
| 264 |
+
height_axis, width_axis = -3, -2
|
| 265 |
+
else:
|
| 266 |
+
height_axis, width_axis = -2, -1
|
| 267 |
+
images_shape[height_axis] = self.size[0]
|
| 268 |
+
images_shape[width_axis] = self.size[1]
|
| 269 |
+
return KerasTensor(shape=images_shape, dtype=images.dtype)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
@keras_export("keras.ops.image.resize")
|
| 273 |
+
def resize(
|
| 274 |
+
images,
|
| 275 |
+
size,
|
| 276 |
+
interpolation="bilinear",
|
| 277 |
+
antialias=False,
|
| 278 |
+
crop_to_aspect_ratio=False,
|
| 279 |
+
pad_to_aspect_ratio=False,
|
| 280 |
+
fill_mode="constant",
|
| 281 |
+
fill_value=0.0,
|
| 282 |
+
data_format=None,
|
| 283 |
+
):
|
| 284 |
+
"""Resize images to size using the specified interpolation method.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
images: Input image or batch of images. Must be 3D or 4D.
|
| 288 |
+
size: Size of output image in `(height, width)` format.
|
| 289 |
+
interpolation: Interpolation method. Available methods are `"nearest"`,
|
| 290 |
+
`"bilinear"`, and `"bicubic"`. Defaults to `"bilinear"`.
|
| 291 |
+
antialias: Whether to use an antialiasing filter when downsampling an
|
| 292 |
+
image. Defaults to `False`.
|
| 293 |
+
crop_to_aspect_ratio: If `True`, resize the images without aspect
|
| 294 |
+
ratio distortion. When the original aspect ratio differs
|
| 295 |
+
from the target aspect ratio, the output image will be
|
| 296 |
+
cropped so as to return the
|
| 297 |
+
largest possible window in the image (of size `(height, width)`)
|
| 298 |
+
that matches the target aspect ratio. By default
|
| 299 |
+
(`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.
|
| 300 |
+
pad_to_aspect_ratio: If `True`, pad the images without aspect
|
| 301 |
+
ratio distortion. When the original aspect ratio differs
|
| 302 |
+
from the target aspect ratio, the output image will be
|
| 303 |
+
evenly padded on the short side.
|
| 304 |
+
fill_mode: When using `pad_to_aspect_ratio=True`, padded areas
|
| 305 |
+
are filled according to the given mode. Only `"constant"` is
|
| 306 |
+
supported at this time
|
| 307 |
+
(fill with constant value, equal to `fill_value`).
|
| 308 |
+
fill_value: Float. Padding value to use when `pad_to_aspect_ratio=True`.
|
| 309 |
+
data_format: A string specifying the data format of the input tensor.
|
| 310 |
+
It can be either `"channels_last"` or `"channels_first"`.
|
| 311 |
+
`"channels_last"` corresponds to inputs with shape
|
| 312 |
+
`(batch, height, width, channels)`, while `"channels_first"`
|
| 313 |
+
corresponds to inputs with shape `(batch, channels, height, width)`.
|
| 314 |
+
If not specified, the value will default to
|
| 315 |
+
`keras.config.image_data_format`.
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
Resized image or batch of images.
|
| 319 |
+
|
| 320 |
+
Examples:
|
| 321 |
+
|
| 322 |
+
>>> x = np.random.random((2, 4, 4, 3)) # batch of 2 RGB images
|
| 323 |
+
>>> y = keras.ops.image.resize(x, (2, 2))
|
| 324 |
+
>>> y.shape
|
| 325 |
+
(2, 2, 2, 3)
|
| 326 |
+
|
| 327 |
+
>>> x = np.random.random((4, 4, 3)) # single RGB image
|
| 328 |
+
>>> y = keras.ops.image.resize(x, (2, 2))
|
| 329 |
+
>>> y.shape
|
| 330 |
+
(2, 2, 3)
|
| 331 |
+
|
| 332 |
+
>>> x = np.random.random((2, 3, 4, 4)) # batch of 2 RGB images
|
| 333 |
+
>>> y = keras.ops.image.resize(x, (2, 2),
|
| 334 |
+
... data_format="channels_first")
|
| 335 |
+
>>> y.shape
|
| 336 |
+
(2, 3, 2, 2)
|
| 337 |
+
"""
|
| 338 |
+
if len(size) != 2:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
"Expected `size` to be a tuple of 2 integers. "
|
| 341 |
+
f"Received: size={size}"
|
| 342 |
+
)
|
| 343 |
+
if len(images.shape) < 3 or len(images.shape) > 4:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
"Invalid images rank: expected rank 3 (single image) "
|
| 346 |
+
"or rank 4 (batch of images). Received input with shape: "
|
| 347 |
+
f"images.shape={images.shape}"
|
| 348 |
+
)
|
| 349 |
+
if pad_to_aspect_ratio and crop_to_aspect_ratio:
|
| 350 |
+
raise ValueError(
|
| 351 |
+
"Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` "
|
| 352 |
+
"can be `True`."
|
| 353 |
+
)
|
| 354 |
+
if any_symbolic_tensors((images,)):
|
| 355 |
+
return Resize(
|
| 356 |
+
size,
|
| 357 |
+
interpolation=interpolation,
|
| 358 |
+
antialias=antialias,
|
| 359 |
+
data_format=data_format,
|
| 360 |
+
crop_to_aspect_ratio=crop_to_aspect_ratio,
|
| 361 |
+
pad_to_aspect_ratio=pad_to_aspect_ratio,
|
| 362 |
+
fill_mode=fill_mode,
|
| 363 |
+
fill_value=fill_value,
|
| 364 |
+
).symbolic_call(images)
|
| 365 |
+
return _resize(
|
| 366 |
+
images,
|
| 367 |
+
size,
|
| 368 |
+
interpolation=interpolation,
|
| 369 |
+
antialias=antialias,
|
| 370 |
+
crop_to_aspect_ratio=crop_to_aspect_ratio,
|
| 371 |
+
data_format=data_format,
|
| 372 |
+
pad_to_aspect_ratio=pad_to_aspect_ratio,
|
| 373 |
+
fill_mode=fill_mode,
|
| 374 |
+
fill_value=fill_value,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def _resize(
|
| 379 |
+
images,
|
| 380 |
+
size,
|
| 381 |
+
interpolation="bilinear",
|
| 382 |
+
antialias=False,
|
| 383 |
+
crop_to_aspect_ratio=False,
|
| 384 |
+
pad_to_aspect_ratio=False,
|
| 385 |
+
fill_mode="constant",
|
| 386 |
+
fill_value=0.0,
|
| 387 |
+
data_format=None,
|
| 388 |
+
):
|
| 389 |
+
resized = backend.image.resize(
|
| 390 |
+
images,
|
| 391 |
+
size,
|
| 392 |
+
interpolation=interpolation,
|
| 393 |
+
antialias=antialias,
|
| 394 |
+
crop_to_aspect_ratio=crop_to_aspect_ratio,
|
| 395 |
+
data_format=data_format,
|
| 396 |
+
pad_to_aspect_ratio=pad_to_aspect_ratio,
|
| 397 |
+
fill_mode=fill_mode,
|
| 398 |
+
fill_value=fill_value,
|
| 399 |
+
)
|
| 400 |
+
if resized.dtype == images.dtype:
|
| 401 |
+
# Only `torch` backend will cast result to original dtype with
|
| 402 |
+
# correct rounding and without dtype overflow
|
| 403 |
+
return resized
|
| 404 |
+
if backend.is_int_dtype(images.dtype):
|
| 405 |
+
resized = ops.round(resized)
|
| 406 |
+
return ops.saturate_cast(resized, images.dtype)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class AffineTransform(Operation):
|
| 410 |
+
def __init__(
|
| 411 |
+
self,
|
| 412 |
+
interpolation="bilinear",
|
| 413 |
+
fill_mode="constant",
|
| 414 |
+
fill_value=0,
|
| 415 |
+
data_format=None,
|
| 416 |
+
):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.interpolation = interpolation
|
| 419 |
+
self.fill_mode = fill_mode
|
| 420 |
+
self.fill_value = fill_value
|
| 421 |
+
self.data_format = backend.standardize_data_format(data_format)
|
| 422 |
+
|
| 423 |
+
def call(self, images, transform):
|
| 424 |
+
return backend.image.affine_transform(
|
| 425 |
+
images,
|
| 426 |
+
transform,
|
| 427 |
+
interpolation=self.interpolation,
|
| 428 |
+
fill_mode=self.fill_mode,
|
| 429 |
+
fill_value=self.fill_value,
|
| 430 |
+
data_format=self.data_format,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
def compute_output_spec(self, images, transform):
|
| 434 |
+
if len(images.shape) not in (3, 4):
|
| 435 |
+
raise ValueError(
|
| 436 |
+
"Invalid images rank: expected rank 3 (single image) "
|
| 437 |
+
"or rank 4 (batch of images). Received input with shape: "
|
| 438 |
+
f"images.shape={images.shape}"
|
| 439 |
+
)
|
| 440 |
+
if len(transform.shape) not in (1, 2):
|
| 441 |
+
raise ValueError(
|
| 442 |
+
"Invalid transform rank: expected rank 1 (single transform) "
|
| 443 |
+
"or rank 2 (batch of transforms). Received input with shape: "
|
| 444 |
+
f"transform.shape={transform.shape}"
|
| 445 |
+
)
|
| 446 |
+
return KerasTensor(images.shape, dtype=images.dtype)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
@keras_export("keras.ops.image.affine_transform")
|
| 450 |
+
def affine_transform(
|
| 451 |
+
images,
|
| 452 |
+
transform,
|
| 453 |
+
interpolation="bilinear",
|
| 454 |
+
fill_mode="constant",
|
| 455 |
+
fill_value=0,
|
| 456 |
+
data_format=None,
|
| 457 |
+
):
|
| 458 |
+
"""Applies the given transform(s) to the image(s).
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
images: Input image or batch of images. Must be 3D or 4D.
|
| 462 |
+
transform: Projective transform matrix/matrices. A vector of length 8 or
|
| 463 |
+
tensor of size N x 8. If one row of transform is
|
| 464 |
+
`[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the output point
|
| 465 |
+
`(x, y)` to a transformed input point
|
| 466 |
+
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
|
| 467 |
+
where `k = c0 x + c1 y + 1`. The transform is inverted compared to
|
| 468 |
+
the transform mapping input points to output points. Note that
|
| 469 |
+
gradients are not backpropagated into transformation parameters.
|
| 470 |
+
Note that `c0` and `c1` are only effective when using TensorFlow
|
| 471 |
+
backend and will be considered as `0` when using other backends.
|
| 472 |
+
interpolation: Interpolation method. Available methods are `"nearest"`,
|
| 473 |
+
and `"bilinear"`. Defaults to `"bilinear"`.
|
| 474 |
+
fill_mode: Points outside the boundaries of the input are filled
|
| 475 |
+
according to the given mode. Available methods are `"constant"`,
|
| 476 |
+
`"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
|
| 477 |
+
- `"reflect"`: `(d c b a | a b c d | d c b a)`
|
| 478 |
+
The input is extended by reflecting about the edge of the last
|
| 479 |
+
pixel.
|
| 480 |
+
- `"constant"`: `(k k k k | a b c d | k k k k)`
|
| 481 |
+
The input is extended by filling all values beyond
|
| 482 |
+
the edge with the same constant value k specified by
|
| 483 |
+
`fill_value`.
|
| 484 |
+
- `"wrap"`: `(a b c d | a b c d | a b c d)`
|
| 485 |
+
The input is extended by wrapping around to the opposite edge.
|
| 486 |
+
- `"nearest"`: `(a a a a | a b c d | d d d d)`
|
| 487 |
+
The input is extended by the nearest pixel.
|
| 488 |
+
fill_value: Value used for points outside the boundaries of the input if
|
| 489 |
+
`fill_mode="constant"`. Defaults to `0`.
|
| 490 |
+
data_format: A string specifying the data format of the input tensor.
|
| 491 |
+
It can be either `"channels_last"` or `"channels_first"`.
|
| 492 |
+
`"channels_last"` corresponds to inputs with shape
|
| 493 |
+
`(batch, height, width, channels)`, while `"channels_first"`
|
| 494 |
+
corresponds to inputs with shape `(batch, channels, height, width)`.
|
| 495 |
+
If not specified, the value will default to
|
| 496 |
+
`keras.config.image_data_format`.
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
Applied affine transform image or batch of images.
|
| 500 |
+
|
| 501 |
+
Examples:
|
| 502 |
+
|
| 503 |
+
>>> x = np.random.random((2, 64, 80, 3)) # batch of 2 RGB images
|
| 504 |
+
>>> transform = np.array(
|
| 505 |
+
... [
|
| 506 |
+
... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom
|
| 507 |
+
... [1, 0, -20, 0, 1, -16, 0, 0], # translation
|
| 508 |
+
... ]
|
| 509 |
+
... )
|
| 510 |
+
>>> y = keras.ops.image.affine_transform(x, transform)
|
| 511 |
+
>>> y.shape
|
| 512 |
+
(2, 64, 80, 3)
|
| 513 |
+
|
| 514 |
+
>>> x = np.random.random((64, 80, 3)) # single RGB image
|
| 515 |
+
>>> transform = np.array([1.0, 0.5, -20, 0.5, 1.0, -16, 0, 0]) # shear
|
| 516 |
+
>>> y = keras.ops.image.affine_transform(x, transform)
|
| 517 |
+
>>> y.shape
|
| 518 |
+
(64, 80, 3)
|
| 519 |
+
|
| 520 |
+
>>> x = np.random.random((2, 3, 64, 80)) # batch of 2 RGB images
|
| 521 |
+
>>> transform = np.array(
|
| 522 |
+
... [
|
| 523 |
+
... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom
|
| 524 |
+
... [1, 0, -20, 0, 1, -16, 0, 0], # translation
|
| 525 |
+
... ]
|
| 526 |
+
... )
|
| 527 |
+
>>> y = keras.ops.image.affine_transform(x, transform,
|
| 528 |
+
... data_format="channels_first")
|
| 529 |
+
>>> y.shape
|
| 530 |
+
(2, 3, 64, 80)
|
| 531 |
+
"""
|
| 532 |
+
if any_symbolic_tensors((images, transform)):
|
| 533 |
+
return AffineTransform(
|
| 534 |
+
interpolation=interpolation,
|
| 535 |
+
fill_mode=fill_mode,
|
| 536 |
+
fill_value=fill_value,
|
| 537 |
+
data_format=data_format,
|
| 538 |
+
).symbolic_call(images, transform)
|
| 539 |
+
return backend.image.affine_transform(
|
| 540 |
+
images,
|
| 541 |
+
transform,
|
| 542 |
+
interpolation=interpolation,
|
| 543 |
+
fill_mode=fill_mode,
|
| 544 |
+
fill_value=fill_value,
|
| 545 |
+
data_format=data_format,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class ExtractPatches(Operation):
|
| 550 |
+
def __init__(
|
| 551 |
+
self,
|
| 552 |
+
size,
|
| 553 |
+
strides=None,
|
| 554 |
+
dilation_rate=1,
|
| 555 |
+
padding="valid",
|
| 556 |
+
data_format=None,
|
| 557 |
+
):
|
| 558 |
+
super().__init__()
|
| 559 |
+
if isinstance(size, int):
|
| 560 |
+
size = (size, size)
|
| 561 |
+
self.size = size
|
| 562 |
+
self.strides = strides
|
| 563 |
+
self.dilation_rate = dilation_rate
|
| 564 |
+
self.padding = padding
|
| 565 |
+
self.data_format = backend.standardize_data_format(data_format)
|
| 566 |
+
|
| 567 |
+
def call(self, images):
|
| 568 |
+
return _extract_patches(
|
| 569 |
+
images=images,
|
| 570 |
+
size=self.size,
|
| 571 |
+
strides=self.strides,
|
| 572 |
+
dilation_rate=self.dilation_rate,
|
| 573 |
+
padding=self.padding,
|
| 574 |
+
data_format=self.data_format,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
def compute_output_spec(self, images):
|
| 578 |
+
images_shape = list(images.shape)
|
| 579 |
+
original_ndim = len(images_shape)
|
| 580 |
+
if not self.strides:
|
| 581 |
+
strides = (self.size[0], self.size[1])
|
| 582 |
+
if self.data_format == "channels_last":
|
| 583 |
+
channels_in = images_shape[-1]
|
| 584 |
+
else:
|
| 585 |
+
channels_in = images_shape[-3]
|
| 586 |
+
if original_ndim == 3:
|
| 587 |
+
images_shape = [1] + images_shape
|
| 588 |
+
filters = self.size[0] * self.size[1] * channels_in
|
| 589 |
+
kernel_size = (self.size[0], self.size[1])
|
| 590 |
+
out_shape = compute_conv_output_shape(
|
| 591 |
+
images_shape,
|
| 592 |
+
filters,
|
| 593 |
+
kernel_size,
|
| 594 |
+
strides=strides,
|
| 595 |
+
padding=self.padding,
|
| 596 |
+
data_format=self.data_format,
|
| 597 |
+
dilation_rate=self.dilation_rate,
|
| 598 |
+
)
|
| 599 |
+
if original_ndim == 3:
|
| 600 |
+
out_shape = out_shape[1:]
|
| 601 |
+
return KerasTensor(shape=out_shape, dtype=images.dtype)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
@keras_export("keras.ops.image.extract_patches")
|
| 605 |
+
def extract_patches(
|
| 606 |
+
images,
|
| 607 |
+
size,
|
| 608 |
+
strides=None,
|
| 609 |
+
dilation_rate=1,
|
| 610 |
+
padding="valid",
|
| 611 |
+
data_format=None,
|
| 612 |
+
):
|
| 613 |
+
"""Extracts patches from the image(s).
|
| 614 |
+
|
| 615 |
+
Args:
|
| 616 |
+
images: Input image or batch of images. Must be 3D or 4D.
|
| 617 |
+
size: Patch size int or tuple (patch_height, patch_width)
|
| 618 |
+
strides: strides along height and width. If not specified, or
|
| 619 |
+
if `None`, it defaults to the same value as `size`.
|
| 620 |
+
dilation_rate: This is the input stride, specifying how far two
|
| 621 |
+
consecutive patch samples are in the input. For value other than 1,
|
| 622 |
+
strides must be 1. NOTE: `strides > 1` is not supported in
|
| 623 |
+
conjunction with `dilation_rate > 1`
|
| 624 |
+
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
|
| 625 |
+
data_format: A string specifying the data format of the input tensor.
|
| 626 |
+
It can be either `"channels_last"` or `"channels_first"`.
|
| 627 |
+
`"channels_last"` corresponds to inputs with shape
|
| 628 |
+
`(batch, height, width, channels)`, while `"channels_first"`
|
| 629 |
+
corresponds to inputs with shape `(batch, channels, height, width)`.
|
| 630 |
+
If not specified, the value will default to
|
| 631 |
+
`keras.config.image_data_format`.
|
| 632 |
+
|
| 633 |
+
Returns:
|
| 634 |
+
Extracted patches 3D (if not batched) or 4D (if batched)
|
| 635 |
+
|
| 636 |
+
Examples:
|
| 637 |
+
|
| 638 |
+
>>> image = np.random.random(
|
| 639 |
+
... (2, 20, 20, 3)
|
| 640 |
+
... ).astype("float32") # batch of 2 RGB images
|
| 641 |
+
>>> patches = keras.ops.image.extract_patches(image, (5, 5))
|
| 642 |
+
>>> patches.shape
|
| 643 |
+
(2, 4, 4, 75)
|
| 644 |
+
>>> image = np.random.random((20, 20, 3)).astype("float32") # 1 RGB image
|
| 645 |
+
>>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))
|
| 646 |
+
>>> patches.shape
|
| 647 |
+
(18, 18, 27)
|
| 648 |
+
"""
|
| 649 |
+
if any_symbolic_tensors((images,)):
|
| 650 |
+
return ExtractPatches(
|
| 651 |
+
size=size,
|
| 652 |
+
strides=strides,
|
| 653 |
+
dilation_rate=dilation_rate,
|
| 654 |
+
padding=padding,
|
| 655 |
+
data_format=data_format,
|
| 656 |
+
).symbolic_call(images)
|
| 657 |
+
|
| 658 |
+
return _extract_patches(
|
| 659 |
+
images, size, strides, dilation_rate, padding, data_format=data_format
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def _extract_patches(
|
| 664 |
+
images,
|
| 665 |
+
size,
|
| 666 |
+
strides=None,
|
| 667 |
+
dilation_rate=1,
|
| 668 |
+
padding="valid",
|
| 669 |
+
data_format=None,
|
| 670 |
+
):
|
| 671 |
+
if isinstance(size, int):
|
| 672 |
+
patch_h = patch_w = size
|
| 673 |
+
elif len(size) == 2:
|
| 674 |
+
patch_h, patch_w = size[0], size[1]
|
| 675 |
+
else:
|
| 676 |
+
raise TypeError(
|
| 677 |
+
"Invalid `size` argument. Expected an "
|
| 678 |
+
f"int or a tuple of length 2. Received: size={size}"
|
| 679 |
+
)
|
| 680 |
+
data_format = backend.standardize_data_format(data_format)
|
| 681 |
+
if data_format == "channels_last":
|
| 682 |
+
channels_in = images.shape[-1]
|
| 683 |
+
elif data_format == "channels_first":
|
| 684 |
+
channels_in = images.shape[-3]
|
| 685 |
+
if not strides:
|
| 686 |
+
strides = size
|
| 687 |
+
out_dim = patch_h * patch_w * channels_in
|
| 688 |
+
kernel = backend.numpy.eye(out_dim, dtype=images.dtype)
|
| 689 |
+
kernel = backend.numpy.reshape(
|
| 690 |
+
kernel, (patch_h, patch_w, channels_in, out_dim)
|
| 691 |
+
)
|
| 692 |
+
_unbatched = False
|
| 693 |
+
if len(images.shape) == 3:
|
| 694 |
+
_unbatched = True
|
| 695 |
+
images = backend.numpy.expand_dims(images, axis=0)
|
| 696 |
+
patches = backend.nn.conv(
|
| 697 |
+
inputs=images,
|
| 698 |
+
kernel=kernel,
|
| 699 |
+
strides=strides,
|
| 700 |
+
padding=padding,
|
| 701 |
+
data_format=data_format,
|
| 702 |
+
dilation_rate=dilation_rate,
|
| 703 |
+
)
|
| 704 |
+
if _unbatched:
|
| 705 |
+
patches = backend.numpy.squeeze(patches, axis=0)
|
| 706 |
+
return patches
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
class MapCoordinates(Operation):
|
| 710 |
+
def __init__(self, order, fill_mode="constant", fill_value=0):
|
| 711 |
+
super().__init__()
|
| 712 |
+
self.order = order
|
| 713 |
+
self.fill_mode = fill_mode
|
| 714 |
+
self.fill_value = fill_value
|
| 715 |
+
|
| 716 |
+
def call(self, inputs, coordinates):
|
| 717 |
+
return backend.image.map_coordinates(
|
| 718 |
+
inputs,
|
| 719 |
+
coordinates,
|
| 720 |
+
order=self.order,
|
| 721 |
+
fill_mode=self.fill_mode,
|
| 722 |
+
fill_value=self.fill_value,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
def compute_output_spec(self, inputs, coordinates):
|
| 726 |
+
if coordinates.shape[0] != len(inputs.shape):
|
| 727 |
+
raise ValueError(
|
| 728 |
+
"First dim of `coordinates` must be the same as the rank of "
|
| 729 |
+
"`inputs`. "
|
| 730 |
+
f"Received inputs with shape: {inputs.shape} and coordinate "
|
| 731 |
+
f"leading dim of {coordinates.shape[0]}"
|
| 732 |
+
)
|
| 733 |
+
if len(coordinates.shape) < 2:
|
| 734 |
+
raise ValueError(
|
| 735 |
+
"Invalid coordinates rank: expected at least rank 2."
|
| 736 |
+
f" Received input with shape: {coordinates.shape}"
|
| 737 |
+
)
|
| 738 |
+
return KerasTensor(coordinates.shape[1:], dtype=inputs.dtype)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
@keras_export("keras.ops.image.map_coordinates")
|
| 742 |
+
def map_coordinates(
|
| 743 |
+
inputs, coordinates, order, fill_mode="constant", fill_value=0
|
| 744 |
+
):
|
| 745 |
+
"""Map the input array to new coordinates by interpolation.
|
| 746 |
+
|
| 747 |
+
Note that interpolation near boundaries differs from the scipy function,
|
| 748 |
+
because we fixed an outstanding bug
|
| 749 |
+
[scipy/issues/2640](https://github.com/scipy/scipy/issues/2640).
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
inputs: The input array.
|
| 753 |
+
coordinates: The coordinates at which inputs is evaluated.
|
| 754 |
+
order: The order of the spline interpolation. The order must be `0` or
|
| 755 |
+
`1`. `0` indicates the nearest neighbor and `1` indicates the linear
|
| 756 |
+
interpolation.
|
| 757 |
+
fill_mode: Points outside the boundaries of the inputs are filled
|
| 758 |
+
according to the given mode. Available methods are `"constant"`,
|
| 759 |
+
`"nearest"`, `"wrap"` and `"mirror"` and `"reflect"`. Defaults to
|
| 760 |
+
`"constant"`.
|
| 761 |
+
- `"constant"`: `(k k k k | a b c d | k k k k)`
|
| 762 |
+
The inputs is extended by filling all values beyond
|
| 763 |
+
the edge with the same constant value k specified by
|
| 764 |
+
`fill_value`.
|
| 765 |
+
- `"nearest"`: `(a a a a | a b c d | d d d d)`
|
| 766 |
+
The inputs is extended by the nearest pixel.
|
| 767 |
+
- `"wrap"`: `(a b c d | a b c d | a b c d)`
|
| 768 |
+
The inputs is extended by wrapping around to the opposite edge.
|
| 769 |
+
- `"mirror"`: `(c d c b | a b c d | c b a b)`
|
| 770 |
+
The inputs is extended by mirroring about the edge.
|
| 771 |
+
- `"reflect"`: `(d c b a | a b c d | d c b a)`
|
| 772 |
+
The inputs is extended by reflecting about the edge of the last
|
| 773 |
+
pixel.
|
| 774 |
+
fill_value: Value used for points outside the boundaries of the inputs
|
| 775 |
+
if `fill_mode="constant"`. Defaults to `0`.
|
| 776 |
+
|
| 777 |
+
Returns:
|
| 778 |
+
Output input or batch of inputs.
|
| 779 |
+
|
| 780 |
+
"""
|
| 781 |
+
if any_symbolic_tensors((inputs, coordinates)):
|
| 782 |
+
return MapCoordinates(
|
| 783 |
+
order,
|
| 784 |
+
fill_mode,
|
| 785 |
+
fill_value,
|
| 786 |
+
).symbolic_call(inputs, coordinates)
|
| 787 |
+
return backend.image.map_coordinates(
|
| 788 |
+
inputs,
|
| 789 |
+
coordinates,
|
| 790 |
+
order,
|
| 791 |
+
fill_mode,
|
| 792 |
+
fill_value,
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
class PadImages(Operation):
|
| 797 |
+
def __init__(
|
| 798 |
+
self,
|
| 799 |
+
top_padding=None,
|
| 800 |
+
left_padding=None,
|
| 801 |
+
bottom_padding=None,
|
| 802 |
+
right_padding=None,
|
| 803 |
+
target_height=None,
|
| 804 |
+
target_width=None,
|
| 805 |
+
data_format=None,
|
| 806 |
+
):
|
| 807 |
+
super().__init__()
|
| 808 |
+
self.top_padding = top_padding
|
| 809 |
+
self.left_padding = left_padding
|
| 810 |
+
self.bottom_padding = bottom_padding
|
| 811 |
+
self.right_padding = right_padding
|
| 812 |
+
self.target_height = target_height
|
| 813 |
+
self.target_width = target_width
|
| 814 |
+
self.data_format = backend.standardize_data_format(data_format)
|
| 815 |
+
|
| 816 |
+
def call(self, images):
|
| 817 |
+
return _pad_images(
|
| 818 |
+
images,
|
| 819 |
+
self.top_padding,
|
| 820 |
+
self.left_padding,
|
| 821 |
+
self.bottom_padding,
|
| 822 |
+
self.right_padding,
|
| 823 |
+
self.target_height,
|
| 824 |
+
self.target_width,
|
| 825 |
+
self.data_format,
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
def compute_output_spec(self, images):
|
| 829 |
+
images_shape = list(images.shape)
|
| 830 |
+
|
| 831 |
+
if self.data_format == "channels_last":
|
| 832 |
+
height_axis, width_axis = -3, -2
|
| 833 |
+
height, width = images_shape[height_axis], images_shape[width_axis]
|
| 834 |
+
else:
|
| 835 |
+
height_axis, width_axis = -2, -1
|
| 836 |
+
height, width = images_shape[height_axis], images_shape[width_axis]
|
| 837 |
+
|
| 838 |
+
target_height = self.target_height
|
| 839 |
+
if target_height is None and height is not None:
|
| 840 |
+
target_height = self.top_padding + height + self.bottom_padding
|
| 841 |
+
target_width = self.target_width
|
| 842 |
+
if target_width is None and width is not None:
|
| 843 |
+
target_width = self.left_padding + width + self.right_padding
|
| 844 |
+
|
| 845 |
+
images_shape[height_axis] = target_height
|
| 846 |
+
images_shape[width_axis] = target_width
|
| 847 |
+
return KerasTensor(shape=images_shape, dtype=images.dtype)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
@keras_export("keras.ops.image.pad_images")
|
| 851 |
+
def pad_images(
|
| 852 |
+
images,
|
| 853 |
+
top_padding=None,
|
| 854 |
+
left_padding=None,
|
| 855 |
+
bottom_padding=None,
|
| 856 |
+
right_padding=None,
|
| 857 |
+
target_height=None,
|
| 858 |
+
target_width=None,
|
| 859 |
+
data_format=None,
|
| 860 |
+
):
|
| 861 |
+
"""Pad `images` with zeros to the specified `height` and `width`.
|
| 862 |
+
|
| 863 |
+
Args:
|
| 864 |
+
images: Input image or batch of images. Must be 3D or 4D.
|
| 865 |
+
top_padding: Number of rows of zeros to add on top.
|
| 866 |
+
left_padding: Number of columns of zeros to add on the left.
|
| 867 |
+
bottom_padding: Number of rows of zeros to add at the bottom.
|
| 868 |
+
right_padding: Number of columns of zeros to add on the right.
|
| 869 |
+
target_height: Height of output images.
|
| 870 |
+
target_width: Width of output images.
|
| 871 |
+
data_format: A string specifying the data format of the input tensor.
|
| 872 |
+
It can be either `"channels_last"` or `"channels_first"`.
|
| 873 |
+
`"channels_last"` corresponds to inputs with shape
|
| 874 |
+
`(batch, height, width, channels)`, while `"channels_first"`
|
| 875 |
+
corresponds to inputs with shape `(batch, channels, height, width)`.
|
| 876 |
+
If not specified, the value will default to
|
| 877 |
+
`keras.config.image_data_format`.
|
| 878 |
+
|
| 879 |
+
Returns:
|
| 880 |
+
Padded image or batch of images.
|
| 881 |
+
|
| 882 |
+
Example:
|
| 883 |
+
|
| 884 |
+
>>> images = np.random.random((15, 25, 3))
|
| 885 |
+
>>> padded_images = keras.ops.image.pad_images(
|
| 886 |
+
... images, 2, 3, target_height=20, target_width=30
|
| 887 |
+
... )
|
| 888 |
+
>>> padded_images.shape
|
| 889 |
+
(20, 30, 3)
|
| 890 |
+
|
| 891 |
+
>>> batch_images = np.random.random((2, 15, 25, 3))
|
| 892 |
+
>>> padded_batch = keras.ops.image.pad_images(
|
| 893 |
+
... batch_images, 2, 3, target_height=20, target_width=30
|
| 894 |
+
... )
|
| 895 |
+
>>> padded_batch.shape
|
| 896 |
+
(2, 20, 30, 3)"""
|
| 897 |
+
|
| 898 |
+
if any_symbolic_tensors((images,)):
|
| 899 |
+
return PadImages(
|
| 900 |
+
top_padding,
|
| 901 |
+
left_padding,
|
| 902 |
+
bottom_padding,
|
| 903 |
+
right_padding,
|
| 904 |
+
target_height,
|
| 905 |
+
target_width,
|
| 906 |
+
data_format,
|
| 907 |
+
).symbolic_call(images)
|
| 908 |
+
|
| 909 |
+
return _pad_images(
|
| 910 |
+
images,
|
| 911 |
+
top_padding,
|
| 912 |
+
left_padding,
|
| 913 |
+
bottom_padding,
|
| 914 |
+
right_padding,
|
| 915 |
+
target_height,
|
| 916 |
+
target_width,
|
| 917 |
+
data_format,
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
def _pad_images(
|
| 922 |
+
images,
|
| 923 |
+
top_padding,
|
| 924 |
+
left_padding,
|
| 925 |
+
bottom_padding,
|
| 926 |
+
right_padding,
|
| 927 |
+
target_height,
|
| 928 |
+
target_width,
|
| 929 |
+
data_format=None,
|
| 930 |
+
):
|
| 931 |
+
data_format = backend.standardize_data_format(data_format)
|
| 932 |
+
images = backend.convert_to_tensor(images)
|
| 933 |
+
images_shape = ops.shape(images)
|
| 934 |
+
|
| 935 |
+
# Check
|
| 936 |
+
if len(images_shape) not in (3, 4):
|
| 937 |
+
raise ValueError(
|
| 938 |
+
f"Invalid shape for argument `images`: "
|
| 939 |
+
"it must have rank 3 or 4. "
|
| 940 |
+
f"Received: images.shape={images_shape}"
|
| 941 |
+
)
|
| 942 |
+
if [top_padding, bottom_padding, target_height].count(None) != 1:
|
| 943 |
+
raise ValueError(
|
| 944 |
+
"Must specify exactly two of "
|
| 945 |
+
"top_padding, bottom_padding, target_height. "
|
| 946 |
+
f"Received: top_padding={top_padding}, "
|
| 947 |
+
f"bottom_padding={bottom_padding}, "
|
| 948 |
+
f"target_height={target_height}"
|
| 949 |
+
)
|
| 950 |
+
if [left_padding, right_padding, target_width].count(None) != 1:
|
| 951 |
+
raise ValueError(
|
| 952 |
+
"Must specify exactly two of "
|
| 953 |
+
"left_padding, right_padding, target_width. "
|
| 954 |
+
f"Received: left_padding={left_padding}, "
|
| 955 |
+
f"right_padding={right_padding}, "
|
| 956 |
+
f"target_width={target_width}"
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
is_batch = False if len(images_shape) == 3 else True
|
| 960 |
+
if data_format == "channels_last":
|
| 961 |
+
height, width = images_shape[-3], images_shape[-2]
|
| 962 |
+
else:
|
| 963 |
+
height, width = images_shape[-2], images_shape[-1]
|
| 964 |
+
|
| 965 |
+
# Infer padding
|
| 966 |
+
if top_padding is None:
|
| 967 |
+
top_padding = target_height - bottom_padding - height
|
| 968 |
+
if bottom_padding is None:
|
| 969 |
+
bottom_padding = target_height - top_padding - height
|
| 970 |
+
if left_padding is None:
|
| 971 |
+
left_padding = target_width - right_padding - width
|
| 972 |
+
if right_padding is None:
|
| 973 |
+
right_padding = target_width - left_padding - width
|
| 974 |
+
|
| 975 |
+
if top_padding < 0:
|
| 976 |
+
raise ValueError(
|
| 977 |
+
f"top_padding must be >= 0. Received: top_padding={top_padding}"
|
| 978 |
+
)
|
| 979 |
+
if left_padding < 0:
|
| 980 |
+
raise ValueError(
|
| 981 |
+
"left_padding must be >= 0. "
|
| 982 |
+
f"Received: left_padding={left_padding}"
|
| 983 |
+
)
|
| 984 |
+
if right_padding < 0:
|
| 985 |
+
raise ValueError(
|
| 986 |
+
"right_padding must be >= 0. "
|
| 987 |
+
f"Received: right_padding={right_padding}"
|
| 988 |
+
)
|
| 989 |
+
if bottom_padding < 0:
|
| 990 |
+
raise ValueError(
|
| 991 |
+
"bottom_padding must be >= 0. "
|
| 992 |
+
f"Received: bottom_padding={bottom_padding}"
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
# Compute pad_width
|
| 996 |
+
pad_width = [[top_padding, bottom_padding], [left_padding, right_padding]]
|
| 997 |
+
if data_format == "channels_last":
|
| 998 |
+
pad_width = pad_width + [[0, 0]]
|
| 999 |
+
else:
|
| 1000 |
+
pad_width = [[0, 0]] + pad_width
|
| 1001 |
+
if is_batch:
|
| 1002 |
+
pad_width = [[0, 0]] + pad_width
|
| 1003 |
+
|
| 1004 |
+
padded_images = backend.numpy.pad(images, pad_width)
|
| 1005 |
+
return padded_images
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
class CropImages(Operation):
|
| 1009 |
+
def __init__(
|
| 1010 |
+
self,
|
| 1011 |
+
top_cropping,
|
| 1012 |
+
left_cropping,
|
| 1013 |
+
bottom_cropping,
|
| 1014 |
+
right_cropping,
|
| 1015 |
+
target_height,
|
| 1016 |
+
target_width,
|
| 1017 |
+
data_format=None,
|
| 1018 |
+
):
|
| 1019 |
+
super().__init__()
|
| 1020 |
+
self.top_cropping = top_cropping
|
| 1021 |
+
self.bottom_cropping = bottom_cropping
|
| 1022 |
+
self.left_cropping = left_cropping
|
| 1023 |
+
self.right_cropping = right_cropping
|
| 1024 |
+
self.target_height = target_height
|
| 1025 |
+
self.target_width = target_width
|
| 1026 |
+
self.data_format = backend.standardize_data_format(data_format)
|
| 1027 |
+
|
| 1028 |
+
def call(self, images):
|
| 1029 |
+
return _crop_images(
|
| 1030 |
+
images,
|
| 1031 |
+
self.top_cropping,
|
| 1032 |
+
self.left_cropping,
|
| 1033 |
+
self.bottom_cropping,
|
| 1034 |
+
self.right_cropping,
|
| 1035 |
+
self.target_height,
|
| 1036 |
+
self.target_width,
|
| 1037 |
+
self.data_format,
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
def compute_output_spec(self, images):
|
| 1041 |
+
images_shape = list(images.shape)
|
| 1042 |
+
|
| 1043 |
+
if self.data_format == "channels_last":
|
| 1044 |
+
height_axis, width_axis = -3, -2
|
| 1045 |
+
else:
|
| 1046 |
+
height_axis, width_axis = -2, -1
|
| 1047 |
+
height, width = images_shape[height_axis], images_shape[width_axis]
|
| 1048 |
+
|
| 1049 |
+
if height is None and self.target_height is None:
|
| 1050 |
+
raise ValueError(
|
| 1051 |
+
"When the height of the images is unknown, `target_height` "
|
| 1052 |
+
"must be specified."
|
| 1053 |
+
f"Received images.shape={images_shape} and "
|
| 1054 |
+
f"target_height={self.target_height}"
|
| 1055 |
+
)
|
| 1056 |
+
if width is None and self.target_width is None:
|
| 1057 |
+
raise ValueError(
|
| 1058 |
+
"When the width of the images is unknown, `target_width` "
|
| 1059 |
+
"must be specified."
|
| 1060 |
+
f"Received images.shape={images_shape} and "
|
| 1061 |
+
f"target_width={self.target_width}"
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
target_height = self.target_height
|
| 1065 |
+
if target_height is None:
|
| 1066 |
+
target_height = height - self.top_cropping - self.bottom_cropping
|
| 1067 |
+
target_width = self.target_width
|
| 1068 |
+
if target_width is None:
|
| 1069 |
+
target_width = width - self.left_cropping - self.right_cropping
|
| 1070 |
+
|
| 1071 |
+
images_shape[height_axis] = target_height
|
| 1072 |
+
images_shape[width_axis] = target_width
|
| 1073 |
+
return KerasTensor(shape=images_shape, dtype=images.dtype)
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
@keras_export("keras.ops.image.crop_images")
|
| 1077 |
+
def crop_images(
|
| 1078 |
+
images,
|
| 1079 |
+
top_cropping=None,
|
| 1080 |
+
left_cropping=None,
|
| 1081 |
+
bottom_cropping=None,
|
| 1082 |
+
right_cropping=None,
|
| 1083 |
+
target_height=None,
|
| 1084 |
+
target_width=None,
|
| 1085 |
+
data_format=None,
|
| 1086 |
+
):
|
| 1087 |
+
"""Crop `images` to a specified `height` and `width`.
|
| 1088 |
+
|
| 1089 |
+
Args:
|
| 1090 |
+
images: Input image or batch of images. Must be 3D or 4D.
|
| 1091 |
+
top_cropping: Number of columns to crop from the top.
|
| 1092 |
+
left_cropping: Number of columns to crop from the left.
|
| 1093 |
+
bottom_cropping: Number of columns to crop from the bottom.
|
| 1094 |
+
right_cropping: Number of columns to crop from the right.
|
| 1095 |
+
target_height: Height of the output images.
|
| 1096 |
+
target_width: Width of the output images.
|
| 1097 |
+
data_format: A string specifying the data format of the input tensor.
|
| 1098 |
+
It can be either `"channels_last"` or `"channels_first"`.
|
| 1099 |
+
`"channels_last"` corresponds to inputs with shape
|
| 1100 |
+
`(batch, height, width, channels)`, while `"channels_first"`
|
| 1101 |
+
corresponds to inputs with shape `(batch, channels, height, width)`.
|
| 1102 |
+
If not specified, the value will default to
|
| 1103 |
+
`keras.config.image_data_format`.
|
| 1104 |
+
|
| 1105 |
+
Returns:
|
| 1106 |
+
Cropped image or batch of images.
|
| 1107 |
+
|
| 1108 |
+
Example:
|
| 1109 |
+
|
| 1110 |
+
>>> images = np.reshape(np.arange(1, 28, dtype="float32"), [3, 3, 3])
|
| 1111 |
+
>>> images[:,:,0] # print the first channel of the images
|
| 1112 |
+
array([[ 1., 4., 7.],
|
| 1113 |
+
[10., 13., 16.],
|
| 1114 |
+
[19., 22., 25.]], dtype=float32)
|
| 1115 |
+
>>> cropped_images = keras.image.crop_images(images, 0, 0, 2, 2)
|
| 1116 |
+
>>> cropped_images[:,:,0] # print the first channel of the cropped images
|
| 1117 |
+
array([[ 1., 4.],
|
| 1118 |
+
[10., 13.]], dtype=float32)"""
|
| 1119 |
+
|
| 1120 |
+
if any_symbolic_tensors((images,)):
|
| 1121 |
+
return CropImages(
|
| 1122 |
+
top_cropping,
|
| 1123 |
+
left_cropping,
|
| 1124 |
+
bottom_cropping,
|
| 1125 |
+
right_cropping,
|
| 1126 |
+
target_height,
|
| 1127 |
+
target_width,
|
| 1128 |
+
data_format,
|
| 1129 |
+
).symbolic_call(images)
|
| 1130 |
+
|
| 1131 |
+
return _crop_images(
|
| 1132 |
+
images,
|
| 1133 |
+
top_cropping,
|
| 1134 |
+
left_cropping,
|
| 1135 |
+
bottom_cropping,
|
| 1136 |
+
right_cropping,
|
| 1137 |
+
target_height,
|
| 1138 |
+
target_width,
|
| 1139 |
+
data_format,
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
def _crop_images(
|
| 1144 |
+
images,
|
| 1145 |
+
top_cropping,
|
| 1146 |
+
left_cropping,
|
| 1147 |
+
bottom_cropping,
|
| 1148 |
+
right_cropping,
|
| 1149 |
+
target_height,
|
| 1150 |
+
target_width,
|
| 1151 |
+
data_format=None,
|
| 1152 |
+
):
|
| 1153 |
+
data_format = backend.standardize_data_format(data_format)
|
| 1154 |
+
images = backend.convert_to_tensor(images)
|
| 1155 |
+
images_shape = ops.shape(images)
|
| 1156 |
+
|
| 1157 |
+
# Check
|
| 1158 |
+
if len(images_shape) not in (3, 4):
|
| 1159 |
+
raise ValueError(
|
| 1160 |
+
f"Invalid shape for argument `images`: "
|
| 1161 |
+
"it must have rank 3 or 4. "
|
| 1162 |
+
f"Received: images.shape={images_shape}"
|
| 1163 |
+
)
|
| 1164 |
+
if [top_cropping, bottom_cropping, target_height].count(None) != 1:
|
| 1165 |
+
raise ValueError(
|
| 1166 |
+
"Must specify exactly two of "
|
| 1167 |
+
"top_cropping, bottom_cropping, target_height. "
|
| 1168 |
+
f"Received: top_cropping={top_cropping}, "
|
| 1169 |
+
f"bottom_cropping={bottom_cropping}, "
|
| 1170 |
+
f"target_height={target_height}"
|
| 1171 |
+
)
|
| 1172 |
+
if [left_cropping, right_cropping, target_width].count(None) != 1:
|
| 1173 |
+
raise ValueError(
|
| 1174 |
+
"Must specify exactly two of "
|
| 1175 |
+
"left_cropping, right_cropping, target_width. "
|
| 1176 |
+
f"Received: left_cropping={left_cropping}, "
|
| 1177 |
+
f"right_cropping={right_cropping}, "
|
| 1178 |
+
f"target_width={target_width}"
|
| 1179 |
+
)
|
| 1180 |
+
|
| 1181 |
+
is_batch = False if len(images_shape) == 3 else True
|
| 1182 |
+
if data_format == "channels_last":
|
| 1183 |
+
height, width = images_shape[-3], images_shape[-2]
|
| 1184 |
+
channels = images_shape[-1]
|
| 1185 |
+
else:
|
| 1186 |
+
height, width = images_shape[-2], images_shape[-1]
|
| 1187 |
+
channels = images_shape[-3]
|
| 1188 |
+
|
| 1189 |
+
# Infer padding
|
| 1190 |
+
if top_cropping is None:
|
| 1191 |
+
top_cropping = height - target_height - bottom_cropping
|
| 1192 |
+
if target_height is None:
|
| 1193 |
+
target_height = height - bottom_cropping - top_cropping
|
| 1194 |
+
if left_cropping is None:
|
| 1195 |
+
left_cropping = width - target_width - right_cropping
|
| 1196 |
+
if target_width is None:
|
| 1197 |
+
target_width = width - right_cropping - left_cropping
|
| 1198 |
+
|
| 1199 |
+
if top_cropping < 0:
|
| 1200 |
+
raise ValueError(
|
| 1201 |
+
"top_cropping must be >= 0. "
|
| 1202 |
+
f"Received: top_cropping={top_cropping}"
|
| 1203 |
+
)
|
| 1204 |
+
if target_height < 0:
|
| 1205 |
+
raise ValueError(
|
| 1206 |
+
"target_height must be >= 0. "
|
| 1207 |
+
f"Received: target_height={target_height}"
|
| 1208 |
+
)
|
| 1209 |
+
if left_cropping < 0:
|
| 1210 |
+
raise ValueError(
|
| 1211 |
+
"left_cropping must be >= 0. "
|
| 1212 |
+
f"Received: left_cropping={left_cropping}"
|
| 1213 |
+
)
|
| 1214 |
+
if target_width < 0:
|
| 1215 |
+
raise ValueError(
|
| 1216 |
+
"target_width must be >= 0. "
|
| 1217 |
+
f"Received: target_width={target_width}"
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
# Compute start_indices and shape
|
| 1221 |
+
start_indices = [top_cropping, left_cropping]
|
| 1222 |
+
shape = [target_height, target_width]
|
| 1223 |
+
if data_format == "channels_last":
|
| 1224 |
+
start_indices = start_indices + [0]
|
| 1225 |
+
shape = shape + [channels]
|
| 1226 |
+
else:
|
| 1227 |
+
start_indices = [0] + start_indices
|
| 1228 |
+
shape = [channels] + shape
|
| 1229 |
+
if is_batch:
|
| 1230 |
+
batch_size = images_shape[0]
|
| 1231 |
+
start_indices = [0] + start_indices
|
| 1232 |
+
shape = [batch_size] + shape
|
| 1233 |
+
|
| 1234 |
+
cropped_images = ops.slice(images, start_indices, shape)
|
| 1235 |
+
return cropped_images
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/linalg.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src.api_export import keras_export
|
| 3 |
+
from keras.src.backend import KerasTensor
|
| 4 |
+
from keras.src.backend import any_symbolic_tensors
|
| 5 |
+
from keras.src.ops.operation import Operation
|
| 6 |
+
from keras.src.ops.operation_utils import reduce_shape
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Cholesky(Operation):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
def call(self, x):
|
| 14 |
+
return _cholesky(x)
|
| 15 |
+
|
| 16 |
+
def compute_output_spec(self, x):
|
| 17 |
+
_assert_2d(x)
|
| 18 |
+
_assert_square(x)
|
| 19 |
+
return KerasTensor(x.shape, x.dtype)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@keras_export(["keras.ops.cholesky", "keras.ops.linalg.cholesky"])
|
| 23 |
+
def cholesky(x):
|
| 24 |
+
"""Computes the Cholesky decomposition of a positive semi-definite matrix.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
x: Input tensor of shape `(..., M, M)`.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
A tensor of shape `(..., M, M)` representing the lower triangular
|
| 31 |
+
Cholesky factor of `x`.
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
if any_symbolic_tensors((x,)):
|
| 35 |
+
return Cholesky().symbolic_call(x)
|
| 36 |
+
return _cholesky(x)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _cholesky(x):
|
| 40 |
+
x = backend.convert_to_tensor(x)
|
| 41 |
+
_assert_2d(x)
|
| 42 |
+
_assert_square(x)
|
| 43 |
+
try:
|
| 44 |
+
return backend.linalg.cholesky(x)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
raise ValueError(f"Cholesky decomposition failed: {e}")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Det(Operation):
|
| 50 |
+
def __init__(self):
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
def call(self, x):
|
| 54 |
+
return _det(x)
|
| 55 |
+
|
| 56 |
+
def compute_output_spec(self, x):
|
| 57 |
+
_assert_2d(x)
|
| 58 |
+
_assert_square(x)
|
| 59 |
+
return KerasTensor(x.shape[:-2], x.dtype)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@keras_export(["keras.ops.det", "keras.ops.linalg.det"])
|
| 63 |
+
def det(x):
|
| 64 |
+
"""Computes the determinant of a square tensor.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
x: Input tensor of shape `(..., M, M)`.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
A tensor of shape `(...,)` representing the determinant of `x`.
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
if any_symbolic_tensors((x,)):
|
| 74 |
+
return Det().symbolic_call(x)
|
| 75 |
+
return _det(x)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _det(x):
|
| 79 |
+
x = backend.convert_to_tensor(x)
|
| 80 |
+
_assert_2d(x)
|
| 81 |
+
_assert_square(x)
|
| 82 |
+
return backend.linalg.det(x)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Eig(Operation):
|
| 86 |
+
def __init__(self):
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
def call(self, x):
|
| 90 |
+
return _eig(x)
|
| 91 |
+
|
| 92 |
+
def compute_output_spec(self, x):
|
| 93 |
+
_assert_square(x)
|
| 94 |
+
_assert_2d(x)
|
| 95 |
+
return (
|
| 96 |
+
KerasTensor(x.shape[:-1], x.dtype),
|
| 97 |
+
KerasTensor(x.shape, x.dtype),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@keras_export(["keras.ops.eig", "keras.ops.linalg.eig"])
|
| 102 |
+
def eig(x):
|
| 103 |
+
"""Computes the eigenvalues and eigenvectors of a square matrix.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
x: Input tensor of shape `(..., M, M)`.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
A tuple of two tensors: a tensor of shape `(..., M)` containing
|
| 110 |
+
eigenvalues and a tensor of shape `(..., M, M)` containing eigenvectors.
|
| 111 |
+
"""
|
| 112 |
+
if any_symbolic_tensors((x,)):
|
| 113 |
+
return Eig().symbolic_call(x)
|
| 114 |
+
return _eig(x)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _eig(x):
|
| 118 |
+
x = backend.convert_to_tensor(x)
|
| 119 |
+
_assert_square(x)
|
| 120 |
+
_assert_2d(x)
|
| 121 |
+
return backend.linalg.eig(x)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Eigh(Operation):
|
| 125 |
+
def __init__(self):
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
def call(self, x):
|
| 129 |
+
return _eigh(x)
|
| 130 |
+
|
| 131 |
+
def compute_output_spec(self, x):
|
| 132 |
+
_assert_square(x)
|
| 133 |
+
_assert_2d(x)
|
| 134 |
+
return (
|
| 135 |
+
KerasTensor(x.shape[:-1], x.dtype),
|
| 136 |
+
KerasTensor(x.shape, x.dtype),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@keras_export(["keras.ops.eigh", "keras.ops.linalg.eigh"])
|
| 141 |
+
def eigh(x):
|
| 142 |
+
"""Computes the eigenvalues and eigenvectors of a complex Hermitian.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
x: Input tensor of shape `(..., M, M)`.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
A tuple of two tensors: a tensor of shape `(..., M)` containing
|
| 149 |
+
eigenvalues and a tensor of shape `(..., M, M)` containing eigenvectors.
|
| 150 |
+
|
| 151 |
+
"""
|
| 152 |
+
if any_symbolic_tensors((x,)):
|
| 153 |
+
return Eigh().symbolic_call(x)
|
| 154 |
+
return _eigh(x)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _eigh(x):
|
| 158 |
+
x = backend.convert_to_tensor(x)
|
| 159 |
+
_assert_square(x)
|
| 160 |
+
_assert_2d(x)
|
| 161 |
+
return backend.linalg.eigh(x)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class Inv(Operation):
|
| 165 |
+
def __init__(self):
|
| 166 |
+
super().__init__()
|
| 167 |
+
|
| 168 |
+
def call(self, x):
|
| 169 |
+
return _inv(x)
|
| 170 |
+
|
| 171 |
+
def compute_output_spec(self, x):
|
| 172 |
+
_assert_2d(x)
|
| 173 |
+
_assert_square(x)
|
| 174 |
+
return KerasTensor(x.shape, x.dtype)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@keras_export(["keras.ops.inv", "keras.ops.linalg.inv"])
|
| 178 |
+
def inv(x):
|
| 179 |
+
"""Computes the inverse of a square tensor.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
x: Input tensor of shape `(..., M, M)`.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
A tensor of shape `(..., M, M)` representing the inverse of `x`.
|
| 186 |
+
|
| 187 |
+
"""
|
| 188 |
+
if any_symbolic_tensors((x,)):
|
| 189 |
+
return Inv().symbolic_call(x)
|
| 190 |
+
return _inv(x)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _inv(x):
|
| 194 |
+
x = backend.convert_to_tensor(x)
|
| 195 |
+
_assert_2d(x)
|
| 196 |
+
_assert_square(x)
|
| 197 |
+
return backend.linalg.inv(x)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class LuFactor(Operation):
|
| 201 |
+
def __init__(self):
|
| 202 |
+
super().__init__()
|
| 203 |
+
|
| 204 |
+
def call(self, x):
|
| 205 |
+
return _lu_factor(x)
|
| 206 |
+
|
| 207 |
+
def compute_output_spec(self, x):
|
| 208 |
+
_assert_2d(x)
|
| 209 |
+
batch_shape = x.shape[:-2]
|
| 210 |
+
m, n = x.shape[-2:]
|
| 211 |
+
k = min(m, n)
|
| 212 |
+
return (
|
| 213 |
+
KerasTensor(batch_shape + (m, n), x.dtype),
|
| 214 |
+
KerasTensor(batch_shape + (k,), x.dtype),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@keras_export(["keras.ops.lu_factor", "keras.ops.linalg.lu_factor"])
|
| 219 |
+
def lu_factor(x):
|
| 220 |
+
"""Computes the lower-upper decomposition of a square matrix.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
x: A tensor of shape `(..., M, M)`.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
A tuple of two tensors: a tensor of shape `(..., M, M)` containing the
|
| 227 |
+
lower and upper triangular matrices and a tensor of shape `(..., M)`
|
| 228 |
+
containing the pivots.
|
| 229 |
+
|
| 230 |
+
"""
|
| 231 |
+
if any_symbolic_tensors((x,)):
|
| 232 |
+
return LuFactor().symbolic_call(x)
|
| 233 |
+
return _lu_factor(x)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _lu_factor(x):
|
| 237 |
+
x = backend.convert_to_tensor(x)
|
| 238 |
+
_assert_2d(x)
|
| 239 |
+
if backend.backend() == "tensorflow":
|
| 240 |
+
try:
|
| 241 |
+
_assert_square(x)
|
| 242 |
+
except ValueError as e:
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"LU decomposition failed: {e}. LU decomposition is only "
|
| 245 |
+
"supported for square matrices in Tensorflow."
|
| 246 |
+
)
|
| 247 |
+
return backend.linalg.lu_factor(x)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class Norm(Operation):
|
| 251 |
+
def __init__(self, ord=None, axis=None, keepdims=False):
|
| 252 |
+
super().__init__()
|
| 253 |
+
if isinstance(ord, str):
|
| 254 |
+
if ord not in ("fro", "nuc"):
|
| 255 |
+
raise ValueError(
|
| 256 |
+
"Invalid `ord` argument. "
|
| 257 |
+
"Expected one of {'fro', 'nuc'} when using string. "
|
| 258 |
+
f"Received: ord={ord}"
|
| 259 |
+
)
|
| 260 |
+
if isinstance(axis, int):
|
| 261 |
+
axis = [axis]
|
| 262 |
+
self.ord = ord
|
| 263 |
+
self.axis = axis
|
| 264 |
+
self.keepdims = keepdims
|
| 265 |
+
|
| 266 |
+
def compute_output_spec(self, x):
|
| 267 |
+
output_dtype = backend.standardize_dtype(x.dtype)
|
| 268 |
+
if "int" in output_dtype or output_dtype == "bool":
|
| 269 |
+
output_dtype = backend.floatx()
|
| 270 |
+
if self.axis is None:
|
| 271 |
+
axis = tuple(range(len(x.shape)))
|
| 272 |
+
else:
|
| 273 |
+
axis = self.axis
|
| 274 |
+
num_axes = len(axis)
|
| 275 |
+
if num_axes == 1 and isinstance(self.ord, str):
|
| 276 |
+
raise ValueError(
|
| 277 |
+
"Invalid `ord` argument for vector norm. "
|
| 278 |
+
f"Received: ord={self.ord}"
|
| 279 |
+
)
|
| 280 |
+
elif num_axes == 2 and self.ord not in (
|
| 281 |
+
None,
|
| 282 |
+
"fro",
|
| 283 |
+
"nuc",
|
| 284 |
+
float("inf"),
|
| 285 |
+
float("-inf"),
|
| 286 |
+
1,
|
| 287 |
+
-1,
|
| 288 |
+
2,
|
| 289 |
+
-2,
|
| 290 |
+
):
|
| 291 |
+
raise ValueError(
|
| 292 |
+
"Invalid `ord` argument for matrix norm. "
|
| 293 |
+
f"Received: ord={self.ord}"
|
| 294 |
+
)
|
| 295 |
+
return KerasTensor(
|
| 296 |
+
reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),
|
| 297 |
+
dtype=output_dtype,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def call(self, x):
|
| 301 |
+
x = backend.convert_to_tensor(x)
|
| 302 |
+
return backend.linalg.norm(
|
| 303 |
+
x, ord=self.ord, axis=self.axis, keepdims=self.keepdims
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@keras_export(["keras.ops.norm", "keras.ops.linalg.norm"])
|
| 308 |
+
def norm(x, ord=None, axis=None, keepdims=False):
|
| 309 |
+
"""Matrix or vector norm.
|
| 310 |
+
|
| 311 |
+
This function is able to return one of eight different matrix norms, or one
|
| 312 |
+
of an infinite number of vector norms (described below), depending on the
|
| 313 |
+
value of the `ord` parameter.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
x: Input tensor.
|
| 317 |
+
ord: Order of the norm (see table under Notes). The default is `None`.
|
| 318 |
+
axis: If `axis` is an integer, it specifies the axis of `x` along which
|
| 319 |
+
to compute the vector norms. If `axis` is a 2-tuple, it specifies
|
| 320 |
+
the axes that hold 2-D matrices, and the matrix norms of these
|
| 321 |
+
matrices are computed.
|
| 322 |
+
keepdims: If this is set to `True`, the axes which are reduced are left
|
| 323 |
+
in the result as dimensions with size one.
|
| 324 |
+
|
| 325 |
+
Note:
|
| 326 |
+
For values of `ord < 1`, the result is, strictly speaking, not a
|
| 327 |
+
mathematical 'norm', but it may still be useful for various numerical
|
| 328 |
+
purposes. The following norms can be calculated:
|
| 329 |
+
- For matrices:
|
| 330 |
+
- `ord=None`: Frobenius norm
|
| 331 |
+
- `ord="fro"`: Frobenius norm
|
| 332 |
+
- `ord="nuc"`: nuclear norm
|
| 333 |
+
- `ord=np.inf`: `max(sum(abs(x), axis=1))`
|
| 334 |
+
- `ord=-np.inf`: `min(sum(abs(x), axis=1))`
|
| 335 |
+
- `ord=0`: not supported
|
| 336 |
+
- `ord=1`: `max(sum(abs(x), axis=0))`
|
| 337 |
+
- `ord=-1`: `min(sum(abs(x), axis=0))`
|
| 338 |
+
- `ord=2`: 2-norm (largest sing. value)
|
| 339 |
+
- `ord=-2`: smallest singular value
|
| 340 |
+
- other: not supported
|
| 341 |
+
- For vectors:
|
| 342 |
+
- `ord=None`: 2-norm
|
| 343 |
+
- `ord="fro"`: not supported
|
| 344 |
+
- `ord="nuc"`: not supported
|
| 345 |
+
- `ord=np.inf`: `max(abs(x))`
|
| 346 |
+
- `ord=-np.inf`: `min(abs(x))`
|
| 347 |
+
- `ord=0`: `sum(x != 0)`
|
| 348 |
+
- `ord=1`: as below
|
| 349 |
+
- `ord=-1`: as below
|
| 350 |
+
- `ord=2`: as below
|
| 351 |
+
- `ord=-2`: as below
|
| 352 |
+
- other: `sum(abs(x)**ord)**(1./ord)`
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
Norm of the matrix or vector(s).
|
| 356 |
+
|
| 357 |
+
Example:
|
| 358 |
+
|
| 359 |
+
>>> x = keras.ops.reshape(keras.ops.arange(9, dtype="float32") - 4, (3, 3))
|
| 360 |
+
>>> keras.ops.linalg.norm(x)
|
| 361 |
+
7.7459664
|
| 362 |
+
"""
|
| 363 |
+
if any_symbolic_tensors((x,)):
|
| 364 |
+
return Norm(ord=ord, axis=axis, keepdims=keepdims).symbolic_call(x)
|
| 365 |
+
x = backend.convert_to_tensor(x)
|
| 366 |
+
return backend.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class Qr(Operation):
|
| 370 |
+
def __init__(self, mode="reduced"):
|
| 371 |
+
super().__init__()
|
| 372 |
+
if mode not in {"reduced", "complete"}:
|
| 373 |
+
raise ValueError(
|
| 374 |
+
"`mode` argument value not supported. "
|
| 375 |
+
"Expected one of {'reduced', 'complete'}. "
|
| 376 |
+
f"Received: mode={mode}"
|
| 377 |
+
)
|
| 378 |
+
self.mode = mode
|
| 379 |
+
|
| 380 |
+
def compute_output_spec(self, x):
|
| 381 |
+
if len(x.shape) < 2:
|
| 382 |
+
raise ValueError(
|
| 383 |
+
"Input should have rank >= 2. Received: "
|
| 384 |
+
f"input.shape = {x.shape}"
|
| 385 |
+
)
|
| 386 |
+
m = x.shape[-2]
|
| 387 |
+
n = x.shape[-1]
|
| 388 |
+
if m is None or n is None:
|
| 389 |
+
raise ValueError(
|
| 390 |
+
"Input should have its last 2 dimensions "
|
| 391 |
+
"fully-defined. Received: "
|
| 392 |
+
f"input.shape = {x.shape}"
|
| 393 |
+
)
|
| 394 |
+
k = min(m, n)
|
| 395 |
+
base = tuple(x.shape[:-2])
|
| 396 |
+
if self.mode == "reduced":
|
| 397 |
+
return (
|
| 398 |
+
KerasTensor(shape=base + (m, k), dtype=x.dtype),
|
| 399 |
+
KerasTensor(shape=base + (k, n), dtype=x.dtype),
|
| 400 |
+
)
|
| 401 |
+
# 'complete' mode.
|
| 402 |
+
return (
|
| 403 |
+
KerasTensor(shape=base + (m, m), dtype=x.dtype),
|
| 404 |
+
KerasTensor(shape=base + (m, n), dtype=x.dtype),
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
def call(self, x):
|
| 408 |
+
x = backend.convert_to_tensor(x)
|
| 409 |
+
return backend.linalg.qr(x, mode=self.mode)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
@keras_export(["keras.ops.qr", "keras.ops.linalg.qr"])
|
| 413 |
+
def qr(x, mode="reduced"):
|
| 414 |
+
"""Computes the QR decomposition of a tensor.
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
x: Input tensor of shape `(..., M, N)`.
|
| 418 |
+
mode: A string specifying the mode of the QR decomposition.
|
| 419 |
+
- 'reduced': Returns the reduced QR decomposition. (default)
|
| 420 |
+
- 'complete': Returns the complete QR decomposition.
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
A tuple containing two tensors. The first tensor of shape `(..., M, K)`
|
| 424 |
+
is the orthogonal matrix `q` and the second tensor of shape
|
| 425 |
+
`(..., K, N)` is the upper triangular matrix `r`, where `K = min(M, N)`.
|
| 426 |
+
|
| 427 |
+
Example:
|
| 428 |
+
|
| 429 |
+
>>> x = keras.ops.convert_to_tensor([[1., 2.], [3., 4.], [5., 6.]])
|
| 430 |
+
>>> q, r = qr(x)
|
| 431 |
+
>>> print(q)
|
| 432 |
+
array([[-0.16903079 0.897085]
|
| 433 |
+
[-0.5070925 0.2760267 ]
|
| 434 |
+
[-0.8451542 -0.34503305]], shape=(3, 2), dtype=float32)
|
| 435 |
+
"""
|
| 436 |
+
if any_symbolic_tensors((x,)):
|
| 437 |
+
return Qr(mode=mode).symbolic_call(x)
|
| 438 |
+
x = backend.convert_to_tensor(x)
|
| 439 |
+
return backend.linalg.qr(x, mode=mode)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class Solve(Operation):
|
| 443 |
+
def __init__(self):
|
| 444 |
+
super().__init__()
|
| 445 |
+
|
| 446 |
+
def call(self, a, b):
|
| 447 |
+
return _solve(a, b)
|
| 448 |
+
|
| 449 |
+
def compute_output_spec(self, a, b):
|
| 450 |
+
_assert_2d(a)
|
| 451 |
+
_assert_square(a)
|
| 452 |
+
_assert_1d(b)
|
| 453 |
+
_assert_a_b_compat(a, b)
|
| 454 |
+
return KerasTensor(b.shape, b.dtype)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
@keras_export(["keras.ops.solve", "keras.ops.linalg.solve"])
|
| 458 |
+
def solve(a, b):
|
| 459 |
+
"""Solves a linear system of equations given by `a x = b`.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
a: A tensor of shape `(..., M, M)` representing the coefficients matrix.
|
| 463 |
+
b: A tensor of shape `(..., M)` or `(..., M, N)` representing the
|
| 464 |
+
right-hand side or "dependent variable" matrix.
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
A tensor of shape `(..., M)` or `(..., M, N)` representing the solution
|
| 468 |
+
of the linear system. Returned shape is identical to `b`.
|
| 469 |
+
|
| 470 |
+
"""
|
| 471 |
+
if any_symbolic_tensors((a, b)):
|
| 472 |
+
return Solve().symbolic_call(a, b)
|
| 473 |
+
return _solve(a, b)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def _solve(a, b):
|
| 477 |
+
a = backend.convert_to_tensor(a)
|
| 478 |
+
b = backend.convert_to_tensor(b)
|
| 479 |
+
_assert_2d(a)
|
| 480 |
+
_assert_square(a)
|
| 481 |
+
_assert_1d(b)
|
| 482 |
+
_assert_a_b_compat(a, b)
|
| 483 |
+
return backend.linalg.solve(a, b)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class SolveTriangular(Operation):
|
| 487 |
+
def __init__(self, lower=False):
|
| 488 |
+
super().__init__()
|
| 489 |
+
self.lower = lower
|
| 490 |
+
|
| 491 |
+
def call(self, a, b):
|
| 492 |
+
return _solve_triangular(a, b, self.lower)
|
| 493 |
+
|
| 494 |
+
def compute_output_spec(self, a, b):
|
| 495 |
+
_assert_2d(a)
|
| 496 |
+
_assert_square(a)
|
| 497 |
+
_assert_1d(b)
|
| 498 |
+
_assert_a_b_compat(a, b)
|
| 499 |
+
return KerasTensor(b.shape, b.dtype)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
@keras_export(
|
| 503 |
+
["keras.ops.solve_triangular", "keras.ops.linalg.solve_triangular"]
|
| 504 |
+
)
|
| 505 |
+
def solve_triangular(a, b, lower=False):
|
| 506 |
+
"""Solves a linear system of equations given by `a x = b`.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
a: A tensor of shape `(..., M, M)` representing the coefficients matrix.
|
| 510 |
+
b: A tensor of shape `(..., M)` or `(..., M, N)` representing the
|
| 511 |
+
right-hand side or "dependent variable" matrix.
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
A tensor of shape `(..., M)` or `(..., M, N)` representing the solution
|
| 515 |
+
of the linear system. Returned shape is identical to `b`.
|
| 516 |
+
|
| 517 |
+
"""
|
| 518 |
+
if any_symbolic_tensors((a, b)):
|
| 519 |
+
return SolveTriangular(lower).symbolic_call(a, b)
|
| 520 |
+
return _solve_triangular(a, b, lower)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def _solve_triangular(a, b, lower=False):
|
| 524 |
+
a = backend.convert_to_tensor(a)
|
| 525 |
+
b = backend.convert_to_tensor(b)
|
| 526 |
+
_assert_2d(a)
|
| 527 |
+
_assert_square(a)
|
| 528 |
+
_assert_1d(b)
|
| 529 |
+
_assert_a_b_compat(a, b)
|
| 530 |
+
return backend.linalg.solve_triangular(a, b, lower)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class SVD(Operation):
|
| 534 |
+
def __init__(self, full_matrices=True, compute_uv=True):
|
| 535 |
+
super().__init__()
|
| 536 |
+
self.full_matrices = full_matrices
|
| 537 |
+
self.compute_uv = compute_uv
|
| 538 |
+
|
| 539 |
+
def call(self, x):
|
| 540 |
+
return _svd(x, self.full_matrices, self.compute_uv)
|
| 541 |
+
|
| 542 |
+
def compute_output_spec(self, x):
|
| 543 |
+
_assert_2d(x)
|
| 544 |
+
rows, columns = x.shape[-2:]
|
| 545 |
+
batches = x.shape[:-2]
|
| 546 |
+
s_shape = batches + (min(rows, columns),)
|
| 547 |
+
if self.full_matrices:
|
| 548 |
+
u_shape = batches + (rows, rows)
|
| 549 |
+
v_shape = batches + (columns, columns)
|
| 550 |
+
else:
|
| 551 |
+
u_shape = batches + (rows, min(rows, columns))
|
| 552 |
+
v_shape = batches + (min(rows, columns), columns)
|
| 553 |
+
|
| 554 |
+
if self.compute_uv:
|
| 555 |
+
return (
|
| 556 |
+
KerasTensor(u_shape, x.dtype),
|
| 557 |
+
KerasTensor(s_shape, x.dtype),
|
| 558 |
+
KerasTensor(v_shape, x.dtype),
|
| 559 |
+
)
|
| 560 |
+
return KerasTensor(s_shape, x.dtype)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
@keras_export(["keras.ops.svd", "keras.ops.linalg.svd"])
|
| 564 |
+
def svd(x, full_matrices=True, compute_uv=True):
|
| 565 |
+
"""Computes the singular value decomposition of a matrix.
|
| 566 |
+
|
| 567 |
+
Args:
|
| 568 |
+
x: Input tensor of shape `(..., M, N)`.
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
A tuple of three tensors: a tensor of shape `(..., M, M)` containing the
|
| 572 |
+
left singular vectors, a tensor of shape `(..., M, N)` containing the
|
| 573 |
+
singular values and a tensor of shape `(..., N, N)` containing the
|
| 574 |
+
right singular vectors.
|
| 575 |
+
|
| 576 |
+
"""
|
| 577 |
+
if any_symbolic_tensors((x,)):
|
| 578 |
+
return SVD(full_matrices, compute_uv).symbolic_call(x)
|
| 579 |
+
return _svd(x, full_matrices, compute_uv)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def _svd(x, full_matrices=True, compute_uv=True):
|
| 583 |
+
x = backend.convert_to_tensor(x)
|
| 584 |
+
_assert_2d(x)
|
| 585 |
+
return backend.linalg.svd(x, full_matrices, compute_uv)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
class Lstsq(Operation):
|
| 589 |
+
def __init__(self, rcond=None):
|
| 590 |
+
super().__init__()
|
| 591 |
+
self.rcond = rcond
|
| 592 |
+
|
| 593 |
+
def call(self, a, b):
|
| 594 |
+
return backend.linalg.lstsq(a, b, rcond=self.rcond)
|
| 595 |
+
|
| 596 |
+
def compute_output_spec(self, a, b):
|
| 597 |
+
if len(a.shape) != 2:
|
| 598 |
+
raise ValueError(
|
| 599 |
+
"Expected a to have rank 2. " f"Received: a.shape={a.shape}"
|
| 600 |
+
)
|
| 601 |
+
if len(b.shape) not in (1, 2):
|
| 602 |
+
raise ValueError(
|
| 603 |
+
"Expected b to have rank 1 or 2. "
|
| 604 |
+
f"Received: b.shape={b.shape}"
|
| 605 |
+
)
|
| 606 |
+
m, n = a.shape
|
| 607 |
+
if b.shape[0] != m:
|
| 608 |
+
raise ValueError(
|
| 609 |
+
"Expected b.shape[0] to be equal to "
|
| 610 |
+
"a.shape[0]. Received: "
|
| 611 |
+
f"a.shape={a.shape}, b.shape={b.shape}"
|
| 612 |
+
)
|
| 613 |
+
if len(b.shape) == 2:
|
| 614 |
+
k = b.shape[1]
|
| 615 |
+
x = KerasTensor((n, k), dtype=a.dtype)
|
| 616 |
+
else:
|
| 617 |
+
x = KerasTensor((n,), dtype=a.dtype)
|
| 618 |
+
return x
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
@keras_export(["keras.ops.lstsq", "keras.ops.linalg.lstsq"])
|
| 622 |
+
def lstsq(a, b, rcond=None):
|
| 623 |
+
"""Return the least-squares solution to a linear matrix equation.
|
| 624 |
+
|
| 625 |
+
Computes the vector x that approximately solves the equation
|
| 626 |
+
`a @ x = b`. The equation may be under-, well-, or over-determined
|
| 627 |
+
(i.e., the number of linearly independent rows of a can be less than,
|
| 628 |
+
equal to, or greater than its number of linearly independent columns).
|
| 629 |
+
If a is square and of full rank, then `x` (but for round-off error)
|
| 630 |
+
is the exact solution of the equation. Else, `x` minimizes the
|
| 631 |
+
L2 norm of `b - a * x`.
|
| 632 |
+
|
| 633 |
+
If there are multiple minimizing solutions,
|
| 634 |
+
the one with the smallest L2 norm is returned.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
a: "Coefficient" matrix of shape `(M, N)`.
|
| 638 |
+
b: Ordinate or "dependent variable" values,
|
| 639 |
+
of shape `(M,)` or `(M, K)`.
|
| 640 |
+
If `b` is two-dimensional, the least-squares solution
|
| 641 |
+
is calculated for each of the K columns of `b`.
|
| 642 |
+
rcond: Cut-off ratio for small singular values of `a`.
|
| 643 |
+
For the purposes of rank determination,
|
| 644 |
+
singular values are treated as zero if they are
|
| 645 |
+
smaller than rcond times the largest
|
| 646 |
+
singular value of `a`.
|
| 647 |
+
|
| 648 |
+
Returns:
|
| 649 |
+
Tensor with shape `(N,)` or `(N, K)` containing
|
| 650 |
+
the least-squares solutions.
|
| 651 |
+
|
| 652 |
+
**NOTE:** The output differs from `numpy.linalg.lstsq`.
|
| 653 |
+
NumPy returns a tuple with four elements, the first of which
|
| 654 |
+
being the least-squares solutions and the others
|
| 655 |
+
being essentially never used.
|
| 656 |
+
Keras only returns the first value. This is done both
|
| 657 |
+
to ensure consistency across backends (which cannot be achieved
|
| 658 |
+
for the other values) and to simplify the API.
|
| 659 |
+
"""
|
| 660 |
+
if any_symbolic_tensors((a, b)):
|
| 661 |
+
return Lstsq(rcond=rcond).symbolic_call(a, b)
|
| 662 |
+
return backend.linalg.lstsq(a, b, rcond=rcond)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def _assert_1d(*arrays):
|
| 666 |
+
for a in arrays:
|
| 667 |
+
if a.ndim < 1:
|
| 668 |
+
raise ValueError(
|
| 669 |
+
"Expected input to have rank >= 1. "
|
| 670 |
+
"Received scalar input {a}."
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def _assert_2d(*arrays):
|
| 675 |
+
for a in arrays:
|
| 676 |
+
if a.ndim < 2:
|
| 677 |
+
raise ValueError(
|
| 678 |
+
"Expected input to have rank >= 2. "
|
| 679 |
+
"Received input with shape {a.shape}."
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def _assert_square(*arrays):
|
| 684 |
+
for a in arrays:
|
| 685 |
+
m, n = a.shape[-2:]
|
| 686 |
+
if m != n:
|
| 687 |
+
raise ValueError(
|
| 688 |
+
"Expected a square matrix. "
|
| 689 |
+
f"Received non-square input with shape {a.shape}"
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def _assert_a_b_compat(a, b):
|
| 694 |
+
if a.ndim == b.ndim:
|
| 695 |
+
if a.shape[-2] != b.shape[-2]:
|
| 696 |
+
raise ValueError(
|
| 697 |
+
"Incompatible shapes between `a` and `b`. "
|
| 698 |
+
"Expected `a.shape[-2] == b.shape[-2]`. "
|
| 699 |
+
f"Received: a.shape={a.shape}, b.shape={b.shape}"
|
| 700 |
+
)
|
| 701 |
+
elif a.ndim == b.ndim - 1:
|
| 702 |
+
if a.shape[-1] != b.shape[-1]:
|
| 703 |
+
raise ValueError(
|
| 704 |
+
"Incompatible shapes between `a` and `b`. "
|
| 705 |
+
"Expected `a.shape[-1] == b.shape[-1]`. "
|
| 706 |
+
f"Received: a.shape={a.shape}, b.shape={b.shape}"
|
| 707 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/math.py
ADDED
|
@@ -0,0 +1,1046 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Commonly used math operations not included in NumPy."""
|
| 2 |
+
|
| 3 |
+
from keras.src import backend
|
| 4 |
+
from keras.src.api_export import keras_export
|
| 5 |
+
from keras.src.backend import KerasTensor
|
| 6 |
+
from keras.src.backend import any_symbolic_tensors
|
| 7 |
+
from keras.src.ops.operation import Operation
|
| 8 |
+
from keras.src.ops.operation_utils import reduce_shape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _segment_reduce_validation(data, segment_ids):
|
| 12 |
+
data_shape = data.shape
|
| 13 |
+
segment_ids_shape = segment_ids.shape
|
| 14 |
+
if len(segment_ids_shape) > 1:
|
| 15 |
+
raise ValueError(
|
| 16 |
+
"Argument `segment_ids` should be an 1-D vector, got shape: "
|
| 17 |
+
f"{len(segment_ids_shape)}. Consider either flatten input with "
|
| 18 |
+
"segment_ids.reshape((-1)) and "
|
| 19 |
+
"data.reshape((-1, ) + data.shape[len(segment_ids.shape):]) or "
|
| 20 |
+
"vectorize with vmap."
|
| 21 |
+
)
|
| 22 |
+
if (
|
| 23 |
+
segment_ids_shape[0] is not None
|
| 24 |
+
and data_shape[0] is not None
|
| 25 |
+
and segment_ids_shape[0] != data_shape[0]
|
| 26 |
+
):
|
| 27 |
+
raise ValueError(
|
| 28 |
+
"Argument `segment_ids` and `data` should have same leading "
|
| 29 |
+
f"dimension. Got {segment_ids_shape} v.s. "
|
| 30 |
+
f"{data_shape}."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class SegmentReduction(Operation):
|
| 35 |
+
def __init__(self, num_segments=None, sorted=False):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.num_segments = num_segments
|
| 38 |
+
self.sorted = sorted
|
| 39 |
+
|
| 40 |
+
def compute_output_spec(self, data, _):
|
| 41 |
+
output_shape = (self.num_segments,) + tuple(data.shape[1:])
|
| 42 |
+
return KerasTensor(shape=output_shape, dtype=data.dtype)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SegmentSum(SegmentReduction):
|
| 46 |
+
def call(self, data, segment_ids):
|
| 47 |
+
_segment_reduce_validation(data, segment_ids)
|
| 48 |
+
return backend.math.segment_sum(
|
| 49 |
+
data,
|
| 50 |
+
segment_ids,
|
| 51 |
+
num_segments=self.num_segments,
|
| 52 |
+
sorted=self.sorted,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@keras_export("keras.ops.segment_sum")
|
| 57 |
+
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
| 58 |
+
"""Computes the sum of segments in a tensor.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
data: Input tensor.
|
| 62 |
+
segment_ids: A N-D tensor containing segment indices for each
|
| 63 |
+
element in `data`. Num dims for segment ids should be strictly
|
| 64 |
+
smaller or equal to number of dims in data.
|
| 65 |
+
num_segments: An integer representing the total number of
|
| 66 |
+
segments. If not specified, it is inferred from the maximum
|
| 67 |
+
value in `segment_ids`.
|
| 68 |
+
sorted: A boolean indicating whether `segment_ids` is sorted.
|
| 69 |
+
Defaults to `False`.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
A tensor containing the sum of segments, where each element
|
| 73 |
+
represents the sum of the corresponding segment in `data`.
|
| 74 |
+
|
| 75 |
+
Example:
|
| 76 |
+
|
| 77 |
+
>>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
|
| 78 |
+
>>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
|
| 79 |
+
>>> num_segments = 3
|
| 80 |
+
>>> keras.ops.segment_sum(data, segment_ids,num_segments)
|
| 81 |
+
array([3, 30, 300], dtype=int32)
|
| 82 |
+
"""
|
| 83 |
+
_segment_reduce_validation(data, segment_ids)
|
| 84 |
+
if any_symbolic_tensors((data,)):
|
| 85 |
+
return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids)
|
| 86 |
+
return backend.math.segment_sum(
|
| 87 |
+
data, segment_ids, num_segments=num_segments, sorted=sorted
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class SegmentMax(SegmentReduction):
|
| 92 |
+
def call(self, data, segment_ids):
|
| 93 |
+
_segment_reduce_validation(data, segment_ids)
|
| 94 |
+
return backend.math.segment_max(
|
| 95 |
+
data,
|
| 96 |
+
segment_ids,
|
| 97 |
+
num_segments=self.num_segments,
|
| 98 |
+
sorted=self.sorted,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@keras_export("keras.ops.segment_max")
|
| 103 |
+
def segment_max(data, segment_ids, num_segments=None, sorted=False):
|
| 104 |
+
"""Computes the max of segments in a tensor.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
data: Input tensor.
|
| 108 |
+
segment_ids: A N-D tensor containing segment indices for each
|
| 109 |
+
element in `data`. data.shape[:len(segment_ids.shape)] should match.
|
| 110 |
+
num_segments: An integer representing the total number of
|
| 111 |
+
segments. If not specified, it is inferred from the maximum
|
| 112 |
+
value in `segment_ids`.
|
| 113 |
+
sorted: A boolean indicating whether `segment_ids` is sorted.
|
| 114 |
+
Defaults to `False`.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
A tensor containing the max of segments, where each element
|
| 118 |
+
represents the max of the corresponding segment in `data`.
|
| 119 |
+
|
| 120 |
+
Example:
|
| 121 |
+
|
| 122 |
+
>>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
|
| 123 |
+
>>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
|
| 124 |
+
>>> num_segments = 3
|
| 125 |
+
>>> keras.ops.segment_max(data, segment_ids, num_segments)
|
| 126 |
+
array([2, 20, 200], dtype=int32)
|
| 127 |
+
"""
|
| 128 |
+
_segment_reduce_validation(data, segment_ids)
|
| 129 |
+
if any_symbolic_tensors((data,)):
|
| 130 |
+
return SegmentMax(num_segments, sorted).symbolic_call(data, segment_ids)
|
| 131 |
+
return backend.math.segment_max(
|
| 132 |
+
data, segment_ids, num_segments=num_segments, sorted=sorted
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class TopK(Operation):
|
| 137 |
+
def __init__(self, k, sorted=False):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.k = k
|
| 140 |
+
self.sorted = sorted
|
| 141 |
+
|
| 142 |
+
def compute_output_spec(self, x):
|
| 143 |
+
output_shape = list(x.shape)
|
| 144 |
+
output_shape[-1] = self.k
|
| 145 |
+
# Return a tuple (values, indices).
|
| 146 |
+
return (
|
| 147 |
+
KerasTensor(shape=output_shape, dtype=x.dtype),
|
| 148 |
+
KerasTensor(shape=output_shape, dtype="int32"),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def call(self, x):
|
| 152 |
+
return backend.math.top_k(x, self.k, self.sorted)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@keras_export("keras.ops.top_k")
|
| 156 |
+
def top_k(x, k, sorted=True):
|
| 157 |
+
"""Finds the top-k values and their indices in a tensor.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
x: Input tensor.
|
| 161 |
+
k: An integer representing the number of top elements to retrieve.
|
| 162 |
+
sorted: A boolean indicating whether to sort the output in
|
| 163 |
+
descending order. Defaults to `True`.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
A tuple containing two tensors. The first tensor contains the
|
| 167 |
+
top-k values, and the second tensor contains the indices of the
|
| 168 |
+
top-k values in the input tensor.
|
| 169 |
+
|
| 170 |
+
Example:
|
| 171 |
+
|
| 172 |
+
>>> x = keras.ops.convert_to_tensor([5, 2, 7, 1, 9, 3])
|
| 173 |
+
>>> values, indices = top_k(x, k=3)
|
| 174 |
+
>>> print(values)
|
| 175 |
+
array([9 7 5], shape=(3,), dtype=int32)
|
| 176 |
+
>>> print(indices)
|
| 177 |
+
array([4 2 0], shape=(3,), dtype=int32)
|
| 178 |
+
|
| 179 |
+
"""
|
| 180 |
+
if any_symbolic_tensors((x,)):
|
| 181 |
+
return TopK(k, sorted).symbolic_call(x)
|
| 182 |
+
return backend.math.top_k(x, k, sorted)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class InTopK(Operation):
|
| 186 |
+
def __init__(self, k):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.k = k
|
| 189 |
+
|
| 190 |
+
def compute_output_spec(self, targets, predictions):
|
| 191 |
+
return KerasTensor(shape=targets.shape, dtype="bool")
|
| 192 |
+
|
| 193 |
+
def call(self, targets, predictions):
|
| 194 |
+
return backend.math.in_top_k(targets, predictions, self.k)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@keras_export("keras.ops.in_top_k")
|
| 198 |
+
def in_top_k(targets, predictions, k):
|
| 199 |
+
"""Checks if the targets are in the top-k predictions.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
targets: A tensor of true labels.
|
| 203 |
+
predictions: A tensor of predicted labels.
|
| 204 |
+
k: An integer representing the number of predictions to consider.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
A boolean tensor of the same shape as `targets`, where each element
|
| 208 |
+
indicates whether the corresponding target is in the top-k predictions.
|
| 209 |
+
|
| 210 |
+
Example:
|
| 211 |
+
|
| 212 |
+
>>> targets = keras.ops.convert_to_tensor([2, 5, 3])
|
| 213 |
+
>>> predictions = keras.ops.convert_to_tensor(
|
| 214 |
+
... [[0.1, 0.4, 0.6, 0.9, 0.5],
|
| 215 |
+
... [0.1, 0.7, 0.9, 0.8, 0.3],
|
| 216 |
+
... [0.1, 0.6, 0.9, 0.9, 0.5]])
|
| 217 |
+
>>> in_top_k(targets, predictions, k=3)
|
| 218 |
+
array([ True False True], shape=(3,), dtype=bool)
|
| 219 |
+
"""
|
| 220 |
+
if any_symbolic_tensors((targets, predictions)):
|
| 221 |
+
return InTopK(k).symbolic_call(targets, predictions)
|
| 222 |
+
return backend.math.in_top_k(targets, predictions, k)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class Logsumexp(Operation):
|
| 226 |
+
def __init__(self, axis=None, keepdims=False):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.axis = axis
|
| 229 |
+
self.keepdims = keepdims
|
| 230 |
+
|
| 231 |
+
def compute_output_spec(self, x):
|
| 232 |
+
output_shape = reduce_shape(x.shape, self.axis, self.keepdims)
|
| 233 |
+
return KerasTensor(shape=output_shape)
|
| 234 |
+
|
| 235 |
+
def call(self, x):
|
| 236 |
+
return backend.math.logsumexp(x, axis=self.axis, keepdims=self.keepdims)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@keras_export("keras.ops.logsumexp")
|
| 240 |
+
def logsumexp(x, axis=None, keepdims=False):
|
| 241 |
+
"""Computes the logarithm of sum of exponentials of elements in a tensor.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
x: Input tensor.
|
| 245 |
+
axis: An integer or a tuple of integers specifying the axis/axes
|
| 246 |
+
along which to compute the sum. If `None`, the sum is computed
|
| 247 |
+
over all elements. Defaults to `None`.
|
| 248 |
+
keepdims: A boolean indicating whether to keep the dimensions of
|
| 249 |
+
the input tensor when computing the sum. Defaults to `False`.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
A tensor containing the logarithm of the sum of exponentials of
|
| 253 |
+
elements in `x`.
|
| 254 |
+
|
| 255 |
+
Example:
|
| 256 |
+
|
| 257 |
+
>>> x = keras.ops.convert_to_tensor([1., 2., 3.])
|
| 258 |
+
>>> logsumexp(x)
|
| 259 |
+
3.407606
|
| 260 |
+
"""
|
| 261 |
+
if any_symbolic_tensors((x,)):
|
| 262 |
+
return Logsumexp(axis, keepdims).symbolic_call(x)
|
| 263 |
+
return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class ExtractSequences(Operation):
|
| 267 |
+
def __init__(self, sequence_length, sequence_stride):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.sequence_length = sequence_length
|
| 270 |
+
self.sequence_stride = sequence_stride
|
| 271 |
+
|
| 272 |
+
def compute_output_spec(self, x):
|
| 273 |
+
if len(x.shape) < 1:
|
| 274 |
+
raise ValueError(
|
| 275 |
+
f"Input should have rank >= 1. "
|
| 276 |
+
f"Received: input.shape = {x.shape}"
|
| 277 |
+
)
|
| 278 |
+
if x.shape[-1] is not None:
|
| 279 |
+
num_sequences = (
|
| 280 |
+
1 + (x.shape[-1] - self.sequence_length) // self.sequence_stride
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
num_sequences = None
|
| 284 |
+
new_shape = x.shape[:-1] + (num_sequences, self.sequence_length)
|
| 285 |
+
return KerasTensor(shape=new_shape, dtype=x.dtype)
|
| 286 |
+
|
| 287 |
+
def call(self, x):
|
| 288 |
+
return backend.math.extract_sequences(
|
| 289 |
+
x,
|
| 290 |
+
sequence_length=self.sequence_length,
|
| 291 |
+
sequence_stride=self.sequence_stride,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@keras_export("keras.ops.extract_sequences")
|
| 296 |
+
def extract_sequences(x, sequence_length, sequence_stride):
|
| 297 |
+
"""Expands the dimension of last axis into sequences of `sequence_length`.
|
| 298 |
+
|
| 299 |
+
Slides a window of size `sequence_length` over the last axis of the input
|
| 300 |
+
with a stride of `sequence_stride`, replacing the last axis with
|
| 301 |
+
`[num_sequences, sequence_length]` sequences.
|
| 302 |
+
|
| 303 |
+
If the dimension along the last axis is N, the number of sequences can be
|
| 304 |
+
computed by:
|
| 305 |
+
|
| 306 |
+
`num_sequences = 1 + (N - sequence_length) // sequence_stride`
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
x: Input tensor.
|
| 310 |
+
sequence_length: An integer representing the sequences length.
|
| 311 |
+
sequence_stride: An integer representing the sequences hop size.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
A tensor of sequences with shape [..., num_sequences, sequence_length].
|
| 315 |
+
|
| 316 |
+
Example:
|
| 317 |
+
|
| 318 |
+
>>> x = keras.ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
|
| 319 |
+
>>> extract_sequences(x, 3, 2)
|
| 320 |
+
array([[1, 2, 3],
|
| 321 |
+
[3, 4, 5]])
|
| 322 |
+
"""
|
| 323 |
+
if any_symbolic_tensors((x,)):
|
| 324 |
+
return ExtractSequences(sequence_length, sequence_stride).symbolic_call(
|
| 325 |
+
x
|
| 326 |
+
)
|
| 327 |
+
return backend.math.extract_sequences(x, sequence_length, sequence_stride)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class FFT(Operation):
|
| 331 |
+
def __init__(self, axis=-1):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.axis = axis
|
| 334 |
+
|
| 335 |
+
def compute_output_spec(self, x):
|
| 336 |
+
if not isinstance(x, (tuple, list)) or len(x) != 2:
|
| 337 |
+
raise ValueError(
|
| 338 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 339 |
+
f"imaginary. Received: x={x}"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
real, imag = x
|
| 343 |
+
# Both real and imaginary parts should have the same shape.
|
| 344 |
+
if real.shape != imag.shape:
|
| 345 |
+
raise ValueError(
|
| 346 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 347 |
+
"imaginary. Both the real and imaginary parts should have the "
|
| 348 |
+
f"same shape. Received: x[0].shape = {real.shape}, "
|
| 349 |
+
f"x[1].shape = {imag.shape}"
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# We are calculating 1D FFT. Hence, rank >= 1.
|
| 353 |
+
if len(real.shape) < 1:
|
| 354 |
+
raise ValueError(
|
| 355 |
+
f"Input should have rank >= 1. "
|
| 356 |
+
f"Received: input.shape = {real.shape}"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# The axis along which we are calculating FFT should be fully-defined.
|
| 360 |
+
m = real.shape[-1]
|
| 361 |
+
if m is None:
|
| 362 |
+
raise ValueError(
|
| 363 |
+
f"Input should have its {self.axis}th axis fully-defined. "
|
| 364 |
+
f"Received: input.shape = {real.shape}"
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
return (
|
| 368 |
+
KerasTensor(shape=real.shape, dtype=real.dtype),
|
| 369 |
+
KerasTensor(shape=imag.shape, dtype=imag.dtype),
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def call(self, x):
|
| 373 |
+
return backend.math.fft(x)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@keras_export("keras.ops.fft")
|
| 377 |
+
def fft(x):
|
| 378 |
+
"""Computes the Fast Fourier Transform along last axis of input.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
x: Tuple of the real and imaginary parts of the input tensor. Both
|
| 382 |
+
tensors in the tuple should be of floating type.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
A tuple containing two tensors - the real and imaginary parts of the
|
| 386 |
+
output tensor.
|
| 387 |
+
|
| 388 |
+
Example:
|
| 389 |
+
|
| 390 |
+
>>> x = (
|
| 391 |
+
... keras.ops.convert_to_tensor([1., 2.]),
|
| 392 |
+
... keras.ops.convert_to_tensor([0., 1.]),
|
| 393 |
+
... )
|
| 394 |
+
>>> fft(x)
|
| 395 |
+
(array([ 3., -1.], dtype=float32), array([ 1., -1.], dtype=float32))
|
| 396 |
+
"""
|
| 397 |
+
if any_symbolic_tensors(x):
|
| 398 |
+
return FFT().symbolic_call(x)
|
| 399 |
+
return backend.math.fft(x)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class FFT2(Operation):
|
| 403 |
+
def __init__(self):
|
| 404 |
+
super().__init__()
|
| 405 |
+
self.axes = (-2, -1)
|
| 406 |
+
|
| 407 |
+
def compute_output_spec(self, x):
|
| 408 |
+
if not isinstance(x, (tuple, list)) or len(x) != 2:
|
| 409 |
+
raise ValueError(
|
| 410 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 411 |
+
f"imaginary. Received: x={x}"
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
real, imag = x
|
| 415 |
+
# Both real and imaginary parts should have the same shape.
|
| 416 |
+
if real.shape != imag.shape:
|
| 417 |
+
raise ValueError(
|
| 418 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 419 |
+
"imaginary. Both the real and imaginary parts should have the "
|
| 420 |
+
f"same shape. Received: x[0].shape = {real.shape}, "
|
| 421 |
+
f"x[1].shape = {imag.shape}"
|
| 422 |
+
)
|
| 423 |
+
# We are calculating 2D FFT. Hence, rank >= 2.
|
| 424 |
+
if len(real.shape) < 2:
|
| 425 |
+
raise ValueError(
|
| 426 |
+
f"Input should have rank >= 2. "
|
| 427 |
+
f"Received: input.shape = {real.shape}"
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# The axes along which we are calculating FFT should be fully-defined.
|
| 431 |
+
m = real.shape[self.axes[0]]
|
| 432 |
+
n = real.shape[self.axes[1]]
|
| 433 |
+
if m is None or n is None:
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"Input should have its {self.axes} axes fully-defined. "
|
| 436 |
+
f"Received: input.shape = {real.shape}"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
return (
|
| 440 |
+
KerasTensor(shape=real.shape, dtype=real.dtype),
|
| 441 |
+
KerasTensor(shape=imag.shape, dtype=imag.dtype),
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
def call(self, x):
|
| 445 |
+
return backend.math.fft2(x)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@keras_export("keras.ops.fft2")
|
| 449 |
+
def fft2(x):
|
| 450 |
+
"""Computes the 2D Fast Fourier Transform along the last two axes of input.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
x: Tuple of the real and imaginary parts of the input tensor. Both
|
| 454 |
+
tensors in the tuple should be of floating type.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
A tuple containing two tensors - the real and imaginary parts of the
|
| 458 |
+
output.
|
| 459 |
+
|
| 460 |
+
Example:
|
| 461 |
+
|
| 462 |
+
>>> x = (
|
| 463 |
+
... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]),
|
| 464 |
+
... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]),
|
| 465 |
+
... )
|
| 466 |
+
>>> fft2(x)
|
| 467 |
+
(array([[ 6., 0.],
|
| 468 |
+
[ 0., -2.]], dtype=float32), array([[ 2., 0.],
|
| 469 |
+
[ 0., -2.]], dtype=float32))
|
| 470 |
+
"""
|
| 471 |
+
if any_symbolic_tensors(x):
|
| 472 |
+
return FFT2().symbolic_call(x)
|
| 473 |
+
return backend.math.fft2(x)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class IFFT2(Operation):
|
| 477 |
+
def __init__(self):
|
| 478 |
+
super().__init__()
|
| 479 |
+
self.axes = (-2, -1)
|
| 480 |
+
|
| 481 |
+
def compute_output_spec(self, x):
|
| 482 |
+
if not isinstance(x, (tuple, list)) or len(x) != 2:
|
| 483 |
+
raise ValueError(
|
| 484 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 485 |
+
f"imaginary. Received: x={x}"
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
real, imag = x
|
| 489 |
+
# Both real and imaginary parts should have the same shape.
|
| 490 |
+
if real.shape != imag.shape:
|
| 491 |
+
raise ValueError(
|
| 492 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 493 |
+
"imaginary. Both the real and imaginary parts should have the "
|
| 494 |
+
f"same shape. Received: x[0].shape = {real.shape}, "
|
| 495 |
+
f"x[1].shape = {imag.shape}"
|
| 496 |
+
)
|
| 497 |
+
# We are calculating 2D IFFT. Hence, rank >= 2.
|
| 498 |
+
if len(real.shape) < 2:
|
| 499 |
+
raise ValueError(
|
| 500 |
+
f"Input should have rank >= 2. "
|
| 501 |
+
f"Received: input.shape = {real.shape}"
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# The axes along which we are calculating IFFT should be fully-defined.
|
| 505 |
+
m = real.shape[self.axes[0]]
|
| 506 |
+
n = real.shape[self.axes[1]]
|
| 507 |
+
if m is None or n is None:
|
| 508 |
+
raise ValueError(
|
| 509 |
+
f"Input should have its {self.axes} axes fully-defined. "
|
| 510 |
+
f"Received: input.shape = {real.shape}"
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
return (
|
| 514 |
+
KerasTensor(shape=real.shape, dtype=real.dtype),
|
| 515 |
+
KerasTensor(shape=imag.shape, dtype=imag.dtype),
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
def call(self, x):
|
| 519 |
+
return backend.math.ifft2(x)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
@keras_export("keras.ops.ifft2")
|
| 523 |
+
def ifft2(x):
|
| 524 |
+
"""Computes the 2D Inverse Fast Fourier Transform along the last two axes of
|
| 525 |
+
input.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
x: Tuple of the real and imaginary parts of the input tensor. Both
|
| 529 |
+
tensors in the tuple should be of floating type.
|
| 530 |
+
|
| 531 |
+
Returns:
|
| 532 |
+
A tuple containing two tensors - the real and imaginary parts of the
|
| 533 |
+
output.
|
| 534 |
+
|
| 535 |
+
Example:
|
| 536 |
+
|
| 537 |
+
>>> x = (
|
| 538 |
+
... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]),
|
| 539 |
+
... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]),
|
| 540 |
+
... )
|
| 541 |
+
>>> ifft2(x)
|
| 542 |
+
(array([[ 6., 0.],
|
| 543 |
+
[ 0., -2.]], dtype=float32), array([[ 2., 0.],
|
| 544 |
+
[ 0., -2.]], dtype=float32))
|
| 545 |
+
"""
|
| 546 |
+
if any_symbolic_tensors(x):
|
| 547 |
+
return IFFT2().symbolic_call(x)
|
| 548 |
+
return backend.math.ifft2(x)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
class RFFT(Operation):
|
| 552 |
+
def __init__(self, fft_length=None):
|
| 553 |
+
super().__init__()
|
| 554 |
+
self.fft_length = fft_length
|
| 555 |
+
|
| 556 |
+
def compute_output_spec(self, x):
|
| 557 |
+
# We are calculating 1D RFFT. Hence, rank >= 1.
|
| 558 |
+
if len(x.shape) < 1:
|
| 559 |
+
raise ValueError(
|
| 560 |
+
f"Input should have rank >= 1. "
|
| 561 |
+
f"Received: input.shape = {x.shape}"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
if self.fft_length is not None:
|
| 565 |
+
new_last_dimension = self.fft_length // 2 + 1
|
| 566 |
+
else:
|
| 567 |
+
if x.shape[-1] is not None:
|
| 568 |
+
new_last_dimension = x.shape[-1] // 2 + 1
|
| 569 |
+
else:
|
| 570 |
+
new_last_dimension = None
|
| 571 |
+
new_shape = x.shape[:-1] + (new_last_dimension,)
|
| 572 |
+
|
| 573 |
+
return (
|
| 574 |
+
KerasTensor(shape=new_shape, dtype=x.dtype),
|
| 575 |
+
KerasTensor(shape=new_shape, dtype=x.dtype),
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
def call(self, x):
|
| 579 |
+
return backend.math.rfft(x, fft_length=self.fft_length)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
@keras_export("keras.ops.rfft")
|
| 583 |
+
def rfft(x, fft_length=None):
|
| 584 |
+
"""Real-valued Fast Fourier Transform along the last axis of the input.
|
| 585 |
+
|
| 586 |
+
Computes the 1D Discrete Fourier Transform of a real-valued signal over the
|
| 587 |
+
inner-most dimension of input.
|
| 588 |
+
|
| 589 |
+
Since the Discrete Fourier Transform of a real-valued signal is
|
| 590 |
+
Hermitian-symmetric, RFFT only returns the `fft_length / 2 + 1` unique
|
| 591 |
+
components of the FFT: the zero-frequency term, followed by the
|
| 592 |
+
`fft_length / 2` positive-frequency terms.
|
| 593 |
+
|
| 594 |
+
Along the axis RFFT is computed on, if `fft_length` is smaller than the
|
| 595 |
+
corresponding dimension of the input, the dimension is cropped. If it is
|
| 596 |
+
larger, the dimension is padded with zeros.
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
x: Input tensor.
|
| 600 |
+
fft_length: An integer representing the number of the fft length. If not
|
| 601 |
+
specified, it is inferred from the length of the last axis of `x`.
|
| 602 |
+
Defaults to `None`.
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
A tuple containing two tensors - the real and imaginary parts of the
|
| 606 |
+
output.
|
| 607 |
+
|
| 608 |
+
Examples:
|
| 609 |
+
|
| 610 |
+
>>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
|
| 611 |
+
>>> rfft(x)
|
| 612 |
+
(array([10.0, -2.5, -2.5]), array([0.0, 3.4409548, 0.81229924]))
|
| 613 |
+
|
| 614 |
+
>>> rfft(x, 3)
|
| 615 |
+
(array([3.0, -1.5]), array([0.0, 0.8660254]))
|
| 616 |
+
"""
|
| 617 |
+
if any_symbolic_tensors((x,)):
|
| 618 |
+
return RFFT(fft_length).symbolic_call(x)
|
| 619 |
+
return backend.math.rfft(x, fft_length)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
class IRFFT(Operation):
|
| 623 |
+
def __init__(self, fft_length=None):
|
| 624 |
+
super().__init__()
|
| 625 |
+
self.fft_length = fft_length
|
| 626 |
+
|
| 627 |
+
def compute_output_spec(self, x):
|
| 628 |
+
if not isinstance(x, (tuple, list)) or len(x) != 2:
|
| 629 |
+
raise ValueError(
|
| 630 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 631 |
+
f"imaginary. Received: x={x}"
|
| 632 |
+
)
|
| 633 |
+
real, imag = x
|
| 634 |
+
# Both real and imaginary parts should have the same shape.
|
| 635 |
+
if real.shape != imag.shape:
|
| 636 |
+
raise ValueError(
|
| 637 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 638 |
+
"imaginary. Both the real and imaginary parts should have the "
|
| 639 |
+
f"same shape. Received: x[0].shape = {real.shape}, "
|
| 640 |
+
f"x[1].shape = {imag.shape}"
|
| 641 |
+
)
|
| 642 |
+
# We are calculating 1D IRFFT. Hence, rank >= 1.
|
| 643 |
+
if len(real.shape) < 1:
|
| 644 |
+
raise ValueError(
|
| 645 |
+
f"Input should have rank >= 1. "
|
| 646 |
+
f"Received: input.shape = {real.shape}"
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
if self.fft_length is not None:
|
| 650 |
+
new_last_dimension = self.fft_length
|
| 651 |
+
else:
|
| 652 |
+
if real.shape[-1] is not None:
|
| 653 |
+
new_last_dimension = 2 * (real.shape[-1] - 1)
|
| 654 |
+
else:
|
| 655 |
+
new_last_dimension = None
|
| 656 |
+
new_shape = real.shape[:-1] + (new_last_dimension,)
|
| 657 |
+
return KerasTensor(shape=new_shape, dtype=real.dtype)
|
| 658 |
+
|
| 659 |
+
def call(self, x):
|
| 660 |
+
return backend.math.irfft(x, fft_length=self.fft_length)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
@keras_export("keras.ops.irfft")
|
| 664 |
+
def irfft(x, fft_length=None):
|
| 665 |
+
"""Inverse real-valued Fast Fourier transform along the last axis.
|
| 666 |
+
|
| 667 |
+
Computes the inverse 1D Discrete Fourier Transform of a real-valued signal
|
| 668 |
+
over the inner-most dimension of input.
|
| 669 |
+
|
| 670 |
+
The inner-most dimension of the input is assumed to be the result of RFFT:
|
| 671 |
+
the `fft_length / 2 + 1` unique components of the DFT of a real-valued
|
| 672 |
+
signal. If `fft_length` is not provided, it is computed from the size of the
|
| 673 |
+
inner-most dimension of the input `(fft_length = 2 * (inner - 1))`. If the
|
| 674 |
+
FFT length used to compute is odd, it should be provided since it cannot
|
| 675 |
+
be inferred properly.
|
| 676 |
+
|
| 677 |
+
Along the axis IRFFT is computed on, if `fft_length / 2 + 1` is smaller than
|
| 678 |
+
the corresponding dimension of the input, the dimension is cropped. If it is
|
| 679 |
+
larger, the dimension is padded with zeros.
|
| 680 |
+
|
| 681 |
+
Args:
|
| 682 |
+
x: Tuple of the real and imaginary parts of the input tensor. Both
|
| 683 |
+
tensors in the tuple should be of floating type.
|
| 684 |
+
fft_length: An integer representing the number of the fft length. If not
|
| 685 |
+
specified, it is inferred from the length of the last axis of `x`.
|
| 686 |
+
Defaults to `None`.
|
| 687 |
+
|
| 688 |
+
Returns:
|
| 689 |
+
A tensor containing the inverse real-valued Fast Fourier Transform
|
| 690 |
+
along the last axis of `x`.
|
| 691 |
+
|
| 692 |
+
Examples:
|
| 693 |
+
|
| 694 |
+
>>> real = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
|
| 695 |
+
>>> imag = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
|
| 696 |
+
>>> irfft((real, imag))
|
| 697 |
+
array([0.66666667, -0.9106836, 0.24401694])
|
| 698 |
+
|
| 699 |
+
>>> irfft(rfft(real, 5), 5)
|
| 700 |
+
array([0.0, 1.0, 2.0, 3.0, 4.0])
|
| 701 |
+
"""
|
| 702 |
+
if any_symbolic_tensors(x):
|
| 703 |
+
return IRFFT(fft_length).symbolic_call(x)
|
| 704 |
+
return backend.math.irfft(x, fft_length)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class STFT(Operation):
|
| 708 |
+
def __init__(
|
| 709 |
+
self,
|
| 710 |
+
sequence_length,
|
| 711 |
+
sequence_stride,
|
| 712 |
+
fft_length,
|
| 713 |
+
window="hann",
|
| 714 |
+
center=True,
|
| 715 |
+
):
|
| 716 |
+
super().__init__()
|
| 717 |
+
self.sequence_length = sequence_length
|
| 718 |
+
self.sequence_stride = sequence_stride
|
| 719 |
+
self.fft_length = fft_length
|
| 720 |
+
self.window = window
|
| 721 |
+
self.center = center
|
| 722 |
+
|
| 723 |
+
def compute_output_spec(self, x):
|
| 724 |
+
if x.shape[-1] is not None:
|
| 725 |
+
padded = 0 if self.center is False else (self.fft_length // 2) * 2
|
| 726 |
+
num_sequences = (
|
| 727 |
+
1
|
| 728 |
+
+ (x.shape[-1] + padded - self.fft_length)
|
| 729 |
+
// self.sequence_stride
|
| 730 |
+
)
|
| 731 |
+
else:
|
| 732 |
+
num_sequences = None
|
| 733 |
+
new_shape = x.shape[:-1] + (num_sequences, self.fft_length // 2 + 1)
|
| 734 |
+
return (
|
| 735 |
+
KerasTensor(shape=new_shape, dtype=x.dtype),
|
| 736 |
+
KerasTensor(shape=new_shape, dtype=x.dtype),
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
def call(self, x):
|
| 740 |
+
return backend.math.stft(
|
| 741 |
+
x,
|
| 742 |
+
sequence_length=self.sequence_length,
|
| 743 |
+
sequence_stride=self.sequence_stride,
|
| 744 |
+
fft_length=self.fft_length,
|
| 745 |
+
window=self.window,
|
| 746 |
+
center=self.center,
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
@keras_export("keras.ops.stft")
|
| 751 |
+
def stft(
|
| 752 |
+
x, sequence_length, sequence_stride, fft_length, window="hann", center=True
|
| 753 |
+
):
|
| 754 |
+
"""Short-Time Fourier Transform along the last axis of the input.
|
| 755 |
+
|
| 756 |
+
The STFT computes the Fourier transform of short overlapping windows of the
|
| 757 |
+
input. This giving frequency components of the signal as they change over
|
| 758 |
+
time.
|
| 759 |
+
|
| 760 |
+
Args:
|
| 761 |
+
x: Input tensor.
|
| 762 |
+
sequence_length: An integer representing the sequence length.
|
| 763 |
+
sequence_stride: An integer representing the sequence hop size.
|
| 764 |
+
fft_length: An integer representing the size of the FFT to apply. If not
|
| 765 |
+
specified, uses the smallest power of 2 enclosing `sequence_length`.
|
| 766 |
+
window: A string, a tensor of the window or `None`. If `window` is a
|
| 767 |
+
string, available values are `"hann"` and `"hamming"`. If `window`
|
| 768 |
+
is a tensor, it will be used directly as the window and its length
|
| 769 |
+
must be `sequence_length`. If `window` is `None`, no windowing is
|
| 770 |
+
used. Defaults to `"hann"`.
|
| 771 |
+
center: Whether to pad `x` on both sides so that the t-th sequence is
|
| 772 |
+
centered at time `t * sequence_stride`. Otherwise, the t-th sequence
|
| 773 |
+
begins at time `t * sequence_stride`. Defaults to `True`.
|
| 774 |
+
|
| 775 |
+
Returns:
|
| 776 |
+
A tuple containing two tensors - the real and imaginary parts of the
|
| 777 |
+
STFT output.
|
| 778 |
+
|
| 779 |
+
Example:
|
| 780 |
+
|
| 781 |
+
>>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
|
| 782 |
+
>>> stft(x, 3, 2, 3)
|
| 783 |
+
(array([[0.75, -0.375],
|
| 784 |
+
[3.75, -1.875],
|
| 785 |
+
[5.25, -2.625]]), array([[0.0, 0.64951905],
|
| 786 |
+
[0.0, 0.64951905],
|
| 787 |
+
[0.0, -0.64951905]]))
|
| 788 |
+
"""
|
| 789 |
+
if any_symbolic_tensors((x,)):
|
| 790 |
+
return STFT(
|
| 791 |
+
sequence_length=sequence_length,
|
| 792 |
+
sequence_stride=sequence_stride,
|
| 793 |
+
fft_length=fft_length,
|
| 794 |
+
window=window,
|
| 795 |
+
center=center,
|
| 796 |
+
).symbolic_call(x)
|
| 797 |
+
return backend.math.stft(
|
| 798 |
+
x,
|
| 799 |
+
sequence_length=sequence_length,
|
| 800 |
+
sequence_stride=sequence_stride,
|
| 801 |
+
fft_length=fft_length,
|
| 802 |
+
window=window,
|
| 803 |
+
center=center,
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
class ISTFT(Operation):
|
| 808 |
+
def __init__(
|
| 809 |
+
self,
|
| 810 |
+
sequence_length,
|
| 811 |
+
sequence_stride,
|
| 812 |
+
fft_length,
|
| 813 |
+
length=None,
|
| 814 |
+
window="hann",
|
| 815 |
+
center=True,
|
| 816 |
+
):
|
| 817 |
+
super().__init__()
|
| 818 |
+
self.sequence_length = sequence_length
|
| 819 |
+
self.sequence_stride = sequence_stride
|
| 820 |
+
self.fft_length = fft_length
|
| 821 |
+
self.length = length
|
| 822 |
+
self.window = window
|
| 823 |
+
self.center = center
|
| 824 |
+
|
| 825 |
+
def compute_output_spec(self, x):
|
| 826 |
+
if not isinstance(x, (tuple, list)) or len(x) != 2:
|
| 827 |
+
raise ValueError(
|
| 828 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 829 |
+
f"imaginary. Received: x={x}"
|
| 830 |
+
)
|
| 831 |
+
real, imag = x
|
| 832 |
+
# Both real and imaginary parts should have the same shape.
|
| 833 |
+
if real.shape != imag.shape:
|
| 834 |
+
raise ValueError(
|
| 835 |
+
"Input `x` should be a tuple of two tensors - real and "
|
| 836 |
+
"imaginary. Both the real and imaginary parts should have the "
|
| 837 |
+
f"same shape. Received: x[0].shape = {real.shape}, "
|
| 838 |
+
f"x[1].shape = {imag.shape}"
|
| 839 |
+
)
|
| 840 |
+
if len(real.shape) < 2:
|
| 841 |
+
raise ValueError(
|
| 842 |
+
f"Input should have rank >= 2. "
|
| 843 |
+
f"Received: input.shape = {real.shape}"
|
| 844 |
+
)
|
| 845 |
+
if real.shape[-2] is not None:
|
| 846 |
+
output_size = (
|
| 847 |
+
real.shape[-2] - 1
|
| 848 |
+
) * self.sequence_stride + self.fft_length
|
| 849 |
+
if self.length is not None:
|
| 850 |
+
output_size = self.length
|
| 851 |
+
elif self.center:
|
| 852 |
+
output_size = output_size - (self.fft_length // 2) * 2
|
| 853 |
+
else:
|
| 854 |
+
output_size = None
|
| 855 |
+
new_shape = real.shape[:-2] + (output_size,)
|
| 856 |
+
return KerasTensor(shape=new_shape, dtype=real.dtype)
|
| 857 |
+
|
| 858 |
+
def call(self, x):
|
| 859 |
+
return backend.math.istft(
|
| 860 |
+
x,
|
| 861 |
+
sequence_length=self.sequence_length,
|
| 862 |
+
sequence_stride=self.sequence_stride,
|
| 863 |
+
fft_length=self.fft_length,
|
| 864 |
+
length=self.length,
|
| 865 |
+
window=self.window,
|
| 866 |
+
center=self.center,
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
@keras_export("keras.ops.istft")
|
| 871 |
+
def istft(
|
| 872 |
+
x,
|
| 873 |
+
sequence_length,
|
| 874 |
+
sequence_stride,
|
| 875 |
+
fft_length,
|
| 876 |
+
length=None,
|
| 877 |
+
window="hann",
|
| 878 |
+
center=True,
|
| 879 |
+
):
|
| 880 |
+
"""Inverse Short-Time Fourier Transform along the last axis of the input.
|
| 881 |
+
|
| 882 |
+
To reconstruct an original waveform, the parameters should be the same in
|
| 883 |
+
`stft`.
|
| 884 |
+
|
| 885 |
+
Args:
|
| 886 |
+
x: Tuple of the real and imaginary parts of the input tensor. Both
|
| 887 |
+
tensors in the tuple should be of floating type.
|
| 888 |
+
sequence_length: An integer representing the sequence length.
|
| 889 |
+
sequence_stride: An integer representing the sequence hop size.
|
| 890 |
+
fft_length: An integer representing the size of the FFT that produced
|
| 891 |
+
`stft`. Should be of type `int32`.
|
| 892 |
+
length: An integer representing the output is clipped to exactly length.
|
| 893 |
+
If not specified, no padding or clipping take place. Defaults to
|
| 894 |
+
`None`.
|
| 895 |
+
window: A string, a tensor of the window or `None`. If `window` is a
|
| 896 |
+
string, available values are `"hann"` and `"hamming"`. If `window`
|
| 897 |
+
is a tensor, it will be used directly as the window and its length
|
| 898 |
+
must be `sequence_length`. If `window` is `None`, no windowing is
|
| 899 |
+
used. Defaults to `"hann"`.
|
| 900 |
+
center: Whether `x` was padded on both sides so that the t-th sequence
|
| 901 |
+
is centered at time `t * sequence_stride`. Defaults to `True`.
|
| 902 |
+
|
| 903 |
+
Returns:
|
| 904 |
+
A tensor containing the inverse Short-Time Fourier Transform along the
|
| 905 |
+
last axis of `x`.
|
| 906 |
+
|
| 907 |
+
Example:
|
| 908 |
+
|
| 909 |
+
>>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
|
| 910 |
+
>>> istft(stft(x, 1, 1, 1), 1, 1, 1)
|
| 911 |
+
array([0.0, 1.0, 2.0, 3.0, 4.0])
|
| 912 |
+
"""
|
| 913 |
+
if any_symbolic_tensors(x):
|
| 914 |
+
return ISTFT(
|
| 915 |
+
sequence_length=sequence_length,
|
| 916 |
+
sequence_stride=sequence_stride,
|
| 917 |
+
fft_length=fft_length,
|
| 918 |
+
window=window,
|
| 919 |
+
center=center,
|
| 920 |
+
).symbolic_call(x)
|
| 921 |
+
return backend.math.istft(
|
| 922 |
+
x,
|
| 923 |
+
sequence_length=sequence_length,
|
| 924 |
+
sequence_stride=sequence_stride,
|
| 925 |
+
fft_length=fft_length,
|
| 926 |
+
length=length,
|
| 927 |
+
window=window,
|
| 928 |
+
center=center,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
class Rsqrt(Operation):
|
| 933 |
+
def call(self, x):
|
| 934 |
+
x = backend.convert_to_tensor(x)
|
| 935 |
+
return backend.math.rsqrt(x)
|
| 936 |
+
|
| 937 |
+
def compute_output_spec(self, x):
|
| 938 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
@keras_export("keras.ops.rsqrt")
|
| 942 |
+
def rsqrt(x):
|
| 943 |
+
"""Computes reciprocal of square root of x element-wise.
|
| 944 |
+
|
| 945 |
+
Args:
|
| 946 |
+
x: input tensor
|
| 947 |
+
|
| 948 |
+
Returns:
|
| 949 |
+
A tensor with the same dtype as `x`.
|
| 950 |
+
|
| 951 |
+
Example:
|
| 952 |
+
|
| 953 |
+
>>> x = keras.ops.convert_to_tensor([1.0, 10.0, 100.0])
|
| 954 |
+
>>> keras.ops.rsqrt(x)
|
| 955 |
+
array([1.0, 0.31622776, 0.1], dtype=float32)
|
| 956 |
+
"""
|
| 957 |
+
if any_symbolic_tensors((x,)):
|
| 958 |
+
return Rsqrt().symbolic_call(x)
|
| 959 |
+
x = backend.convert_to_tensor(x)
|
| 960 |
+
return backend.math.rsqrt(x)
|
| 961 |
+
|
| 962 |
+
|
| 963 |
+
class Erf(Operation):
|
| 964 |
+
def compute_output_spec(self, x):
|
| 965 |
+
return KerasTensor(shape=x.shape, dtype=x.dtype)
|
| 966 |
+
|
| 967 |
+
def call(self, x):
|
| 968 |
+
return backend.math.erf(x)
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
@keras_export("keras.ops.erf")
|
| 972 |
+
def erf(x):
|
| 973 |
+
"""Computes the error function of `x`, element-wise.
|
| 974 |
+
|
| 975 |
+
Args:
|
| 976 |
+
x: Input tensor.
|
| 977 |
+
|
| 978 |
+
Returns:
|
| 979 |
+
A tensor with the same dtype as `x`.
|
| 980 |
+
|
| 981 |
+
Example:
|
| 982 |
+
|
| 983 |
+
>>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0])
|
| 984 |
+
>>> keras.ops.erf(x)
|
| 985 |
+
array([-0.99998 , -0.99532, -0.842701, 0., 0.842701], dtype=float32)
|
| 986 |
+
"""
|
| 987 |
+
if any_symbolic_tensors((x,)):
|
| 988 |
+
return Erf().symbolic_call(x)
|
| 989 |
+
x = backend.convert_to_tensor(x)
|
| 990 |
+
return backend.math.erf(x)
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
class Erfinv(Operation):
|
| 994 |
+
def compute_output_spec(self, x):
|
| 995 |
+
return KerasTensor(shape=x.shape, dtype=x.dtype)
|
| 996 |
+
|
| 997 |
+
def call(self, x):
|
| 998 |
+
return backend.math.erfinv(x)
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
@keras_export("keras.ops.erfinv")
|
| 1002 |
+
def erfinv(x):
|
| 1003 |
+
"""Computes the inverse error function of `x`, element-wise.
|
| 1004 |
+
|
| 1005 |
+
Args:
|
| 1006 |
+
x: Input tensor.
|
| 1007 |
+
|
| 1008 |
+
Returns:
|
| 1009 |
+
A tensor with the same dtype as `x`.
|
| 1010 |
+
|
| 1011 |
+
Example:
|
| 1012 |
+
|
| 1013 |
+
>>> x = np.array([-0.5, -0.2, -0.1, 0.0, 0.3])
|
| 1014 |
+
>>> keras.ops.erfinv(x)
|
| 1015 |
+
array([-0.47694, -0.17914, -0.08886, 0. , 0.27246], dtype=float32)
|
| 1016 |
+
"""
|
| 1017 |
+
if any_symbolic_tensors((x,)):
|
| 1018 |
+
return Erfinv().symbolic_call(x)
|
| 1019 |
+
x = backend.convert_to_tensor(x)
|
| 1020 |
+
return backend.math.erfinv(x)
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
class Logdet(Operation):
|
| 1024 |
+
def __init__(self):
|
| 1025 |
+
super().__init__()
|
| 1026 |
+
|
| 1027 |
+
def call(self, x):
|
| 1028 |
+
return backend.math.logdet(x)
|
| 1029 |
+
|
| 1030 |
+
def compute_output_spec(self, x):
|
| 1031 |
+
return KerasTensor(x.shape[:-2], dtype=x.dtype)
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
@keras_export(["keras.ops.logdet"])
|
| 1035 |
+
def logdet(x):
|
| 1036 |
+
"""Computes log of the determinant of a hermitian positive definite matrix.
|
| 1037 |
+
|
| 1038 |
+
Args:
|
| 1039 |
+
x: Input matrix. It must 2D and square.
|
| 1040 |
+
|
| 1041 |
+
Returns:
|
| 1042 |
+
The natural log of the determinant of matrix.
|
| 1043 |
+
"""
|
| 1044 |
+
if any_symbolic_tensors((x,)):
|
| 1045 |
+
return Logdet().symbolic_call(x)
|
| 1046 |
+
return backend.math.logdet(x)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/nn.py
ADDED
|
@@ -0,0 +1,2653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Commonly-used neural network operations not included in NumPy."""
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
from keras.src import backend
|
| 6 |
+
from keras.src.api_export import keras_export
|
| 7 |
+
from keras.src.backend import KerasTensor
|
| 8 |
+
from keras.src.backend import any_symbolic_tensors
|
| 9 |
+
from keras.src.backend import standardize_data_format
|
| 10 |
+
from keras.src.backend.common.backend_utils import (
|
| 11 |
+
compute_conv_transpose_output_shape,
|
| 12 |
+
)
|
| 13 |
+
from keras.src.ops import operation_utils
|
| 14 |
+
from keras.src.ops.operation import Operation
|
| 15 |
+
from keras.src.ops.operation_utils import reduce_shape
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Relu(Operation):
|
| 19 |
+
def call(self, x):
|
| 20 |
+
return backend.nn.relu(x)
|
| 21 |
+
|
| 22 |
+
def compute_output_spec(self, x):
|
| 23 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@keras_export(["keras.ops.relu", "keras.ops.nn.relu"])
|
| 27 |
+
def relu(x):
|
| 28 |
+
"""Rectified linear unit activation function.
|
| 29 |
+
|
| 30 |
+
It is defined as `f(x) = max(0, x)`.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
x: Input tensor.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
A tensor with the same shape as `x`.
|
| 37 |
+
|
| 38 |
+
Example:
|
| 39 |
+
|
| 40 |
+
>>> x1 = keras.ops.convert_to_tensor([-1.0, 0.0, 1.0, 0.2])
|
| 41 |
+
>>> keras.ops.relu(x1)
|
| 42 |
+
array([0.0, 0.0, 1.0, 0.2], dtype=float32)
|
| 43 |
+
"""
|
| 44 |
+
if any_symbolic_tensors((x,)):
|
| 45 |
+
return Relu().symbolic_call(x)
|
| 46 |
+
return backend.nn.relu(x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Relu6(Operation):
|
| 50 |
+
def call(self, x):
|
| 51 |
+
return backend.nn.relu6(x)
|
| 52 |
+
|
| 53 |
+
def compute_output_spec(self, x):
|
| 54 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@keras_export(["keras.ops.relu6", "keras.ops.nn.relu6"])
|
| 58 |
+
def relu6(x):
|
| 59 |
+
"""Rectified linear unit activation function with upper bound of 6.
|
| 60 |
+
|
| 61 |
+
It is defined as `f(x) = np.clip(x, 0, 6)`.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
x: Input tensor.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
A tensor with the same shape as `x`.
|
| 68 |
+
|
| 69 |
+
Example:
|
| 70 |
+
|
| 71 |
+
>>> x = keras.ops.convert_to_tensor([-3.0, -2.0, 0.1, 0.2, 6.0, 8.0])
|
| 72 |
+
>>> keras.ops.relu6(x)
|
| 73 |
+
array([0.0, 0.0, 0.1, 0.2, 6.0, 6.0], dtype=float32)
|
| 74 |
+
"""
|
| 75 |
+
if any_symbolic_tensors((x,)):
|
| 76 |
+
return Relu6().symbolic_call(x)
|
| 77 |
+
return backend.nn.relu6(x)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Sigmoid(Operation):
|
| 81 |
+
def call(self, x):
|
| 82 |
+
return backend.nn.sigmoid(x)
|
| 83 |
+
|
| 84 |
+
def compute_output_spec(self, x):
|
| 85 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@keras_export(["keras.ops.sigmoid", "keras.ops.nn.sigmoid"])
|
| 89 |
+
def sigmoid(x):
|
| 90 |
+
"""Sigmoid activation function.
|
| 91 |
+
|
| 92 |
+
It is defined as `f(x) = 1 / (1 + exp(-x))`.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
x: Input tensor.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
A tensor with the same shape as `x`.
|
| 99 |
+
|
| 100 |
+
Example:
|
| 101 |
+
|
| 102 |
+
>>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0])
|
| 103 |
+
>>> keras.ops.sigmoid(x)
|
| 104 |
+
array([0.00247262, 0.7310586, 0.5, 0.7310586, 0.9975274], dtype=float32)
|
| 105 |
+
|
| 106 |
+
"""
|
| 107 |
+
if any_symbolic_tensors((x,)):
|
| 108 |
+
return Sigmoid().symbolic_call(x)
|
| 109 |
+
return backend.nn.sigmoid(x)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Softplus(Operation):
|
| 113 |
+
def call(self, x):
|
| 114 |
+
return backend.nn.softplus(x)
|
| 115 |
+
|
| 116 |
+
def compute_output_spec(self, x):
|
| 117 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@keras_export(["keras.ops.softplus", "keras.ops.nn.softplus"])
|
| 121 |
+
def softplus(x):
|
| 122 |
+
"""Softplus activation function.
|
| 123 |
+
|
| 124 |
+
It is defined as `f(x) = log(exp(x) + 1)`, where `log` is the natural
|
| 125 |
+
logarithm and `exp` is the exponential function.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
x: Input tensor.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
A tensor with the same shape as `x`.
|
| 132 |
+
|
| 133 |
+
Example:
|
| 134 |
+
|
| 135 |
+
>>> x = keras.ops.convert_to_tensor([-0.555, 0.0, 0.555])
|
| 136 |
+
>>> keras.ops.softplus(x)
|
| 137 |
+
array([0.45366603, 0.6931472, 1.008666], dtype=float32)
|
| 138 |
+
|
| 139 |
+
"""
|
| 140 |
+
if any_symbolic_tensors((x,)):
|
| 141 |
+
return Softplus().symbolic_call(x)
|
| 142 |
+
return backend.nn.softplus(x)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Softsign(Operation):
|
| 146 |
+
def call(self, x):
|
| 147 |
+
return backend.nn.softsign(x)
|
| 148 |
+
|
| 149 |
+
def compute_output_spec(self, x):
|
| 150 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@keras_export(["keras.ops.softsign", "keras.ops.nn.softsign"])
|
| 154 |
+
def softsign(x):
|
| 155 |
+
"""Softsign activation function.
|
| 156 |
+
|
| 157 |
+
It is defined as `f(x) = x / (abs(x) + 1)`.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
x: Input tensor.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
A tensor with the same shape as `x`.
|
| 164 |
+
|
| 165 |
+
Example:
|
| 166 |
+
|
| 167 |
+
>>> x = keras.ops.convert_to_tensor([-0.100, -10.0, 1.0, 0.0, 100.0])
|
| 168 |
+
>>> keras.ops.softsign(x)
|
| 169 |
+
Array([-0.09090909, -0.90909094, 0.5, 0.0, 0.990099], dtype=float32)
|
| 170 |
+
|
| 171 |
+
"""
|
| 172 |
+
if any_symbolic_tensors((x,)):
|
| 173 |
+
return Softsign().symbolic_call(x)
|
| 174 |
+
return backend.nn.softsign(x)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class SoftShrink(Operation):
|
| 178 |
+
def __init__(self, threshold=0.5):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.threshold = threshold
|
| 181 |
+
|
| 182 |
+
def call(self, x):
|
| 183 |
+
return backend.nn.soft_shrink(x, self.threshold)
|
| 184 |
+
|
| 185 |
+
def compute_output_spec(self, x):
|
| 186 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@keras_export(["keras.ops.soft_shrink", "keras.ops.nn.soft_shrink"])
|
| 190 |
+
def soft_shrink(x, threshold=0.5):
|
| 191 |
+
"""Soft Shrink activation function.
|
| 192 |
+
|
| 193 |
+
It is defined as
|
| 194 |
+
|
| 195 |
+
`f(x) = x - threshold` if `x > threshold`,
|
| 196 |
+
`f(x) = x + threshold` if `x < -threshold`,
|
| 197 |
+
`f(x) = 0` otherwise.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
x: Input tensor.
|
| 201 |
+
threshold: Threshold value. Defaults to 0.5.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
A tensor with the same shape as `x`.
|
| 205 |
+
|
| 206 |
+
Example:
|
| 207 |
+
|
| 208 |
+
>>> x = np.array([-1.0, 0.0, 1.0])
|
| 209 |
+
>>> x_soft_shrink = keras.ops.soft_shrink(x)
|
| 210 |
+
>>> print(x_soft_shrink)
|
| 211 |
+
array([-0.5 0. 0.5], shape=(3,), dtype=float64)
|
| 212 |
+
|
| 213 |
+
"""
|
| 214 |
+
if any_symbolic_tensors((x,)):
|
| 215 |
+
return SoftShrink(threshold).symbolic_call(x)
|
| 216 |
+
return backend.nn.soft_shrink(x, threshold)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class SparsePlus(Operation):
|
| 220 |
+
def call(self, x):
|
| 221 |
+
return backend.nn.sparse_plus(x)
|
| 222 |
+
|
| 223 |
+
def compute_output_spec(self, x):
|
| 224 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@keras_export(["keras.ops.sparse_plus", "keras.ops.nn.sparse_plus"])
|
| 228 |
+
def sparse_plus(x):
|
| 229 |
+
"""SparsePlus activation function.
|
| 230 |
+
|
| 231 |
+
It is defined as
|
| 232 |
+
|
| 233 |
+
`f(x) = 0` for `x <= -1`.
|
| 234 |
+
`f(x) = (1/4) * (x + 1)^2` for `-1 < x < 1`.
|
| 235 |
+
`f(x) = x` for `x >= 1`.
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
x: Input tensor.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
A tensor with the same shape as `x`.
|
| 243 |
+
|
| 244 |
+
Example:
|
| 245 |
+
|
| 246 |
+
>>> x = np.array([-1.0, 0.0, 1.0])
|
| 247 |
+
>>> x_sparse_plus = keras.ops.sparse_plus(x)
|
| 248 |
+
>>> print(x_sparse_plus)
|
| 249 |
+
Array([0. 0.25 1. ], shape=(3,), dtype=float32)
|
| 250 |
+
|
| 251 |
+
"""
|
| 252 |
+
if any_symbolic_tensors((x,)):
|
| 253 |
+
return SparsePlus().symbolic_call(x)
|
| 254 |
+
return backend.nn.sparse_plus(x)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class Silu(Operation):
|
| 258 |
+
def call(self, x):
|
| 259 |
+
return backend.nn.silu(x)
|
| 260 |
+
|
| 261 |
+
def compute_output_spec(self, x):
|
| 262 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@keras_export(
|
| 266 |
+
[
|
| 267 |
+
"keras.ops.silu",
|
| 268 |
+
"keras.ops.nn.silu",
|
| 269 |
+
"keras.ops.swish",
|
| 270 |
+
"keras.ops.nn.swish",
|
| 271 |
+
]
|
| 272 |
+
)
|
| 273 |
+
def silu(x):
|
| 274 |
+
"""Sigmoid Linear Unit (SiLU) activation function, also known as Swish.
|
| 275 |
+
|
| 276 |
+
The SiLU activation function is computed by the sigmoid function multiplied
|
| 277 |
+
by its input. It is defined as `f(x) = x * sigmoid(x)`.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
x: Input tensor.
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
A tensor with the same shape as `x`.
|
| 284 |
+
|
| 285 |
+
Example:
|
| 286 |
+
|
| 287 |
+
>>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0])
|
| 288 |
+
>>> keras.ops.sigmoid(x)
|
| 289 |
+
array([0.00247262, 0.7310586, 0.5, 0.7310586, 0.9975274], dtype=float32)
|
| 290 |
+
>>> keras.ops.silu(x)
|
| 291 |
+
array([-0.0148357, 0.7310586, 0.0, 0.7310586, 5.9851646], dtype=float32)
|
| 292 |
+
|
| 293 |
+
"""
|
| 294 |
+
if any_symbolic_tensors((x,)):
|
| 295 |
+
return Silu().symbolic_call(x)
|
| 296 |
+
return backend.nn.silu(x)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class Squareplus(Operation):
|
| 300 |
+
def __init__(self, b=4):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.b = b
|
| 303 |
+
|
| 304 |
+
def call(self, x):
|
| 305 |
+
return backend.nn.squareplus(x, self.b)
|
| 306 |
+
|
| 307 |
+
def compute_output_spec(self, x):
|
| 308 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
@keras_export(["keras.ops.squareplus", "keras.ops.nn.squareplus"])
|
| 312 |
+
def squareplus(x, b=4):
|
| 313 |
+
"""Squareplus activation function.
|
| 314 |
+
|
| 315 |
+
The Squareplus activation function is defined as:
|
| 316 |
+
|
| 317 |
+
`f(x) = (x + sqrt(x^2 + b)) / 2`
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
x: Input tensor.
|
| 321 |
+
b: Smoothness parameter. Defaults to 4.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
A tensor with the same shape as `x`.
|
| 325 |
+
|
| 326 |
+
Example:
|
| 327 |
+
|
| 328 |
+
>>> x = np.array([-1.0, 0.0, 1.0])
|
| 329 |
+
>>> x_squareplus = keras.ops.squareplus(x)
|
| 330 |
+
>>> print(x_squareplus)
|
| 331 |
+
array([0.6180, 1.0000, 1.6180], dtype=float32)
|
| 332 |
+
|
| 333 |
+
"""
|
| 334 |
+
if any_symbolic_tensors((x,)):
|
| 335 |
+
return Squareplus(b).symbolic_call(x)
|
| 336 |
+
return backend.nn.squareplus(x, b)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class LogSigmoid(Operation):
|
| 340 |
+
def call(self, x):
|
| 341 |
+
return backend.nn.log_sigmoid(x)
|
| 342 |
+
|
| 343 |
+
def compute_output_spec(self, x):
|
| 344 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@keras_export(
|
| 348 |
+
[
|
| 349 |
+
"keras.ops.log_sigmoid",
|
| 350 |
+
"keras.ops.nn.log_sigmoid",
|
| 351 |
+
]
|
| 352 |
+
)
|
| 353 |
+
def log_sigmoid(x):
|
| 354 |
+
"""Logarithm of the sigmoid activation function.
|
| 355 |
+
|
| 356 |
+
It is defined as `f(x) = log(1 / (1 + exp(-x)))`.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
x: Input tensor.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
A tensor with the same shape as `x`.
|
| 363 |
+
|
| 364 |
+
Example:
|
| 365 |
+
|
| 366 |
+
>>> x = keras.ops.convert_to_tensor([-0.541391, 0.0, 0.50, 5.0])
|
| 367 |
+
>>> keras.ops.log_sigmoid(x)
|
| 368 |
+
array([-1.0000418, -0.6931472, -0.474077, -0.00671535], dtype=float32)
|
| 369 |
+
|
| 370 |
+
"""
|
| 371 |
+
if any_symbolic_tensors((x,)):
|
| 372 |
+
return LogSigmoid().symbolic_call(x)
|
| 373 |
+
return backend.nn.log_sigmoid(x)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class LeakyRelu(Operation):
|
| 377 |
+
def __init__(self, negative_slope=0.2):
|
| 378 |
+
super().__init__()
|
| 379 |
+
self.negative_slope = negative_slope
|
| 380 |
+
|
| 381 |
+
def call(self, x):
|
| 382 |
+
return backend.nn.leaky_relu(x, self.negative_slope)
|
| 383 |
+
|
| 384 |
+
def compute_output_spec(self, x):
|
| 385 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@keras_export(["keras.ops.leaky_relu", "keras.ops.nn.leaky_relu"])
|
| 389 |
+
def leaky_relu(x, negative_slope=0.2):
|
| 390 |
+
"""Leaky version of a Rectified Linear Unit activation function.
|
| 391 |
+
|
| 392 |
+
It allows a small gradient when the unit is not active, it is defined as:
|
| 393 |
+
|
| 394 |
+
`f(x) = alpha * x for x < 0` or `f(x) = x for x >= 0`.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
x: Input tensor.
|
| 398 |
+
negative_slope: Slope of the activation function at x < 0.
|
| 399 |
+
Defaults to `0.2`.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
A tensor with the same shape as `x`.
|
| 403 |
+
|
| 404 |
+
Example:
|
| 405 |
+
|
| 406 |
+
>>> x = np.array([-1., 0., 1.])
|
| 407 |
+
>>> x_leaky_relu = keras.ops.leaky_relu(x)
|
| 408 |
+
>>> print(x_leaky_relu)
|
| 409 |
+
array([-0.2, 0. , 1. ], shape=(3,), dtype=float64)
|
| 410 |
+
|
| 411 |
+
"""
|
| 412 |
+
if any_symbolic_tensors((x,)):
|
| 413 |
+
return LeakyRelu(negative_slope).symbolic_call(x)
|
| 414 |
+
return backend.nn.leaky_relu(x, negative_slope=negative_slope)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class HardSigmoid(Operation):
|
| 418 |
+
def call(self, x):
|
| 419 |
+
return backend.nn.hard_sigmoid(x)
|
| 420 |
+
|
| 421 |
+
def compute_output_spec(self, x):
|
| 422 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
@keras_export(
|
| 426 |
+
[
|
| 427 |
+
"keras.ops.hard_sigmoid",
|
| 428 |
+
"keras.ops.nn.hard_sigmoid",
|
| 429 |
+
]
|
| 430 |
+
)
|
| 431 |
+
def hard_sigmoid(x):
|
| 432 |
+
"""Hard sigmoid activation function.
|
| 433 |
+
|
| 434 |
+
It is defined as:
|
| 435 |
+
|
| 436 |
+
`0 if x < -2.5`, `1 if x > 2.5`, `(0.2 * x) + 0.5 if -2.5 <= x <= 2.5`.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
x: Input tensor.
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
A tensor with the same shape as `x`.
|
| 443 |
+
|
| 444 |
+
Example:
|
| 445 |
+
|
| 446 |
+
>>> x = np.array([-1., 0., 1.])
|
| 447 |
+
>>> x_hard_sigmoid = keras.ops.hard_sigmoid(x)
|
| 448 |
+
>>> print(x_hard_sigmoid)
|
| 449 |
+
array([0.3, 0.5, 0.7], shape=(3,), dtype=float64)
|
| 450 |
+
|
| 451 |
+
"""
|
| 452 |
+
if any_symbolic_tensors((x,)):
|
| 453 |
+
return HardSigmoid().symbolic_call(x)
|
| 454 |
+
return backend.nn.hard_sigmoid(x)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class HardSilu(Operation):
|
| 458 |
+
def call(self, x):
|
| 459 |
+
return backend.nn.hard_silu(x)
|
| 460 |
+
|
| 461 |
+
def compute_output_spec(self, x):
|
| 462 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
@keras_export(
|
| 466 |
+
[
|
| 467 |
+
"keras.ops.hard_silu",
|
| 468 |
+
"keras.ops.nn.hard_silu",
|
| 469 |
+
"keras.ops.hard_swish",
|
| 470 |
+
"keras.ops.nn.hard_swish",
|
| 471 |
+
]
|
| 472 |
+
)
|
| 473 |
+
def hard_silu(x):
|
| 474 |
+
"""Hard SiLU activation function, also known as Hard Swish.
|
| 475 |
+
|
| 476 |
+
It is defined as:
|
| 477 |
+
|
| 478 |
+
- `0` if `if x < -3`
|
| 479 |
+
- `x` if `x > 3`
|
| 480 |
+
- `x * (x + 3) / 6` if `-3 <= x <= 3`
|
| 481 |
+
|
| 482 |
+
It's a faster, piecewise linear approximation of the silu activation.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
x: Input tensor.
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
A tensor with the same shape as `x`.
|
| 489 |
+
|
| 490 |
+
Example:
|
| 491 |
+
|
| 492 |
+
>>> x = keras.ops.convert_to_tensor([-3.0, -1.0, 0.0, 1.0, 3.0])
|
| 493 |
+
>>> keras.ops.hard_silu(x)
|
| 494 |
+
array([-0.0, -0.3333333, 0.0, 0.6666667, 3.0], shape=(5,), dtype=float32)
|
| 495 |
+
|
| 496 |
+
"""
|
| 497 |
+
if any_symbolic_tensors((x,)):
|
| 498 |
+
return HardSilu().symbolic_call(x)
|
| 499 |
+
return backend.nn.hard_silu(x)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class Elu(Operation):
|
| 503 |
+
def __init__(self, alpha=1.0):
|
| 504 |
+
super().__init__()
|
| 505 |
+
self.alpha = alpha
|
| 506 |
+
|
| 507 |
+
def call(self, x):
|
| 508 |
+
return backend.nn.elu(x, alpha=self.alpha)
|
| 509 |
+
|
| 510 |
+
def compute_output_spec(self, x):
|
| 511 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
@keras_export(["keras.ops.elu", "keras.ops.nn.elu"])
|
| 515 |
+
def elu(x, alpha=1.0):
|
| 516 |
+
"""Exponential Linear Unit activation function.
|
| 517 |
+
|
| 518 |
+
It is defined as:
|
| 519 |
+
|
| 520 |
+
`f(x) = alpha * (exp(x) - 1.) for x < 0`, `f(x) = x for x >= 0`.
|
| 521 |
+
|
| 522 |
+
Args:
|
| 523 |
+
x: Input tensor.
|
| 524 |
+
alpha: A scalar, slope of positive section. Defaults to `1.0`.
|
| 525 |
+
|
| 526 |
+
Returns:
|
| 527 |
+
A tensor with the same shape as `x`.
|
| 528 |
+
|
| 529 |
+
Example:
|
| 530 |
+
|
| 531 |
+
>>> x = np.array([-1., 0., 1.])
|
| 532 |
+
>>> x_elu = keras.ops.elu(x)
|
| 533 |
+
>>> print(x_elu)
|
| 534 |
+
array([-0.63212055, 0., 1.], shape=(3,), dtype=float64)
|
| 535 |
+
|
| 536 |
+
"""
|
| 537 |
+
if any_symbolic_tensors((x,)):
|
| 538 |
+
return Elu(alpha).symbolic_call(x)
|
| 539 |
+
return backend.nn.elu(x, alpha=alpha)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class Selu(Operation):
|
| 543 |
+
def call(self, x):
|
| 544 |
+
return backend.nn.selu(x)
|
| 545 |
+
|
| 546 |
+
def compute_output_spec(self, x):
|
| 547 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
@keras_export(["keras.ops.selu", "keras.ops.nn.selu"])
|
| 551 |
+
def selu(x):
|
| 552 |
+
"""Scaled Exponential Linear Unit (SELU) activation function.
|
| 553 |
+
|
| 554 |
+
It is defined as:
|
| 555 |
+
|
| 556 |
+
`f(x) = scale * alpha * (exp(x) - 1.) for x < 0`,
|
| 557 |
+
`f(x) = scale * x for x >= 0`.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
x: Input tensor.
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
A tensor with the same shape as `x`.
|
| 564 |
+
|
| 565 |
+
Example:
|
| 566 |
+
|
| 567 |
+
>>> x = np.array([-1., 0., 1.])
|
| 568 |
+
>>> x_selu = keras.ops.selu(x)
|
| 569 |
+
>>> print(x_selu)
|
| 570 |
+
array([-1.11133055, 0., 1.05070098], shape=(3,), dtype=float64)
|
| 571 |
+
|
| 572 |
+
"""
|
| 573 |
+
if any_symbolic_tensors((x,)):
|
| 574 |
+
return Selu().symbolic_call(x)
|
| 575 |
+
return backend.nn.selu(x)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class Gelu(Operation):
|
| 579 |
+
def __init__(self, approximate=True):
|
| 580 |
+
super().__init__()
|
| 581 |
+
self.approximate = approximate
|
| 582 |
+
|
| 583 |
+
def call(self, x):
|
| 584 |
+
return backend.nn.gelu(x, self.approximate)
|
| 585 |
+
|
| 586 |
+
def compute_output_spec(self, x):
|
| 587 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
@keras_export(["keras.ops.gelu", "keras.ops.nn.gelu"])
|
| 591 |
+
def gelu(x, approximate=True):
|
| 592 |
+
"""Gaussian Error Linear Unit (GELU) activation function.
|
| 593 |
+
|
| 594 |
+
If `approximate` is `True`, it is defined as:
|
| 595 |
+
`f(x) = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
|
| 596 |
+
|
| 597 |
+
Or if `approximate` is `False`, it is defined as:
|
| 598 |
+
`f(x) = x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`,
|
| 599 |
+
where `P(X) ~ N(0, 1)`.
|
| 600 |
+
|
| 601 |
+
Args:
|
| 602 |
+
x: Input tensor.
|
| 603 |
+
approximate: Approximate version of GELU activation. Defaults to `True`.
|
| 604 |
+
|
| 605 |
+
Returns:
|
| 606 |
+
A tensor with the same shape as `x`.
|
| 607 |
+
|
| 608 |
+
Example:
|
| 609 |
+
|
| 610 |
+
>>> x = np.array([-1., 0., 1.])
|
| 611 |
+
>>> x_gelu = keras.ops.gelu(x)
|
| 612 |
+
>>> print(x_gelu)
|
| 613 |
+
array([-0.15865525, 0., 0.84134475], shape=(3,), dtype=float64)
|
| 614 |
+
|
| 615 |
+
"""
|
| 616 |
+
if any_symbolic_tensors((x,)):
|
| 617 |
+
return Gelu(approximate).symbolic_call(x)
|
| 618 |
+
return backend.nn.gelu(x, approximate)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class Celu(Operation):
|
| 622 |
+
def __init__(self, alpha=1.0):
|
| 623 |
+
super().__init__()
|
| 624 |
+
self.alpha = alpha
|
| 625 |
+
|
| 626 |
+
def call(self, x):
|
| 627 |
+
return backend.nn.celu(x, self.alpha)
|
| 628 |
+
|
| 629 |
+
def compute_output_spec(self, x):
|
| 630 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
@keras_export(["keras.ops.celu", "keras.ops.nn.celu"])
|
| 634 |
+
def celu(x, alpha=1.0):
|
| 635 |
+
"""Continuously-differentiable exponential linear unit.
|
| 636 |
+
|
| 637 |
+
It is defined as:
|
| 638 |
+
|
| 639 |
+
`f(x) = alpha * (exp(x / alpha) - 1) for x < 0`, `f(x) = x for x >= 0`.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
x: Input tensor.
|
| 643 |
+
alpha: the α value for the CELU formulation. Defaults to `1.0`.
|
| 644 |
+
|
| 645 |
+
Returns:
|
| 646 |
+
A tensor with the same shape as `x`.
|
| 647 |
+
|
| 648 |
+
Example:
|
| 649 |
+
|
| 650 |
+
>>> x = np.array([-1., 0., 1.])
|
| 651 |
+
>>> x_celu = keras.ops.celu(x)
|
| 652 |
+
>>> print(x_celu)
|
| 653 |
+
array([-0.63212056, 0. , 1. ], shape=(3,), dtype=float64)
|
| 654 |
+
|
| 655 |
+
"""
|
| 656 |
+
if any_symbolic_tensors((x,)):
|
| 657 |
+
return Celu(alpha).symbolic_call(x)
|
| 658 |
+
return backend.nn.celu(x, alpha)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
class Glu(Operation):
|
| 662 |
+
def __init__(self, axis=-1):
|
| 663 |
+
super().__init__()
|
| 664 |
+
self.axis = axis
|
| 665 |
+
|
| 666 |
+
def call(self, x):
|
| 667 |
+
return backend.nn.glu(x, axis=self.axis)
|
| 668 |
+
|
| 669 |
+
def compute_output_spec(self, x):
|
| 670 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
@keras_export(["keras.ops.glu", "keras.ops.nn.glu"])
|
| 674 |
+
def glu(x, axis=-1):
|
| 675 |
+
"""Gated Linear Unit (GLU) activation function.
|
| 676 |
+
|
| 677 |
+
It is defined as:
|
| 678 |
+
|
| 679 |
+
`f(x) = a * sigmoid(b)`
|
| 680 |
+
where `x` is split into `a` and `b` along the given axis.
|
| 681 |
+
|
| 682 |
+
Args:
|
| 683 |
+
x: Input tensor.
|
| 684 |
+
axis: The axis along which to split the input tensor. Defaults to `-1`.
|
| 685 |
+
|
| 686 |
+
Returns:
|
| 687 |
+
A tensor with the same shape as half of the input.
|
| 688 |
+
|
| 689 |
+
Example:
|
| 690 |
+
|
| 691 |
+
>>> x = np.array([-1., 0., 1. , 1.])
|
| 692 |
+
>>> x_glu = keras.ops.glu(x)
|
| 693 |
+
>>> print(x_glu)
|
| 694 |
+
array([-0.73105858, 0. ], shape=(2,), dtype=float64)
|
| 695 |
+
|
| 696 |
+
"""
|
| 697 |
+
if any_symbolic_tensors((x,)):
|
| 698 |
+
return Glu(axis).symbolic_call(x)
|
| 699 |
+
return backend.nn.glu(x, axis=axis)
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
class TanhShrink(Operation):
|
| 703 |
+
def __init__(self):
|
| 704 |
+
super().__init__()
|
| 705 |
+
|
| 706 |
+
def call(self, x):
|
| 707 |
+
return backend.nn.tanh_shrink(x)
|
| 708 |
+
|
| 709 |
+
def compute_output_spec(self, x):
|
| 710 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
@keras_export(["keras.ops.tanh_shrink", "keras.ops.nn.tanh_shrink"])
|
| 714 |
+
def tanh_shrink(x):
|
| 715 |
+
"""Applies the tanh shrink function element-wise.
|
| 716 |
+
|
| 717 |
+
It is defined as:
|
| 718 |
+
|
| 719 |
+
`f(x) = x - tanh(x)`.
|
| 720 |
+
|
| 721 |
+
Args:
|
| 722 |
+
x: Input tensor.
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
Output tensor of the same shape as `x`, where each element is
|
| 726 |
+
transformed according to the tanh shrink operation.
|
| 727 |
+
|
| 728 |
+
Example:
|
| 729 |
+
|
| 730 |
+
>>> x = np.array([ -1., 0., 1.])
|
| 731 |
+
>>> x_tanh_shrink = keras.ops.tanh_shrink(x)
|
| 732 |
+
>>> print(x_tanh_shrink)
|
| 733 |
+
array([-0.23840584 0. 0.23840584], shape=(3,), dtype=float64)
|
| 734 |
+
|
| 735 |
+
"""
|
| 736 |
+
if any_symbolic_tensors((x,)):
|
| 737 |
+
return TanhShrink().symbolic_call(x)
|
| 738 |
+
return backend.nn.tanh_shrink(x)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
class HardTanh(Operation):
|
| 742 |
+
def __init__(self):
|
| 743 |
+
super().__init__()
|
| 744 |
+
|
| 745 |
+
def call(self, x):
|
| 746 |
+
return backend.nn.hard_tanh(x)
|
| 747 |
+
|
| 748 |
+
def compute_output_spec(self, x):
|
| 749 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
@keras_export(["keras.ops.hard_tanh", "keras.ops.nn.hard_tanh"])
|
| 753 |
+
def hard_tanh(x):
|
| 754 |
+
"""Applies the HardTanh function element-wise.
|
| 755 |
+
|
| 756 |
+
It is defined as:
|
| 757 |
+
|
| 758 |
+
`f(x) = -1 for x < -1`, `f(x) = x for -1 <= x <= 1`, `f(x) = 1 for x > 1`.
|
| 759 |
+
|
| 760 |
+
Args:
|
| 761 |
+
x: Input tensor.
|
| 762 |
+
|
| 763 |
+
Returns:
|
| 764 |
+
Output tensor of same shape as `x`
|
| 765 |
+
where values are clamped between -1 and 1.
|
| 766 |
+
|
| 767 |
+
Example:
|
| 768 |
+
|
| 769 |
+
>>> x = np.array([-2., -1., 0., 1., 2.])
|
| 770 |
+
>>> x_hard_tanh = keras.ops.hard_tanh(x)
|
| 771 |
+
>>> print(x_hard_tanh)
|
| 772 |
+
array([-1. -1. 0. 1. 1.], shape=(5,), dtype=float64)
|
| 773 |
+
|
| 774 |
+
"""
|
| 775 |
+
if any_symbolic_tensors((x,)):
|
| 776 |
+
return HardTanh().symbolic_call(x)
|
| 777 |
+
return backend.nn.hard_tanh(x)
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
class HardShrink(Operation):
|
| 781 |
+
def __init__(self, threshold=0.5):
|
| 782 |
+
super().__init__()
|
| 783 |
+
self.threshold = threshold
|
| 784 |
+
|
| 785 |
+
def call(self, x):
|
| 786 |
+
return backend.nn.hard_shrink(x, self.threshold)
|
| 787 |
+
|
| 788 |
+
def compute_output_spec(self, x):
|
| 789 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
@keras_export(["keras.ops.hard_shrink", "keras.ops.nn.hard_shrink"])
|
| 793 |
+
def hard_shrink(x, threshold=0.5):
|
| 794 |
+
"""Hard Shrink activation function.
|
| 795 |
+
|
| 796 |
+
The Hard Shrink function is a thresholding operation defined as:
|
| 797 |
+
|
| 798 |
+
`f(x) = x` if `|x| > threshold`,
|
| 799 |
+
`f(x) = 0` otherwise.
|
| 800 |
+
|
| 801 |
+
Args:
|
| 802 |
+
x: Input tensor.
|
| 803 |
+
threshold: Threshold value. Defaults to 0.5.
|
| 804 |
+
|
| 805 |
+
Returns:
|
| 806 |
+
A tensor with the same shape as `x`.
|
| 807 |
+
|
| 808 |
+
Example:
|
| 809 |
+
|
| 810 |
+
>>> x = np.array([-0.5, 0., 1.])
|
| 811 |
+
>>> x_hard_shrink = keras.ops.hard_shrink(x)
|
| 812 |
+
>>> print(x_hard_shrink)
|
| 813 |
+
array([0. 0. 1.], shape=(3,), dtype=float64)
|
| 814 |
+
|
| 815 |
+
"""
|
| 816 |
+
if any_symbolic_tensors((x,)):
|
| 817 |
+
return HardShrink(threshold).symbolic_call(x)
|
| 818 |
+
return backend.nn.hard_shrink(x, threshold)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
class Threshold(Operation):
|
| 822 |
+
def __init__(self, threshold_value, value):
|
| 823 |
+
super().__init__()
|
| 824 |
+
self.threshold_value = threshold_value
|
| 825 |
+
self.value = value
|
| 826 |
+
|
| 827 |
+
def call(self, x):
|
| 828 |
+
return backend.nn.threshold(x, self.threshold_value, self.value)
|
| 829 |
+
|
| 830 |
+
def compute_output_spec(self, x):
|
| 831 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
@keras_export(["keras.ops.threshold", "keras.ops.nn.threshold"])
|
| 835 |
+
def threshold(x, threshold, default_value):
|
| 836 |
+
"""Threshold activation function.
|
| 837 |
+
|
| 838 |
+
The function thresholds the input `x` as follows:
|
| 839 |
+
`f(x) = x` if `x > threshold`,
|
| 840 |
+
`f(x) = default_value` otherwise.
|
| 841 |
+
|
| 842 |
+
Args:
|
| 843 |
+
x: Input tensor.
|
| 844 |
+
threshold: The value that decides when to retain or replace x.
|
| 845 |
+
default_value: Value to assign when `x <= threshold`.
|
| 846 |
+
|
| 847 |
+
Returns:
|
| 848 |
+
A tensor with the same shape as `x`.
|
| 849 |
+
|
| 850 |
+
Example:
|
| 851 |
+
|
| 852 |
+
>>> x = np.array([-1.0, 0.0, 1.0, 2.0])
|
| 853 |
+
>>> x_threshold = keras.ops.threshold(x, 1, 0)
|
| 854 |
+
>>> print(x_threshold)
|
| 855 |
+
array([0., 0., 0., 2.], shape=(4,), dtype=float64)
|
| 856 |
+
|
| 857 |
+
"""
|
| 858 |
+
if any_symbolic_tensors((x,)):
|
| 859 |
+
return Threshold(threshold, default_value).symbolic_call(x)
|
| 860 |
+
return backend.nn.threshold(x, threshold, default_value)
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
class Softmax(Operation):
|
| 864 |
+
def __init__(self, axis=-1):
|
| 865 |
+
super().__init__()
|
| 866 |
+
self.axis = axis
|
| 867 |
+
|
| 868 |
+
def call(self, x):
|
| 869 |
+
return backend.nn.softmax(x, axis=self.axis)
|
| 870 |
+
|
| 871 |
+
def compute_output_spec(self, x):
|
| 872 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
@keras_export(["keras.ops.softmax", "keras.ops.nn.softmax"])
|
| 876 |
+
def softmax(x, axis=-1):
|
| 877 |
+
"""Softmax activation function.
|
| 878 |
+
|
| 879 |
+
The elements of the output vector lie within the range `(0, 1)`, and their
|
| 880 |
+
total sum is exactly 1 (excluding the floating point rounding error).
|
| 881 |
+
|
| 882 |
+
Each vector is processed independently. The `axis` argument specifies the
|
| 883 |
+
axis along which the function is applied within the input.
|
| 884 |
+
|
| 885 |
+
It is defined as:
|
| 886 |
+
`f(x) = exp(x) / sum(exp(x))`
|
| 887 |
+
|
| 888 |
+
Args:
|
| 889 |
+
x: Input tensor.
|
| 890 |
+
axis: Integer, axis along which the softmax is applied.
|
| 891 |
+
|
| 892 |
+
Returns:
|
| 893 |
+
A tensor with the same shape as `x`.
|
| 894 |
+
|
| 895 |
+
Example:
|
| 896 |
+
|
| 897 |
+
>>> x = np.array([-1., 0., 1.])
|
| 898 |
+
>>> x_softmax = keras.ops.softmax(x)
|
| 899 |
+
>>> print(x_softmax)
|
| 900 |
+
array([0.09003057, 0.24472847, 0.66524096], shape=(3,), dtype=float64)
|
| 901 |
+
|
| 902 |
+
"""
|
| 903 |
+
# Don't use `backend.shape` since TensorFlow returns
|
| 904 |
+
# symbolic tensors for unknown shape which can trigger
|
| 905 |
+
# an error in TensorFlow graph execution.
|
| 906 |
+
if isinstance(axis, int) and x.shape[axis] == 1:
|
| 907 |
+
warnings.warn(
|
| 908 |
+
f"You are using a softmax over axis {axis} "
|
| 909 |
+
f"of a tensor of shape {x.shape}. This axis "
|
| 910 |
+
"has size 1. The softmax operation will always return "
|
| 911 |
+
"the value 1, which is likely not what you intended. "
|
| 912 |
+
"Did you mean to use a sigmoid instead?"
|
| 913 |
+
)
|
| 914 |
+
if any_symbolic_tensors((x,)):
|
| 915 |
+
return Softmax(axis).symbolic_call(x)
|
| 916 |
+
if isinstance(axis, tuple):
|
| 917 |
+
axis_to_keep = [v for v in range(len(x.shape)) if v not in axis]
|
| 918 |
+
|
| 919 |
+
x_transposed = backend.numpy.transpose(x, axes=(*axis_to_keep, *axis))
|
| 920 |
+
x_reshaped = backend.numpy.reshape(
|
| 921 |
+
x_transposed, (*[x.shape[v] for v in axis_to_keep], -1)
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
x = backend.nn.softmax(x_reshaped, axis=-1)
|
| 925 |
+
|
| 926 |
+
x = backend.numpy.reshape(x, x_transposed.shape)
|
| 927 |
+
x = backend.numpy.transpose(
|
| 928 |
+
x, axes=list(backend.numpy.argsort([*axis_to_keep, *axis]))
|
| 929 |
+
)
|
| 930 |
+
return x
|
| 931 |
+
else:
|
| 932 |
+
return backend.nn.softmax(x, axis=axis)
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
class LogSoftmax(Operation):
|
| 936 |
+
def __init__(self, axis=-1):
|
| 937 |
+
super().__init__()
|
| 938 |
+
self.axis = axis
|
| 939 |
+
|
| 940 |
+
def call(self, x):
|
| 941 |
+
return backend.nn.log_softmax(x, axis=self.axis)
|
| 942 |
+
|
| 943 |
+
def compute_output_spec(self, x):
|
| 944 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
@keras_export(
|
| 948 |
+
[
|
| 949 |
+
"keras.ops.log_softmax",
|
| 950 |
+
"keras.ops.nn.log_softmax",
|
| 951 |
+
]
|
| 952 |
+
)
|
| 953 |
+
def log_softmax(x, axis=-1):
|
| 954 |
+
"""Log-softmax activation function.
|
| 955 |
+
|
| 956 |
+
It is defined as:
|
| 957 |
+
`f(x) = x - max(x) - log(sum(exp(x - max(x))))`
|
| 958 |
+
|
| 959 |
+
Args:
|
| 960 |
+
x: Input tensor.
|
| 961 |
+
axis: Integer, axis along which the log-softmax is applied.
|
| 962 |
+
Defaults to `-1`.
|
| 963 |
+
|
| 964 |
+
Returns:
|
| 965 |
+
A tensor with the same shape as `x`.
|
| 966 |
+
|
| 967 |
+
Example:
|
| 968 |
+
|
| 969 |
+
>>> x = np.array([-1., 0., 1.])
|
| 970 |
+
>>> x_log_softmax = keras.ops.log_softmax(x)
|
| 971 |
+
>>> print(x_log_softmax)
|
| 972 |
+
array([-2.40760596, -1.40760596, -0.40760596], shape=(3,), dtype=float64)
|
| 973 |
+
|
| 974 |
+
"""
|
| 975 |
+
if any_symbolic_tensors((x,)):
|
| 976 |
+
return LogSoftmax(axis).symbolic_call(x)
|
| 977 |
+
if isinstance(axis, tuple):
|
| 978 |
+
axis_to_keep = [v for v in range(len(x.shape)) if v not in axis]
|
| 979 |
+
|
| 980 |
+
x_transposed = backend.numpy.transpose(x, axes=(*axis_to_keep, *axis))
|
| 981 |
+
x_reshaped = backend.numpy.reshape(
|
| 982 |
+
x_transposed, (*[x.shape[v] for v in axis_to_keep], -1)
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
x = backend.nn.log_softmax(x_reshaped, axis=-1)
|
| 986 |
+
|
| 987 |
+
x = backend.numpy.reshape(x, x_transposed.shape)
|
| 988 |
+
x = backend.numpy.transpose(
|
| 989 |
+
x, axes=list(backend.numpy.argsort([*axis_to_keep, *axis]))
|
| 990 |
+
)
|
| 991 |
+
return x
|
| 992 |
+
else:
|
| 993 |
+
return backend.nn.log_softmax(x, axis=axis)
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
class Sparsemax(Operation):
|
| 997 |
+
def __init__(self, axis=-1):
|
| 998 |
+
super().__init__()
|
| 999 |
+
self.axis = axis
|
| 1000 |
+
|
| 1001 |
+
def call(self, x):
|
| 1002 |
+
return backend.nn.sparsemax(x, axis=self.axis)
|
| 1003 |
+
|
| 1004 |
+
def compute_output_spec(self, x):
|
| 1005 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
@keras_export(["keras.ops.sparsemax", "keras.ops.nn.sparsemax"])
|
| 1009 |
+
def sparsemax(x, axis=-1):
|
| 1010 |
+
"""Sparsemax activation function.
|
| 1011 |
+
|
| 1012 |
+
For each batch `i`, and class `j`,
|
| 1013 |
+
sparsemax activation function is defined as:
|
| 1014 |
+
|
| 1015 |
+
`sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).`
|
| 1016 |
+
|
| 1017 |
+
Args:
|
| 1018 |
+
x: Input tensor.
|
| 1019 |
+
axis: `int`, axis along which the sparsemax operation is applied.
|
| 1020 |
+
|
| 1021 |
+
Returns:
|
| 1022 |
+
A tensor, output of sparsemax transformation. Has the same type and
|
| 1023 |
+
shape as `x`.
|
| 1024 |
+
|
| 1025 |
+
Example:
|
| 1026 |
+
|
| 1027 |
+
>>> x = np.array([-1., 0., 1.])
|
| 1028 |
+
>>> x_sparsemax = keras.ops.sparsemax(x)
|
| 1029 |
+
>>> print(x_sparsemax)
|
| 1030 |
+
array([0., 0., 1.], shape=(3,), dtype=float64)
|
| 1031 |
+
|
| 1032 |
+
"""
|
| 1033 |
+
if any_symbolic_tensors((x,)):
|
| 1034 |
+
return Sparsemax(axis).symbolic_call(x)
|
| 1035 |
+
return backend.nn.sparsemax(x, axis=axis)
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
class MaxPool(Operation):
|
| 1039 |
+
def __init__(
|
| 1040 |
+
self,
|
| 1041 |
+
pool_size,
|
| 1042 |
+
strides=None,
|
| 1043 |
+
padding="valid",
|
| 1044 |
+
data_format=None,
|
| 1045 |
+
):
|
| 1046 |
+
super().__init__()
|
| 1047 |
+
self.pool_size = pool_size
|
| 1048 |
+
self.strides = strides
|
| 1049 |
+
self.padding = padding.lower()
|
| 1050 |
+
self.data_format = data_format
|
| 1051 |
+
|
| 1052 |
+
def call(self, inputs):
|
| 1053 |
+
return backend.nn.max_pool(
|
| 1054 |
+
inputs,
|
| 1055 |
+
self.pool_size,
|
| 1056 |
+
self.strides,
|
| 1057 |
+
self.padding,
|
| 1058 |
+
self.data_format,
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
def compute_output_spec(self, inputs):
|
| 1062 |
+
output_shape = operation_utils.compute_pooling_output_shape(
|
| 1063 |
+
inputs.shape,
|
| 1064 |
+
self.pool_size,
|
| 1065 |
+
self.strides,
|
| 1066 |
+
self.padding,
|
| 1067 |
+
self.data_format,
|
| 1068 |
+
)
|
| 1069 |
+
return KerasTensor(output_shape, dtype=inputs.dtype)
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
@keras_export(["keras.ops.max_pool", "keras.ops.nn.max_pool"])
|
| 1073 |
+
def max_pool(
|
| 1074 |
+
inputs,
|
| 1075 |
+
pool_size,
|
| 1076 |
+
strides=None,
|
| 1077 |
+
padding="valid",
|
| 1078 |
+
data_format=None,
|
| 1079 |
+
):
|
| 1080 |
+
"""Max pooling operation.
|
| 1081 |
+
|
| 1082 |
+
Args:
|
| 1083 |
+
inputs: Tensor of rank N+2. `inputs` has shape
|
| 1084 |
+
`(batch_size,) + inputs_spatial_shape + (num_channels,)` if
|
| 1085 |
+
`data_format="channels_last"`, or
|
| 1086 |
+
`(batch_size, num_channels) + inputs_spatial_shape` if
|
| 1087 |
+
`data_format="channels_first"`. Pooling happens over the spatial
|
| 1088 |
+
dimensions only.
|
| 1089 |
+
pool_size: int or tuple/list of integers of size
|
| 1090 |
+
`len(inputs_spatial_shape)`, specifying the size of the pooling
|
| 1091 |
+
window for each spatial dimension of the input tensor. If
|
| 1092 |
+
`pool_size` is int, then every spatial dimension shares the same
|
| 1093 |
+
`pool_size`.
|
| 1094 |
+
strides: int or tuple/list of integers of size
|
| 1095 |
+
`len(inputs_spatial_shape)`. The stride of the sliding window for
|
| 1096 |
+
each spatial dimension of the input tensor. If `strides` is int,
|
| 1097 |
+
then every spatial dimension shares the same `strides`.
|
| 1098 |
+
padding: string, either `"valid"` or `"same"`. `"valid"` means no
|
| 1099 |
+
padding is applied, and `"same"` results in padding evenly to the
|
| 1100 |
+
left/right or up/down of the input such that output has the
|
| 1101 |
+
same height/width dimension as the input when `strides=1`.
|
| 1102 |
+
data_format: A string, either `"channels_last"` or `"channels_first"`.
|
| 1103 |
+
`data_format` determines the ordering of the dimensions in the
|
| 1104 |
+
inputs. If `data_format="channels_last"`, `inputs` is of shape
|
| 1105 |
+
`(batch_size, ..., channels)` while if
|
| 1106 |
+
`data_format="channels_first"`, `inputs` is of shape
|
| 1107 |
+
`(batch_size, channels, ...)`.
|
| 1108 |
+
|
| 1109 |
+
Returns:
|
| 1110 |
+
A tensor of rank N+2, the result of the max pooling operation.
|
| 1111 |
+
"""
|
| 1112 |
+
data_format = standardize_data_format(data_format)
|
| 1113 |
+
padding = padding.lower()
|
| 1114 |
+
if any_symbolic_tensors((inputs,)):
|
| 1115 |
+
return MaxPool(
|
| 1116 |
+
pool_size,
|
| 1117 |
+
strides,
|
| 1118 |
+
padding,
|
| 1119 |
+
data_format,
|
| 1120 |
+
).symbolic_call(inputs)
|
| 1121 |
+
return backend.nn.max_pool(inputs, pool_size, strides, padding, data_format)
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
class AveragePool(Operation):
|
| 1125 |
+
def __init__(
|
| 1126 |
+
self,
|
| 1127 |
+
pool_size,
|
| 1128 |
+
strides=None,
|
| 1129 |
+
padding="valid",
|
| 1130 |
+
data_format=None,
|
| 1131 |
+
):
|
| 1132 |
+
super().__init__()
|
| 1133 |
+
self.pool_size = pool_size
|
| 1134 |
+
self.strides = strides
|
| 1135 |
+
self.padding = padding.lower()
|
| 1136 |
+
self.data_format = data_format
|
| 1137 |
+
|
| 1138 |
+
def call(self, inputs):
|
| 1139 |
+
return backend.nn.average_pool(
|
| 1140 |
+
inputs,
|
| 1141 |
+
self.pool_size,
|
| 1142 |
+
self.strides,
|
| 1143 |
+
self.padding,
|
| 1144 |
+
self.data_format,
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
def compute_output_spec(self, inputs):
|
| 1148 |
+
output_shape = operation_utils.compute_pooling_output_shape(
|
| 1149 |
+
inputs.shape,
|
| 1150 |
+
self.pool_size,
|
| 1151 |
+
self.strides,
|
| 1152 |
+
self.padding,
|
| 1153 |
+
self.data_format,
|
| 1154 |
+
)
|
| 1155 |
+
return KerasTensor(output_shape, dtype=inputs.dtype)
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
@keras_export(
|
| 1159 |
+
[
|
| 1160 |
+
"keras.ops.average_pool",
|
| 1161 |
+
"keras.ops.nn.average_pool",
|
| 1162 |
+
]
|
| 1163 |
+
)
|
| 1164 |
+
def average_pool(
|
| 1165 |
+
inputs,
|
| 1166 |
+
pool_size,
|
| 1167 |
+
strides=None,
|
| 1168 |
+
padding="valid",
|
| 1169 |
+
data_format=None,
|
| 1170 |
+
):
|
| 1171 |
+
"""Average pooling operation.
|
| 1172 |
+
|
| 1173 |
+
Args:
|
| 1174 |
+
inputs: Tensor of rank N+2. `inputs` has shape
|
| 1175 |
+
`(batch_size,) + inputs_spatial_shape + (num_channels,)` if
|
| 1176 |
+
`data_format="channels_last"`, or
|
| 1177 |
+
`(batch_size, num_channels) + inputs_spatial_shape` if
|
| 1178 |
+
`data_format="channels_first"`. Pooling happens over the spatial
|
| 1179 |
+
dimensions only.
|
| 1180 |
+
pool_size: int or tuple/list of integers of size
|
| 1181 |
+
`len(inputs_spatial_shape)`, specifying the size of the pooling
|
| 1182 |
+
window for each spatial dimension of the input tensor. If
|
| 1183 |
+
`pool_size` is int, then every spatial dimension shares the same
|
| 1184 |
+
`pool_size`.
|
| 1185 |
+
strides: int or tuple/list of integers of size
|
| 1186 |
+
`len(inputs_spatial_shape)`. The stride of the sliding window for
|
| 1187 |
+
each spatial dimension of the input tensor. If `strides` is int,
|
| 1188 |
+
then every spatial dimension shares the same `strides`.
|
| 1189 |
+
padding: string, either `"valid"` or `"same"`. `"valid"` means no
|
| 1190 |
+
padding is applied, and `"same"` results in padding evenly to the
|
| 1191 |
+
left/right or up/down of the input such that output has the
|
| 1192 |
+
same height/width dimension as the input when `strides=1`.
|
| 1193 |
+
data_format: A string, either `"channels_last"` or `"channels_first"`.
|
| 1194 |
+
`data_format` determines the ordering of the dimensions in the
|
| 1195 |
+
inputs. If `data_format="channels_last"`, `inputs` is of shape
|
| 1196 |
+
`(batch_size, ..., channels)` while if
|
| 1197 |
+
`data_format="channels_first"`, `inputs` is of shape
|
| 1198 |
+
`(batch_size, channels, ...)`.
|
| 1199 |
+
|
| 1200 |
+
Returns:
|
| 1201 |
+
A tensor of rank N+2, the result of the average pooling operation.
|
| 1202 |
+
"""
|
| 1203 |
+
data_format = standardize_data_format(data_format)
|
| 1204 |
+
padding = padding.lower()
|
| 1205 |
+
if any_symbolic_tensors((inputs,)):
|
| 1206 |
+
return AveragePool(
|
| 1207 |
+
pool_size,
|
| 1208 |
+
strides,
|
| 1209 |
+
padding,
|
| 1210 |
+
data_format,
|
| 1211 |
+
).symbolic_call(inputs)
|
| 1212 |
+
return backend.nn.average_pool(
|
| 1213 |
+
inputs, pool_size, strides, padding, data_format
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
class Conv(Operation):
|
| 1218 |
+
def __init__(
|
| 1219 |
+
self,
|
| 1220 |
+
strides=1,
|
| 1221 |
+
padding="valid",
|
| 1222 |
+
data_format=None,
|
| 1223 |
+
dilation_rate=1,
|
| 1224 |
+
):
|
| 1225 |
+
super().__init__()
|
| 1226 |
+
self.strides = strides
|
| 1227 |
+
self.padding = padding.lower()
|
| 1228 |
+
self.data_format = data_format
|
| 1229 |
+
self.dilation_rate = dilation_rate
|
| 1230 |
+
|
| 1231 |
+
def call(self, inputs, kernel):
|
| 1232 |
+
return backend.nn.conv(
|
| 1233 |
+
inputs,
|
| 1234 |
+
kernel,
|
| 1235 |
+
strides=self.strides,
|
| 1236 |
+
padding=self.padding,
|
| 1237 |
+
data_format=self.data_format,
|
| 1238 |
+
dilation_rate=self.dilation_rate,
|
| 1239 |
+
)
|
| 1240 |
+
|
| 1241 |
+
def compute_output_spec(self, inputs, kernel):
|
| 1242 |
+
output_shape = operation_utils.compute_conv_output_shape(
|
| 1243 |
+
inputs.shape,
|
| 1244 |
+
kernel.shape[-1],
|
| 1245 |
+
kernel.shape[:-2],
|
| 1246 |
+
self.strides,
|
| 1247 |
+
self.padding,
|
| 1248 |
+
self.data_format,
|
| 1249 |
+
self.dilation_rate,
|
| 1250 |
+
)
|
| 1251 |
+
return KerasTensor(output_shape, dtype=inputs.dtype)
|
| 1252 |
+
|
| 1253 |
+
|
| 1254 |
+
@keras_export(["keras.ops.conv", "keras.ops.nn.conv"])
|
| 1255 |
+
def conv(
|
| 1256 |
+
inputs,
|
| 1257 |
+
kernel,
|
| 1258 |
+
strides=1,
|
| 1259 |
+
padding="valid",
|
| 1260 |
+
data_format=None,
|
| 1261 |
+
dilation_rate=1,
|
| 1262 |
+
):
|
| 1263 |
+
"""General N-D convolution.
|
| 1264 |
+
|
| 1265 |
+
This ops supports 1D, 2D and 3D convolution.
|
| 1266 |
+
|
| 1267 |
+
Args:
|
| 1268 |
+
inputs: Tensor of rank N+2. `inputs` has shape
|
| 1269 |
+
`(batch_size,) + inputs_spatial_shape + (num_channels,)` if
|
| 1270 |
+
`data_format="channels_last"`, or
|
| 1271 |
+
`(batch_size, num_channels) + inputs_spatial_shape` if
|
| 1272 |
+
`data_format="channels_first"`.
|
| 1273 |
+
kernel: Tensor of rank N+2. `kernel` has shape
|
| 1274 |
+
`(kernel_spatial_shape, num_input_channels, num_output_channels)`.
|
| 1275 |
+
`num_input_channels` should match the number of channels in
|
| 1276 |
+
`inputs`.
|
| 1277 |
+
strides: int or int tuple/list of `len(inputs_spatial_shape)`,
|
| 1278 |
+
specifying the strides of the convolution along each spatial
|
| 1279 |
+
dimension. If `strides` is int, then every spatial dimension shares
|
| 1280 |
+
the same `strides`.
|
| 1281 |
+
padding: string, either `"valid"` or `"same"`. `"valid"` means no
|
| 1282 |
+
padding is applied, and `"same"` results in padding evenly to the
|
| 1283 |
+
left/right or up/down of the input such that output has the
|
| 1284 |
+
same height/width dimension as the input when `strides=1`.
|
| 1285 |
+
data_format: A string, either `"channels_last"` or `"channels_first"`.
|
| 1286 |
+
`data_format` determines the ordering of the dimensions in the
|
| 1287 |
+
inputs. If `data_format="channels_last"`, `inputs` is of shape
|
| 1288 |
+
`(batch_size, ..., channels)` while if
|
| 1289 |
+
`data_format="channels_first"`, `inputs` is of shape
|
| 1290 |
+
`(batch_size, channels, ...)`.
|
| 1291 |
+
dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
|
| 1292 |
+
specifying the dilation rate to use for dilated convolution. If
|
| 1293 |
+
`dilation_rate` is int, then every spatial dimension shares
|
| 1294 |
+
the same `dilation_rate`.
|
| 1295 |
+
|
| 1296 |
+
Returns:
|
| 1297 |
+
A tensor of rank N+2, the result of the conv operation.
|
| 1298 |
+
"""
|
| 1299 |
+
data_format = standardize_data_format(data_format)
|
| 1300 |
+
padding = padding.lower()
|
| 1301 |
+
if any_symbolic_tensors((inputs,)):
|
| 1302 |
+
return Conv(strides, padding, data_format, dilation_rate).symbolic_call(
|
| 1303 |
+
inputs, kernel
|
| 1304 |
+
)
|
| 1305 |
+
return backend.nn.conv(
|
| 1306 |
+
inputs, kernel, strides, padding, data_format, dilation_rate
|
| 1307 |
+
)
|
| 1308 |
+
|
| 1309 |
+
|
| 1310 |
+
class DepthwiseConv(Operation):
|
| 1311 |
+
def __init__(
|
| 1312 |
+
self,
|
| 1313 |
+
strides=1,
|
| 1314 |
+
padding="valid",
|
| 1315 |
+
data_format=None,
|
| 1316 |
+
dilation_rate=1,
|
| 1317 |
+
):
|
| 1318 |
+
super().__init__()
|
| 1319 |
+
self.strides = strides
|
| 1320 |
+
self.padding = padding.lower()
|
| 1321 |
+
self.data_format = data_format
|
| 1322 |
+
self.dilation_rate = dilation_rate
|
| 1323 |
+
|
| 1324 |
+
def call(self, inputs, kernel):
|
| 1325 |
+
return backend.nn.depthwise_conv(
|
| 1326 |
+
inputs,
|
| 1327 |
+
kernel,
|
| 1328 |
+
self.strides,
|
| 1329 |
+
self.padding,
|
| 1330 |
+
self.data_format,
|
| 1331 |
+
self.dilation_rate,
|
| 1332 |
+
)
|
| 1333 |
+
|
| 1334 |
+
def compute_output_spec(self, inputs, kernel):
|
| 1335 |
+
output_shape = operation_utils.compute_conv_output_shape(
|
| 1336 |
+
inputs.shape,
|
| 1337 |
+
kernel.shape[-1] * kernel.shape[-2],
|
| 1338 |
+
kernel.shape[:-2],
|
| 1339 |
+
self.strides,
|
| 1340 |
+
self.padding,
|
| 1341 |
+
self.data_format,
|
| 1342 |
+
self.dilation_rate,
|
| 1343 |
+
)
|
| 1344 |
+
return KerasTensor(output_shape, dtype=inputs.dtype)
|
| 1345 |
+
|
| 1346 |
+
|
| 1347 |
+
@keras_export(
|
| 1348 |
+
[
|
| 1349 |
+
"keras.ops.depthwise_conv",
|
| 1350 |
+
"keras.ops.nn.depthwise_conv",
|
| 1351 |
+
]
|
| 1352 |
+
)
|
| 1353 |
+
def depthwise_conv(
|
| 1354 |
+
inputs,
|
| 1355 |
+
kernel,
|
| 1356 |
+
strides=1,
|
| 1357 |
+
padding="valid",
|
| 1358 |
+
data_format=None,
|
| 1359 |
+
dilation_rate=1,
|
| 1360 |
+
):
|
| 1361 |
+
"""General N-D depthwise convolution.
|
| 1362 |
+
|
| 1363 |
+
This ops supports 1D and 2D depthwise convolution.
|
| 1364 |
+
|
| 1365 |
+
Args:
|
| 1366 |
+
inputs: Tensor of rank N+2. `inputs` has shape
|
| 1367 |
+
`(batch_size,) + inputs_spatial_shape + (num_channels,)` if
|
| 1368 |
+
`data_format="channels_last"`, or
|
| 1369 |
+
`(batch_size, num_channels) + inputs_spatial_shape` if
|
| 1370 |
+
`data_format="channels_first"`.
|
| 1371 |
+
kernel: Tensor of rank N+2. `kernel` has shape
|
| 1372 |
+
[kernel_spatial_shape, num_input_channels, num_channels_multiplier],
|
| 1373 |
+
`num_input_channels` should match the number of channels in
|
| 1374 |
+
`inputs`.
|
| 1375 |
+
strides: int or int tuple/list of `len(inputs_spatial_shape)`,
|
| 1376 |
+
specifying the strides of the convolution along each spatial
|
| 1377 |
+
dimension. If `strides` is int, then every spatial dimension shares
|
| 1378 |
+
the same `strides`.
|
| 1379 |
+
padding: string, either `"valid"` or `"same"`. `"valid"` means no
|
| 1380 |
+
padding is applied, and `"same"` results in padding evenly to the
|
| 1381 |
+
left/right or up/down of the input such that output has the
|
| 1382 |
+
same height/width dimension as the input when `strides=1`.
|
| 1383 |
+
data_format: A string, either `"channels_last"` or `"channels_first"`.
|
| 1384 |
+
`data_format` determines the ordering of the dimensions in the
|
| 1385 |
+
inputs. If `data_format="channels_last"`, `inputs` is of shape
|
| 1386 |
+
`(batch_size, ..., channels)` while if
|
| 1387 |
+
`data_format="channels_first"`, `inputs` is of shape
|
| 1388 |
+
`(batch_size, channels, ...)`.
|
| 1389 |
+
dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
|
| 1390 |
+
specifying the dilation rate to use for dilated convolution. If
|
| 1391 |
+
`dilation_rate` is int, then every spatial dimension shares
|
| 1392 |
+
the same `dilation_rate`.
|
| 1393 |
+
|
| 1394 |
+
Returns:
|
| 1395 |
+
A tensor of rank N+2, the result of the depthwise conv operation.
|
| 1396 |
+
"""
|
| 1397 |
+
data_format = standardize_data_format(data_format)
|
| 1398 |
+
padding = padding.lower()
|
| 1399 |
+
if any_symbolic_tensors((inputs,)):
|
| 1400 |
+
return DepthwiseConv(
|
| 1401 |
+
strides, padding, data_format, dilation_rate
|
| 1402 |
+
).symbolic_call(inputs, kernel)
|
| 1403 |
+
return backend.nn.depthwise_conv(
|
| 1404 |
+
inputs,
|
| 1405 |
+
kernel,
|
| 1406 |
+
strides,
|
| 1407 |
+
padding,
|
| 1408 |
+
data_format,
|
| 1409 |
+
dilation_rate,
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
|
| 1413 |
+
class SeparableConv(Operation):
|
| 1414 |
+
def __init__(
|
| 1415 |
+
self,
|
| 1416 |
+
strides=1,
|
| 1417 |
+
padding="valid",
|
| 1418 |
+
data_format=None,
|
| 1419 |
+
dilation_rate=1,
|
| 1420 |
+
):
|
| 1421 |
+
super().__init__()
|
| 1422 |
+
self.strides = strides
|
| 1423 |
+
self.padding = padding.lower()
|
| 1424 |
+
self.data_format = data_format
|
| 1425 |
+
self.dilation_rate = dilation_rate
|
| 1426 |
+
|
| 1427 |
+
def call(self, inputs, depthwise_kernel, pointwise_kernel):
|
| 1428 |
+
return backend.nn.separable_conv(
|
| 1429 |
+
inputs,
|
| 1430 |
+
depthwise_kernel,
|
| 1431 |
+
pointwise_kernel,
|
| 1432 |
+
self.strides,
|
| 1433 |
+
self.padding,
|
| 1434 |
+
self.data_format,
|
| 1435 |
+
self.dilation_rate,
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
def compute_output_spec(self, inputs, depthwise_kernel, pointwise_kernel):
|
| 1439 |
+
output_shape = list(
|
| 1440 |
+
depthwise_conv(
|
| 1441 |
+
inputs,
|
| 1442 |
+
depthwise_kernel,
|
| 1443 |
+
self.strides,
|
| 1444 |
+
self.padding,
|
| 1445 |
+
self.data_format,
|
| 1446 |
+
self.dilation_rate,
|
| 1447 |
+
).shape
|
| 1448 |
+
)
|
| 1449 |
+
if self.data_format == "channels_last":
|
| 1450 |
+
output_shape[-1] = pointwise_kernel.shape[-1]
|
| 1451 |
+
else:
|
| 1452 |
+
output_shape[1] = pointwise_kernel.shape[-1]
|
| 1453 |
+
return KerasTensor(output_shape, dtype=inputs.dtype)
|
| 1454 |
+
|
| 1455 |
+
|
| 1456 |
+
@keras_export(
|
| 1457 |
+
[
|
| 1458 |
+
"keras.ops.separable_conv",
|
| 1459 |
+
"keras.ops.nn.separable_conv",
|
| 1460 |
+
]
|
| 1461 |
+
)
|
| 1462 |
+
def separable_conv(
|
| 1463 |
+
inputs,
|
| 1464 |
+
depthwise_kernel,
|
| 1465 |
+
pointwise_kernel,
|
| 1466 |
+
strides=1,
|
| 1467 |
+
padding="valid",
|
| 1468 |
+
data_format=None,
|
| 1469 |
+
dilation_rate=1,
|
| 1470 |
+
):
|
| 1471 |
+
"""General N-D separable convolution.
|
| 1472 |
+
|
| 1473 |
+
This ops supports 1D and 2D separable convolution. `separable_conv` is
|
| 1474 |
+
a depthwise conv followed by a pointwise conv.
|
| 1475 |
+
|
| 1476 |
+
Args:
|
| 1477 |
+
inputs: Tensor of rank N+2. `inputs` has shape
|
| 1478 |
+
`(batch_size,) + inputs_spatial_shape + (num_channels,)` if
|
| 1479 |
+
`data_format="channels_last"`, or
|
| 1480 |
+
`(batch_size, num_channels) + inputs_spatial_shape` if
|
| 1481 |
+
`data_format="channels_first"`.
|
| 1482 |
+
depthwise_kernel: Tensor of rank N+2. `depthwise_kernel` has shape
|
| 1483 |
+
[kernel_spatial_shape, num_input_channels, num_channels_multiplier],
|
| 1484 |
+
`num_input_channels` should match the number of channels in
|
| 1485 |
+
`inputs`.
|
| 1486 |
+
pointwise_kernel: Tensor of rank N+2. `pointwise_kernel` has shape
|
| 1487 |
+
`(*ones_like(kernel_spatial_shape),
|
| 1488 |
+
num_input_channels * num_channels_multiplier, num_output_channels)`.
|
| 1489 |
+
strides: int or int tuple/list of `len(inputs_spatial_shape)`,
|
| 1490 |
+
specifying the strides of the convolution along each spatial
|
| 1491 |
+
dimension. If `strides` is int, then every spatial dimension shares
|
| 1492 |
+
the same `strides`.
|
| 1493 |
+
padding: string, either `"valid"` or `"same"`. `"valid"` means no
|
| 1494 |
+
padding is applied, and `"same"` results in padding evenly to the
|
| 1495 |
+
left/right or up/down of the input such that output has the
|
| 1496 |
+
same height/width dimension as the input when `strides=1`.
|
| 1497 |
+
data_format: A string, either `"channels_last"` or `"channels_first"`.
|
| 1498 |
+
`data_format` determines the ordering of the dimensions in the
|
| 1499 |
+
inputs. If `data_format="channels_last"`, `inputs` is of shape
|
| 1500 |
+
`(batch_size, ..., channels)` while if
|
| 1501 |
+
`data_format="channels_first"`, `inputs` is of shape
|
| 1502 |
+
`(batch_size, channels, ...)`.
|
| 1503 |
+
dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
|
| 1504 |
+
specifying the dilation rate to use for dilated convolution. If
|
| 1505 |
+
`dilation_rate` is int, then every spatial dimension shares
|
| 1506 |
+
the same `dilation_rate`.
|
| 1507 |
+
|
| 1508 |
+
Returns:
|
| 1509 |
+
A tensor of rank N+2, the result of the depthwise conv operation.
|
| 1510 |
+
"""
|
| 1511 |
+
data_format = standardize_data_format(data_format)
|
| 1512 |
+
padding = padding.lower()
|
| 1513 |
+
if any_symbolic_tensors((inputs,)):
|
| 1514 |
+
return SeparableConv(
|
| 1515 |
+
strides,
|
| 1516 |
+
padding,
|
| 1517 |
+
data_format,
|
| 1518 |
+
dilation_rate,
|
| 1519 |
+
).symbolic_call(inputs, depthwise_kernel, pointwise_kernel)
|
| 1520 |
+
return backend.nn.separable_conv(
|
| 1521 |
+
inputs,
|
| 1522 |
+
depthwise_kernel,
|
| 1523 |
+
pointwise_kernel,
|
| 1524 |
+
strides,
|
| 1525 |
+
padding,
|
| 1526 |
+
data_format,
|
| 1527 |
+
dilation_rate,
|
| 1528 |
+
)
|
| 1529 |
+
|
| 1530 |
+
|
| 1531 |
+
class ConvTranspose(Operation):
|
| 1532 |
+
def __init__(
|
| 1533 |
+
self,
|
| 1534 |
+
strides,
|
| 1535 |
+
padding="valid",
|
| 1536 |
+
output_padding=None,
|
| 1537 |
+
data_format=None,
|
| 1538 |
+
dilation_rate=1,
|
| 1539 |
+
):
|
| 1540 |
+
super().__init__()
|
| 1541 |
+
self.strides = strides
|
| 1542 |
+
self.output_padding = output_padding
|
| 1543 |
+
self.padding = padding.lower()
|
| 1544 |
+
self.data_format = data_format
|
| 1545 |
+
self.dilation_rate = dilation_rate
|
| 1546 |
+
|
| 1547 |
+
def call(
|
| 1548 |
+
self,
|
| 1549 |
+
inputs,
|
| 1550 |
+
kernel,
|
| 1551 |
+
):
|
| 1552 |
+
return backend.nn.conv_transpose(
|
| 1553 |
+
inputs,
|
| 1554 |
+
kernel,
|
| 1555 |
+
self.strides,
|
| 1556 |
+
self.output_padding,
|
| 1557 |
+
self.padding,
|
| 1558 |
+
self.data_format,
|
| 1559 |
+
self.dilation_rate,
|
| 1560 |
+
)
|
| 1561 |
+
|
| 1562 |
+
def compute_output_spec(self, inputs, kernel):
|
| 1563 |
+
kernel_size = kernel.shape[:-2]
|
| 1564 |
+
filters = kernel.shape[-2]
|
| 1565 |
+
output_shape = compute_conv_transpose_output_shape(
|
| 1566 |
+
inputs.shape,
|
| 1567 |
+
kernel_size,
|
| 1568 |
+
filters,
|
| 1569 |
+
self.strides,
|
| 1570 |
+
self.padding,
|
| 1571 |
+
self.output_padding,
|
| 1572 |
+
self.data_format,
|
| 1573 |
+
self.dilation_rate,
|
| 1574 |
+
)
|
| 1575 |
+
return KerasTensor(output_shape, dtype=inputs.dtype)
|
| 1576 |
+
|
| 1577 |
+
|
| 1578 |
+
@keras_export(
|
| 1579 |
+
[
|
| 1580 |
+
"keras.ops.conv_transpose",
|
| 1581 |
+
"keras.ops.nn.conv_transpose",
|
| 1582 |
+
]
|
| 1583 |
+
)
|
| 1584 |
+
def conv_transpose(
|
| 1585 |
+
inputs,
|
| 1586 |
+
kernel,
|
| 1587 |
+
strides,
|
| 1588 |
+
padding="valid",
|
| 1589 |
+
output_padding=None,
|
| 1590 |
+
data_format=None,
|
| 1591 |
+
dilation_rate=1,
|
| 1592 |
+
):
|
| 1593 |
+
"""General N-D convolution transpose.
|
| 1594 |
+
|
| 1595 |
+
Also known as de-convolution. This ops supports 1D, 2D and 3D convolution.
|
| 1596 |
+
|
| 1597 |
+
Args:
|
| 1598 |
+
inputs: Tensor of rank N+2. `inputs` has shape
|
| 1599 |
+
`(batch_size,) + inputs_spatial_shape + (num_channels,)` if
|
| 1600 |
+
`data_format="channels_last"`, or
|
| 1601 |
+
`(batch_size, num_channels) + inputs_spatial_shape` if
|
| 1602 |
+
`data_format="channels_first"`.
|
| 1603 |
+
kernel: Tensor of rank N+2. `kernel` has shape
|
| 1604 |
+
[kernel_spatial_shape, num_output_channels, num_input_channels],
|
| 1605 |
+
`num_input_channels` should match the number of channels in
|
| 1606 |
+
`inputs`.
|
| 1607 |
+
strides: int or int tuple/list of `len(inputs_spatial_shape)`,
|
| 1608 |
+
specifying the strides of the convolution along each spatial
|
| 1609 |
+
dimension. If `strides` is int, then every spatial dimension shares
|
| 1610 |
+
the same `strides`.
|
| 1611 |
+
padding: string, either `"valid"` or `"same"`. `"valid"` means no
|
| 1612 |
+
padding is applied, and `"same"` results in padding evenly to the
|
| 1613 |
+
left/right or up/down of the input such that output has the
|
| 1614 |
+
same height/width dimension as the input when `strides=1`.
|
| 1615 |
+
output_padding: int or int tuple/list of `len(inputs_spatial_shape)`,
|
| 1616 |
+
specifying the amount of padding along the height and width of
|
| 1617 |
+
the output tensor. Can be a single integer to specify the same
|
| 1618 |
+
value for all spatial dimensions. The amount of output padding
|
| 1619 |
+
along a given dimension must be lower than the stride along that
|
| 1620 |
+
same dimension. If set to `None` (default), the output shape is
|
| 1621 |
+
inferred.
|
| 1622 |
+
data_format: A string, either `"channels_last"` or `"channels_first"`.
|
| 1623 |
+
`data_format` determines the ordering of the dimensions in the
|
| 1624 |
+
inputs. If `data_format="channels_last"`, `inputs` is of shape
|
| 1625 |
+
`(batch_size, ..., channels)` while if
|
| 1626 |
+
`data_format="channels_first"`, `inputs` is of shape
|
| 1627 |
+
`(batch_size, channels, ...)`.
|
| 1628 |
+
dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
|
| 1629 |
+
specifying the dilation rate to use for dilated convolution. If
|
| 1630 |
+
`dilation_rate` is int, then every spatial dimension shares
|
| 1631 |
+
the same `dilation_rate`.
|
| 1632 |
+
|
| 1633 |
+
Returns:
|
| 1634 |
+
A tensor of rank N+2, the result of the conv operation.
|
| 1635 |
+
"""
|
| 1636 |
+
data_format = standardize_data_format(data_format)
|
| 1637 |
+
padding = padding.lower()
|
| 1638 |
+
if any_symbolic_tensors((inputs,)):
|
| 1639 |
+
return ConvTranspose(
|
| 1640 |
+
strides, padding, output_padding, data_format, dilation_rate
|
| 1641 |
+
).symbolic_call(inputs, kernel)
|
| 1642 |
+
return backend.nn.conv_transpose(
|
| 1643 |
+
inputs,
|
| 1644 |
+
kernel,
|
| 1645 |
+
strides,
|
| 1646 |
+
padding,
|
| 1647 |
+
output_padding,
|
| 1648 |
+
data_format,
|
| 1649 |
+
dilation_rate,
|
| 1650 |
+
)
|
| 1651 |
+
|
| 1652 |
+
|
| 1653 |
+
class OneHot(Operation):
|
| 1654 |
+
def __init__(self, num_classes, axis=-1, dtype=None, sparse=False):
|
| 1655 |
+
super().__init__()
|
| 1656 |
+
self.num_classes = num_classes
|
| 1657 |
+
self.axis = axis
|
| 1658 |
+
self.dtype = dtype or backend.floatx()
|
| 1659 |
+
self.sparse = sparse
|
| 1660 |
+
|
| 1661 |
+
def call(self, x):
|
| 1662 |
+
return backend.nn.one_hot(
|
| 1663 |
+
x,
|
| 1664 |
+
self.num_classes,
|
| 1665 |
+
axis=self.axis,
|
| 1666 |
+
dtype=self.dtype,
|
| 1667 |
+
sparse=self.sparse,
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
def compute_output_spec(self, x):
|
| 1671 |
+
x_shape = list(getattr(x, "shape", []))
|
| 1672 |
+
if self.axis == -1:
|
| 1673 |
+
x_shape.append(self.num_classes)
|
| 1674 |
+
elif self.axis >= 0 and self.axis < len(x_shape):
|
| 1675 |
+
x_shape.insert(self.axis, self.num_classes)
|
| 1676 |
+
else:
|
| 1677 |
+
raise ValueError(
|
| 1678 |
+
f"axis must be -1 or between [0, {len(x.shape)}), but "
|
| 1679 |
+
f"received {self.axis}."
|
| 1680 |
+
)
|
| 1681 |
+
return KerasTensor(x_shape, dtype=self.dtype, sparse=self.sparse)
|
| 1682 |
+
|
| 1683 |
+
|
| 1684 |
+
@keras_export(["keras.ops.one_hot", "keras.ops.nn.one_hot"])
|
| 1685 |
+
def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
|
| 1686 |
+
"""Converts integer tensor `x` into a one-hot tensor.
|
| 1687 |
+
|
| 1688 |
+
The one-hot encoding is a representation where each integer value is
|
| 1689 |
+
converted into a binary vector with a length equal to `num_classes`,
|
| 1690 |
+
and the index corresponding to the integer value is marked as 1, while
|
| 1691 |
+
all other indices are marked as 0.
|
| 1692 |
+
|
| 1693 |
+
Args:
|
| 1694 |
+
x: Integer tensor to be encoded. The shape can be
|
| 1695 |
+
arbitrary, but the dtype should be integer.
|
| 1696 |
+
num_classes: Number of classes for the one-hot encoding.
|
| 1697 |
+
axis: Axis along which the encoding is performed.
|
| 1698 |
+
`-1` represents the last axis. Defaults to `-1`.
|
| 1699 |
+
dtype: (Optional) Data type of the output tensor. If not
|
| 1700 |
+
provided, it defaults to the default data type of the backend.
|
| 1701 |
+
sparse: Whether to return a sparse tensor; for backends that support
|
| 1702 |
+
sparse tensors.
|
| 1703 |
+
|
| 1704 |
+
Returns:
|
| 1705 |
+
Integer tensor: One-hot encoded tensor with the same shape as `x`
|
| 1706 |
+
except for the specified `axis` dimension, which will have
|
| 1707 |
+
a length of `num_classes`. The dtype of the output tensor
|
| 1708 |
+
is determined by `dtype` or the default data type of the backend.
|
| 1709 |
+
|
| 1710 |
+
Example:
|
| 1711 |
+
|
| 1712 |
+
>>> x = keras.ops.convert_to_tensor([1, 3, 2, 0])
|
| 1713 |
+
>>> one_hot(x, num_classes=4)
|
| 1714 |
+
array([[0. 1. 0. 0.]
|
| 1715 |
+
[0. 0. 0. 1.]
|
| 1716 |
+
[0. 0. 1. 0.]
|
| 1717 |
+
[1. 0. 0. 0.]], shape=(4, 4), dtype=float32)
|
| 1718 |
+
"""
|
| 1719 |
+
if any_symbolic_tensors((x,)):
|
| 1720 |
+
return OneHot(
|
| 1721 |
+
num_classes, axis=axis, dtype=dtype, sparse=sparse
|
| 1722 |
+
).symbolic_call(x)
|
| 1723 |
+
return backend.nn.one_hot(
|
| 1724 |
+
x,
|
| 1725 |
+
num_classes,
|
| 1726 |
+
axis=axis,
|
| 1727 |
+
dtype=dtype or backend.floatx(),
|
| 1728 |
+
sparse=sparse,
|
| 1729 |
+
)
|
| 1730 |
+
|
| 1731 |
+
|
| 1732 |
+
class BinaryCrossentropy(Operation):
|
| 1733 |
+
def __init__(self, from_logits=False):
|
| 1734 |
+
super().__init__()
|
| 1735 |
+
self.from_logits = from_logits
|
| 1736 |
+
|
| 1737 |
+
def call(self, target, output):
|
| 1738 |
+
return backend.nn.binary_crossentropy(
|
| 1739 |
+
target, output, from_logits=self.from_logits
|
| 1740 |
+
)
|
| 1741 |
+
|
| 1742 |
+
def compute_output_spec(self, target, output):
|
| 1743 |
+
if target.shape != output.shape:
|
| 1744 |
+
raise ValueError(
|
| 1745 |
+
"Arguments `target` and `output` must have the same shape. "
|
| 1746 |
+
"Received: "
|
| 1747 |
+
f"target.shape={target.shape}, output.shape={output.shape}"
|
| 1748 |
+
)
|
| 1749 |
+
return KerasTensor(output.shape, dtype=output.dtype)
|
| 1750 |
+
|
| 1751 |
+
|
| 1752 |
+
@keras_export(
|
| 1753 |
+
[
|
| 1754 |
+
"keras.ops.binary_crossentropy",
|
| 1755 |
+
"keras.ops.nn.binary_crossentropy",
|
| 1756 |
+
]
|
| 1757 |
+
)
|
| 1758 |
+
def binary_crossentropy(target, output, from_logits=False):
|
| 1759 |
+
"""Computes binary cross-entropy loss between target and output tensor.
|
| 1760 |
+
|
| 1761 |
+
The binary cross-entropy loss is commonly used in binary
|
| 1762 |
+
classification tasks where each input sample belongs to one
|
| 1763 |
+
of the two classes. It measures the dissimilarity between the
|
| 1764 |
+
target and output probabilities or logits.
|
| 1765 |
+
|
| 1766 |
+
Args:
|
| 1767 |
+
target: The target tensor representing the true binary labels.
|
| 1768 |
+
Its shape should match the shape of the `output` tensor.
|
| 1769 |
+
output: The output tensor representing the predicted probabilities
|
| 1770 |
+
or logits. Its shape should match the shape of the
|
| 1771 |
+
`target` tensor.
|
| 1772 |
+
from_logits: (optional) Whether `output` is a tensor of logits or
|
| 1773 |
+
probabilities.
|
| 1774 |
+
Set it to `True` if `output` represents logits; otherwise,
|
| 1775 |
+
set it to `False` if `output` represents probabilities.
|
| 1776 |
+
Defaults to `False`.
|
| 1777 |
+
|
| 1778 |
+
Returns:
|
| 1779 |
+
Integer tensor: The computed binary cross-entropy loss between
|
| 1780 |
+
`target` and `output`.
|
| 1781 |
+
|
| 1782 |
+
Example:
|
| 1783 |
+
|
| 1784 |
+
>>> target = keras.ops.convert_to_tensor([0, 1, 1, 0])
|
| 1785 |
+
>>> output = keras.ops.convert_to_tensor([0.1, 0.9, 0.8, 0.2])
|
| 1786 |
+
>>> binary_crossentropy(target, output)
|
| 1787 |
+
array([0.10536054 0.10536054 0.22314355 0.22314355],
|
| 1788 |
+
shape=(4,), dtype=float32)
|
| 1789 |
+
"""
|
| 1790 |
+
if any_symbolic_tensors((target, output)):
|
| 1791 |
+
return BinaryCrossentropy(from_logits=from_logits).symbolic_call(
|
| 1792 |
+
target, output
|
| 1793 |
+
)
|
| 1794 |
+
return backend.nn.binary_crossentropy(
|
| 1795 |
+
target, output, from_logits=from_logits
|
| 1796 |
+
)
|
| 1797 |
+
|
| 1798 |
+
|
| 1799 |
+
class CategoricalCrossentropy(Operation):
|
| 1800 |
+
def __init__(self, from_logits=False, axis=-1):
|
| 1801 |
+
super().__init__()
|
| 1802 |
+
self.from_logits = from_logits
|
| 1803 |
+
self.axis = axis
|
| 1804 |
+
|
| 1805 |
+
def call(self, target, output):
|
| 1806 |
+
return backend.nn.categorical_crossentropy(
|
| 1807 |
+
target, output, from_logits=self.from_logits, axis=self.axis
|
| 1808 |
+
)
|
| 1809 |
+
|
| 1810 |
+
def compute_output_spec(self, target, output):
|
| 1811 |
+
if target.shape != output.shape:
|
| 1812 |
+
raise ValueError(
|
| 1813 |
+
"Arguments `target` and `output` must have the same shape. "
|
| 1814 |
+
"Received: "
|
| 1815 |
+
f"target.shape={target.shape}, output.shape={output.shape}"
|
| 1816 |
+
)
|
| 1817 |
+
if len(target.shape) < 1:
|
| 1818 |
+
raise ValueError(
|
| 1819 |
+
"Arguments `target` and `output` must be at least rank 1. "
|
| 1820 |
+
"Received: "
|
| 1821 |
+
f"target.shape={target.shape}, output.shape={output.shape}"
|
| 1822 |
+
)
|
| 1823 |
+
return KerasTensor(output.shape[:-1], dtype=output.dtype)
|
| 1824 |
+
|
| 1825 |
+
|
| 1826 |
+
@keras_export(
|
| 1827 |
+
[
|
| 1828 |
+
"keras.ops.categorical_crossentropy",
|
| 1829 |
+
"keras.ops.nn.categorical_crossentropy",
|
| 1830 |
+
]
|
| 1831 |
+
)
|
| 1832 |
+
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
| 1833 |
+
"""Computes categorical cross-entropy loss between target and output tensor.
|
| 1834 |
+
|
| 1835 |
+
The categorical cross-entropy loss is commonly used in multi-class
|
| 1836 |
+
classification tasks where each input sample can belong to one of
|
| 1837 |
+
multiple classes. It measures the dissimilarity
|
| 1838 |
+
between the target and output probabilities or logits.
|
| 1839 |
+
|
| 1840 |
+
Args:
|
| 1841 |
+
target: The target tensor representing the true categorical labels.
|
| 1842 |
+
Its shape should match the shape of the `output` tensor
|
| 1843 |
+
except for the last dimension.
|
| 1844 |
+
output: The output tensor representing the predicted probabilities
|
| 1845 |
+
or logits. Its shape should match the shape of the `target`
|
| 1846 |
+
tensor except for the last dimension.
|
| 1847 |
+
from_logits: (optional) Whether `output` is a tensor of logits or
|
| 1848 |
+
probabilities.
|
| 1849 |
+
Set it to `True` if `output` represents logits; otherwise,
|
| 1850 |
+
set it to `False` if `output` represents probabilities.
|
| 1851 |
+
Defaults to `False`.
|
| 1852 |
+
axis: (optional) The axis along which the categorical cross-entropy
|
| 1853 |
+
is computed.
|
| 1854 |
+
Defaults to `-1`, which corresponds to the last dimension of
|
| 1855 |
+
the tensors.
|
| 1856 |
+
|
| 1857 |
+
Returns:
|
| 1858 |
+
Integer tensor: The computed categorical cross-entropy loss between
|
| 1859 |
+
`target` and `output`.
|
| 1860 |
+
|
| 1861 |
+
Example:
|
| 1862 |
+
|
| 1863 |
+
>>> target = keras.ops.convert_to_tensor(
|
| 1864 |
+
... [[1, 0, 0],
|
| 1865 |
+
... [0, 1, 0],
|
| 1866 |
+
... [0, 0, 1]])
|
| 1867 |
+
>>> output = keras.ops.convert_to_tensor(
|
| 1868 |
+
... [[0.9, 0.05, 0.05],
|
| 1869 |
+
... [0.1, 0.8, 0.1],
|
| 1870 |
+
... [0.2, 0.3, 0.5]])
|
| 1871 |
+
>>> categorical_crossentropy(target, output)
|
| 1872 |
+
array([0.10536054 0.22314355 0.6931472 ], shape=(3,), dtype=float32)
|
| 1873 |
+
"""
|
| 1874 |
+
if any_symbolic_tensors((target, output)):
|
| 1875 |
+
return CategoricalCrossentropy(
|
| 1876 |
+
from_logits=from_logits, axis=axis
|
| 1877 |
+
).symbolic_call(target, output)
|
| 1878 |
+
return backend.nn.categorical_crossentropy(
|
| 1879 |
+
target, output, from_logits=from_logits, axis=axis
|
| 1880 |
+
)
|
| 1881 |
+
|
| 1882 |
+
|
| 1883 |
+
class SparseCategoricalCrossentropy(Operation):
|
| 1884 |
+
def __init__(self, from_logits=False, axis=-1):
|
| 1885 |
+
super().__init__()
|
| 1886 |
+
self.from_logits = from_logits
|
| 1887 |
+
self.axis = axis
|
| 1888 |
+
|
| 1889 |
+
def call(self, target, output):
|
| 1890 |
+
return backend.nn.sparse_categorical_crossentropy(
|
| 1891 |
+
target, output, from_logits=self.from_logits, axis=self.axis
|
| 1892 |
+
)
|
| 1893 |
+
|
| 1894 |
+
def compute_output_spec(self, target, output):
|
| 1895 |
+
if len(output.shape) < 1:
|
| 1896 |
+
raise ValueError(
|
| 1897 |
+
"Argument `output` must be at least rank 1. "
|
| 1898 |
+
"Received: "
|
| 1899 |
+
f"output.shape={output.shape}"
|
| 1900 |
+
)
|
| 1901 |
+
target_shape = target.shape
|
| 1902 |
+
if len(target_shape) == len(output.shape) and target_shape[-1] == 1:
|
| 1903 |
+
target_shape = target_shape[:-1]
|
| 1904 |
+
if target_shape != output.shape[:-1]:
|
| 1905 |
+
raise ValueError(
|
| 1906 |
+
"Arguments `target` and `output` must have the same shape "
|
| 1907 |
+
"up until the last dimension: "
|
| 1908 |
+
f"target.shape={target.shape}, output.shape={output.shape}"
|
| 1909 |
+
)
|
| 1910 |
+
return KerasTensor(output.shape[:-1], dtype=output.dtype)
|
| 1911 |
+
|
| 1912 |
+
|
| 1913 |
+
@keras_export(
|
| 1914 |
+
[
|
| 1915 |
+
"keras.ops.sparse_categorical_crossentropy",
|
| 1916 |
+
"keras.ops.nn.sparse_categorical_crossentropy",
|
| 1917 |
+
]
|
| 1918 |
+
)
|
| 1919 |
+
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
| 1920 |
+
"""Computes sparse categorical cross-entropy loss.
|
| 1921 |
+
|
| 1922 |
+
The sparse categorical cross-entropy loss is similar to categorical
|
| 1923 |
+
cross-entropy, but it is used when the target tensor contains integer
|
| 1924 |
+
class labels instead of one-hot encoded vectors. It measures the
|
| 1925 |
+
dissimilarity between the target and output probabilities or logits.
|
| 1926 |
+
|
| 1927 |
+
Args:
|
| 1928 |
+
target: The target tensor representing the true class labels as
|
| 1929 |
+
integers. Its shape should match the shape of the `output`
|
| 1930 |
+
tensor except for the last dimension.
|
| 1931 |
+
output: The output tensor representing the predicted probabilities
|
| 1932 |
+
or logits.
|
| 1933 |
+
Its shape should match the shape of the `target` tensor except
|
| 1934 |
+
for the last dimension.
|
| 1935 |
+
from_logits: (optional) Whether `output` is a tensor of logits
|
| 1936 |
+
or probabilities.
|
| 1937 |
+
Set it to `True` if `output` represents logits; otherwise,
|
| 1938 |
+
set it to `False` if `output` represents probabilities.
|
| 1939 |
+
Defaults to `False`.
|
| 1940 |
+
axis: (optional) The axis along which the sparse categorical
|
| 1941 |
+
cross-entropy is computed.
|
| 1942 |
+
Defaults to `-1`, which corresponds to the last dimension
|
| 1943 |
+
of the tensors.
|
| 1944 |
+
|
| 1945 |
+
Returns:
|
| 1946 |
+
Integer tensor: The computed sparse categorical cross-entropy
|
| 1947 |
+
loss between `target` and `output`.
|
| 1948 |
+
|
| 1949 |
+
Example:
|
| 1950 |
+
|
| 1951 |
+
>>> target = keras.ops.convert_to_tensor([0, 1, 2], dtype=int32)
|
| 1952 |
+
>>> output = keras.ops.convert_to_tensor(
|
| 1953 |
+
... [[0.9, 0.05, 0.05],
|
| 1954 |
+
... [0.1, 0.8, 0.1],
|
| 1955 |
+
... [0.2, 0.3, 0.5]])
|
| 1956 |
+
>>> sparse_categorical_crossentropy(target, output)
|
| 1957 |
+
array([0.10536056 0.22314355 0.6931472 ], shape=(3,), dtype=float32)
|
| 1958 |
+
"""
|
| 1959 |
+
if any_symbolic_tensors((target, output)):
|
| 1960 |
+
return SparseCategoricalCrossentropy(
|
| 1961 |
+
from_logits=from_logits, axis=axis
|
| 1962 |
+
).symbolic_call(target, output)
|
| 1963 |
+
return backend.nn.sparse_categorical_crossentropy(
|
| 1964 |
+
target, output, from_logits=from_logits, axis=axis
|
| 1965 |
+
)
|
| 1966 |
+
|
| 1967 |
+
|
| 1968 |
+
class MultiHot(Operation):
|
| 1969 |
+
def __init__(
|
| 1970 |
+
self, num_classes=None, axis=-1, dtype=None, sparse=False, **kwargs
|
| 1971 |
+
):
|
| 1972 |
+
if num_classes is None and "num_tokens" in kwargs:
|
| 1973 |
+
num_classes = kwargs.pop("num_tokens")
|
| 1974 |
+
if num_classes is None:
|
| 1975 |
+
raise ValueError("Argument `num_classes` must be specified.")
|
| 1976 |
+
super().__init__(**kwargs)
|
| 1977 |
+
self.num_classes = num_classes
|
| 1978 |
+
self.axis = axis
|
| 1979 |
+
self.dtype = dtype or backend.floatx()
|
| 1980 |
+
self.sparse = sparse
|
| 1981 |
+
|
| 1982 |
+
def call(self, inputs):
|
| 1983 |
+
return backend.nn.multi_hot(
|
| 1984 |
+
inputs,
|
| 1985 |
+
num_classes=self.num_classes,
|
| 1986 |
+
axis=self.axis,
|
| 1987 |
+
dtype=self.dtype,
|
| 1988 |
+
)
|
| 1989 |
+
|
| 1990 |
+
def compute_output_spec(self, inputs):
|
| 1991 |
+
x_shape = list(getattr(inputs, "shape", []))
|
| 1992 |
+
if self.axis == -1:
|
| 1993 |
+
x_shape.append(self.num_classes)
|
| 1994 |
+
elif self.axis >= 0 and self.axis < len(x_shape):
|
| 1995 |
+
x_shape.insert(self.axis, self.num_classes)
|
| 1996 |
+
else:
|
| 1997 |
+
raise ValueError(
|
| 1998 |
+
f"axis must be -1 or between [0, {len(inputs.shape)}), but "
|
| 1999 |
+
f"received {self.axis}."
|
| 2000 |
+
)
|
| 2001 |
+
|
| 2002 |
+
if len(x_shape) == 2:
|
| 2003 |
+
x_shape = [x_shape[-1]]
|
| 2004 |
+
else:
|
| 2005 |
+
x_shape = [x_shape[0]] + x_shape[2:]
|
| 2006 |
+
|
| 2007 |
+
return KerasTensor(x_shape, dtype=inputs.dtype, sparse=self.sparse)
|
| 2008 |
+
|
| 2009 |
+
|
| 2010 |
+
@keras_export(
|
| 2011 |
+
[
|
| 2012 |
+
"keras.ops.multi_hot",
|
| 2013 |
+
"keras.ops.nn.multi_hot",
|
| 2014 |
+
]
|
| 2015 |
+
)
|
| 2016 |
+
def multi_hot(
|
| 2017 |
+
inputs, num_classes=None, axis=-1, dtype=None, sparse=False, **kwargs
|
| 2018 |
+
):
|
| 2019 |
+
"""Encodes integer labels as multi-hot vectors.
|
| 2020 |
+
|
| 2021 |
+
This function encodes integer labels as multi-hot vectors, where each label
|
| 2022 |
+
is mapped to a binary value in the resulting vector.
|
| 2023 |
+
|
| 2024 |
+
Args:
|
| 2025 |
+
inputs: Tensor of integer labels to be converted to multi-hot vectors.
|
| 2026 |
+
num_classes: Integer, the total number of unique classes.
|
| 2027 |
+
axis: (optional) Axis along which the multi-hot encoding should be
|
| 2028 |
+
added. Defaults to `-1`, which corresponds to the last dimension.
|
| 2029 |
+
dtype: (optional) The data type of the resulting tensor. Default
|
| 2030 |
+
is backend's float type.
|
| 2031 |
+
sparse: Whether to return a sparse tensor; for backends that support
|
| 2032 |
+
sparse tensors.
|
| 2033 |
+
|
| 2034 |
+
Returns:
|
| 2035 |
+
Tensor: The multi-hot encoded tensor.
|
| 2036 |
+
|
| 2037 |
+
Example:
|
| 2038 |
+
|
| 2039 |
+
>>> data = keras.ops.convert_to_tensor([0, 4])
|
| 2040 |
+
>>> keras.ops.multi_hot(data, num_classes=5)
|
| 2041 |
+
array([1.0, 0.0, 0.0, 0.0, 1.0], dtype=float32)
|
| 2042 |
+
|
| 2043 |
+
"""
|
| 2044 |
+
if num_classes is None and "num_tokens" in kwargs:
|
| 2045 |
+
num_classes = kwargs.pop("num_tokens")
|
| 2046 |
+
if num_classes is None:
|
| 2047 |
+
raise ValueError("Argument `num_classes` must be specified.")
|
| 2048 |
+
|
| 2049 |
+
if any_symbolic_tensors((inputs,)):
|
| 2050 |
+
return MultiHot(num_classes, axis, dtype, sparse).symbolic_call(inputs)
|
| 2051 |
+
|
| 2052 |
+
return backend.nn.multi_hot(inputs, num_classes, axis, dtype, sparse)
|
| 2053 |
+
|
| 2054 |
+
|
| 2055 |
+
class Moments(Operation):
|
| 2056 |
+
def __init__(self, axes, keepdims=False, synchronized=False):
|
| 2057 |
+
super().__init__()
|
| 2058 |
+
self.axes = axes
|
| 2059 |
+
self.keepdims = keepdims
|
| 2060 |
+
self.synchronized = synchronized
|
| 2061 |
+
|
| 2062 |
+
def call(self, x):
|
| 2063 |
+
return backend.nn.moments(
|
| 2064 |
+
x,
|
| 2065 |
+
axes=self.axes,
|
| 2066 |
+
keepdims=self.keepdims,
|
| 2067 |
+
synchronized=self.synchronized,
|
| 2068 |
+
)
|
| 2069 |
+
|
| 2070 |
+
def compute_output_spec(self, x):
|
| 2071 |
+
return (
|
| 2072 |
+
KerasTensor(
|
| 2073 |
+
reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),
|
| 2074 |
+
dtype=x.dtype,
|
| 2075 |
+
),
|
| 2076 |
+
KerasTensor(
|
| 2077 |
+
reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),
|
| 2078 |
+
dtype=x.dtype,
|
| 2079 |
+
),
|
| 2080 |
+
)
|
| 2081 |
+
|
| 2082 |
+
|
| 2083 |
+
@keras_export(
|
| 2084 |
+
[
|
| 2085 |
+
"keras.ops.moments",
|
| 2086 |
+
"keras.ops.nn.moments",
|
| 2087 |
+
]
|
| 2088 |
+
)
|
| 2089 |
+
def moments(x, axes, keepdims=False, synchronized=False):
|
| 2090 |
+
"""Calculates the mean and variance of `x`.
|
| 2091 |
+
|
| 2092 |
+
The mean and variance are calculated by aggregating the contents of `x`
|
| 2093 |
+
across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and
|
| 2094 |
+
variance of a vector.
|
| 2095 |
+
|
| 2096 |
+
Args:
|
| 2097 |
+
x: Input tensor.
|
| 2098 |
+
axes: A list of axes which to compute mean and variance.
|
| 2099 |
+
keepdims: If this is set to `True`, the axes which are reduced are left
|
| 2100 |
+
in the result as dimensions with size one.
|
| 2101 |
+
synchronized: Only applicable with the TensorFlow backend.
|
| 2102 |
+
If `True`, synchronizes the global batch statistics (mean and
|
| 2103 |
+
variance) across all devices at each training step in a
|
| 2104 |
+
distributed training strategy. If `False`, each replica uses its own
|
| 2105 |
+
local batch statistics.
|
| 2106 |
+
|
| 2107 |
+
Returns:
|
| 2108 |
+
A tuple containing two tensors - mean and variance.
|
| 2109 |
+
|
| 2110 |
+
Example:
|
| 2111 |
+
|
| 2112 |
+
>>> x = keras.ops.convert_to_tensor([0, 1, 2, 3, 100], dtype="float32")
|
| 2113 |
+
>>> keras.ops.moments(x, axes=[0])
|
| 2114 |
+
(array(21.2, dtype=float32), array(1553.3601, dtype=float32))
|
| 2115 |
+
|
| 2116 |
+
"""
|
| 2117 |
+
if any_symbolic_tensors((x,)):
|
| 2118 |
+
return Moments(axes, keepdims, synchronized=synchronized).symbolic_call(
|
| 2119 |
+
x
|
| 2120 |
+
)
|
| 2121 |
+
|
| 2122 |
+
return backend.nn.moments(x, axes, keepdims, synchronized=synchronized)
|
| 2123 |
+
|
| 2124 |
+
|
| 2125 |
+
class BatchNorm(Operation):
|
| 2126 |
+
def __init__(self, axis, epsilon):
|
| 2127 |
+
super().__init__()
|
| 2128 |
+
self.axis = axis
|
| 2129 |
+
self.epsilon = epsilon
|
| 2130 |
+
|
| 2131 |
+
def _check_shape(self, name, shape, expected_shape):
|
| 2132 |
+
if shape != expected_shape:
|
| 2133 |
+
raise ValueError(
|
| 2134 |
+
f"Arguments `{name}` must be a vector of length "
|
| 2135 |
+
f"`x.shape[axis]`. Expected: `{expected_shape}`. "
|
| 2136 |
+
f"Received: `{shape}."
|
| 2137 |
+
)
|
| 2138 |
+
|
| 2139 |
+
def compute_output_spec(self, x, mean, variance, offset, scale):
|
| 2140 |
+
shape = (x.shape[self.axis],)
|
| 2141 |
+
self._check_shape("mean", tuple(mean.shape), shape)
|
| 2142 |
+
self._check_shape("variance", tuple(variance.shape), shape)
|
| 2143 |
+
if offset is not None:
|
| 2144 |
+
self._check_shape("offset", tuple(offset.shape), shape)
|
| 2145 |
+
if offset is not scale:
|
| 2146 |
+
self._check_shape("scale", tuple(scale.shape), shape)
|
| 2147 |
+
return KerasTensor(x.shape, dtype=x.dtype)
|
| 2148 |
+
|
| 2149 |
+
|
| 2150 |
+
@keras_export(
|
| 2151 |
+
[
|
| 2152 |
+
"keras.ops.batch_normalization",
|
| 2153 |
+
"keras.ops.nn.batch_normalization",
|
| 2154 |
+
]
|
| 2155 |
+
)
|
| 2156 |
+
def batch_normalization(
|
| 2157 |
+
x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3
|
| 2158 |
+
):
|
| 2159 |
+
"""Normalizes `x` by `mean` and `variance`.
|
| 2160 |
+
|
| 2161 |
+
This op is typically used by the batch normalization step in a neural
|
| 2162 |
+
network. It normalizes the input tensor along the given axis.
|
| 2163 |
+
|
| 2164 |
+
Args:
|
| 2165 |
+
x: Input tensor.
|
| 2166 |
+
mean: A mean vector of the same length as the `axis` dimension of the
|
| 2167 |
+
input thensor.
|
| 2168 |
+
variance: A variance vector of the same length as the `axis` dimension
|
| 2169 |
+
of the input tensor.
|
| 2170 |
+
axis: Integer, the axis that should be normalized.
|
| 2171 |
+
offset: An offset vector of the same length as the `axis` dimension of
|
| 2172 |
+
the input tensor. If not `None`, `offset` is added to the normalized
|
| 2173 |
+
tensor. Defaults to `None`.
|
| 2174 |
+
scale: A scale vector of the same length as the `axis` dimension of the
|
| 2175 |
+
input tensor. If not `None`, the normalized tensor is multiplied by
|
| 2176 |
+
`scale`. Defaults to `None`.
|
| 2177 |
+
epsilon: Small float added to variance to avoid dividing by zero.
|
| 2178 |
+
Defaults to 1e-3.
|
| 2179 |
+
|
| 2180 |
+
Returns:
|
| 2181 |
+
The normalized tensor.
|
| 2182 |
+
|
| 2183 |
+
Example:
|
| 2184 |
+
|
| 2185 |
+
>>> x = keras.ops.convert_to_tensor(
|
| 2186 |
+
... [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
|
| 2187 |
+
... )
|
| 2188 |
+
>>> keras.ops.batch_normalization(
|
| 2189 |
+
... x,
|
| 2190 |
+
... mean=[0.4, 0.5, 0.6],
|
| 2191 |
+
... variance=[0.67, 0.67, 0.67],
|
| 2192 |
+
... axis=-1
|
| 2193 |
+
... )
|
| 2194 |
+
array([[-3.6624e-01, -3.6624e-01, -3.6624e-01],
|
| 2195 |
+
[-4.6445e-09, 0.0000e+00, -1.8578e-08],
|
| 2196 |
+
[ 3.6624e-01, 3.6624e-01, 3.6624e-01]])
|
| 2197 |
+
|
| 2198 |
+
"""
|
| 2199 |
+
if any_symbolic_tensors((x, mean, variance, offset, scale)):
|
| 2200 |
+
return BatchNorm(axis, epsilon).symbolic_call(
|
| 2201 |
+
x, mean, variance, offset, scale
|
| 2202 |
+
)
|
| 2203 |
+
|
| 2204 |
+
return backend.nn.batch_normalization(
|
| 2205 |
+
x, mean, variance, axis, offset, scale, epsilon
|
| 2206 |
+
)
|
| 2207 |
+
|
| 2208 |
+
|
| 2209 |
+
class CTCLoss(Operation):
|
| 2210 |
+
def __init__(self, mask_index=0):
|
| 2211 |
+
super().__init__()
|
| 2212 |
+
self.mask_index = mask_index
|
| 2213 |
+
|
| 2214 |
+
def call(self, target, output, target_length, output_length):
|
| 2215 |
+
return backend.nn.ctc_loss(
|
| 2216 |
+
target, output, target_length, output_length, self.mask_index
|
| 2217 |
+
)
|
| 2218 |
+
|
| 2219 |
+
def _check_shape_first_dim(self, name1, shape1, name2, shape2):
|
| 2220 |
+
if shape1[0] != shape2[0]:
|
| 2221 |
+
raise ValueError(
|
| 2222 |
+
f"Arguments `{name1}` and `{name2}` must have the same "
|
| 2223 |
+
"first dimension. "
|
| 2224 |
+
f"Received shapes: `{shape1}` and `{shape2}`."
|
| 2225 |
+
)
|
| 2226 |
+
|
| 2227 |
+
def compute_output_spec(self, target, output, target_length, output_length):
|
| 2228 |
+
self._check_shape_first_dim(
|
| 2229 |
+
"target", target.shape, "output", output.shape
|
| 2230 |
+
)
|
| 2231 |
+
self._check_shape_first_dim(
|
| 2232 |
+
"target_length", target_length.shape, "target", target.shape
|
| 2233 |
+
)
|
| 2234 |
+
self._check_shape_first_dim(
|
| 2235 |
+
"output_length", output_length.shape, "output", output.shape
|
| 2236 |
+
)
|
| 2237 |
+
dtype = backend.result_type(output.dtype, "float32")
|
| 2238 |
+
return KerasTensor((target.shape[0],), dtype=dtype)
|
| 2239 |
+
|
| 2240 |
+
|
| 2241 |
+
@keras_export(
|
| 2242 |
+
[
|
| 2243 |
+
"keras.ops.ctc_loss",
|
| 2244 |
+
"keras.ops.nn.ctc_loss",
|
| 2245 |
+
]
|
| 2246 |
+
)
|
| 2247 |
+
def ctc_loss(target, output, target_length, output_length, mask_index=0):
|
| 2248 |
+
"""CTC (Connectionist Temporal Classification) loss.
|
| 2249 |
+
|
| 2250 |
+
Args:
|
| 2251 |
+
target: A tensor of shape `(batch_size, max_length)` containing
|
| 2252 |
+
the true labels in integer format.
|
| 2253 |
+
output: A tensor of shape `(batch_size, max_length, num_classes)`
|
| 2254 |
+
containing logits (the output of your model).
|
| 2255 |
+
target_length: A tensor of shape `(batch_size,)` containing the
|
| 2256 |
+
true label lengths.
|
| 2257 |
+
output_length: A tensor of shape `(batch_size,)` containing the
|
| 2258 |
+
output lengths.
|
| 2259 |
+
mask_index: The index of the mask character in the vocabulary.
|
| 2260 |
+
Defaults to `0`.
|
| 2261 |
+
"""
|
| 2262 |
+
|
| 2263 |
+
if any_symbolic_tensors((target, output, target_length, output_length)):
|
| 2264 |
+
return CTCLoss(mask_index).symbolic_call(
|
| 2265 |
+
target, output, target_length, output_length
|
| 2266 |
+
)
|
| 2267 |
+
return backend.nn.ctc_loss(
|
| 2268 |
+
target, output, target_length, output_length, mask_index
|
| 2269 |
+
)
|
| 2270 |
+
|
| 2271 |
+
|
| 2272 |
+
class CTCDecode(Operation):
|
| 2273 |
+
def __init__(
|
| 2274 |
+
self,
|
| 2275 |
+
strategy="greedy",
|
| 2276 |
+
beam_width=100,
|
| 2277 |
+
top_paths=1,
|
| 2278 |
+
merge_repeated=True,
|
| 2279 |
+
mask_index=0,
|
| 2280 |
+
):
|
| 2281 |
+
super().__init__()
|
| 2282 |
+
self.strategy = strategy
|
| 2283 |
+
self.beam_width = beam_width
|
| 2284 |
+
self.top_paths = top_paths
|
| 2285 |
+
self.merge_repeated = merge_repeated
|
| 2286 |
+
self.mask_index = mask_index
|
| 2287 |
+
|
| 2288 |
+
def call(self, inputs, sequence_lengths):
|
| 2289 |
+
return backend.nn.ctc_decode(
|
| 2290 |
+
inputs,
|
| 2291 |
+
sequence_lengths,
|
| 2292 |
+
strategy=self.strategy,
|
| 2293 |
+
beam_width=self.beam_width,
|
| 2294 |
+
top_paths=self.top_paths,
|
| 2295 |
+
merge_repeated=self.merge_repeated,
|
| 2296 |
+
mask_index=self.mask_index,
|
| 2297 |
+
)
|
| 2298 |
+
|
| 2299 |
+
def compute_output_spec(self, inputs, sequence_lengths):
|
| 2300 |
+
inputs_shape = inputs.shape
|
| 2301 |
+
if self.strategy == "greedy":
|
| 2302 |
+
top_paths = 1
|
| 2303 |
+
else:
|
| 2304 |
+
top_paths = self.top_paths
|
| 2305 |
+
dtype = backend.result_type(inputs.dtype, "float32")
|
| 2306 |
+
return (
|
| 2307 |
+
KerasTensor(
|
| 2308 |
+
(top_paths, inputs_shape[0], inputs_shape[1]), dtype="int32"
|
| 2309 |
+
),
|
| 2310 |
+
KerasTensor((inputs_shape[0], top_paths), dtype=dtype),
|
| 2311 |
+
)
|
| 2312 |
+
|
| 2313 |
+
|
| 2314 |
+
@keras_export(
|
| 2315 |
+
[
|
| 2316 |
+
"keras.ops.ctc_decode",
|
| 2317 |
+
"keras.ops.nn.ctc_decode",
|
| 2318 |
+
]
|
| 2319 |
+
)
|
| 2320 |
+
def ctc_decode(
|
| 2321 |
+
inputs,
|
| 2322 |
+
sequence_lengths,
|
| 2323 |
+
strategy="greedy",
|
| 2324 |
+
beam_width=100,
|
| 2325 |
+
top_paths=1,
|
| 2326 |
+
merge_repeated=True,
|
| 2327 |
+
mask_index=0,
|
| 2328 |
+
):
|
| 2329 |
+
"""Decodes the output of a CTC model.
|
| 2330 |
+
|
| 2331 |
+
Args:
|
| 2332 |
+
inputs: A tensor of shape `(batch_size, max_length, num_classes)`
|
| 2333 |
+
containing the logits (the output of the model).
|
| 2334 |
+
They should *not* be normalized via softmax.
|
| 2335 |
+
sequence_lengths: A tensor of shape `(batch_size,)` containing the
|
| 2336 |
+
sequence lengths for the batch.
|
| 2337 |
+
strategy: A string for the decoding strategy. Supported values are
|
| 2338 |
+
`"greedy"` and `"beam_search"`.
|
| 2339 |
+
beam_width: An integer scalar beam width used in beam search.
|
| 2340 |
+
Defaults to 100.
|
| 2341 |
+
top_paths: An integer scalar, the number of top paths to return.
|
| 2342 |
+
Defaults to 1.
|
| 2343 |
+
merge_repeated: A boolean scalar, whether to merge repeated
|
| 2344 |
+
labels in the output. Defaults to `True`.
|
| 2345 |
+
mask_index: An integer scalar, the index of the mask character in
|
| 2346 |
+
the vocabulary. Defaults to `0`.
|
| 2347 |
+
|
| 2348 |
+
Returns:
|
| 2349 |
+
A tuple containing:
|
| 2350 |
+
- The tensor representing the list of decoded sequences. If
|
| 2351 |
+
`strategy="greedy"`, the shape is `(1, batch_size, max_length)`. If
|
| 2352 |
+
`strategy="beam_search"`, the shape is
|
| 2353 |
+
`(top_paths, batch_size, max_length)`. Note that: `-1` indicates the
|
| 2354 |
+
blank label.
|
| 2355 |
+
- If `strategy="greedy"`, a tensor of shape `(batch_size, 1)`
|
| 2356 |
+
representing the negative of the sum of the probability logits for
|
| 2357 |
+
each sequence. If `strategy="beam_seatch"`, a tensor of shape
|
| 2358 |
+
`(batch_size, top_paths)` representing the log probability for each
|
| 2359 |
+
sequence.
|
| 2360 |
+
"""
|
| 2361 |
+
|
| 2362 |
+
if any_symbolic_tensors((inputs, sequence_lengths)):
|
| 2363 |
+
return CTCDecode(
|
| 2364 |
+
strategy=strategy,
|
| 2365 |
+
beam_width=beam_width,
|
| 2366 |
+
top_paths=top_paths,
|
| 2367 |
+
merge_repeated=merge_repeated,
|
| 2368 |
+
mask_index=mask_index,
|
| 2369 |
+
).symbolic_call(inputs, sequence_lengths)
|
| 2370 |
+
return backend.nn.ctc_decode(
|
| 2371 |
+
inputs=inputs,
|
| 2372 |
+
sequence_lengths=sequence_lengths,
|
| 2373 |
+
strategy=strategy,
|
| 2374 |
+
beam_width=beam_width,
|
| 2375 |
+
top_paths=top_paths,
|
| 2376 |
+
merge_repeated=merge_repeated,
|
| 2377 |
+
mask_index=mask_index,
|
| 2378 |
+
)
|
| 2379 |
+
|
| 2380 |
+
|
| 2381 |
+
class Normalize(Operation):
|
| 2382 |
+
def __init__(self, axis=-1, order=2, epsilon=None):
|
| 2383 |
+
super().__init__()
|
| 2384 |
+
self.axis = axis
|
| 2385 |
+
self.order = order
|
| 2386 |
+
self.epsilon = epsilon
|
| 2387 |
+
|
| 2388 |
+
def compute_output_spec(self, x):
|
| 2389 |
+
return KerasTensor(shape=x.shape)
|
| 2390 |
+
|
| 2391 |
+
def call(self, x):
|
| 2392 |
+
return _normalize(
|
| 2393 |
+
x, axis=self.axis, order=self.order, epsilon=self.epsilon
|
| 2394 |
+
)
|
| 2395 |
+
|
| 2396 |
+
|
| 2397 |
+
@keras_export(
|
| 2398 |
+
[
|
| 2399 |
+
"keras.ops.normalize",
|
| 2400 |
+
"keras.ops.nn.normalize",
|
| 2401 |
+
]
|
| 2402 |
+
)
|
| 2403 |
+
def normalize(x, axis=-1, order=2, epsilon=None):
|
| 2404 |
+
"""Normalizes `x` over the specified axis.
|
| 2405 |
+
|
| 2406 |
+
It is defined as: `normalize(x) = x / max(norm(x), epsilon)`.
|
| 2407 |
+
|
| 2408 |
+
Args:
|
| 2409 |
+
x: Input tensor.
|
| 2410 |
+
axis: The axis or axes along which to perform normalization.
|
| 2411 |
+
Default to -1.
|
| 2412 |
+
order: The exponent value in the norm formulation.
|
| 2413 |
+
Defaults to 2.
|
| 2414 |
+
epsilon: A lower bound value for the norm.
|
| 2415 |
+
Defaults to `backend.epsilon()`.
|
| 2416 |
+
|
| 2417 |
+
Returns:
|
| 2418 |
+
The normalized array.
|
| 2419 |
+
|
| 2420 |
+
Example:
|
| 2421 |
+
|
| 2422 |
+
>>> x = keras.ops.convert_to_tensor([[1, 2, 3], [4, 5, 6]])
|
| 2423 |
+
>>> x_norm = keras.ops.math.normalize(x)
|
| 2424 |
+
>>> print(x_norm)
|
| 2425 |
+
array([[0.26726124 0.5345225 0.8017837 ]
|
| 2426 |
+
[0.45584232 0.5698029 0.68376344]], shape=(2, 3), dtype=float32)
|
| 2427 |
+
|
| 2428 |
+
"""
|
| 2429 |
+
if any_symbolic_tensors((x,)):
|
| 2430 |
+
return Normalize(axis=axis, order=order, epsilon=epsilon).symbolic_call(
|
| 2431 |
+
x
|
| 2432 |
+
)
|
| 2433 |
+
return _normalize(x, axis=axis, order=order, epsilon=epsilon)
|
| 2434 |
+
|
| 2435 |
+
|
| 2436 |
+
def _normalize(x, axis=-1, order=2, epsilon=None):
|
| 2437 |
+
if not isinstance(order, int) or not order >= 1:
|
| 2438 |
+
raise ValueError(
|
| 2439 |
+
f"Argument `order` must be an int >= 1. Received: order={order}"
|
| 2440 |
+
)
|
| 2441 |
+
x = backend.convert_to_tensor(x)
|
| 2442 |
+
if len(x.shape) == 0:
|
| 2443 |
+
x = backend.numpy.expand_dims(x, axis=0)
|
| 2444 |
+
if epsilon is None:
|
| 2445 |
+
epsilon = backend.epsilon()
|
| 2446 |
+
if 2 == order:
|
| 2447 |
+
# A special case: L2 normalization with `x * rsqrt(...)`
|
| 2448 |
+
# instead of `x / sqrt(...)`
|
| 2449 |
+
square_sum = backend.numpy.sum(
|
| 2450 |
+
backend.numpy.square(x), axis=axis, keepdims=True
|
| 2451 |
+
)
|
| 2452 |
+
inv_norm = backend.math.rsqrt(square_sum)
|
| 2453 |
+
inv_norm = backend.numpy.minimum(inv_norm, 1.0 / epsilon)
|
| 2454 |
+
return x * inv_norm
|
| 2455 |
+
norm = backend.linalg.norm(x, ord=order, axis=axis, keepdims=True)
|
| 2456 |
+
denom = backend.numpy.maximum(norm, epsilon)
|
| 2457 |
+
return backend.numpy.divide(x, denom)
|
| 2458 |
+
|
| 2459 |
+
|
| 2460 |
+
class PSNR(Operation):
|
| 2461 |
+
def __init__(
|
| 2462 |
+
self,
|
| 2463 |
+
max_val,
|
| 2464 |
+
):
|
| 2465 |
+
super().__init__()
|
| 2466 |
+
self.max_val = max_val
|
| 2467 |
+
|
| 2468 |
+
def call(self, x1, x2):
|
| 2469 |
+
return backend.nn.psnr(
|
| 2470 |
+
x1=x1,
|
| 2471 |
+
x2=x2,
|
| 2472 |
+
max_val=self.max_val,
|
| 2473 |
+
)
|
| 2474 |
+
|
| 2475 |
+
def compute_output_spec(self, x1, x2):
|
| 2476 |
+
if len(x1.shape) != len(x2.shape):
|
| 2477 |
+
raise ValueError("Inputs must have the same rank")
|
| 2478 |
+
|
| 2479 |
+
return KerasTensor(shape=())
|
| 2480 |
+
|
| 2481 |
+
|
| 2482 |
+
@keras_export(
|
| 2483 |
+
[
|
| 2484 |
+
"keras.ops.psnr",
|
| 2485 |
+
"keras.ops.nn.psnr",
|
| 2486 |
+
]
|
| 2487 |
+
)
|
| 2488 |
+
def psnr(
|
| 2489 |
+
x1,
|
| 2490 |
+
x2,
|
| 2491 |
+
max_val,
|
| 2492 |
+
):
|
| 2493 |
+
"""Peak Signal-to-Noise Ratio (PSNR) function.
|
| 2494 |
+
|
| 2495 |
+
This function computes the Peak Signal-to-Noise Ratio between two signals,
|
| 2496 |
+
`x1` and `x2`. PSNR is a measure of the quality of a reconstructed signal.
|
| 2497 |
+
The higher the PSNR, the closer the reconstructed signal is to the original
|
| 2498 |
+
signal. Note that it can become negative when the signal power is
|
| 2499 |
+
smaller that the noise power.
|
| 2500 |
+
|
| 2501 |
+
Args:
|
| 2502 |
+
x1: The first input signal.
|
| 2503 |
+
x2: The second input signal. Must have the same shape as `x1`.
|
| 2504 |
+
max_val: The maximum possible value in the signals.
|
| 2505 |
+
|
| 2506 |
+
Returns:
|
| 2507 |
+
float: The PSNR value between `x1` and `x2`.
|
| 2508 |
+
|
| 2509 |
+
Examples:
|
| 2510 |
+
|
| 2511 |
+
>>> x1 = keras.random.normal((2, 4, 4, 3))
|
| 2512 |
+
>>> x2 = keras.random.normal((2, 4, 4, 3))
|
| 2513 |
+
>>> max_val = 1.0
|
| 2514 |
+
>>> keras.ops.nn.psnr(x1, x2, max_val)
|
| 2515 |
+
-3.1697404
|
| 2516 |
+
"""
|
| 2517 |
+
if any_symbolic_tensors(
|
| 2518 |
+
(
|
| 2519 |
+
x1,
|
| 2520 |
+
x2,
|
| 2521 |
+
)
|
| 2522 |
+
):
|
| 2523 |
+
return PSNR(
|
| 2524 |
+
max_val,
|
| 2525 |
+
).symbolic_call(x1, x2)
|
| 2526 |
+
return backend.nn.psnr(
|
| 2527 |
+
x1,
|
| 2528 |
+
x2,
|
| 2529 |
+
max_val,
|
| 2530 |
+
)
|
| 2531 |
+
|
| 2532 |
+
|
| 2533 |
+
class DotProductAttention(Operation):
|
| 2534 |
+
def __init__(self, is_causal=False):
|
| 2535 |
+
super().__init__()
|
| 2536 |
+
self.is_causal = is_causal
|
| 2537 |
+
|
| 2538 |
+
def call(
|
| 2539 |
+
self,
|
| 2540 |
+
query,
|
| 2541 |
+
key,
|
| 2542 |
+
value,
|
| 2543 |
+
bias=None,
|
| 2544 |
+
mask=None,
|
| 2545 |
+
scale=None,
|
| 2546 |
+
flash_attention=None,
|
| 2547 |
+
):
|
| 2548 |
+
return backend.nn.dot_product_attention(
|
| 2549 |
+
query,
|
| 2550 |
+
key,
|
| 2551 |
+
value,
|
| 2552 |
+
bias=bias,
|
| 2553 |
+
mask=mask,
|
| 2554 |
+
scale=scale,
|
| 2555 |
+
is_causal=self.is_causal,
|
| 2556 |
+
flash_attention=flash_attention,
|
| 2557 |
+
)
|
| 2558 |
+
|
| 2559 |
+
def compute_output_spec(
|
| 2560 |
+
self,
|
| 2561 |
+
query,
|
| 2562 |
+
key,
|
| 2563 |
+
value,
|
| 2564 |
+
bias=None,
|
| 2565 |
+
mask=None,
|
| 2566 |
+
scale=None,
|
| 2567 |
+
flash_attention=None,
|
| 2568 |
+
):
|
| 2569 |
+
return KerasTensor(query.shape, dtype=query.dtype)
|
| 2570 |
+
|
| 2571 |
+
|
| 2572 |
+
@keras_export(
|
| 2573 |
+
["keras.ops.dot_product_attention", "keras.ops.nn.dot_product_attention"]
|
| 2574 |
+
)
|
| 2575 |
+
def dot_product_attention(
|
| 2576 |
+
query,
|
| 2577 |
+
key,
|
| 2578 |
+
value,
|
| 2579 |
+
bias=None,
|
| 2580 |
+
mask=None,
|
| 2581 |
+
scale=None,
|
| 2582 |
+
is_causal=False,
|
| 2583 |
+
flash_attention=None,
|
| 2584 |
+
):
|
| 2585 |
+
"""Scaled dot product attention function.
|
| 2586 |
+
|
| 2587 |
+
Computes the attention function on Q (`query`), K (`key`), and V(`value`):
|
| 2588 |
+
`attention(Q, K, V) = softmax(Q * K / sqrt(d)) * V`. If we define `logits`
|
| 2589 |
+
as the output of `Q * K` and the `probs` as the output of `softmax`.
|
| 2590 |
+
|
| 2591 |
+
Throughout this function, we utilize the following notation to represent the
|
| 2592 |
+
shape of array:
|
| 2593 |
+
- B: batch size
|
| 2594 |
+
- S: length of the key/value
|
| 2595 |
+
- T: length of the query
|
| 2596 |
+
- N: number of attention heads
|
| 2597 |
+
- H: dimensions of each attention head
|
| 2598 |
+
- K: number of key/value heads
|
| 2599 |
+
- G: number of groups, which equals to `N // K`
|
| 2600 |
+
|
| 2601 |
+
Args:
|
| 2602 |
+
query: The query array with the shape of `(B, T, N, H)`.
|
| 2603 |
+
key: The key array with the shape of `(B, S, K, H)`. When `K` equals
|
| 2604 |
+
`N`, multi-headed attention (MHA) is performed. Otherwise, grouped
|
| 2605 |
+
query attention (GQA) is performed if `N` is a multiple of `K`. and
|
| 2606 |
+
multi-query attention (MQA) is performed if `K==1` (a special case
|
| 2607 |
+
of GQA).
|
| 2608 |
+
value: The value array with the same shape of `key`.
|
| 2609 |
+
bias: Optional bias array to be added to logits. The shape must be
|
| 2610 |
+
broadcastable to `(B, N, T, S)`.
|
| 2611 |
+
mask: Optional mask array used to filter out logits. It is a boolean
|
| 2612 |
+
mask where `True` indicates the element should take part in
|
| 2613 |
+
attention. For an additive mask, users should pass it to bias. The
|
| 2614 |
+
shape must be broadcastable to `(B, N, T, S)`.
|
| 2615 |
+
scale: Optional scale for the logits. If `None`, the scale will be set
|
| 2616 |
+
to `1.0 / sqrt(H)`.
|
| 2617 |
+
is_causal: Whether to apply causal mask.
|
| 2618 |
+
flash_attention: Whether to use flash attention. If `None`, it will
|
| 2619 |
+
attempt to use flash attention if the required conditions are met.
|
| 2620 |
+
Typically, the inputs must be in float16 and bfloat16 dtype and the
|
| 2621 |
+
input layout requirements may vary depending on the backend.
|
| 2622 |
+
|
| 2623 |
+
Returns:
|
| 2624 |
+
An array of the attention output with the same shape of `query`.
|
| 2625 |
+
|
| 2626 |
+
Example:
|
| 2627 |
+
|
| 2628 |
+
>>> query = keras.random.normal((2, 4, 8, 16))
|
| 2629 |
+
>>> key = keras.random.normal((2, 6, 8, 16))
|
| 2630 |
+
>>> value = keras.random.normal((2, 6, 8, 16))
|
| 2631 |
+
>>> keras.ops.nn.dot_product_attention(query, key, value).shape
|
| 2632 |
+
(2, 4, 8, 16)
|
| 2633 |
+
"""
|
| 2634 |
+
if any_symbolic_tensors((query, key, value)):
|
| 2635 |
+
return DotProductAttention(is_causal=is_causal).symbolic_call(
|
| 2636 |
+
query,
|
| 2637 |
+
key,
|
| 2638 |
+
value,
|
| 2639 |
+
bias=bias,
|
| 2640 |
+
mask=mask,
|
| 2641 |
+
scale=scale,
|
| 2642 |
+
flash_attention=flash_attention,
|
| 2643 |
+
)
|
| 2644 |
+
return backend.nn.dot_product_attention(
|
| 2645 |
+
query,
|
| 2646 |
+
key,
|
| 2647 |
+
value,
|
| 2648 |
+
bias=bias,
|
| 2649 |
+
mask=mask,
|
| 2650 |
+
scale=scale,
|
| 2651 |
+
is_causal=is_causal,
|
| 2652 |
+
flash_attention=flash_attention,
|
| 2653 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/node.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
|
| 3 |
+
from keras.src import tree
|
| 4 |
+
from keras.src.backend import KerasTensor
|
| 5 |
+
from keras.src.ops.symbolic_arguments import SymbolicArguments
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Node:
|
| 9 |
+
"""A `Node` describes an operation `__call__()` event.
|
| 10 |
+
|
| 11 |
+
A Keras Function is a DAG with `Node` instances as nodes, and
|
| 12 |
+
`KerasTensor` instances as edges. Nodes aren't `Operation` instances,
|
| 13 |
+
because a single operation could be called multiple times, which would
|
| 14 |
+
result in graph cycles.
|
| 15 |
+
|
| 16 |
+
A `__call__()` event involves input tensors (and other input arguments),
|
| 17 |
+
the operation that was called, and the resulting output tensors.
|
| 18 |
+
A `Node` will include all this information.
|
| 19 |
+
|
| 20 |
+
Since a single `Operation` could be called multiple times,
|
| 21 |
+
the `Node` instances are stored on operations as a list.
|
| 22 |
+
Each time an operation is called, a node is added to `op._inbound_nodes`.
|
| 23 |
+
Each time the output of an operation is used by another operation,
|
| 24 |
+
a node is added to `op._outbound_nodes`.
|
| 25 |
+
|
| 26 |
+
Every `KerasTensor` instance has a `KerasHistory` object attached,
|
| 27 |
+
which tracks the `Node` that records the `__call__()` event that created
|
| 28 |
+
the tensor. By recursively walking through `Node` instances
|
| 29 |
+
via the `KerasHistory` metadata of `KerasTensor` instances, once can
|
| 30 |
+
retrieve the entire DAG of a Keras Function.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
operation: The Operation that was called in the `op.__call__()`
|
| 34 |
+
event that this node represents.
|
| 35 |
+
call_args: The positional arguments the operation was called with.
|
| 36 |
+
call_kwargs: The keyword arguments the operation was called with.
|
| 37 |
+
outputs: The output tensors of the `op.__call__()` call.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self, operation, call_args=None, call_kwargs=None, outputs=None
|
| 42 |
+
):
|
| 43 |
+
self.operation = operation
|
| 44 |
+
self.arguments = SymbolicArguments(*call_args, **call_kwargs)
|
| 45 |
+
self.outputs = [] if outputs is None else tree.flatten(outputs)
|
| 46 |
+
for x in self.outputs:
|
| 47 |
+
if not isinstance(x, KerasTensor):
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"All operation outputs must be tensors. "
|
| 50 |
+
f"Operation {operation} returned a non-tensor. "
|
| 51 |
+
f"Non-tensor received: {x}"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
zero_history = any(
|
| 55 |
+
not x.record_history for x in self.arguments.keras_tensors
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# If inputs don't have metadata yet, add it.
|
| 59 |
+
if not zero_history:
|
| 60 |
+
for tensor in self.arguments.keras_tensors:
|
| 61 |
+
if not hasattr(tensor, "_keras_history"):
|
| 62 |
+
tensor._keras_history = KerasHistory(
|
| 63 |
+
operation=None, node_index=0, tensor_index=0
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Wire up Node to Operations.
|
| 67 |
+
self.operation._inbound_nodes.append(self)
|
| 68 |
+
for kt in self.arguments.keras_tensors:
|
| 69 |
+
inbound_op = kt._keras_history.operation
|
| 70 |
+
if inbound_op is not None: # It's a graph entry point.
|
| 71 |
+
inbound_op._outbound_nodes.append(self)
|
| 72 |
+
|
| 73 |
+
# Set metadata on outputs.
|
| 74 |
+
if not zero_history:
|
| 75 |
+
node_index = len(self.operation._inbound_nodes) - 1
|
| 76 |
+
for i, tensor in enumerate(self.outputs):
|
| 77 |
+
tensor._keras_history = KerasHistory(
|
| 78 |
+
operation=operation, node_index=node_index, tensor_index=i
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Whether this is a root node.
|
| 82 |
+
self.is_input = not self.arguments.keras_tensors
|
| 83 |
+
|
| 84 |
+
def __repr__(self):
|
| 85 |
+
return f"<Node operation={self.operation.name}, id={id(self)}>"
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def input_tensors(self):
|
| 89 |
+
return self.arguments.keras_tensors
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def output_tensors(self):
|
| 93 |
+
return self.outputs
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def parent_nodes(self):
|
| 97 |
+
"""The parent `Node`s.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
all the `Node`s whose output this node immediately depends on.
|
| 101 |
+
"""
|
| 102 |
+
node_deps = []
|
| 103 |
+
for kt in self.arguments.keras_tensors:
|
| 104 |
+
op = kt._keras_history.operation
|
| 105 |
+
node_index = kt._keras_history.node_index
|
| 106 |
+
if op is not None: # `None` for `Input` tensors.
|
| 107 |
+
node_deps.append(op._inbound_nodes[node_index])
|
| 108 |
+
return node_deps
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class KerasHistory(
|
| 112 |
+
collections.namedtuple(
|
| 113 |
+
"KerasHistory", ["operation", "node_index", "tensor_index"]
|
| 114 |
+
)
|
| 115 |
+
):
|
| 116 |
+
"""Tracks the Operation call that created a Tensor.
|
| 117 |
+
|
| 118 |
+
During construction of Keras Functions, this metadata is added to
|
| 119 |
+
each Tensor produced as the output of an Operation.
|
| 120 |
+
This allows Keras to track how each Tensor was produced, and
|
| 121 |
+
this information is later retraced by the `Function` class to
|
| 122 |
+
reconstruct the Operations graph.
|
| 123 |
+
|
| 124 |
+
Attributes:
|
| 125 |
+
operation: The Operation instance that produced the Tensor.
|
| 126 |
+
node_index: The specific call to the Operation that produced this Tensor.
|
| 127 |
+
Operations can be called multiple times in order to share weights. A new
|
| 128 |
+
node is created every time an Operation is called. The corresponding
|
| 129 |
+
node that represents the call event that produced the Tensor can be
|
| 130 |
+
found at `op._inbound_nodes[node_index]`.
|
| 131 |
+
tensor_index: The output index for this Tensor.
|
| 132 |
+
Always zero if the Operation that produced this Tensor
|
| 133 |
+
only has one output. Nested structures of
|
| 134 |
+
Tensors are deterministically assigned an index via `nest.flatten`.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
# Added to maintain memory and performance characteristics of `namedtuple`
|
| 138 |
+
# while subclassing.
|
| 139 |
+
__slots__ = ()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def is_keras_tensor(obj):
|
| 143 |
+
return hasattr(obj, "_keras_history")
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/numpy.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/operation.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import textwrap
|
| 3 |
+
|
| 4 |
+
from keras.src import backend
|
| 5 |
+
from keras.src import dtype_policies
|
| 6 |
+
from keras.src import tree
|
| 7 |
+
from keras.src.api_export import keras_export
|
| 8 |
+
from keras.src.backend.common.keras_tensor import any_symbolic_tensors
|
| 9 |
+
from keras.src.ops.node import Node
|
| 10 |
+
from keras.src.utils import python_utils
|
| 11 |
+
from keras.src.utils import traceback_utils
|
| 12 |
+
from keras.src.utils.naming import auto_name
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@keras_export("keras.Operation")
|
| 16 |
+
class Operation:
|
| 17 |
+
def __init__(self, dtype=None, name=None):
|
| 18 |
+
if name is None:
|
| 19 |
+
name = auto_name(self.__class__.__name__)
|
| 20 |
+
if not isinstance(name, str) or "/" in name:
|
| 21 |
+
raise ValueError(
|
| 22 |
+
"Argument `name` must be a string and "
|
| 23 |
+
"cannot contain character `/`. "
|
| 24 |
+
f"Received: name={name} (of type {type(name)})"
|
| 25 |
+
)
|
| 26 |
+
self._dtype_policy = dtype_policies.get(dtype)
|
| 27 |
+
self.name = name
|
| 28 |
+
self._inbound_nodes = []
|
| 29 |
+
self._outbound_nodes = []
|
| 30 |
+
|
| 31 |
+
@traceback_utils.filter_traceback
|
| 32 |
+
def __call__(self, *args, **kwargs):
|
| 33 |
+
if traceback_utils.is_traceback_filtering_enabled():
|
| 34 |
+
# Wrap self.call to provide helpful info in case of exception
|
| 35 |
+
if any_symbolic_tensors(args, kwargs):
|
| 36 |
+
call_fn = self.symbolic_call
|
| 37 |
+
else:
|
| 38 |
+
if getattr(self, "quantization_mode", None) is not None:
|
| 39 |
+
call_fn = self.quantized_call
|
| 40 |
+
else:
|
| 41 |
+
call_fn = self.call
|
| 42 |
+
call_fn = traceback_utils.inject_argument_info_in_traceback(
|
| 43 |
+
call_fn,
|
| 44 |
+
object_name=(f"{self.__class__.__name__}.call()"),
|
| 45 |
+
)
|
| 46 |
+
return call_fn(*args, **kwargs)
|
| 47 |
+
|
| 48 |
+
# Plain flow.
|
| 49 |
+
if any_symbolic_tensors(args, kwargs):
|
| 50 |
+
return self.symbolic_call(*args, **kwargs)
|
| 51 |
+
if getattr(self, "quantization_mode", None) is not None:
|
| 52 |
+
return self.quantized_call(*args, **kwargs)
|
| 53 |
+
else:
|
| 54 |
+
return self.call(*args, **kwargs)
|
| 55 |
+
|
| 56 |
+
def symbolic_call(self, *args, **kwargs):
|
| 57 |
+
# Perform shape/dtype inference.
|
| 58 |
+
outputs = self.compute_output_spec(*args, **kwargs)
|
| 59 |
+
# Record a new node in the operations graph.
|
| 60 |
+
# The Node wires itself to inbound and outbound ops. The
|
| 61 |
+
# Node constructor updates this op's self._inbound_nodes,
|
| 62 |
+
# sets _keras_history on the outputs, and adds itself to the
|
| 63 |
+
# `_outbound_nodes` of the ops that produced the inputs to this
|
| 64 |
+
# call.
|
| 65 |
+
Node(
|
| 66 |
+
operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs
|
| 67 |
+
)
|
| 68 |
+
return outputs
|
| 69 |
+
|
| 70 |
+
def call(self, *args, **kwargs):
|
| 71 |
+
raise NotImplementedError
|
| 72 |
+
|
| 73 |
+
def quantized_call(self, *args, **kwargs):
|
| 74 |
+
raise NotImplementedError
|
| 75 |
+
|
| 76 |
+
def compute_output_spec(self, *args, **kwargs):
|
| 77 |
+
try:
|
| 78 |
+
return backend.compute_output_spec(self.call, *args, **kwargs)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
new_e = e.__class__(
|
| 81 |
+
"Could not automatically infer the output shape / dtype of "
|
| 82 |
+
f"'{self.name}' (of type {self.__class__.__name__}). "
|
| 83 |
+
f"Either the `{self.__class__.__name__}.call()` method "
|
| 84 |
+
f"is incorrect, or you need to implement the "
|
| 85 |
+
f"`{self.__class__.__name__}.compute_output_spec() / "
|
| 86 |
+
"compute_output_shape()` method. "
|
| 87 |
+
f"Error encountered:\n\n{e}"
|
| 88 |
+
)
|
| 89 |
+
raise new_e.with_traceback(e.__traceback__) from None
|
| 90 |
+
|
| 91 |
+
def __new__(cls, *args, **kwargs):
|
| 92 |
+
"""We override __new__ to saving serializable constructor arguments.
|
| 93 |
+
|
| 94 |
+
These arguments are used to auto-generate an object serialization
|
| 95 |
+
config, which enables user-created subclasses to be serializable
|
| 96 |
+
out of the box in most cases without forcing the user
|
| 97 |
+
to manually implement `get_config()`.
|
| 98 |
+
"""
|
| 99 |
+
instance = super(Operation, cls).__new__(cls)
|
| 100 |
+
|
| 101 |
+
# Generate a config to be returned by default by `get_config()`.
|
| 102 |
+
arg_names = inspect.getfullargspec(cls.__init__).args
|
| 103 |
+
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
|
| 104 |
+
|
| 105 |
+
# Explicitly serialize `dtype` to support auto_config
|
| 106 |
+
dtype = kwargs.get("dtype", None)
|
| 107 |
+
if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy):
|
| 108 |
+
# For backward compatibility, we use a str (`name`) for
|
| 109 |
+
# `DTypePolicy`
|
| 110 |
+
if dtype.quantization_mode is None:
|
| 111 |
+
kwargs["dtype"] = dtype.name
|
| 112 |
+
# Otherwise, use `dtype_policies.serialize`
|
| 113 |
+
else:
|
| 114 |
+
kwargs["dtype"] = dtype_policies.serialize(dtype)
|
| 115 |
+
|
| 116 |
+
# For safety, we only rely on auto-configs for a small set of
|
| 117 |
+
# serializable types.
|
| 118 |
+
supported_types = (str, int, float, bool, type(None))
|
| 119 |
+
try:
|
| 120 |
+
flat_arg_values = tree.flatten(kwargs)
|
| 121 |
+
auto_config = True
|
| 122 |
+
for value in flat_arg_values:
|
| 123 |
+
if not isinstance(value, supported_types):
|
| 124 |
+
auto_config = False
|
| 125 |
+
break
|
| 126 |
+
except TypeError:
|
| 127 |
+
auto_config = False
|
| 128 |
+
try:
|
| 129 |
+
instance._lock = False
|
| 130 |
+
if auto_config:
|
| 131 |
+
from keras.src.saving import serialization_lib
|
| 132 |
+
|
| 133 |
+
instance._auto_config = serialization_lib.SerializableDict(
|
| 134 |
+
**kwargs
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
instance._auto_config = None
|
| 138 |
+
instance._lock = True
|
| 139 |
+
except RecursionError:
|
| 140 |
+
# Setting an instance attribute in __new__ has the potential
|
| 141 |
+
# to trigger an infinite recursion if a subclass overrides
|
| 142 |
+
# setattr in an unsafe way.
|
| 143 |
+
pass
|
| 144 |
+
return instance
|
| 145 |
+
|
| 146 |
+
@python_utils.default
|
| 147 |
+
def get_config(self):
|
| 148 |
+
"""Returns the config of the object.
|
| 149 |
+
|
| 150 |
+
An object config is a Python dictionary (serializable)
|
| 151 |
+
containing the information needed to re-instantiate it.
|
| 152 |
+
"""
|
| 153 |
+
config = {
|
| 154 |
+
"name": self.name,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
if not python_utils.is_default(self.get_config):
|
| 158 |
+
# In this case the subclass implements get_config()
|
| 159 |
+
return config
|
| 160 |
+
|
| 161 |
+
# In this case the subclass doesn't implement get_config():
|
| 162 |
+
# Let's see if we can autogenerate it.
|
| 163 |
+
if getattr(self, "_auto_config", None) is not None:
|
| 164 |
+
xtra_args = set(config.keys())
|
| 165 |
+
config.update(self._auto_config.config)
|
| 166 |
+
# Remove args non explicitly supported
|
| 167 |
+
argspec = inspect.getfullargspec(self.__init__)
|
| 168 |
+
if argspec.varkw != "kwargs":
|
| 169 |
+
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
|
| 170 |
+
config.pop(key, None)
|
| 171 |
+
return config
|
| 172 |
+
else:
|
| 173 |
+
raise NotImplementedError(
|
| 174 |
+
textwrap.dedent(
|
| 175 |
+
f"""
|
| 176 |
+
Object {self.__class__.__name__} was created by passing
|
| 177 |
+
non-serializable argument values in `__init__()`,
|
| 178 |
+
and therefore the object must override `get_config()` in
|
| 179 |
+
order to be serializable. Please implement `get_config()`.
|
| 180 |
+
|
| 181 |
+
Example:
|
| 182 |
+
|
| 183 |
+
class CustomLayer(keras.layers.Layer):
|
| 184 |
+
def __init__(self, arg1, arg2, **kwargs):
|
| 185 |
+
super().__init__(**kwargs)
|
| 186 |
+
self.arg1 = arg1
|
| 187 |
+
self.arg2 = arg2
|
| 188 |
+
|
| 189 |
+
def get_config(self):
|
| 190 |
+
config = super().get_config()
|
| 191 |
+
config.update({{
|
| 192 |
+
"arg1": self.arg1,
|
| 193 |
+
"arg2": self.arg2,
|
| 194 |
+
}})
|
| 195 |
+
return config"""
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
@classmethod
|
| 200 |
+
def from_config(cls, config):
|
| 201 |
+
"""Creates an operation from its config.
|
| 202 |
+
|
| 203 |
+
This method is the reverse of `get_config`, capable of instantiating the
|
| 204 |
+
same operation from the config dictionary.
|
| 205 |
+
|
| 206 |
+
Note: If you override this method, you might receive a serialized dtype
|
| 207 |
+
config, which is a `dict`. You can deserialize it as follows:
|
| 208 |
+
|
| 209 |
+
```python
|
| 210 |
+
if "dtype" in config and isinstance(config["dtype"], dict):
|
| 211 |
+
policy = dtype_policies.deserialize(config["dtype"])
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
config: A Python dictionary, typically the output of `get_config`.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
An operation instance.
|
| 219 |
+
"""
|
| 220 |
+
# Explicitly deserialize dtype config if needed. This enables users to
|
| 221 |
+
# directly interact with the instance of `DTypePolicy`.
|
| 222 |
+
if "dtype" in config and isinstance(config["dtype"], dict):
|
| 223 |
+
config = config.copy()
|
| 224 |
+
policy = dtype_policies.deserialize(config["dtype"])
|
| 225 |
+
if (
|
| 226 |
+
not isinstance(policy, dtype_policies.DTypePolicyMap)
|
| 227 |
+
and policy.quantization_mode is None
|
| 228 |
+
):
|
| 229 |
+
# For backward compatibility, we use a str (`name`) for
|
| 230 |
+
# `DTypePolicy`
|
| 231 |
+
policy = policy.name
|
| 232 |
+
config["dtype"] = policy
|
| 233 |
+
try:
|
| 234 |
+
return cls(**config)
|
| 235 |
+
except Exception as e:
|
| 236 |
+
raise TypeError(
|
| 237 |
+
f"Error when deserializing class '{cls.__name__}' using "
|
| 238 |
+
f"config={config}.\n\nException encountered: {e}"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def __repr__(self):
|
| 242 |
+
return f"<Operation name={self.name}>"
|
| 243 |
+
|
| 244 |
+
@property
|
| 245 |
+
def input(self):
|
| 246 |
+
"""Retrieves the input tensor(s) of a symbolic operation.
|
| 247 |
+
|
| 248 |
+
Only returns the tensor(s) corresponding to the *first time*
|
| 249 |
+
the operation was called.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
Input tensor or list of input tensors.
|
| 253 |
+
"""
|
| 254 |
+
return self._get_node_attribute_at_index(0, "input_tensors", "input")
|
| 255 |
+
|
| 256 |
+
@property
|
| 257 |
+
def output(self):
|
| 258 |
+
"""Retrieves the output tensor(s) of a layer.
|
| 259 |
+
|
| 260 |
+
Only returns the tensor(s) corresponding to the *first time*
|
| 261 |
+
the operation was called.
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
Output tensor or list of output tensors.
|
| 265 |
+
"""
|
| 266 |
+
return self._get_node_attribute_at_index(0, "output_tensors", "output")
|
| 267 |
+
|
| 268 |
+
def _get_node_attribute_at_index(self, node_index, attr, attr_name):
|
| 269 |
+
"""Private utility to retrieves an attribute (e.g. inputs) from a node.
|
| 270 |
+
|
| 271 |
+
This is used to implement the properties:
|
| 272 |
+
- output
|
| 273 |
+
- input
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
node_index: Integer index of the node from which
|
| 277 |
+
to retrieve the attribute.
|
| 278 |
+
attr: Exact node attribute name.
|
| 279 |
+
attr_name: Human-readable attribute name, for error messages.
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
The operation's attribute `attr` at the node of index `node_index`.
|
| 283 |
+
"""
|
| 284 |
+
if not self._inbound_nodes:
|
| 285 |
+
raise AttributeError(
|
| 286 |
+
f"The layer {self.name} has never been called "
|
| 287 |
+
f"and thus has no defined {attr_name}."
|
| 288 |
+
)
|
| 289 |
+
if not len(self._inbound_nodes) > node_index:
|
| 290 |
+
raise ValueError(
|
| 291 |
+
f"Asked to get {attr_name} at node "
|
| 292 |
+
f"{node_index}, but the operation has only "
|
| 293 |
+
f"{len(self._inbound_nodes)} inbound nodes."
|
| 294 |
+
)
|
| 295 |
+
values = getattr(self._inbound_nodes[node_index], attr)
|
| 296 |
+
if isinstance(values, list) and len(values) == 1:
|
| 297 |
+
return values[0]
|
| 298 |
+
else:
|
| 299 |
+
return values
|
| 300 |
+
|
| 301 |
+
# Hooks for backend layer classes
|
| 302 |
+
def _post_build(self):
|
| 303 |
+
"""Can be overridden for per backend post build actions."""
|
| 304 |
+
pass
|
| 305 |
+
|
| 306 |
+
def _setattr_hook(self, name, value):
|
| 307 |
+
"""Can be overridden for per backend post build actions."""
|
| 308 |
+
return name, value
|
| 309 |
+
|
| 310 |
+
def _post_track_variable(self, variable):
|
| 311 |
+
"""Can be overridden for per backend post track actions."""
|
| 312 |
+
pass
|
| 313 |
+
|
| 314 |
+
def _post_untrack_variable(self, variable):
|
| 315 |
+
"""Can be overridden for per backend post untrack actions."""
|
| 316 |
+
pass
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/operation_utils.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from keras.src import tree
|
| 6 |
+
from keras.src.api_export import keras_export
|
| 7 |
+
from keras.src.backend.common.backend_utils import canonicalize_axis
|
| 8 |
+
from keras.src.backend.common.backend_utils import to_tuple_or_list
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def broadcast_shapes(shape1, shape2):
|
| 12 |
+
"""Broadcast input shapes to a unified shape.
|
| 13 |
+
|
| 14 |
+
Convert to list for mutability.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
shape1: A tuple or list of integers.
|
| 18 |
+
shape2: A tuple or list of integers.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
output_shape (list of integers or `None`): The broadcasted shape.
|
| 22 |
+
|
| 23 |
+
Example:
|
| 24 |
+
>>> broadcast_shapes((5, 3), (1, 3))
|
| 25 |
+
[5, 3]
|
| 26 |
+
"""
|
| 27 |
+
shape1 = list(shape1)
|
| 28 |
+
shape2 = list(shape2)
|
| 29 |
+
origin_shape1 = shape1
|
| 30 |
+
origin_shape2 = shape2
|
| 31 |
+
|
| 32 |
+
if len(shape1) > len(shape2):
|
| 33 |
+
shape2 = [1] * (len(shape1) - len(shape2)) + shape2
|
| 34 |
+
if len(shape1) < len(shape2):
|
| 35 |
+
shape1 = [1] * (len(shape2) - len(shape1)) + shape1
|
| 36 |
+
output_shape = list(shape1)
|
| 37 |
+
for i in range(len(shape1)):
|
| 38 |
+
if shape1[i] == 1:
|
| 39 |
+
output_shape[i] = shape2[i]
|
| 40 |
+
elif shape1[i] is None:
|
| 41 |
+
output_shape[i] = None if shape2[i] == 1 else shape2[i]
|
| 42 |
+
else:
|
| 43 |
+
if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]:
|
| 44 |
+
output_shape[i] = shape1[i]
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(
|
| 47 |
+
"Cannot broadcast shape, the failure dim has value "
|
| 48 |
+
f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. "
|
| 49 |
+
f"Input shapes are: {origin_shape1} and {origin_shape2}."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
return output_shape
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compute_expand_dims_output_shape(input_shape, axis):
|
| 56 |
+
"""Compute the output shape for the `expand_dims` operation.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
input_shape: Input shape.
|
| 60 |
+
axis: int or sequence of ints for the axis to expand.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Tuple of ints: The output shape after the `expand_dims` operation.
|
| 64 |
+
"""
|
| 65 |
+
input_shape = list(input_shape)
|
| 66 |
+
if axis is None:
|
| 67 |
+
axis = len(input_shape)
|
| 68 |
+
axis = to_tuple_or_list(axis)
|
| 69 |
+
out_ndim = len(axis) + len(input_shape)
|
| 70 |
+
axis = [canonicalize_axis(a, out_ndim) for a in axis]
|
| 71 |
+
shape_iter = iter(input_shape)
|
| 72 |
+
new_shape = [
|
| 73 |
+
1 if ax in axis else next(shape_iter) for ax in range(out_ndim)
|
| 74 |
+
]
|
| 75 |
+
return tuple(new_shape)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def compute_pooling_output_shape(
|
| 79 |
+
input_shape,
|
| 80 |
+
pool_size,
|
| 81 |
+
strides,
|
| 82 |
+
padding="valid",
|
| 83 |
+
data_format="channels_last",
|
| 84 |
+
):
|
| 85 |
+
"""Computes the output shape of pooling operations.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
input_shape: Input shape. Must be a tuple of integers.
|
| 89 |
+
pool_size: Size of the pooling operation. Must be a tuple of integers.
|
| 90 |
+
strides: Stride of the pooling operation. Must be a tuple of integers.
|
| 91 |
+
Defaults to `pool_size`.
|
| 92 |
+
padding: Padding method. Available methods are `"valid"` or `"same"`.
|
| 93 |
+
Defaults to `"valid"`.
|
| 94 |
+
data_format: String, either `"channels_last"` or `"channels_first"`.
|
| 95 |
+
The ordering of the dimensions in the inputs. `"channels_last"`
|
| 96 |
+
corresponds to inputs with shape `(batch, height, width, channels)`
|
| 97 |
+
while `"channels_first"` corresponds to inputs with shape
|
| 98 |
+
`(batch, channels, height, weight)`. Defaults to `"channels_last"`.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Tuple of ints: The output shape of the pooling operation.
|
| 102 |
+
|
| 103 |
+
Examples:
|
| 104 |
+
|
| 105 |
+
# Basic usage with square pooling on a single image
|
| 106 |
+
>>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2))
|
| 107 |
+
(1, 2, 2, 1)
|
| 108 |
+
|
| 109 |
+
# Strided pooling on a single image with strides different from pool_size
|
| 110 |
+
>>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2), strides=(1, 1))
|
| 111 |
+
(1, 3, 3, 1)
|
| 112 |
+
|
| 113 |
+
# Pooling on a batch of images
|
| 114 |
+
>>> compute_pooling_output_shape((32, 4, 4, 3), (2, 2))
|
| 115 |
+
(32, 2, 2, 3)
|
| 116 |
+
"""
|
| 117 |
+
strides = pool_size if strides is None else strides
|
| 118 |
+
input_shape_origin = list(input_shape)
|
| 119 |
+
input_shape = np.array(input_shape)
|
| 120 |
+
if data_format == "channels_last":
|
| 121 |
+
spatial_shape = input_shape[1:-1]
|
| 122 |
+
else:
|
| 123 |
+
spatial_shape = input_shape[2:]
|
| 124 |
+
none_dims = []
|
| 125 |
+
for i in range(len(spatial_shape)):
|
| 126 |
+
if spatial_shape[i] is None:
|
| 127 |
+
# Set `None` shape to a manual value so that we can run numpy
|
| 128 |
+
# computation on `spatial_shape`.
|
| 129 |
+
spatial_shape[i] = -1
|
| 130 |
+
none_dims.append(i)
|
| 131 |
+
pool_size = np.array(pool_size)
|
| 132 |
+
if padding == "valid":
|
| 133 |
+
output_spatial_shape = (
|
| 134 |
+
np.floor((spatial_shape - pool_size) / strides) + 1
|
| 135 |
+
)
|
| 136 |
+
for i in range(len(output_spatial_shape)):
|
| 137 |
+
if i not in none_dims and output_spatial_shape[i] < 0:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
"Computed output size would be negative. Received: "
|
| 140 |
+
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
|
| 141 |
+
)
|
| 142 |
+
elif padding == "same":
|
| 143 |
+
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
| 144 |
+
else:
|
| 145 |
+
raise ValueError(
|
| 146 |
+
"Argument `padding` must be either 'valid' or 'same'. Received: "
|
| 147 |
+
f"padding={padding}"
|
| 148 |
+
)
|
| 149 |
+
output_spatial_shape = [int(i) for i in output_spatial_shape]
|
| 150 |
+
for i in none_dims:
|
| 151 |
+
output_spatial_shape[i] = None
|
| 152 |
+
output_spatial_shape = tuple(output_spatial_shape)
|
| 153 |
+
if data_format == "channels_last":
|
| 154 |
+
output_shape = (
|
| 155 |
+
(input_shape_origin[0],)
|
| 156 |
+
+ output_spatial_shape
|
| 157 |
+
+ (input_shape_origin[-1],)
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
output_shape = (
|
| 161 |
+
input_shape_origin[0],
|
| 162 |
+
input_shape_origin[1],
|
| 163 |
+
) + output_spatial_shape
|
| 164 |
+
return output_shape
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def compute_conv_output_shape(
|
| 168 |
+
input_shape,
|
| 169 |
+
filters,
|
| 170 |
+
kernel_size,
|
| 171 |
+
strides=1,
|
| 172 |
+
padding="valid",
|
| 173 |
+
data_format="channels_last",
|
| 174 |
+
dilation_rate=1,
|
| 175 |
+
):
|
| 176 |
+
"""Compute the output shape of conv ops."""
|
| 177 |
+
if data_format == "channels_last":
|
| 178 |
+
spatial_shape = input_shape[1:-1]
|
| 179 |
+
kernel_shape = kernel_size + (input_shape[-1], filters)
|
| 180 |
+
else:
|
| 181 |
+
spatial_shape = input_shape[2:]
|
| 182 |
+
kernel_shape = kernel_size + (input_shape[1], filters)
|
| 183 |
+
if len(kernel_shape) != len(input_shape):
|
| 184 |
+
raise ValueError(
|
| 185 |
+
"Kernel shape must have the same length as input, but received "
|
| 186 |
+
f"kernel of shape {kernel_shape} and "
|
| 187 |
+
f"input of shape {input_shape}."
|
| 188 |
+
)
|
| 189 |
+
if isinstance(dilation_rate, int):
|
| 190 |
+
dilation_rate = (dilation_rate,) * len(spatial_shape)
|
| 191 |
+
if isinstance(strides, int):
|
| 192 |
+
strides = (strides,) * len(spatial_shape)
|
| 193 |
+
if len(dilation_rate) != len(spatial_shape):
|
| 194 |
+
raise ValueError(
|
| 195 |
+
"Dilation must be None, scalar or tuple/list of length of "
|
| 196 |
+
"inputs' spatial shape, but received "
|
| 197 |
+
f"`dilation_rate={dilation_rate}` and "
|
| 198 |
+
f"input of shape {input_shape}."
|
| 199 |
+
)
|
| 200 |
+
none_dims = []
|
| 201 |
+
spatial_shape = np.array(spatial_shape)
|
| 202 |
+
for i in range(len(spatial_shape)):
|
| 203 |
+
if spatial_shape[i] is None:
|
| 204 |
+
# Set `None` shape to a manual value so that we can run numpy
|
| 205 |
+
# computation on `spatial_shape`.
|
| 206 |
+
spatial_shape[i] = -1
|
| 207 |
+
none_dims.append(i)
|
| 208 |
+
|
| 209 |
+
kernel_spatial_shape = np.array(kernel_shape[:-2])
|
| 210 |
+
dilation_rate = np.array(dilation_rate)
|
| 211 |
+
if padding == "valid":
|
| 212 |
+
output_spatial_shape = (
|
| 213 |
+
np.floor(
|
| 214 |
+
(spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)
|
| 215 |
+
/ strides
|
| 216 |
+
)
|
| 217 |
+
+ 1
|
| 218 |
+
)
|
| 219 |
+
for i in range(len(output_spatial_shape)):
|
| 220 |
+
if i not in none_dims and output_spatial_shape[i] < 0:
|
| 221 |
+
raise ValueError(
|
| 222 |
+
"Computed output size would be negative. Received "
|
| 223 |
+
f"`inputs shape={input_shape}`, "
|
| 224 |
+
f"`kernel shape={kernel_shape}`, "
|
| 225 |
+
f"`dilation_rate={dilation_rate}`."
|
| 226 |
+
)
|
| 227 |
+
elif padding == "same" or padding == "causal":
|
| 228 |
+
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
| 229 |
+
else:
|
| 230 |
+
raise ValueError(
|
| 231 |
+
"`padding` must be either `'valid'` or `'same'`. Received "
|
| 232 |
+
f"{padding}."
|
| 233 |
+
)
|
| 234 |
+
output_spatial_shape = [int(i) for i in output_spatial_shape]
|
| 235 |
+
for i in none_dims:
|
| 236 |
+
output_spatial_shape[i] = None
|
| 237 |
+
output_spatial_shape = tuple(output_spatial_shape)
|
| 238 |
+
if data_format == "channels_last":
|
| 239 |
+
output_shape = (
|
| 240 |
+
(input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape
|
| 244 |
+
return output_shape
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def compute_matmul_output_shape(shape1, shape2):
|
| 248 |
+
"""Compute the output shape of a `matmul` operation.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
shape1: Shape of the left operand.
|
| 252 |
+
shape2: Shape of the right operand.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Tuple of ints: The output shape for the `matmul` operation.
|
| 256 |
+
"""
|
| 257 |
+
if len(shape1) == 1:
|
| 258 |
+
shape1 = (1, shape1[0])
|
| 259 |
+
if len(shape2) == 1:
|
| 260 |
+
shape2 = (shape2[0], 1)
|
| 261 |
+
if (
|
| 262 |
+
shape1[-1] is not None
|
| 263 |
+
and shape2[-2] is not None
|
| 264 |
+
and shape1[-1] != shape2[-2]
|
| 265 |
+
):
|
| 266 |
+
raise ValueError(
|
| 267 |
+
"Inner dimensions (`x1.shape[-1]` and `x2.shape[-2]`) must be "
|
| 268 |
+
f"equal, but received `x1.shape={shape1}` and "
|
| 269 |
+
f"`x2.shape={shape2}`."
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
leading_shape = broadcast_shapes(shape1[:-2], shape2[:-2])
|
| 273 |
+
last_2_dims_shape = [shape1[-2], shape2[-1]]
|
| 274 |
+
output_shape = leading_shape + last_2_dims_shape
|
| 275 |
+
if len(shape1) == 1:
|
| 276 |
+
del output_shape[-2]
|
| 277 |
+
if len(shape2) == 1:
|
| 278 |
+
del output_shape[-1]
|
| 279 |
+
return tuple(output_shape)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def compute_reshape_output_shape(input_shape, newshape, newshape_arg_name):
|
| 283 |
+
"""Converts `-1` in `newshape` to either an actual dimension or `None`.
|
| 284 |
+
|
| 285 |
+
This utility does not special case the 0th dimension (batch size).
|
| 286 |
+
"""
|
| 287 |
+
unknown_dim_count = newshape.count(-1)
|
| 288 |
+
if unknown_dim_count > 1:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
"There must be at most one unknown dimension (-1) in "
|
| 291 |
+
f"{newshape_arg_name}. Received: {newshape_arg_name}={newshape}."
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# If there is a None in input_shape, we can't infer what the -1 is
|
| 295 |
+
if None in input_shape:
|
| 296 |
+
return tuple(dim if dim != -1 else None for dim in newshape)
|
| 297 |
+
|
| 298 |
+
input_size = math.prod(input_shape)
|
| 299 |
+
# If the `newshape` is fully defined, return it
|
| 300 |
+
if unknown_dim_count == 0:
|
| 301 |
+
if input_size != math.prod(newshape):
|
| 302 |
+
raise ValueError(
|
| 303 |
+
"The total size of the tensor must be unchanged. Received: "
|
| 304 |
+
f"input_shape={input_shape}, {newshape_arg_name}={newshape}"
|
| 305 |
+
)
|
| 306 |
+
return newshape
|
| 307 |
+
|
| 308 |
+
# We have one -1 in `newshape`, compute the actual value
|
| 309 |
+
known_output_size = 1
|
| 310 |
+
unknown_dim_index = None
|
| 311 |
+
for index, dim in enumerate(newshape):
|
| 312 |
+
if dim == -1:
|
| 313 |
+
unknown_dim_index = index
|
| 314 |
+
else:
|
| 315 |
+
known_output_size *= dim
|
| 316 |
+
|
| 317 |
+
if known_output_size == 0 or input_size % known_output_size != 0:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
"The total size of the tensor must be unchanged, however, the "
|
| 320 |
+
"input size cannot by divided by the specified dimensions in "
|
| 321 |
+
f"{newshape_arg_name}. Received: input_shape={input_shape}, "
|
| 322 |
+
f"{newshape_arg_name}={newshape}"
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
output_shape = list(newshape)
|
| 326 |
+
output_shape[unknown_dim_index] = input_size // known_output_size
|
| 327 |
+
return tuple(output_shape)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def compute_transpose_output_shape(input_shape, axes):
|
| 331 |
+
"""Compute the output shape for the `transpose` operation.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
input_shape: Input shape.
|
| 335 |
+
axes: Permutation of the dimensions for the `transpose` operation.
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
Tuple of ints: The output shape after the `transpose` operation.
|
| 339 |
+
"""
|
| 340 |
+
input_shape = list(input_shape)
|
| 341 |
+
if axes is None:
|
| 342 |
+
return tuple(input_shape[::-1])
|
| 343 |
+
|
| 344 |
+
if len(axes) != len(input_shape):
|
| 345 |
+
raise ValueError(
|
| 346 |
+
"axis must be a list of the same length as the input shape, "
|
| 347 |
+
f"expected {len(input_shape)}, but received {len(axes)}."
|
| 348 |
+
)
|
| 349 |
+
return tuple(input_shape[ax] for ax in axes)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def compute_take_along_axis_output_shape(input_shape, indices_shape, axis):
|
| 353 |
+
input_shape = list(input_shape)
|
| 354 |
+
indices_shape = list(indices_shape)
|
| 355 |
+
if axis is None:
|
| 356 |
+
input_shape = (
|
| 357 |
+
[None] if None in input_shape else [int(np.prod(input_shape))]
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if len(input_shape) != len(indices_shape):
|
| 361 |
+
raise ValueError(
|
| 362 |
+
"`x` and `indices` must have the same number of dimensions, "
|
| 363 |
+
f"but receive shape {input_shape} and {indices_shape}."
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
input_shape[axis] = indices_shape[axis]
|
| 367 |
+
output_shape = broadcast_shapes(input_shape, indices_shape)
|
| 368 |
+
return output_shape
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def reduce_shape(shape, axis=None, keepdims=False):
|
| 372 |
+
shape = list(shape)
|
| 373 |
+
if axis is None:
|
| 374 |
+
if keepdims:
|
| 375 |
+
return tuple([1 for _ in shape])
|
| 376 |
+
else:
|
| 377 |
+
return tuple([])
|
| 378 |
+
|
| 379 |
+
if keepdims:
|
| 380 |
+
for ax in axis:
|
| 381 |
+
shape[ax] = 1
|
| 382 |
+
return tuple(shape)
|
| 383 |
+
else:
|
| 384 |
+
for ax in sorted(axis, reverse=True):
|
| 385 |
+
del shape[ax]
|
| 386 |
+
return tuple(shape)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
@keras_export("keras.utils.get_source_inputs")
|
| 390 |
+
def get_source_inputs(tensor):
|
| 391 |
+
"""Returns the list of input tensors necessary to compute `tensor`.
|
| 392 |
+
|
| 393 |
+
Output will always be a list of tensors
|
| 394 |
+
(potentially with 1 element).
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
tensor: The tensor to start from.
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
List of input tensors.
|
| 401 |
+
"""
|
| 402 |
+
if not hasattr(tensor, "_keras_history"):
|
| 403 |
+
return tensor
|
| 404 |
+
|
| 405 |
+
operation, node_index, _ = tensor._keras_history
|
| 406 |
+
if not operation or not operation._inbound_nodes:
|
| 407 |
+
return [tensor]
|
| 408 |
+
else:
|
| 409 |
+
node = operation._inbound_nodes[node_index]
|
| 410 |
+
if node.is_input:
|
| 411 |
+
# Reached input node, stop recursion.
|
| 412 |
+
return tree.flatten(node.output_tensors)
|
| 413 |
+
else:
|
| 414 |
+
source_tensors = []
|
| 415 |
+
for tensor in node.input_tensors:
|
| 416 |
+
previous_sources = get_source_inputs(tensor)
|
| 417 |
+
# Avoid input redundancy.
|
| 418 |
+
for x in previous_sources:
|
| 419 |
+
if all(x is not t for t in source_tensors):
|
| 420 |
+
source_tensors.append(x)
|
| 421 |
+
return source_tensors
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/ops/symbolic_arguments.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import tree
|
| 2 |
+
from keras.src.backend import KerasTensor
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SymbolicArguments:
|
| 6 |
+
def __init__(self, *args, **kwargs):
|
| 7 |
+
self.args = tree.map_structure(lambda x: x, args)
|
| 8 |
+
self.kwargs = tree.map_structure(lambda x: x, kwargs)
|
| 9 |
+
self._flat_arguments = tree.flatten((self.args, self.kwargs))
|
| 10 |
+
|
| 11 |
+
# Used to avoid expensive `tree` operations in the most common case.
|
| 12 |
+
if (
|
| 13 |
+
not self.kwargs
|
| 14 |
+
and len(self.args) == 1
|
| 15 |
+
and isinstance(self.args[0], KerasTensor)
|
| 16 |
+
):
|
| 17 |
+
self._single_positional_tensor = self.args[0]
|
| 18 |
+
else:
|
| 19 |
+
self._single_positional_tensor = None
|
| 20 |
+
|
| 21 |
+
self.keras_tensors = []
|
| 22 |
+
for arg in self._flat_arguments:
|
| 23 |
+
if isinstance(arg, KerasTensor):
|
| 24 |
+
self.keras_tensors.append(arg)
|
| 25 |
+
|
| 26 |
+
def convert(self, conversion_fn):
|
| 27 |
+
args = tree.map_structure(conversion_fn, self.args)
|
| 28 |
+
kwargs = tree.map_structure(conversion_fn, self.kwargs)
|
| 29 |
+
return args, kwargs
|
| 30 |
+
|
| 31 |
+
def fill_in(self, tensor_dict):
|
| 32 |
+
"""Maps KerasTensors to computed values using `tensor_dict`.
|
| 33 |
+
|
| 34 |
+
`tensor_dict` maps `KerasTensor` instances to their current values.
|
| 35 |
+
"""
|
| 36 |
+
if self._single_positional_tensor is not None:
|
| 37 |
+
# Performance optimization for most common case.
|
| 38 |
+
# Approx. 70x faster.
|
| 39 |
+
return (tensor_dict[id(self._single_positional_tensor)],), {}
|
| 40 |
+
|
| 41 |
+
def switch_fn(x):
|
| 42 |
+
if isinstance(x, KerasTensor):
|
| 43 |
+
return tensor_dict.get(id(x), None)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
return self.convert(switch_fn)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__init__.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src.api_export import keras_export
|
| 2 |
+
from keras.src.optimizers.adadelta import Adadelta
|
| 3 |
+
from keras.src.optimizers.adafactor import Adafactor
|
| 4 |
+
from keras.src.optimizers.adagrad import Adagrad
|
| 5 |
+
from keras.src.optimizers.adam import Adam
|
| 6 |
+
from keras.src.optimizers.adamax import Adamax
|
| 7 |
+
from keras.src.optimizers.adamw import AdamW
|
| 8 |
+
from keras.src.optimizers.ftrl import Ftrl
|
| 9 |
+
from keras.src.optimizers.lion import Lion
|
| 10 |
+
from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer
|
| 11 |
+
from keras.src.optimizers.nadam import Nadam
|
| 12 |
+
from keras.src.optimizers.optimizer import Optimizer
|
| 13 |
+
from keras.src.optimizers.rmsprop import RMSprop
|
| 14 |
+
from keras.src.optimizers.sgd import SGD
|
| 15 |
+
from keras.src.saving import serialization_lib
|
| 16 |
+
|
| 17 |
+
ALL_OBJECTS = {
|
| 18 |
+
Optimizer,
|
| 19 |
+
Adam,
|
| 20 |
+
SGD,
|
| 21 |
+
RMSprop,
|
| 22 |
+
Adadelta,
|
| 23 |
+
AdamW,
|
| 24 |
+
Adagrad,
|
| 25 |
+
Adamax,
|
| 26 |
+
Adafactor,
|
| 27 |
+
Nadam,
|
| 28 |
+
Ftrl,
|
| 29 |
+
Lion,
|
| 30 |
+
LossScaleOptimizer,
|
| 31 |
+
}
|
| 32 |
+
ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@keras_export("keras.optimizers.serialize")
|
| 36 |
+
def serialize(optimizer):
|
| 37 |
+
"""Returns the optimizer configuration as a Python dict.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
optimizer: An `Optimizer` instance to serialize.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Python dict which contains the configuration of the optimizer.
|
| 44 |
+
"""
|
| 45 |
+
return serialization_lib.serialize_keras_object(optimizer)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@keras_export("keras.optimizers.deserialize")
|
| 49 |
+
def deserialize(config, custom_objects=None):
|
| 50 |
+
"""Returns a Keras optimizer object via its configuration.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
config: Optimizer configuration dictionary.
|
| 54 |
+
custom_objects: Optional dictionary mapping names (strings) to custom
|
| 55 |
+
objects (classes and functions) to be considered during
|
| 56 |
+
deserialization.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
A Keras Optimizer instance.
|
| 60 |
+
"""
|
| 61 |
+
# Make deserialization case-insensitive for built-in optimizers.
|
| 62 |
+
if config["class_name"].lower() in ALL_OBJECTS_DICT:
|
| 63 |
+
config["class_name"] = config["class_name"].lower()
|
| 64 |
+
|
| 65 |
+
return serialization_lib.deserialize_keras_object(
|
| 66 |
+
config,
|
| 67 |
+
module_objects=ALL_OBJECTS_DICT,
|
| 68 |
+
custom_objects=custom_objects,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@keras_export("keras.optimizers.get")
|
| 73 |
+
def get(identifier):
|
| 74 |
+
"""Retrieves a Keras Optimizer instance.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
identifier: Optimizer identifier, one of:
|
| 78 |
+
- String: name of an optimizer
|
| 79 |
+
- Dictionary: configuration dictionary.
|
| 80 |
+
- Keras Optimizer instance (it will be returned unchanged).
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
A Keras Optimizer instance.
|
| 84 |
+
"""
|
| 85 |
+
if identifier is None:
|
| 86 |
+
return None
|
| 87 |
+
elif isinstance(identifier, dict):
|
| 88 |
+
obj = deserialize(identifier)
|
| 89 |
+
elif isinstance(identifier, str):
|
| 90 |
+
config = {"class_name": identifier, "config": {}}
|
| 91 |
+
obj = deserialize(config)
|
| 92 |
+
else:
|
| 93 |
+
obj = identifier
|
| 94 |
+
|
| 95 |
+
if isinstance(obj, Optimizer):
|
| 96 |
+
return obj
|
| 97 |
+
raise ValueError(f"Could not interpret optimizer identifier: {identifier}")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# We will add this temporarily so that tensorflow packages that depend on
|
| 101 |
+
# estimators will continue to import (there are a large number). Note that
|
| 102 |
+
# Keras 3 will not work with the estimators API.
|
| 103 |
+
@keras_export(
|
| 104 |
+
[
|
| 105 |
+
"keras.optimizers.legacy.Adagrad",
|
| 106 |
+
"keras.optimizers.legacy.Adam",
|
| 107 |
+
"keras.optimizers.legacy.Ftrl",
|
| 108 |
+
"keras.optimizers.legacy.RMSprop",
|
| 109 |
+
"keras.optimizers.legacy.SGD",
|
| 110 |
+
"keras.optimizers.legacy.Optimizer",
|
| 111 |
+
]
|
| 112 |
+
)
|
| 113 |
+
class LegacyOptimizerWarning:
|
| 114 |
+
def __init__(self, *args, **kwargs):
|
| 115 |
+
raise ImportError(
|
| 116 |
+
"`keras.optimizers.legacy` is not supported in Keras 3. When using "
|
| 117 |
+
"`tf.keras`, to continue using a `tf.keras.optimizers.legacy` "
|
| 118 |
+
"optimizer, you can install the `tf_keras` package (Keras 2) and "
|
| 119 |
+
"set the environment variable `TF_USE_LEGACY_KERAS=True` to "
|
| 120 |
+
"configure TensorFlow to use `tf_keras` when accessing `tf.keras`."
|
| 121 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (3.92 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adadelta.cpython-310.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adafactor.cpython-310.pyc
ADDED
|
Binary file (5.65 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adagrad.cpython-310.pyc
ADDED
|
Binary file (3.54 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adam.cpython-310.pyc
ADDED
|
Binary file (4.9 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adamax.cpython-310.pyc
ADDED
|
Binary file (4.57 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/adamw.cpython-310.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/base_optimizer.cpython-310.pyc
ADDED
|
Binary file (35.5 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/ftrl.cpython-310.pyc
ADDED
|
Binary file (6.72 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/lamb.cpython-310.pyc
ADDED
|
Binary file (4.42 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/lion.cpython-310.pyc
ADDED
|
Binary file (4.46 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/loss_scale_optimizer.cpython-310.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/nadam.cpython-310.pyc
ADDED
|
Binary file (4.94 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/optimizer.cpython-310.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/rmsprop.cpython-310.pyc
ADDED
|
Binary file (4.81 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/__pycache__/sgd.cpython-310.pyc
ADDED
|
Binary file (3.84 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adadelta.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import ops
|
| 2 |
+
from keras.src.api_export import keras_export
|
| 3 |
+
from keras.src.optimizers import optimizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@keras_export(["keras.optimizers.Adadelta"])
|
| 7 |
+
class Adadelta(optimizer.Optimizer):
|
| 8 |
+
"""Optimizer that implements the Adadelta algorithm.
|
| 9 |
+
|
| 10 |
+
Adadelta optimization is a stochastic gradient descent method that is based
|
| 11 |
+
on adaptive learning rate per dimension to address two drawbacks:
|
| 12 |
+
|
| 13 |
+
- The continual decay of learning rates throughout training.
|
| 14 |
+
- The need for a manually selected global learning rate.
|
| 15 |
+
|
| 16 |
+
Adadelta is a more robust extension of Adagrad that adapts learning rates
|
| 17 |
+
based on a moving window of gradient updates, instead of accumulating all
|
| 18 |
+
past gradients. This way, Adadelta continues learning even when many updates
|
| 19 |
+
have been done. Compared to Adagrad, in the original version of Adadelta you
|
| 20 |
+
don't have to set an initial learning rate. In this version, the initial
|
| 21 |
+
learning rate can be set, as in most other Keras optimizers.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
learning_rate: A float, a
|
| 25 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 26 |
+
a callable that takes no arguments and returns the actual value to
|
| 27 |
+
use. The learning rate. Defaults to `0.001`. Note that `Adadelta`
|
| 28 |
+
tends to benefit from higher initial learning rate values compared
|
| 29 |
+
to other optimizers. To match the exact form in the original paper,
|
| 30 |
+
use 1.0.
|
| 31 |
+
rho: A floating point value. The decay rate. Defaults to `0.95`.
|
| 32 |
+
epsilon: Small floating point value for maintaining numerical stability.
|
| 33 |
+
{{base_optimizer_keyword_args}}
|
| 34 |
+
|
| 35 |
+
Reference:
|
| 36 |
+
|
| 37 |
+
- [Zeiler, 2012](http://arxiv.org/abs/1212.5701)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
learning_rate=0.001,
|
| 43 |
+
rho=0.95,
|
| 44 |
+
epsilon=1e-7,
|
| 45 |
+
weight_decay=None,
|
| 46 |
+
clipnorm=None,
|
| 47 |
+
clipvalue=None,
|
| 48 |
+
global_clipnorm=None,
|
| 49 |
+
use_ema=False,
|
| 50 |
+
ema_momentum=0.99,
|
| 51 |
+
ema_overwrite_frequency=None,
|
| 52 |
+
loss_scale_factor=None,
|
| 53 |
+
gradient_accumulation_steps=None,
|
| 54 |
+
name="adadelta",
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
super().__init__(
|
| 58 |
+
learning_rate=learning_rate,
|
| 59 |
+
weight_decay=weight_decay,
|
| 60 |
+
clipnorm=clipnorm,
|
| 61 |
+
clipvalue=clipvalue,
|
| 62 |
+
global_clipnorm=global_clipnorm,
|
| 63 |
+
use_ema=use_ema,
|
| 64 |
+
ema_momentum=ema_momentum,
|
| 65 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 66 |
+
loss_scale_factor=loss_scale_factor,
|
| 67 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 68 |
+
name=name,
|
| 69 |
+
**kwargs,
|
| 70 |
+
)
|
| 71 |
+
self.rho = rho
|
| 72 |
+
self.epsilon = epsilon
|
| 73 |
+
|
| 74 |
+
def build(self, var_list):
|
| 75 |
+
if self.built:
|
| 76 |
+
return
|
| 77 |
+
super().build(var_list)
|
| 78 |
+
self._accumulated_grads = []
|
| 79 |
+
self._accumulated_delta_vars = []
|
| 80 |
+
for var in var_list:
|
| 81 |
+
self._accumulated_grads.append(
|
| 82 |
+
self.add_variable_from_reference(var, "accumulated_grad")
|
| 83 |
+
)
|
| 84 |
+
self._accumulated_delta_vars.append(
|
| 85 |
+
self.add_variable_from_reference(var, "accumulated_delta_var")
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def update_step(self, grad, variable, learning_rate):
|
| 89 |
+
"""Update step given gradient and the associated model variable."""
|
| 90 |
+
lr = ops.cast(learning_rate, variable.dtype)
|
| 91 |
+
grad = ops.cast(grad, variable.dtype)
|
| 92 |
+
|
| 93 |
+
rho = self.rho
|
| 94 |
+
accumulated_grad = self._accumulated_grads[
|
| 95 |
+
self._get_variable_index(variable)
|
| 96 |
+
]
|
| 97 |
+
accumulated_delta_var = self._accumulated_delta_vars[
|
| 98 |
+
self._get_variable_index(variable)
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
def rms(x):
|
| 102 |
+
return ops.sqrt(ops.add(x, self.epsilon))
|
| 103 |
+
|
| 104 |
+
self.assign(
|
| 105 |
+
accumulated_grad,
|
| 106 |
+
ops.add(
|
| 107 |
+
rho * accumulated_grad, ops.multiply(1 - rho, ops.square(grad))
|
| 108 |
+
),
|
| 109 |
+
)
|
| 110 |
+
delta_var = ops.negative(
|
| 111 |
+
ops.divide(
|
| 112 |
+
ops.multiply(rms(accumulated_delta_var), grad),
|
| 113 |
+
rms(accumulated_grad),
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
self.assign(
|
| 117 |
+
accumulated_delta_var,
|
| 118 |
+
ops.add(
|
| 119 |
+
ops.multiply(rho, accumulated_delta_var),
|
| 120 |
+
ops.multiply(1 - rho, ops.square(delta_var)),
|
| 121 |
+
),
|
| 122 |
+
)
|
| 123 |
+
self.assign_add(variable, ops.multiply(lr, delta_var))
|
| 124 |
+
|
| 125 |
+
def get_config(self):
|
| 126 |
+
config = super().get_config()
|
| 127 |
+
|
| 128 |
+
config.update(
|
| 129 |
+
{
|
| 130 |
+
"rho": self.rho,
|
| 131 |
+
"epsilon": self.epsilon,
|
| 132 |
+
}
|
| 133 |
+
)
|
| 134 |
+
return config
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
Adadelta.__doc__ = Adadelta.__doc__.replace(
|
| 138 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 139 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adafactor.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src import ops
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.optimizers import optimizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@keras_export(["keras.optimizers.Adafactor"])
|
| 8 |
+
class Adafactor(optimizer.Optimizer):
|
| 9 |
+
"""Optimizer that implements the Adafactor algorithm.
|
| 10 |
+
|
| 11 |
+
Adafactor is commonly used in NLP tasks, and has the advantage
|
| 12 |
+
of taking less memory because it only saves partial information of previous
|
| 13 |
+
gradients.
|
| 14 |
+
|
| 15 |
+
The default argument setup is based on the original paper (see reference).
|
| 16 |
+
When gradients are of dimension > 2, Adafactor optimizer will delete the
|
| 17 |
+
last 2 dimensions separately in its accumulator variables.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
learning_rate: A float, a
|
| 21 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 22 |
+
a callable that takes no arguments and returns the actual value to
|
| 23 |
+
use. The learning rate. Defaults to `0.001`.
|
| 24 |
+
beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.
|
| 25 |
+
epsilon_1: float, defaults to 1e-30. A small offset to keep denominator
|
| 26 |
+
away from 0.
|
| 27 |
+
epsilon_2: float, defaults to 1e-3. A small offset to avoid learning
|
| 28 |
+
rate becoming too small by time.
|
| 29 |
+
clip_threshold: float, defaults to 1.0. Clipping threshold. This is a
|
| 30 |
+
part of Adafactor algorithm, independent from `clipnorm`,
|
| 31 |
+
`clipvalue`, and `global_clipnorm`.
|
| 32 |
+
relative_step: bool, defaults to `True`. If `learning_rate` is a
|
| 33 |
+
constant and `relative_step=True`, learning rate will be adjusted
|
| 34 |
+
based on current iterations. This is a default learning rate decay
|
| 35 |
+
in Adafactor.
|
| 36 |
+
{{base_optimizer_keyword_args}}
|
| 37 |
+
|
| 38 |
+
Reference:
|
| 39 |
+
|
| 40 |
+
- [Shazeer, Noam et al., 2018](https://arxiv.org/abs/1804.04235).
|
| 41 |
+
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
learning_rate=0.001,
|
| 47 |
+
beta_2_decay=-0.8,
|
| 48 |
+
epsilon_1=1e-30,
|
| 49 |
+
epsilon_2=1e-3,
|
| 50 |
+
clip_threshold=1.0,
|
| 51 |
+
relative_step=True,
|
| 52 |
+
weight_decay=None,
|
| 53 |
+
clipnorm=None,
|
| 54 |
+
clipvalue=None,
|
| 55 |
+
global_clipnorm=None,
|
| 56 |
+
use_ema=False,
|
| 57 |
+
ema_momentum=0.99,
|
| 58 |
+
ema_overwrite_frequency=None,
|
| 59 |
+
loss_scale_factor=None,
|
| 60 |
+
gradient_accumulation_steps=None,
|
| 61 |
+
name="adafactor",
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
super().__init__(
|
| 65 |
+
learning_rate=learning_rate,
|
| 66 |
+
name=name,
|
| 67 |
+
weight_decay=weight_decay,
|
| 68 |
+
clipnorm=clipnorm,
|
| 69 |
+
clipvalue=clipvalue,
|
| 70 |
+
global_clipnorm=global_clipnorm,
|
| 71 |
+
use_ema=use_ema,
|
| 72 |
+
ema_momentum=ema_momentum,
|
| 73 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 74 |
+
loss_scale_factor=loss_scale_factor,
|
| 75 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 76 |
+
**kwargs,
|
| 77 |
+
)
|
| 78 |
+
self.beta_2_decay = beta_2_decay
|
| 79 |
+
self.epsilon_1 = epsilon_1
|
| 80 |
+
self.epsilon_2 = epsilon_2
|
| 81 |
+
self.clip_threshold = clip_threshold
|
| 82 |
+
self.relative_step = relative_step
|
| 83 |
+
|
| 84 |
+
def build(self, var_list):
|
| 85 |
+
"""Initialize optimizer variables.
|
| 86 |
+
|
| 87 |
+
Adam optimizer has 3 types of variables: momentums, velocities and
|
| 88 |
+
velocity_hat (only set when amsgrad is applied),
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
var_list: list of model variables to build Adam variables on.
|
| 92 |
+
"""
|
| 93 |
+
if self.built:
|
| 94 |
+
return
|
| 95 |
+
super().build(var_list)
|
| 96 |
+
self._r = []
|
| 97 |
+
self._c = []
|
| 98 |
+
self._v = []
|
| 99 |
+
for var in var_list:
|
| 100 |
+
if len(var.shape) < 2:
|
| 101 |
+
# Don't factor if variable is of dimension < 2, but we still
|
| 102 |
+
# need to create dummy variables as placeholder.
|
| 103 |
+
with backend.name_scope(self.name, caller=self):
|
| 104 |
+
self._r.append(
|
| 105 |
+
backend.Variable(0, name=var.name, trainable=False)
|
| 106 |
+
)
|
| 107 |
+
self._c.append(
|
| 108 |
+
backend.Variable(0, name=var.name, trainable=False)
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
# Always factor the last 2 dimensions.
|
| 112 |
+
r_shape = var.shape[:-1]
|
| 113 |
+
c_shape = var.shape[:-2] + (var.shape[-1],)
|
| 114 |
+
self._r.append(
|
| 115 |
+
self.add_variable(
|
| 116 |
+
shape=r_shape,
|
| 117 |
+
dtype=var.dtype,
|
| 118 |
+
name=var.name,
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
self._c.append(
|
| 122 |
+
self.add_variable(
|
| 123 |
+
shape=c_shape,
|
| 124 |
+
dtype=var.dtype,
|
| 125 |
+
name=var.name,
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
self._v.append(
|
| 129 |
+
self.add_variable_from_reference(
|
| 130 |
+
reference_variable=var, name="velocity"
|
| 131 |
+
)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def _rms(self, x):
|
| 135 |
+
return ops.sqrt(ops.mean(ops.square(x)))
|
| 136 |
+
|
| 137 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 138 |
+
"""Update step given gradient and the associated model variable."""
|
| 139 |
+
|
| 140 |
+
lr = ops.cast(learning_rate, variable.dtype)
|
| 141 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 142 |
+
epsilon_2 = ops.cast(self.epsilon_2, variable.dtype)
|
| 143 |
+
one = ops.cast(1.0, variable.dtype)
|
| 144 |
+
local_step = ops.cast(self.iterations + 1, variable.dtype)
|
| 145 |
+
if not callable(self._learning_rate) and self.relative_step:
|
| 146 |
+
lr = ops.minimum(lr, 1 / ops.sqrt(local_step))
|
| 147 |
+
|
| 148 |
+
r = self._r[self._get_variable_index(variable)]
|
| 149 |
+
c = self._c[self._get_variable_index(variable)]
|
| 150 |
+
v = self._v[self._get_variable_index(variable)]
|
| 151 |
+
|
| 152 |
+
rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step))
|
| 153 |
+
alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t
|
| 154 |
+
regulated_grad_square = ops.add(ops.square(gradient), self.epsilon_1)
|
| 155 |
+
beta_2_t = 1 - ops.power(local_step, self.beta_2_decay)
|
| 156 |
+
|
| 157 |
+
if len(variable.shape) >= 2:
|
| 158 |
+
# `r` deletes the last dimension of gradient, so it is of shape
|
| 159 |
+
# `gradient.shape[:-1]`.
|
| 160 |
+
self.assign(
|
| 161 |
+
r,
|
| 162 |
+
beta_2_t * r
|
| 163 |
+
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-1),
|
| 164 |
+
)
|
| 165 |
+
# `c` deletes the second last dimension of gradient, so it is of
|
| 166 |
+
# shape `gradient.shape[:-2] + gradient.shape[-1]`.
|
| 167 |
+
self.assign(
|
| 168 |
+
c,
|
| 169 |
+
beta_2_t * c
|
| 170 |
+
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-2),
|
| 171 |
+
)
|
| 172 |
+
self.assign(
|
| 173 |
+
v,
|
| 174 |
+
ops.expand_dims(
|
| 175 |
+
r / ops.mean(r, axis=-1, keepdims=True), axis=-1
|
| 176 |
+
)
|
| 177 |
+
* ops.expand_dims(c, -2),
|
| 178 |
+
)
|
| 179 |
+
else:
|
| 180 |
+
self.assign(
|
| 181 |
+
v, beta_2_t * v + (1 - beta_2_t) * regulated_grad_square
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
u_t = ops.divide(gradient, ops.sqrt(v))
|
| 185 |
+
u_t_hat = ops.divide(
|
| 186 |
+
u_t,
|
| 187 |
+
ops.maximum(one, ops.divide(self._rms(u_t), self.clip_threshold)),
|
| 188 |
+
)
|
| 189 |
+
self.assign_sub(variable, ops.multiply(alpha_t, u_t_hat))
|
| 190 |
+
|
| 191 |
+
def get_config(self):
|
| 192 |
+
config = super().get_config()
|
| 193 |
+
|
| 194 |
+
config.update(
|
| 195 |
+
{
|
| 196 |
+
"beta_2_decay": self.beta_2_decay,
|
| 197 |
+
"epsilon_1": self.epsilon_1,
|
| 198 |
+
"epsilon_2": self.epsilon_2,
|
| 199 |
+
"clip_threshold": self.clip_threshold,
|
| 200 |
+
"relative_step": self.relative_step,
|
| 201 |
+
}
|
| 202 |
+
)
|
| 203 |
+
return config
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
Adafactor.__doc__ = Adafactor.__doc__.replace(
|
| 207 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 208 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adagrad.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import initializers
|
| 2 |
+
from keras.src import ops
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.optimizers import optimizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@keras_export(["keras.optimizers.Adagrad"])
|
| 8 |
+
class Adagrad(optimizer.Optimizer):
|
| 9 |
+
"""Optimizer that implements the Adagrad algorithm.
|
| 10 |
+
|
| 11 |
+
Adagrad is an optimizer with parameter-specific learning rates,
|
| 12 |
+
which are adapted relative to how frequently a parameter gets
|
| 13 |
+
updated during training. The more updates a parameter receives,
|
| 14 |
+
the smaller the updates.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
learning_rate: A float, a
|
| 18 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 19 |
+
a callable that takes no arguments and returns the actual value to
|
| 20 |
+
use. The learning rate. Defaults to `0.001`. Note that `Adagrad`
|
| 21 |
+
tends to benefit from higher initial learning rate values compared
|
| 22 |
+
to other optimizers. To match the exact form in the original paper,
|
| 23 |
+
use `1.0`.
|
| 24 |
+
initial_accumulator_value: Floating point value. Starting value for the
|
| 25 |
+
accumulators (per-parameter momentum values). Must be non-negative.
|
| 26 |
+
epsilon: Small floating point value for maintaining numerical stability.
|
| 27 |
+
{{base_optimizer_keyword_args}}
|
| 28 |
+
|
| 29 |
+
Reference:
|
| 30 |
+
|
| 31 |
+
- [Duchi et al., 2011](
|
| 32 |
+
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
learning_rate=0.001,
|
| 38 |
+
initial_accumulator_value=0.1,
|
| 39 |
+
epsilon=1e-7,
|
| 40 |
+
weight_decay=None,
|
| 41 |
+
clipnorm=None,
|
| 42 |
+
clipvalue=None,
|
| 43 |
+
global_clipnorm=None,
|
| 44 |
+
use_ema=False,
|
| 45 |
+
ema_momentum=0.99,
|
| 46 |
+
ema_overwrite_frequency=None,
|
| 47 |
+
loss_scale_factor=None,
|
| 48 |
+
gradient_accumulation_steps=None,
|
| 49 |
+
name="adagrad",
|
| 50 |
+
**kwargs,
|
| 51 |
+
):
|
| 52 |
+
super().__init__(
|
| 53 |
+
learning_rate=learning_rate,
|
| 54 |
+
weight_decay=weight_decay,
|
| 55 |
+
clipnorm=clipnorm,
|
| 56 |
+
clipvalue=clipvalue,
|
| 57 |
+
global_clipnorm=global_clipnorm,
|
| 58 |
+
use_ema=use_ema,
|
| 59 |
+
ema_momentum=ema_momentum,
|
| 60 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 61 |
+
loss_scale_factor=loss_scale_factor,
|
| 62 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 63 |
+
name=name,
|
| 64 |
+
**kwargs,
|
| 65 |
+
)
|
| 66 |
+
self.initial_accumulator_value = initial_accumulator_value
|
| 67 |
+
self.epsilon = epsilon
|
| 68 |
+
|
| 69 |
+
def build(self, var_list):
|
| 70 |
+
if self.built:
|
| 71 |
+
return
|
| 72 |
+
super().build(var_list)
|
| 73 |
+
self._accumulators = []
|
| 74 |
+
initializer = initializers.Constant(self.initial_accumulator_value)
|
| 75 |
+
for var in var_list:
|
| 76 |
+
self._accumulators.append(
|
| 77 |
+
self.add_variable(
|
| 78 |
+
shape=var.shape,
|
| 79 |
+
initializer=initializer,
|
| 80 |
+
dtype=var.dtype,
|
| 81 |
+
name="accumulator",
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 86 |
+
"""Update step given gradient and the associated model variable."""
|
| 87 |
+
lr = ops.cast(learning_rate, variable.dtype)
|
| 88 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 89 |
+
|
| 90 |
+
accumulator = self._accumulators[self._get_variable_index(variable)]
|
| 91 |
+
|
| 92 |
+
self.assign_add(accumulator, ops.square(gradient))
|
| 93 |
+
self.assign_sub(
|
| 94 |
+
variable,
|
| 95 |
+
ops.divide(
|
| 96 |
+
ops.multiply(lr, gradient),
|
| 97 |
+
ops.sqrt(ops.add(accumulator, self.epsilon)),
|
| 98 |
+
),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def get_config(self):
|
| 102 |
+
config = super().get_config()
|
| 103 |
+
|
| 104 |
+
config.update(
|
| 105 |
+
{
|
| 106 |
+
"initial_accumulator_value": self.initial_accumulator_value,
|
| 107 |
+
"epsilon": self.epsilon,
|
| 108 |
+
}
|
| 109 |
+
)
|
| 110 |
+
return config
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
Adagrad.__doc__ = Adagrad.__doc__.replace(
|
| 114 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 115 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adam.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import ops
|
| 2 |
+
from keras.src.api_export import keras_export
|
| 3 |
+
from keras.src.optimizers import optimizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@keras_export(["keras.optimizers.Adam"])
|
| 7 |
+
class Adam(optimizer.Optimizer):
|
| 8 |
+
"""Optimizer that implements the Adam algorithm.
|
| 9 |
+
|
| 10 |
+
Adam optimization is a stochastic gradient descent method that is based on
|
| 11 |
+
adaptive estimation of first-order and second-order moments.
|
| 12 |
+
|
| 13 |
+
According to
|
| 14 |
+
[Kingma et al., 2014](http://arxiv.org/abs/1412.6980),
|
| 15 |
+
the method is "*computationally
|
| 16 |
+
efficient, has little memory requirement, invariant to diagonal rescaling of
|
| 17 |
+
gradients, and is well suited for problems that are large in terms of
|
| 18 |
+
data/parameters*".
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
learning_rate: A float, a
|
| 22 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 23 |
+
a callable that takes no arguments and returns the actual value to
|
| 24 |
+
use. The learning rate. Defaults to `0.001`.
|
| 25 |
+
beta_1: A float value or a constant float tensor, or a callable
|
| 26 |
+
that takes no arguments and returns the actual value to use. The
|
| 27 |
+
exponential decay rate for the 1st moment estimates. Defaults to
|
| 28 |
+
`0.9`.
|
| 29 |
+
beta_2: A float value or a constant float tensor, or a callable
|
| 30 |
+
that takes no arguments and returns the actual value to use. The
|
| 31 |
+
exponential decay rate for the 2nd moment estimates. Defaults to
|
| 32 |
+
`0.999`.
|
| 33 |
+
epsilon: A small constant for numerical stability. This epsilon is
|
| 34 |
+
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
| 35 |
+
Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults
|
| 36 |
+
to `1e-7`.
|
| 37 |
+
amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm
|
| 38 |
+
from the paper "On the Convergence of Adam and beyond". Defaults
|
| 39 |
+
to `False`.
|
| 40 |
+
{{base_optimizer_keyword_args}}
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
learning_rate=0.001,
|
| 46 |
+
beta_1=0.9,
|
| 47 |
+
beta_2=0.999,
|
| 48 |
+
epsilon=1e-7,
|
| 49 |
+
amsgrad=False,
|
| 50 |
+
weight_decay=None,
|
| 51 |
+
clipnorm=None,
|
| 52 |
+
clipvalue=None,
|
| 53 |
+
global_clipnorm=None,
|
| 54 |
+
use_ema=False,
|
| 55 |
+
ema_momentum=0.99,
|
| 56 |
+
ema_overwrite_frequency=None,
|
| 57 |
+
loss_scale_factor=None,
|
| 58 |
+
gradient_accumulation_steps=None,
|
| 59 |
+
name="adam",
|
| 60 |
+
**kwargs,
|
| 61 |
+
):
|
| 62 |
+
super().__init__(
|
| 63 |
+
learning_rate=learning_rate,
|
| 64 |
+
name=name,
|
| 65 |
+
weight_decay=weight_decay,
|
| 66 |
+
clipnorm=clipnorm,
|
| 67 |
+
clipvalue=clipvalue,
|
| 68 |
+
global_clipnorm=global_clipnorm,
|
| 69 |
+
use_ema=use_ema,
|
| 70 |
+
ema_momentum=ema_momentum,
|
| 71 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 72 |
+
loss_scale_factor=loss_scale_factor,
|
| 73 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 74 |
+
**kwargs,
|
| 75 |
+
)
|
| 76 |
+
self.beta_1 = beta_1
|
| 77 |
+
self.beta_2 = beta_2
|
| 78 |
+
self.epsilon = epsilon
|
| 79 |
+
self.amsgrad = amsgrad
|
| 80 |
+
|
| 81 |
+
def build(self, var_list):
|
| 82 |
+
"""Initialize optimizer variables.
|
| 83 |
+
|
| 84 |
+
Adam optimizer has 3 types of variables: momentums, velocities and
|
| 85 |
+
velocity_hat (only set when amsgrad is applied),
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
var_list: list of model variables to build Adam variables on.
|
| 89 |
+
"""
|
| 90 |
+
if self.built:
|
| 91 |
+
return
|
| 92 |
+
super().build(var_list)
|
| 93 |
+
self._momentums = []
|
| 94 |
+
self._velocities = []
|
| 95 |
+
for var in var_list:
|
| 96 |
+
self._momentums.append(
|
| 97 |
+
self.add_variable_from_reference(
|
| 98 |
+
reference_variable=var, name="momentum"
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
self._velocities.append(
|
| 102 |
+
self.add_variable_from_reference(
|
| 103 |
+
reference_variable=var, name="velocity"
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
if self.amsgrad:
|
| 107 |
+
self._velocity_hats = []
|
| 108 |
+
for var in var_list:
|
| 109 |
+
self._velocity_hats.append(
|
| 110 |
+
self.add_variable_from_reference(
|
| 111 |
+
reference_variable=var, name="velocity_hat"
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 116 |
+
"""Update step given gradient and the associated model variable."""
|
| 117 |
+
lr = ops.cast(learning_rate, variable.dtype)
|
| 118 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 119 |
+
local_step = ops.cast(self.iterations + 1, variable.dtype)
|
| 120 |
+
beta_1_power = ops.power(
|
| 121 |
+
ops.cast(self.beta_1, variable.dtype), local_step
|
| 122 |
+
)
|
| 123 |
+
beta_2_power = ops.power(
|
| 124 |
+
ops.cast(self.beta_2, variable.dtype), local_step
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
m = self._momentums[self._get_variable_index(variable)]
|
| 128 |
+
v = self._velocities[self._get_variable_index(variable)]
|
| 129 |
+
|
| 130 |
+
alpha = lr * ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)
|
| 131 |
+
|
| 132 |
+
self.assign_add(
|
| 133 |
+
m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1)
|
| 134 |
+
)
|
| 135 |
+
self.assign_add(
|
| 136 |
+
v,
|
| 137 |
+
ops.multiply(
|
| 138 |
+
ops.subtract(ops.square(gradient), v), 1 - self.beta_2
|
| 139 |
+
),
|
| 140 |
+
)
|
| 141 |
+
if self.amsgrad:
|
| 142 |
+
v_hat = self._velocity_hats[self._get_variable_index(variable)]
|
| 143 |
+
self.assign(v_hat, ops.maximum(v_hat, v))
|
| 144 |
+
v = v_hat
|
| 145 |
+
self.assign_sub(
|
| 146 |
+
variable,
|
| 147 |
+
ops.divide(
|
| 148 |
+
ops.multiply(m, alpha), ops.add(ops.sqrt(v), self.epsilon)
|
| 149 |
+
),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def get_config(self):
|
| 153 |
+
config = super().get_config()
|
| 154 |
+
config.update(
|
| 155 |
+
{
|
| 156 |
+
"beta_1": self.beta_1,
|
| 157 |
+
"beta_2": self.beta_2,
|
| 158 |
+
"epsilon": self.epsilon,
|
| 159 |
+
"amsgrad": self.amsgrad,
|
| 160 |
+
}
|
| 161 |
+
)
|
| 162 |
+
return config
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
Adam.__doc__ = Adam.__doc__.replace(
|
| 166 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 167 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adamax.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import ops
|
| 2 |
+
from keras.src.api_export import keras_export
|
| 3 |
+
from keras.src.optimizers import optimizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@keras_export(["keras.optimizers.Adamax"])
|
| 7 |
+
class Adamax(optimizer.Optimizer):
|
| 8 |
+
"""Optimizer that implements the Adamax algorithm.
|
| 9 |
+
|
| 10 |
+
Adamax, a variant of Adam based on the infinity norm, is a first-order
|
| 11 |
+
gradient-based optimization method. Due to its capability of adjusting the
|
| 12 |
+
learning rate based on data characteristics, it is suited to learn
|
| 13 |
+
time-variant process, e.g., speech data with dynamically changed noise
|
| 14 |
+
conditions. Default parameters follow those provided in the paper (see
|
| 15 |
+
references below).
|
| 16 |
+
|
| 17 |
+
Initialization:
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
m = 0 # Initialize initial 1st moment vector
|
| 21 |
+
u = 0 # Initialize the exponentially weighted infinity norm
|
| 22 |
+
t = 0 # Initialize timestep
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
The update rule for parameter `w` with gradient `g` is described at the end
|
| 26 |
+
of section 7.1 of the paper (see the reference section):
|
| 27 |
+
|
| 28 |
+
```python
|
| 29 |
+
t += 1
|
| 30 |
+
m = beta1 * m + (1 - beta) * g
|
| 31 |
+
u = max(beta2 * u, abs(g))
|
| 32 |
+
current_lr = learning_rate / (1 - beta1 ** t)
|
| 33 |
+
w = w - current_lr * m / (u + epsilon)
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
learning_rate: A float, a
|
| 38 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 39 |
+
a callable that takes no arguments and returns the actual value to
|
| 40 |
+
use. The learning rate. Defaults to `0.001`.
|
| 41 |
+
beta_1: A float value or a constant float tensor. The exponential decay
|
| 42 |
+
rate for the 1st moment estimates.
|
| 43 |
+
beta_2: A float value or a constant float tensor. The exponential decay
|
| 44 |
+
rate for the exponentially weighted infinity norm.
|
| 45 |
+
epsilon: A small constant for numerical stability.
|
| 46 |
+
{{base_optimizer_keyword_args}}
|
| 47 |
+
|
| 48 |
+
Reference:
|
| 49 |
+
|
| 50 |
+
- [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
learning_rate=0.001,
|
| 56 |
+
beta_1=0.9,
|
| 57 |
+
beta_2=0.999,
|
| 58 |
+
epsilon=1e-7,
|
| 59 |
+
weight_decay=None,
|
| 60 |
+
clipnorm=None,
|
| 61 |
+
clipvalue=None,
|
| 62 |
+
global_clipnorm=None,
|
| 63 |
+
use_ema=False,
|
| 64 |
+
ema_momentum=0.99,
|
| 65 |
+
ema_overwrite_frequency=None,
|
| 66 |
+
loss_scale_factor=None,
|
| 67 |
+
gradient_accumulation_steps=None,
|
| 68 |
+
name="adamax",
|
| 69 |
+
**kwargs,
|
| 70 |
+
):
|
| 71 |
+
super().__init__(
|
| 72 |
+
learning_rate=learning_rate,
|
| 73 |
+
name=name,
|
| 74 |
+
weight_decay=weight_decay,
|
| 75 |
+
clipnorm=clipnorm,
|
| 76 |
+
clipvalue=clipvalue,
|
| 77 |
+
global_clipnorm=global_clipnorm,
|
| 78 |
+
use_ema=use_ema,
|
| 79 |
+
ema_momentum=ema_momentum,
|
| 80 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 81 |
+
loss_scale_factor=loss_scale_factor,
|
| 82 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 83 |
+
**kwargs,
|
| 84 |
+
)
|
| 85 |
+
self.beta_1 = beta_1
|
| 86 |
+
self.beta_2 = beta_2
|
| 87 |
+
self.epsilon = epsilon
|
| 88 |
+
|
| 89 |
+
def build(self, var_list):
|
| 90 |
+
"""Initialize optimizer variables.
|
| 91 |
+
|
| 92 |
+
Adamax optimizer has 2 types of variables: momentums (denoted as m),
|
| 93 |
+
exponentially weighted infinity norm (denoted as u).
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
var_list: list of model variables to build Adamax variables on.
|
| 97 |
+
"""
|
| 98 |
+
if self.built:
|
| 99 |
+
return
|
| 100 |
+
super().build(var_list)
|
| 101 |
+
self._m = []
|
| 102 |
+
self._u = []
|
| 103 |
+
for var in var_list:
|
| 104 |
+
self._m.append(
|
| 105 |
+
self.add_variable_from_reference(
|
| 106 |
+
reference_variable=var, name="momentum"
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
self._u.append(
|
| 110 |
+
self.add_variable_from_reference(
|
| 111 |
+
reference_variable=var, name="norm"
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 116 |
+
"""Update step given gradient and the associated model variable."""
|
| 117 |
+
lr = ops.cast(learning_rate, variable.dtype)
|
| 118 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 119 |
+
local_step = ops.cast(self.iterations + 1, variable.dtype)
|
| 120 |
+
beta_1_power = ops.power(
|
| 121 |
+
ops.cast(self.beta_1, variable.dtype), local_step
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
m = self._m[self._get_variable_index(variable)]
|
| 125 |
+
u = self._u[self._get_variable_index(variable)]
|
| 126 |
+
|
| 127 |
+
self.assign_add(
|
| 128 |
+
m, ops.multiply(ops.subtract(gradient, m), (1 - self.beta_1))
|
| 129 |
+
)
|
| 130 |
+
self.assign(
|
| 131 |
+
u, ops.maximum(ops.multiply(self.beta_2, u), ops.abs(gradient))
|
| 132 |
+
)
|
| 133 |
+
self.assign_sub(
|
| 134 |
+
variable,
|
| 135 |
+
ops.divide(
|
| 136 |
+
ops.multiply(lr, m),
|
| 137 |
+
ops.multiply((1 - beta_1_power), ops.add(u, self.epsilon)),
|
| 138 |
+
),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def get_config(self):
|
| 142 |
+
config = super().get_config()
|
| 143 |
+
|
| 144 |
+
config.update(
|
| 145 |
+
{
|
| 146 |
+
"beta_1": self.beta_1,
|
| 147 |
+
"beta_2": self.beta_2,
|
| 148 |
+
"epsilon": self.epsilon,
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
+
return config
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
Adamax.__doc__ = Adamax.__doc__.replace(
|
| 155 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 156 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/adamw.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src.api_export import keras_export
|
| 2 |
+
from keras.src.optimizers import adam
|
| 3 |
+
from keras.src.optimizers import optimizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@keras_export(["keras.optimizers.AdamW"])
|
| 7 |
+
class AdamW(adam.Adam):
|
| 8 |
+
"""Optimizer that implements the AdamW algorithm.
|
| 9 |
+
|
| 10 |
+
AdamW optimization is a stochastic gradient descent method that is based on
|
| 11 |
+
adaptive estimation of first-order and second-order moments with an added
|
| 12 |
+
method to decay weights per the techniques discussed in the paper,
|
| 13 |
+
'Decoupled Weight Decay Regularization' by
|
| 14 |
+
[Loshchilov, Hutter et al., 2019](https://arxiv.org/abs/1711.05101).
|
| 15 |
+
|
| 16 |
+
According to
|
| 17 |
+
[Kingma et al., 2014](http://arxiv.org/abs/1412.6980),
|
| 18 |
+
the underlying Adam method is "*computationally
|
| 19 |
+
efficient, has little memory requirement, invariant to diagonal rescaling of
|
| 20 |
+
gradients, and is well suited for problems that are large in terms of
|
| 21 |
+
data/parameters*".
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
learning_rate: A float, a
|
| 25 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 26 |
+
a callable that takes no arguments and returns the actual value to
|
| 27 |
+
use. The learning rate. Defaults to `0.001`.
|
| 28 |
+
beta_1: A float value or a constant float tensor, or a callable
|
| 29 |
+
that takes no arguments and returns the actual value to use. The
|
| 30 |
+
exponential decay rate for the 1st moment estimates.
|
| 31 |
+
Defaults to `0.9`.
|
| 32 |
+
beta_2: A float value or a constant float tensor, or a callable
|
| 33 |
+
that takes no arguments and returns the actual value to use. The
|
| 34 |
+
exponential decay rate for the 2nd moment estimates.
|
| 35 |
+
Defaults to `0.999`.
|
| 36 |
+
epsilon: A small constant for numerical stability. This epsilon is
|
| 37 |
+
"epsilon hat" in the Kingma and Ba paper (in the formula just
|
| 38 |
+
before Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
| 39 |
+
Defaults to 1e-7.
|
| 40 |
+
amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm
|
| 41 |
+
from the paper "On the Convergence of Adam and beyond".
|
| 42 |
+
Defaults to `False`.
|
| 43 |
+
{{base_optimizer_keyword_args}}
|
| 44 |
+
|
| 45 |
+
References:
|
| 46 |
+
|
| 47 |
+
- [Loshchilov et al., 2019](https://arxiv.org/abs/1711.05101)
|
| 48 |
+
- [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) for `adam`
|
| 49 |
+
- [Reddi et al., 2018](
|
| 50 |
+
https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
learning_rate=0.001,
|
| 56 |
+
weight_decay=0.004,
|
| 57 |
+
beta_1=0.9,
|
| 58 |
+
beta_2=0.999,
|
| 59 |
+
epsilon=1e-7,
|
| 60 |
+
amsgrad=False,
|
| 61 |
+
clipnorm=None,
|
| 62 |
+
clipvalue=None,
|
| 63 |
+
global_clipnorm=None,
|
| 64 |
+
use_ema=False,
|
| 65 |
+
ema_momentum=0.99,
|
| 66 |
+
ema_overwrite_frequency=None,
|
| 67 |
+
loss_scale_factor=None,
|
| 68 |
+
gradient_accumulation_steps=None,
|
| 69 |
+
name="adamw",
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
super().__init__(
|
| 73 |
+
learning_rate=learning_rate,
|
| 74 |
+
beta_1=beta_1,
|
| 75 |
+
beta_2=beta_2,
|
| 76 |
+
epsilon=epsilon,
|
| 77 |
+
amsgrad=amsgrad,
|
| 78 |
+
name=name,
|
| 79 |
+
weight_decay=weight_decay,
|
| 80 |
+
clipnorm=clipnorm,
|
| 81 |
+
clipvalue=clipvalue,
|
| 82 |
+
global_clipnorm=global_clipnorm,
|
| 83 |
+
use_ema=use_ema,
|
| 84 |
+
ema_momentum=ema_momentum,
|
| 85 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 86 |
+
loss_scale_factor=loss_scale_factor,
|
| 87 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 88 |
+
**kwargs,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if self.weight_decay is None:
|
| 92 |
+
raise ValueError(
|
| 93 |
+
"Argument `weight_decay` must be a float. Received: "
|
| 94 |
+
"weight_decay=None"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
AdamW.__doc__ = AdamW.__doc__.replace(
|
| 99 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 100 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/base_optimizer.py
ADDED
|
@@ -0,0 +1,1102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
from keras.src import backend
|
| 5 |
+
from keras.src import initializers
|
| 6 |
+
from keras.src import ops
|
| 7 |
+
from keras.src.optimizers.schedules import learning_rate_schedule
|
| 8 |
+
from keras.src.saving import serialization_lib
|
| 9 |
+
from keras.src.saving.keras_saveable import KerasSaveable
|
| 10 |
+
from keras.src.utils import tracking
|
| 11 |
+
from keras.src.utils.naming import auto_name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseOptimizer(KerasSaveable):
|
| 15 |
+
"""Abstract optimizer base class.
|
| 16 |
+
|
| 17 |
+
If you intend to create your own optimization algorithm, please inherit from
|
| 18 |
+
this class and override the following methods:
|
| 19 |
+
|
| 20 |
+
- `build`: Create your optimizer-related variables, such as momentum
|
| 21 |
+
variables in the SGD optimizer.
|
| 22 |
+
- `update_step`: Implement your optimizer's variable updating logic.
|
| 23 |
+
- `get_config`: serialization of the optimizer.
|
| 24 |
+
|
| 25 |
+
Example:
|
| 26 |
+
|
| 27 |
+
```python
|
| 28 |
+
class SGD(Optimizer):
|
| 29 |
+
def __init__(self, **kwargs):
|
| 30 |
+
super().__init__(**kwargs)
|
| 31 |
+
self.momentum = 0.9
|
| 32 |
+
|
| 33 |
+
def build(self, variables):
|
| 34 |
+
super().build(variables)
|
| 35 |
+
self.momentums = []
|
| 36 |
+
for variable in variables:
|
| 37 |
+
self.momentums.append(
|
| 38 |
+
self.add_variable_from_reference(
|
| 39 |
+
reference_variable=variable, name="momentum"
|
| 40 |
+
)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 44 |
+
learning_rate = ops.cast(learning_rate, variable.dtype)
|
| 45 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 46 |
+
m = self.momentums[self._get_variable_index(variable)]
|
| 47 |
+
self.assign(
|
| 48 |
+
m,
|
| 49 |
+
ops.subtract(
|
| 50 |
+
ops.multiply(m, ops.cast(self.momentum, variable.dtype)),
|
| 51 |
+
ops.multiply(gradient, learning_rate),
|
| 52 |
+
),
|
| 53 |
+
)
|
| 54 |
+
self.assign_add(variable, m)
|
| 55 |
+
|
| 56 |
+
def get_config(self):
|
| 57 |
+
config = super().get_config()
|
| 58 |
+
config.update(
|
| 59 |
+
{
|
| 60 |
+
"momentum": self.momentum,
|
| 61 |
+
"nesterov": self.nesterov,
|
| 62 |
+
}
|
| 63 |
+
)
|
| 64 |
+
return config
|
| 65 |
+
```
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
learning_rate,
|
| 71 |
+
weight_decay=None,
|
| 72 |
+
clipnorm=None,
|
| 73 |
+
clipvalue=None,
|
| 74 |
+
global_clipnorm=None,
|
| 75 |
+
use_ema=False,
|
| 76 |
+
ema_momentum=0.99,
|
| 77 |
+
ema_overwrite_frequency=None,
|
| 78 |
+
loss_scale_factor=None,
|
| 79 |
+
gradient_accumulation_steps=None,
|
| 80 |
+
name=None,
|
| 81 |
+
**kwargs,
|
| 82 |
+
):
|
| 83 |
+
self._lock = False
|
| 84 |
+
|
| 85 |
+
if kwargs.pop("decay", None) is not None:
|
| 86 |
+
warnings.warn(
|
| 87 |
+
"Argument `decay` is no longer supported and will be ignored."
|
| 88 |
+
)
|
| 89 |
+
if kwargs:
|
| 90 |
+
raise ValueError(f"Argument(s) not recognized: {kwargs}")
|
| 91 |
+
|
| 92 |
+
if name is None:
|
| 93 |
+
name = auto_name(self.__class__.__name__)
|
| 94 |
+
self.name = name
|
| 95 |
+
self.weight_decay = weight_decay
|
| 96 |
+
self.clipnorm = clipnorm
|
| 97 |
+
self.global_clipnorm = global_clipnorm
|
| 98 |
+
self.clipvalue = clipvalue
|
| 99 |
+
self.use_ema = use_ema
|
| 100 |
+
self.loss_scale_factor = loss_scale_factor
|
| 101 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 102 |
+
|
| 103 |
+
if gradient_accumulation_steps:
|
| 104 |
+
if not gradient_accumulation_steps >= 2:
|
| 105 |
+
raise ValueError(
|
| 106 |
+
"`gradient_accumulation_steps` must be an integer >= 2. "
|
| 107 |
+
"Received: gradient_accumulation_steps="
|
| 108 |
+
f"{gradient_accumulation_steps}"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if use_ema:
|
| 112 |
+
# Verify the arguments related to EMA.
|
| 113 |
+
if ema_momentum > 1 or ema_momentum < 0:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
"`ema_momentum` must be in the range [0, 1]. "
|
| 116 |
+
f"Received: ema_momentum={ema_momentum}"
|
| 117 |
+
)
|
| 118 |
+
if ema_overwrite_frequency and (
|
| 119 |
+
not isinstance(ema_overwrite_frequency, int)
|
| 120 |
+
or ema_overwrite_frequency < 1
|
| 121 |
+
):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
"`ema_overwrite_frequency` must be an integer >= 1 or "
|
| 124 |
+
"None. Received: ema_overwrite_frequency="
|
| 125 |
+
f"{ema_overwrite_frequency}"
|
| 126 |
+
)
|
| 127 |
+
self.ema_momentum = ema_momentum
|
| 128 |
+
self.ema_overwrite_frequency = ema_overwrite_frequency
|
| 129 |
+
|
| 130 |
+
clip_args_sum = sum(
|
| 131 |
+
a is not None for a in [clipnorm, clipvalue, global_clipnorm]
|
| 132 |
+
)
|
| 133 |
+
if clip_args_sum > 1:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
"Only one of `clipnorm`, `clipvalue` and `global_clipnorm` can "
|
| 136 |
+
f"be set. Received: clipnorm={clipnorm}, "
|
| 137 |
+
f"clipvalue={clipvalue}, global_clipnorm={global_clipnorm}"
|
| 138 |
+
)
|
| 139 |
+
self.built = False
|
| 140 |
+
|
| 141 |
+
# Set up variable tracking.
|
| 142 |
+
self._variables = []
|
| 143 |
+
self._trainable_variables = []
|
| 144 |
+
self._tracker = tracking.Tracker(
|
| 145 |
+
{
|
| 146 |
+
"variables": (
|
| 147 |
+
lambda x: isinstance(x, backend.Variable),
|
| 148 |
+
self._variables,
|
| 149 |
+
),
|
| 150 |
+
}
|
| 151 |
+
)
|
| 152 |
+
self._trainable_variables_indices = {}
|
| 153 |
+
|
| 154 |
+
# Create iteration variable
|
| 155 |
+
# Note: dtype="int" will resolve to int32 in JAX
|
| 156 |
+
# (since int64 is disallowed in JAX) and to int64 in TF.
|
| 157 |
+
with backend.name_scope(self.name, caller=self):
|
| 158 |
+
iterations = backend.Variable(
|
| 159 |
+
0,
|
| 160 |
+
name="iteration",
|
| 161 |
+
dtype="int",
|
| 162 |
+
trainable=False,
|
| 163 |
+
aggregation="only_first_replica",
|
| 164 |
+
)
|
| 165 |
+
self._track_variable(iterations)
|
| 166 |
+
self._iterations = iterations
|
| 167 |
+
|
| 168 |
+
# Create learning rate (schedule or variable)
|
| 169 |
+
if isinstance(
|
| 170 |
+
learning_rate, learning_rate_schedule.LearningRateSchedule
|
| 171 |
+
):
|
| 172 |
+
self._learning_rate = learning_rate
|
| 173 |
+
elif callable(learning_rate):
|
| 174 |
+
self._learning_rate = learning_rate
|
| 175 |
+
else:
|
| 176 |
+
if not isinstance(learning_rate, float):
|
| 177 |
+
raise ValueError(
|
| 178 |
+
"Argument `learning_rate` should be float, or an instance "
|
| 179 |
+
"of LearningRateSchedule, or a callable "
|
| 180 |
+
"(that takes in the current iteration value "
|
| 181 |
+
"and returns the corresponding learning rate value). "
|
| 182 |
+
f"Received instead: learning_rate={learning_rate}"
|
| 183 |
+
)
|
| 184 |
+
with backend.name_scope(self.name, caller=self):
|
| 185 |
+
learning_rate = backend.Variable(
|
| 186 |
+
learning_rate,
|
| 187 |
+
name="learning_rate",
|
| 188 |
+
dtype=backend.floatx(),
|
| 189 |
+
trainable=False,
|
| 190 |
+
aggregation="only_first_replica",
|
| 191 |
+
)
|
| 192 |
+
self._track_variable(learning_rate)
|
| 193 |
+
self._learning_rate = learning_rate
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def iterations(self):
|
| 197 |
+
if self.gradient_accumulation_steps:
|
| 198 |
+
return ops.floor_divide(
|
| 199 |
+
self._iterations, self.gradient_accumulation_steps
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return self._iterations
|
| 203 |
+
|
| 204 |
+
def _track_variable(self, variable):
|
| 205 |
+
self._tracker.add_to_store("variables", variable)
|
| 206 |
+
|
| 207 |
+
@tracking.no_automatic_dependency_tracking
|
| 208 |
+
def build(self, variables):
|
| 209 |
+
if self.use_ema:
|
| 210 |
+
self._model_variables_moving_average = []
|
| 211 |
+
if self.gradient_accumulation_steps:
|
| 212 |
+
self._accumulated_gradients = []
|
| 213 |
+
for i, variable in enumerate(variables):
|
| 214 |
+
self._trainable_variables_indices[self._var_key(variable)] = i
|
| 215 |
+
if self.use_ema:
|
| 216 |
+
self._model_variables_moving_average.append(
|
| 217 |
+
self.add_variable_from_reference(
|
| 218 |
+
variable,
|
| 219 |
+
name="average",
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
if self.gradient_accumulation_steps:
|
| 223 |
+
self._accumulated_gradients.append(
|
| 224 |
+
self.add_variable_from_reference(
|
| 225 |
+
variable,
|
| 226 |
+
name="gradient_accumulator",
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
self._trainable_variables = variables[:]
|
| 230 |
+
self.built = True
|
| 231 |
+
|
| 232 |
+
def _var_key(self, variable):
|
| 233 |
+
# Helper function to get a stable ID and the variable instance mapping.
|
| 234 |
+
return id(variable)
|
| 235 |
+
|
| 236 |
+
@property
|
| 237 |
+
def variables(self):
|
| 238 |
+
return self._variables[:]
|
| 239 |
+
|
| 240 |
+
def _get_variable_index(self, variable):
|
| 241 |
+
return self._trainable_variables_indices[self._var_key(variable)]
|
| 242 |
+
|
| 243 |
+
def add_variable(
|
| 244 |
+
self,
|
| 245 |
+
shape,
|
| 246 |
+
initializer="zeros",
|
| 247 |
+
dtype=None,
|
| 248 |
+
aggregation="none",
|
| 249 |
+
name=None,
|
| 250 |
+
):
|
| 251 |
+
"""Add a variable to the optimizer.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
shape: Shape tuple for the variable. Must be fully-defined
|
| 255 |
+
(no `None` entries).
|
| 256 |
+
initializer: Initializer object to use to populate the initial
|
| 257 |
+
variable value, or string name of a built-in initializer
|
| 258 |
+
(e.g. `"random_normal"`). Defaults to `"zeros"`.
|
| 259 |
+
dtype: Dtype of the variable to create, e.g. `"float32"`. If
|
| 260 |
+
unspecified, defaults to the `keras.backend.floatx()`.
|
| 261 |
+
aggregation: Optional string, one of `None`, `"none"`, `"mean"`,
|
| 262 |
+
`"sum"` or `"only_first_replica"`. Annotates the variable with
|
| 263 |
+
the type of multi-replica aggregation to be used for this
|
| 264 |
+
variable when writing custom data parallel training loops.
|
| 265 |
+
Defaults to `"none"`.
|
| 266 |
+
name: String name of the variable. Useful for debugging purposes.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
An optimizer variable, in the format of `keras.Variable`.
|
| 270 |
+
"""
|
| 271 |
+
self._check_super_called()
|
| 272 |
+
initializer = initializers.get(initializer)
|
| 273 |
+
with backend.name_scope(self.name, caller=self):
|
| 274 |
+
variable = backend.Variable(
|
| 275 |
+
initializer=initializer,
|
| 276 |
+
shape=shape,
|
| 277 |
+
dtype=dtype,
|
| 278 |
+
trainable=False,
|
| 279 |
+
aggregation=aggregation,
|
| 280 |
+
name=name,
|
| 281 |
+
)
|
| 282 |
+
self._track_variable(variable)
|
| 283 |
+
return variable
|
| 284 |
+
|
| 285 |
+
def add_variable_from_reference(
|
| 286 |
+
self, reference_variable, name=None, initializer="zeros"
|
| 287 |
+
):
|
| 288 |
+
"""Add an optimizer variable from the model variable.
|
| 289 |
+
|
| 290 |
+
Create an optimizer variable based on the information of model variable.
|
| 291 |
+
For example, in SGD optimizer momemtum, for each model variable, a
|
| 292 |
+
corresponding momemtum variable is created of the same shape and dtype.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
reference_variable: `keras.Variable`. The corresponding model
|
| 296 |
+
variable to the optimizer variable to be created.
|
| 297 |
+
name: Optional string. The name prefix of the optimizer variable to
|
| 298 |
+
be created. If not provided, it will be set to `"var"`. The
|
| 299 |
+
variable name will follow the pattern
|
| 300 |
+
`{variable_name}_{reference_variable.name}`,
|
| 301 |
+
e.g., `momemtum/dense_1`. Defaults to `None`.
|
| 302 |
+
initializer: Initializer object to use to populate the initial
|
| 303 |
+
variable value, or string name of a built-in initializer
|
| 304 |
+
(e.g. `"random_normal"`). If unspecified, defaults to
|
| 305 |
+
`"zeros"`.
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
An optimizer variable, in the format of `keras.Variable`.
|
| 309 |
+
"""
|
| 310 |
+
name = name or "var"
|
| 311 |
+
if hasattr(reference_variable, "path"):
|
| 312 |
+
name = reference_variable.path.replace("/", "_") + "_" + name
|
| 313 |
+
else:
|
| 314 |
+
name = (
|
| 315 |
+
str(reference_variable.name).replace("/", "_").replace(":", "_")
|
| 316 |
+
+ "_"
|
| 317 |
+
+ name
|
| 318 |
+
)
|
| 319 |
+
return self.add_variable(
|
| 320 |
+
shape=reference_variable.shape,
|
| 321 |
+
initializer=initializer,
|
| 322 |
+
dtype=reference_variable.dtype,
|
| 323 |
+
name=name,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
def _check_variables_are_known(self, variables):
|
| 327 |
+
for v in variables:
|
| 328 |
+
if self._var_key(v) not in self._trainable_variables_indices:
|
| 329 |
+
raise ValueError(
|
| 330 |
+
f"Unknown variable: {v}. This optimizer can only "
|
| 331 |
+
"be called for the variables it was originally built with. "
|
| 332 |
+
"When working with a new set of variables, you should "
|
| 333 |
+
"recreate a new optimizer instance."
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
def assign(self, variable, value):
|
| 337 |
+
"""Assign a value to a variable.
|
| 338 |
+
|
| 339 |
+
This should be used in optimizers instead of `variable.assign(value)` to
|
| 340 |
+
support backend specific optimizations.
|
| 341 |
+
Note that the variable can be a model variable or an optimizer variable;
|
| 342 |
+
it can be a backend native variable or a Keras variable.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
variable: The variable to update.
|
| 346 |
+
value: The value to add to the variable.
|
| 347 |
+
"""
|
| 348 |
+
variable.assign(value)
|
| 349 |
+
|
| 350 |
+
def assign_add(self, variable, value):
|
| 351 |
+
"""Add a value to a variable.
|
| 352 |
+
|
| 353 |
+
This should be used in optimizers instead of
|
| 354 |
+
`variable.assign_add(value)` to support backend specific optimizations.
|
| 355 |
+
Note that the variable can be a model variable or an optimizer variable;
|
| 356 |
+
it can be a backend native variable or a Keras variable.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
variable: The variable to update.
|
| 360 |
+
value: The value to add to the variable.
|
| 361 |
+
"""
|
| 362 |
+
variable.assign_add(value)
|
| 363 |
+
|
| 364 |
+
def assign_sub(self, variable, value):
|
| 365 |
+
"""Subtract a value from a variable.
|
| 366 |
+
|
| 367 |
+
This should be used in optimizers instead of
|
| 368 |
+
`variable.assign_sub(value)` to support backend specific optimizations.
|
| 369 |
+
Note that the variable can be a model variable or an optimizer variable;
|
| 370 |
+
it can be a backend native variable or a Keras variable.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
variable: The variable to update.
|
| 374 |
+
value: The value to add to the variable.
|
| 375 |
+
"""
|
| 376 |
+
variable.assign_sub(value)
|
| 377 |
+
|
| 378 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 379 |
+
raise NotImplementedError
|
| 380 |
+
|
| 381 |
+
def apply_gradients(self, grads_and_vars):
|
| 382 |
+
grads, trainable_variables = zip(*grads_and_vars)
|
| 383 |
+
self.apply(grads, trainable_variables)
|
| 384 |
+
# Return iterations for compat with tf.keras.
|
| 385 |
+
return self._iterations
|
| 386 |
+
|
| 387 |
+
def apply(self, grads, trainable_variables=None):
|
| 388 |
+
"""Update traininable variables according to provided gradient values.
|
| 389 |
+
|
| 390 |
+
`grads` should be a list of gradient tensors
|
| 391 |
+
with 1:1 mapping to the list of variables the optimizer was built with.
|
| 392 |
+
|
| 393 |
+
`trainable_variables` can be provided
|
| 394 |
+
on the first call to build the optimizer.
|
| 395 |
+
"""
|
| 396 |
+
if len(grads) == 0:
|
| 397 |
+
# It is possible that the grad is empty. In this case,
|
| 398 |
+
# `apply_gradients` is a no-op.
|
| 399 |
+
return
|
| 400 |
+
|
| 401 |
+
if trainable_variables is None:
|
| 402 |
+
if not self.built:
|
| 403 |
+
raise ValueError(
|
| 404 |
+
"When passing `grads` without `variables`, the optimizer "
|
| 405 |
+
"must already be built on a list of variables. "
|
| 406 |
+
"Call `optimizer.build(trainable_variables)` first. "
|
| 407 |
+
)
|
| 408 |
+
if len(grads) != len(self._trainable_variables_indices):
|
| 409 |
+
raise ValueError(
|
| 410 |
+
"When passing `grads` as a list of gradient tensors, the "
|
| 411 |
+
f"gradients must match `optimizer.variables` one-to-on. "
|
| 412 |
+
f"Received a list of {len(grads)} gradients, but the "
|
| 413 |
+
f"optimizer is tracking {len(self._trainable_variables)} "
|
| 414 |
+
"trainable variables."
|
| 415 |
+
)
|
| 416 |
+
trainable_variables = self._trainable_variables
|
| 417 |
+
else:
|
| 418 |
+
trainable_variables = list(trainable_variables)
|
| 419 |
+
# Optionally build optimizer.
|
| 420 |
+
if not self.built:
|
| 421 |
+
with backend.name_scope(self.name, caller=self):
|
| 422 |
+
self.build(trainable_variables)
|
| 423 |
+
self.built = True
|
| 424 |
+
self._check_variables_are_known(trainable_variables)
|
| 425 |
+
|
| 426 |
+
with backend.name_scope(self.name, caller=self):
|
| 427 |
+
# Overwrite targeted variables directly with their gradients if
|
| 428 |
+
# their `overwrite_with_gradient` is set.
|
| 429 |
+
grads, trainable_variables = (
|
| 430 |
+
self._overwrite_variables_directly_with_gradients(
|
| 431 |
+
grads, trainable_variables
|
| 432 |
+
)
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Filter empty gradients.
|
| 436 |
+
grads, trainable_variables = self._filter_empty_gradients(
|
| 437 |
+
grads, trainable_variables
|
| 438 |
+
)
|
| 439 |
+
if len(list(grads)) == 0:
|
| 440 |
+
return
|
| 441 |
+
|
| 442 |
+
# Unscale gradients.
|
| 443 |
+
scale = self.loss_scale_factor
|
| 444 |
+
if scale is not None:
|
| 445 |
+
grads = [g if g is None else g / scale for g in grads]
|
| 446 |
+
|
| 447 |
+
# Apply gradient updates.
|
| 448 |
+
self._backend_apply_gradients(grads, trainable_variables)
|
| 449 |
+
# Apply variable constraints after applying gradients.
|
| 450 |
+
for variable in trainable_variables:
|
| 451 |
+
if variable.constraint is not None:
|
| 452 |
+
variable.assign(variable.constraint(variable))
|
| 453 |
+
|
| 454 |
+
def _backend_apply_gradients(self, grads, trainable_variables):
|
| 455 |
+
"""Apply method that can be overridden by different backends.
|
| 456 |
+
|
| 457 |
+
JAX overrides it in order to deal with statelessness in gradient
|
| 458 |
+
accumulation and EMA handling.
|
| 459 |
+
|
| 460 |
+
The below implementation is intended to be generally backend-agnostic,
|
| 461 |
+
but may not work with all backends.
|
| 462 |
+
|
| 463 |
+
This method does 4 things:
|
| 464 |
+
- Call the optimizer's update_step() to update trainable variables
|
| 465 |
+
and optimizer variables.
|
| 466 |
+
- Update EMA variables, if EMA is configured.
|
| 467 |
+
- Update gradient accumulators, if gradient accumulation is configured.
|
| 468 |
+
- Update the iteration counter.
|
| 469 |
+
"""
|
| 470 |
+
if self.gradient_accumulation_steps:
|
| 471 |
+
is_update_step = (
|
| 472 |
+
self._iterations + 1
|
| 473 |
+
) % self.gradient_accumulation_steps == 0
|
| 474 |
+
# `trainable_variables` might have been filtered in previous
|
| 475 |
+
# processing steps, so we need to ensure the correct mapping between
|
| 476 |
+
# `self._accumulated_gradients` and `trainable_variables`
|
| 477 |
+
acc_grads = [
|
| 478 |
+
self._accumulated_gradients[self._get_variable_index(v)]
|
| 479 |
+
for v in trainable_variables
|
| 480 |
+
]
|
| 481 |
+
|
| 482 |
+
def _update_step_fn(grads, trainable_variables):
|
| 483 |
+
# Run update step with accumulated grads + reset accumulators
|
| 484 |
+
steps = self.gradient_accumulation_steps
|
| 485 |
+
grads = [
|
| 486 |
+
(g + acc_g) / steps for g, acc_g in zip(grads, acc_grads)
|
| 487 |
+
]
|
| 488 |
+
|
| 489 |
+
# Apply clipping and weight decay.
|
| 490 |
+
grads = self._clip_gradients(grads)
|
| 491 |
+
self._apply_weight_decay(trainable_variables)
|
| 492 |
+
|
| 493 |
+
self._backend_update_step(
|
| 494 |
+
grads, trainable_variables, self.learning_rate
|
| 495 |
+
)
|
| 496 |
+
self._backend_reset_gradient_accumulators()
|
| 497 |
+
|
| 498 |
+
ops.cond(
|
| 499 |
+
is_update_step,
|
| 500 |
+
lambda: _update_step_fn(grads, trainable_variables),
|
| 501 |
+
lambda: self._backend_increment_gradient_accumulators(
|
| 502 |
+
grads, acc_grads
|
| 503 |
+
),
|
| 504 |
+
)
|
| 505 |
+
else:
|
| 506 |
+
# Apply clipping and weight decay.
|
| 507 |
+
grads = self._clip_gradients(grads)
|
| 508 |
+
self._apply_weight_decay(trainable_variables)
|
| 509 |
+
|
| 510 |
+
# Run update step.
|
| 511 |
+
self._backend_update_step(
|
| 512 |
+
grads, trainable_variables, self.learning_rate
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if self.use_ema:
|
| 516 |
+
self._update_model_variables_moving_average(
|
| 517 |
+
self._trainable_variables
|
| 518 |
+
)
|
| 519 |
+
if self.ema_overwrite_frequency:
|
| 520 |
+
# Only when self.ema_overwrite_frequency is not None, we
|
| 521 |
+
# overwrite the model variables.
|
| 522 |
+
should_overwrite_model_vars = (
|
| 523 |
+
self.iterations + 1
|
| 524 |
+
) % self.ema_overwrite_frequency == 0
|
| 525 |
+
ops.cond(
|
| 526 |
+
should_overwrite_model_vars,
|
| 527 |
+
lambda: self._overwrite_model_variables_with_average_value(
|
| 528 |
+
self._trainable_variables
|
| 529 |
+
),
|
| 530 |
+
lambda: None,
|
| 531 |
+
)
|
| 532 |
+
# Update iteration counter.
|
| 533 |
+
self._iterations.assign_add(1)
|
| 534 |
+
|
| 535 |
+
def _backend_update_step(self, grads, trainable_variables, learning_rate):
|
| 536 |
+
"""Collective update_step that can be overridden by the backend.
|
| 537 |
+
|
| 538 |
+
It is overridden by torch for performance reasons, and
|
| 539 |
+
by TF to support tf.distribute.
|
| 540 |
+
"""
|
| 541 |
+
for grad, var in zip(grads, trainable_variables):
|
| 542 |
+
self.update_step(grad, var, learning_rate)
|
| 543 |
+
|
| 544 |
+
def _backend_reset_gradient_accumulators(self):
|
| 545 |
+
for g_acc in self._accumulated_gradients:
|
| 546 |
+
g_acc.assign(ops.zeros(g_acc.shape, dtype=g_acc.dtype))
|
| 547 |
+
|
| 548 |
+
def _backend_increment_gradient_accumulators(self, grads, acc_grads):
|
| 549 |
+
new_g_accs = [(g + acc_g) for g, acc_g in zip(grads, acc_grads)]
|
| 550 |
+
for n_g_acc, g_acc in zip(new_g_accs, acc_grads):
|
| 551 |
+
g_acc.assign(n_g_acc)
|
| 552 |
+
|
| 553 |
+
def stateless_apply(self, optimizer_variables, grads, trainable_variables):
|
| 554 |
+
self._check_super_called()
|
| 555 |
+
|
| 556 |
+
if not self.built:
|
| 557 |
+
raise ValueError(
|
| 558 |
+
f"To call `stateless_apply`, {self.__class__.__name__} "
|
| 559 |
+
"must be built (i.e. its variables must have been created). "
|
| 560 |
+
"You can build it via `optimizer.build(trainable_variables)`."
|
| 561 |
+
)
|
| 562 |
+
if len(optimizer_variables) != len(self.variables):
|
| 563 |
+
raise ValueError(
|
| 564 |
+
"Argument `optimizer_variables` must be a list of tensors "
|
| 565 |
+
f"corresponding 1:1 to {self.__class__.__name__}().variables. "
|
| 566 |
+
f"Received list with length {len(optimizer_variables)}, but "
|
| 567 |
+
f"expected {len(self.variables)} variables."
|
| 568 |
+
)
|
| 569 |
+
if len(trainable_variables) != len(self._trainable_variables):
|
| 570 |
+
raise ValueError(
|
| 571 |
+
"Argument `optimizer_variables` must be a list of tensors "
|
| 572 |
+
"corresponding 1:1 to the trainable variables list that "
|
| 573 |
+
"the optimizer was built with. Received "
|
| 574 |
+
f"len(trainable_variables) == {len(trainable_variables)} "
|
| 575 |
+
"whereas the optimizer was built with "
|
| 576 |
+
f"{len(self._trainable_variables)} variables."
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# Gather variable mapping
|
| 580 |
+
mapping = list(
|
| 581 |
+
zip(self._trainable_variables, trainable_variables)
|
| 582 |
+
) + list(zip(self.variables, optimizer_variables))
|
| 583 |
+
|
| 584 |
+
# Call in stateless scope
|
| 585 |
+
with backend.StatelessScope(state_mapping=mapping) as scope:
|
| 586 |
+
self.apply(grads)
|
| 587 |
+
|
| 588 |
+
# Gather updated variables
|
| 589 |
+
trainable_variables = []
|
| 590 |
+
for v in self._trainable_variables:
|
| 591 |
+
new_v = scope.get_current_value(v)
|
| 592 |
+
if new_v is not None:
|
| 593 |
+
trainable_variables.append(new_v)
|
| 594 |
+
else:
|
| 595 |
+
trainable_variables.append(v)
|
| 596 |
+
optimizer_variables = []
|
| 597 |
+
for v in self.variables:
|
| 598 |
+
new_v = scope.get_current_value(v)
|
| 599 |
+
if new_v is not None:
|
| 600 |
+
optimizer_variables.append(new_v)
|
| 601 |
+
else:
|
| 602 |
+
optimizer_variables.append(v)
|
| 603 |
+
return trainable_variables, optimizer_variables
|
| 604 |
+
|
| 605 |
+
def scale_loss(self, loss):
|
| 606 |
+
"""Scale the loss before computing gradients.
|
| 607 |
+
|
| 608 |
+
Scales the loss before gradients are computed in a `train_step`. This
|
| 609 |
+
is primarily useful during mixed precision training to prevent numeric
|
| 610 |
+
underflow.
|
| 611 |
+
"""
|
| 612 |
+
if self.loss_scale_factor is not None:
|
| 613 |
+
return loss * self.loss_scale_factor
|
| 614 |
+
return loss
|
| 615 |
+
|
| 616 |
+
@property
|
| 617 |
+
def learning_rate(self):
|
| 618 |
+
return self._get_current_learning_rate()
|
| 619 |
+
|
| 620 |
+
@learning_rate.setter
|
| 621 |
+
def learning_rate(self, learning_rate):
|
| 622 |
+
if isinstance(self._learning_rate, backend.Variable):
|
| 623 |
+
prev_lr_var = self._learning_rate
|
| 624 |
+
else:
|
| 625 |
+
prev_lr_var = None
|
| 626 |
+
if isinstance(
|
| 627 |
+
learning_rate, learning_rate_schedule.LearningRateSchedule
|
| 628 |
+
):
|
| 629 |
+
self._learning_rate = learning_rate
|
| 630 |
+
elif callable(learning_rate):
|
| 631 |
+
self._learning_rate = learning_rate
|
| 632 |
+
else:
|
| 633 |
+
if isinstance(
|
| 634 |
+
self._learning_rate, learning_rate_schedule.LearningRateSchedule
|
| 635 |
+
):
|
| 636 |
+
raise TypeError(
|
| 637 |
+
"This optimizer was created with a `LearningRateSchedule`"
|
| 638 |
+
" object as its `learning_rate` constructor argument, "
|
| 639 |
+
"hence its learning rate is not settable. If you need the"
|
| 640 |
+
" learning rate to be settable, you should instantiate "
|
| 641 |
+
"the optimizer with a float `learning_rate` argument."
|
| 642 |
+
)
|
| 643 |
+
self._learning_rate.assign(learning_rate)
|
| 644 |
+
if prev_lr_var is not None and not isinstance(
|
| 645 |
+
self._learning_rate, backend.Variable
|
| 646 |
+
):
|
| 647 |
+
# Untrack learning rate variable
|
| 648 |
+
self._untrack_variable(prev_lr_var)
|
| 649 |
+
|
| 650 |
+
def set_weights(self, weights):
|
| 651 |
+
"""Set the weights of the optimizer."""
|
| 652 |
+
if not self.built:
|
| 653 |
+
raise ValueError(
|
| 654 |
+
"You are calling `set_weights()` on an optimizer that has not "
|
| 655 |
+
"yet been built. Please call "
|
| 656 |
+
"`optimizer.build(trainable_variables)` to create the "
|
| 657 |
+
"optimizer weights before calling `set_weights()`."
|
| 658 |
+
)
|
| 659 |
+
for variable, weight in zip(self._variables, weights):
|
| 660 |
+
if variable.shape != weight.shape:
|
| 661 |
+
raise ValueError(
|
| 662 |
+
f"Optimizer variable {self._var_key(variable)} has shape "
|
| 663 |
+
f"{str(variable.shape)} not compatible with provided "
|
| 664 |
+
f"weight shape {str(weight.shape)}."
|
| 665 |
+
)
|
| 666 |
+
variable.assign(weight)
|
| 667 |
+
|
| 668 |
+
def save_own_variables(self, store):
|
| 669 |
+
"""Get the state of this optimizer object."""
|
| 670 |
+
for i, variable in enumerate(self.variables):
|
| 671 |
+
store[str(i)] = variable.numpy()
|
| 672 |
+
|
| 673 |
+
def load_own_variables(self, store):
|
| 674 |
+
"""Set the state of this optimizer object."""
|
| 675 |
+
if len(store.keys()) != len(self.variables):
|
| 676 |
+
msg = (
|
| 677 |
+
f"Skipping variable loading for optimizer '{self.name}', "
|
| 678 |
+
f"because it has {len(self.variables)} variables whereas "
|
| 679 |
+
f"the saved optimizer has {len(store.keys())} variables. "
|
| 680 |
+
)
|
| 681 |
+
if len(self.variables) == 0:
|
| 682 |
+
msg += (
|
| 683 |
+
"This is likely because the optimizer has not been "
|
| 684 |
+
"called/built yet."
|
| 685 |
+
)
|
| 686 |
+
warnings.warn(msg, stacklevel=2)
|
| 687 |
+
return
|
| 688 |
+
for i, variable in enumerate(self.variables):
|
| 689 |
+
variable.assign(store[str(i)])
|
| 690 |
+
|
| 691 |
+
def _get_current_learning_rate(self):
|
| 692 |
+
if isinstance(
|
| 693 |
+
self._learning_rate, learning_rate_schedule.LearningRateSchedule
|
| 694 |
+
):
|
| 695 |
+
return self._learning_rate(self._iterations)
|
| 696 |
+
elif callable(self._learning_rate):
|
| 697 |
+
return self._learning_rate()
|
| 698 |
+
return self._learning_rate
|
| 699 |
+
|
| 700 |
+
def _overwrite_variables_directly_with_gradients(self, grads, vars):
|
| 701 |
+
"""Overwrite the variables directly by their gradients.
|
| 702 |
+
|
| 703 |
+
This method is designed for a special case where we want to overwrite
|
| 704 |
+
the variable directly with its computed gradient. For example, in float8
|
| 705 |
+
training, new `scale` and `amax_history` are computed as gradients, and
|
| 706 |
+
we want to overwrite them directly instead of following the typical
|
| 707 |
+
procedure such as gradient descent with a learning rate, gradient
|
| 708 |
+
clipping and weight decaying.
|
| 709 |
+
|
| 710 |
+
After the update, the processed pairs will be filtered out.
|
| 711 |
+
"""
|
| 712 |
+
# Shortcut for `tf.Variable` because it doesn't have a
|
| 713 |
+
# `overwrite_with_gradient` attr
|
| 714 |
+
if any(not hasattr(v, "overwrite_with_gradient") for v in vars):
|
| 715 |
+
return grads, vars
|
| 716 |
+
|
| 717 |
+
# Shallow copies
|
| 718 |
+
filtered_grads = list(grads)
|
| 719 |
+
filtered_vars = list(vars)
|
| 720 |
+
|
| 721 |
+
# Iterate from right to left for safe popping
|
| 722 |
+
for i in range(len(filtered_grads) - 1, -1, -1):
|
| 723 |
+
g, v = filtered_grads[i], filtered_vars[i]
|
| 724 |
+
if v.overwrite_with_gradient:
|
| 725 |
+
if self.gradient_accumulation_steps:
|
| 726 |
+
# Utilize a stateless manner for JAX compatibility
|
| 727 |
+
steps = self.gradient_accumulation_steps
|
| 728 |
+
is_update_step = (self._iterations + 1) % steps == 0
|
| 729 |
+
acc_g = self._accumulated_gradients[
|
| 730 |
+
self._get_variable_index(v)
|
| 731 |
+
]
|
| 732 |
+
# `ops.maximum` is utilized for gradient accumulation for
|
| 733 |
+
# `overwrite_with_gradient=True` variables
|
| 734 |
+
new_g_acc = ops.cond(
|
| 735 |
+
is_update_step,
|
| 736 |
+
lambda: ops.zeros(g.shape, dtype=g.dtype),
|
| 737 |
+
lambda: ops.maximum(g, acc_g),
|
| 738 |
+
)
|
| 739 |
+
new_g = ops.cond(
|
| 740 |
+
is_update_step,
|
| 741 |
+
lambda: ops.maximum(g, acc_g),
|
| 742 |
+
lambda: g,
|
| 743 |
+
)
|
| 744 |
+
new_v = ops.cond(
|
| 745 |
+
is_update_step, lambda: new_g, lambda: v.value
|
| 746 |
+
)
|
| 747 |
+
v.assign(new_v)
|
| 748 |
+
acc_g.assign(new_g_acc)
|
| 749 |
+
else:
|
| 750 |
+
v.assign(g)
|
| 751 |
+
filtered_grads.pop(i)
|
| 752 |
+
filtered_vars.pop(i)
|
| 753 |
+
return filtered_grads, filtered_vars
|
| 754 |
+
|
| 755 |
+
def _filter_empty_gradients(self, grads, vars):
|
| 756 |
+
filtered_grads = list(grads)
|
| 757 |
+
filtered_vars = list(vars)
|
| 758 |
+
missing_grad_vars = []
|
| 759 |
+
|
| 760 |
+
# Iterate from right to left for safe popping
|
| 761 |
+
for i in range(len(filtered_grads) - 1, -1, -1):
|
| 762 |
+
if filtered_grads[i] is None:
|
| 763 |
+
filtered_grads.pop(i)
|
| 764 |
+
v = filtered_vars.pop(i)
|
| 765 |
+
try:
|
| 766 |
+
missing_grad_vars.append(v.path)
|
| 767 |
+
except AttributeError:
|
| 768 |
+
# `tf.Variable` doesn't have `path` attr.
|
| 769 |
+
missing_grad_vars.append(v.name)
|
| 770 |
+
|
| 771 |
+
if not filtered_grads:
|
| 772 |
+
raise ValueError("No gradients provided for any variable.")
|
| 773 |
+
if missing_grad_vars:
|
| 774 |
+
warnings.warn(
|
| 775 |
+
"Gradients do not exist for variables "
|
| 776 |
+
f"{list(reversed(missing_grad_vars))} when minimizing the loss."
|
| 777 |
+
" If using `model.compile()`, did you forget to provide a "
|
| 778 |
+
"`loss` argument?"
|
| 779 |
+
)
|
| 780 |
+
return filtered_grads, filtered_vars
|
| 781 |
+
|
| 782 |
+
def _clip_gradients(self, grads):
|
| 783 |
+
if self.clipnorm and self.clipnorm > 0:
|
| 784 |
+
return [
|
| 785 |
+
self._clip_by_norm(g) if g is not None else g for g in grads
|
| 786 |
+
]
|
| 787 |
+
elif self.global_clipnorm and self.global_clipnorm > 0:
|
| 788 |
+
return clip_by_global_norm(grads, self.global_clipnorm)
|
| 789 |
+
elif self.clipvalue and self.clipvalue > 0:
|
| 790 |
+
v = self.clipvalue
|
| 791 |
+
return [ops.clip(g, -v, v) if g is not None else g for g in grads]
|
| 792 |
+
else:
|
| 793 |
+
return grads
|
| 794 |
+
|
| 795 |
+
def exclude_from_weight_decay(self, var_list=None, var_names=None):
|
| 796 |
+
"""Exclude variables from weight decay.
|
| 797 |
+
|
| 798 |
+
This method must be called before the optimizer's `build` method is
|
| 799 |
+
called. You can set specific variables to exclude out, or set a list of
|
| 800 |
+
strings as the anchor words, if any of which appear in a variable's
|
| 801 |
+
name, then the variable is excluded.
|
| 802 |
+
|
| 803 |
+
Args:
|
| 804 |
+
var_list: A list of `Variable`s to exclude from weight decay.
|
| 805 |
+
var_names: A list of strings. If any string in `var_names` appear
|
| 806 |
+
in the model variable's name, then this model variable is
|
| 807 |
+
excluded from weight decay. For example, `var_names=['bias']`
|
| 808 |
+
excludes all bias variables from weight decay.
|
| 809 |
+
"""
|
| 810 |
+
if hasattr(self, "_built") and self._built:
|
| 811 |
+
raise ValueError(
|
| 812 |
+
"`exclude_from_weight_decay()` can only be configured before "
|
| 813 |
+
"the optimizer is built."
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
# Use a `set` for the ids of `var_list` to speed up the searching
|
| 817 |
+
if var_list:
|
| 818 |
+
self._exclude_from_weight_decay = set(
|
| 819 |
+
self._var_key(variable) for variable in var_list
|
| 820 |
+
)
|
| 821 |
+
else:
|
| 822 |
+
self._exclude_from_weight_decay = set()
|
| 823 |
+
|
| 824 |
+
# Precompile the pattern for `var_names` to speed up the searching
|
| 825 |
+
if var_names and len(var_names) > 0:
|
| 826 |
+
self._exclude_from_weight_decay_pattern = re.compile(
|
| 827 |
+
"|".join(set(var_names))
|
| 828 |
+
)
|
| 829 |
+
else:
|
| 830 |
+
self._exclude_from_weight_decay_pattern = None
|
| 831 |
+
|
| 832 |
+
# Reset cache
|
| 833 |
+
self._exclude_from_weight_decay_cache = dict()
|
| 834 |
+
|
| 835 |
+
def _use_weight_decay(self, variable):
|
| 836 |
+
variable_id = self._var_key(variable)
|
| 837 |
+
|
| 838 |
+
# Immediately return the value if `variable_id` hits the cache
|
| 839 |
+
if not hasattr(self, "_exclude_from_weight_decay_cache"):
|
| 840 |
+
self._exclude_from_weight_decay_cache = dict()
|
| 841 |
+
if variable_id in self._exclude_from_weight_decay_cache:
|
| 842 |
+
return self._exclude_from_weight_decay_cache[variable_id]
|
| 843 |
+
|
| 844 |
+
# Determine whether the variable should apply weight decay or not
|
| 845 |
+
exclude_from_weight_decay = getattr(
|
| 846 |
+
self, "_exclude_from_weight_decay", set()
|
| 847 |
+
)
|
| 848 |
+
exclude_from_weight_decay_pattern = getattr(
|
| 849 |
+
self, "_exclude_from_weight_decay_pattern", None
|
| 850 |
+
)
|
| 851 |
+
if variable_id in exclude_from_weight_decay:
|
| 852 |
+
self._exclude_from_weight_decay_cache[variable_id] = False
|
| 853 |
+
return False
|
| 854 |
+
if exclude_from_weight_decay_pattern is not None:
|
| 855 |
+
if (
|
| 856 |
+
re.search(exclude_from_weight_decay_pattern, variable.name)
|
| 857 |
+
is not None
|
| 858 |
+
):
|
| 859 |
+
self._exclude_from_weight_decay_cache[variable_id] = False
|
| 860 |
+
return False
|
| 861 |
+
self._exclude_from_weight_decay_cache[variable_id] = True
|
| 862 |
+
return True
|
| 863 |
+
|
| 864 |
+
def _apply_weight_decay(self, variables):
|
| 865 |
+
if self.weight_decay is None:
|
| 866 |
+
return
|
| 867 |
+
for variable in variables:
|
| 868 |
+
if self._use_weight_decay(variable):
|
| 869 |
+
lr = ops.cast(self.learning_rate, variable.dtype)
|
| 870 |
+
wd = ops.cast(self.weight_decay, variable.dtype)
|
| 871 |
+
variable.assign(variable - variable * wd * lr)
|
| 872 |
+
|
| 873 |
+
def _check_super_called(self):
|
| 874 |
+
if not hasattr(self, "_lock"):
|
| 875 |
+
raise RuntimeError(
|
| 876 |
+
f"In optimizer '{self.__class__.__name__}', you forgot to call "
|
| 877 |
+
"`super().__init__()` as the first statement "
|
| 878 |
+
"in the `__init__()` method. "
|
| 879 |
+
"Go add it!"
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
def _update_model_variables_moving_average(self, trainable_variables):
|
| 883 |
+
"""Update the stored moving average using the latest value."""
|
| 884 |
+
if self.use_ema:
|
| 885 |
+
for var, average in zip(
|
| 886 |
+
trainable_variables, self._model_variables_moving_average
|
| 887 |
+
):
|
| 888 |
+
not_first_step = ops.not_equal(self.iterations, 0)
|
| 889 |
+
momentum = (
|
| 890 |
+
ops.cast(not_first_step, var.dtype) * self.ema_momentum
|
| 891 |
+
)
|
| 892 |
+
average.assign(momentum * average + (1 - momentum) * var)
|
| 893 |
+
|
| 894 |
+
def _overwrite_model_variables_with_average_value(
|
| 895 |
+
self, trainable_variables
|
| 896 |
+
):
|
| 897 |
+
"""Overwrite model variables with its moving average."""
|
| 898 |
+
if len(trainable_variables) != len(
|
| 899 |
+
self._model_variables_moving_average
|
| 900 |
+
):
|
| 901 |
+
raise ValueError(
|
| 902 |
+
f"The length of model variables ({len(trainable_variables)}) "
|
| 903 |
+
"to override does not match the length of model variables "
|
| 904 |
+
"stored in the optimizer "
|
| 905 |
+
f"({len(self._model_variables_moving_average)}). Please "
|
| 906 |
+
"check if the optimizer was called on your model."
|
| 907 |
+
)
|
| 908 |
+
for var, average_var in zip(
|
| 909 |
+
trainable_variables, self._model_variables_moving_average
|
| 910 |
+
):
|
| 911 |
+
var.assign(average_var)
|
| 912 |
+
|
| 913 |
+
def finalize_variable_values(self, var_list):
|
| 914 |
+
"""Set the final value of model's trainable variables.
|
| 915 |
+
|
| 916 |
+
Sometimes there are some extra steps before ending the variable updates,
|
| 917 |
+
such as overriding the model variables with its average value.
|
| 918 |
+
|
| 919 |
+
Args:
|
| 920 |
+
var_list: list of model variables.
|
| 921 |
+
"""
|
| 922 |
+
if self.use_ema:
|
| 923 |
+
# If the optimizer uses EMA, then when finalizing, we replace the
|
| 924 |
+
# model variable value with its moving average stored inside
|
| 925 |
+
# optimizer.
|
| 926 |
+
self._overwrite_model_variables_with_average_value(var_list)
|
| 927 |
+
|
| 928 |
+
def _obj_type(self):
|
| 929 |
+
return "Optimizer"
|
| 930 |
+
|
| 931 |
+
def get_config(self):
|
| 932 |
+
"""Returns the config of the optimizer.
|
| 933 |
+
|
| 934 |
+
An optimizer config is a Python dictionary (serializable)
|
| 935 |
+
containing the configuration of an optimizer.
|
| 936 |
+
The same optimizer can be reinstantiated later
|
| 937 |
+
(without any saved state) from this configuration.
|
| 938 |
+
|
| 939 |
+
Subclass optimizer should override this method to include other
|
| 940 |
+
hyperparameters.
|
| 941 |
+
|
| 942 |
+
Returns:
|
| 943 |
+
Python dictionary.
|
| 944 |
+
"""
|
| 945 |
+
|
| 946 |
+
if isinstance(
|
| 947 |
+
self._learning_rate, learning_rate_schedule.LearningRateSchedule
|
| 948 |
+
):
|
| 949 |
+
learning_rate = learning_rate_schedule.serialize(
|
| 950 |
+
self._learning_rate
|
| 951 |
+
)
|
| 952 |
+
elif isinstance(self._learning_rate, backend.Variable):
|
| 953 |
+
learning_rate = float(self._learning_rate.numpy())
|
| 954 |
+
elif ops.is_tensor(self._learning_rate):
|
| 955 |
+
learning_rate = float(self._learning_rate)
|
| 956 |
+
elif callable(self._learning_rate):
|
| 957 |
+
learning_rate = serialization_lib.serialize_keras_object(
|
| 958 |
+
self._learning_rate
|
| 959 |
+
)
|
| 960 |
+
else:
|
| 961 |
+
learning_rate = 0.5
|
| 962 |
+
|
| 963 |
+
config = {
|
| 964 |
+
"name": self.name,
|
| 965 |
+
"learning_rate": learning_rate,
|
| 966 |
+
"weight_decay": self.weight_decay,
|
| 967 |
+
"clipnorm": self.clipnorm,
|
| 968 |
+
"global_clipnorm": self.global_clipnorm,
|
| 969 |
+
"clipvalue": self.clipvalue,
|
| 970 |
+
"use_ema": self.use_ema,
|
| 971 |
+
"ema_momentum": self.ema_momentum,
|
| 972 |
+
"ema_overwrite_frequency": self.ema_overwrite_frequency,
|
| 973 |
+
"loss_scale_factor": self.loss_scale_factor,
|
| 974 |
+
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
| 975 |
+
}
|
| 976 |
+
return config
|
| 977 |
+
|
| 978 |
+
@classmethod
|
| 979 |
+
def from_config(cls, config, custom_objects=None):
|
| 980 |
+
"""Creates an optimizer from its config.
|
| 981 |
+
|
| 982 |
+
This method is the reverse of `get_config`, capable of instantiating the
|
| 983 |
+
same optimizer from the config dictionary.
|
| 984 |
+
|
| 985 |
+
Args:
|
| 986 |
+
config: A Python dictionary, typically the output of get_config.
|
| 987 |
+
custom_objects: A Python dictionary mapping names to additional
|
| 988 |
+
user-defined Python objects needed to recreate this optimizer.
|
| 989 |
+
|
| 990 |
+
Returns:
|
| 991 |
+
An optimizer instance.
|
| 992 |
+
"""
|
| 993 |
+
if "learning_rate" in config:
|
| 994 |
+
if isinstance(config["learning_rate"], dict):
|
| 995 |
+
config["learning_rate"] = (
|
| 996 |
+
serialization_lib.deserialize_keras_object(
|
| 997 |
+
config["learning_rate"], custom_objects=custom_objects
|
| 998 |
+
)
|
| 999 |
+
)
|
| 1000 |
+
return cls(**config)
|
| 1001 |
+
|
| 1002 |
+
def __setattr__(self, name, value):
|
| 1003 |
+
# Prevent users from attaching state to the
|
| 1004 |
+
# layer before `super()` is called -- since that
|
| 1005 |
+
# state would silently not be tracked.
|
| 1006 |
+
if name != "_lock":
|
| 1007 |
+
self._check_super_called()
|
| 1008 |
+
# Track Variables.
|
| 1009 |
+
if hasattr(self, "_tracker"):
|
| 1010 |
+
value = self._tracker.track(value)
|
| 1011 |
+
return super().__setattr__(name, value)
|
| 1012 |
+
|
| 1013 |
+
def _clip_by_norm(self, values, axes=None):
|
| 1014 |
+
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
|
| 1015 |
+
l2sum = ops.sum(ops.square(values), axes, keepdims=True)
|
| 1016 |
+
pred = l2sum > 0
|
| 1017 |
+
# Two-tap tf.where trick to bypass NaN gradients
|
| 1018 |
+
l2sum_safe = ops.where(pred, l2sum, ops.ones_like(l2sum))
|
| 1019 |
+
l2norm = ops.where(pred, ops.sqrt(l2sum_safe), l2sum)
|
| 1020 |
+
intermediate = ops.multiply(values, self.clipnorm)
|
| 1021 |
+
values_clip = ops.convert_to_tensor(intermediate) / ops.maximum(
|
| 1022 |
+
l2norm, self.clipnorm
|
| 1023 |
+
)
|
| 1024 |
+
return values_clip
|
| 1025 |
+
|
| 1026 |
+
def _untrack_variable(self, variable):
|
| 1027 |
+
previous_lock_state = self._tracker.locked
|
| 1028 |
+
self._tracker.unlock()
|
| 1029 |
+
self._tracker.untrack(variable)
|
| 1030 |
+
if previous_lock_state is True:
|
| 1031 |
+
self._tracker.lock()
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
base_optimizer_keyword_args = """name: String. The name to use
|
| 1035 |
+
for momentum accumulator weights created by
|
| 1036 |
+
the optimizer.
|
| 1037 |
+
weight_decay: Float. If set, weight decay is applied.
|
| 1038 |
+
clipnorm: Float. If set, the gradient of each weight is individually
|
| 1039 |
+
clipped so that its norm is no higher than this value.
|
| 1040 |
+
clipvalue: Float. If set, the gradient of each weight is clipped to be
|
| 1041 |
+
no higher than this value.
|
| 1042 |
+
global_clipnorm: Float. If set, the gradient of all weights is clipped
|
| 1043 |
+
so that their global norm is no higher than this value.
|
| 1044 |
+
use_ema: Boolean, defaults to `False`.
|
| 1045 |
+
If `True`, exponential moving average
|
| 1046 |
+
(EMA) is applied. EMA consists of computing an exponential moving
|
| 1047 |
+
average of the weights of the model (as the weight values change
|
| 1048 |
+
after each training batch), and periodically overwriting the
|
| 1049 |
+
weights with their moving average.
|
| 1050 |
+
ema_momentum: Float, defaults to 0.99. Only used if `use_ema=True`.
|
| 1051 |
+
This is the momentum to use when computing
|
| 1052 |
+
the EMA of the model's weights:
|
| 1053 |
+
`new_average = ema_momentum * old_average + (1 - ema_momentum) *
|
| 1054 |
+
current_variable_value`.
|
| 1055 |
+
ema_overwrite_frequency: Int or None, defaults to None. Only used if
|
| 1056 |
+
`use_ema=True`. Every `ema_overwrite_frequency` steps of iterations,
|
| 1057 |
+
we overwrite the model variable by its moving average.
|
| 1058 |
+
If None, the optimizer
|
| 1059 |
+
does not overwrite model variables in the middle of training,
|
| 1060 |
+
and you need to explicitly overwrite the variables
|
| 1061 |
+
at the end of training by calling
|
| 1062 |
+
`optimizer.finalize_variable_values()` (which updates the model
|
| 1063 |
+
variables in-place). When using the built-in `fit()` training loop,
|
| 1064 |
+
this happens automatically after the last epoch,
|
| 1065 |
+
and you don't need to do anything.
|
| 1066 |
+
loss_scale_factor: Float or `None`. If a float, the scale factor will
|
| 1067 |
+
be multiplied the loss before computing gradients, and the inverse
|
| 1068 |
+
of the scale factor will be multiplied by the gradients before
|
| 1069 |
+
updating variables. Useful for preventing underflow during
|
| 1070 |
+
mixed precision training. Alternately,
|
| 1071 |
+
`keras.optimizers.LossScaleOptimizer` will
|
| 1072 |
+
automatically set a loss scale factor.
|
| 1073 |
+
gradient_accumulation_steps: Int or `None`. If an int, model & optimizer
|
| 1074 |
+
variables will not be updated at every step; instead they will be
|
| 1075 |
+
updated every `gradient_accumulation_steps` steps, using the average
|
| 1076 |
+
value of the gradients since the last update. This is known as
|
| 1077 |
+
"gradient accumulation". This can be useful
|
| 1078 |
+
when your batch size is very small, in order to reduce gradient
|
| 1079 |
+
noise at each update step. EMA frequency will look at "accumulated"
|
| 1080 |
+
iterations value (optimizer steps // gradient_accumulation_steps).
|
| 1081 |
+
Learning rate schedules will look at "real" iterations value
|
| 1082 |
+
(optimizer steps).
|
| 1083 |
+
"""
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
def global_norm(value_list):
|
| 1087 |
+
"""Computes the global norm of multiple tensors."""
|
| 1088 |
+
squared_norms = [
|
| 1089 |
+
ops.sum(ops.square(v)) for v in value_list if v is not None
|
| 1090 |
+
]
|
| 1091 |
+
squared_norm = ops.sum(ops.stack(squared_norms))
|
| 1092 |
+
return ops.sqrt(squared_norm)
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
def clip_by_global_norm(value_list, clip_norm):
|
| 1096 |
+
use_norm = global_norm(value_list)
|
| 1097 |
+
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
|
| 1098 |
+
scale_for_finite = clip_norm * ops.minimum(1.0 / use_norm, 1.0 / clip_norm)
|
| 1099 |
+
# If use_norm is any finite number, this is a no-op. For inf/-inf/NaN,
|
| 1100 |
+
# this will make scale NaN.
|
| 1101 |
+
scale = scale_for_finite + (use_norm - use_norm)
|
| 1102 |
+
return [v * scale if v is not None else v for v in value_list]
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/ftrl.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import initializers
|
| 2 |
+
from keras.src import ops
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.optimizers import optimizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@keras_export(["keras.optimizers.Ftrl"])
|
| 8 |
+
class Ftrl(optimizer.Optimizer):
|
| 9 |
+
r"""Optimizer that implements the FTRL algorithm.
|
| 10 |
+
|
| 11 |
+
"Follow The Regularized Leader" (FTRL) is an optimization algorithm
|
| 12 |
+
developed at Google for click-through rate prediction in the early 2010s. It
|
| 13 |
+
is most suitable for shallow models with large and sparse feature spaces.
|
| 14 |
+
The algorithm is described by
|
| 15 |
+
[McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf).
|
| 16 |
+
The Keras version has support for both online L2 regularization
|
| 17 |
+
(the L2 regularization described in the paper
|
| 18 |
+
above) and shrinkage-type L2 regularization
|
| 19 |
+
(which is the addition of an L2 penalty to the loss function).
|
| 20 |
+
|
| 21 |
+
Initialization:
|
| 22 |
+
|
| 23 |
+
```python
|
| 24 |
+
n = 0
|
| 25 |
+
sigma = 0
|
| 26 |
+
z = 0
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
Update rule for one variable `w`:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
prev_n = n
|
| 33 |
+
n = n + g ** 2
|
| 34 |
+
sigma = (n ** -lr_power - prev_n ** -lr_power) / lr
|
| 35 |
+
z = z + g - sigma * w
|
| 36 |
+
if abs(z) < lambda_1:
|
| 37 |
+
w = 0
|
| 38 |
+
else:
|
| 39 |
+
w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Notation:
|
| 43 |
+
|
| 44 |
+
- `lr` is the learning rate
|
| 45 |
+
- `g` is the gradient for the variable
|
| 46 |
+
- `lambda_1` is the L1 regularization strength
|
| 47 |
+
- `lambda_2` is the L2 regularization strength
|
| 48 |
+
- `lr_power` is the power to scale n.
|
| 49 |
+
|
| 50 |
+
Check the documentation for the `l2_shrinkage_regularization_strength`
|
| 51 |
+
parameter for more details when shrinkage is enabled, in which case gradient
|
| 52 |
+
is replaced with a gradient with shrinkage.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
learning_rate: A float, a
|
| 56 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 57 |
+
a callable that takes no arguments and returns the actual value to
|
| 58 |
+
use. The learning rate. Defaults to `0.001`.
|
| 59 |
+
learning_rate_power: A float value, must be less or equal to zero.
|
| 60 |
+
Controls how the learning rate decreases during training. Use zero
|
| 61 |
+
for a fixed learning rate.
|
| 62 |
+
initial_accumulator_value: The starting value for accumulators. Only
|
| 63 |
+
zero or positive values are allowed.
|
| 64 |
+
l1_regularization_strength: A float value, must be greater than or equal
|
| 65 |
+
to zero. Defaults to `0.0`.
|
| 66 |
+
l2_regularization_strength: A float value, must be greater than or equal
|
| 67 |
+
to zero. Defaults to `0.0`.
|
| 68 |
+
l2_shrinkage_regularization_strength: A float value, must be greater
|
| 69 |
+
than or equal to zero. This differs from L2 above in that the L2
|
| 70 |
+
above is a stabilization penalty, whereas this L2 shrinkage is a
|
| 71 |
+
magnitude penalty. When input is sparse shrinkage will only happen
|
| 72 |
+
on the active weights.
|
| 73 |
+
beta: A float value, representing the beta value from the paper.
|
| 74 |
+
Defaults to `0.0`.
|
| 75 |
+
{{base_optimizer_keyword_args}}
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
learning_rate=0.001,
|
| 81 |
+
learning_rate_power=-0.5,
|
| 82 |
+
initial_accumulator_value=0.1,
|
| 83 |
+
l1_regularization_strength=0.0,
|
| 84 |
+
l2_regularization_strength=0.0,
|
| 85 |
+
l2_shrinkage_regularization_strength=0.0,
|
| 86 |
+
beta=0.0,
|
| 87 |
+
weight_decay=None,
|
| 88 |
+
clipnorm=None,
|
| 89 |
+
clipvalue=None,
|
| 90 |
+
global_clipnorm=None,
|
| 91 |
+
use_ema=False,
|
| 92 |
+
ema_momentum=0.99,
|
| 93 |
+
ema_overwrite_frequency=None,
|
| 94 |
+
loss_scale_factor=None,
|
| 95 |
+
gradient_accumulation_steps=None,
|
| 96 |
+
name="ftrl",
|
| 97 |
+
**kwargs,
|
| 98 |
+
):
|
| 99 |
+
super().__init__(
|
| 100 |
+
learning_rate=learning_rate,
|
| 101 |
+
name=name,
|
| 102 |
+
weight_decay=weight_decay,
|
| 103 |
+
clipnorm=clipnorm,
|
| 104 |
+
clipvalue=clipvalue,
|
| 105 |
+
global_clipnorm=global_clipnorm,
|
| 106 |
+
use_ema=use_ema,
|
| 107 |
+
ema_momentum=ema_momentum,
|
| 108 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 109 |
+
loss_scale_factor=loss_scale_factor,
|
| 110 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 111 |
+
**kwargs,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if initial_accumulator_value < 0.0:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
"`initial_accumulator_value` needs to be positive or zero. "
|
| 117 |
+
"Received: initial_accumulator_value="
|
| 118 |
+
f"{initial_accumulator_value}."
|
| 119 |
+
)
|
| 120 |
+
if learning_rate_power > 0.0:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
"`learning_rate_power` needs to be negative or zero. Received: "
|
| 123 |
+
f"learning_rate_power={learning_rate_power}."
|
| 124 |
+
)
|
| 125 |
+
if l1_regularization_strength < 0.0:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
"`l1_regularization_strength` needs to be positive or zero. "
|
| 128 |
+
"Received: l1_regularization_strength="
|
| 129 |
+
f"{l1_regularization_strength}."
|
| 130 |
+
)
|
| 131 |
+
if l2_regularization_strength < 0.0:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
"`l2_regularization_strength` needs to be positive or zero. "
|
| 134 |
+
"Received: l2_regularization_strength="
|
| 135 |
+
f"{l2_regularization_strength}."
|
| 136 |
+
)
|
| 137 |
+
if l2_shrinkage_regularization_strength < 0.0:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
"`l2_shrinkage_regularization_strength` needs to be positive "
|
| 140 |
+
"or zero. Received: l2_shrinkage_regularization_strength"
|
| 141 |
+
f"={l2_shrinkage_regularization_strength}."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.learning_rate_power = learning_rate_power
|
| 145 |
+
self.initial_accumulator_value = initial_accumulator_value
|
| 146 |
+
self.l1_regularization_strength = l1_regularization_strength
|
| 147 |
+
self.l2_regularization_strength = l2_regularization_strength
|
| 148 |
+
self.l2_shrinkage_regularization_strength = (
|
| 149 |
+
l2_shrinkage_regularization_strength
|
| 150 |
+
)
|
| 151 |
+
self.beta = beta
|
| 152 |
+
|
| 153 |
+
def build(self, var_list):
|
| 154 |
+
"""Initialize optimizer variables.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
var_list: list of model variables to build Ftrl variables on.
|
| 158 |
+
"""
|
| 159 |
+
if self.built:
|
| 160 |
+
return
|
| 161 |
+
super().build(var_list)
|
| 162 |
+
self._accumulators = []
|
| 163 |
+
self._linears = []
|
| 164 |
+
for var in var_list:
|
| 165 |
+
self._accumulators.append(
|
| 166 |
+
self.add_variable(
|
| 167 |
+
shape=var.shape,
|
| 168 |
+
dtype=var.dtype,
|
| 169 |
+
name="accumulator",
|
| 170 |
+
initializer=initializers.Constant(
|
| 171 |
+
self.initial_accumulator_value,
|
| 172 |
+
),
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
self._linears.append(
|
| 176 |
+
self.add_variable_from_reference(
|
| 177 |
+
reference_variable=var, name="linear"
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 182 |
+
"""Update step given gradient and the associated model variable."""
|
| 183 |
+
|
| 184 |
+
lr = ops.cast(learning_rate, variable.dtype)
|
| 185 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 186 |
+
|
| 187 |
+
accum = self._accumulators[self._get_variable_index(variable)]
|
| 188 |
+
linear = self._linears[self._get_variable_index(variable)]
|
| 189 |
+
|
| 190 |
+
lr_power = self.learning_rate_power
|
| 191 |
+
l2_reg = self.l2_regularization_strength
|
| 192 |
+
l2_reg = l2_reg + self.beta / (2.0 * lr)
|
| 193 |
+
|
| 194 |
+
grad_to_use = ops.add(
|
| 195 |
+
gradient,
|
| 196 |
+
ops.multiply(
|
| 197 |
+
2 * self.l2_shrinkage_regularization_strength, variable
|
| 198 |
+
),
|
| 199 |
+
)
|
| 200 |
+
new_accum = ops.add(accum, ops.square(gradient))
|
| 201 |
+
self.assign_add(
|
| 202 |
+
linear,
|
| 203 |
+
ops.subtract(
|
| 204 |
+
grad_to_use,
|
| 205 |
+
ops.multiply(
|
| 206 |
+
ops.divide(
|
| 207 |
+
ops.subtract(
|
| 208 |
+
ops.power(new_accum, -lr_power),
|
| 209 |
+
ops.power(accum, -lr_power),
|
| 210 |
+
),
|
| 211 |
+
lr,
|
| 212 |
+
),
|
| 213 |
+
variable,
|
| 214 |
+
),
|
| 215 |
+
),
|
| 216 |
+
)
|
| 217 |
+
quadratic = ops.add(
|
| 218 |
+
ops.divide(ops.power(new_accum, (-lr_power)), lr), 2 * l2_reg
|
| 219 |
+
)
|
| 220 |
+
linear_clipped = ops.clip(
|
| 221 |
+
linear,
|
| 222 |
+
-self.l1_regularization_strength,
|
| 223 |
+
self.l1_regularization_strength,
|
| 224 |
+
)
|
| 225 |
+
self.assign(
|
| 226 |
+
variable,
|
| 227 |
+
ops.divide(ops.subtract(linear_clipped, linear), quadratic),
|
| 228 |
+
)
|
| 229 |
+
self.assign(accum, new_accum)
|
| 230 |
+
|
| 231 |
+
def get_config(self):
|
| 232 |
+
config = super().get_config()
|
| 233 |
+
|
| 234 |
+
config.update(
|
| 235 |
+
{
|
| 236 |
+
"learning_rate_power": self.learning_rate_power,
|
| 237 |
+
"initial_accumulator_value": self.initial_accumulator_value,
|
| 238 |
+
"l1_regularization_strength": self.l1_regularization_strength,
|
| 239 |
+
"l2_regularization_strength": self.l2_regularization_strength,
|
| 240 |
+
"l2_shrinkage_regularization_strength": self.l2_shrinkage_regularization_strength, # noqa: E501
|
| 241 |
+
"beta": self.beta,
|
| 242 |
+
}
|
| 243 |
+
)
|
| 244 |
+
return config
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
Ftrl.__doc__ = Ftrl.__doc__.replace(
|
| 248 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 249 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/lamb.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import ops
|
| 2 |
+
from keras.src.api_export import keras_export
|
| 3 |
+
from keras.src.optimizers import optimizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@keras_export("keras.optimizers.Lamb")
|
| 7 |
+
class Lamb(optimizer.Optimizer):
|
| 8 |
+
"""Optimizer that implements the Lamb algorithm.
|
| 9 |
+
|
| 10 |
+
Lamb is a stochastic gradient descent method that
|
| 11 |
+
uses layer-wise adaptive moments to adjusts the
|
| 12 |
+
learning rate for each parameter based on the ratio of the
|
| 13 |
+
norm of the weight to the norm of the gradient
|
| 14 |
+
This helps to stabilize the training process and improves convergence
|
| 15 |
+
especially for large batch sizes.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
learning_rate: A float, a
|
| 19 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 20 |
+
a callable that takes no arguments and returns the actual value to
|
| 21 |
+
use. The learning rate. Defaults to `0.001`.
|
| 22 |
+
beta_1: A float value or a constant float tensor, or a callable
|
| 23 |
+
that takes no arguments and returns the actual value to use. The
|
| 24 |
+
exponential decay rate for the 1st moment estimates. Defaults to
|
| 25 |
+
`0.9`.
|
| 26 |
+
beta_2: A float value or a constant float tensor, or a callable
|
| 27 |
+
that takes no arguments and returns the actual value to use. The
|
| 28 |
+
exponential decay rate for the 2nd moment estimates. Defaults to
|
| 29 |
+
`0.999`.
|
| 30 |
+
epsilon: A small constant for numerical stability.
|
| 31 |
+
Defaults to `1e-7`.
|
| 32 |
+
{{base_optimizer_keyword_args}}
|
| 33 |
+
|
| 34 |
+
References:
|
| 35 |
+
- [Yang et al.](https://arxiv.org/pdf/1904.00962)
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
learning_rate=0.001,
|
| 41 |
+
beta_1=0.9,
|
| 42 |
+
beta_2=0.999,
|
| 43 |
+
epsilon=1e-7,
|
| 44 |
+
weight_decay=None,
|
| 45 |
+
clipnorm=None,
|
| 46 |
+
clipvalue=None,
|
| 47 |
+
global_clipnorm=None,
|
| 48 |
+
use_ema=False,
|
| 49 |
+
ema_momentum=0.99,
|
| 50 |
+
ema_overwrite_frequency=None,
|
| 51 |
+
loss_scale_factor=None,
|
| 52 |
+
gradient_accumulation_steps=None,
|
| 53 |
+
name="lamb",
|
| 54 |
+
**kwargs,
|
| 55 |
+
):
|
| 56 |
+
super().__init__(
|
| 57 |
+
learning_rate=learning_rate,
|
| 58 |
+
name=name,
|
| 59 |
+
weight_decay=weight_decay,
|
| 60 |
+
clipnorm=clipnorm,
|
| 61 |
+
clipvalue=clipvalue,
|
| 62 |
+
global_clipnorm=global_clipnorm,
|
| 63 |
+
use_ema=use_ema,
|
| 64 |
+
ema_momentum=ema_momentum,
|
| 65 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 66 |
+
loss_scale_factor=loss_scale_factor,
|
| 67 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 68 |
+
**kwargs,
|
| 69 |
+
)
|
| 70 |
+
self.beta_1 = beta_1
|
| 71 |
+
self.beta_2 = beta_2
|
| 72 |
+
self.epsilon = epsilon
|
| 73 |
+
|
| 74 |
+
def build(self, var_list):
|
| 75 |
+
"""Initialize optimizer variables.
|
| 76 |
+
|
| 77 |
+
Lamb optimizer has 2 types of variables: momentums and velocities
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
var_list: list of model variables to build Lamb variables on.
|
| 81 |
+
"""
|
| 82 |
+
if self.built:
|
| 83 |
+
return
|
| 84 |
+
super().build(var_list)
|
| 85 |
+
self._momentums = []
|
| 86 |
+
self._velocities = []
|
| 87 |
+
for var in var_list:
|
| 88 |
+
self._momentums.append(
|
| 89 |
+
self.add_variable_from_reference(
|
| 90 |
+
reference_variable=var, name="momentum"
|
| 91 |
+
)
|
| 92 |
+
)
|
| 93 |
+
self._velocities.append(
|
| 94 |
+
self.add_variable_from_reference(
|
| 95 |
+
reference_variable=var, name="velocity"
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 100 |
+
"""Update step given gradient and the associated model variable."""
|
| 101 |
+
lr = ops.cast(learning_rate, variable.dtype)
|
| 102 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 103 |
+
local_step = ops.cast(self.iterations + 1, variable.dtype)
|
| 104 |
+
|
| 105 |
+
beta_1_power = ops.power(
|
| 106 |
+
ops.cast(self.beta_1, variable.dtype), local_step
|
| 107 |
+
)
|
| 108 |
+
beta_2_power = ops.power(
|
| 109 |
+
ops.cast(self.beta_2, variable.dtype), local_step
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
m = self._momentums[self._get_variable_index(variable)]
|
| 113 |
+
v = self._velocities[self._get_variable_index(variable)]
|
| 114 |
+
|
| 115 |
+
self.assign_add(
|
| 116 |
+
m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.assign_add(
|
| 120 |
+
v,
|
| 121 |
+
ops.multiply(
|
| 122 |
+
ops.subtract(ops.square(gradient), v), 1 - self.beta_2
|
| 123 |
+
),
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
m_t_hat = ops.divide(m, (1.0 - beta_1_power))
|
| 127 |
+
v_sqrt = ops.add(
|
| 128 |
+
ops.sqrt(ops.divide(v, (1.0 - beta_2_power))), self.epsilon
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
update = ops.divide(m_t_hat, v_sqrt)
|
| 132 |
+
w_norm = ops.sqrt(ops.sum(ops.power(variable, 2)))
|
| 133 |
+
g_norm = ops.sqrt(ops.sum(ops.power(update, 2)))
|
| 134 |
+
|
| 135 |
+
# ratio = w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1
|
| 136 |
+
ratio = ops.where(
|
| 137 |
+
ops.greater(w_norm, 0),
|
| 138 |
+
ops.where(ops.greater(g_norm, 0), (w_norm / g_norm), 1.0),
|
| 139 |
+
1.0,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
self.assign_sub(variable, ratio * lr * update)
|
| 143 |
+
|
| 144 |
+
def get_config(self):
|
| 145 |
+
config = super().get_config()
|
| 146 |
+
config.update(
|
| 147 |
+
{
|
| 148 |
+
"beta_1": self.beta_1,
|
| 149 |
+
"beta_2": self.beta_2,
|
| 150 |
+
"epsilon": self.epsilon,
|
| 151 |
+
}
|
| 152 |
+
)
|
| 153 |
+
return config
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
Lamb.__doc__ = Lamb.__doc__.replace(
|
| 157 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 158 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/lion.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import ops
|
| 2 |
+
from keras.src.api_export import keras_export
|
| 3 |
+
from keras.src.optimizers import optimizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@keras_export(["keras.optimizers.Lion"])
|
| 7 |
+
class Lion(optimizer.Optimizer):
|
| 8 |
+
"""Optimizer that implements the Lion algorithm.
|
| 9 |
+
|
| 10 |
+
The Lion optimizer is a stochastic-gradient-descent method that uses the
|
| 11 |
+
sign operator to control the magnitude of the update, unlike other adaptive
|
| 12 |
+
optimizers such as Adam that rely on second-order moments. This make
|
| 13 |
+
Lion more memory-efficient as it only keeps track of the momentum. According
|
| 14 |
+
to the authors (see reference), its performance gain over Adam grows with
|
| 15 |
+
the batch size. Because the update of Lion is produced through the sign
|
| 16 |
+
operation, resulting in a larger norm, a suitable learning rate for Lion is
|
| 17 |
+
typically 3-10x smaller than that for AdamW. The weight decay for Lion
|
| 18 |
+
should be in turn 3-10x larger than that for AdamW to maintain a
|
| 19 |
+
similar strength (lr * wd).
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
learning_rate: A float, a
|
| 23 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 24 |
+
a callable that takes no arguments and returns the actual value to
|
| 25 |
+
use. The learning rate. Defaults to `0.001`.
|
| 26 |
+
beta_1: A float value or a constant float tensor, or a callable
|
| 27 |
+
that takes no arguments and returns the actual value to use. The
|
| 28 |
+
rate to combine the current gradient and the 1st moment estimate.
|
| 29 |
+
Defaults to `0.9`.
|
| 30 |
+
beta_2: A float value or a constant float tensor, or a callable
|
| 31 |
+
that takes no arguments and returns the actual value to use. The
|
| 32 |
+
exponential decay rate for the 1st moment estimate. Defaults to
|
| 33 |
+
`0.99`.
|
| 34 |
+
{{base_optimizer_keyword_args}}
|
| 35 |
+
|
| 36 |
+
References:
|
| 37 |
+
|
| 38 |
+
- [Chen et al., 2023](http://arxiv.org/abs/2302.06675)
|
| 39 |
+
- [Authors' implementation](
|
| 40 |
+
http://github.com/google/automl/tree/master/lion)
|
| 41 |
+
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
learning_rate=0.001,
|
| 47 |
+
beta_1=0.9,
|
| 48 |
+
beta_2=0.99,
|
| 49 |
+
weight_decay=None,
|
| 50 |
+
clipnorm=None,
|
| 51 |
+
clipvalue=None,
|
| 52 |
+
global_clipnorm=None,
|
| 53 |
+
use_ema=False,
|
| 54 |
+
ema_momentum=0.99,
|
| 55 |
+
ema_overwrite_frequency=None,
|
| 56 |
+
loss_scale_factor=None,
|
| 57 |
+
gradient_accumulation_steps=None,
|
| 58 |
+
name="lion",
|
| 59 |
+
**kwargs,
|
| 60 |
+
):
|
| 61 |
+
super().__init__(
|
| 62 |
+
learning_rate=learning_rate,
|
| 63 |
+
name=name,
|
| 64 |
+
weight_decay=weight_decay,
|
| 65 |
+
clipnorm=clipnorm,
|
| 66 |
+
clipvalue=clipvalue,
|
| 67 |
+
global_clipnorm=global_clipnorm,
|
| 68 |
+
use_ema=use_ema,
|
| 69 |
+
ema_momentum=ema_momentum,
|
| 70 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 71 |
+
loss_scale_factor=loss_scale_factor,
|
| 72 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 73 |
+
**kwargs,
|
| 74 |
+
)
|
| 75 |
+
self.beta_1 = beta_1
|
| 76 |
+
self.beta_2 = beta_2
|
| 77 |
+
if beta_1 <= 0 or beta_1 > 1:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
"Argument `beta_1` must be in the [0, 1] range. Otherwise, the "
|
| 80 |
+
f"optimizer degenerates to SignSGD. Received: beta_1={beta_1}."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def build(self, var_list):
|
| 84 |
+
"""Initialize optimizer variables.
|
| 85 |
+
|
| 86 |
+
Lion optimizer has one variable `momentums`.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
var_list: list of model variables to build Lion variables on.
|
| 90 |
+
"""
|
| 91 |
+
if self.built:
|
| 92 |
+
return
|
| 93 |
+
super().build(var_list)
|
| 94 |
+
self._momentums = []
|
| 95 |
+
for var in var_list:
|
| 96 |
+
self._momentums.append(
|
| 97 |
+
self.add_variable_from_reference(
|
| 98 |
+
reference_variable=var, name="momentum"
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 103 |
+
"""Update step given gradient and the associated model variable."""
|
| 104 |
+
lr = ops.cast(learning_rate, variable.dtype)
|
| 105 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 106 |
+
beta_1 = ops.cast(self.beta_1, variable.dtype)
|
| 107 |
+
beta_2 = ops.cast(self.beta_2, variable.dtype)
|
| 108 |
+
m = self._momentums[self._get_variable_index(variable)]
|
| 109 |
+
|
| 110 |
+
self.assign_sub(
|
| 111 |
+
variable,
|
| 112 |
+
ops.multiply(
|
| 113 |
+
lr,
|
| 114 |
+
ops.sign(
|
| 115 |
+
ops.add(
|
| 116 |
+
ops.multiply(m, beta_1),
|
| 117 |
+
ops.multiply(gradient, (1.0 - beta_1)),
|
| 118 |
+
)
|
| 119 |
+
),
|
| 120 |
+
),
|
| 121 |
+
)
|
| 122 |
+
self.assign(
|
| 123 |
+
m,
|
| 124 |
+
ops.add(
|
| 125 |
+
ops.multiply(m, beta_2), ops.multiply(gradient, (1.0 - beta_2))
|
| 126 |
+
),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def get_config(self):
|
| 130 |
+
config = super().get_config()
|
| 131 |
+
config.update(
|
| 132 |
+
{
|
| 133 |
+
"beta_1": self.beta_1,
|
| 134 |
+
"beta_2": self.beta_2,
|
| 135 |
+
}
|
| 136 |
+
)
|
| 137 |
+
return config
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
Lion.__doc__ = Lion.__doc__.replace(
|
| 141 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 142 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/loss_scale_optimizer.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src import initializers
|
| 3 |
+
from keras.src import ops
|
| 4 |
+
from keras.src.api_export import keras_export
|
| 5 |
+
from keras.src.optimizers import optimizer
|
| 6 |
+
from keras.src.saving import serialization_lib
|
| 7 |
+
from keras.src.utils import tracking
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@keras_export(
|
| 11 |
+
[
|
| 12 |
+
"keras.optimizers.LossScaleOptimizer",
|
| 13 |
+
"keras.mixed_precision.LossScaleOptimizer",
|
| 14 |
+
]
|
| 15 |
+
)
|
| 16 |
+
class LossScaleOptimizer(optimizer.Optimizer):
|
| 17 |
+
"""An optimizer that dynamically scales the loss to prevent underflow.
|
| 18 |
+
|
| 19 |
+
Loss scaling is a technique to prevent numeric underflow in intermediate
|
| 20 |
+
gradients when float16 is used. To prevent underflow, the loss is multiplied
|
| 21 |
+
(or "scaled") by a certain factor called the "loss scale", which causes
|
| 22 |
+
intermediate gradients to be scaled by the loss scale as well. The final
|
| 23 |
+
gradients are divided (or "unscaled") by the loss scale to bring them back
|
| 24 |
+
to their original value.
|
| 25 |
+
|
| 26 |
+
`LossScaleOptimizer` wraps another optimizer and applies dynamic loss
|
| 27 |
+
scaling to it. This loss scale is dynamically updated over time as follows:
|
| 28 |
+
- On any train step, if a nonfinite gradient is encountered, the loss scale
|
| 29 |
+
is halved, and the train step is skipped.
|
| 30 |
+
- If `dynamic_growth_steps` have occurred since the last time the loss scale
|
| 31 |
+
was updated, and no nonfinite gradients have occurred, the loss scale
|
| 32 |
+
is doubled.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
inner_optimizer: The `keras.optimizers.Optimizer` instance to wrap.
|
| 36 |
+
initial_scale: Float. The initial loss scale. This scale will be updated
|
| 37 |
+
during training. It is recommended for this to be a very high
|
| 38 |
+
number, because a loss scale that is too high gets lowered far more
|
| 39 |
+
quickly than a loss scale that is too low gets raised.
|
| 40 |
+
dynamic_growth_steps: Int. How often to update the scale upwards. After
|
| 41 |
+
every `dynamic_growth_steps` steps with finite gradients, the
|
| 42 |
+
loss scale is doubled.
|
| 43 |
+
{{base_optimizer_keyword_args}}
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
inner_optimizer,
|
| 49 |
+
initial_scale=2.0**15,
|
| 50 |
+
dynamic_growth_steps=2000,
|
| 51 |
+
**kwargs,
|
| 52 |
+
):
|
| 53 |
+
if not kwargs.pop("dynamic", True):
|
| 54 |
+
raise ValueError(
|
| 55 |
+
"LossScaleOptimizer no longer supports `dynamic=False`. "
|
| 56 |
+
"Instead, simply set `loss_scale_factor` directly on the "
|
| 57 |
+
"`inner_optimizer`."
|
| 58 |
+
)
|
| 59 |
+
super().__init__(learning_rate=0.0, **kwargs)
|
| 60 |
+
self.inner_optimizer = inner_optimizer
|
| 61 |
+
self.initial_scale = initial_scale
|
| 62 |
+
self.dynamic_growth_steps = dynamic_growth_steps
|
| 63 |
+
|
| 64 |
+
@tracking.no_automatic_dependency_tracking
|
| 65 |
+
def build(self, var_list):
|
| 66 |
+
self.step_counter = self.add_variable(
|
| 67 |
+
shape=(),
|
| 68 |
+
dtype="int",
|
| 69 |
+
initializer=initializers.Zeros(),
|
| 70 |
+
aggregation="none",
|
| 71 |
+
name="step_counter",
|
| 72 |
+
)
|
| 73 |
+
self.dynamic_scale = self.add_variable(
|
| 74 |
+
shape=(),
|
| 75 |
+
dtype="float32",
|
| 76 |
+
initializer=initializers.Constant(self.initial_scale),
|
| 77 |
+
aggregation="none",
|
| 78 |
+
name="dynamic_scale",
|
| 79 |
+
)
|
| 80 |
+
self.inner_optimizer.build(var_list)
|
| 81 |
+
self.built = True
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def variables(self):
|
| 85 |
+
return self._variables + self.inner_optimizer.variables
|
| 86 |
+
|
| 87 |
+
def stateless_apply(self, optimizer_variables, grads, trainable_variables):
|
| 88 |
+
if not self.built:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"To call `stateless_apply`, {self.__class__.__name__} "
|
| 91 |
+
"must be built (i.e. its variables must have been created). "
|
| 92 |
+
"You can build it via `optimizer.build(trainable_variables)`."
|
| 93 |
+
)
|
| 94 |
+
finite = self.check_finite(grads)
|
| 95 |
+
return ops.cond(
|
| 96 |
+
finite,
|
| 97 |
+
lambda: self._stateless_handle_finite_grads(
|
| 98 |
+
optimizer_variables, grads, trainable_variables
|
| 99 |
+
),
|
| 100 |
+
lambda: self._stateless_handle_non_finite_grads(
|
| 101 |
+
optimizer_variables, trainable_variables
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def _stateless_handle_finite_grads(
|
| 106 |
+
self, optimizer_variables, grads, trainable_variables
|
| 107 |
+
):
|
| 108 |
+
def upscale():
|
| 109 |
+
mapping = list(zip(self.variables, optimizer_variables))
|
| 110 |
+
with backend.StatelessScope(state_mapping=mapping) as scope:
|
| 111 |
+
self.step_counter.assign(0)
|
| 112 |
+
self.dynamic_scale.assign(self.dynamic_scale * 2.0)
|
| 113 |
+
return [scope.get_current_value(v) for v in self._variables]
|
| 114 |
+
|
| 115 |
+
def increment():
|
| 116 |
+
mapping = list(zip(self.variables, optimizer_variables))
|
| 117 |
+
with backend.StatelessScope(state_mapping=mapping) as scope:
|
| 118 |
+
self.step_counter.assign_add(1)
|
| 119 |
+
return [scope.get_current_value(v) for v in self._variables]
|
| 120 |
+
|
| 121 |
+
mapping = list(zip(self.variables, optimizer_variables))
|
| 122 |
+
with backend.StatelessScope(state_mapping=mapping):
|
| 123 |
+
# Potentially upscale loss and reset counter.
|
| 124 |
+
own_variables = ops.cond(
|
| 125 |
+
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
|
| 126 |
+
upscale,
|
| 127 |
+
increment,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Unscale gradients.
|
| 131 |
+
scale = self.dynamic_scale
|
| 132 |
+
unscaled_grads = [
|
| 133 |
+
g if g is None else ops.divide(g, scale) for g in grads
|
| 134 |
+
]
|
| 135 |
+
(
|
| 136 |
+
new_trainable_variables,
|
| 137 |
+
new_inner_variables,
|
| 138 |
+
) = self.inner_optimizer.stateless_apply(
|
| 139 |
+
self.inner_optimizer.variables,
|
| 140 |
+
unscaled_grads,
|
| 141 |
+
trainable_variables,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
new_optimizer_variables = own_variables + new_inner_variables
|
| 145 |
+
return new_trainable_variables, new_optimizer_variables
|
| 146 |
+
|
| 147 |
+
def _stateless_handle_non_finite_grads(
|
| 148 |
+
self, optimizer_variables, trainable_variables
|
| 149 |
+
):
|
| 150 |
+
mapping = list(zip(self.variables, optimizer_variables))
|
| 151 |
+
with backend.StatelessScope(state_mapping=mapping) as scope:
|
| 152 |
+
self.step_counter.assign(0)
|
| 153 |
+
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
|
| 154 |
+
new_optimizer_variables = []
|
| 155 |
+
for v in self.variables:
|
| 156 |
+
new_optimizer_variables.append(scope.get_current_value(v))
|
| 157 |
+
return trainable_variables, new_optimizer_variables
|
| 158 |
+
|
| 159 |
+
def apply(self, grads, trainable_variables=None):
|
| 160 |
+
# Optionally build optimizer.
|
| 161 |
+
if not self.built:
|
| 162 |
+
with backend.name_scope(self.name, caller=self):
|
| 163 |
+
self.build(trainable_variables)
|
| 164 |
+
self.built = True
|
| 165 |
+
|
| 166 |
+
if backend.backend() == "tensorflow":
|
| 167 |
+
self._tf_apply(grads, trainable_variables)
|
| 168 |
+
else:
|
| 169 |
+
self._common_apply(grads, trainable_variables)
|
| 170 |
+
|
| 171 |
+
def _stateful_handle_finite_grads(self, grads, trainable_variables):
|
| 172 |
+
scale = self.dynamic_scale
|
| 173 |
+
# Unscale gradients.
|
| 174 |
+
unscaled_grads = [
|
| 175 |
+
g if g is None else ops.divide(g, scale) for g in grads
|
| 176 |
+
]
|
| 177 |
+
self.inner_optimizer.apply(
|
| 178 |
+
unscaled_grads, trainable_variables=trainable_variables
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def upscale():
|
| 182 |
+
self.step_counter.assign(0)
|
| 183 |
+
self.dynamic_scale.assign(self.dynamic_scale * 2.0)
|
| 184 |
+
|
| 185 |
+
def increment():
|
| 186 |
+
self.step_counter.assign_add(1)
|
| 187 |
+
|
| 188 |
+
# Potentially upscale loss and reset counter.
|
| 189 |
+
ops.cond(
|
| 190 |
+
ops.equal(self.step_counter, self.dynamic_growth_steps - 1),
|
| 191 |
+
upscale,
|
| 192 |
+
increment,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def _stateful_handle_non_finite_grads(self):
|
| 196 |
+
# If any inf or nan in grads, downscale loss and reset counter.
|
| 197 |
+
self.step_counter.assign(0)
|
| 198 |
+
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
|
| 199 |
+
|
| 200 |
+
def _common_apply(self, grads, trainable_variables=None):
|
| 201 |
+
finite = self.check_finite(grads)
|
| 202 |
+
ops.cond(
|
| 203 |
+
finite,
|
| 204 |
+
lambda: self._stateful_handle_finite_grads(
|
| 205 |
+
grads, trainable_variables
|
| 206 |
+
),
|
| 207 |
+
self._stateful_handle_non_finite_grads,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def _tf_apply(self, grads, trainable_variables=None):
|
| 211 |
+
"""Tensorflow specific logic for apply, which handles distribution."""
|
| 212 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 213 |
+
|
| 214 |
+
if tf.distribute.in_cross_replica_context():
|
| 215 |
+
raise ValueError("apply() must be called in a replica context.")
|
| 216 |
+
|
| 217 |
+
if tf.__internal__.distribute.strategy_supports_no_merge_call():
|
| 218 |
+
self._common_apply(grads, trainable_variables=trainable_variables)
|
| 219 |
+
else:
|
| 220 |
+
|
| 221 |
+
def _handle_cross_replica(distribution, grads, trainable_variables):
|
| 222 |
+
finite_per_replica = (
|
| 223 |
+
distribution.extended.call_for_each_replica(
|
| 224 |
+
self.check_finite, args=(grads,)
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
# Each replica computed the same `finite` value, since
|
| 228 |
+
# `grads` is all-reduced across replicas. Arbitrarily take
|
| 229 |
+
# `finite` from the first replica.
|
| 230 |
+
finite = distribution.experimental_local_results(
|
| 231 |
+
finite_per_replica
|
| 232 |
+
)[0]
|
| 233 |
+
|
| 234 |
+
def apply_fn():
|
| 235 |
+
distribution.extended.call_for_each_replica(
|
| 236 |
+
self._stateful_handle_finite_grads,
|
| 237 |
+
args=(grads, trainable_variables),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Note: We must call this cond() in a cross-replica context.
|
| 241 |
+
# DistributionStrategy does not support having a cond in a
|
| 242 |
+
# replica context with a branch that calls `merge_call`, and
|
| 243 |
+
# self._optimizer.apply_gradients calls `merge_call`.
|
| 244 |
+
ops.cond(
|
| 245 |
+
finite, apply_fn, self._stateful_handle_non_finite_grads
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
tf.distribute.get_replica_context().merge_call(
|
| 249 |
+
_handle_cross_replica, args=(grads, trainable_variables)
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def check_finite(self, grads):
|
| 253 |
+
tensor_grads = [g for g in grads if g is not None]
|
| 254 |
+
finite_grads = [ops.all(ops.isfinite(g)) for g in tensor_grads]
|
| 255 |
+
return ops.all(ops.convert_to_tensor(finite_grads))
|
| 256 |
+
|
| 257 |
+
@property
|
| 258 |
+
def learning_rate(self):
|
| 259 |
+
return self.inner_optimizer.learning_rate
|
| 260 |
+
|
| 261 |
+
@learning_rate.setter
|
| 262 |
+
def learning_rate(self, learning_rate):
|
| 263 |
+
self.inner_optimizer.learning_rate = learning_rate
|
| 264 |
+
|
| 265 |
+
def scale_loss(self, loss):
|
| 266 |
+
scale = self.dynamic_scale if self.built else self.initial_scale
|
| 267 |
+
return loss * scale
|
| 268 |
+
|
| 269 |
+
def finalize_variable_values(self, var_list):
|
| 270 |
+
self.inner_optimizer.finalize_variable_values(var_list)
|
| 271 |
+
|
| 272 |
+
def get_config(self):
|
| 273 |
+
config = super().get_config()
|
| 274 |
+
inner_optimizer_config = serialization_lib.serialize_keras_object(
|
| 275 |
+
self.inner_optimizer
|
| 276 |
+
)
|
| 277 |
+
config.update(
|
| 278 |
+
{
|
| 279 |
+
"inner_optimizer": inner_optimizer_config,
|
| 280 |
+
"initial_scale": self.initial_scale,
|
| 281 |
+
"dynamic_growth_steps": self.dynamic_growth_steps,
|
| 282 |
+
}
|
| 283 |
+
)
|
| 284 |
+
del config["learning_rate"]
|
| 285 |
+
return config
|
| 286 |
+
|
| 287 |
+
@classmethod
|
| 288 |
+
def from_config(cls, config, custom_objects=None):
|
| 289 |
+
inner_optimizer = serialization_lib.deserialize_keras_object(
|
| 290 |
+
config.pop("inner_optimizer"),
|
| 291 |
+
custom_objects=custom_objects,
|
| 292 |
+
)
|
| 293 |
+
return cls(inner_optimizer, **config)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
LossScaleOptimizer.__doc__ = LossScaleOptimizer.__doc__.replace(
|
| 297 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 298 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/nadam.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src import ops
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.optimizers import optimizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@keras_export(["keras.optimizers.Nadam"])
|
| 8 |
+
class Nadam(optimizer.Optimizer):
|
| 9 |
+
"""Optimizer that implements the Nadam algorithm.
|
| 10 |
+
|
| 11 |
+
Much like Adam is essentially RMSprop with momentum, Nadam is Adam with
|
| 12 |
+
Nesterov momentum.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
learning_rate: A float, a
|
| 16 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 17 |
+
a callable that takes no arguments and returns the actual value to
|
| 18 |
+
use. The learning rate. Defaults to `0.001`.
|
| 19 |
+
beta_1: A float value or a constant float tensor, or a callable
|
| 20 |
+
that takes no arguments and returns the actual value to use. The
|
| 21 |
+
exponential decay rate for the 1st moment estimates.
|
| 22 |
+
Defaults to `0.9`.
|
| 23 |
+
beta_2: A float value or a constant float tensor, or a callable
|
| 24 |
+
that takes no arguments and returns the actual value to use. The
|
| 25 |
+
exponential decay rate for the 2nd moment estimates. Defaults to
|
| 26 |
+
`0.999`.
|
| 27 |
+
epsilon: A small constant for numerical stability. This epsilon is
|
| 28 |
+
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
| 29 |
+
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
| 30 |
+
Defaults to `1e-7`.
|
| 31 |
+
{{base_optimizer_keyword_args}}
|
| 32 |
+
|
| 33 |
+
Reference:
|
| 34 |
+
|
| 35 |
+
- [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf).
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
learning_rate=0.001,
|
| 42 |
+
beta_1=0.9,
|
| 43 |
+
beta_2=0.999,
|
| 44 |
+
epsilon=1e-7,
|
| 45 |
+
weight_decay=None,
|
| 46 |
+
clipnorm=None,
|
| 47 |
+
clipvalue=None,
|
| 48 |
+
global_clipnorm=None,
|
| 49 |
+
use_ema=False,
|
| 50 |
+
ema_momentum=0.99,
|
| 51 |
+
ema_overwrite_frequency=None,
|
| 52 |
+
loss_scale_factor=None,
|
| 53 |
+
gradient_accumulation_steps=None,
|
| 54 |
+
name="nadam",
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
super().__init__(
|
| 58 |
+
learning_rate=learning_rate,
|
| 59 |
+
name=name,
|
| 60 |
+
weight_decay=weight_decay,
|
| 61 |
+
clipnorm=clipnorm,
|
| 62 |
+
clipvalue=clipvalue,
|
| 63 |
+
global_clipnorm=global_clipnorm,
|
| 64 |
+
use_ema=use_ema,
|
| 65 |
+
ema_momentum=ema_momentum,
|
| 66 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 67 |
+
loss_scale_factor=loss_scale_factor,
|
| 68 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 69 |
+
**kwargs,
|
| 70 |
+
)
|
| 71 |
+
self.beta_1 = beta_1
|
| 72 |
+
self.beta_2 = beta_2
|
| 73 |
+
self.epsilon = epsilon
|
| 74 |
+
|
| 75 |
+
def build(self, var_list):
|
| 76 |
+
"""Initialize optimizer variables.
|
| 77 |
+
|
| 78 |
+
Nadam optimizer has 2 types of variables: momentums and velocities.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
var_list: list of model variables to build Nadam variables on.
|
| 82 |
+
"""
|
| 83 |
+
if self.built:
|
| 84 |
+
return
|
| 85 |
+
if var_list:
|
| 86 |
+
dtype = var_list[0].dtype
|
| 87 |
+
else:
|
| 88 |
+
dtype = backend.floatx()
|
| 89 |
+
super().build(var_list)
|
| 90 |
+
self._momentums = []
|
| 91 |
+
self._velocities = []
|
| 92 |
+
self._u_product = backend.Variable(1.0, dtype=dtype)
|
| 93 |
+
|
| 94 |
+
for var in var_list:
|
| 95 |
+
self._momentums.append(
|
| 96 |
+
self.add_variable_from_reference(
|
| 97 |
+
reference_variable=var, name="momentum"
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
self._velocities.append(
|
| 101 |
+
self.add_variable_from_reference(
|
| 102 |
+
reference_variable=var, name="velocity"
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def _backend_update_step(self, grads, trainable_variables, learning_rate):
|
| 107 |
+
dtype = self._u_product.dtype
|
| 108 |
+
self.assign(
|
| 109 |
+
self._u_product,
|
| 110 |
+
self._u_product
|
| 111 |
+
* self.beta_1
|
| 112 |
+
* (
|
| 113 |
+
1.0
|
| 114 |
+
- 0.5 * ops.power(0.96, ops.cast(self.iterations + 1, dtype))
|
| 115 |
+
),
|
| 116 |
+
)
|
| 117 |
+
super()._backend_update_step(grads, trainable_variables, learning_rate)
|
| 118 |
+
|
| 119 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 120 |
+
"""Update step given gradient and the associated model variable."""
|
| 121 |
+
var_dtype = variable.dtype
|
| 122 |
+
lr = ops.cast(learning_rate, var_dtype)
|
| 123 |
+
gradient = ops.cast(gradient, var_dtype)
|
| 124 |
+
|
| 125 |
+
local_step = ops.cast(self.iterations + 1, var_dtype)
|
| 126 |
+
next_step = ops.cast(self.iterations + 2, var_dtype)
|
| 127 |
+
decay = ops.cast(0.96, var_dtype)
|
| 128 |
+
beta_1 = ops.cast(self.beta_1, var_dtype)
|
| 129 |
+
beta_2 = ops.cast(self.beta_2, var_dtype)
|
| 130 |
+
u_t = beta_1 * (1.0 - 0.5 * (ops.power(decay, local_step)))
|
| 131 |
+
u_t_1 = beta_1 * (1.0 - 0.5 * (ops.power(decay, next_step)))
|
| 132 |
+
u_product_t = ops.cast(self._u_product, var_dtype)
|
| 133 |
+
|
| 134 |
+
u_product_t_1 = u_product_t * u_t_1
|
| 135 |
+
beta_2_power = ops.power(beta_2, local_step)
|
| 136 |
+
|
| 137 |
+
m = self._momentums[self._get_variable_index(variable)]
|
| 138 |
+
v = self._velocities[self._get_variable_index(variable)]
|
| 139 |
+
|
| 140 |
+
self.assign_add(
|
| 141 |
+
m, ops.multiply(ops.subtract(gradient, m), (1 - beta_1))
|
| 142 |
+
)
|
| 143 |
+
self.assign_add(
|
| 144 |
+
v, ops.multiply(ops.subtract(ops.square(gradient), v), (1 - beta_2))
|
| 145 |
+
)
|
| 146 |
+
m_hat = ops.add(
|
| 147 |
+
ops.divide(ops.multiply(u_t_1, m), 1 - u_product_t_1),
|
| 148 |
+
ops.divide(ops.multiply(1 - u_t, gradient), 1 - u_product_t),
|
| 149 |
+
)
|
| 150 |
+
v_hat = ops.divide(v, (1 - beta_2_power))
|
| 151 |
+
|
| 152 |
+
self.assign_sub(
|
| 153 |
+
variable,
|
| 154 |
+
ops.divide(
|
| 155 |
+
ops.multiply(m_hat, lr), ops.add(ops.sqrt(v_hat), self.epsilon)
|
| 156 |
+
),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def get_config(self):
|
| 160 |
+
config = super().get_config()
|
| 161 |
+
|
| 162 |
+
config.update(
|
| 163 |
+
{
|
| 164 |
+
"beta_1": self.beta_1,
|
| 165 |
+
"beta_2": self.beta_2,
|
| 166 |
+
"epsilon": self.epsilon,
|
| 167 |
+
}
|
| 168 |
+
)
|
| 169 |
+
return config
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
Nadam.__doc__ = Nadam.__doc__.replace(
|
| 173 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 174 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/optimizer.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src.api_export import keras_export
|
| 3 |
+
from keras.src.optimizers import base_optimizer
|
| 4 |
+
|
| 5 |
+
if backend.backend() == "tensorflow":
|
| 6 |
+
from keras.src.backend.tensorflow.optimizer import (
|
| 7 |
+
TFOptimizer as BackendOptimizer,
|
| 8 |
+
)
|
| 9 |
+
elif backend.backend() == "torch":
|
| 10 |
+
from keras.src.backend.torch.optimizers import (
|
| 11 |
+
TorchOptimizer as BackendOptimizer,
|
| 12 |
+
)
|
| 13 |
+
elif backend.backend() == "jax":
|
| 14 |
+
from keras.src.backend.jax.optimizer import JaxOptimizer as BackendOptimizer
|
| 15 |
+
else:
|
| 16 |
+
|
| 17 |
+
class BackendOptimizer(base_optimizer.BaseOptimizer):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@keras_export(["keras.Optimizer", "keras.optimizers.Optimizer"])
|
| 22 |
+
class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Optimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__
|
| 27 |
+
base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/rmsprop.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import ops
|
| 2 |
+
from keras.src.api_export import keras_export
|
| 3 |
+
from keras.src.optimizers import optimizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@keras_export(["keras.optimizers.RMSprop"])
|
| 7 |
+
class RMSprop(optimizer.Optimizer):
|
| 8 |
+
"""Optimizer that implements the RMSprop algorithm.
|
| 9 |
+
|
| 10 |
+
The gist of RMSprop is to:
|
| 11 |
+
|
| 12 |
+
- Maintain a moving (discounted) average of the square of gradients
|
| 13 |
+
- Divide the gradient by the root of this average
|
| 14 |
+
|
| 15 |
+
This implementation of RMSprop uses plain momentum, not Nesterov momentum.
|
| 16 |
+
|
| 17 |
+
The centered version additionally maintains a moving average of the
|
| 18 |
+
gradients, and uses that average to estimate the variance.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
learning_rate: A float, a
|
| 22 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 23 |
+
a callable that takes no arguments and returns the actual value to
|
| 24 |
+
use. The learning rate. Defaults to `0.001`.
|
| 25 |
+
rho: float, defaults to 0.9. Discounting factor for the old gradients.
|
| 26 |
+
momentum: float, defaults to 0.0. If not 0.0., the optimizer tracks the
|
| 27 |
+
momentum value, with a decay rate equals to `1 - momentum`.
|
| 28 |
+
epsilon: A small constant for numerical stability. This epsilon is
|
| 29 |
+
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
| 30 |
+
Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults
|
| 31 |
+
to 1e-7.
|
| 32 |
+
centered: Boolean. If `True`, gradients are normalized by the estimated
|
| 33 |
+
variance of the gradient; if False, by the uncentered second moment.
|
| 34 |
+
Setting this to `True` may help with training, but is slightly more
|
| 35 |
+
expensive in terms of computation and memory. Defaults to `False`.
|
| 36 |
+
{{base_optimizer_keyword_args}}
|
| 37 |
+
|
| 38 |
+
Example:
|
| 39 |
+
|
| 40 |
+
>>> opt = keras.optimizers.RMSprop(learning_rate=0.1)
|
| 41 |
+
>>> var1 = keras.backend.Variable(10.0)
|
| 42 |
+
>>> loss = lambda: (var1 ** 2) / 2.0 # d(loss) / d(var1) = var1
|
| 43 |
+
>>> opt.minimize(loss, [var1])
|
| 44 |
+
>>> var1
|
| 45 |
+
9.683772
|
| 46 |
+
|
| 47 |
+
Reference:
|
| 48 |
+
|
| 49 |
+
- [Hinton, 2012](
|
| 50 |
+
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
learning_rate=0.001,
|
| 56 |
+
rho=0.9,
|
| 57 |
+
momentum=0.0,
|
| 58 |
+
epsilon=1e-7,
|
| 59 |
+
centered=False,
|
| 60 |
+
weight_decay=None,
|
| 61 |
+
clipnorm=None,
|
| 62 |
+
clipvalue=None,
|
| 63 |
+
global_clipnorm=None,
|
| 64 |
+
use_ema=False,
|
| 65 |
+
ema_momentum=0.99,
|
| 66 |
+
ema_overwrite_frequency=None,
|
| 67 |
+
loss_scale_factor=None,
|
| 68 |
+
gradient_accumulation_steps=None,
|
| 69 |
+
name="rmsprop",
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
super().__init__(
|
| 73 |
+
learning_rate=learning_rate,
|
| 74 |
+
weight_decay=weight_decay,
|
| 75 |
+
clipnorm=clipnorm,
|
| 76 |
+
clipvalue=clipvalue,
|
| 77 |
+
global_clipnorm=global_clipnorm,
|
| 78 |
+
use_ema=use_ema,
|
| 79 |
+
ema_momentum=ema_momentum,
|
| 80 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 81 |
+
loss_scale_factor=loss_scale_factor,
|
| 82 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 83 |
+
name=name,
|
| 84 |
+
**kwargs,
|
| 85 |
+
)
|
| 86 |
+
self.rho = rho
|
| 87 |
+
self.momentum = momentum
|
| 88 |
+
self.epsilon = epsilon
|
| 89 |
+
self.centered = centered
|
| 90 |
+
|
| 91 |
+
def build(self, var_list):
|
| 92 |
+
if self.built:
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
super().build(var_list)
|
| 96 |
+
|
| 97 |
+
self._velocities = []
|
| 98 |
+
for var in var_list:
|
| 99 |
+
self._velocities.append(
|
| 100 |
+
self.add_variable_from_reference(var, "velocity")
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self._momentums = []
|
| 104 |
+
if self.momentum > 0:
|
| 105 |
+
for var in var_list:
|
| 106 |
+
self._momentums.append(
|
| 107 |
+
self.add_variable_from_reference(var, "momentum")
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self._average_gradients = []
|
| 111 |
+
if self.centered:
|
| 112 |
+
for var in var_list:
|
| 113 |
+
self._average_gradients.append(
|
| 114 |
+
self.add_variable_from_reference(var, "average_gradient")
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 118 |
+
"""Update step given gradient and the associated model variable."""
|
| 119 |
+
lr = ops.cast(learning_rate, variable.dtype)
|
| 120 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 121 |
+
|
| 122 |
+
velocity = self._velocities[self._get_variable_index(variable)]
|
| 123 |
+
momentum = None
|
| 124 |
+
if self.momentum > 0:
|
| 125 |
+
momentum = self._momentums[self._get_variable_index(variable)]
|
| 126 |
+
average_grad = None
|
| 127 |
+
if self.centered:
|
| 128 |
+
average_grad = self._average_gradients[
|
| 129 |
+
self._get_variable_index(variable)
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
rho = self.rho
|
| 133 |
+
|
| 134 |
+
self.assign(
|
| 135 |
+
velocity,
|
| 136 |
+
ops.add(
|
| 137 |
+
ops.multiply(rho, velocity),
|
| 138 |
+
ops.multiply(1 - rho, ops.square(gradient)),
|
| 139 |
+
),
|
| 140 |
+
)
|
| 141 |
+
if self.centered:
|
| 142 |
+
self.assign(
|
| 143 |
+
average_grad,
|
| 144 |
+
ops.add(
|
| 145 |
+
ops.multiply(rho, average_grad),
|
| 146 |
+
ops.multiply(1 - rho, gradient),
|
| 147 |
+
),
|
| 148 |
+
)
|
| 149 |
+
denominator = velocity - ops.square(average_grad) + self.epsilon
|
| 150 |
+
else:
|
| 151 |
+
denominator = ops.add(velocity, self.epsilon)
|
| 152 |
+
increment = ops.divide(
|
| 153 |
+
ops.multiply(lr, gradient), ops.sqrt(denominator)
|
| 154 |
+
)
|
| 155 |
+
if self.momentum > 0:
|
| 156 |
+
self.assign(
|
| 157 |
+
momentum,
|
| 158 |
+
ops.add(ops.multiply(self.momentum, momentum), increment),
|
| 159 |
+
)
|
| 160 |
+
self.assign_sub(variable, momentum)
|
| 161 |
+
else:
|
| 162 |
+
self.assign_sub(variable, increment)
|
| 163 |
+
|
| 164 |
+
def get_config(self):
|
| 165 |
+
config = super().get_config()
|
| 166 |
+
|
| 167 |
+
config.update(
|
| 168 |
+
{
|
| 169 |
+
"rho": self.rho,
|
| 170 |
+
"momentum": self.momentum,
|
| 171 |
+
"epsilon": self.epsilon,
|
| 172 |
+
"centered": self.centered,
|
| 173 |
+
}
|
| 174 |
+
)
|
| 175 |
+
return config
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
RMSprop.__doc__ = RMSprop.__doc__.replace(
|
| 179 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 180 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src.optimizers.schedules.learning_rate_schedule import CosineDecay
|
| 2 |
+
from keras.src.optimizers.schedules.learning_rate_schedule import (
|
| 3 |
+
CosineDecayRestarts,
|
| 4 |
+
)
|
| 5 |
+
from keras.src.optimizers.schedules.learning_rate_schedule import (
|
| 6 |
+
ExponentialDecay,
|
| 7 |
+
)
|
| 8 |
+
from keras.src.optimizers.schedules.learning_rate_schedule import (
|
| 9 |
+
InverseTimeDecay,
|
| 10 |
+
)
|
| 11 |
+
from keras.src.optimizers.schedules.learning_rate_schedule import (
|
| 12 |
+
PiecewiseConstantDecay,
|
| 13 |
+
)
|
| 14 |
+
from keras.src.optimizers.schedules.learning_rate_schedule import (
|
| 15 |
+
PolynomialDecay,
|
| 16 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (498 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/__pycache__/learning_rate_schedule.cpython-310.pyc
ADDED
|
Binary file (30.4 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/schedules/learning_rate_schedule.py
ADDED
|
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Various learning rate schedule functions."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
from keras.src import ops
|
| 6 |
+
from keras.src.api_export import keras_export
|
| 7 |
+
from keras.src.saving import serialization_lib
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@keras_export("keras.optimizers.schedules.LearningRateSchedule")
|
| 11 |
+
class LearningRateSchedule:
|
| 12 |
+
"""The learning rate schedule base class.
|
| 13 |
+
|
| 14 |
+
You can use a learning rate schedule to modulate how the learning rate
|
| 15 |
+
of your optimizer changes over time.
|
| 16 |
+
|
| 17 |
+
Several built-in learning rate schedules are available, such as
|
| 18 |
+
`keras.optimizers.schedules.ExponentialDecay` or
|
| 19 |
+
`keras.optimizers.schedules.PiecewiseConstantDecay`:
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
|
| 23 |
+
initial_learning_rate=1e-2,
|
| 24 |
+
decay_steps=10000,
|
| 25 |
+
decay_rate=0.9)
|
| 26 |
+
optimizer = keras.optimizers.SGD(learning_rate=lr_schedule)
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
A `LearningRateSchedule` instance can be passed in as the `learning_rate`
|
| 30 |
+
argument of any optimizer.
|
| 31 |
+
|
| 32 |
+
To implement your own schedule object, you should implement the `__call__`
|
| 33 |
+
method, which takes a `step` argument (scalar integer tensor, the
|
| 34 |
+
current training step count).
|
| 35 |
+
Like for any other Keras object, you can also optionally
|
| 36 |
+
make your object serializable by implementing the `get_config`
|
| 37 |
+
and `from_config` methods.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
class MyLRSchedule(keras.optimizers.schedules.LearningRateSchedule):
|
| 43 |
+
|
| 44 |
+
def __init__(self, initial_learning_rate):
|
| 45 |
+
self.initial_learning_rate = initial_learning_rate
|
| 46 |
+
|
| 47 |
+
def __call__(self, step):
|
| 48 |
+
return self.initial_learning_rate / (step + 1)
|
| 49 |
+
|
| 50 |
+
optimizer = keras.optimizers.SGD(learning_rate=MyLRSchedule(0.1))
|
| 51 |
+
```
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __call__(self, step):
|
| 55 |
+
raise NotImplementedError(
|
| 56 |
+
f"Learning rate schedule '{self.__class__.__name__}' "
|
| 57 |
+
"must override `__call__(self, step)`."
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def get_config(self):
|
| 61 |
+
raise NotImplementedError(
|
| 62 |
+
f"Learning rate schedule '{self.__class__.__name__}' "
|
| 63 |
+
"must override `get_config()` in order to be serializable."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def from_config(cls, config):
|
| 68 |
+
"""Instantiates a `LearningRateSchedule` from its config.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
config: Output of `get_config()`.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
A `LearningRateSchedule` instance.
|
| 75 |
+
"""
|
| 76 |
+
return cls(**config)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@keras_export("keras.optimizers.schedules.ExponentialDecay")
|
| 80 |
+
class ExponentialDecay(LearningRateSchedule):
|
| 81 |
+
"""A `LearningRateSchedule` that uses an exponential decay schedule.
|
| 82 |
+
|
| 83 |
+
When training a model, it is often useful to lower the learning rate as
|
| 84 |
+
the training progresses. This schedule applies an exponential decay function
|
| 85 |
+
to an optimizer step, given a provided initial learning rate.
|
| 86 |
+
|
| 87 |
+
The schedule is a 1-arg callable that produces a decayed learning
|
| 88 |
+
rate when passed the current optimizer step. This can be useful for changing
|
| 89 |
+
the learning rate value across different invocations of optimizer functions.
|
| 90 |
+
It is computed as:
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
def decayed_learning_rate(step):
|
| 94 |
+
return initial_learning_rate * decay_rate ^ (step / decay_steps)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
If the argument `staircase` is `True`, then `step / decay_steps` is
|
| 98 |
+
an integer division and the decayed learning rate follows a
|
| 99 |
+
staircase function.
|
| 100 |
+
|
| 101 |
+
You can pass this schedule directly into a `keras.optimizers.Optimizer`
|
| 102 |
+
as the learning rate.
|
| 103 |
+
Example: When fitting a Keras model, decay every 100000 steps with a base
|
| 104 |
+
of 0.96:
|
| 105 |
+
|
| 106 |
+
```python
|
| 107 |
+
initial_learning_rate = 0.1
|
| 108 |
+
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
|
| 109 |
+
initial_learning_rate,
|
| 110 |
+
decay_steps=100000,
|
| 111 |
+
decay_rate=0.96,
|
| 112 |
+
staircase=True)
|
| 113 |
+
|
| 114 |
+
model.compile(optimizer=keras.optimizers.SGD(learning_rate=lr_schedule),
|
| 115 |
+
loss='sparse_categorical_crossentropy',
|
| 116 |
+
metrics=['accuracy'])
|
| 117 |
+
|
| 118 |
+
model.fit(data, labels, epochs=5)
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
The learning rate schedule is also serializable and deserializable using
|
| 122 |
+
`keras.optimizers.schedules.serialize` and
|
| 123 |
+
`keras.optimizers.schedules.deserialize`.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
initial_learning_rate: A Python float. The initial learning rate.
|
| 127 |
+
decay_steps: A Python integer. Must be positive. See the decay
|
| 128 |
+
computation above.
|
| 129 |
+
decay_rate: A Python float. The decay rate.
|
| 130 |
+
staircase: Boolean. If `True` decay the learning rate at discrete
|
| 131 |
+
intervals.
|
| 132 |
+
name: String. Optional name of the operation. Defaults to
|
| 133 |
+
`"ExponentialDecay`".
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
A 1-arg callable learning rate schedule that takes the current optimizer
|
| 137 |
+
step and outputs the decayed learning rate, a scalar tensor of the
|
| 138 |
+
same type as `initial_learning_rate`.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
initial_learning_rate,
|
| 144 |
+
decay_steps,
|
| 145 |
+
decay_rate,
|
| 146 |
+
staircase=False,
|
| 147 |
+
name="ExponentialDecay",
|
| 148 |
+
):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.initial_learning_rate = initial_learning_rate
|
| 151 |
+
self.decay_steps = decay_steps
|
| 152 |
+
self.decay_rate = decay_rate
|
| 153 |
+
self.staircase = staircase
|
| 154 |
+
self.name = name
|
| 155 |
+
|
| 156 |
+
if self.decay_steps <= 0:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
"Argument `decay_steps` must be > 0. "
|
| 159 |
+
f"Received: decay_steps={self.decay_steps}"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def __call__(self, step):
|
| 163 |
+
with ops.name_scope(self.name):
|
| 164 |
+
initial_learning_rate = ops.convert_to_tensor(
|
| 165 |
+
self.initial_learning_rate
|
| 166 |
+
)
|
| 167 |
+
dtype = initial_learning_rate.dtype
|
| 168 |
+
decay_steps = ops.cast(self.decay_steps, dtype)
|
| 169 |
+
decay_rate = ops.cast(self.decay_rate, dtype)
|
| 170 |
+
|
| 171 |
+
global_step_recomp = ops.cast(step, dtype)
|
| 172 |
+
p = global_step_recomp / decay_steps
|
| 173 |
+
if self.staircase:
|
| 174 |
+
p = ops.floor(p)
|
| 175 |
+
return ops.multiply(initial_learning_rate, ops.power(decay_rate, p))
|
| 176 |
+
|
| 177 |
+
def get_config(self):
|
| 178 |
+
return {
|
| 179 |
+
"initial_learning_rate": self.initial_learning_rate,
|
| 180 |
+
"decay_steps": self.decay_steps,
|
| 181 |
+
"decay_rate": self.decay_rate,
|
| 182 |
+
"staircase": self.staircase,
|
| 183 |
+
"name": self.name,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@keras_export("keras.optimizers.schedules.PiecewiseConstantDecay")
|
| 188 |
+
class PiecewiseConstantDecay(LearningRateSchedule):
|
| 189 |
+
"""A `LearningRateSchedule` that uses a piecewise constant decay schedule.
|
| 190 |
+
|
| 191 |
+
The function returns a 1-arg callable to compute the piecewise constant
|
| 192 |
+
when passed the current optimizer step. This can be useful for changing the
|
| 193 |
+
learning rate value across different invocations of optimizer functions.
|
| 194 |
+
|
| 195 |
+
Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
|
| 196 |
+
for the next 10000 steps, and 0.1 for any additional steps.
|
| 197 |
+
|
| 198 |
+
```python
|
| 199 |
+
step = ops.array(0)
|
| 200 |
+
boundaries = [100000, 110000]
|
| 201 |
+
values = [1.0, 0.5, 0.1]
|
| 202 |
+
learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay(
|
| 203 |
+
boundaries, values)
|
| 204 |
+
|
| 205 |
+
# Later, whenever we perform an optimization step, we pass in the step.
|
| 206 |
+
learning_rate = learning_rate_fn(step)
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
You can pass this schedule directly into a `keras.optimizers.Optimizer`
|
| 210 |
+
as the learning rate. The learning rate schedule is also serializable and
|
| 211 |
+
deserializable using `keras.optimizers.schedules.serialize` and
|
| 212 |
+
`keras.optimizers.schedules.deserialize`.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
boundaries: A list of Python numbers with strictly increasing
|
| 216 |
+
entries, and with all elements having the same type as the
|
| 217 |
+
optimizer step.
|
| 218 |
+
values: A list of Python numbers that specifies the values for the
|
| 219 |
+
intervals defined by `boundaries`. It should have one more
|
| 220 |
+
element than `boundaries`, and all elements should have the same
|
| 221 |
+
type.
|
| 222 |
+
name: A string. Optional name of the operation. Defaults to
|
| 223 |
+
`"PiecewiseConstant"`.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
A 1-arg callable learning rate schedule that takes the current optimizer
|
| 227 |
+
step and outputs the decayed learning rate, a scalar tensor of the
|
| 228 |
+
same type as the boundary tensors.
|
| 229 |
+
|
| 230 |
+
The output of the 1-arg function that takes the `step`
|
| 231 |
+
is `values[0]` when `step <= boundaries[0]`,
|
| 232 |
+
`values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`,
|
| 233 |
+
..., and `values[-1]` when `step > boundaries[-1]`.
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
Raises:
|
| 237 |
+
ValueError: if the number of elements in the `boundaries` and `values`
|
| 238 |
+
lists do not match.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
def __init__(self, boundaries, values, name="PiecewiseConstant"):
|
| 242 |
+
super().__init__()
|
| 243 |
+
|
| 244 |
+
if len(boundaries) != len(values) - 1:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
"The length of boundaries should be 1 less than the length of "
|
| 247 |
+
f"values. Received: boundaries={boundaries} of length "
|
| 248 |
+
f"{len(boundaries)}, and values={values} "
|
| 249 |
+
f"of length {len(values)}."
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
self.boundaries = boundaries
|
| 253 |
+
self.values = values
|
| 254 |
+
self.name = name
|
| 255 |
+
|
| 256 |
+
def __call__(self, step):
|
| 257 |
+
with ops.name_scope(self.name):
|
| 258 |
+
boundaries = [ops.convert_to_tensor(x) for x in self.boundaries]
|
| 259 |
+
values = [ops.convert_to_tensor(x) for x in self.values]
|
| 260 |
+
step = ops.convert_to_tensor(step)
|
| 261 |
+
|
| 262 |
+
for i, b in enumerate(boundaries):
|
| 263 |
+
if b.dtype != step.dtype:
|
| 264 |
+
# We cast the boundaries to have the same type as the step
|
| 265 |
+
b = ops.cast(b, step.dtype)
|
| 266 |
+
boundaries[i] = b
|
| 267 |
+
|
| 268 |
+
result_dtype = values[0].dtype
|
| 269 |
+
result_value = ops.array(0, dtype=result_dtype)
|
| 270 |
+
|
| 271 |
+
# For each range between boundaries, we check whether the step is
|
| 272 |
+
# within that range, cast the resulting boolean to a number,
|
| 273 |
+
# and multiply the result by the corresponding value for the range.
|
| 274 |
+
# Taking the sum of these yields a piecewise constant function.
|
| 275 |
+
step_less_than_first_boundary = ops.cast(
|
| 276 |
+
step <= boundaries[0], result_dtype
|
| 277 |
+
)
|
| 278 |
+
result_value += step_less_than_first_boundary * values[0]
|
| 279 |
+
|
| 280 |
+
step_greater_than_last_boundary = ops.cast(
|
| 281 |
+
step > boundaries[-1], result_dtype
|
| 282 |
+
)
|
| 283 |
+
result_value += step_greater_than_last_boundary * values[-1]
|
| 284 |
+
|
| 285 |
+
for low, high, value in zip(
|
| 286 |
+
boundaries[:-1], boundaries[1:], values[1:-1]
|
| 287 |
+
):
|
| 288 |
+
# Need to bind v here; can do this with lambda v=v: ...
|
| 289 |
+
step_in_range = ops.cast(
|
| 290 |
+
(step > low) & (step <= high), result_dtype
|
| 291 |
+
)
|
| 292 |
+
result_value += step_in_range * value
|
| 293 |
+
|
| 294 |
+
return result_value
|
| 295 |
+
|
| 296 |
+
def get_config(self):
|
| 297 |
+
return {
|
| 298 |
+
"boundaries": self.boundaries,
|
| 299 |
+
"values": self.values,
|
| 300 |
+
"name": self.name,
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@keras_export("keras.optimizers.schedules.PolynomialDecay")
|
| 305 |
+
class PolynomialDecay(LearningRateSchedule):
|
| 306 |
+
"""A `LearningRateSchedule` that uses a polynomial decay schedule.
|
| 307 |
+
|
| 308 |
+
It is commonly observed that a monotonically decreasing learning rate, whose
|
| 309 |
+
degree of change is carefully chosen, results in a better performing model.
|
| 310 |
+
This schedule applies a polynomial decay function to an optimizer step,
|
| 311 |
+
given a provided `initial_learning_rate`, to reach an `end_learning_rate`
|
| 312 |
+
in the given `decay_steps`.
|
| 313 |
+
|
| 314 |
+
It requires a `step` value to compute the decayed learning rate. You
|
| 315 |
+
can just pass a backend variable that you increment at each training
|
| 316 |
+
step.
|
| 317 |
+
|
| 318 |
+
The schedule is a 1-arg callable that produces a decayed learning rate
|
| 319 |
+
when passed the current optimizer step. This can be useful for changing the
|
| 320 |
+
learning rate value across different invocations of optimizer functions.
|
| 321 |
+
It is computed as:
|
| 322 |
+
|
| 323 |
+
```python
|
| 324 |
+
def decayed_learning_rate(step):
|
| 325 |
+
step = min(step, decay_steps)
|
| 326 |
+
return ((initial_learning_rate - end_learning_rate) *
|
| 327 |
+
(1 - step / decay_steps) ^ (power)
|
| 328 |
+
) + end_learning_rate
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
If `cycle` is True then a multiple of `decay_steps` is used, the first one
|
| 332 |
+
that is bigger than `step`.
|
| 333 |
+
|
| 334 |
+
```python
|
| 335 |
+
def decayed_learning_rate(step):
|
| 336 |
+
decay_steps = decay_steps * ceil(step / decay_steps)
|
| 337 |
+
return ((initial_learning_rate - end_learning_rate) *
|
| 338 |
+
(1 - step / decay_steps) ^ (power)
|
| 339 |
+
) + end_learning_rate
|
| 340 |
+
```
|
| 341 |
+
|
| 342 |
+
You can pass this schedule directly into a `keras.optimizers.Optimizer`
|
| 343 |
+
as the learning rate.
|
| 344 |
+
Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using
|
| 345 |
+
sqrt (i.e. power=0.5):
|
| 346 |
+
|
| 347 |
+
```python
|
| 348 |
+
...
|
| 349 |
+
starter_learning_rate = 0.1
|
| 350 |
+
end_learning_rate = 0.01
|
| 351 |
+
decay_steps = 10000
|
| 352 |
+
learning_rate_fn = keras.optimizers.schedules.PolynomialDecay(
|
| 353 |
+
starter_learning_rate,
|
| 354 |
+
decay_steps,
|
| 355 |
+
end_learning_rate,
|
| 356 |
+
power=0.5)
|
| 357 |
+
|
| 358 |
+
model.compile(optimizer=keras.optimizers.SGD(
|
| 359 |
+
learning_rate=learning_rate_fn),
|
| 360 |
+
loss='sparse_categorical_crossentropy',
|
| 361 |
+
metrics=['accuracy'])
|
| 362 |
+
|
| 363 |
+
model.fit(data, labels, epochs=5)
|
| 364 |
+
```
|
| 365 |
+
|
| 366 |
+
The learning rate schedule is also serializable and deserializable using
|
| 367 |
+
`keras.optimizers.schedules.serialize` and
|
| 368 |
+
`keras.optimizers.schedules.deserialize`.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
initial_learning_rate: A Python float. The initial learning rate.
|
| 372 |
+
decay_steps: A Python integer. Must be positive. See the decay
|
| 373 |
+
computation above.
|
| 374 |
+
end_learning_rate: A Python float. The minimal end learning rate.
|
| 375 |
+
power: A Python float. The power of the polynomial. Defaults to
|
| 376 |
+
`1.0`.
|
| 377 |
+
cycle: A boolean, whether it should cycle beyond decay_steps.
|
| 378 |
+
name: String. Optional name of the operation. Defaults to
|
| 379 |
+
`"PolynomialDecay"`.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
A 1-arg callable learning rate schedule that takes the current optimizer
|
| 383 |
+
step and outputs the decayed learning rate, a scalar tensor of the
|
| 384 |
+
same type as `initial_learning_rate`.
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
def __init__(
|
| 388 |
+
self,
|
| 389 |
+
initial_learning_rate,
|
| 390 |
+
decay_steps,
|
| 391 |
+
end_learning_rate=0.0001,
|
| 392 |
+
power=1.0,
|
| 393 |
+
cycle=False,
|
| 394 |
+
name="PolynomialDecay",
|
| 395 |
+
):
|
| 396 |
+
super().__init__()
|
| 397 |
+
|
| 398 |
+
self.initial_learning_rate = initial_learning_rate
|
| 399 |
+
self.decay_steps = decay_steps
|
| 400 |
+
self.end_learning_rate = end_learning_rate
|
| 401 |
+
self.power = power
|
| 402 |
+
self.cycle = cycle
|
| 403 |
+
self.name = name
|
| 404 |
+
|
| 405 |
+
if self.decay_steps <= 0:
|
| 406 |
+
raise ValueError(
|
| 407 |
+
"Argument `decay_steps` must be > 0. "
|
| 408 |
+
f"Received: decay_steps={self.decay_steps}"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
def __call__(self, step):
|
| 412 |
+
with ops.name_scope(self.name):
|
| 413 |
+
initial_learning_rate = ops.convert_to_tensor(
|
| 414 |
+
self.initial_learning_rate
|
| 415 |
+
)
|
| 416 |
+
dtype = initial_learning_rate.dtype
|
| 417 |
+
end_learning_rate = ops.cast(self.end_learning_rate, dtype)
|
| 418 |
+
power = ops.cast(self.power, dtype)
|
| 419 |
+
|
| 420 |
+
global_step_recomp = ops.cast(step, dtype)
|
| 421 |
+
decay_steps_recomp = ops.cast(self.decay_steps, dtype)
|
| 422 |
+
if self.cycle:
|
| 423 |
+
# Find the first multiple of decay_steps that is bigger than
|
| 424 |
+
# global_step. If global_step is zero set the multiplier to 1
|
| 425 |
+
multiplier = ops.where(
|
| 426 |
+
ops.equal(global_step_recomp, 0),
|
| 427 |
+
1.0,
|
| 428 |
+
ops.ceil(global_step_recomp / self.decay_steps),
|
| 429 |
+
)
|
| 430 |
+
decay_steps_recomp = ops.multiply(
|
| 431 |
+
decay_steps_recomp, multiplier
|
| 432 |
+
)
|
| 433 |
+
else:
|
| 434 |
+
# Make sure that the global_step used is not bigger than
|
| 435 |
+
# decay_steps.
|
| 436 |
+
global_step_recomp = ops.minimum(
|
| 437 |
+
global_step_recomp, decay_steps_recomp
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
p = ops.divide(global_step_recomp, decay_steps_recomp)
|
| 441 |
+
return ops.add(
|
| 442 |
+
ops.multiply(
|
| 443 |
+
initial_learning_rate - end_learning_rate,
|
| 444 |
+
ops.power(1 - p, power),
|
| 445 |
+
),
|
| 446 |
+
end_learning_rate,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
def get_config(self):
|
| 450 |
+
return {
|
| 451 |
+
"initial_learning_rate": self.initial_learning_rate,
|
| 452 |
+
"decay_steps": self.decay_steps,
|
| 453 |
+
"end_learning_rate": self.end_learning_rate,
|
| 454 |
+
"power": self.power,
|
| 455 |
+
"cycle": self.cycle,
|
| 456 |
+
"name": self.name,
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
@keras_export("keras.optimizers.schedules.InverseTimeDecay")
|
| 461 |
+
class InverseTimeDecay(LearningRateSchedule):
|
| 462 |
+
"""A `LearningRateSchedule` that uses an inverse time decay schedule.
|
| 463 |
+
|
| 464 |
+
When training a model, it is often useful to lower the learning rate as
|
| 465 |
+
the training progresses. This schedule applies the inverse decay function
|
| 466 |
+
to an optimizer step, given a provided initial learning rate.
|
| 467 |
+
It requires a `step` value to compute the decayed learning rate. You can
|
| 468 |
+
just pass a backend variable that you increment at each training step.
|
| 469 |
+
|
| 470 |
+
The schedule is a 1-arg callable that produces a decayed learning
|
| 471 |
+
rate when passed the current optimizer step. This can be useful for changing
|
| 472 |
+
the learning rate value across different invocations of optimizer functions.
|
| 473 |
+
It is computed as:
|
| 474 |
+
|
| 475 |
+
```python
|
| 476 |
+
def decayed_learning_rate(step):
|
| 477 |
+
return initial_learning_rate / (1 + decay_rate * step / decay_step)
|
| 478 |
+
```
|
| 479 |
+
|
| 480 |
+
or, if `staircase` is `True`, as:
|
| 481 |
+
|
| 482 |
+
```python
|
| 483 |
+
def decayed_learning_rate(step):
|
| 484 |
+
return initial_learning_rate /
|
| 485 |
+
(1 + decay_rate * floor(step / decay_step))
|
| 486 |
+
```
|
| 487 |
+
|
| 488 |
+
You can pass this schedule directly into a `keras.optimizers.Optimizer`
|
| 489 |
+
as the learning rate.
|
| 490 |
+
Example: Fit a Keras model when decaying 1/t with a rate of 0.5:
|
| 491 |
+
|
| 492 |
+
```python
|
| 493 |
+
...
|
| 494 |
+
initial_learning_rate = 0.1
|
| 495 |
+
decay_steps = 1.0
|
| 496 |
+
decay_rate = 0.5
|
| 497 |
+
learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay(
|
| 498 |
+
initial_learning_rate, decay_steps, decay_rate)
|
| 499 |
+
|
| 500 |
+
model.compile(optimizer=keras.optimizers.SGD(
|
| 501 |
+
learning_rate=learning_rate_fn),
|
| 502 |
+
loss='sparse_categorical_crossentropy',
|
| 503 |
+
metrics=['accuracy'])
|
| 504 |
+
|
| 505 |
+
model.fit(data, labels, epochs=5)
|
| 506 |
+
```
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
initial_learning_rate: A Python float. The initial learning rate.
|
| 510 |
+
decay_steps: How often to apply decay.
|
| 511 |
+
decay_rate: A Python number. The decay rate.
|
| 512 |
+
staircase: Whether to apply decay in a discrete staircase, as o
|
| 513 |
+
pposed to continuous, fashion.
|
| 514 |
+
name: String. Optional name of the operation. Defaults to
|
| 515 |
+
`"InverseTimeDecay"`.
|
| 516 |
+
|
| 517 |
+
Returns:
|
| 518 |
+
A 1-arg callable learning rate schedule that takes the current optimizer
|
| 519 |
+
step and outputs the decayed learning rate, a scalar tensor of the
|
| 520 |
+
same type as `initial_learning_rate`.
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
def __init__(
|
| 524 |
+
self,
|
| 525 |
+
initial_learning_rate,
|
| 526 |
+
decay_steps,
|
| 527 |
+
decay_rate,
|
| 528 |
+
staircase=False,
|
| 529 |
+
name="InverseTimeDecay",
|
| 530 |
+
):
|
| 531 |
+
super().__init__()
|
| 532 |
+
|
| 533 |
+
self.initial_learning_rate = initial_learning_rate
|
| 534 |
+
self.decay_steps = decay_steps
|
| 535 |
+
self.decay_rate = decay_rate
|
| 536 |
+
self.staircase = staircase
|
| 537 |
+
self.name = name
|
| 538 |
+
|
| 539 |
+
if self.decay_steps <= 0:
|
| 540 |
+
raise ValueError(
|
| 541 |
+
"Argument `decay_steps` must be > 0. "
|
| 542 |
+
f"Received: decay_steps={self.decay_steps}"
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
def __call__(self, step):
|
| 546 |
+
with ops.name_scope(self.name):
|
| 547 |
+
initial_learning_rate = ops.convert_to_tensor(
|
| 548 |
+
self.initial_learning_rate
|
| 549 |
+
)
|
| 550 |
+
dtype = initial_learning_rate.dtype
|
| 551 |
+
decay_steps = ops.cast(self.decay_steps, dtype)
|
| 552 |
+
decay_rate = ops.cast(self.decay_rate, dtype)
|
| 553 |
+
|
| 554 |
+
global_step_recomp = ops.cast(step, dtype)
|
| 555 |
+
p = global_step_recomp / decay_steps
|
| 556 |
+
if self.staircase:
|
| 557 |
+
p = ops.floor(p)
|
| 558 |
+
const = ops.cast(ops.array(1), dtype)
|
| 559 |
+
denom = ops.add(const, ops.multiply(decay_rate, p))
|
| 560 |
+
return ops.divide(initial_learning_rate, denom)
|
| 561 |
+
|
| 562 |
+
def get_config(self):
|
| 563 |
+
return {
|
| 564 |
+
"initial_learning_rate": self.initial_learning_rate,
|
| 565 |
+
"decay_steps": self.decay_steps,
|
| 566 |
+
"decay_rate": self.decay_rate,
|
| 567 |
+
"staircase": self.staircase,
|
| 568 |
+
"name": self.name,
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
@keras_export("keras.optimizers.schedules.CosineDecay")
|
| 573 |
+
class CosineDecay(LearningRateSchedule):
|
| 574 |
+
"""A `LearningRateSchedule` that uses a cosine decay with optional warmup.
|
| 575 |
+
|
| 576 |
+
See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
|
| 577 |
+
SGDR: Stochastic Gradient Descent with Warm Restarts.
|
| 578 |
+
|
| 579 |
+
For the idea of a linear warmup of our learning rate,
|
| 580 |
+
see [Goyal et al.](https://arxiv.org/pdf/1706.02677.pdf).
|
| 581 |
+
|
| 582 |
+
When we begin training a model, we often want an initial increase in our
|
| 583 |
+
learning rate followed by a decay. If `warmup_target` is an int, this
|
| 584 |
+
schedule applies a linear increase per optimizer step to our learning rate
|
| 585 |
+
from `initial_learning_rate` to `warmup_target` for a duration of
|
| 586 |
+
`warmup_steps`. Afterwards, it applies a cosine decay function taking our
|
| 587 |
+
learning rate from `warmup_target` to `alpha` for a duration of
|
| 588 |
+
`decay_steps`. If `warmup_target` is None we skip warmup and our decay
|
| 589 |
+
will take our learning rate from `initial_learning_rate` to `alpha`.
|
| 590 |
+
It requires a `step` value to compute the learning rate. You can
|
| 591 |
+
just pass a backend variable that you increment at each training step.
|
| 592 |
+
|
| 593 |
+
The schedule is a 1-arg callable that produces a warmup followed by a
|
| 594 |
+
decayed learning rate when passed the current optimizer step. This can be
|
| 595 |
+
useful for changing the learning rate value across different invocations of
|
| 596 |
+
optimizer functions.
|
| 597 |
+
|
| 598 |
+
Our warmup is computed as:
|
| 599 |
+
|
| 600 |
+
```python
|
| 601 |
+
def warmup_learning_rate(step):
|
| 602 |
+
completed_fraction = step / warmup_steps
|
| 603 |
+
total_delta = target_warmup - initial_learning_rate
|
| 604 |
+
return completed_fraction * total_delta
|
| 605 |
+
```
|
| 606 |
+
|
| 607 |
+
And our decay is computed as:
|
| 608 |
+
|
| 609 |
+
```python
|
| 610 |
+
if warmup_target is None:
|
| 611 |
+
initial_decay_lr = initial_learning_rate
|
| 612 |
+
else:
|
| 613 |
+
initial_decay_lr = warmup_target
|
| 614 |
+
|
| 615 |
+
def decayed_learning_rate(step):
|
| 616 |
+
step = min(step, decay_steps)
|
| 617 |
+
cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps))
|
| 618 |
+
decayed = (1 - alpha) * cosine_decay + alpha
|
| 619 |
+
return initial_decay_lr * decayed
|
| 620 |
+
```
|
| 621 |
+
|
| 622 |
+
Example usage without warmup:
|
| 623 |
+
|
| 624 |
+
```python
|
| 625 |
+
decay_steps = 1000
|
| 626 |
+
initial_learning_rate = 0.1
|
| 627 |
+
lr_decayed_fn = keras.optimizers.schedules.CosineDecay(
|
| 628 |
+
initial_learning_rate, decay_steps)
|
| 629 |
+
```
|
| 630 |
+
|
| 631 |
+
Example usage with warmup:
|
| 632 |
+
|
| 633 |
+
```python
|
| 634 |
+
decay_steps = 1000
|
| 635 |
+
initial_learning_rate = 0
|
| 636 |
+
warmup_steps = 1000
|
| 637 |
+
target_learning_rate = 0.1
|
| 638 |
+
lr_warmup_decayed_fn = keras.optimizers.schedules.CosineDecay(
|
| 639 |
+
initial_learning_rate, decay_steps, warmup_target=target_learning_rate,
|
| 640 |
+
warmup_steps=warmup_steps
|
| 641 |
+
)
|
| 642 |
+
```
|
| 643 |
+
|
| 644 |
+
You can pass this schedule directly into a `keras.optimizers.Optimizer`
|
| 645 |
+
as the learning rate. The learning rate schedule is also serializable and
|
| 646 |
+
deserializable using `keras.optimizers.schedules.serialize` and
|
| 647 |
+
`keras.optimizers.schedules.deserialize`.
|
| 648 |
+
|
| 649 |
+
Args:
|
| 650 |
+
initial_learning_rate: A Python float. The initial learning rate.
|
| 651 |
+
decay_steps: A Python int. Number of steps to decay over.
|
| 652 |
+
alpha: A Python float. Minimum learning rate value for decay as a
|
| 653 |
+
fraction of `initial_learning_rate`.
|
| 654 |
+
name: String. Optional name of the operation. Defaults to
|
| 655 |
+
`"CosineDecay"`.
|
| 656 |
+
warmup_target: A Python float. The target learning rate for our
|
| 657 |
+
warmup phase. Will cast to the `initial_learning_rate` datatype.
|
| 658 |
+
Setting to `None` will skip warmup and begins decay phase from
|
| 659 |
+
`initial_learning_rate`. Otherwise scheduler will warmup from
|
| 660 |
+
`initial_learning_rate` to `warmup_target`.
|
| 661 |
+
warmup_steps: A Python int. Number of steps to warmup over.
|
| 662 |
+
|
| 663 |
+
Returns:
|
| 664 |
+
A 1-arg callable learning rate schedule that takes the current optimizer
|
| 665 |
+
step and outputs the decayed learning rate, a scalar tensor of the
|
| 666 |
+
same type as `initial_learning_rate`.
|
| 667 |
+
"""
|
| 668 |
+
|
| 669 |
+
def __init__(
|
| 670 |
+
self,
|
| 671 |
+
initial_learning_rate,
|
| 672 |
+
decay_steps,
|
| 673 |
+
alpha=0.0,
|
| 674 |
+
name="CosineDecay",
|
| 675 |
+
warmup_target=None,
|
| 676 |
+
warmup_steps=0,
|
| 677 |
+
):
|
| 678 |
+
super().__init__()
|
| 679 |
+
|
| 680 |
+
self.initial_learning_rate = initial_learning_rate
|
| 681 |
+
self.decay_steps = decay_steps
|
| 682 |
+
self.alpha = alpha
|
| 683 |
+
self.name = name
|
| 684 |
+
self.warmup_steps = warmup_steps
|
| 685 |
+
self.warmup_target = warmup_target
|
| 686 |
+
|
| 687 |
+
if self.decay_steps <= 0:
|
| 688 |
+
raise ValueError(
|
| 689 |
+
"Argument `decay_steps` must be > 0. "
|
| 690 |
+
f"Received: decay_steps={self.decay_steps}"
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
def _decay_function(self, step, decay_steps, decay_from_lr, dtype):
|
| 694 |
+
with ops.name_scope(self.name):
|
| 695 |
+
completed_fraction = step / decay_steps
|
| 696 |
+
pi = ops.array(math.pi, dtype=dtype)
|
| 697 |
+
cosine_decayed = 0.5 * (1.0 + ops.cos(pi * completed_fraction))
|
| 698 |
+
decayed = (1 - self.alpha) * cosine_decayed + self.alpha
|
| 699 |
+
return ops.multiply(decay_from_lr, decayed)
|
| 700 |
+
|
| 701 |
+
def _warmup_function(
|
| 702 |
+
self, step, warmup_steps, warmup_target, initial_learning_rate
|
| 703 |
+
):
|
| 704 |
+
with ops.name_scope(self.name):
|
| 705 |
+
completed_fraction = step / warmup_steps
|
| 706 |
+
total_step_delta = warmup_target - initial_learning_rate
|
| 707 |
+
return total_step_delta * completed_fraction + initial_learning_rate
|
| 708 |
+
|
| 709 |
+
def __call__(self, step):
|
| 710 |
+
with ops.name_scope(self.name):
|
| 711 |
+
initial_learning_rate = ops.convert_to_tensor(
|
| 712 |
+
self.initial_learning_rate
|
| 713 |
+
)
|
| 714 |
+
dtype = initial_learning_rate.dtype
|
| 715 |
+
decay_steps = ops.cast(self.decay_steps, dtype)
|
| 716 |
+
global_step_recomp = ops.cast(step, dtype)
|
| 717 |
+
|
| 718 |
+
if self.warmup_target is None:
|
| 719 |
+
global_step_recomp = ops.minimum(
|
| 720 |
+
global_step_recomp, decay_steps
|
| 721 |
+
)
|
| 722 |
+
return self._decay_function(
|
| 723 |
+
global_step_recomp,
|
| 724 |
+
decay_steps,
|
| 725 |
+
initial_learning_rate,
|
| 726 |
+
dtype,
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
warmup_target = ops.cast(self.warmup_target, dtype)
|
| 730 |
+
warmup_steps = ops.cast(self.warmup_steps, dtype)
|
| 731 |
+
|
| 732 |
+
global_step_recomp = ops.minimum(
|
| 733 |
+
global_step_recomp, decay_steps + warmup_steps
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
return ops.cond(
|
| 737 |
+
global_step_recomp < warmup_steps,
|
| 738 |
+
lambda: self._warmup_function(
|
| 739 |
+
global_step_recomp,
|
| 740 |
+
warmup_steps,
|
| 741 |
+
warmup_target,
|
| 742 |
+
initial_learning_rate,
|
| 743 |
+
),
|
| 744 |
+
lambda: self._decay_function(
|
| 745 |
+
global_step_recomp - warmup_steps,
|
| 746 |
+
decay_steps,
|
| 747 |
+
warmup_target,
|
| 748 |
+
dtype,
|
| 749 |
+
),
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
def get_config(self):
|
| 753 |
+
return {
|
| 754 |
+
"initial_learning_rate": self.initial_learning_rate,
|
| 755 |
+
"decay_steps": self.decay_steps,
|
| 756 |
+
"alpha": self.alpha,
|
| 757 |
+
"name": self.name,
|
| 758 |
+
"warmup_target": self.warmup_target,
|
| 759 |
+
"warmup_steps": self.warmup_steps,
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
@keras_export("keras.optimizers.schedules.CosineDecayRestarts")
|
| 764 |
+
class CosineDecayRestarts(LearningRateSchedule):
|
| 765 |
+
"""A `LearningRateSchedule` that uses a cosine decay schedule with restarts.
|
| 766 |
+
|
| 767 |
+
See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
|
| 768 |
+
SGDR: Stochastic Gradient Descent with Warm Restarts.
|
| 769 |
+
|
| 770 |
+
When training a model, it is often useful to lower the learning rate as
|
| 771 |
+
the training progresses. This schedule applies a cosine decay function with
|
| 772 |
+
restarts to an optimizer step, given a provided initial learning rate.
|
| 773 |
+
It requires a `step` value to compute the decayed learning rate. You can
|
| 774 |
+
just pass a backend variable that you increment at each training step.
|
| 775 |
+
|
| 776 |
+
The schedule is a 1-arg callable that produces a decayed learning
|
| 777 |
+
rate when passed the current optimizer step. This can be useful for changing
|
| 778 |
+
the learning rate value across different invocations of optimizer functions.
|
| 779 |
+
|
| 780 |
+
The learning rate multiplier first decays
|
| 781 |
+
from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
|
| 782 |
+
restart is performed. Each new warm restart runs for `t_mul` times more
|
| 783 |
+
steps and with `m_mul` times initial learning rate as the new learning rate.
|
| 784 |
+
|
| 785 |
+
Example:
|
| 786 |
+
```python
|
| 787 |
+
first_decay_steps = 1000
|
| 788 |
+
lr_decayed_fn = (
|
| 789 |
+
keras.optimizers.schedules.CosineDecayRestarts(
|
| 790 |
+
initial_learning_rate,
|
| 791 |
+
first_decay_steps))
|
| 792 |
+
```
|
| 793 |
+
|
| 794 |
+
You can pass this schedule directly into a `keras.optimizers.Optimizer`
|
| 795 |
+
as the learning rate. The learning rate schedule is also serializable and
|
| 796 |
+
deserializable using `keras.optimizers.schedules.serialize` and
|
| 797 |
+
`keras.optimizers.schedules.deserialize`.
|
| 798 |
+
|
| 799 |
+
Args:
|
| 800 |
+
initial_learning_rate: A Python float. The initial learning rate.
|
| 801 |
+
first_decay_steps: A Python integer. Number of steps to decay over.
|
| 802 |
+
t_mul: A Python float. Used to derive the number of iterations in
|
| 803 |
+
the i-th period.
|
| 804 |
+
m_mul: A Python float. Used to derive the initial learning rate of
|
| 805 |
+
the i-th period.
|
| 806 |
+
alpha: A Python float. Minimum learning rate value as a fraction of
|
| 807 |
+
the `initial_learning_rate`.
|
| 808 |
+
name: String. Optional name of the operation. Defaults to
|
| 809 |
+
`"SGDRDecay"`.
|
| 810 |
+
|
| 811 |
+
Returns:
|
| 812 |
+
A 1-arg callable learning rate schedule that takes the current optimizer
|
| 813 |
+
step and outputs the decayed learning rate, a scalar tensor of the
|
| 814 |
+
same type as `initial_learning_rate`.
|
| 815 |
+
"""
|
| 816 |
+
|
| 817 |
+
def __init__(
|
| 818 |
+
self,
|
| 819 |
+
initial_learning_rate,
|
| 820 |
+
first_decay_steps,
|
| 821 |
+
t_mul=2.0,
|
| 822 |
+
m_mul=1.0,
|
| 823 |
+
alpha=0.0,
|
| 824 |
+
name="SGDRDecay",
|
| 825 |
+
):
|
| 826 |
+
super().__init__()
|
| 827 |
+
|
| 828 |
+
self.initial_learning_rate = initial_learning_rate
|
| 829 |
+
self.first_decay_steps = first_decay_steps
|
| 830 |
+
self._t_mul = t_mul
|
| 831 |
+
self._m_mul = m_mul
|
| 832 |
+
self.alpha = alpha
|
| 833 |
+
self.name = name
|
| 834 |
+
|
| 835 |
+
if self.first_decay_steps <= 0:
|
| 836 |
+
raise ValueError(
|
| 837 |
+
"Argument `first_decay_steps` must be > 0. "
|
| 838 |
+
f"Received: first_decay_steps={self.first_decay_steps}"
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
def __call__(self, step):
|
| 842 |
+
with ops.name_scope(self.name):
|
| 843 |
+
initial_learning_rate = ops.convert_to_tensor(
|
| 844 |
+
self.initial_learning_rate
|
| 845 |
+
)
|
| 846 |
+
dtype = initial_learning_rate.dtype
|
| 847 |
+
first_decay_steps = ops.cast(self.first_decay_steps, dtype)
|
| 848 |
+
alpha = ops.cast(self.alpha, dtype)
|
| 849 |
+
t_mul = ops.cast(self._t_mul, dtype)
|
| 850 |
+
m_mul = ops.cast(self._m_mul, dtype)
|
| 851 |
+
|
| 852 |
+
global_step_recomp = ops.cast(step, dtype)
|
| 853 |
+
completed_fraction = global_step_recomp / first_decay_steps
|
| 854 |
+
|
| 855 |
+
def compute_step(completed_fraction, geometric=False):
|
| 856 |
+
"""Helper for `cond` operation."""
|
| 857 |
+
if geometric:
|
| 858 |
+
# ops.log is sensitive to the precision of dtype, so we need
|
| 859 |
+
# the additional casting
|
| 860 |
+
i_restart = ops.floor(
|
| 861 |
+
ops.log(
|
| 862 |
+
ops.cast(
|
| 863 |
+
1.0 - completed_fraction * (1.0 - t_mul), dtype
|
| 864 |
+
)
|
| 865 |
+
)
|
| 866 |
+
/ ops.log(t_mul)
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
|
| 870 |
+
completed_fraction = (
|
| 871 |
+
completed_fraction - sum_r
|
| 872 |
+
) / t_mul**i_restart
|
| 873 |
+
|
| 874 |
+
else:
|
| 875 |
+
i_restart = ops.floor(completed_fraction)
|
| 876 |
+
completed_fraction -= i_restart
|
| 877 |
+
|
| 878 |
+
return i_restart, completed_fraction
|
| 879 |
+
|
| 880 |
+
i_restart, completed_fraction = ops.cond(
|
| 881 |
+
ops.equal(t_mul, 1.0),
|
| 882 |
+
lambda: compute_step(completed_fraction, geometric=False),
|
| 883 |
+
lambda: compute_step(completed_fraction, geometric=True),
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
m_fac = m_mul**i_restart
|
| 887 |
+
cosine_decayed = (
|
| 888 |
+
0.5
|
| 889 |
+
* m_fac
|
| 890 |
+
* (
|
| 891 |
+
1.0
|
| 892 |
+
+ ops.cos(
|
| 893 |
+
ops.array(math.pi, dtype=dtype) * completed_fraction
|
| 894 |
+
)
|
| 895 |
+
)
|
| 896 |
+
)
|
| 897 |
+
decayed = (1 - alpha) * cosine_decayed + alpha
|
| 898 |
+
|
| 899 |
+
return ops.multiply(initial_learning_rate, decayed)
|
| 900 |
+
|
| 901 |
+
def get_config(self):
|
| 902 |
+
return {
|
| 903 |
+
"initial_learning_rate": self.initial_learning_rate,
|
| 904 |
+
"first_decay_steps": self.first_decay_steps,
|
| 905 |
+
"t_mul": self._t_mul,
|
| 906 |
+
"m_mul": self._m_mul,
|
| 907 |
+
"alpha": self.alpha,
|
| 908 |
+
"name": self.name,
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
@keras_export("keras.optimizers.schedules.serialize")
|
| 913 |
+
def serialize(learning_rate_schedule):
|
| 914 |
+
"""Serializes a `LearningRateSchedule` into a JSON-compatible dict.
|
| 915 |
+
|
| 916 |
+
Args:
|
| 917 |
+
learning_rate_schedule: The `LearningRateSchedule` object to serialize.
|
| 918 |
+
|
| 919 |
+
Returns:
|
| 920 |
+
A JSON-serializable dict representing the object's config.
|
| 921 |
+
|
| 922 |
+
Example:
|
| 923 |
+
|
| 924 |
+
>>> lr_schedule = keras.optimizers.schedules.ExponentialDecay(
|
| 925 |
+
... 0.1, decay_steps=100000, decay_rate=0.96, staircase=True)
|
| 926 |
+
>>> keras.optimizers.schedules.serialize(lr_schedule)
|
| 927 |
+
{'module': 'keras.optimizers.schedules',
|
| 928 |
+
'class_name': 'ExponentialDecay', 'config': {...},
|
| 929 |
+
'registered_name': None}
|
| 930 |
+
"""
|
| 931 |
+
return serialization_lib.serialize_keras_object(learning_rate_schedule)
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
@keras_export("keras.optimizers.schedules.deserialize")
|
| 935 |
+
def deserialize(config, custom_objects=None):
|
| 936 |
+
"""Instantiates a `LearningRateSchedule` object from a serialized form.
|
| 937 |
+
|
| 938 |
+
Args:
|
| 939 |
+
config: The serialized form of the `LearningRateSchedule`. Dictionary of
|
| 940 |
+
the form {'class_name': str, 'config': dict}.
|
| 941 |
+
custom_objects: A dictionary mapping class names (or function names) of
|
| 942 |
+
custom (non-Keras) objects to class/functions.
|
| 943 |
+
|
| 944 |
+
Returns:
|
| 945 |
+
A `LearningRateSchedule` object.
|
| 946 |
+
|
| 947 |
+
Example:
|
| 948 |
+
|
| 949 |
+
```python
|
| 950 |
+
# Configuration for PolynomialDecay
|
| 951 |
+
config = {
|
| 952 |
+
'class_name': 'PolynomialDecay',
|
| 953 |
+
'config': {'cycle': False,
|
| 954 |
+
'decay_steps': 10000,
|
| 955 |
+
'end_learning_rate': 0.01,
|
| 956 |
+
'initial_learning_rate': 0.1,
|
| 957 |
+
'name': None,
|
| 958 |
+
'power': 0.5
|
| 959 |
+
}
|
| 960 |
+
}
|
| 961 |
+
lr_schedule = keras.optimizers.schedules.deserialize(config)
|
| 962 |
+
```
|
| 963 |
+
"""
|
| 964 |
+
return serialization_lib.deserialize_keras_object(
|
| 965 |
+
config,
|
| 966 |
+
module_objects=globals(),
|
| 967 |
+
custom_objects=custom_objects,
|
| 968 |
+
printable_module_name="decay",
|
| 969 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/optimizers/sgd.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import ops
|
| 2 |
+
from keras.src.api_export import keras_export
|
| 3 |
+
from keras.src.optimizers import optimizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@keras_export("keras.optimizers.SGD")
|
| 7 |
+
class SGD(optimizer.Optimizer):
|
| 8 |
+
"""Gradient descent (with momentum) optimizer.
|
| 9 |
+
|
| 10 |
+
Update rule for parameter `w` with gradient `g` when `momentum` is 0:
|
| 11 |
+
|
| 12 |
+
```python
|
| 13 |
+
w = w - learning_rate * g
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Update rule when `momentum` is larger than 0:
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
velocity = momentum * velocity - learning_rate * g
|
| 20 |
+
w = w + velocity
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
When `nesterov=True`, this rule becomes:
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
velocity = momentum * velocity - learning_rate * g
|
| 27 |
+
w = w + momentum * velocity - learning_rate * g
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
learning_rate: A float, a
|
| 32 |
+
`keras.optimizers.schedules.LearningRateSchedule` instance, or
|
| 33 |
+
a callable that takes no arguments and returns the actual value to
|
| 34 |
+
use. The learning rate. Defaults to `0.01`.
|
| 35 |
+
momentum: float hyperparameter >= 0 that accelerates gradient descent in
|
| 36 |
+
the relevant direction and dampens oscillations. 0 is vanilla
|
| 37 |
+
gradient descent. Defaults to `0.0`.
|
| 38 |
+
nesterov: boolean. Whether to apply Nesterov momentum.
|
| 39 |
+
Defaults to `False`.
|
| 40 |
+
{{base_optimizer_keyword_args}}
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
learning_rate=0.01,
|
| 46 |
+
momentum=0.0,
|
| 47 |
+
nesterov=False,
|
| 48 |
+
weight_decay=None,
|
| 49 |
+
clipnorm=None,
|
| 50 |
+
clipvalue=None,
|
| 51 |
+
global_clipnorm=None,
|
| 52 |
+
use_ema=False,
|
| 53 |
+
ema_momentum=0.99,
|
| 54 |
+
ema_overwrite_frequency=None,
|
| 55 |
+
loss_scale_factor=None,
|
| 56 |
+
gradient_accumulation_steps=None,
|
| 57 |
+
name="SGD",
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
super().__init__(
|
| 61 |
+
learning_rate=learning_rate,
|
| 62 |
+
name=name,
|
| 63 |
+
weight_decay=weight_decay,
|
| 64 |
+
clipnorm=clipnorm,
|
| 65 |
+
clipvalue=clipvalue,
|
| 66 |
+
global_clipnorm=global_clipnorm,
|
| 67 |
+
use_ema=use_ema,
|
| 68 |
+
ema_momentum=ema_momentum,
|
| 69 |
+
ema_overwrite_frequency=ema_overwrite_frequency,
|
| 70 |
+
loss_scale_factor=loss_scale_factor,
|
| 71 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 72 |
+
**kwargs,
|
| 73 |
+
)
|
| 74 |
+
if not isinstance(momentum, float) or momentum < 0 or momentum > 1:
|
| 75 |
+
raise ValueError("`momentum` must be a float between [0, 1].")
|
| 76 |
+
self.momentum = momentum
|
| 77 |
+
self.nesterov = nesterov
|
| 78 |
+
|
| 79 |
+
def build(self, variables):
|
| 80 |
+
"""Initialize optimizer variables.
|
| 81 |
+
|
| 82 |
+
SGD optimizer has one variable `momentums`, only set if `self.momentum`
|
| 83 |
+
is not 0.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
var_list: list of model variables to build SGD variables on.
|
| 87 |
+
"""
|
| 88 |
+
if self.built:
|
| 89 |
+
return
|
| 90 |
+
super().build(variables)
|
| 91 |
+
self.momentums = []
|
| 92 |
+
if self.momentum != 0:
|
| 93 |
+
for variable in variables:
|
| 94 |
+
self.momentums.append(
|
| 95 |
+
self.add_variable_from_reference(
|
| 96 |
+
reference_variable=variable, name="momentum"
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def update_step(self, gradient, variable, learning_rate):
|
| 101 |
+
"""Update step given gradient and the associated model variable."""
|
| 102 |
+
learning_rate = ops.cast(learning_rate, variable.dtype)
|
| 103 |
+
gradient = ops.cast(gradient, variable.dtype)
|
| 104 |
+
m = None
|
| 105 |
+
if self.momentum != 0:
|
| 106 |
+
m = self.momentums[self._get_variable_index(variable)]
|
| 107 |
+
|
| 108 |
+
if m is not None:
|
| 109 |
+
momentum = ops.cast(self.momentum, variable.dtype)
|
| 110 |
+
self.assign(
|
| 111 |
+
m,
|
| 112 |
+
ops.subtract(
|
| 113 |
+
ops.multiply(m, momentum),
|
| 114 |
+
ops.multiply(gradient, learning_rate),
|
| 115 |
+
),
|
| 116 |
+
)
|
| 117 |
+
if self.nesterov:
|
| 118 |
+
self.assign_add(
|
| 119 |
+
variable,
|
| 120 |
+
ops.subtract(
|
| 121 |
+
ops.multiply(m, momentum),
|
| 122 |
+
ops.multiply(gradient, learning_rate),
|
| 123 |
+
),
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
self.assign_add(variable, m)
|
| 127 |
+
else:
|
| 128 |
+
self.assign_sub(variable, ops.multiply(gradient, learning_rate))
|
| 129 |
+
|
| 130 |
+
def get_config(self):
|
| 131 |
+
config = super().get_config()
|
| 132 |
+
config.update(
|
| 133 |
+
{
|
| 134 |
+
"momentum": self.momentum,
|
| 135 |
+
"nesterov": self.nesterov,
|
| 136 |
+
}
|
| 137 |
+
)
|
| 138 |
+
return config
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
SGD.__doc__ = SGD.__doc__.replace(
|
| 142 |
+
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
| 143 |
+
)
|