AIDUDE0541 commited on
Commit
7a11456
·
verified ·
1 Parent(s): d219b44

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__init__.py +0 -0
  2. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/__init__.cpython-310.pyc +0 -0
  3. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/backend.cpython-310.pyc +0 -0
  4. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/layers.cpython-310.pyc +0 -0
  5. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/losses.cpython-310.pyc +0 -0
  6. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/backend.py +2291 -0
  7. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/layers.py +244 -0
  8. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/losses.py +20 -0
  9. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/image.py +1892 -0
  10. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/sequence.py +320 -0
  11. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/text.py +336 -0
  12. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__init__.py +0 -0
  13. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/__init__.cpython-310.pyc +0 -0
  14. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/json_utils.cpython-310.pyc +0 -0
  15. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/legacy_h5_format.cpython-310.pyc +0 -0
  16. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/saving_options.cpython-310.pyc +0 -0
  17. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/saving_utils.cpython-310.pyc +0 -0
  18. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/serialization.cpython-310.pyc +0 -0
  19. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/json_utils.py +220 -0
  20. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/legacy_h5_format.py +640 -0
  21. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/saving_options.py +17 -0
  22. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/saving_utils.py +260 -0
  23. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/serialization.py +574 -0
  24. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__init__.py +207 -0
  25. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/__init__.cpython-310.pyc +0 -0
  26. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/loss.cpython-310.pyc +0 -0
  27. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/losses.cpython-310.pyc +0 -0
  28. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/loss.py +256 -0
  29. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/losses.py +2599 -0
  30. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__init__.py +211 -0
  31. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/__init__.cpython-310.pyc +0 -0
  32. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/accuracy_metrics.cpython-310.pyc +0 -0
  33. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/confusion_metrics.cpython-310.pyc +0 -0
  34. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/correlation_metrics.cpython-310.pyc +0 -0
  35. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/f_score_metrics.cpython-310.pyc +0 -0
  36. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/hinge_metrics.cpython-310.pyc +0 -0
  37. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/iou_metrics.cpython-310.pyc +0 -0
  38. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/metric.cpython-310.pyc +0 -0
  39. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/metrics_utils.cpython-310.pyc +0 -0
  40. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/probabilistic_metrics.cpython-310.pyc +0 -0
  41. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/reduction_metrics.cpython-310.pyc +0 -0
  42. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/regression_metrics.cpython-310.pyc +0 -0
  43. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/accuracy_metrics.py +522 -0
  44. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/confusion_metrics.py +1576 -0
  45. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/correlation_metrics.py +215 -0
  46. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/f_score_metrics.py +320 -0
  47. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/hinge_metrics.py +100 -0
  48. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/iou_metrics.py +762 -0
  49. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/metric.py +253 -0
  50. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/metrics_utils.py +683 -0
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (192 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/backend.cpython-310.pyc ADDED
Binary file (51.1 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/layers.cpython-310.pyc ADDED
Binary file (6.7 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/__pycache__/losses.cpython-310.pyc ADDED
Binary file (949 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/backend.py ADDED
@@ -0,0 +1,2291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Legacy Keras 1/2 backend functions."""
2
+
3
+ import itertools
4
+
5
+ import numpy as np
6
+
7
+ from keras.src import backend
8
+ from keras.src.api_export import keras_export
9
+ from keras.src.utils.module_utils import tensorflow as tf
10
+
11
+ py_any = any
12
+ py_all = all
13
+
14
+
15
+ @keras_export("keras._legacy.backend.abs")
16
+ def abs(x):
17
+ """DEPRECATED."""
18
+ return tf.abs(x)
19
+
20
+
21
+ @keras_export("keras._legacy.backend.all")
22
+ def all(x, axis=None, keepdims=False):
23
+ """DEPRECATED."""
24
+ x = tf.cast(x, tf.bool)
25
+ return tf.reduce_all(x, axis, keepdims)
26
+
27
+
28
+ @keras_export("keras._legacy.backend.any")
29
+ def any(x, axis=None, keepdims=False):
30
+ """DEPRECATED."""
31
+ x = tf.cast(x, tf.bool)
32
+ return tf.reduce_any(x, axis, keepdims)
33
+
34
+
35
+ @keras_export("keras._legacy.backend.argmax")
36
+ def argmax(x, axis=-1):
37
+ """DEPRECATED."""
38
+ return tf.argmax(x, axis)
39
+
40
+
41
+ @keras_export("keras._legacy.backend.argmin")
42
+ def argmin(x, axis=-1):
43
+ """DEPRECATED."""
44
+ return tf.argmin(x, axis)
45
+
46
+
47
+ @keras_export("keras._legacy.backend.arange")
48
+ def arange(start, stop=None, step=1, dtype="int32"):
49
+ """DEPRECATED."""
50
+ if stop is None and start < 0:
51
+ start = 0
52
+ result = tf.range(start, limit=stop, delta=step, name="arange")
53
+ if dtype != "int32":
54
+ result = tf.cast(result, dtype)
55
+ return result
56
+
57
+
58
+ @keras_export("keras._legacy.backend.batch_dot")
59
+ def batch_dot(x, y, axes=None):
60
+ """DEPRECATED."""
61
+ x_shape = x.shape
62
+ y_shape = y.shape
63
+
64
+ x_ndim = len(x_shape)
65
+ y_ndim = len(y_shape)
66
+
67
+ if x_ndim < 2 or y_ndim < 2:
68
+ raise ValueError(
69
+ "Cannot do batch_dot on inputs "
70
+ "with rank < 2. "
71
+ "Received inputs with tf.shapes "
72
+ + str(x_shape)
73
+ + " and "
74
+ + str(y_shape)
75
+ + "."
76
+ )
77
+
78
+ x_batch_size = x_shape[0]
79
+ y_batch_size = y_shape[0]
80
+
81
+ if x_batch_size is not None and y_batch_size is not None:
82
+ if x_batch_size != y_batch_size:
83
+ raise ValueError(
84
+ "Cannot do batch_dot on inputs "
85
+ "with different batch sizes. "
86
+ "Received inputs with tf.shapes "
87
+ + str(x_shape)
88
+ + " and "
89
+ + str(y_shape)
90
+ + "."
91
+ )
92
+ if isinstance(axes, int):
93
+ axes = [axes, axes]
94
+
95
+ if axes is None:
96
+ if y_ndim == 2:
97
+ axes = [x_ndim - 1, y_ndim - 1]
98
+ else:
99
+ axes = [x_ndim - 1, y_ndim - 2]
100
+
101
+ if py_any(isinstance(a, (list, tuple)) for a in axes):
102
+ raise ValueError(
103
+ "Multiple target dimensions are not supported. "
104
+ + "Expected: None, int, (int, int), "
105
+ + "Provided: "
106
+ + str(axes)
107
+ )
108
+
109
+ # if tuple, convert to list.
110
+ axes = list(axes)
111
+
112
+ # convert negative indices.
113
+ if axes[0] < 0:
114
+ axes[0] += x_ndim
115
+ if axes[1] < 0:
116
+ axes[1] += y_ndim
117
+
118
+ # sanity checks
119
+ if 0 in axes:
120
+ raise ValueError(
121
+ "Cannot perform batch_dot over axis 0. "
122
+ "If your inputs are not batched, "
123
+ "add a dummy batch dimension to your "
124
+ "inputs using K.expand_dims(x, 0)"
125
+ )
126
+ a0, a1 = axes
127
+ d1 = x_shape[a0]
128
+ d2 = y_shape[a1]
129
+
130
+ if d1 is not None and d2 is not None and d1 != d2:
131
+ raise ValueError(
132
+ "Cannot do batch_dot on inputs with tf.shapes "
133
+ + str(x_shape)
134
+ + " and "
135
+ + str(y_shape)
136
+ + " with axes="
137
+ + str(axes)
138
+ + ". x.shape[%d] != y.shape[%d] (%d != %d)."
139
+ % (axes[0], axes[1], d1, d2)
140
+ )
141
+
142
+ # backup ndims. Need them later.
143
+ orig_x_ndim = x_ndim
144
+ orig_y_ndim = y_ndim
145
+
146
+ # if rank is 2, expand to 3.
147
+ if x_ndim == 2:
148
+ x = tf.expand_dims(x, 1)
149
+ a0 += 1
150
+ x_ndim += 1
151
+ if y_ndim == 2:
152
+ y = tf.expand_dims(y, 2)
153
+ y_ndim += 1
154
+
155
+ # bring x's dimension to be reduced to last axis.
156
+ if a0 != x_ndim - 1:
157
+ pattern = list(range(x_ndim))
158
+ for i in range(a0, x_ndim - 1):
159
+ pattern[i] = pattern[i + 1]
160
+ pattern[-1] = a0
161
+ x = tf.transpose(x, pattern)
162
+
163
+ # bring y's dimension to be reduced to axis 1.
164
+ if a1 != 1:
165
+ pattern = list(range(y_ndim))
166
+ for i in range(a1, 1, -1):
167
+ pattern[i] = pattern[i - 1]
168
+ pattern[1] = a1
169
+ y = tf.transpose(y, pattern)
170
+
171
+ # normalize both inputs to rank 3.
172
+ if x_ndim > 3:
173
+ # squash middle dimensions of x.
174
+ x_shape = tf.shape(x)
175
+ x_mid_dims = x_shape[1:-1]
176
+ x_squashed_shape = tf.stack([x_shape[0], -1, x_shape[-1]])
177
+ x = tf.reshape(x, x_squashed_shape)
178
+ x_squashed = True
179
+ else:
180
+ x_squashed = False
181
+
182
+ if y_ndim > 3:
183
+ # squash trailing dimensions of y.
184
+ y_shape = tf.shape(y)
185
+ y_trail_dims = y_shape[2:]
186
+ y_squashed_shape = tf.stack([y_shape[0], y_shape[1], -1])
187
+ y = tf.reshape(y, y_squashed_shape)
188
+ y_squashed = True
189
+ else:
190
+ y_squashed = False
191
+
192
+ result = tf.matmul(x, y)
193
+
194
+ # if inputs were squashed, we have to reshape the matmul output.
195
+ output_shape = tf.shape(result)
196
+ do_reshape = False
197
+
198
+ if x_squashed:
199
+ output_shape = tf.concat(
200
+ [output_shape[:1], x_mid_dims, output_shape[-1:]], 0
201
+ )
202
+ do_reshape = True
203
+
204
+ if y_squashed:
205
+ output_shape = tf.concat([output_shape[:-1], y_trail_dims], 0)
206
+ do_reshape = True
207
+
208
+ if do_reshape:
209
+ result = tf.reshape(result, output_shape)
210
+
211
+ # if the inputs were originally rank 2, we remove the added 1 dim.
212
+ if orig_x_ndim == 2:
213
+ result = tf.squeeze(result, 1)
214
+ elif orig_y_ndim == 2:
215
+ result = tf.squeeze(result, -1)
216
+
217
+ return result
218
+
219
+
220
+ @keras_export("keras._legacy.backend.batch_flatten")
221
+ def batch_flatten(x):
222
+ """DEPRECATED."""
223
+ x = tf.reshape(x, tf.stack([-1, prod(tf.shape(x)[1:])]))
224
+ return x
225
+
226
+
227
+ @keras_export("keras._legacy.backend.batch_get_value")
228
+ def batch_get_value(tensors):
229
+ """DEPRECATED."""
230
+ return [x.numpy() for x in tensors]
231
+
232
+
233
+ @keras_export("keras._legacy.backend.batch_set_value")
234
+ def batch_set_value(tuples):
235
+ """DEPRECATED."""
236
+ if tf.executing_eagerly() or tf.inside_function():
237
+ for x, value in tuples:
238
+ value = np.asarray(value, dtype=x.dtype.name)
239
+ x.assign(value)
240
+
241
+
242
+ @keras_export("keras._legacy.backend.batch_normalization")
243
+ def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
244
+ """DEPRECATED."""
245
+ return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
246
+
247
+
248
+ @keras_export("keras._legacy.backend.bias_add")
249
+ def bias_add(x, bias, data_format=None):
250
+ """DEPRECATED."""
251
+ if data_format is None:
252
+ data_format = backend.image_data_format()
253
+ if data_format not in {"channels_first", "channels_last"}:
254
+ raise ValueError(f"Unknown data_format: {data_format}")
255
+ bias_shape = bias.shape
256
+ if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
257
+ raise ValueError(
258
+ f"Unexpected bias dimensions {len(bias_shape)}. "
259
+ f"Expected it to be 1 or {ndim(x) - 1} dimensions"
260
+ )
261
+
262
+ if len(bias_shape) == 1:
263
+ if data_format == "channels_first":
264
+ return tf.nn.bias_add(x, bias, data_format="NCHW")
265
+ return tf.nn.bias_add(x, bias, data_format="NHWC")
266
+ if ndim(x) in (3, 4, 5):
267
+ if data_format == "channels_first":
268
+ bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
269
+ return x + reshape(bias, bias_reshape_axis)
270
+ return x + reshape(bias, (1,) + bias_shape)
271
+ return tf.nn.bias_add(x, bias)
272
+
273
+
274
+ @keras_export("keras._legacy.backend.binary_crossentropy")
275
+ def binary_crossentropy(target, output, from_logits=False):
276
+ """DEPRECATED."""
277
+ target = tf.convert_to_tensor(target)
278
+ output = tf.convert_to_tensor(output)
279
+
280
+ if from_logits:
281
+ return tf.nn.sigmoid_cross_entropy_with_logits(
282
+ labels=target, logits=output
283
+ )
284
+
285
+ epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
286
+ output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
287
+
288
+ # Compute cross entropy from probabilities.
289
+ bce = target * tf.math.log(output + backend.epsilon())
290
+ bce += (1 - target) * tf.math.log(1 - output + backend.epsilon())
291
+ return -bce
292
+
293
+
294
+ @keras_export("keras._legacy.backend.binary_focal_crossentropy")
295
+ def binary_focal_crossentropy(
296
+ target,
297
+ output,
298
+ apply_class_balancing=False,
299
+ alpha=0.25,
300
+ gamma=2.0,
301
+ from_logits=False,
302
+ ):
303
+ """DEPRECATED."""
304
+ sigmoidal = tf.sigmoid(output) if from_logits else output
305
+
306
+ p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal)
307
+
308
+ # Calculate focal factor
309
+ focal_factor = tf.pow(1.0 - p_t, gamma)
310
+
311
+ # Binary crossentropy
312
+ bce = binary_crossentropy(
313
+ target=target,
314
+ output=output,
315
+ from_logits=from_logits,
316
+ )
317
+ focal_bce = focal_factor * bce
318
+
319
+ if apply_class_balancing:
320
+ weight = target * alpha + (1 - target) * (1 - alpha)
321
+ focal_bce = weight * focal_bce
322
+
323
+ return focal_bce
324
+
325
+
326
+ @keras_export("keras._legacy.backend.cast")
327
+ def cast(x, dtype):
328
+ """DEPRECATED."""
329
+ return tf.cast(x, dtype)
330
+
331
+
332
+ @keras_export("keras._legacy.backend.cast_to_floatx")
333
+ def cast_to_floatx(x):
334
+ """DEPRECATED."""
335
+ if isinstance(x, (tf.Tensor, tf.Variable, tf.SparseTensor)):
336
+ return tf.cast(x, dtype=backend.floatx())
337
+ return np.asarray(x, dtype=backend.floatx())
338
+
339
+
340
+ @keras_export("keras._legacy.backend.categorical_crossentropy")
341
+ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
342
+ """DEPRECATED."""
343
+ target = tf.convert_to_tensor(target)
344
+ output = tf.convert_to_tensor(output)
345
+ target.shape.assert_is_compatible_with(output.shape)
346
+
347
+ if from_logits:
348
+ return tf.nn.softmax_cross_entropy_with_logits(
349
+ labels=target, logits=output, axis=axis
350
+ )
351
+
352
+ # Adjust the predictions so that the probability of
353
+ # each class for every sample adds up to 1
354
+ # This is needed to ensure that the cross entropy is
355
+ # computed correctly.
356
+ output = output / tf.reduce_sum(output, axis, True)
357
+
358
+ # Compute cross entropy from probabilities.
359
+ epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
360
+ output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
361
+ return -tf.reduce_sum(target * tf.math.log(output), axis)
362
+
363
+
364
+ @keras_export("keras._legacy.backend.categorical_focal_crossentropy")
365
+ def categorical_focal_crossentropy(
366
+ target,
367
+ output,
368
+ alpha=0.25,
369
+ gamma=2.0,
370
+ from_logits=False,
371
+ axis=-1,
372
+ ):
373
+ """DEPRECATED."""
374
+ target = tf.convert_to_tensor(target)
375
+ output = tf.convert_to_tensor(output)
376
+ target.shape.assert_is_compatible_with(output.shape)
377
+
378
+ if from_logits:
379
+ output = tf.nn.softmax(output, axis=axis)
380
+
381
+ # Adjust the predictions so that the probability of
382
+ # each class for every sample adds up to 1
383
+ # This is needed to ensure that the cross entropy is
384
+ # computed correctly.
385
+ output = output / tf.reduce_sum(output, axis=axis, keepdims=True)
386
+
387
+ epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
388
+ output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
389
+
390
+ # Calculate cross entropy
391
+ cce = -target * tf.math.log(output)
392
+
393
+ # Calculate factors
394
+ modulating_factor = tf.pow(1.0 - output, gamma)
395
+ weighting_factor = tf.multiply(modulating_factor, alpha)
396
+
397
+ # Apply weighting factor
398
+ focal_cce = tf.multiply(weighting_factor, cce)
399
+ focal_cce = tf.reduce_sum(focal_cce, axis=axis)
400
+ return focal_cce
401
+
402
+
403
+ @keras_export("keras._legacy.backend.clip")
404
+ def clip(x, min_value, max_value):
405
+ """DEPRECATED."""
406
+ if isinstance(min_value, (int, float)) and isinstance(
407
+ max_value, (int, float)
408
+ ):
409
+ if max_value < min_value:
410
+ max_value = min_value
411
+ if min_value is None:
412
+ min_value = -np.inf
413
+ if max_value is None:
414
+ max_value = np.inf
415
+ return tf.clip_by_value(x, min_value, max_value)
416
+
417
+
418
+ @keras_export("keras._legacy.backend.concatenate")
419
+ def concatenate(tensors, axis=-1):
420
+ """DEPRECATED."""
421
+ if axis < 0:
422
+ rank = ndim(tensors[0])
423
+ if rank:
424
+ axis %= rank
425
+ else:
426
+ axis = 0
427
+
428
+ if py_all(is_sparse(x) for x in tensors):
429
+ return tf.compat.v1.sparse_concat(axis, tensors)
430
+ elif py_all(isinstance(x, tf.RaggedTensor) for x in tensors):
431
+ return tf.concat(tensors, axis)
432
+ else:
433
+ return tf.concat([to_dense(x) for x in tensors], axis)
434
+
435
+
436
+ @keras_export("keras._legacy.backend.constant")
437
+ def constant(value, dtype=None, shape=None, name=None):
438
+ """DEPRECATED."""
439
+ if dtype is None:
440
+ dtype = backend.floatx()
441
+
442
+ return tf.constant(value, dtype=dtype, shape=shape, name=name)
443
+
444
+
445
+ def _preprocess_conv1d_input(x, data_format):
446
+ tf_data_format = "NWC" # to pass TF Conv2dNative operations
447
+ if data_format == "channels_first":
448
+ tf_data_format = "NCW"
449
+ return x, tf_data_format
450
+
451
+
452
+ def _preprocess_conv2d_input(x, data_format, force_transpose=False):
453
+ tf_data_format = "NHWC"
454
+ if data_format == "channels_first":
455
+ if force_transpose:
456
+ x = tf.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
457
+ else:
458
+ tf_data_format = "NCHW"
459
+ return x, tf_data_format
460
+
461
+
462
+ def _preprocess_conv3d_input(x, data_format):
463
+ tf_data_format = "NDHWC"
464
+ if data_format == "channels_first":
465
+ tf_data_format = "NCDHW"
466
+ return x, tf_data_format
467
+
468
+
469
+ def _preprocess_padding(padding):
470
+ if padding == "same":
471
+ padding = "SAME"
472
+ elif padding == "valid":
473
+ padding = "VALID"
474
+ else:
475
+ raise ValueError(f"Invalid padding: {padding}")
476
+ return padding
477
+
478
+
479
+ @keras_export("keras._legacy.backend.conv1d")
480
+ def conv1d(
481
+ x, kernel, strides=1, padding="valid", data_format=None, dilation_rate=1
482
+ ):
483
+ """DEPRECATED."""
484
+ if data_format is None:
485
+ data_format = backend.image_data_format()
486
+ if data_format not in {"channels_first", "channels_last"}:
487
+ raise ValueError(f"Unknown data_format: {data_format}")
488
+
489
+ kernel_shape = kernel.shape.as_list()
490
+ if padding == "causal":
491
+ # causal (dilated) convolution:
492
+ left_pad = dilation_rate * (kernel_shape[0] - 1)
493
+ x = temporal_padding(x, (left_pad, 0))
494
+ padding = "valid"
495
+ padding = _preprocess_padding(padding)
496
+
497
+ x, tf_data_format = _preprocess_conv1d_input(x, data_format)
498
+ x = tf.compat.v1.nn.convolution(
499
+ input=x,
500
+ filter=kernel,
501
+ dilation_rate=dilation_rate,
502
+ strides=strides,
503
+ padding=padding,
504
+ data_format=tf_data_format,
505
+ )
506
+ if data_format == "channels_first" and tf_data_format == "NWC":
507
+ x = tf.transpose(x, (0, 2, 1)) # NWC -> NCW
508
+ return x
509
+
510
+
511
+ @keras_export("keras._legacy.backend.conv2d")
512
+ def conv2d(
513
+ x,
514
+ kernel,
515
+ strides=(1, 1),
516
+ padding="valid",
517
+ data_format=None,
518
+ dilation_rate=(1, 1),
519
+ ):
520
+ """DEPRECATED."""
521
+ if data_format is None:
522
+ data_format = backend.image_data_format()
523
+ if data_format not in {"channels_first", "channels_last"}:
524
+ raise ValueError(f"Unknown data_format: {data_format}")
525
+
526
+ x, tf_data_format = _preprocess_conv2d_input(x, data_format)
527
+ padding = _preprocess_padding(padding)
528
+ x = tf.compat.v1.nn.convolution(
529
+ input=x,
530
+ filter=kernel,
531
+ dilation_rate=dilation_rate,
532
+ strides=strides,
533
+ padding=padding,
534
+ data_format=tf_data_format,
535
+ )
536
+ if data_format == "channels_first" and tf_data_format == "NHWC":
537
+ x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
538
+ return x
539
+
540
+
541
+ @keras_export("keras._legacy.backend.conv2d_transpose")
542
+ def conv2d_transpose(
543
+ x,
544
+ kernel,
545
+ output_shape,
546
+ strides=(1, 1),
547
+ padding="valid",
548
+ data_format=None,
549
+ dilation_rate=(1, 1),
550
+ ):
551
+ """DEPRECATED."""
552
+ if data_format is None:
553
+ data_format = backend.image_data_format()
554
+ if data_format not in {"channels_first", "channels_last"}:
555
+ raise ValueError(f"Unknown data_format: {data_format}")
556
+
557
+ # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
558
+ if data_format == "channels_first" and dilation_rate != (1, 1):
559
+ force_transpose = True
560
+ else:
561
+ force_transpose = False
562
+
563
+ x, tf_data_format = _preprocess_conv2d_input(
564
+ x, data_format, force_transpose
565
+ )
566
+
567
+ if data_format == "channels_first" and tf_data_format == "NHWC":
568
+ output_shape = (
569
+ output_shape[0],
570
+ output_shape[2],
571
+ output_shape[3],
572
+ output_shape[1],
573
+ )
574
+ if output_shape[0] is None:
575
+ output_shape = (tf.shape(x)[0],) + tuple(output_shape[1:])
576
+
577
+ if isinstance(output_shape, (tuple, list)):
578
+ output_shape = tf.stack(list(output_shape))
579
+
580
+ padding = _preprocess_padding(padding)
581
+ if tf_data_format == "NHWC":
582
+ strides = (1,) + strides + (1,)
583
+ else:
584
+ strides = (1, 1) + strides
585
+
586
+ if dilation_rate == (1, 1):
587
+ x = tf.compat.v1.nn.conv2d_transpose(
588
+ x,
589
+ kernel,
590
+ output_shape,
591
+ strides,
592
+ padding=padding,
593
+ data_format=tf_data_format,
594
+ )
595
+ else:
596
+ if dilation_rate[0] != dilation_rate[1]:
597
+ raise ValueError(
598
+ "Expected the 2 dimensions of the `dilation_rate` argument "
599
+ "to be equal to each other. "
600
+ f"Received: dilation_rate={dilation_rate}"
601
+ )
602
+ x = tf.nn.atrous_conv2d_transpose(
603
+ x, kernel, output_shape, rate=dilation_rate[0], padding=padding
604
+ )
605
+ if data_format == "channels_first" and tf_data_format == "NHWC":
606
+ x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
607
+ return x
608
+
609
+
610
+ @keras_export("keras._legacy.backend.conv3d")
611
+ def conv3d(
612
+ x,
613
+ kernel,
614
+ strides=(1, 1, 1),
615
+ padding="valid",
616
+ data_format=None,
617
+ dilation_rate=(1, 1, 1),
618
+ ):
619
+ """DEPRECATED."""
620
+ if data_format is None:
621
+ data_format = backend.image_data_format()
622
+ if data_format not in {"channels_first", "channels_last"}:
623
+ raise ValueError(f"Unknown data_format: {data_format}")
624
+
625
+ x, tf_data_format = _preprocess_conv3d_input(x, data_format)
626
+ padding = _preprocess_padding(padding)
627
+ x = tf.compat.v1.nn.convolution(
628
+ input=x,
629
+ filter=kernel,
630
+ dilation_rate=dilation_rate,
631
+ strides=strides,
632
+ padding=padding,
633
+ data_format=tf_data_format,
634
+ )
635
+ if data_format == "channels_first" and tf_data_format == "NDHWC":
636
+ x = tf.transpose(x, (0, 4, 1, 2, 3))
637
+ return x
638
+
639
+
640
+ @keras_export("keras._legacy.backend.cos")
641
+ def cos(x):
642
+ """DEPRECATED."""
643
+ return tf.cos(x)
644
+
645
+
646
+ @keras_export("keras._legacy.backend.count_params")
647
+ def count_params(x):
648
+ """DEPRECATED."""
649
+ return np.prod(x.shape.as_list())
650
+
651
+
652
+ @keras_export("keras._legacy.backend.ctc_batch_cost")
653
+ def ctc_batch_cost(y_true, y_pred, input_length, label_length):
654
+ """DEPRECATED."""
655
+ label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)
656
+ input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)
657
+ sparse_labels = tf.cast(
658
+ ctc_label_dense_to_sparse(y_true, label_length), tf.int32
659
+ )
660
+
661
+ y_pred = tf.math.log(
662
+ tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon()
663
+ )
664
+
665
+ return tf.expand_dims(
666
+ tf.compat.v1.nn.ctc_loss(
667
+ inputs=y_pred, labels=sparse_labels, sequence_length=input_length
668
+ ),
669
+ 1,
670
+ )
671
+
672
+
673
+ @keras_export("keras._legacy.backend.ctc_label_dense_to_sparse")
674
+ def ctc_label_dense_to_sparse(labels, label_lengths):
675
+ """DEPRECATED."""
676
+ label_shape = tf.shape(labels)
677
+ num_batches_tns = tf.stack([label_shape[0]])
678
+ max_num_labels_tns = tf.stack([label_shape[1]])
679
+
680
+ def range_less_than(old_input, current_input):
681
+ return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill(
682
+ max_num_labels_tns, current_input
683
+ )
684
+
685
+ init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool)
686
+ dense_mask = tf.compat.v1.scan(
687
+ range_less_than, label_lengths, initializer=init, parallel_iterations=1
688
+ )
689
+ dense_mask = dense_mask[:, 0, :]
690
+
691
+ label_array = tf.reshape(
692
+ tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape
693
+ )
694
+ label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)
695
+
696
+ batch_array = tf.transpose(
697
+ tf.reshape(
698
+ tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns),
699
+ reverse(label_shape, 0),
700
+ )
701
+ )
702
+ batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
703
+ indices = tf.transpose(
704
+ tf.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1])
705
+ )
706
+
707
+ vals_sparse = tf.compat.v1.gather_nd(labels, indices)
708
+
709
+ return tf.SparseTensor(
710
+ tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64)
711
+ )
712
+
713
+
714
+ @keras_export("keras._legacy.backend.ctc_decode")
715
+ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
716
+ """DEPRECATED."""
717
+ input_shape = tf.shape(y_pred)
718
+ num_samples, num_steps = input_shape[0], input_shape[1]
719
+ y_pred = tf.math.log(
720
+ tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon()
721
+ )
722
+ input_length = tf.cast(input_length, tf.int32)
723
+
724
+ if greedy:
725
+ (decoded, log_prob) = tf.nn.ctc_greedy_decoder(
726
+ inputs=y_pred, sequence_length=input_length
727
+ )
728
+ else:
729
+ (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
730
+ inputs=y_pred,
731
+ sequence_length=input_length,
732
+ beam_width=beam_width,
733
+ top_paths=top_paths,
734
+ )
735
+ decoded_dense = []
736
+ for st in decoded:
737
+ st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
738
+ decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
739
+ return (decoded_dense, log_prob)
740
+
741
+
742
+ @keras_export("keras._legacy.backend.cumsum")
743
+ def cumsum(x, axis=0):
744
+ """DEPRECATED."""
745
+ return tf.cumsum(x, axis=axis)
746
+
747
+
748
+ @keras_export("keras._legacy.backend.cumprod")
749
+ def cumprod(x, axis=0):
750
+ """DEPRECATED."""
751
+ return tf.math.cumprod(x, axis=axis)
752
+
753
+
754
+ @keras_export("keras._legacy.backend.depthwise_conv2d")
755
+ def depthwise_conv2d(
756
+ x,
757
+ depthwise_kernel,
758
+ strides=(1, 1),
759
+ padding="valid",
760
+ data_format=None,
761
+ dilation_rate=(1, 1),
762
+ ):
763
+ """DEPRECATED."""
764
+ if data_format is None:
765
+ data_format = backend.image_data_format()
766
+ if data_format not in {"channels_first", "channels_last"}:
767
+ raise ValueError(f"Unknown data_format: {data_format}")
768
+
769
+ x, tf_data_format = _preprocess_conv2d_input(x, data_format)
770
+ padding = _preprocess_padding(padding)
771
+ if tf_data_format == "NHWC":
772
+ strides = (1,) + strides + (1,)
773
+ else:
774
+ strides = (1, 1) + strides
775
+
776
+ x = tf.nn.depthwise_conv2d(
777
+ x,
778
+ depthwise_kernel,
779
+ strides=strides,
780
+ padding=padding,
781
+ dilations=dilation_rate,
782
+ data_format=tf_data_format,
783
+ )
784
+ if data_format == "channels_first" and tf_data_format == "NHWC":
785
+ x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
786
+ return x
787
+
788
+
789
+ @keras_export("keras._legacy.backend.dot")
790
+ def dot(x, y):
791
+ """DEPRECATED."""
792
+ if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
793
+ x_shape = []
794
+ for i, s in zip(x.shape, tf.unstack(tf.shape(x))):
795
+ if i is not None:
796
+ x_shape.append(i)
797
+ else:
798
+ x_shape.append(s)
799
+ x_shape = tuple(x_shape)
800
+ y_shape = []
801
+ for i, s in zip(y.shape, tf.unstack(tf.shape(y))):
802
+ if i is not None:
803
+ y_shape.append(i)
804
+ else:
805
+ y_shape.append(s)
806
+ y_shape = tuple(y_shape)
807
+ y_permute_dim = list(range(ndim(y)))
808
+ y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
809
+ xt = tf.reshape(x, [-1, x_shape[-1]])
810
+ yt = tf.reshape(tf.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
811
+ return tf.reshape(
812
+ tf.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:]
813
+ )
814
+ if is_sparse(x):
815
+ out = tf.sparse.sparse_dense_matmul(x, y)
816
+ else:
817
+ out = tf.matmul(x, y)
818
+ return out
819
+
820
+
821
+ @keras_export("keras._legacy.backend.dropout")
822
+ def dropout(x, level, noise_shape=None, seed=None):
823
+ """DEPRECATED."""
824
+ if seed is None:
825
+ seed = np.random.randint(10e6)
826
+ return tf.nn.dropout(x, rate=level, noise_shape=noise_shape, seed=seed)
827
+
828
+
829
+ @keras_export("keras._legacy.backend.dtype")
830
+ def dtype(x):
831
+ """DEPRECATED."""
832
+ return x.dtype.base_dtype.name
833
+
834
+
835
+ @keras_export("keras._legacy.backend.elu")
836
+ def elu(x, alpha=1.0):
837
+ """DEPRECATED."""
838
+ res = tf.nn.elu(x)
839
+ if alpha == 1:
840
+ return res
841
+ else:
842
+ return tf.where(x > 0, res, alpha * res)
843
+
844
+
845
+ @keras_export("keras._legacy.backend.equal")
846
+ def equal(x, y):
847
+ """DEPRECATED."""
848
+ return tf.equal(x, y)
849
+
850
+
851
+ @keras_export("keras._legacy.backend.eval")
852
+ def eval(x):
853
+ """DEPRECATED."""
854
+ return get_value(to_dense(x))
855
+
856
+
857
+ @keras_export("keras._legacy.backend.exp")
858
+ def exp(x):
859
+ """DEPRECATED."""
860
+ return tf.exp(x)
861
+
862
+
863
+ @keras_export("keras._legacy.backend.expand_dims")
864
+ def expand_dims(x, axis=-1):
865
+ """DEPRECATED."""
866
+ return tf.expand_dims(x, axis)
867
+
868
+
869
+ @keras_export("keras._legacy.backend.eye")
870
+ def eye(size, dtype=None, name=None):
871
+ """DEPRECATED."""
872
+ if dtype is None:
873
+ dtype = backend.floatx()
874
+ tf_dtype = tf.as_dtype(dtype)
875
+ return variable(tf.eye(size, dtype=tf_dtype), dtype, name)
876
+
877
+
878
+ @keras_export("keras._legacy.backend.flatten")
879
+ def flatten(x):
880
+ """DEPRECATED."""
881
+ return tf.reshape(x, [-1])
882
+
883
+
884
+ @keras_export("keras._legacy.backend.foldl")
885
+ def foldl(fn, elems, initializer=None, name=None):
886
+ """DEPRECATED."""
887
+ return tf.compat.v1.foldl(fn, elems, initializer=initializer, name=name)
888
+
889
+
890
+ @keras_export("keras._legacy.backend.foldr")
891
+ def foldr(fn, elems, initializer=None, name=None):
892
+ """DEPRECATED."""
893
+ return tf.compat.v1.foldr(fn, elems, initializer=initializer, name=name)
894
+
895
+
896
+ @keras_export("keras._legacy.backend.gather")
897
+ def gather(reference, indices):
898
+ """DEPRECATED."""
899
+ return tf.compat.v1.gather(reference, indices)
900
+
901
+
902
+ @keras_export("keras._legacy.backend.get_value")
903
+ def get_value(x):
904
+ """DEPRECATED."""
905
+ if not tf.is_tensor(x):
906
+ return x
907
+ if tf.executing_eagerly() or isinstance(x, tf.__internal__.EagerTensor):
908
+ return x.numpy()
909
+ if not getattr(x, "_in_graph_mode", True):
910
+ # This is a variable which was created in an eager context, but is being
911
+ # evaluated from a Graph.
912
+ with tf.__internal__.eager_context.eager_mode():
913
+ return x.numpy()
914
+ with tf.init_scope():
915
+ return x.numpy()
916
+
917
+
918
+ @keras_export("keras._legacy.backend.gradients")
919
+ def gradients(loss, variables):
920
+ """DEPRECATED."""
921
+ return tf.compat.v1.gradients(
922
+ loss, variables, colocate_gradients_with_ops=True
923
+ )
924
+
925
+
926
+ @keras_export("keras._legacy.backend.greater")
927
+ def greater(x, y):
928
+ """DEPRECATED."""
929
+ return tf.greater(x, y)
930
+
931
+
932
+ @keras_export("keras._legacy.backend.greater_equal")
933
+ def greater_equal(x, y):
934
+ """DEPRECATED."""
935
+ return tf.greater_equal(x, y)
936
+
937
+
938
+ @keras_export("keras._legacy.backend.hard_sigmoid")
939
+ def hard_sigmoid(x):
940
+ """DEPRECATED."""
941
+ point_two = tf.convert_to_tensor(0.2, dtype=x.dtype)
942
+ point_five = tf.convert_to_tensor(0.5, dtype=x.dtype)
943
+ x = tf.multiply(x, point_two)
944
+ x = tf.add(x, point_five)
945
+ x = tf.clip_by_value(x, 0.0, 1.0)
946
+ return x
947
+
948
+
949
+ @keras_export("keras._legacy.backend.in_top_k")
950
+ def in_top_k(predictions, targets, k):
951
+ """DEPRECATED."""
952
+ return tf.compat.v1.math.in_top_k(predictions, targets, k)
953
+
954
+
955
+ @keras_export("keras._legacy.backend.int_shape")
956
+ def int_shape(x):
957
+ """DEPRECATED."""
958
+ try:
959
+ shape = x.shape
960
+ if not isinstance(shape, tuple):
961
+ shape = tuple(shape.as_list())
962
+ return shape
963
+ except ValueError:
964
+ return None
965
+
966
+
967
+ @keras_export("keras._legacy.backend.is_sparse")
968
+ def is_sparse(tensor):
969
+ """DEPRECATED."""
970
+ spec = getattr(tensor, "_type_spec", None)
971
+ if spec is not None:
972
+ return isinstance(spec, tf.SparseTensorSpec)
973
+ return isinstance(tensor, tf.SparseTensor)
974
+
975
+
976
+ @keras_export("keras._legacy.backend.l2_normalize")
977
+ def l2_normalize(x, axis=None):
978
+ """DEPRECATED."""
979
+ return tf.linalg.l2_normalize(x, axis=axis)
980
+
981
+
982
+ @keras_export("keras._legacy.backend.less")
983
+ def less(x, y):
984
+ """DEPRECATED."""
985
+ return tf.less(x, y)
986
+
987
+
988
+ @keras_export("keras._legacy.backend.less_equal")
989
+ def less_equal(x, y):
990
+ """DEPRECATED."""
991
+ return tf.less_equal(x, y)
992
+
993
+
994
+ @keras_export("keras._legacy.backend.log")
995
+ def log(x):
996
+ """DEPRECATED."""
997
+ return tf.math.log(x)
998
+
999
+
1000
+ @keras_export("keras._legacy.backend.map_fn")
1001
+ def map_fn(fn, elems, name=None, dtype=None):
1002
+ """DEPRECATED."""
1003
+ return tf.compat.v1.map_fn(fn, elems, name=name, dtype=dtype)
1004
+
1005
+
1006
+ @keras_export("keras._legacy.backend.max")
1007
+ def max(x, axis=None, keepdims=False):
1008
+ """DEPRECATED."""
1009
+ return tf.reduce_max(x, axis, keepdims)
1010
+
1011
+
1012
+ @keras_export("keras._legacy.backend.maximum")
1013
+ def maximum(x, y):
1014
+ """DEPRECATED."""
1015
+ return tf.maximum(x, y)
1016
+
1017
+
1018
+ @keras_export("keras._legacy.backend.mean")
1019
+ def mean(x, axis=None, keepdims=False):
1020
+ """DEPRECATED."""
1021
+ if x.dtype.base_dtype == tf.bool:
1022
+ x = tf.cast(x, backend.floatx())
1023
+ return tf.reduce_mean(x, axis, keepdims)
1024
+
1025
+
1026
+ @keras_export("keras._legacy.backend.min")
1027
+ def min(x, axis=None, keepdims=False):
1028
+ """DEPRECATED."""
1029
+ return tf.reduce_min(x, axis, keepdims)
1030
+
1031
+
1032
+ @keras_export("keras._legacy.backend.minimum")
1033
+ def minimum(x, y):
1034
+ """DEPRECATED."""
1035
+ return tf.minimum(x, y)
1036
+
1037
+
1038
+ @keras_export("keras._legacy.backend.moving_average_update")
1039
+ def moving_average_update(x, value, momentum):
1040
+ """DEPRECATED."""
1041
+ momentum = tf.cast(momentum, x.dtype)
1042
+ value = tf.cast(value, x.dtype)
1043
+ return x.assign_sub((x - value) * (1 - momentum))
1044
+
1045
+
1046
+ @keras_export("keras._legacy.backend.name_scope")
1047
+ def name_scope(name):
1048
+ """DEPRECATED."""
1049
+ return tf.name_scope(name)
1050
+
1051
+
1052
+ @keras_export("keras._legacy.backend.ndim")
1053
+ def ndim(x):
1054
+ """DEPRECATED."""
1055
+ return x.shape.rank
1056
+
1057
+
1058
+ @keras_export("keras._legacy.backend.not_equal")
1059
+ def not_equal(x, y):
1060
+ """DEPRECATED."""
1061
+ return tf.not_equal(x, y)
1062
+
1063
+
1064
+ @keras_export("keras._legacy.backend.one_hot")
1065
+ def one_hot(indices, num_classes):
1066
+ """DEPRECATED."""
1067
+ return tf.one_hot(indices, depth=num_classes, axis=-1)
1068
+
1069
+
1070
+ @keras_export("keras._legacy.backend.ones")
1071
+ def ones(shape, dtype=None, name=None):
1072
+ """DEPRECATED."""
1073
+ with tf.init_scope():
1074
+ if dtype is None:
1075
+ dtype = backend.floatx()
1076
+ tf_dtype = tf.as_dtype(dtype)
1077
+ v = tf.ones(shape=shape, dtype=tf_dtype, name=name)
1078
+ if py_all(v.shape.as_list()):
1079
+ return variable(v, dtype=dtype, name=name)
1080
+ return v
1081
+
1082
+
1083
+ @keras_export("keras._legacy.backend.ones_like")
1084
+ def ones_like(x, dtype=None, name=None):
1085
+ """DEPRECATED."""
1086
+ return tf.ones_like(x, dtype=dtype, name=name)
1087
+
1088
+
1089
+ @keras_export("keras._legacy.backend.permute_dimensions")
1090
+ def permute_dimensions(x, pattern):
1091
+ """DEPRECATED."""
1092
+ return tf.transpose(x, perm=pattern)
1093
+
1094
+
1095
+ @keras_export("keras._legacy.backend.pool2d")
1096
+ def pool2d(
1097
+ x,
1098
+ pool_size,
1099
+ strides=(1, 1),
1100
+ padding="valid",
1101
+ data_format=None,
1102
+ pool_mode="max",
1103
+ ):
1104
+ """DEPRECATED."""
1105
+ if data_format is None:
1106
+ data_format = backend.image_data_format()
1107
+ if data_format not in {"channels_first", "channels_last"}:
1108
+ raise ValueError(f"Unknown data_format: {data_format}")
1109
+ if len(pool_size) != 2:
1110
+ raise ValueError("`pool_size` must be a tuple of 2 integers.")
1111
+ if len(strides) != 2:
1112
+ raise ValueError("`strides` must be a tuple of 2 integers.")
1113
+
1114
+ x, tf_data_format = _preprocess_conv2d_input(x, data_format)
1115
+ padding = _preprocess_padding(padding)
1116
+ if tf_data_format == "NHWC":
1117
+ strides = (1,) + strides + (1,)
1118
+ pool_size = (1,) + pool_size + (1,)
1119
+ else:
1120
+ strides = (1, 1) + strides
1121
+ pool_size = (1, 1) + pool_size
1122
+
1123
+ if pool_mode == "max":
1124
+ x = tf.compat.v1.nn.max_pool(
1125
+ x, pool_size, strides, padding=padding, data_format=tf_data_format
1126
+ )
1127
+ elif pool_mode == "avg":
1128
+ x = tf.compat.v1.nn.avg_pool(
1129
+ x, pool_size, strides, padding=padding, data_format=tf_data_format
1130
+ )
1131
+ else:
1132
+ raise ValueError("Invalid pooling mode: " + str(pool_mode))
1133
+
1134
+ if data_format == "channels_first" and tf_data_format == "NHWC":
1135
+ x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
1136
+ return x
1137
+
1138
+
1139
+ @keras_export("keras._legacy.backend.pool3d")
1140
+ def pool3d(
1141
+ x,
1142
+ pool_size,
1143
+ strides=(1, 1, 1),
1144
+ padding="valid",
1145
+ data_format=None,
1146
+ pool_mode="max",
1147
+ ):
1148
+ """DEPRECATED."""
1149
+ if data_format is None:
1150
+ data_format = backend.image_data_format()
1151
+ if data_format not in {"channels_first", "channels_last"}:
1152
+ raise ValueError(f"Unknown data_format: {data_format}")
1153
+
1154
+ x, tf_data_format = _preprocess_conv3d_input(x, data_format)
1155
+ padding = _preprocess_padding(padding)
1156
+ if tf_data_format == "NDHWC":
1157
+ strides = (1,) + strides + (1,)
1158
+ pool_size = (1,) + pool_size + (1,)
1159
+ else:
1160
+ strides = (1, 1) + strides
1161
+ pool_size = (1, 1) + pool_size
1162
+
1163
+ if pool_mode == "max":
1164
+ x = tf.nn.max_pool3d(
1165
+ x, pool_size, strides, padding=padding, data_format=tf_data_format
1166
+ )
1167
+ elif pool_mode == "avg":
1168
+ x = tf.nn.avg_pool3d(
1169
+ x, pool_size, strides, padding=padding, data_format=tf_data_format
1170
+ )
1171
+ else:
1172
+ raise ValueError("Invalid pooling mode: " + str(pool_mode))
1173
+
1174
+ if data_format == "channels_first" and tf_data_format == "NDHWC":
1175
+ x = tf.transpose(x, (0, 4, 1, 2, 3))
1176
+ return x
1177
+
1178
+
1179
+ @keras_export("keras._legacy.backend.pow")
1180
+ def pow(x, a):
1181
+ """DEPRECATED."""
1182
+ return tf.pow(x, a)
1183
+
1184
+
1185
+ @keras_export("keras._legacy.backend.prod")
1186
+ def prod(x, axis=None, keepdims=False):
1187
+ """DEPRECATED."""
1188
+ return tf.reduce_prod(x, axis, keepdims)
1189
+
1190
+
1191
+ @keras_export("keras._legacy.backend.random_bernoulli")
1192
+ def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
1193
+ """DEPRECATED."""
1194
+ if dtype is None:
1195
+ dtype = backend.floatx()
1196
+ if seed is None:
1197
+ seed = np.random.randint(10e6)
1198
+ return tf.where(
1199
+ tf.random.uniform(shape, dtype=dtype, seed=seed) <= p,
1200
+ tf.ones(shape, dtype=dtype),
1201
+ tf.zeros(shape, dtype=dtype),
1202
+ )
1203
+
1204
+
1205
+ @keras_export("keras._legacy.backend.random_normal")
1206
+ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
1207
+ """DEPRECATED."""
1208
+ if dtype is None:
1209
+ dtype = backend.floatx()
1210
+ if seed is None:
1211
+ seed = np.random.randint(10e6)
1212
+ return tf.random.normal(
1213
+ shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
1214
+ )
1215
+
1216
+
1217
+ @keras_export("keras._legacy.backend.random_normal_variable")
1218
+ def random_normal_variable(
1219
+ shape, mean, scale, dtype=None, name=None, seed=None
1220
+ ):
1221
+ """DEPRECATED."""
1222
+ if dtype is None:
1223
+ dtype = backend.floatx()
1224
+ tf_dtype = tf.as_dtype(dtype)
1225
+ if seed is None:
1226
+ # ensure that randomness is conditioned by the Numpy RNG
1227
+ seed = np.random.randint(10e8)
1228
+ value = tf.compat.v1.random_normal_initializer(
1229
+ mean, scale, dtype=tf_dtype, seed=seed
1230
+ )(shape)
1231
+ return variable(value, dtype=dtype, name=name)
1232
+
1233
+
1234
+ @keras_export("keras._legacy.backend.random_uniform")
1235
+ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
1236
+ """DEPRECATED."""
1237
+ if dtype is None:
1238
+ dtype = backend.floatx()
1239
+ if seed is None:
1240
+ seed = np.random.randint(10e6)
1241
+ return tf.random.uniform(
1242
+ shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed
1243
+ )
1244
+
1245
+
1246
+ @keras_export("keras._legacy.backend.random_uniform_variable")
1247
+ def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
1248
+ """DEPRECATED."""
1249
+ if dtype is None:
1250
+ dtype = backend.floatx()
1251
+ tf_dtype = tf.as_dtype(dtype)
1252
+ if seed is None:
1253
+ # ensure that randomness is conditioned by the Numpy RNG
1254
+ seed = np.random.randint(10e8)
1255
+ value = tf.compat.v1.random_uniform_initializer(
1256
+ low, high, dtype=tf_dtype, seed=seed
1257
+ )(shape)
1258
+ return variable(value, dtype=dtype, name=name)
1259
+
1260
+
1261
+ @keras_export("keras._legacy.backend.reshape")
1262
+ def reshape(x, shape):
1263
+ """DEPRECATED."""
1264
+ return tf.reshape(x, shape)
1265
+
1266
+
1267
+ @keras_export("keras._legacy.backend.relu")
1268
+ def relu(x, alpha=0.0, max_value=None, threshold=0.0):
1269
+ """DEPRECATED."""
1270
+ # While x can be a tensor or variable, we also see cases where
1271
+ # numpy arrays, lists, tuples are passed as well.
1272
+ # lists, tuples do not have 'dtype' attribute.
1273
+ dtype = getattr(x, "dtype", backend.floatx())
1274
+ if alpha != 0.0:
1275
+ if max_value is None and threshold == 0:
1276
+ return tf.nn.leaky_relu(x, alpha=alpha)
1277
+
1278
+ if threshold != 0:
1279
+ negative_part = tf.nn.relu(-x + threshold)
1280
+ else:
1281
+ negative_part = tf.nn.relu(-x)
1282
+ else:
1283
+ negative_part = 1
1284
+
1285
+ clip_max = max_value is not None
1286
+
1287
+ if threshold != 0:
1288
+ # computes x for x > threshold else 0
1289
+ x = x * tf.cast(tf.greater(x, threshold), dtype=dtype)
1290
+ elif max_value == 6:
1291
+ # if no threshold, then can use nn.relu6 native TF op for performance
1292
+ x = tf.nn.relu6(x)
1293
+ clip_max = False
1294
+ else:
1295
+ x = tf.nn.relu(x)
1296
+
1297
+ if clip_max:
1298
+ max_value = tf.convert_to_tensor(max_value, dtype=x.dtype)
1299
+ zero = tf.convert_to_tensor(0, dtype=x.dtype)
1300
+ x = tf.clip_by_value(x, zero, max_value)
1301
+
1302
+ if alpha != 0.0:
1303
+ alpha = tf.convert_to_tensor(alpha, dtype=x.dtype)
1304
+ x -= alpha * negative_part
1305
+ return x
1306
+
1307
+
1308
+ @keras_export("keras._legacy.backend.repeat")
1309
+ def repeat(x, n):
1310
+ """DEPRECATED."""
1311
+ assert ndim(x) == 2
1312
+ x = tf.expand_dims(x, 1)
1313
+ pattern = tf.stack([1, n, 1])
1314
+ return tf.tile(x, pattern)
1315
+
1316
+
1317
+ @keras_export("keras._legacy.backend.repeat_elements")
1318
+ def repeat_elements(x, rep, axis):
1319
+ """DEPRECATED."""
1320
+ x_shape = x.shape.as_list()
1321
+ # For static axis
1322
+ if x_shape[axis] is not None:
1323
+ # slices along the repeat axis
1324
+ splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis)
1325
+ # repeat each slice the given number of reps
1326
+ x_rep = [s for s in splits for _ in range(rep)]
1327
+ return concatenate(x_rep, axis)
1328
+
1329
+ # Here we use tf.tile to mimic behavior of np.repeat so that
1330
+ # we can handle dynamic shapes (that include None).
1331
+ # To do that, we need an auxiliary axis to repeat elements along
1332
+ # it and then merge them along the desired axis.
1333
+
1334
+ # Repeating
1335
+ auxiliary_axis = axis + 1
1336
+ x_shape = tf.shape(x)
1337
+ x_rep = tf.expand_dims(x, axis=auxiliary_axis)
1338
+ reps = np.ones(len(x.shape) + 1)
1339
+ reps[auxiliary_axis] = rep
1340
+ x_rep = tf.tile(x_rep, reps)
1341
+
1342
+ # Merging
1343
+ reps = np.delete(reps, auxiliary_axis)
1344
+ reps[axis] = rep
1345
+ reps = tf.constant(reps, dtype="int32")
1346
+ x_shape *= reps
1347
+ x_rep = tf.reshape(x_rep, x_shape)
1348
+
1349
+ # Fix shape representation
1350
+ x_shape = x.shape.as_list()
1351
+ x_rep.set_shape(x_shape)
1352
+ return x_rep
1353
+
1354
+
1355
+ @keras_export("keras._legacy.backend.resize_images")
1356
+ def resize_images(
1357
+ x, height_factor, width_factor, data_format, interpolation="nearest"
1358
+ ):
1359
+ """DEPRECATED."""
1360
+ if data_format == "channels_first":
1361
+ rows, cols = 2, 3
1362
+ elif data_format == "channels_last":
1363
+ rows, cols = 1, 2
1364
+ else:
1365
+ raise ValueError(f"Invalid `data_format` argument: {data_format}")
1366
+
1367
+ new_shape = x.shape[rows : cols + 1]
1368
+ if new_shape.is_fully_defined():
1369
+ new_shape = tf.constant(new_shape.as_list(), dtype="int32")
1370
+ else:
1371
+ new_shape = tf.shape(x)[rows : cols + 1]
1372
+ new_shape *= tf.constant(
1373
+ np.array([height_factor, width_factor], dtype="int32")
1374
+ )
1375
+
1376
+ if data_format == "channels_first":
1377
+ x = permute_dimensions(x, [0, 2, 3, 1])
1378
+ interpolations = {
1379
+ "area": tf.image.ResizeMethod.AREA,
1380
+ "bicubic": tf.image.ResizeMethod.BICUBIC,
1381
+ "bilinear": tf.image.ResizeMethod.BILINEAR,
1382
+ "gaussian": tf.image.ResizeMethod.GAUSSIAN,
1383
+ "lanczos3": tf.image.ResizeMethod.LANCZOS3,
1384
+ "lanczos5": tf.image.ResizeMethod.LANCZOS5,
1385
+ "mitchellcubic": tf.image.ResizeMethod.MITCHELLCUBIC,
1386
+ "nearest": tf.image.ResizeMethod.NEAREST_NEIGHBOR,
1387
+ }
1388
+ interploations_list = '"' + '", "'.join(interpolations.keys()) + '"'
1389
+ if interpolation in interpolations:
1390
+ x = tf.image.resize(x, new_shape, method=interpolations[interpolation])
1391
+ else:
1392
+ raise ValueError(
1393
+ "`interpolation` argument should be one of: "
1394
+ f'{interploations_list}. Received: "{interpolation}".'
1395
+ )
1396
+ if data_format == "channels_first":
1397
+ x = permute_dimensions(x, [0, 3, 1, 2])
1398
+
1399
+ return x
1400
+
1401
+
1402
+ @keras_export("keras._legacy.backend.resize_volumes")
1403
+ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
1404
+ """DEPRECATED."""
1405
+ if data_format == "channels_first":
1406
+ output = repeat_elements(x, depth_factor, axis=2)
1407
+ output = repeat_elements(output, height_factor, axis=3)
1408
+ output = repeat_elements(output, width_factor, axis=4)
1409
+ return output
1410
+ elif data_format == "channels_last":
1411
+ output = repeat_elements(x, depth_factor, axis=1)
1412
+ output = repeat_elements(output, height_factor, axis=2)
1413
+ output = repeat_elements(output, width_factor, axis=3)
1414
+ return output
1415
+ else:
1416
+ raise ValueError(f"Invalid data_format: {data_format}")
1417
+
1418
+
1419
+ @keras_export("keras._legacy.backend.reverse")
1420
+ def reverse(x, axes):
1421
+ """DEPRECATED."""
1422
+ if isinstance(axes, int):
1423
+ axes = [axes]
1424
+ return tf.reverse(x, axes)
1425
+
1426
+
1427
+ @keras_export("keras._legacy.backend.rnn")
1428
+ def rnn(
1429
+ step_function,
1430
+ inputs,
1431
+ initial_states,
1432
+ go_backwards=False,
1433
+ mask=None,
1434
+ constants=None,
1435
+ unroll=False,
1436
+ input_length=None,
1437
+ time_major=False,
1438
+ zero_output_for_mask=False,
1439
+ return_all_outputs=True,
1440
+ ):
1441
+ """DEPRECATED."""
1442
+ if not tf.__internal__.tf2.enabled():
1443
+ return_all_outputs = True # Not supported in TF1.
1444
+
1445
+ def swap_batch_timestep(input_t):
1446
+ # Swap the batch and timestep dim for the incoming tensor.
1447
+ axes = list(range(len(input_t.shape)))
1448
+ axes[0], axes[1] = 1, 0
1449
+ return tf.transpose(input_t, axes)
1450
+
1451
+ if not time_major:
1452
+ inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
1453
+
1454
+ flatted_inputs = tf.nest.flatten(inputs)
1455
+ time_steps = flatted_inputs[0].shape[0]
1456
+ batch = flatted_inputs[0].shape[1]
1457
+ time_steps_t = tf.shape(flatted_inputs[0])[0]
1458
+
1459
+ for input_ in flatted_inputs:
1460
+ input_.shape.with_rank_at_least(3)
1461
+
1462
+ if mask is not None:
1463
+ if mask.dtype != tf.bool:
1464
+ mask = tf.cast(mask, tf.bool)
1465
+ if len(mask.shape) == 2:
1466
+ mask = expand_dims(mask)
1467
+ if not time_major:
1468
+ mask = swap_batch_timestep(mask)
1469
+
1470
+ if constants is None:
1471
+ constants = []
1472
+
1473
+ # tf.where needs its condition tensor to be the same shape as its two
1474
+ # result tensors, but in our case the condition (mask) tensor is
1475
+ # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
1476
+ # So we need to broadcast the mask to match the shape of inputs.
1477
+ # That's what the tile call does, it just repeats the mask along its
1478
+ # second dimension n times.
1479
+ def _expand_mask(mask_t, input_t, fixed_dim=1):
1480
+ if tf.nest.is_nested(mask_t):
1481
+ raise ValueError(
1482
+ f"mask_t is expected to be tensor, but got {mask_t}"
1483
+ )
1484
+ if tf.nest.is_nested(input_t):
1485
+ raise ValueError(
1486
+ f"input_t is expected to be tensor, but got {input_t}"
1487
+ )
1488
+ rank_diff = len(input_t.shape) - len(mask_t.shape)
1489
+ for _ in range(rank_diff):
1490
+ mask_t = tf.expand_dims(mask_t, -1)
1491
+ multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
1492
+ return tf.tile(mask_t, multiples)
1493
+
1494
+ if unroll:
1495
+ if not time_steps:
1496
+ raise ValueError("Unrolling requires a fixed number of timesteps.")
1497
+ states = tuple(initial_states)
1498
+ successive_states = []
1499
+ successive_outputs = []
1500
+
1501
+ # Process the input tensors. The input tensor need to be split on the
1502
+ # time_step dim, and reverse if go_backwards is True. In the case of
1503
+ # nested input, the input is flattened and then transformed
1504
+ # individually. The result of this will be a tuple of lists, each of
1505
+ # the item in tuple is list of the tensor with shape (batch, feature)
1506
+ def _process_single_input_t(input_t):
1507
+ input_t = tf.unstack(input_t) # unstack for time_step dim
1508
+ if go_backwards:
1509
+ input_t.reverse()
1510
+ return input_t
1511
+
1512
+ if tf.nest.is_nested(inputs):
1513
+ processed_input = tf.nest.map_structure(
1514
+ _process_single_input_t, inputs
1515
+ )
1516
+ else:
1517
+ processed_input = (_process_single_input_t(inputs),)
1518
+
1519
+ def _get_input_tensor(time):
1520
+ inp = [t_[time] for t_ in processed_input]
1521
+ return tf.nest.pack_sequence_as(inputs, inp)
1522
+
1523
+ if mask is not None:
1524
+ mask_list = tf.unstack(mask)
1525
+ if go_backwards:
1526
+ mask_list.reverse()
1527
+
1528
+ for i in range(time_steps):
1529
+ inp = _get_input_tensor(i)
1530
+ mask_t = mask_list[i]
1531
+ output, new_states = step_function(
1532
+ inp, tuple(states) + tuple(constants)
1533
+ )
1534
+ tiled_mask_t = _expand_mask(mask_t, output)
1535
+
1536
+ if not successive_outputs:
1537
+ prev_output = zeros_like(output)
1538
+ else:
1539
+ prev_output = successive_outputs[-1]
1540
+
1541
+ output = tf.where(tiled_mask_t, output, prev_output)
1542
+
1543
+ flat_states = tf.nest.flatten(states)
1544
+ flat_new_states = tf.nest.flatten(new_states)
1545
+ tiled_mask_t = tuple(
1546
+ _expand_mask(mask_t, s) for s in flat_states
1547
+ )
1548
+ flat_final_states = tuple(
1549
+ tf.where(m, s, ps)
1550
+ for m, s, ps in zip(
1551
+ tiled_mask_t, flat_new_states, flat_states
1552
+ )
1553
+ )
1554
+ states = tf.nest.pack_sequence_as(states, flat_final_states)
1555
+
1556
+ if return_all_outputs:
1557
+ successive_outputs.append(output)
1558
+ successive_states.append(states)
1559
+ else:
1560
+ successive_outputs = [output]
1561
+ successive_states = [states]
1562
+ last_output = successive_outputs[-1]
1563
+ new_states = successive_states[-1]
1564
+ outputs = tf.stack(successive_outputs)
1565
+
1566
+ if zero_output_for_mask:
1567
+ last_output = tf.where(
1568
+ _expand_mask(mask_list[-1], last_output),
1569
+ last_output,
1570
+ zeros_like(last_output),
1571
+ )
1572
+ outputs = tf.where(
1573
+ _expand_mask(mask, outputs, fixed_dim=2),
1574
+ outputs,
1575
+ zeros_like(outputs),
1576
+ )
1577
+
1578
+ else: # mask is None
1579
+ for i in range(time_steps):
1580
+ inp = _get_input_tensor(i)
1581
+ output, states = step_function(
1582
+ inp, tuple(states) + tuple(constants)
1583
+ )
1584
+ if return_all_outputs:
1585
+ successive_outputs.append(output)
1586
+ successive_states.append(states)
1587
+ else:
1588
+ successive_outputs = [output]
1589
+ successive_states = [states]
1590
+ last_output = successive_outputs[-1]
1591
+ new_states = successive_states[-1]
1592
+ outputs = tf.stack(successive_outputs)
1593
+
1594
+ else: # Unroll == False
1595
+ states = tuple(initial_states)
1596
+
1597
+ # Create input tensor array, if the inputs is nested tensors, then it
1598
+ # will be flattened first, and tensor array will be created one per
1599
+ # flattened tensor.
1600
+ input_ta = tuple(
1601
+ tf.TensorArray(
1602
+ dtype=inp.dtype,
1603
+ size=time_steps_t,
1604
+ tensor_array_name=f"input_ta_{i}",
1605
+ )
1606
+ for i, inp in enumerate(flatted_inputs)
1607
+ )
1608
+ input_ta = tuple(
1609
+ (
1610
+ ta.unstack(input_)
1611
+ if not go_backwards
1612
+ else ta.unstack(reverse(input_, 0))
1613
+ )
1614
+ for ta, input_ in zip(input_ta, flatted_inputs)
1615
+ )
1616
+
1617
+ # Get the time(0) input and compute the output for that, the output will
1618
+ # be used to determine the dtype of output tensor array. Don't read from
1619
+ # input_ta due to TensorArray clear_after_read default to True.
1620
+ input_time_zero = tf.nest.pack_sequence_as(
1621
+ inputs, [inp[0] for inp in flatted_inputs]
1622
+ )
1623
+ # output_time_zero is used to determine the cell output shape and its
1624
+ # dtype. the value is discarded.
1625
+ output_time_zero, _ = step_function(
1626
+ input_time_zero, tuple(initial_states) + tuple(constants)
1627
+ )
1628
+
1629
+ output_ta_size = time_steps_t if return_all_outputs else 1
1630
+ output_ta = tuple(
1631
+ tf.TensorArray(
1632
+ dtype=out.dtype,
1633
+ size=output_ta_size,
1634
+ element_shape=out.shape,
1635
+ tensor_array_name=f"output_ta_{i}",
1636
+ )
1637
+ for i, out in enumerate(tf.nest.flatten(output_time_zero))
1638
+ )
1639
+
1640
+ time = tf.constant(0, dtype="int32", name="time")
1641
+
1642
+ if input_length is None:
1643
+ max_iterations = time_steps_t
1644
+ else:
1645
+ max_iterations = tf.reduce_max(input_length)
1646
+
1647
+ while_loop_kwargs = {
1648
+ "cond": lambda time, *_: time < time_steps_t,
1649
+ "maximum_iterations": max_iterations,
1650
+ "parallel_iterations": 32,
1651
+ "swap_memory": True,
1652
+ }
1653
+ if mask is not None:
1654
+ if go_backwards:
1655
+ mask = reverse(mask, 0)
1656
+
1657
+ mask_ta = tf.TensorArray(
1658
+ dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
1659
+ )
1660
+ mask_ta = mask_ta.unstack(mask)
1661
+
1662
+ def masking_fn(time):
1663
+ return mask_ta.read(time)
1664
+
1665
+ def compute_masked_output(mask_t, flat_out, flat_mask):
1666
+ tiled_mask_t = tuple(
1667
+ _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
1668
+ for o in flat_out
1669
+ )
1670
+ return tuple(
1671
+ tf.where(m, o, fm)
1672
+ for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
1673
+ )
1674
+
1675
+ elif isinstance(input_length, tf.Tensor):
1676
+ if go_backwards:
1677
+ max_len = tf.reduce_max(input_length, axis=0)
1678
+ rev_input_length = tf.subtract(max_len - 1, input_length)
1679
+
1680
+ def masking_fn(time):
1681
+ return tf.less(rev_input_length, time)
1682
+
1683
+ else:
1684
+
1685
+ def masking_fn(time):
1686
+ return tf.greater(input_length, time)
1687
+
1688
+ def compute_masked_output(mask_t, flat_out, flat_mask):
1689
+ return tuple(
1690
+ tf.compat.v1.where(mask_t, o, zo)
1691
+ for (o, zo) in zip(flat_out, flat_mask)
1692
+ )
1693
+
1694
+ else:
1695
+ masking_fn = None
1696
+
1697
+ if masking_fn is not None:
1698
+ # Mask for the T output will be base on the output of T - 1. In the
1699
+ # case T = 0, a zero filled tensor will be used.
1700
+ flat_zero_output = tuple(
1701
+ tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
1702
+ )
1703
+
1704
+ def _step(time, output_ta_t, prev_output, *states):
1705
+ """RNN step function.
1706
+
1707
+ Args:
1708
+ time: Current timestep value.
1709
+ output_ta_t: TensorArray.
1710
+ prev_output: tuple of outputs from time - 1.
1711
+ *states: List of states.
1712
+
1713
+ Returns:
1714
+ Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
1715
+ """
1716
+ current_input = tuple(ta.read(time) for ta in input_ta)
1717
+ # maybe set shape.
1718
+ current_input = tf.nest.pack_sequence_as(inputs, current_input)
1719
+ mask_t = masking_fn(time)
1720
+ output, new_states = step_function(
1721
+ current_input, tuple(states) + tuple(constants)
1722
+ )
1723
+ # mask output
1724
+ flat_output = tf.nest.flatten(output)
1725
+ flat_mask_output = (
1726
+ flat_zero_output
1727
+ if zero_output_for_mask
1728
+ else tf.nest.flatten(prev_output)
1729
+ )
1730
+ flat_new_output = compute_masked_output(
1731
+ mask_t, flat_output, flat_mask_output
1732
+ )
1733
+
1734
+ # mask states
1735
+ flat_state = tf.nest.flatten(states)
1736
+ flat_new_state = tf.nest.flatten(new_states)
1737
+ for state, new_state in zip(flat_state, flat_new_state):
1738
+ if isinstance(new_state, tf.Tensor):
1739
+ new_state.set_shape(state.shape)
1740
+ flat_final_state = compute_masked_output(
1741
+ mask_t, flat_new_state, flat_state
1742
+ )
1743
+ new_states = tf.nest.pack_sequence_as(
1744
+ new_states, flat_final_state
1745
+ )
1746
+
1747
+ ta_index_to_write = time if return_all_outputs else 0
1748
+ output_ta_t = tuple(
1749
+ ta.write(ta_index_to_write, out)
1750
+ for ta, out in zip(output_ta_t, flat_new_output)
1751
+ )
1752
+
1753
+ return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
1754
+ new_states
1755
+ )
1756
+
1757
+ final_outputs = tf.compat.v1.while_loop(
1758
+ body=_step,
1759
+ loop_vars=(time, output_ta, flat_zero_output) + states,
1760
+ **while_loop_kwargs,
1761
+ )
1762
+ # Skip final_outputs[2] which is the output for final timestep.
1763
+ new_states = final_outputs[3:]
1764
+ else:
1765
+
1766
+ def _step(time, output_ta_t, *states):
1767
+ """RNN step function.
1768
+
1769
+ Args:
1770
+ time: Current timestep value.
1771
+ output_ta_t: TensorArray.
1772
+ *states: List of states.
1773
+
1774
+ Returns:
1775
+ Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
1776
+ """
1777
+ current_input = tuple(ta.read(time) for ta in input_ta)
1778
+ current_input = tf.nest.pack_sequence_as(inputs, current_input)
1779
+ output, new_states = step_function(
1780
+ current_input, tuple(states) + tuple(constants)
1781
+ )
1782
+ flat_state = tf.nest.flatten(states)
1783
+ flat_new_state = tf.nest.flatten(new_states)
1784
+ for state, new_state in zip(flat_state, flat_new_state):
1785
+ if isinstance(new_state, tf.Tensor):
1786
+ new_state.set_shape(state.shape)
1787
+
1788
+ flat_output = tf.nest.flatten(output)
1789
+ ta_index_to_write = time if return_all_outputs else 0
1790
+ output_ta_t = tuple(
1791
+ ta.write(ta_index_to_write, out)
1792
+ for ta, out in zip(output_ta_t, flat_output)
1793
+ )
1794
+
1795
+ new_states = tf.nest.pack_sequence_as(
1796
+ initial_states, flat_new_state
1797
+ )
1798
+ return (time + 1, output_ta_t) + tuple(new_states)
1799
+
1800
+ final_outputs = tf.compat.v1.while_loop(
1801
+ body=_step,
1802
+ loop_vars=(time, output_ta) + states,
1803
+ **while_loop_kwargs,
1804
+ )
1805
+ new_states = final_outputs[2:]
1806
+
1807
+ output_ta = final_outputs[1]
1808
+
1809
+ outputs = tuple(o.stack() for o in output_ta)
1810
+ last_output = tuple(o[-1] for o in outputs)
1811
+
1812
+ outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
1813
+ last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
1814
+
1815
+ # static shape inference
1816
+ def set_shape(output_):
1817
+ if isinstance(output_, tf.Tensor):
1818
+ shape = output_.shape.as_list()
1819
+ if return_all_outputs:
1820
+ shape[0] = time_steps
1821
+ else:
1822
+ shape[0] = 1
1823
+ shape[1] = batch
1824
+ output_.set_shape(shape)
1825
+ return output_
1826
+
1827
+ outputs = tf.nest.map_structure(set_shape, outputs)
1828
+
1829
+ if not time_major:
1830
+ outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
1831
+
1832
+ return last_output, outputs, new_states
1833
+
1834
+
1835
+ @keras_export("keras._legacy.backend.round")
1836
+ def round(x):
1837
+ """DEPRECATED."""
1838
+ return tf.round(x)
1839
+
1840
+
1841
+ @keras_export("keras._legacy.backend.separable_conv2d")
1842
+ def separable_conv2d(
1843
+ x,
1844
+ depthwise_kernel,
1845
+ pointwise_kernel,
1846
+ strides=(1, 1),
1847
+ padding="valid",
1848
+ data_format=None,
1849
+ dilation_rate=(1, 1),
1850
+ ):
1851
+ """DEPRECATED."""
1852
+ if data_format is None:
1853
+ data_format = backend.image_data_format()
1854
+ if data_format not in {"channels_first", "channels_last"}:
1855
+ raise ValueError(f"Unknown data_format: {data_format}")
1856
+ if len(strides) != 2:
1857
+ raise ValueError("`strides` must be a tuple of 2 integers.")
1858
+
1859
+ x, tf_data_format = _preprocess_conv2d_input(x, data_format)
1860
+ padding = _preprocess_padding(padding)
1861
+ if not isinstance(strides, tuple):
1862
+ strides = tuple(strides)
1863
+ if tf_data_format == "NHWC":
1864
+ strides = (1,) + strides + (1,)
1865
+ else:
1866
+ strides = (1, 1) + strides
1867
+
1868
+ x = tf.nn.separable_conv2d(
1869
+ x,
1870
+ depthwise_kernel,
1871
+ pointwise_kernel,
1872
+ strides=strides,
1873
+ padding=padding,
1874
+ dilations=dilation_rate,
1875
+ data_format=tf_data_format,
1876
+ )
1877
+ if data_format == "channels_first" and tf_data_format == "NHWC":
1878
+ x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
1879
+ return x
1880
+
1881
+
1882
+ @keras_export("keras._legacy.backend.set_value")
1883
+ def set_value(x, value):
1884
+ """DEPRECATED."""
1885
+ value = np.asarray(value, dtype=x.dtype.name)
1886
+ x.assign(value)
1887
+
1888
+
1889
+ @keras_export("keras._legacy.backend.shape")
1890
+ def shape(x):
1891
+ """DEPRECATED."""
1892
+ return tf.shape(x)
1893
+
1894
+
1895
+ @keras_export("keras._legacy.backend.sigmoid")
1896
+ def sigmoid(x):
1897
+ """DEPRECATED."""
1898
+ output = tf.sigmoid(x)
1899
+ return output
1900
+
1901
+
1902
+ @keras_export("keras._legacy.backend.sign")
1903
+ def sign(x):
1904
+ """DEPRECATED."""
1905
+ return tf.sign(x)
1906
+
1907
+
1908
+ @keras_export("keras._legacy.backend.sin")
1909
+ def sin(x):
1910
+ """DEPRECATED."""
1911
+ return tf.sin(x)
1912
+
1913
+
1914
+ @keras_export("keras._legacy.backend.softmax")
1915
+ def softmax(x, axis=-1):
1916
+ """DEPRECATED."""
1917
+ if x.shape.rank <= 1:
1918
+ raise ValueError(
1919
+ f"Cannot apply softmax to a tensor that is 1D. Received input: {x}"
1920
+ )
1921
+
1922
+ if isinstance(axis, int):
1923
+ output = tf.nn.softmax(x, axis=axis)
1924
+ else:
1925
+ # nn.softmax does not support tuple axis.
1926
+ numerator = tf.exp(x - tf.reduce_max(x, axis=axis, keepdims=True))
1927
+ denominator = tf.reduce_sum(numerator, axis=axis, keepdims=True)
1928
+ output = numerator / denominator
1929
+
1930
+ # Cache the logits to use for crossentropy loss.
1931
+ output._keras_logits = x
1932
+ return output
1933
+
1934
+
1935
+ @keras_export("keras._legacy.backend.softplus")
1936
+ def softplus(x):
1937
+ """DEPRECATED."""
1938
+ return tf.math.softplus(x)
1939
+
1940
+
1941
+ @keras_export("keras._legacy.backend.softsign")
1942
+ def softsign(x):
1943
+ """DEPRECATED."""
1944
+ return tf.math.softsign(x)
1945
+
1946
+
1947
+ @keras_export("keras._legacy.backend.sparse_categorical_crossentropy")
1948
+ def sparse_categorical_crossentropy(
1949
+ target, output, from_logits=False, axis=-1, ignore_class=None
1950
+ ):
1951
+ """DEPRECATED."""
1952
+ target = tf.convert_to_tensor(target)
1953
+ output = tf.convert_to_tensor(output)
1954
+
1955
+ target = cast(target, "int64")
1956
+
1957
+ if not from_logits:
1958
+ epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)
1959
+ output = tf.clip_by_value(output, epsilon_, 1 - epsilon_)
1960
+ output = tf.math.log(output)
1961
+
1962
+ # Permute output so that the last axis contains the logits/probabilities.
1963
+ if isinstance(output.shape, (tuple, list)):
1964
+ output_rank = len(output.shape)
1965
+ else:
1966
+ output_rank = output.shape.ndims
1967
+ if output_rank is not None:
1968
+ axis %= output_rank
1969
+ if axis != output_rank - 1:
1970
+ permutation = list(
1971
+ itertools.chain(
1972
+ range(axis), range(axis + 1, output_rank), [axis]
1973
+ )
1974
+ )
1975
+ output = tf.transpose(output, perm=permutation)
1976
+ elif axis != -1:
1977
+ raise ValueError(
1978
+ "Cannot compute sparse categorical crossentropy with `axis={}` "
1979
+ "on an output tensor with unknown rank".format(axis)
1980
+ )
1981
+
1982
+ # Try to adjust the shape so that rank of labels = rank of logits - 1.
1983
+ output_shape = tf.shape(output)
1984
+ target_rank = target.shape.ndims
1985
+
1986
+ update_shape = (
1987
+ target_rank is not None
1988
+ and output_rank is not None
1989
+ and target_rank != output_rank - 1
1990
+ )
1991
+ if update_shape:
1992
+ target = flatten(target)
1993
+ output = tf.reshape(output, [-1, output_shape[-1]])
1994
+
1995
+ if ignore_class is not None:
1996
+ valid_mask = tf.not_equal(target, cast(ignore_class, target.dtype))
1997
+ target = target[valid_mask]
1998
+ output = output[valid_mask]
1999
+
2000
+ res = tf.nn.sparse_softmax_cross_entropy_with_logits(
2001
+ labels=target, logits=output
2002
+ )
2003
+
2004
+ if ignore_class is not None:
2005
+ res_shape = cast(output_shape[:-1], "int64")
2006
+ valid_mask = tf.reshape(valid_mask, res_shape)
2007
+ res = tf.scatter_nd(tf.where(valid_mask), res, res_shape)
2008
+ res._keras_mask = valid_mask
2009
+
2010
+ return res
2011
+
2012
+ if update_shape and output_rank >= 3:
2013
+ # If our output includes timesteps or
2014
+ # spatial dimensions we need to reshape
2015
+ res = tf.reshape(res, output_shape[:-1])
2016
+
2017
+ return res
2018
+
2019
+
2020
+ @keras_export("keras._legacy.backend.spatial_2d_padding")
2021
+ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
2022
+ """DEPRECATED."""
2023
+ assert len(padding) == 2
2024
+ assert len(padding[0]) == 2
2025
+ assert len(padding[1]) == 2
2026
+ if data_format is None:
2027
+ data_format = backend.image_data_format()
2028
+ if data_format not in {"channels_first", "channels_last"}:
2029
+ raise ValueError(f"Unknown data_format: {data_format}")
2030
+
2031
+ if data_format == "channels_first":
2032
+ pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
2033
+ else:
2034
+ pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
2035
+ return tf.compat.v1.pad(x, pattern)
2036
+
2037
+
2038
+ @keras_export("keras._legacy.backend.spatial_3d_padding")
2039
+ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
2040
+ """DEPRECATED."""
2041
+ assert len(padding) == 3
2042
+ assert len(padding[0]) == 2
2043
+ assert len(padding[1]) == 2
2044
+ assert len(padding[2]) == 2
2045
+ if data_format is None:
2046
+ data_format = backend.image_data_format()
2047
+ if data_format not in {"channels_first", "channels_last"}:
2048
+ raise ValueError(f"Unknown data_format: {data_format}")
2049
+
2050
+ if data_format == "channels_first":
2051
+ pattern = [
2052
+ [0, 0],
2053
+ [0, 0],
2054
+ [padding[0][0], padding[0][1]],
2055
+ [padding[1][0], padding[1][1]],
2056
+ [padding[2][0], padding[2][1]],
2057
+ ]
2058
+ else:
2059
+ pattern = [
2060
+ [0, 0],
2061
+ [padding[0][0], padding[0][1]],
2062
+ [padding[1][0], padding[1][1]],
2063
+ [padding[2][0], padding[2][1]],
2064
+ [0, 0],
2065
+ ]
2066
+ return tf.compat.v1.pad(x, pattern)
2067
+
2068
+
2069
+ @keras_export("keras._legacy.backend.sqrt")
2070
+ def sqrt(x):
2071
+ """DEPRECATED."""
2072
+ zero = tf.convert_to_tensor(0.0, x.dtype)
2073
+ x = tf.maximum(x, zero)
2074
+ return tf.sqrt(x)
2075
+
2076
+
2077
+ @keras_export("keras._legacy.backend.square")
2078
+ def square(x):
2079
+ """DEPRECATED."""
2080
+ return tf.square(x)
2081
+
2082
+
2083
+ @keras_export("keras._legacy.backend.squeeze")
2084
+ def squeeze(x, axis):
2085
+ """DEPRECATED."""
2086
+ return tf.squeeze(x, [axis])
2087
+
2088
+
2089
+ @keras_export("keras._legacy.backend.stack")
2090
+ def stack(x, axis=0):
2091
+ """DEPRECATED."""
2092
+ return tf.stack(x, axis=axis)
2093
+
2094
+
2095
+ @keras_export("keras._legacy.backend.std")
2096
+ def std(x, axis=None, keepdims=False):
2097
+ """DEPRECATED."""
2098
+ if x.dtype.base_dtype == tf.bool:
2099
+ x = tf.cast(x, backend.floatx())
2100
+ return tf.math.reduce_std(x, axis=axis, keepdims=keepdims)
2101
+
2102
+
2103
+ @keras_export("keras._legacy.backend.stop_gradient")
2104
+ def stop_gradient(variables):
2105
+ """DEPRECATED."""
2106
+ if isinstance(variables, (list, tuple)):
2107
+ return map(tf.stop_gradient, variables)
2108
+ return tf.stop_gradient(variables)
2109
+
2110
+
2111
+ @keras_export("keras._legacy.backend.sum")
2112
+ def sum(x, axis=None, keepdims=False):
2113
+ """DEPRECATED."""
2114
+ return tf.reduce_sum(x, axis, keepdims)
2115
+
2116
+
2117
+ @keras_export("keras._legacy.backend.switch")
2118
+ def switch(condition, then_expression, else_expression):
2119
+ """DEPRECATED."""
2120
+ if condition.dtype != tf.bool:
2121
+ condition = tf.cast(condition, "bool")
2122
+ cond_ndim = ndim(condition)
2123
+ if not cond_ndim:
2124
+ if not callable(then_expression):
2125
+
2126
+ def then_expression_fn():
2127
+ return then_expression
2128
+
2129
+ else:
2130
+ then_expression_fn = then_expression
2131
+ if not callable(else_expression):
2132
+
2133
+ def else_expression_fn():
2134
+ return else_expression
2135
+
2136
+ else:
2137
+ else_expression_fn = else_expression
2138
+ x = tf.compat.v1.cond(condition, then_expression_fn, else_expression_fn)
2139
+ else:
2140
+ # tf.where needs its condition tensor
2141
+ # to be the same shape as its two
2142
+ # result tensors
2143
+ if callable(then_expression):
2144
+ then_expression = then_expression()
2145
+ if callable(else_expression):
2146
+ else_expression = else_expression()
2147
+ expr_ndim = ndim(then_expression)
2148
+ if cond_ndim > expr_ndim:
2149
+ raise ValueError(
2150
+ "Rank of `condition` should be less than or"
2151
+ " equal to rank of `then_expression` and "
2152
+ "`else_expression`. ndim(condition)="
2153
+ + str(cond_ndim)
2154
+ + ", ndim(then_expression)="
2155
+ + str(expr_ndim)
2156
+ )
2157
+ if cond_ndim > 1:
2158
+ ndim_diff = expr_ndim - cond_ndim
2159
+ cond_shape = tf.concat(
2160
+ [tf.shape(condition), [1] * ndim_diff], axis=0
2161
+ )
2162
+ condition = tf.reshape(condition, cond_shape)
2163
+ expr_shape = tf.shape(then_expression)
2164
+ shape_diff = expr_shape - cond_shape
2165
+ tile_shape = tf.where(
2166
+ shape_diff > 0, expr_shape, tf.ones_like(expr_shape)
2167
+ )
2168
+ condition = tf.tile(condition, tile_shape)
2169
+ x = tf.where(condition, then_expression, else_expression)
2170
+ return x
2171
+
2172
+
2173
+ @keras_export("keras._legacy.backend.tanh")
2174
+ def tanh(x):
2175
+ """DEPRECATED."""
2176
+ return tf.tanh(x)
2177
+
2178
+
2179
+ @keras_export("keras._legacy.backend.temporal_padding")
2180
+ def temporal_padding(x, padding=(1, 1)):
2181
+ """DEPRECATED."""
2182
+ assert len(padding) == 2
2183
+ pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
2184
+ return tf.compat.v1.pad(x, pattern)
2185
+
2186
+
2187
+ @keras_export("keras._legacy.backend.tile")
2188
+ def tile(x, n):
2189
+ """DEPRECATED."""
2190
+ if isinstance(n, int):
2191
+ n = [n]
2192
+ return tf.tile(x, n)
2193
+
2194
+
2195
+ @keras_export("keras._legacy.backend.to_dense")
2196
+ def to_dense(tensor):
2197
+ """DEPRECATED."""
2198
+ if is_sparse(tensor):
2199
+ return tf.sparse.to_dense(tensor)
2200
+ else:
2201
+ return tensor
2202
+
2203
+
2204
+ @keras_export("keras._legacy.backend.transpose")
2205
+ def transpose(x):
2206
+ """DEPRECATED."""
2207
+ return tf.transpose(x)
2208
+
2209
+
2210
+ @keras_export("keras._legacy.backend.truncated_normal")
2211
+ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
2212
+ """DEPRECATED."""
2213
+ if dtype is None:
2214
+ dtype = backend.floatx()
2215
+ if seed is None:
2216
+ seed = np.random.randint(10e6)
2217
+ return tf.random.truncated_normal(
2218
+ shape, mean, stddev, dtype=dtype, seed=seed
2219
+ )
2220
+
2221
+
2222
+ @keras_export("keras._legacy.backend.update")
2223
+ def update(x, new_x):
2224
+ """DEPRECATED."""
2225
+ return tf.compat.v1.assign(x, new_x)
2226
+
2227
+
2228
+ @keras_export("keras._legacy.backend.update_add")
2229
+ def update_add(x, increment):
2230
+ """DEPRECATED."""
2231
+ return tf.compat.v1.assign_add(x, increment)
2232
+
2233
+
2234
+ @keras_export("keras._legacy.backend.update_sub")
2235
+ def update_sub(x, decrement):
2236
+ """DEPRECATED."""
2237
+ return tf.compat.v1.assign_sub(x, decrement)
2238
+
2239
+
2240
+ @keras_export("keras._legacy.backend.var")
2241
+ def var(x, axis=None, keepdims=False):
2242
+ """DEPRECATED."""
2243
+ if x.dtype.base_dtype == tf.bool:
2244
+ x = tf.cast(x, backend.floatx())
2245
+ return tf.math.reduce_variance(x, axis=axis, keepdims=keepdims)
2246
+
2247
+
2248
+ @keras_export("keras._legacy.backend.variable")
2249
+ def variable(value, dtype=None, name=None, constraint=None):
2250
+ """DEPRECATED."""
2251
+ if dtype is None:
2252
+ dtype = backend.floatx()
2253
+ if hasattr(value, "tocoo"):
2254
+ sparse_coo = value.tocoo()
2255
+ indices = np.concatenate(
2256
+ (
2257
+ np.expand_dims(sparse_coo.row, 1),
2258
+ np.expand_dims(sparse_coo.col, 1),
2259
+ ),
2260
+ 1,
2261
+ )
2262
+ v = tf.SparseTensor(
2263
+ indices=indices,
2264
+ values=sparse_coo.data,
2265
+ dense_shape=sparse_coo.shape,
2266
+ )
2267
+ v._keras_shape = sparse_coo.shape
2268
+ return v
2269
+ v = tf.Variable(
2270
+ value, dtype=tf.as_dtype(dtype), name=name, constraint=constraint
2271
+ )
2272
+ return v
2273
+
2274
+
2275
+ @keras_export("keras._legacy.backend.zeros")
2276
+ def zeros(shape, dtype=None, name=None):
2277
+ """DEPRECATED."""
2278
+ with tf.init_scope():
2279
+ if dtype is None:
2280
+ dtype = backend.floatx()
2281
+ tf_dtype = tf.as_dtype(dtype)
2282
+ v = tf.zeros(shape=shape, dtype=tf_dtype, name=name)
2283
+ if py_all(v.shape.as_list()):
2284
+ return variable(v, dtype=dtype, name=name)
2285
+ return v
2286
+
2287
+
2288
+ @keras_export("keras._legacy.backend.zeros_like")
2289
+ def zeros_like(x, dtype=None, name=None):
2290
+ """DEPRECATED."""
2291
+ return tf.zeros_like(x, dtype=dtype, name=name)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/layers.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Legacy Keras 1/2 layers.
2
+
3
+ AlphaDropout
4
+ RandomHeight
5
+ RandomWidth
6
+ ThresholdedReLU
7
+ """
8
+
9
+ from keras.src import backend
10
+ from keras.src.api_export import keras_export
11
+ from keras.src.layers.layer import Layer
12
+ from keras.src.utils.module_utils import tensorflow as tf
13
+
14
+
15
+ @keras_export("keras._legacy.layers.AlphaDropout")
16
+ class AlphaDropout(Layer):
17
+ """DEPRECATED."""
18
+
19
+ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
+ super().__init__(**kwargs)
21
+ self.rate = rate
22
+ self.seed = seed
23
+ self.noise_shape = noise_shape
24
+ self.seed_generator = backend.random.SeedGenerator(seed)
25
+ self.supports_masking = True
26
+ self.built = True
27
+
28
+ def call(self, inputs, training=False):
29
+ if training and self.rate > 0:
30
+ alpha = 1.6732632423543772848170429916717
31
+ scale = 1.0507009873554804934193349852946
32
+ alpha_p = -alpha * scale
33
+
34
+ if self.noise_shape is None:
35
+ noise_shape = tf.shape(inputs)
36
+ else:
37
+ noise_shape = self.noise_shape
38
+ kept_idx = tf.greater_equal(
39
+ backend.random.uniform(noise_shape, seed=self.seed_generator),
40
+ self.rate,
41
+ )
42
+ kept_idx = tf.cast(kept_idx, inputs.dtype)
43
+
44
+ # Get affine transformation params
45
+ a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5
46
+ b = -a * alpha_p * self.rate
47
+
48
+ # Apply mask
49
+ x = inputs * kept_idx + alpha_p * (1 - kept_idx)
50
+
51
+ # Do affine transformation
52
+ return a * x + b
53
+ return inputs
54
+
55
+ def get_config(self):
56
+ config = {"rate": self.rate, "seed": self.seed}
57
+ base_config = super().get_config()
58
+ return {**base_config, **config}
59
+
60
+ def compute_output_shape(self, input_shape):
61
+ return input_shape
62
+
63
+
64
+ @keras_export("keras._legacy.layers.RandomHeight")
65
+ class RandomHeight(Layer):
66
+ """DEPRECATED."""
67
+
68
+ def __init__(self, factor, interpolation="bilinear", seed=None, **kwargs):
69
+ super().__init__(**kwargs)
70
+ self.seed_generator = backend.random.SeedGenerator(seed)
71
+ self.factor = factor
72
+ if isinstance(factor, (tuple, list)):
73
+ self.height_lower = factor[0]
74
+ self.height_upper = factor[1]
75
+ else:
76
+ self.height_lower = -factor
77
+ self.height_upper = factor
78
+
79
+ if self.height_upper < self.height_lower:
80
+ raise ValueError(
81
+ "`factor` argument cannot have an upper bound lesser than the "
82
+ f"lower bound. Received: factor={factor}"
83
+ )
84
+ if self.height_lower < -1.0 or self.height_upper < -1.0:
85
+ raise ValueError(
86
+ "`factor` argument must have values larger than -1. "
87
+ f"Received: factor={factor}"
88
+ )
89
+ self.interpolation = interpolation
90
+ self.seed = seed
91
+
92
+ def call(self, inputs, training=True):
93
+ inputs = tf.convert_to_tensor(inputs, dtype=self.compute_dtype)
94
+
95
+ def random_height_inputs(inputs):
96
+ """Inputs height-adjusted with random ops."""
97
+ inputs_shape = tf.shape(inputs)
98
+ img_hd = tf.cast(inputs_shape[-3], tf.float32)
99
+ img_wd = inputs_shape[-2]
100
+ height_factor = backend.random.uniform(
101
+ shape=[],
102
+ minval=(1.0 + self.height_lower),
103
+ maxval=(1.0 + self.height_upper),
104
+ seed=self.seed_generator,
105
+ )
106
+ adjusted_height = tf.cast(height_factor * img_hd, tf.int32)
107
+ adjusted_size = tf.stack([adjusted_height, img_wd])
108
+ output = tf.image.resize(
109
+ images=inputs,
110
+ size=adjusted_size,
111
+ method=self.interpolation,
112
+ )
113
+ # tf.resize will output float32 regardless of input type.
114
+ output = tf.cast(output, self.compute_dtype)
115
+ output_shape = inputs.shape.as_list()
116
+ output_shape[-3] = None
117
+ output.set_shape(output_shape)
118
+ return output
119
+
120
+ if training:
121
+ return random_height_inputs(inputs)
122
+ else:
123
+ return inputs
124
+
125
+ def compute_output_shape(self, input_shape):
126
+ input_shape = list(input_shape)
127
+ input_shape[-3] = None
128
+ return tuple(input_shape)
129
+
130
+ def get_config(self):
131
+ config = {
132
+ "factor": self.factor,
133
+ "interpolation": self.interpolation,
134
+ "seed": self.seed,
135
+ }
136
+ base_config = super().get_config()
137
+ return {**base_config, **config}
138
+
139
+
140
+ @keras_export("keras._legacy.layers.RandomWidth")
141
+ class RandomWidth(Layer):
142
+ """DEPRECATED."""
143
+
144
+ def __init__(self, factor, interpolation="bilinear", seed=None, **kwargs):
145
+ super().__init__(**kwargs)
146
+ self.seed_generator = backend.random.SeedGenerator(seed)
147
+ self.factor = factor
148
+ if isinstance(factor, (tuple, list)):
149
+ self.width_lower = factor[0]
150
+ self.width_upper = factor[1]
151
+ else:
152
+ self.width_lower = -factor
153
+ self.width_upper = factor
154
+ if self.width_upper < self.width_lower:
155
+ raise ValueError(
156
+ "`factor` argument cannot have an upper bound less than the "
157
+ f"lower bound. Received: factor={factor}"
158
+ )
159
+ if self.width_lower < -1.0 or self.width_upper < -1.0:
160
+ raise ValueError(
161
+ "`factor` argument must have values larger than -1. "
162
+ f"Received: factor={factor}"
163
+ )
164
+ self.interpolation = interpolation
165
+ self.seed = seed
166
+
167
+ def call(self, inputs, training=True):
168
+ inputs = tf.convert_to_tensor(inputs, dtype=self.compute_dtype)
169
+
170
+ def random_width_inputs(inputs):
171
+ """Inputs width-adjusted with random ops."""
172
+ inputs_shape = tf.shape(inputs)
173
+ img_hd = inputs_shape[-3]
174
+ img_wd = tf.cast(inputs_shape[-2], tf.float32)
175
+ width_factor = backend.random.uniform(
176
+ shape=[],
177
+ minval=(1.0 + self.width_lower),
178
+ maxval=(1.0 + self.width_upper),
179
+ seed=self.seed_generator,
180
+ )
181
+ adjusted_width = tf.cast(width_factor * img_wd, tf.int32)
182
+ adjusted_size = tf.stack([img_hd, adjusted_width])
183
+ output = tf.image.resize(
184
+ images=inputs,
185
+ size=adjusted_size,
186
+ method=self.interpolation,
187
+ )
188
+ # tf.resize will output float32 regardless of input type.
189
+ output = tf.cast(output, self.compute_dtype)
190
+ output_shape = inputs.shape.as_list()
191
+ output_shape[-2] = None
192
+ output.set_shape(output_shape)
193
+ return output
194
+
195
+ if training:
196
+ return random_width_inputs(inputs)
197
+ else:
198
+ return inputs
199
+
200
+ def compute_output_shape(self, input_shape):
201
+ input_shape = list(input_shape)
202
+ input_shape[-2] = None
203
+ return tuple(input_shape)
204
+
205
+ def get_config(self):
206
+ config = {
207
+ "factor": self.factor,
208
+ "interpolation": self.interpolation,
209
+ "seed": self.seed,
210
+ }
211
+ base_config = super().get_config()
212
+ return {**base_config, **config}
213
+
214
+
215
+ @keras_export("keras._legacy.layers.ThresholdedReLU")
216
+ class ThresholdedReLU(Layer):
217
+ """DEPRECATED."""
218
+
219
+ def __init__(self, theta=1.0, **kwargs):
220
+ super().__init__(**kwargs)
221
+ if theta is None:
222
+ raise ValueError(
223
+ "Theta of a Thresholded ReLU layer cannot be None, expecting a "
224
+ f"float. Received: {theta}"
225
+ )
226
+ if theta < 0:
227
+ raise ValueError(
228
+ "The theta value of a Thresholded ReLU layer "
229
+ f"should be >=0. Received: {theta}"
230
+ )
231
+ self.supports_masking = True
232
+ self.theta = tf.convert_to_tensor(theta, dtype=self.compute_dtype)
233
+
234
+ def call(self, inputs):
235
+ dtype = self.compute_dtype
236
+ return inputs * tf.cast(tf.greater(inputs, self.theta), dtype)
237
+
238
+ def get_config(self):
239
+ config = {"theta": float(self.theta)}
240
+ base_config = super().get_config()
241
+ return {**base_config, **config}
242
+
243
+ def compute_output_shape(self, input_shape):
244
+ return input_shape
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/losses.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src.api_export import keras_export
2
+
3
+
4
+ @keras_export("keras._legacy.losses.Reduction")
5
+ class Reduction:
6
+ AUTO = "auto"
7
+ NONE = "none"
8
+ SUM = "sum"
9
+ SUM_OVER_BATCH_SIZE = "sum_over_batch_size"
10
+
11
+ @classmethod
12
+ def all(cls):
13
+ return (cls.AUTO, cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
14
+
15
+ @classmethod
16
+ def validate(cls, key):
17
+ if key not in cls.all():
18
+ raise ValueError(
19
+ f'Invalid Reduction Key: {key}. Expected keys are "{cls.all()}"'
20
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/image.py ADDED
@@ -0,0 +1,1892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deprecated image preprocessing APIs from Keras 1."""
2
+
3
+ import collections
4
+ import multiprocessing
5
+ import os
6
+ import threading
7
+ import warnings
8
+
9
+ import numpy as np
10
+
11
+ from keras.src import backend
12
+ from keras.src.api_export import keras_export
13
+ from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset
14
+ from keras.src.utils import image_utils
15
+ from keras.src.utils import io_utils
16
+ from keras.src.utils.module_utils import scipy
17
+
18
+
19
+ @keras_export("keras._legacy.preprocessing.image.Iterator")
20
+ class Iterator(PyDataset):
21
+ """Base class for image data iterators.
22
+
23
+ DEPRECATED.
24
+
25
+ Every `Iterator` must implement the `_get_batches_of_transformed_samples`
26
+ method.
27
+
28
+ Args:
29
+ n: Integer, total number of samples in the dataset to loop over.
30
+ batch_size: Integer, size of a batch.
31
+ shuffle: Boolean, whether to shuffle the data between epochs.
32
+ seed: Random seeding for data shuffling.
33
+ """
34
+
35
+ white_list_formats = ("png", "jpg", "jpeg", "bmp", "ppm", "tif", "tiff")
36
+
37
+ def __init__(self, n, batch_size, shuffle, seed):
38
+ self.n = n
39
+ self.batch_size = batch_size
40
+ self.seed = seed
41
+ self.shuffle = shuffle
42
+ self.batch_index = 0
43
+ self.total_batches_seen = 0
44
+ self.lock = threading.Lock()
45
+ self.index_array = None
46
+ self.index_generator = self._flow_index()
47
+
48
+ def _set_index_array(self):
49
+ self.index_array = np.arange(self.n)
50
+ if self.shuffle:
51
+ self.index_array = np.random.permutation(self.n)
52
+
53
+ def __getitem__(self, idx):
54
+ if idx >= len(self):
55
+ raise ValueError(
56
+ "Asked to retrieve element {idx}, "
57
+ "but the Sequence "
58
+ "has length {length}".format(idx=idx, length=len(self))
59
+ )
60
+ if self.seed is not None:
61
+ np.random.seed(self.seed + self.total_batches_seen)
62
+ self.total_batches_seen += 1
63
+ if self.index_array is None:
64
+ self._set_index_array()
65
+ index_array = self.index_array[
66
+ self.batch_size * idx : self.batch_size * (idx + 1)
67
+ ]
68
+ return self._get_batches_of_transformed_samples(index_array)
69
+
70
+ def __len__(self):
71
+ return (self.n + self.batch_size - 1) // self.batch_size # round up
72
+
73
+ def on_epoch_end(self):
74
+ self._set_index_array()
75
+
76
+ def reset(self):
77
+ self.batch_index = 0
78
+
79
+ def _flow_index(self):
80
+ # Ensure self.batch_index is 0.
81
+ self.reset()
82
+ while 1:
83
+ if self.seed is not None:
84
+ np.random.seed(self.seed + self.total_batches_seen)
85
+ if self.batch_index == 0:
86
+ self._set_index_array()
87
+
88
+ if self.n == 0:
89
+ # Avoiding modulo by zero error
90
+ current_index = 0
91
+ else:
92
+ current_index = (self.batch_index * self.batch_size) % self.n
93
+ if self.n > current_index + self.batch_size:
94
+ self.batch_index += 1
95
+ else:
96
+ self.batch_index = 0
97
+ self.total_batches_seen += 1
98
+ yield self.index_array[
99
+ current_index : current_index + self.batch_size
100
+ ]
101
+
102
+ def __iter__(self):
103
+ # Needed if we want to do something like:
104
+ # for x, y in data_gen.flow(...):
105
+ return self
106
+
107
+ def __next__(self):
108
+ with self.lock:
109
+ index_array = next(self.index_generator)
110
+ # The transformation of images is not under thread lock
111
+ # so it can be done in parallel
112
+ return self._get_batches_of_transformed_samples(index_array)
113
+
114
+ def _get_batches_of_transformed_samples(self, index_array):
115
+ """Gets a batch of transformed samples.
116
+
117
+ Args:
118
+ index_array: Array of sample indices to include in batch.
119
+ Returns:
120
+ A batch of transformed samples.
121
+ """
122
+ raise NotImplementedError
123
+
124
+
125
+ def _iter_valid_files(directory, white_list_formats, follow_links):
126
+ """Iterates on files with extension.
127
+
128
+ Args:
129
+ directory: Absolute path to the directory
130
+ containing files to be counted
131
+ white_list_formats: Set of strings containing allowed extensions for
132
+ the files to be counted.
133
+ follow_links: Boolean, follow symbolic links to subdirectories.
134
+ Yields:
135
+ Tuple of (root, filename) with extension in `white_list_formats`.
136
+ """
137
+
138
+ def _recursive_list(subpath):
139
+ return sorted(
140
+ os.walk(subpath, followlinks=follow_links), key=lambda x: x[0]
141
+ )
142
+
143
+ for root, _, files in _recursive_list(directory):
144
+ for fname in sorted(files):
145
+ if fname.lower().endswith(".tiff"):
146
+ warnings.warn(
147
+ 'Using ".tiff" files with multiple bands '
148
+ "will cause distortion. Please verify your output."
149
+ )
150
+ if fname.lower().endswith(white_list_formats):
151
+ yield root, fname
152
+
153
+
154
+ def _list_valid_filenames_in_directory(
155
+ directory, white_list_formats, split, class_indices, follow_links
156
+ ):
157
+ """Lists paths of files in `subdir` with extensions in `white_list_formats`.
158
+
159
+ Args:
160
+ directory: absolute path to a directory containing the files to list.
161
+ The directory name is used as class label
162
+ and must be a key of `class_indices`.
163
+ white_list_formats: set of strings containing allowed extensions for
164
+ the files to be counted.
165
+ split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into
166
+ account a certain fraction of files in each directory.
167
+ E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent
168
+ of images in each directory.
169
+ class_indices: dictionary mapping a class name to its index.
170
+ follow_links: boolean, follow symbolic links to subdirectories.
171
+
172
+ Returns:
173
+ classes: a list of class indices
174
+ filenames: the path of valid files in `directory`, relative from
175
+ `directory`'s parent (e.g., if `directory` is "dataset/class1",
176
+ the filenames will be
177
+ `["class1/file1.jpg", "class1/file2.jpg", ...]`).
178
+ """
179
+ dirname = os.path.basename(directory)
180
+ if split:
181
+ all_files = list(
182
+ _iter_valid_files(directory, white_list_formats, follow_links)
183
+ )
184
+ num_files = len(all_files)
185
+ start, stop = int(split[0] * num_files), int(split[1] * num_files)
186
+ valid_files = all_files[start:stop]
187
+ else:
188
+ valid_files = _iter_valid_files(
189
+ directory, white_list_formats, follow_links
190
+ )
191
+ classes = []
192
+ filenames = []
193
+ for root, fname in valid_files:
194
+ classes.append(class_indices[dirname])
195
+ absolute_path = os.path.join(root, fname)
196
+ relative_path = os.path.join(
197
+ dirname, os.path.relpath(absolute_path, directory)
198
+ )
199
+ filenames.append(relative_path)
200
+
201
+ return classes, filenames
202
+
203
+
204
+ class BatchFromFilesMixin:
205
+ """Adds methods related to getting batches from filenames.
206
+
207
+ It includes the logic to transform image files to batches.
208
+ """
209
+
210
+ def set_processing_attrs(
211
+ self,
212
+ image_data_generator,
213
+ target_size,
214
+ color_mode,
215
+ data_format,
216
+ save_to_dir,
217
+ save_prefix,
218
+ save_format,
219
+ subset,
220
+ interpolation,
221
+ keep_aspect_ratio,
222
+ ):
223
+ """Sets attributes to use later for processing files into a batch.
224
+
225
+ Args:
226
+ image_data_generator: Instance of `ImageDataGenerator`
227
+ to use for random transformations and normalization.
228
+ target_size: tuple of integers, dimensions to resize input images
229
+ to.
230
+ color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
231
+ Color mode to read images.
232
+ data_format: String, one of `channels_first`, `channels_last`.
233
+ save_to_dir: Optional directory where to save the pictures
234
+ being yielded, in a viewable format. This is useful
235
+ for visualizing the random transformations being
236
+ applied, for debugging purposes.
237
+ save_prefix: String prefix to use for saving sample
238
+ images (if `save_to_dir` is set).
239
+ save_format: Format to use for saving sample images
240
+ (if `save_to_dir` is set).
241
+ subset: Subset of data (`"training"` or `"validation"`) if
242
+ validation_split is set in ImageDataGenerator.
243
+ interpolation: Interpolation method used to resample the image if
244
+ the target size is different from that of the loaded image.
245
+ Supported methods are "nearest", "bilinear", and "bicubic". If
246
+ PIL version 1.1.3 or newer is installed, "lanczos" is also
247
+ supported. If PIL version 3.4.0 or newer is installed, "box" and
248
+ "hamming" are also supported. By default, "nearest" is used.
249
+ keep_aspect_ratio: Boolean, whether to resize images to a target
250
+ size without aspect ratio distortion. The image is cropped in
251
+ the center with target aspect ratio before resizing.
252
+ """
253
+ self.image_data_generator = image_data_generator
254
+ self.target_size = tuple(target_size)
255
+ self.keep_aspect_ratio = keep_aspect_ratio
256
+ if color_mode not in {"rgb", "rgba", "grayscale"}:
257
+ raise ValueError(
258
+ f"Invalid color mode: {color_mode}"
259
+ '; expected "rgb", "rgba", or "grayscale".'
260
+ )
261
+ self.color_mode = color_mode
262
+ self.data_format = data_format
263
+ if self.color_mode == "rgba":
264
+ if self.data_format == "channels_last":
265
+ self.image_shape = self.target_size + (4,)
266
+ else:
267
+ self.image_shape = (4,) + self.target_size
268
+ elif self.color_mode == "rgb":
269
+ if self.data_format == "channels_last":
270
+ self.image_shape = self.target_size + (3,)
271
+ else:
272
+ self.image_shape = (3,) + self.target_size
273
+ else:
274
+ if self.data_format == "channels_last":
275
+ self.image_shape = self.target_size + (1,)
276
+ else:
277
+ self.image_shape = (1,) + self.target_size
278
+ self.save_to_dir = save_to_dir
279
+ self.save_prefix = save_prefix
280
+ self.save_format = save_format
281
+ self.interpolation = interpolation
282
+ if subset is not None:
283
+ validation_split = self.image_data_generator._validation_split
284
+ if subset == "validation":
285
+ split = (0, validation_split)
286
+ elif subset == "training":
287
+ split = (validation_split, 1)
288
+ else:
289
+ raise ValueError(
290
+ f"Invalid subset name: {subset};"
291
+ 'expected "training" or "validation"'
292
+ )
293
+ else:
294
+ split = None
295
+ self.split = split
296
+ self.subset = subset
297
+
298
+ def _get_batches_of_transformed_samples(self, index_array):
299
+ """Gets a batch of transformed samples.
300
+
301
+ Args:
302
+ index_array: Array of sample indices to include in batch.
303
+ Returns:
304
+ A batch of transformed samples.
305
+ """
306
+ batch_x = np.zeros(
307
+ (len(index_array),) + self.image_shape, dtype=self.dtype
308
+ )
309
+ # build batch of image data
310
+ # self.filepaths is dynamic, is better to call it once outside the loop
311
+ filepaths = self.filepaths
312
+ for i, j in enumerate(index_array):
313
+ img = image_utils.load_img(
314
+ filepaths[j],
315
+ color_mode=self.color_mode,
316
+ target_size=self.target_size,
317
+ interpolation=self.interpolation,
318
+ keep_aspect_ratio=self.keep_aspect_ratio,
319
+ )
320
+ x = image_utils.img_to_array(img, data_format=self.data_format)
321
+ # Pillow images should be closed after `load_img`,
322
+ # but not PIL images.
323
+ if hasattr(img, "close"):
324
+ img.close()
325
+ if self.image_data_generator:
326
+ params = self.image_data_generator.get_random_transform(x.shape)
327
+ x = self.image_data_generator.apply_transform(x, params)
328
+ x = self.image_data_generator.standardize(x)
329
+ batch_x[i] = x
330
+ # optionally save augmented images to disk for debugging purposes
331
+ if self.save_to_dir:
332
+ for i, j in enumerate(index_array):
333
+ img = image_utils.array_to_img(
334
+ batch_x[i], self.data_format, scale=True
335
+ )
336
+ fname = "{prefix}_{index}_{hash}.{format}".format(
337
+ prefix=self.save_prefix,
338
+ index=j,
339
+ hash=np.random.randint(1e7),
340
+ format=self.save_format,
341
+ )
342
+ img.save(os.path.join(self.save_to_dir, fname))
343
+ # build batch of labels
344
+ if self.class_mode == "input":
345
+ batch_y = batch_x.copy()
346
+ elif self.class_mode in {"binary", "sparse"}:
347
+ batch_y = np.empty(len(batch_x), dtype=self.dtype)
348
+ for i, n_observation in enumerate(index_array):
349
+ batch_y[i] = self.classes[n_observation]
350
+ elif self.class_mode == "categorical":
351
+ batch_y = np.zeros(
352
+ (len(batch_x), len(self.class_indices)), dtype=self.dtype
353
+ )
354
+ for i, n_observation in enumerate(index_array):
355
+ batch_y[i, self.classes[n_observation]] = 1.0
356
+ elif self.class_mode == "multi_output":
357
+ batch_y = [output[index_array] for output in self.labels]
358
+ elif self.class_mode == "raw":
359
+ batch_y = self.labels[index_array]
360
+ else:
361
+ return batch_x
362
+ if self.sample_weight is None:
363
+ return batch_x, batch_y
364
+ else:
365
+ return batch_x, batch_y, self.sample_weight[index_array]
366
+
367
+ @property
368
+ def filepaths(self):
369
+ """List of absolute paths to image files."""
370
+ raise NotImplementedError(
371
+ "`filepaths` property method has not "
372
+ "been implemented in {}.".format(type(self).__name__)
373
+ )
374
+
375
+ @property
376
+ def labels(self):
377
+ """Class labels of every observation."""
378
+ raise NotImplementedError(
379
+ "`labels` property method has not been implemented in {}.".format(
380
+ type(self).__name__
381
+ )
382
+ )
383
+
384
+ @property
385
+ def sample_weight(self):
386
+ raise NotImplementedError(
387
+ "`sample_weight` property method has not "
388
+ "been implemented in {}.".format(type(self).__name__)
389
+ )
390
+
391
+
392
+ @keras_export("keras._legacy.preprocessing.image.DirectoryIterator")
393
+ class DirectoryIterator(BatchFromFilesMixin, Iterator):
394
+ """Iterator capable of reading images from a directory on disk.
395
+
396
+ DEPRECATED.
397
+ """
398
+
399
+ allowed_class_modes = {"categorical", "binary", "sparse", "input", None}
400
+
401
+ def __init__(
402
+ self,
403
+ directory,
404
+ image_data_generator,
405
+ target_size=(256, 256),
406
+ color_mode="rgb",
407
+ classes=None,
408
+ class_mode="categorical",
409
+ batch_size=32,
410
+ shuffle=True,
411
+ seed=None,
412
+ data_format=None,
413
+ save_to_dir=None,
414
+ save_prefix="",
415
+ save_format="png",
416
+ follow_links=False,
417
+ subset=None,
418
+ interpolation="nearest",
419
+ keep_aspect_ratio=False,
420
+ dtype=None,
421
+ ):
422
+ if data_format is None:
423
+ data_format = backend.image_data_format()
424
+ if dtype is None:
425
+ dtype = backend.floatx()
426
+ super().set_processing_attrs(
427
+ image_data_generator,
428
+ target_size,
429
+ color_mode,
430
+ data_format,
431
+ save_to_dir,
432
+ save_prefix,
433
+ save_format,
434
+ subset,
435
+ interpolation,
436
+ keep_aspect_ratio,
437
+ )
438
+ self.directory = directory
439
+ self.classes = classes
440
+ if class_mode not in self.allowed_class_modes:
441
+ raise ValueError(
442
+ "Invalid class_mode: {}; expected one of: {}".format(
443
+ class_mode, self.allowed_class_modes
444
+ )
445
+ )
446
+ self.class_mode = class_mode
447
+ self.dtype = dtype
448
+ # First, count the number of samples and classes.
449
+ self.samples = 0
450
+
451
+ if not classes:
452
+ classes = []
453
+ for subdir in sorted(os.listdir(directory)):
454
+ if os.path.isdir(os.path.join(directory, subdir)):
455
+ classes.append(subdir)
456
+ self.num_classes = len(classes)
457
+ self.class_indices = dict(zip(classes, range(len(classes))))
458
+
459
+ pool = multiprocessing.pool.ThreadPool()
460
+
461
+ # Second, build an index of the images
462
+ # in the different class subfolders.
463
+ results = []
464
+ self.filenames = []
465
+ i = 0
466
+ for dirpath in (os.path.join(directory, subdir) for subdir in classes):
467
+ results.append(
468
+ pool.apply_async(
469
+ _list_valid_filenames_in_directory,
470
+ (
471
+ dirpath,
472
+ self.white_list_formats,
473
+ self.split,
474
+ self.class_indices,
475
+ follow_links,
476
+ ),
477
+ )
478
+ )
479
+ classes_list = []
480
+ for res in results:
481
+ classes, filenames = res.get()
482
+ classes_list.append(classes)
483
+ self.filenames += filenames
484
+ self.samples = len(self.filenames)
485
+ self.classes = np.zeros((self.samples,), dtype="int32")
486
+ for classes in classes_list:
487
+ self.classes[i : i + len(classes)] = classes
488
+ i += len(classes)
489
+
490
+ io_utils.print_msg(
491
+ f"Found {self.samples} images belonging to "
492
+ f"{self.num_classes} classes."
493
+ )
494
+ pool.close()
495
+ pool.join()
496
+ self._filepaths = [
497
+ os.path.join(self.directory, fname) for fname in self.filenames
498
+ ]
499
+ super().__init__(self.samples, batch_size, shuffle, seed)
500
+
501
+ @property
502
+ def filepaths(self):
503
+ return self._filepaths
504
+
505
+ @property
506
+ def labels(self):
507
+ return self.classes
508
+
509
+ @property # mixin needs this property to work
510
+ def sample_weight(self):
511
+ # no sample weights will be returned
512
+ return None
513
+
514
+
515
+ @keras_export("keras._legacy.preprocessing.image.NumpyArrayIterator")
516
+ class NumpyArrayIterator(Iterator):
517
+ """Iterator yielding data from a Numpy array.
518
+
519
+ DEPRECATED.
520
+ """
521
+
522
+ def __init__(
523
+ self,
524
+ x,
525
+ y,
526
+ image_data_generator,
527
+ batch_size=32,
528
+ shuffle=False,
529
+ sample_weight=None,
530
+ seed=None,
531
+ data_format=None,
532
+ save_to_dir=None,
533
+ save_prefix="",
534
+ save_format="png",
535
+ subset=None,
536
+ ignore_class_split=False,
537
+ dtype=None,
538
+ ):
539
+ if data_format is None:
540
+ data_format = backend.image_data_format()
541
+ if dtype is None:
542
+ dtype = backend.floatx()
543
+ self.dtype = dtype
544
+ if isinstance(x, tuple) or isinstance(x, list):
545
+ if not isinstance(x[1], list):
546
+ x_misc = [np.asarray(x[1])]
547
+ else:
548
+ x_misc = [np.asarray(xx) for xx in x[1]]
549
+ x = x[0]
550
+ for xx in x_misc:
551
+ if len(x) != len(xx):
552
+ raise ValueError(
553
+ "All of the arrays in `x` "
554
+ "should have the same length. "
555
+ "Found a pair with: "
556
+ f"len(x[0]) = {len(x)}, len(x[?]) = {len(xx)}"
557
+ )
558
+ else:
559
+ x_misc = []
560
+
561
+ if y is not None and len(x) != len(y):
562
+ raise ValueError(
563
+ "`x` (images tensor) and `y` (labels) "
564
+ "should have the same length. "
565
+ f"Found: x.shape = {np.asarray(x).shape}, "
566
+ f"y.shape = {np.asarray(y).shape}"
567
+ )
568
+ if sample_weight is not None and len(x) != len(sample_weight):
569
+ raise ValueError(
570
+ "`x` (images tensor) and `sample_weight` "
571
+ "should have the same length. "
572
+ f"Found: x.shape = {np.asarray(x).shape}, "
573
+ f"sample_weight.shape = {np.asarray(sample_weight).shape}"
574
+ )
575
+ if subset is not None:
576
+ if subset not in {"training", "validation"}:
577
+ raise ValueError(
578
+ f"Invalid subset name: {subset}"
579
+ '; expected "training" or "validation".'
580
+ )
581
+ split_idx = int(len(x) * image_data_generator._validation_split)
582
+
583
+ if (
584
+ y is not None
585
+ and not ignore_class_split
586
+ and not np.array_equal(
587
+ np.unique(y[:split_idx]), np.unique(y[split_idx:])
588
+ )
589
+ ):
590
+ raise ValueError(
591
+ "Training and validation subsets "
592
+ "have different number of classes after "
593
+ "the split. If your numpy arrays are "
594
+ "sorted by the label, you might want "
595
+ "to shuffle them."
596
+ )
597
+
598
+ if subset == "validation":
599
+ x = x[:split_idx]
600
+ x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
601
+ if y is not None:
602
+ y = y[:split_idx]
603
+ else:
604
+ x = x[split_idx:]
605
+ x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
606
+ if y is not None:
607
+ y = y[split_idx:]
608
+
609
+ self.x = np.asarray(x, dtype=self.dtype)
610
+ self.x_misc = x_misc
611
+ if self.x.ndim != 4:
612
+ raise ValueError(
613
+ "Input data in `NumpyArrayIterator` "
614
+ "should have rank 4. You passed an array "
615
+ f"with shape {self.x.shape}"
616
+ )
617
+ channels_axis = 3 if data_format == "channels_last" else 1
618
+ if self.x.shape[channels_axis] not in {1, 3, 4}:
619
+ warnings.warn(
620
+ 'NumpyArrayIterator is set to use the data format convention "'
621
+ + data_format
622
+ + '" (channels on axis '
623
+ + str(channels_axis)
624
+ + "), i.e. expected either 1, 3, or 4 channels on axis "
625
+ + str(channels_axis)
626
+ + ". However, it was passed an array with shape "
627
+ + str(self.x.shape)
628
+ + " ("
629
+ + str(self.x.shape[channels_axis])
630
+ + " channels)."
631
+ )
632
+ if y is not None:
633
+ self.y = np.asarray(y)
634
+ else:
635
+ self.y = None
636
+ if sample_weight is not None:
637
+ self.sample_weight = np.asarray(sample_weight)
638
+ else:
639
+ self.sample_weight = None
640
+ self.image_data_generator = image_data_generator
641
+ self.data_format = data_format
642
+ self.save_to_dir = save_to_dir
643
+ self.save_prefix = save_prefix
644
+ self.save_format = save_format
645
+ super().__init__(x.shape[0], batch_size, shuffle, seed)
646
+
647
+ def _get_batches_of_transformed_samples(self, index_array):
648
+ batch_x = np.zeros(
649
+ tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=self.dtype
650
+ )
651
+ for i, j in enumerate(index_array):
652
+ x = self.x[j]
653
+ params = self.image_data_generator.get_random_transform(x.shape)
654
+ x = self.image_data_generator.apply_transform(
655
+ x.astype(self.dtype), params
656
+ )
657
+ x = self.image_data_generator.standardize(x)
658
+ batch_x[i] = x
659
+
660
+ if self.save_to_dir:
661
+ for i, j in enumerate(index_array):
662
+ img = image_utils.array_to_img(
663
+ batch_x[i], self.data_format, scale=True
664
+ )
665
+ fname = "{prefix}_{index}_{hash}.{format}".format(
666
+ prefix=self.save_prefix,
667
+ index=j,
668
+ hash=np.random.randint(1e4),
669
+ format=self.save_format,
670
+ )
671
+ img.save(os.path.join(self.save_to_dir, fname))
672
+ batch_x_miscs = [xx[index_array] for xx in self.x_misc]
673
+ output = (batch_x if not batch_x_miscs else [batch_x] + batch_x_miscs,)
674
+ if self.y is None:
675
+ return output[0]
676
+ output += (self.y[index_array],)
677
+ if self.sample_weight is not None:
678
+ output += (self.sample_weight[index_array],)
679
+ return output
680
+
681
+
682
+ def validate_filename(filename, white_list_formats):
683
+ """Check if a filename refers to a valid file.
684
+
685
+ Args:
686
+ filename: String, absolute path to a file
687
+ white_list_formats: Set, allowed file extensions
688
+ Returns:
689
+ A boolean value indicating if the filename is valid or not
690
+ """
691
+ return filename.lower().endswith(white_list_formats) and os.path.isfile(
692
+ filename
693
+ )
694
+
695
+
696
+ class DataFrameIterator(BatchFromFilesMixin, Iterator):
697
+ """Iterator capable of reading images from a directory as a dataframe."""
698
+
699
+ allowed_class_modes = {
700
+ "binary",
701
+ "categorical",
702
+ "input",
703
+ "multi_output",
704
+ "raw",
705
+ "sparse",
706
+ None,
707
+ }
708
+
709
+ def __init__(
710
+ self,
711
+ dataframe,
712
+ directory=None,
713
+ image_data_generator=None,
714
+ x_col="filename",
715
+ y_col="class",
716
+ weight_col=None,
717
+ target_size=(256, 256),
718
+ color_mode="rgb",
719
+ classes=None,
720
+ class_mode="categorical",
721
+ batch_size=32,
722
+ shuffle=True,
723
+ seed=None,
724
+ data_format="channels_last",
725
+ save_to_dir=None,
726
+ save_prefix="",
727
+ save_format="png",
728
+ subset=None,
729
+ interpolation="nearest",
730
+ keep_aspect_ratio=False,
731
+ dtype="float32",
732
+ validate_filenames=True,
733
+ ):
734
+ super().set_processing_attrs(
735
+ image_data_generator,
736
+ target_size,
737
+ color_mode,
738
+ data_format,
739
+ save_to_dir,
740
+ save_prefix,
741
+ save_format,
742
+ subset,
743
+ interpolation,
744
+ keep_aspect_ratio,
745
+ )
746
+ df = dataframe.copy()
747
+ self.directory = directory or ""
748
+ self.class_mode = class_mode
749
+ self.dtype = dtype
750
+ # check that inputs match the required class_mode
751
+ self._check_params(df, x_col, y_col, weight_col, classes)
752
+ if (
753
+ validate_filenames
754
+ ): # check which image files are valid and keep them
755
+ df = self._filter_valid_filepaths(df, x_col)
756
+ if class_mode not in ["input", "multi_output", "raw", None]:
757
+ df, classes = self._filter_classes(df, y_col, classes)
758
+ num_classes = len(classes)
759
+ # build an index of all the unique classes
760
+ self.class_indices = dict(zip(classes, range(len(classes))))
761
+ # retrieve only training or validation set
762
+ if self.split:
763
+ num_files = len(df)
764
+ start = int(self.split[0] * num_files)
765
+ stop = int(self.split[1] * num_files)
766
+ df = df.iloc[start:stop, :]
767
+ # get labels for each observation
768
+ if class_mode not in ["input", "multi_output", "raw", None]:
769
+ self.classes = self.get_classes(df, y_col)
770
+ self.filenames = df[x_col].tolist()
771
+ self._sample_weight = df[weight_col].values if weight_col else None
772
+
773
+ if class_mode == "multi_output":
774
+ self._targets = [np.array(df[col].tolist()) for col in y_col]
775
+ if class_mode == "raw":
776
+ self._targets = df[y_col].values
777
+ self.samples = len(self.filenames)
778
+ validated_string = (
779
+ "validated" if validate_filenames else "non-validated"
780
+ )
781
+ if class_mode in ["input", "multi_output", "raw", None]:
782
+ io_utils.print_msg(
783
+ f"Found {self.samples} {validated_string} image filenames."
784
+ )
785
+ else:
786
+ io_utils.print_msg(
787
+ f"Found {self.samples} {validated_string} image filenames "
788
+ f"belonging to {num_classes} classes."
789
+ )
790
+ self._filepaths = [
791
+ os.path.join(self.directory, fname) for fname in self.filenames
792
+ ]
793
+ super().__init__(self.samples, batch_size, shuffle, seed)
794
+
795
+ def _check_params(self, df, x_col, y_col, weight_col, classes):
796
+ # check class mode is one of the currently supported
797
+ if self.class_mode not in self.allowed_class_modes:
798
+ raise ValueError(
799
+ "Invalid class_mode: {}; expected one of: {}".format(
800
+ self.class_mode, self.allowed_class_modes
801
+ )
802
+ )
803
+ # check that y_col has several column names if class_mode is
804
+ # multi_output
805
+ if (self.class_mode == "multi_output") and not isinstance(y_col, list):
806
+ raise TypeError(
807
+ 'If class_mode="{}", y_col must be a list. Received {}.'.format(
808
+ self.class_mode, type(y_col).__name__
809
+ )
810
+ )
811
+ # check that filenames/filepaths column values are all strings
812
+ if not all(df[x_col].apply(lambda x: isinstance(x, str))):
813
+ raise TypeError(
814
+ f"All values in column x_col={x_col} must be strings."
815
+ )
816
+ # check labels are string if class_mode is binary or sparse
817
+ if self.class_mode in {"binary", "sparse"}:
818
+ if not all(df[y_col].apply(lambda x: isinstance(x, str))):
819
+ raise TypeError(
820
+ 'If class_mode="{}", y_col="{}" column '
821
+ "values must be strings.".format(self.class_mode, y_col)
822
+ )
823
+ # check that if binary there are only 2 different classes
824
+ if self.class_mode == "binary":
825
+ if classes:
826
+ classes = set(classes)
827
+ if len(classes) != 2:
828
+ raise ValueError(
829
+ 'If class_mode="binary" there must be 2 '
830
+ "classes. {} class/es were given.".format(len(classes))
831
+ )
832
+ elif df[y_col].nunique() != 2:
833
+ raise ValueError(
834
+ 'If class_mode="binary" there must be 2 classes. '
835
+ "Found {} classes.".format(df[y_col].nunique())
836
+ )
837
+ # check values are string, list or tuple if class_mode is categorical
838
+ if self.class_mode == "categorical":
839
+ types = (str, list, tuple)
840
+ if not all(df[y_col].apply(lambda x: isinstance(x, types))):
841
+ raise TypeError(
842
+ 'If class_mode="{}", y_col="{}" column '
843
+ "values must be type string, list or tuple.".format(
844
+ self.class_mode, y_col
845
+ )
846
+ )
847
+ # raise warning if classes are given but will be unused
848
+ if classes and self.class_mode in {
849
+ "input",
850
+ "multi_output",
851
+ "raw",
852
+ None,
853
+ }:
854
+ warnings.warn(
855
+ '`classes` will be ignored given the class_mode="{}"'.format(
856
+ self.class_mode
857
+ )
858
+ )
859
+ # check that if weight column that the values are numerical
860
+ if weight_col and not issubclass(df[weight_col].dtype.type, np.number):
861
+ raise TypeError(f"Column weight_col={weight_col} must be numeric.")
862
+
863
+ def get_classes(self, df, y_col):
864
+ labels = []
865
+ for label in df[y_col]:
866
+ if isinstance(label, (list, tuple)):
867
+ labels.append([self.class_indices[lbl] for lbl in label])
868
+ else:
869
+ labels.append(self.class_indices[label])
870
+ return labels
871
+
872
+ @staticmethod
873
+ def _filter_classes(df, y_col, classes):
874
+ df = df.copy()
875
+
876
+ def remove_classes(labels, classes):
877
+ if isinstance(labels, (list, tuple)):
878
+ labels = [cls for cls in labels if cls in classes]
879
+ return labels or None
880
+ elif isinstance(labels, str):
881
+ return labels if labels in classes else None
882
+ else:
883
+ raise TypeError(
884
+ "Expect string, list or tuple "
885
+ "but found {} in {} column ".format(type(labels), y_col)
886
+ )
887
+
888
+ if classes:
889
+ # prepare for membership lookup
890
+ classes = list(collections.OrderedDict.fromkeys(classes).keys())
891
+ df[y_col] = df[y_col].apply(lambda x: remove_classes(x, classes))
892
+ else:
893
+ classes = set()
894
+ for v in df[y_col]:
895
+ if isinstance(v, (list, tuple)):
896
+ classes.update(v)
897
+ else:
898
+ classes.add(v)
899
+ classes = sorted(classes)
900
+ return df.dropna(subset=[y_col]), classes
901
+
902
+ def _filter_valid_filepaths(self, df, x_col):
903
+ """Keep only dataframe rows with valid filenames.
904
+
905
+ Args:
906
+ df: Pandas dataframe containing filenames in a column
907
+ x_col: string, column in `df` that contains the filenames or
908
+ filepaths
909
+ Returns:
910
+ absolute paths to image files
911
+ """
912
+ filepaths = df[x_col].map(
913
+ lambda fname: os.path.join(self.directory, fname)
914
+ )
915
+ mask = filepaths.apply(
916
+ validate_filename, args=(self.white_list_formats,)
917
+ )
918
+ n_invalid = (~mask).sum()
919
+ if n_invalid:
920
+ warnings.warn(
921
+ 'Found {} invalid image filename(s) in x_col="{}". '
922
+ "These filename(s) will be ignored.".format(n_invalid, x_col)
923
+ )
924
+ return df[mask]
925
+
926
+ @property
927
+ def filepaths(self):
928
+ return self._filepaths
929
+
930
+ @property
931
+ def labels(self):
932
+ if self.class_mode in {"multi_output", "raw"}:
933
+ return self._targets
934
+ else:
935
+ return self.classes
936
+
937
+ @property
938
+ def sample_weight(self):
939
+ return self._sample_weight
940
+
941
+
942
+ def flip_axis(x, axis):
943
+ x = np.asarray(x).swapaxes(axis, 0)
944
+ x = x[::-1, ...]
945
+ x = x.swapaxes(0, axis)
946
+ return x
947
+
948
+
949
+ @keras_export("keras._legacy.preprocessing.image.ImageDataGenerator")
950
+ class ImageDataGenerator:
951
+ """DEPRECATED."""
952
+
953
+ def __init__(
954
+ self,
955
+ featurewise_center=False,
956
+ samplewise_center=False,
957
+ featurewise_std_normalization=False,
958
+ samplewise_std_normalization=False,
959
+ zca_whitening=False,
960
+ zca_epsilon=1e-6,
961
+ rotation_range=0,
962
+ width_shift_range=0.0,
963
+ height_shift_range=0.0,
964
+ brightness_range=None,
965
+ shear_range=0.0,
966
+ zoom_range=0.0,
967
+ channel_shift_range=0.0,
968
+ fill_mode="nearest",
969
+ cval=0.0,
970
+ horizontal_flip=False,
971
+ vertical_flip=False,
972
+ rescale=None,
973
+ preprocessing_function=None,
974
+ data_format=None,
975
+ validation_split=0.0,
976
+ interpolation_order=1,
977
+ dtype=None,
978
+ ):
979
+ if data_format is None:
980
+ data_format = backend.image_data_format()
981
+ if dtype is None:
982
+ dtype = backend.floatx()
983
+
984
+ self.featurewise_center = featurewise_center
985
+ self.samplewise_center = samplewise_center
986
+ self.featurewise_std_normalization = featurewise_std_normalization
987
+ self.samplewise_std_normalization = samplewise_std_normalization
988
+ self.zca_whitening = zca_whitening
989
+ self.zca_epsilon = zca_epsilon
990
+ self.rotation_range = rotation_range
991
+ self.width_shift_range = width_shift_range
992
+ self.height_shift_range = height_shift_range
993
+ self.shear_range = shear_range
994
+ self.zoom_range = zoom_range
995
+ self.channel_shift_range = channel_shift_range
996
+ self.fill_mode = fill_mode
997
+ self.cval = cval
998
+ self.horizontal_flip = horizontal_flip
999
+ self.vertical_flip = vertical_flip
1000
+ self.rescale = rescale
1001
+ self.preprocessing_function = preprocessing_function
1002
+ self.dtype = dtype
1003
+ self.interpolation_order = interpolation_order
1004
+
1005
+ if data_format not in {"channels_last", "channels_first"}:
1006
+ raise ValueError(
1007
+ '`data_format` should be `"channels_last"` '
1008
+ "(channel after row and column) or "
1009
+ '`"channels_first"` (channel before row and column). '
1010
+ f"Received: {data_format}"
1011
+ )
1012
+ self.data_format = data_format
1013
+ if data_format == "channels_first":
1014
+ self.channel_axis = 1
1015
+ self.row_axis = 2
1016
+ self.col_axis = 3
1017
+ if data_format == "channels_last":
1018
+ self.channel_axis = 3
1019
+ self.row_axis = 1
1020
+ self.col_axis = 2
1021
+ if validation_split and not 0 < validation_split < 1:
1022
+ raise ValueError(
1023
+ "`validation_split` must be strictly between 0 and 1. "
1024
+ f" Received: {validation_split}"
1025
+ )
1026
+ self._validation_split = validation_split
1027
+
1028
+ self.mean = None
1029
+ self.std = None
1030
+ self.zca_whitening_matrix = None
1031
+
1032
+ if isinstance(zoom_range, (float, int)):
1033
+ self.zoom_range = [1 - zoom_range, 1 + zoom_range]
1034
+ elif len(zoom_range) == 2 and all(
1035
+ isinstance(val, (float, int)) for val in zoom_range
1036
+ ):
1037
+ self.zoom_range = [zoom_range[0], zoom_range[1]]
1038
+ else:
1039
+ raise ValueError(
1040
+ "`zoom_range` should be a float or "
1041
+ "a tuple or list of two floats. "
1042
+ f"Received: {zoom_range}"
1043
+ )
1044
+ if zca_whitening:
1045
+ if not featurewise_center:
1046
+ self.featurewise_center = True
1047
+ warnings.warn(
1048
+ "This ImageDataGenerator specifies "
1049
+ "`zca_whitening`, which overrides "
1050
+ "setting of `featurewise_center`."
1051
+ )
1052
+ if featurewise_std_normalization:
1053
+ self.featurewise_std_normalization = False
1054
+ warnings.warn(
1055
+ "This ImageDataGenerator specifies "
1056
+ "`zca_whitening` "
1057
+ "which overrides setting of"
1058
+ "`featurewise_std_normalization`."
1059
+ )
1060
+ if featurewise_std_normalization:
1061
+ if not featurewise_center:
1062
+ self.featurewise_center = True
1063
+ warnings.warn(
1064
+ "This ImageDataGenerator specifies "
1065
+ "`featurewise_std_normalization`, "
1066
+ "which overrides setting of "
1067
+ "`featurewise_center`."
1068
+ )
1069
+ if samplewise_std_normalization:
1070
+ if not samplewise_center:
1071
+ self.samplewise_center = True
1072
+ warnings.warn(
1073
+ "This ImageDataGenerator specifies "
1074
+ "`samplewise_std_normalization`, "
1075
+ "which overrides setting of "
1076
+ "`samplewise_center`."
1077
+ )
1078
+ if brightness_range is not None:
1079
+ if (
1080
+ not isinstance(brightness_range, (tuple, list))
1081
+ or len(brightness_range) != 2
1082
+ ):
1083
+ raise ValueError(
1084
+ "`brightness_range should be tuple or list of two floats. "
1085
+ f"Received: {brightness_range}"
1086
+ )
1087
+ self.brightness_range = brightness_range
1088
+
1089
+ def flow(
1090
+ self,
1091
+ x,
1092
+ y=None,
1093
+ batch_size=32,
1094
+ shuffle=True,
1095
+ sample_weight=None,
1096
+ seed=None,
1097
+ save_to_dir=None,
1098
+ save_prefix="",
1099
+ save_format="png",
1100
+ ignore_class_split=False,
1101
+ subset=None,
1102
+ ):
1103
+ return NumpyArrayIterator(
1104
+ x,
1105
+ y,
1106
+ self,
1107
+ batch_size=batch_size,
1108
+ shuffle=shuffle,
1109
+ sample_weight=sample_weight,
1110
+ seed=seed,
1111
+ data_format=self.data_format,
1112
+ save_to_dir=save_to_dir,
1113
+ save_prefix=save_prefix,
1114
+ save_format=save_format,
1115
+ ignore_class_split=ignore_class_split,
1116
+ subset=subset,
1117
+ dtype=self.dtype,
1118
+ )
1119
+
1120
+ def flow_from_directory(
1121
+ self,
1122
+ directory,
1123
+ target_size=(256, 256),
1124
+ color_mode="rgb",
1125
+ classes=None,
1126
+ class_mode="categorical",
1127
+ batch_size=32,
1128
+ shuffle=True,
1129
+ seed=None,
1130
+ save_to_dir=None,
1131
+ save_prefix="",
1132
+ save_format="png",
1133
+ follow_links=False,
1134
+ subset=None,
1135
+ interpolation="nearest",
1136
+ keep_aspect_ratio=False,
1137
+ ):
1138
+ return DirectoryIterator(
1139
+ directory,
1140
+ self,
1141
+ target_size=target_size,
1142
+ color_mode=color_mode,
1143
+ keep_aspect_ratio=keep_aspect_ratio,
1144
+ classes=classes,
1145
+ class_mode=class_mode,
1146
+ data_format=self.data_format,
1147
+ batch_size=batch_size,
1148
+ shuffle=shuffle,
1149
+ seed=seed,
1150
+ save_to_dir=save_to_dir,
1151
+ save_prefix=save_prefix,
1152
+ save_format=save_format,
1153
+ follow_links=follow_links,
1154
+ subset=subset,
1155
+ interpolation=interpolation,
1156
+ dtype=self.dtype,
1157
+ )
1158
+
1159
+ def flow_from_dataframe(
1160
+ self,
1161
+ dataframe,
1162
+ directory=None,
1163
+ x_col="filename",
1164
+ y_col="class",
1165
+ weight_col=None,
1166
+ target_size=(256, 256),
1167
+ color_mode="rgb",
1168
+ classes=None,
1169
+ class_mode="categorical",
1170
+ batch_size=32,
1171
+ shuffle=True,
1172
+ seed=None,
1173
+ save_to_dir=None,
1174
+ save_prefix="",
1175
+ save_format="png",
1176
+ subset=None,
1177
+ interpolation="nearest",
1178
+ validate_filenames=True,
1179
+ **kwargs,
1180
+ ):
1181
+ if "has_ext" in kwargs:
1182
+ warnings.warn(
1183
+ "has_ext is deprecated, filenames in the dataframe have "
1184
+ "to match the exact filenames in disk.",
1185
+ DeprecationWarning,
1186
+ )
1187
+ if "sort" in kwargs:
1188
+ warnings.warn(
1189
+ "sort is deprecated, batches will be created in the"
1190
+ "same order than the filenames provided if `shuffle`"
1191
+ "is set to `False`.",
1192
+ DeprecationWarning,
1193
+ )
1194
+ if class_mode == "other":
1195
+ warnings.warn(
1196
+ '`class_mode="other"` is deprecated, please use '
1197
+ '`class_mode="raw"`.',
1198
+ DeprecationWarning,
1199
+ )
1200
+ class_mode = "raw"
1201
+ if "drop_duplicates" in kwargs:
1202
+ warnings.warn(
1203
+ "drop_duplicates is deprecated, you can drop duplicates "
1204
+ "by using the pandas.DataFrame.drop_duplicates method.",
1205
+ DeprecationWarning,
1206
+ )
1207
+
1208
+ return DataFrameIterator(
1209
+ dataframe,
1210
+ directory,
1211
+ self,
1212
+ x_col=x_col,
1213
+ y_col=y_col,
1214
+ weight_col=weight_col,
1215
+ target_size=target_size,
1216
+ color_mode=color_mode,
1217
+ classes=classes,
1218
+ class_mode=class_mode,
1219
+ data_format=self.data_format,
1220
+ batch_size=batch_size,
1221
+ shuffle=shuffle,
1222
+ seed=seed,
1223
+ save_to_dir=save_to_dir,
1224
+ save_prefix=save_prefix,
1225
+ save_format=save_format,
1226
+ subset=subset,
1227
+ interpolation=interpolation,
1228
+ validate_filenames=validate_filenames,
1229
+ dtype=self.dtype,
1230
+ )
1231
+
1232
+ def standardize(self, x):
1233
+ """Applies the normalization configuration in-place to a batch of
1234
+ inputs.
1235
+
1236
+ `x` is changed in-place since the function is mainly used internally
1237
+ to standardize images and feed them to your network. If a copy of `x`
1238
+ would be created instead it would have a significant performance cost.
1239
+ If you want to apply this method without changing the input in-place
1240
+ you can call the method creating a copy before:
1241
+
1242
+ standardize(np.copy(x))
1243
+
1244
+ Args:
1245
+ x: Batch of inputs to be normalized.
1246
+
1247
+ Returns:
1248
+ The inputs, normalized.
1249
+ """
1250
+ if self.preprocessing_function:
1251
+ x = self.preprocessing_function(x)
1252
+ if self.rescale:
1253
+ x *= self.rescale
1254
+ if self.samplewise_center:
1255
+ x -= np.mean(x, keepdims=True)
1256
+ if self.samplewise_std_normalization:
1257
+ x /= np.std(x, keepdims=True) + 1e-6
1258
+
1259
+ if self.featurewise_center:
1260
+ if self.mean is not None:
1261
+ x -= self.mean
1262
+ else:
1263
+ warnings.warn(
1264
+ "This ImageDataGenerator specifies "
1265
+ "`featurewise_center`, but it hasn't "
1266
+ "been fit on any training data. Fit it "
1267
+ "first by calling `.fit(numpy_data)`."
1268
+ )
1269
+ if self.featurewise_std_normalization:
1270
+ if self.std is not None:
1271
+ x /= self.std + 1e-6
1272
+ else:
1273
+ warnings.warn(
1274
+ "This ImageDataGenerator specifies "
1275
+ "`featurewise_std_normalization`, "
1276
+ "but it hasn't "
1277
+ "been fit on any training data. Fit it "
1278
+ "first by calling `.fit(numpy_data)`."
1279
+ )
1280
+ if self.zca_whitening:
1281
+ if self.zca_whitening_matrix is not None:
1282
+ flat_x = x.reshape(-1, np.prod(x.shape[-3:]))
1283
+ white_x = flat_x @ self.zca_whitening_matrix
1284
+ x = np.reshape(white_x, x.shape)
1285
+ else:
1286
+ warnings.warn(
1287
+ "This ImageDataGenerator specifies "
1288
+ "`zca_whitening`, but it hasn't "
1289
+ "been fit on any training data. Fit it "
1290
+ "first by calling `.fit(numpy_data)`."
1291
+ )
1292
+ return x
1293
+
1294
+ def get_random_transform(self, img_shape, seed=None):
1295
+ """Generates random parameters for a transformation.
1296
+
1297
+ Args:
1298
+ img_shape: Tuple of integers.
1299
+ Shape of the image that is transformed.
1300
+ seed: Random seed.
1301
+
1302
+ Returns:
1303
+ A dictionary containing randomly chosen parameters describing the
1304
+ transformation.
1305
+ """
1306
+ img_row_axis = self.row_axis - 1
1307
+ img_col_axis = self.col_axis - 1
1308
+
1309
+ if seed is not None:
1310
+ np.random.seed(seed)
1311
+
1312
+ if self.rotation_range:
1313
+ theta = np.random.uniform(-self.rotation_range, self.rotation_range)
1314
+ else:
1315
+ theta = 0
1316
+
1317
+ if self.height_shift_range:
1318
+ try: # 1-D array-like or int
1319
+ tx = np.random.choice(self.height_shift_range)
1320
+ tx *= np.random.choice([-1, 1])
1321
+ except ValueError: # floating point
1322
+ tx = np.random.uniform(
1323
+ -self.height_shift_range, self.height_shift_range
1324
+ )
1325
+ if np.max(self.height_shift_range) < 1:
1326
+ tx *= img_shape[img_row_axis]
1327
+ else:
1328
+ tx = 0
1329
+
1330
+ if self.width_shift_range:
1331
+ try: # 1-D array-like or int
1332
+ ty = np.random.choice(self.width_shift_range)
1333
+ ty *= np.random.choice([-1, 1])
1334
+ except ValueError: # floating point
1335
+ ty = np.random.uniform(
1336
+ -self.width_shift_range, self.width_shift_range
1337
+ )
1338
+ if np.max(self.width_shift_range) < 1:
1339
+ ty *= img_shape[img_col_axis]
1340
+ else:
1341
+ ty = 0
1342
+
1343
+ if self.shear_range:
1344
+ shear = np.random.uniform(-self.shear_range, self.shear_range)
1345
+ else:
1346
+ shear = 0
1347
+
1348
+ if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:
1349
+ zx, zy = 1, 1
1350
+ else:
1351
+ zx, zy = np.random.uniform(
1352
+ self.zoom_range[0], self.zoom_range[1], 2
1353
+ )
1354
+
1355
+ flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip
1356
+ flip_vertical = (np.random.random() < 0.5) * self.vertical_flip
1357
+
1358
+ channel_shift_intensity = None
1359
+ if self.channel_shift_range != 0:
1360
+ channel_shift_intensity = np.random.uniform(
1361
+ -self.channel_shift_range, self.channel_shift_range
1362
+ )
1363
+
1364
+ brightness = None
1365
+ if self.brightness_range is not None:
1366
+ brightness = np.random.uniform(
1367
+ self.brightness_range[0], self.brightness_range[1]
1368
+ )
1369
+
1370
+ transform_parameters = {
1371
+ "theta": theta,
1372
+ "tx": tx,
1373
+ "ty": ty,
1374
+ "shear": shear,
1375
+ "zx": zx,
1376
+ "zy": zy,
1377
+ "flip_horizontal": flip_horizontal,
1378
+ "flip_vertical": flip_vertical,
1379
+ "channel_shift_intensity": channel_shift_intensity,
1380
+ "brightness": brightness,
1381
+ }
1382
+
1383
+ return transform_parameters
1384
+
1385
+ def apply_transform(self, x, transform_parameters):
1386
+ """Applies a transformation to an image according to given parameters.
1387
+
1388
+ Args:
1389
+ x: 3D tensor, single image.
1390
+ transform_parameters: Dictionary with string - parameter pairs
1391
+ describing the transformation.
1392
+ Currently, the following parameters
1393
+ from the dictionary are used:
1394
+ - `'theta'`: Float. Rotation angle in degrees.
1395
+ - `'tx'`: Float. Shift in the x direction.
1396
+ - `'ty'`: Float. Shift in the y direction.
1397
+ - `'shear'`: Float. Shear angle in degrees.
1398
+ - `'zx'`: Float. Zoom in the x direction.
1399
+ - `'zy'`: Float. Zoom in the y direction.
1400
+ - `'flip_horizontal'`: Boolean. Horizontal flip.
1401
+ - `'flip_vertical'`: Boolean. Vertical flip.
1402
+ - `'channel_shift_intensity'`: Float. Channel shift intensity.
1403
+ - `'brightness'`: Float. Brightness shift intensity.
1404
+
1405
+ Returns:
1406
+ A transformed version of the input (same shape).
1407
+ """
1408
+ # x is a single image, so it doesn't have image number at index 0
1409
+ img_row_axis = self.row_axis - 1
1410
+ img_col_axis = self.col_axis - 1
1411
+ img_channel_axis = self.channel_axis - 1
1412
+
1413
+ x = apply_affine_transform(
1414
+ x,
1415
+ transform_parameters.get("theta", 0),
1416
+ transform_parameters.get("tx", 0),
1417
+ transform_parameters.get("ty", 0),
1418
+ transform_parameters.get("shear", 0),
1419
+ transform_parameters.get("zx", 1),
1420
+ transform_parameters.get("zy", 1),
1421
+ row_axis=img_row_axis,
1422
+ col_axis=img_col_axis,
1423
+ channel_axis=img_channel_axis,
1424
+ fill_mode=self.fill_mode,
1425
+ cval=self.cval,
1426
+ order=self.interpolation_order,
1427
+ )
1428
+
1429
+ if transform_parameters.get("channel_shift_intensity") is not None:
1430
+ x = apply_channel_shift(
1431
+ x,
1432
+ transform_parameters["channel_shift_intensity"],
1433
+ img_channel_axis,
1434
+ )
1435
+
1436
+ if transform_parameters.get("flip_horizontal", False):
1437
+ x = flip_axis(x, img_col_axis)
1438
+
1439
+ if transform_parameters.get("flip_vertical", False):
1440
+ x = flip_axis(x, img_row_axis)
1441
+
1442
+ if transform_parameters.get("brightness") is not None:
1443
+ x = apply_brightness_shift(
1444
+ x, transform_parameters["brightness"], False
1445
+ )
1446
+
1447
+ return x
1448
+
1449
+ def random_transform(self, x, seed=None):
1450
+ """Applies a random transformation to an image.
1451
+
1452
+ Args:
1453
+ x: 3D tensor, single image.
1454
+ seed: Random seed.
1455
+
1456
+ Returns:
1457
+ A randomly transformed version of the input (same shape).
1458
+ """
1459
+ params = self.get_random_transform(x.shape, seed)
1460
+ return self.apply_transform(x, params)
1461
+
1462
+ def fit(self, x, augment=False, rounds=1, seed=None):
1463
+ """Fits the data generator to some sample data.
1464
+
1465
+ This computes the internal data stats related to the
1466
+ data-dependent transformations, based on an array of sample data.
1467
+
1468
+ Only required if `featurewise_center` or
1469
+ `featurewise_std_normalization` or `zca_whitening`
1470
+ are set to `True`.
1471
+
1472
+ When `rescale` is set to a value, rescaling is applied to
1473
+ sample data before computing the internal data stats.
1474
+
1475
+ Args:
1476
+ x: Sample data. Should have rank 4.
1477
+ In case of grayscale data,
1478
+ the channels axis should have value 1, in case
1479
+ of RGB data, it should have value 3, and in case
1480
+ of RGBA data, it should have value 4.
1481
+ augment: Boolean (default: False).
1482
+ Whether to fit on randomly augmented samples.
1483
+ rounds: Int (default: 1).
1484
+ If using data augmentation (`augment=True`),
1485
+ this is how many augmentation passes over the data to use.
1486
+ seed: Int (default: None). Random seed.
1487
+ """
1488
+ x = np.asarray(x, dtype=self.dtype)
1489
+ if x.ndim != 4:
1490
+ raise ValueError(
1491
+ "Input to `.fit()` should have rank 4. Got array with shape: "
1492
+ + str(x.shape)
1493
+ )
1494
+ if x.shape[self.channel_axis] not in {1, 3, 4}:
1495
+ warnings.warn(
1496
+ "Expected input to be images (as Numpy array) "
1497
+ 'following the data format convention "'
1498
+ + self.data_format
1499
+ + '" (channels on axis '
1500
+ + str(self.channel_axis)
1501
+ + "), i.e. expected either 1, 3 or 4 channels on axis "
1502
+ + str(self.channel_axis)
1503
+ + ". However, it was passed an array with shape "
1504
+ + str(x.shape)
1505
+ + " ("
1506
+ + str(x.shape[self.channel_axis])
1507
+ + " channels)."
1508
+ )
1509
+
1510
+ if seed is not None:
1511
+ np.random.seed(seed)
1512
+
1513
+ x = np.copy(x)
1514
+ if self.rescale:
1515
+ x *= self.rescale
1516
+
1517
+ if augment:
1518
+ ax = np.zeros(
1519
+ tuple([rounds * x.shape[0]] + list(x.shape)[1:]),
1520
+ dtype=self.dtype,
1521
+ )
1522
+ for r in range(rounds):
1523
+ for i in range(x.shape[0]):
1524
+ ax[i + r * x.shape[0]] = self.random_transform(x[i])
1525
+ x = ax
1526
+
1527
+ if self.featurewise_center:
1528
+ self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))
1529
+ broadcast_shape = [1, 1, 1]
1530
+ broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
1531
+ self.mean = np.reshape(self.mean, broadcast_shape)
1532
+ x -= self.mean
1533
+
1534
+ if self.featurewise_std_normalization:
1535
+ self.std = np.std(x, axis=(0, self.row_axis, self.col_axis))
1536
+ broadcast_shape = [1, 1, 1]
1537
+ broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
1538
+ self.std = np.reshape(self.std, broadcast_shape)
1539
+ x /= self.std + 1e-6
1540
+
1541
+ if self.zca_whitening:
1542
+ n = len(x)
1543
+ flat_x = np.reshape(x, (n, -1))
1544
+
1545
+ u, s, _ = np.linalg.svd(flat_x.T, full_matrices=False)
1546
+ s_inv = np.sqrt(n) / (s + self.zca_epsilon)
1547
+ self.zca_whitening_matrix = (u * s_inv).dot(u.T)
1548
+
1549
+
1550
+ @keras_export("keras._legacy.preprocessing.image.random_rotation")
1551
+ def random_rotation(
1552
+ x,
1553
+ rg,
1554
+ row_axis=1,
1555
+ col_axis=2,
1556
+ channel_axis=0,
1557
+ fill_mode="nearest",
1558
+ cval=0.0,
1559
+ interpolation_order=1,
1560
+ ):
1561
+ """DEPRECATED."""
1562
+ theta = np.random.uniform(-rg, rg)
1563
+ x = apply_affine_transform(
1564
+ x,
1565
+ theta=theta,
1566
+ row_axis=row_axis,
1567
+ col_axis=col_axis,
1568
+ channel_axis=channel_axis,
1569
+ fill_mode=fill_mode,
1570
+ cval=cval,
1571
+ order=interpolation_order,
1572
+ )
1573
+ return x
1574
+
1575
+
1576
+ @keras_export("keras._legacy.preprocessing.image.random_shift")
1577
+ def random_shift(
1578
+ x,
1579
+ wrg,
1580
+ hrg,
1581
+ row_axis=1,
1582
+ col_axis=2,
1583
+ channel_axis=0,
1584
+ fill_mode="nearest",
1585
+ cval=0.0,
1586
+ interpolation_order=1,
1587
+ ):
1588
+ """DEPRECATED."""
1589
+ h, w = x.shape[row_axis], x.shape[col_axis]
1590
+ tx = np.random.uniform(-hrg, hrg) * h
1591
+ ty = np.random.uniform(-wrg, wrg) * w
1592
+ x = apply_affine_transform(
1593
+ x,
1594
+ tx=tx,
1595
+ ty=ty,
1596
+ row_axis=row_axis,
1597
+ col_axis=col_axis,
1598
+ channel_axis=channel_axis,
1599
+ fill_mode=fill_mode,
1600
+ cval=cval,
1601
+ order=interpolation_order,
1602
+ )
1603
+ return x
1604
+
1605
+
1606
+ @keras_export("keras._legacy.preprocessing.image.random_shear")
1607
+ def random_shear(
1608
+ x,
1609
+ intensity,
1610
+ row_axis=1,
1611
+ col_axis=2,
1612
+ channel_axis=0,
1613
+ fill_mode="nearest",
1614
+ cval=0.0,
1615
+ interpolation_order=1,
1616
+ ):
1617
+ """DEPRECATED."""
1618
+ shear = np.random.uniform(-intensity, intensity)
1619
+ x = apply_affine_transform(
1620
+ x,
1621
+ shear=shear,
1622
+ row_axis=row_axis,
1623
+ col_axis=col_axis,
1624
+ channel_axis=channel_axis,
1625
+ fill_mode=fill_mode,
1626
+ cval=cval,
1627
+ order=interpolation_order,
1628
+ )
1629
+ return x
1630
+
1631
+
1632
+ @keras_export("keras._legacy.preprocessing.image.random_zoom")
1633
+ def random_zoom(
1634
+ x,
1635
+ zoom_range,
1636
+ row_axis=1,
1637
+ col_axis=2,
1638
+ channel_axis=0,
1639
+ fill_mode="nearest",
1640
+ cval=0.0,
1641
+ interpolation_order=1,
1642
+ ):
1643
+ """DEPRECATED."""
1644
+ if len(zoom_range) != 2:
1645
+ raise ValueError(
1646
+ "`zoom_range` should be a tuple or list of two floats. "
1647
+ f"Received: {zoom_range}"
1648
+ )
1649
+
1650
+ if zoom_range[0] == 1 and zoom_range[1] == 1:
1651
+ zx, zy = 1, 1
1652
+ else:
1653
+ zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
1654
+ x = apply_affine_transform(
1655
+ x,
1656
+ zx=zx,
1657
+ zy=zy,
1658
+ row_axis=row_axis,
1659
+ col_axis=col_axis,
1660
+ channel_axis=channel_axis,
1661
+ fill_mode=fill_mode,
1662
+ cval=cval,
1663
+ order=interpolation_order,
1664
+ )
1665
+ return x
1666
+
1667
+
1668
+ @keras_export("keras._legacy.preprocessing.image.apply_channel_shift")
1669
+ def apply_channel_shift(x, intensity, channel_axis=0):
1670
+ """Performs a channel shift.
1671
+
1672
+ DEPRECATED.
1673
+
1674
+ Args:
1675
+ x: Input tensor. Must be 3D.
1676
+ intensity: Transformation intensity.
1677
+ channel_axis: Index of axis for channels in the input tensor.
1678
+
1679
+ Returns:
1680
+ Numpy image tensor.
1681
+ """
1682
+ x = np.rollaxis(x, channel_axis, 0)
1683
+ min_x, max_x = np.min(x), np.max(x)
1684
+ channel_images = [
1685
+ np.clip(x_channel + intensity, min_x, max_x) for x_channel in x
1686
+ ]
1687
+ x = np.stack(channel_images, axis=0)
1688
+ x = np.rollaxis(x, 0, channel_axis + 1)
1689
+ return x
1690
+
1691
+
1692
+ @keras_export("keras._legacy.preprocessing.image.random_channel_shift")
1693
+ def random_channel_shift(x, intensity_range, channel_axis=0):
1694
+ """Performs a random channel shift.
1695
+
1696
+ DEPRECATED.
1697
+
1698
+ Args:
1699
+ x: Input tensor. Must be 3D.
1700
+ intensity_range: Transformation intensity.
1701
+ channel_axis: Index of axis for channels in the input tensor.
1702
+
1703
+ Returns:
1704
+ Numpy image tensor.
1705
+ """
1706
+ intensity = np.random.uniform(-intensity_range, intensity_range)
1707
+ return apply_channel_shift(x, intensity, channel_axis=channel_axis)
1708
+
1709
+
1710
+ @keras_export("keras._legacy.preprocessing.image.apply_brightness_shift")
1711
+ def apply_brightness_shift(x, brightness, scale=True):
1712
+ """Performs a brightness shift.
1713
+
1714
+ DEPRECATED.
1715
+
1716
+ Args:
1717
+ x: Input tensor. Must be 3D.
1718
+ brightness: Float. The new brightness value.
1719
+ scale: Whether to rescale the image such that minimum and maximum values
1720
+ are 0 and 255 respectively. Default: True.
1721
+
1722
+ Returns:
1723
+ Numpy image tensor.
1724
+
1725
+ Raises:
1726
+ ImportError: if PIL is not available.
1727
+ """
1728
+ from PIL import ImageEnhance
1729
+
1730
+ x_min, x_max = np.min(x), np.max(x)
1731
+ local_scale = (x_min < 0) or (x_max > 255)
1732
+ x = image_utils.array_to_img(x, scale=local_scale or scale)
1733
+ x = imgenhancer_Brightness = ImageEnhance.Brightness(x)
1734
+ x = imgenhancer_Brightness.enhance(brightness)
1735
+ x = image_utils.img_to_array(x)
1736
+ if not scale and local_scale:
1737
+ x = x / 255 * (x_max - x_min) + x_min
1738
+ return x
1739
+
1740
+
1741
+ @keras_export("keras._legacy.preprocessing.image.random_brightness")
1742
+ def random_brightness(x, brightness_range, scale=True):
1743
+ """Performs a random brightness shift.
1744
+
1745
+ DEPRECATED.
1746
+
1747
+ Args:
1748
+ x: Input tensor. Must be 3D.
1749
+ brightness_range: Tuple of floats; brightness range.
1750
+ scale: Whether to rescale the image such that minimum and maximum values
1751
+ are 0 and 255 respectively. Default: True.
1752
+
1753
+ Returns:
1754
+ Numpy image tensor.
1755
+
1756
+ Raises:
1757
+ ValueError if `brightness_range` isn't a tuple.
1758
+ """
1759
+ if len(brightness_range) != 2:
1760
+ raise ValueError(
1761
+ "`brightness_range should be tuple or list of two floats. "
1762
+ f"Received: {brightness_range}"
1763
+ )
1764
+
1765
+ u = np.random.uniform(brightness_range[0], brightness_range[1])
1766
+ return apply_brightness_shift(x, u, scale)
1767
+
1768
+
1769
+ def transform_matrix_offset_center(matrix, x, y):
1770
+ o_x = float(x) / 2 - 0.5
1771
+ o_y = float(y) / 2 - 0.5
1772
+ offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
1773
+ reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
1774
+ transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
1775
+ return transform_matrix
1776
+
1777
+
1778
+ @keras_export("keras._legacy.preprocessing.image.apply_affine_transform")
1779
+ def apply_affine_transform(
1780
+ x,
1781
+ theta=0,
1782
+ tx=0,
1783
+ ty=0,
1784
+ shear=0,
1785
+ zx=1,
1786
+ zy=1,
1787
+ row_axis=1,
1788
+ col_axis=2,
1789
+ channel_axis=0,
1790
+ fill_mode="nearest",
1791
+ cval=0.0,
1792
+ order=1,
1793
+ ):
1794
+ """Applies an affine transformation specified by the parameters given.
1795
+
1796
+ DEPRECATED.
1797
+ """
1798
+ # Input sanity checks:
1799
+ # 1. x must 2D image with one or more channels (i.e., a 3D tensor)
1800
+ # 2. channels must be either first or last dimension
1801
+ if np.unique([row_axis, col_axis, channel_axis]).size != 3:
1802
+ raise ValueError(
1803
+ "'row_axis', 'col_axis', and 'channel_axis' must be distinct"
1804
+ )
1805
+
1806
+ # shall we support negative indices?
1807
+ valid_indices = set([0, 1, 2])
1808
+ actual_indices = set([row_axis, col_axis, channel_axis])
1809
+ if actual_indices != valid_indices:
1810
+ raise ValueError(
1811
+ f"Invalid axis' indices: {actual_indices - valid_indices}"
1812
+ )
1813
+
1814
+ if x.ndim != 3:
1815
+ raise ValueError("Input arrays must be multi-channel 2D images.")
1816
+ if channel_axis not in [0, 2]:
1817
+ raise ValueError(
1818
+ "Channels are allowed and the first and last dimensions."
1819
+ )
1820
+
1821
+ transform_matrix = None
1822
+ if theta != 0:
1823
+ theta = np.deg2rad(theta)
1824
+ rotation_matrix = np.array(
1825
+ [
1826
+ [np.cos(theta), -np.sin(theta), 0],
1827
+ [np.sin(theta), np.cos(theta), 0],
1828
+ [0, 0, 1],
1829
+ ]
1830
+ )
1831
+ transform_matrix = rotation_matrix
1832
+
1833
+ if tx != 0 or ty != 0:
1834
+ shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
1835
+ if transform_matrix is None:
1836
+ transform_matrix = shift_matrix
1837
+ else:
1838
+ transform_matrix = np.dot(transform_matrix, shift_matrix)
1839
+
1840
+ if shear != 0:
1841
+ shear = np.deg2rad(shear)
1842
+ shear_matrix = np.array(
1843
+ [[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]]
1844
+ )
1845
+ if transform_matrix is None:
1846
+ transform_matrix = shear_matrix
1847
+ else:
1848
+ transform_matrix = np.dot(transform_matrix, shear_matrix)
1849
+
1850
+ if zx != 1 or zy != 1:
1851
+ zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]])
1852
+ if transform_matrix is None:
1853
+ transform_matrix = zoom_matrix
1854
+ else:
1855
+ transform_matrix = np.dot(transform_matrix, zoom_matrix)
1856
+
1857
+ if transform_matrix is not None:
1858
+ h, w = x.shape[row_axis], x.shape[col_axis]
1859
+ transform_matrix = transform_matrix_offset_center(
1860
+ transform_matrix, h, w
1861
+ )
1862
+ x = np.rollaxis(x, channel_axis, 0)
1863
+
1864
+ # Matrix construction assumes that coordinates are x, y (in that order).
1865
+ # However, regular numpy arrays use y,x (aka i,j) indexing.
1866
+ # Possible solution is:
1867
+ # 1. Swap the x and y axes.
1868
+ # 2. Apply transform.
1869
+ # 3. Swap the x and y axes again to restore image-like data ordering.
1870
+ # Mathematically, it is equivalent to the following transformation:
1871
+ # M' = PMP, where P is the permutation matrix, M is the original
1872
+ # transformation matrix.
1873
+ if col_axis > row_axis:
1874
+ transform_matrix[:, [0, 1]] = transform_matrix[:, [1, 0]]
1875
+ transform_matrix[[0, 1]] = transform_matrix[[1, 0]]
1876
+ final_affine_matrix = transform_matrix[:2, :2]
1877
+ final_offset = transform_matrix[:2, 2]
1878
+
1879
+ channel_images = [
1880
+ scipy.ndimage.interpolation.affine_transform(
1881
+ x_channel,
1882
+ final_affine_matrix,
1883
+ final_offset,
1884
+ order=order,
1885
+ mode=fill_mode,
1886
+ cval=cval,
1887
+ )
1888
+ for x_channel in x
1889
+ ]
1890
+ x = np.stack(channel_images, axis=0)
1891
+ x = np.rollaxis(x, 0, channel_axis + 1)
1892
+ return x
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/sequence.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deprecated sequence preprocessing APIs from Keras 1."""
2
+
3
+ import json
4
+ import random
5
+
6
+ import numpy as np
7
+
8
+ from keras.src.api_export import keras_export
9
+ from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset
10
+
11
+
12
+ @keras_export("keras._legacy.preprocessing.sequence.TimeseriesGenerator")
13
+ class TimeseriesGenerator(PyDataset):
14
+ """Utility class for generating batches of temporal data.
15
+
16
+ DEPRECATED.
17
+
18
+ This class takes in a sequence of data-points gathered at
19
+ equal intervals, along with time series parameters such as
20
+ stride, length of history, etc., to produce batches for
21
+ training/validation.
22
+
23
+ Arguments:
24
+ data: Indexable generator (such as list or Numpy array)
25
+ containing consecutive data points (timesteps).
26
+ The data should be at 2D, and axis 0 is expected
27
+ to be the time dimension.
28
+ targets: Targets corresponding to timesteps in `data`.
29
+ It should have same length as `data`.
30
+ length: Length of the output sequences (in number of timesteps).
31
+ sampling_rate: Period between successive individual timesteps
32
+ within sequences. For rate `r`, timesteps
33
+ `data[i]`, `data[i-r]`, ... `data[i - length]`
34
+ are used for create a sample sequence.
35
+ stride: Period between successive output sequences.
36
+ For stride `s`, consecutive output samples would
37
+ be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc.
38
+ start_index: Data points earlier than `start_index` will not be used
39
+ in the output sequences. This is useful to reserve part of the
40
+ data for test or validation.
41
+ end_index: Data points later than `end_index` will not be used
42
+ in the output sequences. This is useful to reserve part of the
43
+ data for test or validation.
44
+ shuffle: Whether to shuffle output samples,
45
+ or instead draw them in chronological order.
46
+ reverse: Boolean: if `true`, timesteps in each output sample will be
47
+ in reverse chronological order.
48
+ batch_size: Number of timeseries samples in each batch
49
+ (except maybe the last one).
50
+
51
+ Returns:
52
+ A PyDataset instance.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ data,
58
+ targets,
59
+ length,
60
+ sampling_rate=1,
61
+ stride=1,
62
+ start_index=0,
63
+ end_index=None,
64
+ shuffle=False,
65
+ reverse=False,
66
+ batch_size=128,
67
+ ):
68
+ if len(data) != len(targets):
69
+ raise ValueError(
70
+ "Data and targets have to be "
71
+ f"of same length. Data length is {len(data)} "
72
+ f"while target length is {len(targets)}"
73
+ )
74
+
75
+ self.data = data
76
+ self.targets = targets
77
+ self.length = length
78
+ self.sampling_rate = sampling_rate
79
+ self.stride = stride
80
+ self.start_index = start_index + length
81
+ if end_index is None:
82
+ end_index = len(data) - 1
83
+ self.end_index = end_index
84
+ self.shuffle = shuffle
85
+ self.reverse = reverse
86
+ self.batch_size = batch_size
87
+
88
+ if self.start_index > self.end_index:
89
+ raise ValueError(
90
+ f"`start_index+length={self.start_index} "
91
+ f"> end_index={self.end_index}` "
92
+ "is disallowed, as no part of the sequence "
93
+ "would be left to be used as current step."
94
+ )
95
+
96
+ def __len__(self):
97
+ return (
98
+ self.end_index - self.start_index + self.batch_size * self.stride
99
+ ) // (self.batch_size * self.stride)
100
+
101
+ def __getitem__(self, index):
102
+ if self.shuffle:
103
+ rows = np.random.randint(
104
+ self.start_index, self.end_index + 1, size=self.batch_size
105
+ )
106
+ else:
107
+ i = self.start_index + self.batch_size * self.stride * index
108
+ rows = np.arange(
109
+ i,
110
+ min(i + self.batch_size * self.stride, self.end_index + 1),
111
+ self.stride,
112
+ )
113
+
114
+ samples = np.array(
115
+ [
116
+ self.data[row - self.length : row : self.sampling_rate]
117
+ for row in rows
118
+ ]
119
+ )
120
+ targets = np.array([self.targets[row] for row in rows])
121
+
122
+ if self.reverse:
123
+ return samples[:, ::-1, ...], targets
124
+ return samples, targets
125
+
126
+ def get_config(self):
127
+ """Returns the TimeseriesGenerator configuration as Python dictionary.
128
+
129
+ Returns:
130
+ A Python dictionary with the TimeseriesGenerator configuration.
131
+ """
132
+ data = self.data
133
+ if type(self.data).__module__ == np.__name__:
134
+ data = self.data.tolist()
135
+ try:
136
+ json_data = json.dumps(data)
137
+ except TypeError as e:
138
+ raise TypeError(f"Data not JSON Serializable: {data}") from e
139
+
140
+ targets = self.targets
141
+ if type(self.targets).__module__ == np.__name__:
142
+ targets = self.targets.tolist()
143
+ try:
144
+ json_targets = json.dumps(targets)
145
+ except TypeError as e:
146
+ raise TypeError(f"Targets not JSON Serializable: {targets}") from e
147
+
148
+ return {
149
+ "data": json_data,
150
+ "targets": json_targets,
151
+ "length": self.length,
152
+ "sampling_rate": self.sampling_rate,
153
+ "stride": self.stride,
154
+ "start_index": self.start_index,
155
+ "end_index": self.end_index,
156
+ "shuffle": self.shuffle,
157
+ "reverse": self.reverse,
158
+ "batch_size": self.batch_size,
159
+ }
160
+
161
+ def to_json(self, **kwargs):
162
+ """Returns a JSON string containing the generator's configuration.
163
+
164
+ Args:
165
+ **kwargs: Additional keyword arguments to be passed
166
+ to `json.dumps()`.
167
+
168
+ Returns:
169
+ A JSON string containing the tokenizer configuration.
170
+ """
171
+ config = self.get_config()
172
+ timeseries_generator_config = {
173
+ "class_name": self.__class__.__name__,
174
+ "config": config,
175
+ }
176
+ return json.dumps(timeseries_generator_config, **kwargs)
177
+
178
+
179
+ @keras_export("keras._legacy.preprocessing.sequence.make_sampling_table")
180
+ def make_sampling_table(size, sampling_factor=1e-5):
181
+ """Generates a word rank-based probabilistic sampling table.
182
+
183
+ DEPRECATED.
184
+
185
+ Used for generating the `sampling_table` argument for `skipgrams`.
186
+ `sampling_table[i]` is the probability of sampling
187
+ the word i-th most common word in a dataset
188
+ (more common words should be sampled less frequently, for balance).
189
+
190
+ The sampling probabilities are generated according
191
+ to the sampling distribution used in word2vec:
192
+
193
+ ```
194
+ p(word) = (min(1, sqrt(word_frequency / sampling_factor) /
195
+ (word_frequency / sampling_factor)))
196
+ ```
197
+
198
+ We assume that the word frequencies follow Zipf's law (s=1) to derive
199
+ a numerical approximation of frequency(rank):
200
+
201
+ `frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank))`
202
+ where `gamma` is the Euler-Mascheroni constant.
203
+
204
+ Args:
205
+ size: Int, number of possible words to sample.
206
+ sampling_factor: The sampling factor in the word2vec formula.
207
+
208
+ Returns:
209
+ A 1D Numpy array of length `size` where the ith entry
210
+ is the probability that a word of rank i should be sampled.
211
+ """
212
+ gamma = 0.577
213
+ rank = np.arange(size)
214
+ rank[0] = 1
215
+ inv_fq = rank * (np.log(rank) + gamma) + 0.5 - 1.0 / (12.0 * rank)
216
+ f = sampling_factor * inv_fq
217
+
218
+ return np.minimum(1.0, f / np.sqrt(f))
219
+
220
+
221
+ @keras_export("keras._legacy.preprocessing.sequence.skipgrams")
222
+ def skipgrams(
223
+ sequence,
224
+ vocabulary_size,
225
+ window_size=4,
226
+ negative_samples=1.0,
227
+ shuffle=True,
228
+ categorical=False,
229
+ sampling_table=None,
230
+ seed=None,
231
+ ):
232
+ """Generates skipgram word pairs.
233
+
234
+ DEPRECATED.
235
+
236
+ This function transforms a sequence of word indexes (list of integers)
237
+ into tuples of words of the form:
238
+
239
+ - (word, word in the same window), with label 1 (positive samples).
240
+ - (word, random word from the vocabulary), with label 0 (negative samples).
241
+
242
+ Read more about Skipgram in this gnomic paper by Mikolov et al.:
243
+ [Efficient Estimation of Word Representations in
244
+ Vector Space](http://arxiv.org/pdf/1301.3781v3.pdf)
245
+
246
+ Args:
247
+ sequence: A word sequence (sentence), encoded as a list
248
+ of word indices (integers). If using a `sampling_table`,
249
+ word indices are expected to match the rank
250
+ of the words in a reference dataset (e.g. 10 would encode
251
+ the 10-th most frequently occurring token).
252
+ Note that index 0 is expected to be a non-word and will be skipped.
253
+ vocabulary_size: Int, maximum possible word index + 1
254
+ window_size: Int, size of sampling windows (technically half-window).
255
+ The window of a word `w_i` will be
256
+ `[i - window_size, i + window_size+1]`.
257
+ negative_samples: Float >= 0. 0 for no negative (i.e. random) samples.
258
+ 1 for same number as positive samples.
259
+ shuffle: Whether to shuffle the word couples before returning them.
260
+ categorical: bool. if False, labels will be
261
+ integers (eg. `[0, 1, 1 .. ]`),
262
+ if `True`, labels will be categorical, e.g.
263
+ `[[1,0],[0,1],[0,1] .. ]`.
264
+ sampling_table: 1D array of size `vocabulary_size` where the entry i
265
+ encodes the probability to sample a word of rank i.
266
+ seed: Random seed.
267
+
268
+ Returns:
269
+ couples, labels: where `couples` are int pairs and
270
+ `labels` are either 0 or 1.
271
+
272
+ Note:
273
+ By convention, index 0 in the vocabulary is
274
+ a non-word and will be skipped.
275
+ """
276
+ couples = []
277
+ labels = []
278
+ for i, wi in enumerate(sequence):
279
+ if not wi:
280
+ continue
281
+ if sampling_table is not None:
282
+ if sampling_table[wi] < random.random():
283
+ continue
284
+
285
+ window_start = max(0, i - window_size)
286
+ window_end = min(len(sequence), i + window_size + 1)
287
+ for j in range(window_start, window_end):
288
+ if j != i:
289
+ wj = sequence[j]
290
+ if not wj:
291
+ continue
292
+ couples.append([wi, wj])
293
+ if categorical:
294
+ labels.append([0, 1])
295
+ else:
296
+ labels.append(1)
297
+
298
+ if negative_samples > 0:
299
+ num_negative_samples = int(len(labels) * negative_samples)
300
+ words = [c[0] for c in couples]
301
+ random.shuffle(words)
302
+
303
+ couples += [
304
+ [words[i % len(words)], random.randint(1, vocabulary_size - 1)]
305
+ for i in range(num_negative_samples)
306
+ ]
307
+ if categorical:
308
+ labels += [[1, 0]] * num_negative_samples
309
+ else:
310
+ labels += [0] * num_negative_samples
311
+
312
+ if shuffle:
313
+ if seed is None:
314
+ seed = random.randint(0, 10e6)
315
+ random.seed(seed)
316
+ random.shuffle(couples)
317
+ random.seed(seed)
318
+ random.shuffle(labels)
319
+
320
+ return couples, labels
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/preprocessing/text.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deprecated text preprocessing APIs from Keras 1."""
2
+
3
+ import collections
4
+ import hashlib
5
+ import json
6
+ import warnings
7
+
8
+ import numpy as np
9
+
10
+ from keras.src.api_export import keras_export
11
+
12
+
13
+ @keras_export("keras._legacy.preprocessing.text.text_to_word_sequence")
14
+ def text_to_word_sequence(
15
+ input_text,
16
+ filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
17
+ lower=True,
18
+ split=" ",
19
+ ):
20
+ """DEPRECATED."""
21
+ if lower:
22
+ input_text = input_text.lower()
23
+
24
+ translate_dict = {c: split for c in filters}
25
+ translate_map = str.maketrans(translate_dict)
26
+ input_text = input_text.translate(translate_map)
27
+
28
+ seq = input_text.split(split)
29
+ return [i for i in seq if i]
30
+
31
+
32
+ @keras_export("keras._legacy.preprocessing.text.one_hot")
33
+ def one_hot(
34
+ input_text,
35
+ n,
36
+ filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
37
+ lower=True,
38
+ split=" ",
39
+ analyzer=None,
40
+ ):
41
+ """DEPRECATED."""
42
+ return hashing_trick(
43
+ input_text,
44
+ n,
45
+ hash_function=hash,
46
+ filters=filters,
47
+ lower=lower,
48
+ split=split,
49
+ analyzer=analyzer,
50
+ )
51
+
52
+
53
+ @keras_export("keras._legacy.preprocessing.text.hashing_trick")
54
+ def hashing_trick(
55
+ text,
56
+ n,
57
+ hash_function=None,
58
+ filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
59
+ lower=True,
60
+ split=" ",
61
+ analyzer=None,
62
+ ):
63
+ """DEPRECATED."""
64
+ if hash_function is None:
65
+ hash_function = hash
66
+ elif hash_function == "md5":
67
+
68
+ def hash_function(w):
69
+ return int(hashlib.md5(w.encode()).hexdigest(), 16)
70
+
71
+ if analyzer is None:
72
+ seq = text_to_word_sequence(
73
+ text, filters=filters, lower=lower, split=split
74
+ )
75
+ else:
76
+ seq = analyzer(text)
77
+
78
+ return [(hash_function(w) % (n - 1) + 1) for w in seq]
79
+
80
+
81
+ @keras_export("keras._legacy.preprocessing.text.Tokenizer")
82
+ class Tokenizer:
83
+ """DEPRECATED."""
84
+
85
+ def __init__(
86
+ self,
87
+ num_words=None,
88
+ filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
89
+ lower=True,
90
+ split=" ",
91
+ char_level=False,
92
+ oov_token=None,
93
+ analyzer=None,
94
+ **kwargs,
95
+ ):
96
+ # Legacy support
97
+ if "nb_words" in kwargs:
98
+ warnings.warn(
99
+ "The `nb_words` argument in `Tokenizer` "
100
+ "has been renamed `num_words`."
101
+ )
102
+ num_words = kwargs.pop("nb_words")
103
+ document_count = kwargs.pop("document_count", 0)
104
+ if kwargs:
105
+ raise TypeError("Unrecognized keyword arguments: " + str(kwargs))
106
+
107
+ self.word_counts = collections.OrderedDict()
108
+ self.word_docs = collections.defaultdict(int)
109
+ self.filters = filters
110
+ self.split = split
111
+ self.lower = lower
112
+ self.num_words = num_words
113
+ self.document_count = document_count
114
+ self.char_level = char_level
115
+ self.oov_token = oov_token
116
+ self.index_docs = collections.defaultdict(int)
117
+ self.word_index = {}
118
+ self.index_word = {}
119
+ self.analyzer = analyzer
120
+
121
+ def fit_on_texts(self, texts):
122
+ for text in texts:
123
+ self.document_count += 1
124
+ if self.char_level or isinstance(text, list):
125
+ if self.lower:
126
+ if isinstance(text, list):
127
+ text = [text_elem.lower() for text_elem in text]
128
+ else:
129
+ text = text.lower()
130
+ seq = text
131
+ else:
132
+ if self.analyzer is None:
133
+ seq = text_to_word_sequence(
134
+ text,
135
+ filters=self.filters,
136
+ lower=self.lower,
137
+ split=self.split,
138
+ )
139
+ else:
140
+ seq = self.analyzer(text)
141
+ for w in seq:
142
+ if w in self.word_counts:
143
+ self.word_counts[w] += 1
144
+ else:
145
+ self.word_counts[w] = 1
146
+ for w in set(seq):
147
+ # In how many documents each word occurs
148
+ self.word_docs[w] += 1
149
+
150
+ wcounts = list(self.word_counts.items())
151
+ wcounts.sort(key=lambda x: x[1], reverse=True)
152
+ # forcing the oov_token to index 1 if it exists
153
+ if self.oov_token is None:
154
+ sorted_voc = []
155
+ else:
156
+ sorted_voc = [self.oov_token]
157
+ sorted_voc.extend(wc[0] for wc in wcounts)
158
+
159
+ # note that index 0 is reserved, never assigned to an existing word
160
+ self.word_index = dict(
161
+ zip(sorted_voc, list(range(1, len(sorted_voc) + 1)))
162
+ )
163
+
164
+ self.index_word = {c: w for w, c in self.word_index.items()}
165
+
166
+ for w, c in list(self.word_docs.items()):
167
+ self.index_docs[self.word_index[w]] = c
168
+
169
+ def fit_on_sequences(self, sequences):
170
+ self.document_count += len(sequences)
171
+ for seq in sequences:
172
+ seq = set(seq)
173
+ for i in seq:
174
+ self.index_docs[i] += 1
175
+
176
+ def texts_to_sequences(self, texts):
177
+ return list(self.texts_to_sequences_generator(texts))
178
+
179
+ def texts_to_sequences_generator(self, texts):
180
+ num_words = self.num_words
181
+ oov_token_index = self.word_index.get(self.oov_token)
182
+ for text in texts:
183
+ if self.char_level or isinstance(text, list):
184
+ if self.lower:
185
+ if isinstance(text, list):
186
+ text = [text_elem.lower() for text_elem in text]
187
+ else:
188
+ text = text.lower()
189
+ seq = text
190
+ else:
191
+ if self.analyzer is None:
192
+ seq = text_to_word_sequence(
193
+ text,
194
+ filters=self.filters,
195
+ lower=self.lower,
196
+ split=self.split,
197
+ )
198
+ else:
199
+ seq = self.analyzer(text)
200
+ vect = []
201
+ for w in seq:
202
+ i = self.word_index.get(w)
203
+ if i is not None:
204
+ if num_words and i >= num_words:
205
+ if oov_token_index is not None:
206
+ vect.append(oov_token_index)
207
+ else:
208
+ vect.append(i)
209
+ elif self.oov_token is not None:
210
+ vect.append(oov_token_index)
211
+ yield vect
212
+
213
+ def sequences_to_texts(self, sequences):
214
+ return list(self.sequences_to_texts_generator(sequences))
215
+
216
+ def sequences_to_texts_generator(self, sequences):
217
+ num_words = self.num_words
218
+ oov_token_index = self.word_index.get(self.oov_token)
219
+ for seq in sequences:
220
+ vect = []
221
+ for num in seq:
222
+ word = self.index_word.get(num)
223
+ if word is not None:
224
+ if num_words and num >= num_words:
225
+ if oov_token_index is not None:
226
+ vect.append(self.index_word[oov_token_index])
227
+ else:
228
+ vect.append(word)
229
+ elif self.oov_token is not None:
230
+ vect.append(self.index_word[oov_token_index])
231
+ vect = " ".join(vect)
232
+ yield vect
233
+
234
+ def texts_to_matrix(self, texts, mode="binary"):
235
+ sequences = self.texts_to_sequences(texts)
236
+ return self.sequences_to_matrix(sequences, mode=mode)
237
+
238
+ def sequences_to_matrix(self, sequences, mode="binary"):
239
+ if not self.num_words:
240
+ if self.word_index:
241
+ num_words = len(self.word_index) + 1
242
+ else:
243
+ raise ValueError(
244
+ "Specify a dimension (`num_words` argument), "
245
+ "or fit on some text data first."
246
+ )
247
+ else:
248
+ num_words = self.num_words
249
+
250
+ if mode == "tfidf" and not self.document_count:
251
+ raise ValueError(
252
+ "Fit the Tokenizer on some data before using tfidf mode."
253
+ )
254
+
255
+ x = np.zeros((len(sequences), num_words))
256
+ for i, seq in enumerate(sequences):
257
+ if not seq:
258
+ continue
259
+ counts = collections.defaultdict(int)
260
+ for j in seq:
261
+ if j >= num_words:
262
+ continue
263
+ counts[j] += 1
264
+ for j, c in list(counts.items()):
265
+ if mode == "count":
266
+ x[i][j] = c
267
+ elif mode == "freq":
268
+ x[i][j] = c / len(seq)
269
+ elif mode == "binary":
270
+ x[i][j] = 1
271
+ elif mode == "tfidf":
272
+ # Use weighting scheme 2 in
273
+ # https://en.wikipedia.org/wiki/Tf%E2%80%93idf
274
+ tf = 1 + np.log(c)
275
+ idf = np.log(
276
+ 1
277
+ + self.document_count / (1 + self.index_docs.get(j, 0))
278
+ )
279
+ x[i][j] = tf * idf
280
+ else:
281
+ raise ValueError("Unknown vectorization mode:", mode)
282
+ return x
283
+
284
+ def get_config(self):
285
+ json_word_counts = json.dumps(self.word_counts)
286
+ json_word_docs = json.dumps(self.word_docs)
287
+ json_index_docs = json.dumps(self.index_docs)
288
+ json_word_index = json.dumps(self.word_index)
289
+ json_index_word = json.dumps(self.index_word)
290
+
291
+ return {
292
+ "num_words": self.num_words,
293
+ "filters": self.filters,
294
+ "lower": self.lower,
295
+ "split": self.split,
296
+ "char_level": self.char_level,
297
+ "oov_token": self.oov_token,
298
+ "document_count": self.document_count,
299
+ "word_counts": json_word_counts,
300
+ "word_docs": json_word_docs,
301
+ "index_docs": json_index_docs,
302
+ "index_word": json_index_word,
303
+ "word_index": json_word_index,
304
+ }
305
+
306
+ def to_json(self, **kwargs):
307
+ config = self.get_config()
308
+ tokenizer_config = {
309
+ "class_name": self.__class__.__name__,
310
+ "config": config,
311
+ }
312
+ return json.dumps(tokenizer_config, **kwargs)
313
+
314
+
315
+ @keras_export("keras._legacy.preprocessing.text.tokenizer_from_json")
316
+ def tokenizer_from_json(json_string):
317
+ """DEPRECATED."""
318
+ tokenizer_config = json.loads(json_string)
319
+ config = tokenizer_config.get("config")
320
+
321
+ word_counts = json.loads(config.pop("word_counts"))
322
+ word_docs = json.loads(config.pop("word_docs"))
323
+ index_docs = json.loads(config.pop("index_docs"))
324
+ # Integer indexing gets converted to strings with json.dumps()
325
+ index_docs = {int(k): v for k, v in index_docs.items()}
326
+ index_word = json.loads(config.pop("index_word"))
327
+ index_word = {int(k): v for k, v in index_word.items()}
328
+ word_index = json.loads(config.pop("word_index"))
329
+
330
+ tokenizer = Tokenizer(**config)
331
+ tokenizer.word_counts = word_counts
332
+ tokenizer.word_docs = word_docs
333
+ tokenizer.index_docs = index_docs
334
+ tokenizer.word_index = word_index
335
+ tokenizer.index_word = index_word
336
+ return tokenizer
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (199 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/json_utils.cpython-310.pyc ADDED
Binary file (5.84 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/legacy_h5_format.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/saving_options.cpython-310.pyc ADDED
Binary file (612 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/saving_utils.cpython-310.pyc ADDED
Binary file (6.75 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/__pycache__/serialization.cpython-310.pyc ADDED
Binary file (13.8 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/json_utils.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """JSON utilities for legacy saving formats (h5 and SavedModel)"""
2
+
3
+ import collections
4
+ import enum
5
+ import functools
6
+ import json
7
+
8
+ import numpy as np
9
+
10
+ from keras.src.legacy.saving import serialization
11
+ from keras.src.saving import serialization_lib
12
+ from keras.src.utils.module_utils import tensorflow as tf
13
+
14
+ _EXTENSION_TYPE_SPEC = "_EXTENSION_TYPE_SPEC"
15
+
16
+
17
+ class Encoder(json.JSONEncoder):
18
+ """JSON encoder and decoder that handles TensorShapes and tuples."""
19
+
20
+ def default(self, obj):
21
+ """Encodes objects for types that aren't handled by the default
22
+ encoder."""
23
+ if tf.available and isinstance(obj, tf.TensorShape):
24
+ items = obj.as_list() if obj.rank is not None else None
25
+ return {"class_name": "TensorShape", "items": items}
26
+ return get_json_type(obj)
27
+
28
+ def encode(self, obj):
29
+ return super().encode(_encode_tuple(obj))
30
+
31
+
32
+ def _encode_tuple(x):
33
+ if isinstance(x, tuple):
34
+ return {
35
+ "class_name": "__tuple__",
36
+ "items": tuple(_encode_tuple(i) for i in x),
37
+ }
38
+ elif isinstance(x, list):
39
+ return [_encode_tuple(i) for i in x]
40
+ elif isinstance(x, dict):
41
+ return {key: _encode_tuple(value) for key, value in x.items()}
42
+ else:
43
+ return x
44
+
45
+
46
+ def decode(json_string):
47
+ return json.loads(json_string, object_hook=_decode_helper)
48
+
49
+
50
+ def decode_and_deserialize(
51
+ json_string, module_objects=None, custom_objects=None
52
+ ):
53
+ """Decodes the JSON and deserializes any Keras objects found in the dict."""
54
+ return json.loads(
55
+ json_string,
56
+ object_hook=functools.partial(
57
+ _decode_helper,
58
+ deserialize=True,
59
+ module_objects=module_objects,
60
+ custom_objects=custom_objects,
61
+ ),
62
+ )
63
+
64
+
65
+ def _decode_helper(
66
+ obj, deserialize=False, module_objects=None, custom_objects=None
67
+ ):
68
+ """A decoding helper that is TF-object aware.
69
+
70
+ Args:
71
+ obj: A decoded dictionary that may represent an object.
72
+ deserialize: Boolean. When True, deserializes any Keras
73
+ objects found in `obj`. Defaults to `False`.
74
+ module_objects: A dictionary of built-in objects to look the name up in.
75
+ Generally, `module_objects` is provided by midlevel library
76
+ implementers.
77
+ custom_objects: A dictionary of custom objects to look the name up in.
78
+ Generally, `custom_objects` is provided by the end user.
79
+
80
+ Returns:
81
+ The decoded object.
82
+ """
83
+ if isinstance(obj, dict) and "class_name" in obj:
84
+ if tf.available:
85
+ if obj["class_name"] == "TensorShape":
86
+ return tf.TensorShape(obj["items"])
87
+ elif obj["class_name"] == "TypeSpec":
88
+ from tensorflow.python.framework import type_spec_registry
89
+
90
+ return type_spec_registry.lookup(obj["type_spec"])._deserialize(
91
+ _decode_helper(obj["serialized"])
92
+ )
93
+ elif obj["class_name"] == "CompositeTensor":
94
+ spec = obj["spec"]
95
+ tensors = []
96
+ for dtype, tensor in obj["tensors"]:
97
+ tensors.append(
98
+ tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype))
99
+ )
100
+ return tf.nest.pack_sequence_as(
101
+ _decode_helper(spec), tensors, expand_composites=True
102
+ )
103
+
104
+ if obj["class_name"] == "__tuple__":
105
+ return tuple(_decode_helper(i) for i in obj["items"])
106
+ elif obj["class_name"] == "__ellipsis__":
107
+ return Ellipsis
108
+ elif deserialize and "__passive_serialization__" in obj:
109
+ # __passive_serialization__ is added by the JSON encoder when
110
+ # encoding an object that has a `get_config()` method.
111
+ try:
112
+ if (
113
+ "module" not in obj
114
+ ): # TODO(nkovela): Add TF SavedModel scope
115
+ return serialization.deserialize_keras_object(
116
+ obj,
117
+ module_objects=module_objects,
118
+ custom_objects=custom_objects,
119
+ )
120
+ else:
121
+ return serialization_lib.deserialize_keras_object(
122
+ obj,
123
+ module_objects=module_objects,
124
+ custom_objects=custom_objects,
125
+ )
126
+ except ValueError:
127
+ pass
128
+ elif obj["class_name"] == "__bytes__":
129
+ return obj["value"].encode("utf-8")
130
+ return obj
131
+
132
+
133
+ def get_json_type(obj):
134
+ """Serializes any object to a JSON-serializable structure.
135
+
136
+ Args:
137
+ obj: the object to serialize
138
+
139
+ Returns:
140
+ JSON-serializable structure representing `obj`.
141
+
142
+ Raises:
143
+ TypeError: if `obj` cannot be serialized.
144
+ """
145
+ # if obj is a serializable Keras class instance
146
+ # e.g. optimizer, layer
147
+ if hasattr(obj, "get_config"):
148
+ # TODO(nkovela): Replace with legacy serialization
149
+ serialized = serialization.serialize_keras_object(obj)
150
+ serialized["__passive_serialization__"] = True
151
+ return serialized
152
+
153
+ # if obj is any numpy type
154
+ if type(obj).__module__ == np.__name__:
155
+ if isinstance(obj, np.ndarray):
156
+ return obj.tolist()
157
+ else:
158
+ return obj.item()
159
+
160
+ # misc functions (e.g. loss function)
161
+ if callable(obj):
162
+ return obj.__name__
163
+
164
+ # if obj is a python 'type'
165
+ if type(obj).__name__ == type.__name__:
166
+ return obj.__name__
167
+
168
+ if tf.available and isinstance(obj, tf.compat.v1.Dimension):
169
+ return obj.value
170
+
171
+ if tf.available and isinstance(obj, tf.TensorShape):
172
+ return obj.as_list()
173
+
174
+ if tf.available and isinstance(obj, tf.DType):
175
+ return obj.name
176
+
177
+ if isinstance(obj, collections.abc.Mapping):
178
+ return dict(obj)
179
+
180
+ if obj is Ellipsis:
181
+ return {"class_name": "__ellipsis__"}
182
+
183
+ # if isinstance(obj, wrapt.ObjectProxy):
184
+ # return obj.__wrapped__
185
+
186
+ if tf.available and isinstance(obj, tf.TypeSpec):
187
+ from tensorflow.python.framework import type_spec_registry
188
+
189
+ try:
190
+ type_spec_name = type_spec_registry.get_name(type(obj))
191
+ return {
192
+ "class_name": "TypeSpec",
193
+ "type_spec": type_spec_name,
194
+ "serialized": obj._serialize(),
195
+ }
196
+ except ValueError:
197
+ raise ValueError(
198
+ f"Unable to serialize {obj} to JSON, because the TypeSpec "
199
+ f"class {type(obj)} has not been registered."
200
+ )
201
+ if tf.available and isinstance(obj, tf.__internal__.CompositeTensor):
202
+ spec = tf.type_spec_from_value(obj)
203
+ tensors = []
204
+ for tensor in tf.nest.flatten(obj, expand_composites=True):
205
+ tensors.append((tensor.dtype.name, tensor.numpy().tolist()))
206
+ return {
207
+ "class_name": "CompositeTensor",
208
+ "spec": get_json_type(spec),
209
+ "tensors": tensors,
210
+ }
211
+
212
+ if isinstance(obj, enum.Enum):
213
+ return obj.value
214
+
215
+ if isinstance(obj, bytes):
216
+ return {"class_name": "__bytes__", "value": obj.decode("utf-8")}
217
+
218
+ raise TypeError(
219
+ f"Unable to serialize {obj} to JSON. Unrecognized type {type(obj)}."
220
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/legacy_h5_format.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import warnings
4
+
5
+ import numpy as np
6
+ from absl import logging
7
+
8
+ from keras.src import backend
9
+ from keras.src import optimizers
10
+ from keras.src.backend.common import global_state
11
+ from keras.src.legacy.saving import json_utils
12
+ from keras.src.legacy.saving import saving_options
13
+ from keras.src.legacy.saving import saving_utils
14
+ from keras.src.saving import object_registration
15
+ from keras.src.utils import io_utils
16
+
17
+ try:
18
+ import h5py
19
+ except ImportError:
20
+ h5py = None
21
+
22
+
23
+ HDF5_OBJECT_HEADER_LIMIT = 64512
24
+
25
+
26
+ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
27
+ if h5py is None:
28
+ raise ImportError(
29
+ "`save_model()` using h5 format requires h5py. Could not "
30
+ "import h5py."
31
+ )
32
+
33
+ if not isinstance(filepath, h5py.File):
34
+ # If file exists and should not be overwritten.
35
+ if not overwrite and os.path.isfile(filepath):
36
+ proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
37
+ if not proceed:
38
+ return
39
+
40
+ dirpath = os.path.dirname(filepath)
41
+ if dirpath and not os.path.exists(dirpath):
42
+ os.makedirs(dirpath, exist_ok=True)
43
+
44
+ f = h5py.File(filepath, mode="w")
45
+ opened_new_file = True
46
+ else:
47
+ f = filepath
48
+ opened_new_file = False
49
+ try:
50
+ with saving_options.keras_option_scope(use_legacy_config=True):
51
+ model_metadata = saving_utils.model_metadata(
52
+ model, include_optimizer
53
+ )
54
+ for k, v in model_metadata.items():
55
+ if isinstance(v, (dict, list, tuple)):
56
+ f.attrs[k] = json.dumps(
57
+ v, default=json_utils.get_json_type
58
+ ).encode("utf8")
59
+ else:
60
+ f.attrs[k] = v
61
+
62
+ model_weights_group = f.create_group("model_weights")
63
+ save_weights_to_hdf5_group(model_weights_group, model)
64
+
65
+ # TODO(b/128683857): Add integration tests between tf.keras and
66
+ # external Keras, to avoid breaking TF.js users.
67
+ if include_optimizer and hasattr(model, "optimizer"):
68
+ save_optimizer_weights_to_hdf5_group(f, model.optimizer)
69
+
70
+ f.flush()
71
+ finally:
72
+ if opened_new_file:
73
+ f.close()
74
+
75
+
76
+ def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
77
+ """Loads a model saved via `save_model_to_hdf5`.
78
+
79
+ Args:
80
+ filepath: One of the following:
81
+ - String, path to the saved model
82
+ - `h5py.File` object from which to load the model
83
+ custom_objects: Optional dictionary mapping names
84
+ (strings) to custom classes or functions to be
85
+ considered during deserialization.
86
+ compile: Boolean, whether to compile the model
87
+ after loading.
88
+
89
+ Returns:
90
+ A Keras model instance. If an optimizer was found
91
+ as part of the saved model, the model is already
92
+ compiled. Otherwise, the model is uncompiled and
93
+ a warning will be displayed. When `compile` is set
94
+ to `False`, the compilation is omitted without any
95
+ warning.
96
+
97
+ Raises:
98
+ ImportError: if h5py is not available.
99
+ ValueError: In case of an invalid savefile.
100
+ """
101
+ if h5py is None:
102
+ raise ImportError(
103
+ "`load_model()` using h5 format requires h5py. Could not "
104
+ "import h5py."
105
+ )
106
+
107
+ if not custom_objects:
108
+ custom_objects = {}
109
+
110
+ gco = object_registration.GLOBAL_CUSTOM_OBJECTS
111
+ tlco = global_state.get_global_attribute("custom_objects_scope_dict", {})
112
+ custom_objects = {**custom_objects, **gco, **tlco}
113
+
114
+ opened_new_file = not isinstance(filepath, h5py.File)
115
+ if opened_new_file:
116
+ f = h5py.File(filepath, mode="r")
117
+ else:
118
+ f = filepath
119
+
120
+ model = None
121
+ try:
122
+ # instantiate model
123
+ model_config = f.attrs.get("model_config")
124
+ if model_config is None:
125
+ raise ValueError(
126
+ f"No model config found in the file at {filepath}."
127
+ )
128
+ if hasattr(model_config, "decode"):
129
+ model_config = model_config.decode("utf-8")
130
+ model_config = json_utils.decode(model_config)
131
+
132
+ with saving_options.keras_option_scope(use_legacy_config=True):
133
+ model = saving_utils.model_from_config(
134
+ model_config, custom_objects=custom_objects
135
+ )
136
+
137
+ # set weights
138
+ load_weights_from_hdf5_group(f["model_weights"], model)
139
+
140
+ if compile:
141
+ # instantiate optimizer
142
+ training_config = f.attrs.get("training_config")
143
+ if hasattr(training_config, "decode"):
144
+ training_config = training_config.decode("utf-8")
145
+ if training_config is None:
146
+ logging.warning(
147
+ "No training configuration found in the save file, so "
148
+ "the model was *not* compiled. Compile it manually."
149
+ )
150
+ return model
151
+ training_config = json_utils.decode(training_config)
152
+
153
+ # Compile model.
154
+ model.compile(
155
+ **saving_utils.compile_args_from_training_config(
156
+ training_config, custom_objects
157
+ )
158
+ )
159
+ saving_utils.try_build_compiled_arguments(model)
160
+
161
+ # Set optimizer weights.
162
+ if "optimizer_weights" in f:
163
+ try:
164
+ if isinstance(model.optimizer, optimizers.Optimizer):
165
+ model.optimizer.build(model._trainable_variables)
166
+ else:
167
+ model.optimizer._create_all_weights(
168
+ model._trainable_variables
169
+ )
170
+ except (NotImplementedError, AttributeError):
171
+ logging.warning(
172
+ "Error when creating the weights of optimizer {}, "
173
+ "making it impossible to restore the saved optimizer "
174
+ "state. As a result, your model is starting with "
175
+ "a freshly initialized optimizer."
176
+ )
177
+
178
+ optimizer_weight_values = (
179
+ load_optimizer_weights_from_hdf5_group(f)
180
+ )
181
+ try:
182
+ model.optimizer.set_weights(optimizer_weight_values)
183
+ except ValueError:
184
+ logging.warning(
185
+ "Error in loading the saved optimizer "
186
+ "state. As a result, your model is "
187
+ "starting with a freshly initialized "
188
+ "optimizer."
189
+ )
190
+ finally:
191
+ if opened_new_file:
192
+ f.close()
193
+ return model
194
+
195
+
196
+ def save_weights_to_hdf5_group(f, model):
197
+ """Saves the weights of a list of layers to a HDF5 group.
198
+
199
+ Args:
200
+ f: HDF5 group.
201
+ model: Model instance.
202
+ """
203
+ from keras.src import __version__ as keras_version
204
+
205
+ save_attributes_to_hdf5_group(
206
+ f, "layer_names", [layer.name.encode("utf8") for layer in model.layers]
207
+ )
208
+ f.attrs["backend"] = backend.backend().encode("utf8")
209
+ f.attrs["keras_version"] = str(keras_version).encode("utf8")
210
+
211
+ # Sort model layers by layer name to ensure that group names are strictly
212
+ # growing to avoid prefix issues.
213
+ for layer in sorted(model.layers, key=lambda x: x.name):
214
+ g = f.create_group(layer.name)
215
+ weights = _legacy_weights(layer)
216
+ save_subset_weights_to_hdf5_group(g, weights)
217
+ weights = list(
218
+ v
219
+ for v in model._trainable_variables + model._non_trainable_variables
220
+ if v in model.weights
221
+ )
222
+ g = f.create_group("top_level_model_weights")
223
+ save_subset_weights_to_hdf5_group(g, weights)
224
+
225
+
226
+ def save_subset_weights_to_hdf5_group(f, weights):
227
+ """Save top-level weights of a model to a HDF5 group.
228
+
229
+ Args:
230
+ f: HDF5 group.
231
+ weights: List of weight variables.
232
+ """
233
+ weight_values = [backend.convert_to_numpy(w) for w in weights]
234
+ weight_names = [str(w.path).encode("utf8") for w in weights]
235
+ save_attributes_to_hdf5_group(f, "weight_names", weight_names)
236
+ for name, val in zip(weight_names, weight_values):
237
+ param_dset = f.create_dataset(name, val.shape, dtype=val.dtype)
238
+ if not val.shape:
239
+ # scalar
240
+ param_dset[()] = val
241
+ else:
242
+ param_dset[:] = val
243
+
244
+
245
+ def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
246
+ """Saves optimizer weights of a optimizer to a HDF5 group.
247
+
248
+ Args:
249
+ hdf5_group: HDF5 group.
250
+ optimizer: optimizer instance.
251
+ """
252
+ if isinstance(optimizer, optimizers.Optimizer):
253
+ symbolic_weights = optimizer.variables
254
+ else:
255
+ symbolic_weights = getattr(optimizer, "weights")
256
+ if symbolic_weights:
257
+ weights_group = hdf5_group.create_group("optimizer_weights")
258
+ weight_names = [str(w.path).encode("utf8") for w in symbolic_weights]
259
+ save_attributes_to_hdf5_group(
260
+ weights_group, "weight_names", weight_names
261
+ )
262
+ weight_values = [backend.convert_to_numpy(w) for w in symbolic_weights]
263
+ for name, val in zip(weight_names, weight_values):
264
+ param_dset = weights_group.create_dataset(
265
+ name, val.shape, dtype=val.dtype
266
+ )
267
+ if not val.shape:
268
+ # scalar
269
+ param_dset[()] = val
270
+ else:
271
+ param_dset[:] = val
272
+
273
+
274
+ def save_attributes_to_hdf5_group(group, name, data):
275
+ """Saves attributes (data) of the specified name into the HDF5 group.
276
+
277
+ This method deals with an inherent problem of HDF5 file which is not
278
+ able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
279
+
280
+ Args:
281
+ group: A pointer to a HDF5 group.
282
+ name: A name of the attributes to save.
283
+ data: Attributes data to store.
284
+
285
+ Raises:
286
+ RuntimeError: If any single attribute is too large to be saved.
287
+ """
288
+ # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
289
+ # because in that case even chunking the array would not make the saving
290
+ # possible.
291
+ bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
292
+
293
+ # Expecting this to never be true.
294
+ if bad_attributes:
295
+ raise RuntimeError(
296
+ "The following attributes cannot be saved to HDF5 file because "
297
+ f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} "
298
+ f"bytes: {bad_attributes}"
299
+ )
300
+
301
+ data_npy = np.asarray(data)
302
+
303
+ num_chunks = 1
304
+ chunked_data = np.array_split(data_npy, num_chunks)
305
+
306
+ # This will never loop forever thanks to the test above.
307
+ while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
308
+ num_chunks += 1
309
+ chunked_data = np.array_split(data_npy, num_chunks)
310
+
311
+ if num_chunks > 1:
312
+ for chunk_id, chunk_data in enumerate(chunked_data):
313
+ group.attrs["%s%d" % (name, chunk_id)] = chunk_data
314
+ else:
315
+ group.attrs[name] = data
316
+
317
+
318
+ def load_weights_from_hdf5_group(f, model):
319
+ """Implements topological (order-based) weight loading.
320
+
321
+ Args:
322
+ f: A pointer to a HDF5 group.
323
+ model: Model instance.
324
+
325
+ Raises:
326
+ ValueError: in case of mismatch between provided layers
327
+ and weights file.
328
+ """
329
+ if "keras_version" in f.attrs:
330
+ original_keras_version = f.attrs["keras_version"]
331
+ if hasattr(original_keras_version, "decode"):
332
+ original_keras_version = original_keras_version.decode("utf8")
333
+ else:
334
+ original_keras_version = "1"
335
+ if "backend" in f.attrs:
336
+ original_backend = f.attrs["backend"]
337
+ if hasattr(original_backend, "decode"):
338
+ original_backend = original_backend.decode("utf8")
339
+ else:
340
+ original_backend = None
341
+
342
+ filtered_layers = []
343
+ for layer in model.layers:
344
+ weights = _legacy_weights(layer)
345
+ if weights:
346
+ filtered_layers.append(layer)
347
+
348
+ layer_names = load_attributes_from_hdf5_group(f, "layer_names")
349
+ filtered_layer_names = []
350
+ for name in layer_names:
351
+ g = f[name]
352
+ weight_names = load_attributes_from_hdf5_group(g, "weight_names")
353
+ if weight_names:
354
+ filtered_layer_names.append(name)
355
+ layer_names = filtered_layer_names
356
+ if len(layer_names) != len(filtered_layers):
357
+ raise ValueError(
358
+ "Layer count mismatch when loading weights from file. "
359
+ f"Model expected {len(filtered_layers)} layers, found "
360
+ f"{len(layer_names)} saved layers."
361
+ )
362
+
363
+ for k, name in enumerate(layer_names):
364
+ g = f[name]
365
+ layer = filtered_layers[k]
366
+ symbolic_weights = _legacy_weights(layer)
367
+ weight_values = load_subset_weights_from_hdf5_group(g)
368
+ if len(weight_values) != len(symbolic_weights):
369
+ raise ValueError(
370
+ f"Weight count mismatch for layer #{k} (named {layer.name} in "
371
+ f"the current model, {name} in the save file). "
372
+ f"Layer expects {len(symbolic_weights)} weight(s). Received "
373
+ f"{len(weight_values)} saved weight(s)"
374
+ )
375
+ _set_weights(
376
+ layer,
377
+ symbolic_weights,
378
+ weight_values,
379
+ name=f"layer #{k} (named {layer.name})",
380
+ )
381
+
382
+ if "top_level_model_weights" in f:
383
+ symbolic_weights = list(
384
+ # model.weights
385
+ v
386
+ for v in model._trainable_variables + model._non_trainable_variables
387
+ if v in model.weights
388
+ )
389
+ weight_values = load_subset_weights_from_hdf5_group(
390
+ f["top_level_model_weights"]
391
+ )
392
+ if len(weight_values) != len(symbolic_weights):
393
+ raise ValueError(
394
+ "Weight count mismatch for top-level weights when loading "
395
+ "weights from file. "
396
+ f"Model expects {len(symbolic_weights)} top-level weight(s). "
397
+ f"Received {len(weight_values)} saved top-level weight(s)"
398
+ )
399
+ _set_weights(
400
+ model,
401
+ symbolic_weights,
402
+ weight_values,
403
+ name="top-level model",
404
+ )
405
+
406
+
407
+ def _set_weights(
408
+ instance, symbolic_weights, weight_values, name, skip_mismatch=False
409
+ ):
410
+ """Safely set weights into a model or a layer.
411
+
412
+ Args:
413
+ instance: Model or layer instance,
414
+ symbolic_weights: symbolic tensors representing
415
+ the weights of the variables to load,
416
+ weight_values: values of the weights to load,
417
+ skip_mismatch: Boolean, whether to skip loading of weights
418
+ where there is a mismatch in the shape of the weights,
419
+ name: name used to identify the group.
420
+
421
+ Raises:
422
+ ValueError: in case of mismatch between provided
423
+ model/layer and weights.
424
+ """
425
+ for i, weight_value in enumerate(weight_values):
426
+ expected_shape = symbolic_weights[i].shape
427
+ received_shape = weight_value.shape
428
+ if expected_shape != received_shape:
429
+ if skip_mismatch:
430
+ warnings.warn(
431
+ f"Skipping loading weights for {name}"
432
+ f"due to mismatch in shape for "
433
+ f"weight {symbolic_weights[i].path}. "
434
+ f"Weight expects shape {expected_shape}. "
435
+ "Received saved weight "
436
+ f"with shape {received_shape}",
437
+ stacklevel=2,
438
+ )
439
+ continue
440
+ raise ValueError(
441
+ f"Shape mismatch in {name}"
442
+ f"for weight {symbolic_weights[i].path}. "
443
+ f"Weight expects shape {expected_shape}. "
444
+ "Received saved weight "
445
+ f"with shape {received_shape}"
446
+ )
447
+ symbolic_weights[i].assign(weight_value)
448
+
449
+ if hasattr(instance, "finalize_state") and symbolic_weights:
450
+ instance.finalize_state()
451
+
452
+
453
+ def load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False):
454
+ """Implements name-based weight loading (instead of topological loading).
455
+
456
+ Layers that have no matching name are skipped.
457
+
458
+ Args:
459
+ f: A pointer to a HDF5 group.
460
+ model: Model instance.
461
+ skip_mismatch: Boolean, whether to skip loading of layers
462
+ where there is a mismatch in the number of weights,
463
+ or a mismatch in the shape of the weights.
464
+
465
+ Raises:
466
+ ValueError: in case of mismatch between provided layers
467
+ and weights file and skip_match=False.
468
+ """
469
+ if "keras_version" in f.attrs:
470
+ original_keras_version = f.attrs["keras_version"]
471
+ if hasattr(original_keras_version, "decode"):
472
+ original_keras_version = original_keras_version.decode("utf8")
473
+ else:
474
+ original_keras_version = "1"
475
+ if "backend" in f.attrs:
476
+ original_backend = f.attrs["backend"]
477
+ if hasattr(original_backend, "decode"):
478
+ original_backend = original_backend.decode("utf8")
479
+ else:
480
+ original_backend = None
481
+
482
+ # New file format.
483
+ layer_names = load_attributes_from_hdf5_group(f, "layer_names")
484
+
485
+ # Reverse index of layer name to list of layers with name.
486
+ index = {}
487
+ for layer in model.layers:
488
+ if layer.name:
489
+ index.setdefault(layer.name, []).append(layer)
490
+
491
+ for k, name in enumerate(layer_names):
492
+ g = f[name]
493
+ weight_values = load_subset_weights_from_hdf5_group(g)
494
+ for layer in index.get(name, []):
495
+ symbolic_weights = _legacy_weights(layer)
496
+ if len(weight_values) != len(symbolic_weights):
497
+ if skip_mismatch:
498
+ warnings.warn(
499
+ f"Skipping loading of weights for layer #{k} (named "
500
+ f"{layer.name}) due to mismatch in number of weights. "
501
+ f"Layer expects {len(symbolic_weights)} weight(s). "
502
+ f"Received {len(weight_values)} saved weight(s)",
503
+ stacklevel=2,
504
+ )
505
+ continue
506
+ raise ValueError(
507
+ f"Weight count mismatch for layer #{k} "
508
+ f"(named {layer.name}). "
509
+ f"Layer expects {len(symbolic_weights)} weight(s). "
510
+ f"Received {len(weight_values)} saved weight(s)"
511
+ )
512
+ # Set values.
513
+ _set_weights(
514
+ layer,
515
+ symbolic_weights,
516
+ weight_values,
517
+ skip_mismatch=skip_mismatch,
518
+ name=f"layer #{k} (named {layer.name})",
519
+ )
520
+
521
+ if "top_level_model_weights" in f:
522
+ symbolic_weights = (
523
+ model._trainable_variables + model._non_trainable_variables
524
+ )
525
+ weight_values = load_subset_weights_from_hdf5_group(
526
+ f["top_level_model_weights"]
527
+ )
528
+
529
+ if len(weight_values) != len(symbolic_weights):
530
+ if skip_mismatch:
531
+ warnings.warn(
532
+ "Skipping loading top-level weights for model due to "
533
+ "mismatch in number of weights. "
534
+ f"Model expects {len(symbolic_weights)} "
535
+ "top-level weight(s). "
536
+ f"Received {len(weight_values)} saved top-level weight(s)",
537
+ stacklevel=2,
538
+ )
539
+ else:
540
+ raise ValueError(
541
+ "Weight count mismatch for top-level weights of model. "
542
+ f"Model expects {len(symbolic_weights)} "
543
+ "top-level weight(s). "
544
+ f"Received {len(weight_values)} saved top-level weight(s)"
545
+ )
546
+ else:
547
+ _set_weights(
548
+ model,
549
+ symbolic_weights,
550
+ weight_values,
551
+ skip_mismatch=skip_mismatch,
552
+ name="top-level model",
553
+ )
554
+
555
+
556
+ def load_subset_weights_from_hdf5_group(f):
557
+ """Load layer weights of a model from hdf5.
558
+
559
+ Args:
560
+ f: A pointer to a HDF5 group.
561
+
562
+ Returns:
563
+ List of NumPy arrays of the weight values.
564
+
565
+ Raises:
566
+ ValueError: in case of mismatch between provided model
567
+ and weights file.
568
+ """
569
+ weight_names = load_attributes_from_hdf5_group(f, "weight_names")
570
+ return [np.asarray(f[weight_name]) for weight_name in weight_names]
571
+
572
+
573
+ def load_optimizer_weights_from_hdf5_group(hdf5_group):
574
+ """Load optimizer weights from a HDF5 group.
575
+
576
+ Args:
577
+ hdf5_group: A pointer to a HDF5 group.
578
+
579
+ Returns:
580
+ data: List of optimizer weight names.
581
+ """
582
+ weights_group = hdf5_group["optimizer_weights"]
583
+ optimizer_weight_names = load_attributes_from_hdf5_group(
584
+ weights_group, "weight_names"
585
+ )
586
+ return [
587
+ weights_group[weight_name] for weight_name in optimizer_weight_names
588
+ ]
589
+
590
+
591
+ def load_attributes_from_hdf5_group(group, name):
592
+ """Loads attributes of the specified name from the HDF5 group.
593
+
594
+ This method deals with an inherent problem
595
+ of HDF5 file which is not able to store
596
+ data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
597
+
598
+ Args:
599
+ group: A pointer to a HDF5 group.
600
+ name: A name of the attributes to load.
601
+
602
+ Returns:
603
+ data: Attributes data.
604
+ """
605
+ if name in group.attrs:
606
+ data = [
607
+ n.decode("utf8") if hasattr(n, "decode") else n
608
+ for n in group.attrs[name]
609
+ ]
610
+ else:
611
+ data = []
612
+ chunk_id = 0
613
+ while f"{name}{chunk_id}" in group.attrs:
614
+ data.extend(
615
+ [
616
+ n.decode("utf8") if hasattr(n, "decode") else n
617
+ for n in group.attrs[f"{name}{chunk_id}"]
618
+ ]
619
+ )
620
+ chunk_id += 1
621
+ return data
622
+
623
+
624
+ def _legacy_weights(layer):
625
+ """Legacy weight order converter.
626
+
627
+ For legacy reason, the layer.weights was in the order of
628
+ [self.trainable_weights + self.non_trainable_weights], and this order was
629
+ used for preserving the weights in h5 format. The new order of layer.weights
630
+ are the same as layer.get_weights() which is more intuitive for user. To
631
+ keep supporting the existing saved h5 file, this method should be used to
632
+ save/load weights.
633
+
634
+ Args:
635
+ layer: a `Model` or `Layer` instance.
636
+
637
+ Returns:
638
+ A list of variables with the legacy weight order.
639
+ """
640
+ return layer.trainable_weights + layer.non_trainable_weights
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/saving_options.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+ from keras.src.backend.common import global_state
4
+
5
+
6
+ @contextlib.contextmanager
7
+ def keras_option_scope(use_legacy_config=True):
8
+ use_legacy_config_prev_value = global_state.get_global_attribute(
9
+ "use_legacy_config", None
10
+ )
11
+ global_state.set_global_attribute("use_legacy_config", use_legacy_config)
12
+ try:
13
+ yield
14
+ finally:
15
+ global_state.set_global_attribute(
16
+ "use_legacy_config", use_legacy_config_prev_value
17
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/saving_utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import threading
3
+
4
+ from absl import logging
5
+
6
+ from keras.src import backend
7
+ from keras.src import layers
8
+ from keras.src import losses
9
+ from keras.src import metrics as metrics_module
10
+ from keras.src import models
11
+ from keras.src import optimizers
12
+ from keras.src import tree
13
+ from keras.src.legacy.saving import serialization
14
+ from keras.src.saving import object_registration
15
+
16
+ MODULE_OBJECTS = threading.local()
17
+
18
+ # Legacy lambda arguments not found in Keras 3
19
+ LAMBDA_DEP_ARGS = (
20
+ "module",
21
+ "function_type",
22
+ "output_shape_type",
23
+ "output_shape_module",
24
+ )
25
+
26
+
27
+ def model_from_config(config, custom_objects=None):
28
+ """Instantiates a Keras model from its config.
29
+
30
+ Args:
31
+ config: Configuration dictionary.
32
+ custom_objects: Optional dictionary mapping names
33
+ (strings) to custom classes or functions to be
34
+ considered during deserialization.
35
+
36
+ Returns:
37
+ A Keras model instance (uncompiled).
38
+
39
+ Raises:
40
+ TypeError: if `config` is not a dictionary.
41
+ """
42
+ if isinstance(config, list):
43
+ raise TypeError(
44
+ "`model_from_config` expects a dictionary, not a list. "
45
+ f"Received: config={config}. Did you meant to use "
46
+ "`Sequential.from_config(config)`?"
47
+ )
48
+
49
+ global MODULE_OBJECTS
50
+
51
+ if not hasattr(MODULE_OBJECTS, "ALL_OBJECTS"):
52
+ MODULE_OBJECTS.ALL_OBJECTS = layers.__dict__
53
+ MODULE_OBJECTS.ALL_OBJECTS["InputLayer"] = layers.InputLayer
54
+ MODULE_OBJECTS.ALL_OBJECTS["Functional"] = models.Functional
55
+ MODULE_OBJECTS.ALL_OBJECTS["Model"] = models.Model
56
+ MODULE_OBJECTS.ALL_OBJECTS["Sequential"] = models.Sequential
57
+
58
+ batch_input_shape = config["config"].pop("batch_input_shape", None)
59
+ if batch_input_shape is not None:
60
+ if config["class_name"] == "InputLayer":
61
+ config["config"]["batch_shape"] = batch_input_shape
62
+ else:
63
+ config["config"]["input_shape"] = batch_input_shape
64
+
65
+ axis = config["config"].pop("axis", None)
66
+ if axis is not None and isinstance(axis, list) and len(axis) == 1:
67
+ config["config"]["axis"] = int(axis[0])
68
+
69
+ # Handle backwards compatibility for Keras lambdas
70
+ if config["class_name"] == "Lambda":
71
+ for dep_arg in LAMBDA_DEP_ARGS:
72
+ _ = config["config"].pop(dep_arg, None)
73
+ function_config = config["config"]["function"]
74
+ if isinstance(function_config, list):
75
+ function_dict = {"class_name": "__lambda__", "config": {}}
76
+ function_dict["config"]["code"] = function_config[0]
77
+ function_dict["config"]["defaults"] = function_config[1]
78
+ function_dict["config"]["closure"] = function_config[2]
79
+ config["config"]["function"] = function_dict
80
+
81
+ # TODO(nkovela): Swap find and replace args during Keras 3.0 release
82
+ # Replace keras refs with keras
83
+ config = _find_replace_nested_dict(config, "keras.", "keras.")
84
+
85
+ return serialization.deserialize_keras_object(
86
+ config,
87
+ module_objects=MODULE_OBJECTS.ALL_OBJECTS,
88
+ custom_objects=custom_objects,
89
+ printable_module_name="layer",
90
+ )
91
+
92
+
93
+ def model_metadata(model, include_optimizer=True, require_config=True):
94
+ """Returns a dictionary containing the model metadata."""
95
+ from keras.src import __version__ as keras_version
96
+
97
+ model_config = {"class_name": model.__class__.__name__}
98
+ try:
99
+ model_config["config"] = model.get_config()
100
+ except NotImplementedError as e:
101
+ if require_config:
102
+ raise e
103
+
104
+ metadata = dict(
105
+ keras_version=str(keras_version),
106
+ backend=backend.backend(),
107
+ model_config=model_config,
108
+ )
109
+ if getattr(model, "optimizer", False) and include_optimizer:
110
+ if model.compiled:
111
+ training_config = model._compile_config.config
112
+ training_config.pop("optimizer", None) # Handled separately.
113
+ metadata["training_config"] = _serialize_nested_config(
114
+ training_config
115
+ )
116
+ optimizer_config = {
117
+ "class_name": object_registration.get_registered_name(
118
+ model.optimizer.__class__
119
+ ),
120
+ "config": model.optimizer.get_config(),
121
+ }
122
+ metadata["training_config"]["optimizer_config"] = optimizer_config
123
+ return metadata
124
+
125
+
126
+ def compile_args_from_training_config(training_config, custom_objects=None):
127
+ """Return model.compile arguments from training config."""
128
+ if custom_objects is None:
129
+ custom_objects = {}
130
+
131
+ with object_registration.CustomObjectScope(custom_objects):
132
+ optimizer_config = training_config["optimizer_config"]
133
+ optimizer = optimizers.deserialize(optimizer_config)
134
+ # Ensure backwards compatibility for optimizers in legacy H5 files
135
+ optimizer = _resolve_compile_arguments_compat(
136
+ optimizer, optimizer_config, optimizers
137
+ )
138
+
139
+ # Recover losses.
140
+ loss = None
141
+ loss_config = training_config.get("loss", None)
142
+ if loss_config is not None:
143
+ loss = _deserialize_nested_config(losses.deserialize, loss_config)
144
+ # Ensure backwards compatibility for losses in legacy H5 files
145
+ loss = _resolve_compile_arguments_compat(loss, loss_config, losses)
146
+
147
+ # Recover metrics.
148
+ metrics = None
149
+ metrics_config = training_config.get("metrics", None)
150
+ if metrics_config is not None:
151
+ metrics = _deserialize_nested_config(
152
+ _deserialize_metric, metrics_config
153
+ )
154
+ # Ensure backwards compatibility for metrics in legacy H5 files
155
+ metrics = _resolve_compile_arguments_compat(
156
+ metrics, metrics_config, metrics_module
157
+ )
158
+
159
+ # Recover weighted metrics.
160
+ weighted_metrics = None
161
+ weighted_metrics_config = training_config.get("weighted_metrics", None)
162
+ if weighted_metrics_config is not None:
163
+ weighted_metrics = _deserialize_nested_config(
164
+ _deserialize_metric, weighted_metrics_config
165
+ )
166
+
167
+ loss_weights = training_config["loss_weights"]
168
+
169
+ return dict(
170
+ optimizer=optimizer,
171
+ loss=loss,
172
+ metrics=metrics,
173
+ weighted_metrics=weighted_metrics,
174
+ loss_weights=loss_weights,
175
+ )
176
+
177
+
178
+ def _serialize_nested_config(config):
179
+ """Serialized a nested structure of Keras objects."""
180
+
181
+ def _serialize_fn(obj):
182
+ if callable(obj):
183
+ return serialization.serialize_keras_object(obj)
184
+ return obj
185
+
186
+ return tree.map_structure(_serialize_fn, config)
187
+
188
+
189
+ def _deserialize_nested_config(deserialize_fn, config):
190
+ """Deserializes arbitrary Keras `config` using `deserialize_fn`."""
191
+
192
+ def _is_single_object(obj):
193
+ if isinstance(obj, dict) and "class_name" in obj:
194
+ return True # Serialized Keras object.
195
+ if isinstance(obj, str):
196
+ return True # Serialized function or string.
197
+ return False
198
+
199
+ if config is None:
200
+ return None
201
+ if _is_single_object(config):
202
+ return deserialize_fn(config)
203
+ elif isinstance(config, dict):
204
+ return {
205
+ k: _deserialize_nested_config(deserialize_fn, v)
206
+ for k, v in config.items()
207
+ }
208
+ elif isinstance(config, (tuple, list)):
209
+ return [
210
+ _deserialize_nested_config(deserialize_fn, obj) for obj in config
211
+ ]
212
+
213
+ raise ValueError(
214
+ "Saved configuration not understood. Configuration should be a "
215
+ f"dictionary, string, tuple or list. Received: config={config}."
216
+ )
217
+
218
+
219
+ def _deserialize_metric(metric_config):
220
+ """Deserialize metrics, leaving special strings untouched."""
221
+ if metric_config in ["accuracy", "acc", "crossentropy", "ce"]:
222
+ # Do not deserialize accuracy and cross-entropy strings as we have
223
+ # special case handling for these in compile, based on model output
224
+ # shape.
225
+ return metric_config
226
+ return metrics_module.deserialize(metric_config)
227
+
228
+
229
+ def _find_replace_nested_dict(config, find, replace):
230
+ dict_str = json.dumps(config)
231
+ dict_str = dict_str.replace(find, replace)
232
+ config = json.loads(dict_str)
233
+ return config
234
+
235
+
236
+ def _resolve_compile_arguments_compat(obj, obj_config, module):
237
+ """Resolves backwards compatibility issues with training config arguments.
238
+
239
+ This helper function accepts built-in Keras modules such as optimizers,
240
+ losses, and metrics to ensure an object being deserialized is compatible
241
+ with Keras 3 built-ins. For legacy H5 files saved within Keras 3,
242
+ this does nothing.
243
+ """
244
+ if isinstance(obj, str) and obj not in module.ALL_OBJECTS_DICT:
245
+ obj = module.get(obj_config["config"]["name"])
246
+ return obj
247
+
248
+
249
+ def try_build_compiled_arguments(model):
250
+ try:
251
+ if not model.compiled_loss.built:
252
+ model.compiled_loss.build(model.outputs)
253
+ if not model.compiled_metrics.built:
254
+ model.compiled_metrics.build(model.outputs, model.outputs)
255
+ except:
256
+ logging.warning(
257
+ "Compiled the loaded model, but the compiled metrics have "
258
+ "yet to be built. `model.compile_metrics` will be empty "
259
+ "until you train or evaluate the model."
260
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/legacy/saving/serialization.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Legacy serialization logic for Keras models."""
2
+
3
+ import contextlib
4
+ import inspect
5
+ import json
6
+ import threading
7
+ import weakref
8
+
9
+ # isort: off
10
+ from keras.src.api_export import keras_export
11
+ from keras.src.saving import object_registration
12
+
13
+ # Flag that determines whether to skip the NotImplementedError when calling
14
+ # get_config in custom models and layers. This is only enabled when saving to
15
+ # SavedModel, when the config isn't required.
16
+ _SKIP_FAILED_SERIALIZATION = False
17
+ # If a layer does not have a defined config, then the returned config will be a
18
+ # dictionary with the below key.
19
+ _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"
20
+
21
+ # Store a unique, per-object ID for shared objects.
22
+ #
23
+ # We store a unique ID for each object so that we may, at loading time,
24
+ # re-create the network properly. Without this ID, we would have no way of
25
+ # determining whether a config is a description of a new object that
26
+ # should be created or is merely a reference to an already-created object.
27
+ SHARED_OBJECT_KEY = "shared_object_id"
28
+
29
+ SHARED_OBJECT_DISABLED = threading.local()
30
+ SHARED_OBJECT_LOADING = threading.local()
31
+ SHARED_OBJECT_SAVING = threading.local()
32
+
33
+
34
+ # Attributes on the threadlocal variable must be set per-thread, thus we
35
+ # cannot initialize these globally. Instead, we have accessor functions with
36
+ # default values.
37
+ def _shared_object_disabled():
38
+ """Get whether shared object handling is disabled in a threadsafe manner."""
39
+ return getattr(SHARED_OBJECT_DISABLED, "disabled", False)
40
+
41
+
42
+ def _shared_object_loading_scope():
43
+ """Get the current shared object saving scope in a threadsafe manner."""
44
+ return getattr(SHARED_OBJECT_LOADING, "scope", NoopLoadingScope())
45
+
46
+
47
+ def _shared_object_saving_scope():
48
+ """Get the current shared object saving scope in a threadsafe manner."""
49
+ return getattr(SHARED_OBJECT_SAVING, "scope", None)
50
+
51
+
52
+ class DisableSharedObjectScope:
53
+ """A context manager for disabling handling of shared objects.
54
+
55
+ Disables shared object handling for both saving and loading.
56
+
57
+ Created primarily for use with `clone_model`, which does extra surgery that
58
+ is incompatible with shared objects.
59
+ """
60
+
61
+ def __enter__(self):
62
+ SHARED_OBJECT_DISABLED.disabled = True
63
+ self._orig_loading_scope = _shared_object_loading_scope()
64
+ self._orig_saving_scope = _shared_object_saving_scope()
65
+
66
+ def __exit__(self, *args, **kwargs):
67
+ SHARED_OBJECT_DISABLED.disabled = False
68
+ SHARED_OBJECT_LOADING.scope = self._orig_loading_scope
69
+ SHARED_OBJECT_SAVING.scope = self._orig_saving_scope
70
+
71
+
72
+ class NoopLoadingScope:
73
+ """The default shared object loading scope. It does nothing.
74
+
75
+ Created to simplify serialization code that doesn't care about shared
76
+ objects (e.g. when serializing a single object).
77
+ """
78
+
79
+ def get(self, unused_object_id):
80
+ return None
81
+
82
+ def set(self, object_id, obj):
83
+ pass
84
+
85
+
86
+ class SharedObjectLoadingScope:
87
+ """A context manager for keeping track of loaded objects.
88
+
89
+ During the deserialization process, we may come across objects that are
90
+ shared across multiple layers. In order to accurately restore the network
91
+ structure to its original state, `SharedObjectLoadingScope` allows us to
92
+ re-use shared objects rather than cloning them.
93
+ """
94
+
95
+ def __enter__(self):
96
+ if _shared_object_disabled():
97
+ return NoopLoadingScope()
98
+
99
+ global SHARED_OBJECT_LOADING
100
+ SHARED_OBJECT_LOADING.scope = self
101
+ self._obj_ids_to_obj = {}
102
+ return self
103
+
104
+ def get(self, object_id):
105
+ """Given a shared object ID, returns a previously instantiated object.
106
+
107
+ Args:
108
+ object_id: shared object ID to use when attempting to find
109
+ already-loaded object.
110
+
111
+ Returns:
112
+ The object, if we've seen this ID before. Else, `None`.
113
+ """
114
+ # Explicitly check for `None` internally to make external calling code a
115
+ # bit cleaner.
116
+ if object_id is None:
117
+ return
118
+ return self._obj_ids_to_obj.get(object_id)
119
+
120
+ def set(self, object_id, obj):
121
+ """Stores an instantiated object for future lookup and sharing."""
122
+ if object_id is None:
123
+ return
124
+ self._obj_ids_to_obj[object_id] = obj
125
+
126
+ def __exit__(self, *args, **kwargs):
127
+ global SHARED_OBJECT_LOADING
128
+ SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
129
+
130
+
131
+ class SharedObjectConfig(dict):
132
+ """A configuration container that keeps track of references.
133
+
134
+ `SharedObjectConfig` will automatically attach a shared object ID to any
135
+ configs which are referenced more than once, allowing for proper shared
136
+ object reconstruction at load time.
137
+
138
+ In most cases, it would be more proper to subclass something like
139
+ `collections.UserDict` or `collections.Mapping` rather than `dict` directly.
140
+ Unfortunately, python's json encoder does not support `Mapping`s. This is
141
+ important functionality to retain, since we are dealing with serialization.
142
+
143
+ We should be safe to subclass `dict` here, since we aren't actually
144
+ overriding any core methods, only augmenting with a new one for reference
145
+ counting.
146
+ """
147
+
148
+ def __init__(self, base_config, object_id, **kwargs):
149
+ self.ref_count = 1
150
+ self.object_id = object_id
151
+ super().__init__(base_config, **kwargs)
152
+
153
+ def increment_ref_count(self):
154
+ # As soon as we've seen the object more than once, we want to attach the
155
+ # shared object ID. This allows us to only attach the shared object ID
156
+ # when it's strictly necessary, making backwards compatibility breakage
157
+ # less likely.
158
+ if self.ref_count == 1:
159
+ self[SHARED_OBJECT_KEY] = self.object_id
160
+ self.ref_count += 1
161
+
162
+
163
+ class SharedObjectSavingScope:
164
+ """Keeps track of shared object configs when serializing."""
165
+
166
+ def __enter__(self):
167
+ if _shared_object_disabled():
168
+ return None
169
+
170
+ global SHARED_OBJECT_SAVING
171
+
172
+ # Serialization can happen at a number of layers for a number of
173
+ # reasons. We may end up with a case where we're opening a saving scope
174
+ # within another saving scope. In that case, we'd like to use the
175
+ # outermost scope available and ignore inner scopes, since there is not
176
+ # (yet) a reasonable use case for having these nested and distinct.
177
+ if _shared_object_saving_scope() is not None:
178
+ self._passthrough = True
179
+ return _shared_object_saving_scope()
180
+ else:
181
+ self._passthrough = False
182
+
183
+ SHARED_OBJECT_SAVING.scope = self
184
+ self._shared_objects_config = weakref.WeakKeyDictionary()
185
+ self._next_id = 0
186
+ return self
187
+
188
+ def get_config(self, obj):
189
+ """Gets a `SharedObjectConfig` if one has already been seen for `obj`.
190
+
191
+ Args:
192
+ obj: The object for which to retrieve the `SharedObjectConfig`.
193
+
194
+ Returns:
195
+ The SharedObjectConfig for a given object, if already seen. Else,
196
+ `None`.
197
+ """
198
+ try:
199
+ shared_object_config = self._shared_objects_config[obj]
200
+ except (TypeError, KeyError):
201
+ # If the object is unhashable (e.g. a subclass of
202
+ # `AbstractBaseClass` that has not overridden `__hash__`), a
203
+ # `TypeError` will be thrown. We'll just continue on without shared
204
+ # object support.
205
+ return None
206
+ shared_object_config.increment_ref_count()
207
+ return shared_object_config
208
+
209
+ def create_config(self, base_config, obj):
210
+ """Create a new SharedObjectConfig for a given object."""
211
+ shared_object_config = SharedObjectConfig(base_config, self._next_id)
212
+ self._next_id += 1
213
+ try:
214
+ self._shared_objects_config[obj] = shared_object_config
215
+ except TypeError:
216
+ # If the object is unhashable (e.g. a subclass of
217
+ # `AbstractBaseClass` that has not overridden `__hash__`), a
218
+ # `TypeError` will be thrown. We'll just continue on without shared
219
+ # object support.
220
+ pass
221
+ return shared_object_config
222
+
223
+ def __exit__(self, *args, **kwargs):
224
+ if not getattr(self, "_passthrough", False):
225
+ global SHARED_OBJECT_SAVING
226
+ SHARED_OBJECT_SAVING.scope = None
227
+
228
+
229
+ def serialize_keras_class_and_config(
230
+ cls_name, cls_config, obj=None, shared_object_id=None
231
+ ):
232
+ """Returns the serialization of the class with the given config."""
233
+ base_config = {"class_name": cls_name, "config": cls_config}
234
+
235
+ # We call `serialize_keras_class_and_config` for some branches of the load
236
+ # path. In that case, we may already have a shared object ID we'd like to
237
+ # retain.
238
+ if shared_object_id is not None:
239
+ base_config[SHARED_OBJECT_KEY] = shared_object_id
240
+
241
+ # If we have an active `SharedObjectSavingScope`, check whether we've
242
+ # already serialized this config. If so, just use that config. This will
243
+ # store an extra ID field in the config, allowing us to re-create the shared
244
+ # object relationship at load time.
245
+ if _shared_object_saving_scope() is not None and obj is not None:
246
+ shared_object_config = _shared_object_saving_scope().get_config(obj)
247
+ if shared_object_config is None:
248
+ return _shared_object_saving_scope().create_config(base_config, obj)
249
+ return shared_object_config
250
+
251
+ return base_config
252
+
253
+
254
+ @contextlib.contextmanager
255
+ def skip_failed_serialization():
256
+ global _SKIP_FAILED_SERIALIZATION
257
+ prev = _SKIP_FAILED_SERIALIZATION
258
+ try:
259
+ _SKIP_FAILED_SERIALIZATION = True
260
+ yield
261
+ finally:
262
+ _SKIP_FAILED_SERIALIZATION = prev
263
+
264
+
265
+ @keras_export(
266
+ [
267
+ "keras.legacy.saving.serialize_keras_object",
268
+ "keras.utils.legacy.serialize_keras_object",
269
+ ]
270
+ )
271
+ def serialize_keras_object(instance):
272
+ """Serialize a Keras object into a JSON-compatible representation.
273
+
274
+ Calls to `serialize_keras_object` while underneath the
275
+ `SharedObjectSavingScope` context manager will cause any objects re-used
276
+ across multiple layers to be saved with a special shared object ID. This
277
+ allows the network to be re-created properly during deserialization.
278
+
279
+ Args:
280
+ instance: The object to serialize.
281
+
282
+ Returns:
283
+ A dict-like, JSON-compatible representation of the object's config.
284
+ """
285
+
286
+ # _, instance = tf.__internal__.decorator.unwrap(instance)
287
+ instance = inspect.unwrap(instance)
288
+ if instance is None:
289
+ return None
290
+
291
+ if hasattr(instance, "get_config"):
292
+ name = object_registration.get_registered_name(instance.__class__)
293
+ try:
294
+ config = instance.get_config()
295
+ except NotImplementedError as e:
296
+ if _SKIP_FAILED_SERIALIZATION:
297
+ return serialize_keras_class_and_config(
298
+ name, {_LAYER_UNDEFINED_CONFIG_KEY: True}
299
+ )
300
+ raise e
301
+ serialization_config = {}
302
+ for key, item in config.items():
303
+ if isinstance(item, str):
304
+ serialization_config[key] = item
305
+ continue
306
+
307
+ # Any object of a different type needs to be converted to string or
308
+ # dict for serialization (e.g. custom functions, custom classes)
309
+ try:
310
+ serialized_item = serialize_keras_object(item)
311
+ if isinstance(serialized_item, dict) and not isinstance(
312
+ item, dict
313
+ ):
314
+ serialized_item["__passive_serialization__"] = True
315
+ serialization_config[key] = serialized_item
316
+ except ValueError:
317
+ serialization_config[key] = item
318
+
319
+ name = object_registration.get_registered_name(instance.__class__)
320
+ return serialize_keras_class_and_config(
321
+ name, serialization_config, instance
322
+ )
323
+ if hasattr(instance, "__name__"):
324
+ return object_registration.get_registered_name(instance)
325
+ raise ValueError(
326
+ f"Cannot serialize {instance} because it doesn't implement "
327
+ "`get_config()`."
328
+ )
329
+
330
+
331
+ def class_and_config_for_serialized_keras_object(
332
+ config,
333
+ module_objects=None,
334
+ custom_objects=None,
335
+ printable_module_name="object",
336
+ ):
337
+ """Returns the class name and config for a serialized keras object."""
338
+
339
+ if (
340
+ not isinstance(config, dict)
341
+ or "class_name" not in config
342
+ or "config" not in config
343
+ ):
344
+ raise ValueError(
345
+ f"Improper config format for {config}. "
346
+ "Expecting python dict contains `class_name` and `config` as keys"
347
+ )
348
+
349
+ class_name = config["class_name"]
350
+ cls = object_registration.get_registered_object(
351
+ class_name, custom_objects, module_objects
352
+ )
353
+ if cls is None:
354
+ raise ValueError(
355
+ f"Unknown {printable_module_name}: '{class_name}'. "
356
+ "Please ensure you are using a `keras.utils.custom_object_scope` "
357
+ "and that this object is included in the scope. See "
358
+ "https://www.tensorflow.org/guide/keras/save_and_serialize"
359
+ "#registering_the_custom_object for details."
360
+ )
361
+
362
+ cls_config = config["config"]
363
+ # Check if `cls_config` is a list. If it is a list, return the class and the
364
+ # associated class configs for recursively deserialization. This case will
365
+ # happen on the old version of sequential model (e.g. `keras_version` ==
366
+ # "2.0.6"), which is serialized in a different structure, for example
367
+ # "{'class_name': 'Sequential',
368
+ # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
369
+ if isinstance(cls_config, list):
370
+ return (cls, cls_config)
371
+
372
+ deserialized_objects = {}
373
+ for key, item in cls_config.items():
374
+ if key == "name":
375
+ # Assume that the value of 'name' is a string that should not be
376
+ # deserialized as a function. This avoids the corner case where
377
+ # cls_config['name'] has an identical name to a custom function and
378
+ # gets converted into that function.
379
+ deserialized_objects[key] = item
380
+ elif isinstance(item, dict) and "__passive_serialization__" in item:
381
+ deserialized_objects[key] = deserialize_keras_object(
382
+ item,
383
+ module_objects=module_objects,
384
+ custom_objects=custom_objects,
385
+ printable_module_name="config_item",
386
+ )
387
+ # TODO(momernick): Should this also have 'module_objects'?
388
+ elif isinstance(item, str) and inspect.isfunction(
389
+ object_registration.get_registered_object(item, custom_objects)
390
+ ):
391
+ # Handle custom functions here. When saving functions, we only save
392
+ # the function's name as a string. If we find a matching string in
393
+ # the custom objects during deserialization, we convert the string
394
+ # back to the original function.
395
+ # Note that a potential issue is that a string field could have a
396
+ # naming conflict with a custom function name, but this should be a
397
+ # rare case. This issue does not occur if a string field has a
398
+ # naming conflict with a custom object, since the config of an
399
+ # object will always be a dict.
400
+ deserialized_objects[key] = (
401
+ object_registration.get_registered_object(item, custom_objects)
402
+ )
403
+ for key, item in deserialized_objects.items():
404
+ cls_config[key] = deserialized_objects[key]
405
+
406
+ return (cls, cls_config)
407
+
408
+
409
+ @keras_export(
410
+ [
411
+ "keras.legacy.saving.deserialize_keras_object",
412
+ "keras.utils.legacy.deserialize_keras_object",
413
+ ]
414
+ )
415
+ def deserialize_keras_object(
416
+ identifier,
417
+ module_objects=None,
418
+ custom_objects=None,
419
+ printable_module_name="object",
420
+ ):
421
+ """Turns the serialized form of a Keras object back into an actual object.
422
+
423
+ This function is for mid-level library implementers rather than end users.
424
+
425
+ Importantly, this utility requires you to provide the dict of
426
+ `module_objects` to use for looking up the object config; this is not
427
+ populated by default. If you need a deserialization utility that has
428
+ preexisting knowledge of built-in Keras objects, use e.g.
429
+ `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`,
430
+ etc.
431
+
432
+ Calling `deserialize_keras_object` while underneath the
433
+ `SharedObjectLoadingScope` context manager will cause any already-seen
434
+ shared objects to be returned as-is rather than creating a new object.
435
+
436
+ Args:
437
+ identifier: the serialized form of the object.
438
+ module_objects: A dictionary of built-in objects to look the name up in.
439
+ Generally, `module_objects` is provided by midlevel library
440
+ implementers.
441
+ custom_objects: A dictionary of custom objects to look the name up in.
442
+ Generally, `custom_objects` is provided by the end user.
443
+ printable_module_name: A human-readable string representing the type of
444
+ the object. Printed in case of exception.
445
+
446
+ Returns:
447
+ The deserialized object.
448
+
449
+ Example:
450
+
451
+ A mid-level library implementer might want to implement a utility for
452
+ retrieving an object from its config, as such:
453
+
454
+ ```python
455
+ def deserialize(config, custom_objects=None):
456
+ return deserialize_keras_object(
457
+ identifier,
458
+ module_objects=globals(),
459
+ custom_objects=custom_objects,
460
+ name="MyObjectType",
461
+ )
462
+ ```
463
+
464
+ This is how e.g. `keras.layers.deserialize()` is implemented.
465
+ """
466
+
467
+ if identifier is None:
468
+ return None
469
+
470
+ if isinstance(identifier, dict):
471
+ # In this case we are dealing with a Keras config dictionary.
472
+ config = identifier
473
+ (cls, cls_config) = class_and_config_for_serialized_keras_object(
474
+ config, module_objects, custom_objects, printable_module_name
475
+ )
476
+
477
+ # If this object has already been loaded (i.e. it's shared between
478
+ # multiple objects), return the already-loaded object.
479
+ shared_object_id = config.get(SHARED_OBJECT_KEY)
480
+ shared_object = _shared_object_loading_scope().get(shared_object_id)
481
+ if shared_object is not None:
482
+ return shared_object
483
+
484
+ if hasattr(cls, "from_config"):
485
+ arg_spec = inspect.getfullargspec(cls.from_config)
486
+ custom_objects = custom_objects or {}
487
+
488
+ # TODO(nkovela): Swap find and replace args during Keras 3.0 release
489
+ # Replace keras refs with keras
490
+ cls_config = _find_replace_nested_dict(
491
+ cls_config, "keras.", "keras."
492
+ )
493
+
494
+ if "custom_objects" in arg_spec.args:
495
+ deserialized_obj = cls.from_config(
496
+ cls_config,
497
+ custom_objects={
498
+ **object_registration.GLOBAL_CUSTOM_OBJECTS,
499
+ **custom_objects,
500
+ },
501
+ )
502
+ else:
503
+ with object_registration.CustomObjectScope(custom_objects):
504
+ deserialized_obj = cls.from_config(cls_config)
505
+ else:
506
+ # Then `cls` may be a function returning a class.
507
+ # in this case by convention `config` holds
508
+ # the kwargs of the function.
509
+ custom_objects = custom_objects or {}
510
+ with object_registration.CustomObjectScope(custom_objects):
511
+ deserialized_obj = cls(**cls_config)
512
+
513
+ # Add object to shared objects, in case we find it referenced again.
514
+ _shared_object_loading_scope().set(shared_object_id, deserialized_obj)
515
+
516
+ return deserialized_obj
517
+
518
+ elif isinstance(identifier, str):
519
+ object_name = identifier
520
+ if custom_objects and object_name in custom_objects:
521
+ obj = custom_objects.get(object_name)
522
+ elif (
523
+ object_name
524
+ in object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__
525
+ ):
526
+ obj = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[
527
+ object_name
528
+ ]
529
+ elif object_name in object_registration._GLOBAL_CUSTOM_OBJECTS:
530
+ obj = object_registration._GLOBAL_CUSTOM_OBJECTS[object_name]
531
+ else:
532
+ obj = module_objects.get(object_name)
533
+ if obj is None:
534
+ raise ValueError(
535
+ f"Unknown {printable_module_name}: '{object_name}'. "
536
+ "Please ensure you are using a "
537
+ "`keras.utils.custom_object_scope` "
538
+ "and that this object is included in the scope. See "
539
+ "https://www.tensorflow.org/guide/keras/save_and_serialize"
540
+ "#registering_the_custom_object for details."
541
+ )
542
+
543
+ # Classes passed by name are instantiated with no args, functions are
544
+ # returned as-is.
545
+ if inspect.isclass(obj):
546
+ return obj()
547
+ return obj
548
+ elif inspect.isfunction(identifier):
549
+ # If a function has already been deserialized, return as is.
550
+ return identifier
551
+ else:
552
+ raise ValueError(
553
+ "Could not interpret serialized "
554
+ f"{printable_module_name}: {identifier}"
555
+ )
556
+
557
+
558
+ def validate_config(config):
559
+ """Determines whether config appears to be a valid layer config."""
560
+ return (
561
+ isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
562
+ )
563
+
564
+
565
+ def is_default(method):
566
+ """Check if a method is decorated with the `default` wrapper."""
567
+ return getattr(method, "_is_default", False)
568
+
569
+
570
+ def _find_replace_nested_dict(config, find, replace):
571
+ dict_str = json.dumps(config)
572
+ dict_str = dict_str.replace(find, replace)
573
+ config = json.loads(dict_str)
574
+ return config
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__init__.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.losses.loss import Loss
5
+ from keras.src.losses.losses import CTC
6
+ from keras.src.losses.losses import BinaryCrossentropy
7
+ from keras.src.losses.losses import BinaryFocalCrossentropy
8
+ from keras.src.losses.losses import CategoricalCrossentropy
9
+ from keras.src.losses.losses import CategoricalFocalCrossentropy
10
+ from keras.src.losses.losses import CategoricalHinge
11
+ from keras.src.losses.losses import Circle
12
+ from keras.src.losses.losses import CosineSimilarity
13
+ from keras.src.losses.losses import Dice
14
+ from keras.src.losses.losses import Hinge
15
+ from keras.src.losses.losses import Huber
16
+ from keras.src.losses.losses import KLDivergence
17
+ from keras.src.losses.losses import LogCosh
18
+ from keras.src.losses.losses import LossFunctionWrapper
19
+ from keras.src.losses.losses import MeanAbsoluteError
20
+ from keras.src.losses.losses import MeanAbsolutePercentageError
21
+ from keras.src.losses.losses import MeanSquaredError
22
+ from keras.src.losses.losses import MeanSquaredLogarithmicError
23
+ from keras.src.losses.losses import Poisson
24
+ from keras.src.losses.losses import SparseCategoricalCrossentropy
25
+ from keras.src.losses.losses import SquaredHinge
26
+ from keras.src.losses.losses import Tversky
27
+ from keras.src.losses.losses import binary_crossentropy
28
+ from keras.src.losses.losses import binary_focal_crossentropy
29
+ from keras.src.losses.losses import categorical_crossentropy
30
+ from keras.src.losses.losses import categorical_focal_crossentropy
31
+ from keras.src.losses.losses import categorical_hinge
32
+ from keras.src.losses.losses import circle
33
+ from keras.src.losses.losses import cosine_similarity
34
+ from keras.src.losses.losses import ctc
35
+ from keras.src.losses.losses import dice
36
+ from keras.src.losses.losses import hinge
37
+ from keras.src.losses.losses import huber
38
+ from keras.src.losses.losses import kl_divergence
39
+ from keras.src.losses.losses import log_cosh
40
+ from keras.src.losses.losses import mean_absolute_error
41
+ from keras.src.losses.losses import mean_absolute_percentage_error
42
+ from keras.src.losses.losses import mean_squared_error
43
+ from keras.src.losses.losses import mean_squared_logarithmic_error
44
+ from keras.src.losses.losses import poisson
45
+ from keras.src.losses.losses import sparse_categorical_crossentropy
46
+ from keras.src.losses.losses import squared_hinge
47
+ from keras.src.losses.losses import tversky
48
+ from keras.src.saving import serialization_lib
49
+
50
+ ALL_OBJECTS = {
51
+ # Base
52
+ Loss,
53
+ LossFunctionWrapper,
54
+ # Probabilistic
55
+ KLDivergence,
56
+ Poisson,
57
+ BinaryCrossentropy,
58
+ BinaryFocalCrossentropy,
59
+ CategoricalCrossentropy,
60
+ CategoricalFocalCrossentropy,
61
+ SparseCategoricalCrossentropy,
62
+ # Regression
63
+ MeanSquaredError,
64
+ MeanAbsoluteError,
65
+ MeanAbsolutePercentageError,
66
+ MeanSquaredLogarithmicError,
67
+ CosineSimilarity,
68
+ LogCosh,
69
+ Huber,
70
+ # Hinge
71
+ Hinge,
72
+ SquaredHinge,
73
+ CategoricalHinge,
74
+ # Image segmentation
75
+ Dice,
76
+ Tversky,
77
+ # Similarity
78
+ Circle,
79
+ # Sequence
80
+ CTC,
81
+ # Probabilistic
82
+ kl_divergence,
83
+ poisson,
84
+ binary_crossentropy,
85
+ binary_focal_crossentropy,
86
+ categorical_crossentropy,
87
+ categorical_focal_crossentropy,
88
+ sparse_categorical_crossentropy,
89
+ # Regression
90
+ mean_squared_error,
91
+ mean_absolute_error,
92
+ mean_absolute_percentage_error,
93
+ mean_squared_logarithmic_error,
94
+ cosine_similarity,
95
+ log_cosh,
96
+ huber,
97
+ # Hinge
98
+ hinge,
99
+ squared_hinge,
100
+ categorical_hinge,
101
+ # Image segmentation
102
+ dice,
103
+ tversky,
104
+ # Similarity
105
+ circle,
106
+ # Sequence
107
+ ctc,
108
+ }
109
+
110
+ ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
111
+ ALL_OBJECTS_DICT.update(
112
+ {
113
+ "bce": binary_crossentropy,
114
+ "BCE": binary_crossentropy,
115
+ "kld": kl_divergence,
116
+ "KLD": kl_divergence,
117
+ "mae": mean_absolute_error,
118
+ "MAE": mean_absolute_error,
119
+ "mse": mean_squared_error,
120
+ "MSE": mean_squared_error,
121
+ "mape": mean_absolute_percentage_error,
122
+ "MAPE": mean_absolute_percentage_error,
123
+ "msle": mean_squared_logarithmic_error,
124
+ "MSLE": mean_squared_logarithmic_error,
125
+ }
126
+ )
127
+
128
+
129
+ @keras_export("keras.losses.serialize")
130
+ def serialize(loss):
131
+ """Serializes loss function or `Loss` instance.
132
+
133
+ Args:
134
+ loss: A Keras `Loss` instance or a loss function.
135
+
136
+ Returns:
137
+ Loss configuration dictionary.
138
+ """
139
+ return serialization_lib.serialize_keras_object(loss)
140
+
141
+
142
+ @keras_export("keras.losses.deserialize")
143
+ def deserialize(name, custom_objects=None):
144
+ """Deserializes a serialized loss class/function instance.
145
+
146
+ Args:
147
+ name: Loss configuration.
148
+ custom_objects: Optional dictionary mapping names (strings) to custom
149
+ objects (classes and functions) to be considered during
150
+ deserialization.
151
+
152
+ Returns:
153
+ A Keras `Loss` instance or a loss function.
154
+ """
155
+ return serialization_lib.deserialize_keras_object(
156
+ name,
157
+ module_objects=ALL_OBJECTS_DICT,
158
+ custom_objects=custom_objects,
159
+ )
160
+
161
+
162
+ @keras_export("keras.losses.get")
163
+ def get(identifier):
164
+ """Retrieves a Keras loss as a `function`/`Loss` class instance.
165
+
166
+ The `identifier` may be the string name of a loss function or `Loss` class.
167
+
168
+ >>> loss = losses.get("categorical_crossentropy")
169
+ >>> type(loss)
170
+ <class 'function'>
171
+ >>> loss = losses.get("CategoricalCrossentropy")
172
+ >>> type(loss)
173
+ <class '...CategoricalCrossentropy'>
174
+
175
+ You can also specify `config` of the loss to this function by passing dict
176
+ containing `class_name` and `config` as an identifier. Also note that the
177
+ `class_name` must map to a `Loss` class
178
+
179
+ >>> identifier = {"class_name": "CategoricalCrossentropy",
180
+ ... "config": {"from_logits": True}}
181
+ >>> loss = losses.get(identifier)
182
+ >>> type(loss)
183
+ <class '...CategoricalCrossentropy'>
184
+
185
+ Args:
186
+ identifier: A loss identifier. One of None or string name of a loss
187
+ function/class or loss configuration dictionary or a loss function
188
+ or a loss class instance.
189
+
190
+ Returns:
191
+ A Keras loss as a `function`/ `Loss` class instance.
192
+ """
193
+ if identifier is None:
194
+ return None
195
+ if isinstance(identifier, dict):
196
+ obj = deserialize(identifier)
197
+ elif isinstance(identifier, str):
198
+ obj = ALL_OBJECTS_DICT.get(identifier, None)
199
+ else:
200
+ obj = identifier
201
+
202
+ if callable(obj):
203
+ if inspect.isclass(obj):
204
+ obj = obj()
205
+ return obj
206
+ else:
207
+ raise ValueError(f"Could not interpret loss identifier: {identifier}")
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.08 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/loss.cpython-310.pyc ADDED
Binary file (6.95 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/__pycache__/losses.cpython-310.pyc ADDED
Binary file (87.3 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/loss.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src import dtype_policies
3
+ from keras.src import ops
4
+ from keras.src import tree
5
+ from keras.src.api_export import keras_export
6
+ from keras.src.saving.keras_saveable import KerasSaveable
7
+ from keras.src.utils.naming import auto_name
8
+
9
+
10
+ @keras_export(["keras.Loss", "keras.losses.Loss"])
11
+ class Loss(KerasSaveable):
12
+ """Loss base class.
13
+
14
+ This is the class to subclass in order to create new custom losses.
15
+
16
+ Args:
17
+ reduction: Type of reduction to apply to the loss. In almost all cases
18
+ this should be `"sum_over_batch_size"`. Supported options are
19
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
20
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
21
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
22
+ sample size, and `"mean_with_sample_weight"` sums the loss and
23
+ divides by the sum of the sample weights. `"none"` and `None`
24
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
25
+ name: Optional name for the loss instance.
26
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
27
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
28
+ `"float32"` unless set to different value
29
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
30
+ provided, then the `compute_dtype` will be utilized.
31
+
32
+ To be implemented by subclasses:
33
+
34
+ * `call()`: Contains the logic for loss calculation using `y_true`,
35
+ `y_pred`.
36
+
37
+ Example subclass implementation:
38
+
39
+ ```python
40
+ class MeanSquaredError(Loss):
41
+ def call(self, y_true, y_pred):
42
+ return ops.mean(ops.square(y_pred - y_true), axis=-1)
43
+ ```
44
+ """
45
+
46
+ def __init__(self, name=None, reduction="sum_over_batch_size", dtype=None):
47
+ self.name = name or auto_name(self.__class__.__name__)
48
+ self.reduction = standardize_reduction(reduction)
49
+ self._dtype_policy = dtype_policies.get(dtype or backend.floatx())
50
+ self._dtype = self._dtype_policy.compute_dtype
51
+
52
+ @property
53
+ def dtype(self):
54
+ return self._dtype
55
+
56
+ def __call__(self, y_true, y_pred, sample_weight=None):
57
+ in_mask = backend.get_keras_mask(y_pred)
58
+
59
+ with ops.name_scope(self.name):
60
+ y_pred = tree.map_structure(
61
+ lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_pred
62
+ )
63
+ y_true = tree.map_structure(
64
+ lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_true
65
+ )
66
+
67
+ losses = self.call(y_true, y_pred)
68
+ out_mask = backend.get_keras_mask(losses)
69
+
70
+ if in_mask is not None and out_mask is not None:
71
+ mask = in_mask & out_mask
72
+ elif in_mask is not None:
73
+ mask = in_mask
74
+ elif out_mask is not None:
75
+ mask = out_mask
76
+ else:
77
+ mask = None
78
+
79
+ return reduce_weighted_values(
80
+ losses,
81
+ sample_weight=sample_weight,
82
+ mask=mask,
83
+ reduction=self.reduction,
84
+ dtype=self.dtype,
85
+ )
86
+
87
+ def call(self, y_true, y_pred):
88
+ raise NotImplementedError
89
+
90
+ def get_config(self):
91
+ return {"name": self.name, "reduction": self.reduction}
92
+
93
+ @classmethod
94
+ def from_config(cls, config):
95
+ return cls(**config)
96
+
97
+ def _obj_type(self):
98
+ return "Loss"
99
+
100
+
101
+ def standardize_reduction(reduction):
102
+ allowed = {
103
+ "sum_over_batch_size",
104
+ "sum",
105
+ None,
106
+ "none",
107
+ "mean",
108
+ "mean_with_sample_weight",
109
+ }
110
+ if reduction not in allowed:
111
+ raise ValueError(
112
+ "Invalid value for argument `reduction`. "
113
+ f"Expected one of {allowed}. Received: "
114
+ f"reduction={reduction}"
115
+ )
116
+ return reduction
117
+
118
+
119
+ def squeeze_or_expand_to_same_rank(x1, x2, expand_rank_1=True):
120
+ """Squeeze/expand last dim if ranks differ from expected by exactly 1."""
121
+ x1_rank = len(x1.shape)
122
+ x2_rank = len(x2.shape)
123
+ if x1_rank == x2_rank:
124
+ return x1, x2
125
+ if x1_rank == x2_rank + 1:
126
+ if x1.shape[-1] == 1:
127
+ if x2_rank == 1 and expand_rank_1:
128
+ x2 = ops.expand_dims(x2, axis=-1)
129
+ else:
130
+ x1 = ops.squeeze(x1, axis=-1)
131
+ if x2_rank == x1_rank + 1:
132
+ if x2.shape[-1] == 1:
133
+ if x1_rank == 1 and expand_rank_1:
134
+ x1 = ops.expand_dims(x1, axis=-1)
135
+ else:
136
+ x2 = ops.squeeze(x2, axis=-1)
137
+ return x1, x2
138
+
139
+
140
+ def reduce_values(values, sample_weight=None, reduction="sum_over_batch_size"):
141
+ if (
142
+ reduction is None
143
+ or reduction == "none"
144
+ or tuple(values.shape) == ()
145
+ or tuple(values.shape) == (0,)
146
+ ):
147
+ return values
148
+ loss = ops.sum(values)
149
+ if reduction in ("sum_over_batch_size", "mean", "mean_with_sample_weight"):
150
+ if reduction == "mean_with_sample_weight" and sample_weight is not None:
151
+ divisor = ops.cast(ops.sum(sample_weight), loss.dtype)
152
+ else:
153
+ divisor = ops.cast(
154
+ ops.prod(
155
+ ops.convert_to_tensor(ops.shape(values), dtype="int32")
156
+ ),
157
+ loss.dtype,
158
+ )
159
+ loss = ops.divide_no_nan(loss, divisor)
160
+ loss = scale_loss_for_distribution(loss)
161
+ return loss
162
+
163
+
164
+ def reduce_weighted_values(
165
+ values,
166
+ sample_weight=None,
167
+ mask=None,
168
+ reduction="sum_over_batch_size",
169
+ dtype=None,
170
+ ):
171
+ reduction = standardize_reduction(reduction)
172
+
173
+ values = ops.convert_to_tensor(values, dtype=dtype)
174
+ if sample_weight is not None:
175
+ sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype)
176
+ if mask is not None:
177
+ mask = ops.convert_to_tensor(mask, dtype=dtype)
178
+
179
+ # Merge mask and sample weight into sample weight.
180
+ sample_weight = apply_mask(
181
+ sample_weight, mask, dtype=values.dtype, reduction=reduction
182
+ )
183
+
184
+ if sample_weight is not None:
185
+ sample_weight = ops.cast(sample_weight, values.dtype)
186
+ # Update dimensions of `sample_weight` to match `losses`.
187
+ values, sample_weight = squeeze_or_expand_to_same_rank(
188
+ values, sample_weight
189
+ )
190
+ values = values * sample_weight
191
+
192
+ # Apply reduction function to the individual weighted losses.
193
+ loss = reduce_values(values, sample_weight, reduction)
194
+ return loss
195
+
196
+
197
+ def apply_mask(sample_weight, mask, dtype, reduction):
198
+ """Applies any mask on predictions to sample weights."""
199
+ if mask is not None:
200
+ mask = ops.cast(mask, dtype=dtype)
201
+ if reduction in ("mean", "sum_over_batch_size"):
202
+ # Valid entries have weight `total/valid`, while invalid ones
203
+ # have 0. When summed over batch, they will be reduced to:
204
+ #
205
+ # mean(loss * sample_weight * total / valid)
206
+ # = sum(loss * sample_weight * total / valid) / total
207
+ # = sum(loss * sample_weight) / total * total / valid
208
+ # = sum(loss * sample_weight) / valid
209
+ total = ops.cast(
210
+ ops.prod(ops.convert_to_tensor(ops.shape(mask), dtype="int32")),
211
+ dtype,
212
+ )
213
+ valid = ops.sum(mask) # May be 0!
214
+ mask *= total / (valid + backend.epsilon())
215
+
216
+ if sample_weight is not None:
217
+ sample_weight = ops.cast(sample_weight, dtype=dtype)
218
+ mask, sample_weight = squeeze_or_expand_to_same_rank(
219
+ mask, sample_weight
220
+ )
221
+ sample_weight *= mask
222
+ else:
223
+ sample_weight = mask
224
+ return sample_weight
225
+
226
+
227
+ def scale_loss_for_distribution(value):
228
+ """Scales the given value by the number of replicas in the strategy.
229
+
230
+ Currently, this function is only effective when using the tensorflow backend
231
+ and `tf.distribute`.
232
+ """
233
+ if backend.backend() == "tensorflow":
234
+ import tensorflow as tf
235
+
236
+ num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
237
+ if num_replicas > 1:
238
+ value = ops.multiply(
239
+ value, ops.cast(1.0 / num_replicas, value.dtype)
240
+ )
241
+ return value
242
+
243
+
244
+ def unscale_loss_for_distribution(value):
245
+ """Unscales the given value by the number of replicas in the strategy.
246
+
247
+ Currently, this function is only effective when using the tensorflow backend
248
+ and `tf.distribute`.
249
+ """
250
+ if backend.backend() == "tensorflow":
251
+ import tensorflow as tf
252
+
253
+ num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
254
+ if num_replicas > 1:
255
+ value = ops.multiply(value, ops.cast(num_replicas, value.dtype))
256
+ return value
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/losses/losses.py ADDED
@@ -0,0 +1,2599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from keras.src import backend
4
+ from keras.src import ops
5
+ from keras.src import tree
6
+ from keras.src.api_export import keras_export
7
+ from keras.src.losses.loss import Loss
8
+ from keras.src.losses.loss import squeeze_or_expand_to_same_rank
9
+ from keras.src.saving import serialization_lib
10
+ from keras.src.utils.numerical_utils import build_pos_neg_masks
11
+ from keras.src.utils.numerical_utils import normalize
12
+
13
+
14
+ class LossFunctionWrapper(Loss):
15
+ def __init__(
16
+ self,
17
+ fn,
18
+ reduction="sum_over_batch_size",
19
+ name=None,
20
+ dtype=None,
21
+ **kwargs,
22
+ ):
23
+ super().__init__(name=name, reduction=reduction, dtype=dtype)
24
+ self.fn = fn
25
+ self._fn_kwargs = kwargs
26
+
27
+ def call(self, y_true, y_pred):
28
+ y_true_y_pred = tree.map_structure(
29
+ squeeze_or_expand_to_same_rank, y_true, y_pred
30
+ )
31
+ y_true = tree.map_structure_up_to(y_true, lambda x: x[0], y_true_y_pred)
32
+ y_pred = tree.map_structure_up_to(y_pred, lambda x: x[1], y_true_y_pred)
33
+ return self.fn(y_true, y_pred, **self._fn_kwargs)
34
+
35
+ def get_config(self):
36
+ config = super().get_config()
37
+ config.update({"fn": serialization_lib.serialize_keras_object(self.fn)})
38
+ config.update(serialization_lib.serialize_keras_object(self._fn_kwargs))
39
+ return config
40
+
41
+ @classmethod
42
+ def from_config(cls, config):
43
+ if "fn" in config:
44
+ config = serialization_lib.deserialize_keras_object(config)
45
+ return cls(**config)
46
+
47
+ def __repr__(self):
48
+ return f"<LossFunctionWrapper({self.fn}, kwargs={self._fn_kwargs})>"
49
+
50
+
51
+ @keras_export("keras.losses.MeanSquaredError")
52
+ class MeanSquaredError(LossFunctionWrapper):
53
+ """Computes the mean of squares of errors between labels and predictions.
54
+
55
+ Formula:
56
+
57
+ ```python
58
+ loss = mean(square(y_true - y_pred))
59
+ ```
60
+
61
+ Args:
62
+ reduction: Type of reduction to apply to the loss. In almost all cases
63
+ this should be `"sum_over_batch_size"`. Supported options are
64
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
65
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
66
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
67
+ sample size, and `"mean_with_sample_weight"` sums the loss and
68
+ divides by the sum of the sample weights. `"none"` and `None`
69
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
70
+ name: Optional name for the loss instance.
71
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
72
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
73
+ `"float32"` unless set to different value
74
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
75
+ provided, then the `compute_dtype` will be utilized.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ reduction="sum_over_batch_size",
81
+ name="mean_squared_error",
82
+ dtype=None,
83
+ ):
84
+ super().__init__(
85
+ mean_squared_error, name=name, reduction=reduction, dtype=dtype
86
+ )
87
+
88
+ def get_config(self):
89
+ return Loss.get_config(self)
90
+
91
+
92
+ @keras_export("keras.losses.MeanAbsoluteError")
93
+ class MeanAbsoluteError(LossFunctionWrapper):
94
+ """Computes the mean of absolute difference between labels and predictions.
95
+
96
+ Formula:
97
+
98
+ ```python
99
+ loss = mean(abs(y_true - y_pred))
100
+ ```
101
+
102
+ Args:
103
+ reduction: Type of reduction to apply to the loss. In almost all cases
104
+ this should be `"sum_over_batch_size"`. Supported options are
105
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
106
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
107
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
108
+ sample size, and `"mean_with_sample_weight"` sums the loss and
109
+ divides by the sum of the sample weights. `"none"` and `None`
110
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
111
+ name: Optional name for the loss instance.
112
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
113
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
114
+ `"float32"` unless set to different value
115
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
116
+ provided, then the `compute_dtype` will be utilized.
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ reduction="sum_over_batch_size",
122
+ name="mean_absolute_error",
123
+ dtype=None,
124
+ ):
125
+ super().__init__(
126
+ mean_absolute_error, name=name, reduction=reduction, dtype=dtype
127
+ )
128
+
129
+ def get_config(self):
130
+ return Loss.get_config(self)
131
+
132
+
133
+ @keras_export("keras.losses.MeanAbsolutePercentageError")
134
+ class MeanAbsolutePercentageError(LossFunctionWrapper):
135
+ """Computes the mean absolute percentage error between `y_true` & `y_pred`.
136
+
137
+ Formula:
138
+
139
+ ```python
140
+ loss = 100 * mean(abs((y_true - y_pred) / y_true))
141
+ ```
142
+
143
+ Args:
144
+ reduction: Type of reduction to apply to the loss. In almost all cases
145
+ this should be `"sum_over_batch_size"`. Supported options are
146
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
147
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
148
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
149
+ sample size, and `"mean_with_sample_weight"` sums the loss and
150
+ divides by the sum of the sample weights. `"none"` and `None`
151
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
152
+ name: Optional name for the loss instance.
153
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
154
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
155
+ `"float32"` unless set to different value
156
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
157
+ provided, then the `compute_dtype` will be utilized.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ reduction="sum_over_batch_size",
163
+ name="mean_absolute_percentage_error",
164
+ dtype=None,
165
+ ):
166
+ super().__init__(
167
+ mean_absolute_percentage_error,
168
+ name=name,
169
+ reduction=reduction,
170
+ dtype=dtype,
171
+ )
172
+
173
+ def get_config(self):
174
+ return Loss.get_config(self)
175
+
176
+
177
+ @keras_export("keras.losses.MeanSquaredLogarithmicError")
178
+ class MeanSquaredLogarithmicError(LossFunctionWrapper):
179
+ """Computes the mean squared logarithmic error between `y_true` & `y_pred`.
180
+
181
+ Formula:
182
+
183
+ ```python
184
+ loss = mean(square(log(y_true + 1) - log(y_pred + 1)))
185
+ ```
186
+
187
+ Args:
188
+ reduction: Type of reduction to apply to the loss. In almost all cases
189
+ this should be `"sum_over_batch_size"`. Supported options are
190
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
191
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
192
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
193
+ sample size, and `"mean_with_sample_weight"` sums the loss and
194
+ divides by the sum of the sample weights. `"none"` and `None`
195
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
196
+ name: Optional name for the loss instance.
197
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
198
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
199
+ `"float32"` unless set to different value
200
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
201
+ provided, then the `compute_dtype` will be utilized.
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ reduction="sum_over_batch_size",
207
+ name="mean_squared_logarithmic_error",
208
+ dtype=None,
209
+ ):
210
+ super().__init__(
211
+ mean_squared_logarithmic_error,
212
+ name=name,
213
+ reduction=reduction,
214
+ dtype=dtype,
215
+ )
216
+
217
+ def get_config(self):
218
+ return Loss.get_config(self)
219
+
220
+
221
+ @keras_export("keras.losses.CosineSimilarity")
222
+ class CosineSimilarity(LossFunctionWrapper):
223
+ """Computes the cosine similarity between `y_true` & `y_pred`.
224
+
225
+ Note that it is a number between -1 and 1. When it is a negative number
226
+ between -1 and 0, 0 indicates orthogonality and values closer to -1
227
+ indicate greater similarity. This makes it usable as a loss function in a
228
+ setting where you try to maximize the proximity between predictions and
229
+ targets. If either `y_true` or `y_pred` is a zero vector, cosine similarity
230
+ will be 0 regardless of the proximity between predictions and targets.
231
+
232
+ Formula:
233
+
234
+ ```python
235
+ loss = -sum(l2_norm(y_true) * l2_norm(y_pred))
236
+ ```
237
+
238
+ Args:
239
+ axis: The axis along which the cosine similarity is computed
240
+ (the features axis). Defaults to `-1`.
241
+ reduction: Type of reduction to apply to the loss. In almost all cases
242
+ this should be `"sum_over_batch_size"`. Supported options are
243
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
244
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
245
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
246
+ sample size, and `"mean_with_sample_weight"` sums the loss and
247
+ divides by the sum of the sample weights. `"none"` and `None`
248
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
249
+ name: Optional name for the loss instance.
250
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
251
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
252
+ `"float32"` unless set to different value
253
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
254
+ provided, then the `compute_dtype` will be utilized.
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ axis=-1,
260
+ reduction="sum_over_batch_size",
261
+ name="cosine_similarity",
262
+ dtype=None,
263
+ ):
264
+ super().__init__(
265
+ cosine_similarity,
266
+ name=name,
267
+ reduction=reduction,
268
+ dtype=dtype,
269
+ axis=axis,
270
+ )
271
+
272
+ def get_config(self):
273
+ return Loss.get_config(self)
274
+
275
+
276
+ @keras_export("keras.losses.Huber")
277
+ class Huber(LossFunctionWrapper):
278
+ """Computes the Huber loss between `y_true` & `y_pred`.
279
+
280
+ Formula:
281
+
282
+ ```python
283
+ for x in error:
284
+ if abs(x) <= delta:
285
+ loss.append(0.5 * x^2)
286
+ elif abs(x) > delta:
287
+ loss.append(delta * abs(x) - 0.5 * delta^2)
288
+
289
+ loss = mean(loss, axis=-1)
290
+ ```
291
+ See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).
292
+
293
+ Args:
294
+ delta: A float, the point where the Huber loss function changes from a
295
+ quadratic to linear.
296
+ reduction: Type of reduction to apply to the loss. In almost all cases
297
+ this should be `"sum_over_batch_size"`. Supported options are
298
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
299
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
300
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
301
+ sample size, and `"mean_with_sample_weight"` sums the loss and
302
+ divides by the sum of the sample weights. `"none"` and `None`
303
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
304
+ name: Optional name for the instance.
305
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
306
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
307
+ `"float32"` unless set to different value
308
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
309
+ provided, then the `compute_dtype` will be utilized.
310
+ """
311
+
312
+ def __init__(
313
+ self,
314
+ delta=1.0,
315
+ reduction="sum_over_batch_size",
316
+ name="huber_loss",
317
+ dtype=None,
318
+ ):
319
+ super().__init__(
320
+ huber,
321
+ name=name,
322
+ reduction=reduction,
323
+ dtype=dtype,
324
+ delta=delta,
325
+ )
326
+
327
+ def get_config(self):
328
+ return Loss.get_config(self)
329
+
330
+
331
+ @keras_export("keras.losses.LogCosh")
332
+ class LogCosh(LossFunctionWrapper):
333
+ """Computes the logarithm of the hyperbolic cosine of the prediction error.
334
+
335
+ Formula:
336
+
337
+ ```python
338
+ error = y_pred - y_true
339
+ logcosh = mean(log((exp(error) + exp(-error))/2), axis=-1)`
340
+ ```
341
+ where x is the error `y_pred - y_true`.
342
+
343
+ Args:
344
+ reduction: Type of reduction to apply to the loss. In almost all cases
345
+ this should be `"sum_over_batch_size"`. Supported options are
346
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
347
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
348
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
349
+ sample size, and `"mean_with_sample_weight"` sums the loss and
350
+ divides by the sum of the sample weights. `"none"` and `None`
351
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
352
+ name: Optional name for the instance.
353
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
354
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
355
+ `"float32"` unless set to different value
356
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
357
+ provided, then the `compute_dtype` will be utilized.
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ reduction="sum_over_batch_size",
363
+ name="log_cosh",
364
+ dtype=None,
365
+ ):
366
+ super().__init__(log_cosh, name=name, reduction=reduction, dtype=dtype)
367
+
368
+ def get_config(self):
369
+ return Loss.get_config(self)
370
+
371
+
372
+ @keras_export("keras.losses.Hinge")
373
+ class Hinge(LossFunctionWrapper):
374
+ """Computes the hinge loss between `y_true` & `y_pred`.
375
+
376
+ Formula:
377
+
378
+ ```python
379
+ loss = maximum(1 - y_true * y_pred, 0)
380
+ ```
381
+
382
+ `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
383
+ provided we will convert them to -1 or 1.
384
+
385
+ Args:
386
+ reduction: Type of reduction to apply to the loss. In almost all cases
387
+ this should be `"sum_over_batch_size"`. Supported options are
388
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
389
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
390
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
391
+ sample size, and `"mean_with_sample_weight"` sums the loss and
392
+ divides by the sum of the sample weights. `"none"` and `None`
393
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
394
+ name: Optional name for the loss instance.
395
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
396
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
397
+ `"float32"` unless set to different value
398
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
399
+ provided, then the `compute_dtype` will be utilized.
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ reduction="sum_over_batch_size",
405
+ name="hinge",
406
+ dtype=None,
407
+ ):
408
+ super().__init__(hinge, name=name, reduction=reduction, dtype=dtype)
409
+
410
+ def get_config(self):
411
+ return Loss.get_config(self)
412
+
413
+
414
+ @keras_export("keras.losses.SquaredHinge")
415
+ class SquaredHinge(LossFunctionWrapper):
416
+ """Computes the squared hinge loss between `y_true` & `y_pred`.
417
+
418
+ Formula:
419
+
420
+ ```python
421
+ loss = square(maximum(1 - y_true * y_pred, 0))
422
+ ```
423
+
424
+ `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
425
+ provided we will convert them to -1 or 1.
426
+
427
+ Args:
428
+ reduction: Type of reduction to apply to the loss. In almost all cases
429
+ this should be `"sum_over_batch_size"`. Supported options are
430
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
431
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
432
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
433
+ sample size, and `"mean_with_sample_weight"` sums the loss and
434
+ divides by the sum of the sample weights. `"none"` and `None`
435
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
436
+ name: Optional name for the loss instance.
437
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
438
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
439
+ `"float32"` unless set to different value
440
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
441
+ provided, then the `compute_dtype` will be utilized.
442
+ """
443
+
444
+ def __init__(
445
+ self, reduction="sum_over_batch_size", name="squared_hinge", dtype=None
446
+ ):
447
+ super().__init__(
448
+ squared_hinge, name=name, reduction=reduction, dtype=dtype
449
+ )
450
+
451
+ def get_config(self):
452
+ return Loss.get_config(self)
453
+
454
+
455
+ @keras_export("keras.losses.CategoricalHinge")
456
+ class CategoricalHinge(LossFunctionWrapper):
457
+ """Computes the categorical hinge loss between `y_true` & `y_pred`.
458
+
459
+ Formula:
460
+
461
+ ```python
462
+ loss = maximum(neg - pos + 1, 0)
463
+ ```
464
+
465
+ where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)`
466
+
467
+ Args:
468
+ reduction: Type of reduction to apply to the loss. In almost all cases
469
+ this should be `"sum_over_batch_size"`. Supported options are
470
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
471
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
472
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
473
+ sample size, and `"mean_with_sample_weight"` sums the loss and
474
+ divides by the sum of the sample weights. `"none"` and `None`
475
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
476
+ name: Optional name for the loss instance.
477
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
478
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
479
+ `"float32"` unless set to different value
480
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
481
+ provided, then the `compute_dtype` will be utilized.
482
+ """
483
+
484
+ def __init__(
485
+ self,
486
+ reduction="sum_over_batch_size",
487
+ name="categorical_hinge",
488
+ dtype=None,
489
+ ):
490
+ super().__init__(
491
+ categorical_hinge, name=name, reduction=reduction, dtype=dtype
492
+ )
493
+
494
+ def get_config(self):
495
+ return Loss.get_config(self)
496
+
497
+
498
+ @keras_export("keras.losses.KLDivergence")
499
+ class KLDivergence(LossFunctionWrapper):
500
+ """Computes Kullback-Leibler divergence loss between `y_true` & `y_pred`.
501
+
502
+ Formula:
503
+
504
+ ```python
505
+ loss = y_true * log(y_true / y_pred)
506
+ ```
507
+
508
+ `y_true` and `y_pred` are expected to be probability
509
+ distributions, with values between 0 and 1. They will get
510
+ clipped to the `[0, 1]` range.
511
+
512
+ Args:
513
+ reduction: Type of reduction to apply to the loss. In almost all cases
514
+ this should be `"sum_over_batch_size"`. Supported options are
515
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
516
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
517
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
518
+ sample size, and `"mean_with_sample_weight"` sums the loss and
519
+ divides by the sum of the sample weights. `"none"` and `None`
520
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
521
+ name: Optional name for the loss instance.
522
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
523
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
524
+ `"float32"` unless set to different value
525
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
526
+ provided, then the `compute_dtype` will be utilized.
527
+ """
528
+
529
+ def __init__(
530
+ self, reduction="sum_over_batch_size", name="kl_divergence", dtype=None
531
+ ):
532
+ super().__init__(
533
+ kl_divergence, name=name, reduction=reduction, dtype=dtype
534
+ )
535
+
536
+ def get_config(self):
537
+ return Loss.get_config(self)
538
+
539
+
540
+ @keras_export("keras.losses.Poisson")
541
+ class Poisson(LossFunctionWrapper):
542
+ """Computes the Poisson loss between `y_true` & `y_pred`.
543
+
544
+ Formula:
545
+
546
+ ```python
547
+ loss = y_pred - y_true * log(y_pred)
548
+ ```
549
+
550
+ Args:
551
+ reduction: Type of reduction to apply to the loss. In almost all cases
552
+ this should be `"sum_over_batch_size"`. Supported options are
553
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
554
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
555
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
556
+ sample size, and `"mean_with_sample_weight"` sums the loss and
557
+ divides by the sum of the sample weights. `"none"` and `None`
558
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
559
+ name: Optional name for the loss instance.
560
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
561
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
562
+ `"float32"` unless set to different value
563
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
564
+ provided, then the `compute_dtype` will be utilized.
565
+ """
566
+
567
+ def __init__(
568
+ self, reduction="sum_over_batch_size", name="poisson", dtype=None
569
+ ):
570
+ super().__init__(poisson, name=name, reduction=reduction, dtype=dtype)
571
+
572
+ def get_config(self):
573
+ return Loss.get_config(self)
574
+
575
+
576
+ @keras_export("keras.losses.BinaryCrossentropy")
577
+ class BinaryCrossentropy(LossFunctionWrapper):
578
+ """Computes the cross-entropy loss between true labels and predicted labels.
579
+
580
+ Use this cross-entropy loss for binary (0 or 1) classification applications.
581
+ The loss function requires the following inputs:
582
+
583
+ - `y_true` (true label): This is either 0 or 1.
584
+ - `y_pred` (predicted value): This is the model's prediction, i.e, a single
585
+ floating-point value which either represents a
586
+ [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
587
+ when `from_logits=True`) or a probability (i.e, value in [0., 1.] when
588
+ `from_logits=False`).
589
+
590
+ Args:
591
+ from_logits: Whether to interpret `y_pred` as a tensor of
592
+ [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
593
+ assume that `y_pred` is probabilities (i.e., values in [0, 1]).
594
+ label_smoothing: Float in range [0, 1]. When 0, no smoothing occurs.
595
+ When > 0, we compute the loss between the predicted labels
596
+ and a smoothed version of the true labels, where the smoothing
597
+ squeezes the labels towards 0.5. Larger values of
598
+ `label_smoothing` correspond to heavier smoothing.
599
+ axis: The axis along which to compute crossentropy (the features axis).
600
+ Defaults to `-1`.
601
+ reduction: Type of reduction to apply to the loss. In almost all cases
602
+ this should be `"sum_over_batch_size"`. Supported options are
603
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
604
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
605
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
606
+ sample size, and `"mean_with_sample_weight"` sums the loss and
607
+ divides by the sum of the sample weights. `"none"` and `None`
608
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
609
+ name: Optional name for the loss instance.
610
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
611
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
612
+ `"float32"` unless set to different value
613
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
614
+ provided, then the `compute_dtype` will be utilized.
615
+
616
+ Examples:
617
+
618
+ **Recommended Usage:** (set `from_logits=True`)
619
+
620
+ With `compile()` API:
621
+
622
+ ```python
623
+ model.compile(
624
+ loss=keras.losses.BinaryCrossentropy(from_logits=True),
625
+ ...
626
+ )
627
+ ```
628
+
629
+ As a standalone function:
630
+
631
+ >>> # Example 1: (batch_size = 1, number of samples = 4)
632
+ >>> y_true = np.array([0, 1, 0, 0])
633
+ >>> y_pred = np.array([-18.6, 0.51, 2.94, -12.8])
634
+ >>> bce = keras.losses.BinaryCrossentropy(from_logits=True)
635
+ >>> bce(y_true, y_pred)
636
+ 0.8654
637
+
638
+ >>> # Example 2: (batch_size = 2, number of samples = 4)
639
+ >>> y_true = np.array([[0, 1], [0, 0]])
640
+ >>> y_pred = np.array([[-18.6, 0.51], [2.94, -12.8]])
641
+ >>> # Using default 'auto'/'sum_over_batch_size' reduction type.
642
+ >>> bce = keras.losses.BinaryCrossentropy(from_logits=True)
643
+ >>> bce(y_true, y_pred)
644
+ 0.8654
645
+ >>> # Using 'sample_weight' attribute
646
+ >>> bce(y_true, y_pred, sample_weight=[0.8, 0.2])
647
+ 0.243
648
+ >>> # Using 'sum' reduction` type.
649
+ >>> bce = keras.losses.BinaryCrossentropy(from_logits=True,
650
+ ... reduction="sum")
651
+ >>> bce(y_true, y_pred)
652
+ 1.730
653
+ >>> # Using 'none' reduction type.
654
+ >>> bce = keras.losses.BinaryCrossentropy(from_logits=True,
655
+ ... reduction=None)
656
+ >>> bce(y_true, y_pred)
657
+ array([0.235, 1.496], dtype=float32)
658
+
659
+ **Default Usage:** (set `from_logits=False`)
660
+
661
+ >>> # Make the following updates to the above "Recommended Usage" section
662
+ >>> # 1. Set `from_logits=False`
663
+ >>> keras.losses.BinaryCrossentropy() # OR ...('from_logits=False')
664
+ >>> # 2. Update `y_pred` to use probabilities instead of logits
665
+ >>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]]
666
+ """
667
+
668
+ def __init__(
669
+ self,
670
+ from_logits=False,
671
+ label_smoothing=0.0,
672
+ axis=-1,
673
+ reduction="sum_over_batch_size",
674
+ name="binary_crossentropy",
675
+ dtype=None,
676
+ ):
677
+ super().__init__(
678
+ binary_crossentropy,
679
+ name=name,
680
+ reduction=reduction,
681
+ dtype=dtype,
682
+ from_logits=from_logits,
683
+ label_smoothing=label_smoothing,
684
+ axis=axis,
685
+ )
686
+ self.from_logits = from_logits
687
+ self.label_smoothing = label_smoothing
688
+ self.axis = axis
689
+
690
+ def get_config(self):
691
+ config = Loss.get_config(self)
692
+ config.update(
693
+ {
694
+ "from_logits": self.from_logits,
695
+ "label_smoothing": self.label_smoothing,
696
+ "axis": self.axis,
697
+ }
698
+ )
699
+ return config
700
+
701
+
702
+ @keras_export("keras.losses.BinaryFocalCrossentropy")
703
+ class BinaryFocalCrossentropy(LossFunctionWrapper):
704
+ """Computes focal cross-entropy loss between true labels and predictions.
705
+
706
+ Binary cross-entropy loss is often used for binary (0 or 1) classification
707
+ tasks. The loss function requires the following inputs:
708
+
709
+ - `y_true` (true label): This is either 0 or 1.
710
+ - `y_pred` (predicted value): This is the model's prediction, i.e, a single
711
+ floating-point value which either represents a
712
+ [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
713
+ when `from_logits=True`) or a probability (i.e, value in `[0., 1.]` when
714
+ `from_logits=False`).
715
+
716
+ According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
717
+ helps to apply a "focal factor" to down-weight easy examples and focus more
718
+ on hard examples. By default, the focal tensor is computed as follows:
719
+
720
+ `focal_factor = (1 - output) ** gamma` for class 1
721
+ `focal_factor = output ** gamma` for class 0
722
+ where `gamma` is a focusing parameter. When `gamma=0`, this function is
723
+ equivalent to the binary crossentropy loss.
724
+
725
+ Args:
726
+ apply_class_balancing: A bool, whether to apply weight balancing on the
727
+ binary classes 0 and 1.
728
+ alpha: A weight balancing factor for class 1, default is `0.25` as
729
+ mentioned in reference [Lin et al., 2018](
730
+ https://arxiv.org/pdf/1708.02002.pdf). The weight for class 0 is
731
+ `1.0 - alpha`.
732
+ gamma: A focusing parameter used to compute the focal factor, default is
733
+ `2.0` as mentioned in the reference
734
+ [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf).
735
+ from_logits: Whether to interpret `y_pred` as a tensor of
736
+ [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
737
+ assume that `y_pred` are probabilities (i.e., values in `[0, 1]`).
738
+ label_smoothing: Float in `[0, 1]`. When `0`, no smoothing occurs.
739
+ When > `0`, we compute the loss between the predicted labels
740
+ and a smoothed version of the true labels, where the smoothing
741
+ squeezes the labels towards `0.5`.
742
+ Larger values of `label_smoothing` correspond to heavier smoothing.
743
+ axis: The axis along which to compute crossentropy (the features axis).
744
+ Defaults to `-1`.
745
+ reduction: Type of reduction to apply to the loss. In almost all cases
746
+ this should be `"sum_over_batch_size"`. Supported options are
747
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
748
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
749
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
750
+ sample size, and `"mean_with_sample_weight"` sums the loss and
751
+ divides by the sum of the sample weights. `"none"` and `None`
752
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
753
+ name: Optional name for the loss instance.
754
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
755
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
756
+ `"float32"` unless set to different value
757
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
758
+ provided, then the `compute_dtype` will be utilized.
759
+
760
+ Examples:
761
+
762
+ With the `compile()` API:
763
+
764
+ ```python
765
+ model.compile(
766
+ loss=keras.losses.BinaryFocalCrossentropy(
767
+ gamma=2.0, from_logits=True),
768
+ ...
769
+ )
770
+ ```
771
+
772
+ As a standalone function:
773
+
774
+ >>> # Example 1: (batch_size = 1, number of samples = 4)
775
+ >>> y_true = np.array([0, 1, 0, 0])
776
+ >>> y_pred = np.array([-18.6, 0.51, 2.94, -12.8])
777
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
778
+ ... gamma=2, from_logits=True)
779
+ >>> loss(y_true, y_pred)
780
+ 0.691
781
+
782
+ >>> # Apply class weight
783
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
784
+ ... apply_class_balancing=True, gamma=2, from_logits=True)
785
+ >>> loss(y_true, y_pred)
786
+ 0.51
787
+
788
+ >>> # Example 2: (batch_size = 2, number of samples = 4)
789
+ >>> y_true = np.array([[0, 1], [0, 0]])
790
+ >>> y_pred = np.array([[-18.6, 0.51], [2.94, -12.8]])
791
+ >>> # Using default 'auto'/'sum_over_batch_size' reduction type.
792
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
793
+ ... gamma=3, from_logits=True)
794
+ >>> loss(y_true, y_pred)
795
+ 0.647
796
+
797
+ >>> # Apply class weight
798
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
799
+ ... apply_class_balancing=True, gamma=3, from_logits=True)
800
+ >>> loss(y_true, y_pred)
801
+ 0.482
802
+
803
+ >>> # Using 'sample_weight' attribute with focal effect
804
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
805
+ ... gamma=3, from_logits=True)
806
+ >>> loss(y_true, y_pred, sample_weight=[0.8, 0.2])
807
+ 0.133
808
+
809
+ >>> # Apply class weight
810
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
811
+ ... apply_class_balancing=True, gamma=3, from_logits=True)
812
+ >>> loss(y_true, y_pred, sample_weight=[0.8, 0.2])
813
+ 0.097
814
+
815
+ >>> # Using 'sum' reduction` type.
816
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
817
+ ... gamma=4, from_logits=True,
818
+ ... reduction="sum")
819
+ >>> loss(y_true, y_pred)
820
+ 1.222
821
+
822
+ >>> # Apply class weight
823
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
824
+ ... apply_class_balancing=True, gamma=4, from_logits=True,
825
+ ... reduction="sum")
826
+ >>> loss(y_true, y_pred)
827
+ 0.914
828
+
829
+ >>> # Using 'none' reduction type.
830
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
831
+ ... gamma=5, from_logits=True,
832
+ ... reduction=None)
833
+ >>> loss(y_true, y_pred)
834
+ array([0.0017 1.1561], dtype=float32)
835
+
836
+ >>> # Apply class weight
837
+ >>> loss = keras.losses.BinaryFocalCrossentropy(
838
+ ... apply_class_balancing=True, gamma=5, from_logits=True,
839
+ ... reduction=None)
840
+ >>> loss(y_true, y_pred)
841
+ array([0.0004 0.8670], dtype=float32)
842
+ """
843
+
844
+ def __init__(
845
+ self,
846
+ apply_class_balancing=False,
847
+ alpha=0.25,
848
+ gamma=2.0,
849
+ from_logits=False,
850
+ label_smoothing=0.0,
851
+ axis=-1,
852
+ reduction="sum_over_batch_size",
853
+ name="binary_focal_crossentropy",
854
+ dtype=None,
855
+ ):
856
+ super().__init__(
857
+ binary_focal_crossentropy,
858
+ name=name,
859
+ reduction=reduction,
860
+ dtype=dtype,
861
+ apply_class_balancing=apply_class_balancing,
862
+ alpha=alpha,
863
+ gamma=gamma,
864
+ from_logits=from_logits,
865
+ label_smoothing=label_smoothing,
866
+ axis=axis,
867
+ )
868
+ self.from_logits = from_logits
869
+ self.label_smoothing = label_smoothing
870
+ self.axis = axis
871
+ self.apply_class_balancing = apply_class_balancing
872
+ self.alpha = alpha
873
+ self.gamma = gamma
874
+
875
+ def get_config(self):
876
+ config = Loss.get_config(self)
877
+ config.update(
878
+ {
879
+ "from_logits": self.from_logits,
880
+ "label_smoothing": self.label_smoothing,
881
+ "axis": self.axis,
882
+ "apply_class_balancing": self.apply_class_balancing,
883
+ "alpha": self.alpha,
884
+ "gamma": self.gamma,
885
+ }
886
+ )
887
+ return config
888
+
889
+
890
+ @keras_export("keras.losses.CategoricalCrossentropy")
891
+ class CategoricalCrossentropy(LossFunctionWrapper):
892
+ """Computes the crossentropy loss between the labels and predictions.
893
+
894
+ Use this crossentropy loss function when there are two or more label
895
+ classes. We expect labels to be provided in a `one_hot` representation. If
896
+ you want to provide labels as integers, please use
897
+ `SparseCategoricalCrossentropy` loss. There should be `num_classes` floating
898
+ point values per feature, i.e., the shape of both `y_pred` and `y_true` are
899
+ `[batch_size, num_classes]`.
900
+
901
+ Args:
902
+ from_logits: Whether `y_pred` is expected to be a logits tensor. By
903
+ default, we assume that `y_pred` encodes a probability distribution.
904
+ label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
905
+ meaning the confidence on label values are relaxed. For example, if
906
+ `0.1`, use `0.1 / num_classes` for non-target labels and
907
+ `0.9 + 0.1 / num_classes` for target labels.
908
+ axis: The axis along which to compute crossentropy (the features
909
+ axis). Defaults to `-1`.
910
+ reduction: Type of reduction to apply to the loss. In almost all cases
911
+ this should be `"sum_over_batch_size"`. Supported options are
912
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
913
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
914
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
915
+ sample size, and `"mean_with_sample_weight"` sums the loss and
916
+ divides by the sum of the sample weights. `"none"` and `None`
917
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
918
+ name: Optional name for the loss instance.
919
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
920
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
921
+ `"float32"` unless set to different value
922
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
923
+ provided, then the `compute_dtype` will be utilized.
924
+
925
+ Examples:
926
+
927
+ Standalone usage:
928
+
929
+ >>> y_true = np.array([[0, 1, 0], [0, 0, 1]])
930
+ >>> y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
931
+ >>> # Using 'auto'/'sum_over_batch_size' reduction type.
932
+ >>> cce = keras.losses.CategoricalCrossentropy()
933
+ >>> cce(y_true, y_pred)
934
+ 1.177
935
+
936
+ >>> # Calling with 'sample_weight'.
937
+ >>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7]))
938
+ 0.814
939
+
940
+ >>> # Using 'sum' reduction type.
941
+ >>> cce = keras.losses.CategoricalCrossentropy(
942
+ ... reduction="sum")
943
+ >>> cce(y_true, y_pred)
944
+ 2.354
945
+
946
+ >>> # Using 'none' reduction type.
947
+ >>> cce = keras.losses.CategoricalCrossentropy(
948
+ ... reduction=None)
949
+ >>> cce(y_true, y_pred)
950
+ array([0.0513, 2.303], dtype=float32)
951
+
952
+ Usage with the `compile()` API:
953
+
954
+ ```python
955
+ model.compile(optimizer='sgd',
956
+ loss=keras.losses.CategoricalCrossentropy())
957
+ ```
958
+ """
959
+
960
+ def __init__(
961
+ self,
962
+ from_logits=False,
963
+ label_smoothing=0.0,
964
+ axis=-1,
965
+ reduction="sum_over_batch_size",
966
+ name="categorical_crossentropy",
967
+ dtype=None,
968
+ ):
969
+ super().__init__(
970
+ categorical_crossentropy,
971
+ name=name,
972
+ reduction=reduction,
973
+ dtype=dtype,
974
+ from_logits=from_logits,
975
+ label_smoothing=label_smoothing,
976
+ axis=axis,
977
+ )
978
+ self.from_logits = from_logits
979
+ self.label_smoothing = label_smoothing
980
+ self.axis = axis
981
+
982
+ def get_config(self):
983
+ config = Loss.get_config(self)
984
+ config.update(
985
+ {
986
+ "from_logits": self.from_logits,
987
+ "label_smoothing": self.label_smoothing,
988
+ "axis": self.axis,
989
+ }
990
+ )
991
+ return config
992
+
993
+
994
+ @keras_export("keras.losses.CategoricalFocalCrossentropy")
995
+ class CategoricalFocalCrossentropy(LossFunctionWrapper):
996
+ """Computes the alpha balanced focal crossentropy loss.
997
+
998
+ Use this crossentropy loss function when there are two or more label
999
+ classes and if you want to handle class imbalance without using
1000
+ `class_weights`. We expect labels to be provided in a `one_hot`
1001
+ representation.
1002
+
1003
+ According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
1004
+ helps to apply a focal factor to down-weight easy examples and focus more on
1005
+ hard examples. The general formula for the focal loss (FL)
1006
+ is as follows:
1007
+
1008
+ `FL(p_t) = (1 - p_t) ** gamma * log(p_t)`
1009
+
1010
+ where `p_t` is defined as follows:
1011
+ `p_t = output if y_true == 1, else 1 - output`
1012
+
1013
+ `(1 - p_t) ** gamma` is the `modulating_factor`, where `gamma` is a focusing
1014
+ parameter. When `gamma` = 0, there is no focal effect on the cross entropy.
1015
+ `gamma` reduces the importance given to simple examples in a smooth manner.
1016
+
1017
+ The authors use alpha-balanced variant of focal loss (FL) in the paper:
1018
+ `FL(p_t) = -alpha * (1 - p_t) ** gamma * log(p_t)`
1019
+
1020
+ where `alpha` is the weight factor for the classes. If `alpha` = 1, the
1021
+ loss won't be able to handle class imbalance properly as all
1022
+ classes will have the same weight. This can be a constant or a list of
1023
+ constants. If alpha is a list, it must have the same length as the number
1024
+ of classes.
1025
+
1026
+ The formula above can be generalized to:
1027
+ `FL(p_t) = alpha * (1 - p_t) ** gamma * CrossEntropy(y_true, y_pred)`
1028
+
1029
+ where minus comes from `CrossEntropy(y_true, y_pred)` (CE).
1030
+
1031
+ Extending this to multi-class case is straightforward:
1032
+ `FL(p_t) = alpha * (1 - p_t) ** gamma * CategoricalCE(y_true, y_pred)`
1033
+
1034
+ In the snippet below, there is `num_classes` floating pointing values per
1035
+ example. The shape of both `y_pred` and `y_true` are
1036
+ `(batch_size, num_classes)`.
1037
+
1038
+ Args:
1039
+ alpha: A weight balancing factor for all classes, default is `0.25` as
1040
+ mentioned in the reference. It can be a list of floats or a scalar.
1041
+ In the multi-class case, alpha may be set by inverse class
1042
+ frequency by using `compute_class_weight` from `sklearn.utils`.
1043
+ gamma: A focusing parameter, default is `2.0` as mentioned in the
1044
+ reference. It helps to gradually reduce the importance given to
1045
+ simple (easy) examples in a smooth manner.
1046
+ from_logits: Whether `output` is expected to be a logits tensor. By
1047
+ default, we consider that `output` encodes a probability
1048
+ distribution.
1049
+ label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
1050
+ meaning the confidence on label values are relaxed. For example, if
1051
+ `0.1`, use `0.1 / num_classes` for non-target labels and
1052
+ `0.9 + 0.1 / num_classes` for target labels.
1053
+ axis: The axis along which to compute crossentropy (the features
1054
+ axis). Defaults to `-1`.
1055
+ reduction: Type of reduction to apply to the loss. In almost all cases
1056
+ this should be `"sum_over_batch_size"`. Supported options are
1057
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
1058
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
1059
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
1060
+ sample size, and `"mean_with_sample_weight"` sums the loss and
1061
+ divides by the sum of the sample weights. `"none"` and `None`
1062
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
1063
+ name: Optional name for the loss instance.
1064
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
1065
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
1066
+ `"float32"` unless set to different value
1067
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
1068
+ provided, then the `compute_dtype` will be utilized.
1069
+
1070
+ Examples:
1071
+
1072
+ Standalone usage:
1073
+
1074
+ >>> y_true = [[0., 1., 0.], [0., 0., 1.]]
1075
+ >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
1076
+ >>> # Using 'auto'/'sum_over_batch_size' reduction type.
1077
+ >>> cce = keras.losses.CategoricalFocalCrossentropy()
1078
+ >>> cce(y_true, y_pred)
1079
+ 0.23315276
1080
+
1081
+ >>> # Calling with 'sample_weight'.
1082
+ >>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7]))
1083
+ 0.1632
1084
+
1085
+ >>> # Using 'sum' reduction type.
1086
+ >>> cce = keras.losses.CategoricalFocalCrossentropy(
1087
+ ... reduction="sum")
1088
+ >>> cce(y_true, y_pred)
1089
+ 0.46631
1090
+
1091
+ >>> # Using 'none' reduction type.
1092
+ >>> cce = keras.losses.CategoricalFocalCrossentropy(
1093
+ ... reduction=None)
1094
+ >>> cce(y_true, y_pred)
1095
+ array([3.2058331e-05, 4.6627346e-01], dtype=float32)
1096
+
1097
+ Usage with the `compile()` API:
1098
+
1099
+ ```python
1100
+ model.compile(optimizer='adam',
1101
+ loss=keras.losses.CategoricalFocalCrossentropy())
1102
+ ```
1103
+ """
1104
+
1105
+ def __init__(
1106
+ self,
1107
+ alpha=0.25,
1108
+ gamma=2.0,
1109
+ from_logits=False,
1110
+ label_smoothing=0.0,
1111
+ axis=-1,
1112
+ reduction="sum_over_batch_size",
1113
+ name="categorical_focal_crossentropy",
1114
+ dtype=None,
1115
+ ):
1116
+ """Initializes `CategoricalFocalCrossentropy` instance."""
1117
+ super().__init__(
1118
+ categorical_focal_crossentropy,
1119
+ name=name,
1120
+ reduction=reduction,
1121
+ dtype=dtype,
1122
+ alpha=alpha,
1123
+ gamma=gamma,
1124
+ from_logits=from_logits,
1125
+ label_smoothing=label_smoothing,
1126
+ axis=axis,
1127
+ )
1128
+ self.from_logits = from_logits
1129
+ self.label_smoothing = label_smoothing
1130
+ self.axis = axis
1131
+ self.alpha = alpha
1132
+ self.gamma = gamma
1133
+
1134
+ def get_config(self):
1135
+ config = Loss.get_config(self)
1136
+ config.update(
1137
+ {
1138
+ "from_logits": self.from_logits,
1139
+ "label_smoothing": self.label_smoothing,
1140
+ "axis": self.axis,
1141
+ "alpha": self.alpha,
1142
+ "gamma": self.gamma,
1143
+ }
1144
+ )
1145
+ return config
1146
+
1147
+
1148
+ @keras_export("keras.losses.SparseCategoricalCrossentropy")
1149
+ class SparseCategoricalCrossentropy(LossFunctionWrapper):
1150
+ """Computes the crossentropy loss between the labels and predictions.
1151
+
1152
+ Use this crossentropy loss function when there are two or more label
1153
+ classes. We expect labels to be provided as integers. If you want to
1154
+ provide labels using `one-hot` representation, please use
1155
+ `CategoricalCrossentropy` loss. There should be `# classes` floating point
1156
+ values per feature for `y_pred` and a single floating point value per
1157
+ feature for `y_true`.
1158
+
1159
+ In the snippet below, there is a single floating point value per example for
1160
+ `y_true` and `num_classes` floating pointing values per example for
1161
+ `y_pred`. The shape of `y_true` is `[batch_size]` and the shape of `y_pred`
1162
+ is `[batch_size, num_classes]`.
1163
+
1164
+ Args:
1165
+ from_logits: Whether `y_pred` is expected to be a logits tensor. By
1166
+ default, we assume that `y_pred` encodes a probability distribution.
1167
+ reduction: Type of reduction to apply to the loss. In almost all cases
1168
+ this should be `"sum_over_batch_size"`. Supported options are
1169
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
1170
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
1171
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
1172
+ sample size, and `"mean_with_sample_weight"` sums the loss and
1173
+ divides by the sum of the sample weights. `"none"` and `None`
1174
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
1175
+ name: Optional name for the loss instance.
1176
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
1177
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
1178
+ `"float32"` unless set to different value
1179
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
1180
+ provided, then the `compute_dtype` will be utilized.
1181
+
1182
+ Examples:
1183
+
1184
+ >>> y_true = [1, 2]
1185
+ >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
1186
+ >>> # Using 'auto'/'sum_over_batch_size' reduction type.
1187
+ >>> scce = keras.losses.SparseCategoricalCrossentropy()
1188
+ >>> scce(y_true, y_pred)
1189
+ 1.177
1190
+
1191
+ >>> # Calling with 'sample_weight'.
1192
+ >>> scce(y_true, y_pred, sample_weight=np.array([0.3, 0.7]))
1193
+ 0.814
1194
+
1195
+ >>> # Using 'sum' reduction type.
1196
+ >>> scce = keras.losses.SparseCategoricalCrossentropy(
1197
+ ... reduction="sum")
1198
+ >>> scce(y_true, y_pred)
1199
+ 2.354
1200
+
1201
+ >>> # Using 'none' reduction type.
1202
+ >>> scce = keras.losses.SparseCategoricalCrossentropy(
1203
+ ... reduction=None)
1204
+ >>> scce(y_true, y_pred)
1205
+ array([0.0513, 2.303], dtype=float32)
1206
+
1207
+ Usage with the `compile()` API:
1208
+
1209
+ ```python
1210
+ model.compile(optimizer='sgd',
1211
+ loss=keras.losses.SparseCategoricalCrossentropy())
1212
+ ```
1213
+ """
1214
+
1215
+ def __init__(
1216
+ self,
1217
+ from_logits=False,
1218
+ ignore_class=None,
1219
+ reduction="sum_over_batch_size",
1220
+ name="sparse_categorical_crossentropy",
1221
+ dtype=None,
1222
+ ):
1223
+ super().__init__(
1224
+ sparse_categorical_crossentropy,
1225
+ name=name,
1226
+ reduction=reduction,
1227
+ dtype=dtype,
1228
+ from_logits=from_logits,
1229
+ ignore_class=ignore_class,
1230
+ )
1231
+ self.from_logits = from_logits
1232
+ self.ignore_class = ignore_class
1233
+
1234
+ def get_config(self):
1235
+ config = Loss.get_config(self)
1236
+ config.update(
1237
+ {
1238
+ "from_logits": self.from_logits,
1239
+ "ignore_class": self.ignore_class,
1240
+ }
1241
+ )
1242
+ return config
1243
+
1244
+
1245
+ @keras_export("keras.losses.CTC")
1246
+ class CTC(LossFunctionWrapper):
1247
+ """CTC (Connectionist Temporal Classification) loss.
1248
+
1249
+ Args:
1250
+ reduction: Type of reduction to apply to the loss. In almost all cases
1251
+ this should be `"sum_over_batch_size"`. Supported options are
1252
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
1253
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
1254
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
1255
+ sample size, and `"mean_with_sample_weight"` sums the loss and
1256
+ divides by the sum of the sample weights. `"none"` and `None`
1257
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
1258
+ name: Optional name for the loss instance.
1259
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
1260
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
1261
+ `"float32"` unless set to different value
1262
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
1263
+ provided, then the `compute_dtype` will be utilized.
1264
+ """
1265
+
1266
+ def __init__(self, reduction="sum_over_batch_size", name="ctc", dtype=None):
1267
+ super().__init__(ctc, name=name, reduction=reduction, dtype=dtype)
1268
+
1269
+ def get_config(self):
1270
+ return Loss.get_config(self)
1271
+
1272
+
1273
+ @keras_export("keras.losses.Dice")
1274
+ class Dice(LossFunctionWrapper):
1275
+ """Computes the Dice loss value between `y_true` and `y_pred`.
1276
+
1277
+ Formula:
1278
+ ```python
1279
+ loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred))
1280
+ ```
1281
+
1282
+ Args:
1283
+ reduction: Type of reduction to apply to the loss. In almost all cases
1284
+ this should be `"sum_over_batch_size"`. Supported options are
1285
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
1286
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
1287
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
1288
+ sample size, and `"mean_with_sample_weight"` sums the loss and
1289
+ divides by the sum of the sample weights. `"none"` and `None`
1290
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
1291
+ name: Optional name for the loss instance.
1292
+ axis: Tuple for which dimensions the loss is calculated. Defaults to
1293
+ `None`.
1294
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
1295
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
1296
+ `"float32"` unless set to different value
1297
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
1298
+ provided, then the `compute_dtype` will be utilized.
1299
+
1300
+ Returns:
1301
+ Dice loss value.
1302
+
1303
+ Example:
1304
+
1305
+ >>> y_true = [[[[1.0], [1.0]], [[0.0], [0.0]]],
1306
+ ... [[[1.0], [1.0]], [[0.0], [0.0]]]]
1307
+ >>> y_pred = [[[[0.0], [1.0]], [[0.0], [1.0]]],
1308
+ ... [[[0.4], [0.0]], [[0.0], [0.9]]]]
1309
+ >>> axis = (1, 2, 3)
1310
+ >>> loss = keras.losses.dice(y_true, y_pred, axis=axis)
1311
+ >>> assert loss.shape == (2,)
1312
+ >>> loss
1313
+ array([0.5, 0.75757575], shape=(2,), dtype=float32)
1314
+
1315
+ >>> loss = keras.losses.dice(y_true, y_pred)
1316
+ >>> assert loss.shape == ()
1317
+ >>> loss
1318
+ array(0.6164384, shape=(), dtype=float32)
1319
+
1320
+ >>> y_true = np.array(y_true)
1321
+ >>> y_pred = np.array(y_pred)
1322
+ >>> loss = keras.losses.Dice(axis=axis, reduction=None)(y_true, y_pred)
1323
+ >>> assert loss.shape == (2,)
1324
+ >>> loss
1325
+ array([0.5, 0.75757575], shape=(2,), dtype=float32)
1326
+
1327
+ """
1328
+
1329
+ def __init__(
1330
+ self,
1331
+ reduction="sum_over_batch_size",
1332
+ name="dice",
1333
+ axis=None,
1334
+ dtype=None,
1335
+ ):
1336
+ super().__init__(
1337
+ dice, name=name, reduction=reduction, dtype=dtype, axis=axis
1338
+ )
1339
+ self.axis = axis
1340
+
1341
+ def get_config(self):
1342
+ config = Loss.get_config(self)
1343
+ config.update({"axis": self.axis})
1344
+ return config
1345
+
1346
+
1347
+ @keras_export("keras.losses.Tversky")
1348
+ class Tversky(LossFunctionWrapper):
1349
+ """Computes the Tversky loss value between `y_true` and `y_pred`.
1350
+
1351
+ This loss function is weighted by the alpha and beta coefficients
1352
+ that penalize false positives and false negatives.
1353
+
1354
+ With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to
1355
+ Dice Loss.
1356
+
1357
+ Args:
1358
+ alpha: The coefficient controlling incidence of false positives.
1359
+ Defaults to `0.5`.
1360
+ beta: The coefficient controlling incidence of false negatives.
1361
+ Defaults to `0.5`.
1362
+ reduction: Type of reduction to apply to the loss. In almost all cases
1363
+ this should be `"sum_over_batch_size"`. Supported options are
1364
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
1365
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
1366
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
1367
+ sample size, and `"mean_with_sample_weight"` sums the loss and
1368
+ divides by the sum of the sample weights. `"none"` and `None`
1369
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
1370
+ name: Optional name for the loss instance.
1371
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
1372
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
1373
+ `"float32"` unless set to different value
1374
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
1375
+ provided, then the `compute_dtype` will be utilized.
1376
+
1377
+ Returns:
1378
+ Tversky loss value.
1379
+
1380
+ Reference:
1381
+
1382
+ - [Salehi et al., 2017](https://arxiv.org/abs/1706.05721)
1383
+ """
1384
+
1385
+ def __init__(
1386
+ self,
1387
+ alpha=0.5,
1388
+ beta=0.5,
1389
+ reduction="sum_over_batch_size",
1390
+ name="tversky",
1391
+ axis=None,
1392
+ dtype=None,
1393
+ ):
1394
+ super().__init__(
1395
+ tversky,
1396
+ name=name,
1397
+ reduction=reduction,
1398
+ dtype=dtype,
1399
+ alpha=alpha,
1400
+ beta=beta,
1401
+ axis=axis,
1402
+ )
1403
+ self.alpha = alpha
1404
+ self.beta = beta
1405
+ self.axis = axis
1406
+
1407
+ def get_config(self):
1408
+ config = Loss.get_config(self)
1409
+ config.update(
1410
+ {"alpha": self.alpha, "beta": self.beta, "axis": self.axis}
1411
+ )
1412
+ return config
1413
+
1414
+
1415
+ @keras_export("keras.losses.Circle")
1416
+ class Circle(LossFunctionWrapper):
1417
+ """Computes Circle Loss between integer labels and L2-normalized embeddings.
1418
+
1419
+ This is a metric learning loss designed to minimize within-class distance
1420
+ and maximize between-class distance in a flexible manner by dynamically
1421
+ adjusting the penalty strength based on optimization status of each
1422
+ similarity score.
1423
+
1424
+ To use Circle Loss effectively, the model should output embeddings without
1425
+ an activation function (such as a `Dense` layer with `activation=None`)
1426
+ followed by UnitNormalization layer to ensure unit-norm embeddings.
1427
+
1428
+ Args:
1429
+ gamma: Scaling factor that determines the largest scale of each
1430
+ similarity score. Defaults to `80`.
1431
+ margin: The relaxation factor, below this distance, negatives are
1432
+ up weighted and positives are down weighted. Similarly, above this
1433
+ distance negatives are down weighted and positive are up weighted.
1434
+ Defaults to `0.4`.
1435
+ remove_diagonal: Boolean, whether to remove self-similarities from the
1436
+ positive mask. Defaults to `True`.
1437
+ reduction: Type of reduction to apply to the loss. In almost all cases
1438
+ this should be `"sum_over_batch_size"`. Supported options are
1439
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
1440
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
1441
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
1442
+ sample size, and `"mean_with_sample_weight"` sums the loss and
1443
+ divides by the sum of the sample weights. `"none"` and `None`
1444
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
1445
+ name: Optional name for the loss instance.
1446
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
1447
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
1448
+ `"float32"` unless set to different value
1449
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
1450
+ provided, then the `compute_dtype` will be utilized.
1451
+
1452
+ Examples:
1453
+
1454
+ Usage with the `compile()` API:
1455
+
1456
+ ```python
1457
+ model = models.Sequential([
1458
+ keras.layers.Input(shape=(224, 224, 3)),
1459
+ keras.layers.Conv2D(16, (3, 3), activation='relu'),
1460
+ keras.layers.Flatten(),
1461
+ keras.layers.Dense(64, activation=None), # No activation
1462
+ keras.layers.UnitNormalization() # L2 normalization
1463
+ ])
1464
+
1465
+ model.compile(optimizer="adam", loss=keras.losses.Circle())
1466
+ ```
1467
+
1468
+ Reference:
1469
+ - [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857)
1470
+
1471
+ """
1472
+
1473
+ def __init__(
1474
+ self,
1475
+ gamma=80.0,
1476
+ margin=0.4,
1477
+ remove_diagonal=True,
1478
+ reduction="sum_over_batch_size",
1479
+ name="circle",
1480
+ dtype=None,
1481
+ ):
1482
+ super().__init__(
1483
+ circle,
1484
+ name=name,
1485
+ reduction=reduction,
1486
+ dtype=dtype,
1487
+ gamma=gamma,
1488
+ margin=margin,
1489
+ remove_diagonal=remove_diagonal,
1490
+ )
1491
+ self.gamma = gamma
1492
+ self.margin = margin
1493
+ self.remove_diagonal = remove_diagonal
1494
+
1495
+ def get_config(self):
1496
+ config = Loss.get_config(self)
1497
+ config.update(
1498
+ {
1499
+ "gamma": self.gamma,
1500
+ "margin": self.margin,
1501
+ "remove_diagonal": self.remove_diagonal,
1502
+ }
1503
+ )
1504
+ return config
1505
+
1506
+
1507
+ def convert_binary_labels_to_hinge(y_true):
1508
+ """Converts binary labels into -1/1 for hinge loss/metric calculation."""
1509
+ are_zeros = ops.equal(y_true, 0)
1510
+ are_ones = ops.equal(y_true, 1)
1511
+ is_binary = ops.all((ops.logical_or(are_zeros, are_ones)))
1512
+
1513
+ def _convert_binary_labels():
1514
+ # Convert the binary labels to -1 or 1.
1515
+ return 2.0 * y_true - 1.0
1516
+
1517
+ def _return_labels_unconverted():
1518
+ # Returns the labels unchanged if they are non-binary
1519
+ return y_true
1520
+
1521
+ updated_y_true = ops.cond(
1522
+ is_binary, _convert_binary_labels, _return_labels_unconverted
1523
+ )
1524
+ return updated_y_true
1525
+
1526
+
1527
+ @keras_export(
1528
+ [
1529
+ "keras.metrics.hinge",
1530
+ "keras.losses.hinge",
1531
+ ]
1532
+ )
1533
+ def hinge(y_true, y_pred):
1534
+ """Computes the hinge loss between `y_true` & `y_pred`.
1535
+
1536
+ Formula:
1537
+
1538
+ ```python
1539
+ loss = mean(maximum(1 - y_true * y_pred, 0), axis=-1)
1540
+ ```
1541
+
1542
+ Args:
1543
+ y_true: The ground truth values. `y_true` values are expected to be -1
1544
+ or 1. If binary (0 or 1) labels are provided they will be converted
1545
+ to -1 or 1 with shape = `[batch_size, d0, .. dN]`.
1546
+ y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
1547
+
1548
+ Returns:
1549
+ Hinge loss values with shape = `[batch_size, d0, .. dN-1]`.
1550
+
1551
+ Example:
1552
+
1553
+ >>> y_true = np.random.choice([-1, 1], size=(2, 3))
1554
+ >>> y_pred = np.random.random(size=(2, 3))
1555
+ >>> loss = keras.losses.hinge(y_true, y_pred)
1556
+ """
1557
+ y_pred = ops.convert_to_tensor(y_pred)
1558
+ y_true = ops.cast(y_true, dtype=y_pred.dtype)
1559
+ y_true = ops.convert_to_tensor(y_true)
1560
+ y_true = convert_binary_labels_to_hinge(y_true)
1561
+ return ops.mean(ops.maximum(1.0 - y_true * y_pred, 0.0), axis=-1)
1562
+
1563
+
1564
+ @keras_export(
1565
+ [
1566
+ "keras.metrics.squared_hinge",
1567
+ "keras.losses.squared_hinge",
1568
+ ]
1569
+ )
1570
+ def squared_hinge(y_true, y_pred):
1571
+ """Computes the squared hinge loss between `y_true` & `y_pred`.
1572
+
1573
+ Formula:
1574
+
1575
+ ```python
1576
+ loss = mean(square(maximum(1 - y_true * y_pred, 0)), axis=-1)
1577
+ ```
1578
+
1579
+ Args:
1580
+ y_true: The ground truth values. `y_true` values are expected to be -1
1581
+ or 1. If binary (0 or 1) labels are provided we will convert them
1582
+ to -1 or 1 with shape = `[batch_size, d0, .. dN]`.
1583
+ y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
1584
+
1585
+ Returns:
1586
+ Squared hinge loss values with shape = `[batch_size, d0, .. dN-1]`.
1587
+
1588
+ Example:
1589
+
1590
+ >>> y_true = np.random.choice([-1, 1], size=(2, 3))
1591
+ >>> y_pred = np.random.random(size=(2, 3))
1592
+ >>> loss = keras.losses.squared_hinge(y_true, y_pred)
1593
+ """
1594
+ y_pred = ops.convert_to_tensor(y_pred)
1595
+ y_true = ops.cast(y_true, y_pred.dtype)
1596
+ y_true = convert_binary_labels_to_hinge(y_true)
1597
+ return ops.mean(
1598
+ ops.square(ops.maximum(1.0 - y_true * y_pred, 0.0)), axis=-1
1599
+ )
1600
+
1601
+
1602
+ @keras_export(
1603
+ [
1604
+ "keras.metrics.categorical_hinge",
1605
+ "keras.losses.categorical_hinge",
1606
+ ]
1607
+ )
1608
+ def categorical_hinge(y_true, y_pred):
1609
+ """Computes the categorical hinge loss between `y_true` & `y_pred`.
1610
+
1611
+ Formula:
1612
+
1613
+ ```python
1614
+ loss = maximum(neg - pos + 1, 0)
1615
+ ```
1616
+
1617
+ where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)`
1618
+
1619
+ Args:
1620
+ y_true: The ground truth values. `y_true` values are expected to be
1621
+ either `{-1, +1}` or `{0, 1}` (i.e. a one-hot-encoded tensor) with
1622
+ shape = `[batch_size, d0, .. dN]`.
1623
+ y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
1624
+
1625
+ Returns:
1626
+ Categorical hinge loss values with shape = `[batch_size, d0, .. dN-1]`.
1627
+
1628
+ Example:
1629
+
1630
+ >>> y_true = np.random.randint(0, 3, size=(2,))
1631
+ >>> y_true = np.eye(np.max(y_true) + 1)[y_true]
1632
+ >>> y_pred = np.random.random(size=(2, 3))
1633
+ >>> loss = keras.losses.categorical_hinge(y_true, y_pred)
1634
+ """
1635
+ y_pred = ops.convert_to_tensor(y_pred)
1636
+ y_true = ops.cast(y_true, y_pred.dtype)
1637
+ pos = ops.sum(y_true * y_pred, axis=-1)
1638
+ neg = ops.max((1.0 - y_true) * y_pred, axis=-1)
1639
+ zero = ops.cast(0.0, y_pred.dtype)
1640
+ return ops.maximum(neg - pos + 1.0, zero)
1641
+
1642
+
1643
+ @keras_export(
1644
+ [
1645
+ "keras.metrics.mean_squared_error",
1646
+ "keras.losses.mean_squared_error",
1647
+ # Legacy aliases
1648
+ "keras._legacy.losses.mse",
1649
+ "keras._legacy.losses.MSE",
1650
+ "keras._legacy.metrics.mse",
1651
+ "keras._legacy.metrics.MSE",
1652
+ ]
1653
+ )
1654
+ def mean_squared_error(y_true, y_pred):
1655
+ """Computes the mean squared error between labels and predictions.
1656
+
1657
+ Formula:
1658
+
1659
+ ```python
1660
+ loss = mean(square(y_true - y_pred), axis=-1)
1661
+ ```
1662
+
1663
+ Example:
1664
+
1665
+ >>> y_true = np.random.randint(0, 2, size=(2, 3))
1666
+ >>> y_pred = np.random.random(size=(2, 3))
1667
+ >>> loss = keras.losses.mean_squared_error(y_true, y_pred)
1668
+
1669
+ Args:
1670
+ y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
1671
+ y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
1672
+
1673
+ Returns:
1674
+ Mean squared error values with shape = `[batch_size, d0, .. dN-1]`.
1675
+ """
1676
+ y_pred = ops.convert_to_tensor(y_pred)
1677
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
1678
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
1679
+ return ops.mean(ops.square(y_true - y_pred), axis=-1)
1680
+
1681
+
1682
+ @keras_export(
1683
+ [
1684
+ "keras.metrics.mean_absolute_error",
1685
+ "keras.losses.mean_absolute_error",
1686
+ # Legacy aliases
1687
+ "keras._legacy.losses.MAE",
1688
+ "keras._legacy.losses.mae",
1689
+ "keras._legacy.metrics.MAE",
1690
+ "keras._legacy.metrics.mae",
1691
+ ]
1692
+ )
1693
+ def mean_absolute_error(y_true, y_pred):
1694
+ """Computes the mean absolute error between labels and predictions.
1695
+
1696
+ ```python
1697
+ loss = mean(abs(y_true - y_pred), axis=-1)
1698
+ ```
1699
+
1700
+ Args:
1701
+ y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
1702
+ y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
1703
+
1704
+ Returns:
1705
+ Mean absolute error values with shape = `[batch_size, d0, .. dN-1]`.
1706
+
1707
+ Example:
1708
+
1709
+ >>> y_true = np.random.randint(0, 2, size=(2, 3))
1710
+ >>> y_pred = np.random.random(size=(2, 3))
1711
+ >>> loss = keras.losses.mean_absolute_error(y_true, y_pred)
1712
+ """
1713
+ y_pred = ops.convert_to_tensor(y_pred)
1714
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
1715
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
1716
+ return ops.mean(ops.abs(y_true - y_pred), axis=-1)
1717
+
1718
+
1719
+ @keras_export(
1720
+ [
1721
+ "keras.metrics.mean_absolute_percentage_error",
1722
+ "keras.losses.mean_absolute_percentage_error",
1723
+ # Legacy aliases
1724
+ "keras._legacy.losses.mape",
1725
+ "keras._legacy.losses.MAPE",
1726
+ "keras._legacy.metrics.mape",
1727
+ "keras._legacy.metrics.MAPE",
1728
+ ]
1729
+ )
1730
+ def mean_absolute_percentage_error(y_true, y_pred):
1731
+ """Computes the mean absolute percentage error between `y_true` & `y_pred`.
1732
+
1733
+ Formula:
1734
+
1735
+ ```python
1736
+ loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1)
1737
+ ```
1738
+
1739
+ Division by zero is prevented by dividing by `maximum(y_true, epsilon)`
1740
+ where `epsilon = keras.backend.epsilon()`
1741
+ (default to `1e-7`).
1742
+
1743
+ Args:
1744
+ y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
1745
+ y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
1746
+
1747
+ Returns:
1748
+ Mean absolute percentage error values with shape = `[batch_size, d0, ..
1749
+ dN-1]`.
1750
+
1751
+ Example:
1752
+
1753
+ >>> y_true = np.random.random(size=(2, 3))
1754
+ >>> y_pred = np.random.random(size=(2, 3))
1755
+ >>> loss = keras.losses.mean_absolute_percentage_error(y_true, y_pred)
1756
+ """
1757
+ y_pred = ops.convert_to_tensor(y_pred)
1758
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
1759
+ epsilon = ops.convert_to_tensor(backend.epsilon(), dtype=y_pred.dtype)
1760
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
1761
+ diff = ops.abs((y_true - y_pred) / ops.maximum(ops.abs(y_true), epsilon))
1762
+ return 100.0 * ops.mean(diff, axis=-1)
1763
+
1764
+
1765
+ @keras_export(
1766
+ [
1767
+ "keras.metrics.mean_squared_logarithmic_error",
1768
+ "keras.losses.mean_squared_logarithmic_error",
1769
+ # Legacy aliases
1770
+ "keras._legacy.losses.msle",
1771
+ "keras._legacy.losses.MSLE",
1772
+ "keras._legacy.metrics.msle",
1773
+ "keras._legacy.metrics.MSLE",
1774
+ ]
1775
+ )
1776
+ def mean_squared_logarithmic_error(y_true, y_pred):
1777
+ """Computes the mean squared logarithmic error between `y_true` & `y_pred`.
1778
+
1779
+ Formula:
1780
+
1781
+ ```python
1782
+ loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)
1783
+ ```
1784
+
1785
+ Note that `y_pred` and `y_true` cannot be less or equal to 0. Negative
1786
+ values and 0 values will be replaced with `keras.backend.epsilon()`
1787
+ (default to `1e-7`).
1788
+
1789
+ Args:
1790
+ y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
1791
+ y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
1792
+
1793
+ Returns:
1794
+ Mean squared logarithmic error values with shape = `[batch_size, d0, ..
1795
+ dN-1]`.
1796
+
1797
+ Example:
1798
+
1799
+ >>> y_true = np.random.randint(0, 2, size=(2, 3))
1800
+ >>> y_pred = np.random.random(size=(2, 3))
1801
+ >>> loss = keras.losses.mean_squared_logarithmic_error(y_true, y_pred)
1802
+ """
1803
+ epsilon = ops.convert_to_tensor(backend.epsilon())
1804
+ y_pred = ops.convert_to_tensor(y_pred)
1805
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
1806
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
1807
+ first_log = ops.log(ops.maximum(y_pred, epsilon) + 1.0)
1808
+ second_log = ops.log(ops.maximum(y_true, epsilon) + 1.0)
1809
+ return ops.mean(ops.square(first_log - second_log), axis=-1)
1810
+
1811
+
1812
+ @keras_export("keras.losses.cosine_similarity")
1813
+ def cosine_similarity(y_true, y_pred, axis=-1):
1814
+ """Computes the cosine similarity between labels and predictions.
1815
+
1816
+ Formula:
1817
+ ```python
1818
+ loss = -sum(l2_norm(y_true) * l2_norm(y_pred))
1819
+ ```
1820
+
1821
+ Note that it is a number between -1 and 1. When it is a negative number
1822
+ between -1 and 0, 0 indicates orthogonality and values closer to -1
1823
+ indicate greater similarity. This makes it usable as a loss function in a
1824
+ setting where you try to maximize the proximity between predictions and
1825
+ targets. If either `y_true` or `y_pred` is a zero vector, cosine
1826
+ similarity will be 0 regardless of the proximity between predictions
1827
+ and targets.
1828
+
1829
+ Args:
1830
+ y_true: Tensor of true targets.
1831
+ y_pred: Tensor of predicted targets.
1832
+ axis: Axis along which to determine similarity. Defaults to `-1`.
1833
+
1834
+ Returns:
1835
+ Cosine similarity tensor.
1836
+
1837
+ Example:
1838
+
1839
+ >>> y_true = [[0., 1.], [1., 1.], [1., 1.]]
1840
+ >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]]
1841
+ >>> loss = keras.losses.cosine_similarity(y_true, y_pred, axis=-1)
1842
+ [-0., -0.99999994, 0.99999994]
1843
+ """
1844
+ y_pred = ops.convert_to_tensor(y_pred)
1845
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
1846
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
1847
+ y_pred = normalize(y_pred, axis=axis)
1848
+ y_true = normalize(y_true, axis=axis)
1849
+ return -ops.sum(y_true * y_pred, axis=axis)
1850
+
1851
+
1852
+ @keras_export(["keras.losses.huber", "keras.metrics.huber"])
1853
+ def huber(y_true, y_pred, delta=1.0):
1854
+ """Computes Huber loss value.
1855
+
1856
+ Formula:
1857
+ ```python
1858
+ for x in error:
1859
+ if abs(x) <= delta:
1860
+ loss.append(0.5 * x^2)
1861
+ elif abs(x) > delta:
1862
+ loss.append(delta * abs(x) - 0.5 * delta^2)
1863
+
1864
+ loss = mean(loss, axis=-1)
1865
+ ```
1866
+ See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).
1867
+
1868
+ Example:
1869
+
1870
+ >>> y_true = [[0, 1], [0, 0]]
1871
+ >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
1872
+ >>> loss = keras.losses.huber(y_true, y_pred)
1873
+ 0.155
1874
+
1875
+
1876
+ Args:
1877
+ y_true: tensor of true targets.
1878
+ y_pred: tensor of predicted targets.
1879
+ delta: A float, the point where the Huber loss function changes from a
1880
+ quadratic to linear. Defaults to `1.0`.
1881
+
1882
+ Returns:
1883
+ Tensor with one scalar loss entry per sample.
1884
+ """
1885
+ y_pred = ops.convert_to_tensor(y_pred)
1886
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
1887
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
1888
+ delta = ops.convert_to_tensor(delta, dtype=y_pred.dtype)
1889
+ error = ops.subtract(y_pred, y_true)
1890
+ abs_error = ops.abs(error)
1891
+ half = ops.convert_to_tensor(0.5, dtype=abs_error.dtype)
1892
+ return ops.mean(
1893
+ ops.where(
1894
+ abs_error <= delta,
1895
+ half * ops.square(error),
1896
+ delta * abs_error - half * ops.square(delta),
1897
+ ),
1898
+ axis=-1,
1899
+ )
1900
+
1901
+
1902
+ @keras_export(
1903
+ [
1904
+ "keras.losses.log_cosh",
1905
+ "keras.metrics.log_cosh",
1906
+ # Legacy aliases
1907
+ "keras._legacy.losses.logcosh",
1908
+ "keras._legacy.metrics.logcosh",
1909
+ ]
1910
+ )
1911
+ def log_cosh(y_true, y_pred):
1912
+ """Logarithm of the hyperbolic cosine of the prediction error.
1913
+
1914
+ Formula:
1915
+ ```python
1916
+ loss = mean(log(cosh(y_pred - y_true)), axis=-1)
1917
+ ```
1918
+
1919
+ Note that `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small
1920
+ `x` and to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works
1921
+ mostly like the mean squared error, but will not be so strongly affected by
1922
+ the occasional wildly incorrect prediction.
1923
+
1924
+ Example:
1925
+
1926
+ >>> y_true = [[0., 1.], [0., 0.]]
1927
+ >>> y_pred = [[1., 1.], [0., 0.]]
1928
+ >>> loss = keras.losses.log_cosh(y_true, y_pred)
1929
+ 0.108
1930
+
1931
+ Args:
1932
+ y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
1933
+ y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
1934
+
1935
+ Returns:
1936
+ Logcosh error values with shape = `[batch_size, d0, .. dN-1]`.
1937
+ """
1938
+ y_pred = ops.convert_to_tensor(y_pred)
1939
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
1940
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
1941
+ log2 = ops.convert_to_tensor(ops.log(2.0), dtype=y_pred.dtype)
1942
+
1943
+ def _logcosh(x):
1944
+ return x + ops.softplus(x * -2.0) - log2
1945
+
1946
+ return ops.mean(_logcosh(y_pred - y_true), axis=-1)
1947
+
1948
+
1949
+ @keras_export(
1950
+ [
1951
+ "keras.metrics.kl_divergence",
1952
+ "keras.losses.kl_divergence",
1953
+ # Legacy aliases
1954
+ "keras._legacy.losses.KLD",
1955
+ "keras._legacy.losses.kld",
1956
+ "keras._legacy.losses.kullback_leibler_divergence",
1957
+ "keras._legacy.metrics.KLD",
1958
+ "keras._legacy.metrics.kld",
1959
+ "keras._legacy.metrics.kullback_leibler_divergence",
1960
+ ]
1961
+ )
1962
+ def kl_divergence(y_true, y_pred):
1963
+ """Computes Kullback-Leibler divergence loss between `y_true` & `y_pred`.
1964
+
1965
+ Formula:
1966
+
1967
+ ```python
1968
+ loss = y_true * log(y_true / y_pred)
1969
+ ```
1970
+
1971
+ `y_true` and `y_pred` are expected to be probability
1972
+ distributions, with values between 0 and 1. They will get
1973
+ clipped to the `[0, 1]` range.
1974
+
1975
+ Args:
1976
+ y_true: Tensor of true targets.
1977
+ y_pred: Tensor of predicted targets.
1978
+
1979
+ Returns:
1980
+ KL Divergence loss values with shape = `[batch_size, d0, .. dN-1]`.
1981
+
1982
+ Example:
1983
+
1984
+ >>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float32)
1985
+ >>> y_pred = np.random.random(size=(2, 3))
1986
+ >>> loss = keras.losses.kl_divergence(y_true, y_pred)
1987
+ >>> assert loss.shape == (2,)
1988
+ >>> y_true = ops.clip(y_true, 1e-7, 1)
1989
+ >>> y_pred = ops.clip(y_pred, 1e-7, 1)
1990
+ >>> assert np.array_equal(
1991
+ ... loss, np.sum(y_true * np.log(y_true / y_pred), axis=-1))
1992
+ """
1993
+ y_pred = ops.convert_to_tensor(y_pred)
1994
+ y_true = ops.convert_to_tensor(y_true, y_pred.dtype)
1995
+ y_true = ops.clip(y_true, backend.epsilon(), 1)
1996
+ y_pred = ops.clip(y_pred, backend.epsilon(), 1)
1997
+ return ops.sum(y_true * ops.log(y_true / y_pred), axis=-1)
1998
+
1999
+
2000
+ @keras_export(
2001
+ [
2002
+ "keras.metrics.poisson",
2003
+ "keras.losses.poisson",
2004
+ ]
2005
+ )
2006
+ def poisson(y_true, y_pred):
2007
+ """Computes the Poisson loss between y_true and y_pred.
2008
+
2009
+ Formula:
2010
+
2011
+ ```python
2012
+ loss = y_pred - y_true * log(y_pred)
2013
+ ```
2014
+
2015
+ Args:
2016
+ y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
2017
+ y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
2018
+
2019
+ Returns:
2020
+ Poisson loss values with shape = `[batch_size, d0, .. dN-1]`.
2021
+
2022
+ Example:
2023
+
2024
+ >>> y_true = np.random.randint(0, 2, size=(2, 3))
2025
+ >>> y_pred = np.random.random(size=(2, 3))
2026
+ >>> loss = keras.losses.poisson(y_true, y_pred)
2027
+ >>> assert loss.shape == (2,)
2028
+ >>> y_pred = y_pred + 1e-7
2029
+ >>> assert np.allclose(
2030
+ ... loss, np.mean(y_pred - y_true * np.log(y_pred), axis=-1),
2031
+ ... atol=1e-5)
2032
+ """
2033
+ y_pred = ops.convert_to_tensor(y_pred)
2034
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
2035
+ epsilon = ops.convert_to_tensor(backend.epsilon(), dtype=y_pred.dtype)
2036
+ return ops.mean(y_pred - y_true * ops.log(y_pred + epsilon), axis=-1)
2037
+
2038
+
2039
+ @keras_export(
2040
+ [
2041
+ "keras.metrics.categorical_crossentropy",
2042
+ "keras.losses.categorical_crossentropy",
2043
+ ]
2044
+ )
2045
+ def categorical_crossentropy(
2046
+ y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1
2047
+ ):
2048
+ """Computes the categorical crossentropy loss.
2049
+
2050
+ Args:
2051
+ y_true: Tensor of one-hot true targets.
2052
+ y_pred: Tensor of predicted targets.
2053
+ from_logits: Whether `y_pred` is expected to be a logits tensor. By
2054
+ default, we assume that `y_pred` encodes a probability distribution.
2055
+ label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For
2056
+ example, if `0.1`, use `0.1 / num_classes` for non-target labels
2057
+ and `0.9 + 0.1 / num_classes` for target labels.
2058
+ axis: Defaults to `-1`. The dimension along which the entropy is
2059
+ computed.
2060
+
2061
+ Returns:
2062
+ Categorical crossentropy loss value.
2063
+
2064
+ Example:
2065
+
2066
+ >>> y_true = [[0, 1, 0], [0, 0, 1]]
2067
+ >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
2068
+ >>> loss = keras.losses.categorical_crossentropy(y_true, y_pred)
2069
+ >>> assert loss.shape == (2,)
2070
+ >>> loss
2071
+ array([0.0513, 2.303], dtype=float32)
2072
+ """
2073
+ if isinstance(axis, bool):
2074
+ raise ValueError(
2075
+ "`axis` must be of type `int`. "
2076
+ f"Received: axis={axis} of type {type(axis)}"
2077
+ )
2078
+ y_pred = ops.convert_to_tensor(y_pred)
2079
+ y_true = ops.cast(y_true, y_pred.dtype)
2080
+
2081
+ if y_pred.shape[-1] == 1:
2082
+ warnings.warn(
2083
+ "In loss categorical_crossentropy, expected "
2084
+ "y_pred.shape to be (batch_size, num_classes) "
2085
+ f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. "
2086
+ "Consider using 'binary_crossentropy' if you only have 2 classes.",
2087
+ SyntaxWarning,
2088
+ stacklevel=2,
2089
+ )
2090
+
2091
+ if label_smoothing:
2092
+ num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype)
2093
+ y_true = y_true * (1.0 - label_smoothing) + (
2094
+ label_smoothing / num_classes
2095
+ )
2096
+
2097
+ return ops.categorical_crossentropy(
2098
+ y_true, y_pred, from_logits=from_logits, axis=axis
2099
+ )
2100
+
2101
+
2102
+ @keras_export(
2103
+ [
2104
+ "keras.metrics.categorical_focal_crossentropy",
2105
+ "keras.losses.categorical_focal_crossentropy",
2106
+ ]
2107
+ )
2108
+ def categorical_focal_crossentropy(
2109
+ y_true,
2110
+ y_pred,
2111
+ alpha=0.25,
2112
+ gamma=2.0,
2113
+ from_logits=False,
2114
+ label_smoothing=0.0,
2115
+ axis=-1,
2116
+ ):
2117
+ """Computes the categorical focal crossentropy loss.
2118
+
2119
+ Args:
2120
+ y_true: Tensor of one-hot true targets.
2121
+ y_pred: Tensor of predicted targets.
2122
+ alpha: A weight balancing factor for all classes, default is `0.25` as
2123
+ mentioned in the reference. It can be a list of floats or a scalar.
2124
+ In the multi-class case, alpha may be set by inverse class
2125
+ frequency by using `compute_class_weight` from `sklearn.utils`.
2126
+ gamma: A focusing parameter, default is `2.0` as mentioned in the
2127
+ reference. It helps to gradually reduce the importance given to
2128
+ simple examples in a smooth manner. When `gamma` = 0, there is
2129
+ no focal effect on the categorical crossentropy.
2130
+ from_logits: Whether `y_pred` is expected to be a logits tensor. By
2131
+ default, we assume that `y_pred` encodes a probability
2132
+ distribution.
2133
+ label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For
2134
+ example, if `0.1`, use `0.1 / num_classes` for non-target labels
2135
+ and `0.9 + 0.1 / num_classes` for target labels.
2136
+ axis: Defaults to `-1`. The dimension along which the entropy is
2137
+ computed.
2138
+
2139
+ Returns:
2140
+ Categorical focal crossentropy loss value.
2141
+
2142
+ Example:
2143
+
2144
+ >>> y_true = [[0, 1, 0], [0, 0, 1]]
2145
+ >>> y_pred = [[0.05, 0.9, 0.05], [0.1, 0.85, 0.05]]
2146
+ >>> loss = keras.losses.categorical_focal_crossentropy(y_true, y_pred)
2147
+ >>> assert loss.shape == (2,)
2148
+ >>> loss
2149
+ array([2.63401289e-04, 6.75912094e-01], dtype=float32)
2150
+ """
2151
+ if isinstance(axis, bool):
2152
+ raise ValueError(
2153
+ "`axis` must be of type `int`. "
2154
+ f"Received: axis={axis} of type {type(axis)}"
2155
+ )
2156
+ y_pred = ops.convert_to_tensor(y_pred)
2157
+ y_true = ops.cast(y_true, y_pred.dtype)
2158
+
2159
+ if y_pred.shape[-1] == 1:
2160
+ warnings.warn(
2161
+ "In loss categorical_focal_crossentropy, expected "
2162
+ "y_pred.shape to be (batch_size, num_classes) "
2163
+ f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. "
2164
+ "Consider using 'binary_crossentropy' if you only have 2 classes.",
2165
+ SyntaxWarning,
2166
+ stacklevel=2,
2167
+ )
2168
+
2169
+ if label_smoothing:
2170
+ num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype)
2171
+ y_true = y_true * (1.0 - label_smoothing) + (
2172
+ label_smoothing / num_classes
2173
+ )
2174
+
2175
+ if from_logits:
2176
+ y_pred = ops.softmax(y_pred, axis=axis)
2177
+
2178
+ # Adjust the predictions so that the probability of
2179
+ # each class for every sample adds up to 1
2180
+ # This is needed to ensure that the cross entropy is
2181
+ # computed correctly.
2182
+ output = y_pred / ops.sum(y_pred, axis=axis, keepdims=True)
2183
+ output = ops.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
2184
+
2185
+ # Calculate cross entropy
2186
+ cce = -y_true * ops.log(output)
2187
+
2188
+ # Calculate factors
2189
+ modulating_factor = ops.power(1.0 - output, gamma)
2190
+ weighting_factor = ops.multiply(modulating_factor, alpha)
2191
+
2192
+ # Apply weighting factor
2193
+ focal_cce = ops.multiply(weighting_factor, cce)
2194
+ focal_cce = ops.sum(focal_cce, axis=axis)
2195
+ return focal_cce
2196
+
2197
+
2198
+ @keras_export(
2199
+ [
2200
+ "keras.metrics.sparse_categorical_crossentropy",
2201
+ "keras.losses.sparse_categorical_crossentropy",
2202
+ ]
2203
+ )
2204
+ def sparse_categorical_crossentropy(
2205
+ y_true, y_pred, from_logits=False, ignore_class=None, axis=-1
2206
+ ):
2207
+ """Computes the sparse categorical crossentropy loss.
2208
+
2209
+ Args:
2210
+ y_true: Ground truth values.
2211
+ y_pred: The predicted values.
2212
+ from_logits: Whether `y_pred` is expected to be a logits tensor. By
2213
+ default, we assume that `y_pred` encodes a probability distribution.
2214
+ ignore_class: Optional integer. The ID of a class to be ignored during
2215
+ loss computation. This is useful, for example, in segmentation
2216
+ problems featuring a "void" class (commonly -1 or 255) in
2217
+ segmentation maps. By default (`ignore_class=None`), all classes are
2218
+ considered.
2219
+ axis: Defaults to `-1`. The dimension along which the entropy is
2220
+ computed.
2221
+
2222
+ Returns:
2223
+ Sparse categorical crossentropy loss value.
2224
+
2225
+ Examples:
2226
+
2227
+ >>> y_true = [1, 2]
2228
+ >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
2229
+ >>> loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
2230
+ >>> assert loss.shape == (2,)
2231
+ >>> loss
2232
+ array([0.0513, 2.303], dtype=float32)
2233
+ """
2234
+
2235
+ if len(y_true.shape) == len(y_pred.shape) and y_true.shape[-1] == 1:
2236
+ y_true = ops.squeeze(y_true, axis=-1)
2237
+
2238
+ if ignore_class is not None:
2239
+ res_shape = ops.shape(y_pred)[:-1]
2240
+ valid_mask = ops.not_equal(y_true, ops.cast(ignore_class, y_pred.dtype))
2241
+ y_true = y_true * ops.cast(valid_mask, y_true.dtype)
2242
+ y_pred = y_pred * ops.cast(
2243
+ ops.expand_dims(valid_mask, -1), y_pred.dtype
2244
+ )
2245
+
2246
+ res = ops.sparse_categorical_crossentropy(
2247
+ y_true,
2248
+ y_pred,
2249
+ from_logits=from_logits,
2250
+ axis=axis,
2251
+ )
2252
+
2253
+ if ignore_class is not None:
2254
+ valid_mask = ops.reshape(valid_mask, res_shape)
2255
+ res = ops.where(valid_mask, res, 0.0)
2256
+ backend.set_keras_mask(res, mask=valid_mask)
2257
+
2258
+ return res
2259
+
2260
+
2261
+ @keras_export(
2262
+ [
2263
+ "keras.metrics.binary_crossentropy",
2264
+ "keras.losses.binary_crossentropy",
2265
+ ]
2266
+ )
2267
+ def binary_crossentropy(
2268
+ y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1
2269
+ ):
2270
+ """Computes the binary crossentropy loss.
2271
+
2272
+ Args:
2273
+ y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
2274
+ y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
2275
+ from_logits: Whether `y_pred` is expected to be a logits tensor. By
2276
+ default, we assume that `y_pred` encodes a probability distribution.
2277
+ label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by
2278
+ squeezing them towards 0.5, that is,
2279
+ using `1. - 0.5 * label_smoothing` for the target class
2280
+ and `0.5 * label_smoothing` for the non-target class.
2281
+ axis: The axis along which the mean is computed. Defaults to `-1`.
2282
+
2283
+ Returns:
2284
+ Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`.
2285
+
2286
+ Example:
2287
+
2288
+ >>> y_true = [[0, 1], [0, 0]]
2289
+ >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
2290
+ >>> loss = keras.losses.binary_crossentropy(y_true, y_pred)
2291
+ >>> assert loss.shape == (2,)
2292
+ >>> loss
2293
+ array([0.916 , 0.714], dtype=float32)
2294
+ """
2295
+ y_pred = ops.convert_to_tensor(y_pred)
2296
+ y_true = ops.cast(y_true, y_pred.dtype)
2297
+
2298
+ if label_smoothing:
2299
+ y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
2300
+
2301
+ return ops.mean(
2302
+ ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
2303
+ axis=axis,
2304
+ )
2305
+
2306
+
2307
+ @keras_export(
2308
+ [
2309
+ "keras.metrics.binary_focal_crossentropy",
2310
+ "keras.losses.binary_focal_crossentropy",
2311
+ ]
2312
+ )
2313
+ def binary_focal_crossentropy(
2314
+ y_true,
2315
+ y_pred,
2316
+ apply_class_balancing=False,
2317
+ alpha=0.25,
2318
+ gamma=2.0,
2319
+ from_logits=False,
2320
+ label_smoothing=0.0,
2321
+ axis=-1,
2322
+ ):
2323
+ """Computes the binary focal crossentropy loss.
2324
+
2325
+ According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
2326
+ helps to apply a focal factor to down-weight easy examples and focus more on
2327
+ hard examples. By default, the focal tensor is computed as follows:
2328
+
2329
+ `focal_factor = (1 - output) ** gamma` for class 1
2330
+ `focal_factor = output ** gamma` for class 0
2331
+ where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal
2332
+ effect on the binary crossentropy loss.
2333
+
2334
+ If `apply_class_balancing == True`, this function also takes into account a
2335
+ weight balancing factor for the binary classes 0 and 1 as follows:
2336
+
2337
+ `weight = alpha` for class 1 (`target == 1`)
2338
+ `weight = 1 - alpha` for class 0
2339
+ where `alpha` is a float in the range of `[0, 1]`.
2340
+
2341
+ Args:
2342
+ y_true: Ground truth values, of shape `(batch_size, d0, .. dN)`.
2343
+ y_pred: The predicted values, of shape `(batch_size, d0, .. dN)`.
2344
+ apply_class_balancing: A bool, whether to apply weight balancing on the
2345
+ binary classes 0 and 1.
2346
+ alpha: A weight balancing factor for class 1, default is `0.25` as
2347
+ mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
2348
+ gamma: A focusing parameter, default is `2.0` as mentioned in the
2349
+ reference.
2350
+ from_logits: Whether `y_pred` is expected to be a logits tensor. By
2351
+ default, we assume that `y_pred` encodes a probability distribution.
2352
+ label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by
2353
+ squeezing them towards 0.5, that is,
2354
+ using `1. - 0.5 * label_smoothing` for the target class
2355
+ and `0.5 * label_smoothing` for the non-target class.
2356
+ axis: The axis along which the mean is computed. Defaults to `-1`.
2357
+
2358
+ Returns:
2359
+ Binary focal crossentropy loss value
2360
+ with shape = `[batch_size, d0, .. dN-1]`.
2361
+
2362
+ Example:
2363
+
2364
+ >>> y_true = [[0, 1], [0, 0]]
2365
+ >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
2366
+ >>> loss = keras.losses.binary_focal_crossentropy(
2367
+ ... y_true, y_pred, gamma=2)
2368
+ >>> assert loss.shape == (2,)
2369
+ >>> loss
2370
+ array([0.330, 0.206], dtype=float32)
2371
+ """
2372
+ y_pred = ops.convert_to_tensor(y_pred)
2373
+ y_true = ops.cast(y_true, y_pred.dtype)
2374
+
2375
+ if label_smoothing:
2376
+ y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
2377
+
2378
+ if from_logits:
2379
+ y_pred = ops.sigmoid(y_pred)
2380
+
2381
+ bce = ops.binary_crossentropy(
2382
+ target=y_true,
2383
+ output=y_pred,
2384
+ from_logits=False,
2385
+ )
2386
+
2387
+ # Calculate focal factor
2388
+ p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
2389
+ focal_factor = ops.power(1.0 - p_t, gamma)
2390
+
2391
+ focal_bce = focal_factor * bce
2392
+
2393
+ if apply_class_balancing:
2394
+ weight = y_true * alpha + (1 - y_true) * (1 - alpha)
2395
+ focal_bce = weight * focal_bce
2396
+
2397
+ return ops.mean(focal_bce, axis=axis)
2398
+
2399
+
2400
+ @keras_export("keras.losses.ctc")
2401
+ def ctc(y_true, y_pred):
2402
+ """CTC (Connectionist Temporal Classification) loss.
2403
+
2404
+ Args:
2405
+ y_true: A tensor of shape `(batch_size, max_length)` containing
2406
+ the true labels in integer format. `0` always represents
2407
+ the blank/mask index and should not be used for classes.
2408
+ y_pred: A tensor of shape `(batch_size, max_length, num_classes)`
2409
+ containing logits (the output of your model).
2410
+ They should *not* be normalized via softmax.
2411
+ """
2412
+ if len(ops.shape(y_true)) != 2:
2413
+ raise ValueError(
2414
+ "Targets `y_true` are expected to be a tensor of shape "
2415
+ "`(batch_size, max_length)` in integer format. "
2416
+ f"Received: y_true.shape={ops.shape(y_true)}"
2417
+ )
2418
+ if len(ops.shape(y_pred)) != 3:
2419
+ raise ValueError(
2420
+ "Logits `y_pred` are expected to be a tensor of shape "
2421
+ "`(batch_size, max_length, num_classes)`. "
2422
+ f"Received: y_pred.shape={ops.shape(y_pred)}"
2423
+ )
2424
+
2425
+ mask_index = 0
2426
+ batch_length = ops.shape(y_pred)[0]
2427
+ input_length = ops.shape(y_pred)[1]
2428
+ input_length = input_length * ops.ones((batch_length,), dtype="int32")
2429
+ label_length = ops.cast(
2430
+ ops.sum(y_true != mask_index, axis=-1), dtype="int32"
2431
+ )
2432
+
2433
+ return ops.ctc_loss(
2434
+ y_true, y_pred, label_length, input_length, mask_index=mask_index
2435
+ )
2436
+
2437
+
2438
+ @keras_export("keras.losses.dice")
2439
+ def dice(y_true, y_pred, axis=None):
2440
+ """Computes the Dice loss value between `y_true` and `y_pred`.
2441
+
2442
+ Formula:
2443
+ ```python
2444
+ loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred))
2445
+ ```
2446
+
2447
+ Args:
2448
+ y_true: tensor of true targets.
2449
+ y_pred: tensor of predicted targets.
2450
+ axis: tuple for which dimensions the loss is calculated
2451
+
2452
+ Returns:
2453
+ Dice loss value.
2454
+ """
2455
+ y_pred = ops.convert_to_tensor(y_pred)
2456
+ y_true = ops.cast(y_true, y_pred.dtype)
2457
+
2458
+ inputs = y_true
2459
+ targets = y_pred
2460
+
2461
+ intersection = ops.sum(inputs * targets, axis=axis)
2462
+ dice = ops.divide(
2463
+ 2.0 * intersection,
2464
+ ops.sum(y_true, axis=axis)
2465
+ + ops.sum(y_pred, axis=axis)
2466
+ + backend.epsilon(),
2467
+ )
2468
+
2469
+ return 1 - dice
2470
+
2471
+
2472
+ @keras_export("keras.losses.tversky")
2473
+ def tversky(y_true, y_pred, alpha=0.5, beta=0.5, axis=None):
2474
+ """Computes the Tversky loss value between `y_true` and `y_pred`.
2475
+
2476
+ This loss function is weighted by the alpha and beta coefficients
2477
+ that penalize false positives and false negatives.
2478
+
2479
+ With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to
2480
+ Dice Loss.
2481
+
2482
+ Args:
2483
+ y_true: tensor of true targets.
2484
+ y_pred: tensor of predicted targets.
2485
+ alpha: coefficient controlling incidence of false positives.
2486
+ beta: coefficient controlling incidence of false negatives.
2487
+ axis: tuple for which dimensions the loss is calculated.
2488
+
2489
+ Returns:
2490
+ Tversky loss value.
2491
+
2492
+ Reference:
2493
+
2494
+ - [Salehi et al., 2017](https://arxiv.org/abs/1706.05721)
2495
+ """
2496
+ y_pred = ops.convert_to_tensor(y_pred)
2497
+ y_true = ops.cast(y_true, y_pred.dtype)
2498
+
2499
+ inputs = y_true
2500
+ targets = y_pred
2501
+
2502
+ intersection = ops.sum(inputs * targets, axis=axis)
2503
+ fp = ops.sum((1 - targets) * inputs, axis=axis)
2504
+ fn = ops.sum(targets * (1 - inputs), axis=axis)
2505
+
2506
+ tversky = ops.divide(
2507
+ intersection,
2508
+ intersection + fp * alpha + fn * beta + backend.epsilon(),
2509
+ )
2510
+
2511
+ return 1 - tversky
2512
+
2513
+
2514
+ @keras_export("keras.losses.circle")
2515
+ def circle(
2516
+ y_true,
2517
+ y_pred,
2518
+ ref_labels=None,
2519
+ ref_embeddings=None,
2520
+ remove_diagonal=True,
2521
+ gamma=80,
2522
+ margin=0.4,
2523
+ ):
2524
+ """Computes the Circle loss.
2525
+
2526
+ It is designed to minimize within-class distances and maximize between-class
2527
+ distances in L2 normalized embedding space.
2528
+
2529
+ Args:
2530
+ y_true: Tensor with ground truth labels in integer format.
2531
+ y_pred: Tensor with predicted L2 normalized embeddings.
2532
+ ref_labels: Optional integer tensor with labels for reference
2533
+ embeddings. If `None`, defaults to `y_true`.
2534
+ ref_embeddings: Optional tensor with L2 normalized reference embeddings.
2535
+ If `None`, defaults to `y_pred`.
2536
+ remove_diagonal: Boolean, whether to remove self-similarities from
2537
+ positive mask. Defaults to `True`.
2538
+ gamma: Float, scaling factor for the loss. Defaults to `80`.
2539
+ margin: Float, relaxation factor for the loss. Defaults to `0.4`.
2540
+
2541
+ Returns:
2542
+ Circle loss value.
2543
+ """
2544
+ y_pred = ops.convert_to_tensor(y_pred)
2545
+ y_true = ops.cast(y_true, "int32")
2546
+ ref_embeddings = (
2547
+ y_pred
2548
+ if ref_embeddings is None
2549
+ else ops.convert_to_tensor(ref_embeddings)
2550
+ )
2551
+ ref_labels = y_true if ref_labels is None else ops.cast(ref_labels, "int32")
2552
+
2553
+ optim_pos = margin
2554
+ optim_neg = 1 + margin
2555
+ delta_pos = margin
2556
+ delta_neg = 1 - margin
2557
+
2558
+ pairwise_cosine_distances = 1 - ops.matmul(
2559
+ y_pred, ops.transpose(ref_embeddings)
2560
+ )
2561
+
2562
+ pairwise_cosine_distances = ops.maximum(pairwise_cosine_distances, 0.0)
2563
+ positive_mask, negative_mask = build_pos_neg_masks(
2564
+ y_true,
2565
+ ref_labels,
2566
+ remove_diagonal=remove_diagonal,
2567
+ )
2568
+ positive_mask = ops.cast(
2569
+ positive_mask, dtype=pairwise_cosine_distances.dtype
2570
+ )
2571
+ negative_mask = ops.cast(
2572
+ negative_mask, dtype=pairwise_cosine_distances.dtype
2573
+ )
2574
+
2575
+ pos_weights = optim_pos + pairwise_cosine_distances
2576
+ pos_weights = pos_weights * positive_mask
2577
+ pos_weights = ops.maximum(pos_weights, 0.0)
2578
+ neg_weights = optim_neg - pairwise_cosine_distances
2579
+ neg_weights = neg_weights * negative_mask
2580
+ neg_weights = ops.maximum(neg_weights, 0.0)
2581
+
2582
+ pos_dists = delta_pos - pairwise_cosine_distances
2583
+ neg_dists = delta_neg - pairwise_cosine_distances
2584
+
2585
+ pos_wdists = -1 * gamma * pos_weights * pos_dists
2586
+ neg_wdists = gamma * neg_weights * neg_dists
2587
+
2588
+ p_loss = ops.logsumexp(
2589
+ ops.where(positive_mask, pos_wdists, float("-inf")),
2590
+ axis=1,
2591
+ )
2592
+ n_loss = ops.logsumexp(
2593
+ ops.where(negative_mask, neg_wdists, float("-inf")),
2594
+ axis=1,
2595
+ )
2596
+
2597
+ circle_loss = ops.softplus(p_loss + n_loss)
2598
+ backend.set_keras_mask(circle_loss, circle_loss > 0)
2599
+ return circle_loss
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__init__.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.metrics.accuracy_metrics import Accuracy
5
+ from keras.src.metrics.accuracy_metrics import BinaryAccuracy
6
+ from keras.src.metrics.accuracy_metrics import CategoricalAccuracy
7
+ from keras.src.metrics.accuracy_metrics import SparseCategoricalAccuracy
8
+ from keras.src.metrics.accuracy_metrics import SparseTopKCategoricalAccuracy
9
+ from keras.src.metrics.accuracy_metrics import TopKCategoricalAccuracy
10
+ from keras.src.metrics.confusion_metrics import AUC
11
+ from keras.src.metrics.confusion_metrics import FalseNegatives
12
+ from keras.src.metrics.confusion_metrics import FalsePositives
13
+ from keras.src.metrics.confusion_metrics import Precision
14
+ from keras.src.metrics.confusion_metrics import PrecisionAtRecall
15
+ from keras.src.metrics.confusion_metrics import Recall
16
+ from keras.src.metrics.confusion_metrics import RecallAtPrecision
17
+ from keras.src.metrics.confusion_metrics import SensitivityAtSpecificity
18
+ from keras.src.metrics.confusion_metrics import SpecificityAtSensitivity
19
+ from keras.src.metrics.confusion_metrics import TrueNegatives
20
+ from keras.src.metrics.confusion_metrics import TruePositives
21
+ from keras.src.metrics.correlation_metrics import ConcordanceCorrelation
22
+ from keras.src.metrics.correlation_metrics import PearsonCorrelation
23
+ from keras.src.metrics.f_score_metrics import F1Score
24
+ from keras.src.metrics.f_score_metrics import FBetaScore
25
+ from keras.src.metrics.hinge_metrics import CategoricalHinge
26
+ from keras.src.metrics.hinge_metrics import Hinge
27
+ from keras.src.metrics.hinge_metrics import SquaredHinge
28
+ from keras.src.metrics.iou_metrics import BinaryIoU
29
+ from keras.src.metrics.iou_metrics import IoU
30
+ from keras.src.metrics.iou_metrics import MeanIoU
31
+ from keras.src.metrics.iou_metrics import OneHotIoU
32
+ from keras.src.metrics.iou_metrics import OneHotMeanIoU
33
+ from keras.src.metrics.metric import Metric
34
+ from keras.src.metrics.probabilistic_metrics import BinaryCrossentropy
35
+ from keras.src.metrics.probabilistic_metrics import CategoricalCrossentropy
36
+ from keras.src.metrics.probabilistic_metrics import KLDivergence
37
+ from keras.src.metrics.probabilistic_metrics import Poisson
38
+ from keras.src.metrics.probabilistic_metrics import (
39
+ SparseCategoricalCrossentropy,
40
+ )
41
+ from keras.src.metrics.reduction_metrics import Mean
42
+ from keras.src.metrics.reduction_metrics import MeanMetricWrapper
43
+ from keras.src.metrics.reduction_metrics import Sum
44
+ from keras.src.metrics.regression_metrics import CosineSimilarity
45
+ from keras.src.metrics.regression_metrics import LogCoshError
46
+ from keras.src.metrics.regression_metrics import MeanAbsoluteError
47
+ from keras.src.metrics.regression_metrics import MeanAbsolutePercentageError
48
+ from keras.src.metrics.regression_metrics import MeanSquaredError
49
+ from keras.src.metrics.regression_metrics import MeanSquaredLogarithmicError
50
+ from keras.src.metrics.regression_metrics import R2Score
51
+ from keras.src.metrics.regression_metrics import RootMeanSquaredError
52
+ from keras.src.saving import serialization_lib
53
+ from keras.src.utils.naming import to_snake_case
54
+
55
+ ALL_OBJECTS = {
56
+ # Base
57
+ Metric,
58
+ Mean,
59
+ Sum,
60
+ MeanMetricWrapper,
61
+ # Regression
62
+ MeanSquaredError,
63
+ RootMeanSquaredError,
64
+ MeanAbsoluteError,
65
+ MeanAbsolutePercentageError,
66
+ MeanSquaredLogarithmicError,
67
+ CosineSimilarity,
68
+ LogCoshError,
69
+ R2Score,
70
+ # Classification
71
+ AUC,
72
+ FalseNegatives,
73
+ FalsePositives,
74
+ Precision,
75
+ PrecisionAtRecall,
76
+ Recall,
77
+ RecallAtPrecision,
78
+ SensitivityAtSpecificity,
79
+ SpecificityAtSensitivity,
80
+ TrueNegatives,
81
+ TruePositives,
82
+ # Correlation
83
+ ConcordanceCorrelation,
84
+ PearsonCorrelation,
85
+ # Hinge
86
+ Hinge,
87
+ SquaredHinge,
88
+ CategoricalHinge,
89
+ # Probabilistic
90
+ KLDivergence,
91
+ Poisson,
92
+ BinaryCrossentropy,
93
+ CategoricalCrossentropy,
94
+ SparseCategoricalCrossentropy,
95
+ # Accuracy
96
+ Accuracy,
97
+ BinaryAccuracy,
98
+ CategoricalAccuracy,
99
+ SparseCategoricalAccuracy,
100
+ TopKCategoricalAccuracy,
101
+ SparseTopKCategoricalAccuracy,
102
+ # F-Score
103
+ F1Score,
104
+ FBetaScore,
105
+ # IoU
106
+ IoU,
107
+ BinaryIoU,
108
+ MeanIoU,
109
+ OneHotIoU,
110
+ OneHotMeanIoU,
111
+ }
112
+ ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
113
+ ALL_OBJECTS_DICT.update(
114
+ {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}
115
+ )
116
+ # TODO: Align with `tf.keras` and set the name attribute of metrics
117
+ # with the key name. Currently it uses default name of class definitions.
118
+ ALL_OBJECTS_DICT.update(
119
+ {
120
+ "bce": BinaryCrossentropy,
121
+ "BCE": BinaryCrossentropy,
122
+ "mse": MeanSquaredError,
123
+ "MSE": MeanSquaredError,
124
+ "mae": MeanAbsoluteError,
125
+ "MAE": MeanAbsoluteError,
126
+ "mape": MeanAbsolutePercentageError,
127
+ "MAPE": MeanAbsolutePercentageError,
128
+ "msle": MeanSquaredLogarithmicError,
129
+ "MSLE": MeanSquaredLogarithmicError,
130
+ }
131
+ )
132
+
133
+
134
+ @keras_export("keras.metrics.serialize")
135
+ def serialize(metric):
136
+ """Serializes metric function or `Metric` instance.
137
+
138
+ Args:
139
+ metric: A Keras `Metric` instance or a metric function.
140
+
141
+ Returns:
142
+ Metric configuration dictionary.
143
+ """
144
+ return serialization_lib.serialize_keras_object(metric)
145
+
146
+
147
+ @keras_export("keras.metrics.deserialize")
148
+ def deserialize(config, custom_objects=None):
149
+ """Deserializes a serialized metric class/function instance.
150
+
151
+ Args:
152
+ config: Metric configuration.
153
+ custom_objects: Optional dictionary mapping names (strings)
154
+ to custom objects (classes and functions) to be
155
+ considered during deserialization.
156
+
157
+ Returns:
158
+ A Keras `Metric` instance or a metric function.
159
+ """
160
+ return serialization_lib.deserialize_keras_object(
161
+ config,
162
+ module_objects=ALL_OBJECTS_DICT,
163
+ custom_objects=custom_objects,
164
+ )
165
+
166
+
167
+ @keras_export("keras.metrics.get")
168
+ def get(identifier):
169
+ """Retrieves a Keras metric as a `function`/`Metric` class instance.
170
+
171
+ The `identifier` may be the string name of a metric function or class.
172
+
173
+ >>> metric = metrics.get("categorical_crossentropy")
174
+ >>> type(metric)
175
+ <class 'function'>
176
+ >>> metric = metrics.get("CategoricalCrossentropy")
177
+ >>> type(metric)
178
+ <class '...metrics.CategoricalCrossentropy'>
179
+
180
+ You can also specify `config` of the metric to this function by passing dict
181
+ containing `class_name` and `config` as an identifier. Also note that the
182
+ `class_name` must map to a `Metric` class
183
+
184
+ >>> identifier = {"class_name": "CategoricalCrossentropy",
185
+ ... "config": {"from_logits": True}}
186
+ >>> metric = metrics.get(identifier)
187
+ >>> type(metric)
188
+ <class '...metrics.CategoricalCrossentropy'>
189
+
190
+ Args:
191
+ identifier: A metric identifier. One of None or string name of a metric
192
+ function/class or metric configuration dictionary or a metric
193
+ function or a metric class instance
194
+
195
+ Returns:
196
+ A Keras metric as a `function`/ `Metric` class instance.
197
+ """
198
+ if identifier is None:
199
+ return None
200
+ if isinstance(identifier, dict):
201
+ obj = deserialize(identifier)
202
+ elif isinstance(identifier, str):
203
+ obj = ALL_OBJECTS_DICT.get(identifier, None)
204
+ else:
205
+ obj = identifier
206
+ if callable(obj):
207
+ if inspect.isclass(obj):
208
+ obj = obj()
209
+ return obj
210
+ else:
211
+ raise ValueError(f"Could not interpret metric identifier: {identifier}")
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.7 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/accuracy_metrics.cpython-310.pyc ADDED
Binary file (16 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/confusion_metrics.cpython-310.pyc ADDED
Binary file (48.7 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/correlation_metrics.cpython-310.pyc ADDED
Binary file (6.58 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/f_score_metrics.cpython-310.pyc ADDED
Binary file (9.84 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/hinge_metrics.cpython-310.pyc ADDED
Binary file (3.64 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/iou_metrics.cpython-310.pyc ADDED
Binary file (24.6 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/metric.cpython-310.pyc ADDED
Binary file (8.74 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/metrics_utils.cpython-310.pyc ADDED
Binary file (19.6 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/probabilistic_metrics.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/reduction_metrics.cpython-310.pyc ADDED
Binary file (6.78 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/__pycache__/regression_metrics.cpython-310.pyc ADDED
Binary file (17.3 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/accuracy_metrics.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src import ops
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.losses.loss import squeeze_or_expand_to_same_rank
5
+ from keras.src.metrics import reduction_metrics
6
+
7
+
8
+ def accuracy(y_true, y_pred):
9
+ y_pred = ops.convert_to_tensor(y_pred)
10
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
11
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
12
+ return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx())
13
+
14
+
15
+ @keras_export("keras.metrics.Accuracy")
16
+ class Accuracy(reduction_metrics.MeanMetricWrapper):
17
+ """Calculates how often predictions equal labels.
18
+
19
+ This metric creates two local variables, `total` and `count` that are used
20
+ to compute the frequency with which `y_pred` matches `y_true`. This
21
+ frequency is ultimately returned as `binary accuracy`: an idempotent
22
+ operation that simply divides `total` by `count`.
23
+
24
+ If `sample_weight` is `None`, weights default to 1.
25
+ Use `sample_weight` of 0 to mask values.
26
+
27
+ Args:
28
+ name: (Optional) string name of the metric instance.
29
+ dtype: (Optional) data type of the metric result.
30
+
31
+ Examples:
32
+
33
+ >>> m = keras.metrics.Accuracy()
34
+ >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]])
35
+ >>> m.result()
36
+ 0.75
37
+
38
+ >>> m.reset_state()
39
+ >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]],
40
+ ... sample_weight=[1, 1, 0, 0])
41
+ >>> m.result()
42
+ 0.5
43
+
44
+ Usage with `compile()` API:
45
+
46
+ ```python
47
+ model.compile(optimizer='sgd',
48
+ loss='binary_crossentropy',
49
+ metrics=[keras.metrics.Accuracy()])
50
+ ```
51
+ """
52
+
53
+ def __init__(self, name="accuracy", dtype=None):
54
+ super().__init__(fn=accuracy, name=name, dtype=dtype)
55
+ # Metric should be maximized during optimization.
56
+ self._direction = "up"
57
+
58
+ def get_config(self):
59
+ return {"name": self.name, "dtype": self.dtype}
60
+
61
+
62
+ @keras_export("keras.metrics.binary_accuracy")
63
+ def binary_accuracy(y_true, y_pred, threshold=0.5):
64
+ y_pred = ops.convert_to_tensor(y_pred)
65
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
66
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
67
+ threshold = ops.cast(threshold, y_pred.dtype)
68
+ y_pred = ops.cast(y_pred > threshold, y_true.dtype)
69
+ return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx())
70
+
71
+
72
+ @keras_export("keras.metrics.BinaryAccuracy")
73
+ class BinaryAccuracy(reduction_metrics.MeanMetricWrapper):
74
+ """Calculates how often predictions match binary labels.
75
+
76
+ This metric creates two local variables, `total` and `count` that are used
77
+ to compute the frequency with which `y_pred` matches `y_true`. This
78
+ frequency is ultimately returned as `binary accuracy`: an idempotent
79
+ operation that simply divides `total` by `count`.
80
+
81
+ If `sample_weight` is `None`, weights default to 1.
82
+ Use `sample_weight` of 0 to mask values.
83
+
84
+ Args:
85
+ name: (Optional) string name of the metric instance.
86
+ dtype: (Optional) data type of the metric result.
87
+ threshold: (Optional) Float representing the threshold for deciding
88
+ whether prediction values are 1 or 0.
89
+
90
+ Example:
91
+
92
+ >>> m = keras.metrics.BinaryAccuracy()
93
+ >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]])
94
+ >>> m.result()
95
+ 0.75
96
+
97
+ >>> m.reset_state()
98
+ >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]],
99
+ ... sample_weight=[1, 0, 0, 1])
100
+ >>> m.result()
101
+ 0.5
102
+
103
+ Usage with `compile()` API:
104
+
105
+ ```python
106
+ model.compile(optimizer='sgd',
107
+ loss='binary_crossentropy',
108
+ metrics=[keras.metrics.BinaryAccuracy()])
109
+ ```
110
+ """
111
+
112
+ def __init__(self, name="binary_accuracy", dtype=None, threshold=0.5):
113
+ super().__init__(
114
+ fn=binary_accuracy, name=name, dtype=dtype, threshold=threshold
115
+ )
116
+ self.threshold = threshold
117
+ # Metric should be maximized during optimization.
118
+ self._direction = "up"
119
+
120
+ def get_config(self):
121
+ return {
122
+ "name": self.name,
123
+ "dtype": self.dtype,
124
+ "threshold": self.threshold,
125
+ }
126
+
127
+
128
+ @keras_export("keras.metrics.categorical_accuracy")
129
+ def categorical_accuracy(y_true, y_pred):
130
+ y_true = ops.argmax(y_true, axis=-1)
131
+
132
+ reshape_matches = False
133
+ y_pred = ops.convert_to_tensor(y_pred)
134
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
135
+
136
+ y_true_org_shape = ops.shape(y_true)
137
+ y_pred_rank = len(y_pred.shape)
138
+ y_true_rank = len(y_true.shape)
139
+
140
+ # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
141
+ if (
142
+ (y_true_rank is not None)
143
+ and (y_pred_rank is not None)
144
+ and (len(y_true.shape) == len(y_pred.shape))
145
+ ):
146
+ y_true = ops.squeeze(y_true, -1)
147
+ reshape_matches = True
148
+ y_pred = ops.argmax(y_pred, axis=-1)
149
+
150
+ # If the predicted output and actual output types don't match, force cast
151
+ # them to match.
152
+ if y_pred.dtype is not y_true.dtype:
153
+ y_pred = ops.cast(y_pred, dtype=y_true.dtype)
154
+ matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())
155
+ if reshape_matches:
156
+ matches = ops.reshape(matches, y_true_org_shape)
157
+ return matches
158
+
159
+
160
+ @keras_export("keras.metrics.CategoricalAccuracy")
161
+ class CategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
162
+ """Calculates how often predictions match one-hot labels.
163
+
164
+ You can provide logits of classes as `y_pred`, since argmax of
165
+ logits and probabilities are same.
166
+
167
+ This metric creates two local variables, `total` and `count` that are used
168
+ to compute the frequency with which `y_pred` matches `y_true`. This
169
+ frequency is ultimately returned as `categorical accuracy`: an idempotent
170
+ operation that simply divides `total` by `count`.
171
+
172
+ `y_pred` and `y_true` should be passed in as vectors of probabilities,
173
+ rather than as labels. If necessary, use `ops.one_hot` to expand `y_true` as
174
+ a vector.
175
+
176
+ If `sample_weight` is `None`, weights default to 1.
177
+ Use `sample_weight` of 0 to mask values.
178
+
179
+ Args:
180
+ name: (Optional) string name of the metric instance.
181
+ dtype: (Optional) data type of the metric result.
182
+
183
+ Example:
184
+
185
+ >>> m = keras.metrics.CategoricalAccuracy()
186
+ >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
187
+ ... [0.05, 0.95, 0]])
188
+ >>> m.result()
189
+ 0.5
190
+
191
+ >>> m.reset_state()
192
+ >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
193
+ ... [0.05, 0.95, 0]],
194
+ ... sample_weight=[0.7, 0.3])
195
+ >>> m.result()
196
+ 0.3
197
+
198
+ Usage with `compile()` API:
199
+
200
+ ```python
201
+ model.compile(optimizer='sgd',
202
+ loss='categorical_crossentropy',
203
+ metrics=[keras.metrics.CategoricalAccuracy()])
204
+ ```
205
+ """
206
+
207
+ def __init__(self, name="categorical_accuracy", dtype=None):
208
+ super().__init__(fn=categorical_accuracy, name=name, dtype=dtype)
209
+ # Metric should be maximized during optimization.
210
+ self._direction = "up"
211
+
212
+ def get_config(self):
213
+ return {"name": self.name, "dtype": self.dtype}
214
+
215
+
216
+ @keras_export("keras.metrics.sparse_categorical_accuracy")
217
+ def sparse_categorical_accuracy(y_true, y_pred):
218
+ reshape_matches = False
219
+ y_pred = ops.convert_to_tensor(y_pred)
220
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
221
+ y_true_org_shape = ops.shape(y_true)
222
+ y_pred_rank = len(y_pred.shape)
223
+ y_true_rank = len(y_true.shape)
224
+
225
+ # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
226
+ if (
227
+ (y_true_rank is not None)
228
+ and (y_pred_rank is not None)
229
+ and (len(y_true.shape) == len(y_pred.shape))
230
+ and ops.shape(y_true)[-1] == 1
231
+ ):
232
+ y_true = ops.squeeze(y_true, -1)
233
+ reshape_matches = True
234
+ y_pred = ops.argmax(y_pred, axis=-1)
235
+
236
+ # If the predicted output and actual output types don't match, force cast
237
+ # them to match.
238
+ if y_pred.dtype is not y_true.dtype:
239
+ y_pred = ops.cast(y_pred, y_true.dtype)
240
+ matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())
241
+ if reshape_matches:
242
+ matches = ops.reshape(matches, y_true_org_shape)
243
+ # if shape is (num_samples, 1) squeeze
244
+ if len(matches.shape) > 1 and matches.shape[-1] == 1:
245
+ matches = ops.squeeze(matches, -1)
246
+ return matches
247
+
248
+
249
+ @keras_export("keras.metrics.SparseCategoricalAccuracy")
250
+ class SparseCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
251
+ """Calculates how often predictions match integer labels.
252
+
253
+ ```python
254
+ acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1))
255
+ ```
256
+
257
+ You can provide logits of classes as `y_pred`, since argmax of
258
+ logits and probabilities are same.
259
+
260
+ This metric creates two local variables, `total` and `count` that are used
261
+ to compute the frequency with which `y_pred` matches `y_true`. This
262
+ frequency is ultimately returned as `sparse categorical accuracy`: an
263
+ idempotent operation that simply divides `total` by `count`.
264
+
265
+ If `sample_weight` is `None`, weights default to 1.
266
+ Use `sample_weight` of 0 to mask values.
267
+
268
+ Args:
269
+ name: (Optional) string name of the metric instance.
270
+ dtype: (Optional) data type of the metric result.
271
+
272
+ Example:
273
+
274
+ >>> m = keras.metrics.SparseCategoricalAccuracy()
275
+ >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]])
276
+ >>> m.result()
277
+ 0.5
278
+
279
+ >>> m.reset_state()
280
+ >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]],
281
+ ... sample_weight=[0.7, 0.3])
282
+ >>> m.result()
283
+ 0.3
284
+
285
+ Usage with `compile()` API:
286
+
287
+ ```python
288
+ model.compile(optimizer='sgd',
289
+ loss='sparse_categorical_crossentropy',
290
+ metrics=[keras.metrics.SparseCategoricalAccuracy()])
291
+ ```
292
+ """
293
+
294
+ def __init__(self, name="sparse_categorical_accuracy", dtype=None):
295
+ super().__init__(fn=sparse_categorical_accuracy, name=name, dtype=dtype)
296
+ # Metric should be maximized during optimization.
297
+ self._direction = "up"
298
+
299
+ def get_config(self):
300
+ return {"name": self.name, "dtype": self.dtype}
301
+
302
+
303
+ @keras_export("keras.metrics.top_k_categorical_accuracy")
304
+ def top_k_categorical_accuracy(y_true, y_pred, k=5):
305
+ reshape_matches = False
306
+ y_pred = ops.convert_to_tensor(y_pred)
307
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
308
+ y_true = ops.argmax(y_true, axis=-1)
309
+ y_true_rank = len(y_true.shape)
310
+ y_pred_rank = len(y_pred.shape)
311
+ y_true_org_shape = ops.shape(y_true)
312
+
313
+ # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
314
+ if (y_true_rank is not None) and (y_pred_rank is not None):
315
+ if y_pred_rank > 2:
316
+ y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]])
317
+ if y_true_rank > 1:
318
+ reshape_matches = True
319
+ y_true = ops.reshape(y_true, [-1])
320
+
321
+ matches = ops.cast(
322
+ ops.in_top_k(ops.cast(y_true, "int32"), y_pred, k=k),
323
+ dtype=backend.floatx(),
324
+ )
325
+
326
+ # returned matches is expected to have same shape as y_true input
327
+ if reshape_matches:
328
+ matches = ops.reshape(matches, y_true_org_shape)
329
+
330
+ return matches
331
+
332
+
333
+ @keras_export("keras.metrics.TopKCategoricalAccuracy")
334
+ class TopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
335
+ """Computes how often targets are in the top `K` predictions.
336
+
337
+ Args:
338
+ k: (Optional) Number of top elements to look at for computing accuracy.
339
+ Defaults to `5`.
340
+ name: (Optional) string name of the metric instance.
341
+ dtype: (Optional) data type of the metric result.
342
+
343
+ Example:
344
+
345
+ >>> m = keras.metrics.TopKCategoricalAccuracy(k=1)
346
+ >>> m.update_state([[0, 0, 1], [0, 1, 0]],
347
+ ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
348
+ >>> m.result()
349
+ 0.5
350
+
351
+ >>> m.reset_state()
352
+ >>> m.update_state([[0, 0, 1], [0, 1, 0]],
353
+ ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
354
+ ... sample_weight=[0.7, 0.3])
355
+ >>> m.result()
356
+ 0.3
357
+
358
+ Usage with `compile()` API:
359
+
360
+ ```python
361
+ model.compile(optimizer='sgd',
362
+ loss='categorical_crossentropy',
363
+ metrics=[keras.metrics.TopKCategoricalAccuracy()])
364
+ ```
365
+ """
366
+
367
+ def __init__(self, k=5, name="top_k_categorical_accuracy", dtype=None):
368
+ super().__init__(
369
+ fn=top_k_categorical_accuracy,
370
+ name=name,
371
+ dtype=dtype,
372
+ k=k,
373
+ )
374
+ self.k = k
375
+ # Metric should be maximized during optimization.
376
+ self._direction = "up"
377
+
378
+ def get_config(self):
379
+ return {"name": self.name, "dtype": self.dtype, "k": self.k}
380
+
381
+
382
+ @keras_export("keras.metrics.sparse_top_k_categorical_accuracy")
383
+ def sparse_top_k_categorical_accuracy(
384
+ y_true, y_pred, k=5, from_sorted_ids=False
385
+ ):
386
+ """Computes how often integer targets are in the top `K` predictions.
387
+
388
+ Args:
389
+ y_true: A tensor of shape `(batch_size)` representing indices or IDs of
390
+ true categories.
391
+ y_pred: If `from_sorted_ids=False`, a tensor of shape
392
+ `(batch_size, num_categories)` containing the scores for each sample
393
+ for all possible categories. If `from_sorted_ids=True`, a tensor of
394
+ shape `(batch_size, N)` containing indices or IDs of the top `N`
395
+ categories in order from highest score to lowest score.
396
+ k: (Optional) Number of top elements to look at for computing accuracy.
397
+ Defaults to `5`.
398
+ from_sorted_ids: (Optional) Whether `y_pred` is sorted category IDs or
399
+ scores for all categories (the default).
400
+
401
+ Returns:
402
+ A tensor with the same shape as `y_true` containing ones where `y_true`
403
+ is in the top `k` and zeros elsewhere.
404
+ """
405
+ reshape_matches = False
406
+ y_pred = ops.convert_to_tensor(y_pred)
407
+ y_true_dtype = y_pred.dtype if from_sorted_ids else "int32"
408
+ y_true = ops.convert_to_tensor(y_true, dtype=y_true_dtype)
409
+ y_true_rank = len(y_true.shape)
410
+ y_pred_rank = len(y_pred.shape)
411
+ y_true_org_shape = ops.shape(y_true)
412
+
413
+ # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
414
+ if (y_true_rank is not None) and (y_pred_rank is not None):
415
+ if y_pred_rank > 2:
416
+ y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]])
417
+ if y_true_rank > 1:
418
+ reshape_matches = True
419
+ y_true = ops.reshape(y_true, [-1])
420
+
421
+ if from_sorted_ids:
422
+ # By slicing the first k items, we assume they are sorted by score.
423
+ # Reduce with `any` to count multiple matches only once.
424
+ matches = ops.any(
425
+ ops.equal(ops.expand_dims(y_true, axis=1), y_pred[:, :k]), axis=1
426
+ )
427
+ else:
428
+ matches = ops.in_top_k(y_true, y_pred, k=k)
429
+
430
+ matches = ops.cast(matches, dtype=backend.floatx())
431
+
432
+ # returned matches is expected to have same shape as y_true input
433
+ if reshape_matches:
434
+ matches = ops.reshape(matches, y_true_org_shape)
435
+
436
+ return matches
437
+
438
+
439
+ @keras_export("keras.metrics.SparseTopKCategoricalAccuracy")
440
+ class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
441
+ """Computes how often integer targets are in the top `K` predictions.
442
+
443
+ By default, the arguments expected by `update_state()` are:
444
+ - `y_true`: a tensor of shape `(batch_size)` representing indices of true
445
+ categories.
446
+ - `y_pred`: a tensor of shape `(batch_size, num_categories)` containing the
447
+ scores for each sample for all possible categories.
448
+
449
+ With `from_sorted_ids=True`, the arguments expected by `update_state` are:
450
+ - `y_true`: a tensor of shape `(batch_size)` representing indices or IDs of
451
+ true categories.
452
+ - `y_pred`: a tensor of shape `(batch_size, N)` containing the indices or
453
+ IDs of the top `N` categories sorted in order from highest score to
454
+ lowest score. `N` must be greater or equal to `k`.
455
+
456
+ The `from_sorted_ids=True` option can be more efficient when the set of
457
+ categories is very large and the model has an optimized way to retrieve the
458
+ top ones either without scoring or without maintaining the scores for all
459
+ the possible categories.
460
+
461
+ Args:
462
+ k: (Optional) Number of top elements to look at for computing accuracy.
463
+ Defaults to `5`.
464
+ name: (Optional) string name of the metric instance.
465
+ dtype: (Optional) data type of the metric result.
466
+ from_sorted_ids: (Optional) When `False`, the default, the tensor passed
467
+ in `y_pred` contains the unsorted scores of all possible categories.
468
+ When `True`, `y_pred` contains a the indices or IDs for the top
469
+ categories.
470
+
471
+ Example:
472
+
473
+ >>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1)
474
+ >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
475
+ >>> m.result()
476
+ 0.5
477
+
478
+ >>> m.reset_state()
479
+ >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
480
+ ... sample_weight=[0.7, 0.3])
481
+ >>> m.result()
482
+ 0.3
483
+
484
+ >>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1,
485
+ ... from_sorted_ids=True)
486
+ >>> m.update_state([2, 1], [[1, 0, 3], [1, 2, 3]])
487
+ >>> m.result()
488
+ 0.5
489
+
490
+ Usage with `compile()` API:
491
+
492
+ ```python
493
+ model.compile(optimizer='sgd',
494
+ loss='sparse_categorical_crossentropy',
495
+ metrics=[keras.metrics.SparseTopKCategoricalAccuracy()])
496
+ ```
497
+ """
498
+
499
+ def __init__(
500
+ self,
501
+ k=5,
502
+ name="sparse_top_k_categorical_accuracy",
503
+ dtype=None,
504
+ from_sorted_ids=False,
505
+ ):
506
+ super().__init__(
507
+ fn=sparse_top_k_categorical_accuracy,
508
+ name=name,
509
+ dtype=dtype,
510
+ k=k,
511
+ from_sorted_ids=from_sorted_ids,
512
+ )
513
+ self.k = k
514
+ self.from_sorted_ids = from_sorted_ids
515
+ # Metric should be maximized during optimization.
516
+ self._direction = "up"
517
+
518
+ def get_config(self):
519
+ config = {"name": self.name, "dtype": self.dtype, "k": self.k}
520
+ if self.from_sorted_ids:
521
+ config["from_sorted_ids"] = True
522
+ return config
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/confusion_metrics.py ADDED
@@ -0,0 +1,1576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from keras.src import activations
4
+ from keras.src import backend
5
+ from keras.src import initializers
6
+ from keras.src import ops
7
+ from keras.src.api_export import keras_export
8
+ from keras.src.metrics import metrics_utils
9
+ from keras.src.metrics.metric import Metric
10
+ from keras.src.utils.python_utils import to_list
11
+
12
+
13
+ class _ConfusionMatrixConditionCount(Metric):
14
+ """Calculates the number of the given confusion matrix condition.
15
+
16
+ Args:
17
+ confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix`
18
+ conditions.
19
+ thresholds: (Optional) Defaults to `0.5`. A float value or a python list
20
+ / tuple of float threshold values in `[0, 1]`. A threshold is
21
+ compared with prediction values to determine the truth value of
22
+ predictions (i.e., above the threshold is `True`, below is `False`).
23
+ One metric value is generated for each threshold value.
24
+ name: (Optional) string name of the metric instance.
25
+ dtype: (Optional) data type of the metric result.
26
+ """
27
+
28
+ def __init__(
29
+ self, confusion_matrix_cond, thresholds=None, name=None, dtype=None
30
+ ):
31
+ super().__init__(name=name, dtype=dtype)
32
+ self._confusion_matrix_cond = confusion_matrix_cond
33
+ self.init_thresholds = thresholds
34
+ self.thresholds = metrics_utils.parse_init_thresholds(
35
+ thresholds, default_threshold=0.5
36
+ )
37
+ self._thresholds_distributed_evenly = (
38
+ metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
39
+ )
40
+ self.accumulator = self.add_variable(
41
+ shape=(len(self.thresholds),),
42
+ initializer=initializers.Zeros(),
43
+ name="accumulator",
44
+ )
45
+
46
+ def update_state(self, y_true, y_pred, sample_weight=None):
47
+ """Accumulates the metric statistics.
48
+
49
+ Args:
50
+ y_true: The ground truth values.
51
+ y_pred: The predicted values.
52
+ sample_weight: Optional weighting of each example. Defaults to `1`.
53
+ Can be a tensor whose rank is either 0, or the same rank as
54
+ `y_true`, and must be broadcastable to `y_true`.
55
+ """
56
+ return metrics_utils.update_confusion_matrix_variables(
57
+ {self._confusion_matrix_cond: self.accumulator},
58
+ y_true,
59
+ y_pred,
60
+ thresholds=self.thresholds,
61
+ thresholds_distributed_evenly=self._thresholds_distributed_evenly,
62
+ sample_weight=sample_weight,
63
+ )
64
+
65
+ def result(self):
66
+ if len(self.thresholds) == 1:
67
+ result = self.accumulator[0]
68
+ else:
69
+ result = self.accumulator
70
+ return backend.convert_to_tensor(result)
71
+
72
+ def get_config(self):
73
+ config = {"thresholds": self.init_thresholds}
74
+ base_config = super().get_config()
75
+ return {**base_config, **config}
76
+
77
+
78
+ @keras_export("keras.metrics.FalsePositives")
79
+ class FalsePositives(_ConfusionMatrixConditionCount):
80
+ """Calculates the number of false positives.
81
+
82
+ If `sample_weight` is given, calculates the sum of the weights of
83
+ false positives. This metric creates one local variable, `accumulator`
84
+ that is used to keep track of the number of false positives.
85
+
86
+ If `sample_weight` is `None`, weights default to 1.
87
+ Use `sample_weight` of 0 to mask values.
88
+
89
+ Args:
90
+ thresholds: (Optional) Defaults to `0.5`. A float value, or a Python
91
+ list/tuple of float threshold values in `[0, 1]`. A threshold is
92
+ compared with prediction values to determine the truth value of
93
+ predictions (i.e., above the threshold is `True`, below is `False`).
94
+ If used with a loss function that sets `from_logits=True` (i.e. no
95
+ sigmoid applied to predictions), `thresholds` should be set to 0.
96
+ One metric value is generated for each threshold value.
97
+ name: (Optional) string name of the metric instance.
98
+ dtype: (Optional) data type of the metric result.
99
+
100
+ Examples:
101
+
102
+ >>> m = keras.metrics.FalsePositives()
103
+ >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
104
+ >>> m.result()
105
+ 2.0
106
+
107
+ >>> m.reset_state()
108
+ >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0])
109
+ >>> m.result()
110
+ 1.0
111
+ """
112
+
113
+ def __init__(self, thresholds=None, name=None, dtype=None):
114
+ super().__init__(
115
+ confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
116
+ thresholds=thresholds,
117
+ name=name,
118
+ dtype=dtype,
119
+ )
120
+
121
+
122
+ @keras_export("keras.metrics.FalseNegatives")
123
+ class FalseNegatives(_ConfusionMatrixConditionCount):
124
+ """Calculates the number of false negatives.
125
+
126
+ If `sample_weight` is given, calculates the sum of the weights of
127
+ false negatives. This metric creates one local variable, `accumulator`
128
+ that is used to keep track of the number of false negatives.
129
+
130
+ If `sample_weight` is `None`, weights default to 1.
131
+ Use `sample_weight` of 0 to mask values.
132
+
133
+ Args:
134
+ thresholds: (Optional) Defaults to `0.5`. A float value, or a Python
135
+ list/tuple of float threshold values in `[0, 1]`. A threshold is
136
+ compared with prediction values to determine the truth value of
137
+ predictions (i.e., above the threshold is `True`, below is `False`).
138
+ If used with a loss function that sets `from_logits=True` (i.e. no
139
+ sigmoid applied to predictions), `thresholds` should be set to 0.
140
+ One metric value is generated for each threshold value.
141
+ name: (Optional) string name of the metric instance.
142
+ dtype: (Optional) data type of the metric result.
143
+
144
+ Example:
145
+
146
+ >>> m = keras.metrics.FalseNegatives()
147
+ >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
148
+ >>> m.result()
149
+ 2.0
150
+
151
+ >>> m.reset_state()
152
+ >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0])
153
+ >>> m.result()
154
+ 1.0
155
+ """
156
+
157
+ def __init__(self, thresholds=None, name=None, dtype=None):
158
+ super().__init__(
159
+ confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES,
160
+ thresholds=thresholds,
161
+ name=name,
162
+ dtype=dtype,
163
+ )
164
+
165
+
166
+ @keras_export("keras.metrics.TrueNegatives")
167
+ class TrueNegatives(_ConfusionMatrixConditionCount):
168
+ """Calculates the number of true negatives.
169
+
170
+ If `sample_weight` is given, calculates the sum of the weights of
171
+ true negatives. This metric creates one local variable, `accumulator`
172
+ that is used to keep track of the number of true negatives.
173
+
174
+ If `sample_weight` is `None`, weights default to 1.
175
+ Use `sample_weight` of 0 to mask values.
176
+
177
+ Args:
178
+ thresholds: (Optional) Defaults to `0.5`. A float value, or a Python
179
+ list/tuple of float threshold values in `[0, 1]`. A threshold is
180
+ compared with prediction values to determine the truth value of
181
+ predictions (i.e., above the threshold is `True`, below is `False`).
182
+ If used with a loss function that sets `from_logits=True` (i.e. no
183
+ sigmoid applied to predictions), `thresholds` should be set to 0.
184
+ One metric value is generated for each threshold value.
185
+ name: (Optional) string name of the metric instance.
186
+ dtype: (Optional) data type of the metric result.
187
+
188
+ Example:
189
+
190
+ >>> m = keras.metrics.TrueNegatives()
191
+ >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0])
192
+ >>> m.result()
193
+ 2.0
194
+
195
+ >>> m.reset_state()
196
+ >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0])
197
+ >>> m.result()
198
+ 1.0
199
+ """
200
+
201
+ def __init__(self, thresholds=None, name=None, dtype=None):
202
+ super().__init__(
203
+ confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES,
204
+ thresholds=thresholds,
205
+ name=name,
206
+ dtype=dtype,
207
+ )
208
+
209
+
210
+ @keras_export("keras.metrics.TruePositives")
211
+ class TruePositives(_ConfusionMatrixConditionCount):
212
+ """Calculates the number of true positives.
213
+
214
+ If `sample_weight` is given, calculates the sum of the weights of
215
+ true positives. This metric creates one local variable, `true_positives`
216
+ that is used to keep track of the number of true positives.
217
+
218
+ If `sample_weight` is `None`, weights default to 1.
219
+ Use `sample_weight` of 0 to mask values.
220
+
221
+ Args:
222
+ thresholds: (Optional) Defaults to `0.5`. A float value, or a Python
223
+ list/tuple of float threshold values in `[0, 1]`. A threshold is
224
+ compared with prediction values to determine the truth value of
225
+ predictions (i.e., above the threshold is `True`, below is `False`).
226
+ If used with a loss function that sets `from_logits=True` (i.e. no
227
+ sigmoid applied to predictions), `thresholds` should be set to 0.
228
+ One metric value is generated for each threshold value.
229
+ name: (Optional) string name of the metric instance.
230
+ dtype: (Optional) data type of the metric result.
231
+
232
+ Example:
233
+
234
+ >>> m = keras.metrics.TruePositives()
235
+ >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
236
+ >>> m.result()
237
+ 2.0
238
+
239
+ >>> m.reset_state()
240
+ >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
241
+ >>> m.result()
242
+ 1.0
243
+ """
244
+
245
+ def __init__(self, thresholds=None, name=None, dtype=None):
246
+ super().__init__(
247
+ confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES,
248
+ thresholds=thresholds,
249
+ name=name,
250
+ dtype=dtype,
251
+ )
252
+
253
+
254
+ @keras_export("keras.metrics.Precision")
255
+ class Precision(Metric):
256
+ """Computes the precision of the predictions with respect to the labels.
257
+
258
+ The metric creates two local variables, `true_positives` and
259
+ `false_positives` that are used to compute the precision. This value is
260
+ ultimately returned as `precision`, an idempotent operation that simply
261
+ divides `true_positives` by the sum of `true_positives` and
262
+ `false_positives`.
263
+
264
+ If `sample_weight` is `None`, weights default to 1.
265
+ Use `sample_weight` of 0 to mask values.
266
+
267
+ If `top_k` is set, we'll calculate precision as how often on average a class
268
+ among the top-k classes with the highest predicted values of a batch entry
269
+ is correct and can be found in the label for that entry.
270
+
271
+ If `class_id` is specified, we calculate precision by considering only the
272
+ entries in the batch for which `class_id` is above the threshold and/or in
273
+ the top-k highest predictions, and computing the fraction of them for which
274
+ `class_id` is indeed a correct label.
275
+
276
+ Args:
277
+ thresholds: (Optional) A float value, or a Python list/tuple of float
278
+ threshold values in `[0, 1]`. A threshold is compared with
279
+ prediction values to determine the truth value of predictions (i.e.,
280
+ above the threshold is `True`, below is `False`). If used with a
281
+ loss function that sets `from_logits=True` (i.e. no sigmoid applied
282
+ to predictions), `thresholds` should be set to 0. One metric value
283
+ is generated for each threshold value. If neither `thresholds` nor
284
+ `top_k` are set, the default is to calculate precision with
285
+ `thresholds=0.5`.
286
+ top_k: (Optional) Unset by default. An int value specifying the top-k
287
+ predictions to consider when calculating precision.
288
+ class_id: (Optional) Integer class ID for which we want binary metrics.
289
+ This must be in the half-open interval `[0, num_classes)`, where
290
+ `num_classes` is the last dimension of predictions.
291
+ name: (Optional) string name of the metric instance.
292
+ dtype: (Optional) data type of the metric result.
293
+
294
+ Example:
295
+
296
+ >>> m = keras.metrics.Precision()
297
+ >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
298
+ >>> m.result()
299
+ 0.6666667
300
+
301
+ >>> m.reset_state()
302
+ >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
303
+ >>> m.result()
304
+ 1.0
305
+
306
+ >>> # With top_k=2, it will calculate precision over y_true[:2]
307
+ >>> # and y_pred[:2]
308
+ >>> m = keras.metrics.Precision(top_k=2)
309
+ >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
310
+ >>> m.result()
311
+ 0.0
312
+
313
+ >>> # With top_k=4, it will calculate precision over y_true[:4]
314
+ >>> # and y_pred[:4]
315
+ >>> m = keras.metrics.Precision(top_k=4)
316
+ >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
317
+ >>> m.result()
318
+ 0.5
319
+
320
+ Usage with `compile()` API:
321
+
322
+ ```python
323
+ model.compile(optimizer='sgd',
324
+ loss='binary_crossentropy',
325
+ metrics=[keras.metrics.Precision()])
326
+ ```
327
+
328
+ Usage with a loss with `from_logits=True`:
329
+
330
+ ```python
331
+ model.compile(optimizer='adam',
332
+ loss=keras.losses.BinaryCrossentropy(from_logits=True),
333
+ metrics=[keras.metrics.Precision(thresholds=0)])
334
+ ```
335
+ """
336
+
337
+ def __init__(
338
+ self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None
339
+ ):
340
+ super().__init__(name=name, dtype=dtype)
341
+ # Metric should be maximized during optimization.
342
+ self._direction = "up"
343
+
344
+ self.init_thresholds = thresholds
345
+ self.top_k = top_k
346
+ self.class_id = class_id
347
+
348
+ default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
349
+ self.thresholds = metrics_utils.parse_init_thresholds(
350
+ thresholds, default_threshold=default_threshold
351
+ )
352
+ self._thresholds_distributed_evenly = (
353
+ metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
354
+ )
355
+ self.true_positives = self.add_variable(
356
+ shape=(len(self.thresholds),),
357
+ initializer=initializers.Zeros(),
358
+ name="true_positives",
359
+ )
360
+ self.false_positives = self.add_variable(
361
+ shape=(len(self.thresholds),),
362
+ initializer=initializers.Zeros(),
363
+ name="false_positives",
364
+ )
365
+
366
+ def update_state(self, y_true, y_pred, sample_weight=None):
367
+ """Accumulates true positive and false positive statistics.
368
+
369
+ Args:
370
+ y_true: The ground truth values, with the same dimensions as
371
+ `y_pred`. Will be cast to `bool`.
372
+ y_pred: The predicted values. Each element must be in the range
373
+ `[0, 1]`.
374
+ sample_weight: Optional weighting of each example. Defaults to `1`.
375
+ Can be a tensor whose rank is either 0, or the same rank as
376
+ `y_true`, and must be broadcastable to `y_true`.
377
+ """
378
+ metrics_utils.update_confusion_matrix_variables(
379
+ {
380
+ metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
381
+ metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
382
+ },
383
+ y_true,
384
+ y_pred,
385
+ thresholds=self.thresholds,
386
+ thresholds_distributed_evenly=self._thresholds_distributed_evenly,
387
+ top_k=self.top_k,
388
+ class_id=self.class_id,
389
+ sample_weight=sample_weight,
390
+ )
391
+
392
+ def result(self):
393
+ result = ops.divide_no_nan(
394
+ self.true_positives,
395
+ ops.add(self.true_positives, self.false_positives),
396
+ )
397
+ return result[0] if len(self.thresholds) == 1 else result
398
+
399
+ def reset_state(self):
400
+ num_thresholds = len(to_list(self.thresholds))
401
+ self.true_positives.assign(ops.zeros((num_thresholds,)))
402
+ self.false_positives.assign(ops.zeros((num_thresholds,)))
403
+
404
+ def get_config(self):
405
+ config = {
406
+ "thresholds": self.init_thresholds,
407
+ "top_k": self.top_k,
408
+ "class_id": self.class_id,
409
+ }
410
+ base_config = super().get_config()
411
+ return {**base_config, **config}
412
+
413
+
414
+ @keras_export("keras.metrics.Recall")
415
+ class Recall(Metric):
416
+ """Computes the recall of the predictions with respect to the labels.
417
+
418
+ This metric creates two local variables, `true_positives` and
419
+ `false_negatives`, that are used to compute the recall. This value is
420
+ ultimately returned as `recall`, an idempotent operation that simply divides
421
+ `true_positives` by the sum of `true_positives` and `false_negatives`.
422
+
423
+ If `sample_weight` is `None`, weights default to 1.
424
+ Use `sample_weight` of 0 to mask values.
425
+
426
+ If `top_k` is set, recall will be computed as how often on average a class
427
+ among the labels of a batch entry is in the top-k predictions.
428
+
429
+ If `class_id` is specified, we calculate recall by considering only the
430
+ entries in the batch for which `class_id` is in the label, and computing the
431
+ fraction of them for which `class_id` is above the threshold and/or in the
432
+ top-k predictions.
433
+
434
+ Args:
435
+ thresholds: (Optional) A float value, or a Python list/tuple of float
436
+ threshold values in `[0, 1]`. A threshold is compared with
437
+ prediction values to determine the truth value of predictions (i.e.,
438
+ above the threshold is `True`, below is `False`). If used with a
439
+ loss function that sets `from_logits=True` (i.e. no sigmoid
440
+ applied to predictions), `thresholds` should be set to 0.
441
+ One metric value is generated for each threshold value.
442
+ If neither `thresholds` nor `top_k` are set,
443
+ the default is to calculate recall with `thresholds=0.5`.
444
+ top_k: (Optional) Unset by default. An int value specifying the top-k
445
+ predictions to consider when calculating recall.
446
+ class_id: (Optional) Integer class ID for which we want binary metrics.
447
+ This must be in the half-open interval `[0, num_classes)`, where
448
+ `num_classes` is the last dimension of predictions.
449
+ name: (Optional) string name of the metric instance.
450
+ dtype: (Optional) data type of the metric result.
451
+
452
+ Example:
453
+
454
+ >>> m = keras.metrics.Recall()
455
+ >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
456
+ >>> m.result()
457
+ 0.6666667
458
+
459
+ >>> m.reset_state()
460
+ >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
461
+ >>> m.result()
462
+ 1.0
463
+
464
+ Usage with `compile()` API:
465
+
466
+ ```python
467
+ model.compile(optimizer='sgd',
468
+ loss='binary_crossentropy',
469
+ metrics=[keras.metrics.Recall()])
470
+ ```
471
+
472
+ Usage with a loss with `from_logits=True`:
473
+
474
+ ```python
475
+ model.compile(optimizer='adam',
476
+ loss=keras.losses.BinaryCrossentropy(from_logits=True),
477
+ metrics=[keras.metrics.Recall(thresholds=0)])
478
+ ```
479
+ """
480
+
481
+ def __init__(
482
+ self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None
483
+ ):
484
+ super().__init__(name=name, dtype=dtype)
485
+ # Metric should be maximized during optimization.
486
+ self._direction = "up"
487
+
488
+ self.init_thresholds = thresholds
489
+ self.top_k = top_k
490
+ self.class_id = class_id
491
+
492
+ default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
493
+ self.thresholds = metrics_utils.parse_init_thresholds(
494
+ thresholds, default_threshold=default_threshold
495
+ )
496
+ self._thresholds_distributed_evenly = (
497
+ metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
498
+ )
499
+ self.true_positives = self.add_variable(
500
+ shape=(len(self.thresholds),),
501
+ initializer=initializers.Zeros(),
502
+ name="true_positives",
503
+ )
504
+ self.false_negatives = self.add_variable(
505
+ shape=(len(self.thresholds),),
506
+ initializer=initializers.Zeros(),
507
+ name="false_negatives",
508
+ )
509
+
510
+ def update_state(self, y_true, y_pred, sample_weight=None):
511
+ """Accumulates true positive and false negative statistics.
512
+
513
+ Args:
514
+ y_true: The ground truth values, with the same dimensions as
515
+ `y_pred`. Will be cast to `bool`.
516
+ y_pred: The predicted values. Each element must be in the range
517
+ `[0, 1]`.
518
+ sample_weight: Optional weighting of each example. Defaults to `1`.
519
+ Can be a tensor whose rank is either 0, or the same rank as
520
+ `y_true`, and must be broadcastable to `y_true`.
521
+ """
522
+ metrics_utils.update_confusion_matrix_variables(
523
+ {
524
+ metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
525
+ metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
526
+ },
527
+ y_true,
528
+ y_pred,
529
+ thresholds=self.thresholds,
530
+ thresholds_distributed_evenly=self._thresholds_distributed_evenly,
531
+ top_k=self.top_k,
532
+ class_id=self.class_id,
533
+ sample_weight=sample_weight,
534
+ )
535
+
536
+ def result(self):
537
+ result = ops.divide_no_nan(
538
+ self.true_positives,
539
+ ops.add(self.true_positives, self.false_negatives),
540
+ )
541
+ return result[0] if len(self.thresholds) == 1 else result
542
+
543
+ def reset_state(self):
544
+ num_thresholds = len(to_list(self.thresholds))
545
+ self.true_positives.assign(ops.zeros((num_thresholds,)))
546
+ self.false_negatives.assign(ops.zeros((num_thresholds,)))
547
+
548
+ def get_config(self):
549
+ config = {
550
+ "thresholds": self.init_thresholds,
551
+ "top_k": self.top_k,
552
+ "class_id": self.class_id,
553
+ }
554
+ base_config = super().get_config()
555
+ return {**base_config, **config}
556
+
557
+
558
+ class SensitivitySpecificityBase(Metric):
559
+ """Abstract base class for computing sensitivity and specificity.
560
+
561
+ For additional information about specificity and sensitivity, see
562
+ [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
563
+ """
564
+
565
+ def __init__(
566
+ self, value, num_thresholds=200, class_id=None, name=None, dtype=None
567
+ ):
568
+ super().__init__(name=name, dtype=dtype)
569
+ # Metric should be maximized during optimization.
570
+ self._direction = "up"
571
+
572
+ if num_thresholds <= 0:
573
+ raise ValueError(
574
+ "Argument `num_thresholds` must be an integer > 0. "
575
+ f"Received: num_thresholds={num_thresholds}"
576
+ )
577
+ self.value = value
578
+ self.class_id = class_id
579
+
580
+ # Compute `num_thresholds` thresholds in [0, 1]
581
+ if num_thresholds == 1:
582
+ self.thresholds = [0.5]
583
+ self._thresholds_distributed_evenly = False
584
+ else:
585
+ thresholds = [
586
+ (i + 1) * 1.0 / (num_thresholds - 1)
587
+ for i in range(num_thresholds - 2)
588
+ ]
589
+ self.thresholds = [0.0] + thresholds + [1.0]
590
+ self._thresholds_distributed_evenly = True
591
+
592
+ self.true_positives = self.add_variable(
593
+ shape=(len(self.thresholds),),
594
+ initializer=initializers.Zeros(),
595
+ name="true_positives",
596
+ )
597
+ self.false_positives = self.add_variable(
598
+ shape=(len(self.thresholds),),
599
+ initializer=initializers.Zeros(),
600
+ name="false_positives",
601
+ )
602
+ self.true_negatives = self.add_variable(
603
+ shape=(len(self.thresholds),),
604
+ initializer=initializers.Zeros(),
605
+ name="true_negatives",
606
+ )
607
+ self.false_negatives = self.add_variable(
608
+ shape=(len(self.thresholds),),
609
+ initializer=initializers.Zeros(),
610
+ name="false_negatives",
611
+ )
612
+
613
+ def update_state(self, y_true, y_pred, sample_weight=None):
614
+ """Accumulates confusion matrix statistics.
615
+
616
+ Args:
617
+ y_true: The ground truth values.
618
+ y_pred: The predicted values.
619
+ sample_weight: Optional weighting of each example. Defaults to `1`.
620
+ Can be a tensor whose rank is either 0, or the same rank as
621
+ `y_true`, and must be broadcastable to `y_true`.
622
+ """
623
+ metrics_utils.update_confusion_matrix_variables(
624
+ {
625
+ metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
626
+ metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
627
+ metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
628
+ metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
629
+ },
630
+ y_true,
631
+ y_pred,
632
+ thresholds=self.thresholds,
633
+ thresholds_distributed_evenly=self._thresholds_distributed_evenly,
634
+ class_id=self.class_id,
635
+ sample_weight=sample_weight,
636
+ )
637
+
638
+ def reset_state(self):
639
+ num_thresholds = len(self.thresholds)
640
+ self.true_positives.assign(ops.zeros((num_thresholds,)))
641
+ self.false_positives.assign(ops.zeros((num_thresholds,)))
642
+ self.true_negatives.assign(ops.zeros((num_thresholds,)))
643
+ self.false_negatives.assign(ops.zeros((num_thresholds,)))
644
+
645
+ def get_config(self):
646
+ config = {"class_id": self.class_id}
647
+ base_config = super().get_config()
648
+ return {**base_config, **config}
649
+
650
+ def _find_max_under_constraint(self, constrained, dependent, predicate):
651
+ """Returns the maximum of dependent_statistic that satisfies the
652
+ constraint.
653
+
654
+ Args:
655
+ constrained: Over these values the constraint is specified. A rank-1
656
+ tensor.
657
+ dependent: From these values the maximum that satiesfies the
658
+ constraint is selected. Values in this tensor and in
659
+ `constrained` are linked by having the same threshold at each
660
+ position, hence this tensor must have the same shape.
661
+ predicate: A binary boolean functor to be applied to arguments
662
+ `constrained` and `self.value`, e.g. `ops.greater`.
663
+
664
+ Returns:
665
+ maximal dependent value, if no value satisfies the constraint 0.0.
666
+ """
667
+ feasible = ops.nonzero(predicate(constrained, self.value))
668
+ feasible_exists = ops.greater(ops.size(feasible), 0)
669
+ max_dependent = ops.max(ops.take(dependent, feasible), initial=0)
670
+
671
+ return ops.where(feasible_exists, max_dependent, 0.0)
672
+
673
+
674
+ @keras_export("keras.metrics.SensitivityAtSpecificity")
675
+ class SensitivityAtSpecificity(SensitivitySpecificityBase):
676
+ """Computes best sensitivity where specificity is >= specified value.
677
+
678
+ `Sensitivity` measures the proportion of actual positives that are correctly
679
+ identified as such `(tp / (tp + fn))`.
680
+ `Specificity` measures the proportion of actual negatives that are correctly
681
+ identified as such `(tn / (tn + fp))`.
682
+
683
+ This metric creates four local variables, `true_positives`,
684
+ `true_negatives`, `false_positives` and `false_negatives` that are used to
685
+ compute the sensitivity at the given specificity. The threshold for the
686
+ given specificity value is computed and used to evaluate the corresponding
687
+ sensitivity.
688
+
689
+ If `sample_weight` is `None`, weights default to 1.
690
+ Use `sample_weight` of 0 to mask values.
691
+
692
+ If `class_id` is specified, we calculate precision by considering only the
693
+ entries in the batch for which `class_id` is above the threshold
694
+ predictions, and computing the fraction of them for which `class_id` is
695
+ indeed a correct label.
696
+
697
+ For additional information about specificity and sensitivity, see
698
+ [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
699
+
700
+ Args:
701
+ specificity: A scalar value in range `[0, 1]`.
702
+ num_thresholds: (Optional) Defaults to 200. The number of thresholds to
703
+ use for matching the given specificity.
704
+ class_id: (Optional) Integer class ID for which we want binary metrics.
705
+ This must be in the half-open interval `[0, num_classes)`, where
706
+ `num_classes` is the last dimension of predictions.
707
+ name: (Optional) string name of the metric instance.
708
+ dtype: (Optional) data type of the metric result.
709
+
710
+ Example:
711
+
712
+ >>> m = keras.metrics.SensitivityAtSpecificity(0.5)
713
+ >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
714
+ >>> m.result()
715
+ 0.5
716
+
717
+ >>> m.reset_state()
718
+ >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
719
+ ... sample_weight=[1, 1, 2, 2, 1])
720
+ >>> m.result()
721
+ 0.333333
722
+
723
+ Usage with `compile()` API:
724
+
725
+ ```python
726
+ model.compile(
727
+ optimizer='sgd',
728
+ loss='binary_crossentropy',
729
+ metrics=[keras.metrics.SensitivityAtSpecificity()])
730
+ ```
731
+ """
732
+
733
+ def __init__(
734
+ self,
735
+ specificity,
736
+ num_thresholds=200,
737
+ class_id=None,
738
+ name=None,
739
+ dtype=None,
740
+ ):
741
+ if specificity < 0 or specificity > 1:
742
+ raise ValueError(
743
+ "Argument `specificity` must be in the range [0, 1]. "
744
+ f"Received: specificity={specificity}"
745
+ )
746
+ self.specificity = specificity
747
+ self.num_thresholds = num_thresholds
748
+ super().__init__(
749
+ specificity,
750
+ num_thresholds=num_thresholds,
751
+ class_id=class_id,
752
+ name=name,
753
+ dtype=dtype,
754
+ )
755
+
756
+ def result(self):
757
+ sensitivities = ops.divide_no_nan(
758
+ self.true_positives,
759
+ ops.add(self.true_positives, self.false_negatives),
760
+ )
761
+ specificities = ops.divide_no_nan(
762
+ self.true_negatives,
763
+ ops.add(self.true_negatives, self.false_positives),
764
+ )
765
+ return self._find_max_under_constraint(
766
+ specificities, sensitivities, ops.greater_equal
767
+ )
768
+
769
+ def get_config(self):
770
+ config = {
771
+ "num_thresholds": self.num_thresholds,
772
+ "specificity": self.specificity,
773
+ }
774
+ base_config = super().get_config()
775
+ return {**base_config, **config}
776
+
777
+
778
+ @keras_export("keras.metrics.SpecificityAtSensitivity")
779
+ class SpecificityAtSensitivity(SensitivitySpecificityBase):
780
+ """Computes best specificity where sensitivity is >= specified value.
781
+
782
+ `Sensitivity` measures the proportion of actual positives that are correctly
783
+ identified as such `(tp / (tp + fn))`.
784
+ `Specificity` measures the proportion of actual negatives that are correctly
785
+ identified as such `(tn / (tn + fp))`.
786
+
787
+ This metric creates four local variables, `true_positives`,
788
+ `true_negatives`, `false_positives` and `false_negatives` that are used to
789
+ compute the specificity at the given sensitivity. The threshold for the
790
+ given sensitivity value is computed and used to evaluate the corresponding
791
+ specificity.
792
+
793
+ If `sample_weight` is `None`, weights default to 1.
794
+ Use `sample_weight` of 0 to mask values.
795
+
796
+ If `class_id` is specified, we calculate precision by considering only the
797
+ entries in the batch for which `class_id` is above the threshold
798
+ predictions, and computing the fraction of them for which `class_id` is
799
+ indeed a correct label.
800
+
801
+ For additional information about specificity and sensitivity, see
802
+ [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
803
+
804
+ Args:
805
+ sensitivity: A scalar value in range `[0, 1]`.
806
+ num_thresholds: (Optional) Defaults to 200. The number of thresholds to
807
+ use for matching the given sensitivity.
808
+ class_id: (Optional) Integer class ID for which we want binary metrics.
809
+ This must be in the half-open interval `[0, num_classes)`, where
810
+ `num_classes` is the last dimension of predictions.
811
+ name: (Optional) string name of the metric instance.
812
+ dtype: (Optional) data type of the metric result.
813
+
814
+ Example:
815
+
816
+ >>> m = keras.metrics.SpecificityAtSensitivity(0.5)
817
+ >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
818
+ >>> m.result()
819
+ 0.66666667
820
+
821
+ >>> m.reset_state()
822
+ >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
823
+ ... sample_weight=[1, 1, 2, 2, 2])
824
+ >>> m.result()
825
+ 0.5
826
+
827
+ Usage with `compile()` API:
828
+
829
+ ```python
830
+ model.compile(
831
+ optimizer='sgd',
832
+ loss='binary_crossentropy',
833
+ metrics=[keras.metrics.SpecificityAtSensitivity()])
834
+ ```
835
+ """
836
+
837
+ def __init__(
838
+ self,
839
+ sensitivity,
840
+ num_thresholds=200,
841
+ class_id=None,
842
+ name=None,
843
+ dtype=None,
844
+ ):
845
+ if sensitivity < 0 or sensitivity > 1:
846
+ raise ValueError(
847
+ "Argument `sensitivity` must be in the range [0, 1]. "
848
+ f"Received: sensitivity={sensitivity}"
849
+ )
850
+ self.sensitivity = sensitivity
851
+ self.num_thresholds = num_thresholds
852
+ super().__init__(
853
+ sensitivity,
854
+ num_thresholds=num_thresholds,
855
+ class_id=class_id,
856
+ name=name,
857
+ dtype=dtype,
858
+ )
859
+
860
+ def result(self):
861
+ sensitivities = ops.divide_no_nan(
862
+ self.true_positives,
863
+ ops.add(self.true_positives, self.false_negatives),
864
+ )
865
+ specificities = ops.divide_no_nan(
866
+ self.true_negatives,
867
+ ops.add(self.true_negatives, self.false_positives),
868
+ )
869
+ return self._find_max_under_constraint(
870
+ sensitivities, specificities, ops.greater_equal
871
+ )
872
+
873
+ def get_config(self):
874
+ config = {
875
+ "num_thresholds": self.num_thresholds,
876
+ "sensitivity": self.sensitivity,
877
+ }
878
+ base_config = super().get_config()
879
+ return {**base_config, **config}
880
+
881
+
882
+ @keras_export("keras.metrics.PrecisionAtRecall")
883
+ class PrecisionAtRecall(SensitivitySpecificityBase):
884
+ """Computes best precision where recall is >= specified value.
885
+
886
+ This metric creates four local variables, `true_positives`,
887
+ `true_negatives`, `false_positives` and `false_negatives` that are used to
888
+ compute the precision at the given recall. The threshold for the given
889
+ recall value is computed and used to evaluate the corresponding precision.
890
+
891
+ If `sample_weight` is `None`, weights default to 1.
892
+ Use `sample_weight` of 0 to mask values.
893
+
894
+ If `class_id` is specified, we calculate precision by considering only the
895
+ entries in the batch for which `class_id` is above the threshold
896
+ predictions, and computing the fraction of them for which `class_id` is
897
+ indeed a correct label.
898
+
899
+ Args:
900
+ recall: A scalar value in range `[0, 1]`.
901
+ num_thresholds: (Optional) Defaults to 200. The number of thresholds to
902
+ use for matching the given recall.
903
+ class_id: (Optional) Integer class ID for which we want binary metrics.
904
+ This must be in the half-open interval `[0, num_classes)`, where
905
+ `num_classes` is the last dimension of predictions.
906
+ name: (Optional) string name of the metric instance.
907
+ dtype: (Optional) data type of the metric result.
908
+
909
+ Example:
910
+
911
+ >>> m = keras.metrics.PrecisionAtRecall(0.5)
912
+ >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
913
+ >>> m.result()
914
+ 0.5
915
+
916
+ >>> m.reset_state()
917
+ >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
918
+ ... sample_weight=[2, 2, 2, 1, 1])
919
+ >>> m.result()
920
+ 0.33333333
921
+
922
+ Usage with `compile()` API:
923
+
924
+ ```python
925
+ model.compile(
926
+ optimizer='sgd',
927
+ loss='binary_crossentropy',
928
+ metrics=[keras.metrics.PrecisionAtRecall(recall=0.8)])
929
+ ```
930
+ """
931
+
932
+ def __init__(
933
+ self, recall, num_thresholds=200, class_id=None, name=None, dtype=None
934
+ ):
935
+ if recall < 0 or recall > 1:
936
+ raise ValueError(
937
+ "Argument `recall` must be in the range [0, 1]. "
938
+ f"Received: recall={recall}"
939
+ )
940
+ self.recall = recall
941
+ self.num_thresholds = num_thresholds
942
+ super().__init__(
943
+ value=recall,
944
+ num_thresholds=num_thresholds,
945
+ class_id=class_id,
946
+ name=name,
947
+ dtype=dtype,
948
+ )
949
+
950
+ def result(self):
951
+ recalls = ops.divide_no_nan(
952
+ self.true_positives,
953
+ ops.add(self.true_positives, self.false_negatives),
954
+ )
955
+ precisions = ops.divide_no_nan(
956
+ self.true_positives,
957
+ ops.add(self.true_positives, self.false_positives),
958
+ )
959
+ return self._find_max_under_constraint(
960
+ recalls, precisions, ops.greater_equal
961
+ )
962
+
963
+ def get_config(self):
964
+ config = {"num_thresholds": self.num_thresholds, "recall": self.recall}
965
+ base_config = super().get_config()
966
+ return {**base_config, **config}
967
+
968
+
969
+ @keras_export("keras.metrics.RecallAtPrecision")
970
+ class RecallAtPrecision(SensitivitySpecificityBase):
971
+ """Computes best recall where precision is >= specified value.
972
+
973
+ For a given score-label-distribution the required precision might not
974
+ be achievable, in this case 0.0 is returned as recall.
975
+
976
+ This metric creates four local variables, `true_positives`,
977
+ `true_negatives`, `false_positives` and `false_negatives` that are used to
978
+ compute the recall at the given precision. The threshold for the given
979
+ precision value is computed and used to evaluate the corresponding recall.
980
+
981
+ If `sample_weight` is `None`, weights default to 1.
982
+ Use `sample_weight` of 0 to mask values.
983
+
984
+ If `class_id` is specified, we calculate precision by considering only the
985
+ entries in the batch for which `class_id` is above the threshold
986
+ predictions, and computing the fraction of them for which `class_id` is
987
+ indeed a correct label.
988
+
989
+ Args:
990
+ precision: A scalar value in range `[0, 1]`.
991
+ num_thresholds: (Optional) Defaults to 200. The number of thresholds
992
+ to use for matching the given precision.
993
+ class_id: (Optional) Integer class ID for which we want binary metrics.
994
+ This must be in the half-open interval `[0, num_classes)`, where
995
+ `num_classes` is the last dimension of predictions.
996
+ name: (Optional) string name of the metric instance.
997
+ dtype: (Optional) data type of the metric result.
998
+
999
+ Example:
1000
+
1001
+ >>> m = keras.metrics.RecallAtPrecision(0.8)
1002
+ >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1003
+ >>> m.result()
1004
+ 0.5
1005
+
1006
+ >>> m.reset_state()
1007
+ >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
1008
+ ... sample_weight=[1, 0, 0, 1])
1009
+ >>> m.result()
1010
+ 1.0
1011
+
1012
+ Usage with `compile()` API:
1013
+
1014
+ ```python
1015
+ model.compile(
1016
+ optimizer='sgd',
1017
+ loss='binary_crossentropy',
1018
+ metrics=[keras.metrics.RecallAtPrecision(precision=0.8)])
1019
+ ```
1020
+ """
1021
+
1022
+ def __init__(
1023
+ self,
1024
+ precision,
1025
+ num_thresholds=200,
1026
+ class_id=None,
1027
+ name=None,
1028
+ dtype=None,
1029
+ ):
1030
+ if precision < 0 or precision > 1:
1031
+ raise ValueError(
1032
+ "Argument `precision` must be in the range [0, 1]. "
1033
+ f"Received: precision={precision}"
1034
+ )
1035
+ self.precision = precision
1036
+ self.num_thresholds = num_thresholds
1037
+ super().__init__(
1038
+ value=precision,
1039
+ num_thresholds=num_thresholds,
1040
+ class_id=class_id,
1041
+ name=name,
1042
+ dtype=dtype,
1043
+ )
1044
+
1045
+ def result(self):
1046
+ recalls = ops.divide_no_nan(
1047
+ self.true_positives,
1048
+ ops.add(self.true_positives, self.false_negatives),
1049
+ )
1050
+ precisions = ops.divide_no_nan(
1051
+ self.true_positives,
1052
+ ops.add(self.true_positives, self.false_positives),
1053
+ )
1054
+ return self._find_max_under_constraint(
1055
+ precisions, recalls, ops.greater_equal
1056
+ )
1057
+
1058
+ def get_config(self):
1059
+ config = {
1060
+ "num_thresholds": self.num_thresholds,
1061
+ "precision": self.precision,
1062
+ }
1063
+ base_config = super().get_config()
1064
+ return {**base_config, **config}
1065
+
1066
+
1067
+ @keras_export("keras.metrics.AUC")
1068
+ class AUC(Metric):
1069
+ """Approximates the AUC (Area under the curve) of the ROC or PR curves.
1070
+
1071
+ The AUC (Area under the curve) of the ROC (Receiver operating
1072
+ characteristic; default) or PR (Precision Recall) curves are quality
1073
+ measures of binary classifiers. Unlike the accuracy, and like cross-entropy
1074
+ losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
1075
+
1076
+ This class approximates AUCs using a Riemann sum. During the metric
1077
+ accumulation phrase, predictions are accumulated within predefined buckets
1078
+ by value. The AUC is then computed by interpolating per-bucket averages.
1079
+ These buckets define the evaluated operational points.
1080
+
1081
+ This metric creates four local variables, `true_positives`,
1082
+ `true_negatives`, `false_positives` and `false_negatives` that are used to
1083
+ compute the AUC. To discretize the AUC curve, a linearly spaced set of
1084
+ thresholds is used to compute pairs of recall and precision values. The area
1085
+ under the ROC-curve is therefore computed using the height of the recall
1086
+ values by the false positive rate, while the area under the PR-curve is the
1087
+ computed using the height of the precision values by the recall.
1088
+
1089
+ This value is ultimately returned as `auc`, an idempotent operation that
1090
+ computes the area under a discretized curve of precision versus recall
1091
+ values (computed using the aforementioned variables). The `num_thresholds`
1092
+ variable controls the degree of discretization with larger numbers of
1093
+ thresholds more closely approximating the true AUC. The quality of the
1094
+ approximation may vary dramatically depending on `num_thresholds`. The
1095
+ `thresholds` parameter can be used to manually specify thresholds which
1096
+ split the predictions more evenly.
1097
+
1098
+ For a best approximation of the real AUC, `predictions` should be
1099
+ distributed approximately uniformly in the range `[0, 1]` (if
1100
+ `from_logits=False`). The quality of the AUC approximation may be poor if
1101
+ this is not the case. Setting `summation_method` to 'minoring' or 'majoring'
1102
+ can help quantify the error in the approximation by providing lower or upper
1103
+ bound estimate of the AUC.
1104
+
1105
+ If `sample_weight` is `None`, weights default to 1.
1106
+ Use `sample_weight` of 0 to mask values.
1107
+
1108
+ Args:
1109
+ num_thresholds: (Optional) The number of thresholds to
1110
+ use when discretizing the roc curve. Values must be > 1.
1111
+ Defaults to `200`.
1112
+ curve: (Optional) Specifies the name of the curve to be computed,
1113
+ `'ROC'` (default) or `'PR'` for the Precision-Recall-curve.
1114
+ summation_method: (Optional) Specifies the [Riemann summation method](
1115
+ https://en.wikipedia.org/wiki/Riemann_sum) used.
1116
+ 'interpolation' (default) applies mid-point summation scheme for
1117
+ `ROC`. For PR-AUC, interpolates (true/false) positives but not
1118
+ the ratio that is precision (see Davis & Goadrich 2006 for
1119
+ details); 'minoring' applies left summation for increasing
1120
+ intervals and right summation for decreasing intervals; 'majoring'
1121
+ does the opposite.
1122
+ name: (Optional) string name of the metric instance.
1123
+ dtype: (Optional) data type of the metric result.
1124
+ thresholds: (Optional) A list of floating point values to use as the
1125
+ thresholds for discretizing the curve. If set, the `num_thresholds`
1126
+ parameter is ignored. Values should be in `[0, 1]`. Endpoint
1127
+ thresholds equal to {`-epsilon`, `1+epsilon`} for a small positive
1128
+ epsilon value will be automatically included with these to correctly
1129
+ handle predictions equal to exactly 0 or 1.
1130
+ multi_label: boolean indicating whether multilabel data should be
1131
+ treated as such, wherein AUC is computed separately for each label
1132
+ and then averaged across labels, or (when `False`) if the data
1133
+ should be flattened into a single label before AUC computation. In
1134
+ the latter case, when multilabel data is passed to AUC, each
1135
+ label-prediction pair is treated as an individual data point. Should
1136
+ be set to `False` for multi-class data.
1137
+ num_labels: (Optional) The number of labels, used when `multi_label` is
1138
+ True. If `num_labels` is not specified, then state variables get
1139
+ created on the first call to `update_state`.
1140
+ label_weights: (Optional) list, array, or tensor of non-negative weights
1141
+ used to compute AUCs for multilabel data. When `multi_label` is
1142
+ True, the weights are applied to the individual label AUCs when they
1143
+ are averaged to produce the multi-label AUC. When it's False, they
1144
+ are used to weight the individual label predictions in computing the
1145
+ confusion matrix on the flattened data. Note that this is unlike
1146
+ `class_weights` in that `class_weights` weights the example
1147
+ depending on the value of its label, whereas `label_weights` depends
1148
+ only on the index of that label before flattening; therefore
1149
+ `label_weights` should not be used for multi-class data.
1150
+ from_logits: boolean indicating whether the predictions (`y_pred` in
1151
+ `update_state`) are probabilities or sigmoid logits. As a rule of thumb,
1152
+ when using a keras loss, the `from_logits` constructor argument of the
1153
+ loss should match the AUC `from_logits` constructor argument.
1154
+
1155
+ Example:
1156
+
1157
+ >>> m = keras.metrics.AUC(num_thresholds=3)
1158
+ >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1159
+ >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
1160
+ >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
1161
+ >>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
1162
+ >>> # auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0)))
1163
+ >>> # = 0.75
1164
+ >>> m.result()
1165
+ 0.75
1166
+
1167
+ >>> m.reset_state()
1168
+ >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
1169
+ ... sample_weight=[1, 0, 0, 1])
1170
+ >>> m.result()
1171
+ 1.0
1172
+
1173
+ Usage with `compile()` API:
1174
+
1175
+ ```python
1176
+ # Reports the AUC of a model outputting a probability.
1177
+ model.compile(optimizer='sgd',
1178
+ loss=keras.losses.BinaryCrossentropy(),
1179
+ metrics=[keras.metrics.AUC()])
1180
+
1181
+ # Reports the AUC of a model outputting a logit.
1182
+ model.compile(optimizer='sgd',
1183
+ loss=keras.losses.BinaryCrossentropy(from_logits=True),
1184
+ metrics=[keras.metrics.AUC(from_logits=True)])
1185
+ ```
1186
+ """
1187
+
1188
+ def __init__(
1189
+ self,
1190
+ num_thresholds=200,
1191
+ curve="ROC",
1192
+ summation_method="interpolation",
1193
+ name=None,
1194
+ dtype=None,
1195
+ thresholds=None,
1196
+ multi_label=False,
1197
+ num_labels=None,
1198
+ label_weights=None,
1199
+ from_logits=False,
1200
+ ):
1201
+ # Metric should be maximized during optimization.
1202
+ self._direction = "up"
1203
+
1204
+ # Validate configurations.
1205
+ if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
1206
+ metrics_utils.AUCCurve
1207
+ ):
1208
+ raise ValueError(
1209
+ f'Invalid `curve` argument value "{curve}". '
1210
+ f"Expected one of: {list(metrics_utils.AUCCurve)}"
1211
+ )
1212
+ if isinstance(
1213
+ summation_method, metrics_utils.AUCSummationMethod
1214
+ ) and summation_method not in list(metrics_utils.AUCSummationMethod):
1215
+ raise ValueError(
1216
+ "Invalid `summation_method` "
1217
+ f'argument value "{summation_method}". '
1218
+ f"Expected one of: {list(metrics_utils.AUCSummationMethod)}"
1219
+ )
1220
+
1221
+ # Update properties.
1222
+ self._init_from_thresholds = thresholds is not None
1223
+ if thresholds is not None:
1224
+ # If specified, use the supplied thresholds.
1225
+ self.num_thresholds = len(thresholds) + 2
1226
+ thresholds = sorted(thresholds)
1227
+ self._thresholds_distributed_evenly = (
1228
+ metrics_utils.is_evenly_distributed_thresholds(
1229
+ np.array([0.0] + thresholds + [1.0])
1230
+ )
1231
+ )
1232
+ else:
1233
+ if num_thresholds <= 1:
1234
+ raise ValueError(
1235
+ "Argument `num_thresholds` must be an integer > 1. "
1236
+ f"Received: num_thresholds={num_thresholds}"
1237
+ )
1238
+
1239
+ # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
1240
+ # (0, 1).
1241
+ self.num_thresholds = num_thresholds
1242
+ thresholds = [
1243
+ (i + 1) * 1.0 / (num_thresholds - 1)
1244
+ for i in range(num_thresholds - 2)
1245
+ ]
1246
+ self._thresholds_distributed_evenly = True
1247
+
1248
+ # Add an endpoint "threshold" below zero and above one for either
1249
+ # threshold method to account for floating point imprecisions.
1250
+ self._thresholds = np.array(
1251
+ [0.0 - backend.epsilon()] + thresholds + [1.0 + backend.epsilon()]
1252
+ )
1253
+
1254
+ if isinstance(curve, metrics_utils.AUCCurve):
1255
+ self.curve = curve
1256
+ else:
1257
+ self.curve = metrics_utils.AUCCurve.from_str(curve)
1258
+ if isinstance(summation_method, metrics_utils.AUCSummationMethod):
1259
+ self.summation_method = summation_method
1260
+ else:
1261
+ self.summation_method = metrics_utils.AUCSummationMethod.from_str(
1262
+ summation_method
1263
+ )
1264
+ super().__init__(name=name, dtype=dtype)
1265
+
1266
+ # Handle multilabel arguments.
1267
+ self.multi_label = multi_label
1268
+ self.num_labels = num_labels
1269
+ if label_weights is not None:
1270
+ label_weights = ops.array(label_weights, dtype=self.dtype)
1271
+ self.label_weights = label_weights
1272
+
1273
+ else:
1274
+ self.label_weights = None
1275
+
1276
+ self._from_logits = from_logits
1277
+
1278
+ self._built = False
1279
+ if self.multi_label:
1280
+ if num_labels:
1281
+ shape = [None, num_labels]
1282
+ self._build(shape)
1283
+ else:
1284
+ if num_labels:
1285
+ raise ValueError(
1286
+ "`num_labels` is needed only when `multi_label` is True."
1287
+ )
1288
+ self._build(None)
1289
+
1290
+ @property
1291
+ def thresholds(self):
1292
+ """The thresholds used for evaluating AUC."""
1293
+ return list(self._thresholds)
1294
+
1295
+ def _build(self, shape):
1296
+ """Initialize TP, FP, TN, and FN tensors, given the shape of the
1297
+ data."""
1298
+ if self.multi_label:
1299
+ if len(shape) != 2:
1300
+ raise ValueError(
1301
+ "`y_pred` must have rank 2 when `multi_label=True`. "
1302
+ f"Found rank {len(shape)}. "
1303
+ f"Full shape received for `y_pred`: {shape}"
1304
+ )
1305
+ self._num_labels = shape[1]
1306
+ variable_shape = [self.num_thresholds, self._num_labels]
1307
+ else:
1308
+ variable_shape = [self.num_thresholds]
1309
+
1310
+ self._build_input_shape = shape
1311
+ # Create metric variables
1312
+ self.true_positives = self.add_variable(
1313
+ shape=variable_shape,
1314
+ initializer=initializers.Zeros(),
1315
+ name="true_positives",
1316
+ )
1317
+ self.false_positives = self.add_variable(
1318
+ shape=variable_shape,
1319
+ initializer=initializers.Zeros(),
1320
+ name="false_positives",
1321
+ )
1322
+ self.true_negatives = self.add_variable(
1323
+ shape=variable_shape,
1324
+ initializer=initializers.Zeros(),
1325
+ name="true_negatives",
1326
+ )
1327
+ self.false_negatives = self.add_variable(
1328
+ shape=variable_shape,
1329
+ initializer=initializers.Zeros(),
1330
+ name="false_negatives",
1331
+ )
1332
+
1333
+ self._built = True
1334
+
1335
+ def update_state(self, y_true, y_pred, sample_weight=None):
1336
+ """Accumulates confusion matrix statistics.
1337
+
1338
+ Args:
1339
+ y_true: The ground truth values.
1340
+ y_pred: The predicted values.
1341
+ sample_weight: Optional weighting of each example. Can
1342
+ be a tensor whose rank is either 0, or the same rank as
1343
+ `y_true`, and must be broadcastable to `y_true`. Defaults to
1344
+ `1`.
1345
+ """
1346
+ if not self._built:
1347
+ self._build(y_pred.shape)
1348
+
1349
+ if self.multi_label or (self.label_weights is not None):
1350
+ # y_true should have shape (number of examples, number of labels).
1351
+ shapes = [(y_true, ("N", "L"))]
1352
+ if self.multi_label:
1353
+ # TP, TN, FP, and FN should all have shape
1354
+ # (number of thresholds, number of labels).
1355
+ shapes.extend(
1356
+ [
1357
+ (self.true_positives, ("T", "L")),
1358
+ (self.true_negatives, ("T", "L")),
1359
+ (self.false_positives, ("T", "L")),
1360
+ (self.false_negatives, ("T", "L")),
1361
+ ]
1362
+ )
1363
+ if self.label_weights is not None:
1364
+ # label_weights should be of length equal to the number of
1365
+ # labels.
1366
+ shapes.append((self.label_weights, ("L",)))
1367
+
1368
+ # Only forward label_weights to update_confusion_matrix_variables when
1369
+ # multi_label is False. Otherwise the averaging of individual label AUCs
1370
+ # is handled in AUC.result
1371
+ label_weights = None if self.multi_label else self.label_weights
1372
+
1373
+ if self._from_logits:
1374
+ y_pred = activations.sigmoid(y_pred)
1375
+
1376
+ metrics_utils.update_confusion_matrix_variables(
1377
+ {
1378
+ metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
1379
+ metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
1380
+ metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
1381
+ metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
1382
+ },
1383
+ y_true,
1384
+ y_pred,
1385
+ self._thresholds,
1386
+ thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1387
+ sample_weight=sample_weight,
1388
+ multi_label=self.multi_label,
1389
+ label_weights=label_weights,
1390
+ )
1391
+
1392
+ def interpolate_pr_auc(self):
1393
+ """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
1394
+
1395
+ https://www.biostat.wisc.edu/~page/rocpr.pdf
1396
+
1397
+ Note here we derive & use a closed formula not present in the paper
1398
+ as follows:
1399
+
1400
+ Precision = TP / (TP + FP) = TP / P
1401
+
1402
+ Modeling all of TP (true positive), FP (false positive) and their sum
1403
+ P = TP + FP (predicted positive) as varying linearly within each
1404
+ interval [A, B] between successive thresholds, we get
1405
+
1406
+ Precision slope = dTP / dP
1407
+ = (TP_B - TP_A) / (P_B - P_A)
1408
+ = (TP - TP_A) / (P - P_A)
1409
+ Precision = (TP_A + slope * (P - P_A)) / P
1410
+
1411
+ The area within the interval is (slope / total_pos_weight) times
1412
+
1413
+ int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
1414
+ int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
1415
+
1416
+ where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
1417
+
1418
+ int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
1419
+
1420
+ Bringing back the factor (slope / total_pos_weight) we'd put aside, we
1421
+ get
1422
+
1423
+ slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
1424
+
1425
+ where dTP == TP_B - TP_A.
1426
+
1427
+ Note that when P_A == 0 the above calculation simplifies into
1428
+
1429
+ int_A^B{Precision.dTP} = int_A^B{slope * dTP}
1430
+ = slope * (TP_B - TP_A)
1431
+
1432
+ which is really equivalent to imputing constant precision throughout the
1433
+ first bucket having >0 true positives.
1434
+
1435
+ Returns:
1436
+ pr_auc: an approximation of the area under the P-R curve.
1437
+ """
1438
+
1439
+ dtp = ops.subtract(
1440
+ self.true_positives[: self.num_thresholds - 1],
1441
+ self.true_positives[1:],
1442
+ )
1443
+ p = ops.add(self.true_positives, self.false_positives)
1444
+ dp = ops.subtract(p[: self.num_thresholds - 1], p[1:])
1445
+ prec_slope = ops.divide_no_nan(dtp, ops.maximum(dp, 0))
1446
+ intercept = ops.subtract(
1447
+ self.true_positives[1:], ops.multiply(prec_slope, p[1:])
1448
+ )
1449
+
1450
+ safe_p_ratio = ops.where(
1451
+ ops.logical_and(p[: self.num_thresholds - 1] > 0, p[1:] > 0),
1452
+ ops.divide_no_nan(
1453
+ p[: self.num_thresholds - 1], ops.maximum(p[1:], 0)
1454
+ ),
1455
+ ops.ones_like(p[1:]),
1456
+ )
1457
+
1458
+ pr_auc_increment = ops.divide_no_nan(
1459
+ ops.multiply(
1460
+ prec_slope,
1461
+ (ops.add(dtp, ops.multiply(intercept, ops.log(safe_p_ratio)))),
1462
+ ),
1463
+ ops.maximum(
1464
+ ops.add(self.true_positives[1:], self.false_negatives[1:]), 0
1465
+ ),
1466
+ )
1467
+
1468
+ if self.multi_label:
1469
+ by_label_auc = ops.sum(pr_auc_increment, axis=0)
1470
+ if self.label_weights is None:
1471
+ # Evenly weighted average of the label AUCs.
1472
+ return ops.mean(by_label_auc)
1473
+ else:
1474
+ # Weighted average of the label AUCs.
1475
+ return ops.divide_no_nan(
1476
+ ops.sum(ops.multiply(by_label_auc, self.label_weights)),
1477
+ ops.sum(self.label_weights),
1478
+ )
1479
+ else:
1480
+ return ops.sum(pr_auc_increment)
1481
+
1482
+ def result(self):
1483
+ if (
1484
+ self.curve == metrics_utils.AUCCurve.PR
1485
+ and self.summation_method
1486
+ == metrics_utils.AUCSummationMethod.INTERPOLATION
1487
+ ):
1488
+ # This use case is different and is handled separately.
1489
+ return self.interpolate_pr_auc()
1490
+
1491
+ # Set `x` and `y` values for the curves based on `curve` config.
1492
+ recall = ops.divide_no_nan(
1493
+ self.true_positives,
1494
+ ops.add(self.true_positives, self.false_negatives),
1495
+ )
1496
+ if self.curve == metrics_utils.AUCCurve.ROC:
1497
+ fp_rate = ops.divide_no_nan(
1498
+ self.false_positives,
1499
+ ops.add(self.false_positives, self.true_negatives),
1500
+ )
1501
+ x = fp_rate
1502
+ y = recall
1503
+ else: # curve == 'PR'.
1504
+ precision = ops.divide_no_nan(
1505
+ self.true_positives,
1506
+ ops.add(self.true_positives, self.false_positives),
1507
+ )
1508
+ x = recall
1509
+ y = precision
1510
+
1511
+ # Find the rectangle heights based on `summation_method`.
1512
+ if (
1513
+ self.summation_method
1514
+ == metrics_utils.AUCSummationMethod.INTERPOLATION
1515
+ ):
1516
+ # Note: the case ('PR', 'interpolation') has been handled above.
1517
+ heights = ops.divide(
1518
+ ops.add(y[: self.num_thresholds - 1], y[1:]), 2.0
1519
+ )
1520
+ elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
1521
+ heights = ops.minimum(y[: self.num_thresholds - 1], y[1:])
1522
+ # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
1523
+ else:
1524
+ heights = ops.maximum(y[: self.num_thresholds - 1], y[1:])
1525
+
1526
+ # Sum up the areas of all the rectangles.
1527
+ riemann_terms = ops.multiply(
1528
+ ops.subtract(x[: self.num_thresholds - 1], x[1:]), heights
1529
+ )
1530
+ if self.multi_label:
1531
+ by_label_auc = ops.sum(riemann_terms, axis=0)
1532
+
1533
+ if self.label_weights is None:
1534
+ # Unweighted average of the label AUCs.
1535
+ return ops.mean(by_label_auc)
1536
+ else:
1537
+ # Weighted average of the label AUCs.
1538
+ return ops.divide_no_nan(
1539
+ ops.sum(ops.multiply(by_label_auc, self.label_weights)),
1540
+ ops.sum(self.label_weights),
1541
+ )
1542
+ else:
1543
+ return ops.sum(riemann_terms)
1544
+
1545
+ def reset_state(self):
1546
+ if self._built:
1547
+ if self.multi_label:
1548
+ variable_shape = (self.num_thresholds, self._num_labels)
1549
+ else:
1550
+ variable_shape = (self.num_thresholds,)
1551
+
1552
+ self.true_positives.assign(ops.zeros(variable_shape))
1553
+ self.false_positives.assign(ops.zeros(variable_shape))
1554
+ self.true_negatives.assign(ops.zeros(variable_shape))
1555
+ self.false_negatives.assign(ops.zeros(variable_shape))
1556
+
1557
+ def get_config(self):
1558
+ label_weights = self.label_weights
1559
+ config = {
1560
+ "num_thresholds": self.num_thresholds,
1561
+ "curve": self.curve.value,
1562
+ "summation_method": self.summation_method.value,
1563
+ "multi_label": self.multi_label,
1564
+ "num_labels": self.num_labels,
1565
+ "label_weights": label_weights,
1566
+ "from_logits": self._from_logits,
1567
+ }
1568
+ # optimization to avoid serializing a large number of generated
1569
+ # thresholds
1570
+ if self._init_from_thresholds:
1571
+ # We remove the endpoint thresholds as an inverse of how the
1572
+ # thresholds were initialized. This ensures that a metric
1573
+ # initialized from this config has the same thresholds.
1574
+ config["thresholds"] = self.thresholds[1:-1]
1575
+ base_config = super().get_config()
1576
+ return {**base_config, **config}
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/correlation_metrics.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src import ops
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.losses.loss import squeeze_or_expand_to_same_rank
5
+ from keras.src.metrics import reduction_metrics
6
+
7
+
8
+ @keras_export("keras.metrics.pearson_correlation")
9
+ def pearson_correlation(y_true, y_pred, axis=-1):
10
+ """Computes the Pearson coefficient between labels and predictions.
11
+
12
+ Formula:
13
+
14
+ ```python
15
+ loss = mean(l2norm(y_true - mean(y_true) * l2norm(y_pred - mean(y_pred)))
16
+ ```
17
+
18
+ Args:
19
+ y_true: Tensor of true targets.
20
+ y_pred: Tensor of predicted targets.
21
+ axis: Axis along which to determine similarity. Defaults to `-1`.
22
+
23
+ Returns:
24
+ Pearson Correlation Coefficient tensor.
25
+
26
+ Example:
27
+
28
+ >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]
29
+ >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]
30
+ >>> loss = keras.losses.concordance_correlation(
31
+ ... y_true, y_pred, axis=-1
32
+ ... ).numpy()
33
+ [1. 0.99339927]
34
+ """
35
+ y_pred = ops.convert_to_tensor(y_pred)
36
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
37
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
38
+
39
+ y_true_norm = y_true - ops.mean(y_true, axis=axis, keepdims=True)
40
+ y_pred_norm = y_pred - ops.mean(y_pred, axis=axis, keepdims=True)
41
+
42
+ y_true_norm = y_true_norm / ops.std(y_true_norm, axis=axis, keepdims=True)
43
+ y_pred_norm = y_pred_norm / ops.std(y_pred_norm, axis=axis, keepdims=True)
44
+
45
+ return ops.mean(y_true_norm * y_pred_norm, axis=axis)
46
+
47
+
48
+ @keras_export("keras.metrics.concordance_correlation")
49
+ def concordance_correlation(y_true, y_pred, axis=-1):
50
+ """Computes the Concordance coefficient between labels and predictions.
51
+
52
+ Formula:
53
+
54
+ ```python
55
+ loss = mean(
56
+ 2 * (y_true - mean(y_true) * (y_pred - mean(y_pred)) / (
57
+ var(y_true) + var(y_pred) + square(mean(y_true) - mean(y_pred))
58
+ )
59
+ )
60
+ ```
61
+
62
+ Args:
63
+ y_true: Tensor of true targets.
64
+ y_pred: Tensor of predicted targets.
65
+ axis: Axis along which to determine similarity. Defaults to `-1`.
66
+
67
+ Returns:
68
+ Concordance Correlation Coefficient tensor.
69
+
70
+ Example:
71
+
72
+ >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]
73
+ >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]
74
+ >>> loss = keras.losses.concordance_correlation(
75
+ ... y_true, y_pred, axis=-1
76
+ ... ).numpy()
77
+ [0.97560976 0.98765432]
78
+ """
79
+ y_pred = ops.convert_to_tensor(y_pred)
80
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
81
+ y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)
82
+
83
+ y_true_mean = ops.mean(y_true, axis=axis, keepdims=True)
84
+ y_pred_mean = ops.mean(y_pred, axis=axis, keepdims=True)
85
+
86
+ y_true_var = ops.var(y_true - y_true_mean, axis=axis, keepdims=True)
87
+ y_pred_var = ops.var(y_pred - y_pred_mean, axis=axis, keepdims=True)
88
+
89
+ covar = (y_true - y_pred_mean) * (y_pred - y_pred_mean)
90
+ norm = y_true_var + y_pred_var + ops.square(y_true_mean - y_pred_mean)
91
+
92
+ return ops.mean(2 * covar / (norm + backend.epsilon()), axis=axis)
93
+
94
+
95
+ @keras_export("keras.metrics.PearsonCorrelation")
96
+ class PearsonCorrelation(reduction_metrics.MeanMetricWrapper):
97
+ """Calculates the Pearson Correlation Coefficient (PCC).
98
+
99
+ PCC measures the linear relationship between the true values (`y_true`) and
100
+ the predicted values (`y_pred`). The coefficient ranges from -1 to 1, where
101
+ a value of 1 implies a perfect positive linear correlation, 0 indicates no
102
+ linear correlation, and -1 indicates a perfect negative linear correlation.
103
+
104
+ This metric is widely used in regression tasks where the strength of the
105
+ linear relationship between predictions and true labels is an
106
+ important evaluation criterion.
107
+
108
+ Args:
109
+ name: (Optional) string name of the metric instance.
110
+ dtype: (Optional) data type of the metric result.
111
+ axis: (Optional) integer or tuple of integers of the axis/axes along
112
+ which to compute the metric. Defaults to `-1`.
113
+
114
+ Example:
115
+
116
+ >>> pcc = keras.metrics.PearsonCorrelation(axis=-1)
117
+ >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]
118
+ >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]
119
+ >>> pcc.update_state(y_true, y_pred)
120
+ >>> pcc.result()
121
+ 0.9966996338993913
122
+
123
+ Usage with `compile()` API:
124
+
125
+ ```python
126
+ model.compile(optimizer='sgd',
127
+ loss='mean_squared_error',
128
+ metrics=[keras.metrics.PearsonCorrelation()])
129
+ ```
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ name="pearson_correlation",
135
+ dtype=None,
136
+ axis=-1,
137
+ ):
138
+ super().__init__(
139
+ fn=pearson_correlation,
140
+ name=name,
141
+ dtype=dtype,
142
+ axis=axis,
143
+ )
144
+ self.axis = axis
145
+ # Metric should be maximized during optimization.
146
+ self._direction = "up"
147
+
148
+ def get_config(self):
149
+ return {
150
+ "name": self.name,
151
+ "dtype": self.dtype,
152
+ "axis": self.axis,
153
+ }
154
+
155
+
156
+ @keras_export("keras.metrics.ConcordanceCorrelation")
157
+ class ConcordanceCorrelation(reduction_metrics.MeanMetricWrapper):
158
+ """Calculates the Concordance Correlation Coefficient (CCC).
159
+
160
+ CCC evaluates the agreement between true values (`y_true`) and predicted
161
+ values (`y_pred`) by considering both precision and accuracy. The
162
+ coefficient ranges from -1 to 1, where a value of 1 indicates perfect
163
+ agreement.
164
+
165
+ This metric is useful in regression tasks where it is important to assess
166
+ how well the predictions match the true values, taking into account both
167
+ their correlation and proximity to the 45-degree line of perfect
168
+ concordance.
169
+
170
+ Args:
171
+ name: (Optional) string name of the metric instance.
172
+ dtype: (Optional) data type of the metric result.
173
+ axis: (Optional) integer or tuple of integers of the axis/axes along
174
+ which to compute the metric. Defaults to `-1`.
175
+
176
+ Example:
177
+
178
+ >>> ccc = keras.metrics.ConcordanceCorrelation(axis=-1)
179
+ >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]
180
+ >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]
181
+ >>> ccc.update_state(y_true, y_pred)
182
+ >>> ccc.result()
183
+ 0.9816320385426076
184
+
185
+ Usage with `compile()` API:
186
+
187
+ ```python
188
+ model.compile(optimizer='sgd',
189
+ loss='mean_squared_error',
190
+ metrics=[keras.metrics.ConcordanceCorrelation()])
191
+ ```
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ name="concordance_correlation",
197
+ dtype=None,
198
+ axis=-1,
199
+ ):
200
+ super().__init__(
201
+ fn=concordance_correlation,
202
+ name=name,
203
+ dtype=dtype,
204
+ axis=axis,
205
+ )
206
+ self.axis = axis
207
+ # Metric should be maximized during optimization.
208
+ self._direction = "up"
209
+
210
+ def get_config(self):
211
+ return {
212
+ "name": self.name,
213
+ "dtype": self.dtype,
214
+ "axis": self.axis,
215
+ }
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/f_score_metrics.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src import initializers
3
+ from keras.src import ops
4
+ from keras.src.api_export import keras_export
5
+ from keras.src.metrics.metric import Metric
6
+
7
+
8
+ @keras_export("keras.metrics.FBetaScore")
9
+ class FBetaScore(Metric):
10
+ """Computes F-Beta score.
11
+
12
+ Formula:
13
+
14
+ ```python
15
+ b2 = beta ** 2
16
+ f_beta_score = (1 + b2) * (precision * recall) / (precision * b2 + recall)
17
+ ```
18
+ This is the weighted harmonic mean of precision and recall.
19
+ Its output range is `[0, 1]`. It works for both multi-class
20
+ and multi-label classification.
21
+
22
+ Args:
23
+ average: Type of averaging to be performed across per-class results
24
+ in the multi-class case.
25
+ Acceptable values are `None`, `"micro"`, `"macro"` and
26
+ `"weighted"`. Defaults to `None`.
27
+ If `None`, no averaging is performed and `result()` will return
28
+ the score for each class.
29
+ If `"micro"`, compute metrics globally by counting the total
30
+ true positives, false negatives and false positives.
31
+ If `"macro"`, compute metrics for each label,
32
+ and return their unweighted mean.
33
+ This does not take label imbalance into account.
34
+ If `"weighted"`, compute metrics for each label,
35
+ and return their average weighted by support
36
+ (the number of true instances for each label).
37
+ This alters `"macro"` to account for label imbalance.
38
+ It can result in an score that is not between precision and recall.
39
+ beta: Determines the weight of given to recall
40
+ in the harmonic mean between precision and recall (see pseudocode
41
+ equation above). Defaults to `1`.
42
+ threshold: Elements of `y_pred` greater than `threshold` are
43
+ converted to be 1, and the rest 0. If `threshold` is
44
+ `None`, the argmax of `y_pred` is converted to 1, and the rest to 0.
45
+ name: Optional. String name of the metric instance.
46
+ dtype: Optional. Data type of the metric result.
47
+
48
+ Returns:
49
+ F-Beta Score: float.
50
+
51
+ Example:
52
+
53
+ >>> metric = keras.metrics.FBetaScore(beta=2.0, threshold=0.5)
54
+ >>> y_true = np.array([[1, 1, 1],
55
+ ... [1, 0, 0],
56
+ ... [1, 1, 0]], np.int32)
57
+ >>> y_pred = np.array([[0.2, 0.6, 0.7],
58
+ ... [0.2, 0.6, 0.6],
59
+ ... [0.6, 0.8, 0.0]], np.float32)
60
+ >>> metric.update_state(y_true, y_pred)
61
+ >>> result = metric.result()
62
+ >>> result
63
+ [0.3846154 , 0.90909094, 0.8333334 ]
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ average=None,
69
+ beta=1.0,
70
+ threshold=None,
71
+ name="fbeta_score",
72
+ dtype=None,
73
+ ):
74
+ super().__init__(name=name, dtype=dtype)
75
+ # Metric should be maximized during optimization.
76
+ self._direction = "up"
77
+
78
+ if average not in (None, "micro", "macro", "weighted"):
79
+ raise ValueError(
80
+ "Invalid `average` argument value. Expected one of: "
81
+ "{None, 'micro', 'macro', 'weighted'}. "
82
+ f"Received: average={average}"
83
+ )
84
+
85
+ if not isinstance(beta, float):
86
+ raise ValueError(
87
+ "Invalid `beta` argument value. "
88
+ "It should be a Python float. "
89
+ f"Received: beta={beta} of type '{type(beta)}'"
90
+ )
91
+ if beta <= 0.0:
92
+ raise ValueError(
93
+ "Invalid `beta` argument value. "
94
+ "It should be > 0. "
95
+ f"Received: beta={beta}"
96
+ )
97
+
98
+ if threshold is not None:
99
+ if not isinstance(threshold, float):
100
+ raise ValueError(
101
+ "Invalid `threshold` argument value. "
102
+ "It should be a Python float. "
103
+ f"Received: threshold={threshold} "
104
+ f"of type '{type(threshold)}'"
105
+ )
106
+ if threshold > 1.0 or threshold <= 0.0:
107
+ raise ValueError(
108
+ "Invalid `threshold` argument value. "
109
+ "It should verify 0 < threshold <= 1. "
110
+ f"Received: threshold={threshold}"
111
+ )
112
+
113
+ self.average = average
114
+ self.beta = beta
115
+ self.threshold = threshold
116
+ self.axis = None
117
+ self._built = False
118
+
119
+ if self.average != "micro":
120
+ self.axis = 0
121
+
122
+ def _build(self, y_true_shape, y_pred_shape):
123
+ if len(y_pred_shape) != 2 or len(y_true_shape) != 2:
124
+ raise ValueError(
125
+ "FBetaScore expects 2D inputs with shape "
126
+ "(batch_size, output_dim). Received input "
127
+ f"shapes: y_pred.shape={y_pred_shape} and "
128
+ f"y_true.shape={y_true_shape}."
129
+ )
130
+ if y_pred_shape[-1] is None or y_true_shape[-1] is None:
131
+ raise ValueError(
132
+ "FBetaScore expects 2D inputs with shape "
133
+ "(batch_size, output_dim), with output_dim fully "
134
+ "defined (not None). Received input "
135
+ f"shapes: y_pred.shape={y_pred_shape} and "
136
+ f"y_true.shape={y_true_shape}."
137
+ )
138
+ num_classes = y_pred_shape[-1]
139
+ if self.average != "micro":
140
+ init_shape = (num_classes,)
141
+ else:
142
+ init_shape = ()
143
+
144
+ def _add_zeros_variable(name):
145
+ return self.add_variable(
146
+ name=name,
147
+ shape=init_shape,
148
+ initializer=initializers.Zeros(),
149
+ dtype=self.dtype,
150
+ )
151
+
152
+ self.true_positives = _add_zeros_variable("true_positives")
153
+ self.false_positives = _add_zeros_variable("false_positives")
154
+ self.false_negatives = _add_zeros_variable("false_negatives")
155
+ self.intermediate_weights = _add_zeros_variable("intermediate_weights")
156
+ self._built = True
157
+
158
+ def update_state(self, y_true, y_pred, sample_weight=None):
159
+ y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
160
+ y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)
161
+ if not self._built:
162
+ self._build(y_true.shape, y_pred.shape)
163
+
164
+ if self.threshold is None:
165
+ threshold = ops.max(y_pred, axis=-1, keepdims=True)
166
+ # make sure [0, 0, 0] doesn't become [1, 1, 1]
167
+ # Use abs(x) > eps, instead of x != 0 to check for zero
168
+ y_pred = ops.logical_and(
169
+ y_pred >= threshold, ops.abs(y_pred) > 1e-9
170
+ )
171
+ else:
172
+ y_pred = y_pred > self.threshold
173
+
174
+ y_pred = ops.cast(y_pred, dtype=self.dtype)
175
+ y_true = ops.cast(y_true, dtype=self.dtype)
176
+ if sample_weight is not None:
177
+ sample_weight = ops.convert_to_tensor(
178
+ sample_weight, dtype=self.dtype
179
+ )
180
+
181
+ def _weighted_sum(val, sample_weight):
182
+ if sample_weight is not None:
183
+ val = ops.multiply(val, ops.expand_dims(sample_weight, 1))
184
+ return ops.sum(val, axis=self.axis)
185
+
186
+ self.true_positives.assign(
187
+ self.true_positives + _weighted_sum(y_pred * y_true, sample_weight)
188
+ )
189
+ self.false_positives.assign(
190
+ self.false_positives
191
+ + _weighted_sum(y_pred * (1 - y_true), sample_weight)
192
+ )
193
+ self.false_negatives.assign(
194
+ self.false_negatives
195
+ + _weighted_sum((1 - y_pred) * y_true, sample_weight)
196
+ )
197
+ self.intermediate_weights.assign(
198
+ self.intermediate_weights + _weighted_sum(y_true, sample_weight)
199
+ )
200
+
201
+ def result(self):
202
+ precision = ops.divide(
203
+ self.true_positives,
204
+ self.true_positives + self.false_positives + backend.epsilon(),
205
+ )
206
+ recall = ops.divide(
207
+ self.true_positives,
208
+ self.true_positives + self.false_negatives + backend.epsilon(),
209
+ )
210
+
211
+ precision = ops.convert_to_tensor(precision, dtype=self.dtype)
212
+ recall = ops.convert_to_tensor(recall, dtype=self.dtype)
213
+
214
+ mul_value = precision * recall
215
+ add_value = ((self.beta**2) * precision) + recall
216
+ mean = ops.divide(mul_value, add_value + backend.epsilon())
217
+ f1_score = mean * (1 + (self.beta**2))
218
+
219
+ if self.average == "weighted":
220
+ weights = ops.divide(
221
+ self.intermediate_weights,
222
+ ops.sum(self.intermediate_weights) + backend.epsilon(),
223
+ )
224
+ f1_score = ops.sum(f1_score * weights)
225
+
226
+ elif self.average is not None: # [micro, macro]
227
+ f1_score = ops.mean(f1_score)
228
+
229
+ return f1_score
230
+
231
+ def get_config(self):
232
+ """Returns the serializable config of the metric."""
233
+
234
+ config = {
235
+ "name": self.name,
236
+ "dtype": self.dtype,
237
+ "average": self.average,
238
+ "beta": self.beta,
239
+ "threshold": self.threshold,
240
+ }
241
+
242
+ base_config = super().get_config()
243
+ return {**base_config, **config}
244
+
245
+ def reset_state(self):
246
+ for v in self.variables:
247
+ v.assign(ops.zeros(v.shape, dtype=v.dtype))
248
+
249
+
250
+ @keras_export("keras.metrics.F1Score")
251
+ class F1Score(FBetaScore):
252
+ r"""Computes F-1 Score.
253
+
254
+ Formula:
255
+
256
+ ```python
257
+ f1_score = 2 * (precision * recall) / (precision + recall)
258
+ ```
259
+ This is the harmonic mean of precision and recall.
260
+ Its output range is `[0, 1]`. It works for both multi-class
261
+ and multi-label classification.
262
+
263
+ Args:
264
+ average: Type of averaging to be performed on data.
265
+ Acceptable values are `None`, `"micro"`, `"macro"`
266
+ and `"weighted"`. Defaults to `None`.
267
+ If `None`, no averaging is performed and `result()` will return
268
+ the score for each class.
269
+ If `"micro"`, compute metrics globally by counting the total
270
+ true positives, false negatives and false positives.
271
+ If `"macro"`, compute metrics for each label,
272
+ and return their unweighted mean.
273
+ This does not take label imbalance into account.
274
+ If `"weighted"`, compute metrics for each label,
275
+ and return their average weighted by support
276
+ (the number of true instances for each label).
277
+ This alters `"macro"` to account for label imbalance.
278
+ It can result in an score that is not between precision and recall.
279
+ threshold: Elements of `y_pred` greater than `threshold` are
280
+ converted to be 1, and the rest 0. If `threshold` is
281
+ `None`, the argmax of `y_pred` is converted to 1, and the rest to 0.
282
+ name: Optional. String name of the metric instance.
283
+ dtype: Optional. Data type of the metric result.
284
+
285
+ Returns:
286
+ F-1 Score: float.
287
+
288
+ Example:
289
+
290
+ >>> metric = keras.metrics.F1Score(threshold=0.5)
291
+ >>> y_true = np.array([[1, 1, 1],
292
+ ... [1, 0, 0],
293
+ ... [1, 1, 0]], np.int32)
294
+ >>> y_pred = np.array([[0.2, 0.6, 0.7],
295
+ ... [0.2, 0.6, 0.6],
296
+ ... [0.6, 0.8, 0.0]], np.float32)
297
+ >>> metric.update_state(y_true, y_pred)
298
+ >>> result = metric.result()
299
+ array([0.5 , 0.8 , 0.6666667], dtype=float32)
300
+ """
301
+
302
+ def __init__(
303
+ self,
304
+ average=None,
305
+ threshold=None,
306
+ name="f1_score",
307
+ dtype=None,
308
+ ):
309
+ super().__init__(
310
+ average=average,
311
+ beta=1.0,
312
+ threshold=threshold,
313
+ name=name,
314
+ dtype=dtype,
315
+ )
316
+
317
+ def get_config(self):
318
+ base_config = super().get_config()
319
+ del base_config["beta"]
320
+ return base_config
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/hinge_metrics.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src.api_export import keras_export
2
+ from keras.src.losses.losses import categorical_hinge
3
+ from keras.src.losses.losses import hinge
4
+ from keras.src.losses.losses import squared_hinge
5
+ from keras.src.metrics import reduction_metrics
6
+
7
+
8
+ @keras_export("keras.metrics.Hinge")
9
+ class Hinge(reduction_metrics.MeanMetricWrapper):
10
+ """Computes the hinge metric between `y_true` and `y_pred`.
11
+
12
+ `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
13
+ provided we will convert them to -1 or 1.
14
+
15
+ Args:
16
+ name: (Optional) string name of the metric instance.
17
+ dtype: (Optional) data type of the metric result.
18
+
19
+ Examples:
20
+
21
+ >>> m = keras.metrics.Hinge()
22
+ >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
23
+ >>> m.result()
24
+ 1.3
25
+ >>> m.reset_state()
26
+ >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
27
+ ... sample_weight=[1, 0])
28
+ >>> m.result()
29
+ 1.1
30
+ """
31
+
32
+ def __init__(self, name="hinge", dtype=None):
33
+ super().__init__(fn=hinge, name=name, dtype=dtype)
34
+ # Metric should be minimized during optimization.
35
+ self._direction = "down"
36
+
37
+ def get_config(self):
38
+ return {"name": self.name, "dtype": self.dtype}
39
+
40
+
41
+ @keras_export("keras.metrics.SquaredHinge")
42
+ class SquaredHinge(reduction_metrics.MeanMetricWrapper):
43
+ """Computes the hinge metric between `y_true` and `y_pred`.
44
+
45
+ `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
46
+ provided we will convert them to -1 or 1.
47
+
48
+ Args:
49
+ name: (Optional) string name of the metric instance.
50
+ dtype: (Optional) data type of the metric result.
51
+
52
+ Example:
53
+
54
+ >>> m = keras.metrics.SquaredHinge()
55
+ >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
56
+ >>> m.result()
57
+ 1.86
58
+ >>> m.reset_state()
59
+ >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
60
+ ... sample_weight=[1, 0])
61
+ >>> m.result()
62
+ 1.46
63
+ """
64
+
65
+ def __init__(self, name="squared_hinge", dtype=None):
66
+ super().__init__(fn=squared_hinge, name=name, dtype=dtype)
67
+ # Metric should be minimized during optimization.
68
+ self._direction = "down"
69
+
70
+ def get_config(self):
71
+ return {"name": self.name, "dtype": self.dtype}
72
+
73
+
74
+ @keras_export("keras.metrics.CategoricalHinge")
75
+ class CategoricalHinge(reduction_metrics.MeanMetricWrapper):
76
+ """Computes the categorical hinge metric between `y_true` and `y_pred`.
77
+
78
+ Args:
79
+ name: (Optional) string name of the metric instance.
80
+ dtype: (Optional) data type of the metric result.
81
+
82
+ Example:
83
+ >>> m = keras.metrics.CategoricalHinge()
84
+ >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
85
+ >>> m.result().numpy()
86
+ 1.4000001
87
+ >>> m.reset_state()
88
+ >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
89
+ ... sample_weight=[1, 0])
90
+ >>> m.result()
91
+ 1.2
92
+ """
93
+
94
+ def __init__(self, name="categorical_hinge", dtype=None):
95
+ super().__init__(fn=categorical_hinge, name=name, dtype=dtype)
96
+ # Metric should be minimized during optimization.
97
+ self._direction = "down"
98
+
99
+ def get_config(self):
100
+ return {"name": self.name, "dtype": self.dtype}
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/iou_metrics.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from keras.src import backend
4
+ from keras.src import initializers
5
+ from keras.src import ops
6
+ from keras.src.api_export import keras_export
7
+ from keras.src.metrics.metric import Metric
8
+ from keras.src.metrics.metrics_utils import confusion_matrix
9
+
10
+
11
+ class _IoUBase(Metric):
12
+ """Computes the confusion matrix for Intersection-Over-Union metrics.
13
+
14
+ Formula:
15
+
16
+ ```python
17
+ iou = true_positives / (true_positives + false_positives + false_negatives)
18
+ ```
19
+ Intersection-Over-Union is a common evaluation metric for semantic image
20
+ segmentation.
21
+
22
+ From IoUs of individual classes, the MeanIoU can be computed as the mean of
23
+ the individual IoUs.
24
+
25
+ To compute IoUs, the predictions are accumulated in a confusion matrix,
26
+ weighted by `sample_weight` and the metric is then calculated from it.
27
+
28
+ If `sample_weight` is `None`, weights default to 1.
29
+ Use `sample_weight` of 0 to mask values.
30
+
31
+ Args:
32
+ num_classes: The possible number of labels the prediction task can have.
33
+ name: (Optional) string name of the metric instance.
34
+ dtype: (Optional) data type of the metric result.
35
+ ignore_class: Optional integer. The ID of a class to be ignored during
36
+ metric computation. This is useful, for example, in segmentation
37
+ problems featuring a "void" class (commonly -1 or 255) in
38
+ segmentation maps. By default (`ignore_class=None`), all classes are
39
+ considered.
40
+ sparse_y_true: Whether labels are encoded using integers or
41
+ dense floating point vectors. If `False`, the `argmax` function
42
+ is used to determine each sample's most likely associated label.
43
+ sparse_y_pred: Whether predictions are encoded using integers or
44
+ dense floating point vectors. If `False`, the `argmax` function
45
+ is used to determine each sample's most likely associated label.
46
+ axis: (Optional) -1 is the dimension containing the logits.
47
+ Defaults to `-1`.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ num_classes,
53
+ name=None,
54
+ dtype=None,
55
+ ignore_class=None,
56
+ sparse_y_true=True,
57
+ sparse_y_pred=True,
58
+ axis=-1,
59
+ ):
60
+ # defaulting to int to avoid issues with confusion matrix
61
+ super().__init__(name=name, dtype=dtype or "int")
62
+ # Metric should be maximized during optimization.
63
+ self._direction = "up"
64
+ self.num_classes = num_classes
65
+ self.ignore_class = ignore_class
66
+ self.sparse_y_true = sparse_y_true
67
+ self.sparse_y_pred = sparse_y_pred
68
+ self.axis = axis
69
+
70
+ self.total_cm = self.add_variable(
71
+ name="total_confusion_matrix",
72
+ shape=(num_classes, num_classes),
73
+ initializer=initializers.Zeros(),
74
+ dtype=self.dtype,
75
+ )
76
+
77
+ def update_state(self, y_true, y_pred, sample_weight=None):
78
+ """Accumulates the confusion matrix statistics.
79
+
80
+ Args:
81
+ y_true: The ground truth values.
82
+ y_pred: The predicted values.
83
+ sample_weight: Optional weighting of each example. Can
84
+ be a `Tensor` whose rank is either 0, or the same as `y_true`,
85
+ and must be broadcastable to `y_true`. Defaults to `1`.
86
+
87
+ Returns:
88
+ Update op.
89
+ """
90
+
91
+ if not self.sparse_y_true:
92
+ y_true = ops.argmax(y_true, axis=self.axis)
93
+ if not self.sparse_y_pred:
94
+ y_pred = ops.argmax(y_pred, axis=self.axis)
95
+
96
+ y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
97
+ y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)
98
+
99
+ # Flatten the input if its rank > 1.
100
+ if len(y_pred.shape) > 1:
101
+ y_pred = ops.reshape(y_pred, [-1])
102
+
103
+ if len(y_true.shape) > 1:
104
+ y_true = ops.reshape(y_true, [-1])
105
+
106
+ if sample_weight is None:
107
+ sample_weight = 1
108
+ else:
109
+ if (
110
+ hasattr(sample_weight, "dtype")
111
+ and "float" in str(sample_weight.dtype)
112
+ and "int" in str(self.dtype)
113
+ ):
114
+ warnings.warn(
115
+ "You are passing weight as `float`, but dtype is `int`. "
116
+ "This may result in an incorrect weight due to type casting"
117
+ " Consider using integer weights."
118
+ )
119
+ sample_weight = ops.convert_to_tensor(sample_weight, dtype=self.dtype)
120
+
121
+ if len(sample_weight.shape) > 1:
122
+ sample_weight = ops.reshape(sample_weight, [-1])
123
+
124
+ sample_weight = ops.broadcast_to(sample_weight, ops.shape(y_true))
125
+
126
+ if self.ignore_class is not None:
127
+ ignore_class = ops.convert_to_tensor(
128
+ self.ignore_class, y_true.dtype
129
+ )
130
+ valid_mask = ops.not_equal(y_true, ignore_class)
131
+ y_true = y_true * ops.cast(valid_mask, y_true.dtype)
132
+ y_pred = y_pred * ops.cast(valid_mask, y_pred.dtype)
133
+ if sample_weight is not None:
134
+ sample_weight = sample_weight * ops.cast(
135
+ valid_mask, sample_weight.dtype
136
+ )
137
+
138
+ y_pred = ops.cast(y_pred, dtype=self.dtype)
139
+ y_true = ops.cast(y_true, dtype=self.dtype)
140
+ sample_weight = ops.cast(sample_weight, dtype=self.dtype)
141
+
142
+ current_cm = confusion_matrix(
143
+ y_true,
144
+ y_pred,
145
+ self.num_classes,
146
+ weights=sample_weight,
147
+ dtype=self.dtype,
148
+ )
149
+
150
+ return self.total_cm.assign(self.total_cm + current_cm)
151
+
152
+ def reset_state(self):
153
+ self.total_cm.assign(
154
+ ops.zeros(self.total_cm.shape, dtype=self.total_cm.dtype)
155
+ )
156
+
157
+
158
+ @keras_export("keras.metrics.IoU")
159
+ class IoU(_IoUBase):
160
+ """Computes the Intersection-Over-Union metric for specific target classes.
161
+
162
+ Formula:
163
+
164
+ ```python
165
+ iou = true_positives / (true_positives + false_positives + false_negatives)
166
+ ```
167
+ Intersection-Over-Union is a common evaluation metric for semantic image
168
+ segmentation.
169
+
170
+ To compute IoUs, the predictions are accumulated in a confusion matrix,
171
+ weighted by `sample_weight` and the metric is then calculated from it.
172
+
173
+ If `sample_weight` is `None`, weights default to 1.
174
+ Use `sample_weight` of 0 to mask values.
175
+
176
+ Note, this class first computes IoUs for all individual classes, then
177
+ returns the mean of IoUs for the classes that are specified by
178
+ `target_class_ids`. If `target_class_ids` has only one id value, the IoU of
179
+ that specific class is returned.
180
+
181
+ Args:
182
+ num_classes: The possible number of labels the prediction task can have.
183
+ target_class_ids: A tuple or list of target class ids for which the
184
+ metric is returned. To compute IoU for a specific class, a list
185
+ (or tuple) of a single id value should be provided.
186
+ name: (Optional) string name of the metric instance.
187
+ dtype: (Optional) data type of the metric result.
188
+ ignore_class: Optional integer. The ID of a class to be ignored during
189
+ metric computation. This is useful, for example, in segmentation
190
+ problems featuring a "void" class (commonly -1 or 255) in
191
+ segmentation maps. By default (`ignore_class=None`), all classes are
192
+ considered.
193
+ sparse_y_true: Whether labels are encoded using integers or
194
+ dense floating point vectors. If `False`, the `argmax` function
195
+ is used to determine each sample's most likely associated label.
196
+ sparse_y_pred: Whether predictions are encoded using integers or
197
+ dense floating point vectors. If `False`, the `argmax` function
198
+ is used to determine each sample's most likely associated label.
199
+ axis: (Optional) -1 is the dimension containing the logits.
200
+ Defaults to `-1`.
201
+
202
+ Examples:
203
+
204
+ >>> # cm = [[1, 1],
205
+ >>> # [1, 1]]
206
+ >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
207
+ >>> # iou = true_positives / (sum_row + sum_col - true_positives))
208
+ >>> # iou = [0.33, 0.33]
209
+ >>> m = keras.metrics.IoU(num_classes=2, target_class_ids=[0])
210
+ >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
211
+ >>> m.result()
212
+ 0.33333334
213
+
214
+ >>> m.reset_state()
215
+ >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
216
+ ... sample_weight=[0.3, 0.3, 0.3, 0.1])
217
+ >>> # cm = [[0.3, 0.3],
218
+ >>> # [0.3, 0.1]]
219
+ >>> # sum_row = [0.6, 0.4], sum_col = [0.6, 0.4],
220
+ >>> # true_positives = [0.3, 0.1]
221
+ >>> # iou = [0.33, 0.14]
222
+ >>> m.result()
223
+ 0.33333334
224
+
225
+ Usage with `compile()` API:
226
+
227
+ ```python
228
+ model.compile(
229
+ optimizer='sgd',
230
+ loss='mse',
231
+ metrics=[keras.metrics.IoU(num_classes=2, target_class_ids=[0])])
232
+ ```
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ num_classes,
238
+ target_class_ids,
239
+ name=None,
240
+ dtype=None,
241
+ ignore_class=None,
242
+ sparse_y_true=True,
243
+ sparse_y_pred=True,
244
+ axis=-1,
245
+ ):
246
+ super().__init__(
247
+ name=name,
248
+ num_classes=num_classes,
249
+ ignore_class=ignore_class,
250
+ sparse_y_true=sparse_y_true,
251
+ sparse_y_pred=sparse_y_pred,
252
+ axis=axis,
253
+ dtype=dtype,
254
+ )
255
+ if max(target_class_ids) >= num_classes:
256
+ raise ValueError(
257
+ f"Target class id {max(target_class_ids)} "
258
+ "is out of range, which is "
259
+ f"[{0}, {num_classes})."
260
+ )
261
+ self.target_class_ids = list(target_class_ids)
262
+
263
+ def result(self):
264
+ """Compute the intersection-over-union via the confusion matrix."""
265
+ sum_over_row = ops.cast(
266
+ ops.sum(self.total_cm, axis=0), dtype=self.dtype
267
+ )
268
+ sum_over_col = ops.cast(
269
+ ops.sum(self.total_cm, axis=1), dtype=self.dtype
270
+ )
271
+ true_positives = ops.cast(ops.diag(self.total_cm), dtype=self.dtype)
272
+
273
+ # sum_over_row + sum_over_col =
274
+ # 2 * true_positives + false_positives + false_negatives.
275
+ denominator = sum_over_row + sum_over_col - true_positives
276
+
277
+ target_class_ids = ops.convert_to_tensor(
278
+ self.target_class_ids, dtype="int32"
279
+ )
280
+
281
+ # Only keep the target classes
282
+ true_positives = ops.take_along_axis(
283
+ true_positives, target_class_ids, axis=-1
284
+ )
285
+ denominator = ops.take_along_axis(
286
+ denominator, target_class_ids, axis=-1
287
+ )
288
+ denominator = ops.cast(denominator, dtype="float32")
289
+
290
+ # If the denominator is 0, we need to ignore the class.
291
+ num_valid_entries = ops.sum(
292
+ ops.cast(ops.greater(denominator, 1e-9), dtype="float32")
293
+ )
294
+
295
+ iou = ops.divide(true_positives, denominator + backend.epsilon())
296
+
297
+ return ops.divide(
298
+ ops.sum(iou, axis=self.axis), num_valid_entries + backend.epsilon()
299
+ )
300
+
301
+ def get_config(self):
302
+ config = {
303
+ "num_classes": self.num_classes,
304
+ "target_class_ids": self.target_class_ids,
305
+ "ignore_class": self.ignore_class,
306
+ "sparse_y_true": self.sparse_y_true,
307
+ "sparse_y_pred": self.sparse_y_pred,
308
+ "axis": self.axis,
309
+ }
310
+ base_config = super().get_config()
311
+ return dict(list(base_config.items()) + list(config.items()))
312
+
313
+
314
+ @keras_export("keras.metrics.BinaryIoU")
315
+ class BinaryIoU(IoU):
316
+ """Computes the Intersection-Over-Union metric for class 0 and/or 1.
317
+
318
+ Formula:
319
+
320
+ ```python
321
+ iou = true_positives / (true_positives + false_positives + false_negatives)
322
+ ```
323
+ Intersection-Over-Union is a common evaluation metric for semantic image
324
+ segmentation.
325
+
326
+ To compute IoUs, the predictions are accumulated in a confusion matrix,
327
+ weighted by `sample_weight` and the metric is then calculated from it.
328
+
329
+ If `sample_weight` is `None`, weights default to 1.
330
+ Use `sample_weight` of 0 to mask values.
331
+
332
+ This class can be used to compute IoUs for a binary classification task
333
+ where the predictions are provided as logits. First a `threshold` is applied
334
+ to the predicted values such that those that are below the `threshold` are
335
+ converted to class 0 and those that are above the `threshold` are converted
336
+ to class 1.
337
+
338
+ IoUs for classes 0 and 1 are then computed, the mean of IoUs for the classes
339
+ that are specified by `target_class_ids` is returned.
340
+
341
+ Note: with `threshold=0`, this metric has the same behavior as `IoU`.
342
+
343
+ Args:
344
+ target_class_ids: A tuple or list of target class ids for which the
345
+ metric is returned. Options are `[0]`, `[1]`, or `[0, 1]`. With
346
+ `[0]` (or `[1]`), the IoU metric for class 0 (or class 1,
347
+ respectively) is returned. With `[0, 1]`, the mean of IoUs for the
348
+ two classes is returned.
349
+ threshold: A threshold that applies to the prediction logits to convert
350
+ them to either predicted class 0 if the logit is below `threshold`
351
+ or predicted class 1 if the logit is above `threshold`.
352
+ name: (Optional) string name of the metric instance.
353
+ dtype: (Optional) data type of the metric result.
354
+
355
+ Example:
356
+
357
+ Example:
358
+
359
+ >>> m = keras.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)
360
+ >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7])
361
+ >>> m.result()
362
+ 0.33333334
363
+
364
+ >>> m.reset_state()
365
+ >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7],
366
+ ... sample_weight=[0.2, 0.3, 0.4, 0.1])
367
+ >>> # cm = [[0.2, 0.4],
368
+ >>> # [0.3, 0.1]]
369
+ >>> # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5],
370
+ >>> # true_positives = [0.2, 0.1]
371
+ >>> # iou = [0.222, 0.125]
372
+ >>> m.result()
373
+ 0.17361112
374
+
375
+ Usage with `compile()` API:
376
+
377
+ ```python
378
+ model.compile(
379
+ optimizer='sgd',
380
+ loss='mse',
381
+ metrics=[keras.metrics.BinaryIoU(
382
+ target_class_ids=[0],
383
+ threshold=0.5
384
+ )]
385
+ )
386
+ ```
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ target_class_ids=(0, 1),
392
+ threshold=0.5,
393
+ name=None,
394
+ dtype=None,
395
+ ):
396
+ super().__init__(
397
+ num_classes=2,
398
+ target_class_ids=target_class_ids,
399
+ name=name,
400
+ dtype=dtype,
401
+ )
402
+ self.threshold = threshold
403
+
404
+ def update_state(self, y_true, y_pred, sample_weight=None):
405
+ """Accumulates the confusion matrix statistics.
406
+
407
+ Before the confusion matrix is updated, the predicted values are
408
+ thresholded to be:
409
+ 0 for values that are smaller than the `threshold`
410
+ 1 for values that are larger or equal to the `threshold`
411
+
412
+ Args:
413
+ y_true: The ground truth values.
414
+ y_pred: The predicted values.
415
+ sample_weight: Optional weighting of each example. Can
416
+ be a `Tensor` whose rank is either 0, or the same as `y_true`,
417
+ and must be broadcastable to `y_true`. Defaults to `1`.
418
+
419
+ Returns:
420
+ Update op.
421
+ """
422
+ y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
423
+ # convert y_pred on float 32 and cast just after to dtype
424
+ y_pred = ops.convert_to_tensor(y_pred, dtype="float32")
425
+ y_pred = ops.cast(y_pred >= self.threshold, self.dtype)
426
+ return super().update_state(y_true, y_pred, sample_weight)
427
+
428
+ def get_config(self):
429
+ return {
430
+ "target_class_ids": self.target_class_ids,
431
+ "threshold": self.threshold,
432
+ "name": self.name,
433
+ "dtype": self._dtype,
434
+ }
435
+
436
+
437
+ @keras_export("keras.metrics.MeanIoU")
438
+ class MeanIoU(IoU):
439
+ """Computes the mean Intersection-Over-Union metric.
440
+
441
+ Formula:
442
+
443
+ ```python
444
+ iou = true_positives / (true_positives + false_positives + false_negatives)
445
+ ```
446
+ Intersection-Over-Union is a common evaluation metric for semantic image
447
+ segmentation.
448
+
449
+ To compute IoUs, the predictions are accumulated in a confusion matrix,
450
+ weighted by `sample_weight` and the metric is then calculated from it.
451
+
452
+ If `sample_weight` is `None`, weights default to 1.
453
+ Use `sample_weight` of 0 to mask values.
454
+
455
+ Note that this class first computes IoUs for all individual classes, then
456
+ returns the mean of these values.
457
+
458
+ Args:
459
+ num_classes: The possible number of labels the prediction task can have.
460
+ This value must be provided, since a confusion matrix of dimension =
461
+ [num_classes, num_classes] will be allocated.
462
+ name: (Optional) string name of the metric instance.
463
+ dtype: (Optional) data type of the metric result.
464
+ ignore_class: Optional integer. The ID of a class to be ignored during
465
+ metric computation. This is useful, for example, in segmentation
466
+ problems featuring a "void" class (commonly -1 or 255) in
467
+ segmentation maps. By default (`ignore_class=None`), all classes are
468
+ considered.
469
+ sparse_y_true: Whether labels are encoded using integers or
470
+ dense floating point vectors. If `False`, the `argmax` function
471
+ is used to determine each sample's most likely associated label.
472
+ sparse_y_pred: Whether predictions are encoded using integers or
473
+ dense floating point vectors. If `False`, the `argmax` function
474
+ is used to determine each sample's most likely associated label.
475
+ axis: (Optional) The dimension containing the logits. Defaults to `-1`.
476
+
477
+ Example:
478
+
479
+ Example:
480
+
481
+ >>> # cm = [[1, 1],
482
+ >>> # [1, 1]]
483
+ >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
484
+ >>> # iou = true_positives / (sum_row + sum_col - true_positives))
485
+ >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33
486
+ >>> m = keras.metrics.MeanIoU(num_classes=2)
487
+ >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
488
+ >>> m.result()
489
+ 0.33333334
490
+
491
+ >>> m.reset_state()
492
+ >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
493
+ ... sample_weight=[0.3, 0.3, 0.3, 0.1])
494
+ >>> m.result().numpy()
495
+ 0.23809525
496
+
497
+ Usage with `compile()` API:
498
+
499
+ ```python
500
+ model.compile(
501
+ optimizer='sgd',
502
+ loss='mse',
503
+ metrics=[keras.metrics.MeanIoU(num_classes=2)])
504
+ ```
505
+ """
506
+
507
+ def __init__(
508
+ self,
509
+ num_classes,
510
+ name=None,
511
+ dtype=None,
512
+ ignore_class=None,
513
+ sparse_y_true=True,
514
+ sparse_y_pred=True,
515
+ axis=-1,
516
+ ):
517
+ target_class_ids = list(range(num_classes))
518
+ super().__init__(
519
+ name=name,
520
+ num_classes=num_classes,
521
+ target_class_ids=target_class_ids,
522
+ axis=axis,
523
+ dtype=dtype,
524
+ ignore_class=ignore_class,
525
+ sparse_y_true=sparse_y_true,
526
+ sparse_y_pred=sparse_y_pred,
527
+ )
528
+
529
+ def get_config(self):
530
+ return {
531
+ "num_classes": self.num_classes,
532
+ "name": self.name,
533
+ "dtype": self._dtype,
534
+ "ignore_class": self.ignore_class,
535
+ "sparse_y_true": self.sparse_y_true,
536
+ "sparse_y_pred": self.sparse_y_pred,
537
+ "axis": self.axis,
538
+ }
539
+
540
+
541
+ @keras_export("keras.metrics.OneHotIoU")
542
+ class OneHotIoU(IoU):
543
+ """Computes the Intersection-Over-Union metric for one-hot encoded labels.
544
+
545
+ Formula:
546
+
547
+ ```python
548
+ iou = true_positives / (true_positives + false_positives + false_negatives)
549
+ ```
550
+ Intersection-Over-Union is a common evaluation metric for semantic image
551
+ segmentation.
552
+
553
+ To compute IoUs, the predictions are accumulated in a confusion matrix,
554
+ weighted by `sample_weight` and the metric is then calculated from it.
555
+
556
+ If `sample_weight` is `None`, weights default to 1.
557
+ Use `sample_weight` of 0 to mask values.
558
+
559
+ This class can be used to compute IoU for multi-class classification tasks
560
+ where the labels are one-hot encoded (the last axis should have one
561
+ dimension per class). Note that the predictions should also have the same
562
+ shape. To compute the IoU, first the labels and predictions are converted
563
+ back into integer format by taking the argmax over the class axis. Then the
564
+ same computation steps as for the base `IoU` class apply.
565
+
566
+ Note, if there is only one channel in the labels and predictions, this class
567
+ is the same as class `IoU`. In this case, use `IoU` instead.
568
+
569
+ Also, make sure that `num_classes` is equal to the number of classes in the
570
+ data, to avoid a "labels out of bound" error when the confusion matrix is
571
+ computed.
572
+
573
+ Args:
574
+ num_classes: The possible number of labels the prediction task can have.
575
+ target_class_ids: A tuple or list of target class ids for which the
576
+ metric is returned. To compute IoU for a specific class, a list
577
+ (or tuple) of a single id value should be provided.
578
+ name: (Optional) string name of the metric instance.
579
+ dtype: (Optional) data type of the metric result.
580
+ ignore_class: Optional integer. The ID of a class to be ignored during
581
+ metric computation. This is useful, for example, in segmentation
582
+ problems featuring a "void" class (commonly -1 or 255) in
583
+ segmentation maps. By default (`ignore_class=None`), all classes are
584
+ considered.
585
+ sparse_y_pred: Whether predictions are encoded using integers or
586
+ dense floating point vectors. If `False`, the `argmax` function
587
+ is used to determine each sample's most likely associated label.
588
+ axis: (Optional) The dimension containing the logits. Defaults to `-1`.
589
+
590
+
591
+ Example:
592
+
593
+ >>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
594
+ >>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],
595
+ ... [0.1, 0.4, 0.5]])
596
+ >>> sample_weight = [0.1, 0.2, 0.3, 0.4]
597
+ >>> m = keras.metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])
598
+ >>> m.update_state(
599
+ ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
600
+ >>> # cm = [[0, 0, 0.2+0.4],
601
+ >>> # [0.3, 0, 0],
602
+ >>> # [0, 0, 0.1]]
603
+ >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]
604
+ >>> # true_positives = [0, 0, 0.1]
605
+ >>> # single_iou = true_positives / (sum_row + sum_col - true_positives))
606
+ >>> # mean_iou = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2
607
+ >>> m.result()
608
+ 0.071
609
+
610
+ Usage with `compile()` API:
611
+
612
+ ```python
613
+ model.compile(
614
+ optimizer='sgd',
615
+ loss='mse',
616
+ metrics=[keras.metrics.OneHotIoU(
617
+ num_classes=3,
618
+ target_class_id=[1]
619
+ )]
620
+ )
621
+ ```
622
+ """
623
+
624
+ def __init__(
625
+ self,
626
+ num_classes,
627
+ target_class_ids,
628
+ name=None,
629
+ dtype=None,
630
+ ignore_class=None,
631
+ sparse_y_pred=False,
632
+ axis=-1,
633
+ ):
634
+ super().__init__(
635
+ num_classes=num_classes,
636
+ target_class_ids=target_class_ids,
637
+ name=name,
638
+ dtype=dtype,
639
+ ignore_class=ignore_class,
640
+ sparse_y_true=False,
641
+ sparse_y_pred=sparse_y_pred,
642
+ axis=axis,
643
+ )
644
+
645
+ def get_config(self):
646
+ return {
647
+ "num_classes": self.num_classes,
648
+ "target_class_ids": self.target_class_ids,
649
+ "name": self.name,
650
+ "dtype": self._dtype,
651
+ "ignore_class": self.ignore_class,
652
+ "sparse_y_pred": self.sparse_y_pred,
653
+ "axis": self.axis,
654
+ }
655
+
656
+
657
+ @keras_export("keras.metrics.OneHotMeanIoU")
658
+ class OneHotMeanIoU(MeanIoU):
659
+ """Computes mean Intersection-Over-Union metric for one-hot encoded labels.
660
+
661
+ Formula:
662
+
663
+ ```python
664
+ iou = true_positives / (true_positives + false_positives + false_negatives)
665
+ ```
666
+ Intersection-Over-Union is a common evaluation metric for semantic image
667
+ segmentation.
668
+
669
+ To compute IoUs, the predictions are accumulated in a confusion matrix,
670
+ weighted by `sample_weight` and the metric is then calculated from it.
671
+
672
+ If `sample_weight` is `None`, weights default to 1.
673
+ Use `sample_weight` of 0 to mask values.
674
+
675
+ This class can be used to compute the mean IoU for multi-class
676
+ classification tasks where the labels are one-hot encoded (the last axis
677
+ should have one dimension per class). Note that the predictions should also
678
+ have the same shape. To compute the mean IoU, first the labels and
679
+ predictions are converted back into integer format by taking the argmax over
680
+ the class axis. Then the same computation steps as for the base `MeanIoU`
681
+ class apply.
682
+
683
+ Note, if there is only one channel in the labels and predictions, this class
684
+ is the same as class `MeanIoU`. In this case, use `MeanIoU` instead.
685
+
686
+ Also, make sure that `num_classes` is equal to the number of classes in the
687
+ data, to avoid a "labels out of bound" error when the confusion matrix is
688
+ computed.
689
+
690
+ Args:
691
+ num_classes: The possible number of labels the prediction task can have.
692
+ name: (Optional) string name of the metric instance.
693
+ dtype: (Optional) data type of the metric result.
694
+ ignore_class: Optional integer. The ID of a class to be ignored during
695
+ metric computation. This is useful, for example, in segmentation
696
+ problems featuring a "void" class (commonly -1 or 255) in
697
+ segmentation maps. By default (`ignore_class=None`), all classes are
698
+ considered.
699
+ sparse_y_pred: Whether predictions are encoded using natural numbers or
700
+ probability distribution vectors. If `False`, the `argmax`
701
+ function will be used to determine each sample's most likely
702
+ associated label.
703
+ axis: (Optional) The dimension containing the logits. Defaults to `-1`.
704
+
705
+
706
+ Example:
707
+
708
+ >>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
709
+ >>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],
710
+ ... [0.1, 0.4, 0.5]])
711
+ >>> sample_weight = [0.1, 0.2, 0.3, 0.4]
712
+ >>> m = keras.metrics.OneHotMeanIoU(num_classes=3)
713
+ >>> m.update_state(
714
+ ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
715
+ >>> # cm = [[0, 0, 0.2+0.4],
716
+ >>> # [0.3, 0, 0],
717
+ >>> # [0, 0, 0.1]]
718
+ >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]
719
+ >>> # true_positives = [0, 0, 0.1]
720
+ >>> # single_iou = true_positives / (sum_row + sum_col - true_positives))
721
+ >>> # mean_iou = (0 + 0 + 0.1 / (0.7 + 0.1 - 0.1)) / 3
722
+ >>> m.result()
723
+ 0.048
724
+
725
+ Usage with `compile()` API:
726
+
727
+ ```python
728
+ model.compile(
729
+ optimizer='sgd',
730
+ loss='mse',
731
+ metrics=[keras.metrics.OneHotMeanIoU(num_classes=3)])
732
+ ```
733
+ """
734
+
735
+ def __init__(
736
+ self,
737
+ num_classes,
738
+ name=None,
739
+ dtype=None,
740
+ ignore_class=None,
741
+ sparse_y_pred=False,
742
+ axis=-1,
743
+ ):
744
+ super().__init__(
745
+ num_classes=num_classes,
746
+ axis=axis,
747
+ name=name,
748
+ dtype=dtype,
749
+ ignore_class=ignore_class,
750
+ sparse_y_true=False,
751
+ sparse_y_pred=sparse_y_pred,
752
+ )
753
+
754
+ def get_config(self):
755
+ return {
756
+ "num_classes": self.num_classes,
757
+ "name": self.name,
758
+ "dtype": self._dtype,
759
+ "ignore_class": self.ignore_class,
760
+ "sparse_y_pred": self.sparse_y_pred,
761
+ "axis": self.axis,
762
+ }
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/metric.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import backend
2
+ from keras.src import dtype_policies
3
+ from keras.src import initializers
4
+ from keras.src import ops
5
+ from keras.src.api_export import keras_export
6
+ from keras.src.saving.keras_saveable import KerasSaveable
7
+ from keras.src.utils.naming import auto_name
8
+ from keras.src.utils.tracking import Tracker
9
+
10
+
11
+ @keras_export(["keras.Metric", "keras.metrics.Metric"])
12
+ class Metric(KerasSaveable):
13
+ """Encapsulates metric logic and state.
14
+
15
+ Args:
16
+ name: Optional name for the metric instance.
17
+ dtype: The dtype of the metric's computations. Defaults to `None`, which
18
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
19
+ `"float32"` unless set to different value
20
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
21
+ provided, then the `compute_dtype` will be utilized.
22
+
23
+ Example:
24
+
25
+ ```python
26
+ m = SomeMetric(...)
27
+ for input in ...:
28
+ m.update_state(input)
29
+ print('Final result: ', m.result())
30
+ ```
31
+
32
+ Usage with `compile()` API:
33
+
34
+ ```python
35
+ model = keras.Sequential()
36
+ model.add(keras.layers.Dense(64, activation='relu'))
37
+ model.add(keras.layers.Dense(64, activation='relu'))
38
+ model.add(keras.layers.Dense(10, activation='softmax'))
39
+
40
+ model.compile(optimizer=keras.optimizers.RMSprop(0.01),
41
+ loss=keras.losses.CategoricalCrossentropy(),
42
+ metrics=[keras.metrics.CategoricalAccuracy()])
43
+
44
+ data = np.random.random((1000, 32))
45
+ labels = np.random.random((1000, 10))
46
+
47
+ model.fit(data, labels, epochs=10)
48
+ ```
49
+
50
+ To be implemented by subclasses:
51
+
52
+ * `__init__()`: All state variables should be created in this method by
53
+ calling `self.add_variable()` like: `self.var = self.add_variable(...)`
54
+ * `update_state()`: Has all updates to the state variables like:
55
+ `self.var.assign(...)`.
56
+ * `result()`: Computes and returns a scalar value or a dict of scalar values
57
+ for the metric from the state variables.
58
+
59
+ Example subclass implementation:
60
+
61
+ ```python
62
+ class BinaryTruePositives(Metric):
63
+
64
+ def __init__(self, name='binary_true_positives', **kwargs):
65
+ super().__init__(name=name, **kwargs)
66
+ self.true_positives = self.add_variable(
67
+ shape=(),
68
+ initializer='zeros',
69
+ name='true_positives'
70
+ )
71
+
72
+ def update_state(self, y_true, y_pred, sample_weight=None):
73
+ y_true = ops.cast(y_true, "bool")
74
+ y_pred = ops.cast(y_pred, "bool")
75
+
76
+ values = ops.logical_and(
77
+ ops.equal(y_true, True), ops.equal(y_pred, True))
78
+ values = ops.cast(values, self.dtype)
79
+ if sample_weight is not None:
80
+ sample_weight = ops.cast(sample_weight, self.dtype)
81
+ sample_weight = ops.broadcast_to(
82
+ sample_weight, ops.shape(values)
83
+ )
84
+ values = ops.multiply(values, sample_weight)
85
+ self.true_positives.assign(self.true_positives + ops.sum(values))
86
+
87
+ def result(self):
88
+ return self.true_positives
89
+ ```
90
+ """
91
+
92
+ def __init__(self, dtype=None, name=None):
93
+ self.name = name or auto_name(self.__class__.__name__)
94
+ self._dtype_policy = dtype_policies.get(dtype or backend.floatx())
95
+ self._dtype = self._dtype_policy.compute_dtype
96
+ self._metrics = []
97
+ self._variables = []
98
+ self._tracker = Tracker(
99
+ {
100
+ "variables": (
101
+ lambda x: isinstance(x, backend.Variable),
102
+ self._variables,
103
+ ),
104
+ "metrics": (lambda x: isinstance(x, Metric), self._metrics),
105
+ }
106
+ )
107
+
108
+ def reset_state(self):
109
+ """Reset all of the metric state variables.
110
+
111
+ This function is called between epochs/steps,
112
+ when a metric is evaluated during training.
113
+ """
114
+ for v in self.variables:
115
+ v.assign(ops.zeros(v.shape, dtype=v.dtype))
116
+
117
+ def update_state(self, *args, **kwargs):
118
+ """Accumulate statistics for the metric."""
119
+ raise NotImplementedError
120
+
121
+ def stateless_update_state(self, metric_variables, *args, **kwargs):
122
+ if len(metric_variables) != len(self.variables):
123
+ raise ValueError(
124
+ "Argument `metric_variables` must be a list of tensors "
125
+ f"corresponding 1:1 to {self.__class__.__name__}().variables. "
126
+ f"Received list with length {len(metric_variables)}, but "
127
+ f"expected {len(self.variables)} variables."
128
+ )
129
+ # Gather variable mapping
130
+ mapping = list(zip(self.variables, metric_variables))
131
+
132
+ # Call in stateless scope
133
+ with backend.StatelessScope(state_mapping=mapping) as scope:
134
+ self.update_state(*args, **kwargs)
135
+
136
+ # Gather updated variables
137
+ metric_variables = []
138
+ for v in self.variables:
139
+ new_v = scope.get_current_value(v)
140
+ if new_v is not None:
141
+ metric_variables.append(new_v)
142
+ else:
143
+ metric_variables.append(v)
144
+ return metric_variables
145
+
146
+ def result(self):
147
+ """Compute the current metric value.
148
+
149
+ Returns:
150
+ A scalar tensor, or a dictionary of scalar tensors.
151
+ """
152
+ raise NotImplementedError
153
+
154
+ def stateless_result(self, metric_variables):
155
+ if len(metric_variables) != len(self.variables):
156
+ raise ValueError(
157
+ "Argument `metric_variables` must be a list of tensors "
158
+ f"corresponding 1:1 to {self.__class__.__name__}().variables. "
159
+ f"Received list with length {len(metric_variables)}, but "
160
+ f"expected {len(self.variables)} variables."
161
+ )
162
+ # Gather variable mapping
163
+ mapping = list(zip(self.variables, metric_variables))
164
+
165
+ # Call in stateless scope
166
+ with backend.StatelessScope(state_mapping=mapping):
167
+ res = self.result()
168
+ return res
169
+
170
+ def stateless_reset_state(self):
171
+ # Call in stateless scope
172
+ with backend.StatelessScope() as scope:
173
+ self.reset_state()
174
+
175
+ # Gather updated variables
176
+ metric_variables = []
177
+ for v in self.variables:
178
+ new_v = scope.get_current_value(v)
179
+ if new_v is not None:
180
+ metric_variables.append(new_v)
181
+ else:
182
+ metric_variables.append(v)
183
+ return metric_variables
184
+
185
+ @property
186
+ def dtype(self):
187
+ return self._dtype
188
+
189
+ def _obj_type(self):
190
+ return "Metric"
191
+
192
+ def add_variable(
193
+ self, shape, initializer, dtype=None, aggregation="sum", name=None
194
+ ):
195
+ self._check_super_called()
196
+ with backend.name_scope(self.name.replace("/", ">"), caller=self):
197
+ initializer = initializers.get(initializer)
198
+ variable = backend.Variable(
199
+ initializer=initializer,
200
+ shape=shape,
201
+ dtype=dtype,
202
+ trainable=False,
203
+ aggregation=aggregation,
204
+ name=name,
205
+ )
206
+ # Prevent double-tracking
207
+ self._tracker.add_to_store("variables", variable)
208
+ return variable
209
+
210
+ def add_weight(self, shape=(), initializer=None, dtype=None, name=None):
211
+ # Backwards compatibility alias
212
+ return self.add_variable(
213
+ shape=shape, initializer=initializer, dtype=dtype, name=name
214
+ )
215
+
216
+ @property
217
+ def variables(self):
218
+ variables = list(self._variables)
219
+ for metric in self._metrics:
220
+ variables.extend(metric.variables)
221
+ return variables
222
+
223
+ def __call__(self, *args, **kwargs):
224
+ self._check_super_called()
225
+ self.update_state(*args, **kwargs)
226
+ return self.result()
227
+
228
+ def get_config(self):
229
+ """Return the serializable config of the metric."""
230
+ return {"name": self.name, "dtype": self.dtype}
231
+
232
+ @classmethod
233
+ def from_config(cls, config):
234
+ return cls(**config)
235
+
236
+ def __setattr__(self, name, value):
237
+ # Track Variables, Layers, Metrics
238
+ if hasattr(self, "_tracker"):
239
+ value = self._tracker.track(value)
240
+ return super().__setattr__(name, value)
241
+
242
+ def _check_super_called(self):
243
+ if not hasattr(self, "_tracker"):
244
+ raise RuntimeError(
245
+ "You forgot to call `super().__init__()` "
246
+ "in the `__init__()` method. Go add it!"
247
+ )
248
+
249
+ def __repr__(self):
250
+ return f"<{self.__class__.__name__} " f"name={self.name}>"
251
+
252
+ def __str__(self):
253
+ return self.__repr__()
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/metrics/metrics_utils.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import numpy as np
4
+
5
+ from keras.src import backend
6
+ from keras.src import ops
7
+ from keras.src.losses.loss import squeeze_or_expand_to_same_rank
8
+ from keras.src.utils.python_utils import to_list
9
+
10
+ NEG_INF = -1e10
11
+
12
+
13
+ def assert_thresholds_range(thresholds):
14
+ if thresholds is not None:
15
+ invalid_thresholds = [
16
+ t for t in thresholds if t is None or t < 0 or t > 1
17
+ ]
18
+ if invalid_thresholds:
19
+ raise ValueError(
20
+ "Threshold values must be in [0, 1]. "
21
+ f"Received: {invalid_thresholds}"
22
+ )
23
+
24
+
25
+ def parse_init_thresholds(thresholds, default_threshold=0.5):
26
+ if thresholds is not None:
27
+ assert_thresholds_range(to_list(thresholds))
28
+ thresholds = to_list(
29
+ default_threshold if thresholds is None else thresholds
30
+ )
31
+ return thresholds
32
+
33
+
34
+ class ConfusionMatrix(Enum):
35
+ TRUE_POSITIVES = "tp"
36
+ FALSE_POSITIVES = "fp"
37
+ TRUE_NEGATIVES = "tn"
38
+ FALSE_NEGATIVES = "fn"
39
+
40
+
41
+ class AUCCurve(Enum):
42
+ """Type of AUC Curve (ROC or PR)."""
43
+
44
+ ROC = "ROC"
45
+ PR = "PR"
46
+
47
+ @staticmethod
48
+ def from_str(key):
49
+ if key in ("pr", "PR"):
50
+ return AUCCurve.PR
51
+ elif key in ("roc", "ROC"):
52
+ return AUCCurve.ROC
53
+ else:
54
+ raise ValueError(
55
+ f'Invalid AUC curve value: "{key}". '
56
+ 'Expected values are ["PR", "ROC"]'
57
+ )
58
+
59
+
60
+ class AUCSummationMethod(Enum):
61
+ """Type of AUC summation method.
62
+
63
+ https://en.wikipedia.org/wiki/Riemann_sum)
64
+
65
+ Contains the following values:
66
+ * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
67
+ `PR` curve, interpolates (true/false) positives but not the ratio that is
68
+ precision (see Davis & Goadrich 2006 for details).
69
+ * 'minoring': Applies left summation for increasing intervals and right
70
+ summation for decreasing intervals.
71
+ * 'majoring': Applies right summation for increasing intervals and left
72
+ summation for decreasing intervals.
73
+ """
74
+
75
+ INTERPOLATION = "interpolation"
76
+ MAJORING = "majoring"
77
+ MINORING = "minoring"
78
+
79
+ @staticmethod
80
+ def from_str(key):
81
+ if key in ("interpolation", "Interpolation"):
82
+ return AUCSummationMethod.INTERPOLATION
83
+ elif key in ("majoring", "Majoring"):
84
+ return AUCSummationMethod.MAJORING
85
+ elif key in ("minoring", "Minoring"):
86
+ return AUCSummationMethod.MINORING
87
+ else:
88
+ raise ValueError(
89
+ f'Invalid AUC summation method value: "{key}". '
90
+ 'Expected values are ["interpolation", "majoring", "minoring"]'
91
+ )
92
+
93
+
94
+ def _update_confusion_matrix_variables_optimized(
95
+ variables_to_update,
96
+ y_true,
97
+ y_pred,
98
+ thresholds,
99
+ multi_label=False,
100
+ sample_weights=None,
101
+ label_weights=None,
102
+ thresholds_with_epsilon=False,
103
+ ):
104
+ """Update confusion matrix variables with memory efficient alternative.
105
+
106
+ Note that the thresholds need to be evenly distributed within the list, eg,
107
+ the diff between consecutive elements are the same.
108
+
109
+ To compute TP/FP/TN/FN, we are measuring a binary classifier
110
+ C(t) = (predictions >= t)
111
+ at each threshold 't'. So we have
112
+ TP(t) = sum( C(t) * true_labels )
113
+ FP(t) = sum( C(t) * false_labels )
114
+
115
+ But, computing C(t) requires computation for each t. To make it fast,
116
+ observe that C(t) is a cumulative integral, and so if we have
117
+ thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
118
+ where n = num_thresholds, and if we can compute the bucket function
119
+ B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
120
+ then we get
121
+ C(t_i) = sum( B(j), j >= i )
122
+ which is the reversed cumulative sum in ops.cumsum().
123
+
124
+ We can compute B(i) efficiently by taking advantage of the fact that
125
+ our thresholds are evenly distributed, in that
126
+ width = 1.0 / (num_thresholds - 1)
127
+ thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
128
+ Given a prediction value p, we can map it to its bucket by
129
+ bucket_index(p) = floor( p * (num_thresholds - 1) )
130
+ so we can use ops.segment_sum() to update the buckets in one pass.
131
+
132
+ Consider following example:
133
+ y_true = [0, 0, 1, 1]
134
+ y_pred = [0.1, 0.5, 0.3, 0.9]
135
+ thresholds = [0.0, 0.5, 1.0]
136
+ num_buckets = 2 # [0.0, 1.0], (1.0, 2.0]
137
+ bucket_index(y_pred) = ops.floor(y_pred * num_buckets)
138
+ = ops.floor([0.2, 1.0, 0.6, 1.8])
139
+ = [0, 0, 0, 1]
140
+ # The meaning of this bucket is that if any of the label is true,
141
+ # then 1 will be added to the corresponding bucket with the index.
142
+ # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
143
+ # label for 1.8 is true, then 1 will be added to bucket 1.
144
+ #
145
+ # Note the second item "1.0" is floored to 0, since the value need to be
146
+ # strictly larger than the bucket lower bound.
147
+ # In the implementation, we use ops.ceil() - 1 to achieve this.
148
+ tp_bucket_value = ops.segment_sum(true_labels, bucket_indices,
149
+ num_segments=num_thresholds)
150
+ = [1, 1, 0]
151
+ # For [1, 1, 0] here, it means there is 1 true value contributed by bucket
152
+ # 0, and 1 value contributed by bucket 1. When we aggregate them to
153
+ # together, the result become [a + b + c, b + c, c], since large thresholds
154
+ # will always contribute to the value for smaller thresholds.
155
+ true_positive = ops.cumsum(tp_bucket_value, reverse=True)
156
+ = [2, 1, 0]
157
+
158
+ This implementation exhibits a run time and space complexity of O(T + N),
159
+ where T is the number of thresholds and N is the size of predictions.
160
+ Metrics that rely on standard implementation instead exhibit a complexity of
161
+ O(T * N).
162
+
163
+ Args:
164
+ variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid
165
+ keys and corresponding variables to update as values.
166
+ y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
167
+ cast to `bool`.
168
+ y_pred: A floating point `Tensor` of arbitrary shape and whose values
169
+ are in the range `[0, 1]`.
170
+ thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
171
+ It need to be evenly distributed (the diff between each element need
172
+ to be the same).
173
+ multi_label: Optional boolean indicating whether multidimensional
174
+ prediction/labels should be treated as multilabel responses, or
175
+ flattened into a single label. When True, the values of
176
+ `variables_to_update` must have a second dimension equal to the
177
+ number of labels in y_true and y_pred, and those tensors must not be
178
+ RaggedTensors.
179
+ sample_weights: Optional `Tensor` whose rank is either 0, or the same
180
+ rank as `y_true`, and must be broadcastable to `y_true` (i.e., all
181
+ dimensions must be either `1`, or the same as the corresponding
182
+ `y_true` dimension).
183
+ label_weights: Optional tensor of non-negative weights for multilabel
184
+ data. The weights are applied when calculating TP, FP, FN, and TN
185
+ without explicit multilabel handling (i.e. when the data is to be
186
+ flattened).
187
+ thresholds_with_epsilon: Optional boolean indicating whether the leading
188
+ and tailing thresholds has any epsilon added for floating point
189
+ imprecisions. It will change how we handle the leading and tailing
190
+ bucket.
191
+ """
192
+ num_thresholds = ops.shape(thresholds)[0]
193
+
194
+ if sample_weights is None:
195
+ sample_weights = 1.0
196
+ else:
197
+ sample_weights = ops.broadcast_to(
198
+ ops.cast(sample_weights, dtype=y_pred.dtype), ops.shape(y_pred)
199
+ )
200
+ if not multi_label:
201
+ sample_weights = ops.reshape(sample_weights, [-1])
202
+ if label_weights is None:
203
+ label_weights = 1.0
204
+ else:
205
+ label_weights = ops.expand_dims(label_weights, 0)
206
+ label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred))
207
+ if not multi_label:
208
+ label_weights = ops.reshape(label_weights, [-1])
209
+ weights = ops.cast(
210
+ ops.multiply(sample_weights, label_weights), y_true.dtype
211
+ )
212
+
213
+ # We shouldn't need this, but in case there are predict value that is out of
214
+ # the range of [0.0, 1.0]
215
+ y_pred = ops.clip(y_pred, x_min=0.0, x_max=1.0)
216
+
217
+ y_true = ops.cast(ops.cast(y_true, "bool"), y_true.dtype)
218
+ if not multi_label:
219
+ y_true = ops.reshape(y_true, [-1])
220
+ y_pred = ops.reshape(y_pred, [-1])
221
+
222
+ true_labels = ops.multiply(y_true, weights)
223
+ false_labels = ops.multiply((1.0 - y_true), weights)
224
+
225
+ # Compute the bucket indices for each prediction value.
226
+ # Since the predict value has to be strictly greater than the thresholds,
227
+ # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
228
+ # We have to use math.ceil(val) - 1 for the bucket.
229
+ bucket_indices = (
230
+ ops.ceil(y_pred * (ops.cast(num_thresholds, dtype=y_pred.dtype) - 1))
231
+ - 1
232
+ )
233
+
234
+ if thresholds_with_epsilon:
235
+ # In this case, the first bucket should actually take into account since
236
+ # the any prediction between [0.0, 1.0] should be larger than the first
237
+ # threshold. We change the bucket value from -1 to 0.
238
+ bucket_indices = ops.relu(bucket_indices)
239
+
240
+ bucket_indices = ops.cast(bucket_indices, "int32")
241
+
242
+ if multi_label:
243
+ # We need to run bucket segment sum for each of the label class. In the
244
+ # multi_label case, the rank of the label is 2. We first transpose it so
245
+ # that the label dim becomes the first and we can parallel run though
246
+ # them.
247
+ true_labels = ops.transpose(true_labels)
248
+ false_labels = ops.transpose(false_labels)
249
+ bucket_indices = ops.transpose(bucket_indices)
250
+
251
+ def gather_bucket(label_and_bucket_index):
252
+ label, bucket_index = (
253
+ label_and_bucket_index[0],
254
+ label_and_bucket_index[1],
255
+ )
256
+ return ops.segment_sum(
257
+ data=label,
258
+ segment_ids=bucket_index,
259
+ num_segments=num_thresholds,
260
+ )
261
+
262
+ tp_bucket_v = backend.vectorized_map(
263
+ gather_bucket,
264
+ (true_labels, bucket_indices),
265
+ )
266
+ fp_bucket_v = backend.vectorized_map(
267
+ gather_bucket, (false_labels, bucket_indices)
268
+ )
269
+ tp = ops.transpose(ops.flip(ops.cumsum(ops.flip(tp_bucket_v), axis=1)))
270
+ fp = ops.transpose(ops.flip(ops.cumsum(ops.flip(fp_bucket_v), axis=1)))
271
+ else:
272
+ tp_bucket_v = ops.segment_sum(
273
+ data=true_labels,
274
+ segment_ids=bucket_indices,
275
+ num_segments=num_thresholds,
276
+ )
277
+ fp_bucket_v = ops.segment_sum(
278
+ data=false_labels,
279
+ segment_ids=bucket_indices,
280
+ num_segments=num_thresholds,
281
+ )
282
+ tp = ops.flip(ops.cumsum(ops.flip(tp_bucket_v)))
283
+ fp = ops.flip(ops.cumsum(ops.flip(fp_bucket_v)))
284
+
285
+ # fn = sum(true_labels) - tp
286
+ # tn = sum(false_labels) - fp
287
+ if (
288
+ ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
289
+ or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
290
+ ):
291
+ if multi_label:
292
+ total_true_labels = ops.sum(true_labels, axis=1)
293
+ total_false_labels = ops.sum(false_labels, axis=1)
294
+ else:
295
+ total_true_labels = ops.sum(true_labels)
296
+ total_false_labels = ops.sum(false_labels)
297
+
298
+ if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
299
+ variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
300
+ variable.assign(variable + tp)
301
+ if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
302
+ variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
303
+ variable.assign(variable + fp)
304
+ if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
305
+ variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
306
+ tn = total_false_labels - fp
307
+ variable.assign(variable + tn)
308
+ if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
309
+ variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
310
+ fn = total_true_labels - tp
311
+ variable.assign(variable + fn)
312
+
313
+
314
+ def is_evenly_distributed_thresholds(thresholds):
315
+ """Check if the thresholds list is evenly distributed.
316
+
317
+ We could leverage evenly distributed thresholds to use less memory when
318
+ calculate metrcis like AUC where each individual threshold need to be
319
+ evaluated.
320
+
321
+ Args:
322
+ thresholds: A python list or tuple, or 1D numpy array whose value is
323
+ ranged in [0, 1].
324
+
325
+ Returns:
326
+ boolean, whether the values in the inputs are evenly distributed.
327
+ """
328
+ # Check the list value and see if it is evenly distributed.
329
+ num_thresholds = len(thresholds)
330
+ if num_thresholds < 3:
331
+ return False
332
+ even_thresholds = np.arange(num_thresholds, dtype=np.float32) / (
333
+ num_thresholds - 1
334
+ )
335
+ return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())
336
+
337
+
338
+ def update_confusion_matrix_variables(
339
+ variables_to_update,
340
+ y_true,
341
+ y_pred,
342
+ thresholds,
343
+ top_k=None,
344
+ class_id=None,
345
+ sample_weight=None,
346
+ multi_label=False,
347
+ label_weights=None,
348
+ thresholds_distributed_evenly=False,
349
+ ):
350
+ """Updates the given confusion matrix variables.
351
+
352
+ For every pair of values in y_true and y_pred:
353
+
354
+ true_positive: y_true == True and y_pred > thresholds
355
+ false_negatives: y_true == True and y_pred <= thresholds
356
+ true_negatives: y_true == False and y_pred <= thresholds
357
+ false_positive: y_true == False and y_pred > thresholds
358
+
359
+ The results will be weighted and added together. When multiple thresholds
360
+ are provided, we will repeat the same for every threshold.
361
+
362
+ For estimation of these metrics over a stream of data, the function creates
363
+ an `update_op` operation that updates the given variables.
364
+
365
+ If `sample_weight` is `None`, weights default to 1.
366
+ Use weights of 0 to mask values.
367
+
368
+ Args:
369
+ variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
370
+ and corresponding variables to update as values.
371
+ y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
372
+ y_pred: A floating point `Tensor` of arbitrary shape and whose values are
373
+ in the range `[0, 1]`.
374
+ thresholds: A float value, float tensor, python list, or tuple of float
375
+ thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
376
+ top_k: Optional int, indicates that the positive labels should be limited
377
+ to the top k predictions.
378
+ class_id: Optional int, limits the prediction and labels to the class
379
+ specified by this argument.
380
+ sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
381
+ as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
382
+ must be either `1`, or the same as the corresponding `y_true`
383
+ dimension).
384
+ multi_label: Optional boolean indicating whether multidimensional
385
+ prediction/labels should be treated as multilabel responses, or
386
+ flattened into a single label. When True, the values of
387
+ `variables_to_update` must have a second dimension equal to the number
388
+ of labels in y_true and y_pred, and those tensors must not be
389
+ RaggedTensors.
390
+ label_weights: (optional) tensor of non-negative weights for multilabel
391
+ data. The weights are applied when calculating TP, FP, FN, and TN
392
+ without explicit multilabel handling (i.e. when the data is to be
393
+ flattened).
394
+ thresholds_distributed_evenly: Boolean, whether the thresholds are evenly
395
+ distributed within the list. An optimized method will be used if this is
396
+ the case. See _update_confusion_matrix_variables_optimized() for more
397
+ details.
398
+
399
+ Raises:
400
+ ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
401
+ `sample_weight` is not `None` and its shape doesn't match `y_pred`, or
402
+ if `variables_to_update` contains invalid keys.
403
+ """
404
+ if multi_label and label_weights is not None:
405
+ raise ValueError(
406
+ "`label_weights` for multilabel data should be handled "
407
+ "outside of `update_confusion_matrix_variables` when "
408
+ "`multi_label` is True."
409
+ )
410
+ if variables_to_update is None:
411
+ return
412
+ if not any(
413
+ key for key in variables_to_update if key in list(ConfusionMatrix)
414
+ ):
415
+ raise ValueError(
416
+ "Please provide at least one valid confusion matrix "
417
+ "variable to update. Valid variable key options are: "
418
+ f'"{list(ConfusionMatrix)}". '
419
+ f'Received: "{variables_to_update.keys()}"'
420
+ )
421
+
422
+ variable_dtype = list(variables_to_update.values())[0].dtype
423
+
424
+ y_true = ops.cast(y_true, dtype=variable_dtype)
425
+ y_pred = ops.cast(y_pred, dtype=variable_dtype)
426
+
427
+ if thresholds_distributed_evenly:
428
+ # Check whether the thresholds has any leading or tailing epsilon added
429
+ # for floating point imprecision. The leading and tailing threshold will
430
+ # be handled bit differently as the corner case. At this point,
431
+ # thresholds should be a list/array with more than 2 items, and ranged
432
+ # between [0, 1]. See is_evenly_distributed_thresholds() for more
433
+ # details.
434
+ thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0
435
+
436
+ thresholds = ops.convert_to_tensor(thresholds, dtype=variable_dtype)
437
+ num_thresholds = ops.shape(thresholds)[0]
438
+
439
+ if multi_label:
440
+ one_thresh = ops.equal(
441
+ np.array(1, dtype="int32"),
442
+ len(thresholds.shape),
443
+ )
444
+ else:
445
+ one_thresh = np.array(True, dtype="bool")
446
+
447
+ invalid_keys = [
448
+ key for key in variables_to_update if key not in list(ConfusionMatrix)
449
+ ]
450
+ if invalid_keys:
451
+ raise ValueError(
452
+ f'Invalid keys: "{invalid_keys}". '
453
+ f'Valid variable key options are: "{list(ConfusionMatrix)}"'
454
+ )
455
+
456
+ y_pred, y_true = squeeze_or_expand_to_same_rank(y_pred, y_true)
457
+ if sample_weight is not None:
458
+ sample_weight = ops.expand_dims(
459
+ ops.cast(sample_weight, dtype=variable_dtype), axis=-1
460
+ )
461
+ _, sample_weight = squeeze_or_expand_to_same_rank(
462
+ y_true, sample_weight, expand_rank_1=False
463
+ )
464
+
465
+ if top_k is not None:
466
+ y_pred = _filter_top_k(y_pred, top_k)
467
+
468
+ if class_id is not None:
469
+ if len(y_pred.shape) == 1:
470
+ raise ValueError(
471
+ "When class_id is provided, y_pred must be a 2D array "
472
+ "with shape (num_samples, num_classes), found shape: "
473
+ f"{y_pred.shape}"
474
+ )
475
+
476
+ # Preserve dimension to match with sample_weight
477
+ y_true = y_true[..., class_id, None]
478
+ y_pred = y_pred[..., class_id, None]
479
+
480
+ if thresholds_distributed_evenly:
481
+ return _update_confusion_matrix_variables_optimized(
482
+ variables_to_update,
483
+ y_true,
484
+ y_pred,
485
+ thresholds,
486
+ multi_label=multi_label,
487
+ sample_weights=sample_weight,
488
+ label_weights=label_weights,
489
+ thresholds_with_epsilon=thresholds_with_epsilon,
490
+ )
491
+
492
+ if None in y_pred.shape:
493
+ pred_shape = ops.shape(y_pred)
494
+ num_predictions = pred_shape[0]
495
+ if len(y_pred.shape) == 1:
496
+ num_labels = 1
497
+ else:
498
+ num_labels = ops.cast(
499
+ ops.prod(ops.array(pred_shape[1:]), axis=0), "int32"
500
+ )
501
+ thresh_label_tile = ops.where(one_thresh, num_labels, 1)
502
+ else:
503
+ pred_shape = ops.shape(y_pred)
504
+ num_predictions = pred_shape[0]
505
+ if len(y_pred.shape) == 1:
506
+ num_labels = 1
507
+ else:
508
+ num_labels = np.prod(pred_shape[1:], axis=0).astype("int32")
509
+ thresh_label_tile = np.where(one_thresh, num_labels, 1)
510
+
511
+ # Reshape predictions and labels, adding a dim for thresholding.
512
+ if multi_label:
513
+ predictions_extra_dim = ops.expand_dims(y_pred, 0)
514
+ labels_extra_dim = ops.expand_dims(ops.cast(y_true, dtype="bool"), 0)
515
+ else:
516
+ # Flatten predictions and labels when not multilabel.
517
+ predictions_extra_dim = ops.reshape(y_pred, [1, -1])
518
+ labels_extra_dim = ops.reshape(ops.cast(y_true, dtype="bool"), [1, -1])
519
+
520
+ # Tile the thresholds for every prediction.
521
+ if multi_label:
522
+ thresh_pretile_shape = [num_thresholds, 1, -1]
523
+ thresh_tiles = [1, num_predictions, thresh_label_tile]
524
+ data_tiles = [num_thresholds, 1, 1]
525
+ else:
526
+ thresh_pretile_shape = [num_thresholds, -1]
527
+ thresh_tiles = [1, num_predictions * num_labels]
528
+ data_tiles = [num_thresholds, 1]
529
+
530
+ thresh_tiled = ops.tile(
531
+ ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles
532
+ )
533
+
534
+ # Tile the predictions for every threshold.
535
+ preds_tiled = ops.tile(predictions_extra_dim, data_tiles)
536
+
537
+ # Compare predictions and threshold.
538
+ pred_is_pos = ops.greater(preds_tiled, thresh_tiled)
539
+
540
+ # Tile labels by number of thresholds
541
+ label_is_pos = ops.tile(labels_extra_dim, data_tiles)
542
+
543
+ if sample_weight is not None:
544
+ sample_weight = ops.broadcast_to(
545
+ ops.cast(sample_weight, dtype=y_pred.dtype), ops.shape(y_pred)
546
+ )
547
+ weights_tiled = ops.tile(
548
+ ops.reshape(sample_weight, thresh_tiles), data_tiles
549
+ )
550
+ else:
551
+ weights_tiled = None
552
+
553
+ if label_weights is not None and not multi_label:
554
+ label_weights = ops.expand_dims(label_weights, 0)
555
+ label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred))
556
+ label_weights_tiled = ops.tile(
557
+ ops.reshape(label_weights, thresh_tiles), data_tiles
558
+ )
559
+ if weights_tiled is None:
560
+ weights_tiled = label_weights_tiled
561
+ else:
562
+ weights_tiled = ops.multiply(weights_tiled, label_weights_tiled)
563
+
564
+ def weighted_assign_add(label, pred, weights, var):
565
+ label_and_pred = ops.cast(ops.logical_and(label, pred), dtype=var.dtype)
566
+ if weights is not None:
567
+ label_and_pred *= ops.cast(weights, dtype=var.dtype)
568
+ var.assign(var + ops.sum(label_and_pred, 1))
569
+
570
+ loop_vars = {
571
+ ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
572
+ }
573
+ update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
574
+ update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
575
+ update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
576
+
577
+ if update_fn or update_tn:
578
+ pred_is_neg = ops.logical_not(pred_is_pos)
579
+ loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
580
+
581
+ if update_fp or update_tn:
582
+ label_is_neg = ops.logical_not(label_is_pos)
583
+ loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
584
+ if update_tn:
585
+ loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (
586
+ label_is_neg,
587
+ pred_is_neg,
588
+ )
589
+
590
+ for matrix_cond, (label, pred) in loop_vars.items():
591
+ if matrix_cond in variables_to_update:
592
+ weighted_assign_add(
593
+ label, pred, weights_tiled, variables_to_update[matrix_cond]
594
+ )
595
+
596
+
597
+ def _filter_top_k(x, k):
598
+ """Filters top-k values in the last dim of x and set the rest to NEG_INF.
599
+
600
+ Used for computing top-k prediction values in dense labels (which has the
601
+ same shape as predictions) for recall and precision top-k metrics.
602
+
603
+ Args:
604
+ x: tensor with any dimensions.
605
+ k: the number of values to keep.
606
+
607
+ Returns:
608
+ tensor with same shape and dtype as x.
609
+ """
610
+ _, top_k_idx = ops.top_k(x, k)
611
+ top_k_mask = ops.sum(
612
+ ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2
613
+ )
614
+ return x * top_k_mask + NEG_INF * (1 - top_k_mask)
615
+
616
+
617
+ def confusion_matrix(
618
+ labels,
619
+ predictions,
620
+ num_classes,
621
+ weights=None,
622
+ dtype="int32",
623
+ ):
624
+ """Computes the confusion matrix from predictions and labels.
625
+
626
+ The matrix columns represent the prediction labels and the rows represent
627
+ the real labels. The confusion matrix is always a 2-D array of shape
628
+ `(n, n)`, where `n` is the number of valid labels for a given classification
629
+ task. Both prediction and labels must be 1-D arrays of the same shape in
630
+ order for this function to work.
631
+
632
+ If `num_classes` is `None`, then `num_classes` will be set to one plus the
633
+ maximum value in either predictions or labels. Class labels are expected to
634
+ start at 0. For example, if `num_classes` is 3, then the possible labels
635
+ would be `[0, 1, 2]`.
636
+
637
+ If `weights` is not `None`, then each prediction contributes its
638
+ corresponding weight to the total value of the confusion matrix cell.
639
+
640
+ For example:
641
+
642
+ ```python
643
+ keras.metrics.metrics_utils.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
644
+ [[0 0 0 0 0]
645
+ [0 0 1 0 0]
646
+ [0 0 1 0 0]
647
+ [0 0 0 0 0]
648
+ [0 0 0 0 1]]
649
+ ```
650
+
651
+ Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
652
+ resulting in a 5x5 confusion matrix.
653
+
654
+ Args:
655
+ labels: 1-D tensor of real labels for the classification task.
656
+ predictions: 1-D tensor of predictions for a given classification.
657
+ num_classes: The possible number of labels the classification
658
+ task can have.
659
+ weights: An optional tensor whose shape matches `predictions`.
660
+ dtype: Data type of the confusion matrix.
661
+
662
+ Returns:
663
+ A tensor of type `dtype` with shape `(n, n)` representing the confusion
664
+ matrix, where `n` is the number of possible labels in the classification
665
+ task.
666
+ """
667
+ labels = ops.convert_to_tensor(labels, dtype)
668
+ predictions = ops.convert_to_tensor(predictions, dtype)
669
+ labels, predictions = squeeze_or_expand_to_same_rank(labels, predictions)
670
+
671
+ predictions = ops.cast(predictions, dtype)
672
+ labels = ops.cast(labels, dtype)
673
+
674
+ if weights is not None:
675
+ weights = ops.convert_to_tensor(weights, dtype)
676
+
677
+ indices = ops.stack([labels, predictions], axis=1)
678
+ values = ops.ones_like(predictions, dtype) if weights is None else weights
679
+ indices = ops.cast(indices, dtype="int64")
680
+ values = ops.cast(values, dtype=dtype)
681
+ num_classes = int(num_classes)
682
+ confusion_matrix = ops.scatter(indices, values, (num_classes, num_classes))
683
+ return confusion_matrix