Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/backend.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/layers.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/losses.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/backend.py +2291 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/layers.py +244 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/losses.py +20 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/image.py +1892 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/sequence.py +320 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/text.py +336 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/json_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/legacy_h5_format.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/saving_options.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/saving_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/serialization.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/json_utils.py +220 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/legacy_h5_format.py +640 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/saving_options.py +17 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/saving_utils.py +260 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/serialization.py +574 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__init__.py +207 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/loss.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/losses.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/loss.py +256 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/losses.py +2599 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__init__.py +211 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/accuracy_metrics.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/confusion_metrics.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/correlation_metrics.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/f_score_metrics.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/hinge_metrics.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/iou_metrics.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/metric.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/metrics_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/probabilistic_metrics.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/reduction_metrics.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/regression_metrics.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/accuracy_metrics.py +522 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/confusion_metrics.py +1576 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/correlation_metrics.py +215 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/f_score_metrics.py +320 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/hinge_metrics.py +100 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/iou_metrics.py +762 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/metric.py +253 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/metrics_utils.py +683 -0
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/backend.cpython-310.pyc
ADDED
|
Binary file (51.1 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/layers.cpython-310.pyc
ADDED
|
Binary file (6.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/losses.cpython-310.pyc
ADDED
|
Binary file (949 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/backend.py
ADDED
|
@@ -0,0 +1,2291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Legacy Keras 1/2 backend functions."""
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from keras.src import backend
|
| 8 |
+
from keras.src.api_export import keras_export
|
| 9 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 10 |
+
|
| 11 |
+
py_any = any
|
| 12 |
+
py_all = all
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@keras_export("keras._legacy.backend.abs")
|
| 16 |
+
def abs(x):
|
| 17 |
+
"""DEPRECATED."""
|
| 18 |
+
return tf.abs(x)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@keras_export("keras._legacy.backend.all")
|
| 22 |
+
def all(x, axis=None, keepdims=False):
|
| 23 |
+
"""DEPRECATED."""
|
| 24 |
+
x = tf.cast(x, tf.bool)
|
| 25 |
+
return tf.reduce_all(x, axis, keepdims)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@keras_export("keras._legacy.backend.any")
|
| 29 |
+
def any(x, axis=None, keepdims=False):
|
| 30 |
+
"""DEPRECATED."""
|
| 31 |
+
x = tf.cast(x, tf.bool)
|
| 32 |
+
return tf.reduce_any(x, axis, keepdims)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@keras_export("keras._legacy.backend.argmax")
|
| 36 |
+
def argmax(x, axis=-1):
|
| 37 |
+
"""DEPRECATED."""
|
| 38 |
+
return tf.argmax(x, axis)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@keras_export("keras._legacy.backend.argmin")
|
| 42 |
+
def argmin(x, axis=-1):
|
| 43 |
+
"""DEPRECATED."""
|
| 44 |
+
return tf.argmin(x, axis)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@keras_export("keras._legacy.backend.arange")
|
| 48 |
+
def arange(start, stop=None, step=1, dtype="int32"):
|
| 49 |
+
"""DEPRECATED."""
|
| 50 |
+
if stop is None and start < 0:
|
| 51 |
+
start = 0
|
| 52 |
+
result = tf.range(start, limit=stop, delta=step, name="arange")
|
| 53 |
+
if dtype != "int32":
|
| 54 |
+
result = tf.cast(result, dtype)
|
| 55 |
+
return result
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@keras_export("keras._legacy.backend.batch_dot")
|
| 59 |
+
def batch_dot(x, y, axes=None):
|
| 60 |
+
"""DEPRECATED."""
|
| 61 |
+
x_shape = x.shape
|
| 62 |
+
y_shape = y.shape
|
| 63 |
+
|
| 64 |
+
x_ndim = len(x_shape)
|
| 65 |
+
y_ndim = len(y_shape)
|
| 66 |
+
|
| 67 |
+
if x_ndim < 2 or y_ndim < 2:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"Cannot do batch_dot on inputs "
|
| 70 |
+
"with rank < 2. "
|
| 71 |
+
"Received inputs with tf.shapes "
|
| 72 |
+
+ str(x_shape)
|
| 73 |
+
+ " and "
|
| 74 |
+
+ str(y_shape)
|
| 75 |
+
+ "."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
x_batch_size = x_shape[0]
|
| 79 |
+
y_batch_size = y_shape[0]
|
| 80 |
+
|
| 81 |
+
if x_batch_size is not None and y_batch_size is not None:
|
| 82 |
+
if x_batch_size != y_batch_size:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"Cannot do batch_dot on inputs "
|
| 85 |
+
"with different batch sizes. "
|
| 86 |
+
"Received inputs with tf.shapes "
|
| 87 |
+
+ str(x_shape)
|
| 88 |
+
+ " and "
|
| 89 |
+
+ str(y_shape)
|
| 90 |
+
+ "."
|
| 91 |
+
)
|
| 92 |
+
if isinstance(axes, int):
|
| 93 |
+
axes = [axes, axes]
|
| 94 |
+
|
| 95 |
+
if axes is None:
|
| 96 |
+
if y_ndim == 2:
|
| 97 |
+
axes = [x_ndim - 1, y_ndim - 1]
|
| 98 |
+
else:
|
| 99 |
+
axes = [x_ndim - 1, y_ndim - 2]
|
| 100 |
+
|
| 101 |
+
if py_any(isinstance(a, (list, tuple)) for a in axes):
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"Multiple target dimensions are not supported. "
|
| 104 |
+
+ "Expected: None, int, (int, int), "
|
| 105 |
+
+ "Provided: "
|
| 106 |
+
+ str(axes)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# if tuple, convert to list.
|
| 110 |
+
axes = list(axes)
|
| 111 |
+
|
| 112 |
+
# convert negative indices.
|
| 113 |
+
if axes[0] < 0:
|
| 114 |
+
axes[0] += x_ndim
|
| 115 |
+
if axes[1] < 0:
|
| 116 |
+
axes[1] += y_ndim
|
| 117 |
+
|
| 118 |
+
# sanity checks
|
| 119 |
+
if 0 in axes:
|
| 120 |
+
raise ValueError(
|
| 121 |
+
"Cannot perform batch_dot over axis 0. "
|
| 122 |
+
"If your inputs are not batched, "
|
| 123 |
+
"add a dummy batch dimension to your "
|
| 124 |
+
"inputs using K.expand_dims(x, 0)"
|
| 125 |
+
)
|
| 126 |
+
a0, a1 = axes
|
| 127 |
+
d1 = x_shape[a0]
|
| 128 |
+
d2 = y_shape[a1]
|
| 129 |
+
|
| 130 |
+
if d1 is not None and d2 is not None and d1 != d2:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
"Cannot do batch_dot on inputs with tf.shapes "
|
| 133 |
+
+ str(x_shape)
|
| 134 |
+
+ " and "
|
| 135 |
+
+ str(y_shape)
|
| 136 |
+
+ " with axes="
|
| 137 |
+
+ str(axes)
|
| 138 |
+
+ ". x.shape[%d] != y.shape[%d] (%d != %d)."
|
| 139 |
+
% (axes[0], axes[1], d1, d2)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# backup ndims. Need them later.
|
| 143 |
+
orig_x_ndim = x_ndim
|
| 144 |
+
orig_y_ndim = y_ndim
|
| 145 |
+
|
| 146 |
+
# if rank is 2, expand to 3.
|
| 147 |
+
if x_ndim == 2:
|
| 148 |
+
x = tf.expand_dims(x, 1)
|
| 149 |
+
a0 += 1
|
| 150 |
+
x_ndim += 1
|
| 151 |
+
if y_ndim == 2:
|
| 152 |
+
y = tf.expand_dims(y, 2)
|
| 153 |
+
y_ndim += 1
|
| 154 |
+
|
| 155 |
+
# bring x's dimension to be reduced to last axis.
|
| 156 |
+
if a0 != x_ndim - 1:
|
| 157 |
+
pattern = list(range(x_ndim))
|
| 158 |
+
for i in range(a0, x_ndim - 1):
|
| 159 |
+
pattern[i] = pattern[i + 1]
|
| 160 |
+
pattern[-1] = a0
|
| 161 |
+
x = tf.transpose(x, pattern)
|
| 162 |
+
|
| 163 |
+
# bring y's dimension to be reduced to axis 1.
|
| 164 |
+
if a1 != 1:
|
| 165 |
+
pattern = list(range(y_ndim))
|
| 166 |
+
for i in range(a1, 1, -1):
|
| 167 |
+
pattern[i] = pattern[i - 1]
|
| 168 |
+
pattern[1] = a1
|
| 169 |
+
y = tf.transpose(y, pattern)
|
| 170 |
+
|
| 171 |
+
# normalize both inputs to rank 3.
|
| 172 |
+
if x_ndim > 3:
|
| 173 |
+
# squash middle dimensions of x.
|
| 174 |
+
x_shape = tf.shape(x)
|
| 175 |
+
x_mid_dims = x_shape[1:-1]
|
| 176 |
+
x_squashed_shape = tf.stack([x_shape[0], -1, x_shape[-1]])
|
| 177 |
+
x = tf.reshape(x, x_squashed_shape)
|
| 178 |
+
x_squashed = True
|
| 179 |
+
else:
|
| 180 |
+
x_squashed = False
|
| 181 |
+
|
| 182 |
+
if y_ndim > 3:
|
| 183 |
+
# squash trailing dimensions of y.
|
| 184 |
+
y_shape = tf.shape(y)
|
| 185 |
+
y_trail_dims = y_shape[2:]
|
| 186 |
+
y_squashed_shape = tf.stack([y_shape[0], y_shape[1], -1])
|
| 187 |
+
y = tf.reshape(y, y_squashed_shape)
|
| 188 |
+
y_squashed = True
|
| 189 |
+
else:
|
| 190 |
+
y_squashed = False
|
| 191 |
+
|
| 192 |
+
result = tf.matmul(x, y)
|
| 193 |
+
|
| 194 |
+
# if inputs were squashed, we have to reshape the matmul output.
|
| 195 |
+
output_shape = tf.shape(result)
|
| 196 |
+
do_reshape = False
|
| 197 |
+
|
| 198 |
+
if x_squashed:
|
| 199 |
+
output_shape = tf.concat(
|
| 200 |
+
[output_shape[:1], x_mid_dims, output_shape[-1:]], 0
|
| 201 |
+
)
|
| 202 |
+
do_reshape = True
|
| 203 |
+
|
| 204 |
+
if y_squashed:
|
| 205 |
+
output_shape = tf.concat([output_shape[:-1], y_trail_dims], 0)
|
| 206 |
+
do_reshape = True
|
| 207 |
+
|
| 208 |
+
if do_reshape:
|
| 209 |
+
result = tf.reshape(result, output_shape)
|
| 210 |
+
|
| 211 |
+
# if the inputs were originally rank 2, we remove the added 1 dim.
|
| 212 |
+
if orig_x_ndim == 2:
|
| 213 |
+
result = tf.squeeze(result, 1)
|
| 214 |
+
elif orig_y_ndim == 2:
|
| 215 |
+
result = tf.squeeze(result, -1)
|
| 216 |
+
|
| 217 |
+
return result
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@keras_export("keras._legacy.backend.batch_flatten")
|
| 221 |
+
def batch_flatten(x):
|
| 222 |
+
"""DEPRECATED."""
|
| 223 |
+
x = tf.reshape(x, tf.stack([-1, prod(tf.shape(x)[1:])]))
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@keras_export("keras._legacy.backend.batch_get_value")
|
| 228 |
+
def batch_get_value(tensors):
|
| 229 |
+
"""DEPRECATED."""
|
| 230 |
+
return [x.numpy() for x in tensors]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@keras_export("keras._legacy.backend.batch_set_value")
|
| 234 |
+
def batch_set_value(tuples):
|
| 235 |
+
"""DEPRECATED."""
|
| 236 |
+
if tf.executing_eagerly() or tf.inside_function():
|
| 237 |
+
for x, value in tuples:
|
| 238 |
+
value = np.asarray(value, dtype=x.dtype.name)
|
| 239 |
+
x.assign(value)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@keras_export("keras._legacy.backend.batch_normalization")
|
| 243 |
+
def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
|
| 244 |
+
"""DEPRECATED."""
|
| 245 |
+
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@keras_export("keras._legacy.backend.bias_add")
|
| 249 |
+
def bias_add(x, bias, data_format=None):
|
| 250 |
+
"""DEPRECATED."""
|
| 251 |
+
if data_format is None:
|
| 252 |
+
data_format = backend.image_data_format()
|
| 253 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 254 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 255 |
+
bias_shape = bias.shape
|
| 256 |
+
if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
f"Unexpected bias dimensions {len(bias_shape)}. "
|
| 259 |
+
f"Expected it to be 1 or {ndim(x) - 1} dimensions"
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if len(bias_shape) == 1:
|
| 263 |
+
if data_format == "channels_first":
|
| 264 |
+
return tf.nn.bias_add(x, bias, data_format="NCHW")
|
| 265 |
+
return tf.nn.bias_add(x, bias, data_format="NHWC")
|
| 266 |
+
if ndim(x) in (3, 4, 5):
|
| 267 |
+
if data_format == "channels_first":
|
| 268 |
+
bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
|
| 269 |
+
return x + reshape(bias, bias_reshape_axis)
|
| 270 |
+
return x + reshape(bias, (1,) + bias_shape)
|
| 271 |
+
return tf.nn.bias_add(x, bias)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@keras_export("keras._legacy.backend.binary_crossentropy")
|
| 275 |
+
def binary_crossentropy(target, output, from_logits=False):
|
| 276 |
+
"""DEPRECATED."""
|
| 277 |
+
target = tf.convert_to_tensor(target)
|
| 278 |
+
output = tf.convert_to_tensor(output)
|
| 279 |
+
|
| 280 |
+
if from_logits:
|
| 281 |
+
return tf.nn.sigmoid_cross_entropy_with_logits(
|
| 282 |
+
labels=target, logits=output
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
|
| 286 |
+
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
|
| 287 |
+
|
| 288 |
+
# Compute cross entropy from probabilities.
|
| 289 |
+
bce = target * tf.math.log(output + backend.epsilon())
|
| 290 |
+
bce += (1 - target) * tf.math.log(1 - output + backend.epsilon())
|
| 291 |
+
return -bce
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@keras_export("keras._legacy.backend.binary_focal_crossentropy")
|
| 295 |
+
def binary_focal_crossentropy(
|
| 296 |
+
target,
|
| 297 |
+
output,
|
| 298 |
+
apply_class_balancing=False,
|
| 299 |
+
alpha=0.25,
|
| 300 |
+
gamma=2.0,
|
| 301 |
+
from_logits=False,
|
| 302 |
+
):
|
| 303 |
+
"""DEPRECATED."""
|
| 304 |
+
sigmoidal = tf.sigmoid(output) if from_logits else output
|
| 305 |
+
|
| 306 |
+
p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal)
|
| 307 |
+
|
| 308 |
+
# Calculate focal factor
|
| 309 |
+
focal_factor = tf.pow(1.0 - p_t, gamma)
|
| 310 |
+
|
| 311 |
+
# Binary crossentropy
|
| 312 |
+
bce = binary_crossentropy(
|
| 313 |
+
target=target,
|
| 314 |
+
output=output,
|
| 315 |
+
from_logits=from_logits,
|
| 316 |
+
)
|
| 317 |
+
focal_bce = focal_factor * bce
|
| 318 |
+
|
| 319 |
+
if apply_class_balancing:
|
| 320 |
+
weight = target * alpha + (1 - target) * (1 - alpha)
|
| 321 |
+
focal_bce = weight * focal_bce
|
| 322 |
+
|
| 323 |
+
return focal_bce
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@keras_export("keras._legacy.backend.cast")
|
| 327 |
+
def cast(x, dtype):
|
| 328 |
+
"""DEPRECATED."""
|
| 329 |
+
return tf.cast(x, dtype)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@keras_export("keras._legacy.backend.cast_to_floatx")
|
| 333 |
+
def cast_to_floatx(x):
|
| 334 |
+
"""DEPRECATED."""
|
| 335 |
+
if isinstance(x, (tf.Tensor, tf.Variable, tf.SparseTensor)):
|
| 336 |
+
return tf.cast(x, dtype=backend.floatx())
|
| 337 |
+
return np.asarray(x, dtype=backend.floatx())
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@keras_export("keras._legacy.backend.categorical_crossentropy")
|
| 341 |
+
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
| 342 |
+
"""DEPRECATED."""
|
| 343 |
+
target = tf.convert_to_tensor(target)
|
| 344 |
+
output = tf.convert_to_tensor(output)
|
| 345 |
+
target.shape.assert_is_compatible_with(output.shape)
|
| 346 |
+
|
| 347 |
+
if from_logits:
|
| 348 |
+
return tf.nn.softmax_cross_entropy_with_logits(
|
| 349 |
+
labels=target, logits=output, axis=axis
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Adjust the predictions so that the probability of
|
| 353 |
+
# each class for every sample adds up to 1
|
| 354 |
+
# This is needed to ensure that the cross entropy is
|
| 355 |
+
# computed correctly.
|
| 356 |
+
output = output / tf.reduce_sum(output, axis, True)
|
| 357 |
+
|
| 358 |
+
# Compute cross entropy from probabilities.
|
| 359 |
+
epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
|
| 360 |
+
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
|
| 361 |
+
return -tf.reduce_sum(target * tf.math.log(output), axis)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@keras_export("keras._legacy.backend.categorical_focal_crossentropy")
|
| 365 |
+
def categorical_focal_crossentropy(
|
| 366 |
+
target,
|
| 367 |
+
output,
|
| 368 |
+
alpha=0.25,
|
| 369 |
+
gamma=2.0,
|
| 370 |
+
from_logits=False,
|
| 371 |
+
axis=-1,
|
| 372 |
+
):
|
| 373 |
+
"""DEPRECATED."""
|
| 374 |
+
target = tf.convert_to_tensor(target)
|
| 375 |
+
output = tf.convert_to_tensor(output)
|
| 376 |
+
target.shape.assert_is_compatible_with(output.shape)
|
| 377 |
+
|
| 378 |
+
if from_logits:
|
| 379 |
+
output = tf.nn.softmax(output, axis=axis)
|
| 380 |
+
|
| 381 |
+
# Adjust the predictions so that the probability of
|
| 382 |
+
# each class for every sample adds up to 1
|
| 383 |
+
# This is needed to ensure that the cross entropy is
|
| 384 |
+
# computed correctly.
|
| 385 |
+
output = output / tf.reduce_sum(output, axis=axis, keepdims=True)
|
| 386 |
+
|
| 387 |
+
epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
|
| 388 |
+
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
|
| 389 |
+
|
| 390 |
+
# Calculate cross entropy
|
| 391 |
+
cce = -target * tf.math.log(output)
|
| 392 |
+
|
| 393 |
+
# Calculate factors
|
| 394 |
+
modulating_factor = tf.pow(1.0 - output, gamma)
|
| 395 |
+
weighting_factor = tf.multiply(modulating_factor, alpha)
|
| 396 |
+
|
| 397 |
+
# Apply weighting factor
|
| 398 |
+
focal_cce = tf.multiply(weighting_factor, cce)
|
| 399 |
+
focal_cce = tf.reduce_sum(focal_cce, axis=axis)
|
| 400 |
+
return focal_cce
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@keras_export("keras._legacy.backend.clip")
|
| 404 |
+
def clip(x, min_value, max_value):
|
| 405 |
+
"""DEPRECATED."""
|
| 406 |
+
if isinstance(min_value, (int, float)) and isinstance(
|
| 407 |
+
max_value, (int, float)
|
| 408 |
+
):
|
| 409 |
+
if max_value < min_value:
|
| 410 |
+
max_value = min_value
|
| 411 |
+
if min_value is None:
|
| 412 |
+
min_value = -np.inf
|
| 413 |
+
if max_value is None:
|
| 414 |
+
max_value = np.inf
|
| 415 |
+
return tf.clip_by_value(x, min_value, max_value)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
@keras_export("keras._legacy.backend.concatenate")
|
| 419 |
+
def concatenate(tensors, axis=-1):
|
| 420 |
+
"""DEPRECATED."""
|
| 421 |
+
if axis < 0:
|
| 422 |
+
rank = ndim(tensors[0])
|
| 423 |
+
if rank:
|
| 424 |
+
axis %= rank
|
| 425 |
+
else:
|
| 426 |
+
axis = 0
|
| 427 |
+
|
| 428 |
+
if py_all(is_sparse(x) for x in tensors):
|
| 429 |
+
return tf.compat.v1.sparse_concat(axis, tensors)
|
| 430 |
+
elif py_all(isinstance(x, tf.RaggedTensor) for x in tensors):
|
| 431 |
+
return tf.concat(tensors, axis)
|
| 432 |
+
else:
|
| 433 |
+
return tf.concat([to_dense(x) for x in tensors], axis)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
@keras_export("keras._legacy.backend.constant")
|
| 437 |
+
def constant(value, dtype=None, shape=None, name=None):
|
| 438 |
+
"""DEPRECATED."""
|
| 439 |
+
if dtype is None:
|
| 440 |
+
dtype = backend.floatx()
|
| 441 |
+
|
| 442 |
+
return tf.constant(value, dtype=dtype, shape=shape, name=name)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _preprocess_conv1d_input(x, data_format):
|
| 446 |
+
tf_data_format = "NWC" # to pass TF Conv2dNative operations
|
| 447 |
+
if data_format == "channels_first":
|
| 448 |
+
tf_data_format = "NCW"
|
| 449 |
+
return x, tf_data_format
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def _preprocess_conv2d_input(x, data_format, force_transpose=False):
|
| 453 |
+
tf_data_format = "NHWC"
|
| 454 |
+
if data_format == "channels_first":
|
| 455 |
+
if force_transpose:
|
| 456 |
+
x = tf.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
|
| 457 |
+
else:
|
| 458 |
+
tf_data_format = "NCHW"
|
| 459 |
+
return x, tf_data_format
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _preprocess_conv3d_input(x, data_format):
|
| 463 |
+
tf_data_format = "NDHWC"
|
| 464 |
+
if data_format == "channels_first":
|
| 465 |
+
tf_data_format = "NCDHW"
|
| 466 |
+
return x, tf_data_format
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _preprocess_padding(padding):
|
| 470 |
+
if padding == "same":
|
| 471 |
+
padding = "SAME"
|
| 472 |
+
elif padding == "valid":
|
| 473 |
+
padding = "VALID"
|
| 474 |
+
else:
|
| 475 |
+
raise ValueError(f"Invalid padding: {padding}")
|
| 476 |
+
return padding
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
@keras_export("keras._legacy.backend.conv1d")
|
| 480 |
+
def conv1d(
|
| 481 |
+
x, kernel, strides=1, padding="valid", data_format=None, dilation_rate=1
|
| 482 |
+
):
|
| 483 |
+
"""DEPRECATED."""
|
| 484 |
+
if data_format is None:
|
| 485 |
+
data_format = backend.image_data_format()
|
| 486 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 487 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 488 |
+
|
| 489 |
+
kernel_shape = kernel.shape.as_list()
|
| 490 |
+
if padding == "causal":
|
| 491 |
+
# causal (dilated) convolution:
|
| 492 |
+
left_pad = dilation_rate * (kernel_shape[0] - 1)
|
| 493 |
+
x = temporal_padding(x, (left_pad, 0))
|
| 494 |
+
padding = "valid"
|
| 495 |
+
padding = _preprocess_padding(padding)
|
| 496 |
+
|
| 497 |
+
x, tf_data_format = _preprocess_conv1d_input(x, data_format)
|
| 498 |
+
x = tf.compat.v1.nn.convolution(
|
| 499 |
+
input=x,
|
| 500 |
+
filter=kernel,
|
| 501 |
+
dilation_rate=dilation_rate,
|
| 502 |
+
strides=strides,
|
| 503 |
+
padding=padding,
|
| 504 |
+
data_format=tf_data_format,
|
| 505 |
+
)
|
| 506 |
+
if data_format == "channels_first" and tf_data_format == "NWC":
|
| 507 |
+
x = tf.transpose(x, (0, 2, 1)) # NWC -> NCW
|
| 508 |
+
return x
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
@keras_export("keras._legacy.backend.conv2d")
|
| 512 |
+
def conv2d(
|
| 513 |
+
x,
|
| 514 |
+
kernel,
|
| 515 |
+
strides=(1, 1),
|
| 516 |
+
padding="valid",
|
| 517 |
+
data_format=None,
|
| 518 |
+
dilation_rate=(1, 1),
|
| 519 |
+
):
|
| 520 |
+
"""DEPRECATED."""
|
| 521 |
+
if data_format is None:
|
| 522 |
+
data_format = backend.image_data_format()
|
| 523 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 524 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 525 |
+
|
| 526 |
+
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
|
| 527 |
+
padding = _preprocess_padding(padding)
|
| 528 |
+
x = tf.compat.v1.nn.convolution(
|
| 529 |
+
input=x,
|
| 530 |
+
filter=kernel,
|
| 531 |
+
dilation_rate=dilation_rate,
|
| 532 |
+
strides=strides,
|
| 533 |
+
padding=padding,
|
| 534 |
+
data_format=tf_data_format,
|
| 535 |
+
)
|
| 536 |
+
if data_format == "channels_first" and tf_data_format == "NHWC":
|
| 537 |
+
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
|
| 538 |
+
return x
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
@keras_export("keras._legacy.backend.conv2d_transpose")
|
| 542 |
+
def conv2d_transpose(
|
| 543 |
+
x,
|
| 544 |
+
kernel,
|
| 545 |
+
output_shape,
|
| 546 |
+
strides=(1, 1),
|
| 547 |
+
padding="valid",
|
| 548 |
+
data_format=None,
|
| 549 |
+
dilation_rate=(1, 1),
|
| 550 |
+
):
|
| 551 |
+
"""DEPRECATED."""
|
| 552 |
+
if data_format is None:
|
| 553 |
+
data_format = backend.image_data_format()
|
| 554 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 555 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 556 |
+
|
| 557 |
+
# `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
|
| 558 |
+
if data_format == "channels_first" and dilation_rate != (1, 1):
|
| 559 |
+
force_transpose = True
|
| 560 |
+
else:
|
| 561 |
+
force_transpose = False
|
| 562 |
+
|
| 563 |
+
x, tf_data_format = _preprocess_conv2d_input(
|
| 564 |
+
x, data_format, force_transpose
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
if data_format == "channels_first" and tf_data_format == "NHWC":
|
| 568 |
+
output_shape = (
|
| 569 |
+
output_shape[0],
|
| 570 |
+
output_shape[2],
|
| 571 |
+
output_shape[3],
|
| 572 |
+
output_shape[1],
|
| 573 |
+
)
|
| 574 |
+
if output_shape[0] is None:
|
| 575 |
+
output_shape = (tf.shape(x)[0],) + tuple(output_shape[1:])
|
| 576 |
+
|
| 577 |
+
if isinstance(output_shape, (tuple, list)):
|
| 578 |
+
output_shape = tf.stack(list(output_shape))
|
| 579 |
+
|
| 580 |
+
padding = _preprocess_padding(padding)
|
| 581 |
+
if tf_data_format == "NHWC":
|
| 582 |
+
strides = (1,) + strides + (1,)
|
| 583 |
+
else:
|
| 584 |
+
strides = (1, 1) + strides
|
| 585 |
+
|
| 586 |
+
if dilation_rate == (1, 1):
|
| 587 |
+
x = tf.compat.v1.nn.conv2d_transpose(
|
| 588 |
+
x,
|
| 589 |
+
kernel,
|
| 590 |
+
output_shape,
|
| 591 |
+
strides,
|
| 592 |
+
padding=padding,
|
| 593 |
+
data_format=tf_data_format,
|
| 594 |
+
)
|
| 595 |
+
else:
|
| 596 |
+
if dilation_rate[0] != dilation_rate[1]:
|
| 597 |
+
raise ValueError(
|
| 598 |
+
"Expected the 2 dimensions of the `dilation_rate` argument "
|
| 599 |
+
"to be equal to each other. "
|
| 600 |
+
f"Received: dilation_rate={dilation_rate}"
|
| 601 |
+
)
|
| 602 |
+
x = tf.nn.atrous_conv2d_transpose(
|
| 603 |
+
x, kernel, output_shape, rate=dilation_rate[0], padding=padding
|
| 604 |
+
)
|
| 605 |
+
if data_format == "channels_first" and tf_data_format == "NHWC":
|
| 606 |
+
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
|
| 607 |
+
return x
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
@keras_export("keras._legacy.backend.conv3d")
|
| 611 |
+
def conv3d(
|
| 612 |
+
x,
|
| 613 |
+
kernel,
|
| 614 |
+
strides=(1, 1, 1),
|
| 615 |
+
padding="valid",
|
| 616 |
+
data_format=None,
|
| 617 |
+
dilation_rate=(1, 1, 1),
|
| 618 |
+
):
|
| 619 |
+
"""DEPRECATED."""
|
| 620 |
+
if data_format is None:
|
| 621 |
+
data_format = backend.image_data_format()
|
| 622 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 623 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 624 |
+
|
| 625 |
+
x, tf_data_format = _preprocess_conv3d_input(x, data_format)
|
| 626 |
+
padding = _preprocess_padding(padding)
|
| 627 |
+
x = tf.compat.v1.nn.convolution(
|
| 628 |
+
input=x,
|
| 629 |
+
filter=kernel,
|
| 630 |
+
dilation_rate=dilation_rate,
|
| 631 |
+
strides=strides,
|
| 632 |
+
padding=padding,
|
| 633 |
+
data_format=tf_data_format,
|
| 634 |
+
)
|
| 635 |
+
if data_format == "channels_first" and tf_data_format == "NDHWC":
|
| 636 |
+
x = tf.transpose(x, (0, 4, 1, 2, 3))
|
| 637 |
+
return x
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
@keras_export("keras._legacy.backend.cos")
|
| 641 |
+
def cos(x):
|
| 642 |
+
"""DEPRECATED."""
|
| 643 |
+
return tf.cos(x)
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
@keras_export("keras._legacy.backend.count_params")
|
| 647 |
+
def count_params(x):
|
| 648 |
+
"""DEPRECATED."""
|
| 649 |
+
return np.prod(x.shape.as_list())
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
@keras_export("keras._legacy.backend.ctc_batch_cost")
|
| 653 |
+
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
|
| 654 |
+
"""DEPRECATED."""
|
| 655 |
+
label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)
|
| 656 |
+
input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)
|
| 657 |
+
sparse_labels = tf.cast(
|
| 658 |
+
ctc_label_dense_to_sparse(y_true, label_length), tf.int32
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
y_pred = tf.math.log(
|
| 662 |
+
tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon()
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
return tf.expand_dims(
|
| 666 |
+
tf.compat.v1.nn.ctc_loss(
|
| 667 |
+
inputs=y_pred, labels=sparse_labels, sequence_length=input_length
|
| 668 |
+
),
|
| 669 |
+
1,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
@keras_export("keras._legacy.backend.ctc_label_dense_to_sparse")
|
| 674 |
+
def ctc_label_dense_to_sparse(labels, label_lengths):
|
| 675 |
+
"""DEPRECATED."""
|
| 676 |
+
label_shape = tf.shape(labels)
|
| 677 |
+
num_batches_tns = tf.stack([label_shape[0]])
|
| 678 |
+
max_num_labels_tns = tf.stack([label_shape[1]])
|
| 679 |
+
|
| 680 |
+
def range_less_than(old_input, current_input):
|
| 681 |
+
return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill(
|
| 682 |
+
max_num_labels_tns, current_input
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool)
|
| 686 |
+
dense_mask = tf.compat.v1.scan(
|
| 687 |
+
range_less_than, label_lengths, initializer=init, parallel_iterations=1
|
| 688 |
+
)
|
| 689 |
+
dense_mask = dense_mask[:, 0, :]
|
| 690 |
+
|
| 691 |
+
label_array = tf.reshape(
|
| 692 |
+
tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape
|
| 693 |
+
)
|
| 694 |
+
label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)
|
| 695 |
+
|
| 696 |
+
batch_array = tf.transpose(
|
| 697 |
+
tf.reshape(
|
| 698 |
+
tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns),
|
| 699 |
+
reverse(label_shape, 0),
|
| 700 |
+
)
|
| 701 |
+
)
|
| 702 |
+
batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
|
| 703 |
+
indices = tf.transpose(
|
| 704 |
+
tf.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1])
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
vals_sparse = tf.compat.v1.gather_nd(labels, indices)
|
| 708 |
+
|
| 709 |
+
return tf.SparseTensor(
|
| 710 |
+
tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64)
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
@keras_export("keras._legacy.backend.ctc_decode")
|
| 715 |
+
def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
|
| 716 |
+
"""DEPRECATED."""
|
| 717 |
+
input_shape = tf.shape(y_pred)
|
| 718 |
+
num_samples, num_steps = input_shape[0], input_shape[1]
|
| 719 |
+
y_pred = tf.math.log(
|
| 720 |
+
tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon()
|
| 721 |
+
)
|
| 722 |
+
input_length = tf.cast(input_length, tf.int32)
|
| 723 |
+
|
| 724 |
+
if greedy:
|
| 725 |
+
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
|
| 726 |
+
inputs=y_pred, sequence_length=input_length
|
| 727 |
+
)
|
| 728 |
+
else:
|
| 729 |
+
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
|
| 730 |
+
inputs=y_pred,
|
| 731 |
+
sequence_length=input_length,
|
| 732 |
+
beam_width=beam_width,
|
| 733 |
+
top_paths=top_paths,
|
| 734 |
+
)
|
| 735 |
+
decoded_dense = []
|
| 736 |
+
for st in decoded:
|
| 737 |
+
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
|
| 738 |
+
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
|
| 739 |
+
return (decoded_dense, log_prob)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
@keras_export("keras._legacy.backend.cumsum")
|
| 743 |
+
def cumsum(x, axis=0):
|
| 744 |
+
"""DEPRECATED."""
|
| 745 |
+
return tf.cumsum(x, axis=axis)
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
@keras_export("keras._legacy.backend.cumprod")
|
| 749 |
+
def cumprod(x, axis=0):
|
| 750 |
+
"""DEPRECATED."""
|
| 751 |
+
return tf.math.cumprod(x, axis=axis)
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
@keras_export("keras._legacy.backend.depthwise_conv2d")
|
| 755 |
+
def depthwise_conv2d(
|
| 756 |
+
x,
|
| 757 |
+
depthwise_kernel,
|
| 758 |
+
strides=(1, 1),
|
| 759 |
+
padding="valid",
|
| 760 |
+
data_format=None,
|
| 761 |
+
dilation_rate=(1, 1),
|
| 762 |
+
):
|
| 763 |
+
"""DEPRECATED."""
|
| 764 |
+
if data_format is None:
|
| 765 |
+
data_format = backend.image_data_format()
|
| 766 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 767 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 768 |
+
|
| 769 |
+
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
|
| 770 |
+
padding = _preprocess_padding(padding)
|
| 771 |
+
if tf_data_format == "NHWC":
|
| 772 |
+
strides = (1,) + strides + (1,)
|
| 773 |
+
else:
|
| 774 |
+
strides = (1, 1) + strides
|
| 775 |
+
|
| 776 |
+
x = tf.nn.depthwise_conv2d(
|
| 777 |
+
x,
|
| 778 |
+
depthwise_kernel,
|
| 779 |
+
strides=strides,
|
| 780 |
+
padding=padding,
|
| 781 |
+
dilations=dilation_rate,
|
| 782 |
+
data_format=tf_data_format,
|
| 783 |
+
)
|
| 784 |
+
if data_format == "channels_first" and tf_data_format == "NHWC":
|
| 785 |
+
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
|
| 786 |
+
return x
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
@keras_export("keras._legacy.backend.dot")
|
| 790 |
+
def dot(x, y):
|
| 791 |
+
"""DEPRECATED."""
|
| 792 |
+
if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
|
| 793 |
+
x_shape = []
|
| 794 |
+
for i, s in zip(x.shape, tf.unstack(tf.shape(x))):
|
| 795 |
+
if i is not None:
|
| 796 |
+
x_shape.append(i)
|
| 797 |
+
else:
|
| 798 |
+
x_shape.append(s)
|
| 799 |
+
x_shape = tuple(x_shape)
|
| 800 |
+
y_shape = []
|
| 801 |
+
for i, s in zip(y.shape, tf.unstack(tf.shape(y))):
|
| 802 |
+
if i is not None:
|
| 803 |
+
y_shape.append(i)
|
| 804 |
+
else:
|
| 805 |
+
y_shape.append(s)
|
| 806 |
+
y_shape = tuple(y_shape)
|
| 807 |
+
y_permute_dim = list(range(ndim(y)))
|
| 808 |
+
y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
|
| 809 |
+
xt = tf.reshape(x, [-1, x_shape[-1]])
|
| 810 |
+
yt = tf.reshape(tf.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
|
| 811 |
+
return tf.reshape(
|
| 812 |
+
tf.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:]
|
| 813 |
+
)
|
| 814 |
+
if is_sparse(x):
|
| 815 |
+
out = tf.sparse.sparse_dense_matmul(x, y)
|
| 816 |
+
else:
|
| 817 |
+
out = tf.matmul(x, y)
|
| 818 |
+
return out
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
@keras_export("keras._legacy.backend.dropout")
|
| 822 |
+
def dropout(x, level, noise_shape=None, seed=None):
|
| 823 |
+
"""DEPRECATED."""
|
| 824 |
+
if seed is None:
|
| 825 |
+
seed = np.random.randint(10e6)
|
| 826 |
+
return tf.nn.dropout(x, rate=level, noise_shape=noise_shape, seed=seed)
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
@keras_export("keras._legacy.backend.dtype")
|
| 830 |
+
def dtype(x):
|
| 831 |
+
"""DEPRECATED."""
|
| 832 |
+
return x.dtype.base_dtype.name
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
@keras_export("keras._legacy.backend.elu")
|
| 836 |
+
def elu(x, alpha=1.0):
|
| 837 |
+
"""DEPRECATED."""
|
| 838 |
+
res = tf.nn.elu(x)
|
| 839 |
+
if alpha == 1:
|
| 840 |
+
return res
|
| 841 |
+
else:
|
| 842 |
+
return tf.where(x > 0, res, alpha * res)
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
@keras_export("keras._legacy.backend.equal")
|
| 846 |
+
def equal(x, y):
|
| 847 |
+
"""DEPRECATED."""
|
| 848 |
+
return tf.equal(x, y)
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
@keras_export("keras._legacy.backend.eval")
|
| 852 |
+
def eval(x):
|
| 853 |
+
"""DEPRECATED."""
|
| 854 |
+
return get_value(to_dense(x))
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
@keras_export("keras._legacy.backend.exp")
|
| 858 |
+
def exp(x):
|
| 859 |
+
"""DEPRECATED."""
|
| 860 |
+
return tf.exp(x)
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
@keras_export("keras._legacy.backend.expand_dims")
|
| 864 |
+
def expand_dims(x, axis=-1):
|
| 865 |
+
"""DEPRECATED."""
|
| 866 |
+
return tf.expand_dims(x, axis)
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
@keras_export("keras._legacy.backend.eye")
|
| 870 |
+
def eye(size, dtype=None, name=None):
|
| 871 |
+
"""DEPRECATED."""
|
| 872 |
+
if dtype is None:
|
| 873 |
+
dtype = backend.floatx()
|
| 874 |
+
tf_dtype = tf.as_dtype(dtype)
|
| 875 |
+
return variable(tf.eye(size, dtype=tf_dtype), dtype, name)
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
@keras_export("keras._legacy.backend.flatten")
|
| 879 |
+
def flatten(x):
|
| 880 |
+
"""DEPRECATED."""
|
| 881 |
+
return tf.reshape(x, [-1])
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
@keras_export("keras._legacy.backend.foldl")
|
| 885 |
+
def foldl(fn, elems, initializer=None, name=None):
|
| 886 |
+
"""DEPRECATED."""
|
| 887 |
+
return tf.compat.v1.foldl(fn, elems, initializer=initializer, name=name)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
@keras_export("keras._legacy.backend.foldr")
|
| 891 |
+
def foldr(fn, elems, initializer=None, name=None):
|
| 892 |
+
"""DEPRECATED."""
|
| 893 |
+
return tf.compat.v1.foldr(fn, elems, initializer=initializer, name=name)
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
@keras_export("keras._legacy.backend.gather")
|
| 897 |
+
def gather(reference, indices):
|
| 898 |
+
"""DEPRECATED."""
|
| 899 |
+
return tf.compat.v1.gather(reference, indices)
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
@keras_export("keras._legacy.backend.get_value")
|
| 903 |
+
def get_value(x):
|
| 904 |
+
"""DEPRECATED."""
|
| 905 |
+
if not tf.is_tensor(x):
|
| 906 |
+
return x
|
| 907 |
+
if tf.executing_eagerly() or isinstance(x, tf.__internal__.EagerTensor):
|
| 908 |
+
return x.numpy()
|
| 909 |
+
if not getattr(x, "_in_graph_mode", True):
|
| 910 |
+
# This is a variable which was created in an eager context, but is being
|
| 911 |
+
# evaluated from a Graph.
|
| 912 |
+
with tf.__internal__.eager_context.eager_mode():
|
| 913 |
+
return x.numpy()
|
| 914 |
+
with tf.init_scope():
|
| 915 |
+
return x.numpy()
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
@keras_export("keras._legacy.backend.gradients")
|
| 919 |
+
def gradients(loss, variables):
|
| 920 |
+
"""DEPRECATED."""
|
| 921 |
+
return tf.compat.v1.gradients(
|
| 922 |
+
loss, variables, colocate_gradients_with_ops=True
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
@keras_export("keras._legacy.backend.greater")
|
| 927 |
+
def greater(x, y):
|
| 928 |
+
"""DEPRECATED."""
|
| 929 |
+
return tf.greater(x, y)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
@keras_export("keras._legacy.backend.greater_equal")
|
| 933 |
+
def greater_equal(x, y):
|
| 934 |
+
"""DEPRECATED."""
|
| 935 |
+
return tf.greater_equal(x, y)
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
@keras_export("keras._legacy.backend.hard_sigmoid")
|
| 939 |
+
def hard_sigmoid(x):
|
| 940 |
+
"""DEPRECATED."""
|
| 941 |
+
point_two = tf.convert_to_tensor(0.2, dtype=x.dtype)
|
| 942 |
+
point_five = tf.convert_to_tensor(0.5, dtype=x.dtype)
|
| 943 |
+
x = tf.multiply(x, point_two)
|
| 944 |
+
x = tf.add(x, point_five)
|
| 945 |
+
x = tf.clip_by_value(x, 0.0, 1.0)
|
| 946 |
+
return x
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
@keras_export("keras._legacy.backend.in_top_k")
|
| 950 |
+
def in_top_k(predictions, targets, k):
|
| 951 |
+
"""DEPRECATED."""
|
| 952 |
+
return tf.compat.v1.math.in_top_k(predictions, targets, k)
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
@keras_export("keras._legacy.backend.int_shape")
|
| 956 |
+
def int_shape(x):
|
| 957 |
+
"""DEPRECATED."""
|
| 958 |
+
try:
|
| 959 |
+
shape = x.shape
|
| 960 |
+
if not isinstance(shape, tuple):
|
| 961 |
+
shape = tuple(shape.as_list())
|
| 962 |
+
return shape
|
| 963 |
+
except ValueError:
|
| 964 |
+
return None
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
@keras_export("keras._legacy.backend.is_sparse")
|
| 968 |
+
def is_sparse(tensor):
|
| 969 |
+
"""DEPRECATED."""
|
| 970 |
+
spec = getattr(tensor, "_type_spec", None)
|
| 971 |
+
if spec is not None:
|
| 972 |
+
return isinstance(spec, tf.SparseTensorSpec)
|
| 973 |
+
return isinstance(tensor, tf.SparseTensor)
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
@keras_export("keras._legacy.backend.l2_normalize")
|
| 977 |
+
def l2_normalize(x, axis=None):
|
| 978 |
+
"""DEPRECATED."""
|
| 979 |
+
return tf.linalg.l2_normalize(x, axis=axis)
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
@keras_export("keras._legacy.backend.less")
|
| 983 |
+
def less(x, y):
|
| 984 |
+
"""DEPRECATED."""
|
| 985 |
+
return tf.less(x, y)
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
@keras_export("keras._legacy.backend.less_equal")
|
| 989 |
+
def less_equal(x, y):
|
| 990 |
+
"""DEPRECATED."""
|
| 991 |
+
return tf.less_equal(x, y)
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
@keras_export("keras._legacy.backend.log")
|
| 995 |
+
def log(x):
|
| 996 |
+
"""DEPRECATED."""
|
| 997 |
+
return tf.math.log(x)
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
@keras_export("keras._legacy.backend.map_fn")
|
| 1001 |
+
def map_fn(fn, elems, name=None, dtype=None):
|
| 1002 |
+
"""DEPRECATED."""
|
| 1003 |
+
return tf.compat.v1.map_fn(fn, elems, name=name, dtype=dtype)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
@keras_export("keras._legacy.backend.max")
|
| 1007 |
+
def max(x, axis=None, keepdims=False):
|
| 1008 |
+
"""DEPRECATED."""
|
| 1009 |
+
return tf.reduce_max(x, axis, keepdims)
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
@keras_export("keras._legacy.backend.maximum")
|
| 1013 |
+
def maximum(x, y):
|
| 1014 |
+
"""DEPRECATED."""
|
| 1015 |
+
return tf.maximum(x, y)
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
@keras_export("keras._legacy.backend.mean")
|
| 1019 |
+
def mean(x, axis=None, keepdims=False):
|
| 1020 |
+
"""DEPRECATED."""
|
| 1021 |
+
if x.dtype.base_dtype == tf.bool:
|
| 1022 |
+
x = tf.cast(x, backend.floatx())
|
| 1023 |
+
return tf.reduce_mean(x, axis, keepdims)
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
@keras_export("keras._legacy.backend.min")
|
| 1027 |
+
def min(x, axis=None, keepdims=False):
|
| 1028 |
+
"""DEPRECATED."""
|
| 1029 |
+
return tf.reduce_min(x, axis, keepdims)
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
@keras_export("keras._legacy.backend.minimum")
|
| 1033 |
+
def minimum(x, y):
|
| 1034 |
+
"""DEPRECATED."""
|
| 1035 |
+
return tf.minimum(x, y)
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
@keras_export("keras._legacy.backend.moving_average_update")
|
| 1039 |
+
def moving_average_update(x, value, momentum):
|
| 1040 |
+
"""DEPRECATED."""
|
| 1041 |
+
momentum = tf.cast(momentum, x.dtype)
|
| 1042 |
+
value = tf.cast(value, x.dtype)
|
| 1043 |
+
return x.assign_sub((x - value) * (1 - momentum))
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
@keras_export("keras._legacy.backend.name_scope")
|
| 1047 |
+
def name_scope(name):
|
| 1048 |
+
"""DEPRECATED."""
|
| 1049 |
+
return tf.name_scope(name)
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
@keras_export("keras._legacy.backend.ndim")
|
| 1053 |
+
def ndim(x):
|
| 1054 |
+
"""DEPRECATED."""
|
| 1055 |
+
return x.shape.rank
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
@keras_export("keras._legacy.backend.not_equal")
|
| 1059 |
+
def not_equal(x, y):
|
| 1060 |
+
"""DEPRECATED."""
|
| 1061 |
+
return tf.not_equal(x, y)
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
@keras_export("keras._legacy.backend.one_hot")
|
| 1065 |
+
def one_hot(indices, num_classes):
|
| 1066 |
+
"""DEPRECATED."""
|
| 1067 |
+
return tf.one_hot(indices, depth=num_classes, axis=-1)
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
@keras_export("keras._legacy.backend.ones")
|
| 1071 |
+
def ones(shape, dtype=None, name=None):
|
| 1072 |
+
"""DEPRECATED."""
|
| 1073 |
+
with tf.init_scope():
|
| 1074 |
+
if dtype is None:
|
| 1075 |
+
dtype = backend.floatx()
|
| 1076 |
+
tf_dtype = tf.as_dtype(dtype)
|
| 1077 |
+
v = tf.ones(shape=shape, dtype=tf_dtype, name=name)
|
| 1078 |
+
if py_all(v.shape.as_list()):
|
| 1079 |
+
return variable(v, dtype=dtype, name=name)
|
| 1080 |
+
return v
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
@keras_export("keras._legacy.backend.ones_like")
|
| 1084 |
+
def ones_like(x, dtype=None, name=None):
|
| 1085 |
+
"""DEPRECATED."""
|
| 1086 |
+
return tf.ones_like(x, dtype=dtype, name=name)
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
@keras_export("keras._legacy.backend.permute_dimensions")
|
| 1090 |
+
def permute_dimensions(x, pattern):
|
| 1091 |
+
"""DEPRECATED."""
|
| 1092 |
+
return tf.transpose(x, perm=pattern)
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
@keras_export("keras._legacy.backend.pool2d")
|
| 1096 |
+
def pool2d(
|
| 1097 |
+
x,
|
| 1098 |
+
pool_size,
|
| 1099 |
+
strides=(1, 1),
|
| 1100 |
+
padding="valid",
|
| 1101 |
+
data_format=None,
|
| 1102 |
+
pool_mode="max",
|
| 1103 |
+
):
|
| 1104 |
+
"""DEPRECATED."""
|
| 1105 |
+
if data_format is None:
|
| 1106 |
+
data_format = backend.image_data_format()
|
| 1107 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 1108 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 1109 |
+
if len(pool_size) != 2:
|
| 1110 |
+
raise ValueError("`pool_size` must be a tuple of 2 integers.")
|
| 1111 |
+
if len(strides) != 2:
|
| 1112 |
+
raise ValueError("`strides` must be a tuple of 2 integers.")
|
| 1113 |
+
|
| 1114 |
+
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
|
| 1115 |
+
padding = _preprocess_padding(padding)
|
| 1116 |
+
if tf_data_format == "NHWC":
|
| 1117 |
+
strides = (1,) + strides + (1,)
|
| 1118 |
+
pool_size = (1,) + pool_size + (1,)
|
| 1119 |
+
else:
|
| 1120 |
+
strides = (1, 1) + strides
|
| 1121 |
+
pool_size = (1, 1) + pool_size
|
| 1122 |
+
|
| 1123 |
+
if pool_mode == "max":
|
| 1124 |
+
x = tf.compat.v1.nn.max_pool(
|
| 1125 |
+
x, pool_size, strides, padding=padding, data_format=tf_data_format
|
| 1126 |
+
)
|
| 1127 |
+
elif pool_mode == "avg":
|
| 1128 |
+
x = tf.compat.v1.nn.avg_pool(
|
| 1129 |
+
x, pool_size, strides, padding=padding, data_format=tf_data_format
|
| 1130 |
+
)
|
| 1131 |
+
else:
|
| 1132 |
+
raise ValueError("Invalid pooling mode: " + str(pool_mode))
|
| 1133 |
+
|
| 1134 |
+
if data_format == "channels_first" and tf_data_format == "NHWC":
|
| 1135 |
+
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
|
| 1136 |
+
return x
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
@keras_export("keras._legacy.backend.pool3d")
|
| 1140 |
+
def pool3d(
|
| 1141 |
+
x,
|
| 1142 |
+
pool_size,
|
| 1143 |
+
strides=(1, 1, 1),
|
| 1144 |
+
padding="valid",
|
| 1145 |
+
data_format=None,
|
| 1146 |
+
pool_mode="max",
|
| 1147 |
+
):
|
| 1148 |
+
"""DEPRECATED."""
|
| 1149 |
+
if data_format is None:
|
| 1150 |
+
data_format = backend.image_data_format()
|
| 1151 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 1152 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 1153 |
+
|
| 1154 |
+
x, tf_data_format = _preprocess_conv3d_input(x, data_format)
|
| 1155 |
+
padding = _preprocess_padding(padding)
|
| 1156 |
+
if tf_data_format == "NDHWC":
|
| 1157 |
+
strides = (1,) + strides + (1,)
|
| 1158 |
+
pool_size = (1,) + pool_size + (1,)
|
| 1159 |
+
else:
|
| 1160 |
+
strides = (1, 1) + strides
|
| 1161 |
+
pool_size = (1, 1) + pool_size
|
| 1162 |
+
|
| 1163 |
+
if pool_mode == "max":
|
| 1164 |
+
x = tf.nn.max_pool3d(
|
| 1165 |
+
x, pool_size, strides, padding=padding, data_format=tf_data_format
|
| 1166 |
+
)
|
| 1167 |
+
elif pool_mode == "avg":
|
| 1168 |
+
x = tf.nn.avg_pool3d(
|
| 1169 |
+
x, pool_size, strides, padding=padding, data_format=tf_data_format
|
| 1170 |
+
)
|
| 1171 |
+
else:
|
| 1172 |
+
raise ValueError("Invalid pooling mode: " + str(pool_mode))
|
| 1173 |
+
|
| 1174 |
+
if data_format == "channels_first" and tf_data_format == "NDHWC":
|
| 1175 |
+
x = tf.transpose(x, (0, 4, 1, 2, 3))
|
| 1176 |
+
return x
|
| 1177 |
+
|
| 1178 |
+
|
| 1179 |
+
@keras_export("keras._legacy.backend.pow")
|
| 1180 |
+
def pow(x, a):
|
| 1181 |
+
"""DEPRECATED."""
|
| 1182 |
+
return tf.pow(x, a)
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
@keras_export("keras._legacy.backend.prod")
|
| 1186 |
+
def prod(x, axis=None, keepdims=False):
|
| 1187 |
+
"""DEPRECATED."""
|
| 1188 |
+
return tf.reduce_prod(x, axis, keepdims)
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
@keras_export("keras._legacy.backend.random_bernoulli")
|
| 1192 |
+
def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
|
| 1193 |
+
"""DEPRECATED."""
|
| 1194 |
+
if dtype is None:
|
| 1195 |
+
dtype = backend.floatx()
|
| 1196 |
+
if seed is None:
|
| 1197 |
+
seed = np.random.randint(10e6)
|
| 1198 |
+
return tf.where(
|
| 1199 |
+
tf.random.uniform(shape, dtype=dtype, seed=seed) <= p,
|
| 1200 |
+
tf.ones(shape, dtype=dtype),
|
| 1201 |
+
tf.zeros(shape, dtype=dtype),
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
@keras_export("keras._legacy.backend.random_normal")
|
| 1206 |
+
def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
| 1207 |
+
"""DEPRECATED."""
|
| 1208 |
+
if dtype is None:
|
| 1209 |
+
dtype = backend.floatx()
|
| 1210 |
+
if seed is None:
|
| 1211 |
+
seed = np.random.randint(10e6)
|
| 1212 |
+
return tf.random.normal(
|
| 1213 |
+
shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
@keras_export("keras._legacy.backend.random_normal_variable")
|
| 1218 |
+
def random_normal_variable(
|
| 1219 |
+
shape, mean, scale, dtype=None, name=None, seed=None
|
| 1220 |
+
):
|
| 1221 |
+
"""DEPRECATED."""
|
| 1222 |
+
if dtype is None:
|
| 1223 |
+
dtype = backend.floatx()
|
| 1224 |
+
tf_dtype = tf.as_dtype(dtype)
|
| 1225 |
+
if seed is None:
|
| 1226 |
+
# ensure that randomness is conditioned by the Numpy RNG
|
| 1227 |
+
seed = np.random.randint(10e8)
|
| 1228 |
+
value = tf.compat.v1.random_normal_initializer(
|
| 1229 |
+
mean, scale, dtype=tf_dtype, seed=seed
|
| 1230 |
+
)(shape)
|
| 1231 |
+
return variable(value, dtype=dtype, name=name)
|
| 1232 |
+
|
| 1233 |
+
|
| 1234 |
+
@keras_export("keras._legacy.backend.random_uniform")
|
| 1235 |
+
def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
|
| 1236 |
+
"""DEPRECATED."""
|
| 1237 |
+
if dtype is None:
|
| 1238 |
+
dtype = backend.floatx()
|
| 1239 |
+
if seed is None:
|
| 1240 |
+
seed = np.random.randint(10e6)
|
| 1241 |
+
return tf.random.uniform(
|
| 1242 |
+
shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed
|
| 1243 |
+
)
|
| 1244 |
+
|
| 1245 |
+
|
| 1246 |
+
@keras_export("keras._legacy.backend.random_uniform_variable")
|
| 1247 |
+
def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
|
| 1248 |
+
"""DEPRECATED."""
|
| 1249 |
+
if dtype is None:
|
| 1250 |
+
dtype = backend.floatx()
|
| 1251 |
+
tf_dtype = tf.as_dtype(dtype)
|
| 1252 |
+
if seed is None:
|
| 1253 |
+
# ensure that randomness is conditioned by the Numpy RNG
|
| 1254 |
+
seed = np.random.randint(10e8)
|
| 1255 |
+
value = tf.compat.v1.random_uniform_initializer(
|
| 1256 |
+
low, high, dtype=tf_dtype, seed=seed
|
| 1257 |
+
)(shape)
|
| 1258 |
+
return variable(value, dtype=dtype, name=name)
|
| 1259 |
+
|
| 1260 |
+
|
| 1261 |
+
@keras_export("keras._legacy.backend.reshape")
|
| 1262 |
+
def reshape(x, shape):
|
| 1263 |
+
"""DEPRECATED."""
|
| 1264 |
+
return tf.reshape(x, shape)
|
| 1265 |
+
|
| 1266 |
+
|
| 1267 |
+
@keras_export("keras._legacy.backend.relu")
|
| 1268 |
+
def relu(x, alpha=0.0, max_value=None, threshold=0.0):
|
| 1269 |
+
"""DEPRECATED."""
|
| 1270 |
+
# While x can be a tensor or variable, we also see cases where
|
| 1271 |
+
# numpy arrays, lists, tuples are passed as well.
|
| 1272 |
+
# lists, tuples do not have 'dtype' attribute.
|
| 1273 |
+
dtype = getattr(x, "dtype", backend.floatx())
|
| 1274 |
+
if alpha != 0.0:
|
| 1275 |
+
if max_value is None and threshold == 0:
|
| 1276 |
+
return tf.nn.leaky_relu(x, alpha=alpha)
|
| 1277 |
+
|
| 1278 |
+
if threshold != 0:
|
| 1279 |
+
negative_part = tf.nn.relu(-x + threshold)
|
| 1280 |
+
else:
|
| 1281 |
+
negative_part = tf.nn.relu(-x)
|
| 1282 |
+
else:
|
| 1283 |
+
negative_part = 1
|
| 1284 |
+
|
| 1285 |
+
clip_max = max_value is not None
|
| 1286 |
+
|
| 1287 |
+
if threshold != 0:
|
| 1288 |
+
# computes x for x > threshold else 0
|
| 1289 |
+
x = x * tf.cast(tf.greater(x, threshold), dtype=dtype)
|
| 1290 |
+
elif max_value == 6:
|
| 1291 |
+
# if no threshold, then can use nn.relu6 native TF op for performance
|
| 1292 |
+
x = tf.nn.relu6(x)
|
| 1293 |
+
clip_max = False
|
| 1294 |
+
else:
|
| 1295 |
+
x = tf.nn.relu(x)
|
| 1296 |
+
|
| 1297 |
+
if clip_max:
|
| 1298 |
+
max_value = tf.convert_to_tensor(max_value, dtype=x.dtype)
|
| 1299 |
+
zero = tf.convert_to_tensor(0, dtype=x.dtype)
|
| 1300 |
+
x = tf.clip_by_value(x, zero, max_value)
|
| 1301 |
+
|
| 1302 |
+
if alpha != 0.0:
|
| 1303 |
+
alpha = tf.convert_to_tensor(alpha, dtype=x.dtype)
|
| 1304 |
+
x -= alpha * negative_part
|
| 1305 |
+
return x
|
| 1306 |
+
|
| 1307 |
+
|
| 1308 |
+
@keras_export("keras._legacy.backend.repeat")
|
| 1309 |
+
def repeat(x, n):
|
| 1310 |
+
"""DEPRECATED."""
|
| 1311 |
+
assert ndim(x) == 2
|
| 1312 |
+
x = tf.expand_dims(x, 1)
|
| 1313 |
+
pattern = tf.stack([1, n, 1])
|
| 1314 |
+
return tf.tile(x, pattern)
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
@keras_export("keras._legacy.backend.repeat_elements")
|
| 1318 |
+
def repeat_elements(x, rep, axis):
|
| 1319 |
+
"""DEPRECATED."""
|
| 1320 |
+
x_shape = x.shape.as_list()
|
| 1321 |
+
# For static axis
|
| 1322 |
+
if x_shape[axis] is not None:
|
| 1323 |
+
# slices along the repeat axis
|
| 1324 |
+
splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis)
|
| 1325 |
+
# repeat each slice the given number of reps
|
| 1326 |
+
x_rep = [s for s in splits for _ in range(rep)]
|
| 1327 |
+
return concatenate(x_rep, axis)
|
| 1328 |
+
|
| 1329 |
+
# Here we use tf.tile to mimic behavior of np.repeat so that
|
| 1330 |
+
# we can handle dynamic shapes (that include None).
|
| 1331 |
+
# To do that, we need an auxiliary axis to repeat elements along
|
| 1332 |
+
# it and then merge them along the desired axis.
|
| 1333 |
+
|
| 1334 |
+
# Repeating
|
| 1335 |
+
auxiliary_axis = axis + 1
|
| 1336 |
+
x_shape = tf.shape(x)
|
| 1337 |
+
x_rep = tf.expand_dims(x, axis=auxiliary_axis)
|
| 1338 |
+
reps = np.ones(len(x.shape) + 1)
|
| 1339 |
+
reps[auxiliary_axis] = rep
|
| 1340 |
+
x_rep = tf.tile(x_rep, reps)
|
| 1341 |
+
|
| 1342 |
+
# Merging
|
| 1343 |
+
reps = np.delete(reps, auxiliary_axis)
|
| 1344 |
+
reps[axis] = rep
|
| 1345 |
+
reps = tf.constant(reps, dtype="int32")
|
| 1346 |
+
x_shape *= reps
|
| 1347 |
+
x_rep = tf.reshape(x_rep, x_shape)
|
| 1348 |
+
|
| 1349 |
+
# Fix shape representation
|
| 1350 |
+
x_shape = x.shape.as_list()
|
| 1351 |
+
x_rep.set_shape(x_shape)
|
| 1352 |
+
return x_rep
|
| 1353 |
+
|
| 1354 |
+
|
| 1355 |
+
@keras_export("keras._legacy.backend.resize_images")
|
| 1356 |
+
def resize_images(
|
| 1357 |
+
x, height_factor, width_factor, data_format, interpolation="nearest"
|
| 1358 |
+
):
|
| 1359 |
+
"""DEPRECATED."""
|
| 1360 |
+
if data_format == "channels_first":
|
| 1361 |
+
rows, cols = 2, 3
|
| 1362 |
+
elif data_format == "channels_last":
|
| 1363 |
+
rows, cols = 1, 2
|
| 1364 |
+
else:
|
| 1365 |
+
raise ValueError(f"Invalid `data_format` argument: {data_format}")
|
| 1366 |
+
|
| 1367 |
+
new_shape = x.shape[rows : cols + 1]
|
| 1368 |
+
if new_shape.is_fully_defined():
|
| 1369 |
+
new_shape = tf.constant(new_shape.as_list(), dtype="int32")
|
| 1370 |
+
else:
|
| 1371 |
+
new_shape = tf.shape(x)[rows : cols + 1]
|
| 1372 |
+
new_shape *= tf.constant(
|
| 1373 |
+
np.array([height_factor, width_factor], dtype="int32")
|
| 1374 |
+
)
|
| 1375 |
+
|
| 1376 |
+
if data_format == "channels_first":
|
| 1377 |
+
x = permute_dimensions(x, [0, 2, 3, 1])
|
| 1378 |
+
interpolations = {
|
| 1379 |
+
"area": tf.image.ResizeMethod.AREA,
|
| 1380 |
+
"bicubic": tf.image.ResizeMethod.BICUBIC,
|
| 1381 |
+
"bilinear": tf.image.ResizeMethod.BILINEAR,
|
| 1382 |
+
"gaussian": tf.image.ResizeMethod.GAUSSIAN,
|
| 1383 |
+
"lanczos3": tf.image.ResizeMethod.LANCZOS3,
|
| 1384 |
+
"lanczos5": tf.image.ResizeMethod.LANCZOS5,
|
| 1385 |
+
"mitchellcubic": tf.image.ResizeMethod.MITCHELLCUBIC,
|
| 1386 |
+
"nearest": tf.image.ResizeMethod.NEAREST_NEIGHBOR,
|
| 1387 |
+
}
|
| 1388 |
+
interploations_list = '"' + '", "'.join(interpolations.keys()) + '"'
|
| 1389 |
+
if interpolation in interpolations:
|
| 1390 |
+
x = tf.image.resize(x, new_shape, method=interpolations[interpolation])
|
| 1391 |
+
else:
|
| 1392 |
+
raise ValueError(
|
| 1393 |
+
"`interpolation` argument should be one of: "
|
| 1394 |
+
f'{interploations_list}. Received: "{interpolation}".'
|
| 1395 |
+
)
|
| 1396 |
+
if data_format == "channels_first":
|
| 1397 |
+
x = permute_dimensions(x, [0, 3, 1, 2])
|
| 1398 |
+
|
| 1399 |
+
return x
|
| 1400 |
+
|
| 1401 |
+
|
| 1402 |
+
@keras_export("keras._legacy.backend.resize_volumes")
|
| 1403 |
+
def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
|
| 1404 |
+
"""DEPRECATED."""
|
| 1405 |
+
if data_format == "channels_first":
|
| 1406 |
+
output = repeat_elements(x, depth_factor, axis=2)
|
| 1407 |
+
output = repeat_elements(output, height_factor, axis=3)
|
| 1408 |
+
output = repeat_elements(output, width_factor, axis=4)
|
| 1409 |
+
return output
|
| 1410 |
+
elif data_format == "channels_last":
|
| 1411 |
+
output = repeat_elements(x, depth_factor, axis=1)
|
| 1412 |
+
output = repeat_elements(output, height_factor, axis=2)
|
| 1413 |
+
output = repeat_elements(output, width_factor, axis=3)
|
| 1414 |
+
return output
|
| 1415 |
+
else:
|
| 1416 |
+
raise ValueError(f"Invalid data_format: {data_format}")
|
| 1417 |
+
|
| 1418 |
+
|
| 1419 |
+
@keras_export("keras._legacy.backend.reverse")
|
| 1420 |
+
def reverse(x, axes):
|
| 1421 |
+
"""DEPRECATED."""
|
| 1422 |
+
if isinstance(axes, int):
|
| 1423 |
+
axes = [axes]
|
| 1424 |
+
return tf.reverse(x, axes)
|
| 1425 |
+
|
| 1426 |
+
|
| 1427 |
+
@keras_export("keras._legacy.backend.rnn")
|
| 1428 |
+
def rnn(
|
| 1429 |
+
step_function,
|
| 1430 |
+
inputs,
|
| 1431 |
+
initial_states,
|
| 1432 |
+
go_backwards=False,
|
| 1433 |
+
mask=None,
|
| 1434 |
+
constants=None,
|
| 1435 |
+
unroll=False,
|
| 1436 |
+
input_length=None,
|
| 1437 |
+
time_major=False,
|
| 1438 |
+
zero_output_for_mask=False,
|
| 1439 |
+
return_all_outputs=True,
|
| 1440 |
+
):
|
| 1441 |
+
"""DEPRECATED."""
|
| 1442 |
+
if not tf.__internal__.tf2.enabled():
|
| 1443 |
+
return_all_outputs = True # Not supported in TF1.
|
| 1444 |
+
|
| 1445 |
+
def swap_batch_timestep(input_t):
|
| 1446 |
+
# Swap the batch and timestep dim for the incoming tensor.
|
| 1447 |
+
axes = list(range(len(input_t.shape)))
|
| 1448 |
+
axes[0], axes[1] = 1, 0
|
| 1449 |
+
return tf.transpose(input_t, axes)
|
| 1450 |
+
|
| 1451 |
+
if not time_major:
|
| 1452 |
+
inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
|
| 1453 |
+
|
| 1454 |
+
flatted_inputs = tf.nest.flatten(inputs)
|
| 1455 |
+
time_steps = flatted_inputs[0].shape[0]
|
| 1456 |
+
batch = flatted_inputs[0].shape[1]
|
| 1457 |
+
time_steps_t = tf.shape(flatted_inputs[0])[0]
|
| 1458 |
+
|
| 1459 |
+
for input_ in flatted_inputs:
|
| 1460 |
+
input_.shape.with_rank_at_least(3)
|
| 1461 |
+
|
| 1462 |
+
if mask is not None:
|
| 1463 |
+
if mask.dtype != tf.bool:
|
| 1464 |
+
mask = tf.cast(mask, tf.bool)
|
| 1465 |
+
if len(mask.shape) == 2:
|
| 1466 |
+
mask = expand_dims(mask)
|
| 1467 |
+
if not time_major:
|
| 1468 |
+
mask = swap_batch_timestep(mask)
|
| 1469 |
+
|
| 1470 |
+
if constants is None:
|
| 1471 |
+
constants = []
|
| 1472 |
+
|
| 1473 |
+
# tf.where needs its condition tensor to be the same shape as its two
|
| 1474 |
+
# result tensors, but in our case the condition (mask) tensor is
|
| 1475 |
+
# (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
|
| 1476 |
+
# So we need to broadcast the mask to match the shape of inputs.
|
| 1477 |
+
# That's what the tile call does, it just repeats the mask along its
|
| 1478 |
+
# second dimension n times.
|
| 1479 |
+
def _expand_mask(mask_t, input_t, fixed_dim=1):
|
| 1480 |
+
if tf.nest.is_nested(mask_t):
|
| 1481 |
+
raise ValueError(
|
| 1482 |
+
f"mask_t is expected to be tensor, but got {mask_t}"
|
| 1483 |
+
)
|
| 1484 |
+
if tf.nest.is_nested(input_t):
|
| 1485 |
+
raise ValueError(
|
| 1486 |
+
f"input_t is expected to be tensor, but got {input_t}"
|
| 1487 |
+
)
|
| 1488 |
+
rank_diff = len(input_t.shape) - len(mask_t.shape)
|
| 1489 |
+
for _ in range(rank_diff):
|
| 1490 |
+
mask_t = tf.expand_dims(mask_t, -1)
|
| 1491 |
+
multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
|
| 1492 |
+
return tf.tile(mask_t, multiples)
|
| 1493 |
+
|
| 1494 |
+
if unroll:
|
| 1495 |
+
if not time_steps:
|
| 1496 |
+
raise ValueError("Unrolling requires a fixed number of timesteps.")
|
| 1497 |
+
states = tuple(initial_states)
|
| 1498 |
+
successive_states = []
|
| 1499 |
+
successive_outputs = []
|
| 1500 |
+
|
| 1501 |
+
# Process the input tensors. The input tensor need to be split on the
|
| 1502 |
+
# time_step dim, and reverse if go_backwards is True. In the case of
|
| 1503 |
+
# nested input, the input is flattened and then transformed
|
| 1504 |
+
# individually. The result of this will be a tuple of lists, each of
|
| 1505 |
+
# the item in tuple is list of the tensor with shape (batch, feature)
|
| 1506 |
+
def _process_single_input_t(input_t):
|
| 1507 |
+
input_t = tf.unstack(input_t) # unstack for time_step dim
|
| 1508 |
+
if go_backwards:
|
| 1509 |
+
input_t.reverse()
|
| 1510 |
+
return input_t
|
| 1511 |
+
|
| 1512 |
+
if tf.nest.is_nested(inputs):
|
| 1513 |
+
processed_input = tf.nest.map_structure(
|
| 1514 |
+
_process_single_input_t, inputs
|
| 1515 |
+
)
|
| 1516 |
+
else:
|
| 1517 |
+
processed_input = (_process_single_input_t(inputs),)
|
| 1518 |
+
|
| 1519 |
+
def _get_input_tensor(time):
|
| 1520 |
+
inp = [t_[time] for t_ in processed_input]
|
| 1521 |
+
return tf.nest.pack_sequence_as(inputs, inp)
|
| 1522 |
+
|
| 1523 |
+
if mask is not None:
|
| 1524 |
+
mask_list = tf.unstack(mask)
|
| 1525 |
+
if go_backwards:
|
| 1526 |
+
mask_list.reverse()
|
| 1527 |
+
|
| 1528 |
+
for i in range(time_steps):
|
| 1529 |
+
inp = _get_input_tensor(i)
|
| 1530 |
+
mask_t = mask_list[i]
|
| 1531 |
+
output, new_states = step_function(
|
| 1532 |
+
inp, tuple(states) + tuple(constants)
|
| 1533 |
+
)
|
| 1534 |
+
tiled_mask_t = _expand_mask(mask_t, output)
|
| 1535 |
+
|
| 1536 |
+
if not successive_outputs:
|
| 1537 |
+
prev_output = zeros_like(output)
|
| 1538 |
+
else:
|
| 1539 |
+
prev_output = successive_outputs[-1]
|
| 1540 |
+
|
| 1541 |
+
output = tf.where(tiled_mask_t, output, prev_output)
|
| 1542 |
+
|
| 1543 |
+
flat_states = tf.nest.flatten(states)
|
| 1544 |
+
flat_new_states = tf.nest.flatten(new_states)
|
| 1545 |
+
tiled_mask_t = tuple(
|
| 1546 |
+
_expand_mask(mask_t, s) for s in flat_states
|
| 1547 |
+
)
|
| 1548 |
+
flat_final_states = tuple(
|
| 1549 |
+
tf.where(m, s, ps)
|
| 1550 |
+
for m, s, ps in zip(
|
| 1551 |
+
tiled_mask_t, flat_new_states, flat_states
|
| 1552 |
+
)
|
| 1553 |
+
)
|
| 1554 |
+
states = tf.nest.pack_sequence_as(states, flat_final_states)
|
| 1555 |
+
|
| 1556 |
+
if return_all_outputs:
|
| 1557 |
+
successive_outputs.append(output)
|
| 1558 |
+
successive_states.append(states)
|
| 1559 |
+
else:
|
| 1560 |
+
successive_outputs = [output]
|
| 1561 |
+
successive_states = [states]
|
| 1562 |
+
last_output = successive_outputs[-1]
|
| 1563 |
+
new_states = successive_states[-1]
|
| 1564 |
+
outputs = tf.stack(successive_outputs)
|
| 1565 |
+
|
| 1566 |
+
if zero_output_for_mask:
|
| 1567 |
+
last_output = tf.where(
|
| 1568 |
+
_expand_mask(mask_list[-1], last_output),
|
| 1569 |
+
last_output,
|
| 1570 |
+
zeros_like(last_output),
|
| 1571 |
+
)
|
| 1572 |
+
outputs = tf.where(
|
| 1573 |
+
_expand_mask(mask, outputs, fixed_dim=2),
|
| 1574 |
+
outputs,
|
| 1575 |
+
zeros_like(outputs),
|
| 1576 |
+
)
|
| 1577 |
+
|
| 1578 |
+
else: # mask is None
|
| 1579 |
+
for i in range(time_steps):
|
| 1580 |
+
inp = _get_input_tensor(i)
|
| 1581 |
+
output, states = step_function(
|
| 1582 |
+
inp, tuple(states) + tuple(constants)
|
| 1583 |
+
)
|
| 1584 |
+
if return_all_outputs:
|
| 1585 |
+
successive_outputs.append(output)
|
| 1586 |
+
successive_states.append(states)
|
| 1587 |
+
else:
|
| 1588 |
+
successive_outputs = [output]
|
| 1589 |
+
successive_states = [states]
|
| 1590 |
+
last_output = successive_outputs[-1]
|
| 1591 |
+
new_states = successive_states[-1]
|
| 1592 |
+
outputs = tf.stack(successive_outputs)
|
| 1593 |
+
|
| 1594 |
+
else: # Unroll == False
|
| 1595 |
+
states = tuple(initial_states)
|
| 1596 |
+
|
| 1597 |
+
# Create input tensor array, if the inputs is nested tensors, then it
|
| 1598 |
+
# will be flattened first, and tensor array will be created one per
|
| 1599 |
+
# flattened tensor.
|
| 1600 |
+
input_ta = tuple(
|
| 1601 |
+
tf.TensorArray(
|
| 1602 |
+
dtype=inp.dtype,
|
| 1603 |
+
size=time_steps_t,
|
| 1604 |
+
tensor_array_name=f"input_ta_{i}",
|
| 1605 |
+
)
|
| 1606 |
+
for i, inp in enumerate(flatted_inputs)
|
| 1607 |
+
)
|
| 1608 |
+
input_ta = tuple(
|
| 1609 |
+
(
|
| 1610 |
+
ta.unstack(input_)
|
| 1611 |
+
if not go_backwards
|
| 1612 |
+
else ta.unstack(reverse(input_, 0))
|
| 1613 |
+
)
|
| 1614 |
+
for ta, input_ in zip(input_ta, flatted_inputs)
|
| 1615 |
+
)
|
| 1616 |
+
|
| 1617 |
+
# Get the time(0) input and compute the output for that, the output will
|
| 1618 |
+
# be used to determine the dtype of output tensor array. Don't read from
|
| 1619 |
+
# input_ta due to TensorArray clear_after_read default to True.
|
| 1620 |
+
input_time_zero = tf.nest.pack_sequence_as(
|
| 1621 |
+
inputs, [inp[0] for inp in flatted_inputs]
|
| 1622 |
+
)
|
| 1623 |
+
# output_time_zero is used to determine the cell output shape and its
|
| 1624 |
+
# dtype. the value is discarded.
|
| 1625 |
+
output_time_zero, _ = step_function(
|
| 1626 |
+
input_time_zero, tuple(initial_states) + tuple(constants)
|
| 1627 |
+
)
|
| 1628 |
+
|
| 1629 |
+
output_ta_size = time_steps_t if return_all_outputs else 1
|
| 1630 |
+
output_ta = tuple(
|
| 1631 |
+
tf.TensorArray(
|
| 1632 |
+
dtype=out.dtype,
|
| 1633 |
+
size=output_ta_size,
|
| 1634 |
+
element_shape=out.shape,
|
| 1635 |
+
tensor_array_name=f"output_ta_{i}",
|
| 1636 |
+
)
|
| 1637 |
+
for i, out in enumerate(tf.nest.flatten(output_time_zero))
|
| 1638 |
+
)
|
| 1639 |
+
|
| 1640 |
+
time = tf.constant(0, dtype="int32", name="time")
|
| 1641 |
+
|
| 1642 |
+
if input_length is None:
|
| 1643 |
+
max_iterations = time_steps_t
|
| 1644 |
+
else:
|
| 1645 |
+
max_iterations = tf.reduce_max(input_length)
|
| 1646 |
+
|
| 1647 |
+
while_loop_kwargs = {
|
| 1648 |
+
"cond": lambda time, *_: time < time_steps_t,
|
| 1649 |
+
"maximum_iterations": max_iterations,
|
| 1650 |
+
"parallel_iterations": 32,
|
| 1651 |
+
"swap_memory": True,
|
| 1652 |
+
}
|
| 1653 |
+
if mask is not None:
|
| 1654 |
+
if go_backwards:
|
| 1655 |
+
mask = reverse(mask, 0)
|
| 1656 |
+
|
| 1657 |
+
mask_ta = tf.TensorArray(
|
| 1658 |
+
dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
|
| 1659 |
+
)
|
| 1660 |
+
mask_ta = mask_ta.unstack(mask)
|
| 1661 |
+
|
| 1662 |
+
def masking_fn(time):
|
| 1663 |
+
return mask_ta.read(time)
|
| 1664 |
+
|
| 1665 |
+
def compute_masked_output(mask_t, flat_out, flat_mask):
|
| 1666 |
+
tiled_mask_t = tuple(
|
| 1667 |
+
_expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
|
| 1668 |
+
for o in flat_out
|
| 1669 |
+
)
|
| 1670 |
+
return tuple(
|
| 1671 |
+
tf.where(m, o, fm)
|
| 1672 |
+
for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
|
| 1673 |
+
)
|
| 1674 |
+
|
| 1675 |
+
elif isinstance(input_length, tf.Tensor):
|
| 1676 |
+
if go_backwards:
|
| 1677 |
+
max_len = tf.reduce_max(input_length, axis=0)
|
| 1678 |
+
rev_input_length = tf.subtract(max_len - 1, input_length)
|
| 1679 |
+
|
| 1680 |
+
def masking_fn(time):
|
| 1681 |
+
return tf.less(rev_input_length, time)
|
| 1682 |
+
|
| 1683 |
+
else:
|
| 1684 |
+
|
| 1685 |
+
def masking_fn(time):
|
| 1686 |
+
return tf.greater(input_length, time)
|
| 1687 |
+
|
| 1688 |
+
def compute_masked_output(mask_t, flat_out, flat_mask):
|
| 1689 |
+
return tuple(
|
| 1690 |
+
tf.compat.v1.where(mask_t, o, zo)
|
| 1691 |
+
for (o, zo) in zip(flat_out, flat_mask)
|
| 1692 |
+
)
|
| 1693 |
+
|
| 1694 |
+
else:
|
| 1695 |
+
masking_fn = None
|
| 1696 |
+
|
| 1697 |
+
if masking_fn is not None:
|
| 1698 |
+
# Mask for the T output will be base on the output of T - 1. In the
|
| 1699 |
+
# case T = 0, a zero filled tensor will be used.
|
| 1700 |
+
flat_zero_output = tuple(
|
| 1701 |
+
tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
|
| 1702 |
+
)
|
| 1703 |
+
|
| 1704 |
+
def _step(time, output_ta_t, prev_output, *states):
|
| 1705 |
+
"""RNN step function.
|
| 1706 |
+
|
| 1707 |
+
Args:
|
| 1708 |
+
time: Current timestep value.
|
| 1709 |
+
output_ta_t: TensorArray.
|
| 1710 |
+
prev_output: tuple of outputs from time - 1.
|
| 1711 |
+
*states: List of states.
|
| 1712 |
+
|
| 1713 |
+
Returns:
|
| 1714 |
+
Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
|
| 1715 |
+
"""
|
| 1716 |
+
current_input = tuple(ta.read(time) for ta in input_ta)
|
| 1717 |
+
# maybe set shape.
|
| 1718 |
+
current_input = tf.nest.pack_sequence_as(inputs, current_input)
|
| 1719 |
+
mask_t = masking_fn(time)
|
| 1720 |
+
output, new_states = step_function(
|
| 1721 |
+
current_input, tuple(states) + tuple(constants)
|
| 1722 |
+
)
|
| 1723 |
+
# mask output
|
| 1724 |
+
flat_output = tf.nest.flatten(output)
|
| 1725 |
+
flat_mask_output = (
|
| 1726 |
+
flat_zero_output
|
| 1727 |
+
if zero_output_for_mask
|
| 1728 |
+
else tf.nest.flatten(prev_output)
|
| 1729 |
+
)
|
| 1730 |
+
flat_new_output = compute_masked_output(
|
| 1731 |
+
mask_t, flat_output, flat_mask_output
|
| 1732 |
+
)
|
| 1733 |
+
|
| 1734 |
+
# mask states
|
| 1735 |
+
flat_state = tf.nest.flatten(states)
|
| 1736 |
+
flat_new_state = tf.nest.flatten(new_states)
|
| 1737 |
+
for state, new_state in zip(flat_state, flat_new_state):
|
| 1738 |
+
if isinstance(new_state, tf.Tensor):
|
| 1739 |
+
new_state.set_shape(state.shape)
|
| 1740 |
+
flat_final_state = compute_masked_output(
|
| 1741 |
+
mask_t, flat_new_state, flat_state
|
| 1742 |
+
)
|
| 1743 |
+
new_states = tf.nest.pack_sequence_as(
|
| 1744 |
+
new_states, flat_final_state
|
| 1745 |
+
)
|
| 1746 |
+
|
| 1747 |
+
ta_index_to_write = time if return_all_outputs else 0
|
| 1748 |
+
output_ta_t = tuple(
|
| 1749 |
+
ta.write(ta_index_to_write, out)
|
| 1750 |
+
for ta, out in zip(output_ta_t, flat_new_output)
|
| 1751 |
+
)
|
| 1752 |
+
|
| 1753 |
+
return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
|
| 1754 |
+
new_states
|
| 1755 |
+
)
|
| 1756 |
+
|
| 1757 |
+
final_outputs = tf.compat.v1.while_loop(
|
| 1758 |
+
body=_step,
|
| 1759 |
+
loop_vars=(time, output_ta, flat_zero_output) + states,
|
| 1760 |
+
**while_loop_kwargs,
|
| 1761 |
+
)
|
| 1762 |
+
# Skip final_outputs[2] which is the output for final timestep.
|
| 1763 |
+
new_states = final_outputs[3:]
|
| 1764 |
+
else:
|
| 1765 |
+
|
| 1766 |
+
def _step(time, output_ta_t, *states):
|
| 1767 |
+
"""RNN step function.
|
| 1768 |
+
|
| 1769 |
+
Args:
|
| 1770 |
+
time: Current timestep value.
|
| 1771 |
+
output_ta_t: TensorArray.
|
| 1772 |
+
*states: List of states.
|
| 1773 |
+
|
| 1774 |
+
Returns:
|
| 1775 |
+
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
|
| 1776 |
+
"""
|
| 1777 |
+
current_input = tuple(ta.read(time) for ta in input_ta)
|
| 1778 |
+
current_input = tf.nest.pack_sequence_as(inputs, current_input)
|
| 1779 |
+
output, new_states = step_function(
|
| 1780 |
+
current_input, tuple(states) + tuple(constants)
|
| 1781 |
+
)
|
| 1782 |
+
flat_state = tf.nest.flatten(states)
|
| 1783 |
+
flat_new_state = tf.nest.flatten(new_states)
|
| 1784 |
+
for state, new_state in zip(flat_state, flat_new_state):
|
| 1785 |
+
if isinstance(new_state, tf.Tensor):
|
| 1786 |
+
new_state.set_shape(state.shape)
|
| 1787 |
+
|
| 1788 |
+
flat_output = tf.nest.flatten(output)
|
| 1789 |
+
ta_index_to_write = time if return_all_outputs else 0
|
| 1790 |
+
output_ta_t = tuple(
|
| 1791 |
+
ta.write(ta_index_to_write, out)
|
| 1792 |
+
for ta, out in zip(output_ta_t, flat_output)
|
| 1793 |
+
)
|
| 1794 |
+
|
| 1795 |
+
new_states = tf.nest.pack_sequence_as(
|
| 1796 |
+
initial_states, flat_new_state
|
| 1797 |
+
)
|
| 1798 |
+
return (time + 1, output_ta_t) + tuple(new_states)
|
| 1799 |
+
|
| 1800 |
+
final_outputs = tf.compat.v1.while_loop(
|
| 1801 |
+
body=_step,
|
| 1802 |
+
loop_vars=(time, output_ta) + states,
|
| 1803 |
+
**while_loop_kwargs,
|
| 1804 |
+
)
|
| 1805 |
+
new_states = final_outputs[2:]
|
| 1806 |
+
|
| 1807 |
+
output_ta = final_outputs[1]
|
| 1808 |
+
|
| 1809 |
+
outputs = tuple(o.stack() for o in output_ta)
|
| 1810 |
+
last_output = tuple(o[-1] for o in outputs)
|
| 1811 |
+
|
| 1812 |
+
outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
|
| 1813 |
+
last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
|
| 1814 |
+
|
| 1815 |
+
# static shape inference
|
| 1816 |
+
def set_shape(output_):
|
| 1817 |
+
if isinstance(output_, tf.Tensor):
|
| 1818 |
+
shape = output_.shape.as_list()
|
| 1819 |
+
if return_all_outputs:
|
| 1820 |
+
shape[0] = time_steps
|
| 1821 |
+
else:
|
| 1822 |
+
shape[0] = 1
|
| 1823 |
+
shape[1] = batch
|
| 1824 |
+
output_.set_shape(shape)
|
| 1825 |
+
return output_
|
| 1826 |
+
|
| 1827 |
+
outputs = tf.nest.map_structure(set_shape, outputs)
|
| 1828 |
+
|
| 1829 |
+
if not time_major:
|
| 1830 |
+
outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
|
| 1831 |
+
|
| 1832 |
+
return last_output, outputs, new_states
|
| 1833 |
+
|
| 1834 |
+
|
| 1835 |
+
@keras_export("keras._legacy.backend.round")
|
| 1836 |
+
def round(x):
|
| 1837 |
+
"""DEPRECATED."""
|
| 1838 |
+
return tf.round(x)
|
| 1839 |
+
|
| 1840 |
+
|
| 1841 |
+
@keras_export("keras._legacy.backend.separable_conv2d")
|
| 1842 |
+
def separable_conv2d(
|
| 1843 |
+
x,
|
| 1844 |
+
depthwise_kernel,
|
| 1845 |
+
pointwise_kernel,
|
| 1846 |
+
strides=(1, 1),
|
| 1847 |
+
padding="valid",
|
| 1848 |
+
data_format=None,
|
| 1849 |
+
dilation_rate=(1, 1),
|
| 1850 |
+
):
|
| 1851 |
+
"""DEPRECATED."""
|
| 1852 |
+
if data_format is None:
|
| 1853 |
+
data_format = backend.image_data_format()
|
| 1854 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 1855 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 1856 |
+
if len(strides) != 2:
|
| 1857 |
+
raise ValueError("`strides` must be a tuple of 2 integers.")
|
| 1858 |
+
|
| 1859 |
+
x, tf_data_format = _preprocess_conv2d_input(x, data_format)
|
| 1860 |
+
padding = _preprocess_padding(padding)
|
| 1861 |
+
if not isinstance(strides, tuple):
|
| 1862 |
+
strides = tuple(strides)
|
| 1863 |
+
if tf_data_format == "NHWC":
|
| 1864 |
+
strides = (1,) + strides + (1,)
|
| 1865 |
+
else:
|
| 1866 |
+
strides = (1, 1) + strides
|
| 1867 |
+
|
| 1868 |
+
x = tf.nn.separable_conv2d(
|
| 1869 |
+
x,
|
| 1870 |
+
depthwise_kernel,
|
| 1871 |
+
pointwise_kernel,
|
| 1872 |
+
strides=strides,
|
| 1873 |
+
padding=padding,
|
| 1874 |
+
dilations=dilation_rate,
|
| 1875 |
+
data_format=tf_data_format,
|
| 1876 |
+
)
|
| 1877 |
+
if data_format == "channels_first" and tf_data_format == "NHWC":
|
| 1878 |
+
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
|
| 1879 |
+
return x
|
| 1880 |
+
|
| 1881 |
+
|
| 1882 |
+
@keras_export("keras._legacy.backend.set_value")
|
| 1883 |
+
def set_value(x, value):
|
| 1884 |
+
"""DEPRECATED."""
|
| 1885 |
+
value = np.asarray(value, dtype=x.dtype.name)
|
| 1886 |
+
x.assign(value)
|
| 1887 |
+
|
| 1888 |
+
|
| 1889 |
+
@keras_export("keras._legacy.backend.shape")
|
| 1890 |
+
def shape(x):
|
| 1891 |
+
"""DEPRECATED."""
|
| 1892 |
+
return tf.shape(x)
|
| 1893 |
+
|
| 1894 |
+
|
| 1895 |
+
@keras_export("keras._legacy.backend.sigmoid")
|
| 1896 |
+
def sigmoid(x):
|
| 1897 |
+
"""DEPRECATED."""
|
| 1898 |
+
output = tf.sigmoid(x)
|
| 1899 |
+
return output
|
| 1900 |
+
|
| 1901 |
+
|
| 1902 |
+
@keras_export("keras._legacy.backend.sign")
|
| 1903 |
+
def sign(x):
|
| 1904 |
+
"""DEPRECATED."""
|
| 1905 |
+
return tf.sign(x)
|
| 1906 |
+
|
| 1907 |
+
|
| 1908 |
+
@keras_export("keras._legacy.backend.sin")
|
| 1909 |
+
def sin(x):
|
| 1910 |
+
"""DEPRECATED."""
|
| 1911 |
+
return tf.sin(x)
|
| 1912 |
+
|
| 1913 |
+
|
| 1914 |
+
@keras_export("keras._legacy.backend.softmax")
|
| 1915 |
+
def softmax(x, axis=-1):
|
| 1916 |
+
"""DEPRECATED."""
|
| 1917 |
+
if x.shape.rank <= 1:
|
| 1918 |
+
raise ValueError(
|
| 1919 |
+
f"Cannot apply softmax to a tensor that is 1D. Received input: {x}"
|
| 1920 |
+
)
|
| 1921 |
+
|
| 1922 |
+
if isinstance(axis, int):
|
| 1923 |
+
output = tf.nn.softmax(x, axis=axis)
|
| 1924 |
+
else:
|
| 1925 |
+
# nn.softmax does not support tuple axis.
|
| 1926 |
+
numerator = tf.exp(x - tf.reduce_max(x, axis=axis, keepdims=True))
|
| 1927 |
+
denominator = tf.reduce_sum(numerator, axis=axis, keepdims=True)
|
| 1928 |
+
output = numerator / denominator
|
| 1929 |
+
|
| 1930 |
+
# Cache the logits to use for crossentropy loss.
|
| 1931 |
+
output._keras_logits = x
|
| 1932 |
+
return output
|
| 1933 |
+
|
| 1934 |
+
|
| 1935 |
+
@keras_export("keras._legacy.backend.softplus")
|
| 1936 |
+
def softplus(x):
|
| 1937 |
+
"""DEPRECATED."""
|
| 1938 |
+
return tf.math.softplus(x)
|
| 1939 |
+
|
| 1940 |
+
|
| 1941 |
+
@keras_export("keras._legacy.backend.softsign")
|
| 1942 |
+
def softsign(x):
|
| 1943 |
+
"""DEPRECATED."""
|
| 1944 |
+
return tf.math.softsign(x)
|
| 1945 |
+
|
| 1946 |
+
|
| 1947 |
+
@keras_export("keras._legacy.backend.sparse_categorical_crossentropy")
|
| 1948 |
+
def sparse_categorical_crossentropy(
|
| 1949 |
+
target, output, from_logits=False, axis=-1, ignore_class=None
|
| 1950 |
+
):
|
| 1951 |
+
"""DEPRECATED."""
|
| 1952 |
+
target = tf.convert_to_tensor(target)
|
| 1953 |
+
output = tf.convert_to_tensor(output)
|
| 1954 |
+
|
| 1955 |
+
target = cast(target, "int64")
|
| 1956 |
+
|
| 1957 |
+
if not from_logits:
|
| 1958 |
+
epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
|
| 1959 |
+
output = tf.clip_by_value(output, epsilon_, 1 - epsilon_)
|
| 1960 |
+
output = tf.math.log(output)
|
| 1961 |
+
|
| 1962 |
+
# Permute output so that the last axis contains the logits/probabilities.
|
| 1963 |
+
if isinstance(output.shape, (tuple, list)):
|
| 1964 |
+
output_rank = len(output.shape)
|
| 1965 |
+
else:
|
| 1966 |
+
output_rank = output.shape.ndims
|
| 1967 |
+
if output_rank is not None:
|
| 1968 |
+
axis %= output_rank
|
| 1969 |
+
if axis != output_rank - 1:
|
| 1970 |
+
permutation = list(
|
| 1971 |
+
itertools.chain(
|
| 1972 |
+
range(axis), range(axis + 1, output_rank), [axis]
|
| 1973 |
+
)
|
| 1974 |
+
)
|
| 1975 |
+
output = tf.transpose(output, perm=permutation)
|
| 1976 |
+
elif axis != -1:
|
| 1977 |
+
raise ValueError(
|
| 1978 |
+
"Cannot compute sparse categorical crossentropy with `axis={}` "
|
| 1979 |
+
"on an output tensor with unknown rank".format(axis)
|
| 1980 |
+
)
|
| 1981 |
+
|
| 1982 |
+
# Try to adjust the shape so that rank of labels = rank of logits - 1.
|
| 1983 |
+
output_shape = tf.shape(output)
|
| 1984 |
+
target_rank = target.shape.ndims
|
| 1985 |
+
|
| 1986 |
+
update_shape = (
|
| 1987 |
+
target_rank is not None
|
| 1988 |
+
and output_rank is not None
|
| 1989 |
+
and target_rank != output_rank - 1
|
| 1990 |
+
)
|
| 1991 |
+
if update_shape:
|
| 1992 |
+
target = flatten(target)
|
| 1993 |
+
output = tf.reshape(output, [-1, output_shape[-1]])
|
| 1994 |
+
|
| 1995 |
+
if ignore_class is not None:
|
| 1996 |
+
valid_mask = tf.not_equal(target, cast(ignore_class, target.dtype))
|
| 1997 |
+
target = target[valid_mask]
|
| 1998 |
+
output = output[valid_mask]
|
| 1999 |
+
|
| 2000 |
+
res = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
| 2001 |
+
labels=target, logits=output
|
| 2002 |
+
)
|
| 2003 |
+
|
| 2004 |
+
if ignore_class is not None:
|
| 2005 |
+
res_shape = cast(output_shape[:-1], "int64")
|
| 2006 |
+
valid_mask = tf.reshape(valid_mask, res_shape)
|
| 2007 |
+
res = tf.scatter_nd(tf.where(valid_mask), res, res_shape)
|
| 2008 |
+
res._keras_mask = valid_mask
|
| 2009 |
+
|
| 2010 |
+
return res
|
| 2011 |
+
|
| 2012 |
+
if update_shape and output_rank >= 3:
|
| 2013 |
+
# If our output includes timesteps or
|
| 2014 |
+
# spatial dimensions we need to reshape
|
| 2015 |
+
res = tf.reshape(res, output_shape[:-1])
|
| 2016 |
+
|
| 2017 |
+
return res
|
| 2018 |
+
|
| 2019 |
+
|
| 2020 |
+
@keras_export("keras._legacy.backend.spatial_2d_padding")
|
| 2021 |
+
def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
|
| 2022 |
+
"""DEPRECATED."""
|
| 2023 |
+
assert len(padding) == 2
|
| 2024 |
+
assert len(padding[0]) == 2
|
| 2025 |
+
assert len(padding[1]) == 2
|
| 2026 |
+
if data_format is None:
|
| 2027 |
+
data_format = backend.image_data_format()
|
| 2028 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 2029 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 2030 |
+
|
| 2031 |
+
if data_format == "channels_first":
|
| 2032 |
+
pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
|
| 2033 |
+
else:
|
| 2034 |
+
pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
|
| 2035 |
+
return tf.compat.v1.pad(x, pattern)
|
| 2036 |
+
|
| 2037 |
+
|
| 2038 |
+
@keras_export("keras._legacy.backend.spatial_3d_padding")
|
| 2039 |
+
def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
|
| 2040 |
+
"""DEPRECATED."""
|
| 2041 |
+
assert len(padding) == 3
|
| 2042 |
+
assert len(padding[0]) == 2
|
| 2043 |
+
assert len(padding[1]) == 2
|
| 2044 |
+
assert len(padding[2]) == 2
|
| 2045 |
+
if data_format is None:
|
| 2046 |
+
data_format = backend.image_data_format()
|
| 2047 |
+
if data_format not in {"channels_first", "channels_last"}:
|
| 2048 |
+
raise ValueError(f"Unknown data_format: {data_format}")
|
| 2049 |
+
|
| 2050 |
+
if data_format == "channels_first":
|
| 2051 |
+
pattern = [
|
| 2052 |
+
[0, 0],
|
| 2053 |
+
[0, 0],
|
| 2054 |
+
[padding[0][0], padding[0][1]],
|
| 2055 |
+
[padding[1][0], padding[1][1]],
|
| 2056 |
+
[padding[2][0], padding[2][1]],
|
| 2057 |
+
]
|
| 2058 |
+
else:
|
| 2059 |
+
pattern = [
|
| 2060 |
+
[0, 0],
|
| 2061 |
+
[padding[0][0], padding[0][1]],
|
| 2062 |
+
[padding[1][0], padding[1][1]],
|
| 2063 |
+
[padding[2][0], padding[2][1]],
|
| 2064 |
+
[0, 0],
|
| 2065 |
+
]
|
| 2066 |
+
return tf.compat.v1.pad(x, pattern)
|
| 2067 |
+
|
| 2068 |
+
|
| 2069 |
+
@keras_export("keras._legacy.backend.sqrt")
|
| 2070 |
+
def sqrt(x):
|
| 2071 |
+
"""DEPRECATED."""
|
| 2072 |
+
zero = tf.convert_to_tensor(0.0, x.dtype)
|
| 2073 |
+
x = tf.maximum(x, zero)
|
| 2074 |
+
return tf.sqrt(x)
|
| 2075 |
+
|
| 2076 |
+
|
| 2077 |
+
@keras_export("keras._legacy.backend.square")
|
| 2078 |
+
def square(x):
|
| 2079 |
+
"""DEPRECATED."""
|
| 2080 |
+
return tf.square(x)
|
| 2081 |
+
|
| 2082 |
+
|
| 2083 |
+
@keras_export("keras._legacy.backend.squeeze")
|
| 2084 |
+
def squeeze(x, axis):
|
| 2085 |
+
"""DEPRECATED."""
|
| 2086 |
+
return tf.squeeze(x, [axis])
|
| 2087 |
+
|
| 2088 |
+
|
| 2089 |
+
@keras_export("keras._legacy.backend.stack")
|
| 2090 |
+
def stack(x, axis=0):
|
| 2091 |
+
"""DEPRECATED."""
|
| 2092 |
+
return tf.stack(x, axis=axis)
|
| 2093 |
+
|
| 2094 |
+
|
| 2095 |
+
@keras_export("keras._legacy.backend.std")
|
| 2096 |
+
def std(x, axis=None, keepdims=False):
|
| 2097 |
+
"""DEPRECATED."""
|
| 2098 |
+
if x.dtype.base_dtype == tf.bool:
|
| 2099 |
+
x = tf.cast(x, backend.floatx())
|
| 2100 |
+
return tf.math.reduce_std(x, axis=axis, keepdims=keepdims)
|
| 2101 |
+
|
| 2102 |
+
|
| 2103 |
+
@keras_export("keras._legacy.backend.stop_gradient")
|
| 2104 |
+
def stop_gradient(variables):
|
| 2105 |
+
"""DEPRECATED."""
|
| 2106 |
+
if isinstance(variables, (list, tuple)):
|
| 2107 |
+
return map(tf.stop_gradient, variables)
|
| 2108 |
+
return tf.stop_gradient(variables)
|
| 2109 |
+
|
| 2110 |
+
|
| 2111 |
+
@keras_export("keras._legacy.backend.sum")
|
| 2112 |
+
def sum(x, axis=None, keepdims=False):
|
| 2113 |
+
"""DEPRECATED."""
|
| 2114 |
+
return tf.reduce_sum(x, axis, keepdims)
|
| 2115 |
+
|
| 2116 |
+
|
| 2117 |
+
@keras_export("keras._legacy.backend.switch")
|
| 2118 |
+
def switch(condition, then_expression, else_expression):
|
| 2119 |
+
"""DEPRECATED."""
|
| 2120 |
+
if condition.dtype != tf.bool:
|
| 2121 |
+
condition = tf.cast(condition, "bool")
|
| 2122 |
+
cond_ndim = ndim(condition)
|
| 2123 |
+
if not cond_ndim:
|
| 2124 |
+
if not callable(then_expression):
|
| 2125 |
+
|
| 2126 |
+
def then_expression_fn():
|
| 2127 |
+
return then_expression
|
| 2128 |
+
|
| 2129 |
+
else:
|
| 2130 |
+
then_expression_fn = then_expression
|
| 2131 |
+
if not callable(else_expression):
|
| 2132 |
+
|
| 2133 |
+
def else_expression_fn():
|
| 2134 |
+
return else_expression
|
| 2135 |
+
|
| 2136 |
+
else:
|
| 2137 |
+
else_expression_fn = else_expression
|
| 2138 |
+
x = tf.compat.v1.cond(condition, then_expression_fn, else_expression_fn)
|
| 2139 |
+
else:
|
| 2140 |
+
# tf.where needs its condition tensor
|
| 2141 |
+
# to be the same shape as its two
|
| 2142 |
+
# result tensors
|
| 2143 |
+
if callable(then_expression):
|
| 2144 |
+
then_expression = then_expression()
|
| 2145 |
+
if callable(else_expression):
|
| 2146 |
+
else_expression = else_expression()
|
| 2147 |
+
expr_ndim = ndim(then_expression)
|
| 2148 |
+
if cond_ndim > expr_ndim:
|
| 2149 |
+
raise ValueError(
|
| 2150 |
+
"Rank of `condition` should be less than or"
|
| 2151 |
+
" equal to rank of `then_expression` and "
|
| 2152 |
+
"`else_expression`. ndim(condition)="
|
| 2153 |
+
+ str(cond_ndim)
|
| 2154 |
+
+ ", ndim(then_expression)="
|
| 2155 |
+
+ str(expr_ndim)
|
| 2156 |
+
)
|
| 2157 |
+
if cond_ndim > 1:
|
| 2158 |
+
ndim_diff = expr_ndim - cond_ndim
|
| 2159 |
+
cond_shape = tf.concat(
|
| 2160 |
+
[tf.shape(condition), [1] * ndim_diff], axis=0
|
| 2161 |
+
)
|
| 2162 |
+
condition = tf.reshape(condition, cond_shape)
|
| 2163 |
+
expr_shape = tf.shape(then_expression)
|
| 2164 |
+
shape_diff = expr_shape - cond_shape
|
| 2165 |
+
tile_shape = tf.where(
|
| 2166 |
+
shape_diff > 0, expr_shape, tf.ones_like(expr_shape)
|
| 2167 |
+
)
|
| 2168 |
+
condition = tf.tile(condition, tile_shape)
|
| 2169 |
+
x = tf.where(condition, then_expression, else_expression)
|
| 2170 |
+
return x
|
| 2171 |
+
|
| 2172 |
+
|
| 2173 |
+
@keras_export("keras._legacy.backend.tanh")
|
| 2174 |
+
def tanh(x):
|
| 2175 |
+
"""DEPRECATED."""
|
| 2176 |
+
return tf.tanh(x)
|
| 2177 |
+
|
| 2178 |
+
|
| 2179 |
+
@keras_export("keras._legacy.backend.temporal_padding")
|
| 2180 |
+
def temporal_padding(x, padding=(1, 1)):
|
| 2181 |
+
"""DEPRECATED."""
|
| 2182 |
+
assert len(padding) == 2
|
| 2183 |
+
pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
|
| 2184 |
+
return tf.compat.v1.pad(x, pattern)
|
| 2185 |
+
|
| 2186 |
+
|
| 2187 |
+
@keras_export("keras._legacy.backend.tile")
|
| 2188 |
+
def tile(x, n):
|
| 2189 |
+
"""DEPRECATED."""
|
| 2190 |
+
if isinstance(n, int):
|
| 2191 |
+
n = [n]
|
| 2192 |
+
return tf.tile(x, n)
|
| 2193 |
+
|
| 2194 |
+
|
| 2195 |
+
@keras_export("keras._legacy.backend.to_dense")
|
| 2196 |
+
def to_dense(tensor):
|
| 2197 |
+
"""DEPRECATED."""
|
| 2198 |
+
if is_sparse(tensor):
|
| 2199 |
+
return tf.sparse.to_dense(tensor)
|
| 2200 |
+
else:
|
| 2201 |
+
return tensor
|
| 2202 |
+
|
| 2203 |
+
|
| 2204 |
+
@keras_export("keras._legacy.backend.transpose")
|
| 2205 |
+
def transpose(x):
|
| 2206 |
+
"""DEPRECATED."""
|
| 2207 |
+
return tf.transpose(x)
|
| 2208 |
+
|
| 2209 |
+
|
| 2210 |
+
@keras_export("keras._legacy.backend.truncated_normal")
|
| 2211 |
+
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
| 2212 |
+
"""DEPRECATED."""
|
| 2213 |
+
if dtype is None:
|
| 2214 |
+
dtype = backend.floatx()
|
| 2215 |
+
if seed is None:
|
| 2216 |
+
seed = np.random.randint(10e6)
|
| 2217 |
+
return tf.random.truncated_normal(
|
| 2218 |
+
shape, mean, stddev, dtype=dtype, seed=seed
|
| 2219 |
+
)
|
| 2220 |
+
|
| 2221 |
+
|
| 2222 |
+
@keras_export("keras._legacy.backend.update")
|
| 2223 |
+
def update(x, new_x):
|
| 2224 |
+
"""DEPRECATED."""
|
| 2225 |
+
return tf.compat.v1.assign(x, new_x)
|
| 2226 |
+
|
| 2227 |
+
|
| 2228 |
+
@keras_export("keras._legacy.backend.update_add")
|
| 2229 |
+
def update_add(x, increment):
|
| 2230 |
+
"""DEPRECATED."""
|
| 2231 |
+
return tf.compat.v1.assign_add(x, increment)
|
| 2232 |
+
|
| 2233 |
+
|
| 2234 |
+
@keras_export("keras._legacy.backend.update_sub")
|
| 2235 |
+
def update_sub(x, decrement):
|
| 2236 |
+
"""DEPRECATED."""
|
| 2237 |
+
return tf.compat.v1.assign_sub(x, decrement)
|
| 2238 |
+
|
| 2239 |
+
|
| 2240 |
+
@keras_export("keras._legacy.backend.var")
|
| 2241 |
+
def var(x, axis=None, keepdims=False):
|
| 2242 |
+
"""DEPRECATED."""
|
| 2243 |
+
if x.dtype.base_dtype == tf.bool:
|
| 2244 |
+
x = tf.cast(x, backend.floatx())
|
| 2245 |
+
return tf.math.reduce_variance(x, axis=axis, keepdims=keepdims)
|
| 2246 |
+
|
| 2247 |
+
|
| 2248 |
+
@keras_export("keras._legacy.backend.variable")
|
| 2249 |
+
def variable(value, dtype=None, name=None, constraint=None):
|
| 2250 |
+
"""DEPRECATED."""
|
| 2251 |
+
if dtype is None:
|
| 2252 |
+
dtype = backend.floatx()
|
| 2253 |
+
if hasattr(value, "tocoo"):
|
| 2254 |
+
sparse_coo = value.tocoo()
|
| 2255 |
+
indices = np.concatenate(
|
| 2256 |
+
(
|
| 2257 |
+
np.expand_dims(sparse_coo.row, 1),
|
| 2258 |
+
np.expand_dims(sparse_coo.col, 1),
|
| 2259 |
+
),
|
| 2260 |
+
1,
|
| 2261 |
+
)
|
| 2262 |
+
v = tf.SparseTensor(
|
| 2263 |
+
indices=indices,
|
| 2264 |
+
values=sparse_coo.data,
|
| 2265 |
+
dense_shape=sparse_coo.shape,
|
| 2266 |
+
)
|
| 2267 |
+
v._keras_shape = sparse_coo.shape
|
| 2268 |
+
return v
|
| 2269 |
+
v = tf.Variable(
|
| 2270 |
+
value, dtype=tf.as_dtype(dtype), name=name, constraint=constraint
|
| 2271 |
+
)
|
| 2272 |
+
return v
|
| 2273 |
+
|
| 2274 |
+
|
| 2275 |
+
@keras_export("keras._legacy.backend.zeros")
|
| 2276 |
+
def zeros(shape, dtype=None, name=None):
|
| 2277 |
+
"""DEPRECATED."""
|
| 2278 |
+
with tf.init_scope():
|
| 2279 |
+
if dtype is None:
|
| 2280 |
+
dtype = backend.floatx()
|
| 2281 |
+
tf_dtype = tf.as_dtype(dtype)
|
| 2282 |
+
v = tf.zeros(shape=shape, dtype=tf_dtype, name=name)
|
| 2283 |
+
if py_all(v.shape.as_list()):
|
| 2284 |
+
return variable(v, dtype=dtype, name=name)
|
| 2285 |
+
return v
|
| 2286 |
+
|
| 2287 |
+
|
| 2288 |
+
@keras_export("keras._legacy.backend.zeros_like")
|
| 2289 |
+
def zeros_like(x, dtype=None, name=None):
|
| 2290 |
+
"""DEPRECATED."""
|
| 2291 |
+
return tf.zeros_like(x, dtype=dtype, name=name)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/layers.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Legacy Keras 1/2 layers.
|
| 2 |
+
|
| 3 |
+
AlphaDropout
|
| 4 |
+
RandomHeight
|
| 5 |
+
RandomWidth
|
| 6 |
+
ThresholdedReLU
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from keras.src import backend
|
| 10 |
+
from keras.src.api_export import keras_export
|
| 11 |
+
from keras.src.layers.layer import Layer
|
| 12 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@keras_export("keras._legacy.layers.AlphaDropout")
|
| 16 |
+
class AlphaDropout(Layer):
|
| 17 |
+
"""DEPRECATED."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
|
| 20 |
+
super().__init__(**kwargs)
|
| 21 |
+
self.rate = rate
|
| 22 |
+
self.seed = seed
|
| 23 |
+
self.noise_shape = noise_shape
|
| 24 |
+
self.seed_generator = backend.random.SeedGenerator(seed)
|
| 25 |
+
self.supports_masking = True
|
| 26 |
+
self.built = True
|
| 27 |
+
|
| 28 |
+
def call(self, inputs, training=False):
|
| 29 |
+
if training and self.rate > 0:
|
| 30 |
+
alpha = 1.6732632423543772848170429916717
|
| 31 |
+
scale = 1.0507009873554804934193349852946
|
| 32 |
+
alpha_p = -alpha * scale
|
| 33 |
+
|
| 34 |
+
if self.noise_shape is None:
|
| 35 |
+
noise_shape = tf.shape(inputs)
|
| 36 |
+
else:
|
| 37 |
+
noise_shape = self.noise_shape
|
| 38 |
+
kept_idx = tf.greater_equal(
|
| 39 |
+
backend.random.uniform(noise_shape, seed=self.seed_generator),
|
| 40 |
+
self.rate,
|
| 41 |
+
)
|
| 42 |
+
kept_idx = tf.cast(kept_idx, inputs.dtype)
|
| 43 |
+
|
| 44 |
+
# Get affine transformation params
|
| 45 |
+
a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5
|
| 46 |
+
b = -a * alpha_p * self.rate
|
| 47 |
+
|
| 48 |
+
# Apply mask
|
| 49 |
+
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
|
| 50 |
+
|
| 51 |
+
# Do affine transformation
|
| 52 |
+
return a * x + b
|
| 53 |
+
return inputs
|
| 54 |
+
|
| 55 |
+
def get_config(self):
|
| 56 |
+
config = {"rate": self.rate, "seed": self.seed}
|
| 57 |
+
base_config = super().get_config()
|
| 58 |
+
return {**base_config, **config}
|
| 59 |
+
|
| 60 |
+
def compute_output_shape(self, input_shape):
|
| 61 |
+
return input_shape
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@keras_export("keras._legacy.layers.RandomHeight")
|
| 65 |
+
class RandomHeight(Layer):
|
| 66 |
+
"""DEPRECATED."""
|
| 67 |
+
|
| 68 |
+
def __init__(self, factor, interpolation="bilinear", seed=None, **kwargs):
|
| 69 |
+
super().__init__(**kwargs)
|
| 70 |
+
self.seed_generator = backend.random.SeedGenerator(seed)
|
| 71 |
+
self.factor = factor
|
| 72 |
+
if isinstance(factor, (tuple, list)):
|
| 73 |
+
self.height_lower = factor[0]
|
| 74 |
+
self.height_upper = factor[1]
|
| 75 |
+
else:
|
| 76 |
+
self.height_lower = -factor
|
| 77 |
+
self.height_upper = factor
|
| 78 |
+
|
| 79 |
+
if self.height_upper < self.height_lower:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
"`factor` argument cannot have an upper bound lesser than the "
|
| 82 |
+
f"lower bound. Received: factor={factor}"
|
| 83 |
+
)
|
| 84 |
+
if self.height_lower < -1.0 or self.height_upper < -1.0:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
"`factor` argument must have values larger than -1. "
|
| 87 |
+
f"Received: factor={factor}"
|
| 88 |
+
)
|
| 89 |
+
self.interpolation = interpolation
|
| 90 |
+
self.seed = seed
|
| 91 |
+
|
| 92 |
+
def call(self, inputs, training=True):
|
| 93 |
+
inputs = tf.convert_to_tensor(inputs, dtype=self.compute_dtype)
|
| 94 |
+
|
| 95 |
+
def random_height_inputs(inputs):
|
| 96 |
+
"""Inputs height-adjusted with random ops."""
|
| 97 |
+
inputs_shape = tf.shape(inputs)
|
| 98 |
+
img_hd = tf.cast(inputs_shape[-3], tf.float32)
|
| 99 |
+
img_wd = inputs_shape[-2]
|
| 100 |
+
height_factor = backend.random.uniform(
|
| 101 |
+
shape=[],
|
| 102 |
+
minval=(1.0 + self.height_lower),
|
| 103 |
+
maxval=(1.0 + self.height_upper),
|
| 104 |
+
seed=self.seed_generator,
|
| 105 |
+
)
|
| 106 |
+
adjusted_height = tf.cast(height_factor * img_hd, tf.int32)
|
| 107 |
+
adjusted_size = tf.stack([adjusted_height, img_wd])
|
| 108 |
+
output = tf.image.resize(
|
| 109 |
+
images=inputs,
|
| 110 |
+
size=adjusted_size,
|
| 111 |
+
method=self.interpolation,
|
| 112 |
+
)
|
| 113 |
+
# tf.resize will output float32 regardless of input type.
|
| 114 |
+
output = tf.cast(output, self.compute_dtype)
|
| 115 |
+
output_shape = inputs.shape.as_list()
|
| 116 |
+
output_shape[-3] = None
|
| 117 |
+
output.set_shape(output_shape)
|
| 118 |
+
return output
|
| 119 |
+
|
| 120 |
+
if training:
|
| 121 |
+
return random_height_inputs(inputs)
|
| 122 |
+
else:
|
| 123 |
+
return inputs
|
| 124 |
+
|
| 125 |
+
def compute_output_shape(self, input_shape):
|
| 126 |
+
input_shape = list(input_shape)
|
| 127 |
+
input_shape[-3] = None
|
| 128 |
+
return tuple(input_shape)
|
| 129 |
+
|
| 130 |
+
def get_config(self):
|
| 131 |
+
config = {
|
| 132 |
+
"factor": self.factor,
|
| 133 |
+
"interpolation": self.interpolation,
|
| 134 |
+
"seed": self.seed,
|
| 135 |
+
}
|
| 136 |
+
base_config = super().get_config()
|
| 137 |
+
return {**base_config, **config}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@keras_export("keras._legacy.layers.RandomWidth")
|
| 141 |
+
class RandomWidth(Layer):
|
| 142 |
+
"""DEPRECATED."""
|
| 143 |
+
|
| 144 |
+
def __init__(self, factor, interpolation="bilinear", seed=None, **kwargs):
|
| 145 |
+
super().__init__(**kwargs)
|
| 146 |
+
self.seed_generator = backend.random.SeedGenerator(seed)
|
| 147 |
+
self.factor = factor
|
| 148 |
+
if isinstance(factor, (tuple, list)):
|
| 149 |
+
self.width_lower = factor[0]
|
| 150 |
+
self.width_upper = factor[1]
|
| 151 |
+
else:
|
| 152 |
+
self.width_lower = -factor
|
| 153 |
+
self.width_upper = factor
|
| 154 |
+
if self.width_upper < self.width_lower:
|
| 155 |
+
raise ValueError(
|
| 156 |
+
"`factor` argument cannot have an upper bound less than the "
|
| 157 |
+
f"lower bound. Received: factor={factor}"
|
| 158 |
+
)
|
| 159 |
+
if self.width_lower < -1.0 or self.width_upper < -1.0:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"`factor` argument must have values larger than -1. "
|
| 162 |
+
f"Received: factor={factor}"
|
| 163 |
+
)
|
| 164 |
+
self.interpolation = interpolation
|
| 165 |
+
self.seed = seed
|
| 166 |
+
|
| 167 |
+
def call(self, inputs, training=True):
|
| 168 |
+
inputs = tf.convert_to_tensor(inputs, dtype=self.compute_dtype)
|
| 169 |
+
|
| 170 |
+
def random_width_inputs(inputs):
|
| 171 |
+
"""Inputs width-adjusted with random ops."""
|
| 172 |
+
inputs_shape = tf.shape(inputs)
|
| 173 |
+
img_hd = inputs_shape[-3]
|
| 174 |
+
img_wd = tf.cast(inputs_shape[-2], tf.float32)
|
| 175 |
+
width_factor = backend.random.uniform(
|
| 176 |
+
shape=[],
|
| 177 |
+
minval=(1.0 + self.width_lower),
|
| 178 |
+
maxval=(1.0 + self.width_upper),
|
| 179 |
+
seed=self.seed_generator,
|
| 180 |
+
)
|
| 181 |
+
adjusted_width = tf.cast(width_factor * img_wd, tf.int32)
|
| 182 |
+
adjusted_size = tf.stack([img_hd, adjusted_width])
|
| 183 |
+
output = tf.image.resize(
|
| 184 |
+
images=inputs,
|
| 185 |
+
size=adjusted_size,
|
| 186 |
+
method=self.interpolation,
|
| 187 |
+
)
|
| 188 |
+
# tf.resize will output float32 regardless of input type.
|
| 189 |
+
output = tf.cast(output, self.compute_dtype)
|
| 190 |
+
output_shape = inputs.shape.as_list()
|
| 191 |
+
output_shape[-2] = None
|
| 192 |
+
output.set_shape(output_shape)
|
| 193 |
+
return output
|
| 194 |
+
|
| 195 |
+
if training:
|
| 196 |
+
return random_width_inputs(inputs)
|
| 197 |
+
else:
|
| 198 |
+
return inputs
|
| 199 |
+
|
| 200 |
+
def compute_output_shape(self, input_shape):
|
| 201 |
+
input_shape = list(input_shape)
|
| 202 |
+
input_shape[-2] = None
|
| 203 |
+
return tuple(input_shape)
|
| 204 |
+
|
| 205 |
+
def get_config(self):
|
| 206 |
+
config = {
|
| 207 |
+
"factor": self.factor,
|
| 208 |
+
"interpolation": self.interpolation,
|
| 209 |
+
"seed": self.seed,
|
| 210 |
+
}
|
| 211 |
+
base_config = super().get_config()
|
| 212 |
+
return {**base_config, **config}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@keras_export("keras._legacy.layers.ThresholdedReLU")
|
| 216 |
+
class ThresholdedReLU(Layer):
|
| 217 |
+
"""DEPRECATED."""
|
| 218 |
+
|
| 219 |
+
def __init__(self, theta=1.0, **kwargs):
|
| 220 |
+
super().__init__(**kwargs)
|
| 221 |
+
if theta is None:
|
| 222 |
+
raise ValueError(
|
| 223 |
+
"Theta of a Thresholded ReLU layer cannot be None, expecting a "
|
| 224 |
+
f"float. Received: {theta}"
|
| 225 |
+
)
|
| 226 |
+
if theta < 0:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"The theta value of a Thresholded ReLU layer "
|
| 229 |
+
f"should be >=0. Received: {theta}"
|
| 230 |
+
)
|
| 231 |
+
self.supports_masking = True
|
| 232 |
+
self.theta = tf.convert_to_tensor(theta, dtype=self.compute_dtype)
|
| 233 |
+
|
| 234 |
+
def call(self, inputs):
|
| 235 |
+
dtype = self.compute_dtype
|
| 236 |
+
return inputs * tf.cast(tf.greater(inputs, self.theta), dtype)
|
| 237 |
+
|
| 238 |
+
def get_config(self):
|
| 239 |
+
config = {"theta": float(self.theta)}
|
| 240 |
+
base_config = super().get_config()
|
| 241 |
+
return {**base_config, **config}
|
| 242 |
+
|
| 243 |
+
def compute_output_shape(self, input_shape):
|
| 244 |
+
return input_shape
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/losses.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src.api_export import keras_export
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@keras_export("keras._legacy.losses.Reduction")
|
| 5 |
+
class Reduction:
|
| 6 |
+
AUTO = "auto"
|
| 7 |
+
NONE = "none"
|
| 8 |
+
SUM = "sum"
|
| 9 |
+
SUM_OVER_BATCH_SIZE = "sum_over_batch_size"
|
| 10 |
+
|
| 11 |
+
@classmethod
|
| 12 |
+
def all(cls):
|
| 13 |
+
return (cls.AUTO, cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
|
| 14 |
+
|
| 15 |
+
@classmethod
|
| 16 |
+
def validate(cls, key):
|
| 17 |
+
if key not in cls.all():
|
| 18 |
+
raise ValueError(
|
| 19 |
+
f'Invalid Reduction Key: {key}. Expected keys are "{cls.all()}"'
|
| 20 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/image.py
ADDED
|
@@ -0,0 +1,1892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deprecated image preprocessing APIs from Keras 1."""
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import multiprocessing
|
| 5 |
+
import os
|
| 6 |
+
import threading
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from keras.src import backend
|
| 12 |
+
from keras.src.api_export import keras_export
|
| 13 |
+
from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset
|
| 14 |
+
from keras.src.utils import image_utils
|
| 15 |
+
from keras.src.utils import io_utils
|
| 16 |
+
from keras.src.utils.module_utils import scipy
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@keras_export("keras._legacy.preprocessing.image.Iterator")
|
| 20 |
+
class Iterator(PyDataset):
|
| 21 |
+
"""Base class for image data iterators.
|
| 22 |
+
|
| 23 |
+
DEPRECATED.
|
| 24 |
+
|
| 25 |
+
Every `Iterator` must implement the `_get_batches_of_transformed_samples`
|
| 26 |
+
method.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
n: Integer, total number of samples in the dataset to loop over.
|
| 30 |
+
batch_size: Integer, size of a batch.
|
| 31 |
+
shuffle: Boolean, whether to shuffle the data between epochs.
|
| 32 |
+
seed: Random seeding for data shuffling.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
white_list_formats = ("png", "jpg", "jpeg", "bmp", "ppm", "tif", "tiff")
|
| 36 |
+
|
| 37 |
+
def __init__(self, n, batch_size, shuffle, seed):
|
| 38 |
+
self.n = n
|
| 39 |
+
self.batch_size = batch_size
|
| 40 |
+
self.seed = seed
|
| 41 |
+
self.shuffle = shuffle
|
| 42 |
+
self.batch_index = 0
|
| 43 |
+
self.total_batches_seen = 0
|
| 44 |
+
self.lock = threading.Lock()
|
| 45 |
+
self.index_array = None
|
| 46 |
+
self.index_generator = self._flow_index()
|
| 47 |
+
|
| 48 |
+
def _set_index_array(self):
|
| 49 |
+
self.index_array = np.arange(self.n)
|
| 50 |
+
if self.shuffle:
|
| 51 |
+
self.index_array = np.random.permutation(self.n)
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx):
|
| 54 |
+
if idx >= len(self):
|
| 55 |
+
raise ValueError(
|
| 56 |
+
"Asked to retrieve element {idx}, "
|
| 57 |
+
"but the Sequence "
|
| 58 |
+
"has length {length}".format(idx=idx, length=len(self))
|
| 59 |
+
)
|
| 60 |
+
if self.seed is not None:
|
| 61 |
+
np.random.seed(self.seed + self.total_batches_seen)
|
| 62 |
+
self.total_batches_seen += 1
|
| 63 |
+
if self.index_array is None:
|
| 64 |
+
self._set_index_array()
|
| 65 |
+
index_array = self.index_array[
|
| 66 |
+
self.batch_size * idx : self.batch_size * (idx + 1)
|
| 67 |
+
]
|
| 68 |
+
return self._get_batches_of_transformed_samples(index_array)
|
| 69 |
+
|
| 70 |
+
def __len__(self):
|
| 71 |
+
return (self.n + self.batch_size - 1) // self.batch_size # round up
|
| 72 |
+
|
| 73 |
+
def on_epoch_end(self):
|
| 74 |
+
self._set_index_array()
|
| 75 |
+
|
| 76 |
+
def reset(self):
|
| 77 |
+
self.batch_index = 0
|
| 78 |
+
|
| 79 |
+
def _flow_index(self):
|
| 80 |
+
# Ensure self.batch_index is 0.
|
| 81 |
+
self.reset()
|
| 82 |
+
while 1:
|
| 83 |
+
if self.seed is not None:
|
| 84 |
+
np.random.seed(self.seed + self.total_batches_seen)
|
| 85 |
+
if self.batch_index == 0:
|
| 86 |
+
self._set_index_array()
|
| 87 |
+
|
| 88 |
+
if self.n == 0:
|
| 89 |
+
# Avoiding modulo by zero error
|
| 90 |
+
current_index = 0
|
| 91 |
+
else:
|
| 92 |
+
current_index = (self.batch_index * self.batch_size) % self.n
|
| 93 |
+
if self.n > current_index + self.batch_size:
|
| 94 |
+
self.batch_index += 1
|
| 95 |
+
else:
|
| 96 |
+
self.batch_index = 0
|
| 97 |
+
self.total_batches_seen += 1
|
| 98 |
+
yield self.index_array[
|
| 99 |
+
current_index : current_index + self.batch_size
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
def __iter__(self):
|
| 103 |
+
# Needed if we want to do something like:
|
| 104 |
+
# for x, y in data_gen.flow(...):
|
| 105 |
+
return self
|
| 106 |
+
|
| 107 |
+
def __next__(self):
|
| 108 |
+
with self.lock:
|
| 109 |
+
index_array = next(self.index_generator)
|
| 110 |
+
# The transformation of images is not under thread lock
|
| 111 |
+
# so it can be done in parallel
|
| 112 |
+
return self._get_batches_of_transformed_samples(index_array)
|
| 113 |
+
|
| 114 |
+
def _get_batches_of_transformed_samples(self, index_array):
|
| 115 |
+
"""Gets a batch of transformed samples.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
index_array: Array of sample indices to include in batch.
|
| 119 |
+
Returns:
|
| 120 |
+
A batch of transformed samples.
|
| 121 |
+
"""
|
| 122 |
+
raise NotImplementedError
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _iter_valid_files(directory, white_list_formats, follow_links):
|
| 126 |
+
"""Iterates on files with extension.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
directory: Absolute path to the directory
|
| 130 |
+
containing files to be counted
|
| 131 |
+
white_list_formats: Set of strings containing allowed extensions for
|
| 132 |
+
the files to be counted.
|
| 133 |
+
follow_links: Boolean, follow symbolic links to subdirectories.
|
| 134 |
+
Yields:
|
| 135 |
+
Tuple of (root, filename) with extension in `white_list_formats`.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def _recursive_list(subpath):
|
| 139 |
+
return sorted(
|
| 140 |
+
os.walk(subpath, followlinks=follow_links), key=lambda x: x[0]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
for root, _, files in _recursive_list(directory):
|
| 144 |
+
for fname in sorted(files):
|
| 145 |
+
if fname.lower().endswith(".tiff"):
|
| 146 |
+
warnings.warn(
|
| 147 |
+
'Using ".tiff" files with multiple bands '
|
| 148 |
+
"will cause distortion. Please verify your output."
|
| 149 |
+
)
|
| 150 |
+
if fname.lower().endswith(white_list_formats):
|
| 151 |
+
yield root, fname
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _list_valid_filenames_in_directory(
|
| 155 |
+
directory, white_list_formats, split, class_indices, follow_links
|
| 156 |
+
):
|
| 157 |
+
"""Lists paths of files in `subdir` with extensions in `white_list_formats`.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
directory: absolute path to a directory containing the files to list.
|
| 161 |
+
The directory name is used as class label
|
| 162 |
+
and must be a key of `class_indices`.
|
| 163 |
+
white_list_formats: set of strings containing allowed extensions for
|
| 164 |
+
the files to be counted.
|
| 165 |
+
split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into
|
| 166 |
+
account a certain fraction of files in each directory.
|
| 167 |
+
E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent
|
| 168 |
+
of images in each directory.
|
| 169 |
+
class_indices: dictionary mapping a class name to its index.
|
| 170 |
+
follow_links: boolean, follow symbolic links to subdirectories.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
classes: a list of class indices
|
| 174 |
+
filenames: the path of valid files in `directory`, relative from
|
| 175 |
+
`directory`'s parent (e.g., if `directory` is "dataset/class1",
|
| 176 |
+
the filenames will be
|
| 177 |
+
`["class1/file1.jpg", "class1/file2.jpg", ...]`).
|
| 178 |
+
"""
|
| 179 |
+
dirname = os.path.basename(directory)
|
| 180 |
+
if split:
|
| 181 |
+
all_files = list(
|
| 182 |
+
_iter_valid_files(directory, white_list_formats, follow_links)
|
| 183 |
+
)
|
| 184 |
+
num_files = len(all_files)
|
| 185 |
+
start, stop = int(split[0] * num_files), int(split[1] * num_files)
|
| 186 |
+
valid_files = all_files[start:stop]
|
| 187 |
+
else:
|
| 188 |
+
valid_files = _iter_valid_files(
|
| 189 |
+
directory, white_list_formats, follow_links
|
| 190 |
+
)
|
| 191 |
+
classes = []
|
| 192 |
+
filenames = []
|
| 193 |
+
for root, fname in valid_files:
|
| 194 |
+
classes.append(class_indices[dirname])
|
| 195 |
+
absolute_path = os.path.join(root, fname)
|
| 196 |
+
relative_path = os.path.join(
|
| 197 |
+
dirname, os.path.relpath(absolute_path, directory)
|
| 198 |
+
)
|
| 199 |
+
filenames.append(relative_path)
|
| 200 |
+
|
| 201 |
+
return classes, filenames
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class BatchFromFilesMixin:
|
| 205 |
+
"""Adds methods related to getting batches from filenames.
|
| 206 |
+
|
| 207 |
+
It includes the logic to transform image files to batches.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def set_processing_attrs(
|
| 211 |
+
self,
|
| 212 |
+
image_data_generator,
|
| 213 |
+
target_size,
|
| 214 |
+
color_mode,
|
| 215 |
+
data_format,
|
| 216 |
+
save_to_dir,
|
| 217 |
+
save_prefix,
|
| 218 |
+
save_format,
|
| 219 |
+
subset,
|
| 220 |
+
interpolation,
|
| 221 |
+
keep_aspect_ratio,
|
| 222 |
+
):
|
| 223 |
+
"""Sets attributes to use later for processing files into a batch.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
image_data_generator: Instance of `ImageDataGenerator`
|
| 227 |
+
to use for random transformations and normalization.
|
| 228 |
+
target_size: tuple of integers, dimensions to resize input images
|
| 229 |
+
to.
|
| 230 |
+
color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
|
| 231 |
+
Color mode to read images.
|
| 232 |
+
data_format: String, one of `channels_first`, `channels_last`.
|
| 233 |
+
save_to_dir: Optional directory where to save the pictures
|
| 234 |
+
being yielded, in a viewable format. This is useful
|
| 235 |
+
for visualizing the random transformations being
|
| 236 |
+
applied, for debugging purposes.
|
| 237 |
+
save_prefix: String prefix to use for saving sample
|
| 238 |
+
images (if `save_to_dir` is set).
|
| 239 |
+
save_format: Format to use for saving sample images
|
| 240 |
+
(if `save_to_dir` is set).
|
| 241 |
+
subset: Subset of data (`"training"` or `"validation"`) if
|
| 242 |
+
validation_split is set in ImageDataGenerator.
|
| 243 |
+
interpolation: Interpolation method used to resample the image if
|
| 244 |
+
the target size is different from that of the loaded image.
|
| 245 |
+
Supported methods are "nearest", "bilinear", and "bicubic". If
|
| 246 |
+
PIL version 1.1.3 or newer is installed, "lanczos" is also
|
| 247 |
+
supported. If PIL version 3.4.0 or newer is installed, "box" and
|
| 248 |
+
"hamming" are also supported. By default, "nearest" is used.
|
| 249 |
+
keep_aspect_ratio: Boolean, whether to resize images to a target
|
| 250 |
+
size without aspect ratio distortion. The image is cropped in
|
| 251 |
+
the center with target aspect ratio before resizing.
|
| 252 |
+
"""
|
| 253 |
+
self.image_data_generator = image_data_generator
|
| 254 |
+
self.target_size = tuple(target_size)
|
| 255 |
+
self.keep_aspect_ratio = keep_aspect_ratio
|
| 256 |
+
if color_mode not in {"rgb", "rgba", "grayscale"}:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
f"Invalid color mode: {color_mode}"
|
| 259 |
+
'; expected "rgb", "rgba", or "grayscale".'
|
| 260 |
+
)
|
| 261 |
+
self.color_mode = color_mode
|
| 262 |
+
self.data_format = data_format
|
| 263 |
+
if self.color_mode == "rgba":
|
| 264 |
+
if self.data_format == "channels_last":
|
| 265 |
+
self.image_shape = self.target_size + (4,)
|
| 266 |
+
else:
|
| 267 |
+
self.image_shape = (4,) + self.target_size
|
| 268 |
+
elif self.color_mode == "rgb":
|
| 269 |
+
if self.data_format == "channels_last":
|
| 270 |
+
self.image_shape = self.target_size + (3,)
|
| 271 |
+
else:
|
| 272 |
+
self.image_shape = (3,) + self.target_size
|
| 273 |
+
else:
|
| 274 |
+
if self.data_format == "channels_last":
|
| 275 |
+
self.image_shape = self.target_size + (1,)
|
| 276 |
+
else:
|
| 277 |
+
self.image_shape = (1,) + self.target_size
|
| 278 |
+
self.save_to_dir = save_to_dir
|
| 279 |
+
self.save_prefix = save_prefix
|
| 280 |
+
self.save_format = save_format
|
| 281 |
+
self.interpolation = interpolation
|
| 282 |
+
if subset is not None:
|
| 283 |
+
validation_split = self.image_data_generator._validation_split
|
| 284 |
+
if subset == "validation":
|
| 285 |
+
split = (0, validation_split)
|
| 286 |
+
elif subset == "training":
|
| 287 |
+
split = (validation_split, 1)
|
| 288 |
+
else:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"Invalid subset name: {subset};"
|
| 291 |
+
'expected "training" or "validation"'
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
split = None
|
| 295 |
+
self.split = split
|
| 296 |
+
self.subset = subset
|
| 297 |
+
|
| 298 |
+
def _get_batches_of_transformed_samples(self, index_array):
|
| 299 |
+
"""Gets a batch of transformed samples.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
index_array: Array of sample indices to include in batch.
|
| 303 |
+
Returns:
|
| 304 |
+
A batch of transformed samples.
|
| 305 |
+
"""
|
| 306 |
+
batch_x = np.zeros(
|
| 307 |
+
(len(index_array),) + self.image_shape, dtype=self.dtype
|
| 308 |
+
)
|
| 309 |
+
# build batch of image data
|
| 310 |
+
# self.filepaths is dynamic, is better to call it once outside the loop
|
| 311 |
+
filepaths = self.filepaths
|
| 312 |
+
for i, j in enumerate(index_array):
|
| 313 |
+
img = image_utils.load_img(
|
| 314 |
+
filepaths[j],
|
| 315 |
+
color_mode=self.color_mode,
|
| 316 |
+
target_size=self.target_size,
|
| 317 |
+
interpolation=self.interpolation,
|
| 318 |
+
keep_aspect_ratio=self.keep_aspect_ratio,
|
| 319 |
+
)
|
| 320 |
+
x = image_utils.img_to_array(img, data_format=self.data_format)
|
| 321 |
+
# Pillow images should be closed after `load_img`,
|
| 322 |
+
# but not PIL images.
|
| 323 |
+
if hasattr(img, "close"):
|
| 324 |
+
img.close()
|
| 325 |
+
if self.image_data_generator:
|
| 326 |
+
params = self.image_data_generator.get_random_transform(x.shape)
|
| 327 |
+
x = self.image_data_generator.apply_transform(x, params)
|
| 328 |
+
x = self.image_data_generator.standardize(x)
|
| 329 |
+
batch_x[i] = x
|
| 330 |
+
# optionally save augmented images to disk for debugging purposes
|
| 331 |
+
if self.save_to_dir:
|
| 332 |
+
for i, j in enumerate(index_array):
|
| 333 |
+
img = image_utils.array_to_img(
|
| 334 |
+
batch_x[i], self.data_format, scale=True
|
| 335 |
+
)
|
| 336 |
+
fname = "{prefix}_{index}_{hash}.{format}".format(
|
| 337 |
+
prefix=self.save_prefix,
|
| 338 |
+
index=j,
|
| 339 |
+
hash=np.random.randint(1e7),
|
| 340 |
+
format=self.save_format,
|
| 341 |
+
)
|
| 342 |
+
img.save(os.path.join(self.save_to_dir, fname))
|
| 343 |
+
# build batch of labels
|
| 344 |
+
if self.class_mode == "input":
|
| 345 |
+
batch_y = batch_x.copy()
|
| 346 |
+
elif self.class_mode in {"binary", "sparse"}:
|
| 347 |
+
batch_y = np.empty(len(batch_x), dtype=self.dtype)
|
| 348 |
+
for i, n_observation in enumerate(index_array):
|
| 349 |
+
batch_y[i] = self.classes[n_observation]
|
| 350 |
+
elif self.class_mode == "categorical":
|
| 351 |
+
batch_y = np.zeros(
|
| 352 |
+
(len(batch_x), len(self.class_indices)), dtype=self.dtype
|
| 353 |
+
)
|
| 354 |
+
for i, n_observation in enumerate(index_array):
|
| 355 |
+
batch_y[i, self.classes[n_observation]] = 1.0
|
| 356 |
+
elif self.class_mode == "multi_output":
|
| 357 |
+
batch_y = [output[index_array] for output in self.labels]
|
| 358 |
+
elif self.class_mode == "raw":
|
| 359 |
+
batch_y = self.labels[index_array]
|
| 360 |
+
else:
|
| 361 |
+
return batch_x
|
| 362 |
+
if self.sample_weight is None:
|
| 363 |
+
return batch_x, batch_y
|
| 364 |
+
else:
|
| 365 |
+
return batch_x, batch_y, self.sample_weight[index_array]
|
| 366 |
+
|
| 367 |
+
@property
|
| 368 |
+
def filepaths(self):
|
| 369 |
+
"""List of absolute paths to image files."""
|
| 370 |
+
raise NotImplementedError(
|
| 371 |
+
"`filepaths` property method has not "
|
| 372 |
+
"been implemented in {}.".format(type(self).__name__)
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
@property
|
| 376 |
+
def labels(self):
|
| 377 |
+
"""Class labels of every observation."""
|
| 378 |
+
raise NotImplementedError(
|
| 379 |
+
"`labels` property method has not been implemented in {}.".format(
|
| 380 |
+
type(self).__name__
|
| 381 |
+
)
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
@property
|
| 385 |
+
def sample_weight(self):
|
| 386 |
+
raise NotImplementedError(
|
| 387 |
+
"`sample_weight` property method has not "
|
| 388 |
+
"been implemented in {}.".format(type(self).__name__)
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@keras_export("keras._legacy.preprocessing.image.DirectoryIterator")
|
| 393 |
+
class DirectoryIterator(BatchFromFilesMixin, Iterator):
|
| 394 |
+
"""Iterator capable of reading images from a directory on disk.
|
| 395 |
+
|
| 396 |
+
DEPRECATED.
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
allowed_class_modes = {"categorical", "binary", "sparse", "input", None}
|
| 400 |
+
|
| 401 |
+
def __init__(
|
| 402 |
+
self,
|
| 403 |
+
directory,
|
| 404 |
+
image_data_generator,
|
| 405 |
+
target_size=(256, 256),
|
| 406 |
+
color_mode="rgb",
|
| 407 |
+
classes=None,
|
| 408 |
+
class_mode="categorical",
|
| 409 |
+
batch_size=32,
|
| 410 |
+
shuffle=True,
|
| 411 |
+
seed=None,
|
| 412 |
+
data_format=None,
|
| 413 |
+
save_to_dir=None,
|
| 414 |
+
save_prefix="",
|
| 415 |
+
save_format="png",
|
| 416 |
+
follow_links=False,
|
| 417 |
+
subset=None,
|
| 418 |
+
interpolation="nearest",
|
| 419 |
+
keep_aspect_ratio=False,
|
| 420 |
+
dtype=None,
|
| 421 |
+
):
|
| 422 |
+
if data_format is None:
|
| 423 |
+
data_format = backend.image_data_format()
|
| 424 |
+
if dtype is None:
|
| 425 |
+
dtype = backend.floatx()
|
| 426 |
+
super().set_processing_attrs(
|
| 427 |
+
image_data_generator,
|
| 428 |
+
target_size,
|
| 429 |
+
color_mode,
|
| 430 |
+
data_format,
|
| 431 |
+
save_to_dir,
|
| 432 |
+
save_prefix,
|
| 433 |
+
save_format,
|
| 434 |
+
subset,
|
| 435 |
+
interpolation,
|
| 436 |
+
keep_aspect_ratio,
|
| 437 |
+
)
|
| 438 |
+
self.directory = directory
|
| 439 |
+
self.classes = classes
|
| 440 |
+
if class_mode not in self.allowed_class_modes:
|
| 441 |
+
raise ValueError(
|
| 442 |
+
"Invalid class_mode: {}; expected one of: {}".format(
|
| 443 |
+
class_mode, self.allowed_class_modes
|
| 444 |
+
)
|
| 445 |
+
)
|
| 446 |
+
self.class_mode = class_mode
|
| 447 |
+
self.dtype = dtype
|
| 448 |
+
# First, count the number of samples and classes.
|
| 449 |
+
self.samples = 0
|
| 450 |
+
|
| 451 |
+
if not classes:
|
| 452 |
+
classes = []
|
| 453 |
+
for subdir in sorted(os.listdir(directory)):
|
| 454 |
+
if os.path.isdir(os.path.join(directory, subdir)):
|
| 455 |
+
classes.append(subdir)
|
| 456 |
+
self.num_classes = len(classes)
|
| 457 |
+
self.class_indices = dict(zip(classes, range(len(classes))))
|
| 458 |
+
|
| 459 |
+
pool = multiprocessing.pool.ThreadPool()
|
| 460 |
+
|
| 461 |
+
# Second, build an index of the images
|
| 462 |
+
# in the different class subfolders.
|
| 463 |
+
results = []
|
| 464 |
+
self.filenames = []
|
| 465 |
+
i = 0
|
| 466 |
+
for dirpath in (os.path.join(directory, subdir) for subdir in classes):
|
| 467 |
+
results.append(
|
| 468 |
+
pool.apply_async(
|
| 469 |
+
_list_valid_filenames_in_directory,
|
| 470 |
+
(
|
| 471 |
+
dirpath,
|
| 472 |
+
self.white_list_formats,
|
| 473 |
+
self.split,
|
| 474 |
+
self.class_indices,
|
| 475 |
+
follow_links,
|
| 476 |
+
),
|
| 477 |
+
)
|
| 478 |
+
)
|
| 479 |
+
classes_list = []
|
| 480 |
+
for res in results:
|
| 481 |
+
classes, filenames = res.get()
|
| 482 |
+
classes_list.append(classes)
|
| 483 |
+
self.filenames += filenames
|
| 484 |
+
self.samples = len(self.filenames)
|
| 485 |
+
self.classes = np.zeros((self.samples,), dtype="int32")
|
| 486 |
+
for classes in classes_list:
|
| 487 |
+
self.classes[i : i + len(classes)] = classes
|
| 488 |
+
i += len(classes)
|
| 489 |
+
|
| 490 |
+
io_utils.print_msg(
|
| 491 |
+
f"Found {self.samples} images belonging to "
|
| 492 |
+
f"{self.num_classes} classes."
|
| 493 |
+
)
|
| 494 |
+
pool.close()
|
| 495 |
+
pool.join()
|
| 496 |
+
self._filepaths = [
|
| 497 |
+
os.path.join(self.directory, fname) for fname in self.filenames
|
| 498 |
+
]
|
| 499 |
+
super().__init__(self.samples, batch_size, shuffle, seed)
|
| 500 |
+
|
| 501 |
+
@property
|
| 502 |
+
def filepaths(self):
|
| 503 |
+
return self._filepaths
|
| 504 |
+
|
| 505 |
+
@property
|
| 506 |
+
def labels(self):
|
| 507 |
+
return self.classes
|
| 508 |
+
|
| 509 |
+
@property # mixin needs this property to work
|
| 510 |
+
def sample_weight(self):
|
| 511 |
+
# no sample weights will be returned
|
| 512 |
+
return None
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
@keras_export("keras._legacy.preprocessing.image.NumpyArrayIterator")
|
| 516 |
+
class NumpyArrayIterator(Iterator):
|
| 517 |
+
"""Iterator yielding data from a Numpy array.
|
| 518 |
+
|
| 519 |
+
DEPRECATED.
|
| 520 |
+
"""
|
| 521 |
+
|
| 522 |
+
def __init__(
|
| 523 |
+
self,
|
| 524 |
+
x,
|
| 525 |
+
y,
|
| 526 |
+
image_data_generator,
|
| 527 |
+
batch_size=32,
|
| 528 |
+
shuffle=False,
|
| 529 |
+
sample_weight=None,
|
| 530 |
+
seed=None,
|
| 531 |
+
data_format=None,
|
| 532 |
+
save_to_dir=None,
|
| 533 |
+
save_prefix="",
|
| 534 |
+
save_format="png",
|
| 535 |
+
subset=None,
|
| 536 |
+
ignore_class_split=False,
|
| 537 |
+
dtype=None,
|
| 538 |
+
):
|
| 539 |
+
if data_format is None:
|
| 540 |
+
data_format = backend.image_data_format()
|
| 541 |
+
if dtype is None:
|
| 542 |
+
dtype = backend.floatx()
|
| 543 |
+
self.dtype = dtype
|
| 544 |
+
if isinstance(x, tuple) or isinstance(x, list):
|
| 545 |
+
if not isinstance(x[1], list):
|
| 546 |
+
x_misc = [np.asarray(x[1])]
|
| 547 |
+
else:
|
| 548 |
+
x_misc = [np.asarray(xx) for xx in x[1]]
|
| 549 |
+
x = x[0]
|
| 550 |
+
for xx in x_misc:
|
| 551 |
+
if len(x) != len(xx):
|
| 552 |
+
raise ValueError(
|
| 553 |
+
"All of the arrays in `x` "
|
| 554 |
+
"should have the same length. "
|
| 555 |
+
"Found a pair with: "
|
| 556 |
+
f"len(x[0]) = {len(x)}, len(x[?]) = {len(xx)}"
|
| 557 |
+
)
|
| 558 |
+
else:
|
| 559 |
+
x_misc = []
|
| 560 |
+
|
| 561 |
+
if y is not None and len(x) != len(y):
|
| 562 |
+
raise ValueError(
|
| 563 |
+
"`x` (images tensor) and `y` (labels) "
|
| 564 |
+
"should have the same length. "
|
| 565 |
+
f"Found: x.shape = {np.asarray(x).shape}, "
|
| 566 |
+
f"y.shape = {np.asarray(y).shape}"
|
| 567 |
+
)
|
| 568 |
+
if sample_weight is not None and len(x) != len(sample_weight):
|
| 569 |
+
raise ValueError(
|
| 570 |
+
"`x` (images tensor) and `sample_weight` "
|
| 571 |
+
"should have the same length. "
|
| 572 |
+
f"Found: x.shape = {np.asarray(x).shape}, "
|
| 573 |
+
f"sample_weight.shape = {np.asarray(sample_weight).shape}"
|
| 574 |
+
)
|
| 575 |
+
if subset is not None:
|
| 576 |
+
if subset not in {"training", "validation"}:
|
| 577 |
+
raise ValueError(
|
| 578 |
+
f"Invalid subset name: {subset}"
|
| 579 |
+
'; expected "training" or "validation".'
|
| 580 |
+
)
|
| 581 |
+
split_idx = int(len(x) * image_data_generator._validation_split)
|
| 582 |
+
|
| 583 |
+
if (
|
| 584 |
+
y is not None
|
| 585 |
+
and not ignore_class_split
|
| 586 |
+
and not np.array_equal(
|
| 587 |
+
np.unique(y[:split_idx]), np.unique(y[split_idx:])
|
| 588 |
+
)
|
| 589 |
+
):
|
| 590 |
+
raise ValueError(
|
| 591 |
+
"Training and validation subsets "
|
| 592 |
+
"have different number of classes after "
|
| 593 |
+
"the split. If your numpy arrays are "
|
| 594 |
+
"sorted by the label, you might want "
|
| 595 |
+
"to shuffle them."
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
if subset == "validation":
|
| 599 |
+
x = x[:split_idx]
|
| 600 |
+
x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
|
| 601 |
+
if y is not None:
|
| 602 |
+
y = y[:split_idx]
|
| 603 |
+
else:
|
| 604 |
+
x = x[split_idx:]
|
| 605 |
+
x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
|
| 606 |
+
if y is not None:
|
| 607 |
+
y = y[split_idx:]
|
| 608 |
+
|
| 609 |
+
self.x = np.asarray(x, dtype=self.dtype)
|
| 610 |
+
self.x_misc = x_misc
|
| 611 |
+
if self.x.ndim != 4:
|
| 612 |
+
raise ValueError(
|
| 613 |
+
"Input data in `NumpyArrayIterator` "
|
| 614 |
+
"should have rank 4. You passed an array "
|
| 615 |
+
f"with shape {self.x.shape}"
|
| 616 |
+
)
|
| 617 |
+
channels_axis = 3 if data_format == "channels_last" else 1
|
| 618 |
+
if self.x.shape[channels_axis] not in {1, 3, 4}:
|
| 619 |
+
warnings.warn(
|
| 620 |
+
'NumpyArrayIterator is set to use the data format convention "'
|
| 621 |
+
+ data_format
|
| 622 |
+
+ '" (channels on axis '
|
| 623 |
+
+ str(channels_axis)
|
| 624 |
+
+ "), i.e. expected either 1, 3, or 4 channels on axis "
|
| 625 |
+
+ str(channels_axis)
|
| 626 |
+
+ ". However, it was passed an array with shape "
|
| 627 |
+
+ str(self.x.shape)
|
| 628 |
+
+ " ("
|
| 629 |
+
+ str(self.x.shape[channels_axis])
|
| 630 |
+
+ " channels)."
|
| 631 |
+
)
|
| 632 |
+
if y is not None:
|
| 633 |
+
self.y = np.asarray(y)
|
| 634 |
+
else:
|
| 635 |
+
self.y = None
|
| 636 |
+
if sample_weight is not None:
|
| 637 |
+
self.sample_weight = np.asarray(sample_weight)
|
| 638 |
+
else:
|
| 639 |
+
self.sample_weight = None
|
| 640 |
+
self.image_data_generator = image_data_generator
|
| 641 |
+
self.data_format = data_format
|
| 642 |
+
self.save_to_dir = save_to_dir
|
| 643 |
+
self.save_prefix = save_prefix
|
| 644 |
+
self.save_format = save_format
|
| 645 |
+
super().__init__(x.shape[0], batch_size, shuffle, seed)
|
| 646 |
+
|
| 647 |
+
def _get_batches_of_transformed_samples(self, index_array):
|
| 648 |
+
batch_x = np.zeros(
|
| 649 |
+
tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=self.dtype
|
| 650 |
+
)
|
| 651 |
+
for i, j in enumerate(index_array):
|
| 652 |
+
x = self.x[j]
|
| 653 |
+
params = self.image_data_generator.get_random_transform(x.shape)
|
| 654 |
+
x = self.image_data_generator.apply_transform(
|
| 655 |
+
x.astype(self.dtype), params
|
| 656 |
+
)
|
| 657 |
+
x = self.image_data_generator.standardize(x)
|
| 658 |
+
batch_x[i] = x
|
| 659 |
+
|
| 660 |
+
if self.save_to_dir:
|
| 661 |
+
for i, j in enumerate(index_array):
|
| 662 |
+
img = image_utils.array_to_img(
|
| 663 |
+
batch_x[i], self.data_format, scale=True
|
| 664 |
+
)
|
| 665 |
+
fname = "{prefix}_{index}_{hash}.{format}".format(
|
| 666 |
+
prefix=self.save_prefix,
|
| 667 |
+
index=j,
|
| 668 |
+
hash=np.random.randint(1e4),
|
| 669 |
+
format=self.save_format,
|
| 670 |
+
)
|
| 671 |
+
img.save(os.path.join(self.save_to_dir, fname))
|
| 672 |
+
batch_x_miscs = [xx[index_array] for xx in self.x_misc]
|
| 673 |
+
output = (batch_x if not batch_x_miscs else [batch_x] + batch_x_miscs,)
|
| 674 |
+
if self.y is None:
|
| 675 |
+
return output[0]
|
| 676 |
+
output += (self.y[index_array],)
|
| 677 |
+
if self.sample_weight is not None:
|
| 678 |
+
output += (self.sample_weight[index_array],)
|
| 679 |
+
return output
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def validate_filename(filename, white_list_formats):
|
| 683 |
+
"""Check if a filename refers to a valid file.
|
| 684 |
+
|
| 685 |
+
Args:
|
| 686 |
+
filename: String, absolute path to a file
|
| 687 |
+
white_list_formats: Set, allowed file extensions
|
| 688 |
+
Returns:
|
| 689 |
+
A boolean value indicating if the filename is valid or not
|
| 690 |
+
"""
|
| 691 |
+
return filename.lower().endswith(white_list_formats) and os.path.isfile(
|
| 692 |
+
filename
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class DataFrameIterator(BatchFromFilesMixin, Iterator):
|
| 697 |
+
"""Iterator capable of reading images from a directory as a dataframe."""
|
| 698 |
+
|
| 699 |
+
allowed_class_modes = {
|
| 700 |
+
"binary",
|
| 701 |
+
"categorical",
|
| 702 |
+
"input",
|
| 703 |
+
"multi_output",
|
| 704 |
+
"raw",
|
| 705 |
+
"sparse",
|
| 706 |
+
None,
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
def __init__(
|
| 710 |
+
self,
|
| 711 |
+
dataframe,
|
| 712 |
+
directory=None,
|
| 713 |
+
image_data_generator=None,
|
| 714 |
+
x_col="filename",
|
| 715 |
+
y_col="class",
|
| 716 |
+
weight_col=None,
|
| 717 |
+
target_size=(256, 256),
|
| 718 |
+
color_mode="rgb",
|
| 719 |
+
classes=None,
|
| 720 |
+
class_mode="categorical",
|
| 721 |
+
batch_size=32,
|
| 722 |
+
shuffle=True,
|
| 723 |
+
seed=None,
|
| 724 |
+
data_format="channels_last",
|
| 725 |
+
save_to_dir=None,
|
| 726 |
+
save_prefix="",
|
| 727 |
+
save_format="png",
|
| 728 |
+
subset=None,
|
| 729 |
+
interpolation="nearest",
|
| 730 |
+
keep_aspect_ratio=False,
|
| 731 |
+
dtype="float32",
|
| 732 |
+
validate_filenames=True,
|
| 733 |
+
):
|
| 734 |
+
super().set_processing_attrs(
|
| 735 |
+
image_data_generator,
|
| 736 |
+
target_size,
|
| 737 |
+
color_mode,
|
| 738 |
+
data_format,
|
| 739 |
+
save_to_dir,
|
| 740 |
+
save_prefix,
|
| 741 |
+
save_format,
|
| 742 |
+
subset,
|
| 743 |
+
interpolation,
|
| 744 |
+
keep_aspect_ratio,
|
| 745 |
+
)
|
| 746 |
+
df = dataframe.copy()
|
| 747 |
+
self.directory = directory or ""
|
| 748 |
+
self.class_mode = class_mode
|
| 749 |
+
self.dtype = dtype
|
| 750 |
+
# check that inputs match the required class_mode
|
| 751 |
+
self._check_params(df, x_col, y_col, weight_col, classes)
|
| 752 |
+
if (
|
| 753 |
+
validate_filenames
|
| 754 |
+
): # check which image files are valid and keep them
|
| 755 |
+
df = self._filter_valid_filepaths(df, x_col)
|
| 756 |
+
if class_mode not in ["input", "multi_output", "raw", None]:
|
| 757 |
+
df, classes = self._filter_classes(df, y_col, classes)
|
| 758 |
+
num_classes = len(classes)
|
| 759 |
+
# build an index of all the unique classes
|
| 760 |
+
self.class_indices = dict(zip(classes, range(len(classes))))
|
| 761 |
+
# retrieve only training or validation set
|
| 762 |
+
if self.split:
|
| 763 |
+
num_files = len(df)
|
| 764 |
+
start = int(self.split[0] * num_files)
|
| 765 |
+
stop = int(self.split[1] * num_files)
|
| 766 |
+
df = df.iloc[start:stop, :]
|
| 767 |
+
# get labels for each observation
|
| 768 |
+
if class_mode not in ["input", "multi_output", "raw", None]:
|
| 769 |
+
self.classes = self.get_classes(df, y_col)
|
| 770 |
+
self.filenames = df[x_col].tolist()
|
| 771 |
+
self._sample_weight = df[weight_col].values if weight_col else None
|
| 772 |
+
|
| 773 |
+
if class_mode == "multi_output":
|
| 774 |
+
self._targets = [np.array(df[col].tolist()) for col in y_col]
|
| 775 |
+
if class_mode == "raw":
|
| 776 |
+
self._targets = df[y_col].values
|
| 777 |
+
self.samples = len(self.filenames)
|
| 778 |
+
validated_string = (
|
| 779 |
+
"validated" if validate_filenames else "non-validated"
|
| 780 |
+
)
|
| 781 |
+
if class_mode in ["input", "multi_output", "raw", None]:
|
| 782 |
+
io_utils.print_msg(
|
| 783 |
+
f"Found {self.samples} {validated_string} image filenames."
|
| 784 |
+
)
|
| 785 |
+
else:
|
| 786 |
+
io_utils.print_msg(
|
| 787 |
+
f"Found {self.samples} {validated_string} image filenames "
|
| 788 |
+
f"belonging to {num_classes} classes."
|
| 789 |
+
)
|
| 790 |
+
self._filepaths = [
|
| 791 |
+
os.path.join(self.directory, fname) for fname in self.filenames
|
| 792 |
+
]
|
| 793 |
+
super().__init__(self.samples, batch_size, shuffle, seed)
|
| 794 |
+
|
| 795 |
+
def _check_params(self, df, x_col, y_col, weight_col, classes):
|
| 796 |
+
# check class mode is one of the currently supported
|
| 797 |
+
if self.class_mode not in self.allowed_class_modes:
|
| 798 |
+
raise ValueError(
|
| 799 |
+
"Invalid class_mode: {}; expected one of: {}".format(
|
| 800 |
+
self.class_mode, self.allowed_class_modes
|
| 801 |
+
)
|
| 802 |
+
)
|
| 803 |
+
# check that y_col has several column names if class_mode is
|
| 804 |
+
# multi_output
|
| 805 |
+
if (self.class_mode == "multi_output") and not isinstance(y_col, list):
|
| 806 |
+
raise TypeError(
|
| 807 |
+
'If class_mode="{}", y_col must be a list. Received {}.'.format(
|
| 808 |
+
self.class_mode, type(y_col).__name__
|
| 809 |
+
)
|
| 810 |
+
)
|
| 811 |
+
# check that filenames/filepaths column values are all strings
|
| 812 |
+
if not all(df[x_col].apply(lambda x: isinstance(x, str))):
|
| 813 |
+
raise TypeError(
|
| 814 |
+
f"All values in column x_col={x_col} must be strings."
|
| 815 |
+
)
|
| 816 |
+
# check labels are string if class_mode is binary or sparse
|
| 817 |
+
if self.class_mode in {"binary", "sparse"}:
|
| 818 |
+
if not all(df[y_col].apply(lambda x: isinstance(x, str))):
|
| 819 |
+
raise TypeError(
|
| 820 |
+
'If class_mode="{}", y_col="{}" column '
|
| 821 |
+
"values must be strings.".format(self.class_mode, y_col)
|
| 822 |
+
)
|
| 823 |
+
# check that if binary there are only 2 different classes
|
| 824 |
+
if self.class_mode == "binary":
|
| 825 |
+
if classes:
|
| 826 |
+
classes = set(classes)
|
| 827 |
+
if len(classes) != 2:
|
| 828 |
+
raise ValueError(
|
| 829 |
+
'If class_mode="binary" there must be 2 '
|
| 830 |
+
"classes. {} class/es were given.".format(len(classes))
|
| 831 |
+
)
|
| 832 |
+
elif df[y_col].nunique() != 2:
|
| 833 |
+
raise ValueError(
|
| 834 |
+
'If class_mode="binary" there must be 2 classes. '
|
| 835 |
+
"Found {} classes.".format(df[y_col].nunique())
|
| 836 |
+
)
|
| 837 |
+
# check values are string, list or tuple if class_mode is categorical
|
| 838 |
+
if self.class_mode == "categorical":
|
| 839 |
+
types = (str, list, tuple)
|
| 840 |
+
if not all(df[y_col].apply(lambda x: isinstance(x, types))):
|
| 841 |
+
raise TypeError(
|
| 842 |
+
'If class_mode="{}", y_col="{}" column '
|
| 843 |
+
"values must be type string, list or tuple.".format(
|
| 844 |
+
self.class_mode, y_col
|
| 845 |
+
)
|
| 846 |
+
)
|
| 847 |
+
# raise warning if classes are given but will be unused
|
| 848 |
+
if classes and self.class_mode in {
|
| 849 |
+
"input",
|
| 850 |
+
"multi_output",
|
| 851 |
+
"raw",
|
| 852 |
+
None,
|
| 853 |
+
}:
|
| 854 |
+
warnings.warn(
|
| 855 |
+
'`classes` will be ignored given the class_mode="{}"'.format(
|
| 856 |
+
self.class_mode
|
| 857 |
+
)
|
| 858 |
+
)
|
| 859 |
+
# check that if weight column that the values are numerical
|
| 860 |
+
if weight_col and not issubclass(df[weight_col].dtype.type, np.number):
|
| 861 |
+
raise TypeError(f"Column weight_col={weight_col} must be numeric.")
|
| 862 |
+
|
| 863 |
+
def get_classes(self, df, y_col):
|
| 864 |
+
labels = []
|
| 865 |
+
for label in df[y_col]:
|
| 866 |
+
if isinstance(label, (list, tuple)):
|
| 867 |
+
labels.append([self.class_indices[lbl] for lbl in label])
|
| 868 |
+
else:
|
| 869 |
+
labels.append(self.class_indices[label])
|
| 870 |
+
return labels
|
| 871 |
+
|
| 872 |
+
@staticmethod
|
| 873 |
+
def _filter_classes(df, y_col, classes):
|
| 874 |
+
df = df.copy()
|
| 875 |
+
|
| 876 |
+
def remove_classes(labels, classes):
|
| 877 |
+
if isinstance(labels, (list, tuple)):
|
| 878 |
+
labels = [cls for cls in labels if cls in classes]
|
| 879 |
+
return labels or None
|
| 880 |
+
elif isinstance(labels, str):
|
| 881 |
+
return labels if labels in classes else None
|
| 882 |
+
else:
|
| 883 |
+
raise TypeError(
|
| 884 |
+
"Expect string, list or tuple "
|
| 885 |
+
"but found {} in {} column ".format(type(labels), y_col)
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
if classes:
|
| 889 |
+
# prepare for membership lookup
|
| 890 |
+
classes = list(collections.OrderedDict.fromkeys(classes).keys())
|
| 891 |
+
df[y_col] = df[y_col].apply(lambda x: remove_classes(x, classes))
|
| 892 |
+
else:
|
| 893 |
+
classes = set()
|
| 894 |
+
for v in df[y_col]:
|
| 895 |
+
if isinstance(v, (list, tuple)):
|
| 896 |
+
classes.update(v)
|
| 897 |
+
else:
|
| 898 |
+
classes.add(v)
|
| 899 |
+
classes = sorted(classes)
|
| 900 |
+
return df.dropna(subset=[y_col]), classes
|
| 901 |
+
|
| 902 |
+
def _filter_valid_filepaths(self, df, x_col):
|
| 903 |
+
"""Keep only dataframe rows with valid filenames.
|
| 904 |
+
|
| 905 |
+
Args:
|
| 906 |
+
df: Pandas dataframe containing filenames in a column
|
| 907 |
+
x_col: string, column in `df` that contains the filenames or
|
| 908 |
+
filepaths
|
| 909 |
+
Returns:
|
| 910 |
+
absolute paths to image files
|
| 911 |
+
"""
|
| 912 |
+
filepaths = df[x_col].map(
|
| 913 |
+
lambda fname: os.path.join(self.directory, fname)
|
| 914 |
+
)
|
| 915 |
+
mask = filepaths.apply(
|
| 916 |
+
validate_filename, args=(self.white_list_formats,)
|
| 917 |
+
)
|
| 918 |
+
n_invalid = (~mask).sum()
|
| 919 |
+
if n_invalid:
|
| 920 |
+
warnings.warn(
|
| 921 |
+
'Found {} invalid image filename(s) in x_col="{}". '
|
| 922 |
+
"These filename(s) will be ignored.".format(n_invalid, x_col)
|
| 923 |
+
)
|
| 924 |
+
return df[mask]
|
| 925 |
+
|
| 926 |
+
@property
|
| 927 |
+
def filepaths(self):
|
| 928 |
+
return self._filepaths
|
| 929 |
+
|
| 930 |
+
@property
|
| 931 |
+
def labels(self):
|
| 932 |
+
if self.class_mode in {"multi_output", "raw"}:
|
| 933 |
+
return self._targets
|
| 934 |
+
else:
|
| 935 |
+
return self.classes
|
| 936 |
+
|
| 937 |
+
@property
|
| 938 |
+
def sample_weight(self):
|
| 939 |
+
return self._sample_weight
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
def flip_axis(x, axis):
|
| 943 |
+
x = np.asarray(x).swapaxes(axis, 0)
|
| 944 |
+
x = x[::-1, ...]
|
| 945 |
+
x = x.swapaxes(0, axis)
|
| 946 |
+
return x
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
@keras_export("keras._legacy.preprocessing.image.ImageDataGenerator")
|
| 950 |
+
class ImageDataGenerator:
|
| 951 |
+
"""DEPRECATED."""
|
| 952 |
+
|
| 953 |
+
def __init__(
|
| 954 |
+
self,
|
| 955 |
+
featurewise_center=False,
|
| 956 |
+
samplewise_center=False,
|
| 957 |
+
featurewise_std_normalization=False,
|
| 958 |
+
samplewise_std_normalization=False,
|
| 959 |
+
zca_whitening=False,
|
| 960 |
+
zca_epsilon=1e-6,
|
| 961 |
+
rotation_range=0,
|
| 962 |
+
width_shift_range=0.0,
|
| 963 |
+
height_shift_range=0.0,
|
| 964 |
+
brightness_range=None,
|
| 965 |
+
shear_range=0.0,
|
| 966 |
+
zoom_range=0.0,
|
| 967 |
+
channel_shift_range=0.0,
|
| 968 |
+
fill_mode="nearest",
|
| 969 |
+
cval=0.0,
|
| 970 |
+
horizontal_flip=False,
|
| 971 |
+
vertical_flip=False,
|
| 972 |
+
rescale=None,
|
| 973 |
+
preprocessing_function=None,
|
| 974 |
+
data_format=None,
|
| 975 |
+
validation_split=0.0,
|
| 976 |
+
interpolation_order=1,
|
| 977 |
+
dtype=None,
|
| 978 |
+
):
|
| 979 |
+
if data_format is None:
|
| 980 |
+
data_format = backend.image_data_format()
|
| 981 |
+
if dtype is None:
|
| 982 |
+
dtype = backend.floatx()
|
| 983 |
+
|
| 984 |
+
self.featurewise_center = featurewise_center
|
| 985 |
+
self.samplewise_center = samplewise_center
|
| 986 |
+
self.featurewise_std_normalization = featurewise_std_normalization
|
| 987 |
+
self.samplewise_std_normalization = samplewise_std_normalization
|
| 988 |
+
self.zca_whitening = zca_whitening
|
| 989 |
+
self.zca_epsilon = zca_epsilon
|
| 990 |
+
self.rotation_range = rotation_range
|
| 991 |
+
self.width_shift_range = width_shift_range
|
| 992 |
+
self.height_shift_range = height_shift_range
|
| 993 |
+
self.shear_range = shear_range
|
| 994 |
+
self.zoom_range = zoom_range
|
| 995 |
+
self.channel_shift_range = channel_shift_range
|
| 996 |
+
self.fill_mode = fill_mode
|
| 997 |
+
self.cval = cval
|
| 998 |
+
self.horizontal_flip = horizontal_flip
|
| 999 |
+
self.vertical_flip = vertical_flip
|
| 1000 |
+
self.rescale = rescale
|
| 1001 |
+
self.preprocessing_function = preprocessing_function
|
| 1002 |
+
self.dtype = dtype
|
| 1003 |
+
self.interpolation_order = interpolation_order
|
| 1004 |
+
|
| 1005 |
+
if data_format not in {"channels_last", "channels_first"}:
|
| 1006 |
+
raise ValueError(
|
| 1007 |
+
'`data_format` should be `"channels_last"` '
|
| 1008 |
+
"(channel after row and column) or "
|
| 1009 |
+
'`"channels_first"` (channel before row and column). '
|
| 1010 |
+
f"Received: {data_format}"
|
| 1011 |
+
)
|
| 1012 |
+
self.data_format = data_format
|
| 1013 |
+
if data_format == "channels_first":
|
| 1014 |
+
self.channel_axis = 1
|
| 1015 |
+
self.row_axis = 2
|
| 1016 |
+
self.col_axis = 3
|
| 1017 |
+
if data_format == "channels_last":
|
| 1018 |
+
self.channel_axis = 3
|
| 1019 |
+
self.row_axis = 1
|
| 1020 |
+
self.col_axis = 2
|
| 1021 |
+
if validation_split and not 0 < validation_split < 1:
|
| 1022 |
+
raise ValueError(
|
| 1023 |
+
"`validation_split` must be strictly between 0 and 1. "
|
| 1024 |
+
f" Received: {validation_split}"
|
| 1025 |
+
)
|
| 1026 |
+
self._validation_split = validation_split
|
| 1027 |
+
|
| 1028 |
+
self.mean = None
|
| 1029 |
+
self.std = None
|
| 1030 |
+
self.zca_whitening_matrix = None
|
| 1031 |
+
|
| 1032 |
+
if isinstance(zoom_range, (float, int)):
|
| 1033 |
+
self.zoom_range = [1 - zoom_range, 1 + zoom_range]
|
| 1034 |
+
elif len(zoom_range) == 2 and all(
|
| 1035 |
+
isinstance(val, (float, int)) for val in zoom_range
|
| 1036 |
+
):
|
| 1037 |
+
self.zoom_range = [zoom_range[0], zoom_range[1]]
|
| 1038 |
+
else:
|
| 1039 |
+
raise ValueError(
|
| 1040 |
+
"`zoom_range` should be a float or "
|
| 1041 |
+
"a tuple or list of two floats. "
|
| 1042 |
+
f"Received: {zoom_range}"
|
| 1043 |
+
)
|
| 1044 |
+
if zca_whitening:
|
| 1045 |
+
if not featurewise_center:
|
| 1046 |
+
self.featurewise_center = True
|
| 1047 |
+
warnings.warn(
|
| 1048 |
+
"This ImageDataGenerator specifies "
|
| 1049 |
+
"`zca_whitening`, which overrides "
|
| 1050 |
+
"setting of `featurewise_center`."
|
| 1051 |
+
)
|
| 1052 |
+
if featurewise_std_normalization:
|
| 1053 |
+
self.featurewise_std_normalization = False
|
| 1054 |
+
warnings.warn(
|
| 1055 |
+
"This ImageDataGenerator specifies "
|
| 1056 |
+
"`zca_whitening` "
|
| 1057 |
+
"which overrides setting of"
|
| 1058 |
+
"`featurewise_std_normalization`."
|
| 1059 |
+
)
|
| 1060 |
+
if featurewise_std_normalization:
|
| 1061 |
+
if not featurewise_center:
|
| 1062 |
+
self.featurewise_center = True
|
| 1063 |
+
warnings.warn(
|
| 1064 |
+
"This ImageDataGenerator specifies "
|
| 1065 |
+
"`featurewise_std_normalization`, "
|
| 1066 |
+
"which overrides setting of "
|
| 1067 |
+
"`featurewise_center`."
|
| 1068 |
+
)
|
| 1069 |
+
if samplewise_std_normalization:
|
| 1070 |
+
if not samplewise_center:
|
| 1071 |
+
self.samplewise_center = True
|
| 1072 |
+
warnings.warn(
|
| 1073 |
+
"This ImageDataGenerator specifies "
|
| 1074 |
+
"`samplewise_std_normalization`, "
|
| 1075 |
+
"which overrides setting of "
|
| 1076 |
+
"`samplewise_center`."
|
| 1077 |
+
)
|
| 1078 |
+
if brightness_range is not None:
|
| 1079 |
+
if (
|
| 1080 |
+
not isinstance(brightness_range, (tuple, list))
|
| 1081 |
+
or len(brightness_range) != 2
|
| 1082 |
+
):
|
| 1083 |
+
raise ValueError(
|
| 1084 |
+
"`brightness_range should be tuple or list of two floats. "
|
| 1085 |
+
f"Received: {brightness_range}"
|
| 1086 |
+
)
|
| 1087 |
+
self.brightness_range = brightness_range
|
| 1088 |
+
|
| 1089 |
+
def flow(
|
| 1090 |
+
self,
|
| 1091 |
+
x,
|
| 1092 |
+
y=None,
|
| 1093 |
+
batch_size=32,
|
| 1094 |
+
shuffle=True,
|
| 1095 |
+
sample_weight=None,
|
| 1096 |
+
seed=None,
|
| 1097 |
+
save_to_dir=None,
|
| 1098 |
+
save_prefix="",
|
| 1099 |
+
save_format="png",
|
| 1100 |
+
ignore_class_split=False,
|
| 1101 |
+
subset=None,
|
| 1102 |
+
):
|
| 1103 |
+
return NumpyArrayIterator(
|
| 1104 |
+
x,
|
| 1105 |
+
y,
|
| 1106 |
+
self,
|
| 1107 |
+
batch_size=batch_size,
|
| 1108 |
+
shuffle=shuffle,
|
| 1109 |
+
sample_weight=sample_weight,
|
| 1110 |
+
seed=seed,
|
| 1111 |
+
data_format=self.data_format,
|
| 1112 |
+
save_to_dir=save_to_dir,
|
| 1113 |
+
save_prefix=save_prefix,
|
| 1114 |
+
save_format=save_format,
|
| 1115 |
+
ignore_class_split=ignore_class_split,
|
| 1116 |
+
subset=subset,
|
| 1117 |
+
dtype=self.dtype,
|
| 1118 |
+
)
|
| 1119 |
+
|
| 1120 |
+
def flow_from_directory(
|
| 1121 |
+
self,
|
| 1122 |
+
directory,
|
| 1123 |
+
target_size=(256, 256),
|
| 1124 |
+
color_mode="rgb",
|
| 1125 |
+
classes=None,
|
| 1126 |
+
class_mode="categorical",
|
| 1127 |
+
batch_size=32,
|
| 1128 |
+
shuffle=True,
|
| 1129 |
+
seed=None,
|
| 1130 |
+
save_to_dir=None,
|
| 1131 |
+
save_prefix="",
|
| 1132 |
+
save_format="png",
|
| 1133 |
+
follow_links=False,
|
| 1134 |
+
subset=None,
|
| 1135 |
+
interpolation="nearest",
|
| 1136 |
+
keep_aspect_ratio=False,
|
| 1137 |
+
):
|
| 1138 |
+
return DirectoryIterator(
|
| 1139 |
+
directory,
|
| 1140 |
+
self,
|
| 1141 |
+
target_size=target_size,
|
| 1142 |
+
color_mode=color_mode,
|
| 1143 |
+
keep_aspect_ratio=keep_aspect_ratio,
|
| 1144 |
+
classes=classes,
|
| 1145 |
+
class_mode=class_mode,
|
| 1146 |
+
data_format=self.data_format,
|
| 1147 |
+
batch_size=batch_size,
|
| 1148 |
+
shuffle=shuffle,
|
| 1149 |
+
seed=seed,
|
| 1150 |
+
save_to_dir=save_to_dir,
|
| 1151 |
+
save_prefix=save_prefix,
|
| 1152 |
+
save_format=save_format,
|
| 1153 |
+
follow_links=follow_links,
|
| 1154 |
+
subset=subset,
|
| 1155 |
+
interpolation=interpolation,
|
| 1156 |
+
dtype=self.dtype,
|
| 1157 |
+
)
|
| 1158 |
+
|
| 1159 |
+
def flow_from_dataframe(
|
| 1160 |
+
self,
|
| 1161 |
+
dataframe,
|
| 1162 |
+
directory=None,
|
| 1163 |
+
x_col="filename",
|
| 1164 |
+
y_col="class",
|
| 1165 |
+
weight_col=None,
|
| 1166 |
+
target_size=(256, 256),
|
| 1167 |
+
color_mode="rgb",
|
| 1168 |
+
classes=None,
|
| 1169 |
+
class_mode="categorical",
|
| 1170 |
+
batch_size=32,
|
| 1171 |
+
shuffle=True,
|
| 1172 |
+
seed=None,
|
| 1173 |
+
save_to_dir=None,
|
| 1174 |
+
save_prefix="",
|
| 1175 |
+
save_format="png",
|
| 1176 |
+
subset=None,
|
| 1177 |
+
interpolation="nearest",
|
| 1178 |
+
validate_filenames=True,
|
| 1179 |
+
**kwargs,
|
| 1180 |
+
):
|
| 1181 |
+
if "has_ext" in kwargs:
|
| 1182 |
+
warnings.warn(
|
| 1183 |
+
"has_ext is deprecated, filenames in the dataframe have "
|
| 1184 |
+
"to match the exact filenames in disk.",
|
| 1185 |
+
DeprecationWarning,
|
| 1186 |
+
)
|
| 1187 |
+
if "sort" in kwargs:
|
| 1188 |
+
warnings.warn(
|
| 1189 |
+
"sort is deprecated, batches will be created in the"
|
| 1190 |
+
"same order than the filenames provided if `shuffle`"
|
| 1191 |
+
"is set to `False`.",
|
| 1192 |
+
DeprecationWarning,
|
| 1193 |
+
)
|
| 1194 |
+
if class_mode == "other":
|
| 1195 |
+
warnings.warn(
|
| 1196 |
+
'`class_mode="other"` is deprecated, please use '
|
| 1197 |
+
'`class_mode="raw"`.',
|
| 1198 |
+
DeprecationWarning,
|
| 1199 |
+
)
|
| 1200 |
+
class_mode = "raw"
|
| 1201 |
+
if "drop_duplicates" in kwargs:
|
| 1202 |
+
warnings.warn(
|
| 1203 |
+
"drop_duplicates is deprecated, you can drop duplicates "
|
| 1204 |
+
"by using the pandas.DataFrame.drop_duplicates method.",
|
| 1205 |
+
DeprecationWarning,
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
return DataFrameIterator(
|
| 1209 |
+
dataframe,
|
| 1210 |
+
directory,
|
| 1211 |
+
self,
|
| 1212 |
+
x_col=x_col,
|
| 1213 |
+
y_col=y_col,
|
| 1214 |
+
weight_col=weight_col,
|
| 1215 |
+
target_size=target_size,
|
| 1216 |
+
color_mode=color_mode,
|
| 1217 |
+
classes=classes,
|
| 1218 |
+
class_mode=class_mode,
|
| 1219 |
+
data_format=self.data_format,
|
| 1220 |
+
batch_size=batch_size,
|
| 1221 |
+
shuffle=shuffle,
|
| 1222 |
+
seed=seed,
|
| 1223 |
+
save_to_dir=save_to_dir,
|
| 1224 |
+
save_prefix=save_prefix,
|
| 1225 |
+
save_format=save_format,
|
| 1226 |
+
subset=subset,
|
| 1227 |
+
interpolation=interpolation,
|
| 1228 |
+
validate_filenames=validate_filenames,
|
| 1229 |
+
dtype=self.dtype,
|
| 1230 |
+
)
|
| 1231 |
+
|
| 1232 |
+
def standardize(self, x):
|
| 1233 |
+
"""Applies the normalization configuration in-place to a batch of
|
| 1234 |
+
inputs.
|
| 1235 |
+
|
| 1236 |
+
`x` is changed in-place since the function is mainly used internally
|
| 1237 |
+
to standardize images and feed them to your network. If a copy of `x`
|
| 1238 |
+
would be created instead it would have a significant performance cost.
|
| 1239 |
+
If you want to apply this method without changing the input in-place
|
| 1240 |
+
you can call the method creating a copy before:
|
| 1241 |
+
|
| 1242 |
+
standardize(np.copy(x))
|
| 1243 |
+
|
| 1244 |
+
Args:
|
| 1245 |
+
x: Batch of inputs to be normalized.
|
| 1246 |
+
|
| 1247 |
+
Returns:
|
| 1248 |
+
The inputs, normalized.
|
| 1249 |
+
"""
|
| 1250 |
+
if self.preprocessing_function:
|
| 1251 |
+
x = self.preprocessing_function(x)
|
| 1252 |
+
if self.rescale:
|
| 1253 |
+
x *= self.rescale
|
| 1254 |
+
if self.samplewise_center:
|
| 1255 |
+
x -= np.mean(x, keepdims=True)
|
| 1256 |
+
if self.samplewise_std_normalization:
|
| 1257 |
+
x /= np.std(x, keepdims=True) + 1e-6
|
| 1258 |
+
|
| 1259 |
+
if self.featurewise_center:
|
| 1260 |
+
if self.mean is not None:
|
| 1261 |
+
x -= self.mean
|
| 1262 |
+
else:
|
| 1263 |
+
warnings.warn(
|
| 1264 |
+
"This ImageDataGenerator specifies "
|
| 1265 |
+
"`featurewise_center`, but it hasn't "
|
| 1266 |
+
"been fit on any training data. Fit it "
|
| 1267 |
+
"first by calling `.fit(numpy_data)`."
|
| 1268 |
+
)
|
| 1269 |
+
if self.featurewise_std_normalization:
|
| 1270 |
+
if self.std is not None:
|
| 1271 |
+
x /= self.std + 1e-6
|
| 1272 |
+
else:
|
| 1273 |
+
warnings.warn(
|
| 1274 |
+
"This ImageDataGenerator specifies "
|
| 1275 |
+
"`featurewise_std_normalization`, "
|
| 1276 |
+
"but it hasn't "
|
| 1277 |
+
"been fit on any training data. Fit it "
|
| 1278 |
+
"first by calling `.fit(numpy_data)`."
|
| 1279 |
+
)
|
| 1280 |
+
if self.zca_whitening:
|
| 1281 |
+
if self.zca_whitening_matrix is not None:
|
| 1282 |
+
flat_x = x.reshape(-1, np.prod(x.shape[-3:]))
|
| 1283 |
+
white_x = flat_x @ self.zca_whitening_matrix
|
| 1284 |
+
x = np.reshape(white_x, x.shape)
|
| 1285 |
+
else:
|
| 1286 |
+
warnings.warn(
|
| 1287 |
+
"This ImageDataGenerator specifies "
|
| 1288 |
+
"`zca_whitening`, but it hasn't "
|
| 1289 |
+
"been fit on any training data. Fit it "
|
| 1290 |
+
"first by calling `.fit(numpy_data)`."
|
| 1291 |
+
)
|
| 1292 |
+
return x
|
| 1293 |
+
|
| 1294 |
+
def get_random_transform(self, img_shape, seed=None):
|
| 1295 |
+
"""Generates random parameters for a transformation.
|
| 1296 |
+
|
| 1297 |
+
Args:
|
| 1298 |
+
img_shape: Tuple of integers.
|
| 1299 |
+
Shape of the image that is transformed.
|
| 1300 |
+
seed: Random seed.
|
| 1301 |
+
|
| 1302 |
+
Returns:
|
| 1303 |
+
A dictionary containing randomly chosen parameters describing the
|
| 1304 |
+
transformation.
|
| 1305 |
+
"""
|
| 1306 |
+
img_row_axis = self.row_axis - 1
|
| 1307 |
+
img_col_axis = self.col_axis - 1
|
| 1308 |
+
|
| 1309 |
+
if seed is not None:
|
| 1310 |
+
np.random.seed(seed)
|
| 1311 |
+
|
| 1312 |
+
if self.rotation_range:
|
| 1313 |
+
theta = np.random.uniform(-self.rotation_range, self.rotation_range)
|
| 1314 |
+
else:
|
| 1315 |
+
theta = 0
|
| 1316 |
+
|
| 1317 |
+
if self.height_shift_range:
|
| 1318 |
+
try: # 1-D array-like or int
|
| 1319 |
+
tx = np.random.choice(self.height_shift_range)
|
| 1320 |
+
tx *= np.random.choice([-1, 1])
|
| 1321 |
+
except ValueError: # floating point
|
| 1322 |
+
tx = np.random.uniform(
|
| 1323 |
+
-self.height_shift_range, self.height_shift_range
|
| 1324 |
+
)
|
| 1325 |
+
if np.max(self.height_shift_range) < 1:
|
| 1326 |
+
tx *= img_shape[img_row_axis]
|
| 1327 |
+
else:
|
| 1328 |
+
tx = 0
|
| 1329 |
+
|
| 1330 |
+
if self.width_shift_range:
|
| 1331 |
+
try: # 1-D array-like or int
|
| 1332 |
+
ty = np.random.choice(self.width_shift_range)
|
| 1333 |
+
ty *= np.random.choice([-1, 1])
|
| 1334 |
+
except ValueError: # floating point
|
| 1335 |
+
ty = np.random.uniform(
|
| 1336 |
+
-self.width_shift_range, self.width_shift_range
|
| 1337 |
+
)
|
| 1338 |
+
if np.max(self.width_shift_range) < 1:
|
| 1339 |
+
ty *= img_shape[img_col_axis]
|
| 1340 |
+
else:
|
| 1341 |
+
ty = 0
|
| 1342 |
+
|
| 1343 |
+
if self.shear_range:
|
| 1344 |
+
shear = np.random.uniform(-self.shear_range, self.shear_range)
|
| 1345 |
+
else:
|
| 1346 |
+
shear = 0
|
| 1347 |
+
|
| 1348 |
+
if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:
|
| 1349 |
+
zx, zy = 1, 1
|
| 1350 |
+
else:
|
| 1351 |
+
zx, zy = np.random.uniform(
|
| 1352 |
+
self.zoom_range[0], self.zoom_range[1], 2
|
| 1353 |
+
)
|
| 1354 |
+
|
| 1355 |
+
flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip
|
| 1356 |
+
flip_vertical = (np.random.random() < 0.5) * self.vertical_flip
|
| 1357 |
+
|
| 1358 |
+
channel_shift_intensity = None
|
| 1359 |
+
if self.channel_shift_range != 0:
|
| 1360 |
+
channel_shift_intensity = np.random.uniform(
|
| 1361 |
+
-self.channel_shift_range, self.channel_shift_range
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
brightness = None
|
| 1365 |
+
if self.brightness_range is not None:
|
| 1366 |
+
brightness = np.random.uniform(
|
| 1367 |
+
self.brightness_range[0], self.brightness_range[1]
|
| 1368 |
+
)
|
| 1369 |
+
|
| 1370 |
+
transform_parameters = {
|
| 1371 |
+
"theta": theta,
|
| 1372 |
+
"tx": tx,
|
| 1373 |
+
"ty": ty,
|
| 1374 |
+
"shear": shear,
|
| 1375 |
+
"zx": zx,
|
| 1376 |
+
"zy": zy,
|
| 1377 |
+
"flip_horizontal": flip_horizontal,
|
| 1378 |
+
"flip_vertical": flip_vertical,
|
| 1379 |
+
"channel_shift_intensity": channel_shift_intensity,
|
| 1380 |
+
"brightness": brightness,
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
return transform_parameters
|
| 1384 |
+
|
| 1385 |
+
def apply_transform(self, x, transform_parameters):
|
| 1386 |
+
"""Applies a transformation to an image according to given parameters.
|
| 1387 |
+
|
| 1388 |
+
Args:
|
| 1389 |
+
x: 3D tensor, single image.
|
| 1390 |
+
transform_parameters: Dictionary with string - parameter pairs
|
| 1391 |
+
describing the transformation.
|
| 1392 |
+
Currently, the following parameters
|
| 1393 |
+
from the dictionary are used:
|
| 1394 |
+
- `'theta'`: Float. Rotation angle in degrees.
|
| 1395 |
+
- `'tx'`: Float. Shift in the x direction.
|
| 1396 |
+
- `'ty'`: Float. Shift in the y direction.
|
| 1397 |
+
- `'shear'`: Float. Shear angle in degrees.
|
| 1398 |
+
- `'zx'`: Float. Zoom in the x direction.
|
| 1399 |
+
- `'zy'`: Float. Zoom in the y direction.
|
| 1400 |
+
- `'flip_horizontal'`: Boolean. Horizontal flip.
|
| 1401 |
+
- `'flip_vertical'`: Boolean. Vertical flip.
|
| 1402 |
+
- `'channel_shift_intensity'`: Float. Channel shift intensity.
|
| 1403 |
+
- `'brightness'`: Float. Brightness shift intensity.
|
| 1404 |
+
|
| 1405 |
+
Returns:
|
| 1406 |
+
A transformed version of the input (same shape).
|
| 1407 |
+
"""
|
| 1408 |
+
# x is a single image, so it doesn't have image number at index 0
|
| 1409 |
+
img_row_axis = self.row_axis - 1
|
| 1410 |
+
img_col_axis = self.col_axis - 1
|
| 1411 |
+
img_channel_axis = self.channel_axis - 1
|
| 1412 |
+
|
| 1413 |
+
x = apply_affine_transform(
|
| 1414 |
+
x,
|
| 1415 |
+
transform_parameters.get("theta", 0),
|
| 1416 |
+
transform_parameters.get("tx", 0),
|
| 1417 |
+
transform_parameters.get("ty", 0),
|
| 1418 |
+
transform_parameters.get("shear", 0),
|
| 1419 |
+
transform_parameters.get("zx", 1),
|
| 1420 |
+
transform_parameters.get("zy", 1),
|
| 1421 |
+
row_axis=img_row_axis,
|
| 1422 |
+
col_axis=img_col_axis,
|
| 1423 |
+
channel_axis=img_channel_axis,
|
| 1424 |
+
fill_mode=self.fill_mode,
|
| 1425 |
+
cval=self.cval,
|
| 1426 |
+
order=self.interpolation_order,
|
| 1427 |
+
)
|
| 1428 |
+
|
| 1429 |
+
if transform_parameters.get("channel_shift_intensity") is not None:
|
| 1430 |
+
x = apply_channel_shift(
|
| 1431 |
+
x,
|
| 1432 |
+
transform_parameters["channel_shift_intensity"],
|
| 1433 |
+
img_channel_axis,
|
| 1434 |
+
)
|
| 1435 |
+
|
| 1436 |
+
if transform_parameters.get("flip_horizontal", False):
|
| 1437 |
+
x = flip_axis(x, img_col_axis)
|
| 1438 |
+
|
| 1439 |
+
if transform_parameters.get("flip_vertical", False):
|
| 1440 |
+
x = flip_axis(x, img_row_axis)
|
| 1441 |
+
|
| 1442 |
+
if transform_parameters.get("brightness") is not None:
|
| 1443 |
+
x = apply_brightness_shift(
|
| 1444 |
+
x, transform_parameters["brightness"], False
|
| 1445 |
+
)
|
| 1446 |
+
|
| 1447 |
+
return x
|
| 1448 |
+
|
| 1449 |
+
def random_transform(self, x, seed=None):
|
| 1450 |
+
"""Applies a random transformation to an image.
|
| 1451 |
+
|
| 1452 |
+
Args:
|
| 1453 |
+
x: 3D tensor, single image.
|
| 1454 |
+
seed: Random seed.
|
| 1455 |
+
|
| 1456 |
+
Returns:
|
| 1457 |
+
A randomly transformed version of the input (same shape).
|
| 1458 |
+
"""
|
| 1459 |
+
params = self.get_random_transform(x.shape, seed)
|
| 1460 |
+
return self.apply_transform(x, params)
|
| 1461 |
+
|
| 1462 |
+
def fit(self, x, augment=False, rounds=1, seed=None):
|
| 1463 |
+
"""Fits the data generator to some sample data.
|
| 1464 |
+
|
| 1465 |
+
This computes the internal data stats related to the
|
| 1466 |
+
data-dependent transformations, based on an array of sample data.
|
| 1467 |
+
|
| 1468 |
+
Only required if `featurewise_center` or
|
| 1469 |
+
`featurewise_std_normalization` or `zca_whitening`
|
| 1470 |
+
are set to `True`.
|
| 1471 |
+
|
| 1472 |
+
When `rescale` is set to a value, rescaling is applied to
|
| 1473 |
+
sample data before computing the internal data stats.
|
| 1474 |
+
|
| 1475 |
+
Args:
|
| 1476 |
+
x: Sample data. Should have rank 4.
|
| 1477 |
+
In case of grayscale data,
|
| 1478 |
+
the channels axis should have value 1, in case
|
| 1479 |
+
of RGB data, it should have value 3, and in case
|
| 1480 |
+
of RGBA data, it should have value 4.
|
| 1481 |
+
augment: Boolean (default: False).
|
| 1482 |
+
Whether to fit on randomly augmented samples.
|
| 1483 |
+
rounds: Int (default: 1).
|
| 1484 |
+
If using data augmentation (`augment=True`),
|
| 1485 |
+
this is how many augmentation passes over the data to use.
|
| 1486 |
+
seed: Int (default: None). Random seed.
|
| 1487 |
+
"""
|
| 1488 |
+
x = np.asarray(x, dtype=self.dtype)
|
| 1489 |
+
if x.ndim != 4:
|
| 1490 |
+
raise ValueError(
|
| 1491 |
+
"Input to `.fit()` should have rank 4. Got array with shape: "
|
| 1492 |
+
+ str(x.shape)
|
| 1493 |
+
)
|
| 1494 |
+
if x.shape[self.channel_axis] not in {1, 3, 4}:
|
| 1495 |
+
warnings.warn(
|
| 1496 |
+
"Expected input to be images (as Numpy array) "
|
| 1497 |
+
'following the data format convention "'
|
| 1498 |
+
+ self.data_format
|
| 1499 |
+
+ '" (channels on axis '
|
| 1500 |
+
+ str(self.channel_axis)
|
| 1501 |
+
+ "), i.e. expected either 1, 3 or 4 channels on axis "
|
| 1502 |
+
+ str(self.channel_axis)
|
| 1503 |
+
+ ". However, it was passed an array with shape "
|
| 1504 |
+
+ str(x.shape)
|
| 1505 |
+
+ " ("
|
| 1506 |
+
+ str(x.shape[self.channel_axis])
|
| 1507 |
+
+ " channels)."
|
| 1508 |
+
)
|
| 1509 |
+
|
| 1510 |
+
if seed is not None:
|
| 1511 |
+
np.random.seed(seed)
|
| 1512 |
+
|
| 1513 |
+
x = np.copy(x)
|
| 1514 |
+
if self.rescale:
|
| 1515 |
+
x *= self.rescale
|
| 1516 |
+
|
| 1517 |
+
if augment:
|
| 1518 |
+
ax = np.zeros(
|
| 1519 |
+
tuple([rounds * x.shape[0]] + list(x.shape)[1:]),
|
| 1520 |
+
dtype=self.dtype,
|
| 1521 |
+
)
|
| 1522 |
+
for r in range(rounds):
|
| 1523 |
+
for i in range(x.shape[0]):
|
| 1524 |
+
ax[i + r * x.shape[0]] = self.random_transform(x[i])
|
| 1525 |
+
x = ax
|
| 1526 |
+
|
| 1527 |
+
if self.featurewise_center:
|
| 1528 |
+
self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))
|
| 1529 |
+
broadcast_shape = [1, 1, 1]
|
| 1530 |
+
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
|
| 1531 |
+
self.mean = np.reshape(self.mean, broadcast_shape)
|
| 1532 |
+
x -= self.mean
|
| 1533 |
+
|
| 1534 |
+
if self.featurewise_std_normalization:
|
| 1535 |
+
self.std = np.std(x, axis=(0, self.row_axis, self.col_axis))
|
| 1536 |
+
broadcast_shape = [1, 1, 1]
|
| 1537 |
+
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
|
| 1538 |
+
self.std = np.reshape(self.std, broadcast_shape)
|
| 1539 |
+
x /= self.std + 1e-6
|
| 1540 |
+
|
| 1541 |
+
if self.zca_whitening:
|
| 1542 |
+
n = len(x)
|
| 1543 |
+
flat_x = np.reshape(x, (n, -1))
|
| 1544 |
+
|
| 1545 |
+
u, s, _ = np.linalg.svd(flat_x.T, full_matrices=False)
|
| 1546 |
+
s_inv = np.sqrt(n) / (s + self.zca_epsilon)
|
| 1547 |
+
self.zca_whitening_matrix = (u * s_inv).dot(u.T)
|
| 1548 |
+
|
| 1549 |
+
|
| 1550 |
+
@keras_export("keras._legacy.preprocessing.image.random_rotation")
|
| 1551 |
+
def random_rotation(
|
| 1552 |
+
x,
|
| 1553 |
+
rg,
|
| 1554 |
+
row_axis=1,
|
| 1555 |
+
col_axis=2,
|
| 1556 |
+
channel_axis=0,
|
| 1557 |
+
fill_mode="nearest",
|
| 1558 |
+
cval=0.0,
|
| 1559 |
+
interpolation_order=1,
|
| 1560 |
+
):
|
| 1561 |
+
"""DEPRECATED."""
|
| 1562 |
+
theta = np.random.uniform(-rg, rg)
|
| 1563 |
+
x = apply_affine_transform(
|
| 1564 |
+
x,
|
| 1565 |
+
theta=theta,
|
| 1566 |
+
row_axis=row_axis,
|
| 1567 |
+
col_axis=col_axis,
|
| 1568 |
+
channel_axis=channel_axis,
|
| 1569 |
+
fill_mode=fill_mode,
|
| 1570 |
+
cval=cval,
|
| 1571 |
+
order=interpolation_order,
|
| 1572 |
+
)
|
| 1573 |
+
return x
|
| 1574 |
+
|
| 1575 |
+
|
| 1576 |
+
@keras_export("keras._legacy.preprocessing.image.random_shift")
|
| 1577 |
+
def random_shift(
|
| 1578 |
+
x,
|
| 1579 |
+
wrg,
|
| 1580 |
+
hrg,
|
| 1581 |
+
row_axis=1,
|
| 1582 |
+
col_axis=2,
|
| 1583 |
+
channel_axis=0,
|
| 1584 |
+
fill_mode="nearest",
|
| 1585 |
+
cval=0.0,
|
| 1586 |
+
interpolation_order=1,
|
| 1587 |
+
):
|
| 1588 |
+
"""DEPRECATED."""
|
| 1589 |
+
h, w = x.shape[row_axis], x.shape[col_axis]
|
| 1590 |
+
tx = np.random.uniform(-hrg, hrg) * h
|
| 1591 |
+
ty = np.random.uniform(-wrg, wrg) * w
|
| 1592 |
+
x = apply_affine_transform(
|
| 1593 |
+
x,
|
| 1594 |
+
tx=tx,
|
| 1595 |
+
ty=ty,
|
| 1596 |
+
row_axis=row_axis,
|
| 1597 |
+
col_axis=col_axis,
|
| 1598 |
+
channel_axis=channel_axis,
|
| 1599 |
+
fill_mode=fill_mode,
|
| 1600 |
+
cval=cval,
|
| 1601 |
+
order=interpolation_order,
|
| 1602 |
+
)
|
| 1603 |
+
return x
|
| 1604 |
+
|
| 1605 |
+
|
| 1606 |
+
@keras_export("keras._legacy.preprocessing.image.random_shear")
|
| 1607 |
+
def random_shear(
|
| 1608 |
+
x,
|
| 1609 |
+
intensity,
|
| 1610 |
+
row_axis=1,
|
| 1611 |
+
col_axis=2,
|
| 1612 |
+
channel_axis=0,
|
| 1613 |
+
fill_mode="nearest",
|
| 1614 |
+
cval=0.0,
|
| 1615 |
+
interpolation_order=1,
|
| 1616 |
+
):
|
| 1617 |
+
"""DEPRECATED."""
|
| 1618 |
+
shear = np.random.uniform(-intensity, intensity)
|
| 1619 |
+
x = apply_affine_transform(
|
| 1620 |
+
x,
|
| 1621 |
+
shear=shear,
|
| 1622 |
+
row_axis=row_axis,
|
| 1623 |
+
col_axis=col_axis,
|
| 1624 |
+
channel_axis=channel_axis,
|
| 1625 |
+
fill_mode=fill_mode,
|
| 1626 |
+
cval=cval,
|
| 1627 |
+
order=interpolation_order,
|
| 1628 |
+
)
|
| 1629 |
+
return x
|
| 1630 |
+
|
| 1631 |
+
|
| 1632 |
+
@keras_export("keras._legacy.preprocessing.image.random_zoom")
|
| 1633 |
+
def random_zoom(
|
| 1634 |
+
x,
|
| 1635 |
+
zoom_range,
|
| 1636 |
+
row_axis=1,
|
| 1637 |
+
col_axis=2,
|
| 1638 |
+
channel_axis=0,
|
| 1639 |
+
fill_mode="nearest",
|
| 1640 |
+
cval=0.0,
|
| 1641 |
+
interpolation_order=1,
|
| 1642 |
+
):
|
| 1643 |
+
"""DEPRECATED."""
|
| 1644 |
+
if len(zoom_range) != 2:
|
| 1645 |
+
raise ValueError(
|
| 1646 |
+
"`zoom_range` should be a tuple or list of two floats. "
|
| 1647 |
+
f"Received: {zoom_range}"
|
| 1648 |
+
)
|
| 1649 |
+
|
| 1650 |
+
if zoom_range[0] == 1 and zoom_range[1] == 1:
|
| 1651 |
+
zx, zy = 1, 1
|
| 1652 |
+
else:
|
| 1653 |
+
zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
|
| 1654 |
+
x = apply_affine_transform(
|
| 1655 |
+
x,
|
| 1656 |
+
zx=zx,
|
| 1657 |
+
zy=zy,
|
| 1658 |
+
row_axis=row_axis,
|
| 1659 |
+
col_axis=col_axis,
|
| 1660 |
+
channel_axis=channel_axis,
|
| 1661 |
+
fill_mode=fill_mode,
|
| 1662 |
+
cval=cval,
|
| 1663 |
+
order=interpolation_order,
|
| 1664 |
+
)
|
| 1665 |
+
return x
|
| 1666 |
+
|
| 1667 |
+
|
| 1668 |
+
@keras_export("keras._legacy.preprocessing.image.apply_channel_shift")
|
| 1669 |
+
def apply_channel_shift(x, intensity, channel_axis=0):
|
| 1670 |
+
"""Performs a channel shift.
|
| 1671 |
+
|
| 1672 |
+
DEPRECATED.
|
| 1673 |
+
|
| 1674 |
+
Args:
|
| 1675 |
+
x: Input tensor. Must be 3D.
|
| 1676 |
+
intensity: Transformation intensity.
|
| 1677 |
+
channel_axis: Index of axis for channels in the input tensor.
|
| 1678 |
+
|
| 1679 |
+
Returns:
|
| 1680 |
+
Numpy image tensor.
|
| 1681 |
+
"""
|
| 1682 |
+
x = np.rollaxis(x, channel_axis, 0)
|
| 1683 |
+
min_x, max_x = np.min(x), np.max(x)
|
| 1684 |
+
channel_images = [
|
| 1685 |
+
np.clip(x_channel + intensity, min_x, max_x) for x_channel in x
|
| 1686 |
+
]
|
| 1687 |
+
x = np.stack(channel_images, axis=0)
|
| 1688 |
+
x = np.rollaxis(x, 0, channel_axis + 1)
|
| 1689 |
+
return x
|
| 1690 |
+
|
| 1691 |
+
|
| 1692 |
+
@keras_export("keras._legacy.preprocessing.image.random_channel_shift")
|
| 1693 |
+
def random_channel_shift(x, intensity_range, channel_axis=0):
|
| 1694 |
+
"""Performs a random channel shift.
|
| 1695 |
+
|
| 1696 |
+
DEPRECATED.
|
| 1697 |
+
|
| 1698 |
+
Args:
|
| 1699 |
+
x: Input tensor. Must be 3D.
|
| 1700 |
+
intensity_range: Transformation intensity.
|
| 1701 |
+
channel_axis: Index of axis for channels in the input tensor.
|
| 1702 |
+
|
| 1703 |
+
Returns:
|
| 1704 |
+
Numpy image tensor.
|
| 1705 |
+
"""
|
| 1706 |
+
intensity = np.random.uniform(-intensity_range, intensity_range)
|
| 1707 |
+
return apply_channel_shift(x, intensity, channel_axis=channel_axis)
|
| 1708 |
+
|
| 1709 |
+
|
| 1710 |
+
@keras_export("keras._legacy.preprocessing.image.apply_brightness_shift")
|
| 1711 |
+
def apply_brightness_shift(x, brightness, scale=True):
|
| 1712 |
+
"""Performs a brightness shift.
|
| 1713 |
+
|
| 1714 |
+
DEPRECATED.
|
| 1715 |
+
|
| 1716 |
+
Args:
|
| 1717 |
+
x: Input tensor. Must be 3D.
|
| 1718 |
+
brightness: Float. The new brightness value.
|
| 1719 |
+
scale: Whether to rescale the image such that minimum and maximum values
|
| 1720 |
+
are 0 and 255 respectively. Default: True.
|
| 1721 |
+
|
| 1722 |
+
Returns:
|
| 1723 |
+
Numpy image tensor.
|
| 1724 |
+
|
| 1725 |
+
Raises:
|
| 1726 |
+
ImportError: if PIL is not available.
|
| 1727 |
+
"""
|
| 1728 |
+
from PIL import ImageEnhance
|
| 1729 |
+
|
| 1730 |
+
x_min, x_max = np.min(x), np.max(x)
|
| 1731 |
+
local_scale = (x_min < 0) or (x_max > 255)
|
| 1732 |
+
x = image_utils.array_to_img(x, scale=local_scale or scale)
|
| 1733 |
+
x = imgenhancer_Brightness = ImageEnhance.Brightness(x)
|
| 1734 |
+
x = imgenhancer_Brightness.enhance(brightness)
|
| 1735 |
+
x = image_utils.img_to_array(x)
|
| 1736 |
+
if not scale and local_scale:
|
| 1737 |
+
x = x / 255 * (x_max - x_min) + x_min
|
| 1738 |
+
return x
|
| 1739 |
+
|
| 1740 |
+
|
| 1741 |
+
@keras_export("keras._legacy.preprocessing.image.random_brightness")
|
| 1742 |
+
def random_brightness(x, brightness_range, scale=True):
|
| 1743 |
+
"""Performs a random brightness shift.
|
| 1744 |
+
|
| 1745 |
+
DEPRECATED.
|
| 1746 |
+
|
| 1747 |
+
Args:
|
| 1748 |
+
x: Input tensor. Must be 3D.
|
| 1749 |
+
brightness_range: Tuple of floats; brightness range.
|
| 1750 |
+
scale: Whether to rescale the image such that minimum and maximum values
|
| 1751 |
+
are 0 and 255 respectively. Default: True.
|
| 1752 |
+
|
| 1753 |
+
Returns:
|
| 1754 |
+
Numpy image tensor.
|
| 1755 |
+
|
| 1756 |
+
Raises:
|
| 1757 |
+
ValueError if `brightness_range` isn't a tuple.
|
| 1758 |
+
"""
|
| 1759 |
+
if len(brightness_range) != 2:
|
| 1760 |
+
raise ValueError(
|
| 1761 |
+
"`brightness_range should be tuple or list of two floats. "
|
| 1762 |
+
f"Received: {brightness_range}"
|
| 1763 |
+
)
|
| 1764 |
+
|
| 1765 |
+
u = np.random.uniform(brightness_range[0], brightness_range[1])
|
| 1766 |
+
return apply_brightness_shift(x, u, scale)
|
| 1767 |
+
|
| 1768 |
+
|
| 1769 |
+
def transform_matrix_offset_center(matrix, x, y):
|
| 1770 |
+
o_x = float(x) / 2 - 0.5
|
| 1771 |
+
o_y = float(y) / 2 - 0.5
|
| 1772 |
+
offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
|
| 1773 |
+
reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
|
| 1774 |
+
transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
|
| 1775 |
+
return transform_matrix
|
| 1776 |
+
|
| 1777 |
+
|
| 1778 |
+
@keras_export("keras._legacy.preprocessing.image.apply_affine_transform")
|
| 1779 |
+
def apply_affine_transform(
|
| 1780 |
+
x,
|
| 1781 |
+
theta=0,
|
| 1782 |
+
tx=0,
|
| 1783 |
+
ty=0,
|
| 1784 |
+
shear=0,
|
| 1785 |
+
zx=1,
|
| 1786 |
+
zy=1,
|
| 1787 |
+
row_axis=1,
|
| 1788 |
+
col_axis=2,
|
| 1789 |
+
channel_axis=0,
|
| 1790 |
+
fill_mode="nearest",
|
| 1791 |
+
cval=0.0,
|
| 1792 |
+
order=1,
|
| 1793 |
+
):
|
| 1794 |
+
"""Applies an affine transformation specified by the parameters given.
|
| 1795 |
+
|
| 1796 |
+
DEPRECATED.
|
| 1797 |
+
"""
|
| 1798 |
+
# Input sanity checks:
|
| 1799 |
+
# 1. x must 2D image with one or more channels (i.e., a 3D tensor)
|
| 1800 |
+
# 2. channels must be either first or last dimension
|
| 1801 |
+
if np.unique([row_axis, col_axis, channel_axis]).size != 3:
|
| 1802 |
+
raise ValueError(
|
| 1803 |
+
"'row_axis', 'col_axis', and 'channel_axis' must be distinct"
|
| 1804 |
+
)
|
| 1805 |
+
|
| 1806 |
+
# shall we support negative indices?
|
| 1807 |
+
valid_indices = set([0, 1, 2])
|
| 1808 |
+
actual_indices = set([row_axis, col_axis, channel_axis])
|
| 1809 |
+
if actual_indices != valid_indices:
|
| 1810 |
+
raise ValueError(
|
| 1811 |
+
f"Invalid axis' indices: {actual_indices - valid_indices}"
|
| 1812 |
+
)
|
| 1813 |
+
|
| 1814 |
+
if x.ndim != 3:
|
| 1815 |
+
raise ValueError("Input arrays must be multi-channel 2D images.")
|
| 1816 |
+
if channel_axis not in [0, 2]:
|
| 1817 |
+
raise ValueError(
|
| 1818 |
+
"Channels are allowed and the first and last dimensions."
|
| 1819 |
+
)
|
| 1820 |
+
|
| 1821 |
+
transform_matrix = None
|
| 1822 |
+
if theta != 0:
|
| 1823 |
+
theta = np.deg2rad(theta)
|
| 1824 |
+
rotation_matrix = np.array(
|
| 1825 |
+
[
|
| 1826 |
+
[np.cos(theta), -np.sin(theta), 0],
|
| 1827 |
+
[np.sin(theta), np.cos(theta), 0],
|
| 1828 |
+
[0, 0, 1],
|
| 1829 |
+
]
|
| 1830 |
+
)
|
| 1831 |
+
transform_matrix = rotation_matrix
|
| 1832 |
+
|
| 1833 |
+
if tx != 0 or ty != 0:
|
| 1834 |
+
shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
|
| 1835 |
+
if transform_matrix is None:
|
| 1836 |
+
transform_matrix = shift_matrix
|
| 1837 |
+
else:
|
| 1838 |
+
transform_matrix = np.dot(transform_matrix, shift_matrix)
|
| 1839 |
+
|
| 1840 |
+
if shear != 0:
|
| 1841 |
+
shear = np.deg2rad(shear)
|
| 1842 |
+
shear_matrix = np.array(
|
| 1843 |
+
[[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]]
|
| 1844 |
+
)
|
| 1845 |
+
if transform_matrix is None:
|
| 1846 |
+
transform_matrix = shear_matrix
|
| 1847 |
+
else:
|
| 1848 |
+
transform_matrix = np.dot(transform_matrix, shear_matrix)
|
| 1849 |
+
|
| 1850 |
+
if zx != 1 or zy != 1:
|
| 1851 |
+
zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]])
|
| 1852 |
+
if transform_matrix is None:
|
| 1853 |
+
transform_matrix = zoom_matrix
|
| 1854 |
+
else:
|
| 1855 |
+
transform_matrix = np.dot(transform_matrix, zoom_matrix)
|
| 1856 |
+
|
| 1857 |
+
if transform_matrix is not None:
|
| 1858 |
+
h, w = x.shape[row_axis], x.shape[col_axis]
|
| 1859 |
+
transform_matrix = transform_matrix_offset_center(
|
| 1860 |
+
transform_matrix, h, w
|
| 1861 |
+
)
|
| 1862 |
+
x = np.rollaxis(x, channel_axis, 0)
|
| 1863 |
+
|
| 1864 |
+
# Matrix construction assumes that coordinates are x, y (in that order).
|
| 1865 |
+
# However, regular numpy arrays use y,x (aka i,j) indexing.
|
| 1866 |
+
# Possible solution is:
|
| 1867 |
+
# 1. Swap the x and y axes.
|
| 1868 |
+
# 2. Apply transform.
|
| 1869 |
+
# 3. Swap the x and y axes again to restore image-like data ordering.
|
| 1870 |
+
# Mathematically, it is equivalent to the following transformation:
|
| 1871 |
+
# M' = PMP, where P is the permutation matrix, M is the original
|
| 1872 |
+
# transformation matrix.
|
| 1873 |
+
if col_axis > row_axis:
|
| 1874 |
+
transform_matrix[:, [0, 1]] = transform_matrix[:, [1, 0]]
|
| 1875 |
+
transform_matrix[[0, 1]] = transform_matrix[[1, 0]]
|
| 1876 |
+
final_affine_matrix = transform_matrix[:2, :2]
|
| 1877 |
+
final_offset = transform_matrix[:2, 2]
|
| 1878 |
+
|
| 1879 |
+
channel_images = [
|
| 1880 |
+
scipy.ndimage.interpolation.affine_transform(
|
| 1881 |
+
x_channel,
|
| 1882 |
+
final_affine_matrix,
|
| 1883 |
+
final_offset,
|
| 1884 |
+
order=order,
|
| 1885 |
+
mode=fill_mode,
|
| 1886 |
+
cval=cval,
|
| 1887 |
+
)
|
| 1888 |
+
for x_channel in x
|
| 1889 |
+
]
|
| 1890 |
+
x = np.stack(channel_images, axis=0)
|
| 1891 |
+
x = np.rollaxis(x, 0, channel_axis + 1)
|
| 1892 |
+
return x
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/sequence.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deprecated sequence preprocessing APIs from Keras 1."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from keras.src.api_export import keras_export
|
| 9 |
+
from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@keras_export("keras._legacy.preprocessing.sequence.TimeseriesGenerator")
|
| 13 |
+
class TimeseriesGenerator(PyDataset):
|
| 14 |
+
"""Utility class for generating batches of temporal data.
|
| 15 |
+
|
| 16 |
+
DEPRECATED.
|
| 17 |
+
|
| 18 |
+
This class takes in a sequence of data-points gathered at
|
| 19 |
+
equal intervals, along with time series parameters such as
|
| 20 |
+
stride, length of history, etc., to produce batches for
|
| 21 |
+
training/validation.
|
| 22 |
+
|
| 23 |
+
Arguments:
|
| 24 |
+
data: Indexable generator (such as list or Numpy array)
|
| 25 |
+
containing consecutive data points (timesteps).
|
| 26 |
+
The data should be at 2D, and axis 0 is expected
|
| 27 |
+
to be the time dimension.
|
| 28 |
+
targets: Targets corresponding to timesteps in `data`.
|
| 29 |
+
It should have same length as `data`.
|
| 30 |
+
length: Length of the output sequences (in number of timesteps).
|
| 31 |
+
sampling_rate: Period between successive individual timesteps
|
| 32 |
+
within sequences. For rate `r`, timesteps
|
| 33 |
+
`data[i]`, `data[i-r]`, ... `data[i - length]`
|
| 34 |
+
are used for create a sample sequence.
|
| 35 |
+
stride: Period between successive output sequences.
|
| 36 |
+
For stride `s`, consecutive output samples would
|
| 37 |
+
be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc.
|
| 38 |
+
start_index: Data points earlier than `start_index` will not be used
|
| 39 |
+
in the output sequences. This is useful to reserve part of the
|
| 40 |
+
data for test or validation.
|
| 41 |
+
end_index: Data points later than `end_index` will not be used
|
| 42 |
+
in the output sequences. This is useful to reserve part of the
|
| 43 |
+
data for test or validation.
|
| 44 |
+
shuffle: Whether to shuffle output samples,
|
| 45 |
+
or instead draw them in chronological order.
|
| 46 |
+
reverse: Boolean: if `true`, timesteps in each output sample will be
|
| 47 |
+
in reverse chronological order.
|
| 48 |
+
batch_size: Number of timeseries samples in each batch
|
| 49 |
+
(except maybe the last one).
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
A PyDataset instance.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
data,
|
| 58 |
+
targets,
|
| 59 |
+
length,
|
| 60 |
+
sampling_rate=1,
|
| 61 |
+
stride=1,
|
| 62 |
+
start_index=0,
|
| 63 |
+
end_index=None,
|
| 64 |
+
shuffle=False,
|
| 65 |
+
reverse=False,
|
| 66 |
+
batch_size=128,
|
| 67 |
+
):
|
| 68 |
+
if len(data) != len(targets):
|
| 69 |
+
raise ValueError(
|
| 70 |
+
"Data and targets have to be "
|
| 71 |
+
f"of same length. Data length is {len(data)} "
|
| 72 |
+
f"while target length is {len(targets)}"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.data = data
|
| 76 |
+
self.targets = targets
|
| 77 |
+
self.length = length
|
| 78 |
+
self.sampling_rate = sampling_rate
|
| 79 |
+
self.stride = stride
|
| 80 |
+
self.start_index = start_index + length
|
| 81 |
+
if end_index is None:
|
| 82 |
+
end_index = len(data) - 1
|
| 83 |
+
self.end_index = end_index
|
| 84 |
+
self.shuffle = shuffle
|
| 85 |
+
self.reverse = reverse
|
| 86 |
+
self.batch_size = batch_size
|
| 87 |
+
|
| 88 |
+
if self.start_index > self.end_index:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"`start_index+length={self.start_index} "
|
| 91 |
+
f"> end_index={self.end_index}` "
|
| 92 |
+
"is disallowed, as no part of the sequence "
|
| 93 |
+
"would be left to be used as current step."
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def __len__(self):
|
| 97 |
+
return (
|
| 98 |
+
self.end_index - self.start_index + self.batch_size * self.stride
|
| 99 |
+
) // (self.batch_size * self.stride)
|
| 100 |
+
|
| 101 |
+
def __getitem__(self, index):
|
| 102 |
+
if self.shuffle:
|
| 103 |
+
rows = np.random.randint(
|
| 104 |
+
self.start_index, self.end_index + 1, size=self.batch_size
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
i = self.start_index + self.batch_size * self.stride * index
|
| 108 |
+
rows = np.arange(
|
| 109 |
+
i,
|
| 110 |
+
min(i + self.batch_size * self.stride, self.end_index + 1),
|
| 111 |
+
self.stride,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
samples = np.array(
|
| 115 |
+
[
|
| 116 |
+
self.data[row - self.length : row : self.sampling_rate]
|
| 117 |
+
for row in rows
|
| 118 |
+
]
|
| 119 |
+
)
|
| 120 |
+
targets = np.array([self.targets[row] for row in rows])
|
| 121 |
+
|
| 122 |
+
if self.reverse:
|
| 123 |
+
return samples[:, ::-1, ...], targets
|
| 124 |
+
return samples, targets
|
| 125 |
+
|
| 126 |
+
def get_config(self):
|
| 127 |
+
"""Returns the TimeseriesGenerator configuration as Python dictionary.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
A Python dictionary with the TimeseriesGenerator configuration.
|
| 131 |
+
"""
|
| 132 |
+
data = self.data
|
| 133 |
+
if type(self.data).__module__ == np.__name__:
|
| 134 |
+
data = self.data.tolist()
|
| 135 |
+
try:
|
| 136 |
+
json_data = json.dumps(data)
|
| 137 |
+
except TypeError as e:
|
| 138 |
+
raise TypeError(f"Data not JSON Serializable: {data}") from e
|
| 139 |
+
|
| 140 |
+
targets = self.targets
|
| 141 |
+
if type(self.targets).__module__ == np.__name__:
|
| 142 |
+
targets = self.targets.tolist()
|
| 143 |
+
try:
|
| 144 |
+
json_targets = json.dumps(targets)
|
| 145 |
+
except TypeError as e:
|
| 146 |
+
raise TypeError(f"Targets not JSON Serializable: {targets}") from e
|
| 147 |
+
|
| 148 |
+
return {
|
| 149 |
+
"data": json_data,
|
| 150 |
+
"targets": json_targets,
|
| 151 |
+
"length": self.length,
|
| 152 |
+
"sampling_rate": self.sampling_rate,
|
| 153 |
+
"stride": self.stride,
|
| 154 |
+
"start_index": self.start_index,
|
| 155 |
+
"end_index": self.end_index,
|
| 156 |
+
"shuffle": self.shuffle,
|
| 157 |
+
"reverse": self.reverse,
|
| 158 |
+
"batch_size": self.batch_size,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
def to_json(self, **kwargs):
|
| 162 |
+
"""Returns a JSON string containing the generator's configuration.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
**kwargs: Additional keyword arguments to be passed
|
| 166 |
+
to `json.dumps()`.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
A JSON string containing the tokenizer configuration.
|
| 170 |
+
"""
|
| 171 |
+
config = self.get_config()
|
| 172 |
+
timeseries_generator_config = {
|
| 173 |
+
"class_name": self.__class__.__name__,
|
| 174 |
+
"config": config,
|
| 175 |
+
}
|
| 176 |
+
return json.dumps(timeseries_generator_config, **kwargs)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@keras_export("keras._legacy.preprocessing.sequence.make_sampling_table")
|
| 180 |
+
def make_sampling_table(size, sampling_factor=1e-5):
|
| 181 |
+
"""Generates a word rank-based probabilistic sampling table.
|
| 182 |
+
|
| 183 |
+
DEPRECATED.
|
| 184 |
+
|
| 185 |
+
Used for generating the `sampling_table` argument for `skipgrams`.
|
| 186 |
+
`sampling_table[i]` is the probability of sampling
|
| 187 |
+
the word i-th most common word in a dataset
|
| 188 |
+
(more common words should be sampled less frequently, for balance).
|
| 189 |
+
|
| 190 |
+
The sampling probabilities are generated according
|
| 191 |
+
to the sampling distribution used in word2vec:
|
| 192 |
+
|
| 193 |
+
```
|
| 194 |
+
p(word) = (min(1, sqrt(word_frequency / sampling_factor) /
|
| 195 |
+
(word_frequency / sampling_factor)))
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
We assume that the word frequencies follow Zipf's law (s=1) to derive
|
| 199 |
+
a numerical approximation of frequency(rank):
|
| 200 |
+
|
| 201 |
+
`frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank))`
|
| 202 |
+
where `gamma` is the Euler-Mascheroni constant.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
size: Int, number of possible words to sample.
|
| 206 |
+
sampling_factor: The sampling factor in the word2vec formula.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
A 1D Numpy array of length `size` where the ith entry
|
| 210 |
+
is the probability that a word of rank i should be sampled.
|
| 211 |
+
"""
|
| 212 |
+
gamma = 0.577
|
| 213 |
+
rank = np.arange(size)
|
| 214 |
+
rank[0] = 1
|
| 215 |
+
inv_fq = rank * (np.log(rank) + gamma) + 0.5 - 1.0 / (12.0 * rank)
|
| 216 |
+
f = sampling_factor * inv_fq
|
| 217 |
+
|
| 218 |
+
return np.minimum(1.0, f / np.sqrt(f))
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@keras_export("keras._legacy.preprocessing.sequence.skipgrams")
|
| 222 |
+
def skipgrams(
|
| 223 |
+
sequence,
|
| 224 |
+
vocabulary_size,
|
| 225 |
+
window_size=4,
|
| 226 |
+
negative_samples=1.0,
|
| 227 |
+
shuffle=True,
|
| 228 |
+
categorical=False,
|
| 229 |
+
sampling_table=None,
|
| 230 |
+
seed=None,
|
| 231 |
+
):
|
| 232 |
+
"""Generates skipgram word pairs.
|
| 233 |
+
|
| 234 |
+
DEPRECATED.
|
| 235 |
+
|
| 236 |
+
This function transforms a sequence of word indexes (list of integers)
|
| 237 |
+
into tuples of words of the form:
|
| 238 |
+
|
| 239 |
+
- (word, word in the same window), with label 1 (positive samples).
|
| 240 |
+
- (word, random word from the vocabulary), with label 0 (negative samples).
|
| 241 |
+
|
| 242 |
+
Read more about Skipgram in this gnomic paper by Mikolov et al.:
|
| 243 |
+
[Efficient Estimation of Word Representations in
|
| 244 |
+
Vector Space](http://arxiv.org/pdf/1301.3781v3.pdf)
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
sequence: A word sequence (sentence), encoded as a list
|
| 248 |
+
of word indices (integers). If using a `sampling_table`,
|
| 249 |
+
word indices are expected to match the rank
|
| 250 |
+
of the words in a reference dataset (e.g. 10 would encode
|
| 251 |
+
the 10-th most frequently occurring token).
|
| 252 |
+
Note that index 0 is expected to be a non-word and will be skipped.
|
| 253 |
+
vocabulary_size: Int, maximum possible word index + 1
|
| 254 |
+
window_size: Int, size of sampling windows (technically half-window).
|
| 255 |
+
The window of a word `w_i` will be
|
| 256 |
+
`[i - window_size, i + window_size+1]`.
|
| 257 |
+
negative_samples: Float >= 0. 0 for no negative (i.e. random) samples.
|
| 258 |
+
1 for same number as positive samples.
|
| 259 |
+
shuffle: Whether to shuffle the word couples before returning them.
|
| 260 |
+
categorical: bool. if False, labels will be
|
| 261 |
+
integers (eg. `[0, 1, 1 .. ]`),
|
| 262 |
+
if `True`, labels will be categorical, e.g.
|
| 263 |
+
`[[1,0],[0,1],[0,1] .. ]`.
|
| 264 |
+
sampling_table: 1D array of size `vocabulary_size` where the entry i
|
| 265 |
+
encodes the probability to sample a word of rank i.
|
| 266 |
+
seed: Random seed.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
couples, labels: where `couples` are int pairs and
|
| 270 |
+
`labels` are either 0 or 1.
|
| 271 |
+
|
| 272 |
+
Note:
|
| 273 |
+
By convention, index 0 in the vocabulary is
|
| 274 |
+
a non-word and will be skipped.
|
| 275 |
+
"""
|
| 276 |
+
couples = []
|
| 277 |
+
labels = []
|
| 278 |
+
for i, wi in enumerate(sequence):
|
| 279 |
+
if not wi:
|
| 280 |
+
continue
|
| 281 |
+
if sampling_table is not None:
|
| 282 |
+
if sampling_table[wi] < random.random():
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
window_start = max(0, i - window_size)
|
| 286 |
+
window_end = min(len(sequence), i + window_size + 1)
|
| 287 |
+
for j in range(window_start, window_end):
|
| 288 |
+
if j != i:
|
| 289 |
+
wj = sequence[j]
|
| 290 |
+
if not wj:
|
| 291 |
+
continue
|
| 292 |
+
couples.append([wi, wj])
|
| 293 |
+
if categorical:
|
| 294 |
+
labels.append([0, 1])
|
| 295 |
+
else:
|
| 296 |
+
labels.append(1)
|
| 297 |
+
|
| 298 |
+
if negative_samples > 0:
|
| 299 |
+
num_negative_samples = int(len(labels) * negative_samples)
|
| 300 |
+
words = [c[0] for c in couples]
|
| 301 |
+
random.shuffle(words)
|
| 302 |
+
|
| 303 |
+
couples += [
|
| 304 |
+
[words[i % len(words)], random.randint(1, vocabulary_size - 1)]
|
| 305 |
+
for i in range(num_negative_samples)
|
| 306 |
+
]
|
| 307 |
+
if categorical:
|
| 308 |
+
labels += [[1, 0]] * num_negative_samples
|
| 309 |
+
else:
|
| 310 |
+
labels += [0] * num_negative_samples
|
| 311 |
+
|
| 312 |
+
if shuffle:
|
| 313 |
+
if seed is None:
|
| 314 |
+
seed = random.randint(0, 10e6)
|
| 315 |
+
random.seed(seed)
|
| 316 |
+
random.shuffle(couples)
|
| 317 |
+
random.seed(seed)
|
| 318 |
+
random.shuffle(labels)
|
| 319 |
+
|
| 320 |
+
return couples, labels
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/text.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deprecated text preprocessing APIs from Keras 1."""
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import hashlib
|
| 5 |
+
import json
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from keras.src.api_export import keras_export
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@keras_export("keras._legacy.preprocessing.text.text_to_word_sequence")
|
| 14 |
+
def text_to_word_sequence(
|
| 15 |
+
input_text,
|
| 16 |
+
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
|
| 17 |
+
lower=True,
|
| 18 |
+
split=" ",
|
| 19 |
+
):
|
| 20 |
+
"""DEPRECATED."""
|
| 21 |
+
if lower:
|
| 22 |
+
input_text = input_text.lower()
|
| 23 |
+
|
| 24 |
+
translate_dict = {c: split for c in filters}
|
| 25 |
+
translate_map = str.maketrans(translate_dict)
|
| 26 |
+
input_text = input_text.translate(translate_map)
|
| 27 |
+
|
| 28 |
+
seq = input_text.split(split)
|
| 29 |
+
return [i for i in seq if i]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@keras_export("keras._legacy.preprocessing.text.one_hot")
|
| 33 |
+
def one_hot(
|
| 34 |
+
input_text,
|
| 35 |
+
n,
|
| 36 |
+
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
|
| 37 |
+
lower=True,
|
| 38 |
+
split=" ",
|
| 39 |
+
analyzer=None,
|
| 40 |
+
):
|
| 41 |
+
"""DEPRECATED."""
|
| 42 |
+
return hashing_trick(
|
| 43 |
+
input_text,
|
| 44 |
+
n,
|
| 45 |
+
hash_function=hash,
|
| 46 |
+
filters=filters,
|
| 47 |
+
lower=lower,
|
| 48 |
+
split=split,
|
| 49 |
+
analyzer=analyzer,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@keras_export("keras._legacy.preprocessing.text.hashing_trick")
|
| 54 |
+
def hashing_trick(
|
| 55 |
+
text,
|
| 56 |
+
n,
|
| 57 |
+
hash_function=None,
|
| 58 |
+
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
|
| 59 |
+
lower=True,
|
| 60 |
+
split=" ",
|
| 61 |
+
analyzer=None,
|
| 62 |
+
):
|
| 63 |
+
"""DEPRECATED."""
|
| 64 |
+
if hash_function is None:
|
| 65 |
+
hash_function = hash
|
| 66 |
+
elif hash_function == "md5":
|
| 67 |
+
|
| 68 |
+
def hash_function(w):
|
| 69 |
+
return int(hashlib.md5(w.encode()).hexdigest(), 16)
|
| 70 |
+
|
| 71 |
+
if analyzer is None:
|
| 72 |
+
seq = text_to_word_sequence(
|
| 73 |
+
text, filters=filters, lower=lower, split=split
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
seq = analyzer(text)
|
| 77 |
+
|
| 78 |
+
return [(hash_function(w) % (n - 1) + 1) for w in seq]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@keras_export("keras._legacy.preprocessing.text.Tokenizer")
|
| 82 |
+
class Tokenizer:
|
| 83 |
+
"""DEPRECATED."""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
num_words=None,
|
| 88 |
+
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
|
| 89 |
+
lower=True,
|
| 90 |
+
split=" ",
|
| 91 |
+
char_level=False,
|
| 92 |
+
oov_token=None,
|
| 93 |
+
analyzer=None,
|
| 94 |
+
**kwargs,
|
| 95 |
+
):
|
| 96 |
+
# Legacy support
|
| 97 |
+
if "nb_words" in kwargs:
|
| 98 |
+
warnings.warn(
|
| 99 |
+
"The `nb_words` argument in `Tokenizer` "
|
| 100 |
+
"has been renamed `num_words`."
|
| 101 |
+
)
|
| 102 |
+
num_words = kwargs.pop("nb_words")
|
| 103 |
+
document_count = kwargs.pop("document_count", 0)
|
| 104 |
+
if kwargs:
|
| 105 |
+
raise TypeError("Unrecognized keyword arguments: " + str(kwargs))
|
| 106 |
+
|
| 107 |
+
self.word_counts = collections.OrderedDict()
|
| 108 |
+
self.word_docs = collections.defaultdict(int)
|
| 109 |
+
self.filters = filters
|
| 110 |
+
self.split = split
|
| 111 |
+
self.lower = lower
|
| 112 |
+
self.num_words = num_words
|
| 113 |
+
self.document_count = document_count
|
| 114 |
+
self.char_level = char_level
|
| 115 |
+
self.oov_token = oov_token
|
| 116 |
+
self.index_docs = collections.defaultdict(int)
|
| 117 |
+
self.word_index = {}
|
| 118 |
+
self.index_word = {}
|
| 119 |
+
self.analyzer = analyzer
|
| 120 |
+
|
| 121 |
+
def fit_on_texts(self, texts):
|
| 122 |
+
for text in texts:
|
| 123 |
+
self.document_count += 1
|
| 124 |
+
if self.char_level or isinstance(text, list):
|
| 125 |
+
if self.lower:
|
| 126 |
+
if isinstance(text, list):
|
| 127 |
+
text = [text_elem.lower() for text_elem in text]
|
| 128 |
+
else:
|
| 129 |
+
text = text.lower()
|
| 130 |
+
seq = text
|
| 131 |
+
else:
|
| 132 |
+
if self.analyzer is None:
|
| 133 |
+
seq = text_to_word_sequence(
|
| 134 |
+
text,
|
| 135 |
+
filters=self.filters,
|
| 136 |
+
lower=self.lower,
|
| 137 |
+
split=self.split,
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
seq = self.analyzer(text)
|
| 141 |
+
for w in seq:
|
| 142 |
+
if w in self.word_counts:
|
| 143 |
+
self.word_counts[w] += 1
|
| 144 |
+
else:
|
| 145 |
+
self.word_counts[w] = 1
|
| 146 |
+
for w in set(seq):
|
| 147 |
+
# In how many documents each word occurs
|
| 148 |
+
self.word_docs[w] += 1
|
| 149 |
+
|
| 150 |
+
wcounts = list(self.word_counts.items())
|
| 151 |
+
wcounts.sort(key=lambda x: x[1], reverse=True)
|
| 152 |
+
# forcing the oov_token to index 1 if it exists
|
| 153 |
+
if self.oov_token is None:
|
| 154 |
+
sorted_voc = []
|
| 155 |
+
else:
|
| 156 |
+
sorted_voc = [self.oov_token]
|
| 157 |
+
sorted_voc.extend(wc[0] for wc in wcounts)
|
| 158 |
+
|
| 159 |
+
# note that index 0 is reserved, never assigned to an existing word
|
| 160 |
+
self.word_index = dict(
|
| 161 |
+
zip(sorted_voc, list(range(1, len(sorted_voc) + 1)))
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.index_word = {c: w for w, c in self.word_index.items()}
|
| 165 |
+
|
| 166 |
+
for w, c in list(self.word_docs.items()):
|
| 167 |
+
self.index_docs[self.word_index[w]] = c
|
| 168 |
+
|
| 169 |
+
def fit_on_sequences(self, sequences):
|
| 170 |
+
self.document_count += len(sequences)
|
| 171 |
+
for seq in sequences:
|
| 172 |
+
seq = set(seq)
|
| 173 |
+
for i in seq:
|
| 174 |
+
self.index_docs[i] += 1
|
| 175 |
+
|
| 176 |
+
def texts_to_sequences(self, texts):
|
| 177 |
+
return list(self.texts_to_sequences_generator(texts))
|
| 178 |
+
|
| 179 |
+
def texts_to_sequences_generator(self, texts):
|
| 180 |
+
num_words = self.num_words
|
| 181 |
+
oov_token_index = self.word_index.get(self.oov_token)
|
| 182 |
+
for text in texts:
|
| 183 |
+
if self.char_level or isinstance(text, list):
|
| 184 |
+
if self.lower:
|
| 185 |
+
if isinstance(text, list):
|
| 186 |
+
text = [text_elem.lower() for text_elem in text]
|
| 187 |
+
else:
|
| 188 |
+
text = text.lower()
|
| 189 |
+
seq = text
|
| 190 |
+
else:
|
| 191 |
+
if self.analyzer is None:
|
| 192 |
+
seq = text_to_word_sequence(
|
| 193 |
+
text,
|
| 194 |
+
filters=self.filters,
|
| 195 |
+
lower=self.lower,
|
| 196 |
+
split=self.split,
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
seq = self.analyzer(text)
|
| 200 |
+
vect = []
|
| 201 |
+
for w in seq:
|
| 202 |
+
i = self.word_index.get(w)
|
| 203 |
+
if i is not None:
|
| 204 |
+
if num_words and i >= num_words:
|
| 205 |
+
if oov_token_index is not None:
|
| 206 |
+
vect.append(oov_token_index)
|
| 207 |
+
else:
|
| 208 |
+
vect.append(i)
|
| 209 |
+
elif self.oov_token is not None:
|
| 210 |
+
vect.append(oov_token_index)
|
| 211 |
+
yield vect
|
| 212 |
+
|
| 213 |
+
def sequences_to_texts(self, sequences):
|
| 214 |
+
return list(self.sequences_to_texts_generator(sequences))
|
| 215 |
+
|
| 216 |
+
def sequences_to_texts_generator(self, sequences):
|
| 217 |
+
num_words = self.num_words
|
| 218 |
+
oov_token_index = self.word_index.get(self.oov_token)
|
| 219 |
+
for seq in sequences:
|
| 220 |
+
vect = []
|
| 221 |
+
for num in seq:
|
| 222 |
+
word = self.index_word.get(num)
|
| 223 |
+
if word is not None:
|
| 224 |
+
if num_words and num >= num_words:
|
| 225 |
+
if oov_token_index is not None:
|
| 226 |
+
vect.append(self.index_word[oov_token_index])
|
| 227 |
+
else:
|
| 228 |
+
vect.append(word)
|
| 229 |
+
elif self.oov_token is not None:
|
| 230 |
+
vect.append(self.index_word[oov_token_index])
|
| 231 |
+
vect = " ".join(vect)
|
| 232 |
+
yield vect
|
| 233 |
+
|
| 234 |
+
def texts_to_matrix(self, texts, mode="binary"):
|
| 235 |
+
sequences = self.texts_to_sequences(texts)
|
| 236 |
+
return self.sequences_to_matrix(sequences, mode=mode)
|
| 237 |
+
|
| 238 |
+
def sequences_to_matrix(self, sequences, mode="binary"):
|
| 239 |
+
if not self.num_words:
|
| 240 |
+
if self.word_index:
|
| 241 |
+
num_words = len(self.word_index) + 1
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError(
|
| 244 |
+
"Specify a dimension (`num_words` argument), "
|
| 245 |
+
"or fit on some text data first."
|
| 246 |
+
)
|
| 247 |
+
else:
|
| 248 |
+
num_words = self.num_words
|
| 249 |
+
|
| 250 |
+
if mode == "tfidf" and not self.document_count:
|
| 251 |
+
raise ValueError(
|
| 252 |
+
"Fit the Tokenizer on some data before using tfidf mode."
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
x = np.zeros((len(sequences), num_words))
|
| 256 |
+
for i, seq in enumerate(sequences):
|
| 257 |
+
if not seq:
|
| 258 |
+
continue
|
| 259 |
+
counts = collections.defaultdict(int)
|
| 260 |
+
for j in seq:
|
| 261 |
+
if j >= num_words:
|
| 262 |
+
continue
|
| 263 |
+
counts[j] += 1
|
| 264 |
+
for j, c in list(counts.items()):
|
| 265 |
+
if mode == "count":
|
| 266 |
+
x[i][j] = c
|
| 267 |
+
elif mode == "freq":
|
| 268 |
+
x[i][j] = c / len(seq)
|
| 269 |
+
elif mode == "binary":
|
| 270 |
+
x[i][j] = 1
|
| 271 |
+
elif mode == "tfidf":
|
| 272 |
+
# Use weighting scheme 2 in
|
| 273 |
+
# https://en.wikipedia.org/wiki/Tf%E2%80%93idf
|
| 274 |
+
tf = 1 + np.log(c)
|
| 275 |
+
idf = np.log(
|
| 276 |
+
1
|
| 277 |
+
+ self.document_count / (1 + self.index_docs.get(j, 0))
|
| 278 |
+
)
|
| 279 |
+
x[i][j] = tf * idf
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError("Unknown vectorization mode:", mode)
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
def get_config(self):
|
| 285 |
+
json_word_counts = json.dumps(self.word_counts)
|
| 286 |
+
json_word_docs = json.dumps(self.word_docs)
|
| 287 |
+
json_index_docs = json.dumps(self.index_docs)
|
| 288 |
+
json_word_index = json.dumps(self.word_index)
|
| 289 |
+
json_index_word = json.dumps(self.index_word)
|
| 290 |
+
|
| 291 |
+
return {
|
| 292 |
+
"num_words": self.num_words,
|
| 293 |
+
"filters": self.filters,
|
| 294 |
+
"lower": self.lower,
|
| 295 |
+
"split": self.split,
|
| 296 |
+
"char_level": self.char_level,
|
| 297 |
+
"oov_token": self.oov_token,
|
| 298 |
+
"document_count": self.document_count,
|
| 299 |
+
"word_counts": json_word_counts,
|
| 300 |
+
"word_docs": json_word_docs,
|
| 301 |
+
"index_docs": json_index_docs,
|
| 302 |
+
"index_word": json_index_word,
|
| 303 |
+
"word_index": json_word_index,
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
def to_json(self, **kwargs):
|
| 307 |
+
config = self.get_config()
|
| 308 |
+
tokenizer_config = {
|
| 309 |
+
"class_name": self.__class__.__name__,
|
| 310 |
+
"config": config,
|
| 311 |
+
}
|
| 312 |
+
return json.dumps(tokenizer_config, **kwargs)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@keras_export("keras._legacy.preprocessing.text.tokenizer_from_json")
|
| 316 |
+
def tokenizer_from_json(json_string):
|
| 317 |
+
"""DEPRECATED."""
|
| 318 |
+
tokenizer_config = json.loads(json_string)
|
| 319 |
+
config = tokenizer_config.get("config")
|
| 320 |
+
|
| 321 |
+
word_counts = json.loads(config.pop("word_counts"))
|
| 322 |
+
word_docs = json.loads(config.pop("word_docs"))
|
| 323 |
+
index_docs = json.loads(config.pop("index_docs"))
|
| 324 |
+
# Integer indexing gets converted to strings with json.dumps()
|
| 325 |
+
index_docs = {int(k): v for k, v in index_docs.items()}
|
| 326 |
+
index_word = json.loads(config.pop("index_word"))
|
| 327 |
+
index_word = {int(k): v for k, v in index_word.items()}
|
| 328 |
+
word_index = json.loads(config.pop("word_index"))
|
| 329 |
+
|
| 330 |
+
tokenizer = Tokenizer(**config)
|
| 331 |
+
tokenizer.word_counts = word_counts
|
| 332 |
+
tokenizer.word_docs = word_docs
|
| 333 |
+
tokenizer.index_docs = index_docs
|
| 334 |
+
tokenizer.word_index = word_index
|
| 335 |
+
tokenizer.index_word = index_word
|
| 336 |
+
return tokenizer
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (199 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/json_utils.cpython-310.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/legacy_h5_format.cpython-310.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/saving_options.cpython-310.pyc
ADDED
|
Binary file (612 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/saving_utils.cpython-310.pyc
ADDED
|
Binary file (6.75 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/serialization.cpython-310.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/json_utils.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""JSON utilities for legacy saving formats (h5 and SavedModel)"""
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import enum
|
| 5 |
+
import functools
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from keras.src.legacy.saving import serialization
|
| 11 |
+
from keras.src.saving import serialization_lib
|
| 12 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 13 |
+
|
| 14 |
+
_EXTENSION_TYPE_SPEC = "_EXTENSION_TYPE_SPEC"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Encoder(json.JSONEncoder):
|
| 18 |
+
"""JSON encoder and decoder that handles TensorShapes and tuples."""
|
| 19 |
+
|
| 20 |
+
def default(self, obj):
|
| 21 |
+
"""Encodes objects for types that aren't handled by the default
|
| 22 |
+
encoder."""
|
| 23 |
+
if tf.available and isinstance(obj, tf.TensorShape):
|
| 24 |
+
items = obj.as_list() if obj.rank is not None else None
|
| 25 |
+
return {"class_name": "TensorShape", "items": items}
|
| 26 |
+
return get_json_type(obj)
|
| 27 |
+
|
| 28 |
+
def encode(self, obj):
|
| 29 |
+
return super().encode(_encode_tuple(obj))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _encode_tuple(x):
|
| 33 |
+
if isinstance(x, tuple):
|
| 34 |
+
return {
|
| 35 |
+
"class_name": "__tuple__",
|
| 36 |
+
"items": tuple(_encode_tuple(i) for i in x),
|
| 37 |
+
}
|
| 38 |
+
elif isinstance(x, list):
|
| 39 |
+
return [_encode_tuple(i) for i in x]
|
| 40 |
+
elif isinstance(x, dict):
|
| 41 |
+
return {key: _encode_tuple(value) for key, value in x.items()}
|
| 42 |
+
else:
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def decode(json_string):
|
| 47 |
+
return json.loads(json_string, object_hook=_decode_helper)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def decode_and_deserialize(
|
| 51 |
+
json_string, module_objects=None, custom_objects=None
|
| 52 |
+
):
|
| 53 |
+
"""Decodes the JSON and deserializes any Keras objects found in the dict."""
|
| 54 |
+
return json.loads(
|
| 55 |
+
json_string,
|
| 56 |
+
object_hook=functools.partial(
|
| 57 |
+
_decode_helper,
|
| 58 |
+
deserialize=True,
|
| 59 |
+
module_objects=module_objects,
|
| 60 |
+
custom_objects=custom_objects,
|
| 61 |
+
),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _decode_helper(
|
| 66 |
+
obj, deserialize=False, module_objects=None, custom_objects=None
|
| 67 |
+
):
|
| 68 |
+
"""A decoding helper that is TF-object aware.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
obj: A decoded dictionary that may represent an object.
|
| 72 |
+
deserialize: Boolean. When True, deserializes any Keras
|
| 73 |
+
objects found in `obj`. Defaults to `False`.
|
| 74 |
+
module_objects: A dictionary of built-in objects to look the name up in.
|
| 75 |
+
Generally, `module_objects` is provided by midlevel library
|
| 76 |
+
implementers.
|
| 77 |
+
custom_objects: A dictionary of custom objects to look the name up in.
|
| 78 |
+
Generally, `custom_objects` is provided by the end user.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
The decoded object.
|
| 82 |
+
"""
|
| 83 |
+
if isinstance(obj, dict) and "class_name" in obj:
|
| 84 |
+
if tf.available:
|
| 85 |
+
if obj["class_name"] == "TensorShape":
|
| 86 |
+
return tf.TensorShape(obj["items"])
|
| 87 |
+
elif obj["class_name"] == "TypeSpec":
|
| 88 |
+
from tensorflow.python.framework import type_spec_registry
|
| 89 |
+
|
| 90 |
+
return type_spec_registry.lookup(obj["type_spec"])._deserialize(
|
| 91 |
+
_decode_helper(obj["serialized"])
|
| 92 |
+
)
|
| 93 |
+
elif obj["class_name"] == "CompositeTensor":
|
| 94 |
+
spec = obj["spec"]
|
| 95 |
+
tensors = []
|
| 96 |
+
for dtype, tensor in obj["tensors"]:
|
| 97 |
+
tensors.append(
|
| 98 |
+
tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype))
|
| 99 |
+
)
|
| 100 |
+
return tf.nest.pack_sequence_as(
|
| 101 |
+
_decode_helper(spec), tensors, expand_composites=True
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if obj["class_name"] == "__tuple__":
|
| 105 |
+
return tuple(_decode_helper(i) for i in obj["items"])
|
| 106 |
+
elif obj["class_name"] == "__ellipsis__":
|
| 107 |
+
return Ellipsis
|
| 108 |
+
elif deserialize and "__passive_serialization__" in obj:
|
| 109 |
+
# __passive_serialization__ is added by the JSON encoder when
|
| 110 |
+
# encoding an object that has a `get_config()` method.
|
| 111 |
+
try:
|
| 112 |
+
if (
|
| 113 |
+
"module" not in obj
|
| 114 |
+
): # TODO(nkovela): Add TF SavedModel scope
|
| 115 |
+
return serialization.deserialize_keras_object(
|
| 116 |
+
obj,
|
| 117 |
+
module_objects=module_objects,
|
| 118 |
+
custom_objects=custom_objects,
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
return serialization_lib.deserialize_keras_object(
|
| 122 |
+
obj,
|
| 123 |
+
module_objects=module_objects,
|
| 124 |
+
custom_objects=custom_objects,
|
| 125 |
+
)
|
| 126 |
+
except ValueError:
|
| 127 |
+
pass
|
| 128 |
+
elif obj["class_name"] == "__bytes__":
|
| 129 |
+
return obj["value"].encode("utf-8")
|
| 130 |
+
return obj
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_json_type(obj):
|
| 134 |
+
"""Serializes any object to a JSON-serializable structure.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
obj: the object to serialize
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
JSON-serializable structure representing `obj`.
|
| 141 |
+
|
| 142 |
+
Raises:
|
| 143 |
+
TypeError: if `obj` cannot be serialized.
|
| 144 |
+
"""
|
| 145 |
+
# if obj is a serializable Keras class instance
|
| 146 |
+
# e.g. optimizer, layer
|
| 147 |
+
if hasattr(obj, "get_config"):
|
| 148 |
+
# TODO(nkovela): Replace with legacy serialization
|
| 149 |
+
serialized = serialization.serialize_keras_object(obj)
|
| 150 |
+
serialized["__passive_serialization__"] = True
|
| 151 |
+
return serialized
|
| 152 |
+
|
| 153 |
+
# if obj is any numpy type
|
| 154 |
+
if type(obj).__module__ == np.__name__:
|
| 155 |
+
if isinstance(obj, np.ndarray):
|
| 156 |
+
return obj.tolist()
|
| 157 |
+
else:
|
| 158 |
+
return obj.item()
|
| 159 |
+
|
| 160 |
+
# misc functions (e.g. loss function)
|
| 161 |
+
if callable(obj):
|
| 162 |
+
return obj.__name__
|
| 163 |
+
|
| 164 |
+
# if obj is a python 'type'
|
| 165 |
+
if type(obj).__name__ == type.__name__:
|
| 166 |
+
return obj.__name__
|
| 167 |
+
|
| 168 |
+
if tf.available and isinstance(obj, tf.compat.v1.Dimension):
|
| 169 |
+
return obj.value
|
| 170 |
+
|
| 171 |
+
if tf.available and isinstance(obj, tf.TensorShape):
|
| 172 |
+
return obj.as_list()
|
| 173 |
+
|
| 174 |
+
if tf.available and isinstance(obj, tf.DType):
|
| 175 |
+
return obj.name
|
| 176 |
+
|
| 177 |
+
if isinstance(obj, collections.abc.Mapping):
|
| 178 |
+
return dict(obj)
|
| 179 |
+
|
| 180 |
+
if obj is Ellipsis:
|
| 181 |
+
return {"class_name": "__ellipsis__"}
|
| 182 |
+
|
| 183 |
+
# if isinstance(obj, wrapt.ObjectProxy):
|
| 184 |
+
# return obj.__wrapped__
|
| 185 |
+
|
| 186 |
+
if tf.available and isinstance(obj, tf.TypeSpec):
|
| 187 |
+
from tensorflow.python.framework import type_spec_registry
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
type_spec_name = type_spec_registry.get_name(type(obj))
|
| 191 |
+
return {
|
| 192 |
+
"class_name": "TypeSpec",
|
| 193 |
+
"type_spec": type_spec_name,
|
| 194 |
+
"serialized": obj._serialize(),
|
| 195 |
+
}
|
| 196 |
+
except ValueError:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"Unable to serialize {obj} to JSON, because the TypeSpec "
|
| 199 |
+
f"class {type(obj)} has not been registered."
|
| 200 |
+
)
|
| 201 |
+
if tf.available and isinstance(obj, tf.__internal__.CompositeTensor):
|
| 202 |
+
spec = tf.type_spec_from_value(obj)
|
| 203 |
+
tensors = []
|
| 204 |
+
for tensor in tf.nest.flatten(obj, expand_composites=True):
|
| 205 |
+
tensors.append((tensor.dtype.name, tensor.numpy().tolist()))
|
| 206 |
+
return {
|
| 207 |
+
"class_name": "CompositeTensor",
|
| 208 |
+
"spec": get_json_type(spec),
|
| 209 |
+
"tensors": tensors,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
if isinstance(obj, enum.Enum):
|
| 213 |
+
return obj.value
|
| 214 |
+
|
| 215 |
+
if isinstance(obj, bytes):
|
| 216 |
+
return {"class_name": "__bytes__", "value": obj.decode("utf-8")}
|
| 217 |
+
|
| 218 |
+
raise TypeError(
|
| 219 |
+
f"Unable to serialize {obj} to JSON. Unrecognized type {type(obj)}."
|
| 220 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/legacy_h5_format.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from absl import logging
|
| 7 |
+
|
| 8 |
+
from keras.src import backend
|
| 9 |
+
from keras.src import optimizers
|
| 10 |
+
from keras.src.backend.common import global_state
|
| 11 |
+
from keras.src.legacy.saving import json_utils
|
| 12 |
+
from keras.src.legacy.saving import saving_options
|
| 13 |
+
from keras.src.legacy.saving import saving_utils
|
| 14 |
+
from keras.src.saving import object_registration
|
| 15 |
+
from keras.src.utils import io_utils
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import h5py
|
| 19 |
+
except ImportError:
|
| 20 |
+
h5py = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
HDF5_OBJECT_HEADER_LIMIT = 64512
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
|
| 27 |
+
if h5py is None:
|
| 28 |
+
raise ImportError(
|
| 29 |
+
"`save_model()` using h5 format requires h5py. Could not "
|
| 30 |
+
"import h5py."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if not isinstance(filepath, h5py.File):
|
| 34 |
+
# If file exists and should not be overwritten.
|
| 35 |
+
if not overwrite and os.path.isfile(filepath):
|
| 36 |
+
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
|
| 37 |
+
if not proceed:
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
dirpath = os.path.dirname(filepath)
|
| 41 |
+
if dirpath and not os.path.exists(dirpath):
|
| 42 |
+
os.makedirs(dirpath, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
f = h5py.File(filepath, mode="w")
|
| 45 |
+
opened_new_file = True
|
| 46 |
+
else:
|
| 47 |
+
f = filepath
|
| 48 |
+
opened_new_file = False
|
| 49 |
+
try:
|
| 50 |
+
with saving_options.keras_option_scope(use_legacy_config=True):
|
| 51 |
+
model_metadata = saving_utils.model_metadata(
|
| 52 |
+
model, include_optimizer
|
| 53 |
+
)
|
| 54 |
+
for k, v in model_metadata.items():
|
| 55 |
+
if isinstance(v, (dict, list, tuple)):
|
| 56 |
+
f.attrs[k] = json.dumps(
|
| 57 |
+
v, default=json_utils.get_json_type
|
| 58 |
+
).encode("utf8")
|
| 59 |
+
else:
|
| 60 |
+
f.attrs[k] = v
|
| 61 |
+
|
| 62 |
+
model_weights_group = f.create_group("model_weights")
|
| 63 |
+
save_weights_to_hdf5_group(model_weights_group, model)
|
| 64 |
+
|
| 65 |
+
# TODO(b/128683857): Add integration tests between tf.keras and
|
| 66 |
+
# external Keras, to avoid breaking TF.js users.
|
| 67 |
+
if include_optimizer and hasattr(model, "optimizer"):
|
| 68 |
+
save_optimizer_weights_to_hdf5_group(f, model.optimizer)
|
| 69 |
+
|
| 70 |
+
f.flush()
|
| 71 |
+
finally:
|
| 72 |
+
if opened_new_file:
|
| 73 |
+
f.close()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
|
| 77 |
+
"""Loads a model saved via `save_model_to_hdf5`.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
filepath: One of the following:
|
| 81 |
+
- String, path to the saved model
|
| 82 |
+
- `h5py.File` object from which to load the model
|
| 83 |
+
custom_objects: Optional dictionary mapping names
|
| 84 |
+
(strings) to custom classes or functions to be
|
| 85 |
+
considered during deserialization.
|
| 86 |
+
compile: Boolean, whether to compile the model
|
| 87 |
+
after loading.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
A Keras model instance. If an optimizer was found
|
| 91 |
+
as part of the saved model, the model is already
|
| 92 |
+
compiled. Otherwise, the model is uncompiled and
|
| 93 |
+
a warning will be displayed. When `compile` is set
|
| 94 |
+
to `False`, the compilation is omitted without any
|
| 95 |
+
warning.
|
| 96 |
+
|
| 97 |
+
Raises:
|
| 98 |
+
ImportError: if h5py is not available.
|
| 99 |
+
ValueError: In case of an invalid savefile.
|
| 100 |
+
"""
|
| 101 |
+
if h5py is None:
|
| 102 |
+
raise ImportError(
|
| 103 |
+
"`load_model()` using h5 format requires h5py. Could not "
|
| 104 |
+
"import h5py."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if not custom_objects:
|
| 108 |
+
custom_objects = {}
|
| 109 |
+
|
| 110 |
+
gco = object_registration.GLOBAL_CUSTOM_OBJECTS
|
| 111 |
+
tlco = global_state.get_global_attribute("custom_objects_scope_dict", {})
|
| 112 |
+
custom_objects = {**custom_objects, **gco, **tlco}
|
| 113 |
+
|
| 114 |
+
opened_new_file = not isinstance(filepath, h5py.File)
|
| 115 |
+
if opened_new_file:
|
| 116 |
+
f = h5py.File(filepath, mode="r")
|
| 117 |
+
else:
|
| 118 |
+
f = filepath
|
| 119 |
+
|
| 120 |
+
model = None
|
| 121 |
+
try:
|
| 122 |
+
# instantiate model
|
| 123 |
+
model_config = f.attrs.get("model_config")
|
| 124 |
+
if model_config is None:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"No model config found in the file at {filepath}."
|
| 127 |
+
)
|
| 128 |
+
if hasattr(model_config, "decode"):
|
| 129 |
+
model_config = model_config.decode("utf-8")
|
| 130 |
+
model_config = json_utils.decode(model_config)
|
| 131 |
+
|
| 132 |
+
with saving_options.keras_option_scope(use_legacy_config=True):
|
| 133 |
+
model = saving_utils.model_from_config(
|
| 134 |
+
model_config, custom_objects=custom_objects
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# set weights
|
| 138 |
+
load_weights_from_hdf5_group(f["model_weights"], model)
|
| 139 |
+
|
| 140 |
+
if compile:
|
| 141 |
+
# instantiate optimizer
|
| 142 |
+
training_config = f.attrs.get("training_config")
|
| 143 |
+
if hasattr(training_config, "decode"):
|
| 144 |
+
training_config = training_config.decode("utf-8")
|
| 145 |
+
if training_config is None:
|
| 146 |
+
logging.warning(
|
| 147 |
+
"No training configuration found in the save file, so "
|
| 148 |
+
"the model was *not* compiled. Compile it manually."
|
| 149 |
+
)
|
| 150 |
+
return model
|
| 151 |
+
training_config = json_utils.decode(training_config)
|
| 152 |
+
|
| 153 |
+
# Compile model.
|
| 154 |
+
model.compile(
|
| 155 |
+
**saving_utils.compile_args_from_training_config(
|
| 156 |
+
training_config, custom_objects
|
| 157 |
+
)
|
| 158 |
+
)
|
| 159 |
+
saving_utils.try_build_compiled_arguments(model)
|
| 160 |
+
|
| 161 |
+
# Set optimizer weights.
|
| 162 |
+
if "optimizer_weights" in f:
|
| 163 |
+
try:
|
| 164 |
+
if isinstance(model.optimizer, optimizers.Optimizer):
|
| 165 |
+
model.optimizer.build(model._trainable_variables)
|
| 166 |
+
else:
|
| 167 |
+
model.optimizer._create_all_weights(
|
| 168 |
+
model._trainable_variables
|
| 169 |
+
)
|
| 170 |
+
except (NotImplementedError, AttributeError):
|
| 171 |
+
logging.warning(
|
| 172 |
+
"Error when creating the weights of optimizer {}, "
|
| 173 |
+
"making it impossible to restore the saved optimizer "
|
| 174 |
+
"state. As a result, your model is starting with "
|
| 175 |
+
"a freshly initialized optimizer."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
optimizer_weight_values = (
|
| 179 |
+
load_optimizer_weights_from_hdf5_group(f)
|
| 180 |
+
)
|
| 181 |
+
try:
|
| 182 |
+
model.optimizer.set_weights(optimizer_weight_values)
|
| 183 |
+
except ValueError:
|
| 184 |
+
logging.warning(
|
| 185 |
+
"Error in loading the saved optimizer "
|
| 186 |
+
"state. As a result, your model is "
|
| 187 |
+
"starting with a freshly initialized "
|
| 188 |
+
"optimizer."
|
| 189 |
+
)
|
| 190 |
+
finally:
|
| 191 |
+
if opened_new_file:
|
| 192 |
+
f.close()
|
| 193 |
+
return model
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def save_weights_to_hdf5_group(f, model):
|
| 197 |
+
"""Saves the weights of a list of layers to a HDF5 group.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
f: HDF5 group.
|
| 201 |
+
model: Model instance.
|
| 202 |
+
"""
|
| 203 |
+
from keras.src import __version__ as keras_version
|
| 204 |
+
|
| 205 |
+
save_attributes_to_hdf5_group(
|
| 206 |
+
f, "layer_names", [layer.name.encode("utf8") for layer in model.layers]
|
| 207 |
+
)
|
| 208 |
+
f.attrs["backend"] = backend.backend().encode("utf8")
|
| 209 |
+
f.attrs["keras_version"] = str(keras_version).encode("utf8")
|
| 210 |
+
|
| 211 |
+
# Sort model layers by layer name to ensure that group names are strictly
|
| 212 |
+
# growing to avoid prefix issues.
|
| 213 |
+
for layer in sorted(model.layers, key=lambda x: x.name):
|
| 214 |
+
g = f.create_group(layer.name)
|
| 215 |
+
weights = _legacy_weights(layer)
|
| 216 |
+
save_subset_weights_to_hdf5_group(g, weights)
|
| 217 |
+
weights = list(
|
| 218 |
+
v
|
| 219 |
+
for v in model._trainable_variables + model._non_trainable_variables
|
| 220 |
+
if v in model.weights
|
| 221 |
+
)
|
| 222 |
+
g = f.create_group("top_level_model_weights")
|
| 223 |
+
save_subset_weights_to_hdf5_group(g, weights)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def save_subset_weights_to_hdf5_group(f, weights):
|
| 227 |
+
"""Save top-level weights of a model to a HDF5 group.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
f: HDF5 group.
|
| 231 |
+
weights: List of weight variables.
|
| 232 |
+
"""
|
| 233 |
+
weight_values = [backend.convert_to_numpy(w) for w in weights]
|
| 234 |
+
weight_names = [str(w.path).encode("utf8") for w in weights]
|
| 235 |
+
save_attributes_to_hdf5_group(f, "weight_names", weight_names)
|
| 236 |
+
for name, val in zip(weight_names, weight_values):
|
| 237 |
+
param_dset = f.create_dataset(name, val.shape, dtype=val.dtype)
|
| 238 |
+
if not val.shape:
|
| 239 |
+
# scalar
|
| 240 |
+
param_dset[()] = val
|
| 241 |
+
else:
|
| 242 |
+
param_dset[:] = val
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
|
| 246 |
+
"""Saves optimizer weights of a optimizer to a HDF5 group.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
hdf5_group: HDF5 group.
|
| 250 |
+
optimizer: optimizer instance.
|
| 251 |
+
"""
|
| 252 |
+
if isinstance(optimizer, optimizers.Optimizer):
|
| 253 |
+
symbolic_weights = optimizer.variables
|
| 254 |
+
else:
|
| 255 |
+
symbolic_weights = getattr(optimizer, "weights")
|
| 256 |
+
if symbolic_weights:
|
| 257 |
+
weights_group = hdf5_group.create_group("optimizer_weights")
|
| 258 |
+
weight_names = [str(w.path).encode("utf8") for w in symbolic_weights]
|
| 259 |
+
save_attributes_to_hdf5_group(
|
| 260 |
+
weights_group, "weight_names", weight_names
|
| 261 |
+
)
|
| 262 |
+
weight_values = [backend.convert_to_numpy(w) for w in symbolic_weights]
|
| 263 |
+
for name, val in zip(weight_names, weight_values):
|
| 264 |
+
param_dset = weights_group.create_dataset(
|
| 265 |
+
name, val.shape, dtype=val.dtype
|
| 266 |
+
)
|
| 267 |
+
if not val.shape:
|
| 268 |
+
# scalar
|
| 269 |
+
param_dset[()] = val
|
| 270 |
+
else:
|
| 271 |
+
param_dset[:] = val
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def save_attributes_to_hdf5_group(group, name, data):
|
| 275 |
+
"""Saves attributes (data) of the specified name into the HDF5 group.
|
| 276 |
+
|
| 277 |
+
This method deals with an inherent problem of HDF5 file which is not
|
| 278 |
+
able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
group: A pointer to a HDF5 group.
|
| 282 |
+
name: A name of the attributes to save.
|
| 283 |
+
data: Attributes data to store.
|
| 284 |
+
|
| 285 |
+
Raises:
|
| 286 |
+
RuntimeError: If any single attribute is too large to be saved.
|
| 287 |
+
"""
|
| 288 |
+
# Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
|
| 289 |
+
# because in that case even chunking the array would not make the saving
|
| 290 |
+
# possible.
|
| 291 |
+
bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
|
| 292 |
+
|
| 293 |
+
# Expecting this to never be true.
|
| 294 |
+
if bad_attributes:
|
| 295 |
+
raise RuntimeError(
|
| 296 |
+
"The following attributes cannot be saved to HDF5 file because "
|
| 297 |
+
f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} "
|
| 298 |
+
f"bytes: {bad_attributes}"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
data_npy = np.asarray(data)
|
| 302 |
+
|
| 303 |
+
num_chunks = 1
|
| 304 |
+
chunked_data = np.array_split(data_npy, num_chunks)
|
| 305 |
+
|
| 306 |
+
# This will never loop forever thanks to the test above.
|
| 307 |
+
while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
|
| 308 |
+
num_chunks += 1
|
| 309 |
+
chunked_data = np.array_split(data_npy, num_chunks)
|
| 310 |
+
|
| 311 |
+
if num_chunks > 1:
|
| 312 |
+
for chunk_id, chunk_data in enumerate(chunked_data):
|
| 313 |
+
group.attrs["%s%d" % (name, chunk_id)] = chunk_data
|
| 314 |
+
else:
|
| 315 |
+
group.attrs[name] = data
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def load_weights_from_hdf5_group(f, model):
|
| 319 |
+
"""Implements topological (order-based) weight loading.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
f: A pointer to a HDF5 group.
|
| 323 |
+
model: Model instance.
|
| 324 |
+
|
| 325 |
+
Raises:
|
| 326 |
+
ValueError: in case of mismatch between provided layers
|
| 327 |
+
and weights file.
|
| 328 |
+
"""
|
| 329 |
+
if "keras_version" in f.attrs:
|
| 330 |
+
original_keras_version = f.attrs["keras_version"]
|
| 331 |
+
if hasattr(original_keras_version, "decode"):
|
| 332 |
+
original_keras_version = original_keras_version.decode("utf8")
|
| 333 |
+
else:
|
| 334 |
+
original_keras_version = "1"
|
| 335 |
+
if "backend" in f.attrs:
|
| 336 |
+
original_backend = f.attrs["backend"]
|
| 337 |
+
if hasattr(original_backend, "decode"):
|
| 338 |
+
original_backend = original_backend.decode("utf8")
|
| 339 |
+
else:
|
| 340 |
+
original_backend = None
|
| 341 |
+
|
| 342 |
+
filtered_layers = []
|
| 343 |
+
for layer in model.layers:
|
| 344 |
+
weights = _legacy_weights(layer)
|
| 345 |
+
if weights:
|
| 346 |
+
filtered_layers.append(layer)
|
| 347 |
+
|
| 348 |
+
layer_names = load_attributes_from_hdf5_group(f, "layer_names")
|
| 349 |
+
filtered_layer_names = []
|
| 350 |
+
for name in layer_names:
|
| 351 |
+
g = f[name]
|
| 352 |
+
weight_names = load_attributes_from_hdf5_group(g, "weight_names")
|
| 353 |
+
if weight_names:
|
| 354 |
+
filtered_layer_names.append(name)
|
| 355 |
+
layer_names = filtered_layer_names
|
| 356 |
+
if len(layer_names) != len(filtered_layers):
|
| 357 |
+
raise ValueError(
|
| 358 |
+
"Layer count mismatch when loading weights from file. "
|
| 359 |
+
f"Model expected {len(filtered_layers)} layers, found "
|
| 360 |
+
f"{len(layer_names)} saved layers."
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
for k, name in enumerate(layer_names):
|
| 364 |
+
g = f[name]
|
| 365 |
+
layer = filtered_layers[k]
|
| 366 |
+
symbolic_weights = _legacy_weights(layer)
|
| 367 |
+
weight_values = load_subset_weights_from_hdf5_group(g)
|
| 368 |
+
if len(weight_values) != len(symbolic_weights):
|
| 369 |
+
raise ValueError(
|
| 370 |
+
f"Weight count mismatch for layer #{k} (named {layer.name} in "
|
| 371 |
+
f"the current model, {name} in the save file). "
|
| 372 |
+
f"Layer expects {len(symbolic_weights)} weight(s). Received "
|
| 373 |
+
f"{len(weight_values)} saved weight(s)"
|
| 374 |
+
)
|
| 375 |
+
_set_weights(
|
| 376 |
+
layer,
|
| 377 |
+
symbolic_weights,
|
| 378 |
+
weight_values,
|
| 379 |
+
name=f"layer #{k} (named {layer.name})",
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
if "top_level_model_weights" in f:
|
| 383 |
+
symbolic_weights = list(
|
| 384 |
+
# model.weights
|
| 385 |
+
v
|
| 386 |
+
for v in model._trainable_variables + model._non_trainable_variables
|
| 387 |
+
if v in model.weights
|
| 388 |
+
)
|
| 389 |
+
weight_values = load_subset_weights_from_hdf5_group(
|
| 390 |
+
f["top_level_model_weights"]
|
| 391 |
+
)
|
| 392 |
+
if len(weight_values) != len(symbolic_weights):
|
| 393 |
+
raise ValueError(
|
| 394 |
+
"Weight count mismatch for top-level weights when loading "
|
| 395 |
+
"weights from file. "
|
| 396 |
+
f"Model expects {len(symbolic_weights)} top-level weight(s). "
|
| 397 |
+
f"Received {len(weight_values)} saved top-level weight(s)"
|
| 398 |
+
)
|
| 399 |
+
_set_weights(
|
| 400 |
+
model,
|
| 401 |
+
symbolic_weights,
|
| 402 |
+
weight_values,
|
| 403 |
+
name="top-level model",
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _set_weights(
|
| 408 |
+
instance, symbolic_weights, weight_values, name, skip_mismatch=False
|
| 409 |
+
):
|
| 410 |
+
"""Safely set weights into a model or a layer.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
instance: Model or layer instance,
|
| 414 |
+
symbolic_weights: symbolic tensors representing
|
| 415 |
+
the weights of the variables to load,
|
| 416 |
+
weight_values: values of the weights to load,
|
| 417 |
+
skip_mismatch: Boolean, whether to skip loading of weights
|
| 418 |
+
where there is a mismatch in the shape of the weights,
|
| 419 |
+
name: name used to identify the group.
|
| 420 |
+
|
| 421 |
+
Raises:
|
| 422 |
+
ValueError: in case of mismatch between provided
|
| 423 |
+
model/layer and weights.
|
| 424 |
+
"""
|
| 425 |
+
for i, weight_value in enumerate(weight_values):
|
| 426 |
+
expected_shape = symbolic_weights[i].shape
|
| 427 |
+
received_shape = weight_value.shape
|
| 428 |
+
if expected_shape != received_shape:
|
| 429 |
+
if skip_mismatch:
|
| 430 |
+
warnings.warn(
|
| 431 |
+
f"Skipping loading weights for {name}"
|
| 432 |
+
f"due to mismatch in shape for "
|
| 433 |
+
f"weight {symbolic_weights[i].path}. "
|
| 434 |
+
f"Weight expects shape {expected_shape}. "
|
| 435 |
+
"Received saved weight "
|
| 436 |
+
f"with shape {received_shape}",
|
| 437 |
+
stacklevel=2,
|
| 438 |
+
)
|
| 439 |
+
continue
|
| 440 |
+
raise ValueError(
|
| 441 |
+
f"Shape mismatch in {name}"
|
| 442 |
+
f"for weight {symbolic_weights[i].path}. "
|
| 443 |
+
f"Weight expects shape {expected_shape}. "
|
| 444 |
+
"Received saved weight "
|
| 445 |
+
f"with shape {received_shape}"
|
| 446 |
+
)
|
| 447 |
+
symbolic_weights[i].assign(weight_value)
|
| 448 |
+
|
| 449 |
+
if hasattr(instance, "finalize_state") and symbolic_weights:
|
| 450 |
+
instance.finalize_state()
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False):
|
| 454 |
+
"""Implements name-based weight loading (instead of topological loading).
|
| 455 |
+
|
| 456 |
+
Layers that have no matching name are skipped.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
f: A pointer to a HDF5 group.
|
| 460 |
+
model: Model instance.
|
| 461 |
+
skip_mismatch: Boolean, whether to skip loading of layers
|
| 462 |
+
where there is a mismatch in the number of weights,
|
| 463 |
+
or a mismatch in the shape of the weights.
|
| 464 |
+
|
| 465 |
+
Raises:
|
| 466 |
+
ValueError: in case of mismatch between provided layers
|
| 467 |
+
and weights file and skip_match=False.
|
| 468 |
+
"""
|
| 469 |
+
if "keras_version" in f.attrs:
|
| 470 |
+
original_keras_version = f.attrs["keras_version"]
|
| 471 |
+
if hasattr(original_keras_version, "decode"):
|
| 472 |
+
original_keras_version = original_keras_version.decode("utf8")
|
| 473 |
+
else:
|
| 474 |
+
original_keras_version = "1"
|
| 475 |
+
if "backend" in f.attrs:
|
| 476 |
+
original_backend = f.attrs["backend"]
|
| 477 |
+
if hasattr(original_backend, "decode"):
|
| 478 |
+
original_backend = original_backend.decode("utf8")
|
| 479 |
+
else:
|
| 480 |
+
original_backend = None
|
| 481 |
+
|
| 482 |
+
# New file format.
|
| 483 |
+
layer_names = load_attributes_from_hdf5_group(f, "layer_names")
|
| 484 |
+
|
| 485 |
+
# Reverse index of layer name to list of layers with name.
|
| 486 |
+
index = {}
|
| 487 |
+
for layer in model.layers:
|
| 488 |
+
if layer.name:
|
| 489 |
+
index.setdefault(layer.name, []).append(layer)
|
| 490 |
+
|
| 491 |
+
for k, name in enumerate(layer_names):
|
| 492 |
+
g = f[name]
|
| 493 |
+
weight_values = load_subset_weights_from_hdf5_group(g)
|
| 494 |
+
for layer in index.get(name, []):
|
| 495 |
+
symbolic_weights = _legacy_weights(layer)
|
| 496 |
+
if len(weight_values) != len(symbolic_weights):
|
| 497 |
+
if skip_mismatch:
|
| 498 |
+
warnings.warn(
|
| 499 |
+
f"Skipping loading of weights for layer #{k} (named "
|
| 500 |
+
f"{layer.name}) due to mismatch in number of weights. "
|
| 501 |
+
f"Layer expects {len(symbolic_weights)} weight(s). "
|
| 502 |
+
f"Received {len(weight_values)} saved weight(s)",
|
| 503 |
+
stacklevel=2,
|
| 504 |
+
)
|
| 505 |
+
continue
|
| 506 |
+
raise ValueError(
|
| 507 |
+
f"Weight count mismatch for layer #{k} "
|
| 508 |
+
f"(named {layer.name}). "
|
| 509 |
+
f"Layer expects {len(symbolic_weights)} weight(s). "
|
| 510 |
+
f"Received {len(weight_values)} saved weight(s)"
|
| 511 |
+
)
|
| 512 |
+
# Set values.
|
| 513 |
+
_set_weights(
|
| 514 |
+
layer,
|
| 515 |
+
symbolic_weights,
|
| 516 |
+
weight_values,
|
| 517 |
+
skip_mismatch=skip_mismatch,
|
| 518 |
+
name=f"layer #{k} (named {layer.name})",
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if "top_level_model_weights" in f:
|
| 522 |
+
symbolic_weights = (
|
| 523 |
+
model._trainable_variables + model._non_trainable_variables
|
| 524 |
+
)
|
| 525 |
+
weight_values = load_subset_weights_from_hdf5_group(
|
| 526 |
+
f["top_level_model_weights"]
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
if len(weight_values) != len(symbolic_weights):
|
| 530 |
+
if skip_mismatch:
|
| 531 |
+
warnings.warn(
|
| 532 |
+
"Skipping loading top-level weights for model due to "
|
| 533 |
+
"mismatch in number of weights. "
|
| 534 |
+
f"Model expects {len(symbolic_weights)} "
|
| 535 |
+
"top-level weight(s). "
|
| 536 |
+
f"Received {len(weight_values)} saved top-level weight(s)",
|
| 537 |
+
stacklevel=2,
|
| 538 |
+
)
|
| 539 |
+
else:
|
| 540 |
+
raise ValueError(
|
| 541 |
+
"Weight count mismatch for top-level weights of model. "
|
| 542 |
+
f"Model expects {len(symbolic_weights)} "
|
| 543 |
+
"top-level weight(s). "
|
| 544 |
+
f"Received {len(weight_values)} saved top-level weight(s)"
|
| 545 |
+
)
|
| 546 |
+
else:
|
| 547 |
+
_set_weights(
|
| 548 |
+
model,
|
| 549 |
+
symbolic_weights,
|
| 550 |
+
weight_values,
|
| 551 |
+
skip_mismatch=skip_mismatch,
|
| 552 |
+
name="top-level model",
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def load_subset_weights_from_hdf5_group(f):
|
| 557 |
+
"""Load layer weights of a model from hdf5.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
f: A pointer to a HDF5 group.
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
List of NumPy arrays of the weight values.
|
| 564 |
+
|
| 565 |
+
Raises:
|
| 566 |
+
ValueError: in case of mismatch between provided model
|
| 567 |
+
and weights file.
|
| 568 |
+
"""
|
| 569 |
+
weight_names = load_attributes_from_hdf5_group(f, "weight_names")
|
| 570 |
+
return [np.asarray(f[weight_name]) for weight_name in weight_names]
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def load_optimizer_weights_from_hdf5_group(hdf5_group):
|
| 574 |
+
"""Load optimizer weights from a HDF5 group.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
hdf5_group: A pointer to a HDF5 group.
|
| 578 |
+
|
| 579 |
+
Returns:
|
| 580 |
+
data: List of optimizer weight names.
|
| 581 |
+
"""
|
| 582 |
+
weights_group = hdf5_group["optimizer_weights"]
|
| 583 |
+
optimizer_weight_names = load_attributes_from_hdf5_group(
|
| 584 |
+
weights_group, "weight_names"
|
| 585 |
+
)
|
| 586 |
+
return [
|
| 587 |
+
weights_group[weight_name] for weight_name in optimizer_weight_names
|
| 588 |
+
]
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def load_attributes_from_hdf5_group(group, name):
|
| 592 |
+
"""Loads attributes of the specified name from the HDF5 group.
|
| 593 |
+
|
| 594 |
+
This method deals with an inherent problem
|
| 595 |
+
of HDF5 file which is not able to store
|
| 596 |
+
data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
group: A pointer to a HDF5 group.
|
| 600 |
+
name: A name of the attributes to load.
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
data: Attributes data.
|
| 604 |
+
"""
|
| 605 |
+
if name in group.attrs:
|
| 606 |
+
data = [
|
| 607 |
+
n.decode("utf8") if hasattr(n, "decode") else n
|
| 608 |
+
for n in group.attrs[name]
|
| 609 |
+
]
|
| 610 |
+
else:
|
| 611 |
+
data = []
|
| 612 |
+
chunk_id = 0
|
| 613 |
+
while f"{name}{chunk_id}" in group.attrs:
|
| 614 |
+
data.extend(
|
| 615 |
+
[
|
| 616 |
+
n.decode("utf8") if hasattr(n, "decode") else n
|
| 617 |
+
for n in group.attrs[f"{name}{chunk_id}"]
|
| 618 |
+
]
|
| 619 |
+
)
|
| 620 |
+
chunk_id += 1
|
| 621 |
+
return data
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def _legacy_weights(layer):
|
| 625 |
+
"""Legacy weight order converter.
|
| 626 |
+
|
| 627 |
+
For legacy reason, the layer.weights was in the order of
|
| 628 |
+
[self.trainable_weights + self.non_trainable_weights], and this order was
|
| 629 |
+
used for preserving the weights in h5 format. The new order of layer.weights
|
| 630 |
+
are the same as layer.get_weights() which is more intuitive for user. To
|
| 631 |
+
keep supporting the existing saved h5 file, this method should be used to
|
| 632 |
+
save/load weights.
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
layer: a `Model` or `Layer` instance.
|
| 636 |
+
|
| 637 |
+
Returns:
|
| 638 |
+
A list of variables with the legacy weight order.
|
| 639 |
+
"""
|
| 640 |
+
return layer.trainable_weights + layer.non_trainable_weights
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/saving_options.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
|
| 3 |
+
from keras.src.backend.common import global_state
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@contextlib.contextmanager
|
| 7 |
+
def keras_option_scope(use_legacy_config=True):
|
| 8 |
+
use_legacy_config_prev_value = global_state.get_global_attribute(
|
| 9 |
+
"use_legacy_config", None
|
| 10 |
+
)
|
| 11 |
+
global_state.set_global_attribute("use_legacy_config", use_legacy_config)
|
| 12 |
+
try:
|
| 13 |
+
yield
|
| 14 |
+
finally:
|
| 15 |
+
global_state.set_global_attribute(
|
| 16 |
+
"use_legacy_config", use_legacy_config_prev_value
|
| 17 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/saving_utils.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import threading
|
| 3 |
+
|
| 4 |
+
from absl import logging
|
| 5 |
+
|
| 6 |
+
from keras.src import backend
|
| 7 |
+
from keras.src import layers
|
| 8 |
+
from keras.src import losses
|
| 9 |
+
from keras.src import metrics as metrics_module
|
| 10 |
+
from keras.src import models
|
| 11 |
+
from keras.src import optimizers
|
| 12 |
+
from keras.src import tree
|
| 13 |
+
from keras.src.legacy.saving import serialization
|
| 14 |
+
from keras.src.saving import object_registration
|
| 15 |
+
|
| 16 |
+
MODULE_OBJECTS = threading.local()
|
| 17 |
+
|
| 18 |
+
# Legacy lambda arguments not found in Keras 3
|
| 19 |
+
LAMBDA_DEP_ARGS = (
|
| 20 |
+
"module",
|
| 21 |
+
"function_type",
|
| 22 |
+
"output_shape_type",
|
| 23 |
+
"output_shape_module",
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def model_from_config(config, custom_objects=None):
|
| 28 |
+
"""Instantiates a Keras model from its config.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
config: Configuration dictionary.
|
| 32 |
+
custom_objects: Optional dictionary mapping names
|
| 33 |
+
(strings) to custom classes or functions to be
|
| 34 |
+
considered during deserialization.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
A Keras model instance (uncompiled).
|
| 38 |
+
|
| 39 |
+
Raises:
|
| 40 |
+
TypeError: if `config` is not a dictionary.
|
| 41 |
+
"""
|
| 42 |
+
if isinstance(config, list):
|
| 43 |
+
raise TypeError(
|
| 44 |
+
"`model_from_config` expects a dictionary, not a list. "
|
| 45 |
+
f"Received: config={config}. Did you meant to use "
|
| 46 |
+
"`Sequential.from_config(config)`?"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
global MODULE_OBJECTS
|
| 50 |
+
|
| 51 |
+
if not hasattr(MODULE_OBJECTS, "ALL_OBJECTS"):
|
| 52 |
+
MODULE_OBJECTS.ALL_OBJECTS = layers.__dict__
|
| 53 |
+
MODULE_OBJECTS.ALL_OBJECTS["InputLayer"] = layers.InputLayer
|
| 54 |
+
MODULE_OBJECTS.ALL_OBJECTS["Functional"] = models.Functional
|
| 55 |
+
MODULE_OBJECTS.ALL_OBJECTS["Model"] = models.Model
|
| 56 |
+
MODULE_OBJECTS.ALL_OBJECTS["Sequential"] = models.Sequential
|
| 57 |
+
|
| 58 |
+
batch_input_shape = config["config"].pop("batch_input_shape", None)
|
| 59 |
+
if batch_input_shape is not None:
|
| 60 |
+
if config["class_name"] == "InputLayer":
|
| 61 |
+
config["config"]["batch_shape"] = batch_input_shape
|
| 62 |
+
else:
|
| 63 |
+
config["config"]["input_shape"] = batch_input_shape
|
| 64 |
+
|
| 65 |
+
axis = config["config"].pop("axis", None)
|
| 66 |
+
if axis is not None and isinstance(axis, list) and len(axis) == 1:
|
| 67 |
+
config["config"]["axis"] = int(axis[0])
|
| 68 |
+
|
| 69 |
+
# Handle backwards compatibility for Keras lambdas
|
| 70 |
+
if config["class_name"] == "Lambda":
|
| 71 |
+
for dep_arg in LAMBDA_DEP_ARGS:
|
| 72 |
+
_ = config["config"].pop(dep_arg, None)
|
| 73 |
+
function_config = config["config"]["function"]
|
| 74 |
+
if isinstance(function_config, list):
|
| 75 |
+
function_dict = {"class_name": "__lambda__", "config": {}}
|
| 76 |
+
function_dict["config"]["code"] = function_config[0]
|
| 77 |
+
function_dict["config"]["defaults"] = function_config[1]
|
| 78 |
+
function_dict["config"]["closure"] = function_config[2]
|
| 79 |
+
config["config"]["function"] = function_dict
|
| 80 |
+
|
| 81 |
+
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
|
| 82 |
+
# Replace keras refs with keras
|
| 83 |
+
config = _find_replace_nested_dict(config, "keras.", "keras.")
|
| 84 |
+
|
| 85 |
+
return serialization.deserialize_keras_object(
|
| 86 |
+
config,
|
| 87 |
+
module_objects=MODULE_OBJECTS.ALL_OBJECTS,
|
| 88 |
+
custom_objects=custom_objects,
|
| 89 |
+
printable_module_name="layer",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def model_metadata(model, include_optimizer=True, require_config=True):
|
| 94 |
+
"""Returns a dictionary containing the model metadata."""
|
| 95 |
+
from keras.src import __version__ as keras_version
|
| 96 |
+
|
| 97 |
+
model_config = {"class_name": model.__class__.__name__}
|
| 98 |
+
try:
|
| 99 |
+
model_config["config"] = model.get_config()
|
| 100 |
+
except NotImplementedError as e:
|
| 101 |
+
if require_config:
|
| 102 |
+
raise e
|
| 103 |
+
|
| 104 |
+
metadata = dict(
|
| 105 |
+
keras_version=str(keras_version),
|
| 106 |
+
backend=backend.backend(),
|
| 107 |
+
model_config=model_config,
|
| 108 |
+
)
|
| 109 |
+
if getattr(model, "optimizer", False) and include_optimizer:
|
| 110 |
+
if model.compiled:
|
| 111 |
+
training_config = model._compile_config.config
|
| 112 |
+
training_config.pop("optimizer", None) # Handled separately.
|
| 113 |
+
metadata["training_config"] = _serialize_nested_config(
|
| 114 |
+
training_config
|
| 115 |
+
)
|
| 116 |
+
optimizer_config = {
|
| 117 |
+
"class_name": object_registration.get_registered_name(
|
| 118 |
+
model.optimizer.__class__
|
| 119 |
+
),
|
| 120 |
+
"config": model.optimizer.get_config(),
|
| 121 |
+
}
|
| 122 |
+
metadata["training_config"]["optimizer_config"] = optimizer_config
|
| 123 |
+
return metadata
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def compile_args_from_training_config(training_config, custom_objects=None):
|
| 127 |
+
"""Return model.compile arguments from training config."""
|
| 128 |
+
if custom_objects is None:
|
| 129 |
+
custom_objects = {}
|
| 130 |
+
|
| 131 |
+
with object_registration.CustomObjectScope(custom_objects):
|
| 132 |
+
optimizer_config = training_config["optimizer_config"]
|
| 133 |
+
optimizer = optimizers.deserialize(optimizer_config)
|
| 134 |
+
# Ensure backwards compatibility for optimizers in legacy H5 files
|
| 135 |
+
optimizer = _resolve_compile_arguments_compat(
|
| 136 |
+
optimizer, optimizer_config, optimizers
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Recover losses.
|
| 140 |
+
loss = None
|
| 141 |
+
loss_config = training_config.get("loss", None)
|
| 142 |
+
if loss_config is not None:
|
| 143 |
+
loss = _deserialize_nested_config(losses.deserialize, loss_config)
|
| 144 |
+
# Ensure backwards compatibility for losses in legacy H5 files
|
| 145 |
+
loss = _resolve_compile_arguments_compat(loss, loss_config, losses)
|
| 146 |
+
|
| 147 |
+
# Recover metrics.
|
| 148 |
+
metrics = None
|
| 149 |
+
metrics_config = training_config.get("metrics", None)
|
| 150 |
+
if metrics_config is not None:
|
| 151 |
+
metrics = _deserialize_nested_config(
|
| 152 |
+
_deserialize_metric, metrics_config
|
| 153 |
+
)
|
| 154 |
+
# Ensure backwards compatibility for metrics in legacy H5 files
|
| 155 |
+
metrics = _resolve_compile_arguments_compat(
|
| 156 |
+
metrics, metrics_config, metrics_module
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Recover weighted metrics.
|
| 160 |
+
weighted_metrics = None
|
| 161 |
+
weighted_metrics_config = training_config.get("weighted_metrics", None)
|
| 162 |
+
if weighted_metrics_config is not None:
|
| 163 |
+
weighted_metrics = _deserialize_nested_config(
|
| 164 |
+
_deserialize_metric, weighted_metrics_config
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
loss_weights = training_config["loss_weights"]
|
| 168 |
+
|
| 169 |
+
return dict(
|
| 170 |
+
optimizer=optimizer,
|
| 171 |
+
loss=loss,
|
| 172 |
+
metrics=metrics,
|
| 173 |
+
weighted_metrics=weighted_metrics,
|
| 174 |
+
loss_weights=loss_weights,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _serialize_nested_config(config):
|
| 179 |
+
"""Serialized a nested structure of Keras objects."""
|
| 180 |
+
|
| 181 |
+
def _serialize_fn(obj):
|
| 182 |
+
if callable(obj):
|
| 183 |
+
return serialization.serialize_keras_object(obj)
|
| 184 |
+
return obj
|
| 185 |
+
|
| 186 |
+
return tree.map_structure(_serialize_fn, config)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _deserialize_nested_config(deserialize_fn, config):
|
| 190 |
+
"""Deserializes arbitrary Keras `config` using `deserialize_fn`."""
|
| 191 |
+
|
| 192 |
+
def _is_single_object(obj):
|
| 193 |
+
if isinstance(obj, dict) and "class_name" in obj:
|
| 194 |
+
return True # Serialized Keras object.
|
| 195 |
+
if isinstance(obj, str):
|
| 196 |
+
return True # Serialized function or string.
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
if config is None:
|
| 200 |
+
return None
|
| 201 |
+
if _is_single_object(config):
|
| 202 |
+
return deserialize_fn(config)
|
| 203 |
+
elif isinstance(config, dict):
|
| 204 |
+
return {
|
| 205 |
+
k: _deserialize_nested_config(deserialize_fn, v)
|
| 206 |
+
for k, v in config.items()
|
| 207 |
+
}
|
| 208 |
+
elif isinstance(config, (tuple, list)):
|
| 209 |
+
return [
|
| 210 |
+
_deserialize_nested_config(deserialize_fn, obj) for obj in config
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
raise ValueError(
|
| 214 |
+
"Saved configuration not understood. Configuration should be a "
|
| 215 |
+
f"dictionary, string, tuple or list. Received: config={config}."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _deserialize_metric(metric_config):
|
| 220 |
+
"""Deserialize metrics, leaving special strings untouched."""
|
| 221 |
+
if metric_config in ["accuracy", "acc", "crossentropy", "ce"]:
|
| 222 |
+
# Do not deserialize accuracy and cross-entropy strings as we have
|
| 223 |
+
# special case handling for these in compile, based on model output
|
| 224 |
+
# shape.
|
| 225 |
+
return metric_config
|
| 226 |
+
return metrics_module.deserialize(metric_config)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _find_replace_nested_dict(config, find, replace):
|
| 230 |
+
dict_str = json.dumps(config)
|
| 231 |
+
dict_str = dict_str.replace(find, replace)
|
| 232 |
+
config = json.loads(dict_str)
|
| 233 |
+
return config
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _resolve_compile_arguments_compat(obj, obj_config, module):
|
| 237 |
+
"""Resolves backwards compatibility issues with training config arguments.
|
| 238 |
+
|
| 239 |
+
This helper function accepts built-in Keras modules such as optimizers,
|
| 240 |
+
losses, and metrics to ensure an object being deserialized is compatible
|
| 241 |
+
with Keras 3 built-ins. For legacy H5 files saved within Keras 3,
|
| 242 |
+
this does nothing.
|
| 243 |
+
"""
|
| 244 |
+
if isinstance(obj, str) and obj not in module.ALL_OBJECTS_DICT:
|
| 245 |
+
obj = module.get(obj_config["config"]["name"])
|
| 246 |
+
return obj
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def try_build_compiled_arguments(model):
|
| 250 |
+
try:
|
| 251 |
+
if not model.compiled_loss.built:
|
| 252 |
+
model.compiled_loss.build(model.outputs)
|
| 253 |
+
if not model.compiled_metrics.built:
|
| 254 |
+
model.compiled_metrics.build(model.outputs, model.outputs)
|
| 255 |
+
except:
|
| 256 |
+
logging.warning(
|
| 257 |
+
"Compiled the loaded model, but the compiled metrics have "
|
| 258 |
+
"yet to be built. `model.compile_metrics` will be empty "
|
| 259 |
+
"until you train or evaluate the model."
|
| 260 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/serialization.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Legacy serialization logic for Keras models."""
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import inspect
|
| 5 |
+
import json
|
| 6 |
+
import threading
|
| 7 |
+
import weakref
|
| 8 |
+
|
| 9 |
+
# isort: off
|
| 10 |
+
from keras.src.api_export import keras_export
|
| 11 |
+
from keras.src.saving import object_registration
|
| 12 |
+
|
| 13 |
+
# Flag that determines whether to skip the NotImplementedError when calling
|
| 14 |
+
# get_config in custom models and layers. This is only enabled when saving to
|
| 15 |
+
# SavedModel, when the config isn't required.
|
| 16 |
+
_SKIP_FAILED_SERIALIZATION = False
|
| 17 |
+
# If a layer does not have a defined config, then the returned config will be a
|
| 18 |
+
# dictionary with the below key.
|
| 19 |
+
_LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"
|
| 20 |
+
|
| 21 |
+
# Store a unique, per-object ID for shared objects.
|
| 22 |
+
#
|
| 23 |
+
# We store a unique ID for each object so that we may, at loading time,
|
| 24 |
+
# re-create the network properly. Without this ID, we would have no way of
|
| 25 |
+
# determining whether a config is a description of a new object that
|
| 26 |
+
# should be created or is merely a reference to an already-created object.
|
| 27 |
+
SHARED_OBJECT_KEY = "shared_object_id"
|
| 28 |
+
|
| 29 |
+
SHARED_OBJECT_DISABLED = threading.local()
|
| 30 |
+
SHARED_OBJECT_LOADING = threading.local()
|
| 31 |
+
SHARED_OBJECT_SAVING = threading.local()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Attributes on the threadlocal variable must be set per-thread, thus we
|
| 35 |
+
# cannot initialize these globally. Instead, we have accessor functions with
|
| 36 |
+
# default values.
|
| 37 |
+
def _shared_object_disabled():
|
| 38 |
+
"""Get whether shared object handling is disabled in a threadsafe manner."""
|
| 39 |
+
return getattr(SHARED_OBJECT_DISABLED, "disabled", False)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _shared_object_loading_scope():
|
| 43 |
+
"""Get the current shared object saving scope in a threadsafe manner."""
|
| 44 |
+
return getattr(SHARED_OBJECT_LOADING, "scope", NoopLoadingScope())
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _shared_object_saving_scope():
|
| 48 |
+
"""Get the current shared object saving scope in a threadsafe manner."""
|
| 49 |
+
return getattr(SHARED_OBJECT_SAVING, "scope", None)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DisableSharedObjectScope:
|
| 53 |
+
"""A context manager for disabling handling of shared objects.
|
| 54 |
+
|
| 55 |
+
Disables shared object handling for both saving and loading.
|
| 56 |
+
|
| 57 |
+
Created primarily for use with `clone_model`, which does extra surgery that
|
| 58 |
+
is incompatible with shared objects.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __enter__(self):
|
| 62 |
+
SHARED_OBJECT_DISABLED.disabled = True
|
| 63 |
+
self._orig_loading_scope = _shared_object_loading_scope()
|
| 64 |
+
self._orig_saving_scope = _shared_object_saving_scope()
|
| 65 |
+
|
| 66 |
+
def __exit__(self, *args, **kwargs):
|
| 67 |
+
SHARED_OBJECT_DISABLED.disabled = False
|
| 68 |
+
SHARED_OBJECT_LOADING.scope = self._orig_loading_scope
|
| 69 |
+
SHARED_OBJECT_SAVING.scope = self._orig_saving_scope
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class NoopLoadingScope:
|
| 73 |
+
"""The default shared object loading scope. It does nothing.
|
| 74 |
+
|
| 75 |
+
Created to simplify serialization code that doesn't care about shared
|
| 76 |
+
objects (e.g. when serializing a single object).
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def get(self, unused_object_id):
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
def set(self, object_id, obj):
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class SharedObjectLoadingScope:
|
| 87 |
+
"""A context manager for keeping track of loaded objects.
|
| 88 |
+
|
| 89 |
+
During the deserialization process, we may come across objects that are
|
| 90 |
+
shared across multiple layers. In order to accurately restore the network
|
| 91 |
+
structure to its original state, `SharedObjectLoadingScope` allows us to
|
| 92 |
+
re-use shared objects rather than cloning them.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __enter__(self):
|
| 96 |
+
if _shared_object_disabled():
|
| 97 |
+
return NoopLoadingScope()
|
| 98 |
+
|
| 99 |
+
global SHARED_OBJECT_LOADING
|
| 100 |
+
SHARED_OBJECT_LOADING.scope = self
|
| 101 |
+
self._obj_ids_to_obj = {}
|
| 102 |
+
return self
|
| 103 |
+
|
| 104 |
+
def get(self, object_id):
|
| 105 |
+
"""Given a shared object ID, returns a previously instantiated object.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
object_id: shared object ID to use when attempting to find
|
| 109 |
+
already-loaded object.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
The object, if we've seen this ID before. Else, `None`.
|
| 113 |
+
"""
|
| 114 |
+
# Explicitly check for `None` internally to make external calling code a
|
| 115 |
+
# bit cleaner.
|
| 116 |
+
if object_id is None:
|
| 117 |
+
return
|
| 118 |
+
return self._obj_ids_to_obj.get(object_id)
|
| 119 |
+
|
| 120 |
+
def set(self, object_id, obj):
|
| 121 |
+
"""Stores an instantiated object for future lookup and sharing."""
|
| 122 |
+
if object_id is None:
|
| 123 |
+
return
|
| 124 |
+
self._obj_ids_to_obj[object_id] = obj
|
| 125 |
+
|
| 126 |
+
def __exit__(self, *args, **kwargs):
|
| 127 |
+
global SHARED_OBJECT_LOADING
|
| 128 |
+
SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class SharedObjectConfig(dict):
|
| 132 |
+
"""A configuration container that keeps track of references.
|
| 133 |
+
|
| 134 |
+
`SharedObjectConfig` will automatically attach a shared object ID to any
|
| 135 |
+
configs which are referenced more than once, allowing for proper shared
|
| 136 |
+
object reconstruction at load time.
|
| 137 |
+
|
| 138 |
+
In most cases, it would be more proper to subclass something like
|
| 139 |
+
`collections.UserDict` or `collections.Mapping` rather than `dict` directly.
|
| 140 |
+
Unfortunately, python's json encoder does not support `Mapping`s. This is
|
| 141 |
+
important functionality to retain, since we are dealing with serialization.
|
| 142 |
+
|
| 143 |
+
We should be safe to subclass `dict` here, since we aren't actually
|
| 144 |
+
overriding any core methods, only augmenting with a new one for reference
|
| 145 |
+
counting.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def __init__(self, base_config, object_id, **kwargs):
|
| 149 |
+
self.ref_count = 1
|
| 150 |
+
self.object_id = object_id
|
| 151 |
+
super().__init__(base_config, **kwargs)
|
| 152 |
+
|
| 153 |
+
def increment_ref_count(self):
|
| 154 |
+
# As soon as we've seen the object more than once, we want to attach the
|
| 155 |
+
# shared object ID. This allows us to only attach the shared object ID
|
| 156 |
+
# when it's strictly necessary, making backwards compatibility breakage
|
| 157 |
+
# less likely.
|
| 158 |
+
if self.ref_count == 1:
|
| 159 |
+
self[SHARED_OBJECT_KEY] = self.object_id
|
| 160 |
+
self.ref_count += 1
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class SharedObjectSavingScope:
|
| 164 |
+
"""Keeps track of shared object configs when serializing."""
|
| 165 |
+
|
| 166 |
+
def __enter__(self):
|
| 167 |
+
if _shared_object_disabled():
|
| 168 |
+
return None
|
| 169 |
+
|
| 170 |
+
global SHARED_OBJECT_SAVING
|
| 171 |
+
|
| 172 |
+
# Serialization can happen at a number of layers for a number of
|
| 173 |
+
# reasons. We may end up with a case where we're opening a saving scope
|
| 174 |
+
# within another saving scope. In that case, we'd like to use the
|
| 175 |
+
# outermost scope available and ignore inner scopes, since there is not
|
| 176 |
+
# (yet) a reasonable use case for having these nested and distinct.
|
| 177 |
+
if _shared_object_saving_scope() is not None:
|
| 178 |
+
self._passthrough = True
|
| 179 |
+
return _shared_object_saving_scope()
|
| 180 |
+
else:
|
| 181 |
+
self._passthrough = False
|
| 182 |
+
|
| 183 |
+
SHARED_OBJECT_SAVING.scope = self
|
| 184 |
+
self._shared_objects_config = weakref.WeakKeyDictionary()
|
| 185 |
+
self._next_id = 0
|
| 186 |
+
return self
|
| 187 |
+
|
| 188 |
+
def get_config(self, obj):
|
| 189 |
+
"""Gets a `SharedObjectConfig` if one has already been seen for `obj`.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
obj: The object for which to retrieve the `SharedObjectConfig`.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
The SharedObjectConfig for a given object, if already seen. Else,
|
| 196 |
+
`None`.
|
| 197 |
+
"""
|
| 198 |
+
try:
|
| 199 |
+
shared_object_config = self._shared_objects_config[obj]
|
| 200 |
+
except (TypeError, KeyError):
|
| 201 |
+
# If the object is unhashable (e.g. a subclass of
|
| 202 |
+
# `AbstractBaseClass` that has not overridden `__hash__`), a
|
| 203 |
+
# `TypeError` will be thrown. We'll just continue on without shared
|
| 204 |
+
# object support.
|
| 205 |
+
return None
|
| 206 |
+
shared_object_config.increment_ref_count()
|
| 207 |
+
return shared_object_config
|
| 208 |
+
|
| 209 |
+
def create_config(self, base_config, obj):
|
| 210 |
+
"""Create a new SharedObjectConfig for a given object."""
|
| 211 |
+
shared_object_config = SharedObjectConfig(base_config, self._next_id)
|
| 212 |
+
self._next_id += 1
|
| 213 |
+
try:
|
| 214 |
+
self._shared_objects_config[obj] = shared_object_config
|
| 215 |
+
except TypeError:
|
| 216 |
+
# If the object is unhashable (e.g. a subclass of
|
| 217 |
+
# `AbstractBaseClass` that has not overridden `__hash__`), a
|
| 218 |
+
# `TypeError` will be thrown. We'll just continue on without shared
|
| 219 |
+
# object support.
|
| 220 |
+
pass
|
| 221 |
+
return shared_object_config
|
| 222 |
+
|
| 223 |
+
def __exit__(self, *args, **kwargs):
|
| 224 |
+
if not getattr(self, "_passthrough", False):
|
| 225 |
+
global SHARED_OBJECT_SAVING
|
| 226 |
+
SHARED_OBJECT_SAVING.scope = None
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def serialize_keras_class_and_config(
|
| 230 |
+
cls_name, cls_config, obj=None, shared_object_id=None
|
| 231 |
+
):
|
| 232 |
+
"""Returns the serialization of the class with the given config."""
|
| 233 |
+
base_config = {"class_name": cls_name, "config": cls_config}
|
| 234 |
+
|
| 235 |
+
# We call `serialize_keras_class_and_config` for some branches of the load
|
| 236 |
+
# path. In that case, we may already have a shared object ID we'd like to
|
| 237 |
+
# retain.
|
| 238 |
+
if shared_object_id is not None:
|
| 239 |
+
base_config[SHARED_OBJECT_KEY] = shared_object_id
|
| 240 |
+
|
| 241 |
+
# If we have an active `SharedObjectSavingScope`, check whether we've
|
| 242 |
+
# already serialized this config. If so, just use that config. This will
|
| 243 |
+
# store an extra ID field in the config, allowing us to re-create the shared
|
| 244 |
+
# object relationship at load time.
|
| 245 |
+
if _shared_object_saving_scope() is not None and obj is not None:
|
| 246 |
+
shared_object_config = _shared_object_saving_scope().get_config(obj)
|
| 247 |
+
if shared_object_config is None:
|
| 248 |
+
return _shared_object_saving_scope().create_config(base_config, obj)
|
| 249 |
+
return shared_object_config
|
| 250 |
+
|
| 251 |
+
return base_config
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@contextlib.contextmanager
|
| 255 |
+
def skip_failed_serialization():
|
| 256 |
+
global _SKIP_FAILED_SERIALIZATION
|
| 257 |
+
prev = _SKIP_FAILED_SERIALIZATION
|
| 258 |
+
try:
|
| 259 |
+
_SKIP_FAILED_SERIALIZATION = True
|
| 260 |
+
yield
|
| 261 |
+
finally:
|
| 262 |
+
_SKIP_FAILED_SERIALIZATION = prev
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@keras_export(
|
| 266 |
+
[
|
| 267 |
+
"keras.legacy.saving.serialize_keras_object",
|
| 268 |
+
"keras.utils.legacy.serialize_keras_object",
|
| 269 |
+
]
|
| 270 |
+
)
|
| 271 |
+
def serialize_keras_object(instance):
|
| 272 |
+
"""Serialize a Keras object into a JSON-compatible representation.
|
| 273 |
+
|
| 274 |
+
Calls to `serialize_keras_object` while underneath the
|
| 275 |
+
`SharedObjectSavingScope` context manager will cause any objects re-used
|
| 276 |
+
across multiple layers to be saved with a special shared object ID. This
|
| 277 |
+
allows the network to be re-created properly during deserialization.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
instance: The object to serialize.
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
A dict-like, JSON-compatible representation of the object's config.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
# _, instance = tf.__internal__.decorator.unwrap(instance)
|
| 287 |
+
instance = inspect.unwrap(instance)
|
| 288 |
+
if instance is None:
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
if hasattr(instance, "get_config"):
|
| 292 |
+
name = object_registration.get_registered_name(instance.__class__)
|
| 293 |
+
try:
|
| 294 |
+
config = instance.get_config()
|
| 295 |
+
except NotImplementedError as e:
|
| 296 |
+
if _SKIP_FAILED_SERIALIZATION:
|
| 297 |
+
return serialize_keras_class_and_config(
|
| 298 |
+
name, {_LAYER_UNDEFINED_CONFIG_KEY: True}
|
| 299 |
+
)
|
| 300 |
+
raise e
|
| 301 |
+
serialization_config = {}
|
| 302 |
+
for key, item in config.items():
|
| 303 |
+
if isinstance(item, str):
|
| 304 |
+
serialization_config[key] = item
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
# Any object of a different type needs to be converted to string or
|
| 308 |
+
# dict for serialization (e.g. custom functions, custom classes)
|
| 309 |
+
try:
|
| 310 |
+
serialized_item = serialize_keras_object(item)
|
| 311 |
+
if isinstance(serialized_item, dict) and not isinstance(
|
| 312 |
+
item, dict
|
| 313 |
+
):
|
| 314 |
+
serialized_item["__passive_serialization__"] = True
|
| 315 |
+
serialization_config[key] = serialized_item
|
| 316 |
+
except ValueError:
|
| 317 |
+
serialization_config[key] = item
|
| 318 |
+
|
| 319 |
+
name = object_registration.get_registered_name(instance.__class__)
|
| 320 |
+
return serialize_keras_class_and_config(
|
| 321 |
+
name, serialization_config, instance
|
| 322 |
+
)
|
| 323 |
+
if hasattr(instance, "__name__"):
|
| 324 |
+
return object_registration.get_registered_name(instance)
|
| 325 |
+
raise ValueError(
|
| 326 |
+
f"Cannot serialize {instance} because it doesn't implement "
|
| 327 |
+
"`get_config()`."
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def class_and_config_for_serialized_keras_object(
|
| 332 |
+
config,
|
| 333 |
+
module_objects=None,
|
| 334 |
+
custom_objects=None,
|
| 335 |
+
printable_module_name="object",
|
| 336 |
+
):
|
| 337 |
+
"""Returns the class name and config for a serialized keras object."""
|
| 338 |
+
|
| 339 |
+
if (
|
| 340 |
+
not isinstance(config, dict)
|
| 341 |
+
or "class_name" not in config
|
| 342 |
+
or "config" not in config
|
| 343 |
+
):
|
| 344 |
+
raise ValueError(
|
| 345 |
+
f"Improper config format for {config}. "
|
| 346 |
+
"Expecting python dict contains `class_name` and `config` as keys"
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
class_name = config["class_name"]
|
| 350 |
+
cls = object_registration.get_registered_object(
|
| 351 |
+
class_name, custom_objects, module_objects
|
| 352 |
+
)
|
| 353 |
+
if cls is None:
|
| 354 |
+
raise ValueError(
|
| 355 |
+
f"Unknown {printable_module_name}: '{class_name}'. "
|
| 356 |
+
"Please ensure you are using a `keras.utils.custom_object_scope` "
|
| 357 |
+
"and that this object is included in the scope. See "
|
| 358 |
+
"https://www.tensorflow.org/guide/keras/save_and_serialize"
|
| 359 |
+
"#registering_the_custom_object for details."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
cls_config = config["config"]
|
| 363 |
+
# Check if `cls_config` is a list. If it is a list, return the class and the
|
| 364 |
+
# associated class configs for recursively deserialization. This case will
|
| 365 |
+
# happen on the old version of sequential model (e.g. `keras_version` ==
|
| 366 |
+
# "2.0.6"), which is serialized in a different structure, for example
|
| 367 |
+
# "{'class_name': 'Sequential',
|
| 368 |
+
# 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
|
| 369 |
+
if isinstance(cls_config, list):
|
| 370 |
+
return (cls, cls_config)
|
| 371 |
+
|
| 372 |
+
deserialized_objects = {}
|
| 373 |
+
for key, item in cls_config.items():
|
| 374 |
+
if key == "name":
|
| 375 |
+
# Assume that the value of 'name' is a string that should not be
|
| 376 |
+
# deserialized as a function. This avoids the corner case where
|
| 377 |
+
# cls_config['name'] has an identical name to a custom function and
|
| 378 |
+
# gets converted into that function.
|
| 379 |
+
deserialized_objects[key] = item
|
| 380 |
+
elif isinstance(item, dict) and "__passive_serialization__" in item:
|
| 381 |
+
deserialized_objects[key] = deserialize_keras_object(
|
| 382 |
+
item,
|
| 383 |
+
module_objects=module_objects,
|
| 384 |
+
custom_objects=custom_objects,
|
| 385 |
+
printable_module_name="config_item",
|
| 386 |
+
)
|
| 387 |
+
# TODO(momernick): Should this also have 'module_objects'?
|
| 388 |
+
elif isinstance(item, str) and inspect.isfunction(
|
| 389 |
+
object_registration.get_registered_object(item, custom_objects)
|
| 390 |
+
):
|
| 391 |
+
# Handle custom functions here. When saving functions, we only save
|
| 392 |
+
# the function's name as a string. If we find a matching string in
|
| 393 |
+
# the custom objects during deserialization, we convert the string
|
| 394 |
+
# back to the original function.
|
| 395 |
+
# Note that a potential issue is that a string field could have a
|
| 396 |
+
# naming conflict with a custom function name, but this should be a
|
| 397 |
+
# rare case. This issue does not occur if a string field has a
|
| 398 |
+
# naming conflict with a custom object, since the config of an
|
| 399 |
+
# object will always be a dict.
|
| 400 |
+
deserialized_objects[key] = (
|
| 401 |
+
object_registration.get_registered_object(item, custom_objects)
|
| 402 |
+
)
|
| 403 |
+
for key, item in deserialized_objects.items():
|
| 404 |
+
cls_config[key] = deserialized_objects[key]
|
| 405 |
+
|
| 406 |
+
return (cls, cls_config)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
@keras_export(
|
| 410 |
+
[
|
| 411 |
+
"keras.legacy.saving.deserialize_keras_object",
|
| 412 |
+
"keras.utils.legacy.deserialize_keras_object",
|
| 413 |
+
]
|
| 414 |
+
)
|
| 415 |
+
def deserialize_keras_object(
|
| 416 |
+
identifier,
|
| 417 |
+
module_objects=None,
|
| 418 |
+
custom_objects=None,
|
| 419 |
+
printable_module_name="object",
|
| 420 |
+
):
|
| 421 |
+
"""Turns the serialized form of a Keras object back into an actual object.
|
| 422 |
+
|
| 423 |
+
This function is for mid-level library implementers rather than end users.
|
| 424 |
+
|
| 425 |
+
Importantly, this utility requires you to provide the dict of
|
| 426 |
+
`module_objects` to use for looking up the object config; this is not
|
| 427 |
+
populated by default. If you need a deserialization utility that has
|
| 428 |
+
preexisting knowledge of built-in Keras objects, use e.g.
|
| 429 |
+
`keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`,
|
| 430 |
+
etc.
|
| 431 |
+
|
| 432 |
+
Calling `deserialize_keras_object` while underneath the
|
| 433 |
+
`SharedObjectLoadingScope` context manager will cause any already-seen
|
| 434 |
+
shared objects to be returned as-is rather than creating a new object.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
identifier: the serialized form of the object.
|
| 438 |
+
module_objects: A dictionary of built-in objects to look the name up in.
|
| 439 |
+
Generally, `module_objects` is provided by midlevel library
|
| 440 |
+
implementers.
|
| 441 |
+
custom_objects: A dictionary of custom objects to look the name up in.
|
| 442 |
+
Generally, `custom_objects` is provided by the end user.
|
| 443 |
+
printable_module_name: A human-readable string representing the type of
|
| 444 |
+
the object. Printed in case of exception.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
The deserialized object.
|
| 448 |
+
|
| 449 |
+
Example:
|
| 450 |
+
|
| 451 |
+
A mid-level library implementer might want to implement a utility for
|
| 452 |
+
retrieving an object from its config, as such:
|
| 453 |
+
|
| 454 |
+
```python
|
| 455 |
+
def deserialize(config, custom_objects=None):
|
| 456 |
+
return deserialize_keras_object(
|
| 457 |
+
identifier,
|
| 458 |
+
module_objects=globals(),
|
| 459 |
+
custom_objects=custom_objects,
|
| 460 |
+
name="MyObjectType",
|
| 461 |
+
)
|
| 462 |
+
```
|
| 463 |
+
|
| 464 |
+
This is how e.g. `keras.layers.deserialize()` is implemented.
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
if identifier is None:
|
| 468 |
+
return None
|
| 469 |
+
|
| 470 |
+
if isinstance(identifier, dict):
|
| 471 |
+
# In this case we are dealing with a Keras config dictionary.
|
| 472 |
+
config = identifier
|
| 473 |
+
(cls, cls_config) = class_and_config_for_serialized_keras_object(
|
| 474 |
+
config, module_objects, custom_objects, printable_module_name
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# If this object has already been loaded (i.e. it's shared between
|
| 478 |
+
# multiple objects), return the already-loaded object.
|
| 479 |
+
shared_object_id = config.get(SHARED_OBJECT_KEY)
|
| 480 |
+
shared_object = _shared_object_loading_scope().get(shared_object_id)
|
| 481 |
+
if shared_object is not None:
|
| 482 |
+
return shared_object
|
| 483 |
+
|
| 484 |
+
if hasattr(cls, "from_config"):
|
| 485 |
+
arg_spec = inspect.getfullargspec(cls.from_config)
|
| 486 |
+
custom_objects = custom_objects or {}
|
| 487 |
+
|
| 488 |
+
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
|
| 489 |
+
# Replace keras refs with keras
|
| 490 |
+
cls_config = _find_replace_nested_dict(
|
| 491 |
+
cls_config, "keras.", "keras."
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
if "custom_objects" in arg_spec.args:
|
| 495 |
+
deserialized_obj = cls.from_config(
|
| 496 |
+
cls_config,
|
| 497 |
+
custom_objects={
|
| 498 |
+
**object_registration.GLOBAL_CUSTOM_OBJECTS,
|
| 499 |
+
**custom_objects,
|
| 500 |
+
},
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
with object_registration.CustomObjectScope(custom_objects):
|
| 504 |
+
deserialized_obj = cls.from_config(cls_config)
|
| 505 |
+
else:
|
| 506 |
+
# Then `cls` may be a function returning a class.
|
| 507 |
+
# in this case by convention `config` holds
|
| 508 |
+
# the kwargs of the function.
|
| 509 |
+
custom_objects = custom_objects or {}
|
| 510 |
+
with object_registration.CustomObjectScope(custom_objects):
|
| 511 |
+
deserialized_obj = cls(**cls_config)
|
| 512 |
+
|
| 513 |
+
# Add object to shared objects, in case we find it referenced again.
|
| 514 |
+
_shared_object_loading_scope().set(shared_object_id, deserialized_obj)
|
| 515 |
+
|
| 516 |
+
return deserialized_obj
|
| 517 |
+
|
| 518 |
+
elif isinstance(identifier, str):
|
| 519 |
+
object_name = identifier
|
| 520 |
+
if custom_objects and object_name in custom_objects:
|
| 521 |
+
obj = custom_objects.get(object_name)
|
| 522 |
+
elif (
|
| 523 |
+
object_name
|
| 524 |
+
in object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__
|
| 525 |
+
):
|
| 526 |
+
obj = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[
|
| 527 |
+
object_name
|
| 528 |
+
]
|
| 529 |
+
elif object_name in object_registration._GLOBAL_CUSTOM_OBJECTS:
|
| 530 |
+
obj = object_registration._GLOBAL_CUSTOM_OBJECTS[object_name]
|
| 531 |
+
else:
|
| 532 |
+
obj = module_objects.get(object_name)
|
| 533 |
+
if obj is None:
|
| 534 |
+
raise ValueError(
|
| 535 |
+
f"Unknown {printable_module_name}: '{object_name}'. "
|
| 536 |
+
"Please ensure you are using a "
|
| 537 |
+
"`keras.utils.custom_object_scope` "
|
| 538 |
+
"and that this object is included in the scope. See "
|
| 539 |
+
"https://www.tensorflow.org/guide/keras/save_and_serialize"
|
| 540 |
+
"#registering_the_custom_object for details."
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# Classes passed by name are instantiated with no args, functions are
|
| 544 |
+
# returned as-is.
|
| 545 |
+
if inspect.isclass(obj):
|
| 546 |
+
return obj()
|
| 547 |
+
return obj
|
| 548 |
+
elif inspect.isfunction(identifier):
|
| 549 |
+
# If a function has already been deserialized, return as is.
|
| 550 |
+
return identifier
|
| 551 |
+
else:
|
| 552 |
+
raise ValueError(
|
| 553 |
+
"Could not interpret serialized "
|
| 554 |
+
f"{printable_module_name}: {identifier}"
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def validate_config(config):
|
| 559 |
+
"""Determines whether config appears to be a valid layer config."""
|
| 560 |
+
return (
|
| 561 |
+
isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def is_default(method):
|
| 566 |
+
"""Check if a method is decorated with the `default` wrapper."""
|
| 567 |
+
return getattr(method, "_is_default", False)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def _find_replace_nested_dict(config, find, replace):
|
| 571 |
+
dict_str = json.dumps(config)
|
| 572 |
+
dict_str = dict_str.replace(find, replace)
|
| 573 |
+
config = json.loads(dict_str)
|
| 574 |
+
return config
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__init__.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.losses.loss import Loss
|
| 5 |
+
from keras.src.losses.losses import CTC
|
| 6 |
+
from keras.src.losses.losses import BinaryCrossentropy
|
| 7 |
+
from keras.src.losses.losses import BinaryFocalCrossentropy
|
| 8 |
+
from keras.src.losses.losses import CategoricalCrossentropy
|
| 9 |
+
from keras.src.losses.losses import CategoricalFocalCrossentropy
|
| 10 |
+
from keras.src.losses.losses import CategoricalHinge
|
| 11 |
+
from keras.src.losses.losses import Circle
|
| 12 |
+
from keras.src.losses.losses import CosineSimilarity
|
| 13 |
+
from keras.src.losses.losses import Dice
|
| 14 |
+
from keras.src.losses.losses import Hinge
|
| 15 |
+
from keras.src.losses.losses import Huber
|
| 16 |
+
from keras.src.losses.losses import KLDivergence
|
| 17 |
+
from keras.src.losses.losses import LogCosh
|
| 18 |
+
from keras.src.losses.losses import LossFunctionWrapper
|
| 19 |
+
from keras.src.losses.losses import MeanAbsoluteError
|
| 20 |
+
from keras.src.losses.losses import MeanAbsolutePercentageError
|
| 21 |
+
from keras.src.losses.losses import MeanSquaredError
|
| 22 |
+
from keras.src.losses.losses import MeanSquaredLogarithmicError
|
| 23 |
+
from keras.src.losses.losses import Poisson
|
| 24 |
+
from keras.src.losses.losses import SparseCategoricalCrossentropy
|
| 25 |
+
from keras.src.losses.losses import SquaredHinge
|
| 26 |
+
from keras.src.losses.losses import Tversky
|
| 27 |
+
from keras.src.losses.losses import binary_crossentropy
|
| 28 |
+
from keras.src.losses.losses import binary_focal_crossentropy
|
| 29 |
+
from keras.src.losses.losses import categorical_crossentropy
|
| 30 |
+
from keras.src.losses.losses import categorical_focal_crossentropy
|
| 31 |
+
from keras.src.losses.losses import categorical_hinge
|
| 32 |
+
from keras.src.losses.losses import circle
|
| 33 |
+
from keras.src.losses.losses import cosine_similarity
|
| 34 |
+
from keras.src.losses.losses import ctc
|
| 35 |
+
from keras.src.losses.losses import dice
|
| 36 |
+
from keras.src.losses.losses import hinge
|
| 37 |
+
from keras.src.losses.losses import huber
|
| 38 |
+
from keras.src.losses.losses import kl_divergence
|
| 39 |
+
from keras.src.losses.losses import log_cosh
|
| 40 |
+
from keras.src.losses.losses import mean_absolute_error
|
| 41 |
+
from keras.src.losses.losses import mean_absolute_percentage_error
|
| 42 |
+
from keras.src.losses.losses import mean_squared_error
|
| 43 |
+
from keras.src.losses.losses import mean_squared_logarithmic_error
|
| 44 |
+
from keras.src.losses.losses import poisson
|
| 45 |
+
from keras.src.losses.losses import sparse_categorical_crossentropy
|
| 46 |
+
from keras.src.losses.losses import squared_hinge
|
| 47 |
+
from keras.src.losses.losses import tversky
|
| 48 |
+
from keras.src.saving import serialization_lib
|
| 49 |
+
|
| 50 |
+
ALL_OBJECTS = {
|
| 51 |
+
# Base
|
| 52 |
+
Loss,
|
| 53 |
+
LossFunctionWrapper,
|
| 54 |
+
# Probabilistic
|
| 55 |
+
KLDivergence,
|
| 56 |
+
Poisson,
|
| 57 |
+
BinaryCrossentropy,
|
| 58 |
+
BinaryFocalCrossentropy,
|
| 59 |
+
CategoricalCrossentropy,
|
| 60 |
+
CategoricalFocalCrossentropy,
|
| 61 |
+
SparseCategoricalCrossentropy,
|
| 62 |
+
# Regression
|
| 63 |
+
MeanSquaredError,
|
| 64 |
+
MeanAbsoluteError,
|
| 65 |
+
MeanAbsolutePercentageError,
|
| 66 |
+
MeanSquaredLogarithmicError,
|
| 67 |
+
CosineSimilarity,
|
| 68 |
+
LogCosh,
|
| 69 |
+
Huber,
|
| 70 |
+
# Hinge
|
| 71 |
+
Hinge,
|
| 72 |
+
SquaredHinge,
|
| 73 |
+
CategoricalHinge,
|
| 74 |
+
# Image segmentation
|
| 75 |
+
Dice,
|
| 76 |
+
Tversky,
|
| 77 |
+
# Similarity
|
| 78 |
+
Circle,
|
| 79 |
+
# Sequence
|
| 80 |
+
CTC,
|
| 81 |
+
# Probabilistic
|
| 82 |
+
kl_divergence,
|
| 83 |
+
poisson,
|
| 84 |
+
binary_crossentropy,
|
| 85 |
+
binary_focal_crossentropy,
|
| 86 |
+
categorical_crossentropy,
|
| 87 |
+
categorical_focal_crossentropy,
|
| 88 |
+
sparse_categorical_crossentropy,
|
| 89 |
+
# Regression
|
| 90 |
+
mean_squared_error,
|
| 91 |
+
mean_absolute_error,
|
| 92 |
+
mean_absolute_percentage_error,
|
| 93 |
+
mean_squared_logarithmic_error,
|
| 94 |
+
cosine_similarity,
|
| 95 |
+
log_cosh,
|
| 96 |
+
huber,
|
| 97 |
+
# Hinge
|
| 98 |
+
hinge,
|
| 99 |
+
squared_hinge,
|
| 100 |
+
categorical_hinge,
|
| 101 |
+
# Image segmentation
|
| 102 |
+
dice,
|
| 103 |
+
tversky,
|
| 104 |
+
# Similarity
|
| 105 |
+
circle,
|
| 106 |
+
# Sequence
|
| 107 |
+
ctc,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
|
| 111 |
+
ALL_OBJECTS_DICT.update(
|
| 112 |
+
{
|
| 113 |
+
"bce": binary_crossentropy,
|
| 114 |
+
"BCE": binary_crossentropy,
|
| 115 |
+
"kld": kl_divergence,
|
| 116 |
+
"KLD": kl_divergence,
|
| 117 |
+
"mae": mean_absolute_error,
|
| 118 |
+
"MAE": mean_absolute_error,
|
| 119 |
+
"mse": mean_squared_error,
|
| 120 |
+
"MSE": mean_squared_error,
|
| 121 |
+
"mape": mean_absolute_percentage_error,
|
| 122 |
+
"MAPE": mean_absolute_percentage_error,
|
| 123 |
+
"msle": mean_squared_logarithmic_error,
|
| 124 |
+
"MSLE": mean_squared_logarithmic_error,
|
| 125 |
+
}
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@keras_export("keras.losses.serialize")
|
| 130 |
+
def serialize(loss):
|
| 131 |
+
"""Serializes loss function or `Loss` instance.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
loss: A Keras `Loss` instance or a loss function.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Loss configuration dictionary.
|
| 138 |
+
"""
|
| 139 |
+
return serialization_lib.serialize_keras_object(loss)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@keras_export("keras.losses.deserialize")
|
| 143 |
+
def deserialize(name, custom_objects=None):
|
| 144 |
+
"""Deserializes a serialized loss class/function instance.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
name: Loss configuration.
|
| 148 |
+
custom_objects: Optional dictionary mapping names (strings) to custom
|
| 149 |
+
objects (classes and functions) to be considered during
|
| 150 |
+
deserialization.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
A Keras `Loss` instance or a loss function.
|
| 154 |
+
"""
|
| 155 |
+
return serialization_lib.deserialize_keras_object(
|
| 156 |
+
name,
|
| 157 |
+
module_objects=ALL_OBJECTS_DICT,
|
| 158 |
+
custom_objects=custom_objects,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@keras_export("keras.losses.get")
|
| 163 |
+
def get(identifier):
|
| 164 |
+
"""Retrieves a Keras loss as a `function`/`Loss` class instance.
|
| 165 |
+
|
| 166 |
+
The `identifier` may be the string name of a loss function or `Loss` class.
|
| 167 |
+
|
| 168 |
+
>>> loss = losses.get("categorical_crossentropy")
|
| 169 |
+
>>> type(loss)
|
| 170 |
+
<class 'function'>
|
| 171 |
+
>>> loss = losses.get("CategoricalCrossentropy")
|
| 172 |
+
>>> type(loss)
|
| 173 |
+
<class '...CategoricalCrossentropy'>
|
| 174 |
+
|
| 175 |
+
You can also specify `config` of the loss to this function by passing dict
|
| 176 |
+
containing `class_name` and `config` as an identifier. Also note that the
|
| 177 |
+
`class_name` must map to a `Loss` class
|
| 178 |
+
|
| 179 |
+
>>> identifier = {"class_name": "CategoricalCrossentropy",
|
| 180 |
+
... "config": {"from_logits": True}}
|
| 181 |
+
>>> loss = losses.get(identifier)
|
| 182 |
+
>>> type(loss)
|
| 183 |
+
<class '...CategoricalCrossentropy'>
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
identifier: A loss identifier. One of None or string name of a loss
|
| 187 |
+
function/class or loss configuration dictionary or a loss function
|
| 188 |
+
or a loss class instance.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
A Keras loss as a `function`/ `Loss` class instance.
|
| 192 |
+
"""
|
| 193 |
+
if identifier is None:
|
| 194 |
+
return None
|
| 195 |
+
if isinstance(identifier, dict):
|
| 196 |
+
obj = deserialize(identifier)
|
| 197 |
+
elif isinstance(identifier, str):
|
| 198 |
+
obj = ALL_OBJECTS_DICT.get(identifier, None)
|
| 199 |
+
else:
|
| 200 |
+
obj = identifier
|
| 201 |
+
|
| 202 |
+
if callable(obj):
|
| 203 |
+
if inspect.isclass(obj):
|
| 204 |
+
obj = obj()
|
| 205 |
+
return obj
|
| 206 |
+
else:
|
| 207 |
+
raise ValueError(f"Could not interpret loss identifier: {identifier}")
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (5.08 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/loss.cpython-310.pyc
ADDED
|
Binary file (6.95 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/losses.cpython-310.pyc
ADDED
|
Binary file (87.3 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/loss.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src import dtype_policies
|
| 3 |
+
from keras.src import ops
|
| 4 |
+
from keras.src import tree
|
| 5 |
+
from keras.src.api_export import keras_export
|
| 6 |
+
from keras.src.saving.keras_saveable import KerasSaveable
|
| 7 |
+
from keras.src.utils.naming import auto_name
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@keras_export(["keras.Loss", "keras.losses.Loss"])
|
| 11 |
+
class Loss(KerasSaveable):
|
| 12 |
+
"""Loss base class.
|
| 13 |
+
|
| 14 |
+
This is the class to subclass in order to create new custom losses.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 18 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 19 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 20 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 21 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 22 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 23 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 24 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 25 |
+
name: Optional name for the loss instance.
|
| 26 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 27 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 28 |
+
`"float32"` unless set to different value
|
| 29 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 30 |
+
provided, then the `compute_dtype` will be utilized.
|
| 31 |
+
|
| 32 |
+
To be implemented by subclasses:
|
| 33 |
+
|
| 34 |
+
* `call()`: Contains the logic for loss calculation using `y_true`,
|
| 35 |
+
`y_pred`.
|
| 36 |
+
|
| 37 |
+
Example subclass implementation:
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
class MeanSquaredError(Loss):
|
| 41 |
+
def call(self, y_true, y_pred):
|
| 42 |
+
return ops.mean(ops.square(y_pred - y_true), axis=-1)
|
| 43 |
+
```
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, name=None, reduction="sum_over_batch_size", dtype=None):
|
| 47 |
+
self.name = name or auto_name(self.__class__.__name__)
|
| 48 |
+
self.reduction = standardize_reduction(reduction)
|
| 49 |
+
self._dtype_policy = dtype_policies.get(dtype or backend.floatx())
|
| 50 |
+
self._dtype = self._dtype_policy.compute_dtype
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def dtype(self):
|
| 54 |
+
return self._dtype
|
| 55 |
+
|
| 56 |
+
def __call__(self, y_true, y_pred, sample_weight=None):
|
| 57 |
+
in_mask = backend.get_keras_mask(y_pred)
|
| 58 |
+
|
| 59 |
+
with ops.name_scope(self.name):
|
| 60 |
+
y_pred = tree.map_structure(
|
| 61 |
+
lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_pred
|
| 62 |
+
)
|
| 63 |
+
y_true = tree.map_structure(
|
| 64 |
+
lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_true
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
losses = self.call(y_true, y_pred)
|
| 68 |
+
out_mask = backend.get_keras_mask(losses)
|
| 69 |
+
|
| 70 |
+
if in_mask is not None and out_mask is not None:
|
| 71 |
+
mask = in_mask & out_mask
|
| 72 |
+
elif in_mask is not None:
|
| 73 |
+
mask = in_mask
|
| 74 |
+
elif out_mask is not None:
|
| 75 |
+
mask = out_mask
|
| 76 |
+
else:
|
| 77 |
+
mask = None
|
| 78 |
+
|
| 79 |
+
return reduce_weighted_values(
|
| 80 |
+
losses,
|
| 81 |
+
sample_weight=sample_weight,
|
| 82 |
+
mask=mask,
|
| 83 |
+
reduction=self.reduction,
|
| 84 |
+
dtype=self.dtype,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def call(self, y_true, y_pred):
|
| 88 |
+
raise NotImplementedError
|
| 89 |
+
|
| 90 |
+
def get_config(self):
|
| 91 |
+
return {"name": self.name, "reduction": self.reduction}
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def from_config(cls, config):
|
| 95 |
+
return cls(**config)
|
| 96 |
+
|
| 97 |
+
def _obj_type(self):
|
| 98 |
+
return "Loss"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def standardize_reduction(reduction):
|
| 102 |
+
allowed = {
|
| 103 |
+
"sum_over_batch_size",
|
| 104 |
+
"sum",
|
| 105 |
+
None,
|
| 106 |
+
"none",
|
| 107 |
+
"mean",
|
| 108 |
+
"mean_with_sample_weight",
|
| 109 |
+
}
|
| 110 |
+
if reduction not in allowed:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"Invalid value for argument `reduction`. "
|
| 113 |
+
f"Expected one of {allowed}. Received: "
|
| 114 |
+
f"reduction={reduction}"
|
| 115 |
+
)
|
| 116 |
+
return reduction
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def squeeze_or_expand_to_same_rank(x1, x2, expand_rank_1=True):
|
| 120 |
+
"""Squeeze/expand last dim if ranks differ from expected by exactly 1."""
|
| 121 |
+
x1_rank = len(x1.shape)
|
| 122 |
+
x2_rank = len(x2.shape)
|
| 123 |
+
if x1_rank == x2_rank:
|
| 124 |
+
return x1, x2
|
| 125 |
+
if x1_rank == x2_rank + 1:
|
| 126 |
+
if x1.shape[-1] == 1:
|
| 127 |
+
if x2_rank == 1 and expand_rank_1:
|
| 128 |
+
x2 = ops.expand_dims(x2, axis=-1)
|
| 129 |
+
else:
|
| 130 |
+
x1 = ops.squeeze(x1, axis=-1)
|
| 131 |
+
if x2_rank == x1_rank + 1:
|
| 132 |
+
if x2.shape[-1] == 1:
|
| 133 |
+
if x1_rank == 1 and expand_rank_1:
|
| 134 |
+
x1 = ops.expand_dims(x1, axis=-1)
|
| 135 |
+
else:
|
| 136 |
+
x2 = ops.squeeze(x2, axis=-1)
|
| 137 |
+
return x1, x2
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def reduce_values(values, sample_weight=None, reduction="sum_over_batch_size"):
|
| 141 |
+
if (
|
| 142 |
+
reduction is None
|
| 143 |
+
or reduction == "none"
|
| 144 |
+
or tuple(values.shape) == ()
|
| 145 |
+
or tuple(values.shape) == (0,)
|
| 146 |
+
):
|
| 147 |
+
return values
|
| 148 |
+
loss = ops.sum(values)
|
| 149 |
+
if reduction in ("sum_over_batch_size", "mean", "mean_with_sample_weight"):
|
| 150 |
+
if reduction == "mean_with_sample_weight" and sample_weight is not None:
|
| 151 |
+
divisor = ops.cast(ops.sum(sample_weight), loss.dtype)
|
| 152 |
+
else:
|
| 153 |
+
divisor = ops.cast(
|
| 154 |
+
ops.prod(
|
| 155 |
+
ops.convert_to_tensor(ops.shape(values), dtype="int32")
|
| 156 |
+
),
|
| 157 |
+
loss.dtype,
|
| 158 |
+
)
|
| 159 |
+
loss = ops.divide_no_nan(loss, divisor)
|
| 160 |
+
loss = scale_loss_for_distribution(loss)
|
| 161 |
+
return loss
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def reduce_weighted_values(
|
| 165 |
+
values,
|
| 166 |
+
sample_weight=None,
|
| 167 |
+
mask=None,
|
| 168 |
+
reduction="sum_over_batch_size",
|
| 169 |
+
dtype=None,
|
| 170 |
+
):
|
| 171 |
+
reduction = standardize_reduction(reduction)
|
| 172 |
+
|
| 173 |
+
values = ops.convert_to_tensor(values, dtype=dtype)
|
| 174 |
+
if sample_weight is not None:
|
| 175 |
+
sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype)
|
| 176 |
+
if mask is not None:
|
| 177 |
+
mask = ops.convert_to_tensor(mask, dtype=dtype)
|
| 178 |
+
|
| 179 |
+
# Merge mask and sample weight into sample weight.
|
| 180 |
+
sample_weight = apply_mask(
|
| 181 |
+
sample_weight, mask, dtype=values.dtype, reduction=reduction
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if sample_weight is not None:
|
| 185 |
+
sample_weight = ops.cast(sample_weight, values.dtype)
|
| 186 |
+
# Update dimensions of `sample_weight` to match `losses`.
|
| 187 |
+
values, sample_weight = squeeze_or_expand_to_same_rank(
|
| 188 |
+
values, sample_weight
|
| 189 |
+
)
|
| 190 |
+
values = values * sample_weight
|
| 191 |
+
|
| 192 |
+
# Apply reduction function to the individual weighted losses.
|
| 193 |
+
loss = reduce_values(values, sample_weight, reduction)
|
| 194 |
+
return loss
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def apply_mask(sample_weight, mask, dtype, reduction):
|
| 198 |
+
"""Applies any mask on predictions to sample weights."""
|
| 199 |
+
if mask is not None:
|
| 200 |
+
mask = ops.cast(mask, dtype=dtype)
|
| 201 |
+
if reduction in ("mean", "sum_over_batch_size"):
|
| 202 |
+
# Valid entries have weight `total/valid`, while invalid ones
|
| 203 |
+
# have 0. When summed over batch, they will be reduced to:
|
| 204 |
+
#
|
| 205 |
+
# mean(loss * sample_weight * total / valid)
|
| 206 |
+
# = sum(loss * sample_weight * total / valid) / total
|
| 207 |
+
# = sum(loss * sample_weight) / total * total / valid
|
| 208 |
+
# = sum(loss * sample_weight) / valid
|
| 209 |
+
total = ops.cast(
|
| 210 |
+
ops.prod(ops.convert_to_tensor(ops.shape(mask), dtype="int32")),
|
| 211 |
+
dtype,
|
| 212 |
+
)
|
| 213 |
+
valid = ops.sum(mask) # May be 0!
|
| 214 |
+
mask *= total / (valid + backend.epsilon())
|
| 215 |
+
|
| 216 |
+
if sample_weight is not None:
|
| 217 |
+
sample_weight = ops.cast(sample_weight, dtype=dtype)
|
| 218 |
+
mask, sample_weight = squeeze_or_expand_to_same_rank(
|
| 219 |
+
mask, sample_weight
|
| 220 |
+
)
|
| 221 |
+
sample_weight *= mask
|
| 222 |
+
else:
|
| 223 |
+
sample_weight = mask
|
| 224 |
+
return sample_weight
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def scale_loss_for_distribution(value):
|
| 228 |
+
"""Scales the given value by the number of replicas in the strategy.
|
| 229 |
+
|
| 230 |
+
Currently, this function is only effective when using the tensorflow backend
|
| 231 |
+
and `tf.distribute`.
|
| 232 |
+
"""
|
| 233 |
+
if backend.backend() == "tensorflow":
|
| 234 |
+
import tensorflow as tf
|
| 235 |
+
|
| 236 |
+
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
|
| 237 |
+
if num_replicas > 1:
|
| 238 |
+
value = ops.multiply(
|
| 239 |
+
value, ops.cast(1.0 / num_replicas, value.dtype)
|
| 240 |
+
)
|
| 241 |
+
return value
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def unscale_loss_for_distribution(value):
|
| 245 |
+
"""Unscales the given value by the number of replicas in the strategy.
|
| 246 |
+
|
| 247 |
+
Currently, this function is only effective when using the tensorflow backend
|
| 248 |
+
and `tf.distribute`.
|
| 249 |
+
"""
|
| 250 |
+
if backend.backend() == "tensorflow":
|
| 251 |
+
import tensorflow as tf
|
| 252 |
+
|
| 253 |
+
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
|
| 254 |
+
if num_replicas > 1:
|
| 255 |
+
value = ops.multiply(value, ops.cast(num_replicas, value.dtype))
|
| 256 |
+
return value
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/losses.py
ADDED
|
@@ -0,0 +1,2599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
from keras.src import backend
|
| 4 |
+
from keras.src import ops
|
| 5 |
+
from keras.src import tree
|
| 6 |
+
from keras.src.api_export import keras_export
|
| 7 |
+
from keras.src.losses.loss import Loss
|
| 8 |
+
from keras.src.losses.loss import squeeze_or_expand_to_same_rank
|
| 9 |
+
from keras.src.saving import serialization_lib
|
| 10 |
+
from keras.src.utils.numerical_utils import build_pos_neg_masks
|
| 11 |
+
from keras.src.utils.numerical_utils import normalize
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LossFunctionWrapper(Loss):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
fn,
|
| 18 |
+
reduction="sum_over_batch_size",
|
| 19 |
+
name=None,
|
| 20 |
+
dtype=None,
|
| 21 |
+
**kwargs,
|
| 22 |
+
):
|
| 23 |
+
super().__init__(name=name, reduction=reduction, dtype=dtype)
|
| 24 |
+
self.fn = fn
|
| 25 |
+
self._fn_kwargs = kwargs
|
| 26 |
+
|
| 27 |
+
def call(self, y_true, y_pred):
|
| 28 |
+
y_true_y_pred = tree.map_structure(
|
| 29 |
+
squeeze_or_expand_to_same_rank, y_true, y_pred
|
| 30 |
+
)
|
| 31 |
+
y_true = tree.map_structure_up_to(y_true, lambda x: x[0], y_true_y_pred)
|
| 32 |
+
y_pred = tree.map_structure_up_to(y_pred, lambda x: x[1], y_true_y_pred)
|
| 33 |
+
return self.fn(y_true, y_pred, **self._fn_kwargs)
|
| 34 |
+
|
| 35 |
+
def get_config(self):
|
| 36 |
+
config = super().get_config()
|
| 37 |
+
config.update({"fn": serialization_lib.serialize_keras_object(self.fn)})
|
| 38 |
+
config.update(serialization_lib.serialize_keras_object(self._fn_kwargs))
|
| 39 |
+
return config
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def from_config(cls, config):
|
| 43 |
+
if "fn" in config:
|
| 44 |
+
config = serialization_lib.deserialize_keras_object(config)
|
| 45 |
+
return cls(**config)
|
| 46 |
+
|
| 47 |
+
def __repr__(self):
|
| 48 |
+
return f"<LossFunctionWrapper({self.fn}, kwargs={self._fn_kwargs})>"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@keras_export("keras.losses.MeanSquaredError")
|
| 52 |
+
class MeanSquaredError(LossFunctionWrapper):
|
| 53 |
+
"""Computes the mean of squares of errors between labels and predictions.
|
| 54 |
+
|
| 55 |
+
Formula:
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
loss = mean(square(y_true - y_pred))
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 63 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 64 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 65 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 66 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 67 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 68 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 69 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 70 |
+
name: Optional name for the loss instance.
|
| 71 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 72 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 73 |
+
`"float32"` unless set to different value
|
| 74 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 75 |
+
provided, then the `compute_dtype` will be utilized.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
reduction="sum_over_batch_size",
|
| 81 |
+
name="mean_squared_error",
|
| 82 |
+
dtype=None,
|
| 83 |
+
):
|
| 84 |
+
super().__init__(
|
| 85 |
+
mean_squared_error, name=name, reduction=reduction, dtype=dtype
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def get_config(self):
|
| 89 |
+
return Loss.get_config(self)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@keras_export("keras.losses.MeanAbsoluteError")
|
| 93 |
+
class MeanAbsoluteError(LossFunctionWrapper):
|
| 94 |
+
"""Computes the mean of absolute difference between labels and predictions.
|
| 95 |
+
|
| 96 |
+
Formula:
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
loss = mean(abs(y_true - y_pred))
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 104 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 105 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 106 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 107 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 108 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 109 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 110 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 111 |
+
name: Optional name for the loss instance.
|
| 112 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 113 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 114 |
+
`"float32"` unless set to different value
|
| 115 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 116 |
+
provided, then the `compute_dtype` will be utilized.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
reduction="sum_over_batch_size",
|
| 122 |
+
name="mean_absolute_error",
|
| 123 |
+
dtype=None,
|
| 124 |
+
):
|
| 125 |
+
super().__init__(
|
| 126 |
+
mean_absolute_error, name=name, reduction=reduction, dtype=dtype
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def get_config(self):
|
| 130 |
+
return Loss.get_config(self)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@keras_export("keras.losses.MeanAbsolutePercentageError")
|
| 134 |
+
class MeanAbsolutePercentageError(LossFunctionWrapper):
|
| 135 |
+
"""Computes the mean absolute percentage error between `y_true` & `y_pred`.
|
| 136 |
+
|
| 137 |
+
Formula:
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
loss = 100 * mean(abs((y_true - y_pred) / y_true))
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 145 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 146 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 147 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 148 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 149 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 150 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 151 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 152 |
+
name: Optional name for the loss instance.
|
| 153 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 154 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 155 |
+
`"float32"` unless set to different value
|
| 156 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 157 |
+
provided, then the `compute_dtype` will be utilized.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
reduction="sum_over_batch_size",
|
| 163 |
+
name="mean_absolute_percentage_error",
|
| 164 |
+
dtype=None,
|
| 165 |
+
):
|
| 166 |
+
super().__init__(
|
| 167 |
+
mean_absolute_percentage_error,
|
| 168 |
+
name=name,
|
| 169 |
+
reduction=reduction,
|
| 170 |
+
dtype=dtype,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def get_config(self):
|
| 174 |
+
return Loss.get_config(self)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@keras_export("keras.losses.MeanSquaredLogarithmicError")
|
| 178 |
+
class MeanSquaredLogarithmicError(LossFunctionWrapper):
|
| 179 |
+
"""Computes the mean squared logarithmic error between `y_true` & `y_pred`.
|
| 180 |
+
|
| 181 |
+
Formula:
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
loss = mean(square(log(y_true + 1) - log(y_pred + 1)))
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 189 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 190 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 191 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 192 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 193 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 194 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 195 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 196 |
+
name: Optional name for the loss instance.
|
| 197 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 198 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 199 |
+
`"float32"` unless set to different value
|
| 200 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 201 |
+
provided, then the `compute_dtype` will be utilized.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
reduction="sum_over_batch_size",
|
| 207 |
+
name="mean_squared_logarithmic_error",
|
| 208 |
+
dtype=None,
|
| 209 |
+
):
|
| 210 |
+
super().__init__(
|
| 211 |
+
mean_squared_logarithmic_error,
|
| 212 |
+
name=name,
|
| 213 |
+
reduction=reduction,
|
| 214 |
+
dtype=dtype,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def get_config(self):
|
| 218 |
+
return Loss.get_config(self)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@keras_export("keras.losses.CosineSimilarity")
|
| 222 |
+
class CosineSimilarity(LossFunctionWrapper):
|
| 223 |
+
"""Computes the cosine similarity between `y_true` & `y_pred`.
|
| 224 |
+
|
| 225 |
+
Note that it is a number between -1 and 1. When it is a negative number
|
| 226 |
+
between -1 and 0, 0 indicates orthogonality and values closer to -1
|
| 227 |
+
indicate greater similarity. This makes it usable as a loss function in a
|
| 228 |
+
setting where you try to maximize the proximity between predictions and
|
| 229 |
+
targets. If either `y_true` or `y_pred` is a zero vector, cosine similarity
|
| 230 |
+
will be 0 regardless of the proximity between predictions and targets.
|
| 231 |
+
|
| 232 |
+
Formula:
|
| 233 |
+
|
| 234 |
+
```python
|
| 235 |
+
loss = -sum(l2_norm(y_true) * l2_norm(y_pred))
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
axis: The axis along which the cosine similarity is computed
|
| 240 |
+
(the features axis). Defaults to `-1`.
|
| 241 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 242 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 243 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 244 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 245 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 246 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 247 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 248 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 249 |
+
name: Optional name for the loss instance.
|
| 250 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 251 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 252 |
+
`"float32"` unless set to different value
|
| 253 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 254 |
+
provided, then the `compute_dtype` will be utilized.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
axis=-1,
|
| 260 |
+
reduction="sum_over_batch_size",
|
| 261 |
+
name="cosine_similarity",
|
| 262 |
+
dtype=None,
|
| 263 |
+
):
|
| 264 |
+
super().__init__(
|
| 265 |
+
cosine_similarity,
|
| 266 |
+
name=name,
|
| 267 |
+
reduction=reduction,
|
| 268 |
+
dtype=dtype,
|
| 269 |
+
axis=axis,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def get_config(self):
|
| 273 |
+
return Loss.get_config(self)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@keras_export("keras.losses.Huber")
|
| 277 |
+
class Huber(LossFunctionWrapper):
|
| 278 |
+
"""Computes the Huber loss between `y_true` & `y_pred`.
|
| 279 |
+
|
| 280 |
+
Formula:
|
| 281 |
+
|
| 282 |
+
```python
|
| 283 |
+
for x in error:
|
| 284 |
+
if abs(x) <= delta:
|
| 285 |
+
loss.append(0.5 * x^2)
|
| 286 |
+
elif abs(x) > delta:
|
| 287 |
+
loss.append(delta * abs(x) - 0.5 * delta^2)
|
| 288 |
+
|
| 289 |
+
loss = mean(loss, axis=-1)
|
| 290 |
+
```
|
| 291 |
+
See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
delta: A float, the point where the Huber loss function changes from a
|
| 295 |
+
quadratic to linear.
|
| 296 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 297 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 298 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 299 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 300 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 301 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 302 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 303 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 304 |
+
name: Optional name for the instance.
|
| 305 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 306 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 307 |
+
`"float32"` unless set to different value
|
| 308 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 309 |
+
provided, then the `compute_dtype` will be utilized.
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
def __init__(
|
| 313 |
+
self,
|
| 314 |
+
delta=1.0,
|
| 315 |
+
reduction="sum_over_batch_size",
|
| 316 |
+
name="huber_loss",
|
| 317 |
+
dtype=None,
|
| 318 |
+
):
|
| 319 |
+
super().__init__(
|
| 320 |
+
huber,
|
| 321 |
+
name=name,
|
| 322 |
+
reduction=reduction,
|
| 323 |
+
dtype=dtype,
|
| 324 |
+
delta=delta,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
def get_config(self):
|
| 328 |
+
return Loss.get_config(self)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
@keras_export("keras.losses.LogCosh")
|
| 332 |
+
class LogCosh(LossFunctionWrapper):
|
| 333 |
+
"""Computes the logarithm of the hyperbolic cosine of the prediction error.
|
| 334 |
+
|
| 335 |
+
Formula:
|
| 336 |
+
|
| 337 |
+
```python
|
| 338 |
+
error = y_pred - y_true
|
| 339 |
+
logcosh = mean(log((exp(error) + exp(-error))/2), axis=-1)`
|
| 340 |
+
```
|
| 341 |
+
where x is the error `y_pred - y_true`.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 345 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 346 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 347 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 348 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 349 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 350 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 351 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 352 |
+
name: Optional name for the instance.
|
| 353 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 354 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 355 |
+
`"float32"` unless set to different value
|
| 356 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 357 |
+
provided, then the `compute_dtype` will be utilized.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
def __init__(
|
| 361 |
+
self,
|
| 362 |
+
reduction="sum_over_batch_size",
|
| 363 |
+
name="log_cosh",
|
| 364 |
+
dtype=None,
|
| 365 |
+
):
|
| 366 |
+
super().__init__(log_cosh, name=name, reduction=reduction, dtype=dtype)
|
| 367 |
+
|
| 368 |
+
def get_config(self):
|
| 369 |
+
return Loss.get_config(self)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@keras_export("keras.losses.Hinge")
|
| 373 |
+
class Hinge(LossFunctionWrapper):
|
| 374 |
+
"""Computes the hinge loss between `y_true` & `y_pred`.
|
| 375 |
+
|
| 376 |
+
Formula:
|
| 377 |
+
|
| 378 |
+
```python
|
| 379 |
+
loss = maximum(1 - y_true * y_pred, 0)
|
| 380 |
+
```
|
| 381 |
+
|
| 382 |
+
`y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
|
| 383 |
+
provided we will convert them to -1 or 1.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 387 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 388 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 389 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 390 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 391 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 392 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 393 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 394 |
+
name: Optional name for the loss instance.
|
| 395 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 396 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 397 |
+
`"float32"` unless set to different value
|
| 398 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 399 |
+
provided, then the `compute_dtype` will be utilized.
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
reduction="sum_over_batch_size",
|
| 405 |
+
name="hinge",
|
| 406 |
+
dtype=None,
|
| 407 |
+
):
|
| 408 |
+
super().__init__(hinge, name=name, reduction=reduction, dtype=dtype)
|
| 409 |
+
|
| 410 |
+
def get_config(self):
|
| 411 |
+
return Loss.get_config(self)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
@keras_export("keras.losses.SquaredHinge")
|
| 415 |
+
class SquaredHinge(LossFunctionWrapper):
|
| 416 |
+
"""Computes the squared hinge loss between `y_true` & `y_pred`.
|
| 417 |
+
|
| 418 |
+
Formula:
|
| 419 |
+
|
| 420 |
+
```python
|
| 421 |
+
loss = square(maximum(1 - y_true * y_pred, 0))
|
| 422 |
+
```
|
| 423 |
+
|
| 424 |
+
`y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
|
| 425 |
+
provided we will convert them to -1 or 1.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 429 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 430 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 431 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 432 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 433 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 434 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 435 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 436 |
+
name: Optional name for the loss instance.
|
| 437 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 438 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 439 |
+
`"float32"` unless set to different value
|
| 440 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 441 |
+
provided, then the `compute_dtype` will be utilized.
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
def __init__(
|
| 445 |
+
self, reduction="sum_over_batch_size", name="squared_hinge", dtype=None
|
| 446 |
+
):
|
| 447 |
+
super().__init__(
|
| 448 |
+
squared_hinge, name=name, reduction=reduction, dtype=dtype
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
def get_config(self):
|
| 452 |
+
return Loss.get_config(self)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
@keras_export("keras.losses.CategoricalHinge")
|
| 456 |
+
class CategoricalHinge(LossFunctionWrapper):
|
| 457 |
+
"""Computes the categorical hinge loss between `y_true` & `y_pred`.
|
| 458 |
+
|
| 459 |
+
Formula:
|
| 460 |
+
|
| 461 |
+
```python
|
| 462 |
+
loss = maximum(neg - pos + 1, 0)
|
| 463 |
+
```
|
| 464 |
+
|
| 465 |
+
where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)`
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 469 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 470 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 471 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 472 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 473 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 474 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 475 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 476 |
+
name: Optional name for the loss instance.
|
| 477 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 478 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 479 |
+
`"float32"` unless set to different value
|
| 480 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 481 |
+
provided, then the `compute_dtype` will be utilized.
|
| 482 |
+
"""
|
| 483 |
+
|
| 484 |
+
def __init__(
|
| 485 |
+
self,
|
| 486 |
+
reduction="sum_over_batch_size",
|
| 487 |
+
name="categorical_hinge",
|
| 488 |
+
dtype=None,
|
| 489 |
+
):
|
| 490 |
+
super().__init__(
|
| 491 |
+
categorical_hinge, name=name, reduction=reduction, dtype=dtype
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
def get_config(self):
|
| 495 |
+
return Loss.get_config(self)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
@keras_export("keras.losses.KLDivergence")
|
| 499 |
+
class KLDivergence(LossFunctionWrapper):
|
| 500 |
+
"""Computes Kullback-Leibler divergence loss between `y_true` & `y_pred`.
|
| 501 |
+
|
| 502 |
+
Formula:
|
| 503 |
+
|
| 504 |
+
```python
|
| 505 |
+
loss = y_true * log(y_true / y_pred)
|
| 506 |
+
```
|
| 507 |
+
|
| 508 |
+
`y_true` and `y_pred` are expected to be probability
|
| 509 |
+
distributions, with values between 0 and 1. They will get
|
| 510 |
+
clipped to the `[0, 1]` range.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 514 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 515 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 516 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 517 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 518 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 519 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 520 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 521 |
+
name: Optional name for the loss instance.
|
| 522 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 523 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 524 |
+
`"float32"` unless set to different value
|
| 525 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 526 |
+
provided, then the `compute_dtype` will be utilized.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
def __init__(
|
| 530 |
+
self, reduction="sum_over_batch_size", name="kl_divergence", dtype=None
|
| 531 |
+
):
|
| 532 |
+
super().__init__(
|
| 533 |
+
kl_divergence, name=name, reduction=reduction, dtype=dtype
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
def get_config(self):
|
| 537 |
+
return Loss.get_config(self)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@keras_export("keras.losses.Poisson")
|
| 541 |
+
class Poisson(LossFunctionWrapper):
|
| 542 |
+
"""Computes the Poisson loss between `y_true` & `y_pred`.
|
| 543 |
+
|
| 544 |
+
Formula:
|
| 545 |
+
|
| 546 |
+
```python
|
| 547 |
+
loss = y_pred - y_true * log(y_pred)
|
| 548 |
+
```
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 552 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 553 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 554 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 555 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 556 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 557 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 558 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 559 |
+
name: Optional name for the loss instance.
|
| 560 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 561 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 562 |
+
`"float32"` unless set to different value
|
| 563 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 564 |
+
provided, then the `compute_dtype` will be utilized.
|
| 565 |
+
"""
|
| 566 |
+
|
| 567 |
+
def __init__(
|
| 568 |
+
self, reduction="sum_over_batch_size", name="poisson", dtype=None
|
| 569 |
+
):
|
| 570 |
+
super().__init__(poisson, name=name, reduction=reduction, dtype=dtype)
|
| 571 |
+
|
| 572 |
+
def get_config(self):
|
| 573 |
+
return Loss.get_config(self)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
@keras_export("keras.losses.BinaryCrossentropy")
|
| 577 |
+
class BinaryCrossentropy(LossFunctionWrapper):
|
| 578 |
+
"""Computes the cross-entropy loss between true labels and predicted labels.
|
| 579 |
+
|
| 580 |
+
Use this cross-entropy loss for binary (0 or 1) classification applications.
|
| 581 |
+
The loss function requires the following inputs:
|
| 582 |
+
|
| 583 |
+
- `y_true` (true label): This is either 0 or 1.
|
| 584 |
+
- `y_pred` (predicted value): This is the model's prediction, i.e, a single
|
| 585 |
+
floating-point value which either represents a
|
| 586 |
+
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
|
| 587 |
+
when `from_logits=True`) or a probability (i.e, value in [0., 1.] when
|
| 588 |
+
`from_logits=False`).
|
| 589 |
+
|
| 590 |
+
Args:
|
| 591 |
+
from_logits: Whether to interpret `y_pred` as a tensor of
|
| 592 |
+
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
|
| 593 |
+
assume that `y_pred` is probabilities (i.e., values in [0, 1]).
|
| 594 |
+
label_smoothing: Float in range [0, 1]. When 0, no smoothing occurs.
|
| 595 |
+
When > 0, we compute the loss between the predicted labels
|
| 596 |
+
and a smoothed version of the true labels, where the smoothing
|
| 597 |
+
squeezes the labels towards 0.5. Larger values of
|
| 598 |
+
`label_smoothing` correspond to heavier smoothing.
|
| 599 |
+
axis: The axis along which to compute crossentropy (the features axis).
|
| 600 |
+
Defaults to `-1`.
|
| 601 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 602 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 603 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 604 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 605 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 606 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 607 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 608 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 609 |
+
name: Optional name for the loss instance.
|
| 610 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 611 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 612 |
+
`"float32"` unless set to different value
|
| 613 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 614 |
+
provided, then the `compute_dtype` will be utilized.
|
| 615 |
+
|
| 616 |
+
Examples:
|
| 617 |
+
|
| 618 |
+
**Recommended Usage:** (set `from_logits=True`)
|
| 619 |
+
|
| 620 |
+
With `compile()` API:
|
| 621 |
+
|
| 622 |
+
```python
|
| 623 |
+
model.compile(
|
| 624 |
+
loss=keras.losses.BinaryCrossentropy(from_logits=True),
|
| 625 |
+
...
|
| 626 |
+
)
|
| 627 |
+
```
|
| 628 |
+
|
| 629 |
+
As a standalone function:
|
| 630 |
+
|
| 631 |
+
>>> # Example 1: (batch_size = 1, number of samples = 4)
|
| 632 |
+
>>> y_true = np.array([0, 1, 0, 0])
|
| 633 |
+
>>> y_pred = np.array([-18.6, 0.51, 2.94, -12.8])
|
| 634 |
+
>>> bce = keras.losses.BinaryCrossentropy(from_logits=True)
|
| 635 |
+
>>> bce(y_true, y_pred)
|
| 636 |
+
0.8654
|
| 637 |
+
|
| 638 |
+
>>> # Example 2: (batch_size = 2, number of samples = 4)
|
| 639 |
+
>>> y_true = np.array([[0, 1], [0, 0]])
|
| 640 |
+
>>> y_pred = np.array([[-18.6, 0.51], [2.94, -12.8]])
|
| 641 |
+
>>> # Using default 'auto'/'sum_over_batch_size' reduction type.
|
| 642 |
+
>>> bce = keras.losses.BinaryCrossentropy(from_logits=True)
|
| 643 |
+
>>> bce(y_true, y_pred)
|
| 644 |
+
0.8654
|
| 645 |
+
>>> # Using 'sample_weight' attribute
|
| 646 |
+
>>> bce(y_true, y_pred, sample_weight=[0.8, 0.2])
|
| 647 |
+
0.243
|
| 648 |
+
>>> # Using 'sum' reduction` type.
|
| 649 |
+
>>> bce = keras.losses.BinaryCrossentropy(from_logits=True,
|
| 650 |
+
... reduction="sum")
|
| 651 |
+
>>> bce(y_true, y_pred)
|
| 652 |
+
1.730
|
| 653 |
+
>>> # Using 'none' reduction type.
|
| 654 |
+
>>> bce = keras.losses.BinaryCrossentropy(from_logits=True,
|
| 655 |
+
... reduction=None)
|
| 656 |
+
>>> bce(y_true, y_pred)
|
| 657 |
+
array([0.235, 1.496], dtype=float32)
|
| 658 |
+
|
| 659 |
+
**Default Usage:** (set `from_logits=False`)
|
| 660 |
+
|
| 661 |
+
>>> # Make the following updates to the above "Recommended Usage" section
|
| 662 |
+
>>> # 1. Set `from_logits=False`
|
| 663 |
+
>>> keras.losses.BinaryCrossentropy() # OR ...('from_logits=False')
|
| 664 |
+
>>> # 2. Update `y_pred` to use probabilities instead of logits
|
| 665 |
+
>>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]]
|
| 666 |
+
"""
|
| 667 |
+
|
| 668 |
+
def __init__(
|
| 669 |
+
self,
|
| 670 |
+
from_logits=False,
|
| 671 |
+
label_smoothing=0.0,
|
| 672 |
+
axis=-1,
|
| 673 |
+
reduction="sum_over_batch_size",
|
| 674 |
+
name="binary_crossentropy",
|
| 675 |
+
dtype=None,
|
| 676 |
+
):
|
| 677 |
+
super().__init__(
|
| 678 |
+
binary_crossentropy,
|
| 679 |
+
name=name,
|
| 680 |
+
reduction=reduction,
|
| 681 |
+
dtype=dtype,
|
| 682 |
+
from_logits=from_logits,
|
| 683 |
+
label_smoothing=label_smoothing,
|
| 684 |
+
axis=axis,
|
| 685 |
+
)
|
| 686 |
+
self.from_logits = from_logits
|
| 687 |
+
self.label_smoothing = label_smoothing
|
| 688 |
+
self.axis = axis
|
| 689 |
+
|
| 690 |
+
def get_config(self):
|
| 691 |
+
config = Loss.get_config(self)
|
| 692 |
+
config.update(
|
| 693 |
+
{
|
| 694 |
+
"from_logits": self.from_logits,
|
| 695 |
+
"label_smoothing": self.label_smoothing,
|
| 696 |
+
"axis": self.axis,
|
| 697 |
+
}
|
| 698 |
+
)
|
| 699 |
+
return config
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
@keras_export("keras.losses.BinaryFocalCrossentropy")
|
| 703 |
+
class BinaryFocalCrossentropy(LossFunctionWrapper):
|
| 704 |
+
"""Computes focal cross-entropy loss between true labels and predictions.
|
| 705 |
+
|
| 706 |
+
Binary cross-entropy loss is often used for binary (0 or 1) classification
|
| 707 |
+
tasks. The loss function requires the following inputs:
|
| 708 |
+
|
| 709 |
+
- `y_true` (true label): This is either 0 or 1.
|
| 710 |
+
- `y_pred` (predicted value): This is the model's prediction, i.e, a single
|
| 711 |
+
floating-point value which either represents a
|
| 712 |
+
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
|
| 713 |
+
when `from_logits=True`) or a probability (i.e, value in `[0., 1.]` when
|
| 714 |
+
`from_logits=False`).
|
| 715 |
+
|
| 716 |
+
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
|
| 717 |
+
helps to apply a "focal factor" to down-weight easy examples and focus more
|
| 718 |
+
on hard examples. By default, the focal tensor is computed as follows:
|
| 719 |
+
|
| 720 |
+
`focal_factor = (1 - output) ** gamma` for class 1
|
| 721 |
+
`focal_factor = output ** gamma` for class 0
|
| 722 |
+
where `gamma` is a focusing parameter. When `gamma=0`, this function is
|
| 723 |
+
equivalent to the binary crossentropy loss.
|
| 724 |
+
|
| 725 |
+
Args:
|
| 726 |
+
apply_class_balancing: A bool, whether to apply weight balancing on the
|
| 727 |
+
binary classes 0 and 1.
|
| 728 |
+
alpha: A weight balancing factor for class 1, default is `0.25` as
|
| 729 |
+
mentioned in reference [Lin et al., 2018](
|
| 730 |
+
https://arxiv.org/pdf/1708.02002.pdf). The weight for class 0 is
|
| 731 |
+
`1.0 - alpha`.
|
| 732 |
+
gamma: A focusing parameter used to compute the focal factor, default is
|
| 733 |
+
`2.0` as mentioned in the reference
|
| 734 |
+
[Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf).
|
| 735 |
+
from_logits: Whether to interpret `y_pred` as a tensor of
|
| 736 |
+
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
|
| 737 |
+
assume that `y_pred` are probabilities (i.e., values in `[0, 1]`).
|
| 738 |
+
label_smoothing: Float in `[0, 1]`. When `0`, no smoothing occurs.
|
| 739 |
+
When > `0`, we compute the loss between the predicted labels
|
| 740 |
+
and a smoothed version of the true labels, where the smoothing
|
| 741 |
+
squeezes the labels towards `0.5`.
|
| 742 |
+
Larger values of `label_smoothing` correspond to heavier smoothing.
|
| 743 |
+
axis: The axis along which to compute crossentropy (the features axis).
|
| 744 |
+
Defaults to `-1`.
|
| 745 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 746 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 747 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 748 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 749 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 750 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 751 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 752 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 753 |
+
name: Optional name for the loss instance.
|
| 754 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 755 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 756 |
+
`"float32"` unless set to different value
|
| 757 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 758 |
+
provided, then the `compute_dtype` will be utilized.
|
| 759 |
+
|
| 760 |
+
Examples:
|
| 761 |
+
|
| 762 |
+
With the `compile()` API:
|
| 763 |
+
|
| 764 |
+
```python
|
| 765 |
+
model.compile(
|
| 766 |
+
loss=keras.losses.BinaryFocalCrossentropy(
|
| 767 |
+
gamma=2.0, from_logits=True),
|
| 768 |
+
...
|
| 769 |
+
)
|
| 770 |
+
```
|
| 771 |
+
|
| 772 |
+
As a standalone function:
|
| 773 |
+
|
| 774 |
+
>>> # Example 1: (batch_size = 1, number of samples = 4)
|
| 775 |
+
>>> y_true = np.array([0, 1, 0, 0])
|
| 776 |
+
>>> y_pred = np.array([-18.6, 0.51, 2.94, -12.8])
|
| 777 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 778 |
+
... gamma=2, from_logits=True)
|
| 779 |
+
>>> loss(y_true, y_pred)
|
| 780 |
+
0.691
|
| 781 |
+
|
| 782 |
+
>>> # Apply class weight
|
| 783 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 784 |
+
... apply_class_balancing=True, gamma=2, from_logits=True)
|
| 785 |
+
>>> loss(y_true, y_pred)
|
| 786 |
+
0.51
|
| 787 |
+
|
| 788 |
+
>>> # Example 2: (batch_size = 2, number of samples = 4)
|
| 789 |
+
>>> y_true = np.array([[0, 1], [0, 0]])
|
| 790 |
+
>>> y_pred = np.array([[-18.6, 0.51], [2.94, -12.8]])
|
| 791 |
+
>>> # Using default 'auto'/'sum_over_batch_size' reduction type.
|
| 792 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 793 |
+
... gamma=3, from_logits=True)
|
| 794 |
+
>>> loss(y_true, y_pred)
|
| 795 |
+
0.647
|
| 796 |
+
|
| 797 |
+
>>> # Apply class weight
|
| 798 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 799 |
+
... apply_class_balancing=True, gamma=3, from_logits=True)
|
| 800 |
+
>>> loss(y_true, y_pred)
|
| 801 |
+
0.482
|
| 802 |
+
|
| 803 |
+
>>> # Using 'sample_weight' attribute with focal effect
|
| 804 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 805 |
+
... gamma=3, from_logits=True)
|
| 806 |
+
>>> loss(y_true, y_pred, sample_weight=[0.8, 0.2])
|
| 807 |
+
0.133
|
| 808 |
+
|
| 809 |
+
>>> # Apply class weight
|
| 810 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 811 |
+
... apply_class_balancing=True, gamma=3, from_logits=True)
|
| 812 |
+
>>> loss(y_true, y_pred, sample_weight=[0.8, 0.2])
|
| 813 |
+
0.097
|
| 814 |
+
|
| 815 |
+
>>> # Using 'sum' reduction` type.
|
| 816 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 817 |
+
... gamma=4, from_logits=True,
|
| 818 |
+
... reduction="sum")
|
| 819 |
+
>>> loss(y_true, y_pred)
|
| 820 |
+
1.222
|
| 821 |
+
|
| 822 |
+
>>> # Apply class weight
|
| 823 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 824 |
+
... apply_class_balancing=True, gamma=4, from_logits=True,
|
| 825 |
+
... reduction="sum")
|
| 826 |
+
>>> loss(y_true, y_pred)
|
| 827 |
+
0.914
|
| 828 |
+
|
| 829 |
+
>>> # Using 'none' reduction type.
|
| 830 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 831 |
+
... gamma=5, from_logits=True,
|
| 832 |
+
... reduction=None)
|
| 833 |
+
>>> loss(y_true, y_pred)
|
| 834 |
+
array([0.0017 1.1561], dtype=float32)
|
| 835 |
+
|
| 836 |
+
>>> # Apply class weight
|
| 837 |
+
>>> loss = keras.losses.BinaryFocalCrossentropy(
|
| 838 |
+
... apply_class_balancing=True, gamma=5, from_logits=True,
|
| 839 |
+
... reduction=None)
|
| 840 |
+
>>> loss(y_true, y_pred)
|
| 841 |
+
array([0.0004 0.8670], dtype=float32)
|
| 842 |
+
"""
|
| 843 |
+
|
| 844 |
+
def __init__(
|
| 845 |
+
self,
|
| 846 |
+
apply_class_balancing=False,
|
| 847 |
+
alpha=0.25,
|
| 848 |
+
gamma=2.0,
|
| 849 |
+
from_logits=False,
|
| 850 |
+
label_smoothing=0.0,
|
| 851 |
+
axis=-1,
|
| 852 |
+
reduction="sum_over_batch_size",
|
| 853 |
+
name="binary_focal_crossentropy",
|
| 854 |
+
dtype=None,
|
| 855 |
+
):
|
| 856 |
+
super().__init__(
|
| 857 |
+
binary_focal_crossentropy,
|
| 858 |
+
name=name,
|
| 859 |
+
reduction=reduction,
|
| 860 |
+
dtype=dtype,
|
| 861 |
+
apply_class_balancing=apply_class_balancing,
|
| 862 |
+
alpha=alpha,
|
| 863 |
+
gamma=gamma,
|
| 864 |
+
from_logits=from_logits,
|
| 865 |
+
label_smoothing=label_smoothing,
|
| 866 |
+
axis=axis,
|
| 867 |
+
)
|
| 868 |
+
self.from_logits = from_logits
|
| 869 |
+
self.label_smoothing = label_smoothing
|
| 870 |
+
self.axis = axis
|
| 871 |
+
self.apply_class_balancing = apply_class_balancing
|
| 872 |
+
self.alpha = alpha
|
| 873 |
+
self.gamma = gamma
|
| 874 |
+
|
| 875 |
+
def get_config(self):
|
| 876 |
+
config = Loss.get_config(self)
|
| 877 |
+
config.update(
|
| 878 |
+
{
|
| 879 |
+
"from_logits": self.from_logits,
|
| 880 |
+
"label_smoothing": self.label_smoothing,
|
| 881 |
+
"axis": self.axis,
|
| 882 |
+
"apply_class_balancing": self.apply_class_balancing,
|
| 883 |
+
"alpha": self.alpha,
|
| 884 |
+
"gamma": self.gamma,
|
| 885 |
+
}
|
| 886 |
+
)
|
| 887 |
+
return config
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
@keras_export("keras.losses.CategoricalCrossentropy")
|
| 891 |
+
class CategoricalCrossentropy(LossFunctionWrapper):
|
| 892 |
+
"""Computes the crossentropy loss between the labels and predictions.
|
| 893 |
+
|
| 894 |
+
Use this crossentropy loss function when there are two or more label
|
| 895 |
+
classes. We expect labels to be provided in a `one_hot` representation. If
|
| 896 |
+
you want to provide labels as integers, please use
|
| 897 |
+
`SparseCategoricalCrossentropy` loss. There should be `num_classes` floating
|
| 898 |
+
point values per feature, i.e., the shape of both `y_pred` and `y_true` are
|
| 899 |
+
`[batch_size, num_classes]`.
|
| 900 |
+
|
| 901 |
+
Args:
|
| 902 |
+
from_logits: Whether `y_pred` is expected to be a logits tensor. By
|
| 903 |
+
default, we assume that `y_pred` encodes a probability distribution.
|
| 904 |
+
label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
|
| 905 |
+
meaning the confidence on label values are relaxed. For example, if
|
| 906 |
+
`0.1`, use `0.1 / num_classes` for non-target labels and
|
| 907 |
+
`0.9 + 0.1 / num_classes` for target labels.
|
| 908 |
+
axis: The axis along which to compute crossentropy (the features
|
| 909 |
+
axis). Defaults to `-1`.
|
| 910 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 911 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 912 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 913 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 914 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 915 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 916 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 917 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 918 |
+
name: Optional name for the loss instance.
|
| 919 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 920 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 921 |
+
`"float32"` unless set to different value
|
| 922 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 923 |
+
provided, then the `compute_dtype` will be utilized.
|
| 924 |
+
|
| 925 |
+
Examples:
|
| 926 |
+
|
| 927 |
+
Standalone usage:
|
| 928 |
+
|
| 929 |
+
>>> y_true = np.array([[0, 1, 0], [0, 0, 1]])
|
| 930 |
+
>>> y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
|
| 931 |
+
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
|
| 932 |
+
>>> cce = keras.losses.CategoricalCrossentropy()
|
| 933 |
+
>>> cce(y_true, y_pred)
|
| 934 |
+
1.177
|
| 935 |
+
|
| 936 |
+
>>> # Calling with 'sample_weight'.
|
| 937 |
+
>>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7]))
|
| 938 |
+
0.814
|
| 939 |
+
|
| 940 |
+
>>> # Using 'sum' reduction type.
|
| 941 |
+
>>> cce = keras.losses.CategoricalCrossentropy(
|
| 942 |
+
... reduction="sum")
|
| 943 |
+
>>> cce(y_true, y_pred)
|
| 944 |
+
2.354
|
| 945 |
+
|
| 946 |
+
>>> # Using 'none' reduction type.
|
| 947 |
+
>>> cce = keras.losses.CategoricalCrossentropy(
|
| 948 |
+
... reduction=None)
|
| 949 |
+
>>> cce(y_true, y_pred)
|
| 950 |
+
array([0.0513, 2.303], dtype=float32)
|
| 951 |
+
|
| 952 |
+
Usage with the `compile()` API:
|
| 953 |
+
|
| 954 |
+
```python
|
| 955 |
+
model.compile(optimizer='sgd',
|
| 956 |
+
loss=keras.losses.CategoricalCrossentropy())
|
| 957 |
+
```
|
| 958 |
+
"""
|
| 959 |
+
|
| 960 |
+
def __init__(
|
| 961 |
+
self,
|
| 962 |
+
from_logits=False,
|
| 963 |
+
label_smoothing=0.0,
|
| 964 |
+
axis=-1,
|
| 965 |
+
reduction="sum_over_batch_size",
|
| 966 |
+
name="categorical_crossentropy",
|
| 967 |
+
dtype=None,
|
| 968 |
+
):
|
| 969 |
+
super().__init__(
|
| 970 |
+
categorical_crossentropy,
|
| 971 |
+
name=name,
|
| 972 |
+
reduction=reduction,
|
| 973 |
+
dtype=dtype,
|
| 974 |
+
from_logits=from_logits,
|
| 975 |
+
label_smoothing=label_smoothing,
|
| 976 |
+
axis=axis,
|
| 977 |
+
)
|
| 978 |
+
self.from_logits = from_logits
|
| 979 |
+
self.label_smoothing = label_smoothing
|
| 980 |
+
self.axis = axis
|
| 981 |
+
|
| 982 |
+
def get_config(self):
|
| 983 |
+
config = Loss.get_config(self)
|
| 984 |
+
config.update(
|
| 985 |
+
{
|
| 986 |
+
"from_logits": self.from_logits,
|
| 987 |
+
"label_smoothing": self.label_smoothing,
|
| 988 |
+
"axis": self.axis,
|
| 989 |
+
}
|
| 990 |
+
)
|
| 991 |
+
return config
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
@keras_export("keras.losses.CategoricalFocalCrossentropy")
|
| 995 |
+
class CategoricalFocalCrossentropy(LossFunctionWrapper):
|
| 996 |
+
"""Computes the alpha balanced focal crossentropy loss.
|
| 997 |
+
|
| 998 |
+
Use this crossentropy loss function when there are two or more label
|
| 999 |
+
classes and if you want to handle class imbalance without using
|
| 1000 |
+
`class_weights`. We expect labels to be provided in a `one_hot`
|
| 1001 |
+
representation.
|
| 1002 |
+
|
| 1003 |
+
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
|
| 1004 |
+
helps to apply a focal factor to down-weight easy examples and focus more on
|
| 1005 |
+
hard examples. The general formula for the focal loss (FL)
|
| 1006 |
+
is as follows:
|
| 1007 |
+
|
| 1008 |
+
`FL(p_t) = (1 - p_t) ** gamma * log(p_t)`
|
| 1009 |
+
|
| 1010 |
+
where `p_t` is defined as follows:
|
| 1011 |
+
`p_t = output if y_true == 1, else 1 - output`
|
| 1012 |
+
|
| 1013 |
+
`(1 - p_t) ** gamma` is the `modulating_factor`, where `gamma` is a focusing
|
| 1014 |
+
parameter. When `gamma` = 0, there is no focal effect on the cross entropy.
|
| 1015 |
+
`gamma` reduces the importance given to simple examples in a smooth manner.
|
| 1016 |
+
|
| 1017 |
+
The authors use alpha-balanced variant of focal loss (FL) in the paper:
|
| 1018 |
+
`FL(p_t) = -alpha * (1 - p_t) ** gamma * log(p_t)`
|
| 1019 |
+
|
| 1020 |
+
where `alpha` is the weight factor for the classes. If `alpha` = 1, the
|
| 1021 |
+
loss won't be able to handle class imbalance properly as all
|
| 1022 |
+
classes will have the same weight. This can be a constant or a list of
|
| 1023 |
+
constants. If alpha is a list, it must have the same length as the number
|
| 1024 |
+
of classes.
|
| 1025 |
+
|
| 1026 |
+
The formula above can be generalized to:
|
| 1027 |
+
`FL(p_t) = alpha * (1 - p_t) ** gamma * CrossEntropy(y_true, y_pred)`
|
| 1028 |
+
|
| 1029 |
+
where minus comes from `CrossEntropy(y_true, y_pred)` (CE).
|
| 1030 |
+
|
| 1031 |
+
Extending this to multi-class case is straightforward:
|
| 1032 |
+
`FL(p_t) = alpha * (1 - p_t) ** gamma * CategoricalCE(y_true, y_pred)`
|
| 1033 |
+
|
| 1034 |
+
In the snippet below, there is `num_classes` floating pointing values per
|
| 1035 |
+
example. The shape of both `y_pred` and `y_true` are
|
| 1036 |
+
`(batch_size, num_classes)`.
|
| 1037 |
+
|
| 1038 |
+
Args:
|
| 1039 |
+
alpha: A weight balancing factor for all classes, default is `0.25` as
|
| 1040 |
+
mentioned in the reference. It can be a list of floats or a scalar.
|
| 1041 |
+
In the multi-class case, alpha may be set by inverse class
|
| 1042 |
+
frequency by using `compute_class_weight` from `sklearn.utils`.
|
| 1043 |
+
gamma: A focusing parameter, default is `2.0` as mentioned in the
|
| 1044 |
+
reference. It helps to gradually reduce the importance given to
|
| 1045 |
+
simple (easy) examples in a smooth manner.
|
| 1046 |
+
from_logits: Whether `output` is expected to be a logits tensor. By
|
| 1047 |
+
default, we consider that `output` encodes a probability
|
| 1048 |
+
distribution.
|
| 1049 |
+
label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
|
| 1050 |
+
meaning the confidence on label values are relaxed. For example, if
|
| 1051 |
+
`0.1`, use `0.1 / num_classes` for non-target labels and
|
| 1052 |
+
`0.9 + 0.1 / num_classes` for target labels.
|
| 1053 |
+
axis: The axis along which to compute crossentropy (the features
|
| 1054 |
+
axis). Defaults to `-1`.
|
| 1055 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 1056 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 1057 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 1058 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 1059 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 1060 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 1061 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 1062 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 1063 |
+
name: Optional name for the loss instance.
|
| 1064 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 1065 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 1066 |
+
`"float32"` unless set to different value
|
| 1067 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 1068 |
+
provided, then the `compute_dtype` will be utilized.
|
| 1069 |
+
|
| 1070 |
+
Examples:
|
| 1071 |
+
|
| 1072 |
+
Standalone usage:
|
| 1073 |
+
|
| 1074 |
+
>>> y_true = [[0., 1., 0.], [0., 0., 1.]]
|
| 1075 |
+
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
| 1076 |
+
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
|
| 1077 |
+
>>> cce = keras.losses.CategoricalFocalCrossentropy()
|
| 1078 |
+
>>> cce(y_true, y_pred)
|
| 1079 |
+
0.23315276
|
| 1080 |
+
|
| 1081 |
+
>>> # Calling with 'sample_weight'.
|
| 1082 |
+
>>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7]))
|
| 1083 |
+
0.1632
|
| 1084 |
+
|
| 1085 |
+
>>> # Using 'sum' reduction type.
|
| 1086 |
+
>>> cce = keras.losses.CategoricalFocalCrossentropy(
|
| 1087 |
+
... reduction="sum")
|
| 1088 |
+
>>> cce(y_true, y_pred)
|
| 1089 |
+
0.46631
|
| 1090 |
+
|
| 1091 |
+
>>> # Using 'none' reduction type.
|
| 1092 |
+
>>> cce = keras.losses.CategoricalFocalCrossentropy(
|
| 1093 |
+
... reduction=None)
|
| 1094 |
+
>>> cce(y_true, y_pred)
|
| 1095 |
+
array([3.2058331e-05, 4.6627346e-01], dtype=float32)
|
| 1096 |
+
|
| 1097 |
+
Usage with the `compile()` API:
|
| 1098 |
+
|
| 1099 |
+
```python
|
| 1100 |
+
model.compile(optimizer='adam',
|
| 1101 |
+
loss=keras.losses.CategoricalFocalCrossentropy())
|
| 1102 |
+
```
|
| 1103 |
+
"""
|
| 1104 |
+
|
| 1105 |
+
def __init__(
|
| 1106 |
+
self,
|
| 1107 |
+
alpha=0.25,
|
| 1108 |
+
gamma=2.0,
|
| 1109 |
+
from_logits=False,
|
| 1110 |
+
label_smoothing=0.0,
|
| 1111 |
+
axis=-1,
|
| 1112 |
+
reduction="sum_over_batch_size",
|
| 1113 |
+
name="categorical_focal_crossentropy",
|
| 1114 |
+
dtype=None,
|
| 1115 |
+
):
|
| 1116 |
+
"""Initializes `CategoricalFocalCrossentropy` instance."""
|
| 1117 |
+
super().__init__(
|
| 1118 |
+
categorical_focal_crossentropy,
|
| 1119 |
+
name=name,
|
| 1120 |
+
reduction=reduction,
|
| 1121 |
+
dtype=dtype,
|
| 1122 |
+
alpha=alpha,
|
| 1123 |
+
gamma=gamma,
|
| 1124 |
+
from_logits=from_logits,
|
| 1125 |
+
label_smoothing=label_smoothing,
|
| 1126 |
+
axis=axis,
|
| 1127 |
+
)
|
| 1128 |
+
self.from_logits = from_logits
|
| 1129 |
+
self.label_smoothing = label_smoothing
|
| 1130 |
+
self.axis = axis
|
| 1131 |
+
self.alpha = alpha
|
| 1132 |
+
self.gamma = gamma
|
| 1133 |
+
|
| 1134 |
+
def get_config(self):
|
| 1135 |
+
config = Loss.get_config(self)
|
| 1136 |
+
config.update(
|
| 1137 |
+
{
|
| 1138 |
+
"from_logits": self.from_logits,
|
| 1139 |
+
"label_smoothing": self.label_smoothing,
|
| 1140 |
+
"axis": self.axis,
|
| 1141 |
+
"alpha": self.alpha,
|
| 1142 |
+
"gamma": self.gamma,
|
| 1143 |
+
}
|
| 1144 |
+
)
|
| 1145 |
+
return config
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
@keras_export("keras.losses.SparseCategoricalCrossentropy")
|
| 1149 |
+
class SparseCategoricalCrossentropy(LossFunctionWrapper):
|
| 1150 |
+
"""Computes the crossentropy loss between the labels and predictions.
|
| 1151 |
+
|
| 1152 |
+
Use this crossentropy loss function when there are two or more label
|
| 1153 |
+
classes. We expect labels to be provided as integers. If you want to
|
| 1154 |
+
provide labels using `one-hot` representation, please use
|
| 1155 |
+
`CategoricalCrossentropy` loss. There should be `# classes` floating point
|
| 1156 |
+
values per feature for `y_pred` and a single floating point value per
|
| 1157 |
+
feature for `y_true`.
|
| 1158 |
+
|
| 1159 |
+
In the snippet below, there is a single floating point value per example for
|
| 1160 |
+
`y_true` and `num_classes` floating pointing values per example for
|
| 1161 |
+
`y_pred`. The shape of `y_true` is `[batch_size]` and the shape of `y_pred`
|
| 1162 |
+
is `[batch_size, num_classes]`.
|
| 1163 |
+
|
| 1164 |
+
Args:
|
| 1165 |
+
from_logits: Whether `y_pred` is expected to be a logits tensor. By
|
| 1166 |
+
default, we assume that `y_pred` encodes a probability distribution.
|
| 1167 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 1168 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 1169 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 1170 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 1171 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 1172 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 1173 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 1174 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 1175 |
+
name: Optional name for the loss instance.
|
| 1176 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 1177 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 1178 |
+
`"float32"` unless set to different value
|
| 1179 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 1180 |
+
provided, then the `compute_dtype` will be utilized.
|
| 1181 |
+
|
| 1182 |
+
Examples:
|
| 1183 |
+
|
| 1184 |
+
>>> y_true = [1, 2]
|
| 1185 |
+
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
| 1186 |
+
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
|
| 1187 |
+
>>> scce = keras.losses.SparseCategoricalCrossentropy()
|
| 1188 |
+
>>> scce(y_true, y_pred)
|
| 1189 |
+
1.177
|
| 1190 |
+
|
| 1191 |
+
>>> # Calling with 'sample_weight'.
|
| 1192 |
+
>>> scce(y_true, y_pred, sample_weight=np.array([0.3, 0.7]))
|
| 1193 |
+
0.814
|
| 1194 |
+
|
| 1195 |
+
>>> # Using 'sum' reduction type.
|
| 1196 |
+
>>> scce = keras.losses.SparseCategoricalCrossentropy(
|
| 1197 |
+
... reduction="sum")
|
| 1198 |
+
>>> scce(y_true, y_pred)
|
| 1199 |
+
2.354
|
| 1200 |
+
|
| 1201 |
+
>>> # Using 'none' reduction type.
|
| 1202 |
+
>>> scce = keras.losses.SparseCategoricalCrossentropy(
|
| 1203 |
+
... reduction=None)
|
| 1204 |
+
>>> scce(y_true, y_pred)
|
| 1205 |
+
array([0.0513, 2.303], dtype=float32)
|
| 1206 |
+
|
| 1207 |
+
Usage with the `compile()` API:
|
| 1208 |
+
|
| 1209 |
+
```python
|
| 1210 |
+
model.compile(optimizer='sgd',
|
| 1211 |
+
loss=keras.losses.SparseCategoricalCrossentropy())
|
| 1212 |
+
```
|
| 1213 |
+
"""
|
| 1214 |
+
|
| 1215 |
+
def __init__(
|
| 1216 |
+
self,
|
| 1217 |
+
from_logits=False,
|
| 1218 |
+
ignore_class=None,
|
| 1219 |
+
reduction="sum_over_batch_size",
|
| 1220 |
+
name="sparse_categorical_crossentropy",
|
| 1221 |
+
dtype=None,
|
| 1222 |
+
):
|
| 1223 |
+
super().__init__(
|
| 1224 |
+
sparse_categorical_crossentropy,
|
| 1225 |
+
name=name,
|
| 1226 |
+
reduction=reduction,
|
| 1227 |
+
dtype=dtype,
|
| 1228 |
+
from_logits=from_logits,
|
| 1229 |
+
ignore_class=ignore_class,
|
| 1230 |
+
)
|
| 1231 |
+
self.from_logits = from_logits
|
| 1232 |
+
self.ignore_class = ignore_class
|
| 1233 |
+
|
| 1234 |
+
def get_config(self):
|
| 1235 |
+
config = Loss.get_config(self)
|
| 1236 |
+
config.update(
|
| 1237 |
+
{
|
| 1238 |
+
"from_logits": self.from_logits,
|
| 1239 |
+
"ignore_class": self.ignore_class,
|
| 1240 |
+
}
|
| 1241 |
+
)
|
| 1242 |
+
return config
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
@keras_export("keras.losses.CTC")
|
| 1246 |
+
class CTC(LossFunctionWrapper):
|
| 1247 |
+
"""CTC (Connectionist Temporal Classification) loss.
|
| 1248 |
+
|
| 1249 |
+
Args:
|
| 1250 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 1251 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 1252 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 1253 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 1254 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 1255 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 1256 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 1257 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 1258 |
+
name: Optional name for the loss instance.
|
| 1259 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 1260 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 1261 |
+
`"float32"` unless set to different value
|
| 1262 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 1263 |
+
provided, then the `compute_dtype` will be utilized.
|
| 1264 |
+
"""
|
| 1265 |
+
|
| 1266 |
+
def __init__(self, reduction="sum_over_batch_size", name="ctc", dtype=None):
|
| 1267 |
+
super().__init__(ctc, name=name, reduction=reduction, dtype=dtype)
|
| 1268 |
+
|
| 1269 |
+
def get_config(self):
|
| 1270 |
+
return Loss.get_config(self)
|
| 1271 |
+
|
| 1272 |
+
|
| 1273 |
+
@keras_export("keras.losses.Dice")
|
| 1274 |
+
class Dice(LossFunctionWrapper):
|
| 1275 |
+
"""Computes the Dice loss value between `y_true` and `y_pred`.
|
| 1276 |
+
|
| 1277 |
+
Formula:
|
| 1278 |
+
```python
|
| 1279 |
+
loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred))
|
| 1280 |
+
```
|
| 1281 |
+
|
| 1282 |
+
Args:
|
| 1283 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 1284 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 1285 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 1286 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 1287 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 1288 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 1289 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 1290 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 1291 |
+
name: Optional name for the loss instance.
|
| 1292 |
+
axis: Tuple for which dimensions the loss is calculated. Defaults to
|
| 1293 |
+
`None`.
|
| 1294 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 1295 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 1296 |
+
`"float32"` unless set to different value
|
| 1297 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 1298 |
+
provided, then the `compute_dtype` will be utilized.
|
| 1299 |
+
|
| 1300 |
+
Returns:
|
| 1301 |
+
Dice loss value.
|
| 1302 |
+
|
| 1303 |
+
Example:
|
| 1304 |
+
|
| 1305 |
+
>>> y_true = [[[[1.0], [1.0]], [[0.0], [0.0]]],
|
| 1306 |
+
... [[[1.0], [1.0]], [[0.0], [0.0]]]]
|
| 1307 |
+
>>> y_pred = [[[[0.0], [1.0]], [[0.0], [1.0]]],
|
| 1308 |
+
... [[[0.4], [0.0]], [[0.0], [0.9]]]]
|
| 1309 |
+
>>> axis = (1, 2, 3)
|
| 1310 |
+
>>> loss = keras.losses.dice(y_true, y_pred, axis=axis)
|
| 1311 |
+
>>> assert loss.shape == (2,)
|
| 1312 |
+
>>> loss
|
| 1313 |
+
array([0.5, 0.75757575], shape=(2,), dtype=float32)
|
| 1314 |
+
|
| 1315 |
+
>>> loss = keras.losses.dice(y_true, y_pred)
|
| 1316 |
+
>>> assert loss.shape == ()
|
| 1317 |
+
>>> loss
|
| 1318 |
+
array(0.6164384, shape=(), dtype=float32)
|
| 1319 |
+
|
| 1320 |
+
>>> y_true = np.array(y_true)
|
| 1321 |
+
>>> y_pred = np.array(y_pred)
|
| 1322 |
+
>>> loss = keras.losses.Dice(axis=axis, reduction=None)(y_true, y_pred)
|
| 1323 |
+
>>> assert loss.shape == (2,)
|
| 1324 |
+
>>> loss
|
| 1325 |
+
array([0.5, 0.75757575], shape=(2,), dtype=float32)
|
| 1326 |
+
|
| 1327 |
+
"""
|
| 1328 |
+
|
| 1329 |
+
def __init__(
|
| 1330 |
+
self,
|
| 1331 |
+
reduction="sum_over_batch_size",
|
| 1332 |
+
name="dice",
|
| 1333 |
+
axis=None,
|
| 1334 |
+
dtype=None,
|
| 1335 |
+
):
|
| 1336 |
+
super().__init__(
|
| 1337 |
+
dice, name=name, reduction=reduction, dtype=dtype, axis=axis
|
| 1338 |
+
)
|
| 1339 |
+
self.axis = axis
|
| 1340 |
+
|
| 1341 |
+
def get_config(self):
|
| 1342 |
+
config = Loss.get_config(self)
|
| 1343 |
+
config.update({"axis": self.axis})
|
| 1344 |
+
return config
|
| 1345 |
+
|
| 1346 |
+
|
| 1347 |
+
@keras_export("keras.losses.Tversky")
|
| 1348 |
+
class Tversky(LossFunctionWrapper):
|
| 1349 |
+
"""Computes the Tversky loss value between `y_true` and `y_pred`.
|
| 1350 |
+
|
| 1351 |
+
This loss function is weighted by the alpha and beta coefficients
|
| 1352 |
+
that penalize false positives and false negatives.
|
| 1353 |
+
|
| 1354 |
+
With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to
|
| 1355 |
+
Dice Loss.
|
| 1356 |
+
|
| 1357 |
+
Args:
|
| 1358 |
+
alpha: The coefficient controlling incidence of false positives.
|
| 1359 |
+
Defaults to `0.5`.
|
| 1360 |
+
beta: The coefficient controlling incidence of false negatives.
|
| 1361 |
+
Defaults to `0.5`.
|
| 1362 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 1363 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 1364 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 1365 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 1366 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 1367 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 1368 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 1369 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 1370 |
+
name: Optional name for the loss instance.
|
| 1371 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 1372 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 1373 |
+
`"float32"` unless set to different value
|
| 1374 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 1375 |
+
provided, then the `compute_dtype` will be utilized.
|
| 1376 |
+
|
| 1377 |
+
Returns:
|
| 1378 |
+
Tversky loss value.
|
| 1379 |
+
|
| 1380 |
+
Reference:
|
| 1381 |
+
|
| 1382 |
+
- [Salehi et al., 2017](https://arxiv.org/abs/1706.05721)
|
| 1383 |
+
"""
|
| 1384 |
+
|
| 1385 |
+
def __init__(
|
| 1386 |
+
self,
|
| 1387 |
+
alpha=0.5,
|
| 1388 |
+
beta=0.5,
|
| 1389 |
+
reduction="sum_over_batch_size",
|
| 1390 |
+
name="tversky",
|
| 1391 |
+
axis=None,
|
| 1392 |
+
dtype=None,
|
| 1393 |
+
):
|
| 1394 |
+
super().__init__(
|
| 1395 |
+
tversky,
|
| 1396 |
+
name=name,
|
| 1397 |
+
reduction=reduction,
|
| 1398 |
+
dtype=dtype,
|
| 1399 |
+
alpha=alpha,
|
| 1400 |
+
beta=beta,
|
| 1401 |
+
axis=axis,
|
| 1402 |
+
)
|
| 1403 |
+
self.alpha = alpha
|
| 1404 |
+
self.beta = beta
|
| 1405 |
+
self.axis = axis
|
| 1406 |
+
|
| 1407 |
+
def get_config(self):
|
| 1408 |
+
config = Loss.get_config(self)
|
| 1409 |
+
config.update(
|
| 1410 |
+
{"alpha": self.alpha, "beta": self.beta, "axis": self.axis}
|
| 1411 |
+
)
|
| 1412 |
+
return config
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
+
@keras_export("keras.losses.Circle")
|
| 1416 |
+
class Circle(LossFunctionWrapper):
|
| 1417 |
+
"""Computes Circle Loss between integer labels and L2-normalized embeddings.
|
| 1418 |
+
|
| 1419 |
+
This is a metric learning loss designed to minimize within-class distance
|
| 1420 |
+
and maximize between-class distance in a flexible manner by dynamically
|
| 1421 |
+
adjusting the penalty strength based on optimization status of each
|
| 1422 |
+
similarity score.
|
| 1423 |
+
|
| 1424 |
+
To use Circle Loss effectively, the model should output embeddings without
|
| 1425 |
+
an activation function (such as a `Dense` layer with `activation=None`)
|
| 1426 |
+
followed by UnitNormalization layer to ensure unit-norm embeddings.
|
| 1427 |
+
|
| 1428 |
+
Args:
|
| 1429 |
+
gamma: Scaling factor that determines the largest scale of each
|
| 1430 |
+
similarity score. Defaults to `80`.
|
| 1431 |
+
margin: The relaxation factor, below this distance, negatives are
|
| 1432 |
+
up weighted and positives are down weighted. Similarly, above this
|
| 1433 |
+
distance negatives are down weighted and positive are up weighted.
|
| 1434 |
+
Defaults to `0.4`.
|
| 1435 |
+
remove_diagonal: Boolean, whether to remove self-similarities from the
|
| 1436 |
+
positive mask. Defaults to `True`.
|
| 1437 |
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
| 1438 |
+
this should be `"sum_over_batch_size"`. Supported options are
|
| 1439 |
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
| 1440 |
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
| 1441 |
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
| 1442 |
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
| 1443 |
+
divides by the sum of the sample weights. `"none"` and `None`
|
| 1444 |
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
| 1445 |
+
name: Optional name for the loss instance.
|
| 1446 |
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
| 1447 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 1448 |
+
`"float32"` unless set to different value
|
| 1449 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 1450 |
+
provided, then the `compute_dtype` will be utilized.
|
| 1451 |
+
|
| 1452 |
+
Examples:
|
| 1453 |
+
|
| 1454 |
+
Usage with the `compile()` API:
|
| 1455 |
+
|
| 1456 |
+
```python
|
| 1457 |
+
model = models.Sequential([
|
| 1458 |
+
keras.layers.Input(shape=(224, 224, 3)),
|
| 1459 |
+
keras.layers.Conv2D(16, (3, 3), activation='relu'),
|
| 1460 |
+
keras.layers.Flatten(),
|
| 1461 |
+
keras.layers.Dense(64, activation=None), # No activation
|
| 1462 |
+
keras.layers.UnitNormalization() # L2 normalization
|
| 1463 |
+
])
|
| 1464 |
+
|
| 1465 |
+
model.compile(optimizer="adam", loss=keras.losses.Circle())
|
| 1466 |
+
```
|
| 1467 |
+
|
| 1468 |
+
Reference:
|
| 1469 |
+
- [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857)
|
| 1470 |
+
|
| 1471 |
+
"""
|
| 1472 |
+
|
| 1473 |
+
def __init__(
|
| 1474 |
+
self,
|
| 1475 |
+
gamma=80.0,
|
| 1476 |
+
margin=0.4,
|
| 1477 |
+
remove_diagonal=True,
|
| 1478 |
+
reduction="sum_over_batch_size",
|
| 1479 |
+
name="circle",
|
| 1480 |
+
dtype=None,
|
| 1481 |
+
):
|
| 1482 |
+
super().__init__(
|
| 1483 |
+
circle,
|
| 1484 |
+
name=name,
|
| 1485 |
+
reduction=reduction,
|
| 1486 |
+
dtype=dtype,
|
| 1487 |
+
gamma=gamma,
|
| 1488 |
+
margin=margin,
|
| 1489 |
+
remove_diagonal=remove_diagonal,
|
| 1490 |
+
)
|
| 1491 |
+
self.gamma = gamma
|
| 1492 |
+
self.margin = margin
|
| 1493 |
+
self.remove_diagonal = remove_diagonal
|
| 1494 |
+
|
| 1495 |
+
def get_config(self):
|
| 1496 |
+
config = Loss.get_config(self)
|
| 1497 |
+
config.update(
|
| 1498 |
+
{
|
| 1499 |
+
"gamma": self.gamma,
|
| 1500 |
+
"margin": self.margin,
|
| 1501 |
+
"remove_diagonal": self.remove_diagonal,
|
| 1502 |
+
}
|
| 1503 |
+
)
|
| 1504 |
+
return config
|
| 1505 |
+
|
| 1506 |
+
|
| 1507 |
+
def convert_binary_labels_to_hinge(y_true):
|
| 1508 |
+
"""Converts binary labels into -1/1 for hinge loss/metric calculation."""
|
| 1509 |
+
are_zeros = ops.equal(y_true, 0)
|
| 1510 |
+
are_ones = ops.equal(y_true, 1)
|
| 1511 |
+
is_binary = ops.all((ops.logical_or(are_zeros, are_ones)))
|
| 1512 |
+
|
| 1513 |
+
def _convert_binary_labels():
|
| 1514 |
+
# Convert the binary labels to -1 or 1.
|
| 1515 |
+
return 2.0 * y_true - 1.0
|
| 1516 |
+
|
| 1517 |
+
def _return_labels_unconverted():
|
| 1518 |
+
# Returns the labels unchanged if they are non-binary
|
| 1519 |
+
return y_true
|
| 1520 |
+
|
| 1521 |
+
updated_y_true = ops.cond(
|
| 1522 |
+
is_binary, _convert_binary_labels, _return_labels_unconverted
|
| 1523 |
+
)
|
| 1524 |
+
return updated_y_true
|
| 1525 |
+
|
| 1526 |
+
|
| 1527 |
+
@keras_export(
|
| 1528 |
+
[
|
| 1529 |
+
"keras.metrics.hinge",
|
| 1530 |
+
"keras.losses.hinge",
|
| 1531 |
+
]
|
| 1532 |
+
)
|
| 1533 |
+
def hinge(y_true, y_pred):
|
| 1534 |
+
"""Computes the hinge loss between `y_true` & `y_pred`.
|
| 1535 |
+
|
| 1536 |
+
Formula:
|
| 1537 |
+
|
| 1538 |
+
```python
|
| 1539 |
+
loss = mean(maximum(1 - y_true * y_pred, 0), axis=-1)
|
| 1540 |
+
```
|
| 1541 |
+
|
| 1542 |
+
Args:
|
| 1543 |
+
y_true: The ground truth values. `y_true` values are expected to be -1
|
| 1544 |
+
or 1. If binary (0 or 1) labels are provided they will be converted
|
| 1545 |
+
to -1 or 1 with shape = `[batch_size, d0, .. dN]`.
|
| 1546 |
+
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
|
| 1547 |
+
|
| 1548 |
+
Returns:
|
| 1549 |
+
Hinge loss values with shape = `[batch_size, d0, .. dN-1]`.
|
| 1550 |
+
|
| 1551 |
+
Example:
|
| 1552 |
+
|
| 1553 |
+
>>> y_true = np.random.choice([-1, 1], size=(2, 3))
|
| 1554 |
+
>>> y_pred = np.random.random(size=(2, 3))
|
| 1555 |
+
>>> loss = keras.losses.hinge(y_true, y_pred)
|
| 1556 |
+
"""
|
| 1557 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1558 |
+
y_true = ops.cast(y_true, dtype=y_pred.dtype)
|
| 1559 |
+
y_true = ops.convert_to_tensor(y_true)
|
| 1560 |
+
y_true = convert_binary_labels_to_hinge(y_true)
|
| 1561 |
+
return ops.mean(ops.maximum(1.0 - y_true * y_pred, 0.0), axis=-1)
|
| 1562 |
+
|
| 1563 |
+
|
| 1564 |
+
@keras_export(
|
| 1565 |
+
[
|
| 1566 |
+
"keras.metrics.squared_hinge",
|
| 1567 |
+
"keras.losses.squared_hinge",
|
| 1568 |
+
]
|
| 1569 |
+
)
|
| 1570 |
+
def squared_hinge(y_true, y_pred):
|
| 1571 |
+
"""Computes the squared hinge loss between `y_true` & `y_pred`.
|
| 1572 |
+
|
| 1573 |
+
Formula:
|
| 1574 |
+
|
| 1575 |
+
```python
|
| 1576 |
+
loss = mean(square(maximum(1 - y_true * y_pred, 0)), axis=-1)
|
| 1577 |
+
```
|
| 1578 |
+
|
| 1579 |
+
Args:
|
| 1580 |
+
y_true: The ground truth values. `y_true` values are expected to be -1
|
| 1581 |
+
or 1. If binary (0 or 1) labels are provided we will convert them
|
| 1582 |
+
to -1 or 1 with shape = `[batch_size, d0, .. dN]`.
|
| 1583 |
+
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
|
| 1584 |
+
|
| 1585 |
+
Returns:
|
| 1586 |
+
Squared hinge loss values with shape = `[batch_size, d0, .. dN-1]`.
|
| 1587 |
+
|
| 1588 |
+
Example:
|
| 1589 |
+
|
| 1590 |
+
>>> y_true = np.random.choice([-1, 1], size=(2, 3))
|
| 1591 |
+
>>> y_pred = np.random.random(size=(2, 3))
|
| 1592 |
+
>>> loss = keras.losses.squared_hinge(y_true, y_pred)
|
| 1593 |
+
"""
|
| 1594 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1595 |
+
y_true = ops.cast(y_true, y_pred.dtype)
|
| 1596 |
+
y_true = convert_binary_labels_to_hinge(y_true)
|
| 1597 |
+
return ops.mean(
|
| 1598 |
+
ops.square(ops.maximum(1.0 - y_true * y_pred, 0.0)), axis=-1
|
| 1599 |
+
)
|
| 1600 |
+
|
| 1601 |
+
|
| 1602 |
+
@keras_export(
|
| 1603 |
+
[
|
| 1604 |
+
"keras.metrics.categorical_hinge",
|
| 1605 |
+
"keras.losses.categorical_hinge",
|
| 1606 |
+
]
|
| 1607 |
+
)
|
| 1608 |
+
def categorical_hinge(y_true, y_pred):
|
| 1609 |
+
"""Computes the categorical hinge loss between `y_true` & `y_pred`.
|
| 1610 |
+
|
| 1611 |
+
Formula:
|
| 1612 |
+
|
| 1613 |
+
```python
|
| 1614 |
+
loss = maximum(neg - pos + 1, 0)
|
| 1615 |
+
```
|
| 1616 |
+
|
| 1617 |
+
where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)`
|
| 1618 |
+
|
| 1619 |
+
Args:
|
| 1620 |
+
y_true: The ground truth values. `y_true` values are expected to be
|
| 1621 |
+
either `{-1, +1}` or `{0, 1}` (i.e. a one-hot-encoded tensor) with
|
| 1622 |
+
shape = `[batch_size, d0, .. dN]`.
|
| 1623 |
+
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
|
| 1624 |
+
|
| 1625 |
+
Returns:
|
| 1626 |
+
Categorical hinge loss values with shape = `[batch_size, d0, .. dN-1]`.
|
| 1627 |
+
|
| 1628 |
+
Example:
|
| 1629 |
+
|
| 1630 |
+
>>> y_true = np.random.randint(0, 3, size=(2,))
|
| 1631 |
+
>>> y_true = np.eye(np.max(y_true) + 1)[y_true]
|
| 1632 |
+
>>> y_pred = np.random.random(size=(2, 3))
|
| 1633 |
+
>>> loss = keras.losses.categorical_hinge(y_true, y_pred)
|
| 1634 |
+
"""
|
| 1635 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1636 |
+
y_true = ops.cast(y_true, y_pred.dtype)
|
| 1637 |
+
pos = ops.sum(y_true * y_pred, axis=-1)
|
| 1638 |
+
neg = ops.max((1.0 - y_true) * y_pred, axis=-1)
|
| 1639 |
+
zero = ops.cast(0.0, y_pred.dtype)
|
| 1640 |
+
return ops.maximum(neg - pos + 1.0, zero)
|
| 1641 |
+
|
| 1642 |
+
|
| 1643 |
+
@keras_export(
|
| 1644 |
+
[
|
| 1645 |
+
"keras.metrics.mean_squared_error",
|
| 1646 |
+
"keras.losses.mean_squared_error",
|
| 1647 |
+
# Legacy aliases
|
| 1648 |
+
"keras._legacy.losses.mse",
|
| 1649 |
+
"keras._legacy.losses.MSE",
|
| 1650 |
+
"keras._legacy.metrics.mse",
|
| 1651 |
+
"keras._legacy.metrics.MSE",
|
| 1652 |
+
]
|
| 1653 |
+
)
|
| 1654 |
+
def mean_squared_error(y_true, y_pred):
|
| 1655 |
+
"""Computes the mean squared error between labels and predictions.
|
| 1656 |
+
|
| 1657 |
+
Formula:
|
| 1658 |
+
|
| 1659 |
+
```python
|
| 1660 |
+
loss = mean(square(y_true - y_pred), axis=-1)
|
| 1661 |
+
```
|
| 1662 |
+
|
| 1663 |
+
Example:
|
| 1664 |
+
|
| 1665 |
+
>>> y_true = np.random.randint(0, 2, size=(2, 3))
|
| 1666 |
+
>>> y_pred = np.random.random(size=(2, 3))
|
| 1667 |
+
>>> loss = keras.losses.mean_squared_error(y_true, y_pred)
|
| 1668 |
+
|
| 1669 |
+
Args:
|
| 1670 |
+
y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
|
| 1671 |
+
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
|
| 1672 |
+
|
| 1673 |
+
Returns:
|
| 1674 |
+
Mean squared error values with shape = `[batch_size, d0, .. dN-1]`.
|
| 1675 |
+
"""
|
| 1676 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1677 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 1678 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 1679 |
+
return ops.mean(ops.square(y_true - y_pred), axis=-1)
|
| 1680 |
+
|
| 1681 |
+
|
| 1682 |
+
@keras_export(
|
| 1683 |
+
[
|
| 1684 |
+
"keras.metrics.mean_absolute_error",
|
| 1685 |
+
"keras.losses.mean_absolute_error",
|
| 1686 |
+
# Legacy aliases
|
| 1687 |
+
"keras._legacy.losses.MAE",
|
| 1688 |
+
"keras._legacy.losses.mae",
|
| 1689 |
+
"keras._legacy.metrics.MAE",
|
| 1690 |
+
"keras._legacy.metrics.mae",
|
| 1691 |
+
]
|
| 1692 |
+
)
|
| 1693 |
+
def mean_absolute_error(y_true, y_pred):
|
| 1694 |
+
"""Computes the mean absolute error between labels and predictions.
|
| 1695 |
+
|
| 1696 |
+
```python
|
| 1697 |
+
loss = mean(abs(y_true - y_pred), axis=-1)
|
| 1698 |
+
```
|
| 1699 |
+
|
| 1700 |
+
Args:
|
| 1701 |
+
y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
|
| 1702 |
+
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
|
| 1703 |
+
|
| 1704 |
+
Returns:
|
| 1705 |
+
Mean absolute error values with shape = `[batch_size, d0, .. dN-1]`.
|
| 1706 |
+
|
| 1707 |
+
Example:
|
| 1708 |
+
|
| 1709 |
+
>>> y_true = np.random.randint(0, 2, size=(2, 3))
|
| 1710 |
+
>>> y_pred = np.random.random(size=(2, 3))
|
| 1711 |
+
>>> loss = keras.losses.mean_absolute_error(y_true, y_pred)
|
| 1712 |
+
"""
|
| 1713 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1714 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 1715 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 1716 |
+
return ops.mean(ops.abs(y_true - y_pred), axis=-1)
|
| 1717 |
+
|
| 1718 |
+
|
| 1719 |
+
@keras_export(
|
| 1720 |
+
[
|
| 1721 |
+
"keras.metrics.mean_absolute_percentage_error",
|
| 1722 |
+
"keras.losses.mean_absolute_percentage_error",
|
| 1723 |
+
# Legacy aliases
|
| 1724 |
+
"keras._legacy.losses.mape",
|
| 1725 |
+
"keras._legacy.losses.MAPE",
|
| 1726 |
+
"keras._legacy.metrics.mape",
|
| 1727 |
+
"keras._legacy.metrics.MAPE",
|
| 1728 |
+
]
|
| 1729 |
+
)
|
| 1730 |
+
def mean_absolute_percentage_error(y_true, y_pred):
|
| 1731 |
+
"""Computes the mean absolute percentage error between `y_true` & `y_pred`.
|
| 1732 |
+
|
| 1733 |
+
Formula:
|
| 1734 |
+
|
| 1735 |
+
```python
|
| 1736 |
+
loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1)
|
| 1737 |
+
```
|
| 1738 |
+
|
| 1739 |
+
Division by zero is prevented by dividing by `maximum(y_true, epsilon)`
|
| 1740 |
+
where `epsilon = keras.backend.epsilon()`
|
| 1741 |
+
(default to `1e-7`).
|
| 1742 |
+
|
| 1743 |
+
Args:
|
| 1744 |
+
y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
|
| 1745 |
+
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
|
| 1746 |
+
|
| 1747 |
+
Returns:
|
| 1748 |
+
Mean absolute percentage error values with shape = `[batch_size, d0, ..
|
| 1749 |
+
dN-1]`.
|
| 1750 |
+
|
| 1751 |
+
Example:
|
| 1752 |
+
|
| 1753 |
+
>>> y_true = np.random.random(size=(2, 3))
|
| 1754 |
+
>>> y_pred = np.random.random(size=(2, 3))
|
| 1755 |
+
>>> loss = keras.losses.mean_absolute_percentage_error(y_true, y_pred)
|
| 1756 |
+
"""
|
| 1757 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1758 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 1759 |
+
epsilon = ops.convert_to_tensor(backend.epsilon(), dtype=y_pred.dtype)
|
| 1760 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 1761 |
+
diff = ops.abs((y_true - y_pred) / ops.maximum(ops.abs(y_true), epsilon))
|
| 1762 |
+
return 100.0 * ops.mean(diff, axis=-1)
|
| 1763 |
+
|
| 1764 |
+
|
| 1765 |
+
@keras_export(
|
| 1766 |
+
[
|
| 1767 |
+
"keras.metrics.mean_squared_logarithmic_error",
|
| 1768 |
+
"keras.losses.mean_squared_logarithmic_error",
|
| 1769 |
+
# Legacy aliases
|
| 1770 |
+
"keras._legacy.losses.msle",
|
| 1771 |
+
"keras._legacy.losses.MSLE",
|
| 1772 |
+
"keras._legacy.metrics.msle",
|
| 1773 |
+
"keras._legacy.metrics.MSLE",
|
| 1774 |
+
]
|
| 1775 |
+
)
|
| 1776 |
+
def mean_squared_logarithmic_error(y_true, y_pred):
|
| 1777 |
+
"""Computes the mean squared logarithmic error between `y_true` & `y_pred`.
|
| 1778 |
+
|
| 1779 |
+
Formula:
|
| 1780 |
+
|
| 1781 |
+
```python
|
| 1782 |
+
loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)
|
| 1783 |
+
```
|
| 1784 |
+
|
| 1785 |
+
Note that `y_pred` and `y_true` cannot be less or equal to 0. Negative
|
| 1786 |
+
values and 0 values will be replaced with `keras.backend.epsilon()`
|
| 1787 |
+
(default to `1e-7`).
|
| 1788 |
+
|
| 1789 |
+
Args:
|
| 1790 |
+
y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
|
| 1791 |
+
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
|
| 1792 |
+
|
| 1793 |
+
Returns:
|
| 1794 |
+
Mean squared logarithmic error values with shape = `[batch_size, d0, ..
|
| 1795 |
+
dN-1]`.
|
| 1796 |
+
|
| 1797 |
+
Example:
|
| 1798 |
+
|
| 1799 |
+
>>> y_true = np.random.randint(0, 2, size=(2, 3))
|
| 1800 |
+
>>> y_pred = np.random.random(size=(2, 3))
|
| 1801 |
+
>>> loss = keras.losses.mean_squared_logarithmic_error(y_true, y_pred)
|
| 1802 |
+
"""
|
| 1803 |
+
epsilon = ops.convert_to_tensor(backend.epsilon())
|
| 1804 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1805 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 1806 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 1807 |
+
first_log = ops.log(ops.maximum(y_pred, epsilon) + 1.0)
|
| 1808 |
+
second_log = ops.log(ops.maximum(y_true, epsilon) + 1.0)
|
| 1809 |
+
return ops.mean(ops.square(first_log - second_log), axis=-1)
|
| 1810 |
+
|
| 1811 |
+
|
| 1812 |
+
@keras_export("keras.losses.cosine_similarity")
|
| 1813 |
+
def cosine_similarity(y_true, y_pred, axis=-1):
|
| 1814 |
+
"""Computes the cosine similarity between labels and predictions.
|
| 1815 |
+
|
| 1816 |
+
Formula:
|
| 1817 |
+
```python
|
| 1818 |
+
loss = -sum(l2_norm(y_true) * l2_norm(y_pred))
|
| 1819 |
+
```
|
| 1820 |
+
|
| 1821 |
+
Note that it is a number between -1 and 1. When it is a negative number
|
| 1822 |
+
between -1 and 0, 0 indicates orthogonality and values closer to -1
|
| 1823 |
+
indicate greater similarity. This makes it usable as a loss function in a
|
| 1824 |
+
setting where you try to maximize the proximity between predictions and
|
| 1825 |
+
targets. If either `y_true` or `y_pred` is a zero vector, cosine
|
| 1826 |
+
similarity will be 0 regardless of the proximity between predictions
|
| 1827 |
+
and targets.
|
| 1828 |
+
|
| 1829 |
+
Args:
|
| 1830 |
+
y_true: Tensor of true targets.
|
| 1831 |
+
y_pred: Tensor of predicted targets.
|
| 1832 |
+
axis: Axis along which to determine similarity. Defaults to `-1`.
|
| 1833 |
+
|
| 1834 |
+
Returns:
|
| 1835 |
+
Cosine similarity tensor.
|
| 1836 |
+
|
| 1837 |
+
Example:
|
| 1838 |
+
|
| 1839 |
+
>>> y_true = [[0., 1.], [1., 1.], [1., 1.]]
|
| 1840 |
+
>>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]]
|
| 1841 |
+
>>> loss = keras.losses.cosine_similarity(y_true, y_pred, axis=-1)
|
| 1842 |
+
[-0., -0.99999994, 0.99999994]
|
| 1843 |
+
"""
|
| 1844 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1845 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 1846 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 1847 |
+
y_pred = normalize(y_pred, axis=axis)
|
| 1848 |
+
y_true = normalize(y_true, axis=axis)
|
| 1849 |
+
return -ops.sum(y_true * y_pred, axis=axis)
|
| 1850 |
+
|
| 1851 |
+
|
| 1852 |
+
@keras_export(["keras.losses.huber", "keras.metrics.huber"])
|
| 1853 |
+
def huber(y_true, y_pred, delta=1.0):
|
| 1854 |
+
"""Computes Huber loss value.
|
| 1855 |
+
|
| 1856 |
+
Formula:
|
| 1857 |
+
```python
|
| 1858 |
+
for x in error:
|
| 1859 |
+
if abs(x) <= delta:
|
| 1860 |
+
loss.append(0.5 * x^2)
|
| 1861 |
+
elif abs(x) > delta:
|
| 1862 |
+
loss.append(delta * abs(x) - 0.5 * delta^2)
|
| 1863 |
+
|
| 1864 |
+
loss = mean(loss, axis=-1)
|
| 1865 |
+
```
|
| 1866 |
+
See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).
|
| 1867 |
+
|
| 1868 |
+
Example:
|
| 1869 |
+
|
| 1870 |
+
>>> y_true = [[0, 1], [0, 0]]
|
| 1871 |
+
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
|
| 1872 |
+
>>> loss = keras.losses.huber(y_true, y_pred)
|
| 1873 |
+
0.155
|
| 1874 |
+
|
| 1875 |
+
|
| 1876 |
+
Args:
|
| 1877 |
+
y_true: tensor of true targets.
|
| 1878 |
+
y_pred: tensor of predicted targets.
|
| 1879 |
+
delta: A float, the point where the Huber loss function changes from a
|
| 1880 |
+
quadratic to linear. Defaults to `1.0`.
|
| 1881 |
+
|
| 1882 |
+
Returns:
|
| 1883 |
+
Tensor with one scalar loss entry per sample.
|
| 1884 |
+
"""
|
| 1885 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1886 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 1887 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 1888 |
+
delta = ops.convert_to_tensor(delta, dtype=y_pred.dtype)
|
| 1889 |
+
error = ops.subtract(y_pred, y_true)
|
| 1890 |
+
abs_error = ops.abs(error)
|
| 1891 |
+
half = ops.convert_to_tensor(0.5, dtype=abs_error.dtype)
|
| 1892 |
+
return ops.mean(
|
| 1893 |
+
ops.where(
|
| 1894 |
+
abs_error <= delta,
|
| 1895 |
+
half * ops.square(error),
|
| 1896 |
+
delta * abs_error - half * ops.square(delta),
|
| 1897 |
+
),
|
| 1898 |
+
axis=-1,
|
| 1899 |
+
)
|
| 1900 |
+
|
| 1901 |
+
|
| 1902 |
+
@keras_export(
|
| 1903 |
+
[
|
| 1904 |
+
"keras.losses.log_cosh",
|
| 1905 |
+
"keras.metrics.log_cosh",
|
| 1906 |
+
# Legacy aliases
|
| 1907 |
+
"keras._legacy.losses.logcosh",
|
| 1908 |
+
"keras._legacy.metrics.logcosh",
|
| 1909 |
+
]
|
| 1910 |
+
)
|
| 1911 |
+
def log_cosh(y_true, y_pred):
|
| 1912 |
+
"""Logarithm of the hyperbolic cosine of the prediction error.
|
| 1913 |
+
|
| 1914 |
+
Formula:
|
| 1915 |
+
```python
|
| 1916 |
+
loss = mean(log(cosh(y_pred - y_true)), axis=-1)
|
| 1917 |
+
```
|
| 1918 |
+
|
| 1919 |
+
Note that `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small
|
| 1920 |
+
`x` and to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works
|
| 1921 |
+
mostly like the mean squared error, but will not be so strongly affected by
|
| 1922 |
+
the occasional wildly incorrect prediction.
|
| 1923 |
+
|
| 1924 |
+
Example:
|
| 1925 |
+
|
| 1926 |
+
>>> y_true = [[0., 1.], [0., 0.]]
|
| 1927 |
+
>>> y_pred = [[1., 1.], [0., 0.]]
|
| 1928 |
+
>>> loss = keras.losses.log_cosh(y_true, y_pred)
|
| 1929 |
+
0.108
|
| 1930 |
+
|
| 1931 |
+
Args:
|
| 1932 |
+
y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
|
| 1933 |
+
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
|
| 1934 |
+
|
| 1935 |
+
Returns:
|
| 1936 |
+
Logcosh error values with shape = `[batch_size, d0, .. dN-1]`.
|
| 1937 |
+
"""
|
| 1938 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1939 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 1940 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 1941 |
+
log2 = ops.convert_to_tensor(ops.log(2.0), dtype=y_pred.dtype)
|
| 1942 |
+
|
| 1943 |
+
def _logcosh(x):
|
| 1944 |
+
return x + ops.softplus(x * -2.0) - log2
|
| 1945 |
+
|
| 1946 |
+
return ops.mean(_logcosh(y_pred - y_true), axis=-1)
|
| 1947 |
+
|
| 1948 |
+
|
| 1949 |
+
@keras_export(
|
| 1950 |
+
[
|
| 1951 |
+
"keras.metrics.kl_divergence",
|
| 1952 |
+
"keras.losses.kl_divergence",
|
| 1953 |
+
# Legacy aliases
|
| 1954 |
+
"keras._legacy.losses.KLD",
|
| 1955 |
+
"keras._legacy.losses.kld",
|
| 1956 |
+
"keras._legacy.losses.kullback_leibler_divergence",
|
| 1957 |
+
"keras._legacy.metrics.KLD",
|
| 1958 |
+
"keras._legacy.metrics.kld",
|
| 1959 |
+
"keras._legacy.metrics.kullback_leibler_divergence",
|
| 1960 |
+
]
|
| 1961 |
+
)
|
| 1962 |
+
def kl_divergence(y_true, y_pred):
|
| 1963 |
+
"""Computes Kullback-Leibler divergence loss between `y_true` & `y_pred`.
|
| 1964 |
+
|
| 1965 |
+
Formula:
|
| 1966 |
+
|
| 1967 |
+
```python
|
| 1968 |
+
loss = y_true * log(y_true / y_pred)
|
| 1969 |
+
```
|
| 1970 |
+
|
| 1971 |
+
`y_true` and `y_pred` are expected to be probability
|
| 1972 |
+
distributions, with values between 0 and 1. They will get
|
| 1973 |
+
clipped to the `[0, 1]` range.
|
| 1974 |
+
|
| 1975 |
+
Args:
|
| 1976 |
+
y_true: Tensor of true targets.
|
| 1977 |
+
y_pred: Tensor of predicted targets.
|
| 1978 |
+
|
| 1979 |
+
Returns:
|
| 1980 |
+
KL Divergence loss values with shape = `[batch_size, d0, .. dN-1]`.
|
| 1981 |
+
|
| 1982 |
+
Example:
|
| 1983 |
+
|
| 1984 |
+
>>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float32)
|
| 1985 |
+
>>> y_pred = np.random.random(size=(2, 3))
|
| 1986 |
+
>>> loss = keras.losses.kl_divergence(y_true, y_pred)
|
| 1987 |
+
>>> assert loss.shape == (2,)
|
| 1988 |
+
>>> y_true = ops.clip(y_true, 1e-7, 1)
|
| 1989 |
+
>>> y_pred = ops.clip(y_pred, 1e-7, 1)
|
| 1990 |
+
>>> assert np.array_equal(
|
| 1991 |
+
... loss, np.sum(y_true * np.log(y_true / y_pred), axis=-1))
|
| 1992 |
+
"""
|
| 1993 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 1994 |
+
y_true = ops.convert_to_tensor(y_true, y_pred.dtype)
|
| 1995 |
+
y_true = ops.clip(y_true, backend.epsilon(), 1)
|
| 1996 |
+
y_pred = ops.clip(y_pred, backend.epsilon(), 1)
|
| 1997 |
+
return ops.sum(y_true * ops.log(y_true / y_pred), axis=-1)
|
| 1998 |
+
|
| 1999 |
+
|
| 2000 |
+
@keras_export(
|
| 2001 |
+
[
|
| 2002 |
+
"keras.metrics.poisson",
|
| 2003 |
+
"keras.losses.poisson",
|
| 2004 |
+
]
|
| 2005 |
+
)
|
| 2006 |
+
def poisson(y_true, y_pred):
|
| 2007 |
+
"""Computes the Poisson loss between y_true and y_pred.
|
| 2008 |
+
|
| 2009 |
+
Formula:
|
| 2010 |
+
|
| 2011 |
+
```python
|
| 2012 |
+
loss = y_pred - y_true * log(y_pred)
|
| 2013 |
+
```
|
| 2014 |
+
|
| 2015 |
+
Args:
|
| 2016 |
+
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
|
| 2017 |
+
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
|
| 2018 |
+
|
| 2019 |
+
Returns:
|
| 2020 |
+
Poisson loss values with shape = `[batch_size, d0, .. dN-1]`.
|
| 2021 |
+
|
| 2022 |
+
Example:
|
| 2023 |
+
|
| 2024 |
+
>>> y_true = np.random.randint(0, 2, size=(2, 3))
|
| 2025 |
+
>>> y_pred = np.random.random(size=(2, 3))
|
| 2026 |
+
>>> loss = keras.losses.poisson(y_true, y_pred)
|
| 2027 |
+
>>> assert loss.shape == (2,)
|
| 2028 |
+
>>> y_pred = y_pred + 1e-7
|
| 2029 |
+
>>> assert np.allclose(
|
| 2030 |
+
... loss, np.mean(y_pred - y_true * np.log(y_pred), axis=-1),
|
| 2031 |
+
... atol=1e-5)
|
| 2032 |
+
"""
|
| 2033 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 2034 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 2035 |
+
epsilon = ops.convert_to_tensor(backend.epsilon(), dtype=y_pred.dtype)
|
| 2036 |
+
return ops.mean(y_pred - y_true * ops.log(y_pred + epsilon), axis=-1)
|
| 2037 |
+
|
| 2038 |
+
|
| 2039 |
+
@keras_export(
|
| 2040 |
+
[
|
| 2041 |
+
"keras.metrics.categorical_crossentropy",
|
| 2042 |
+
"keras.losses.categorical_crossentropy",
|
| 2043 |
+
]
|
| 2044 |
+
)
|
| 2045 |
+
def categorical_crossentropy(
|
| 2046 |
+
y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1
|
| 2047 |
+
):
|
| 2048 |
+
"""Computes the categorical crossentropy loss.
|
| 2049 |
+
|
| 2050 |
+
Args:
|
| 2051 |
+
y_true: Tensor of one-hot true targets.
|
| 2052 |
+
y_pred: Tensor of predicted targets.
|
| 2053 |
+
from_logits: Whether `y_pred` is expected to be a logits tensor. By
|
| 2054 |
+
default, we assume that `y_pred` encodes a probability distribution.
|
| 2055 |
+
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For
|
| 2056 |
+
example, if `0.1`, use `0.1 / num_classes` for non-target labels
|
| 2057 |
+
and `0.9 + 0.1 / num_classes` for target labels.
|
| 2058 |
+
axis: Defaults to `-1`. The dimension along which the entropy is
|
| 2059 |
+
computed.
|
| 2060 |
+
|
| 2061 |
+
Returns:
|
| 2062 |
+
Categorical crossentropy loss value.
|
| 2063 |
+
|
| 2064 |
+
Example:
|
| 2065 |
+
|
| 2066 |
+
>>> y_true = [[0, 1, 0], [0, 0, 1]]
|
| 2067 |
+
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
| 2068 |
+
>>> loss = keras.losses.categorical_crossentropy(y_true, y_pred)
|
| 2069 |
+
>>> assert loss.shape == (2,)
|
| 2070 |
+
>>> loss
|
| 2071 |
+
array([0.0513, 2.303], dtype=float32)
|
| 2072 |
+
"""
|
| 2073 |
+
if isinstance(axis, bool):
|
| 2074 |
+
raise ValueError(
|
| 2075 |
+
"`axis` must be of type `int`. "
|
| 2076 |
+
f"Received: axis={axis} of type {type(axis)}"
|
| 2077 |
+
)
|
| 2078 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 2079 |
+
y_true = ops.cast(y_true, y_pred.dtype)
|
| 2080 |
+
|
| 2081 |
+
if y_pred.shape[-1] == 1:
|
| 2082 |
+
warnings.warn(
|
| 2083 |
+
"In loss categorical_crossentropy, expected "
|
| 2084 |
+
"y_pred.shape to be (batch_size, num_classes) "
|
| 2085 |
+
f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. "
|
| 2086 |
+
"Consider using 'binary_crossentropy' if you only have 2 classes.",
|
| 2087 |
+
SyntaxWarning,
|
| 2088 |
+
stacklevel=2,
|
| 2089 |
+
)
|
| 2090 |
+
|
| 2091 |
+
if label_smoothing:
|
| 2092 |
+
num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype)
|
| 2093 |
+
y_true = y_true * (1.0 - label_smoothing) + (
|
| 2094 |
+
label_smoothing / num_classes
|
| 2095 |
+
)
|
| 2096 |
+
|
| 2097 |
+
return ops.categorical_crossentropy(
|
| 2098 |
+
y_true, y_pred, from_logits=from_logits, axis=axis
|
| 2099 |
+
)
|
| 2100 |
+
|
| 2101 |
+
|
| 2102 |
+
@keras_export(
|
| 2103 |
+
[
|
| 2104 |
+
"keras.metrics.categorical_focal_crossentropy",
|
| 2105 |
+
"keras.losses.categorical_focal_crossentropy",
|
| 2106 |
+
]
|
| 2107 |
+
)
|
| 2108 |
+
def categorical_focal_crossentropy(
|
| 2109 |
+
y_true,
|
| 2110 |
+
y_pred,
|
| 2111 |
+
alpha=0.25,
|
| 2112 |
+
gamma=2.0,
|
| 2113 |
+
from_logits=False,
|
| 2114 |
+
label_smoothing=0.0,
|
| 2115 |
+
axis=-1,
|
| 2116 |
+
):
|
| 2117 |
+
"""Computes the categorical focal crossentropy loss.
|
| 2118 |
+
|
| 2119 |
+
Args:
|
| 2120 |
+
y_true: Tensor of one-hot true targets.
|
| 2121 |
+
y_pred: Tensor of predicted targets.
|
| 2122 |
+
alpha: A weight balancing factor for all classes, default is `0.25` as
|
| 2123 |
+
mentioned in the reference. It can be a list of floats or a scalar.
|
| 2124 |
+
In the multi-class case, alpha may be set by inverse class
|
| 2125 |
+
frequency by using `compute_class_weight` from `sklearn.utils`.
|
| 2126 |
+
gamma: A focusing parameter, default is `2.0` as mentioned in the
|
| 2127 |
+
reference. It helps to gradually reduce the importance given to
|
| 2128 |
+
simple examples in a smooth manner. When `gamma` = 0, there is
|
| 2129 |
+
no focal effect on the categorical crossentropy.
|
| 2130 |
+
from_logits: Whether `y_pred` is expected to be a logits tensor. By
|
| 2131 |
+
default, we assume that `y_pred` encodes a probability
|
| 2132 |
+
distribution.
|
| 2133 |
+
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For
|
| 2134 |
+
example, if `0.1`, use `0.1 / num_classes` for non-target labels
|
| 2135 |
+
and `0.9 + 0.1 / num_classes` for target labels.
|
| 2136 |
+
axis: Defaults to `-1`. The dimension along which the entropy is
|
| 2137 |
+
computed.
|
| 2138 |
+
|
| 2139 |
+
Returns:
|
| 2140 |
+
Categorical focal crossentropy loss value.
|
| 2141 |
+
|
| 2142 |
+
Example:
|
| 2143 |
+
|
| 2144 |
+
>>> y_true = [[0, 1, 0], [0, 0, 1]]
|
| 2145 |
+
>>> y_pred = [[0.05, 0.9, 0.05], [0.1, 0.85, 0.05]]
|
| 2146 |
+
>>> loss = keras.losses.categorical_focal_crossentropy(y_true, y_pred)
|
| 2147 |
+
>>> assert loss.shape == (2,)
|
| 2148 |
+
>>> loss
|
| 2149 |
+
array([2.63401289e-04, 6.75912094e-01], dtype=float32)
|
| 2150 |
+
"""
|
| 2151 |
+
if isinstance(axis, bool):
|
| 2152 |
+
raise ValueError(
|
| 2153 |
+
"`axis` must be of type `int`. "
|
| 2154 |
+
f"Received: axis={axis} of type {type(axis)}"
|
| 2155 |
+
)
|
| 2156 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 2157 |
+
y_true = ops.cast(y_true, y_pred.dtype)
|
| 2158 |
+
|
| 2159 |
+
if y_pred.shape[-1] == 1:
|
| 2160 |
+
warnings.warn(
|
| 2161 |
+
"In loss categorical_focal_crossentropy, expected "
|
| 2162 |
+
"y_pred.shape to be (batch_size, num_classes) "
|
| 2163 |
+
f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. "
|
| 2164 |
+
"Consider using 'binary_crossentropy' if you only have 2 classes.",
|
| 2165 |
+
SyntaxWarning,
|
| 2166 |
+
stacklevel=2,
|
| 2167 |
+
)
|
| 2168 |
+
|
| 2169 |
+
if label_smoothing:
|
| 2170 |
+
num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype)
|
| 2171 |
+
y_true = y_true * (1.0 - label_smoothing) + (
|
| 2172 |
+
label_smoothing / num_classes
|
| 2173 |
+
)
|
| 2174 |
+
|
| 2175 |
+
if from_logits:
|
| 2176 |
+
y_pred = ops.softmax(y_pred, axis=axis)
|
| 2177 |
+
|
| 2178 |
+
# Adjust the predictions so that the probability of
|
| 2179 |
+
# each class for every sample adds up to 1
|
| 2180 |
+
# This is needed to ensure that the cross entropy is
|
| 2181 |
+
# computed correctly.
|
| 2182 |
+
output = y_pred / ops.sum(y_pred, axis=axis, keepdims=True)
|
| 2183 |
+
output = ops.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
|
| 2184 |
+
|
| 2185 |
+
# Calculate cross entropy
|
| 2186 |
+
cce = -y_true * ops.log(output)
|
| 2187 |
+
|
| 2188 |
+
# Calculate factors
|
| 2189 |
+
modulating_factor = ops.power(1.0 - output, gamma)
|
| 2190 |
+
weighting_factor = ops.multiply(modulating_factor, alpha)
|
| 2191 |
+
|
| 2192 |
+
# Apply weighting factor
|
| 2193 |
+
focal_cce = ops.multiply(weighting_factor, cce)
|
| 2194 |
+
focal_cce = ops.sum(focal_cce, axis=axis)
|
| 2195 |
+
return focal_cce
|
| 2196 |
+
|
| 2197 |
+
|
| 2198 |
+
@keras_export(
|
| 2199 |
+
[
|
| 2200 |
+
"keras.metrics.sparse_categorical_crossentropy",
|
| 2201 |
+
"keras.losses.sparse_categorical_crossentropy",
|
| 2202 |
+
]
|
| 2203 |
+
)
|
| 2204 |
+
def sparse_categorical_crossentropy(
|
| 2205 |
+
y_true, y_pred, from_logits=False, ignore_class=None, axis=-1
|
| 2206 |
+
):
|
| 2207 |
+
"""Computes the sparse categorical crossentropy loss.
|
| 2208 |
+
|
| 2209 |
+
Args:
|
| 2210 |
+
y_true: Ground truth values.
|
| 2211 |
+
y_pred: The predicted values.
|
| 2212 |
+
from_logits: Whether `y_pred` is expected to be a logits tensor. By
|
| 2213 |
+
default, we assume that `y_pred` encodes a probability distribution.
|
| 2214 |
+
ignore_class: Optional integer. The ID of a class to be ignored during
|
| 2215 |
+
loss computation. This is useful, for example, in segmentation
|
| 2216 |
+
problems featuring a "void" class (commonly -1 or 255) in
|
| 2217 |
+
segmentation maps. By default (`ignore_class=None`), all classes are
|
| 2218 |
+
considered.
|
| 2219 |
+
axis: Defaults to `-1`. The dimension along which the entropy is
|
| 2220 |
+
computed.
|
| 2221 |
+
|
| 2222 |
+
Returns:
|
| 2223 |
+
Sparse categorical crossentropy loss value.
|
| 2224 |
+
|
| 2225 |
+
Examples:
|
| 2226 |
+
|
| 2227 |
+
>>> y_true = [1, 2]
|
| 2228 |
+
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
| 2229 |
+
>>> loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
|
| 2230 |
+
>>> assert loss.shape == (2,)
|
| 2231 |
+
>>> loss
|
| 2232 |
+
array([0.0513, 2.303], dtype=float32)
|
| 2233 |
+
"""
|
| 2234 |
+
|
| 2235 |
+
if len(y_true.shape) == len(y_pred.shape) and y_true.shape[-1] == 1:
|
| 2236 |
+
y_true = ops.squeeze(y_true, axis=-1)
|
| 2237 |
+
|
| 2238 |
+
if ignore_class is not None:
|
| 2239 |
+
res_shape = ops.shape(y_pred)[:-1]
|
| 2240 |
+
valid_mask = ops.not_equal(y_true, ops.cast(ignore_class, y_pred.dtype))
|
| 2241 |
+
y_true = y_true * ops.cast(valid_mask, y_true.dtype)
|
| 2242 |
+
y_pred = y_pred * ops.cast(
|
| 2243 |
+
ops.expand_dims(valid_mask, -1), y_pred.dtype
|
| 2244 |
+
)
|
| 2245 |
+
|
| 2246 |
+
res = ops.sparse_categorical_crossentropy(
|
| 2247 |
+
y_true,
|
| 2248 |
+
y_pred,
|
| 2249 |
+
from_logits=from_logits,
|
| 2250 |
+
axis=axis,
|
| 2251 |
+
)
|
| 2252 |
+
|
| 2253 |
+
if ignore_class is not None:
|
| 2254 |
+
valid_mask = ops.reshape(valid_mask, res_shape)
|
| 2255 |
+
res = ops.where(valid_mask, res, 0.0)
|
| 2256 |
+
backend.set_keras_mask(res, mask=valid_mask)
|
| 2257 |
+
|
| 2258 |
+
return res
|
| 2259 |
+
|
| 2260 |
+
|
| 2261 |
+
@keras_export(
|
| 2262 |
+
[
|
| 2263 |
+
"keras.metrics.binary_crossentropy",
|
| 2264 |
+
"keras.losses.binary_crossentropy",
|
| 2265 |
+
]
|
| 2266 |
+
)
|
| 2267 |
+
def binary_crossentropy(
|
| 2268 |
+
y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1
|
| 2269 |
+
):
|
| 2270 |
+
"""Computes the binary crossentropy loss.
|
| 2271 |
+
|
| 2272 |
+
Args:
|
| 2273 |
+
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
|
| 2274 |
+
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
|
| 2275 |
+
from_logits: Whether `y_pred` is expected to be a logits tensor. By
|
| 2276 |
+
default, we assume that `y_pred` encodes a probability distribution.
|
| 2277 |
+
label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by
|
| 2278 |
+
squeezing them towards 0.5, that is,
|
| 2279 |
+
using `1. - 0.5 * label_smoothing` for the target class
|
| 2280 |
+
and `0.5 * label_smoothing` for the non-target class.
|
| 2281 |
+
axis: The axis along which the mean is computed. Defaults to `-1`.
|
| 2282 |
+
|
| 2283 |
+
Returns:
|
| 2284 |
+
Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`.
|
| 2285 |
+
|
| 2286 |
+
Example:
|
| 2287 |
+
|
| 2288 |
+
>>> y_true = [[0, 1], [0, 0]]
|
| 2289 |
+
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
|
| 2290 |
+
>>> loss = keras.losses.binary_crossentropy(y_true, y_pred)
|
| 2291 |
+
>>> assert loss.shape == (2,)
|
| 2292 |
+
>>> loss
|
| 2293 |
+
array([0.916 , 0.714], dtype=float32)
|
| 2294 |
+
"""
|
| 2295 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 2296 |
+
y_true = ops.cast(y_true, y_pred.dtype)
|
| 2297 |
+
|
| 2298 |
+
if label_smoothing:
|
| 2299 |
+
y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
|
| 2300 |
+
|
| 2301 |
+
return ops.mean(
|
| 2302 |
+
ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
|
| 2303 |
+
axis=axis,
|
| 2304 |
+
)
|
| 2305 |
+
|
| 2306 |
+
|
| 2307 |
+
@keras_export(
|
| 2308 |
+
[
|
| 2309 |
+
"keras.metrics.binary_focal_crossentropy",
|
| 2310 |
+
"keras.losses.binary_focal_crossentropy",
|
| 2311 |
+
]
|
| 2312 |
+
)
|
| 2313 |
+
def binary_focal_crossentropy(
|
| 2314 |
+
y_true,
|
| 2315 |
+
y_pred,
|
| 2316 |
+
apply_class_balancing=False,
|
| 2317 |
+
alpha=0.25,
|
| 2318 |
+
gamma=2.0,
|
| 2319 |
+
from_logits=False,
|
| 2320 |
+
label_smoothing=0.0,
|
| 2321 |
+
axis=-1,
|
| 2322 |
+
):
|
| 2323 |
+
"""Computes the binary focal crossentropy loss.
|
| 2324 |
+
|
| 2325 |
+
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
|
| 2326 |
+
helps to apply a focal factor to down-weight easy examples and focus more on
|
| 2327 |
+
hard examples. By default, the focal tensor is computed as follows:
|
| 2328 |
+
|
| 2329 |
+
`focal_factor = (1 - output) ** gamma` for class 1
|
| 2330 |
+
`focal_factor = output ** gamma` for class 0
|
| 2331 |
+
where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal
|
| 2332 |
+
effect on the binary crossentropy loss.
|
| 2333 |
+
|
| 2334 |
+
If `apply_class_balancing == True`, this function also takes into account a
|
| 2335 |
+
weight balancing factor for the binary classes 0 and 1 as follows:
|
| 2336 |
+
|
| 2337 |
+
`weight = alpha` for class 1 (`target == 1`)
|
| 2338 |
+
`weight = 1 - alpha` for class 0
|
| 2339 |
+
where `alpha` is a float in the range of `[0, 1]`.
|
| 2340 |
+
|
| 2341 |
+
Args:
|
| 2342 |
+
y_true: Ground truth values, of shape `(batch_size, d0, .. dN)`.
|
| 2343 |
+
y_pred: The predicted values, of shape `(batch_size, d0, .. dN)`.
|
| 2344 |
+
apply_class_balancing: A bool, whether to apply weight balancing on the
|
| 2345 |
+
binary classes 0 and 1.
|
| 2346 |
+
alpha: A weight balancing factor for class 1, default is `0.25` as
|
| 2347 |
+
mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
|
| 2348 |
+
gamma: A focusing parameter, default is `2.0` as mentioned in the
|
| 2349 |
+
reference.
|
| 2350 |
+
from_logits: Whether `y_pred` is expected to be a logits tensor. By
|
| 2351 |
+
default, we assume that `y_pred` encodes a probability distribution.
|
| 2352 |
+
label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by
|
| 2353 |
+
squeezing them towards 0.5, that is,
|
| 2354 |
+
using `1. - 0.5 * label_smoothing` for the target class
|
| 2355 |
+
and `0.5 * label_smoothing` for the non-target class.
|
| 2356 |
+
axis: The axis along which the mean is computed. Defaults to `-1`.
|
| 2357 |
+
|
| 2358 |
+
Returns:
|
| 2359 |
+
Binary focal crossentropy loss value
|
| 2360 |
+
with shape = `[batch_size, d0, .. dN-1]`.
|
| 2361 |
+
|
| 2362 |
+
Example:
|
| 2363 |
+
|
| 2364 |
+
>>> y_true = [[0, 1], [0, 0]]
|
| 2365 |
+
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
|
| 2366 |
+
>>> loss = keras.losses.binary_focal_crossentropy(
|
| 2367 |
+
... y_true, y_pred, gamma=2)
|
| 2368 |
+
>>> assert loss.shape == (2,)
|
| 2369 |
+
>>> loss
|
| 2370 |
+
array([0.330, 0.206], dtype=float32)
|
| 2371 |
+
"""
|
| 2372 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 2373 |
+
y_true = ops.cast(y_true, y_pred.dtype)
|
| 2374 |
+
|
| 2375 |
+
if label_smoothing:
|
| 2376 |
+
y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
|
| 2377 |
+
|
| 2378 |
+
if from_logits:
|
| 2379 |
+
y_pred = ops.sigmoid(y_pred)
|
| 2380 |
+
|
| 2381 |
+
bce = ops.binary_crossentropy(
|
| 2382 |
+
target=y_true,
|
| 2383 |
+
output=y_pred,
|
| 2384 |
+
from_logits=False,
|
| 2385 |
+
)
|
| 2386 |
+
|
| 2387 |
+
# Calculate focal factor
|
| 2388 |
+
p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
|
| 2389 |
+
focal_factor = ops.power(1.0 - p_t, gamma)
|
| 2390 |
+
|
| 2391 |
+
focal_bce = focal_factor * bce
|
| 2392 |
+
|
| 2393 |
+
if apply_class_balancing:
|
| 2394 |
+
weight = y_true * alpha + (1 - y_true) * (1 - alpha)
|
| 2395 |
+
focal_bce = weight * focal_bce
|
| 2396 |
+
|
| 2397 |
+
return ops.mean(focal_bce, axis=axis)
|
| 2398 |
+
|
| 2399 |
+
|
| 2400 |
+
@keras_export("keras.losses.ctc")
|
| 2401 |
+
def ctc(y_true, y_pred):
|
| 2402 |
+
"""CTC (Connectionist Temporal Classification) loss.
|
| 2403 |
+
|
| 2404 |
+
Args:
|
| 2405 |
+
y_true: A tensor of shape `(batch_size, max_length)` containing
|
| 2406 |
+
the true labels in integer format. `0` always represents
|
| 2407 |
+
the blank/mask index and should not be used for classes.
|
| 2408 |
+
y_pred: A tensor of shape `(batch_size, max_length, num_classes)`
|
| 2409 |
+
containing logits (the output of your model).
|
| 2410 |
+
They should *not* be normalized via softmax.
|
| 2411 |
+
"""
|
| 2412 |
+
if len(ops.shape(y_true)) != 2:
|
| 2413 |
+
raise ValueError(
|
| 2414 |
+
"Targets `y_true` are expected to be a tensor of shape "
|
| 2415 |
+
"`(batch_size, max_length)` in integer format. "
|
| 2416 |
+
f"Received: y_true.shape={ops.shape(y_true)}"
|
| 2417 |
+
)
|
| 2418 |
+
if len(ops.shape(y_pred)) != 3:
|
| 2419 |
+
raise ValueError(
|
| 2420 |
+
"Logits `y_pred` are expected to be a tensor of shape "
|
| 2421 |
+
"`(batch_size, max_length, num_classes)`. "
|
| 2422 |
+
f"Received: y_pred.shape={ops.shape(y_pred)}"
|
| 2423 |
+
)
|
| 2424 |
+
|
| 2425 |
+
mask_index = 0
|
| 2426 |
+
batch_length = ops.shape(y_pred)[0]
|
| 2427 |
+
input_length = ops.shape(y_pred)[1]
|
| 2428 |
+
input_length = input_length * ops.ones((batch_length,), dtype="int32")
|
| 2429 |
+
label_length = ops.cast(
|
| 2430 |
+
ops.sum(y_true != mask_index, axis=-1), dtype="int32"
|
| 2431 |
+
)
|
| 2432 |
+
|
| 2433 |
+
return ops.ctc_loss(
|
| 2434 |
+
y_true, y_pred, label_length, input_length, mask_index=mask_index
|
| 2435 |
+
)
|
| 2436 |
+
|
| 2437 |
+
|
| 2438 |
+
@keras_export("keras.losses.dice")
|
| 2439 |
+
def dice(y_true, y_pred, axis=None):
|
| 2440 |
+
"""Computes the Dice loss value between `y_true` and `y_pred`.
|
| 2441 |
+
|
| 2442 |
+
Formula:
|
| 2443 |
+
```python
|
| 2444 |
+
loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred))
|
| 2445 |
+
```
|
| 2446 |
+
|
| 2447 |
+
Args:
|
| 2448 |
+
y_true: tensor of true targets.
|
| 2449 |
+
y_pred: tensor of predicted targets.
|
| 2450 |
+
axis: tuple for which dimensions the loss is calculated
|
| 2451 |
+
|
| 2452 |
+
Returns:
|
| 2453 |
+
Dice loss value.
|
| 2454 |
+
"""
|
| 2455 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 2456 |
+
y_true = ops.cast(y_true, y_pred.dtype)
|
| 2457 |
+
|
| 2458 |
+
inputs = y_true
|
| 2459 |
+
targets = y_pred
|
| 2460 |
+
|
| 2461 |
+
intersection = ops.sum(inputs * targets, axis=axis)
|
| 2462 |
+
dice = ops.divide(
|
| 2463 |
+
2.0 * intersection,
|
| 2464 |
+
ops.sum(y_true, axis=axis)
|
| 2465 |
+
+ ops.sum(y_pred, axis=axis)
|
| 2466 |
+
+ backend.epsilon(),
|
| 2467 |
+
)
|
| 2468 |
+
|
| 2469 |
+
return 1 - dice
|
| 2470 |
+
|
| 2471 |
+
|
| 2472 |
+
@keras_export("keras.losses.tversky")
|
| 2473 |
+
def tversky(y_true, y_pred, alpha=0.5, beta=0.5, axis=None):
|
| 2474 |
+
"""Computes the Tversky loss value between `y_true` and `y_pred`.
|
| 2475 |
+
|
| 2476 |
+
This loss function is weighted by the alpha and beta coefficients
|
| 2477 |
+
that penalize false positives and false negatives.
|
| 2478 |
+
|
| 2479 |
+
With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to
|
| 2480 |
+
Dice Loss.
|
| 2481 |
+
|
| 2482 |
+
Args:
|
| 2483 |
+
y_true: tensor of true targets.
|
| 2484 |
+
y_pred: tensor of predicted targets.
|
| 2485 |
+
alpha: coefficient controlling incidence of false positives.
|
| 2486 |
+
beta: coefficient controlling incidence of false negatives.
|
| 2487 |
+
axis: tuple for which dimensions the loss is calculated.
|
| 2488 |
+
|
| 2489 |
+
Returns:
|
| 2490 |
+
Tversky loss value.
|
| 2491 |
+
|
| 2492 |
+
Reference:
|
| 2493 |
+
|
| 2494 |
+
- [Salehi et al., 2017](https://arxiv.org/abs/1706.05721)
|
| 2495 |
+
"""
|
| 2496 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 2497 |
+
y_true = ops.cast(y_true, y_pred.dtype)
|
| 2498 |
+
|
| 2499 |
+
inputs = y_true
|
| 2500 |
+
targets = y_pred
|
| 2501 |
+
|
| 2502 |
+
intersection = ops.sum(inputs * targets, axis=axis)
|
| 2503 |
+
fp = ops.sum((1 - targets) * inputs, axis=axis)
|
| 2504 |
+
fn = ops.sum(targets * (1 - inputs), axis=axis)
|
| 2505 |
+
|
| 2506 |
+
tversky = ops.divide(
|
| 2507 |
+
intersection,
|
| 2508 |
+
intersection + fp * alpha + fn * beta + backend.epsilon(),
|
| 2509 |
+
)
|
| 2510 |
+
|
| 2511 |
+
return 1 - tversky
|
| 2512 |
+
|
| 2513 |
+
|
| 2514 |
+
@keras_export("keras.losses.circle")
|
| 2515 |
+
def circle(
|
| 2516 |
+
y_true,
|
| 2517 |
+
y_pred,
|
| 2518 |
+
ref_labels=None,
|
| 2519 |
+
ref_embeddings=None,
|
| 2520 |
+
remove_diagonal=True,
|
| 2521 |
+
gamma=80,
|
| 2522 |
+
margin=0.4,
|
| 2523 |
+
):
|
| 2524 |
+
"""Computes the Circle loss.
|
| 2525 |
+
|
| 2526 |
+
It is designed to minimize within-class distances and maximize between-class
|
| 2527 |
+
distances in L2 normalized embedding space.
|
| 2528 |
+
|
| 2529 |
+
Args:
|
| 2530 |
+
y_true: Tensor with ground truth labels in integer format.
|
| 2531 |
+
y_pred: Tensor with predicted L2 normalized embeddings.
|
| 2532 |
+
ref_labels: Optional integer tensor with labels for reference
|
| 2533 |
+
embeddings. If `None`, defaults to `y_true`.
|
| 2534 |
+
ref_embeddings: Optional tensor with L2 normalized reference embeddings.
|
| 2535 |
+
If `None`, defaults to `y_pred`.
|
| 2536 |
+
remove_diagonal: Boolean, whether to remove self-similarities from
|
| 2537 |
+
positive mask. Defaults to `True`.
|
| 2538 |
+
gamma: Float, scaling factor for the loss. Defaults to `80`.
|
| 2539 |
+
margin: Float, relaxation factor for the loss. Defaults to `0.4`.
|
| 2540 |
+
|
| 2541 |
+
Returns:
|
| 2542 |
+
Circle loss value.
|
| 2543 |
+
"""
|
| 2544 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 2545 |
+
y_true = ops.cast(y_true, "int32")
|
| 2546 |
+
ref_embeddings = (
|
| 2547 |
+
y_pred
|
| 2548 |
+
if ref_embeddings is None
|
| 2549 |
+
else ops.convert_to_tensor(ref_embeddings)
|
| 2550 |
+
)
|
| 2551 |
+
ref_labels = y_true if ref_labels is None else ops.cast(ref_labels, "int32")
|
| 2552 |
+
|
| 2553 |
+
optim_pos = margin
|
| 2554 |
+
optim_neg = 1 + margin
|
| 2555 |
+
delta_pos = margin
|
| 2556 |
+
delta_neg = 1 - margin
|
| 2557 |
+
|
| 2558 |
+
pairwise_cosine_distances = 1 - ops.matmul(
|
| 2559 |
+
y_pred, ops.transpose(ref_embeddings)
|
| 2560 |
+
)
|
| 2561 |
+
|
| 2562 |
+
pairwise_cosine_distances = ops.maximum(pairwise_cosine_distances, 0.0)
|
| 2563 |
+
positive_mask, negative_mask = build_pos_neg_masks(
|
| 2564 |
+
y_true,
|
| 2565 |
+
ref_labels,
|
| 2566 |
+
remove_diagonal=remove_diagonal,
|
| 2567 |
+
)
|
| 2568 |
+
positive_mask = ops.cast(
|
| 2569 |
+
positive_mask, dtype=pairwise_cosine_distances.dtype
|
| 2570 |
+
)
|
| 2571 |
+
negative_mask = ops.cast(
|
| 2572 |
+
negative_mask, dtype=pairwise_cosine_distances.dtype
|
| 2573 |
+
)
|
| 2574 |
+
|
| 2575 |
+
pos_weights = optim_pos + pairwise_cosine_distances
|
| 2576 |
+
pos_weights = pos_weights * positive_mask
|
| 2577 |
+
pos_weights = ops.maximum(pos_weights, 0.0)
|
| 2578 |
+
neg_weights = optim_neg - pairwise_cosine_distances
|
| 2579 |
+
neg_weights = neg_weights * negative_mask
|
| 2580 |
+
neg_weights = ops.maximum(neg_weights, 0.0)
|
| 2581 |
+
|
| 2582 |
+
pos_dists = delta_pos - pairwise_cosine_distances
|
| 2583 |
+
neg_dists = delta_neg - pairwise_cosine_distances
|
| 2584 |
+
|
| 2585 |
+
pos_wdists = -1 * gamma * pos_weights * pos_dists
|
| 2586 |
+
neg_wdists = gamma * neg_weights * neg_dists
|
| 2587 |
+
|
| 2588 |
+
p_loss = ops.logsumexp(
|
| 2589 |
+
ops.where(positive_mask, pos_wdists, float("-inf")),
|
| 2590 |
+
axis=1,
|
| 2591 |
+
)
|
| 2592 |
+
n_loss = ops.logsumexp(
|
| 2593 |
+
ops.where(negative_mask, neg_wdists, float("-inf")),
|
| 2594 |
+
axis=1,
|
| 2595 |
+
)
|
| 2596 |
+
|
| 2597 |
+
circle_loss = ops.softplus(p_loss + n_loss)
|
| 2598 |
+
backend.set_keras_mask(circle_loss, circle_loss > 0)
|
| 2599 |
+
return circle_loss
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__init__.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.metrics.accuracy_metrics import Accuracy
|
| 5 |
+
from keras.src.metrics.accuracy_metrics import BinaryAccuracy
|
| 6 |
+
from keras.src.metrics.accuracy_metrics import CategoricalAccuracy
|
| 7 |
+
from keras.src.metrics.accuracy_metrics import SparseCategoricalAccuracy
|
| 8 |
+
from keras.src.metrics.accuracy_metrics import SparseTopKCategoricalAccuracy
|
| 9 |
+
from keras.src.metrics.accuracy_metrics import TopKCategoricalAccuracy
|
| 10 |
+
from keras.src.metrics.confusion_metrics import AUC
|
| 11 |
+
from keras.src.metrics.confusion_metrics import FalseNegatives
|
| 12 |
+
from keras.src.metrics.confusion_metrics import FalsePositives
|
| 13 |
+
from keras.src.metrics.confusion_metrics import Precision
|
| 14 |
+
from keras.src.metrics.confusion_metrics import PrecisionAtRecall
|
| 15 |
+
from keras.src.metrics.confusion_metrics import Recall
|
| 16 |
+
from keras.src.metrics.confusion_metrics import RecallAtPrecision
|
| 17 |
+
from keras.src.metrics.confusion_metrics import SensitivityAtSpecificity
|
| 18 |
+
from keras.src.metrics.confusion_metrics import SpecificityAtSensitivity
|
| 19 |
+
from keras.src.metrics.confusion_metrics import TrueNegatives
|
| 20 |
+
from keras.src.metrics.confusion_metrics import TruePositives
|
| 21 |
+
from keras.src.metrics.correlation_metrics import ConcordanceCorrelation
|
| 22 |
+
from keras.src.metrics.correlation_metrics import PearsonCorrelation
|
| 23 |
+
from keras.src.metrics.f_score_metrics import F1Score
|
| 24 |
+
from keras.src.metrics.f_score_metrics import FBetaScore
|
| 25 |
+
from keras.src.metrics.hinge_metrics import CategoricalHinge
|
| 26 |
+
from keras.src.metrics.hinge_metrics import Hinge
|
| 27 |
+
from keras.src.metrics.hinge_metrics import SquaredHinge
|
| 28 |
+
from keras.src.metrics.iou_metrics import BinaryIoU
|
| 29 |
+
from keras.src.metrics.iou_metrics import IoU
|
| 30 |
+
from keras.src.metrics.iou_metrics import MeanIoU
|
| 31 |
+
from keras.src.metrics.iou_metrics import OneHotIoU
|
| 32 |
+
from keras.src.metrics.iou_metrics import OneHotMeanIoU
|
| 33 |
+
from keras.src.metrics.metric import Metric
|
| 34 |
+
from keras.src.metrics.probabilistic_metrics import BinaryCrossentropy
|
| 35 |
+
from keras.src.metrics.probabilistic_metrics import CategoricalCrossentropy
|
| 36 |
+
from keras.src.metrics.probabilistic_metrics import KLDivergence
|
| 37 |
+
from keras.src.metrics.probabilistic_metrics import Poisson
|
| 38 |
+
from keras.src.metrics.probabilistic_metrics import (
|
| 39 |
+
SparseCategoricalCrossentropy,
|
| 40 |
+
)
|
| 41 |
+
from keras.src.metrics.reduction_metrics import Mean
|
| 42 |
+
from keras.src.metrics.reduction_metrics import MeanMetricWrapper
|
| 43 |
+
from keras.src.metrics.reduction_metrics import Sum
|
| 44 |
+
from keras.src.metrics.regression_metrics import CosineSimilarity
|
| 45 |
+
from keras.src.metrics.regression_metrics import LogCoshError
|
| 46 |
+
from keras.src.metrics.regression_metrics import MeanAbsoluteError
|
| 47 |
+
from keras.src.metrics.regression_metrics import MeanAbsolutePercentageError
|
| 48 |
+
from keras.src.metrics.regression_metrics import MeanSquaredError
|
| 49 |
+
from keras.src.metrics.regression_metrics import MeanSquaredLogarithmicError
|
| 50 |
+
from keras.src.metrics.regression_metrics import R2Score
|
| 51 |
+
from keras.src.metrics.regression_metrics import RootMeanSquaredError
|
| 52 |
+
from keras.src.saving import serialization_lib
|
| 53 |
+
from keras.src.utils.naming import to_snake_case
|
| 54 |
+
|
| 55 |
+
ALL_OBJECTS = {
|
| 56 |
+
# Base
|
| 57 |
+
Metric,
|
| 58 |
+
Mean,
|
| 59 |
+
Sum,
|
| 60 |
+
MeanMetricWrapper,
|
| 61 |
+
# Regression
|
| 62 |
+
MeanSquaredError,
|
| 63 |
+
RootMeanSquaredError,
|
| 64 |
+
MeanAbsoluteError,
|
| 65 |
+
MeanAbsolutePercentageError,
|
| 66 |
+
MeanSquaredLogarithmicError,
|
| 67 |
+
CosineSimilarity,
|
| 68 |
+
LogCoshError,
|
| 69 |
+
R2Score,
|
| 70 |
+
# Classification
|
| 71 |
+
AUC,
|
| 72 |
+
FalseNegatives,
|
| 73 |
+
FalsePositives,
|
| 74 |
+
Precision,
|
| 75 |
+
PrecisionAtRecall,
|
| 76 |
+
Recall,
|
| 77 |
+
RecallAtPrecision,
|
| 78 |
+
SensitivityAtSpecificity,
|
| 79 |
+
SpecificityAtSensitivity,
|
| 80 |
+
TrueNegatives,
|
| 81 |
+
TruePositives,
|
| 82 |
+
# Correlation
|
| 83 |
+
ConcordanceCorrelation,
|
| 84 |
+
PearsonCorrelation,
|
| 85 |
+
# Hinge
|
| 86 |
+
Hinge,
|
| 87 |
+
SquaredHinge,
|
| 88 |
+
CategoricalHinge,
|
| 89 |
+
# Probabilistic
|
| 90 |
+
KLDivergence,
|
| 91 |
+
Poisson,
|
| 92 |
+
BinaryCrossentropy,
|
| 93 |
+
CategoricalCrossentropy,
|
| 94 |
+
SparseCategoricalCrossentropy,
|
| 95 |
+
# Accuracy
|
| 96 |
+
Accuracy,
|
| 97 |
+
BinaryAccuracy,
|
| 98 |
+
CategoricalAccuracy,
|
| 99 |
+
SparseCategoricalAccuracy,
|
| 100 |
+
TopKCategoricalAccuracy,
|
| 101 |
+
SparseTopKCategoricalAccuracy,
|
| 102 |
+
# F-Score
|
| 103 |
+
F1Score,
|
| 104 |
+
FBetaScore,
|
| 105 |
+
# IoU
|
| 106 |
+
IoU,
|
| 107 |
+
BinaryIoU,
|
| 108 |
+
MeanIoU,
|
| 109 |
+
OneHotIoU,
|
| 110 |
+
OneHotMeanIoU,
|
| 111 |
+
}
|
| 112 |
+
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
|
| 113 |
+
ALL_OBJECTS_DICT.update(
|
| 114 |
+
{to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}
|
| 115 |
+
)
|
| 116 |
+
# TODO: Align with `tf.keras` and set the name attribute of metrics
|
| 117 |
+
# with the key name. Currently it uses default name of class definitions.
|
| 118 |
+
ALL_OBJECTS_DICT.update(
|
| 119 |
+
{
|
| 120 |
+
"bce": BinaryCrossentropy,
|
| 121 |
+
"BCE": BinaryCrossentropy,
|
| 122 |
+
"mse": MeanSquaredError,
|
| 123 |
+
"MSE": MeanSquaredError,
|
| 124 |
+
"mae": MeanAbsoluteError,
|
| 125 |
+
"MAE": MeanAbsoluteError,
|
| 126 |
+
"mape": MeanAbsolutePercentageError,
|
| 127 |
+
"MAPE": MeanAbsolutePercentageError,
|
| 128 |
+
"msle": MeanSquaredLogarithmicError,
|
| 129 |
+
"MSLE": MeanSquaredLogarithmicError,
|
| 130 |
+
}
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@keras_export("keras.metrics.serialize")
|
| 135 |
+
def serialize(metric):
|
| 136 |
+
"""Serializes metric function or `Metric` instance.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
metric: A Keras `Metric` instance or a metric function.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Metric configuration dictionary.
|
| 143 |
+
"""
|
| 144 |
+
return serialization_lib.serialize_keras_object(metric)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@keras_export("keras.metrics.deserialize")
|
| 148 |
+
def deserialize(config, custom_objects=None):
|
| 149 |
+
"""Deserializes a serialized metric class/function instance.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
config: Metric configuration.
|
| 153 |
+
custom_objects: Optional dictionary mapping names (strings)
|
| 154 |
+
to custom objects (classes and functions) to be
|
| 155 |
+
considered during deserialization.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
A Keras `Metric` instance or a metric function.
|
| 159 |
+
"""
|
| 160 |
+
return serialization_lib.deserialize_keras_object(
|
| 161 |
+
config,
|
| 162 |
+
module_objects=ALL_OBJECTS_DICT,
|
| 163 |
+
custom_objects=custom_objects,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@keras_export("keras.metrics.get")
|
| 168 |
+
def get(identifier):
|
| 169 |
+
"""Retrieves a Keras metric as a `function`/`Metric` class instance.
|
| 170 |
+
|
| 171 |
+
The `identifier` may be the string name of a metric function or class.
|
| 172 |
+
|
| 173 |
+
>>> metric = metrics.get("categorical_crossentropy")
|
| 174 |
+
>>> type(metric)
|
| 175 |
+
<class 'function'>
|
| 176 |
+
>>> metric = metrics.get("CategoricalCrossentropy")
|
| 177 |
+
>>> type(metric)
|
| 178 |
+
<class '...metrics.CategoricalCrossentropy'>
|
| 179 |
+
|
| 180 |
+
You can also specify `config` of the metric to this function by passing dict
|
| 181 |
+
containing `class_name` and `config` as an identifier. Also note that the
|
| 182 |
+
`class_name` must map to a `Metric` class
|
| 183 |
+
|
| 184 |
+
>>> identifier = {"class_name": "CategoricalCrossentropy",
|
| 185 |
+
... "config": {"from_logits": True}}
|
| 186 |
+
>>> metric = metrics.get(identifier)
|
| 187 |
+
>>> type(metric)
|
| 188 |
+
<class '...metrics.CategoricalCrossentropy'>
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
identifier: A metric identifier. One of None or string name of a metric
|
| 192 |
+
function/class or metric configuration dictionary or a metric
|
| 193 |
+
function or a metric class instance
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
A Keras metric as a `function`/ `Metric` class instance.
|
| 197 |
+
"""
|
| 198 |
+
if identifier is None:
|
| 199 |
+
return None
|
| 200 |
+
if isinstance(identifier, dict):
|
| 201 |
+
obj = deserialize(identifier)
|
| 202 |
+
elif isinstance(identifier, str):
|
| 203 |
+
obj = ALL_OBJECTS_DICT.get(identifier, None)
|
| 204 |
+
else:
|
| 205 |
+
obj = identifier
|
| 206 |
+
if callable(obj):
|
| 207 |
+
if inspect.isclass(obj):
|
| 208 |
+
obj = obj()
|
| 209 |
+
return obj
|
| 210 |
+
else:
|
| 211 |
+
raise ValueError(f"Could not interpret metric identifier: {identifier}")
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (5.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/accuracy_metrics.cpython-310.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/confusion_metrics.cpython-310.pyc
ADDED
|
Binary file (48.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/correlation_metrics.cpython-310.pyc
ADDED
|
Binary file (6.58 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/f_score_metrics.cpython-310.pyc
ADDED
|
Binary file (9.84 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/hinge_metrics.cpython-310.pyc
ADDED
|
Binary file (3.64 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/iou_metrics.cpython-310.pyc
ADDED
|
Binary file (24.6 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/metric.cpython-310.pyc
ADDED
|
Binary file (8.74 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/metrics_utils.cpython-310.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/probabilistic_metrics.cpython-310.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/reduction_metrics.cpython-310.pyc
ADDED
|
Binary file (6.78 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/regression_metrics.cpython-310.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/accuracy_metrics.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src import ops
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.losses.loss import squeeze_or_expand_to_same_rank
|
| 5 |
+
from keras.src.metrics import reduction_metrics
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def accuracy(y_true, y_pred):
|
| 9 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 10 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 11 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 12 |
+
return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx())
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@keras_export("keras.metrics.Accuracy")
|
| 16 |
+
class Accuracy(reduction_metrics.MeanMetricWrapper):
|
| 17 |
+
"""Calculates how often predictions equal labels.
|
| 18 |
+
|
| 19 |
+
This metric creates two local variables, `total` and `count` that are used
|
| 20 |
+
to compute the frequency with which `y_pred` matches `y_true`. This
|
| 21 |
+
frequency is ultimately returned as `binary accuracy`: an idempotent
|
| 22 |
+
operation that simply divides `total` by `count`.
|
| 23 |
+
|
| 24 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 25 |
+
Use `sample_weight` of 0 to mask values.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
name: (Optional) string name of the metric instance.
|
| 29 |
+
dtype: (Optional) data type of the metric result.
|
| 30 |
+
|
| 31 |
+
Examples:
|
| 32 |
+
|
| 33 |
+
>>> m = keras.metrics.Accuracy()
|
| 34 |
+
>>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]])
|
| 35 |
+
>>> m.result()
|
| 36 |
+
0.75
|
| 37 |
+
|
| 38 |
+
>>> m.reset_state()
|
| 39 |
+
>>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]],
|
| 40 |
+
... sample_weight=[1, 1, 0, 0])
|
| 41 |
+
>>> m.result()
|
| 42 |
+
0.5
|
| 43 |
+
|
| 44 |
+
Usage with `compile()` API:
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
model.compile(optimizer='sgd',
|
| 48 |
+
loss='binary_crossentropy',
|
| 49 |
+
metrics=[keras.metrics.Accuracy()])
|
| 50 |
+
```
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, name="accuracy", dtype=None):
|
| 54 |
+
super().__init__(fn=accuracy, name=name, dtype=dtype)
|
| 55 |
+
# Metric should be maximized during optimization.
|
| 56 |
+
self._direction = "up"
|
| 57 |
+
|
| 58 |
+
def get_config(self):
|
| 59 |
+
return {"name": self.name, "dtype": self.dtype}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@keras_export("keras.metrics.binary_accuracy")
|
| 63 |
+
def binary_accuracy(y_true, y_pred, threshold=0.5):
|
| 64 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 65 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 66 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 67 |
+
threshold = ops.cast(threshold, y_pred.dtype)
|
| 68 |
+
y_pred = ops.cast(y_pred > threshold, y_true.dtype)
|
| 69 |
+
return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx())
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@keras_export("keras.metrics.BinaryAccuracy")
|
| 73 |
+
class BinaryAccuracy(reduction_metrics.MeanMetricWrapper):
|
| 74 |
+
"""Calculates how often predictions match binary labels.
|
| 75 |
+
|
| 76 |
+
This metric creates two local variables, `total` and `count` that are used
|
| 77 |
+
to compute the frequency with which `y_pred` matches `y_true`. This
|
| 78 |
+
frequency is ultimately returned as `binary accuracy`: an idempotent
|
| 79 |
+
operation that simply divides `total` by `count`.
|
| 80 |
+
|
| 81 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 82 |
+
Use `sample_weight` of 0 to mask values.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
name: (Optional) string name of the metric instance.
|
| 86 |
+
dtype: (Optional) data type of the metric result.
|
| 87 |
+
threshold: (Optional) Float representing the threshold for deciding
|
| 88 |
+
whether prediction values are 1 or 0.
|
| 89 |
+
|
| 90 |
+
Example:
|
| 91 |
+
|
| 92 |
+
>>> m = keras.metrics.BinaryAccuracy()
|
| 93 |
+
>>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]])
|
| 94 |
+
>>> m.result()
|
| 95 |
+
0.75
|
| 96 |
+
|
| 97 |
+
>>> m.reset_state()
|
| 98 |
+
>>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]],
|
| 99 |
+
... sample_weight=[1, 0, 0, 1])
|
| 100 |
+
>>> m.result()
|
| 101 |
+
0.5
|
| 102 |
+
|
| 103 |
+
Usage with `compile()` API:
|
| 104 |
+
|
| 105 |
+
```python
|
| 106 |
+
model.compile(optimizer='sgd',
|
| 107 |
+
loss='binary_crossentropy',
|
| 108 |
+
metrics=[keras.metrics.BinaryAccuracy()])
|
| 109 |
+
```
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, name="binary_accuracy", dtype=None, threshold=0.5):
|
| 113 |
+
super().__init__(
|
| 114 |
+
fn=binary_accuracy, name=name, dtype=dtype, threshold=threshold
|
| 115 |
+
)
|
| 116 |
+
self.threshold = threshold
|
| 117 |
+
# Metric should be maximized during optimization.
|
| 118 |
+
self._direction = "up"
|
| 119 |
+
|
| 120 |
+
def get_config(self):
|
| 121 |
+
return {
|
| 122 |
+
"name": self.name,
|
| 123 |
+
"dtype": self.dtype,
|
| 124 |
+
"threshold": self.threshold,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@keras_export("keras.metrics.categorical_accuracy")
|
| 129 |
+
def categorical_accuracy(y_true, y_pred):
|
| 130 |
+
y_true = ops.argmax(y_true, axis=-1)
|
| 131 |
+
|
| 132 |
+
reshape_matches = False
|
| 133 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 134 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 135 |
+
|
| 136 |
+
y_true_org_shape = ops.shape(y_true)
|
| 137 |
+
y_pred_rank = len(y_pred.shape)
|
| 138 |
+
y_true_rank = len(y_true.shape)
|
| 139 |
+
|
| 140 |
+
# If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
|
| 141 |
+
if (
|
| 142 |
+
(y_true_rank is not None)
|
| 143 |
+
and (y_pred_rank is not None)
|
| 144 |
+
and (len(y_true.shape) == len(y_pred.shape))
|
| 145 |
+
):
|
| 146 |
+
y_true = ops.squeeze(y_true, -1)
|
| 147 |
+
reshape_matches = True
|
| 148 |
+
y_pred = ops.argmax(y_pred, axis=-1)
|
| 149 |
+
|
| 150 |
+
# If the predicted output and actual output types don't match, force cast
|
| 151 |
+
# them to match.
|
| 152 |
+
if y_pred.dtype is not y_true.dtype:
|
| 153 |
+
y_pred = ops.cast(y_pred, dtype=y_true.dtype)
|
| 154 |
+
matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())
|
| 155 |
+
if reshape_matches:
|
| 156 |
+
matches = ops.reshape(matches, y_true_org_shape)
|
| 157 |
+
return matches
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@keras_export("keras.metrics.CategoricalAccuracy")
|
| 161 |
+
class CategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
|
| 162 |
+
"""Calculates how often predictions match one-hot labels.
|
| 163 |
+
|
| 164 |
+
You can provide logits of classes as `y_pred`, since argmax of
|
| 165 |
+
logits and probabilities are same.
|
| 166 |
+
|
| 167 |
+
This metric creates two local variables, `total` and `count` that are used
|
| 168 |
+
to compute the frequency with which `y_pred` matches `y_true`. This
|
| 169 |
+
frequency is ultimately returned as `categorical accuracy`: an idempotent
|
| 170 |
+
operation that simply divides `total` by `count`.
|
| 171 |
+
|
| 172 |
+
`y_pred` and `y_true` should be passed in as vectors of probabilities,
|
| 173 |
+
rather than as labels. If necessary, use `ops.one_hot` to expand `y_true` as
|
| 174 |
+
a vector.
|
| 175 |
+
|
| 176 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 177 |
+
Use `sample_weight` of 0 to mask values.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
name: (Optional) string name of the metric instance.
|
| 181 |
+
dtype: (Optional) data type of the metric result.
|
| 182 |
+
|
| 183 |
+
Example:
|
| 184 |
+
|
| 185 |
+
>>> m = keras.metrics.CategoricalAccuracy()
|
| 186 |
+
>>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
|
| 187 |
+
... [0.05, 0.95, 0]])
|
| 188 |
+
>>> m.result()
|
| 189 |
+
0.5
|
| 190 |
+
|
| 191 |
+
>>> m.reset_state()
|
| 192 |
+
>>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
|
| 193 |
+
... [0.05, 0.95, 0]],
|
| 194 |
+
... sample_weight=[0.7, 0.3])
|
| 195 |
+
>>> m.result()
|
| 196 |
+
0.3
|
| 197 |
+
|
| 198 |
+
Usage with `compile()` API:
|
| 199 |
+
|
| 200 |
+
```python
|
| 201 |
+
model.compile(optimizer='sgd',
|
| 202 |
+
loss='categorical_crossentropy',
|
| 203 |
+
metrics=[keras.metrics.CategoricalAccuracy()])
|
| 204 |
+
```
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def __init__(self, name="categorical_accuracy", dtype=None):
|
| 208 |
+
super().__init__(fn=categorical_accuracy, name=name, dtype=dtype)
|
| 209 |
+
# Metric should be maximized during optimization.
|
| 210 |
+
self._direction = "up"
|
| 211 |
+
|
| 212 |
+
def get_config(self):
|
| 213 |
+
return {"name": self.name, "dtype": self.dtype}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@keras_export("keras.metrics.sparse_categorical_accuracy")
|
| 217 |
+
def sparse_categorical_accuracy(y_true, y_pred):
|
| 218 |
+
reshape_matches = False
|
| 219 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 220 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 221 |
+
y_true_org_shape = ops.shape(y_true)
|
| 222 |
+
y_pred_rank = len(y_pred.shape)
|
| 223 |
+
y_true_rank = len(y_true.shape)
|
| 224 |
+
|
| 225 |
+
# If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
|
| 226 |
+
if (
|
| 227 |
+
(y_true_rank is not None)
|
| 228 |
+
and (y_pred_rank is not None)
|
| 229 |
+
and (len(y_true.shape) == len(y_pred.shape))
|
| 230 |
+
and ops.shape(y_true)[-1] == 1
|
| 231 |
+
):
|
| 232 |
+
y_true = ops.squeeze(y_true, -1)
|
| 233 |
+
reshape_matches = True
|
| 234 |
+
y_pred = ops.argmax(y_pred, axis=-1)
|
| 235 |
+
|
| 236 |
+
# If the predicted output and actual output types don't match, force cast
|
| 237 |
+
# them to match.
|
| 238 |
+
if y_pred.dtype is not y_true.dtype:
|
| 239 |
+
y_pred = ops.cast(y_pred, y_true.dtype)
|
| 240 |
+
matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())
|
| 241 |
+
if reshape_matches:
|
| 242 |
+
matches = ops.reshape(matches, y_true_org_shape)
|
| 243 |
+
# if shape is (num_samples, 1) squeeze
|
| 244 |
+
if len(matches.shape) > 1 and matches.shape[-1] == 1:
|
| 245 |
+
matches = ops.squeeze(matches, -1)
|
| 246 |
+
return matches
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@keras_export("keras.metrics.SparseCategoricalAccuracy")
|
| 250 |
+
class SparseCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
|
| 251 |
+
"""Calculates how often predictions match integer labels.
|
| 252 |
+
|
| 253 |
+
```python
|
| 254 |
+
acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1))
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
You can provide logits of classes as `y_pred`, since argmax of
|
| 258 |
+
logits and probabilities are same.
|
| 259 |
+
|
| 260 |
+
This metric creates two local variables, `total` and `count` that are used
|
| 261 |
+
to compute the frequency with which `y_pred` matches `y_true`. This
|
| 262 |
+
frequency is ultimately returned as `sparse categorical accuracy`: an
|
| 263 |
+
idempotent operation that simply divides `total` by `count`.
|
| 264 |
+
|
| 265 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 266 |
+
Use `sample_weight` of 0 to mask values.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
name: (Optional) string name of the metric instance.
|
| 270 |
+
dtype: (Optional) data type of the metric result.
|
| 271 |
+
|
| 272 |
+
Example:
|
| 273 |
+
|
| 274 |
+
>>> m = keras.metrics.SparseCategoricalAccuracy()
|
| 275 |
+
>>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]])
|
| 276 |
+
>>> m.result()
|
| 277 |
+
0.5
|
| 278 |
+
|
| 279 |
+
>>> m.reset_state()
|
| 280 |
+
>>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]],
|
| 281 |
+
... sample_weight=[0.7, 0.3])
|
| 282 |
+
>>> m.result()
|
| 283 |
+
0.3
|
| 284 |
+
|
| 285 |
+
Usage with `compile()` API:
|
| 286 |
+
|
| 287 |
+
```python
|
| 288 |
+
model.compile(optimizer='sgd',
|
| 289 |
+
loss='sparse_categorical_crossentropy',
|
| 290 |
+
metrics=[keras.metrics.SparseCategoricalAccuracy()])
|
| 291 |
+
```
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
def __init__(self, name="sparse_categorical_accuracy", dtype=None):
|
| 295 |
+
super().__init__(fn=sparse_categorical_accuracy, name=name, dtype=dtype)
|
| 296 |
+
# Metric should be maximized during optimization.
|
| 297 |
+
self._direction = "up"
|
| 298 |
+
|
| 299 |
+
def get_config(self):
|
| 300 |
+
return {"name": self.name, "dtype": self.dtype}
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@keras_export("keras.metrics.top_k_categorical_accuracy")
|
| 304 |
+
def top_k_categorical_accuracy(y_true, y_pred, k=5):
|
| 305 |
+
reshape_matches = False
|
| 306 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 307 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 308 |
+
y_true = ops.argmax(y_true, axis=-1)
|
| 309 |
+
y_true_rank = len(y_true.shape)
|
| 310 |
+
y_pred_rank = len(y_pred.shape)
|
| 311 |
+
y_true_org_shape = ops.shape(y_true)
|
| 312 |
+
|
| 313 |
+
# Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
|
| 314 |
+
if (y_true_rank is not None) and (y_pred_rank is not None):
|
| 315 |
+
if y_pred_rank > 2:
|
| 316 |
+
y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]])
|
| 317 |
+
if y_true_rank > 1:
|
| 318 |
+
reshape_matches = True
|
| 319 |
+
y_true = ops.reshape(y_true, [-1])
|
| 320 |
+
|
| 321 |
+
matches = ops.cast(
|
| 322 |
+
ops.in_top_k(ops.cast(y_true, "int32"), y_pred, k=k),
|
| 323 |
+
dtype=backend.floatx(),
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# returned matches is expected to have same shape as y_true input
|
| 327 |
+
if reshape_matches:
|
| 328 |
+
matches = ops.reshape(matches, y_true_org_shape)
|
| 329 |
+
|
| 330 |
+
return matches
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@keras_export("keras.metrics.TopKCategoricalAccuracy")
|
| 334 |
+
class TopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
|
| 335 |
+
"""Computes how often targets are in the top `K` predictions.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
k: (Optional) Number of top elements to look at for computing accuracy.
|
| 339 |
+
Defaults to `5`.
|
| 340 |
+
name: (Optional) string name of the metric instance.
|
| 341 |
+
dtype: (Optional) data type of the metric result.
|
| 342 |
+
|
| 343 |
+
Example:
|
| 344 |
+
|
| 345 |
+
>>> m = keras.metrics.TopKCategoricalAccuracy(k=1)
|
| 346 |
+
>>> m.update_state([[0, 0, 1], [0, 1, 0]],
|
| 347 |
+
... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
|
| 348 |
+
>>> m.result()
|
| 349 |
+
0.5
|
| 350 |
+
|
| 351 |
+
>>> m.reset_state()
|
| 352 |
+
>>> m.update_state([[0, 0, 1], [0, 1, 0]],
|
| 353 |
+
... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
|
| 354 |
+
... sample_weight=[0.7, 0.3])
|
| 355 |
+
>>> m.result()
|
| 356 |
+
0.3
|
| 357 |
+
|
| 358 |
+
Usage with `compile()` API:
|
| 359 |
+
|
| 360 |
+
```python
|
| 361 |
+
model.compile(optimizer='sgd',
|
| 362 |
+
loss='categorical_crossentropy',
|
| 363 |
+
metrics=[keras.metrics.TopKCategoricalAccuracy()])
|
| 364 |
+
```
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
def __init__(self, k=5, name="top_k_categorical_accuracy", dtype=None):
|
| 368 |
+
super().__init__(
|
| 369 |
+
fn=top_k_categorical_accuracy,
|
| 370 |
+
name=name,
|
| 371 |
+
dtype=dtype,
|
| 372 |
+
k=k,
|
| 373 |
+
)
|
| 374 |
+
self.k = k
|
| 375 |
+
# Metric should be maximized during optimization.
|
| 376 |
+
self._direction = "up"
|
| 377 |
+
|
| 378 |
+
def get_config(self):
|
| 379 |
+
return {"name": self.name, "dtype": self.dtype, "k": self.k}
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
@keras_export("keras.metrics.sparse_top_k_categorical_accuracy")
|
| 383 |
+
def sparse_top_k_categorical_accuracy(
|
| 384 |
+
y_true, y_pred, k=5, from_sorted_ids=False
|
| 385 |
+
):
|
| 386 |
+
"""Computes how often integer targets are in the top `K` predictions.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
y_true: A tensor of shape `(batch_size)` representing indices or IDs of
|
| 390 |
+
true categories.
|
| 391 |
+
y_pred: If `from_sorted_ids=False`, a tensor of shape
|
| 392 |
+
`(batch_size, num_categories)` containing the scores for each sample
|
| 393 |
+
for all possible categories. If `from_sorted_ids=True`, a tensor of
|
| 394 |
+
shape `(batch_size, N)` containing indices or IDs of the top `N`
|
| 395 |
+
categories in order from highest score to lowest score.
|
| 396 |
+
k: (Optional) Number of top elements to look at for computing accuracy.
|
| 397 |
+
Defaults to `5`.
|
| 398 |
+
from_sorted_ids: (Optional) Whether `y_pred` is sorted category IDs or
|
| 399 |
+
scores for all categories (the default).
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
A tensor with the same shape as `y_true` containing ones where `y_true`
|
| 403 |
+
is in the top `k` and zeros elsewhere.
|
| 404 |
+
"""
|
| 405 |
+
reshape_matches = False
|
| 406 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 407 |
+
y_true_dtype = y_pred.dtype if from_sorted_ids else "int32"
|
| 408 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_true_dtype)
|
| 409 |
+
y_true_rank = len(y_true.shape)
|
| 410 |
+
y_pred_rank = len(y_pred.shape)
|
| 411 |
+
y_true_org_shape = ops.shape(y_true)
|
| 412 |
+
|
| 413 |
+
# Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
|
| 414 |
+
if (y_true_rank is not None) and (y_pred_rank is not None):
|
| 415 |
+
if y_pred_rank > 2:
|
| 416 |
+
y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]])
|
| 417 |
+
if y_true_rank > 1:
|
| 418 |
+
reshape_matches = True
|
| 419 |
+
y_true = ops.reshape(y_true, [-1])
|
| 420 |
+
|
| 421 |
+
if from_sorted_ids:
|
| 422 |
+
# By slicing the first k items, we assume they are sorted by score.
|
| 423 |
+
# Reduce with `any` to count multiple matches only once.
|
| 424 |
+
matches = ops.any(
|
| 425 |
+
ops.equal(ops.expand_dims(y_true, axis=1), y_pred[:, :k]), axis=1
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
matches = ops.in_top_k(y_true, y_pred, k=k)
|
| 429 |
+
|
| 430 |
+
matches = ops.cast(matches, dtype=backend.floatx())
|
| 431 |
+
|
| 432 |
+
# returned matches is expected to have same shape as y_true input
|
| 433 |
+
if reshape_matches:
|
| 434 |
+
matches = ops.reshape(matches, y_true_org_shape)
|
| 435 |
+
|
| 436 |
+
return matches
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
@keras_export("keras.metrics.SparseTopKCategoricalAccuracy")
|
| 440 |
+
class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
|
| 441 |
+
"""Computes how often integer targets are in the top `K` predictions.
|
| 442 |
+
|
| 443 |
+
By default, the arguments expected by `update_state()` are:
|
| 444 |
+
- `y_true`: a tensor of shape `(batch_size)` representing indices of true
|
| 445 |
+
categories.
|
| 446 |
+
- `y_pred`: a tensor of shape `(batch_size, num_categories)` containing the
|
| 447 |
+
scores for each sample for all possible categories.
|
| 448 |
+
|
| 449 |
+
With `from_sorted_ids=True`, the arguments expected by `update_state` are:
|
| 450 |
+
- `y_true`: a tensor of shape `(batch_size)` representing indices or IDs of
|
| 451 |
+
true categories.
|
| 452 |
+
- `y_pred`: a tensor of shape `(batch_size, N)` containing the indices or
|
| 453 |
+
IDs of the top `N` categories sorted in order from highest score to
|
| 454 |
+
lowest score. `N` must be greater or equal to `k`.
|
| 455 |
+
|
| 456 |
+
The `from_sorted_ids=True` option can be more efficient when the set of
|
| 457 |
+
categories is very large and the model has an optimized way to retrieve the
|
| 458 |
+
top ones either without scoring or without maintaining the scores for all
|
| 459 |
+
the possible categories.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
k: (Optional) Number of top elements to look at for computing accuracy.
|
| 463 |
+
Defaults to `5`.
|
| 464 |
+
name: (Optional) string name of the metric instance.
|
| 465 |
+
dtype: (Optional) data type of the metric result.
|
| 466 |
+
from_sorted_ids: (Optional) When `False`, the default, the tensor passed
|
| 467 |
+
in `y_pred` contains the unsorted scores of all possible categories.
|
| 468 |
+
When `True`, `y_pred` contains a the indices or IDs for the top
|
| 469 |
+
categories.
|
| 470 |
+
|
| 471 |
+
Example:
|
| 472 |
+
|
| 473 |
+
>>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1)
|
| 474 |
+
>>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
|
| 475 |
+
>>> m.result()
|
| 476 |
+
0.5
|
| 477 |
+
|
| 478 |
+
>>> m.reset_state()
|
| 479 |
+
>>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
|
| 480 |
+
... sample_weight=[0.7, 0.3])
|
| 481 |
+
>>> m.result()
|
| 482 |
+
0.3
|
| 483 |
+
|
| 484 |
+
>>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1,
|
| 485 |
+
... from_sorted_ids=True)
|
| 486 |
+
>>> m.update_state([2, 1], [[1, 0, 3], [1, 2, 3]])
|
| 487 |
+
>>> m.result()
|
| 488 |
+
0.5
|
| 489 |
+
|
| 490 |
+
Usage with `compile()` API:
|
| 491 |
+
|
| 492 |
+
```python
|
| 493 |
+
model.compile(optimizer='sgd',
|
| 494 |
+
loss='sparse_categorical_crossentropy',
|
| 495 |
+
metrics=[keras.metrics.SparseTopKCategoricalAccuracy()])
|
| 496 |
+
```
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
def __init__(
|
| 500 |
+
self,
|
| 501 |
+
k=5,
|
| 502 |
+
name="sparse_top_k_categorical_accuracy",
|
| 503 |
+
dtype=None,
|
| 504 |
+
from_sorted_ids=False,
|
| 505 |
+
):
|
| 506 |
+
super().__init__(
|
| 507 |
+
fn=sparse_top_k_categorical_accuracy,
|
| 508 |
+
name=name,
|
| 509 |
+
dtype=dtype,
|
| 510 |
+
k=k,
|
| 511 |
+
from_sorted_ids=from_sorted_ids,
|
| 512 |
+
)
|
| 513 |
+
self.k = k
|
| 514 |
+
self.from_sorted_ids = from_sorted_ids
|
| 515 |
+
# Metric should be maximized during optimization.
|
| 516 |
+
self._direction = "up"
|
| 517 |
+
|
| 518 |
+
def get_config(self):
|
| 519 |
+
config = {"name": self.name, "dtype": self.dtype, "k": self.k}
|
| 520 |
+
if self.from_sorted_ids:
|
| 521 |
+
config["from_sorted_ids"] = True
|
| 522 |
+
return config
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/confusion_metrics.py
ADDED
|
@@ -0,0 +1,1576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from keras.src import activations
|
| 4 |
+
from keras.src import backend
|
| 5 |
+
from keras.src import initializers
|
| 6 |
+
from keras.src import ops
|
| 7 |
+
from keras.src.api_export import keras_export
|
| 8 |
+
from keras.src.metrics import metrics_utils
|
| 9 |
+
from keras.src.metrics.metric import Metric
|
| 10 |
+
from keras.src.utils.python_utils import to_list
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class _ConfusionMatrixConditionCount(Metric):
|
| 14 |
+
"""Calculates the number of the given confusion matrix condition.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix`
|
| 18 |
+
conditions.
|
| 19 |
+
thresholds: (Optional) Defaults to `0.5`. A float value or a python list
|
| 20 |
+
/ tuple of float threshold values in `[0, 1]`. A threshold is
|
| 21 |
+
compared with prediction values to determine the truth value of
|
| 22 |
+
predictions (i.e., above the threshold is `True`, below is `False`).
|
| 23 |
+
One metric value is generated for each threshold value.
|
| 24 |
+
name: (Optional) string name of the metric instance.
|
| 25 |
+
dtype: (Optional) data type of the metric result.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self, confusion_matrix_cond, thresholds=None, name=None, dtype=None
|
| 30 |
+
):
|
| 31 |
+
super().__init__(name=name, dtype=dtype)
|
| 32 |
+
self._confusion_matrix_cond = confusion_matrix_cond
|
| 33 |
+
self.init_thresholds = thresholds
|
| 34 |
+
self.thresholds = metrics_utils.parse_init_thresholds(
|
| 35 |
+
thresholds, default_threshold=0.5
|
| 36 |
+
)
|
| 37 |
+
self._thresholds_distributed_evenly = (
|
| 38 |
+
metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
|
| 39 |
+
)
|
| 40 |
+
self.accumulator = self.add_variable(
|
| 41 |
+
shape=(len(self.thresholds),),
|
| 42 |
+
initializer=initializers.Zeros(),
|
| 43 |
+
name="accumulator",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 47 |
+
"""Accumulates the metric statistics.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
y_true: The ground truth values.
|
| 51 |
+
y_pred: The predicted values.
|
| 52 |
+
sample_weight: Optional weighting of each example. Defaults to `1`.
|
| 53 |
+
Can be a tensor whose rank is either 0, or the same rank as
|
| 54 |
+
`y_true`, and must be broadcastable to `y_true`.
|
| 55 |
+
"""
|
| 56 |
+
return metrics_utils.update_confusion_matrix_variables(
|
| 57 |
+
{self._confusion_matrix_cond: self.accumulator},
|
| 58 |
+
y_true,
|
| 59 |
+
y_pred,
|
| 60 |
+
thresholds=self.thresholds,
|
| 61 |
+
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
|
| 62 |
+
sample_weight=sample_weight,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def result(self):
|
| 66 |
+
if len(self.thresholds) == 1:
|
| 67 |
+
result = self.accumulator[0]
|
| 68 |
+
else:
|
| 69 |
+
result = self.accumulator
|
| 70 |
+
return backend.convert_to_tensor(result)
|
| 71 |
+
|
| 72 |
+
def get_config(self):
|
| 73 |
+
config = {"thresholds": self.init_thresholds}
|
| 74 |
+
base_config = super().get_config()
|
| 75 |
+
return {**base_config, **config}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@keras_export("keras.metrics.FalsePositives")
|
| 79 |
+
class FalsePositives(_ConfusionMatrixConditionCount):
|
| 80 |
+
"""Calculates the number of false positives.
|
| 81 |
+
|
| 82 |
+
If `sample_weight` is given, calculates the sum of the weights of
|
| 83 |
+
false positives. This metric creates one local variable, `accumulator`
|
| 84 |
+
that is used to keep track of the number of false positives.
|
| 85 |
+
|
| 86 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 87 |
+
Use `sample_weight` of 0 to mask values.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
thresholds: (Optional) Defaults to `0.5`. A float value, or a Python
|
| 91 |
+
list/tuple of float threshold values in `[0, 1]`. A threshold is
|
| 92 |
+
compared with prediction values to determine the truth value of
|
| 93 |
+
predictions (i.e., above the threshold is `True`, below is `False`).
|
| 94 |
+
If used with a loss function that sets `from_logits=True` (i.e. no
|
| 95 |
+
sigmoid applied to predictions), `thresholds` should be set to 0.
|
| 96 |
+
One metric value is generated for each threshold value.
|
| 97 |
+
name: (Optional) string name of the metric instance.
|
| 98 |
+
dtype: (Optional) data type of the metric result.
|
| 99 |
+
|
| 100 |
+
Examples:
|
| 101 |
+
|
| 102 |
+
>>> m = keras.metrics.FalsePositives()
|
| 103 |
+
>>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
|
| 104 |
+
>>> m.result()
|
| 105 |
+
2.0
|
| 106 |
+
|
| 107 |
+
>>> m.reset_state()
|
| 108 |
+
>>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0])
|
| 109 |
+
>>> m.result()
|
| 110 |
+
1.0
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, thresholds=None, name=None, dtype=None):
|
| 114 |
+
super().__init__(
|
| 115 |
+
confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
|
| 116 |
+
thresholds=thresholds,
|
| 117 |
+
name=name,
|
| 118 |
+
dtype=dtype,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@keras_export("keras.metrics.FalseNegatives")
|
| 123 |
+
class FalseNegatives(_ConfusionMatrixConditionCount):
|
| 124 |
+
"""Calculates the number of false negatives.
|
| 125 |
+
|
| 126 |
+
If `sample_weight` is given, calculates the sum of the weights of
|
| 127 |
+
false negatives. This metric creates one local variable, `accumulator`
|
| 128 |
+
that is used to keep track of the number of false negatives.
|
| 129 |
+
|
| 130 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 131 |
+
Use `sample_weight` of 0 to mask values.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
thresholds: (Optional) Defaults to `0.5`. A float value, or a Python
|
| 135 |
+
list/tuple of float threshold values in `[0, 1]`. A threshold is
|
| 136 |
+
compared with prediction values to determine the truth value of
|
| 137 |
+
predictions (i.e., above the threshold is `True`, below is `False`).
|
| 138 |
+
If used with a loss function that sets `from_logits=True` (i.e. no
|
| 139 |
+
sigmoid applied to predictions), `thresholds` should be set to 0.
|
| 140 |
+
One metric value is generated for each threshold value.
|
| 141 |
+
name: (Optional) string name of the metric instance.
|
| 142 |
+
dtype: (Optional) data type of the metric result.
|
| 143 |
+
|
| 144 |
+
Example:
|
| 145 |
+
|
| 146 |
+
>>> m = keras.metrics.FalseNegatives()
|
| 147 |
+
>>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
|
| 148 |
+
>>> m.result()
|
| 149 |
+
2.0
|
| 150 |
+
|
| 151 |
+
>>> m.reset_state()
|
| 152 |
+
>>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0])
|
| 153 |
+
>>> m.result()
|
| 154 |
+
1.0
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, thresholds=None, name=None, dtype=None):
|
| 158 |
+
super().__init__(
|
| 159 |
+
confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES,
|
| 160 |
+
thresholds=thresholds,
|
| 161 |
+
name=name,
|
| 162 |
+
dtype=dtype,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@keras_export("keras.metrics.TrueNegatives")
|
| 167 |
+
class TrueNegatives(_ConfusionMatrixConditionCount):
|
| 168 |
+
"""Calculates the number of true negatives.
|
| 169 |
+
|
| 170 |
+
If `sample_weight` is given, calculates the sum of the weights of
|
| 171 |
+
true negatives. This metric creates one local variable, `accumulator`
|
| 172 |
+
that is used to keep track of the number of true negatives.
|
| 173 |
+
|
| 174 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 175 |
+
Use `sample_weight` of 0 to mask values.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
thresholds: (Optional) Defaults to `0.5`. A float value, or a Python
|
| 179 |
+
list/tuple of float threshold values in `[0, 1]`. A threshold is
|
| 180 |
+
compared with prediction values to determine the truth value of
|
| 181 |
+
predictions (i.e., above the threshold is `True`, below is `False`).
|
| 182 |
+
If used with a loss function that sets `from_logits=True` (i.e. no
|
| 183 |
+
sigmoid applied to predictions), `thresholds` should be set to 0.
|
| 184 |
+
One metric value is generated for each threshold value.
|
| 185 |
+
name: (Optional) string name of the metric instance.
|
| 186 |
+
dtype: (Optional) data type of the metric result.
|
| 187 |
+
|
| 188 |
+
Example:
|
| 189 |
+
|
| 190 |
+
>>> m = keras.metrics.TrueNegatives()
|
| 191 |
+
>>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0])
|
| 192 |
+
>>> m.result()
|
| 193 |
+
2.0
|
| 194 |
+
|
| 195 |
+
>>> m.reset_state()
|
| 196 |
+
>>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0])
|
| 197 |
+
>>> m.result()
|
| 198 |
+
1.0
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(self, thresholds=None, name=None, dtype=None):
|
| 202 |
+
super().__init__(
|
| 203 |
+
confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES,
|
| 204 |
+
thresholds=thresholds,
|
| 205 |
+
name=name,
|
| 206 |
+
dtype=dtype,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@keras_export("keras.metrics.TruePositives")
|
| 211 |
+
class TruePositives(_ConfusionMatrixConditionCount):
|
| 212 |
+
"""Calculates the number of true positives.
|
| 213 |
+
|
| 214 |
+
If `sample_weight` is given, calculates the sum of the weights of
|
| 215 |
+
true positives. This metric creates one local variable, `true_positives`
|
| 216 |
+
that is used to keep track of the number of true positives.
|
| 217 |
+
|
| 218 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 219 |
+
Use `sample_weight` of 0 to mask values.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
thresholds: (Optional) Defaults to `0.5`. A float value, or a Python
|
| 223 |
+
list/tuple of float threshold values in `[0, 1]`. A threshold is
|
| 224 |
+
compared with prediction values to determine the truth value of
|
| 225 |
+
predictions (i.e., above the threshold is `True`, below is `False`).
|
| 226 |
+
If used with a loss function that sets `from_logits=True` (i.e. no
|
| 227 |
+
sigmoid applied to predictions), `thresholds` should be set to 0.
|
| 228 |
+
One metric value is generated for each threshold value.
|
| 229 |
+
name: (Optional) string name of the metric instance.
|
| 230 |
+
dtype: (Optional) data type of the metric result.
|
| 231 |
+
|
| 232 |
+
Example:
|
| 233 |
+
|
| 234 |
+
>>> m = keras.metrics.TruePositives()
|
| 235 |
+
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
|
| 236 |
+
>>> m.result()
|
| 237 |
+
2.0
|
| 238 |
+
|
| 239 |
+
>>> m.reset_state()
|
| 240 |
+
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
|
| 241 |
+
>>> m.result()
|
| 242 |
+
1.0
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(self, thresholds=None, name=None, dtype=None):
|
| 246 |
+
super().__init__(
|
| 247 |
+
confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES,
|
| 248 |
+
thresholds=thresholds,
|
| 249 |
+
name=name,
|
| 250 |
+
dtype=dtype,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@keras_export("keras.metrics.Precision")
|
| 255 |
+
class Precision(Metric):
|
| 256 |
+
"""Computes the precision of the predictions with respect to the labels.
|
| 257 |
+
|
| 258 |
+
The metric creates two local variables, `true_positives` and
|
| 259 |
+
`false_positives` that are used to compute the precision. This value is
|
| 260 |
+
ultimately returned as `precision`, an idempotent operation that simply
|
| 261 |
+
divides `true_positives` by the sum of `true_positives` and
|
| 262 |
+
`false_positives`.
|
| 263 |
+
|
| 264 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 265 |
+
Use `sample_weight` of 0 to mask values.
|
| 266 |
+
|
| 267 |
+
If `top_k` is set, we'll calculate precision as how often on average a class
|
| 268 |
+
among the top-k classes with the highest predicted values of a batch entry
|
| 269 |
+
is correct and can be found in the label for that entry.
|
| 270 |
+
|
| 271 |
+
If `class_id` is specified, we calculate precision by considering only the
|
| 272 |
+
entries in the batch for which `class_id` is above the threshold and/or in
|
| 273 |
+
the top-k highest predictions, and computing the fraction of them for which
|
| 274 |
+
`class_id` is indeed a correct label.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
thresholds: (Optional) A float value, or a Python list/tuple of float
|
| 278 |
+
threshold values in `[0, 1]`. A threshold is compared with
|
| 279 |
+
prediction values to determine the truth value of predictions (i.e.,
|
| 280 |
+
above the threshold is `True`, below is `False`). If used with a
|
| 281 |
+
loss function that sets `from_logits=True` (i.e. no sigmoid applied
|
| 282 |
+
to predictions), `thresholds` should be set to 0. One metric value
|
| 283 |
+
is generated for each threshold value. If neither `thresholds` nor
|
| 284 |
+
`top_k` are set, the default is to calculate precision with
|
| 285 |
+
`thresholds=0.5`.
|
| 286 |
+
top_k: (Optional) Unset by default. An int value specifying the top-k
|
| 287 |
+
predictions to consider when calculating precision.
|
| 288 |
+
class_id: (Optional) Integer class ID for which we want binary metrics.
|
| 289 |
+
This must be in the half-open interval `[0, num_classes)`, where
|
| 290 |
+
`num_classes` is the last dimension of predictions.
|
| 291 |
+
name: (Optional) string name of the metric instance.
|
| 292 |
+
dtype: (Optional) data type of the metric result.
|
| 293 |
+
|
| 294 |
+
Example:
|
| 295 |
+
|
| 296 |
+
>>> m = keras.metrics.Precision()
|
| 297 |
+
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
|
| 298 |
+
>>> m.result()
|
| 299 |
+
0.6666667
|
| 300 |
+
|
| 301 |
+
>>> m.reset_state()
|
| 302 |
+
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
|
| 303 |
+
>>> m.result()
|
| 304 |
+
1.0
|
| 305 |
+
|
| 306 |
+
>>> # With top_k=2, it will calculate precision over y_true[:2]
|
| 307 |
+
>>> # and y_pred[:2]
|
| 308 |
+
>>> m = keras.metrics.Precision(top_k=2)
|
| 309 |
+
>>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
|
| 310 |
+
>>> m.result()
|
| 311 |
+
0.0
|
| 312 |
+
|
| 313 |
+
>>> # With top_k=4, it will calculate precision over y_true[:4]
|
| 314 |
+
>>> # and y_pred[:4]
|
| 315 |
+
>>> m = keras.metrics.Precision(top_k=4)
|
| 316 |
+
>>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
|
| 317 |
+
>>> m.result()
|
| 318 |
+
0.5
|
| 319 |
+
|
| 320 |
+
Usage with `compile()` API:
|
| 321 |
+
|
| 322 |
+
```python
|
| 323 |
+
model.compile(optimizer='sgd',
|
| 324 |
+
loss='binary_crossentropy',
|
| 325 |
+
metrics=[keras.metrics.Precision()])
|
| 326 |
+
```
|
| 327 |
+
|
| 328 |
+
Usage with a loss with `from_logits=True`:
|
| 329 |
+
|
| 330 |
+
```python
|
| 331 |
+
model.compile(optimizer='adam',
|
| 332 |
+
loss=keras.losses.BinaryCrossentropy(from_logits=True),
|
| 333 |
+
metrics=[keras.metrics.Precision(thresholds=0)])
|
| 334 |
+
```
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
def __init__(
|
| 338 |
+
self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None
|
| 339 |
+
):
|
| 340 |
+
super().__init__(name=name, dtype=dtype)
|
| 341 |
+
# Metric should be maximized during optimization.
|
| 342 |
+
self._direction = "up"
|
| 343 |
+
|
| 344 |
+
self.init_thresholds = thresholds
|
| 345 |
+
self.top_k = top_k
|
| 346 |
+
self.class_id = class_id
|
| 347 |
+
|
| 348 |
+
default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
|
| 349 |
+
self.thresholds = metrics_utils.parse_init_thresholds(
|
| 350 |
+
thresholds, default_threshold=default_threshold
|
| 351 |
+
)
|
| 352 |
+
self._thresholds_distributed_evenly = (
|
| 353 |
+
metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
|
| 354 |
+
)
|
| 355 |
+
self.true_positives = self.add_variable(
|
| 356 |
+
shape=(len(self.thresholds),),
|
| 357 |
+
initializer=initializers.Zeros(),
|
| 358 |
+
name="true_positives",
|
| 359 |
+
)
|
| 360 |
+
self.false_positives = self.add_variable(
|
| 361 |
+
shape=(len(self.thresholds),),
|
| 362 |
+
initializer=initializers.Zeros(),
|
| 363 |
+
name="false_positives",
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 367 |
+
"""Accumulates true positive and false positive statistics.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
y_true: The ground truth values, with the same dimensions as
|
| 371 |
+
`y_pred`. Will be cast to `bool`.
|
| 372 |
+
y_pred: The predicted values. Each element must be in the range
|
| 373 |
+
`[0, 1]`.
|
| 374 |
+
sample_weight: Optional weighting of each example. Defaults to `1`.
|
| 375 |
+
Can be a tensor whose rank is either 0, or the same rank as
|
| 376 |
+
`y_true`, and must be broadcastable to `y_true`.
|
| 377 |
+
"""
|
| 378 |
+
metrics_utils.update_confusion_matrix_variables(
|
| 379 |
+
{
|
| 380 |
+
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
|
| 381 |
+
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
|
| 382 |
+
},
|
| 383 |
+
y_true,
|
| 384 |
+
y_pred,
|
| 385 |
+
thresholds=self.thresholds,
|
| 386 |
+
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
|
| 387 |
+
top_k=self.top_k,
|
| 388 |
+
class_id=self.class_id,
|
| 389 |
+
sample_weight=sample_weight,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
def result(self):
|
| 393 |
+
result = ops.divide_no_nan(
|
| 394 |
+
self.true_positives,
|
| 395 |
+
ops.add(self.true_positives, self.false_positives),
|
| 396 |
+
)
|
| 397 |
+
return result[0] if len(self.thresholds) == 1 else result
|
| 398 |
+
|
| 399 |
+
def reset_state(self):
|
| 400 |
+
num_thresholds = len(to_list(self.thresholds))
|
| 401 |
+
self.true_positives.assign(ops.zeros((num_thresholds,)))
|
| 402 |
+
self.false_positives.assign(ops.zeros((num_thresholds,)))
|
| 403 |
+
|
| 404 |
+
def get_config(self):
|
| 405 |
+
config = {
|
| 406 |
+
"thresholds": self.init_thresholds,
|
| 407 |
+
"top_k": self.top_k,
|
| 408 |
+
"class_id": self.class_id,
|
| 409 |
+
}
|
| 410 |
+
base_config = super().get_config()
|
| 411 |
+
return {**base_config, **config}
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
@keras_export("keras.metrics.Recall")
|
| 415 |
+
class Recall(Metric):
|
| 416 |
+
"""Computes the recall of the predictions with respect to the labels.
|
| 417 |
+
|
| 418 |
+
This metric creates two local variables, `true_positives` and
|
| 419 |
+
`false_negatives`, that are used to compute the recall. This value is
|
| 420 |
+
ultimately returned as `recall`, an idempotent operation that simply divides
|
| 421 |
+
`true_positives` by the sum of `true_positives` and `false_negatives`.
|
| 422 |
+
|
| 423 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 424 |
+
Use `sample_weight` of 0 to mask values.
|
| 425 |
+
|
| 426 |
+
If `top_k` is set, recall will be computed as how often on average a class
|
| 427 |
+
among the labels of a batch entry is in the top-k predictions.
|
| 428 |
+
|
| 429 |
+
If `class_id` is specified, we calculate recall by considering only the
|
| 430 |
+
entries in the batch for which `class_id` is in the label, and computing the
|
| 431 |
+
fraction of them for which `class_id` is above the threshold and/or in the
|
| 432 |
+
top-k predictions.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
thresholds: (Optional) A float value, or a Python list/tuple of float
|
| 436 |
+
threshold values in `[0, 1]`. A threshold is compared with
|
| 437 |
+
prediction values to determine the truth value of predictions (i.e.,
|
| 438 |
+
above the threshold is `True`, below is `False`). If used with a
|
| 439 |
+
loss function that sets `from_logits=True` (i.e. no sigmoid
|
| 440 |
+
applied to predictions), `thresholds` should be set to 0.
|
| 441 |
+
One metric value is generated for each threshold value.
|
| 442 |
+
If neither `thresholds` nor `top_k` are set,
|
| 443 |
+
the default is to calculate recall with `thresholds=0.5`.
|
| 444 |
+
top_k: (Optional) Unset by default. An int value specifying the top-k
|
| 445 |
+
predictions to consider when calculating recall.
|
| 446 |
+
class_id: (Optional) Integer class ID for which we want binary metrics.
|
| 447 |
+
This must be in the half-open interval `[0, num_classes)`, where
|
| 448 |
+
`num_classes` is the last dimension of predictions.
|
| 449 |
+
name: (Optional) string name of the metric instance.
|
| 450 |
+
dtype: (Optional) data type of the metric result.
|
| 451 |
+
|
| 452 |
+
Example:
|
| 453 |
+
|
| 454 |
+
>>> m = keras.metrics.Recall()
|
| 455 |
+
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
|
| 456 |
+
>>> m.result()
|
| 457 |
+
0.6666667
|
| 458 |
+
|
| 459 |
+
>>> m.reset_state()
|
| 460 |
+
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
|
| 461 |
+
>>> m.result()
|
| 462 |
+
1.0
|
| 463 |
+
|
| 464 |
+
Usage with `compile()` API:
|
| 465 |
+
|
| 466 |
+
```python
|
| 467 |
+
model.compile(optimizer='sgd',
|
| 468 |
+
loss='binary_crossentropy',
|
| 469 |
+
metrics=[keras.metrics.Recall()])
|
| 470 |
+
```
|
| 471 |
+
|
| 472 |
+
Usage with a loss with `from_logits=True`:
|
| 473 |
+
|
| 474 |
+
```python
|
| 475 |
+
model.compile(optimizer='adam',
|
| 476 |
+
loss=keras.losses.BinaryCrossentropy(from_logits=True),
|
| 477 |
+
metrics=[keras.metrics.Recall(thresholds=0)])
|
| 478 |
+
```
|
| 479 |
+
"""
|
| 480 |
+
|
| 481 |
+
def __init__(
|
| 482 |
+
self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None
|
| 483 |
+
):
|
| 484 |
+
super().__init__(name=name, dtype=dtype)
|
| 485 |
+
# Metric should be maximized during optimization.
|
| 486 |
+
self._direction = "up"
|
| 487 |
+
|
| 488 |
+
self.init_thresholds = thresholds
|
| 489 |
+
self.top_k = top_k
|
| 490 |
+
self.class_id = class_id
|
| 491 |
+
|
| 492 |
+
default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
|
| 493 |
+
self.thresholds = metrics_utils.parse_init_thresholds(
|
| 494 |
+
thresholds, default_threshold=default_threshold
|
| 495 |
+
)
|
| 496 |
+
self._thresholds_distributed_evenly = (
|
| 497 |
+
metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
|
| 498 |
+
)
|
| 499 |
+
self.true_positives = self.add_variable(
|
| 500 |
+
shape=(len(self.thresholds),),
|
| 501 |
+
initializer=initializers.Zeros(),
|
| 502 |
+
name="true_positives",
|
| 503 |
+
)
|
| 504 |
+
self.false_negatives = self.add_variable(
|
| 505 |
+
shape=(len(self.thresholds),),
|
| 506 |
+
initializer=initializers.Zeros(),
|
| 507 |
+
name="false_negatives",
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 511 |
+
"""Accumulates true positive and false negative statistics.
|
| 512 |
+
|
| 513 |
+
Args:
|
| 514 |
+
y_true: The ground truth values, with the same dimensions as
|
| 515 |
+
`y_pred`. Will be cast to `bool`.
|
| 516 |
+
y_pred: The predicted values. Each element must be in the range
|
| 517 |
+
`[0, 1]`.
|
| 518 |
+
sample_weight: Optional weighting of each example. Defaults to `1`.
|
| 519 |
+
Can be a tensor whose rank is either 0, or the same rank as
|
| 520 |
+
`y_true`, and must be broadcastable to `y_true`.
|
| 521 |
+
"""
|
| 522 |
+
metrics_utils.update_confusion_matrix_variables(
|
| 523 |
+
{
|
| 524 |
+
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
|
| 525 |
+
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
|
| 526 |
+
},
|
| 527 |
+
y_true,
|
| 528 |
+
y_pred,
|
| 529 |
+
thresholds=self.thresholds,
|
| 530 |
+
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
|
| 531 |
+
top_k=self.top_k,
|
| 532 |
+
class_id=self.class_id,
|
| 533 |
+
sample_weight=sample_weight,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
def result(self):
|
| 537 |
+
result = ops.divide_no_nan(
|
| 538 |
+
self.true_positives,
|
| 539 |
+
ops.add(self.true_positives, self.false_negatives),
|
| 540 |
+
)
|
| 541 |
+
return result[0] if len(self.thresholds) == 1 else result
|
| 542 |
+
|
| 543 |
+
def reset_state(self):
|
| 544 |
+
num_thresholds = len(to_list(self.thresholds))
|
| 545 |
+
self.true_positives.assign(ops.zeros((num_thresholds,)))
|
| 546 |
+
self.false_negatives.assign(ops.zeros((num_thresholds,)))
|
| 547 |
+
|
| 548 |
+
def get_config(self):
|
| 549 |
+
config = {
|
| 550 |
+
"thresholds": self.init_thresholds,
|
| 551 |
+
"top_k": self.top_k,
|
| 552 |
+
"class_id": self.class_id,
|
| 553 |
+
}
|
| 554 |
+
base_config = super().get_config()
|
| 555 |
+
return {**base_config, **config}
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class SensitivitySpecificityBase(Metric):
|
| 559 |
+
"""Abstract base class for computing sensitivity and specificity.
|
| 560 |
+
|
| 561 |
+
For additional information about specificity and sensitivity, see
|
| 562 |
+
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
def __init__(
|
| 566 |
+
self, value, num_thresholds=200, class_id=None, name=None, dtype=None
|
| 567 |
+
):
|
| 568 |
+
super().__init__(name=name, dtype=dtype)
|
| 569 |
+
# Metric should be maximized during optimization.
|
| 570 |
+
self._direction = "up"
|
| 571 |
+
|
| 572 |
+
if num_thresholds <= 0:
|
| 573 |
+
raise ValueError(
|
| 574 |
+
"Argument `num_thresholds` must be an integer > 0. "
|
| 575 |
+
f"Received: num_thresholds={num_thresholds}"
|
| 576 |
+
)
|
| 577 |
+
self.value = value
|
| 578 |
+
self.class_id = class_id
|
| 579 |
+
|
| 580 |
+
# Compute `num_thresholds` thresholds in [0, 1]
|
| 581 |
+
if num_thresholds == 1:
|
| 582 |
+
self.thresholds = [0.5]
|
| 583 |
+
self._thresholds_distributed_evenly = False
|
| 584 |
+
else:
|
| 585 |
+
thresholds = [
|
| 586 |
+
(i + 1) * 1.0 / (num_thresholds - 1)
|
| 587 |
+
for i in range(num_thresholds - 2)
|
| 588 |
+
]
|
| 589 |
+
self.thresholds = [0.0] + thresholds + [1.0]
|
| 590 |
+
self._thresholds_distributed_evenly = True
|
| 591 |
+
|
| 592 |
+
self.true_positives = self.add_variable(
|
| 593 |
+
shape=(len(self.thresholds),),
|
| 594 |
+
initializer=initializers.Zeros(),
|
| 595 |
+
name="true_positives",
|
| 596 |
+
)
|
| 597 |
+
self.false_positives = self.add_variable(
|
| 598 |
+
shape=(len(self.thresholds),),
|
| 599 |
+
initializer=initializers.Zeros(),
|
| 600 |
+
name="false_positives",
|
| 601 |
+
)
|
| 602 |
+
self.true_negatives = self.add_variable(
|
| 603 |
+
shape=(len(self.thresholds),),
|
| 604 |
+
initializer=initializers.Zeros(),
|
| 605 |
+
name="true_negatives",
|
| 606 |
+
)
|
| 607 |
+
self.false_negatives = self.add_variable(
|
| 608 |
+
shape=(len(self.thresholds),),
|
| 609 |
+
initializer=initializers.Zeros(),
|
| 610 |
+
name="false_negatives",
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 614 |
+
"""Accumulates confusion matrix statistics.
|
| 615 |
+
|
| 616 |
+
Args:
|
| 617 |
+
y_true: The ground truth values.
|
| 618 |
+
y_pred: The predicted values.
|
| 619 |
+
sample_weight: Optional weighting of each example. Defaults to `1`.
|
| 620 |
+
Can be a tensor whose rank is either 0, or the same rank as
|
| 621 |
+
`y_true`, and must be broadcastable to `y_true`.
|
| 622 |
+
"""
|
| 623 |
+
metrics_utils.update_confusion_matrix_variables(
|
| 624 |
+
{
|
| 625 |
+
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
|
| 626 |
+
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
|
| 627 |
+
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
|
| 628 |
+
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
|
| 629 |
+
},
|
| 630 |
+
y_true,
|
| 631 |
+
y_pred,
|
| 632 |
+
thresholds=self.thresholds,
|
| 633 |
+
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
|
| 634 |
+
class_id=self.class_id,
|
| 635 |
+
sample_weight=sample_weight,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
def reset_state(self):
|
| 639 |
+
num_thresholds = len(self.thresholds)
|
| 640 |
+
self.true_positives.assign(ops.zeros((num_thresholds,)))
|
| 641 |
+
self.false_positives.assign(ops.zeros((num_thresholds,)))
|
| 642 |
+
self.true_negatives.assign(ops.zeros((num_thresholds,)))
|
| 643 |
+
self.false_negatives.assign(ops.zeros((num_thresholds,)))
|
| 644 |
+
|
| 645 |
+
def get_config(self):
|
| 646 |
+
config = {"class_id": self.class_id}
|
| 647 |
+
base_config = super().get_config()
|
| 648 |
+
return {**base_config, **config}
|
| 649 |
+
|
| 650 |
+
def _find_max_under_constraint(self, constrained, dependent, predicate):
|
| 651 |
+
"""Returns the maximum of dependent_statistic that satisfies the
|
| 652 |
+
constraint.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
constrained: Over these values the constraint is specified. A rank-1
|
| 656 |
+
tensor.
|
| 657 |
+
dependent: From these values the maximum that satiesfies the
|
| 658 |
+
constraint is selected. Values in this tensor and in
|
| 659 |
+
`constrained` are linked by having the same threshold at each
|
| 660 |
+
position, hence this tensor must have the same shape.
|
| 661 |
+
predicate: A binary boolean functor to be applied to arguments
|
| 662 |
+
`constrained` and `self.value`, e.g. `ops.greater`.
|
| 663 |
+
|
| 664 |
+
Returns:
|
| 665 |
+
maximal dependent value, if no value satisfies the constraint 0.0.
|
| 666 |
+
"""
|
| 667 |
+
feasible = ops.nonzero(predicate(constrained, self.value))
|
| 668 |
+
feasible_exists = ops.greater(ops.size(feasible), 0)
|
| 669 |
+
max_dependent = ops.max(ops.take(dependent, feasible), initial=0)
|
| 670 |
+
|
| 671 |
+
return ops.where(feasible_exists, max_dependent, 0.0)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
@keras_export("keras.metrics.SensitivityAtSpecificity")
|
| 675 |
+
class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
| 676 |
+
"""Computes best sensitivity where specificity is >= specified value.
|
| 677 |
+
|
| 678 |
+
`Sensitivity` measures the proportion of actual positives that are correctly
|
| 679 |
+
identified as such `(tp / (tp + fn))`.
|
| 680 |
+
`Specificity` measures the proportion of actual negatives that are correctly
|
| 681 |
+
identified as such `(tn / (tn + fp))`.
|
| 682 |
+
|
| 683 |
+
This metric creates four local variables, `true_positives`,
|
| 684 |
+
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
| 685 |
+
compute the sensitivity at the given specificity. The threshold for the
|
| 686 |
+
given specificity value is computed and used to evaluate the corresponding
|
| 687 |
+
sensitivity.
|
| 688 |
+
|
| 689 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 690 |
+
Use `sample_weight` of 0 to mask values.
|
| 691 |
+
|
| 692 |
+
If `class_id` is specified, we calculate precision by considering only the
|
| 693 |
+
entries in the batch for which `class_id` is above the threshold
|
| 694 |
+
predictions, and computing the fraction of them for which `class_id` is
|
| 695 |
+
indeed a correct label.
|
| 696 |
+
|
| 697 |
+
For additional information about specificity and sensitivity, see
|
| 698 |
+
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
|
| 699 |
+
|
| 700 |
+
Args:
|
| 701 |
+
specificity: A scalar value in range `[0, 1]`.
|
| 702 |
+
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
|
| 703 |
+
use for matching the given specificity.
|
| 704 |
+
class_id: (Optional) Integer class ID for which we want binary metrics.
|
| 705 |
+
This must be in the half-open interval `[0, num_classes)`, where
|
| 706 |
+
`num_classes` is the last dimension of predictions.
|
| 707 |
+
name: (Optional) string name of the metric instance.
|
| 708 |
+
dtype: (Optional) data type of the metric result.
|
| 709 |
+
|
| 710 |
+
Example:
|
| 711 |
+
|
| 712 |
+
>>> m = keras.metrics.SensitivityAtSpecificity(0.5)
|
| 713 |
+
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
| 714 |
+
>>> m.result()
|
| 715 |
+
0.5
|
| 716 |
+
|
| 717 |
+
>>> m.reset_state()
|
| 718 |
+
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
| 719 |
+
... sample_weight=[1, 1, 2, 2, 1])
|
| 720 |
+
>>> m.result()
|
| 721 |
+
0.333333
|
| 722 |
+
|
| 723 |
+
Usage with `compile()` API:
|
| 724 |
+
|
| 725 |
+
```python
|
| 726 |
+
model.compile(
|
| 727 |
+
optimizer='sgd',
|
| 728 |
+
loss='binary_crossentropy',
|
| 729 |
+
metrics=[keras.metrics.SensitivityAtSpecificity()])
|
| 730 |
+
```
|
| 731 |
+
"""
|
| 732 |
+
|
| 733 |
+
def __init__(
|
| 734 |
+
self,
|
| 735 |
+
specificity,
|
| 736 |
+
num_thresholds=200,
|
| 737 |
+
class_id=None,
|
| 738 |
+
name=None,
|
| 739 |
+
dtype=None,
|
| 740 |
+
):
|
| 741 |
+
if specificity < 0 or specificity > 1:
|
| 742 |
+
raise ValueError(
|
| 743 |
+
"Argument `specificity` must be in the range [0, 1]. "
|
| 744 |
+
f"Received: specificity={specificity}"
|
| 745 |
+
)
|
| 746 |
+
self.specificity = specificity
|
| 747 |
+
self.num_thresholds = num_thresholds
|
| 748 |
+
super().__init__(
|
| 749 |
+
specificity,
|
| 750 |
+
num_thresholds=num_thresholds,
|
| 751 |
+
class_id=class_id,
|
| 752 |
+
name=name,
|
| 753 |
+
dtype=dtype,
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
def result(self):
|
| 757 |
+
sensitivities = ops.divide_no_nan(
|
| 758 |
+
self.true_positives,
|
| 759 |
+
ops.add(self.true_positives, self.false_negatives),
|
| 760 |
+
)
|
| 761 |
+
specificities = ops.divide_no_nan(
|
| 762 |
+
self.true_negatives,
|
| 763 |
+
ops.add(self.true_negatives, self.false_positives),
|
| 764 |
+
)
|
| 765 |
+
return self._find_max_under_constraint(
|
| 766 |
+
specificities, sensitivities, ops.greater_equal
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
def get_config(self):
|
| 770 |
+
config = {
|
| 771 |
+
"num_thresholds": self.num_thresholds,
|
| 772 |
+
"specificity": self.specificity,
|
| 773 |
+
}
|
| 774 |
+
base_config = super().get_config()
|
| 775 |
+
return {**base_config, **config}
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
@keras_export("keras.metrics.SpecificityAtSensitivity")
|
| 779 |
+
class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
| 780 |
+
"""Computes best specificity where sensitivity is >= specified value.
|
| 781 |
+
|
| 782 |
+
`Sensitivity` measures the proportion of actual positives that are correctly
|
| 783 |
+
identified as such `(tp / (tp + fn))`.
|
| 784 |
+
`Specificity` measures the proportion of actual negatives that are correctly
|
| 785 |
+
identified as such `(tn / (tn + fp))`.
|
| 786 |
+
|
| 787 |
+
This metric creates four local variables, `true_positives`,
|
| 788 |
+
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
| 789 |
+
compute the specificity at the given sensitivity. The threshold for the
|
| 790 |
+
given sensitivity value is computed and used to evaluate the corresponding
|
| 791 |
+
specificity.
|
| 792 |
+
|
| 793 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 794 |
+
Use `sample_weight` of 0 to mask values.
|
| 795 |
+
|
| 796 |
+
If `class_id` is specified, we calculate precision by considering only the
|
| 797 |
+
entries in the batch for which `class_id` is above the threshold
|
| 798 |
+
predictions, and computing the fraction of them for which `class_id` is
|
| 799 |
+
indeed a correct label.
|
| 800 |
+
|
| 801 |
+
For additional information about specificity and sensitivity, see
|
| 802 |
+
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
sensitivity: A scalar value in range `[0, 1]`.
|
| 806 |
+
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
|
| 807 |
+
use for matching the given sensitivity.
|
| 808 |
+
class_id: (Optional) Integer class ID for which we want binary metrics.
|
| 809 |
+
This must be in the half-open interval `[0, num_classes)`, where
|
| 810 |
+
`num_classes` is the last dimension of predictions.
|
| 811 |
+
name: (Optional) string name of the metric instance.
|
| 812 |
+
dtype: (Optional) data type of the metric result.
|
| 813 |
+
|
| 814 |
+
Example:
|
| 815 |
+
|
| 816 |
+
>>> m = keras.metrics.SpecificityAtSensitivity(0.5)
|
| 817 |
+
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
| 818 |
+
>>> m.result()
|
| 819 |
+
0.66666667
|
| 820 |
+
|
| 821 |
+
>>> m.reset_state()
|
| 822 |
+
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
| 823 |
+
... sample_weight=[1, 1, 2, 2, 2])
|
| 824 |
+
>>> m.result()
|
| 825 |
+
0.5
|
| 826 |
+
|
| 827 |
+
Usage with `compile()` API:
|
| 828 |
+
|
| 829 |
+
```python
|
| 830 |
+
model.compile(
|
| 831 |
+
optimizer='sgd',
|
| 832 |
+
loss='binary_crossentropy',
|
| 833 |
+
metrics=[keras.metrics.SpecificityAtSensitivity()])
|
| 834 |
+
```
|
| 835 |
+
"""
|
| 836 |
+
|
| 837 |
+
def __init__(
|
| 838 |
+
self,
|
| 839 |
+
sensitivity,
|
| 840 |
+
num_thresholds=200,
|
| 841 |
+
class_id=None,
|
| 842 |
+
name=None,
|
| 843 |
+
dtype=None,
|
| 844 |
+
):
|
| 845 |
+
if sensitivity < 0 or sensitivity > 1:
|
| 846 |
+
raise ValueError(
|
| 847 |
+
"Argument `sensitivity` must be in the range [0, 1]. "
|
| 848 |
+
f"Received: sensitivity={sensitivity}"
|
| 849 |
+
)
|
| 850 |
+
self.sensitivity = sensitivity
|
| 851 |
+
self.num_thresholds = num_thresholds
|
| 852 |
+
super().__init__(
|
| 853 |
+
sensitivity,
|
| 854 |
+
num_thresholds=num_thresholds,
|
| 855 |
+
class_id=class_id,
|
| 856 |
+
name=name,
|
| 857 |
+
dtype=dtype,
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
def result(self):
|
| 861 |
+
sensitivities = ops.divide_no_nan(
|
| 862 |
+
self.true_positives,
|
| 863 |
+
ops.add(self.true_positives, self.false_negatives),
|
| 864 |
+
)
|
| 865 |
+
specificities = ops.divide_no_nan(
|
| 866 |
+
self.true_negatives,
|
| 867 |
+
ops.add(self.true_negatives, self.false_positives),
|
| 868 |
+
)
|
| 869 |
+
return self._find_max_under_constraint(
|
| 870 |
+
sensitivities, specificities, ops.greater_equal
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
def get_config(self):
|
| 874 |
+
config = {
|
| 875 |
+
"num_thresholds": self.num_thresholds,
|
| 876 |
+
"sensitivity": self.sensitivity,
|
| 877 |
+
}
|
| 878 |
+
base_config = super().get_config()
|
| 879 |
+
return {**base_config, **config}
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
@keras_export("keras.metrics.PrecisionAtRecall")
|
| 883 |
+
class PrecisionAtRecall(SensitivitySpecificityBase):
|
| 884 |
+
"""Computes best precision where recall is >= specified value.
|
| 885 |
+
|
| 886 |
+
This metric creates four local variables, `true_positives`,
|
| 887 |
+
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
| 888 |
+
compute the precision at the given recall. The threshold for the given
|
| 889 |
+
recall value is computed and used to evaluate the corresponding precision.
|
| 890 |
+
|
| 891 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 892 |
+
Use `sample_weight` of 0 to mask values.
|
| 893 |
+
|
| 894 |
+
If `class_id` is specified, we calculate precision by considering only the
|
| 895 |
+
entries in the batch for which `class_id` is above the threshold
|
| 896 |
+
predictions, and computing the fraction of them for which `class_id` is
|
| 897 |
+
indeed a correct label.
|
| 898 |
+
|
| 899 |
+
Args:
|
| 900 |
+
recall: A scalar value in range `[0, 1]`.
|
| 901 |
+
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
|
| 902 |
+
use for matching the given recall.
|
| 903 |
+
class_id: (Optional) Integer class ID for which we want binary metrics.
|
| 904 |
+
This must be in the half-open interval `[0, num_classes)`, where
|
| 905 |
+
`num_classes` is the last dimension of predictions.
|
| 906 |
+
name: (Optional) string name of the metric instance.
|
| 907 |
+
dtype: (Optional) data type of the metric result.
|
| 908 |
+
|
| 909 |
+
Example:
|
| 910 |
+
|
| 911 |
+
>>> m = keras.metrics.PrecisionAtRecall(0.5)
|
| 912 |
+
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
| 913 |
+
>>> m.result()
|
| 914 |
+
0.5
|
| 915 |
+
|
| 916 |
+
>>> m.reset_state()
|
| 917 |
+
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
| 918 |
+
... sample_weight=[2, 2, 2, 1, 1])
|
| 919 |
+
>>> m.result()
|
| 920 |
+
0.33333333
|
| 921 |
+
|
| 922 |
+
Usage with `compile()` API:
|
| 923 |
+
|
| 924 |
+
```python
|
| 925 |
+
model.compile(
|
| 926 |
+
optimizer='sgd',
|
| 927 |
+
loss='binary_crossentropy',
|
| 928 |
+
metrics=[keras.metrics.PrecisionAtRecall(recall=0.8)])
|
| 929 |
+
```
|
| 930 |
+
"""
|
| 931 |
+
|
| 932 |
+
def __init__(
|
| 933 |
+
self, recall, num_thresholds=200, class_id=None, name=None, dtype=None
|
| 934 |
+
):
|
| 935 |
+
if recall < 0 or recall > 1:
|
| 936 |
+
raise ValueError(
|
| 937 |
+
"Argument `recall` must be in the range [0, 1]. "
|
| 938 |
+
f"Received: recall={recall}"
|
| 939 |
+
)
|
| 940 |
+
self.recall = recall
|
| 941 |
+
self.num_thresholds = num_thresholds
|
| 942 |
+
super().__init__(
|
| 943 |
+
value=recall,
|
| 944 |
+
num_thresholds=num_thresholds,
|
| 945 |
+
class_id=class_id,
|
| 946 |
+
name=name,
|
| 947 |
+
dtype=dtype,
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
def result(self):
|
| 951 |
+
recalls = ops.divide_no_nan(
|
| 952 |
+
self.true_positives,
|
| 953 |
+
ops.add(self.true_positives, self.false_negatives),
|
| 954 |
+
)
|
| 955 |
+
precisions = ops.divide_no_nan(
|
| 956 |
+
self.true_positives,
|
| 957 |
+
ops.add(self.true_positives, self.false_positives),
|
| 958 |
+
)
|
| 959 |
+
return self._find_max_under_constraint(
|
| 960 |
+
recalls, precisions, ops.greater_equal
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
def get_config(self):
|
| 964 |
+
config = {"num_thresholds": self.num_thresholds, "recall": self.recall}
|
| 965 |
+
base_config = super().get_config()
|
| 966 |
+
return {**base_config, **config}
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
@keras_export("keras.metrics.RecallAtPrecision")
|
| 970 |
+
class RecallAtPrecision(SensitivitySpecificityBase):
|
| 971 |
+
"""Computes best recall where precision is >= specified value.
|
| 972 |
+
|
| 973 |
+
For a given score-label-distribution the required precision might not
|
| 974 |
+
be achievable, in this case 0.0 is returned as recall.
|
| 975 |
+
|
| 976 |
+
This metric creates four local variables, `true_positives`,
|
| 977 |
+
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
| 978 |
+
compute the recall at the given precision. The threshold for the given
|
| 979 |
+
precision value is computed and used to evaluate the corresponding recall.
|
| 980 |
+
|
| 981 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 982 |
+
Use `sample_weight` of 0 to mask values.
|
| 983 |
+
|
| 984 |
+
If `class_id` is specified, we calculate precision by considering only the
|
| 985 |
+
entries in the batch for which `class_id` is above the threshold
|
| 986 |
+
predictions, and computing the fraction of them for which `class_id` is
|
| 987 |
+
indeed a correct label.
|
| 988 |
+
|
| 989 |
+
Args:
|
| 990 |
+
precision: A scalar value in range `[0, 1]`.
|
| 991 |
+
num_thresholds: (Optional) Defaults to 200. The number of thresholds
|
| 992 |
+
to use for matching the given precision.
|
| 993 |
+
class_id: (Optional) Integer class ID for which we want binary metrics.
|
| 994 |
+
This must be in the half-open interval `[0, num_classes)`, where
|
| 995 |
+
`num_classes` is the last dimension of predictions.
|
| 996 |
+
name: (Optional) string name of the metric instance.
|
| 997 |
+
dtype: (Optional) data type of the metric result.
|
| 998 |
+
|
| 999 |
+
Example:
|
| 1000 |
+
|
| 1001 |
+
>>> m = keras.metrics.RecallAtPrecision(0.8)
|
| 1002 |
+
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
| 1003 |
+
>>> m.result()
|
| 1004 |
+
0.5
|
| 1005 |
+
|
| 1006 |
+
>>> m.reset_state()
|
| 1007 |
+
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
| 1008 |
+
... sample_weight=[1, 0, 0, 1])
|
| 1009 |
+
>>> m.result()
|
| 1010 |
+
1.0
|
| 1011 |
+
|
| 1012 |
+
Usage with `compile()` API:
|
| 1013 |
+
|
| 1014 |
+
```python
|
| 1015 |
+
model.compile(
|
| 1016 |
+
optimizer='sgd',
|
| 1017 |
+
loss='binary_crossentropy',
|
| 1018 |
+
metrics=[keras.metrics.RecallAtPrecision(precision=0.8)])
|
| 1019 |
+
```
|
| 1020 |
+
"""
|
| 1021 |
+
|
| 1022 |
+
def __init__(
|
| 1023 |
+
self,
|
| 1024 |
+
precision,
|
| 1025 |
+
num_thresholds=200,
|
| 1026 |
+
class_id=None,
|
| 1027 |
+
name=None,
|
| 1028 |
+
dtype=None,
|
| 1029 |
+
):
|
| 1030 |
+
if precision < 0 or precision > 1:
|
| 1031 |
+
raise ValueError(
|
| 1032 |
+
"Argument `precision` must be in the range [0, 1]. "
|
| 1033 |
+
f"Received: precision={precision}"
|
| 1034 |
+
)
|
| 1035 |
+
self.precision = precision
|
| 1036 |
+
self.num_thresholds = num_thresholds
|
| 1037 |
+
super().__init__(
|
| 1038 |
+
value=precision,
|
| 1039 |
+
num_thresholds=num_thresholds,
|
| 1040 |
+
class_id=class_id,
|
| 1041 |
+
name=name,
|
| 1042 |
+
dtype=dtype,
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
+
def result(self):
|
| 1046 |
+
recalls = ops.divide_no_nan(
|
| 1047 |
+
self.true_positives,
|
| 1048 |
+
ops.add(self.true_positives, self.false_negatives),
|
| 1049 |
+
)
|
| 1050 |
+
precisions = ops.divide_no_nan(
|
| 1051 |
+
self.true_positives,
|
| 1052 |
+
ops.add(self.true_positives, self.false_positives),
|
| 1053 |
+
)
|
| 1054 |
+
return self._find_max_under_constraint(
|
| 1055 |
+
precisions, recalls, ops.greater_equal
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
def get_config(self):
|
| 1059 |
+
config = {
|
| 1060 |
+
"num_thresholds": self.num_thresholds,
|
| 1061 |
+
"precision": self.precision,
|
| 1062 |
+
}
|
| 1063 |
+
base_config = super().get_config()
|
| 1064 |
+
return {**base_config, **config}
|
| 1065 |
+
|
| 1066 |
+
|
| 1067 |
+
@keras_export("keras.metrics.AUC")
|
| 1068 |
+
class AUC(Metric):
|
| 1069 |
+
"""Approximates the AUC (Area under the curve) of the ROC or PR curves.
|
| 1070 |
+
|
| 1071 |
+
The AUC (Area under the curve) of the ROC (Receiver operating
|
| 1072 |
+
characteristic; default) or PR (Precision Recall) curves are quality
|
| 1073 |
+
measures of binary classifiers. Unlike the accuracy, and like cross-entropy
|
| 1074 |
+
losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
|
| 1075 |
+
|
| 1076 |
+
This class approximates AUCs using a Riemann sum. During the metric
|
| 1077 |
+
accumulation phrase, predictions are accumulated within predefined buckets
|
| 1078 |
+
by value. The AUC is then computed by interpolating per-bucket averages.
|
| 1079 |
+
These buckets define the evaluated operational points.
|
| 1080 |
+
|
| 1081 |
+
This metric creates four local variables, `true_positives`,
|
| 1082 |
+
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
| 1083 |
+
compute the AUC. To discretize the AUC curve, a linearly spaced set of
|
| 1084 |
+
thresholds is used to compute pairs of recall and precision values. The area
|
| 1085 |
+
under the ROC-curve is therefore computed using the height of the recall
|
| 1086 |
+
values by the false positive rate, while the area under the PR-curve is the
|
| 1087 |
+
computed using the height of the precision values by the recall.
|
| 1088 |
+
|
| 1089 |
+
This value is ultimately returned as `auc`, an idempotent operation that
|
| 1090 |
+
computes the area under a discretized curve of precision versus recall
|
| 1091 |
+
values (computed using the aforementioned variables). The `num_thresholds`
|
| 1092 |
+
variable controls the degree of discretization with larger numbers of
|
| 1093 |
+
thresholds more closely approximating the true AUC. The quality of the
|
| 1094 |
+
approximation may vary dramatically depending on `num_thresholds`. The
|
| 1095 |
+
`thresholds` parameter can be used to manually specify thresholds which
|
| 1096 |
+
split the predictions more evenly.
|
| 1097 |
+
|
| 1098 |
+
For a best approximation of the real AUC, `predictions` should be
|
| 1099 |
+
distributed approximately uniformly in the range `[0, 1]` (if
|
| 1100 |
+
`from_logits=False`). The quality of the AUC approximation may be poor if
|
| 1101 |
+
this is not the case. Setting `summation_method` to 'minoring' or 'majoring'
|
| 1102 |
+
can help quantify the error in the approximation by providing lower or upper
|
| 1103 |
+
bound estimate of the AUC.
|
| 1104 |
+
|
| 1105 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 1106 |
+
Use `sample_weight` of 0 to mask values.
|
| 1107 |
+
|
| 1108 |
+
Args:
|
| 1109 |
+
num_thresholds: (Optional) The number of thresholds to
|
| 1110 |
+
use when discretizing the roc curve. Values must be > 1.
|
| 1111 |
+
Defaults to `200`.
|
| 1112 |
+
curve: (Optional) Specifies the name of the curve to be computed,
|
| 1113 |
+
`'ROC'` (default) or `'PR'` for the Precision-Recall-curve.
|
| 1114 |
+
summation_method: (Optional) Specifies the [Riemann summation method](
|
| 1115 |
+
https://en.wikipedia.org/wiki/Riemann_sum) used.
|
| 1116 |
+
'interpolation' (default) applies mid-point summation scheme for
|
| 1117 |
+
`ROC`. For PR-AUC, interpolates (true/false) positives but not
|
| 1118 |
+
the ratio that is precision (see Davis & Goadrich 2006 for
|
| 1119 |
+
details); 'minoring' applies left summation for increasing
|
| 1120 |
+
intervals and right summation for decreasing intervals; 'majoring'
|
| 1121 |
+
does the opposite.
|
| 1122 |
+
name: (Optional) string name of the metric instance.
|
| 1123 |
+
dtype: (Optional) data type of the metric result.
|
| 1124 |
+
thresholds: (Optional) A list of floating point values to use as the
|
| 1125 |
+
thresholds for discretizing the curve. If set, the `num_thresholds`
|
| 1126 |
+
parameter is ignored. Values should be in `[0, 1]`. Endpoint
|
| 1127 |
+
thresholds equal to {`-epsilon`, `1+epsilon`} for a small positive
|
| 1128 |
+
epsilon value will be automatically included with these to correctly
|
| 1129 |
+
handle predictions equal to exactly 0 or 1.
|
| 1130 |
+
multi_label: boolean indicating whether multilabel data should be
|
| 1131 |
+
treated as such, wherein AUC is computed separately for each label
|
| 1132 |
+
and then averaged across labels, or (when `False`) if the data
|
| 1133 |
+
should be flattened into a single label before AUC computation. In
|
| 1134 |
+
the latter case, when multilabel data is passed to AUC, each
|
| 1135 |
+
label-prediction pair is treated as an individual data point. Should
|
| 1136 |
+
be set to `False` for multi-class data.
|
| 1137 |
+
num_labels: (Optional) The number of labels, used when `multi_label` is
|
| 1138 |
+
True. If `num_labels` is not specified, then state variables get
|
| 1139 |
+
created on the first call to `update_state`.
|
| 1140 |
+
label_weights: (Optional) list, array, or tensor of non-negative weights
|
| 1141 |
+
used to compute AUCs for multilabel data. When `multi_label` is
|
| 1142 |
+
True, the weights are applied to the individual label AUCs when they
|
| 1143 |
+
are averaged to produce the multi-label AUC. When it's False, they
|
| 1144 |
+
are used to weight the individual label predictions in computing the
|
| 1145 |
+
confusion matrix on the flattened data. Note that this is unlike
|
| 1146 |
+
`class_weights` in that `class_weights` weights the example
|
| 1147 |
+
depending on the value of its label, whereas `label_weights` depends
|
| 1148 |
+
only on the index of that label before flattening; therefore
|
| 1149 |
+
`label_weights` should not be used for multi-class data.
|
| 1150 |
+
from_logits: boolean indicating whether the predictions (`y_pred` in
|
| 1151 |
+
`update_state`) are probabilities or sigmoid logits. As a rule of thumb,
|
| 1152 |
+
when using a keras loss, the `from_logits` constructor argument of the
|
| 1153 |
+
loss should match the AUC `from_logits` constructor argument.
|
| 1154 |
+
|
| 1155 |
+
Example:
|
| 1156 |
+
|
| 1157 |
+
>>> m = keras.metrics.AUC(num_thresholds=3)
|
| 1158 |
+
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
| 1159 |
+
>>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
|
| 1160 |
+
>>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
|
| 1161 |
+
>>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
|
| 1162 |
+
>>> # auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0)))
|
| 1163 |
+
>>> # = 0.75
|
| 1164 |
+
>>> m.result()
|
| 1165 |
+
0.75
|
| 1166 |
+
|
| 1167 |
+
>>> m.reset_state()
|
| 1168 |
+
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
| 1169 |
+
... sample_weight=[1, 0, 0, 1])
|
| 1170 |
+
>>> m.result()
|
| 1171 |
+
1.0
|
| 1172 |
+
|
| 1173 |
+
Usage with `compile()` API:
|
| 1174 |
+
|
| 1175 |
+
```python
|
| 1176 |
+
# Reports the AUC of a model outputting a probability.
|
| 1177 |
+
model.compile(optimizer='sgd',
|
| 1178 |
+
loss=keras.losses.BinaryCrossentropy(),
|
| 1179 |
+
metrics=[keras.metrics.AUC()])
|
| 1180 |
+
|
| 1181 |
+
# Reports the AUC of a model outputting a logit.
|
| 1182 |
+
model.compile(optimizer='sgd',
|
| 1183 |
+
loss=keras.losses.BinaryCrossentropy(from_logits=True),
|
| 1184 |
+
metrics=[keras.metrics.AUC(from_logits=True)])
|
| 1185 |
+
```
|
| 1186 |
+
"""
|
| 1187 |
+
|
| 1188 |
+
def __init__(
|
| 1189 |
+
self,
|
| 1190 |
+
num_thresholds=200,
|
| 1191 |
+
curve="ROC",
|
| 1192 |
+
summation_method="interpolation",
|
| 1193 |
+
name=None,
|
| 1194 |
+
dtype=None,
|
| 1195 |
+
thresholds=None,
|
| 1196 |
+
multi_label=False,
|
| 1197 |
+
num_labels=None,
|
| 1198 |
+
label_weights=None,
|
| 1199 |
+
from_logits=False,
|
| 1200 |
+
):
|
| 1201 |
+
# Metric should be maximized during optimization.
|
| 1202 |
+
self._direction = "up"
|
| 1203 |
+
|
| 1204 |
+
# Validate configurations.
|
| 1205 |
+
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
|
| 1206 |
+
metrics_utils.AUCCurve
|
| 1207 |
+
):
|
| 1208 |
+
raise ValueError(
|
| 1209 |
+
f'Invalid `curve` argument value "{curve}". '
|
| 1210 |
+
f"Expected one of: {list(metrics_utils.AUCCurve)}"
|
| 1211 |
+
)
|
| 1212 |
+
if isinstance(
|
| 1213 |
+
summation_method, metrics_utils.AUCSummationMethod
|
| 1214 |
+
) and summation_method not in list(metrics_utils.AUCSummationMethod):
|
| 1215 |
+
raise ValueError(
|
| 1216 |
+
"Invalid `summation_method` "
|
| 1217 |
+
f'argument value "{summation_method}". '
|
| 1218 |
+
f"Expected one of: {list(metrics_utils.AUCSummationMethod)}"
|
| 1219 |
+
)
|
| 1220 |
+
|
| 1221 |
+
# Update properties.
|
| 1222 |
+
self._init_from_thresholds = thresholds is not None
|
| 1223 |
+
if thresholds is not None:
|
| 1224 |
+
# If specified, use the supplied thresholds.
|
| 1225 |
+
self.num_thresholds = len(thresholds) + 2
|
| 1226 |
+
thresholds = sorted(thresholds)
|
| 1227 |
+
self._thresholds_distributed_evenly = (
|
| 1228 |
+
metrics_utils.is_evenly_distributed_thresholds(
|
| 1229 |
+
np.array([0.0] + thresholds + [1.0])
|
| 1230 |
+
)
|
| 1231 |
+
)
|
| 1232 |
+
else:
|
| 1233 |
+
if num_thresholds <= 1:
|
| 1234 |
+
raise ValueError(
|
| 1235 |
+
"Argument `num_thresholds` must be an integer > 1. "
|
| 1236 |
+
f"Received: num_thresholds={num_thresholds}"
|
| 1237 |
+
)
|
| 1238 |
+
|
| 1239 |
+
# Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
|
| 1240 |
+
# (0, 1).
|
| 1241 |
+
self.num_thresholds = num_thresholds
|
| 1242 |
+
thresholds = [
|
| 1243 |
+
(i + 1) * 1.0 / (num_thresholds - 1)
|
| 1244 |
+
for i in range(num_thresholds - 2)
|
| 1245 |
+
]
|
| 1246 |
+
self._thresholds_distributed_evenly = True
|
| 1247 |
+
|
| 1248 |
+
# Add an endpoint "threshold" below zero and above one for either
|
| 1249 |
+
# threshold method to account for floating point imprecisions.
|
| 1250 |
+
self._thresholds = np.array(
|
| 1251 |
+
[0.0 - backend.epsilon()] + thresholds + [1.0 + backend.epsilon()]
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
if isinstance(curve, metrics_utils.AUCCurve):
|
| 1255 |
+
self.curve = curve
|
| 1256 |
+
else:
|
| 1257 |
+
self.curve = metrics_utils.AUCCurve.from_str(curve)
|
| 1258 |
+
if isinstance(summation_method, metrics_utils.AUCSummationMethod):
|
| 1259 |
+
self.summation_method = summation_method
|
| 1260 |
+
else:
|
| 1261 |
+
self.summation_method = metrics_utils.AUCSummationMethod.from_str(
|
| 1262 |
+
summation_method
|
| 1263 |
+
)
|
| 1264 |
+
super().__init__(name=name, dtype=dtype)
|
| 1265 |
+
|
| 1266 |
+
# Handle multilabel arguments.
|
| 1267 |
+
self.multi_label = multi_label
|
| 1268 |
+
self.num_labels = num_labels
|
| 1269 |
+
if label_weights is not None:
|
| 1270 |
+
label_weights = ops.array(label_weights, dtype=self.dtype)
|
| 1271 |
+
self.label_weights = label_weights
|
| 1272 |
+
|
| 1273 |
+
else:
|
| 1274 |
+
self.label_weights = None
|
| 1275 |
+
|
| 1276 |
+
self._from_logits = from_logits
|
| 1277 |
+
|
| 1278 |
+
self._built = False
|
| 1279 |
+
if self.multi_label:
|
| 1280 |
+
if num_labels:
|
| 1281 |
+
shape = [None, num_labels]
|
| 1282 |
+
self._build(shape)
|
| 1283 |
+
else:
|
| 1284 |
+
if num_labels:
|
| 1285 |
+
raise ValueError(
|
| 1286 |
+
"`num_labels` is needed only when `multi_label` is True."
|
| 1287 |
+
)
|
| 1288 |
+
self._build(None)
|
| 1289 |
+
|
| 1290 |
+
@property
|
| 1291 |
+
def thresholds(self):
|
| 1292 |
+
"""The thresholds used for evaluating AUC."""
|
| 1293 |
+
return list(self._thresholds)
|
| 1294 |
+
|
| 1295 |
+
def _build(self, shape):
|
| 1296 |
+
"""Initialize TP, FP, TN, and FN tensors, given the shape of the
|
| 1297 |
+
data."""
|
| 1298 |
+
if self.multi_label:
|
| 1299 |
+
if len(shape) != 2:
|
| 1300 |
+
raise ValueError(
|
| 1301 |
+
"`y_pred` must have rank 2 when `multi_label=True`. "
|
| 1302 |
+
f"Found rank {len(shape)}. "
|
| 1303 |
+
f"Full shape received for `y_pred`: {shape}"
|
| 1304 |
+
)
|
| 1305 |
+
self._num_labels = shape[1]
|
| 1306 |
+
variable_shape = [self.num_thresholds, self._num_labels]
|
| 1307 |
+
else:
|
| 1308 |
+
variable_shape = [self.num_thresholds]
|
| 1309 |
+
|
| 1310 |
+
self._build_input_shape = shape
|
| 1311 |
+
# Create metric variables
|
| 1312 |
+
self.true_positives = self.add_variable(
|
| 1313 |
+
shape=variable_shape,
|
| 1314 |
+
initializer=initializers.Zeros(),
|
| 1315 |
+
name="true_positives",
|
| 1316 |
+
)
|
| 1317 |
+
self.false_positives = self.add_variable(
|
| 1318 |
+
shape=variable_shape,
|
| 1319 |
+
initializer=initializers.Zeros(),
|
| 1320 |
+
name="false_positives",
|
| 1321 |
+
)
|
| 1322 |
+
self.true_negatives = self.add_variable(
|
| 1323 |
+
shape=variable_shape,
|
| 1324 |
+
initializer=initializers.Zeros(),
|
| 1325 |
+
name="true_negatives",
|
| 1326 |
+
)
|
| 1327 |
+
self.false_negatives = self.add_variable(
|
| 1328 |
+
shape=variable_shape,
|
| 1329 |
+
initializer=initializers.Zeros(),
|
| 1330 |
+
name="false_negatives",
|
| 1331 |
+
)
|
| 1332 |
+
|
| 1333 |
+
self._built = True
|
| 1334 |
+
|
| 1335 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 1336 |
+
"""Accumulates confusion matrix statistics.
|
| 1337 |
+
|
| 1338 |
+
Args:
|
| 1339 |
+
y_true: The ground truth values.
|
| 1340 |
+
y_pred: The predicted values.
|
| 1341 |
+
sample_weight: Optional weighting of each example. Can
|
| 1342 |
+
be a tensor whose rank is either 0, or the same rank as
|
| 1343 |
+
`y_true`, and must be broadcastable to `y_true`. Defaults to
|
| 1344 |
+
`1`.
|
| 1345 |
+
"""
|
| 1346 |
+
if not self._built:
|
| 1347 |
+
self._build(y_pred.shape)
|
| 1348 |
+
|
| 1349 |
+
if self.multi_label or (self.label_weights is not None):
|
| 1350 |
+
# y_true should have shape (number of examples, number of labels).
|
| 1351 |
+
shapes = [(y_true, ("N", "L"))]
|
| 1352 |
+
if self.multi_label:
|
| 1353 |
+
# TP, TN, FP, and FN should all have shape
|
| 1354 |
+
# (number of thresholds, number of labels).
|
| 1355 |
+
shapes.extend(
|
| 1356 |
+
[
|
| 1357 |
+
(self.true_positives, ("T", "L")),
|
| 1358 |
+
(self.true_negatives, ("T", "L")),
|
| 1359 |
+
(self.false_positives, ("T", "L")),
|
| 1360 |
+
(self.false_negatives, ("T", "L")),
|
| 1361 |
+
]
|
| 1362 |
+
)
|
| 1363 |
+
if self.label_weights is not None:
|
| 1364 |
+
# label_weights should be of length equal to the number of
|
| 1365 |
+
# labels.
|
| 1366 |
+
shapes.append((self.label_weights, ("L",)))
|
| 1367 |
+
|
| 1368 |
+
# Only forward label_weights to update_confusion_matrix_variables when
|
| 1369 |
+
# multi_label is False. Otherwise the averaging of individual label AUCs
|
| 1370 |
+
# is handled in AUC.result
|
| 1371 |
+
label_weights = None if self.multi_label else self.label_weights
|
| 1372 |
+
|
| 1373 |
+
if self._from_logits:
|
| 1374 |
+
y_pred = activations.sigmoid(y_pred)
|
| 1375 |
+
|
| 1376 |
+
metrics_utils.update_confusion_matrix_variables(
|
| 1377 |
+
{
|
| 1378 |
+
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
|
| 1379 |
+
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
|
| 1380 |
+
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
|
| 1381 |
+
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
|
| 1382 |
+
},
|
| 1383 |
+
y_true,
|
| 1384 |
+
y_pred,
|
| 1385 |
+
self._thresholds,
|
| 1386 |
+
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
|
| 1387 |
+
sample_weight=sample_weight,
|
| 1388 |
+
multi_label=self.multi_label,
|
| 1389 |
+
label_weights=label_weights,
|
| 1390 |
+
)
|
| 1391 |
+
|
| 1392 |
+
def interpolate_pr_auc(self):
|
| 1393 |
+
"""Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
|
| 1394 |
+
|
| 1395 |
+
https://www.biostat.wisc.edu/~page/rocpr.pdf
|
| 1396 |
+
|
| 1397 |
+
Note here we derive & use a closed formula not present in the paper
|
| 1398 |
+
as follows:
|
| 1399 |
+
|
| 1400 |
+
Precision = TP / (TP + FP) = TP / P
|
| 1401 |
+
|
| 1402 |
+
Modeling all of TP (true positive), FP (false positive) and their sum
|
| 1403 |
+
P = TP + FP (predicted positive) as varying linearly within each
|
| 1404 |
+
interval [A, B] between successive thresholds, we get
|
| 1405 |
+
|
| 1406 |
+
Precision slope = dTP / dP
|
| 1407 |
+
= (TP_B - TP_A) / (P_B - P_A)
|
| 1408 |
+
= (TP - TP_A) / (P - P_A)
|
| 1409 |
+
Precision = (TP_A + slope * (P - P_A)) / P
|
| 1410 |
+
|
| 1411 |
+
The area within the interval is (slope / total_pos_weight) times
|
| 1412 |
+
|
| 1413 |
+
int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
|
| 1414 |
+
int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
|
| 1415 |
+
|
| 1416 |
+
where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
|
| 1417 |
+
|
| 1418 |
+
int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
|
| 1419 |
+
|
| 1420 |
+
Bringing back the factor (slope / total_pos_weight) we'd put aside, we
|
| 1421 |
+
get
|
| 1422 |
+
|
| 1423 |
+
slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
|
| 1424 |
+
|
| 1425 |
+
where dTP == TP_B - TP_A.
|
| 1426 |
+
|
| 1427 |
+
Note that when P_A == 0 the above calculation simplifies into
|
| 1428 |
+
|
| 1429 |
+
int_A^B{Precision.dTP} = int_A^B{slope * dTP}
|
| 1430 |
+
= slope * (TP_B - TP_A)
|
| 1431 |
+
|
| 1432 |
+
which is really equivalent to imputing constant precision throughout the
|
| 1433 |
+
first bucket having >0 true positives.
|
| 1434 |
+
|
| 1435 |
+
Returns:
|
| 1436 |
+
pr_auc: an approximation of the area under the P-R curve.
|
| 1437 |
+
"""
|
| 1438 |
+
|
| 1439 |
+
dtp = ops.subtract(
|
| 1440 |
+
self.true_positives[: self.num_thresholds - 1],
|
| 1441 |
+
self.true_positives[1:],
|
| 1442 |
+
)
|
| 1443 |
+
p = ops.add(self.true_positives, self.false_positives)
|
| 1444 |
+
dp = ops.subtract(p[: self.num_thresholds - 1], p[1:])
|
| 1445 |
+
prec_slope = ops.divide_no_nan(dtp, ops.maximum(dp, 0))
|
| 1446 |
+
intercept = ops.subtract(
|
| 1447 |
+
self.true_positives[1:], ops.multiply(prec_slope, p[1:])
|
| 1448 |
+
)
|
| 1449 |
+
|
| 1450 |
+
safe_p_ratio = ops.where(
|
| 1451 |
+
ops.logical_and(p[: self.num_thresholds - 1] > 0, p[1:] > 0),
|
| 1452 |
+
ops.divide_no_nan(
|
| 1453 |
+
p[: self.num_thresholds - 1], ops.maximum(p[1:], 0)
|
| 1454 |
+
),
|
| 1455 |
+
ops.ones_like(p[1:]),
|
| 1456 |
+
)
|
| 1457 |
+
|
| 1458 |
+
pr_auc_increment = ops.divide_no_nan(
|
| 1459 |
+
ops.multiply(
|
| 1460 |
+
prec_slope,
|
| 1461 |
+
(ops.add(dtp, ops.multiply(intercept, ops.log(safe_p_ratio)))),
|
| 1462 |
+
),
|
| 1463 |
+
ops.maximum(
|
| 1464 |
+
ops.add(self.true_positives[1:], self.false_negatives[1:]), 0
|
| 1465 |
+
),
|
| 1466 |
+
)
|
| 1467 |
+
|
| 1468 |
+
if self.multi_label:
|
| 1469 |
+
by_label_auc = ops.sum(pr_auc_increment, axis=0)
|
| 1470 |
+
if self.label_weights is None:
|
| 1471 |
+
# Evenly weighted average of the label AUCs.
|
| 1472 |
+
return ops.mean(by_label_auc)
|
| 1473 |
+
else:
|
| 1474 |
+
# Weighted average of the label AUCs.
|
| 1475 |
+
return ops.divide_no_nan(
|
| 1476 |
+
ops.sum(ops.multiply(by_label_auc, self.label_weights)),
|
| 1477 |
+
ops.sum(self.label_weights),
|
| 1478 |
+
)
|
| 1479 |
+
else:
|
| 1480 |
+
return ops.sum(pr_auc_increment)
|
| 1481 |
+
|
| 1482 |
+
def result(self):
|
| 1483 |
+
if (
|
| 1484 |
+
self.curve == metrics_utils.AUCCurve.PR
|
| 1485 |
+
and self.summation_method
|
| 1486 |
+
== metrics_utils.AUCSummationMethod.INTERPOLATION
|
| 1487 |
+
):
|
| 1488 |
+
# This use case is different and is handled separately.
|
| 1489 |
+
return self.interpolate_pr_auc()
|
| 1490 |
+
|
| 1491 |
+
# Set `x` and `y` values for the curves based on `curve` config.
|
| 1492 |
+
recall = ops.divide_no_nan(
|
| 1493 |
+
self.true_positives,
|
| 1494 |
+
ops.add(self.true_positives, self.false_negatives),
|
| 1495 |
+
)
|
| 1496 |
+
if self.curve == metrics_utils.AUCCurve.ROC:
|
| 1497 |
+
fp_rate = ops.divide_no_nan(
|
| 1498 |
+
self.false_positives,
|
| 1499 |
+
ops.add(self.false_positives, self.true_negatives),
|
| 1500 |
+
)
|
| 1501 |
+
x = fp_rate
|
| 1502 |
+
y = recall
|
| 1503 |
+
else: # curve == 'PR'.
|
| 1504 |
+
precision = ops.divide_no_nan(
|
| 1505 |
+
self.true_positives,
|
| 1506 |
+
ops.add(self.true_positives, self.false_positives),
|
| 1507 |
+
)
|
| 1508 |
+
x = recall
|
| 1509 |
+
y = precision
|
| 1510 |
+
|
| 1511 |
+
# Find the rectangle heights based on `summation_method`.
|
| 1512 |
+
if (
|
| 1513 |
+
self.summation_method
|
| 1514 |
+
== metrics_utils.AUCSummationMethod.INTERPOLATION
|
| 1515 |
+
):
|
| 1516 |
+
# Note: the case ('PR', 'interpolation') has been handled above.
|
| 1517 |
+
heights = ops.divide(
|
| 1518 |
+
ops.add(y[: self.num_thresholds - 1], y[1:]), 2.0
|
| 1519 |
+
)
|
| 1520 |
+
elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
|
| 1521 |
+
heights = ops.minimum(y[: self.num_thresholds - 1], y[1:])
|
| 1522 |
+
# self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
|
| 1523 |
+
else:
|
| 1524 |
+
heights = ops.maximum(y[: self.num_thresholds - 1], y[1:])
|
| 1525 |
+
|
| 1526 |
+
# Sum up the areas of all the rectangles.
|
| 1527 |
+
riemann_terms = ops.multiply(
|
| 1528 |
+
ops.subtract(x[: self.num_thresholds - 1], x[1:]), heights
|
| 1529 |
+
)
|
| 1530 |
+
if self.multi_label:
|
| 1531 |
+
by_label_auc = ops.sum(riemann_terms, axis=0)
|
| 1532 |
+
|
| 1533 |
+
if self.label_weights is None:
|
| 1534 |
+
# Unweighted average of the label AUCs.
|
| 1535 |
+
return ops.mean(by_label_auc)
|
| 1536 |
+
else:
|
| 1537 |
+
# Weighted average of the label AUCs.
|
| 1538 |
+
return ops.divide_no_nan(
|
| 1539 |
+
ops.sum(ops.multiply(by_label_auc, self.label_weights)),
|
| 1540 |
+
ops.sum(self.label_weights),
|
| 1541 |
+
)
|
| 1542 |
+
else:
|
| 1543 |
+
return ops.sum(riemann_terms)
|
| 1544 |
+
|
| 1545 |
+
def reset_state(self):
|
| 1546 |
+
if self._built:
|
| 1547 |
+
if self.multi_label:
|
| 1548 |
+
variable_shape = (self.num_thresholds, self._num_labels)
|
| 1549 |
+
else:
|
| 1550 |
+
variable_shape = (self.num_thresholds,)
|
| 1551 |
+
|
| 1552 |
+
self.true_positives.assign(ops.zeros(variable_shape))
|
| 1553 |
+
self.false_positives.assign(ops.zeros(variable_shape))
|
| 1554 |
+
self.true_negatives.assign(ops.zeros(variable_shape))
|
| 1555 |
+
self.false_negatives.assign(ops.zeros(variable_shape))
|
| 1556 |
+
|
| 1557 |
+
def get_config(self):
|
| 1558 |
+
label_weights = self.label_weights
|
| 1559 |
+
config = {
|
| 1560 |
+
"num_thresholds": self.num_thresholds,
|
| 1561 |
+
"curve": self.curve.value,
|
| 1562 |
+
"summation_method": self.summation_method.value,
|
| 1563 |
+
"multi_label": self.multi_label,
|
| 1564 |
+
"num_labels": self.num_labels,
|
| 1565 |
+
"label_weights": label_weights,
|
| 1566 |
+
"from_logits": self._from_logits,
|
| 1567 |
+
}
|
| 1568 |
+
# optimization to avoid serializing a large number of generated
|
| 1569 |
+
# thresholds
|
| 1570 |
+
if self._init_from_thresholds:
|
| 1571 |
+
# We remove the endpoint thresholds as an inverse of how the
|
| 1572 |
+
# thresholds were initialized. This ensures that a metric
|
| 1573 |
+
# initialized from this config has the same thresholds.
|
| 1574 |
+
config["thresholds"] = self.thresholds[1:-1]
|
| 1575 |
+
base_config = super().get_config()
|
| 1576 |
+
return {**base_config, **config}
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/correlation_metrics.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src import ops
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.losses.loss import squeeze_or_expand_to_same_rank
|
| 5 |
+
from keras.src.metrics import reduction_metrics
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@keras_export("keras.metrics.pearson_correlation")
|
| 9 |
+
def pearson_correlation(y_true, y_pred, axis=-1):
|
| 10 |
+
"""Computes the Pearson coefficient between labels and predictions.
|
| 11 |
+
|
| 12 |
+
Formula:
|
| 13 |
+
|
| 14 |
+
```python
|
| 15 |
+
loss = mean(l2norm(y_true - mean(y_true) * l2norm(y_pred - mean(y_pred)))
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
y_true: Tensor of true targets.
|
| 20 |
+
y_pred: Tensor of predicted targets.
|
| 21 |
+
axis: Axis along which to determine similarity. Defaults to `-1`.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Pearson Correlation Coefficient tensor.
|
| 25 |
+
|
| 26 |
+
Example:
|
| 27 |
+
|
| 28 |
+
>>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]
|
| 29 |
+
>>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]
|
| 30 |
+
>>> loss = keras.losses.concordance_correlation(
|
| 31 |
+
... y_true, y_pred, axis=-1
|
| 32 |
+
... ).numpy()
|
| 33 |
+
[1. 0.99339927]
|
| 34 |
+
"""
|
| 35 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 36 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 37 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 38 |
+
|
| 39 |
+
y_true_norm = y_true - ops.mean(y_true, axis=axis, keepdims=True)
|
| 40 |
+
y_pred_norm = y_pred - ops.mean(y_pred, axis=axis, keepdims=True)
|
| 41 |
+
|
| 42 |
+
y_true_norm = y_true_norm / ops.std(y_true_norm, axis=axis, keepdims=True)
|
| 43 |
+
y_pred_norm = y_pred_norm / ops.std(y_pred_norm, axis=axis, keepdims=True)
|
| 44 |
+
|
| 45 |
+
return ops.mean(y_true_norm * y_pred_norm, axis=axis)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@keras_export("keras.metrics.concordance_correlation")
|
| 49 |
+
def concordance_correlation(y_true, y_pred, axis=-1):
|
| 50 |
+
"""Computes the Concordance coefficient between labels and predictions.
|
| 51 |
+
|
| 52 |
+
Formula:
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
loss = mean(
|
| 56 |
+
2 * (y_true - mean(y_true) * (y_pred - mean(y_pred)) / (
|
| 57 |
+
var(y_true) + var(y_pred) + square(mean(y_true) - mean(y_pred))
|
| 58 |
+
)
|
| 59 |
+
)
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
y_true: Tensor of true targets.
|
| 64 |
+
y_pred: Tensor of predicted targets.
|
| 65 |
+
axis: Axis along which to determine similarity. Defaults to `-1`.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Concordance Correlation Coefficient tensor.
|
| 69 |
+
|
| 70 |
+
Example:
|
| 71 |
+
|
| 72 |
+
>>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]
|
| 73 |
+
>>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]
|
| 74 |
+
>>> loss = keras.losses.concordance_correlation(
|
| 75 |
+
... y_true, y_pred, axis=-1
|
| 76 |
+
... ).numpy()
|
| 77 |
+
[0.97560976 0.98765432]
|
| 78 |
+
"""
|
| 79 |
+
y_pred = ops.convert_to_tensor(y_pred)
|
| 80 |
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
| 81 |
+
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
|
| 82 |
+
|
| 83 |
+
y_true_mean = ops.mean(y_true, axis=axis, keepdims=True)
|
| 84 |
+
y_pred_mean = ops.mean(y_pred, axis=axis, keepdims=True)
|
| 85 |
+
|
| 86 |
+
y_true_var = ops.var(y_true - y_true_mean, axis=axis, keepdims=True)
|
| 87 |
+
y_pred_var = ops.var(y_pred - y_pred_mean, axis=axis, keepdims=True)
|
| 88 |
+
|
| 89 |
+
covar = (y_true - y_pred_mean) * (y_pred - y_pred_mean)
|
| 90 |
+
norm = y_true_var + y_pred_var + ops.square(y_true_mean - y_pred_mean)
|
| 91 |
+
|
| 92 |
+
return ops.mean(2 * covar / (norm + backend.epsilon()), axis=axis)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@keras_export("keras.metrics.PearsonCorrelation")
|
| 96 |
+
class PearsonCorrelation(reduction_metrics.MeanMetricWrapper):
|
| 97 |
+
"""Calculates the Pearson Correlation Coefficient (PCC).
|
| 98 |
+
|
| 99 |
+
PCC measures the linear relationship between the true values (`y_true`) and
|
| 100 |
+
the predicted values (`y_pred`). The coefficient ranges from -1 to 1, where
|
| 101 |
+
a value of 1 implies a perfect positive linear correlation, 0 indicates no
|
| 102 |
+
linear correlation, and -1 indicates a perfect negative linear correlation.
|
| 103 |
+
|
| 104 |
+
This metric is widely used in regression tasks where the strength of the
|
| 105 |
+
linear relationship between predictions and true labels is an
|
| 106 |
+
important evaluation criterion.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
name: (Optional) string name of the metric instance.
|
| 110 |
+
dtype: (Optional) data type of the metric result.
|
| 111 |
+
axis: (Optional) integer or tuple of integers of the axis/axes along
|
| 112 |
+
which to compute the metric. Defaults to `-1`.
|
| 113 |
+
|
| 114 |
+
Example:
|
| 115 |
+
|
| 116 |
+
>>> pcc = keras.metrics.PearsonCorrelation(axis=-1)
|
| 117 |
+
>>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]
|
| 118 |
+
>>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]
|
| 119 |
+
>>> pcc.update_state(y_true, y_pred)
|
| 120 |
+
>>> pcc.result()
|
| 121 |
+
0.9966996338993913
|
| 122 |
+
|
| 123 |
+
Usage with `compile()` API:
|
| 124 |
+
|
| 125 |
+
```python
|
| 126 |
+
model.compile(optimizer='sgd',
|
| 127 |
+
loss='mean_squared_error',
|
| 128 |
+
metrics=[keras.metrics.PearsonCorrelation()])
|
| 129 |
+
```
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
name="pearson_correlation",
|
| 135 |
+
dtype=None,
|
| 136 |
+
axis=-1,
|
| 137 |
+
):
|
| 138 |
+
super().__init__(
|
| 139 |
+
fn=pearson_correlation,
|
| 140 |
+
name=name,
|
| 141 |
+
dtype=dtype,
|
| 142 |
+
axis=axis,
|
| 143 |
+
)
|
| 144 |
+
self.axis = axis
|
| 145 |
+
# Metric should be maximized during optimization.
|
| 146 |
+
self._direction = "up"
|
| 147 |
+
|
| 148 |
+
def get_config(self):
|
| 149 |
+
return {
|
| 150 |
+
"name": self.name,
|
| 151 |
+
"dtype": self.dtype,
|
| 152 |
+
"axis": self.axis,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@keras_export("keras.metrics.ConcordanceCorrelation")
|
| 157 |
+
class ConcordanceCorrelation(reduction_metrics.MeanMetricWrapper):
|
| 158 |
+
"""Calculates the Concordance Correlation Coefficient (CCC).
|
| 159 |
+
|
| 160 |
+
CCC evaluates the agreement between true values (`y_true`) and predicted
|
| 161 |
+
values (`y_pred`) by considering both precision and accuracy. The
|
| 162 |
+
coefficient ranges from -1 to 1, where a value of 1 indicates perfect
|
| 163 |
+
agreement.
|
| 164 |
+
|
| 165 |
+
This metric is useful in regression tasks where it is important to assess
|
| 166 |
+
how well the predictions match the true values, taking into account both
|
| 167 |
+
their correlation and proximity to the 45-degree line of perfect
|
| 168 |
+
concordance.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
name: (Optional) string name of the metric instance.
|
| 172 |
+
dtype: (Optional) data type of the metric result.
|
| 173 |
+
axis: (Optional) integer or tuple of integers of the axis/axes along
|
| 174 |
+
which to compute the metric. Defaults to `-1`.
|
| 175 |
+
|
| 176 |
+
Example:
|
| 177 |
+
|
| 178 |
+
>>> ccc = keras.metrics.ConcordanceCorrelation(axis=-1)
|
| 179 |
+
>>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]
|
| 180 |
+
>>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]
|
| 181 |
+
>>> ccc.update_state(y_true, y_pred)
|
| 182 |
+
>>> ccc.result()
|
| 183 |
+
0.9816320385426076
|
| 184 |
+
|
| 185 |
+
Usage with `compile()` API:
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
model.compile(optimizer='sgd',
|
| 189 |
+
loss='mean_squared_error',
|
| 190 |
+
metrics=[keras.metrics.ConcordanceCorrelation()])
|
| 191 |
+
```
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
name="concordance_correlation",
|
| 197 |
+
dtype=None,
|
| 198 |
+
axis=-1,
|
| 199 |
+
):
|
| 200 |
+
super().__init__(
|
| 201 |
+
fn=concordance_correlation,
|
| 202 |
+
name=name,
|
| 203 |
+
dtype=dtype,
|
| 204 |
+
axis=axis,
|
| 205 |
+
)
|
| 206 |
+
self.axis = axis
|
| 207 |
+
# Metric should be maximized during optimization.
|
| 208 |
+
self._direction = "up"
|
| 209 |
+
|
| 210 |
+
def get_config(self):
|
| 211 |
+
return {
|
| 212 |
+
"name": self.name,
|
| 213 |
+
"dtype": self.dtype,
|
| 214 |
+
"axis": self.axis,
|
| 215 |
+
}
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/f_score_metrics.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src import initializers
|
| 3 |
+
from keras.src import ops
|
| 4 |
+
from keras.src.api_export import keras_export
|
| 5 |
+
from keras.src.metrics.metric import Metric
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@keras_export("keras.metrics.FBetaScore")
|
| 9 |
+
class FBetaScore(Metric):
|
| 10 |
+
"""Computes F-Beta score.
|
| 11 |
+
|
| 12 |
+
Formula:
|
| 13 |
+
|
| 14 |
+
```python
|
| 15 |
+
b2 = beta ** 2
|
| 16 |
+
f_beta_score = (1 + b2) * (precision * recall) / (precision * b2 + recall)
|
| 17 |
+
```
|
| 18 |
+
This is the weighted harmonic mean of precision and recall.
|
| 19 |
+
Its output range is `[0, 1]`. It works for both multi-class
|
| 20 |
+
and multi-label classification.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
average: Type of averaging to be performed across per-class results
|
| 24 |
+
in the multi-class case.
|
| 25 |
+
Acceptable values are `None`, `"micro"`, `"macro"` and
|
| 26 |
+
`"weighted"`. Defaults to `None`.
|
| 27 |
+
If `None`, no averaging is performed and `result()` will return
|
| 28 |
+
the score for each class.
|
| 29 |
+
If `"micro"`, compute metrics globally by counting the total
|
| 30 |
+
true positives, false negatives and false positives.
|
| 31 |
+
If `"macro"`, compute metrics for each label,
|
| 32 |
+
and return their unweighted mean.
|
| 33 |
+
This does not take label imbalance into account.
|
| 34 |
+
If `"weighted"`, compute metrics for each label,
|
| 35 |
+
and return their average weighted by support
|
| 36 |
+
(the number of true instances for each label).
|
| 37 |
+
This alters `"macro"` to account for label imbalance.
|
| 38 |
+
It can result in an score that is not between precision and recall.
|
| 39 |
+
beta: Determines the weight of given to recall
|
| 40 |
+
in the harmonic mean between precision and recall (see pseudocode
|
| 41 |
+
equation above). Defaults to `1`.
|
| 42 |
+
threshold: Elements of `y_pred` greater than `threshold` are
|
| 43 |
+
converted to be 1, and the rest 0. If `threshold` is
|
| 44 |
+
`None`, the argmax of `y_pred` is converted to 1, and the rest to 0.
|
| 45 |
+
name: Optional. String name of the metric instance.
|
| 46 |
+
dtype: Optional. Data type of the metric result.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
F-Beta Score: float.
|
| 50 |
+
|
| 51 |
+
Example:
|
| 52 |
+
|
| 53 |
+
>>> metric = keras.metrics.FBetaScore(beta=2.0, threshold=0.5)
|
| 54 |
+
>>> y_true = np.array([[1, 1, 1],
|
| 55 |
+
... [1, 0, 0],
|
| 56 |
+
... [1, 1, 0]], np.int32)
|
| 57 |
+
>>> y_pred = np.array([[0.2, 0.6, 0.7],
|
| 58 |
+
... [0.2, 0.6, 0.6],
|
| 59 |
+
... [0.6, 0.8, 0.0]], np.float32)
|
| 60 |
+
>>> metric.update_state(y_true, y_pred)
|
| 61 |
+
>>> result = metric.result()
|
| 62 |
+
>>> result
|
| 63 |
+
[0.3846154 , 0.90909094, 0.8333334 ]
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
average=None,
|
| 69 |
+
beta=1.0,
|
| 70 |
+
threshold=None,
|
| 71 |
+
name="fbeta_score",
|
| 72 |
+
dtype=None,
|
| 73 |
+
):
|
| 74 |
+
super().__init__(name=name, dtype=dtype)
|
| 75 |
+
# Metric should be maximized during optimization.
|
| 76 |
+
self._direction = "up"
|
| 77 |
+
|
| 78 |
+
if average not in (None, "micro", "macro", "weighted"):
|
| 79 |
+
raise ValueError(
|
| 80 |
+
"Invalid `average` argument value. Expected one of: "
|
| 81 |
+
"{None, 'micro', 'macro', 'weighted'}. "
|
| 82 |
+
f"Received: average={average}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if not isinstance(beta, float):
|
| 86 |
+
raise ValueError(
|
| 87 |
+
"Invalid `beta` argument value. "
|
| 88 |
+
"It should be a Python float. "
|
| 89 |
+
f"Received: beta={beta} of type '{type(beta)}'"
|
| 90 |
+
)
|
| 91 |
+
if beta <= 0.0:
|
| 92 |
+
raise ValueError(
|
| 93 |
+
"Invalid `beta` argument value. "
|
| 94 |
+
"It should be > 0. "
|
| 95 |
+
f"Received: beta={beta}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if threshold is not None:
|
| 99 |
+
if not isinstance(threshold, float):
|
| 100 |
+
raise ValueError(
|
| 101 |
+
"Invalid `threshold` argument value. "
|
| 102 |
+
"It should be a Python float. "
|
| 103 |
+
f"Received: threshold={threshold} "
|
| 104 |
+
f"of type '{type(threshold)}'"
|
| 105 |
+
)
|
| 106 |
+
if threshold > 1.0 or threshold <= 0.0:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
"Invalid `threshold` argument value. "
|
| 109 |
+
"It should verify 0 < threshold <= 1. "
|
| 110 |
+
f"Received: threshold={threshold}"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.average = average
|
| 114 |
+
self.beta = beta
|
| 115 |
+
self.threshold = threshold
|
| 116 |
+
self.axis = None
|
| 117 |
+
self._built = False
|
| 118 |
+
|
| 119 |
+
if self.average != "micro":
|
| 120 |
+
self.axis = 0
|
| 121 |
+
|
| 122 |
+
def _build(self, y_true_shape, y_pred_shape):
|
| 123 |
+
if len(y_pred_shape) != 2 or len(y_true_shape) != 2:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
"FBetaScore expects 2D inputs with shape "
|
| 126 |
+
"(batch_size, output_dim). Received input "
|
| 127 |
+
f"shapes: y_pred.shape={y_pred_shape} and "
|
| 128 |
+
f"y_true.shape={y_true_shape}."
|
| 129 |
+
)
|
| 130 |
+
if y_pred_shape[-1] is None or y_true_shape[-1] is None:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
"FBetaScore expects 2D inputs with shape "
|
| 133 |
+
"(batch_size, output_dim), with output_dim fully "
|
| 134 |
+
"defined (not None). Received input "
|
| 135 |
+
f"shapes: y_pred.shape={y_pred_shape} and "
|
| 136 |
+
f"y_true.shape={y_true_shape}."
|
| 137 |
+
)
|
| 138 |
+
num_classes = y_pred_shape[-1]
|
| 139 |
+
if self.average != "micro":
|
| 140 |
+
init_shape = (num_classes,)
|
| 141 |
+
else:
|
| 142 |
+
init_shape = ()
|
| 143 |
+
|
| 144 |
+
def _add_zeros_variable(name):
|
| 145 |
+
return self.add_variable(
|
| 146 |
+
name=name,
|
| 147 |
+
shape=init_shape,
|
| 148 |
+
initializer=initializers.Zeros(),
|
| 149 |
+
dtype=self.dtype,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.true_positives = _add_zeros_variable("true_positives")
|
| 153 |
+
self.false_positives = _add_zeros_variable("false_positives")
|
| 154 |
+
self.false_negatives = _add_zeros_variable("false_negatives")
|
| 155 |
+
self.intermediate_weights = _add_zeros_variable("intermediate_weights")
|
| 156 |
+
self._built = True
|
| 157 |
+
|
| 158 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 159 |
+
y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
|
| 160 |
+
y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)
|
| 161 |
+
if not self._built:
|
| 162 |
+
self._build(y_true.shape, y_pred.shape)
|
| 163 |
+
|
| 164 |
+
if self.threshold is None:
|
| 165 |
+
threshold = ops.max(y_pred, axis=-1, keepdims=True)
|
| 166 |
+
# make sure [0, 0, 0] doesn't become [1, 1, 1]
|
| 167 |
+
# Use abs(x) > eps, instead of x != 0 to check for zero
|
| 168 |
+
y_pred = ops.logical_and(
|
| 169 |
+
y_pred >= threshold, ops.abs(y_pred) > 1e-9
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
y_pred = y_pred > self.threshold
|
| 173 |
+
|
| 174 |
+
y_pred = ops.cast(y_pred, dtype=self.dtype)
|
| 175 |
+
y_true = ops.cast(y_true, dtype=self.dtype)
|
| 176 |
+
if sample_weight is not None:
|
| 177 |
+
sample_weight = ops.convert_to_tensor(
|
| 178 |
+
sample_weight, dtype=self.dtype
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def _weighted_sum(val, sample_weight):
|
| 182 |
+
if sample_weight is not None:
|
| 183 |
+
val = ops.multiply(val, ops.expand_dims(sample_weight, 1))
|
| 184 |
+
return ops.sum(val, axis=self.axis)
|
| 185 |
+
|
| 186 |
+
self.true_positives.assign(
|
| 187 |
+
self.true_positives + _weighted_sum(y_pred * y_true, sample_weight)
|
| 188 |
+
)
|
| 189 |
+
self.false_positives.assign(
|
| 190 |
+
self.false_positives
|
| 191 |
+
+ _weighted_sum(y_pred * (1 - y_true), sample_weight)
|
| 192 |
+
)
|
| 193 |
+
self.false_negatives.assign(
|
| 194 |
+
self.false_negatives
|
| 195 |
+
+ _weighted_sum((1 - y_pred) * y_true, sample_weight)
|
| 196 |
+
)
|
| 197 |
+
self.intermediate_weights.assign(
|
| 198 |
+
self.intermediate_weights + _weighted_sum(y_true, sample_weight)
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def result(self):
|
| 202 |
+
precision = ops.divide(
|
| 203 |
+
self.true_positives,
|
| 204 |
+
self.true_positives + self.false_positives + backend.epsilon(),
|
| 205 |
+
)
|
| 206 |
+
recall = ops.divide(
|
| 207 |
+
self.true_positives,
|
| 208 |
+
self.true_positives + self.false_negatives + backend.epsilon(),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
precision = ops.convert_to_tensor(precision, dtype=self.dtype)
|
| 212 |
+
recall = ops.convert_to_tensor(recall, dtype=self.dtype)
|
| 213 |
+
|
| 214 |
+
mul_value = precision * recall
|
| 215 |
+
add_value = ((self.beta**2) * precision) + recall
|
| 216 |
+
mean = ops.divide(mul_value, add_value + backend.epsilon())
|
| 217 |
+
f1_score = mean * (1 + (self.beta**2))
|
| 218 |
+
|
| 219 |
+
if self.average == "weighted":
|
| 220 |
+
weights = ops.divide(
|
| 221 |
+
self.intermediate_weights,
|
| 222 |
+
ops.sum(self.intermediate_weights) + backend.epsilon(),
|
| 223 |
+
)
|
| 224 |
+
f1_score = ops.sum(f1_score * weights)
|
| 225 |
+
|
| 226 |
+
elif self.average is not None: # [micro, macro]
|
| 227 |
+
f1_score = ops.mean(f1_score)
|
| 228 |
+
|
| 229 |
+
return f1_score
|
| 230 |
+
|
| 231 |
+
def get_config(self):
|
| 232 |
+
"""Returns the serializable config of the metric."""
|
| 233 |
+
|
| 234 |
+
config = {
|
| 235 |
+
"name": self.name,
|
| 236 |
+
"dtype": self.dtype,
|
| 237 |
+
"average": self.average,
|
| 238 |
+
"beta": self.beta,
|
| 239 |
+
"threshold": self.threshold,
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
base_config = super().get_config()
|
| 243 |
+
return {**base_config, **config}
|
| 244 |
+
|
| 245 |
+
def reset_state(self):
|
| 246 |
+
for v in self.variables:
|
| 247 |
+
v.assign(ops.zeros(v.shape, dtype=v.dtype))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@keras_export("keras.metrics.F1Score")
|
| 251 |
+
class F1Score(FBetaScore):
|
| 252 |
+
r"""Computes F-1 Score.
|
| 253 |
+
|
| 254 |
+
Formula:
|
| 255 |
+
|
| 256 |
+
```python
|
| 257 |
+
f1_score = 2 * (precision * recall) / (precision + recall)
|
| 258 |
+
```
|
| 259 |
+
This is the harmonic mean of precision and recall.
|
| 260 |
+
Its output range is `[0, 1]`. It works for both multi-class
|
| 261 |
+
and multi-label classification.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
average: Type of averaging to be performed on data.
|
| 265 |
+
Acceptable values are `None`, `"micro"`, `"macro"`
|
| 266 |
+
and `"weighted"`. Defaults to `None`.
|
| 267 |
+
If `None`, no averaging is performed and `result()` will return
|
| 268 |
+
the score for each class.
|
| 269 |
+
If `"micro"`, compute metrics globally by counting the total
|
| 270 |
+
true positives, false negatives and false positives.
|
| 271 |
+
If `"macro"`, compute metrics for each label,
|
| 272 |
+
and return their unweighted mean.
|
| 273 |
+
This does not take label imbalance into account.
|
| 274 |
+
If `"weighted"`, compute metrics for each label,
|
| 275 |
+
and return their average weighted by support
|
| 276 |
+
(the number of true instances for each label).
|
| 277 |
+
This alters `"macro"` to account for label imbalance.
|
| 278 |
+
It can result in an score that is not between precision and recall.
|
| 279 |
+
threshold: Elements of `y_pred` greater than `threshold` are
|
| 280 |
+
converted to be 1, and the rest 0. If `threshold` is
|
| 281 |
+
`None`, the argmax of `y_pred` is converted to 1, and the rest to 0.
|
| 282 |
+
name: Optional. String name of the metric instance.
|
| 283 |
+
dtype: Optional. Data type of the metric result.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
F-1 Score: float.
|
| 287 |
+
|
| 288 |
+
Example:
|
| 289 |
+
|
| 290 |
+
>>> metric = keras.metrics.F1Score(threshold=0.5)
|
| 291 |
+
>>> y_true = np.array([[1, 1, 1],
|
| 292 |
+
... [1, 0, 0],
|
| 293 |
+
... [1, 1, 0]], np.int32)
|
| 294 |
+
>>> y_pred = np.array([[0.2, 0.6, 0.7],
|
| 295 |
+
... [0.2, 0.6, 0.6],
|
| 296 |
+
... [0.6, 0.8, 0.0]], np.float32)
|
| 297 |
+
>>> metric.update_state(y_true, y_pred)
|
| 298 |
+
>>> result = metric.result()
|
| 299 |
+
array([0.5 , 0.8 , 0.6666667], dtype=float32)
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def __init__(
|
| 303 |
+
self,
|
| 304 |
+
average=None,
|
| 305 |
+
threshold=None,
|
| 306 |
+
name="f1_score",
|
| 307 |
+
dtype=None,
|
| 308 |
+
):
|
| 309 |
+
super().__init__(
|
| 310 |
+
average=average,
|
| 311 |
+
beta=1.0,
|
| 312 |
+
threshold=threshold,
|
| 313 |
+
name=name,
|
| 314 |
+
dtype=dtype,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def get_config(self):
|
| 318 |
+
base_config = super().get_config()
|
| 319 |
+
del base_config["beta"]
|
| 320 |
+
return base_config
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/hinge_metrics.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src.api_export import keras_export
|
| 2 |
+
from keras.src.losses.losses import categorical_hinge
|
| 3 |
+
from keras.src.losses.losses import hinge
|
| 4 |
+
from keras.src.losses.losses import squared_hinge
|
| 5 |
+
from keras.src.metrics import reduction_metrics
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@keras_export("keras.metrics.Hinge")
|
| 9 |
+
class Hinge(reduction_metrics.MeanMetricWrapper):
|
| 10 |
+
"""Computes the hinge metric between `y_true` and `y_pred`.
|
| 11 |
+
|
| 12 |
+
`y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
|
| 13 |
+
provided we will convert them to -1 or 1.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
name: (Optional) string name of the metric instance.
|
| 17 |
+
dtype: (Optional) data type of the metric result.
|
| 18 |
+
|
| 19 |
+
Examples:
|
| 20 |
+
|
| 21 |
+
>>> m = keras.metrics.Hinge()
|
| 22 |
+
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
|
| 23 |
+
>>> m.result()
|
| 24 |
+
1.3
|
| 25 |
+
>>> m.reset_state()
|
| 26 |
+
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
|
| 27 |
+
... sample_weight=[1, 0])
|
| 28 |
+
>>> m.result()
|
| 29 |
+
1.1
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, name="hinge", dtype=None):
|
| 33 |
+
super().__init__(fn=hinge, name=name, dtype=dtype)
|
| 34 |
+
# Metric should be minimized during optimization.
|
| 35 |
+
self._direction = "down"
|
| 36 |
+
|
| 37 |
+
def get_config(self):
|
| 38 |
+
return {"name": self.name, "dtype": self.dtype}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@keras_export("keras.metrics.SquaredHinge")
|
| 42 |
+
class SquaredHinge(reduction_metrics.MeanMetricWrapper):
|
| 43 |
+
"""Computes the hinge metric between `y_true` and `y_pred`.
|
| 44 |
+
|
| 45 |
+
`y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
|
| 46 |
+
provided we will convert them to -1 or 1.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
name: (Optional) string name of the metric instance.
|
| 50 |
+
dtype: (Optional) data type of the metric result.
|
| 51 |
+
|
| 52 |
+
Example:
|
| 53 |
+
|
| 54 |
+
>>> m = keras.metrics.SquaredHinge()
|
| 55 |
+
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
|
| 56 |
+
>>> m.result()
|
| 57 |
+
1.86
|
| 58 |
+
>>> m.reset_state()
|
| 59 |
+
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
|
| 60 |
+
... sample_weight=[1, 0])
|
| 61 |
+
>>> m.result()
|
| 62 |
+
1.46
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, name="squared_hinge", dtype=None):
|
| 66 |
+
super().__init__(fn=squared_hinge, name=name, dtype=dtype)
|
| 67 |
+
# Metric should be minimized during optimization.
|
| 68 |
+
self._direction = "down"
|
| 69 |
+
|
| 70 |
+
def get_config(self):
|
| 71 |
+
return {"name": self.name, "dtype": self.dtype}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@keras_export("keras.metrics.CategoricalHinge")
|
| 75 |
+
class CategoricalHinge(reduction_metrics.MeanMetricWrapper):
|
| 76 |
+
"""Computes the categorical hinge metric between `y_true` and `y_pred`.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
name: (Optional) string name of the metric instance.
|
| 80 |
+
dtype: (Optional) data type of the metric result.
|
| 81 |
+
|
| 82 |
+
Example:
|
| 83 |
+
>>> m = keras.metrics.CategoricalHinge()
|
| 84 |
+
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
|
| 85 |
+
>>> m.result().numpy()
|
| 86 |
+
1.4000001
|
| 87 |
+
>>> m.reset_state()
|
| 88 |
+
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
|
| 89 |
+
... sample_weight=[1, 0])
|
| 90 |
+
>>> m.result()
|
| 91 |
+
1.2
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, name="categorical_hinge", dtype=None):
|
| 95 |
+
super().__init__(fn=categorical_hinge, name=name, dtype=dtype)
|
| 96 |
+
# Metric should be minimized during optimization.
|
| 97 |
+
self._direction = "down"
|
| 98 |
+
|
| 99 |
+
def get_config(self):
|
| 100 |
+
return {"name": self.name, "dtype": self.dtype}
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/iou_metrics.py
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
from keras.src import backend
|
| 4 |
+
from keras.src import initializers
|
| 5 |
+
from keras.src import ops
|
| 6 |
+
from keras.src.api_export import keras_export
|
| 7 |
+
from keras.src.metrics.metric import Metric
|
| 8 |
+
from keras.src.metrics.metrics_utils import confusion_matrix
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class _IoUBase(Metric):
|
| 12 |
+
"""Computes the confusion matrix for Intersection-Over-Union metrics.
|
| 13 |
+
|
| 14 |
+
Formula:
|
| 15 |
+
|
| 16 |
+
```python
|
| 17 |
+
iou = true_positives / (true_positives + false_positives + false_negatives)
|
| 18 |
+
```
|
| 19 |
+
Intersection-Over-Union is a common evaluation metric for semantic image
|
| 20 |
+
segmentation.
|
| 21 |
+
|
| 22 |
+
From IoUs of individual classes, the MeanIoU can be computed as the mean of
|
| 23 |
+
the individual IoUs.
|
| 24 |
+
|
| 25 |
+
To compute IoUs, the predictions are accumulated in a confusion matrix,
|
| 26 |
+
weighted by `sample_weight` and the metric is then calculated from it.
|
| 27 |
+
|
| 28 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 29 |
+
Use `sample_weight` of 0 to mask values.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
num_classes: The possible number of labels the prediction task can have.
|
| 33 |
+
name: (Optional) string name of the metric instance.
|
| 34 |
+
dtype: (Optional) data type of the metric result.
|
| 35 |
+
ignore_class: Optional integer. The ID of a class to be ignored during
|
| 36 |
+
metric computation. This is useful, for example, in segmentation
|
| 37 |
+
problems featuring a "void" class (commonly -1 or 255) in
|
| 38 |
+
segmentation maps. By default (`ignore_class=None`), all classes are
|
| 39 |
+
considered.
|
| 40 |
+
sparse_y_true: Whether labels are encoded using integers or
|
| 41 |
+
dense floating point vectors. If `False`, the `argmax` function
|
| 42 |
+
is used to determine each sample's most likely associated label.
|
| 43 |
+
sparse_y_pred: Whether predictions are encoded using integers or
|
| 44 |
+
dense floating point vectors. If `False`, the `argmax` function
|
| 45 |
+
is used to determine each sample's most likely associated label.
|
| 46 |
+
axis: (Optional) -1 is the dimension containing the logits.
|
| 47 |
+
Defaults to `-1`.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
num_classes,
|
| 53 |
+
name=None,
|
| 54 |
+
dtype=None,
|
| 55 |
+
ignore_class=None,
|
| 56 |
+
sparse_y_true=True,
|
| 57 |
+
sparse_y_pred=True,
|
| 58 |
+
axis=-1,
|
| 59 |
+
):
|
| 60 |
+
# defaulting to int to avoid issues with confusion matrix
|
| 61 |
+
super().__init__(name=name, dtype=dtype or "int")
|
| 62 |
+
# Metric should be maximized during optimization.
|
| 63 |
+
self._direction = "up"
|
| 64 |
+
self.num_classes = num_classes
|
| 65 |
+
self.ignore_class = ignore_class
|
| 66 |
+
self.sparse_y_true = sparse_y_true
|
| 67 |
+
self.sparse_y_pred = sparse_y_pred
|
| 68 |
+
self.axis = axis
|
| 69 |
+
|
| 70 |
+
self.total_cm = self.add_variable(
|
| 71 |
+
name="total_confusion_matrix",
|
| 72 |
+
shape=(num_classes, num_classes),
|
| 73 |
+
initializer=initializers.Zeros(),
|
| 74 |
+
dtype=self.dtype,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 78 |
+
"""Accumulates the confusion matrix statistics.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
y_true: The ground truth values.
|
| 82 |
+
y_pred: The predicted values.
|
| 83 |
+
sample_weight: Optional weighting of each example. Can
|
| 84 |
+
be a `Tensor` whose rank is either 0, or the same as `y_true`,
|
| 85 |
+
and must be broadcastable to `y_true`. Defaults to `1`.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Update op.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
if not self.sparse_y_true:
|
| 92 |
+
y_true = ops.argmax(y_true, axis=self.axis)
|
| 93 |
+
if not self.sparse_y_pred:
|
| 94 |
+
y_pred = ops.argmax(y_pred, axis=self.axis)
|
| 95 |
+
|
| 96 |
+
y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
|
| 97 |
+
y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)
|
| 98 |
+
|
| 99 |
+
# Flatten the input if its rank > 1.
|
| 100 |
+
if len(y_pred.shape) > 1:
|
| 101 |
+
y_pred = ops.reshape(y_pred, [-1])
|
| 102 |
+
|
| 103 |
+
if len(y_true.shape) > 1:
|
| 104 |
+
y_true = ops.reshape(y_true, [-1])
|
| 105 |
+
|
| 106 |
+
if sample_weight is None:
|
| 107 |
+
sample_weight = 1
|
| 108 |
+
else:
|
| 109 |
+
if (
|
| 110 |
+
hasattr(sample_weight, "dtype")
|
| 111 |
+
and "float" in str(sample_weight.dtype)
|
| 112 |
+
and "int" in str(self.dtype)
|
| 113 |
+
):
|
| 114 |
+
warnings.warn(
|
| 115 |
+
"You are passing weight as `float`, but dtype is `int`. "
|
| 116 |
+
"This may result in an incorrect weight due to type casting"
|
| 117 |
+
" Consider using integer weights."
|
| 118 |
+
)
|
| 119 |
+
sample_weight = ops.convert_to_tensor(sample_weight, dtype=self.dtype)
|
| 120 |
+
|
| 121 |
+
if len(sample_weight.shape) > 1:
|
| 122 |
+
sample_weight = ops.reshape(sample_weight, [-1])
|
| 123 |
+
|
| 124 |
+
sample_weight = ops.broadcast_to(sample_weight, ops.shape(y_true))
|
| 125 |
+
|
| 126 |
+
if self.ignore_class is not None:
|
| 127 |
+
ignore_class = ops.convert_to_tensor(
|
| 128 |
+
self.ignore_class, y_true.dtype
|
| 129 |
+
)
|
| 130 |
+
valid_mask = ops.not_equal(y_true, ignore_class)
|
| 131 |
+
y_true = y_true * ops.cast(valid_mask, y_true.dtype)
|
| 132 |
+
y_pred = y_pred * ops.cast(valid_mask, y_pred.dtype)
|
| 133 |
+
if sample_weight is not None:
|
| 134 |
+
sample_weight = sample_weight * ops.cast(
|
| 135 |
+
valid_mask, sample_weight.dtype
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
y_pred = ops.cast(y_pred, dtype=self.dtype)
|
| 139 |
+
y_true = ops.cast(y_true, dtype=self.dtype)
|
| 140 |
+
sample_weight = ops.cast(sample_weight, dtype=self.dtype)
|
| 141 |
+
|
| 142 |
+
current_cm = confusion_matrix(
|
| 143 |
+
y_true,
|
| 144 |
+
y_pred,
|
| 145 |
+
self.num_classes,
|
| 146 |
+
weights=sample_weight,
|
| 147 |
+
dtype=self.dtype,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return self.total_cm.assign(self.total_cm + current_cm)
|
| 151 |
+
|
| 152 |
+
def reset_state(self):
|
| 153 |
+
self.total_cm.assign(
|
| 154 |
+
ops.zeros(self.total_cm.shape, dtype=self.total_cm.dtype)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@keras_export("keras.metrics.IoU")
|
| 159 |
+
class IoU(_IoUBase):
|
| 160 |
+
"""Computes the Intersection-Over-Union metric for specific target classes.
|
| 161 |
+
|
| 162 |
+
Formula:
|
| 163 |
+
|
| 164 |
+
```python
|
| 165 |
+
iou = true_positives / (true_positives + false_positives + false_negatives)
|
| 166 |
+
```
|
| 167 |
+
Intersection-Over-Union is a common evaluation metric for semantic image
|
| 168 |
+
segmentation.
|
| 169 |
+
|
| 170 |
+
To compute IoUs, the predictions are accumulated in a confusion matrix,
|
| 171 |
+
weighted by `sample_weight` and the metric is then calculated from it.
|
| 172 |
+
|
| 173 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 174 |
+
Use `sample_weight` of 0 to mask values.
|
| 175 |
+
|
| 176 |
+
Note, this class first computes IoUs for all individual classes, then
|
| 177 |
+
returns the mean of IoUs for the classes that are specified by
|
| 178 |
+
`target_class_ids`. If `target_class_ids` has only one id value, the IoU of
|
| 179 |
+
that specific class is returned.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
num_classes: The possible number of labels the prediction task can have.
|
| 183 |
+
target_class_ids: A tuple or list of target class ids for which the
|
| 184 |
+
metric is returned. To compute IoU for a specific class, a list
|
| 185 |
+
(or tuple) of a single id value should be provided.
|
| 186 |
+
name: (Optional) string name of the metric instance.
|
| 187 |
+
dtype: (Optional) data type of the metric result.
|
| 188 |
+
ignore_class: Optional integer. The ID of a class to be ignored during
|
| 189 |
+
metric computation. This is useful, for example, in segmentation
|
| 190 |
+
problems featuring a "void" class (commonly -1 or 255) in
|
| 191 |
+
segmentation maps. By default (`ignore_class=None`), all classes are
|
| 192 |
+
considered.
|
| 193 |
+
sparse_y_true: Whether labels are encoded using integers or
|
| 194 |
+
dense floating point vectors. If `False`, the `argmax` function
|
| 195 |
+
is used to determine each sample's most likely associated label.
|
| 196 |
+
sparse_y_pred: Whether predictions are encoded using integers or
|
| 197 |
+
dense floating point vectors. If `False`, the `argmax` function
|
| 198 |
+
is used to determine each sample's most likely associated label.
|
| 199 |
+
axis: (Optional) -1 is the dimension containing the logits.
|
| 200 |
+
Defaults to `-1`.
|
| 201 |
+
|
| 202 |
+
Examples:
|
| 203 |
+
|
| 204 |
+
>>> # cm = [[1, 1],
|
| 205 |
+
>>> # [1, 1]]
|
| 206 |
+
>>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
|
| 207 |
+
>>> # iou = true_positives / (sum_row + sum_col - true_positives))
|
| 208 |
+
>>> # iou = [0.33, 0.33]
|
| 209 |
+
>>> m = keras.metrics.IoU(num_classes=2, target_class_ids=[0])
|
| 210 |
+
>>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
|
| 211 |
+
>>> m.result()
|
| 212 |
+
0.33333334
|
| 213 |
+
|
| 214 |
+
>>> m.reset_state()
|
| 215 |
+
>>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
|
| 216 |
+
... sample_weight=[0.3, 0.3, 0.3, 0.1])
|
| 217 |
+
>>> # cm = [[0.3, 0.3],
|
| 218 |
+
>>> # [0.3, 0.1]]
|
| 219 |
+
>>> # sum_row = [0.6, 0.4], sum_col = [0.6, 0.4],
|
| 220 |
+
>>> # true_positives = [0.3, 0.1]
|
| 221 |
+
>>> # iou = [0.33, 0.14]
|
| 222 |
+
>>> m.result()
|
| 223 |
+
0.33333334
|
| 224 |
+
|
| 225 |
+
Usage with `compile()` API:
|
| 226 |
+
|
| 227 |
+
```python
|
| 228 |
+
model.compile(
|
| 229 |
+
optimizer='sgd',
|
| 230 |
+
loss='mse',
|
| 231 |
+
metrics=[keras.metrics.IoU(num_classes=2, target_class_ids=[0])])
|
| 232 |
+
```
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
num_classes,
|
| 238 |
+
target_class_ids,
|
| 239 |
+
name=None,
|
| 240 |
+
dtype=None,
|
| 241 |
+
ignore_class=None,
|
| 242 |
+
sparse_y_true=True,
|
| 243 |
+
sparse_y_pred=True,
|
| 244 |
+
axis=-1,
|
| 245 |
+
):
|
| 246 |
+
super().__init__(
|
| 247 |
+
name=name,
|
| 248 |
+
num_classes=num_classes,
|
| 249 |
+
ignore_class=ignore_class,
|
| 250 |
+
sparse_y_true=sparse_y_true,
|
| 251 |
+
sparse_y_pred=sparse_y_pred,
|
| 252 |
+
axis=axis,
|
| 253 |
+
dtype=dtype,
|
| 254 |
+
)
|
| 255 |
+
if max(target_class_ids) >= num_classes:
|
| 256 |
+
raise ValueError(
|
| 257 |
+
f"Target class id {max(target_class_ids)} "
|
| 258 |
+
"is out of range, which is "
|
| 259 |
+
f"[{0}, {num_classes})."
|
| 260 |
+
)
|
| 261 |
+
self.target_class_ids = list(target_class_ids)
|
| 262 |
+
|
| 263 |
+
def result(self):
|
| 264 |
+
"""Compute the intersection-over-union via the confusion matrix."""
|
| 265 |
+
sum_over_row = ops.cast(
|
| 266 |
+
ops.sum(self.total_cm, axis=0), dtype=self.dtype
|
| 267 |
+
)
|
| 268 |
+
sum_over_col = ops.cast(
|
| 269 |
+
ops.sum(self.total_cm, axis=1), dtype=self.dtype
|
| 270 |
+
)
|
| 271 |
+
true_positives = ops.cast(ops.diag(self.total_cm), dtype=self.dtype)
|
| 272 |
+
|
| 273 |
+
# sum_over_row + sum_over_col =
|
| 274 |
+
# 2 * true_positives + false_positives + false_negatives.
|
| 275 |
+
denominator = sum_over_row + sum_over_col - true_positives
|
| 276 |
+
|
| 277 |
+
target_class_ids = ops.convert_to_tensor(
|
| 278 |
+
self.target_class_ids, dtype="int32"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Only keep the target classes
|
| 282 |
+
true_positives = ops.take_along_axis(
|
| 283 |
+
true_positives, target_class_ids, axis=-1
|
| 284 |
+
)
|
| 285 |
+
denominator = ops.take_along_axis(
|
| 286 |
+
denominator, target_class_ids, axis=-1
|
| 287 |
+
)
|
| 288 |
+
denominator = ops.cast(denominator, dtype="float32")
|
| 289 |
+
|
| 290 |
+
# If the denominator is 0, we need to ignore the class.
|
| 291 |
+
num_valid_entries = ops.sum(
|
| 292 |
+
ops.cast(ops.greater(denominator, 1e-9), dtype="float32")
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
iou = ops.divide(true_positives, denominator + backend.epsilon())
|
| 296 |
+
|
| 297 |
+
return ops.divide(
|
| 298 |
+
ops.sum(iou, axis=self.axis), num_valid_entries + backend.epsilon()
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def get_config(self):
|
| 302 |
+
config = {
|
| 303 |
+
"num_classes": self.num_classes,
|
| 304 |
+
"target_class_ids": self.target_class_ids,
|
| 305 |
+
"ignore_class": self.ignore_class,
|
| 306 |
+
"sparse_y_true": self.sparse_y_true,
|
| 307 |
+
"sparse_y_pred": self.sparse_y_pred,
|
| 308 |
+
"axis": self.axis,
|
| 309 |
+
}
|
| 310 |
+
base_config = super().get_config()
|
| 311 |
+
return dict(list(base_config.items()) + list(config.items()))
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
@keras_export("keras.metrics.BinaryIoU")
|
| 315 |
+
class BinaryIoU(IoU):
|
| 316 |
+
"""Computes the Intersection-Over-Union metric for class 0 and/or 1.
|
| 317 |
+
|
| 318 |
+
Formula:
|
| 319 |
+
|
| 320 |
+
```python
|
| 321 |
+
iou = true_positives / (true_positives + false_positives + false_negatives)
|
| 322 |
+
```
|
| 323 |
+
Intersection-Over-Union is a common evaluation metric for semantic image
|
| 324 |
+
segmentation.
|
| 325 |
+
|
| 326 |
+
To compute IoUs, the predictions are accumulated in a confusion matrix,
|
| 327 |
+
weighted by `sample_weight` and the metric is then calculated from it.
|
| 328 |
+
|
| 329 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 330 |
+
Use `sample_weight` of 0 to mask values.
|
| 331 |
+
|
| 332 |
+
This class can be used to compute IoUs for a binary classification task
|
| 333 |
+
where the predictions are provided as logits. First a `threshold` is applied
|
| 334 |
+
to the predicted values such that those that are below the `threshold` are
|
| 335 |
+
converted to class 0 and those that are above the `threshold` are converted
|
| 336 |
+
to class 1.
|
| 337 |
+
|
| 338 |
+
IoUs for classes 0 and 1 are then computed, the mean of IoUs for the classes
|
| 339 |
+
that are specified by `target_class_ids` is returned.
|
| 340 |
+
|
| 341 |
+
Note: with `threshold=0`, this metric has the same behavior as `IoU`.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
target_class_ids: A tuple or list of target class ids for which the
|
| 345 |
+
metric is returned. Options are `[0]`, `[1]`, or `[0, 1]`. With
|
| 346 |
+
`[0]` (or `[1]`), the IoU metric for class 0 (or class 1,
|
| 347 |
+
respectively) is returned. With `[0, 1]`, the mean of IoUs for the
|
| 348 |
+
two classes is returned.
|
| 349 |
+
threshold: A threshold that applies to the prediction logits to convert
|
| 350 |
+
them to either predicted class 0 if the logit is below `threshold`
|
| 351 |
+
or predicted class 1 if the logit is above `threshold`.
|
| 352 |
+
name: (Optional) string name of the metric instance.
|
| 353 |
+
dtype: (Optional) data type of the metric result.
|
| 354 |
+
|
| 355 |
+
Example:
|
| 356 |
+
|
| 357 |
+
Example:
|
| 358 |
+
|
| 359 |
+
>>> m = keras.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)
|
| 360 |
+
>>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7])
|
| 361 |
+
>>> m.result()
|
| 362 |
+
0.33333334
|
| 363 |
+
|
| 364 |
+
>>> m.reset_state()
|
| 365 |
+
>>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7],
|
| 366 |
+
... sample_weight=[0.2, 0.3, 0.4, 0.1])
|
| 367 |
+
>>> # cm = [[0.2, 0.4],
|
| 368 |
+
>>> # [0.3, 0.1]]
|
| 369 |
+
>>> # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5],
|
| 370 |
+
>>> # true_positives = [0.2, 0.1]
|
| 371 |
+
>>> # iou = [0.222, 0.125]
|
| 372 |
+
>>> m.result()
|
| 373 |
+
0.17361112
|
| 374 |
+
|
| 375 |
+
Usage with `compile()` API:
|
| 376 |
+
|
| 377 |
+
```python
|
| 378 |
+
model.compile(
|
| 379 |
+
optimizer='sgd',
|
| 380 |
+
loss='mse',
|
| 381 |
+
metrics=[keras.metrics.BinaryIoU(
|
| 382 |
+
target_class_ids=[0],
|
| 383 |
+
threshold=0.5
|
| 384 |
+
)]
|
| 385 |
+
)
|
| 386 |
+
```
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(
|
| 390 |
+
self,
|
| 391 |
+
target_class_ids=(0, 1),
|
| 392 |
+
threshold=0.5,
|
| 393 |
+
name=None,
|
| 394 |
+
dtype=None,
|
| 395 |
+
):
|
| 396 |
+
super().__init__(
|
| 397 |
+
num_classes=2,
|
| 398 |
+
target_class_ids=target_class_ids,
|
| 399 |
+
name=name,
|
| 400 |
+
dtype=dtype,
|
| 401 |
+
)
|
| 402 |
+
self.threshold = threshold
|
| 403 |
+
|
| 404 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 405 |
+
"""Accumulates the confusion matrix statistics.
|
| 406 |
+
|
| 407 |
+
Before the confusion matrix is updated, the predicted values are
|
| 408 |
+
thresholded to be:
|
| 409 |
+
0 for values that are smaller than the `threshold`
|
| 410 |
+
1 for values that are larger or equal to the `threshold`
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
y_true: The ground truth values.
|
| 414 |
+
y_pred: The predicted values.
|
| 415 |
+
sample_weight: Optional weighting of each example. Can
|
| 416 |
+
be a `Tensor` whose rank is either 0, or the same as `y_true`,
|
| 417 |
+
and must be broadcastable to `y_true`. Defaults to `1`.
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
Update op.
|
| 421 |
+
"""
|
| 422 |
+
y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
|
| 423 |
+
# convert y_pred on float 32 and cast just after to dtype
|
| 424 |
+
y_pred = ops.convert_to_tensor(y_pred, dtype="float32")
|
| 425 |
+
y_pred = ops.cast(y_pred >= self.threshold, self.dtype)
|
| 426 |
+
return super().update_state(y_true, y_pred, sample_weight)
|
| 427 |
+
|
| 428 |
+
def get_config(self):
|
| 429 |
+
return {
|
| 430 |
+
"target_class_ids": self.target_class_ids,
|
| 431 |
+
"threshold": self.threshold,
|
| 432 |
+
"name": self.name,
|
| 433 |
+
"dtype": self._dtype,
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
@keras_export("keras.metrics.MeanIoU")
|
| 438 |
+
class MeanIoU(IoU):
|
| 439 |
+
"""Computes the mean Intersection-Over-Union metric.
|
| 440 |
+
|
| 441 |
+
Formula:
|
| 442 |
+
|
| 443 |
+
```python
|
| 444 |
+
iou = true_positives / (true_positives + false_positives + false_negatives)
|
| 445 |
+
```
|
| 446 |
+
Intersection-Over-Union is a common evaluation metric for semantic image
|
| 447 |
+
segmentation.
|
| 448 |
+
|
| 449 |
+
To compute IoUs, the predictions are accumulated in a confusion matrix,
|
| 450 |
+
weighted by `sample_weight` and the metric is then calculated from it.
|
| 451 |
+
|
| 452 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 453 |
+
Use `sample_weight` of 0 to mask values.
|
| 454 |
+
|
| 455 |
+
Note that this class first computes IoUs for all individual classes, then
|
| 456 |
+
returns the mean of these values.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
num_classes: The possible number of labels the prediction task can have.
|
| 460 |
+
This value must be provided, since a confusion matrix of dimension =
|
| 461 |
+
[num_classes, num_classes] will be allocated.
|
| 462 |
+
name: (Optional) string name of the metric instance.
|
| 463 |
+
dtype: (Optional) data type of the metric result.
|
| 464 |
+
ignore_class: Optional integer. The ID of a class to be ignored during
|
| 465 |
+
metric computation. This is useful, for example, in segmentation
|
| 466 |
+
problems featuring a "void" class (commonly -1 or 255) in
|
| 467 |
+
segmentation maps. By default (`ignore_class=None`), all classes are
|
| 468 |
+
considered.
|
| 469 |
+
sparse_y_true: Whether labels are encoded using integers or
|
| 470 |
+
dense floating point vectors. If `False`, the `argmax` function
|
| 471 |
+
is used to determine each sample's most likely associated label.
|
| 472 |
+
sparse_y_pred: Whether predictions are encoded using integers or
|
| 473 |
+
dense floating point vectors. If `False`, the `argmax` function
|
| 474 |
+
is used to determine each sample's most likely associated label.
|
| 475 |
+
axis: (Optional) The dimension containing the logits. Defaults to `-1`.
|
| 476 |
+
|
| 477 |
+
Example:
|
| 478 |
+
|
| 479 |
+
Example:
|
| 480 |
+
|
| 481 |
+
>>> # cm = [[1, 1],
|
| 482 |
+
>>> # [1, 1]]
|
| 483 |
+
>>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
|
| 484 |
+
>>> # iou = true_positives / (sum_row + sum_col - true_positives))
|
| 485 |
+
>>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33
|
| 486 |
+
>>> m = keras.metrics.MeanIoU(num_classes=2)
|
| 487 |
+
>>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
|
| 488 |
+
>>> m.result()
|
| 489 |
+
0.33333334
|
| 490 |
+
|
| 491 |
+
>>> m.reset_state()
|
| 492 |
+
>>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
|
| 493 |
+
... sample_weight=[0.3, 0.3, 0.3, 0.1])
|
| 494 |
+
>>> m.result().numpy()
|
| 495 |
+
0.23809525
|
| 496 |
+
|
| 497 |
+
Usage with `compile()` API:
|
| 498 |
+
|
| 499 |
+
```python
|
| 500 |
+
model.compile(
|
| 501 |
+
optimizer='sgd',
|
| 502 |
+
loss='mse',
|
| 503 |
+
metrics=[keras.metrics.MeanIoU(num_classes=2)])
|
| 504 |
+
```
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
def __init__(
|
| 508 |
+
self,
|
| 509 |
+
num_classes,
|
| 510 |
+
name=None,
|
| 511 |
+
dtype=None,
|
| 512 |
+
ignore_class=None,
|
| 513 |
+
sparse_y_true=True,
|
| 514 |
+
sparse_y_pred=True,
|
| 515 |
+
axis=-1,
|
| 516 |
+
):
|
| 517 |
+
target_class_ids = list(range(num_classes))
|
| 518 |
+
super().__init__(
|
| 519 |
+
name=name,
|
| 520 |
+
num_classes=num_classes,
|
| 521 |
+
target_class_ids=target_class_ids,
|
| 522 |
+
axis=axis,
|
| 523 |
+
dtype=dtype,
|
| 524 |
+
ignore_class=ignore_class,
|
| 525 |
+
sparse_y_true=sparse_y_true,
|
| 526 |
+
sparse_y_pred=sparse_y_pred,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
def get_config(self):
|
| 530 |
+
return {
|
| 531 |
+
"num_classes": self.num_classes,
|
| 532 |
+
"name": self.name,
|
| 533 |
+
"dtype": self._dtype,
|
| 534 |
+
"ignore_class": self.ignore_class,
|
| 535 |
+
"sparse_y_true": self.sparse_y_true,
|
| 536 |
+
"sparse_y_pred": self.sparse_y_pred,
|
| 537 |
+
"axis": self.axis,
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
@keras_export("keras.metrics.OneHotIoU")
|
| 542 |
+
class OneHotIoU(IoU):
|
| 543 |
+
"""Computes the Intersection-Over-Union metric for one-hot encoded labels.
|
| 544 |
+
|
| 545 |
+
Formula:
|
| 546 |
+
|
| 547 |
+
```python
|
| 548 |
+
iou = true_positives / (true_positives + false_positives + false_negatives)
|
| 549 |
+
```
|
| 550 |
+
Intersection-Over-Union is a common evaluation metric for semantic image
|
| 551 |
+
segmentation.
|
| 552 |
+
|
| 553 |
+
To compute IoUs, the predictions are accumulated in a confusion matrix,
|
| 554 |
+
weighted by `sample_weight` and the metric is then calculated from it.
|
| 555 |
+
|
| 556 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 557 |
+
Use `sample_weight` of 0 to mask values.
|
| 558 |
+
|
| 559 |
+
This class can be used to compute IoU for multi-class classification tasks
|
| 560 |
+
where the labels are one-hot encoded (the last axis should have one
|
| 561 |
+
dimension per class). Note that the predictions should also have the same
|
| 562 |
+
shape. To compute the IoU, first the labels and predictions are converted
|
| 563 |
+
back into integer format by taking the argmax over the class axis. Then the
|
| 564 |
+
same computation steps as for the base `IoU` class apply.
|
| 565 |
+
|
| 566 |
+
Note, if there is only one channel in the labels and predictions, this class
|
| 567 |
+
is the same as class `IoU`. In this case, use `IoU` instead.
|
| 568 |
+
|
| 569 |
+
Also, make sure that `num_classes` is equal to the number of classes in the
|
| 570 |
+
data, to avoid a "labels out of bound" error when the confusion matrix is
|
| 571 |
+
computed.
|
| 572 |
+
|
| 573 |
+
Args:
|
| 574 |
+
num_classes: The possible number of labels the prediction task can have.
|
| 575 |
+
target_class_ids: A tuple or list of target class ids for which the
|
| 576 |
+
metric is returned. To compute IoU for a specific class, a list
|
| 577 |
+
(or tuple) of a single id value should be provided.
|
| 578 |
+
name: (Optional) string name of the metric instance.
|
| 579 |
+
dtype: (Optional) data type of the metric result.
|
| 580 |
+
ignore_class: Optional integer. The ID of a class to be ignored during
|
| 581 |
+
metric computation. This is useful, for example, in segmentation
|
| 582 |
+
problems featuring a "void" class (commonly -1 or 255) in
|
| 583 |
+
segmentation maps. By default (`ignore_class=None`), all classes are
|
| 584 |
+
considered.
|
| 585 |
+
sparse_y_pred: Whether predictions are encoded using integers or
|
| 586 |
+
dense floating point vectors. If `False`, the `argmax` function
|
| 587 |
+
is used to determine each sample's most likely associated label.
|
| 588 |
+
axis: (Optional) The dimension containing the logits. Defaults to `-1`.
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
Example:
|
| 592 |
+
|
| 593 |
+
>>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
|
| 594 |
+
>>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],
|
| 595 |
+
... [0.1, 0.4, 0.5]])
|
| 596 |
+
>>> sample_weight = [0.1, 0.2, 0.3, 0.4]
|
| 597 |
+
>>> m = keras.metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])
|
| 598 |
+
>>> m.update_state(
|
| 599 |
+
... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
|
| 600 |
+
>>> # cm = [[0, 0, 0.2+0.4],
|
| 601 |
+
>>> # [0.3, 0, 0],
|
| 602 |
+
>>> # [0, 0, 0.1]]
|
| 603 |
+
>>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]
|
| 604 |
+
>>> # true_positives = [0, 0, 0.1]
|
| 605 |
+
>>> # single_iou = true_positives / (sum_row + sum_col - true_positives))
|
| 606 |
+
>>> # mean_iou = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2
|
| 607 |
+
>>> m.result()
|
| 608 |
+
0.071
|
| 609 |
+
|
| 610 |
+
Usage with `compile()` API:
|
| 611 |
+
|
| 612 |
+
```python
|
| 613 |
+
model.compile(
|
| 614 |
+
optimizer='sgd',
|
| 615 |
+
loss='mse',
|
| 616 |
+
metrics=[keras.metrics.OneHotIoU(
|
| 617 |
+
num_classes=3,
|
| 618 |
+
target_class_id=[1]
|
| 619 |
+
)]
|
| 620 |
+
)
|
| 621 |
+
```
|
| 622 |
+
"""
|
| 623 |
+
|
| 624 |
+
def __init__(
|
| 625 |
+
self,
|
| 626 |
+
num_classes,
|
| 627 |
+
target_class_ids,
|
| 628 |
+
name=None,
|
| 629 |
+
dtype=None,
|
| 630 |
+
ignore_class=None,
|
| 631 |
+
sparse_y_pred=False,
|
| 632 |
+
axis=-1,
|
| 633 |
+
):
|
| 634 |
+
super().__init__(
|
| 635 |
+
num_classes=num_classes,
|
| 636 |
+
target_class_ids=target_class_ids,
|
| 637 |
+
name=name,
|
| 638 |
+
dtype=dtype,
|
| 639 |
+
ignore_class=ignore_class,
|
| 640 |
+
sparse_y_true=False,
|
| 641 |
+
sparse_y_pred=sparse_y_pred,
|
| 642 |
+
axis=axis,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
def get_config(self):
|
| 646 |
+
return {
|
| 647 |
+
"num_classes": self.num_classes,
|
| 648 |
+
"target_class_ids": self.target_class_ids,
|
| 649 |
+
"name": self.name,
|
| 650 |
+
"dtype": self._dtype,
|
| 651 |
+
"ignore_class": self.ignore_class,
|
| 652 |
+
"sparse_y_pred": self.sparse_y_pred,
|
| 653 |
+
"axis": self.axis,
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
@keras_export("keras.metrics.OneHotMeanIoU")
|
| 658 |
+
class OneHotMeanIoU(MeanIoU):
|
| 659 |
+
"""Computes mean Intersection-Over-Union metric for one-hot encoded labels.
|
| 660 |
+
|
| 661 |
+
Formula:
|
| 662 |
+
|
| 663 |
+
```python
|
| 664 |
+
iou = true_positives / (true_positives + false_positives + false_negatives)
|
| 665 |
+
```
|
| 666 |
+
Intersection-Over-Union is a common evaluation metric for semantic image
|
| 667 |
+
segmentation.
|
| 668 |
+
|
| 669 |
+
To compute IoUs, the predictions are accumulated in a confusion matrix,
|
| 670 |
+
weighted by `sample_weight` and the metric is then calculated from it.
|
| 671 |
+
|
| 672 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 673 |
+
Use `sample_weight` of 0 to mask values.
|
| 674 |
+
|
| 675 |
+
This class can be used to compute the mean IoU for multi-class
|
| 676 |
+
classification tasks where the labels are one-hot encoded (the last axis
|
| 677 |
+
should have one dimension per class). Note that the predictions should also
|
| 678 |
+
have the same shape. To compute the mean IoU, first the labels and
|
| 679 |
+
predictions are converted back into integer format by taking the argmax over
|
| 680 |
+
the class axis. Then the same computation steps as for the base `MeanIoU`
|
| 681 |
+
class apply.
|
| 682 |
+
|
| 683 |
+
Note, if there is only one channel in the labels and predictions, this class
|
| 684 |
+
is the same as class `MeanIoU`. In this case, use `MeanIoU` instead.
|
| 685 |
+
|
| 686 |
+
Also, make sure that `num_classes` is equal to the number of classes in the
|
| 687 |
+
data, to avoid a "labels out of bound" error when the confusion matrix is
|
| 688 |
+
computed.
|
| 689 |
+
|
| 690 |
+
Args:
|
| 691 |
+
num_classes: The possible number of labels the prediction task can have.
|
| 692 |
+
name: (Optional) string name of the metric instance.
|
| 693 |
+
dtype: (Optional) data type of the metric result.
|
| 694 |
+
ignore_class: Optional integer. The ID of a class to be ignored during
|
| 695 |
+
metric computation. This is useful, for example, in segmentation
|
| 696 |
+
problems featuring a "void" class (commonly -1 or 255) in
|
| 697 |
+
segmentation maps. By default (`ignore_class=None`), all classes are
|
| 698 |
+
considered.
|
| 699 |
+
sparse_y_pred: Whether predictions are encoded using natural numbers or
|
| 700 |
+
probability distribution vectors. If `False`, the `argmax`
|
| 701 |
+
function will be used to determine each sample's most likely
|
| 702 |
+
associated label.
|
| 703 |
+
axis: (Optional) The dimension containing the logits. Defaults to `-1`.
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
Example:
|
| 707 |
+
|
| 708 |
+
>>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
|
| 709 |
+
>>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],
|
| 710 |
+
... [0.1, 0.4, 0.5]])
|
| 711 |
+
>>> sample_weight = [0.1, 0.2, 0.3, 0.4]
|
| 712 |
+
>>> m = keras.metrics.OneHotMeanIoU(num_classes=3)
|
| 713 |
+
>>> m.update_state(
|
| 714 |
+
... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
|
| 715 |
+
>>> # cm = [[0, 0, 0.2+0.4],
|
| 716 |
+
>>> # [0.3, 0, 0],
|
| 717 |
+
>>> # [0, 0, 0.1]]
|
| 718 |
+
>>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]
|
| 719 |
+
>>> # true_positives = [0, 0, 0.1]
|
| 720 |
+
>>> # single_iou = true_positives / (sum_row + sum_col - true_positives))
|
| 721 |
+
>>> # mean_iou = (0 + 0 + 0.1 / (0.7 + 0.1 - 0.1)) / 3
|
| 722 |
+
>>> m.result()
|
| 723 |
+
0.048
|
| 724 |
+
|
| 725 |
+
Usage with `compile()` API:
|
| 726 |
+
|
| 727 |
+
```python
|
| 728 |
+
model.compile(
|
| 729 |
+
optimizer='sgd',
|
| 730 |
+
loss='mse',
|
| 731 |
+
metrics=[keras.metrics.OneHotMeanIoU(num_classes=3)])
|
| 732 |
+
```
|
| 733 |
+
"""
|
| 734 |
+
|
| 735 |
+
def __init__(
|
| 736 |
+
self,
|
| 737 |
+
num_classes,
|
| 738 |
+
name=None,
|
| 739 |
+
dtype=None,
|
| 740 |
+
ignore_class=None,
|
| 741 |
+
sparse_y_pred=False,
|
| 742 |
+
axis=-1,
|
| 743 |
+
):
|
| 744 |
+
super().__init__(
|
| 745 |
+
num_classes=num_classes,
|
| 746 |
+
axis=axis,
|
| 747 |
+
name=name,
|
| 748 |
+
dtype=dtype,
|
| 749 |
+
ignore_class=ignore_class,
|
| 750 |
+
sparse_y_true=False,
|
| 751 |
+
sparse_y_pred=sparse_y_pred,
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
def get_config(self):
|
| 755 |
+
return {
|
| 756 |
+
"num_classes": self.num_classes,
|
| 757 |
+
"name": self.name,
|
| 758 |
+
"dtype": self._dtype,
|
| 759 |
+
"ignore_class": self.ignore_class,
|
| 760 |
+
"sparse_y_pred": self.sparse_y_pred,
|
| 761 |
+
"axis": self.axis,
|
| 762 |
+
}
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/metric.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import backend
|
| 2 |
+
from keras.src import dtype_policies
|
| 3 |
+
from keras.src import initializers
|
| 4 |
+
from keras.src import ops
|
| 5 |
+
from keras.src.api_export import keras_export
|
| 6 |
+
from keras.src.saving.keras_saveable import KerasSaveable
|
| 7 |
+
from keras.src.utils.naming import auto_name
|
| 8 |
+
from keras.src.utils.tracking import Tracker
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@keras_export(["keras.Metric", "keras.metrics.Metric"])
|
| 12 |
+
class Metric(KerasSaveable):
|
| 13 |
+
"""Encapsulates metric logic and state.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
name: Optional name for the metric instance.
|
| 17 |
+
dtype: The dtype of the metric's computations. Defaults to `None`, which
|
| 18 |
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
| 19 |
+
`"float32"` unless set to different value
|
| 20 |
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
| 21 |
+
provided, then the `compute_dtype` will be utilized.
|
| 22 |
+
|
| 23 |
+
Example:
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
m = SomeMetric(...)
|
| 27 |
+
for input in ...:
|
| 28 |
+
m.update_state(input)
|
| 29 |
+
print('Final result: ', m.result())
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Usage with `compile()` API:
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
model = keras.Sequential()
|
| 36 |
+
model.add(keras.layers.Dense(64, activation='relu'))
|
| 37 |
+
model.add(keras.layers.Dense(64, activation='relu'))
|
| 38 |
+
model.add(keras.layers.Dense(10, activation='softmax'))
|
| 39 |
+
|
| 40 |
+
model.compile(optimizer=keras.optimizers.RMSprop(0.01),
|
| 41 |
+
loss=keras.losses.CategoricalCrossentropy(),
|
| 42 |
+
metrics=[keras.metrics.CategoricalAccuracy()])
|
| 43 |
+
|
| 44 |
+
data = np.random.random((1000, 32))
|
| 45 |
+
labels = np.random.random((1000, 10))
|
| 46 |
+
|
| 47 |
+
model.fit(data, labels, epochs=10)
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
To be implemented by subclasses:
|
| 51 |
+
|
| 52 |
+
* `__init__()`: All state variables should be created in this method by
|
| 53 |
+
calling `self.add_variable()` like: `self.var = self.add_variable(...)`
|
| 54 |
+
* `update_state()`: Has all updates to the state variables like:
|
| 55 |
+
`self.var.assign(...)`.
|
| 56 |
+
* `result()`: Computes and returns a scalar value or a dict of scalar values
|
| 57 |
+
for the metric from the state variables.
|
| 58 |
+
|
| 59 |
+
Example subclass implementation:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
class BinaryTruePositives(Metric):
|
| 63 |
+
|
| 64 |
+
def __init__(self, name='binary_true_positives', **kwargs):
|
| 65 |
+
super().__init__(name=name, **kwargs)
|
| 66 |
+
self.true_positives = self.add_variable(
|
| 67 |
+
shape=(),
|
| 68 |
+
initializer='zeros',
|
| 69 |
+
name='true_positives'
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 73 |
+
y_true = ops.cast(y_true, "bool")
|
| 74 |
+
y_pred = ops.cast(y_pred, "bool")
|
| 75 |
+
|
| 76 |
+
values = ops.logical_and(
|
| 77 |
+
ops.equal(y_true, True), ops.equal(y_pred, True))
|
| 78 |
+
values = ops.cast(values, self.dtype)
|
| 79 |
+
if sample_weight is not None:
|
| 80 |
+
sample_weight = ops.cast(sample_weight, self.dtype)
|
| 81 |
+
sample_weight = ops.broadcast_to(
|
| 82 |
+
sample_weight, ops.shape(values)
|
| 83 |
+
)
|
| 84 |
+
values = ops.multiply(values, sample_weight)
|
| 85 |
+
self.true_positives.assign(self.true_positives + ops.sum(values))
|
| 86 |
+
|
| 87 |
+
def result(self):
|
| 88 |
+
return self.true_positives
|
| 89 |
+
```
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, dtype=None, name=None):
|
| 93 |
+
self.name = name or auto_name(self.__class__.__name__)
|
| 94 |
+
self._dtype_policy = dtype_policies.get(dtype or backend.floatx())
|
| 95 |
+
self._dtype = self._dtype_policy.compute_dtype
|
| 96 |
+
self._metrics = []
|
| 97 |
+
self._variables = []
|
| 98 |
+
self._tracker = Tracker(
|
| 99 |
+
{
|
| 100 |
+
"variables": (
|
| 101 |
+
lambda x: isinstance(x, backend.Variable),
|
| 102 |
+
self._variables,
|
| 103 |
+
),
|
| 104 |
+
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
|
| 105 |
+
}
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def reset_state(self):
|
| 109 |
+
"""Reset all of the metric state variables.
|
| 110 |
+
|
| 111 |
+
This function is called between epochs/steps,
|
| 112 |
+
when a metric is evaluated during training.
|
| 113 |
+
"""
|
| 114 |
+
for v in self.variables:
|
| 115 |
+
v.assign(ops.zeros(v.shape, dtype=v.dtype))
|
| 116 |
+
|
| 117 |
+
def update_state(self, *args, **kwargs):
|
| 118 |
+
"""Accumulate statistics for the metric."""
|
| 119 |
+
raise NotImplementedError
|
| 120 |
+
|
| 121 |
+
def stateless_update_state(self, metric_variables, *args, **kwargs):
|
| 122 |
+
if len(metric_variables) != len(self.variables):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
"Argument `metric_variables` must be a list of tensors "
|
| 125 |
+
f"corresponding 1:1 to {self.__class__.__name__}().variables. "
|
| 126 |
+
f"Received list with length {len(metric_variables)}, but "
|
| 127 |
+
f"expected {len(self.variables)} variables."
|
| 128 |
+
)
|
| 129 |
+
# Gather variable mapping
|
| 130 |
+
mapping = list(zip(self.variables, metric_variables))
|
| 131 |
+
|
| 132 |
+
# Call in stateless scope
|
| 133 |
+
with backend.StatelessScope(state_mapping=mapping) as scope:
|
| 134 |
+
self.update_state(*args, **kwargs)
|
| 135 |
+
|
| 136 |
+
# Gather updated variables
|
| 137 |
+
metric_variables = []
|
| 138 |
+
for v in self.variables:
|
| 139 |
+
new_v = scope.get_current_value(v)
|
| 140 |
+
if new_v is not None:
|
| 141 |
+
metric_variables.append(new_v)
|
| 142 |
+
else:
|
| 143 |
+
metric_variables.append(v)
|
| 144 |
+
return metric_variables
|
| 145 |
+
|
| 146 |
+
def result(self):
|
| 147 |
+
"""Compute the current metric value.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
A scalar tensor, or a dictionary of scalar tensors.
|
| 151 |
+
"""
|
| 152 |
+
raise NotImplementedError
|
| 153 |
+
|
| 154 |
+
def stateless_result(self, metric_variables):
|
| 155 |
+
if len(metric_variables) != len(self.variables):
|
| 156 |
+
raise ValueError(
|
| 157 |
+
"Argument `metric_variables` must be a list of tensors "
|
| 158 |
+
f"corresponding 1:1 to {self.__class__.__name__}().variables. "
|
| 159 |
+
f"Received list with length {len(metric_variables)}, but "
|
| 160 |
+
f"expected {len(self.variables)} variables."
|
| 161 |
+
)
|
| 162 |
+
# Gather variable mapping
|
| 163 |
+
mapping = list(zip(self.variables, metric_variables))
|
| 164 |
+
|
| 165 |
+
# Call in stateless scope
|
| 166 |
+
with backend.StatelessScope(state_mapping=mapping):
|
| 167 |
+
res = self.result()
|
| 168 |
+
return res
|
| 169 |
+
|
| 170 |
+
def stateless_reset_state(self):
|
| 171 |
+
# Call in stateless scope
|
| 172 |
+
with backend.StatelessScope() as scope:
|
| 173 |
+
self.reset_state()
|
| 174 |
+
|
| 175 |
+
# Gather updated variables
|
| 176 |
+
metric_variables = []
|
| 177 |
+
for v in self.variables:
|
| 178 |
+
new_v = scope.get_current_value(v)
|
| 179 |
+
if new_v is not None:
|
| 180 |
+
metric_variables.append(new_v)
|
| 181 |
+
else:
|
| 182 |
+
metric_variables.append(v)
|
| 183 |
+
return metric_variables
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def dtype(self):
|
| 187 |
+
return self._dtype
|
| 188 |
+
|
| 189 |
+
def _obj_type(self):
|
| 190 |
+
return "Metric"
|
| 191 |
+
|
| 192 |
+
def add_variable(
|
| 193 |
+
self, shape, initializer, dtype=None, aggregation="sum", name=None
|
| 194 |
+
):
|
| 195 |
+
self._check_super_called()
|
| 196 |
+
with backend.name_scope(self.name.replace("/", ">"), caller=self):
|
| 197 |
+
initializer = initializers.get(initializer)
|
| 198 |
+
variable = backend.Variable(
|
| 199 |
+
initializer=initializer,
|
| 200 |
+
shape=shape,
|
| 201 |
+
dtype=dtype,
|
| 202 |
+
trainable=False,
|
| 203 |
+
aggregation=aggregation,
|
| 204 |
+
name=name,
|
| 205 |
+
)
|
| 206 |
+
# Prevent double-tracking
|
| 207 |
+
self._tracker.add_to_store("variables", variable)
|
| 208 |
+
return variable
|
| 209 |
+
|
| 210 |
+
def add_weight(self, shape=(), initializer=None, dtype=None, name=None):
|
| 211 |
+
# Backwards compatibility alias
|
| 212 |
+
return self.add_variable(
|
| 213 |
+
shape=shape, initializer=initializer, dtype=dtype, name=name
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def variables(self):
|
| 218 |
+
variables = list(self._variables)
|
| 219 |
+
for metric in self._metrics:
|
| 220 |
+
variables.extend(metric.variables)
|
| 221 |
+
return variables
|
| 222 |
+
|
| 223 |
+
def __call__(self, *args, **kwargs):
|
| 224 |
+
self._check_super_called()
|
| 225 |
+
self.update_state(*args, **kwargs)
|
| 226 |
+
return self.result()
|
| 227 |
+
|
| 228 |
+
def get_config(self):
|
| 229 |
+
"""Return the serializable config of the metric."""
|
| 230 |
+
return {"name": self.name, "dtype": self.dtype}
|
| 231 |
+
|
| 232 |
+
@classmethod
|
| 233 |
+
def from_config(cls, config):
|
| 234 |
+
return cls(**config)
|
| 235 |
+
|
| 236 |
+
def __setattr__(self, name, value):
|
| 237 |
+
# Track Variables, Layers, Metrics
|
| 238 |
+
if hasattr(self, "_tracker"):
|
| 239 |
+
value = self._tracker.track(value)
|
| 240 |
+
return super().__setattr__(name, value)
|
| 241 |
+
|
| 242 |
+
def _check_super_called(self):
|
| 243 |
+
if not hasattr(self, "_tracker"):
|
| 244 |
+
raise RuntimeError(
|
| 245 |
+
"You forgot to call `super().__init__()` "
|
| 246 |
+
"in the `__init__()` method. Go add it!"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
def __repr__(self):
|
| 250 |
+
return f"<{self.__class__.__name__} " f"name={self.name}>"
|
| 251 |
+
|
| 252 |
+
def __str__(self):
|
| 253 |
+
return self.__repr__()
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/metrics_utils.py
ADDED
|
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from keras.src import backend
|
| 6 |
+
from keras.src import ops
|
| 7 |
+
from keras.src.losses.loss import squeeze_or_expand_to_same_rank
|
| 8 |
+
from keras.src.utils.python_utils import to_list
|
| 9 |
+
|
| 10 |
+
NEG_INF = -1e10
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def assert_thresholds_range(thresholds):
|
| 14 |
+
if thresholds is not None:
|
| 15 |
+
invalid_thresholds = [
|
| 16 |
+
t for t in thresholds if t is None or t < 0 or t > 1
|
| 17 |
+
]
|
| 18 |
+
if invalid_thresholds:
|
| 19 |
+
raise ValueError(
|
| 20 |
+
"Threshold values must be in [0, 1]. "
|
| 21 |
+
f"Received: {invalid_thresholds}"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_init_thresholds(thresholds, default_threshold=0.5):
|
| 26 |
+
if thresholds is not None:
|
| 27 |
+
assert_thresholds_range(to_list(thresholds))
|
| 28 |
+
thresholds = to_list(
|
| 29 |
+
default_threshold if thresholds is None else thresholds
|
| 30 |
+
)
|
| 31 |
+
return thresholds
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ConfusionMatrix(Enum):
|
| 35 |
+
TRUE_POSITIVES = "tp"
|
| 36 |
+
FALSE_POSITIVES = "fp"
|
| 37 |
+
TRUE_NEGATIVES = "tn"
|
| 38 |
+
FALSE_NEGATIVES = "fn"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AUCCurve(Enum):
|
| 42 |
+
"""Type of AUC Curve (ROC or PR)."""
|
| 43 |
+
|
| 44 |
+
ROC = "ROC"
|
| 45 |
+
PR = "PR"
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def from_str(key):
|
| 49 |
+
if key in ("pr", "PR"):
|
| 50 |
+
return AUCCurve.PR
|
| 51 |
+
elif key in ("roc", "ROC"):
|
| 52 |
+
return AUCCurve.ROC
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f'Invalid AUC curve value: "{key}". '
|
| 56 |
+
'Expected values are ["PR", "ROC"]'
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AUCSummationMethod(Enum):
|
| 61 |
+
"""Type of AUC summation method.
|
| 62 |
+
|
| 63 |
+
https://en.wikipedia.org/wiki/Riemann_sum)
|
| 64 |
+
|
| 65 |
+
Contains the following values:
|
| 66 |
+
* 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
|
| 67 |
+
`PR` curve, interpolates (true/false) positives but not the ratio that is
|
| 68 |
+
precision (see Davis & Goadrich 2006 for details).
|
| 69 |
+
* 'minoring': Applies left summation for increasing intervals and right
|
| 70 |
+
summation for decreasing intervals.
|
| 71 |
+
* 'majoring': Applies right summation for increasing intervals and left
|
| 72 |
+
summation for decreasing intervals.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
INTERPOLATION = "interpolation"
|
| 76 |
+
MAJORING = "majoring"
|
| 77 |
+
MINORING = "minoring"
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def from_str(key):
|
| 81 |
+
if key in ("interpolation", "Interpolation"):
|
| 82 |
+
return AUCSummationMethod.INTERPOLATION
|
| 83 |
+
elif key in ("majoring", "Majoring"):
|
| 84 |
+
return AUCSummationMethod.MAJORING
|
| 85 |
+
elif key in ("minoring", "Minoring"):
|
| 86 |
+
return AUCSummationMethod.MINORING
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f'Invalid AUC summation method value: "{key}". '
|
| 90 |
+
'Expected values are ["interpolation", "majoring", "minoring"]'
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _update_confusion_matrix_variables_optimized(
|
| 95 |
+
variables_to_update,
|
| 96 |
+
y_true,
|
| 97 |
+
y_pred,
|
| 98 |
+
thresholds,
|
| 99 |
+
multi_label=False,
|
| 100 |
+
sample_weights=None,
|
| 101 |
+
label_weights=None,
|
| 102 |
+
thresholds_with_epsilon=False,
|
| 103 |
+
):
|
| 104 |
+
"""Update confusion matrix variables with memory efficient alternative.
|
| 105 |
+
|
| 106 |
+
Note that the thresholds need to be evenly distributed within the list, eg,
|
| 107 |
+
the diff between consecutive elements are the same.
|
| 108 |
+
|
| 109 |
+
To compute TP/FP/TN/FN, we are measuring a binary classifier
|
| 110 |
+
C(t) = (predictions >= t)
|
| 111 |
+
at each threshold 't'. So we have
|
| 112 |
+
TP(t) = sum( C(t) * true_labels )
|
| 113 |
+
FP(t) = sum( C(t) * false_labels )
|
| 114 |
+
|
| 115 |
+
But, computing C(t) requires computation for each t. To make it fast,
|
| 116 |
+
observe that C(t) is a cumulative integral, and so if we have
|
| 117 |
+
thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
|
| 118 |
+
where n = num_thresholds, and if we can compute the bucket function
|
| 119 |
+
B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
|
| 120 |
+
then we get
|
| 121 |
+
C(t_i) = sum( B(j), j >= i )
|
| 122 |
+
which is the reversed cumulative sum in ops.cumsum().
|
| 123 |
+
|
| 124 |
+
We can compute B(i) efficiently by taking advantage of the fact that
|
| 125 |
+
our thresholds are evenly distributed, in that
|
| 126 |
+
width = 1.0 / (num_thresholds - 1)
|
| 127 |
+
thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
|
| 128 |
+
Given a prediction value p, we can map it to its bucket by
|
| 129 |
+
bucket_index(p) = floor( p * (num_thresholds - 1) )
|
| 130 |
+
so we can use ops.segment_sum() to update the buckets in one pass.
|
| 131 |
+
|
| 132 |
+
Consider following example:
|
| 133 |
+
y_true = [0, 0, 1, 1]
|
| 134 |
+
y_pred = [0.1, 0.5, 0.3, 0.9]
|
| 135 |
+
thresholds = [0.0, 0.5, 1.0]
|
| 136 |
+
num_buckets = 2 # [0.0, 1.0], (1.0, 2.0]
|
| 137 |
+
bucket_index(y_pred) = ops.floor(y_pred * num_buckets)
|
| 138 |
+
= ops.floor([0.2, 1.0, 0.6, 1.8])
|
| 139 |
+
= [0, 0, 0, 1]
|
| 140 |
+
# The meaning of this bucket is that if any of the label is true,
|
| 141 |
+
# then 1 will be added to the corresponding bucket with the index.
|
| 142 |
+
# Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
|
| 143 |
+
# label for 1.8 is true, then 1 will be added to bucket 1.
|
| 144 |
+
#
|
| 145 |
+
# Note the second item "1.0" is floored to 0, since the value need to be
|
| 146 |
+
# strictly larger than the bucket lower bound.
|
| 147 |
+
# In the implementation, we use ops.ceil() - 1 to achieve this.
|
| 148 |
+
tp_bucket_value = ops.segment_sum(true_labels, bucket_indices,
|
| 149 |
+
num_segments=num_thresholds)
|
| 150 |
+
= [1, 1, 0]
|
| 151 |
+
# For [1, 1, 0] here, it means there is 1 true value contributed by bucket
|
| 152 |
+
# 0, and 1 value contributed by bucket 1. When we aggregate them to
|
| 153 |
+
# together, the result become [a + b + c, b + c, c], since large thresholds
|
| 154 |
+
# will always contribute to the value for smaller thresholds.
|
| 155 |
+
true_positive = ops.cumsum(tp_bucket_value, reverse=True)
|
| 156 |
+
= [2, 1, 0]
|
| 157 |
+
|
| 158 |
+
This implementation exhibits a run time and space complexity of O(T + N),
|
| 159 |
+
where T is the number of thresholds and N is the size of predictions.
|
| 160 |
+
Metrics that rely on standard implementation instead exhibit a complexity of
|
| 161 |
+
O(T * N).
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid
|
| 165 |
+
keys and corresponding variables to update as values.
|
| 166 |
+
y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
|
| 167 |
+
cast to `bool`.
|
| 168 |
+
y_pred: A floating point `Tensor` of arbitrary shape and whose values
|
| 169 |
+
are in the range `[0, 1]`.
|
| 170 |
+
thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
|
| 171 |
+
It need to be evenly distributed (the diff between each element need
|
| 172 |
+
to be the same).
|
| 173 |
+
multi_label: Optional boolean indicating whether multidimensional
|
| 174 |
+
prediction/labels should be treated as multilabel responses, or
|
| 175 |
+
flattened into a single label. When True, the values of
|
| 176 |
+
`variables_to_update` must have a second dimension equal to the
|
| 177 |
+
number of labels in y_true and y_pred, and those tensors must not be
|
| 178 |
+
RaggedTensors.
|
| 179 |
+
sample_weights: Optional `Tensor` whose rank is either 0, or the same
|
| 180 |
+
rank as `y_true`, and must be broadcastable to `y_true` (i.e., all
|
| 181 |
+
dimensions must be either `1`, or the same as the corresponding
|
| 182 |
+
`y_true` dimension).
|
| 183 |
+
label_weights: Optional tensor of non-negative weights for multilabel
|
| 184 |
+
data. The weights are applied when calculating TP, FP, FN, and TN
|
| 185 |
+
without explicit multilabel handling (i.e. when the data is to be
|
| 186 |
+
flattened).
|
| 187 |
+
thresholds_with_epsilon: Optional boolean indicating whether the leading
|
| 188 |
+
and tailing thresholds has any epsilon added for floating point
|
| 189 |
+
imprecisions. It will change how we handle the leading and tailing
|
| 190 |
+
bucket.
|
| 191 |
+
"""
|
| 192 |
+
num_thresholds = ops.shape(thresholds)[0]
|
| 193 |
+
|
| 194 |
+
if sample_weights is None:
|
| 195 |
+
sample_weights = 1.0
|
| 196 |
+
else:
|
| 197 |
+
sample_weights = ops.broadcast_to(
|
| 198 |
+
ops.cast(sample_weights, dtype=y_pred.dtype), ops.shape(y_pred)
|
| 199 |
+
)
|
| 200 |
+
if not multi_label:
|
| 201 |
+
sample_weights = ops.reshape(sample_weights, [-1])
|
| 202 |
+
if label_weights is None:
|
| 203 |
+
label_weights = 1.0
|
| 204 |
+
else:
|
| 205 |
+
label_weights = ops.expand_dims(label_weights, 0)
|
| 206 |
+
label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred))
|
| 207 |
+
if not multi_label:
|
| 208 |
+
label_weights = ops.reshape(label_weights, [-1])
|
| 209 |
+
weights = ops.cast(
|
| 210 |
+
ops.multiply(sample_weights, label_weights), y_true.dtype
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# We shouldn't need this, but in case there are predict value that is out of
|
| 214 |
+
# the range of [0.0, 1.0]
|
| 215 |
+
y_pred = ops.clip(y_pred, x_min=0.0, x_max=1.0)
|
| 216 |
+
|
| 217 |
+
y_true = ops.cast(ops.cast(y_true, "bool"), y_true.dtype)
|
| 218 |
+
if not multi_label:
|
| 219 |
+
y_true = ops.reshape(y_true, [-1])
|
| 220 |
+
y_pred = ops.reshape(y_pred, [-1])
|
| 221 |
+
|
| 222 |
+
true_labels = ops.multiply(y_true, weights)
|
| 223 |
+
false_labels = ops.multiply((1.0 - y_true), weights)
|
| 224 |
+
|
| 225 |
+
# Compute the bucket indices for each prediction value.
|
| 226 |
+
# Since the predict value has to be strictly greater than the thresholds,
|
| 227 |
+
# eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
|
| 228 |
+
# We have to use math.ceil(val) - 1 for the bucket.
|
| 229 |
+
bucket_indices = (
|
| 230 |
+
ops.ceil(y_pred * (ops.cast(num_thresholds, dtype=y_pred.dtype) - 1))
|
| 231 |
+
- 1
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
if thresholds_with_epsilon:
|
| 235 |
+
# In this case, the first bucket should actually take into account since
|
| 236 |
+
# the any prediction between [0.0, 1.0] should be larger than the first
|
| 237 |
+
# threshold. We change the bucket value from -1 to 0.
|
| 238 |
+
bucket_indices = ops.relu(bucket_indices)
|
| 239 |
+
|
| 240 |
+
bucket_indices = ops.cast(bucket_indices, "int32")
|
| 241 |
+
|
| 242 |
+
if multi_label:
|
| 243 |
+
# We need to run bucket segment sum for each of the label class. In the
|
| 244 |
+
# multi_label case, the rank of the label is 2. We first transpose it so
|
| 245 |
+
# that the label dim becomes the first and we can parallel run though
|
| 246 |
+
# them.
|
| 247 |
+
true_labels = ops.transpose(true_labels)
|
| 248 |
+
false_labels = ops.transpose(false_labels)
|
| 249 |
+
bucket_indices = ops.transpose(bucket_indices)
|
| 250 |
+
|
| 251 |
+
def gather_bucket(label_and_bucket_index):
|
| 252 |
+
label, bucket_index = (
|
| 253 |
+
label_and_bucket_index[0],
|
| 254 |
+
label_and_bucket_index[1],
|
| 255 |
+
)
|
| 256 |
+
return ops.segment_sum(
|
| 257 |
+
data=label,
|
| 258 |
+
segment_ids=bucket_index,
|
| 259 |
+
num_segments=num_thresholds,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
tp_bucket_v = backend.vectorized_map(
|
| 263 |
+
gather_bucket,
|
| 264 |
+
(true_labels, bucket_indices),
|
| 265 |
+
)
|
| 266 |
+
fp_bucket_v = backend.vectorized_map(
|
| 267 |
+
gather_bucket, (false_labels, bucket_indices)
|
| 268 |
+
)
|
| 269 |
+
tp = ops.transpose(ops.flip(ops.cumsum(ops.flip(tp_bucket_v), axis=1)))
|
| 270 |
+
fp = ops.transpose(ops.flip(ops.cumsum(ops.flip(fp_bucket_v), axis=1)))
|
| 271 |
+
else:
|
| 272 |
+
tp_bucket_v = ops.segment_sum(
|
| 273 |
+
data=true_labels,
|
| 274 |
+
segment_ids=bucket_indices,
|
| 275 |
+
num_segments=num_thresholds,
|
| 276 |
+
)
|
| 277 |
+
fp_bucket_v = ops.segment_sum(
|
| 278 |
+
data=false_labels,
|
| 279 |
+
segment_ids=bucket_indices,
|
| 280 |
+
num_segments=num_thresholds,
|
| 281 |
+
)
|
| 282 |
+
tp = ops.flip(ops.cumsum(ops.flip(tp_bucket_v)))
|
| 283 |
+
fp = ops.flip(ops.cumsum(ops.flip(fp_bucket_v)))
|
| 284 |
+
|
| 285 |
+
# fn = sum(true_labels) - tp
|
| 286 |
+
# tn = sum(false_labels) - fp
|
| 287 |
+
if (
|
| 288 |
+
ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
|
| 289 |
+
or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
|
| 290 |
+
):
|
| 291 |
+
if multi_label:
|
| 292 |
+
total_true_labels = ops.sum(true_labels, axis=1)
|
| 293 |
+
total_false_labels = ops.sum(false_labels, axis=1)
|
| 294 |
+
else:
|
| 295 |
+
total_true_labels = ops.sum(true_labels)
|
| 296 |
+
total_false_labels = ops.sum(false_labels)
|
| 297 |
+
|
| 298 |
+
if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
|
| 299 |
+
variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
|
| 300 |
+
variable.assign(variable + tp)
|
| 301 |
+
if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
|
| 302 |
+
variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
|
| 303 |
+
variable.assign(variable + fp)
|
| 304 |
+
if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
|
| 305 |
+
variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
|
| 306 |
+
tn = total_false_labels - fp
|
| 307 |
+
variable.assign(variable + tn)
|
| 308 |
+
if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
|
| 309 |
+
variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
|
| 310 |
+
fn = total_true_labels - tp
|
| 311 |
+
variable.assign(variable + fn)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def is_evenly_distributed_thresholds(thresholds):
|
| 315 |
+
"""Check if the thresholds list is evenly distributed.
|
| 316 |
+
|
| 317 |
+
We could leverage evenly distributed thresholds to use less memory when
|
| 318 |
+
calculate metrcis like AUC where each individual threshold need to be
|
| 319 |
+
evaluated.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
thresholds: A python list or tuple, or 1D numpy array whose value is
|
| 323 |
+
ranged in [0, 1].
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
boolean, whether the values in the inputs are evenly distributed.
|
| 327 |
+
"""
|
| 328 |
+
# Check the list value and see if it is evenly distributed.
|
| 329 |
+
num_thresholds = len(thresholds)
|
| 330 |
+
if num_thresholds < 3:
|
| 331 |
+
return False
|
| 332 |
+
even_thresholds = np.arange(num_thresholds, dtype=np.float32) / (
|
| 333 |
+
num_thresholds - 1
|
| 334 |
+
)
|
| 335 |
+
return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def update_confusion_matrix_variables(
|
| 339 |
+
variables_to_update,
|
| 340 |
+
y_true,
|
| 341 |
+
y_pred,
|
| 342 |
+
thresholds,
|
| 343 |
+
top_k=None,
|
| 344 |
+
class_id=None,
|
| 345 |
+
sample_weight=None,
|
| 346 |
+
multi_label=False,
|
| 347 |
+
label_weights=None,
|
| 348 |
+
thresholds_distributed_evenly=False,
|
| 349 |
+
):
|
| 350 |
+
"""Updates the given confusion matrix variables.
|
| 351 |
+
|
| 352 |
+
For every pair of values in y_true and y_pred:
|
| 353 |
+
|
| 354 |
+
true_positive: y_true == True and y_pred > thresholds
|
| 355 |
+
false_negatives: y_true == True and y_pred <= thresholds
|
| 356 |
+
true_negatives: y_true == False and y_pred <= thresholds
|
| 357 |
+
false_positive: y_true == False and y_pred > thresholds
|
| 358 |
+
|
| 359 |
+
The results will be weighted and added together. When multiple thresholds
|
| 360 |
+
are provided, we will repeat the same for every threshold.
|
| 361 |
+
|
| 362 |
+
For estimation of these metrics over a stream of data, the function creates
|
| 363 |
+
an `update_op` operation that updates the given variables.
|
| 364 |
+
|
| 365 |
+
If `sample_weight` is `None`, weights default to 1.
|
| 366 |
+
Use weights of 0 to mask values.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
|
| 370 |
+
and corresponding variables to update as values.
|
| 371 |
+
y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
|
| 372 |
+
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
|
| 373 |
+
in the range `[0, 1]`.
|
| 374 |
+
thresholds: A float value, float tensor, python list, or tuple of float
|
| 375 |
+
thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
|
| 376 |
+
top_k: Optional int, indicates that the positive labels should be limited
|
| 377 |
+
to the top k predictions.
|
| 378 |
+
class_id: Optional int, limits the prediction and labels to the class
|
| 379 |
+
specified by this argument.
|
| 380 |
+
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
|
| 381 |
+
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
|
| 382 |
+
must be either `1`, or the same as the corresponding `y_true`
|
| 383 |
+
dimension).
|
| 384 |
+
multi_label: Optional boolean indicating whether multidimensional
|
| 385 |
+
prediction/labels should be treated as multilabel responses, or
|
| 386 |
+
flattened into a single label. When True, the values of
|
| 387 |
+
`variables_to_update` must have a second dimension equal to the number
|
| 388 |
+
of labels in y_true and y_pred, and those tensors must not be
|
| 389 |
+
RaggedTensors.
|
| 390 |
+
label_weights: (optional) tensor of non-negative weights for multilabel
|
| 391 |
+
data. The weights are applied when calculating TP, FP, FN, and TN
|
| 392 |
+
without explicit multilabel handling (i.e. when the data is to be
|
| 393 |
+
flattened).
|
| 394 |
+
thresholds_distributed_evenly: Boolean, whether the thresholds are evenly
|
| 395 |
+
distributed within the list. An optimized method will be used if this is
|
| 396 |
+
the case. See _update_confusion_matrix_variables_optimized() for more
|
| 397 |
+
details.
|
| 398 |
+
|
| 399 |
+
Raises:
|
| 400 |
+
ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
|
| 401 |
+
`sample_weight` is not `None` and its shape doesn't match `y_pred`, or
|
| 402 |
+
if `variables_to_update` contains invalid keys.
|
| 403 |
+
"""
|
| 404 |
+
if multi_label and label_weights is not None:
|
| 405 |
+
raise ValueError(
|
| 406 |
+
"`label_weights` for multilabel data should be handled "
|
| 407 |
+
"outside of `update_confusion_matrix_variables` when "
|
| 408 |
+
"`multi_label` is True."
|
| 409 |
+
)
|
| 410 |
+
if variables_to_update is None:
|
| 411 |
+
return
|
| 412 |
+
if not any(
|
| 413 |
+
key for key in variables_to_update if key in list(ConfusionMatrix)
|
| 414 |
+
):
|
| 415 |
+
raise ValueError(
|
| 416 |
+
"Please provide at least one valid confusion matrix "
|
| 417 |
+
"variable to update. Valid variable key options are: "
|
| 418 |
+
f'"{list(ConfusionMatrix)}". '
|
| 419 |
+
f'Received: "{variables_to_update.keys()}"'
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
variable_dtype = list(variables_to_update.values())[0].dtype
|
| 423 |
+
|
| 424 |
+
y_true = ops.cast(y_true, dtype=variable_dtype)
|
| 425 |
+
y_pred = ops.cast(y_pred, dtype=variable_dtype)
|
| 426 |
+
|
| 427 |
+
if thresholds_distributed_evenly:
|
| 428 |
+
# Check whether the thresholds has any leading or tailing epsilon added
|
| 429 |
+
# for floating point imprecision. The leading and tailing threshold will
|
| 430 |
+
# be handled bit differently as the corner case. At this point,
|
| 431 |
+
# thresholds should be a list/array with more than 2 items, and ranged
|
| 432 |
+
# between [0, 1]. See is_evenly_distributed_thresholds() for more
|
| 433 |
+
# details.
|
| 434 |
+
thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0
|
| 435 |
+
|
| 436 |
+
thresholds = ops.convert_to_tensor(thresholds, dtype=variable_dtype)
|
| 437 |
+
num_thresholds = ops.shape(thresholds)[0]
|
| 438 |
+
|
| 439 |
+
if multi_label:
|
| 440 |
+
one_thresh = ops.equal(
|
| 441 |
+
np.array(1, dtype="int32"),
|
| 442 |
+
len(thresholds.shape),
|
| 443 |
+
)
|
| 444 |
+
else:
|
| 445 |
+
one_thresh = np.array(True, dtype="bool")
|
| 446 |
+
|
| 447 |
+
invalid_keys = [
|
| 448 |
+
key for key in variables_to_update if key not in list(ConfusionMatrix)
|
| 449 |
+
]
|
| 450 |
+
if invalid_keys:
|
| 451 |
+
raise ValueError(
|
| 452 |
+
f'Invalid keys: "{invalid_keys}". '
|
| 453 |
+
f'Valid variable key options are: "{list(ConfusionMatrix)}"'
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
y_pred, y_true = squeeze_or_expand_to_same_rank(y_pred, y_true)
|
| 457 |
+
if sample_weight is not None:
|
| 458 |
+
sample_weight = ops.expand_dims(
|
| 459 |
+
ops.cast(sample_weight, dtype=variable_dtype), axis=-1
|
| 460 |
+
)
|
| 461 |
+
_, sample_weight = squeeze_or_expand_to_same_rank(
|
| 462 |
+
y_true, sample_weight, expand_rank_1=False
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
if top_k is not None:
|
| 466 |
+
y_pred = _filter_top_k(y_pred, top_k)
|
| 467 |
+
|
| 468 |
+
if class_id is not None:
|
| 469 |
+
if len(y_pred.shape) == 1:
|
| 470 |
+
raise ValueError(
|
| 471 |
+
"When class_id is provided, y_pred must be a 2D array "
|
| 472 |
+
"with shape (num_samples, num_classes), found shape: "
|
| 473 |
+
f"{y_pred.shape}"
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Preserve dimension to match with sample_weight
|
| 477 |
+
y_true = y_true[..., class_id, None]
|
| 478 |
+
y_pred = y_pred[..., class_id, None]
|
| 479 |
+
|
| 480 |
+
if thresholds_distributed_evenly:
|
| 481 |
+
return _update_confusion_matrix_variables_optimized(
|
| 482 |
+
variables_to_update,
|
| 483 |
+
y_true,
|
| 484 |
+
y_pred,
|
| 485 |
+
thresholds,
|
| 486 |
+
multi_label=multi_label,
|
| 487 |
+
sample_weights=sample_weight,
|
| 488 |
+
label_weights=label_weights,
|
| 489 |
+
thresholds_with_epsilon=thresholds_with_epsilon,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
if None in y_pred.shape:
|
| 493 |
+
pred_shape = ops.shape(y_pred)
|
| 494 |
+
num_predictions = pred_shape[0]
|
| 495 |
+
if len(y_pred.shape) == 1:
|
| 496 |
+
num_labels = 1
|
| 497 |
+
else:
|
| 498 |
+
num_labels = ops.cast(
|
| 499 |
+
ops.prod(ops.array(pred_shape[1:]), axis=0), "int32"
|
| 500 |
+
)
|
| 501 |
+
thresh_label_tile = ops.where(one_thresh, num_labels, 1)
|
| 502 |
+
else:
|
| 503 |
+
pred_shape = ops.shape(y_pred)
|
| 504 |
+
num_predictions = pred_shape[0]
|
| 505 |
+
if len(y_pred.shape) == 1:
|
| 506 |
+
num_labels = 1
|
| 507 |
+
else:
|
| 508 |
+
num_labels = np.prod(pred_shape[1:], axis=0).astype("int32")
|
| 509 |
+
thresh_label_tile = np.where(one_thresh, num_labels, 1)
|
| 510 |
+
|
| 511 |
+
# Reshape predictions and labels, adding a dim for thresholding.
|
| 512 |
+
if multi_label:
|
| 513 |
+
predictions_extra_dim = ops.expand_dims(y_pred, 0)
|
| 514 |
+
labels_extra_dim = ops.expand_dims(ops.cast(y_true, dtype="bool"), 0)
|
| 515 |
+
else:
|
| 516 |
+
# Flatten predictions and labels when not multilabel.
|
| 517 |
+
predictions_extra_dim = ops.reshape(y_pred, [1, -1])
|
| 518 |
+
labels_extra_dim = ops.reshape(ops.cast(y_true, dtype="bool"), [1, -1])
|
| 519 |
+
|
| 520 |
+
# Tile the thresholds for every prediction.
|
| 521 |
+
if multi_label:
|
| 522 |
+
thresh_pretile_shape = [num_thresholds, 1, -1]
|
| 523 |
+
thresh_tiles = [1, num_predictions, thresh_label_tile]
|
| 524 |
+
data_tiles = [num_thresholds, 1, 1]
|
| 525 |
+
else:
|
| 526 |
+
thresh_pretile_shape = [num_thresholds, -1]
|
| 527 |
+
thresh_tiles = [1, num_predictions * num_labels]
|
| 528 |
+
data_tiles = [num_thresholds, 1]
|
| 529 |
+
|
| 530 |
+
thresh_tiled = ops.tile(
|
| 531 |
+
ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Tile the predictions for every threshold.
|
| 535 |
+
preds_tiled = ops.tile(predictions_extra_dim, data_tiles)
|
| 536 |
+
|
| 537 |
+
# Compare predictions and threshold.
|
| 538 |
+
pred_is_pos = ops.greater(preds_tiled, thresh_tiled)
|
| 539 |
+
|
| 540 |
+
# Tile labels by number of thresholds
|
| 541 |
+
label_is_pos = ops.tile(labels_extra_dim, data_tiles)
|
| 542 |
+
|
| 543 |
+
if sample_weight is not None:
|
| 544 |
+
sample_weight = ops.broadcast_to(
|
| 545 |
+
ops.cast(sample_weight, dtype=y_pred.dtype), ops.shape(y_pred)
|
| 546 |
+
)
|
| 547 |
+
weights_tiled = ops.tile(
|
| 548 |
+
ops.reshape(sample_weight, thresh_tiles), data_tiles
|
| 549 |
+
)
|
| 550 |
+
else:
|
| 551 |
+
weights_tiled = None
|
| 552 |
+
|
| 553 |
+
if label_weights is not None and not multi_label:
|
| 554 |
+
label_weights = ops.expand_dims(label_weights, 0)
|
| 555 |
+
label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred))
|
| 556 |
+
label_weights_tiled = ops.tile(
|
| 557 |
+
ops.reshape(label_weights, thresh_tiles), data_tiles
|
| 558 |
+
)
|
| 559 |
+
if weights_tiled is None:
|
| 560 |
+
weights_tiled = label_weights_tiled
|
| 561 |
+
else:
|
| 562 |
+
weights_tiled = ops.multiply(weights_tiled, label_weights_tiled)
|
| 563 |
+
|
| 564 |
+
def weighted_assign_add(label, pred, weights, var):
|
| 565 |
+
label_and_pred = ops.cast(ops.logical_and(label, pred), dtype=var.dtype)
|
| 566 |
+
if weights is not None:
|
| 567 |
+
label_and_pred *= ops.cast(weights, dtype=var.dtype)
|
| 568 |
+
var.assign(var + ops.sum(label_and_pred, 1))
|
| 569 |
+
|
| 570 |
+
loop_vars = {
|
| 571 |
+
ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
|
| 572 |
+
}
|
| 573 |
+
update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
|
| 574 |
+
update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
|
| 575 |
+
update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
|
| 576 |
+
|
| 577 |
+
if update_fn or update_tn:
|
| 578 |
+
pred_is_neg = ops.logical_not(pred_is_pos)
|
| 579 |
+
loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
|
| 580 |
+
|
| 581 |
+
if update_fp or update_tn:
|
| 582 |
+
label_is_neg = ops.logical_not(label_is_pos)
|
| 583 |
+
loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
|
| 584 |
+
if update_tn:
|
| 585 |
+
loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (
|
| 586 |
+
label_is_neg,
|
| 587 |
+
pred_is_neg,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
for matrix_cond, (label, pred) in loop_vars.items():
|
| 591 |
+
if matrix_cond in variables_to_update:
|
| 592 |
+
weighted_assign_add(
|
| 593 |
+
label, pred, weights_tiled, variables_to_update[matrix_cond]
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def _filter_top_k(x, k):
|
| 598 |
+
"""Filters top-k values in the last dim of x and set the rest to NEG_INF.
|
| 599 |
+
|
| 600 |
+
Used for computing top-k prediction values in dense labels (which has the
|
| 601 |
+
same shape as predictions) for recall and precision top-k metrics.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
x: tensor with any dimensions.
|
| 605 |
+
k: the number of values to keep.
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
tensor with same shape and dtype as x.
|
| 609 |
+
"""
|
| 610 |
+
_, top_k_idx = ops.top_k(x, k)
|
| 611 |
+
top_k_mask = ops.sum(
|
| 612 |
+
ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2
|
| 613 |
+
)
|
| 614 |
+
return x * top_k_mask + NEG_INF * (1 - top_k_mask)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def confusion_matrix(
|
| 618 |
+
labels,
|
| 619 |
+
predictions,
|
| 620 |
+
num_classes,
|
| 621 |
+
weights=None,
|
| 622 |
+
dtype="int32",
|
| 623 |
+
):
|
| 624 |
+
"""Computes the confusion matrix from predictions and labels.
|
| 625 |
+
|
| 626 |
+
The matrix columns represent the prediction labels and the rows represent
|
| 627 |
+
the real labels. The confusion matrix is always a 2-D array of shape
|
| 628 |
+
`(n, n)`, where `n` is the number of valid labels for a given classification
|
| 629 |
+
task. Both prediction and labels must be 1-D arrays of the same shape in
|
| 630 |
+
order for this function to work.
|
| 631 |
+
|
| 632 |
+
If `num_classes` is `None`, then `num_classes` will be set to one plus the
|
| 633 |
+
maximum value in either predictions or labels. Class labels are expected to
|
| 634 |
+
start at 0. For example, if `num_classes` is 3, then the possible labels
|
| 635 |
+
would be `[0, 1, 2]`.
|
| 636 |
+
|
| 637 |
+
If `weights` is not `None`, then each prediction contributes its
|
| 638 |
+
corresponding weight to the total value of the confusion matrix cell.
|
| 639 |
+
|
| 640 |
+
For example:
|
| 641 |
+
|
| 642 |
+
```python
|
| 643 |
+
keras.metrics.metrics_utils.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
|
| 644 |
+
[[0 0 0 0 0]
|
| 645 |
+
[0 0 1 0 0]
|
| 646 |
+
[0 0 1 0 0]
|
| 647 |
+
[0 0 0 0 0]
|
| 648 |
+
[0 0 0 0 1]]
|
| 649 |
+
```
|
| 650 |
+
|
| 651 |
+
Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
|
| 652 |
+
resulting in a 5x5 confusion matrix.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
labels: 1-D tensor of real labels for the classification task.
|
| 656 |
+
predictions: 1-D tensor of predictions for a given classification.
|
| 657 |
+
num_classes: The possible number of labels the classification
|
| 658 |
+
task can have.
|
| 659 |
+
weights: An optional tensor whose shape matches `predictions`.
|
| 660 |
+
dtype: Data type of the confusion matrix.
|
| 661 |
+
|
| 662 |
+
Returns:
|
| 663 |
+
A tensor of type `dtype` with shape `(n, n)` representing the confusion
|
| 664 |
+
matrix, where `n` is the number of possible labels in the classification
|
| 665 |
+
task.
|
| 666 |
+
"""
|
| 667 |
+
labels = ops.convert_to_tensor(labels, dtype)
|
| 668 |
+
predictions = ops.convert_to_tensor(predictions, dtype)
|
| 669 |
+
labels, predictions = squeeze_or_expand_to_same_rank(labels, predictions)
|
| 670 |
+
|
| 671 |
+
predictions = ops.cast(predictions, dtype)
|
| 672 |
+
labels = ops.cast(labels, dtype)
|
| 673 |
+
|
| 674 |
+
if weights is not None:
|
| 675 |
+
weights = ops.convert_to_tensor(weights, dtype)
|
| 676 |
+
|
| 677 |
+
indices = ops.stack([labels, predictions], axis=1)
|
| 678 |
+
values = ops.ones_like(predictions, dtype) if weights is None else weights
|
| 679 |
+
indices = ops.cast(indices, dtype="int64")
|
| 680 |
+
values = ops.cast(values, dtype=dtype)
|
| 681 |
+
num_classes = int(num_classes)
|
| 682 |
+
confusion_matrix = ops.scatter(indices, values, (num_classes, num_classes))
|
| 683 |
+
return confusion_matrix
|