Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__init__.py +16 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/bijector_test_util.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/categorical.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/dirichlet.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/dirichlet_multinomial.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/distribution.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/distributions.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/exponential.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/gamma.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/identity_bijector.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/kullback_leibler.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/laplace.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/multinomial.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/normal.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/special_math.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/student_t.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/uniform.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/util.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/bernoulli.py +183 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/beta.py +407 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/bijector.py +21 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/bijector_impl.py +1113 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/bijector_test_util.py +221 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/categorical.py +345 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/dirichlet.py +410 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/dirichlet_multinomial.py +353 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/distribution.py +1316 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/distributions.py +36 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/exponential.py +162 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/gamma.py +338 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/identity_bijector.py +68 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/kullback_leibler.py +210 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/laplace.py +238 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/multinomial.py +314 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/normal.py +291 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/special_math.py +470 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/student_t.py +391 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/uniform.py +204 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/util.py +1448 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/losses/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/losses/__pycache__/losses.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/losses/__pycache__/losses_impl.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/losses/__pycache__/util.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__init__.py +37 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/dct_ops.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/fft_ops.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/mel_ops.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/mfcc_ops.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/reconstruction_ops.cpython-310.pyc +0 -0
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Core module for TensorFlow distribution objects and helpers."""
|
| 16 |
+
from tensorflow.python.ops.distributions import distributions
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/bijector_test_util.cpython-310.pyc
ADDED
|
Binary file (5.12 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/categorical.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/dirichlet.cpython-310.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/dirichlet_multinomial.cpython-310.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/distribution.cpython-310.pyc
ADDED
|
Binary file (42.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/distributions.cpython-310.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/exponential.cpython-310.pyc
ADDED
|
Binary file (4.86 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/gamma.cpython-310.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/identity_bijector.cpython-310.pyc
ADDED
|
Binary file (2.13 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/kullback_leibler.cpython-310.pyc
ADDED
|
Binary file (6.47 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/laplace.cpython-310.pyc
ADDED
|
Binary file (8.43 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/multinomial.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/normal.cpython-310.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/special_math.cpython-310.pyc
ADDED
|
Binary file (9.82 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/student_t.cpython-310.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/uniform.cpython-310.pyc
ADDED
|
Binary file (7.23 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/__pycache__/util.cpython-310.pyc
ADDED
|
Binary file (41.6 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/bernoulli.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Bernoulli distribution class."""
|
| 16 |
+
|
| 17 |
+
from tensorflow.python.framework import dtypes
|
| 18 |
+
from tensorflow.python.framework import ops
|
| 19 |
+
from tensorflow.python.framework import tensor_shape
|
| 20 |
+
from tensorflow.python.ops import array_ops
|
| 21 |
+
from tensorflow.python.ops import math_ops
|
| 22 |
+
from tensorflow.python.ops import nn
|
| 23 |
+
from tensorflow.python.ops import random_ops
|
| 24 |
+
from tensorflow.python.ops.distributions import distribution
|
| 25 |
+
from tensorflow.python.ops.distributions import kullback_leibler
|
| 26 |
+
from tensorflow.python.ops.distributions import util as distribution_util
|
| 27 |
+
from tensorflow.python.util import deprecation
|
| 28 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@tf_export(v1=["distributions.Bernoulli"])
|
| 32 |
+
class Bernoulli(distribution.Distribution):
|
| 33 |
+
"""Bernoulli distribution.
|
| 34 |
+
|
| 35 |
+
The Bernoulli distribution with `probs` parameter, i.e., the probability of a
|
| 36 |
+
`1` outcome (vs a `0` outcome).
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
@deprecation.deprecated(
|
| 40 |
+
"2019-01-01",
|
| 41 |
+
"The TensorFlow Distributions library has moved to "
|
| 42 |
+
"TensorFlow Probability "
|
| 43 |
+
"(https://github.com/tensorflow/probability). You "
|
| 44 |
+
"should update all references to use `tfp.distributions` "
|
| 45 |
+
"instead of `tf.distributions`.",
|
| 46 |
+
warn_once=True)
|
| 47 |
+
def __init__(self,
|
| 48 |
+
logits=None,
|
| 49 |
+
probs=None,
|
| 50 |
+
dtype=dtypes.int32,
|
| 51 |
+
validate_args=False,
|
| 52 |
+
allow_nan_stats=True,
|
| 53 |
+
name="Bernoulli"):
|
| 54 |
+
"""Construct Bernoulli distributions.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
logits: An N-D `Tensor` representing the log-odds of a `1` event. Each
|
| 58 |
+
entry in the `Tensor` parametrizes an independent Bernoulli distribution
|
| 59 |
+
where the probability of an event is sigmoid(logits). Only one of
|
| 60 |
+
`logits` or `probs` should be passed in.
|
| 61 |
+
probs: An N-D `Tensor` representing the probability of a `1`
|
| 62 |
+
event. Each entry in the `Tensor` parameterizes an independent
|
| 63 |
+
Bernoulli distribution. Only one of `logits` or `probs` should be passed
|
| 64 |
+
in.
|
| 65 |
+
dtype: The type of the event samples. Default: `int32`.
|
| 66 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 67 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 68 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 69 |
+
outputs.
|
| 70 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`,
|
| 71 |
+
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
|
| 72 |
+
indicate the result is undefined. When `False`, an exception is raised
|
| 73 |
+
if one or more of the statistic's batch members are undefined.
|
| 74 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 75 |
+
|
| 76 |
+
Raises:
|
| 77 |
+
ValueError: If p and logits are passed, or if neither are passed.
|
| 78 |
+
"""
|
| 79 |
+
parameters = dict(locals())
|
| 80 |
+
with ops.name_scope(name) as name:
|
| 81 |
+
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
| 82 |
+
logits=logits,
|
| 83 |
+
probs=probs,
|
| 84 |
+
validate_args=validate_args,
|
| 85 |
+
name=name)
|
| 86 |
+
super(Bernoulli, self).__init__(
|
| 87 |
+
dtype=dtype,
|
| 88 |
+
reparameterization_type=distribution.NOT_REPARAMETERIZED,
|
| 89 |
+
validate_args=validate_args,
|
| 90 |
+
allow_nan_stats=allow_nan_stats,
|
| 91 |
+
parameters=parameters,
|
| 92 |
+
graph_parents=[self._logits, self._probs],
|
| 93 |
+
name=name)
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def _param_shapes(sample_shape):
|
| 97 |
+
return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def logits(self):
|
| 101 |
+
"""Log-odds of a `1` outcome (vs `0`)."""
|
| 102 |
+
return self._logits
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def probs(self):
|
| 106 |
+
"""Probability of a `1` outcome (vs `0`)."""
|
| 107 |
+
return self._probs
|
| 108 |
+
|
| 109 |
+
def _batch_shape_tensor(self):
|
| 110 |
+
return array_ops.shape(self._logits)
|
| 111 |
+
|
| 112 |
+
def _batch_shape(self):
|
| 113 |
+
return self._logits.get_shape()
|
| 114 |
+
|
| 115 |
+
def _event_shape_tensor(self):
|
| 116 |
+
return array_ops.constant([], dtype=dtypes.int32)
|
| 117 |
+
|
| 118 |
+
def _event_shape(self):
|
| 119 |
+
return tensor_shape.TensorShape([])
|
| 120 |
+
|
| 121 |
+
def _sample_n(self, n, seed=None):
|
| 122 |
+
new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
|
| 123 |
+
uniform = random_ops.random_uniform(
|
| 124 |
+
new_shape, seed=seed, dtype=self.probs.dtype)
|
| 125 |
+
sample = math_ops.less(uniform, self.probs)
|
| 126 |
+
return math_ops.cast(sample, self.dtype)
|
| 127 |
+
|
| 128 |
+
def _log_prob(self, event):
|
| 129 |
+
if self.validate_args:
|
| 130 |
+
event = distribution_util.embed_check_integer_casting_closed(
|
| 131 |
+
event, target_dtype=dtypes.bool)
|
| 132 |
+
|
| 133 |
+
# TODO(jaana): The current sigmoid_cross_entropy_with_logits has
|
| 134 |
+
# inconsistent behavior for logits = inf/-inf.
|
| 135 |
+
event = math_ops.cast(event, self.logits.dtype)
|
| 136 |
+
logits = self.logits
|
| 137 |
+
# sigmoid_cross_entropy_with_logits doesn't broadcast shape,
|
| 138 |
+
# so we do this here.
|
| 139 |
+
|
| 140 |
+
def _broadcast(logits, event):
|
| 141 |
+
return (array_ops.ones_like(event) * logits,
|
| 142 |
+
array_ops.ones_like(logits) * event)
|
| 143 |
+
|
| 144 |
+
if not (event.get_shape().is_fully_defined() and
|
| 145 |
+
logits.get_shape().is_fully_defined() and
|
| 146 |
+
event.get_shape() == logits.get_shape()):
|
| 147 |
+
logits, event = _broadcast(logits, event)
|
| 148 |
+
return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
|
| 149 |
+
|
| 150 |
+
def _entropy(self):
|
| 151 |
+
return (-self.logits * (math_ops.sigmoid(self.logits) - 1) + # pylint: disable=invalid-unary-operand-type
|
| 152 |
+
nn.softplus(-self.logits)) # pylint: disable=invalid-unary-operand-type
|
| 153 |
+
|
| 154 |
+
def _mean(self):
|
| 155 |
+
return array_ops.identity(self.probs)
|
| 156 |
+
|
| 157 |
+
def _variance(self):
|
| 158 |
+
return self._mean() * (1. - self.probs)
|
| 159 |
+
|
| 160 |
+
def _mode(self):
|
| 161 |
+
"""Returns `1` if `prob > 0.5` and `0` otherwise."""
|
| 162 |
+
return math_ops.cast(self.probs > 0.5, self.dtype)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@kullback_leibler.RegisterKL(Bernoulli, Bernoulli)
|
| 166 |
+
def _kl_bernoulli_bernoulli(a, b, name=None):
|
| 167 |
+
"""Calculate the batched KL divergence KL(a || b) with a and b Bernoulli.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
a: instance of a Bernoulli distribution object.
|
| 171 |
+
b: instance of a Bernoulli distribution object.
|
| 172 |
+
name: (optional) Name to use for created operations.
|
| 173 |
+
default is "kl_bernoulli_bernoulli".
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Batchwise KL(a || b)
|
| 177 |
+
"""
|
| 178 |
+
with ops.name_scope(name, "kl_bernoulli_bernoulli",
|
| 179 |
+
values=[a.logits, b.logits]):
|
| 180 |
+
delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits)
|
| 181 |
+
delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits)
|
| 182 |
+
return (math_ops.sigmoid(a.logits) * delta_probs0
|
| 183 |
+
+ math_ops.sigmoid(-a.logits) * delta_probs1)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/beta.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Beta distribution class."""
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from tensorflow.python.framework import constant_op
|
| 20 |
+
from tensorflow.python.framework import dtypes
|
| 21 |
+
from tensorflow.python.framework import ops
|
| 22 |
+
from tensorflow.python.framework import tensor_shape
|
| 23 |
+
from tensorflow.python.ops import array_ops
|
| 24 |
+
from tensorflow.python.ops import check_ops
|
| 25 |
+
from tensorflow.python.ops import control_flow_ops
|
| 26 |
+
from tensorflow.python.ops import math_ops
|
| 27 |
+
from tensorflow.python.ops import nn
|
| 28 |
+
from tensorflow.python.ops import random_ops
|
| 29 |
+
from tensorflow.python.ops.distributions import distribution
|
| 30 |
+
from tensorflow.python.ops.distributions import kullback_leibler
|
| 31 |
+
from tensorflow.python.ops.distributions import util as distribution_util
|
| 32 |
+
from tensorflow.python.util import deprecation
|
| 33 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"Beta",
|
| 38 |
+
"BetaWithSoftplusConcentration",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in
|
| 43 |
+
`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@tf_export(v1=["distributions.Beta"])
|
| 47 |
+
class Beta(distribution.Distribution):
|
| 48 |
+
"""Beta distribution.
|
| 49 |
+
|
| 50 |
+
The Beta distribution is defined over the `(0, 1)` interval using parameters
|
| 51 |
+
`concentration1` (aka "alpha") and `concentration0` (aka "beta").
|
| 52 |
+
|
| 53 |
+
#### Mathematical Details
|
| 54 |
+
|
| 55 |
+
The probability density function (pdf) is,
|
| 56 |
+
|
| 57 |
+
```none
|
| 58 |
+
pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
|
| 59 |
+
Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
where:
|
| 63 |
+
|
| 64 |
+
* `concentration1 = alpha`,
|
| 65 |
+
* `concentration0 = beta`,
|
| 66 |
+
* `Z` is the normalization constant, and,
|
| 67 |
+
* `Gamma` is the [gamma function](
|
| 68 |
+
https://en.wikipedia.org/wiki/Gamma_function).
|
| 69 |
+
|
| 70 |
+
The concentration parameters represent mean total counts of a `1` or a `0`,
|
| 71 |
+
i.e.,
|
| 72 |
+
|
| 73 |
+
```none
|
| 74 |
+
concentration1 = alpha = mean * total_concentration
|
| 75 |
+
concentration0 = beta = (1. - mean) * total_concentration
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
where `mean` in `(0, 1)` and `total_concentration` is a positive real number
|
| 79 |
+
representing a mean `total_count = concentration1 + concentration0`.
|
| 80 |
+
|
| 81 |
+
Distribution parameters are automatically broadcast in all functions; see
|
| 82 |
+
examples for details.
|
| 83 |
+
|
| 84 |
+
Warning: The samples can be zero due to finite precision.
|
| 85 |
+
This happens more often when some of the concentrations are very small.
|
| 86 |
+
Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
|
| 87 |
+
density.
|
| 88 |
+
|
| 89 |
+
Samples of this distribution are reparameterized (pathwise differentiable).
|
| 90 |
+
The derivatives are computed using the approach described in
|
| 91 |
+
(Figurnov et al., 2018).
|
| 92 |
+
|
| 93 |
+
#### Examples
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
import tensorflow_probability as tfp
|
| 97 |
+
tfd = tfp.distributions
|
| 98 |
+
|
| 99 |
+
# Create a batch of three Beta distributions.
|
| 100 |
+
alpha = [1, 2, 3]
|
| 101 |
+
beta = [1, 2, 3]
|
| 102 |
+
dist = tfd.Beta(alpha, beta)
|
| 103 |
+
|
| 104 |
+
dist.sample([4, 5]) # Shape [4, 5, 3]
|
| 105 |
+
|
| 106 |
+
# `x` has three batch entries, each with two samples.
|
| 107 |
+
x = [[.1, .4, .5],
|
| 108 |
+
[.2, .3, .5]]
|
| 109 |
+
# Calculate the probability of each pair of samples under the corresponding
|
| 110 |
+
# distribution in `dist`.
|
| 111 |
+
dist.prob(x) # Shape [2, 3]
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
# Create batch_shape=[2, 3] via parameter broadcast:
|
| 116 |
+
alpha = [[1.], [2]] # Shape [2, 1]
|
| 117 |
+
beta = [3., 4, 5] # Shape [3]
|
| 118 |
+
dist = tfd.Beta(alpha, beta)
|
| 119 |
+
|
| 120 |
+
# alpha broadcast as: [[1., 1, 1,],
|
| 121 |
+
# [2, 2, 2]]
|
| 122 |
+
# beta broadcast as: [[3., 4, 5],
|
| 123 |
+
# [3, 4, 5]]
|
| 124 |
+
# batch_Shape [2, 3]
|
| 125 |
+
dist.sample([4, 5]) # Shape [4, 5, 2, 3]
|
| 126 |
+
|
| 127 |
+
x = [.2, .3, .5]
|
| 128 |
+
# x will be broadcast as [[.2, .3, .5],
|
| 129 |
+
# [.2, .3, .5]],
|
| 130 |
+
# thus matching batch_shape [2, 3].
|
| 131 |
+
dist.prob(x) # Shape [2, 3]
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Compute the gradients of samples w.r.t. the parameters:
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
alpha = tf.constant(1.0)
|
| 138 |
+
beta = tf.constant(2.0)
|
| 139 |
+
dist = tfd.Beta(alpha, beta)
|
| 140 |
+
samples = dist.sample(5) # Shape [5]
|
| 141 |
+
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
|
| 142 |
+
# Unbiased stochastic gradients of the loss function
|
| 143 |
+
grads = tf.gradients(loss, [alpha, beta])
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
References:
|
| 147 |
+
Implicit Reparameterization Gradients:
|
| 148 |
+
[Figurnov et al., 2018]
|
| 149 |
+
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
|
| 150 |
+
([pdf]
|
| 151 |
+
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
@deprecation.deprecated(
|
| 155 |
+
"2019-01-01",
|
| 156 |
+
"The TensorFlow Distributions library has moved to "
|
| 157 |
+
"TensorFlow Probability "
|
| 158 |
+
"(https://github.com/tensorflow/probability). You "
|
| 159 |
+
"should update all references to use `tfp.distributions` "
|
| 160 |
+
"instead of `tf.distributions`.",
|
| 161 |
+
warn_once=True)
|
| 162 |
+
def __init__(self,
|
| 163 |
+
concentration1=None,
|
| 164 |
+
concentration0=None,
|
| 165 |
+
validate_args=False,
|
| 166 |
+
allow_nan_stats=True,
|
| 167 |
+
name="Beta"):
|
| 168 |
+
"""Initialize a batch of Beta distributions.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
concentration1: Positive floating-point `Tensor` indicating mean
|
| 172 |
+
number of successes; aka "alpha". Implies `self.dtype` and
|
| 173 |
+
`self.batch_shape`, i.e.,
|
| 174 |
+
`concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
|
| 175 |
+
concentration0: Positive floating-point `Tensor` indicating mean
|
| 176 |
+
number of failures; aka "beta". Otherwise has same semantics as
|
| 177 |
+
`concentration1`.
|
| 178 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 179 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 180 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 181 |
+
outputs.
|
| 182 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
| 183 |
+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
| 184 |
+
result is undefined. When `False`, an exception is raised if one or
|
| 185 |
+
more of the statistic's batch members are undefined.
|
| 186 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 187 |
+
"""
|
| 188 |
+
parameters = dict(locals())
|
| 189 |
+
with ops.name_scope(name, values=[concentration1, concentration0]) as name:
|
| 190 |
+
self._concentration1 = self._maybe_assert_valid_concentration(
|
| 191 |
+
ops.convert_to_tensor(concentration1, name="concentration1"),
|
| 192 |
+
validate_args)
|
| 193 |
+
self._concentration0 = self._maybe_assert_valid_concentration(
|
| 194 |
+
ops.convert_to_tensor(concentration0, name="concentration0"),
|
| 195 |
+
validate_args)
|
| 196 |
+
check_ops.assert_same_float_dtype([
|
| 197 |
+
self._concentration1, self._concentration0])
|
| 198 |
+
self._total_concentration = self._concentration1 + self._concentration0
|
| 199 |
+
super(Beta, self).__init__(
|
| 200 |
+
dtype=self._total_concentration.dtype,
|
| 201 |
+
validate_args=validate_args,
|
| 202 |
+
allow_nan_stats=allow_nan_stats,
|
| 203 |
+
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
|
| 204 |
+
parameters=parameters,
|
| 205 |
+
graph_parents=[self._concentration1,
|
| 206 |
+
self._concentration0,
|
| 207 |
+
self._total_concentration],
|
| 208 |
+
name=name)
|
| 209 |
+
|
| 210 |
+
@staticmethod
|
| 211 |
+
def _param_shapes(sample_shape):
|
| 212 |
+
return dict(zip(
|
| 213 |
+
["concentration1", "concentration0"],
|
| 214 |
+
[ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def concentration1(self):
|
| 218 |
+
"""Concentration parameter associated with a `1` outcome."""
|
| 219 |
+
return self._concentration1
|
| 220 |
+
|
| 221 |
+
@property
|
| 222 |
+
def concentration0(self):
|
| 223 |
+
"""Concentration parameter associated with a `0` outcome."""
|
| 224 |
+
return self._concentration0
|
| 225 |
+
|
| 226 |
+
@property
|
| 227 |
+
def total_concentration(self):
|
| 228 |
+
"""Sum of concentration parameters."""
|
| 229 |
+
return self._total_concentration
|
| 230 |
+
|
| 231 |
+
def _batch_shape_tensor(self):
|
| 232 |
+
return array_ops.shape(self.total_concentration)
|
| 233 |
+
|
| 234 |
+
def _batch_shape(self):
|
| 235 |
+
return self.total_concentration.get_shape()
|
| 236 |
+
|
| 237 |
+
def _event_shape_tensor(self):
|
| 238 |
+
return constant_op.constant([], dtype=dtypes.int32)
|
| 239 |
+
|
| 240 |
+
def _event_shape(self):
|
| 241 |
+
return tensor_shape.TensorShape([])
|
| 242 |
+
|
| 243 |
+
def _sample_n(self, n, seed=None):
|
| 244 |
+
expanded_concentration1 = array_ops.ones_like(
|
| 245 |
+
self.total_concentration, dtype=self.dtype) * self.concentration1
|
| 246 |
+
expanded_concentration0 = array_ops.ones_like(
|
| 247 |
+
self.total_concentration, dtype=self.dtype) * self.concentration0
|
| 248 |
+
gamma1_sample = random_ops.random_gamma(
|
| 249 |
+
shape=[n],
|
| 250 |
+
alpha=expanded_concentration1,
|
| 251 |
+
dtype=self.dtype,
|
| 252 |
+
seed=seed)
|
| 253 |
+
gamma2_sample = random_ops.random_gamma(
|
| 254 |
+
shape=[n],
|
| 255 |
+
alpha=expanded_concentration0,
|
| 256 |
+
dtype=self.dtype,
|
| 257 |
+
seed=distribution_util.gen_new_seed(seed, "beta"))
|
| 258 |
+
beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
|
| 259 |
+
return beta_sample
|
| 260 |
+
|
| 261 |
+
@distribution_util.AppendDocstring(_beta_sample_note)
|
| 262 |
+
def _log_prob(self, x):
|
| 263 |
+
return self._log_unnormalized_prob(x) - self._log_normalization()
|
| 264 |
+
|
| 265 |
+
@distribution_util.AppendDocstring(_beta_sample_note)
|
| 266 |
+
def _prob(self, x):
|
| 267 |
+
return math_ops.exp(self._log_prob(x))
|
| 268 |
+
|
| 269 |
+
@distribution_util.AppendDocstring(_beta_sample_note)
|
| 270 |
+
def _log_cdf(self, x):
|
| 271 |
+
return math_ops.log(self._cdf(x))
|
| 272 |
+
|
| 273 |
+
@distribution_util.AppendDocstring(_beta_sample_note)
|
| 274 |
+
def _cdf(self, x):
|
| 275 |
+
return math_ops.betainc(self.concentration1, self.concentration0, x)
|
| 276 |
+
|
| 277 |
+
def _log_unnormalized_prob(self, x):
|
| 278 |
+
x = self._maybe_assert_valid_sample(x)
|
| 279 |
+
return (math_ops.xlogy(self.concentration1 - 1., x) +
|
| 280 |
+
(self.concentration0 - 1.) * math_ops.log1p(-x)) # pylint: disable=invalid-unary-operand-type
|
| 281 |
+
|
| 282 |
+
def _log_normalization(self):
|
| 283 |
+
return (math_ops.lgamma(self.concentration1)
|
| 284 |
+
+ math_ops.lgamma(self.concentration0)
|
| 285 |
+
- math_ops.lgamma(self.total_concentration))
|
| 286 |
+
|
| 287 |
+
def _entropy(self):
|
| 288 |
+
return (
|
| 289 |
+
self._log_normalization()
|
| 290 |
+
- (self.concentration1 - 1.) * math_ops.digamma(self.concentration1)
|
| 291 |
+
- (self.concentration0 - 1.) * math_ops.digamma(self.concentration0)
|
| 292 |
+
+ ((self.total_concentration - 2.) *
|
| 293 |
+
math_ops.digamma(self.total_concentration)))
|
| 294 |
+
|
| 295 |
+
def _mean(self):
|
| 296 |
+
return self._concentration1 / self._total_concentration
|
| 297 |
+
|
| 298 |
+
def _variance(self):
|
| 299 |
+
return self._mean() * (1. - self._mean()) / (1. + self.total_concentration)
|
| 300 |
+
|
| 301 |
+
@distribution_util.AppendDocstring(
|
| 302 |
+
"""Note: The mode is undefined when `concentration1 <= 1` or
|
| 303 |
+
`concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN`
|
| 304 |
+
is used for undefined modes. If `self.allow_nan_stats` is `False` an
|
| 305 |
+
exception is raised when one or more modes are undefined.""")
|
| 306 |
+
def _mode(self):
|
| 307 |
+
mode = (self.concentration1 - 1.) / (self.total_concentration - 2.)
|
| 308 |
+
if self.allow_nan_stats:
|
| 309 |
+
nan = array_ops.fill(
|
| 310 |
+
self.batch_shape_tensor(),
|
| 311 |
+
np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
|
| 312 |
+
name="nan")
|
| 313 |
+
is_defined = math_ops.logical_and(self.concentration1 > 1.,
|
| 314 |
+
self.concentration0 > 1.)
|
| 315 |
+
return array_ops.where_v2(is_defined, mode, nan)
|
| 316 |
+
return control_flow_ops.with_dependencies([
|
| 317 |
+
check_ops.assert_less(
|
| 318 |
+
array_ops.ones([], dtype=self.dtype),
|
| 319 |
+
self.concentration1,
|
| 320 |
+
message="Mode undefined for concentration1 <= 1."),
|
| 321 |
+
check_ops.assert_less(
|
| 322 |
+
array_ops.ones([], dtype=self.dtype),
|
| 323 |
+
self.concentration0,
|
| 324 |
+
message="Mode undefined for concentration0 <= 1.")
|
| 325 |
+
], mode)
|
| 326 |
+
|
| 327 |
+
def _maybe_assert_valid_concentration(self, concentration, validate_args):
|
| 328 |
+
"""Checks the validity of a concentration parameter."""
|
| 329 |
+
if not validate_args:
|
| 330 |
+
return concentration
|
| 331 |
+
return control_flow_ops.with_dependencies([
|
| 332 |
+
check_ops.assert_positive(
|
| 333 |
+
concentration,
|
| 334 |
+
message="Concentration parameter must be positive."),
|
| 335 |
+
], concentration)
|
| 336 |
+
|
| 337 |
+
def _maybe_assert_valid_sample(self, x):
|
| 338 |
+
"""Checks the validity of a sample."""
|
| 339 |
+
if not self.validate_args:
|
| 340 |
+
return x
|
| 341 |
+
return control_flow_ops.with_dependencies([
|
| 342 |
+
check_ops.assert_positive(x, message="sample must be positive"),
|
| 343 |
+
check_ops.assert_less(
|
| 344 |
+
x,
|
| 345 |
+
array_ops.ones([], self.dtype),
|
| 346 |
+
message="sample must be less than `1`."),
|
| 347 |
+
], x)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class BetaWithSoftplusConcentration(Beta):
|
| 351 |
+
"""Beta with softplus transform of `concentration1` and `concentration0`."""
|
| 352 |
+
|
| 353 |
+
@deprecation.deprecated(
|
| 354 |
+
"2019-01-01",
|
| 355 |
+
"Use `tfd.Beta(tf.nn.softplus(concentration1), "
|
| 356 |
+
"tf.nn.softplus(concentration2))` instead.",
|
| 357 |
+
warn_once=True)
|
| 358 |
+
def __init__(self,
|
| 359 |
+
concentration1,
|
| 360 |
+
concentration0,
|
| 361 |
+
validate_args=False,
|
| 362 |
+
allow_nan_stats=True,
|
| 363 |
+
name="BetaWithSoftplusConcentration"):
|
| 364 |
+
parameters = dict(locals())
|
| 365 |
+
with ops.name_scope(name, values=[concentration1,
|
| 366 |
+
concentration0]) as name:
|
| 367 |
+
super(BetaWithSoftplusConcentration, self).__init__(
|
| 368 |
+
concentration1=nn.softplus(concentration1,
|
| 369 |
+
name="softplus_concentration1"),
|
| 370 |
+
concentration0=nn.softplus(concentration0,
|
| 371 |
+
name="softplus_concentration0"),
|
| 372 |
+
validate_args=validate_args,
|
| 373 |
+
allow_nan_stats=allow_nan_stats,
|
| 374 |
+
name=name)
|
| 375 |
+
self._parameters = parameters
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
@kullback_leibler.RegisterKL(Beta, Beta)
|
| 379 |
+
def _kl_beta_beta(d1, d2, name=None):
|
| 380 |
+
"""Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
d1: instance of a Beta distribution object.
|
| 384 |
+
d2: instance of a Beta distribution object.
|
| 385 |
+
name: (optional) Name to use for created operations.
|
| 386 |
+
default is "kl_beta_beta".
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
Batchwise KL(d1 || d2)
|
| 390 |
+
"""
|
| 391 |
+
def delta(fn, is_property=True):
|
| 392 |
+
fn1 = getattr(d1, fn)
|
| 393 |
+
fn2 = getattr(d2, fn)
|
| 394 |
+
return (fn2 - fn1) if is_property else (fn2() - fn1())
|
| 395 |
+
with ops.name_scope(name, "kl_beta_beta", values=[
|
| 396 |
+
d1.concentration1,
|
| 397 |
+
d1.concentration0,
|
| 398 |
+
d1.total_concentration,
|
| 399 |
+
d2.concentration1,
|
| 400 |
+
d2.concentration0,
|
| 401 |
+
d2.total_concentration,
|
| 402 |
+
]):
|
| 403 |
+
return (delta("_log_normalization", is_property=False)
|
| 404 |
+
- math_ops.digamma(d1.concentration1) * delta("concentration1")
|
| 405 |
+
- math_ops.digamma(d1.concentration0) * delta("concentration0")
|
| 406 |
+
+ (math_ops.digamma(d1.total_concentration)
|
| 407 |
+
* delta("total_concentration")))
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/bijector.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Bijector base."""
|
| 16 |
+
|
| 17 |
+
# go/tf-wildcard-import
|
| 18 |
+
# pylint: disable=wildcard-import,unused-import
|
| 19 |
+
from tensorflow.python.ops.distributions.bijector_impl import Bijector
|
| 20 |
+
|
| 21 |
+
# pylint: enable=wildcard-import,unused-import
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/bijector_impl.py
ADDED
|
@@ -0,0 +1,1113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Bijector base."""
|
| 16 |
+
|
| 17 |
+
import abc
|
| 18 |
+
import collections
|
| 19 |
+
import contextlib
|
| 20 |
+
import re
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from tensorflow.python.framework import dtypes
|
| 25 |
+
from tensorflow.python.framework import ops
|
| 26 |
+
from tensorflow.python.framework import tensor_shape
|
| 27 |
+
from tensorflow.python.framework import tensor_util
|
| 28 |
+
from tensorflow.python.ops import array_ops
|
| 29 |
+
from tensorflow.python.ops import check_ops
|
| 30 |
+
from tensorflow.python.ops import math_ops
|
| 31 |
+
from tensorflow.python.ops.distributions import util as distribution_util
|
| 32 |
+
from tensorflow.python.util import object_identity
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
"Bijector",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class _Mapping(collections.namedtuple(
|
| 41 |
+
"_Mapping", ["x", "y", "ildj_map", "kwargs"])):
|
| 42 |
+
"""Helper class to make it easier to manage caching in `Bijector`."""
|
| 43 |
+
|
| 44 |
+
def __new__(cls, x=None, y=None, ildj_map=None, kwargs=None):
|
| 45 |
+
"""Custom __new__ so namedtuple items have defaults.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
x: `Tensor`. Forward.
|
| 49 |
+
y: `Tensor`. Inverse.
|
| 50 |
+
ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor`
|
| 51 |
+
representing the inverse log det jacobian.
|
| 52 |
+
kwargs: Python dictionary. Extra args supplied to
|
| 53 |
+
forward/inverse/etc functions.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
mapping: New instance of _Mapping.
|
| 57 |
+
"""
|
| 58 |
+
return super(_Mapping, cls).__new__(cls, x, y, ildj_map, kwargs)
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def x_key(self):
|
| 62 |
+
"""Returns key used for caching Y=g(X)."""
|
| 63 |
+
return ((object_identity.Reference(self.x),) +
|
| 64 |
+
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def y_key(self):
|
| 68 |
+
"""Returns key used for caching X=g^{-1}(Y)."""
|
| 69 |
+
return ((object_identity.Reference(self.y),) +
|
| 70 |
+
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
|
| 71 |
+
|
| 72 |
+
def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None):
|
| 73 |
+
"""Returns new _Mapping with args merged with self.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
x: `Tensor`. Forward.
|
| 77 |
+
y: `Tensor`. Inverse.
|
| 78 |
+
ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor`
|
| 79 |
+
representing the inverse log det jacobian.
|
| 80 |
+
kwargs: Python dictionary. Extra args supplied to
|
| 81 |
+
forward/inverse/etc functions.
|
| 82 |
+
mapping: Instance of _Mapping to merge. Can only be specified if no other
|
| 83 |
+
arg is specified.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
mapping: New instance of `_Mapping` which has inputs merged with self.
|
| 87 |
+
|
| 88 |
+
Raises:
|
| 89 |
+
ValueError: if mapping and any other arg is not `None`.
|
| 90 |
+
"""
|
| 91 |
+
if mapping is None:
|
| 92 |
+
mapping = _Mapping(x=x, y=y, ildj_map=ildj_map, kwargs=kwargs)
|
| 93 |
+
elif any(arg is not None for arg in [x, y, ildj_map, kwargs]):
|
| 94 |
+
raise ValueError("Cannot simultaneously specify mapping and individual "
|
| 95 |
+
"arguments.")
|
| 96 |
+
|
| 97 |
+
return _Mapping(
|
| 98 |
+
x=self._merge(self.x, mapping.x),
|
| 99 |
+
y=self._merge(self.y, mapping.y),
|
| 100 |
+
ildj_map=self._merge_dicts(self.ildj_map, mapping.ildj_map),
|
| 101 |
+
kwargs=self._merge(self.kwargs, mapping.kwargs))
|
| 102 |
+
|
| 103 |
+
def _merge_dicts(self, old=None, new=None):
|
| 104 |
+
"""Helper to merge two dictionaries."""
|
| 105 |
+
old = {} if old is None else old
|
| 106 |
+
new = {} if new is None else new
|
| 107 |
+
for k, v in new.items():
|
| 108 |
+
val = old.get(k, None)
|
| 109 |
+
if val is not None and val is not v:
|
| 110 |
+
raise ValueError("Found different value for existing key "
|
| 111 |
+
"(key:{} old_value:{} new_value:{}".format(
|
| 112 |
+
k, old[k], v))
|
| 113 |
+
old[k] = v
|
| 114 |
+
return old
|
| 115 |
+
|
| 116 |
+
def _merge(self, old, new):
|
| 117 |
+
"""Helper to merge which handles merging one value."""
|
| 118 |
+
if old is None:
|
| 119 |
+
return new
|
| 120 |
+
elif new is not None and old is not new:
|
| 121 |
+
raise ValueError("Incompatible values: %s != %s" % (old, new))
|
| 122 |
+
return old
|
| 123 |
+
|
| 124 |
+
def _deep_tuple(self, x):
|
| 125 |
+
"""Converts lists of lists to tuples of tuples."""
|
| 126 |
+
return (tuple(map(self._deep_tuple, x))
|
| 127 |
+
if isinstance(x, (list, tuple)) else x)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class Bijector(metaclass=abc.ABCMeta):
|
| 131 |
+
r"""Interface for transformations of a `Distribution` sample.
|
| 132 |
+
|
| 133 |
+
Bijectors can be used to represent any differentiable and injective
|
| 134 |
+
(one to one) function defined on an open subset of `R^n`. Some non-injective
|
| 135 |
+
transformations are also supported (see "Non Injective Transforms" below).
|
| 136 |
+
|
| 137 |
+
#### Mathematical Details
|
| 138 |
+
|
| 139 |
+
A `Bijector` implements a [smooth covering map](
|
| 140 |
+
https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local
|
| 141 |
+
diffeomorphism such that every point in the target has a neighborhood evenly
|
| 142 |
+
covered by a map ([see also](
|
| 143 |
+
https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)).
|
| 144 |
+
A `Bijector` is used by `TransformedDistribution` but can be generally used
|
| 145 |
+
for transforming a `Distribution` generated `Tensor`. A `Bijector` is
|
| 146 |
+
characterized by three operations:
|
| 147 |
+
|
| 148 |
+
1. Forward
|
| 149 |
+
|
| 150 |
+
Useful for turning one random outcome into another random outcome from a
|
| 151 |
+
different distribution.
|
| 152 |
+
|
| 153 |
+
2. Inverse
|
| 154 |
+
|
| 155 |
+
Useful for "reversing" a transformation to compute one probability in
|
| 156 |
+
terms of another.
|
| 157 |
+
|
| 158 |
+
3. `log_det_jacobian(x)`
|
| 159 |
+
|
| 160 |
+
"The log of the absolute value of the determinant of the matrix of all
|
| 161 |
+
first-order partial derivatives of the inverse function."
|
| 162 |
+
|
| 163 |
+
Useful for inverting a transformation to compute one probability in terms
|
| 164 |
+
of another. Geometrically, the Jacobian determinant is the volume of the
|
| 165 |
+
transformation and is used to scale the probability.
|
| 166 |
+
|
| 167 |
+
We take the absolute value of the determinant before log to avoid NaN
|
| 168 |
+
values. Geometrically, a negative determinant corresponds to an
|
| 169 |
+
orientation-reversing transformation. It is ok for us to discard the sign
|
| 170 |
+
of the determinant because we only integrate everywhere-nonnegative
|
| 171 |
+
functions (probability densities) and the correct orientation is always the
|
| 172 |
+
one that produces a nonnegative integrand.
|
| 173 |
+
|
| 174 |
+
By convention, transformations of random variables are named in terms of the
|
| 175 |
+
forward transformation. The forward transformation creates samples, the
|
| 176 |
+
inverse is useful for computing probabilities.
|
| 177 |
+
|
| 178 |
+
#### Example Uses
|
| 179 |
+
|
| 180 |
+
- Basic properties:
|
| 181 |
+
|
| 182 |
+
```python
|
| 183 |
+
x = ... # A tensor.
|
| 184 |
+
# Evaluate forward transformation.
|
| 185 |
+
fwd_x = my_bijector.forward(x)
|
| 186 |
+
x == my_bijector.inverse(fwd_x)
|
| 187 |
+
x != my_bijector.forward(fwd_x) # Not equal because x != g(g(x)).
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
- Computing a log-likelihood:
|
| 191 |
+
|
| 192 |
+
```python
|
| 193 |
+
def transformed_log_prob(bijector, log_prob, x):
|
| 194 |
+
return (bijector.inverse_log_det_jacobian(x, event_ndims=0) +
|
| 195 |
+
log_prob(bijector.inverse(x)))
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
- Transforming a random outcome:
|
| 199 |
+
|
| 200 |
+
```python
|
| 201 |
+
def transformed_sample(bijector, x):
|
| 202 |
+
return bijector.forward(x)
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
#### Example Bijectors
|
| 206 |
+
|
| 207 |
+
- "Exponential"
|
| 208 |
+
|
| 209 |
+
```none
|
| 210 |
+
Y = g(X) = exp(X)
|
| 211 |
+
X ~ Normal(0, 1) # Univariate.
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
Implies:
|
| 215 |
+
|
| 216 |
+
```none
|
| 217 |
+
g^{-1}(Y) = log(Y)
|
| 218 |
+
|Jacobian(g^{-1})(y)| = 1 / y
|
| 219 |
+
Y ~ LogNormal(0, 1), i.e.,
|
| 220 |
+
prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
|
| 221 |
+
= (1 / y) Normal(log(y); 0, 1)
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
Here is an example of how one might implement the `Exp` bijector:
|
| 225 |
+
|
| 226 |
+
```python
|
| 227 |
+
class Exp(Bijector):
|
| 228 |
+
|
| 229 |
+
def __init__(self, validate_args=False, name="exp"):
|
| 230 |
+
super(Exp, self).__init__(
|
| 231 |
+
validate_args=validate_args,
|
| 232 |
+
forward_min_event_ndims=0,
|
| 233 |
+
name=name)
|
| 234 |
+
|
| 235 |
+
def _forward(self, x):
|
| 236 |
+
return math_ops.exp(x)
|
| 237 |
+
|
| 238 |
+
def _inverse(self, y):
|
| 239 |
+
return math_ops.log(y)
|
| 240 |
+
|
| 241 |
+
def _inverse_log_det_jacobian(self, y):
|
| 242 |
+
return -self._forward_log_det_jacobian(self._inverse(y))
|
| 243 |
+
|
| 244 |
+
def _forward_log_det_jacobian(self, x):
|
| 245 |
+
# Notice that we needn't do any reducing, even when`event_ndims > 0`.
|
| 246 |
+
# The base Bijector class will handle reducing for us; it knows how
|
| 247 |
+
# to do so because we called `super` `__init__` with
|
| 248 |
+
# `forward_min_event_ndims = 0`.
|
| 249 |
+
return x
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
- "Affine"
|
| 253 |
+
|
| 254 |
+
```none
|
| 255 |
+
Y = g(X) = sqrtSigma * X + mu
|
| 256 |
+
X ~ MultivariateNormal(0, I_d)
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
Implies:
|
| 260 |
+
|
| 261 |
+
```none
|
| 262 |
+
g^{-1}(Y) = inv(sqrtSigma) * (Y - mu)
|
| 263 |
+
|Jacobian(g^{-1})(y)| = det(inv(sqrtSigma))
|
| 264 |
+
Y ~ MultivariateNormal(mu, sqrtSigma) , i.e.,
|
| 265 |
+
prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
|
| 266 |
+
= det(sqrtSigma)^(-d) *
|
| 267 |
+
MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
#### Min_event_ndims and Naming
|
| 271 |
+
|
| 272 |
+
Bijectors are named for the dimensionality of data they act on (i.e. without
|
| 273 |
+
broadcasting). We can think of bijectors having an intrinsic `min_event_ndims`
|
| 274 |
+
, which is the minimum number of dimensions for the bijector act on. For
|
| 275 |
+
instance, a Cholesky decomposition requires a matrix, and hence
|
| 276 |
+
`min_event_ndims=2`.
|
| 277 |
+
|
| 278 |
+
Some examples:
|
| 279 |
+
|
| 280 |
+
`AffineScalar: min_event_ndims=0`
|
| 281 |
+
`Affine: min_event_ndims=1`
|
| 282 |
+
`Cholesky: min_event_ndims=2`
|
| 283 |
+
`Exp: min_event_ndims=0`
|
| 284 |
+
`Sigmoid: min_event_ndims=0`
|
| 285 |
+
`SoftmaxCentered: min_event_ndims=1`
|
| 286 |
+
|
| 287 |
+
Note the difference between `Affine` and `AffineScalar`. `AffineScalar`
|
| 288 |
+
operates on scalar events, whereas `Affine` operates on vector-valued events.
|
| 289 |
+
|
| 290 |
+
More generally, there is a `forward_min_event_ndims` and an
|
| 291 |
+
`inverse_min_event_ndims`. In most cases, these will be the same.
|
| 292 |
+
However, for some shape changing bijectors, these will be different
|
| 293 |
+
(e.g. a bijector which pads an extra dimension at the end, might have
|
| 294 |
+
`forward_min_event_ndims=0` and `inverse_min_event_ndims=1`.
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
#### Jacobian Determinant
|
| 298 |
+
|
| 299 |
+
The Jacobian determinant is a reduction over `event_ndims - min_event_ndims`
|
| 300 |
+
(`forward_min_event_ndims` for `forward_log_det_jacobian` and
|
| 301 |
+
`inverse_min_event_ndims` for `inverse_log_det_jacobian`).
|
| 302 |
+
To see this, consider the `Exp` `Bijector` applied to a `Tensor` which has
|
| 303 |
+
sample, batch, and event (S, B, E) shape semantics. Suppose the `Tensor`'s
|
| 304 |
+
partitioned-shape is `(S=[4], B=[2], E=[3, 3])`. The shape of the `Tensor`
|
| 305 |
+
returned by `forward` and `inverse` is unchanged, i.e., `[4, 2, 3, 3]`.
|
| 306 |
+
However the shape returned by `inverse_log_det_jacobian` is `[4, 2]` because
|
| 307 |
+
the Jacobian determinant is a reduction over the event dimensions.
|
| 308 |
+
|
| 309 |
+
Another example is the `Affine` `Bijector`. Because `min_event_ndims = 1`, the
|
| 310 |
+
Jacobian determinant reduction is over `event_ndims - 1`.
|
| 311 |
+
|
| 312 |
+
It is sometimes useful to implement the inverse Jacobian determinant as the
|
| 313 |
+
negative forward Jacobian determinant. For example,
|
| 314 |
+
|
| 315 |
+
```python
|
| 316 |
+
def _inverse_log_det_jacobian(self, y):
|
| 317 |
+
return -self._forward_log_det_jac(self._inverse(y)) # Note negation.
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
The correctness of this approach can be seen from the following claim.
|
| 321 |
+
|
| 322 |
+
- Claim:
|
| 323 |
+
|
| 324 |
+
Assume `Y = g(X)` is a bijection whose derivative exists and is nonzero
|
| 325 |
+
for its domain, i.e., `dY/dX = d/dX g(X) != 0`. Then:
|
| 326 |
+
|
| 327 |
+
```none
|
| 328 |
+
(log o det o jacobian o g^{-1})(Y) = -(log o det o jacobian o g)(X)
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
- Proof:
|
| 332 |
+
|
| 333 |
+
From the bijective, nonzero differentiability of `g`, the
|
| 334 |
+
[inverse function theorem](
|
| 335 |
+
https://en.wikipedia.org/wiki/Inverse_function_theorem)
|
| 336 |
+
implies `g^{-1}` is differentiable in the image of `g`.
|
| 337 |
+
Applying the chain rule to `y = g(x) = g(g^{-1}(y))` yields
|
| 338 |
+
`I = g'(g^{-1}(y))*g^{-1}'(y)`.
|
| 339 |
+
The same theorem also implies `g^{-1}'` is non-singular therefore:
|
| 340 |
+
`inv[ g'(g^{-1}(y)) ] = g^{-1}'(y)`.
|
| 341 |
+
The claim follows from [properties of determinant](
|
| 342 |
+
https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups).
|
| 343 |
+
|
| 344 |
+
Generally its preferable to directly implement the inverse Jacobian
|
| 345 |
+
determinant. This should have superior numerical stability and will often
|
| 346 |
+
share subgraphs with the `_inverse` implementation.
|
| 347 |
+
|
| 348 |
+
#### Is_constant_jacobian
|
| 349 |
+
|
| 350 |
+
Certain bijectors will have constant jacobian matrices. For instance, the
|
| 351 |
+
`Affine` bijector encodes multiplication by a matrix plus a shift, with
|
| 352 |
+
jacobian matrix, the same aforementioned matrix.
|
| 353 |
+
|
| 354 |
+
`is_constant_jacobian` encodes the fact that the jacobian matrix is constant.
|
| 355 |
+
The semantics of this argument are the following:
|
| 356 |
+
|
| 357 |
+
* Repeated calls to "log_det_jacobian" functions with the same
|
| 358 |
+
`event_ndims` (but not necessarily same input), will return the first
|
| 359 |
+
computed jacobian (because the matrix is constant, and hence is input
|
| 360 |
+
independent).
|
| 361 |
+
* `log_det_jacobian` implementations are merely broadcastable to the true
|
| 362 |
+
`log_det_jacobian` (because, again, the jacobian matrix is input
|
| 363 |
+
independent). Specifically, `log_det_jacobian` is implemented as the
|
| 364 |
+
log jacobian determinant for a single input.
|
| 365 |
+
|
| 366 |
+
```python
|
| 367 |
+
class Identity(Bijector):
|
| 368 |
+
|
| 369 |
+
def __init__(self, validate_args=False, name="identity"):
|
| 370 |
+
super(Identity, self).__init__(
|
| 371 |
+
is_constant_jacobian=True,
|
| 372 |
+
validate_args=validate_args,
|
| 373 |
+
forward_min_event_ndims=0,
|
| 374 |
+
name=name)
|
| 375 |
+
|
| 376 |
+
def _forward(self, x):
|
| 377 |
+
return x
|
| 378 |
+
|
| 379 |
+
def _inverse(self, y):
|
| 380 |
+
return y
|
| 381 |
+
|
| 382 |
+
def _inverse_log_det_jacobian(self, y):
|
| 383 |
+
return -self._forward_log_det_jacobian(self._inverse(y))
|
| 384 |
+
|
| 385 |
+
def _forward_log_det_jacobian(self, x):
|
| 386 |
+
# The full log jacobian determinant would be array_ops.zero_like(x).
|
| 387 |
+
# However, we circumvent materializing that, since the jacobian
|
| 388 |
+
# calculation is input independent, and we specify it for one input.
|
| 389 |
+
return constant_op.constant(0., x.dtype.base_dtype)
|
| 390 |
+
|
| 391 |
+
```
|
| 392 |
+
|
| 393 |
+
#### Subclass Requirements
|
| 394 |
+
|
| 395 |
+
- Subclasses typically implement:
|
| 396 |
+
|
| 397 |
+
- `_forward`,
|
| 398 |
+
- `_inverse`,
|
| 399 |
+
- `_inverse_log_det_jacobian`,
|
| 400 |
+
- `_forward_log_det_jacobian` (optional).
|
| 401 |
+
|
| 402 |
+
The `_forward_log_det_jacobian` is called when the bijector is inverted via
|
| 403 |
+
the `Invert` bijector. If undefined, a slightly less efficiently
|
| 404 |
+
calculation, `-1 * _inverse_log_det_jacobian`, is used.
|
| 405 |
+
|
| 406 |
+
If the bijector changes the shape of the input, you must also implement:
|
| 407 |
+
|
| 408 |
+
- _forward_event_shape_tensor,
|
| 409 |
+
- _forward_event_shape (optional),
|
| 410 |
+
- _inverse_event_shape_tensor,
|
| 411 |
+
- _inverse_event_shape (optional).
|
| 412 |
+
|
| 413 |
+
By default the event-shape is assumed unchanged from input.
|
| 414 |
+
|
| 415 |
+
- If the `Bijector`'s use is limited to `TransformedDistribution` (or friends
|
| 416 |
+
like `QuantizedDistribution`) then depending on your use, you may not need
|
| 417 |
+
to implement all of `_forward` and `_inverse` functions.
|
| 418 |
+
|
| 419 |
+
Examples:
|
| 420 |
+
|
| 421 |
+
1. Sampling (e.g., `sample`) only requires `_forward`.
|
| 422 |
+
2. Probability functions (e.g., `prob`, `cdf`, `survival`) only require
|
| 423 |
+
`_inverse` (and related).
|
| 424 |
+
3. Only calling probability functions on the output of `sample` means
|
| 425 |
+
`_inverse` can be implemented as a cache lookup.
|
| 426 |
+
|
| 427 |
+
See "Example Uses" [above] which shows how these functions are used to
|
| 428 |
+
transform a distribution. (Note: `_forward` could theoretically be
|
| 429 |
+
implemented as a cache lookup but this would require controlling the
|
| 430 |
+
underlying sample generation mechanism.)
|
| 431 |
+
|
| 432 |
+
#### Non Injective Transforms
|
| 433 |
+
|
| 434 |
+
**WARNING** Handing of non-injective transforms is subject to change.
|
| 435 |
+
|
| 436 |
+
Non injective maps `g` are supported, provided their domain `D` can be
|
| 437 |
+
partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
|
| 438 |
+
ignoring sets of measure zero, the restriction of `g` to each subset is a
|
| 439 |
+
differentiable bijection onto `g(D)`. In particular, this implies that for
|
| 440 |
+
`y in g(D)`, the set inverse, i.e. `g^{-1}(y) = {x in D : g(x) = y}`, always
|
| 441 |
+
contains exactly `k` distinct points.
|
| 442 |
+
|
| 443 |
+
The property, `_is_injective` is set to `False` to indicate that the bijector
|
| 444 |
+
is not injective, yet satisfies the above condition.
|
| 445 |
+
|
| 446 |
+
The usual bijector API is modified in the case `_is_injective is False` (see
|
| 447 |
+
method docstrings for specifics). Here we show by example the `AbsoluteValue`
|
| 448 |
+
bijector. In this case, the domain `D = (-inf, inf)`, can be partitioned
|
| 449 |
+
into `D1 = (-inf, 0)`, `D2 = {0}`, and `D3 = (0, inf)`. Let `gi` be the
|
| 450 |
+
restriction of `g` to `Di`, then both `g1` and `g3` are bijections onto
|
| 451 |
+
`(0, inf)`, with `g1^{-1}(y) = -y`, and `g3^{-1}(y) = y`. We will use
|
| 452 |
+
`g1` and `g3` to define bijector methods over `D1` and `D3`. `D2 = {0}` is
|
| 453 |
+
an oddball in that `g2` is one to one, and the derivative is not well defined.
|
| 454 |
+
Fortunately, when considering transformations of probability densities
|
| 455 |
+
(e.g. in `TransformedDistribution`), sets of measure zero have no effect in
|
| 456 |
+
theory, and only a small effect in 32 or 64 bit precision. For that reason,
|
| 457 |
+
we define `inverse(0)` and `inverse_log_det_jacobian(0)` both as `[0, 0]`,
|
| 458 |
+
which is convenient and results in a left-semicontinuous pdf.
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
```python
|
| 462 |
+
abs = tfp.distributions.bijectors.AbsoluteValue()
|
| 463 |
+
|
| 464 |
+
abs.forward(-1.)
|
| 465 |
+
==> 1.
|
| 466 |
+
|
| 467 |
+
abs.forward(1.)
|
| 468 |
+
==> 1.
|
| 469 |
+
|
| 470 |
+
abs.inverse(1.)
|
| 471 |
+
==> (-1., 1.)
|
| 472 |
+
|
| 473 |
+
# The |dX/dY| is constant, == 1. So Log|dX/dY| == 0.
|
| 474 |
+
abs.inverse_log_det_jacobian(1., event_ndims=0)
|
| 475 |
+
==> (0., 0.)
|
| 476 |
+
|
| 477 |
+
# Special case handling of 0.
|
| 478 |
+
abs.inverse(0.)
|
| 479 |
+
==> (0., 0.)
|
| 480 |
+
|
| 481 |
+
abs.inverse_log_det_jacobian(0., event_ndims=0)
|
| 482 |
+
==> (0., 0.)
|
| 483 |
+
```
|
| 484 |
+
|
| 485 |
+
"""
|
| 486 |
+
|
| 487 |
+
@abc.abstractmethod
|
| 488 |
+
def __init__(self,
|
| 489 |
+
graph_parents=None,
|
| 490 |
+
is_constant_jacobian=False,
|
| 491 |
+
validate_args=False,
|
| 492 |
+
dtype=None,
|
| 493 |
+
forward_min_event_ndims=None,
|
| 494 |
+
inverse_min_event_ndims=None,
|
| 495 |
+
name=None):
|
| 496 |
+
"""Constructs Bijector.
|
| 497 |
+
|
| 498 |
+
A `Bijector` transforms random variables into new random variables.
|
| 499 |
+
|
| 500 |
+
Examples:
|
| 501 |
+
|
| 502 |
+
```python
|
| 503 |
+
# Create the Y = g(X) = X transform.
|
| 504 |
+
identity = Identity()
|
| 505 |
+
|
| 506 |
+
# Create the Y = g(X) = exp(X) transform.
|
| 507 |
+
exp = Exp()
|
| 508 |
+
```
|
| 509 |
+
|
| 510 |
+
See `Bijector` subclass docstring for more details and specific examples.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
graph_parents: Python list of graph prerequisites of this `Bijector`.
|
| 514 |
+
is_constant_jacobian: Python `bool` indicating that the Jacobian matrix is
|
| 515 |
+
not a function of the input.
|
| 516 |
+
validate_args: Python `bool`, default `False`. Whether to validate input
|
| 517 |
+
with asserts. If `validate_args` is `False`, and the inputs are invalid,
|
| 518 |
+
correct behavior is not guaranteed.
|
| 519 |
+
dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
|
| 520 |
+
enforced.
|
| 521 |
+
forward_min_event_ndims: Python `integer` indicating the minimum number of
|
| 522 |
+
dimensions `forward` operates on.
|
| 523 |
+
inverse_min_event_ndims: Python `integer` indicating the minimum number of
|
| 524 |
+
dimensions `inverse` operates on. Will be set to
|
| 525 |
+
`forward_min_event_ndims` by default, if no value is provided.
|
| 526 |
+
name: The name to give Ops created by the initializer.
|
| 527 |
+
|
| 528 |
+
Raises:
|
| 529 |
+
ValueError: If neither `forward_min_event_ndims` and
|
| 530 |
+
`inverse_min_event_ndims` are specified, or if either of them is
|
| 531 |
+
negative.
|
| 532 |
+
ValueError: If a member of `graph_parents` is not a `Tensor`.
|
| 533 |
+
"""
|
| 534 |
+
self._graph_parents = graph_parents or []
|
| 535 |
+
|
| 536 |
+
if forward_min_event_ndims is None and inverse_min_event_ndims is None:
|
| 537 |
+
raise ValueError("Must specify at least one of `forward_min_event_ndims` "
|
| 538 |
+
"and `inverse_min_event_ndims`.")
|
| 539 |
+
elif inverse_min_event_ndims is None:
|
| 540 |
+
inverse_min_event_ndims = forward_min_event_ndims
|
| 541 |
+
elif forward_min_event_ndims is None:
|
| 542 |
+
forward_min_event_ndims = inverse_min_event_ndims
|
| 543 |
+
|
| 544 |
+
if not isinstance(forward_min_event_ndims, int):
|
| 545 |
+
raise TypeError("Expected forward_min_event_ndims to be of "
|
| 546 |
+
"type int, got {}".format(
|
| 547 |
+
type(forward_min_event_ndims).__name__))
|
| 548 |
+
|
| 549 |
+
if not isinstance(inverse_min_event_ndims, int):
|
| 550 |
+
raise TypeError("Expected inverse_min_event_ndims to be of "
|
| 551 |
+
"type int, got {}".format(
|
| 552 |
+
type(inverse_min_event_ndims).__name__))
|
| 553 |
+
|
| 554 |
+
if forward_min_event_ndims < 0:
|
| 555 |
+
raise ValueError("forward_min_event_ndims must be a non-negative "
|
| 556 |
+
"integer.")
|
| 557 |
+
if inverse_min_event_ndims < 0:
|
| 558 |
+
raise ValueError("inverse_min_event_ndims must be a non-negative "
|
| 559 |
+
"integer.")
|
| 560 |
+
|
| 561 |
+
self._forward_min_event_ndims = forward_min_event_ndims
|
| 562 |
+
self._inverse_min_event_ndims = inverse_min_event_ndims
|
| 563 |
+
self._is_constant_jacobian = is_constant_jacobian
|
| 564 |
+
self._constant_ildj_map = {}
|
| 565 |
+
self._validate_args = validate_args
|
| 566 |
+
self._dtype = dtype
|
| 567 |
+
# These dicts can only be accessed using _Mapping.x_key or _Mapping.y_key
|
| 568 |
+
self._from_y = {}
|
| 569 |
+
self._from_x = {}
|
| 570 |
+
if name:
|
| 571 |
+
self._name = name
|
| 572 |
+
else:
|
| 573 |
+
# We want the default convention to be snake_case rather than CamelCase
|
| 574 |
+
# since `Chain` uses bijector.name as the kwargs dictionary key.
|
| 575 |
+
def camel_to_snake(name):
|
| 576 |
+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
| 577 |
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
| 578 |
+
self._name = camel_to_snake(type(self).__name__.lstrip("_"))
|
| 579 |
+
|
| 580 |
+
for i, t in enumerate(self._graph_parents):
|
| 581 |
+
if t is None or not tensor_util.is_tf_type(t):
|
| 582 |
+
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
|
| 583 |
+
|
| 584 |
+
@property
|
| 585 |
+
def graph_parents(self):
|
| 586 |
+
"""Returns this `Bijector`'s graph_parents as a Python list."""
|
| 587 |
+
return self._graph_parents
|
| 588 |
+
|
| 589 |
+
@property
|
| 590 |
+
def forward_min_event_ndims(self):
|
| 591 |
+
"""Returns the minimal number of dimensions bijector.forward operates on."""
|
| 592 |
+
return self._forward_min_event_ndims
|
| 593 |
+
|
| 594 |
+
@property
|
| 595 |
+
def inverse_min_event_ndims(self):
|
| 596 |
+
"""Returns the minimal number of dimensions bijector.inverse operates on."""
|
| 597 |
+
return self._inverse_min_event_ndims
|
| 598 |
+
|
| 599 |
+
@property
|
| 600 |
+
def is_constant_jacobian(self):
|
| 601 |
+
"""Returns true iff the Jacobian matrix is not a function of x.
|
| 602 |
+
|
| 603 |
+
Note: Jacobian matrix is either constant for both forward and inverse or
|
| 604 |
+
neither.
|
| 605 |
+
|
| 606 |
+
Returns:
|
| 607 |
+
is_constant_jacobian: Python `bool`.
|
| 608 |
+
"""
|
| 609 |
+
return self._is_constant_jacobian
|
| 610 |
+
|
| 611 |
+
@property
|
| 612 |
+
def _is_injective(self):
|
| 613 |
+
"""Returns true iff the forward map `g` is injective (one-to-one function).
|
| 614 |
+
|
| 615 |
+
**WARNING** This hidden property and its behavior are subject to change.
|
| 616 |
+
|
| 617 |
+
Note: Non-injective maps `g` are supported, provided their domain `D` can
|
| 618 |
+
be partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
|
| 619 |
+
ignoring sets of measure zero, the restriction of `g` to each subset is a
|
| 620 |
+
differentiable bijection onto `g(D)`.
|
| 621 |
+
|
| 622 |
+
Returns:
|
| 623 |
+
is_injective: Python `bool`.
|
| 624 |
+
"""
|
| 625 |
+
return True
|
| 626 |
+
|
| 627 |
+
@property
|
| 628 |
+
def validate_args(self):
|
| 629 |
+
"""Returns True if Tensor arguments will be validated."""
|
| 630 |
+
return self._validate_args
|
| 631 |
+
|
| 632 |
+
@property
|
| 633 |
+
def dtype(self):
|
| 634 |
+
"""dtype of `Tensor`s transformable by this distribution."""
|
| 635 |
+
return self._dtype
|
| 636 |
+
|
| 637 |
+
@property
|
| 638 |
+
def name(self):
|
| 639 |
+
"""Returns the string name of this `Bijector`."""
|
| 640 |
+
return self._name
|
| 641 |
+
|
| 642 |
+
def _forward_event_shape_tensor(self, input_shape):
|
| 643 |
+
"""Subclass implementation for `forward_event_shape_tensor` function."""
|
| 644 |
+
# By default, we assume event_shape is unchanged.
|
| 645 |
+
return input_shape
|
| 646 |
+
|
| 647 |
+
def forward_event_shape_tensor(self,
|
| 648 |
+
input_shape,
|
| 649 |
+
name="forward_event_shape_tensor"):
|
| 650 |
+
"""Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
|
| 651 |
+
|
| 652 |
+
Args:
|
| 653 |
+
input_shape: `Tensor`, `int32` vector indicating event-portion shape
|
| 654 |
+
passed into `forward` function.
|
| 655 |
+
name: name to give to the op
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
forward_event_shape_tensor: `Tensor`, `int32` vector indicating
|
| 659 |
+
event-portion shape after applying `forward`.
|
| 660 |
+
"""
|
| 661 |
+
with self._name_scope(name, [input_shape]):
|
| 662 |
+
input_shape = ops.convert_to_tensor(input_shape, dtype=dtypes.int32,
|
| 663 |
+
name="input_shape")
|
| 664 |
+
return self._forward_event_shape_tensor(input_shape)
|
| 665 |
+
|
| 666 |
+
def _forward_event_shape(self, input_shape):
|
| 667 |
+
"""Subclass implementation for `forward_event_shape` public function."""
|
| 668 |
+
# By default, we assume event_shape is unchanged.
|
| 669 |
+
return input_shape
|
| 670 |
+
|
| 671 |
+
def forward_event_shape(self, input_shape):
|
| 672 |
+
"""Shape of a single sample from a single batch as a `TensorShape`.
|
| 673 |
+
|
| 674 |
+
Same meaning as `forward_event_shape_tensor`. May be only partially defined.
|
| 675 |
+
|
| 676 |
+
Args:
|
| 677 |
+
input_shape: `TensorShape` indicating event-portion shape passed into
|
| 678 |
+
`forward` function.
|
| 679 |
+
|
| 680 |
+
Returns:
|
| 681 |
+
forward_event_shape_tensor: `TensorShape` indicating event-portion shape
|
| 682 |
+
after applying `forward`. Possibly unknown.
|
| 683 |
+
"""
|
| 684 |
+
return self._forward_event_shape(tensor_shape.TensorShape(input_shape))
|
| 685 |
+
|
| 686 |
+
def _inverse_event_shape_tensor(self, output_shape):
|
| 687 |
+
"""Subclass implementation for `inverse_event_shape_tensor` function."""
|
| 688 |
+
# By default, we assume event_shape is unchanged.
|
| 689 |
+
return output_shape
|
| 690 |
+
|
| 691 |
+
def inverse_event_shape_tensor(self,
|
| 692 |
+
output_shape,
|
| 693 |
+
name="inverse_event_shape_tensor"):
|
| 694 |
+
"""Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
|
| 695 |
+
|
| 696 |
+
Args:
|
| 697 |
+
output_shape: `Tensor`, `int32` vector indicating event-portion shape
|
| 698 |
+
passed into `inverse` function.
|
| 699 |
+
name: name to give to the op
|
| 700 |
+
|
| 701 |
+
Returns:
|
| 702 |
+
inverse_event_shape_tensor: `Tensor`, `int32` vector indicating
|
| 703 |
+
event-portion shape after applying `inverse`.
|
| 704 |
+
"""
|
| 705 |
+
with self._name_scope(name, [output_shape]):
|
| 706 |
+
output_shape = ops.convert_to_tensor(output_shape, dtype=dtypes.int32,
|
| 707 |
+
name="output_shape")
|
| 708 |
+
return self._inverse_event_shape_tensor(output_shape)
|
| 709 |
+
|
| 710 |
+
def _inverse_event_shape(self, output_shape):
|
| 711 |
+
"""Subclass implementation for `inverse_event_shape` public function."""
|
| 712 |
+
# By default, we assume event_shape is unchanged.
|
| 713 |
+
return tensor_shape.TensorShape(output_shape)
|
| 714 |
+
|
| 715 |
+
def inverse_event_shape(self, output_shape):
|
| 716 |
+
"""Shape of a single sample from a single batch as a `TensorShape`.
|
| 717 |
+
|
| 718 |
+
Same meaning as `inverse_event_shape_tensor`. May be only partially defined.
|
| 719 |
+
|
| 720 |
+
Args:
|
| 721 |
+
output_shape: `TensorShape` indicating event-portion shape passed into
|
| 722 |
+
`inverse` function.
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
inverse_event_shape_tensor: `TensorShape` indicating event-portion shape
|
| 726 |
+
after applying `inverse`. Possibly unknown.
|
| 727 |
+
"""
|
| 728 |
+
return self._inverse_event_shape(output_shape)
|
| 729 |
+
|
| 730 |
+
def _forward(self, x):
|
| 731 |
+
"""Subclass implementation for `forward` public function."""
|
| 732 |
+
raise NotImplementedError("forward not implemented.")
|
| 733 |
+
|
| 734 |
+
def _call_forward(self, x, name, **kwargs):
|
| 735 |
+
with self._name_scope(name, [x]):
|
| 736 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 737 |
+
self._maybe_assert_dtype(x)
|
| 738 |
+
if not self._is_injective: # No caching for non-injective
|
| 739 |
+
return self._forward(x, **kwargs)
|
| 740 |
+
mapping = self._lookup(x=x, kwargs=kwargs)
|
| 741 |
+
if mapping.y is not None:
|
| 742 |
+
return mapping.y
|
| 743 |
+
mapping = mapping.merge(y=self._forward(x, **kwargs))
|
| 744 |
+
self._cache(mapping)
|
| 745 |
+
return mapping.y
|
| 746 |
+
|
| 747 |
+
def forward(self, x, name="forward"):
|
| 748 |
+
"""Returns the forward `Bijector` evaluation, i.e., X = g(Y).
|
| 749 |
+
|
| 750 |
+
Args:
|
| 751 |
+
x: `Tensor`. The input to the "forward" evaluation.
|
| 752 |
+
name: The name to give this op.
|
| 753 |
+
|
| 754 |
+
Returns:
|
| 755 |
+
`Tensor`.
|
| 756 |
+
|
| 757 |
+
Raises:
|
| 758 |
+
TypeError: if `self.dtype` is specified and `x.dtype` is not
|
| 759 |
+
`self.dtype`.
|
| 760 |
+
NotImplementedError: if `_forward` is not implemented.
|
| 761 |
+
"""
|
| 762 |
+
return self._call_forward(x, name)
|
| 763 |
+
|
| 764 |
+
def _inverse(self, y):
|
| 765 |
+
"""Subclass implementation for `inverse` public function."""
|
| 766 |
+
raise NotImplementedError("inverse not implemented")
|
| 767 |
+
|
| 768 |
+
def _call_inverse(self, y, name, **kwargs):
|
| 769 |
+
with self._name_scope(name, [y]):
|
| 770 |
+
y = ops.convert_to_tensor(y, name="y")
|
| 771 |
+
self._maybe_assert_dtype(y)
|
| 772 |
+
if not self._is_injective: # No caching for non-injective
|
| 773 |
+
return self._inverse(y, **kwargs)
|
| 774 |
+
mapping = self._lookup(y=y, kwargs=kwargs)
|
| 775 |
+
if mapping.x is not None:
|
| 776 |
+
return mapping.x
|
| 777 |
+
mapping = mapping.merge(x=self._inverse(y, **kwargs))
|
| 778 |
+
self._cache(mapping)
|
| 779 |
+
return mapping.x
|
| 780 |
+
|
| 781 |
+
def inverse(self, y, name="inverse"):
|
| 782 |
+
"""Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y).
|
| 783 |
+
|
| 784 |
+
Args:
|
| 785 |
+
y: `Tensor`. The input to the "inverse" evaluation.
|
| 786 |
+
name: The name to give this op.
|
| 787 |
+
|
| 788 |
+
Returns:
|
| 789 |
+
`Tensor`, if this bijector is injective.
|
| 790 |
+
If not injective, returns the k-tuple containing the unique
|
| 791 |
+
`k` points `(x1, ..., xk)` such that `g(xi) = y`.
|
| 792 |
+
|
| 793 |
+
Raises:
|
| 794 |
+
TypeError: if `self.dtype` is specified and `y.dtype` is not
|
| 795 |
+
`self.dtype`.
|
| 796 |
+
NotImplementedError: if `_inverse` is not implemented.
|
| 797 |
+
"""
|
| 798 |
+
return self._call_inverse(y, name)
|
| 799 |
+
|
| 800 |
+
def _inverse_log_det_jacobian(self, y):
|
| 801 |
+
"""Subclass implementation of `inverse_log_det_jacobian` public function.
|
| 802 |
+
|
| 803 |
+
In particular, this method differs from the public function, in that it
|
| 804 |
+
does not take `event_ndims`. Thus, this implements the minimal Jacobian
|
| 805 |
+
determinant calculation (i.e. over `inverse_min_event_ndims`).
|
| 806 |
+
|
| 807 |
+
Args:
|
| 808 |
+
y: `Tensor`. The input to the "inverse_log_det_jacobian" evaluation.
|
| 809 |
+
Returns:
|
| 810 |
+
inverse_log_det_jacobian: `Tensor`, if this bijector is injective.
|
| 811 |
+
If not injective, returns the k-tuple containing jacobians for the
|
| 812 |
+
unique `k` points `(x1, ..., xk)` such that `g(xi) = y`.
|
| 813 |
+
"""
|
| 814 |
+
raise NotImplementedError("inverse_log_det_jacobian not implemented.")
|
| 815 |
+
|
| 816 |
+
def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
|
| 817 |
+
with self._name_scope(name, [y]):
|
| 818 |
+
if event_ndims in self._constant_ildj_map:
|
| 819 |
+
return self._constant_ildj_map[event_ndims]
|
| 820 |
+
y = ops.convert_to_tensor(y, name="y")
|
| 821 |
+
self._maybe_assert_dtype(y)
|
| 822 |
+
with ops.control_dependencies(self._check_valid_event_ndims(
|
| 823 |
+
min_event_ndims=self.inverse_min_event_ndims,
|
| 824 |
+
event_ndims=event_ndims)):
|
| 825 |
+
if not self._is_injective: # No caching for non-injective
|
| 826 |
+
try:
|
| 827 |
+
ildjs = self._inverse_log_det_jacobian(y, **kwargs)
|
| 828 |
+
return tuple(self._reduce_jacobian_det_over_event(
|
| 829 |
+
y, ildj, self.inverse_min_event_ndims, event_ndims)
|
| 830 |
+
for ildj in ildjs)
|
| 831 |
+
except NotImplementedError as original_exception:
|
| 832 |
+
try:
|
| 833 |
+
x = self._inverse(y, **kwargs)
|
| 834 |
+
fldjs = self._forward_log_det_jacobian(x, **kwargs)
|
| 835 |
+
return tuple(self._reduce_jacobian_det_over_event(
|
| 836 |
+
x, -fldj, self.forward_min_event_ndims, event_ndims)
|
| 837 |
+
for fldj in fldjs)
|
| 838 |
+
except NotImplementedError:
|
| 839 |
+
raise original_exception
|
| 840 |
+
|
| 841 |
+
mapping = self._lookup(y=y, kwargs=kwargs)
|
| 842 |
+
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
|
| 843 |
+
return mapping.ildj_map[event_ndims]
|
| 844 |
+
try:
|
| 845 |
+
x = None # Not needed; leave cache as is.
|
| 846 |
+
ildj = self._inverse_log_det_jacobian(y, **kwargs)
|
| 847 |
+
ildj = self._reduce_jacobian_det_over_event(
|
| 848 |
+
y, ildj, self.inverse_min_event_ndims, event_ndims)
|
| 849 |
+
except NotImplementedError as original_exception:
|
| 850 |
+
try:
|
| 851 |
+
x = (mapping.x if mapping.x is not None
|
| 852 |
+
else self._inverse(y, **kwargs))
|
| 853 |
+
ildj = -self._forward_log_det_jacobian(x, **kwargs)
|
| 854 |
+
ildj = self._reduce_jacobian_det_over_event(
|
| 855 |
+
x, ildj, self.forward_min_event_ndims, event_ndims)
|
| 856 |
+
except NotImplementedError:
|
| 857 |
+
raise original_exception
|
| 858 |
+
|
| 859 |
+
mapping = mapping.merge(x=x, ildj_map={event_ndims: ildj})
|
| 860 |
+
self._cache(mapping)
|
| 861 |
+
if self.is_constant_jacobian:
|
| 862 |
+
self._constant_ildj_map[event_ndims] = ildj
|
| 863 |
+
return ildj
|
| 864 |
+
|
| 865 |
+
def inverse_log_det_jacobian(
|
| 866 |
+
self, y, event_ndims, name="inverse_log_det_jacobian"):
|
| 867 |
+
"""Returns the (log o det o Jacobian o inverse)(y).
|
| 868 |
+
|
| 869 |
+
Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.)
|
| 870 |
+
|
| 871 |
+
Note that `forward_log_det_jacobian` is the negative of this function,
|
| 872 |
+
evaluated at `g^{-1}(y)`.
|
| 873 |
+
|
| 874 |
+
Args:
|
| 875 |
+
y: `Tensor`. The input to the "inverse" Jacobian determinant evaluation.
|
| 876 |
+
event_ndims: Number of dimensions in the probabilistic events being
|
| 877 |
+
transformed. Must be greater than or equal to
|
| 878 |
+
`self.inverse_min_event_ndims`. The result is summed over the final
|
| 879 |
+
dimensions to produce a scalar Jacobian determinant for each event,
|
| 880 |
+
i.e. it has shape `y.shape.ndims - event_ndims` dimensions.
|
| 881 |
+
name: The name to give this op.
|
| 882 |
+
|
| 883 |
+
Returns:
|
| 884 |
+
`Tensor`, if this bijector is injective.
|
| 885 |
+
If not injective, returns the tuple of local log det
|
| 886 |
+
Jacobians, `log(det(Dg_i^{-1}(y)))`, where `g_i` is the restriction
|
| 887 |
+
of `g` to the `ith` partition `Di`.
|
| 888 |
+
|
| 889 |
+
Raises:
|
| 890 |
+
TypeError: if `self.dtype` is specified and `y.dtype` is not
|
| 891 |
+
`self.dtype`.
|
| 892 |
+
NotImplementedError: if `_inverse_log_det_jacobian` is not implemented.
|
| 893 |
+
"""
|
| 894 |
+
return self._call_inverse_log_det_jacobian(y, event_ndims, name)
|
| 895 |
+
|
| 896 |
+
def _forward_log_det_jacobian(self, x):
|
| 897 |
+
"""Subclass implementation of `forward_log_det_jacobian` public function.
|
| 898 |
+
|
| 899 |
+
In particular, this method differs from the public function, in that it
|
| 900 |
+
does not take `event_ndims`. Thus, this implements the minimal Jacobian
|
| 901 |
+
determinant calculation (i.e. over `forward_min_event_ndims`).
|
| 902 |
+
|
| 903 |
+
Args:
|
| 904 |
+
x: `Tensor`. The input to the "forward_log_det_jacobian" evaluation.
|
| 905 |
+
|
| 906 |
+
Returns:
|
| 907 |
+
forward_log_det_jacobian: `Tensor`, if this bijector is injective.
|
| 908 |
+
If not injective, returns the k-tuple containing jacobians for the
|
| 909 |
+
unique `k` points `(x1, ..., xk)` such that `g(xi) = y`.
|
| 910 |
+
"""
|
| 911 |
+
|
| 912 |
+
raise NotImplementedError(
|
| 913 |
+
"forward_log_det_jacobian not implemented.")
|
| 914 |
+
|
| 915 |
+
def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
|
| 916 |
+
if not self._is_injective:
|
| 917 |
+
raise NotImplementedError(
|
| 918 |
+
"forward_log_det_jacobian cannot be implemented for non-injective "
|
| 919 |
+
"transforms.")
|
| 920 |
+
with self._name_scope(name, [x]):
|
| 921 |
+
with ops.control_dependencies(self._check_valid_event_ndims(
|
| 922 |
+
min_event_ndims=self.forward_min_event_ndims,
|
| 923 |
+
event_ndims=event_ndims)):
|
| 924 |
+
if event_ndims in self._constant_ildj_map:
|
| 925 |
+
# Need "-1. *" to avoid invalid-unary-operand-type linter warning.
|
| 926 |
+
return -1. * self._constant_ildj_map[event_ndims]
|
| 927 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 928 |
+
self._maybe_assert_dtype(x)
|
| 929 |
+
if not self._is_injective: # No caching for non-injective
|
| 930 |
+
try:
|
| 931 |
+
fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
|
| 932 |
+
return tuple(self._reduce_jacobian_det_over_event(
|
| 933 |
+
x, fldj, self.forward_min_event_ndims, event_ndims)
|
| 934 |
+
for fldj in fldjs)
|
| 935 |
+
except NotImplementedError as original_exception:
|
| 936 |
+
try:
|
| 937 |
+
y = self._forward(x, **kwargs)
|
| 938 |
+
ildjs = self._inverse_log_det_jacobian(y, **kwargs)
|
| 939 |
+
return tuple(self._reduce_jacobian_det_over_event(
|
| 940 |
+
y, -ildj, self.inverse_min_event_ndims, event_ndims)
|
| 941 |
+
for ildj in ildjs)
|
| 942 |
+
except NotImplementedError:
|
| 943 |
+
raise original_exception
|
| 944 |
+
mapping = self._lookup(x=x, kwargs=kwargs)
|
| 945 |
+
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
|
| 946 |
+
return -mapping.ildj_map[event_ndims]
|
| 947 |
+
try:
|
| 948 |
+
y = None # Not needed; leave cache as is.
|
| 949 |
+
ildj = -self._forward_log_det_jacobian(x, **kwargs)
|
| 950 |
+
ildj = self._reduce_jacobian_det_over_event(
|
| 951 |
+
x, ildj, self.forward_min_event_ndims, event_ndims)
|
| 952 |
+
except NotImplementedError as original_exception:
|
| 953 |
+
try:
|
| 954 |
+
y = (mapping.y if mapping.y is not None
|
| 955 |
+
else self._forward(x, **kwargs))
|
| 956 |
+
ildj = self._inverse_log_det_jacobian(y, **kwargs)
|
| 957 |
+
ildj = self._reduce_jacobian_det_over_event(
|
| 958 |
+
y, ildj, self.inverse_min_event_ndims, event_ndims)
|
| 959 |
+
except NotImplementedError:
|
| 960 |
+
raise original_exception
|
| 961 |
+
mapping = mapping.merge(y=y, ildj_map={event_ndims: ildj})
|
| 962 |
+
self._cache(mapping)
|
| 963 |
+
if self.is_constant_jacobian:
|
| 964 |
+
self._constant_ildj_map[event_ndims] = ildj
|
| 965 |
+
return -ildj
|
| 966 |
+
|
| 967 |
+
def forward_log_det_jacobian(
|
| 968 |
+
self, x, event_ndims, name="forward_log_det_jacobian"):
|
| 969 |
+
"""Returns both the forward_log_det_jacobian.
|
| 970 |
+
|
| 971 |
+
Args:
|
| 972 |
+
x: `Tensor`. The input to the "forward" Jacobian determinant evaluation.
|
| 973 |
+
event_ndims: Number of dimensions in the probabilistic events being
|
| 974 |
+
transformed. Must be greater than or equal to
|
| 975 |
+
`self.forward_min_event_ndims`. The result is summed over the final
|
| 976 |
+
dimensions to produce a scalar Jacobian determinant for each event,
|
| 977 |
+
i.e. it has shape `x.shape.ndims - event_ndims` dimensions.
|
| 978 |
+
name: The name to give this op.
|
| 979 |
+
|
| 980 |
+
Returns:
|
| 981 |
+
`Tensor`, if this bijector is injective.
|
| 982 |
+
If not injective this is not implemented.
|
| 983 |
+
|
| 984 |
+
Raises:
|
| 985 |
+
TypeError: if `self.dtype` is specified and `y.dtype` is not
|
| 986 |
+
`self.dtype`.
|
| 987 |
+
NotImplementedError: if neither `_forward_log_det_jacobian`
|
| 988 |
+
nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented, or
|
| 989 |
+
this is a non-injective bijector.
|
| 990 |
+
"""
|
| 991 |
+
return self._call_forward_log_det_jacobian(x, event_ndims, name)
|
| 992 |
+
|
| 993 |
+
@contextlib.contextmanager
|
| 994 |
+
def _name_scope(self, name=None, values=None):
|
| 995 |
+
"""Helper function to standardize op scope."""
|
| 996 |
+
with ops.name_scope(self.name):
|
| 997 |
+
with ops.name_scope(
|
| 998 |
+
name, values=(values or []) + self.graph_parents) as scope:
|
| 999 |
+
yield scope
|
| 1000 |
+
|
| 1001 |
+
def _maybe_assert_dtype(self, x):
|
| 1002 |
+
"""Helper to check dtype when self.dtype is known."""
|
| 1003 |
+
if self.dtype is not None and self.dtype.base_dtype != x.dtype.base_dtype:
|
| 1004 |
+
raise TypeError("Input had dtype %s but expected %s." %
|
| 1005 |
+
(self.dtype, x.dtype))
|
| 1006 |
+
|
| 1007 |
+
def _cache(self, mapping):
|
| 1008 |
+
"""Helper which stores mapping info in forward/inverse dicts."""
|
| 1009 |
+
# Merging from lookup is an added check that we're not overwriting anything
|
| 1010 |
+
# which is not None.
|
| 1011 |
+
mapping = mapping.merge(mapping=self._lookup(
|
| 1012 |
+
mapping.x, mapping.y, mapping.kwargs))
|
| 1013 |
+
if mapping.x is None and mapping.y is None:
|
| 1014 |
+
raise ValueError("Caching expects at least one of (x,y) to be known, "
|
| 1015 |
+
"i.e., not None.")
|
| 1016 |
+
self._from_x[mapping.x_key] = mapping
|
| 1017 |
+
self._from_y[mapping.y_key] = mapping
|
| 1018 |
+
|
| 1019 |
+
def _lookup(self, x=None, y=None, kwargs=None):
|
| 1020 |
+
"""Helper which retrieves mapping info from forward/inverse dicts."""
|
| 1021 |
+
mapping = _Mapping(x=x, y=y, kwargs=kwargs)
|
| 1022 |
+
# Since _cache requires both x,y to be set, we only need to do one cache
|
| 1023 |
+
# lookup since the mapping is always in both or neither.
|
| 1024 |
+
if mapping.x is not None:
|
| 1025 |
+
return self._from_x.get(mapping.x_key, mapping)
|
| 1026 |
+
if mapping.y is not None:
|
| 1027 |
+
return self._from_y.get(mapping.y_key, mapping)
|
| 1028 |
+
return mapping
|
| 1029 |
+
|
| 1030 |
+
def _reduce_jacobian_det_over_event(
|
| 1031 |
+
self, y, ildj, min_event_ndims, event_ndims):
|
| 1032 |
+
"""Reduce jacobian over event_ndims - min_event_ndims."""
|
| 1033 |
+
# In this case, we need to tile the Jacobian over the event and reduce.
|
| 1034 |
+
y_rank = array_ops.rank(y)
|
| 1035 |
+
y_shape = array_ops.shape(y)[
|
| 1036 |
+
y_rank - event_ndims : y_rank - min_event_ndims]
|
| 1037 |
+
|
| 1038 |
+
ones = array_ops.ones(y_shape, ildj.dtype)
|
| 1039 |
+
reduced_ildj = math_ops.reduce_sum(
|
| 1040 |
+
ones * ildj,
|
| 1041 |
+
axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
|
| 1042 |
+
# The multiplication by ones can change the inferred static shape so we try
|
| 1043 |
+
# to recover as much as possible.
|
| 1044 |
+
event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
|
| 1045 |
+
if (event_ndims_ is not None and
|
| 1046 |
+
y.shape.ndims is not None and
|
| 1047 |
+
ildj.shape.ndims is not None):
|
| 1048 |
+
y_shape = y.shape[y.shape.ndims - event_ndims_ :
|
| 1049 |
+
y.shape.ndims - min_event_ndims]
|
| 1050 |
+
broadcast_shape = array_ops.broadcast_static_shape(ildj.shape, y_shape)
|
| 1051 |
+
reduced_ildj.set_shape(
|
| 1052 |
+
broadcast_shape[: broadcast_shape.ndims - (
|
| 1053 |
+
event_ndims_ - min_event_ndims)])
|
| 1054 |
+
|
| 1055 |
+
return reduced_ildj
|
| 1056 |
+
|
| 1057 |
+
def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
|
| 1058 |
+
"""Compute the reduction dimensions given event_ndims."""
|
| 1059 |
+
event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
|
| 1060 |
+
|
| 1061 |
+
if event_ndims_ is not None:
|
| 1062 |
+
return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)]
|
| 1063 |
+
else:
|
| 1064 |
+
reduce_ndims = event_ndims - min_event_ndims
|
| 1065 |
+
return math_ops.range(-reduce_ndims, 0)
|
| 1066 |
+
|
| 1067 |
+
def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
|
| 1068 |
+
"""Check whether event_ndims is at least min_event_ndims."""
|
| 1069 |
+
event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
|
| 1070 |
+
event_ndims_ = tensor_util.constant_value(event_ndims)
|
| 1071 |
+
assertions = []
|
| 1072 |
+
|
| 1073 |
+
if not event_ndims.dtype.is_integer:
|
| 1074 |
+
raise ValueError("Expected integer dtype, got dtype {}".format(
|
| 1075 |
+
event_ndims.dtype))
|
| 1076 |
+
|
| 1077 |
+
if event_ndims_ is not None:
|
| 1078 |
+
if event_ndims.shape.ndims != 0:
|
| 1079 |
+
raise ValueError("Expected scalar event_ndims, got shape {}".format(
|
| 1080 |
+
event_ndims.shape))
|
| 1081 |
+
if min_event_ndims > event_ndims_:
|
| 1082 |
+
raise ValueError("event_ndims ({}) must be larger than "
|
| 1083 |
+
"min_event_ndims ({})".format(
|
| 1084 |
+
event_ndims_, min_event_ndims))
|
| 1085 |
+
elif self.validate_args:
|
| 1086 |
+
assertions += [
|
| 1087 |
+
check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
|
| 1088 |
+
|
| 1089 |
+
if event_ndims.shape.is_fully_defined():
|
| 1090 |
+
if event_ndims.shape.ndims != 0:
|
| 1091 |
+
raise ValueError("Expected scalar shape, got ndims {}".format(
|
| 1092 |
+
event_ndims.shape.ndims))
|
| 1093 |
+
|
| 1094 |
+
elif self.validate_args:
|
| 1095 |
+
assertions += [
|
| 1096 |
+
check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")]
|
| 1097 |
+
return assertions
|
| 1098 |
+
|
| 1099 |
+
def _maybe_get_static_event_ndims(self, event_ndims):
|
| 1100 |
+
"""Helper which returns tries to return an integer static value."""
|
| 1101 |
+
event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
|
| 1102 |
+
|
| 1103 |
+
if isinstance(event_ndims_, (np.generic, np.ndarray)):
|
| 1104 |
+
if event_ndims_.dtype not in (np.int32, np.int64):
|
| 1105 |
+
raise ValueError("Expected integer dtype, got dtype {}".format(
|
| 1106 |
+
event_ndims_.dtype))
|
| 1107 |
+
|
| 1108 |
+
if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape):
|
| 1109 |
+
raise ValueError("Expected a scalar integer, got {}".format(
|
| 1110 |
+
event_ndims_))
|
| 1111 |
+
event_ndims_ = int(event_ndims_)
|
| 1112 |
+
|
| 1113 |
+
return event_ndims_
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/bijector_test_util.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Bijector unit-test utilities."""
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from tensorflow.python.framework import ops
|
| 20 |
+
from tensorflow.python.ops import math_ops
|
| 21 |
+
from tensorflow.python.ops.distributions import uniform as uniform_lib
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def assert_finite(array):
|
| 25 |
+
if not np.isfinite(array).all():
|
| 26 |
+
raise AssertionError("array was not all finite. %s" % array[:15])
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def assert_strictly_increasing(array):
|
| 30 |
+
np.testing.assert_array_less(0., np.diff(array))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def assert_strictly_decreasing(array):
|
| 34 |
+
np.testing.assert_array_less(np.diff(array), 0.)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def assert_strictly_monotonic(array):
|
| 38 |
+
if array[0] < array[-1]:
|
| 39 |
+
assert_strictly_increasing(array)
|
| 40 |
+
else:
|
| 41 |
+
assert_strictly_decreasing(array)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def assert_scalar_congruency(bijector,
|
| 45 |
+
lower_x,
|
| 46 |
+
upper_x,
|
| 47 |
+
n=int(10e3),
|
| 48 |
+
rtol=0.01,
|
| 49 |
+
sess=None):
|
| 50 |
+
"""Assert `bijector`'s forward/inverse/inverse_log_det_jacobian are congruent.
|
| 51 |
+
|
| 52 |
+
We draw samples `X ~ U(lower_x, upper_x)`, then feed these through the
|
| 53 |
+
`bijector` in order to check that:
|
| 54 |
+
|
| 55 |
+
1. the forward is strictly monotonic.
|
| 56 |
+
2. the forward/inverse methods are inverses of each other.
|
| 57 |
+
3. the jacobian is the correct change of measure.
|
| 58 |
+
|
| 59 |
+
This can only be used for a Bijector mapping open subsets of the real line
|
| 60 |
+
to themselves. This is due to the fact that this test compares the `prob`
|
| 61 |
+
before/after transformation with the Lebesgue measure on the line.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
bijector: Instance of Bijector
|
| 65 |
+
lower_x: Python scalar.
|
| 66 |
+
upper_x: Python scalar. Must have `lower_x < upper_x`, and both must be in
|
| 67 |
+
the domain of the `bijector`. The `bijector` should probably not produce
|
| 68 |
+
huge variation in values in the interval `(lower_x, upper_x)`, or else
|
| 69 |
+
the variance based check of the Jacobian will require small `rtol` or
|
| 70 |
+
huge `n`.
|
| 71 |
+
n: Number of samples to draw for the checks.
|
| 72 |
+
rtol: Positive number. Used for the Jacobian check.
|
| 73 |
+
sess: `tf.compat.v1.Session`. Defaults to the default session.
|
| 74 |
+
|
| 75 |
+
Raises:
|
| 76 |
+
AssertionError: If tests fail.
|
| 77 |
+
"""
|
| 78 |
+
# Checks and defaults.
|
| 79 |
+
if sess is None:
|
| 80 |
+
sess = ops.get_default_session()
|
| 81 |
+
|
| 82 |
+
# Should be monotonic over this interval
|
| 83 |
+
ten_x_pts = np.linspace(lower_x, upper_x, num=10).astype(np.float32)
|
| 84 |
+
if bijector.dtype is not None:
|
| 85 |
+
ten_x_pts = ten_x_pts.astype(bijector.dtype.as_numpy_dtype)
|
| 86 |
+
forward_on_10_pts = bijector.forward(ten_x_pts)
|
| 87 |
+
|
| 88 |
+
# Set the lower/upper limits in the range of the bijector.
|
| 89 |
+
lower_y, upper_y = sess.run(
|
| 90 |
+
[bijector.forward(lower_x), bijector.forward(upper_x)])
|
| 91 |
+
if upper_y < lower_y: # If bijector.forward is a decreasing function.
|
| 92 |
+
lower_y, upper_y = upper_y, lower_y
|
| 93 |
+
|
| 94 |
+
# Uniform samples from the domain, range.
|
| 95 |
+
uniform_x_samps = uniform_lib.Uniform(
|
| 96 |
+
low=lower_x, high=upper_x).sample(n, seed=0)
|
| 97 |
+
uniform_y_samps = uniform_lib.Uniform(
|
| 98 |
+
low=lower_y, high=upper_y).sample(n, seed=1)
|
| 99 |
+
|
| 100 |
+
# These compositions should be the identity.
|
| 101 |
+
inverse_forward_x = bijector.inverse(bijector.forward(uniform_x_samps))
|
| 102 |
+
forward_inverse_y = bijector.forward(bijector.inverse(uniform_y_samps))
|
| 103 |
+
|
| 104 |
+
# For a < b, and transformation y = y(x),
|
| 105 |
+
# (b - a) = \int_a^b dx = \int_{y(a)}^{y(b)} |dx/dy| dy
|
| 106 |
+
# "change_measure_dy_dx" below is a Monte Carlo approximation to the right
|
| 107 |
+
# hand side, which should then be close to the left, which is (b - a).
|
| 108 |
+
# We assume event_ndims=0 because we assume scalar -> scalar. The log_det
|
| 109 |
+
# methods will handle whether they expect event_ndims > 0.
|
| 110 |
+
dy_dx = math_ops.exp(bijector.inverse_log_det_jacobian(
|
| 111 |
+
uniform_y_samps, event_ndims=0))
|
| 112 |
+
# E[|dx/dy|] under Uniform[lower_y, upper_y]
|
| 113 |
+
# = \int_{y(a)}^{y(b)} |dx/dy| dP(u), where dP(u) is the uniform measure
|
| 114 |
+
expectation_of_dy_dx_under_uniform = math_ops.reduce_mean(dy_dx)
|
| 115 |
+
# dy = dP(u) * (upper_y - lower_y)
|
| 116 |
+
change_measure_dy_dx = (
|
| 117 |
+
(upper_y - lower_y) * expectation_of_dy_dx_under_uniform)
|
| 118 |
+
|
| 119 |
+
# We'll also check that dy_dx = 1 / dx_dy.
|
| 120 |
+
dx_dy = math_ops.exp(
|
| 121 |
+
bijector.forward_log_det_jacobian(
|
| 122 |
+
bijector.inverse(uniform_y_samps), event_ndims=0))
|
| 123 |
+
|
| 124 |
+
[
|
| 125 |
+
forward_on_10_pts_v,
|
| 126 |
+
dy_dx_v,
|
| 127 |
+
dx_dy_v,
|
| 128 |
+
change_measure_dy_dx_v,
|
| 129 |
+
uniform_x_samps_v,
|
| 130 |
+
uniform_y_samps_v,
|
| 131 |
+
inverse_forward_x_v,
|
| 132 |
+
forward_inverse_y_v,
|
| 133 |
+
] = sess.run([
|
| 134 |
+
forward_on_10_pts,
|
| 135 |
+
dy_dx,
|
| 136 |
+
dx_dy,
|
| 137 |
+
change_measure_dy_dx,
|
| 138 |
+
uniform_x_samps,
|
| 139 |
+
uniform_y_samps,
|
| 140 |
+
inverse_forward_x,
|
| 141 |
+
forward_inverse_y,
|
| 142 |
+
])
|
| 143 |
+
|
| 144 |
+
assert_strictly_monotonic(forward_on_10_pts_v)
|
| 145 |
+
# Composition of forward/inverse should be the identity.
|
| 146 |
+
np.testing.assert_allclose(
|
| 147 |
+
inverse_forward_x_v, uniform_x_samps_v, atol=1e-5, rtol=1e-3)
|
| 148 |
+
np.testing.assert_allclose(
|
| 149 |
+
forward_inverse_y_v, uniform_y_samps_v, atol=1e-5, rtol=1e-3)
|
| 150 |
+
# Change of measure should be correct.
|
| 151 |
+
np.testing.assert_allclose(
|
| 152 |
+
upper_x - lower_x, change_measure_dy_dx_v, atol=0, rtol=rtol)
|
| 153 |
+
# Inverse Jacobian should be equivalent to the reciprocal of the forward
|
| 154 |
+
# Jacobian.
|
| 155 |
+
np.testing.assert_allclose(
|
| 156 |
+
dy_dx_v, np.divide(1., dx_dy_v), atol=1e-5, rtol=1e-3)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def assert_bijective_and_finite(
|
| 160 |
+
bijector, x, y, event_ndims, atol=0, rtol=1e-5, sess=None):
|
| 161 |
+
"""Assert that forward/inverse (along with jacobians) are inverses and finite.
|
| 162 |
+
|
| 163 |
+
It is recommended to use x and y values that are very very close to the edge
|
| 164 |
+
of the Bijector's domain.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
bijector: A Bijector instance.
|
| 168 |
+
x: np.array of values in the domain of bijector.forward.
|
| 169 |
+
y: np.array of values in the domain of bijector.inverse.
|
| 170 |
+
event_ndims: Integer describing the number of event dimensions this bijector
|
| 171 |
+
operates on.
|
| 172 |
+
atol: Absolute tolerance.
|
| 173 |
+
rtol: Relative tolerance.
|
| 174 |
+
sess: TensorFlow session. Defaults to the default session.
|
| 175 |
+
|
| 176 |
+
Raises:
|
| 177 |
+
AssertionError: If tests fail.
|
| 178 |
+
"""
|
| 179 |
+
sess = sess or ops.get_default_session()
|
| 180 |
+
|
| 181 |
+
# These are the incoming points, but people often create a crazy range of
|
| 182 |
+
# values for which these end up being bad, especially in 16bit.
|
| 183 |
+
assert_finite(x)
|
| 184 |
+
assert_finite(y)
|
| 185 |
+
|
| 186 |
+
f_x = bijector.forward(x)
|
| 187 |
+
g_y = bijector.inverse(y)
|
| 188 |
+
|
| 189 |
+
[
|
| 190 |
+
x_from_x,
|
| 191 |
+
y_from_y,
|
| 192 |
+
ildj_f_x,
|
| 193 |
+
fldj_x,
|
| 194 |
+
ildj_y,
|
| 195 |
+
fldj_g_y,
|
| 196 |
+
f_x_v,
|
| 197 |
+
g_y_v,
|
| 198 |
+
] = sess.run([
|
| 199 |
+
bijector.inverse(f_x),
|
| 200 |
+
bijector.forward(g_y),
|
| 201 |
+
bijector.inverse_log_det_jacobian(f_x, event_ndims=event_ndims),
|
| 202 |
+
bijector.forward_log_det_jacobian(x, event_ndims=event_ndims),
|
| 203 |
+
bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims),
|
| 204 |
+
bijector.forward_log_det_jacobian(g_y, event_ndims=event_ndims),
|
| 205 |
+
f_x,
|
| 206 |
+
g_y,
|
| 207 |
+
])
|
| 208 |
+
|
| 209 |
+
assert_finite(x_from_x)
|
| 210 |
+
assert_finite(y_from_y)
|
| 211 |
+
assert_finite(ildj_f_x)
|
| 212 |
+
assert_finite(fldj_x)
|
| 213 |
+
assert_finite(ildj_y)
|
| 214 |
+
assert_finite(fldj_g_y)
|
| 215 |
+
assert_finite(f_x_v)
|
| 216 |
+
assert_finite(g_y_v)
|
| 217 |
+
|
| 218 |
+
np.testing.assert_allclose(x_from_x, x, atol=atol, rtol=rtol)
|
| 219 |
+
np.testing.assert_allclose(y_from_y, y, atol=atol, rtol=rtol)
|
| 220 |
+
np.testing.assert_allclose(-ildj_f_x, fldj_x, atol=atol, rtol=rtol)
|
| 221 |
+
np.testing.assert_allclose(-ildj_y, fldj_g_y, atol=atol, rtol=rtol)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/categorical.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Categorical distribution class."""
|
| 16 |
+
|
| 17 |
+
from tensorflow.python.framework import constant_op
|
| 18 |
+
from tensorflow.python.framework import dtypes
|
| 19 |
+
from tensorflow.python.framework import ops
|
| 20 |
+
from tensorflow.python.framework import tensor_shape
|
| 21 |
+
from tensorflow.python.ops import array_ops
|
| 22 |
+
from tensorflow.python.ops import math_ops
|
| 23 |
+
from tensorflow.python.ops import nn_ops
|
| 24 |
+
from tensorflow.python.ops import random_ops
|
| 25 |
+
from tensorflow.python.ops.distributions import distribution
|
| 26 |
+
from tensorflow.python.ops.distributions import kullback_leibler
|
| 27 |
+
from tensorflow.python.ops.distributions import util as distribution_util
|
| 28 |
+
from tensorflow.python.util import deprecation
|
| 29 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _broadcast_cat_event_and_params(event, params, base_dtype):
|
| 33 |
+
"""Broadcasts the event or distribution parameters."""
|
| 34 |
+
if event.dtype.is_integer:
|
| 35 |
+
pass
|
| 36 |
+
elif event.dtype.is_floating:
|
| 37 |
+
# When `validate_args=True` we've already ensured int/float casting
|
| 38 |
+
# is closed.
|
| 39 |
+
event = math_ops.cast(event, dtype=dtypes.int32)
|
| 40 |
+
else:
|
| 41 |
+
raise TypeError("`value` should have integer `dtype` or "
|
| 42 |
+
"`self.dtype` ({})".format(base_dtype))
|
| 43 |
+
shape_known_statically = (
|
| 44 |
+
params.shape.ndims is not None and
|
| 45 |
+
params.shape[:-1].is_fully_defined() and
|
| 46 |
+
event.shape.is_fully_defined())
|
| 47 |
+
if not shape_known_statically or params.shape[:-1] != event.shape:
|
| 48 |
+
params *= array_ops.ones_like(event[..., array_ops.newaxis],
|
| 49 |
+
dtype=params.dtype)
|
| 50 |
+
params_shape = array_ops.shape(params)[:-1]
|
| 51 |
+
event *= array_ops.ones(params_shape, dtype=event.dtype)
|
| 52 |
+
if params.shape.ndims is not None:
|
| 53 |
+
event.set_shape(tensor_shape.TensorShape(params.shape[:-1]))
|
| 54 |
+
|
| 55 |
+
return event, params
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@tf_export(v1=["distributions.Categorical"])
|
| 59 |
+
class Categorical(distribution.Distribution):
|
| 60 |
+
"""Categorical distribution.
|
| 61 |
+
|
| 62 |
+
The Categorical distribution is parameterized by either probabilities or
|
| 63 |
+
log-probabilities of a set of `K` classes. It is defined over the integers
|
| 64 |
+
`{0, 1, ..., K}`.
|
| 65 |
+
|
| 66 |
+
The Categorical distribution is closely related to the `OneHotCategorical` and
|
| 67 |
+
`Multinomial` distributions. The Categorical distribution can be intuited as
|
| 68 |
+
generating samples according to `argmax{ OneHotCategorical(probs) }` itself
|
| 69 |
+
being identical to `argmax{ Multinomial(probs, total_count=1) }`.
|
| 70 |
+
|
| 71 |
+
#### Mathematical Details
|
| 72 |
+
|
| 73 |
+
The probability mass function (pmf) is,
|
| 74 |
+
|
| 75 |
+
```none
|
| 76 |
+
pmf(k; pi) = prod_j pi_j**[k == j]
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
#### Pitfalls
|
| 80 |
+
|
| 81 |
+
The number of classes, `K`, must not exceed:
|
| 82 |
+
- the largest integer representable by `self.dtype`, i.e.,
|
| 83 |
+
`2**(mantissa_bits+1)` (IEEE 754),
|
| 84 |
+
- the maximum `Tensor` index, i.e., `2**31-1`.
|
| 85 |
+
|
| 86 |
+
In other words,
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
K <= min(2**31-1, {
|
| 90 |
+
tf.float16: 2**11,
|
| 91 |
+
tf.float32: 2**24,
|
| 92 |
+
tf.float64: 2**53 }[param.dtype])
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Note: This condition is validated only when `self.validate_args = True`.
|
| 96 |
+
|
| 97 |
+
#### Examples
|
| 98 |
+
|
| 99 |
+
Creates a 3-class distribution with the 2nd class being most likely.
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
dist = Categorical(probs=[0.1, 0.5, 0.4])
|
| 103 |
+
n = 1e4
|
| 104 |
+
empirical_prob = tf.cast(
|
| 105 |
+
tf.histogram_fixed_width(
|
| 106 |
+
dist.sample(int(n)),
|
| 107 |
+
[0., 2],
|
| 108 |
+
nbins=3),
|
| 109 |
+
dtype=tf.float32) / n
|
| 110 |
+
# ==> array([ 0.1005, 0.5037, 0.3958], dtype=float32)
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
Creates a 3-class distribution with the 2nd class being most likely.
|
| 114 |
+
Parameterized by [logits](https://en.wikipedia.org/wiki/Logit) rather than
|
| 115 |
+
probabilities.
|
| 116 |
+
|
| 117 |
+
```python
|
| 118 |
+
dist = Categorical(logits=np.log([0.1, 0.5, 0.4])
|
| 119 |
+
n = 1e4
|
| 120 |
+
empirical_prob = tf.cast(
|
| 121 |
+
tf.histogram_fixed_width(
|
| 122 |
+
dist.sample(int(n)),
|
| 123 |
+
[0., 2],
|
| 124 |
+
nbins=3),
|
| 125 |
+
dtype=tf.float32) / n
|
| 126 |
+
# ==> array([0.1045, 0.5047, 0.3908], dtype=float32)
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
Creates a 3-class distribution with the 3rd class being most likely.
|
| 130 |
+
The distribution functions can be evaluated on counts.
|
| 131 |
+
|
| 132 |
+
```python
|
| 133 |
+
# counts is a scalar.
|
| 134 |
+
p = [0.1, 0.4, 0.5]
|
| 135 |
+
dist = Categorical(probs=p)
|
| 136 |
+
dist.prob(0) # Shape []
|
| 137 |
+
|
| 138 |
+
# p will be broadcast to [[0.1, 0.4, 0.5], [0.1, 0.4, 0.5]] to match counts.
|
| 139 |
+
counts = [1, 0]
|
| 140 |
+
dist.prob(counts) # Shape [2]
|
| 141 |
+
|
| 142 |
+
# p will be broadcast to shape [3, 5, 7, 3] to match counts.
|
| 143 |
+
counts = [[...]] # Shape [5, 7, 3]
|
| 144 |
+
dist.prob(counts) # Shape [5, 7, 3]
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
@deprecation.deprecated(
|
| 150 |
+
"2019-01-01",
|
| 151 |
+
"The TensorFlow Distributions library has moved to "
|
| 152 |
+
"TensorFlow Probability "
|
| 153 |
+
"(https://github.com/tensorflow/probability). You "
|
| 154 |
+
"should update all references to use `tfp.distributions` "
|
| 155 |
+
"instead of `tf.distributions`.",
|
| 156 |
+
warn_once=True)
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
logits=None,
|
| 160 |
+
probs=None,
|
| 161 |
+
dtype=dtypes.int32,
|
| 162 |
+
validate_args=False,
|
| 163 |
+
allow_nan_stats=True,
|
| 164 |
+
name="Categorical"):
|
| 165 |
+
"""Initialize Categorical distributions using class log-probabilities.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
|
| 169 |
+
of a set of Categorical distributions. The first `N - 1` dimensions
|
| 170 |
+
index into a batch of independent distributions and the last dimension
|
| 171 |
+
represents a vector of logits for each class. Only one of `logits` or
|
| 172 |
+
`probs` should be passed in.
|
| 173 |
+
probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
|
| 174 |
+
of a set of Categorical distributions. The first `N - 1` dimensions
|
| 175 |
+
index into a batch of independent distributions and the last dimension
|
| 176 |
+
represents a vector of probabilities for each class. Only one of
|
| 177 |
+
`logits` or `probs` should be passed in.
|
| 178 |
+
dtype: The type of the event samples (default: int32).
|
| 179 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 180 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 181 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 182 |
+
outputs.
|
| 183 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
| 184 |
+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
| 185 |
+
result is undefined. When `False`, an exception is raised if one or
|
| 186 |
+
more of the statistic's batch members are undefined.
|
| 187 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 188 |
+
"""
|
| 189 |
+
parameters = dict(locals())
|
| 190 |
+
with ops.name_scope(name, values=[logits, probs]) as name:
|
| 191 |
+
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
| 192 |
+
logits=logits,
|
| 193 |
+
probs=probs,
|
| 194 |
+
validate_args=validate_args,
|
| 195 |
+
multidimensional=True,
|
| 196 |
+
name=name)
|
| 197 |
+
|
| 198 |
+
if validate_args:
|
| 199 |
+
self._logits = distribution_util.embed_check_categorical_event_shape(
|
| 200 |
+
self._logits)
|
| 201 |
+
|
| 202 |
+
logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
|
| 203 |
+
if logits_shape_static.ndims is not None:
|
| 204 |
+
self._batch_rank = ops.convert_to_tensor(
|
| 205 |
+
logits_shape_static.ndims - 1,
|
| 206 |
+
dtype=dtypes.int32,
|
| 207 |
+
name="batch_rank")
|
| 208 |
+
else:
|
| 209 |
+
with ops.name_scope(name="batch_rank"):
|
| 210 |
+
self._batch_rank = array_ops.rank(self._logits) - 1
|
| 211 |
+
|
| 212 |
+
logits_shape = array_ops.shape(self._logits, name="logits_shape")
|
| 213 |
+
if tensor_shape.dimension_value(logits_shape_static[-1]) is not None:
|
| 214 |
+
self._event_size = ops.convert_to_tensor(
|
| 215 |
+
logits_shape_static.dims[-1].value,
|
| 216 |
+
dtype=dtypes.int32,
|
| 217 |
+
name="event_size")
|
| 218 |
+
else:
|
| 219 |
+
with ops.name_scope(name="event_size"):
|
| 220 |
+
self._event_size = logits_shape[self._batch_rank]
|
| 221 |
+
|
| 222 |
+
if logits_shape_static[:-1].is_fully_defined():
|
| 223 |
+
self._batch_shape_val = constant_op.constant(
|
| 224 |
+
logits_shape_static[:-1].as_list(),
|
| 225 |
+
dtype=dtypes.int32,
|
| 226 |
+
name="batch_shape")
|
| 227 |
+
else:
|
| 228 |
+
with ops.name_scope(name="batch_shape"):
|
| 229 |
+
self._batch_shape_val = logits_shape[:-1]
|
| 230 |
+
super(Categorical, self).__init__(
|
| 231 |
+
dtype=dtype,
|
| 232 |
+
reparameterization_type=distribution.NOT_REPARAMETERIZED,
|
| 233 |
+
validate_args=validate_args,
|
| 234 |
+
allow_nan_stats=allow_nan_stats,
|
| 235 |
+
parameters=parameters,
|
| 236 |
+
graph_parents=[self._logits,
|
| 237 |
+
self._probs],
|
| 238 |
+
name=name)
|
| 239 |
+
|
| 240 |
+
@property
|
| 241 |
+
def event_size(self):
|
| 242 |
+
"""Scalar `int32` tensor: the number of classes."""
|
| 243 |
+
return self._event_size
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def logits(self):
|
| 247 |
+
"""Vector of coordinatewise logits."""
|
| 248 |
+
return self._logits
|
| 249 |
+
|
| 250 |
+
@property
|
| 251 |
+
def probs(self):
|
| 252 |
+
"""Vector of coordinatewise probabilities."""
|
| 253 |
+
return self._probs
|
| 254 |
+
|
| 255 |
+
def _batch_shape_tensor(self):
|
| 256 |
+
return array_ops.identity(self._batch_shape_val)
|
| 257 |
+
|
| 258 |
+
def _batch_shape(self):
|
| 259 |
+
return self.logits.get_shape()[:-1]
|
| 260 |
+
|
| 261 |
+
def _event_shape_tensor(self):
|
| 262 |
+
return constant_op.constant([], dtype=dtypes.int32)
|
| 263 |
+
|
| 264 |
+
def _event_shape(self):
|
| 265 |
+
return tensor_shape.TensorShape([])
|
| 266 |
+
|
| 267 |
+
def _sample_n(self, n, seed=None):
|
| 268 |
+
if self.logits.get_shape().ndims == 2:
|
| 269 |
+
logits_2d = self.logits
|
| 270 |
+
else:
|
| 271 |
+
logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
|
| 272 |
+
sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32
|
| 273 |
+
draws = random_ops.multinomial(
|
| 274 |
+
logits_2d, n, seed=seed, output_dtype=sample_dtype)
|
| 275 |
+
draws = array_ops.reshape(
|
| 276 |
+
array_ops.transpose(draws),
|
| 277 |
+
array_ops.concat([[n], self.batch_shape_tensor()], 0))
|
| 278 |
+
return math_ops.cast(draws, self.dtype)
|
| 279 |
+
|
| 280 |
+
def _cdf(self, k):
|
| 281 |
+
k = ops.convert_to_tensor(k, name="k")
|
| 282 |
+
if self.validate_args:
|
| 283 |
+
k = distribution_util.embed_check_integer_casting_closed(
|
| 284 |
+
k, target_dtype=dtypes.int32)
|
| 285 |
+
|
| 286 |
+
k, probs = _broadcast_cat_event_and_params(
|
| 287 |
+
k, self.probs, base_dtype=self.dtype.base_dtype)
|
| 288 |
+
|
| 289 |
+
# batch-flatten everything in order to use `sequence_mask()`.
|
| 290 |
+
batch_flattened_probs = array_ops.reshape(probs,
|
| 291 |
+
(-1, self._event_size))
|
| 292 |
+
batch_flattened_k = array_ops.reshape(k, [-1])
|
| 293 |
+
|
| 294 |
+
to_sum_over = array_ops.where(
|
| 295 |
+
array_ops.sequence_mask(batch_flattened_k, self._event_size),
|
| 296 |
+
batch_flattened_probs,
|
| 297 |
+
array_ops.zeros_like(batch_flattened_probs))
|
| 298 |
+
batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
|
| 299 |
+
# Reshape back to the shape of the argument.
|
| 300 |
+
return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k))
|
| 301 |
+
|
| 302 |
+
def _log_prob(self, k):
|
| 303 |
+
k = ops.convert_to_tensor(k, name="k")
|
| 304 |
+
if self.validate_args:
|
| 305 |
+
k = distribution_util.embed_check_integer_casting_closed(
|
| 306 |
+
k, target_dtype=dtypes.int32)
|
| 307 |
+
k, logits = _broadcast_cat_event_and_params(
|
| 308 |
+
k, self.logits, base_dtype=self.dtype.base_dtype)
|
| 309 |
+
|
| 310 |
+
# pylint: disable=invalid-unary-operand-type
|
| 311 |
+
return -nn_ops.sparse_softmax_cross_entropy_with_logits(
|
| 312 |
+
labels=k,
|
| 313 |
+
logits=logits)
|
| 314 |
+
|
| 315 |
+
def _entropy(self):
|
| 316 |
+
return -math_ops.reduce_sum(
|
| 317 |
+
nn_ops.log_softmax(self.logits) * self.probs, axis=-1)
|
| 318 |
+
|
| 319 |
+
def _mode(self):
|
| 320 |
+
ret = math_ops.argmax(self.logits, axis=self._batch_rank)
|
| 321 |
+
ret = math_ops.cast(ret, self.dtype)
|
| 322 |
+
ret.set_shape(self.batch_shape)
|
| 323 |
+
return ret
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@kullback_leibler.RegisterKL(Categorical, Categorical)
|
| 327 |
+
def _kl_categorical_categorical(a, b, name=None):
|
| 328 |
+
"""Calculate the batched KL divergence KL(a || b) with a and b Categorical.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
a: instance of a Categorical distribution object.
|
| 332 |
+
b: instance of a Categorical distribution object.
|
| 333 |
+
name: (optional) Name to use for created operations.
|
| 334 |
+
default is "kl_categorical_categorical".
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Batchwise KL(a || b)
|
| 338 |
+
"""
|
| 339 |
+
with ops.name_scope(name, "kl_categorical_categorical",
|
| 340 |
+
values=[a.logits, b.logits]):
|
| 341 |
+
# sum(probs log(probs / (1 - probs)))
|
| 342 |
+
delta_log_probs1 = (nn_ops.log_softmax(a.logits) -
|
| 343 |
+
nn_ops.log_softmax(b.logits))
|
| 344 |
+
return math_ops.reduce_sum(nn_ops.softmax(a.logits) * delta_log_probs1,
|
| 345 |
+
axis=-1)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/dirichlet.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Dirichlet distribution class."""
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from tensorflow.python.framework import ops
|
| 20 |
+
from tensorflow.python.ops import array_ops
|
| 21 |
+
from tensorflow.python.ops import check_ops
|
| 22 |
+
from tensorflow.python.ops import control_flow_ops
|
| 23 |
+
from tensorflow.python.ops import math_ops
|
| 24 |
+
from tensorflow.python.ops import random_ops
|
| 25 |
+
from tensorflow.python.ops import special_math_ops
|
| 26 |
+
from tensorflow.python.ops.distributions import distribution
|
| 27 |
+
from tensorflow.python.ops.distributions import kullback_leibler
|
| 28 |
+
from tensorflow.python.ops.distributions import util as distribution_util
|
| 29 |
+
from tensorflow.python.util import deprecation
|
| 30 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
"Dirichlet",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
_dirichlet_sample_note = """Note: `value` must be a non-negative tensor with
|
| 39 |
+
dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e.,
|
| 40 |
+
`tf.reduce_sum(value, -1) = 1`. It must have a shape compatible with
|
| 41 |
+
`self.batch_shape() + self.event_shape()`."""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@tf_export(v1=["distributions.Dirichlet"])
|
| 45 |
+
class Dirichlet(distribution.Distribution):
|
| 46 |
+
"""Dirichlet distribution.
|
| 47 |
+
|
| 48 |
+
The Dirichlet distribution is defined over the
|
| 49 |
+
[`(k-1)`-simplex](https://en.wikipedia.org/wiki/Simplex) using a positive,
|
| 50 |
+
length-`k` vector `concentration` (`k > 1`). The Dirichlet is identically the
|
| 51 |
+
Beta distribution when `k = 2`.
|
| 52 |
+
|
| 53 |
+
#### Mathematical Details
|
| 54 |
+
|
| 55 |
+
The Dirichlet is a distribution over the open `(k-1)`-simplex, i.e.,
|
| 56 |
+
|
| 57 |
+
```none
|
| 58 |
+
S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }.
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
The probability density function (pdf) is,
|
| 62 |
+
|
| 63 |
+
```none
|
| 64 |
+
pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
|
| 65 |
+
Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
where:
|
| 69 |
+
|
| 70 |
+
* `x in S^{k-1}`, i.e., the `(k-1)`-simplex,
|
| 71 |
+
* `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`,
|
| 72 |
+
* `Z` is the normalization constant aka the [multivariate beta function](
|
| 73 |
+
https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
|
| 74 |
+
and,
|
| 75 |
+
* `Gamma` is the [gamma function](
|
| 76 |
+
https://en.wikipedia.org/wiki/Gamma_function).
|
| 77 |
+
|
| 78 |
+
The `concentration` represents mean total counts of class occurrence, i.e.,
|
| 79 |
+
|
| 80 |
+
```none
|
| 81 |
+
concentration = alpha = mean * total_concentration
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
where `mean` in `S^{k-1}` and `total_concentration` is a positive real number
|
| 85 |
+
representing a mean total count.
|
| 86 |
+
|
| 87 |
+
Distribution parameters are automatically broadcast in all functions; see
|
| 88 |
+
examples for details.
|
| 89 |
+
|
| 90 |
+
Warning: Some components of the samples can be zero due to finite precision.
|
| 91 |
+
This happens more often when some of the concentrations are very small.
|
| 92 |
+
Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
|
| 93 |
+
density.
|
| 94 |
+
|
| 95 |
+
Samples of this distribution are reparameterized (pathwise differentiable).
|
| 96 |
+
The derivatives are computed using the approach described in
|
| 97 |
+
(Figurnov et al., 2018).
|
| 98 |
+
|
| 99 |
+
#### Examples
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
import tensorflow_probability as tfp
|
| 103 |
+
tfd = tfp.distributions
|
| 104 |
+
|
| 105 |
+
# Create a single trivariate Dirichlet, with the 3rd class being three times
|
| 106 |
+
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
|
| 107 |
+
alpha = [1., 2, 3]
|
| 108 |
+
dist = tfd.Dirichlet(alpha)
|
| 109 |
+
|
| 110 |
+
dist.sample([4, 5]) # shape: [4, 5, 3]
|
| 111 |
+
|
| 112 |
+
# x has one sample, one batch, three classes:
|
| 113 |
+
x = [.2, .3, .5] # shape: [3]
|
| 114 |
+
dist.prob(x) # shape: []
|
| 115 |
+
|
| 116 |
+
# x has two samples from one batch:
|
| 117 |
+
x = [[.1, .4, .5],
|
| 118 |
+
[.2, .3, .5]]
|
| 119 |
+
dist.prob(x) # shape: [2]
|
| 120 |
+
|
| 121 |
+
# alpha will be broadcast to shape [5, 7, 3] to match x.
|
| 122 |
+
x = [[...]] # shape: [5, 7, 3]
|
| 123 |
+
dist.prob(x) # shape: [5, 7]
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
# Create batch_shape=[2], event_shape=[3]:
|
| 128 |
+
alpha = [[1., 2, 3],
|
| 129 |
+
[4, 5, 6]] # shape: [2, 3]
|
| 130 |
+
dist = tfd.Dirichlet(alpha)
|
| 131 |
+
|
| 132 |
+
dist.sample([4, 5]) # shape: [4, 5, 2, 3]
|
| 133 |
+
|
| 134 |
+
x = [.2, .3, .5]
|
| 135 |
+
# x will be broadcast as [[.2, .3, .5],
|
| 136 |
+
# [.2, .3, .5]],
|
| 137 |
+
# thus matching batch_shape [2, 3].
|
| 138 |
+
dist.prob(x) # shape: [2]
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
Compute the gradients of samples w.r.t. the parameters:
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
alpha = tf.constant([1.0, 2.0, 3.0])
|
| 145 |
+
dist = tfd.Dirichlet(alpha)
|
| 146 |
+
samples = dist.sample(5) # Shape [5, 3]
|
| 147 |
+
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
|
| 148 |
+
# Unbiased stochastic gradients of the loss function
|
| 149 |
+
grads = tf.gradients(loss, alpha)
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
References:
|
| 153 |
+
Implicit Reparameterization Gradients:
|
| 154 |
+
[Figurnov et al., 2018]
|
| 155 |
+
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
|
| 156 |
+
([pdf]
|
| 157 |
+
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
@deprecation.deprecated(
|
| 161 |
+
"2019-01-01",
|
| 162 |
+
"The TensorFlow Distributions library has moved to "
|
| 163 |
+
"TensorFlow Probability "
|
| 164 |
+
"(https://github.com/tensorflow/probability). You "
|
| 165 |
+
"should update all references to use `tfp.distributions` "
|
| 166 |
+
"instead of `tf.distributions`.",
|
| 167 |
+
warn_once=True)
|
| 168 |
+
def __init__(self,
|
| 169 |
+
concentration,
|
| 170 |
+
validate_args=False,
|
| 171 |
+
allow_nan_stats=True,
|
| 172 |
+
name="Dirichlet"):
|
| 173 |
+
"""Initialize a batch of Dirichlet distributions.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
concentration: Positive floating-point `Tensor` indicating mean number
|
| 177 |
+
of class occurrences; aka "alpha". Implies `self.dtype`, and
|
| 178 |
+
`self.batch_shape`, `self.event_shape`, i.e., if
|
| 179 |
+
`concentration.shape = [N1, N2, ..., Nm, k]` then
|
| 180 |
+
`batch_shape = [N1, N2, ..., Nm]` and
|
| 181 |
+
`event_shape = [k]`.
|
| 182 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 183 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 184 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 185 |
+
outputs.
|
| 186 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
| 187 |
+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
| 188 |
+
result is undefined. When `False`, an exception is raised if one or
|
| 189 |
+
more of the statistic's batch members are undefined.
|
| 190 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 191 |
+
"""
|
| 192 |
+
parameters = dict(locals())
|
| 193 |
+
with ops.name_scope(name, values=[concentration]) as name:
|
| 194 |
+
self._concentration = self._maybe_assert_valid_concentration(
|
| 195 |
+
ops.convert_to_tensor(concentration, name="concentration"),
|
| 196 |
+
validate_args)
|
| 197 |
+
self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
|
| 198 |
+
super(Dirichlet, self).__init__(
|
| 199 |
+
dtype=self._concentration.dtype,
|
| 200 |
+
validate_args=validate_args,
|
| 201 |
+
allow_nan_stats=allow_nan_stats,
|
| 202 |
+
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
|
| 203 |
+
parameters=parameters,
|
| 204 |
+
graph_parents=[self._concentration,
|
| 205 |
+
self._total_concentration],
|
| 206 |
+
name=name)
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def concentration(self):
|
| 210 |
+
"""Concentration parameter; expected counts for that coordinate."""
|
| 211 |
+
return self._concentration
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def total_concentration(self):
|
| 215 |
+
"""Sum of last dim of concentration parameter."""
|
| 216 |
+
return self._total_concentration
|
| 217 |
+
|
| 218 |
+
def _batch_shape_tensor(self):
|
| 219 |
+
return array_ops.shape(self.total_concentration)
|
| 220 |
+
|
| 221 |
+
def _batch_shape(self):
|
| 222 |
+
return self.total_concentration.get_shape()
|
| 223 |
+
|
| 224 |
+
def _event_shape_tensor(self):
|
| 225 |
+
return array_ops.shape(self.concentration)[-1:]
|
| 226 |
+
|
| 227 |
+
def _event_shape(self):
|
| 228 |
+
return self.concentration.get_shape().with_rank_at_least(1)[-1:]
|
| 229 |
+
|
| 230 |
+
def _sample_n(self, n, seed=None):
|
| 231 |
+
gamma_sample = random_ops.random_gamma(
|
| 232 |
+
shape=[n],
|
| 233 |
+
alpha=self.concentration,
|
| 234 |
+
dtype=self.dtype,
|
| 235 |
+
seed=seed)
|
| 236 |
+
return gamma_sample / math_ops.reduce_sum(gamma_sample, -1, keepdims=True)
|
| 237 |
+
|
| 238 |
+
@distribution_util.AppendDocstring(_dirichlet_sample_note)
|
| 239 |
+
def _log_prob(self, x):
|
| 240 |
+
return self._log_unnormalized_prob(x) - self._log_normalization()
|
| 241 |
+
|
| 242 |
+
@distribution_util.AppendDocstring(_dirichlet_sample_note)
|
| 243 |
+
def _prob(self, x):
|
| 244 |
+
return math_ops.exp(self._log_prob(x))
|
| 245 |
+
|
| 246 |
+
def _log_unnormalized_prob(self, x):
|
| 247 |
+
x = self._maybe_assert_valid_sample(x)
|
| 248 |
+
return math_ops.reduce_sum(math_ops.xlogy(self.concentration - 1., x), -1)
|
| 249 |
+
|
| 250 |
+
def _log_normalization(self):
|
| 251 |
+
return special_math_ops.lbeta(self.concentration)
|
| 252 |
+
|
| 253 |
+
def _entropy(self):
|
| 254 |
+
k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
|
| 255 |
+
return (
|
| 256 |
+
self._log_normalization()
|
| 257 |
+
+ ((self.total_concentration - k)
|
| 258 |
+
* math_ops.digamma(self.total_concentration))
|
| 259 |
+
- math_ops.reduce_sum(
|
| 260 |
+
(self.concentration - 1.) * math_ops.digamma(self.concentration),
|
| 261 |
+
axis=-1))
|
| 262 |
+
|
| 263 |
+
def _mean(self):
|
| 264 |
+
return self.concentration / self.total_concentration[..., array_ops.newaxis]
|
| 265 |
+
|
| 266 |
+
def _covariance(self):
|
| 267 |
+
x = self._variance_scale_term() * self._mean()
|
| 268 |
+
# pylint: disable=invalid-unary-operand-type
|
| 269 |
+
return array_ops.matrix_set_diag(
|
| 270 |
+
-math_ops.matmul(
|
| 271 |
+
x[..., array_ops.newaxis],
|
| 272 |
+
x[..., array_ops.newaxis, :]), # outer prod
|
| 273 |
+
self._variance())
|
| 274 |
+
|
| 275 |
+
def _variance(self):
|
| 276 |
+
scale = self._variance_scale_term()
|
| 277 |
+
x = scale * self._mean()
|
| 278 |
+
return x * (scale - x)
|
| 279 |
+
|
| 280 |
+
def _variance_scale_term(self):
|
| 281 |
+
"""Helper to `_covariance` and `_variance` which computes a shared scale."""
|
| 282 |
+
return math_ops.rsqrt(1. + self.total_concentration[..., array_ops.newaxis])
|
| 283 |
+
|
| 284 |
+
@distribution_util.AppendDocstring(
|
| 285 |
+
"""Note: The mode is undefined when any `concentration <= 1`. If
|
| 286 |
+
`self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If
|
| 287 |
+
`self.allow_nan_stats` is `False` an exception is raised when one or more
|
| 288 |
+
modes are undefined.""")
|
| 289 |
+
def _mode(self):
|
| 290 |
+
k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
|
| 291 |
+
mode = (self.concentration - 1.) / (
|
| 292 |
+
self.total_concentration[..., array_ops.newaxis] - k)
|
| 293 |
+
if self.allow_nan_stats:
|
| 294 |
+
nan = array_ops.fill(
|
| 295 |
+
array_ops.shape(mode),
|
| 296 |
+
np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
|
| 297 |
+
name="nan")
|
| 298 |
+
return array_ops.where_v2(
|
| 299 |
+
math_ops.reduce_all(self.concentration > 1., axis=-1), mode, nan)
|
| 300 |
+
return control_flow_ops.with_dependencies([
|
| 301 |
+
check_ops.assert_less(
|
| 302 |
+
array_ops.ones([], self.dtype),
|
| 303 |
+
self.concentration,
|
| 304 |
+
message="Mode undefined when any concentration <= 1"),
|
| 305 |
+
], mode)
|
| 306 |
+
|
| 307 |
+
def _maybe_assert_valid_concentration(self, concentration, validate_args):
|
| 308 |
+
"""Checks the validity of the concentration parameter."""
|
| 309 |
+
if not validate_args:
|
| 310 |
+
return concentration
|
| 311 |
+
return control_flow_ops.with_dependencies([
|
| 312 |
+
check_ops.assert_positive(
|
| 313 |
+
concentration,
|
| 314 |
+
message="Concentration parameter must be positive."),
|
| 315 |
+
check_ops.assert_rank_at_least(
|
| 316 |
+
concentration, 1,
|
| 317 |
+
message="Concentration parameter must have >=1 dimensions."),
|
| 318 |
+
check_ops.assert_less(
|
| 319 |
+
1, array_ops.shape(concentration)[-1],
|
| 320 |
+
message="Concentration parameter must have event_size >= 2."),
|
| 321 |
+
], concentration)
|
| 322 |
+
|
| 323 |
+
def _maybe_assert_valid_sample(self, x):
|
| 324 |
+
"""Checks the validity of a sample."""
|
| 325 |
+
if not self.validate_args:
|
| 326 |
+
return x
|
| 327 |
+
return control_flow_ops.with_dependencies([
|
| 328 |
+
check_ops.assert_positive(x, message="samples must be positive"),
|
| 329 |
+
check_ops.assert_near(
|
| 330 |
+
array_ops.ones([], dtype=self.dtype),
|
| 331 |
+
math_ops.reduce_sum(x, -1),
|
| 332 |
+
message="sample last-dimension must sum to `1`"),
|
| 333 |
+
], x)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@kullback_leibler.RegisterKL(Dirichlet, Dirichlet)
|
| 337 |
+
def _kl_dirichlet_dirichlet(d1, d2, name=None):
|
| 338 |
+
"""Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
d1: instance of a Dirichlet distribution object.
|
| 342 |
+
d2: instance of a Dirichlet distribution object.
|
| 343 |
+
name: (optional) Name to use for created operations.
|
| 344 |
+
default is "kl_dirichlet_dirichlet".
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
Batchwise KL(d1 || d2)
|
| 348 |
+
"""
|
| 349 |
+
with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[
|
| 350 |
+
d1.concentration, d2.concentration]):
|
| 351 |
+
# The KL between Dirichlet distributions can be derived as follows. We have
|
| 352 |
+
#
|
| 353 |
+
# Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)]
|
| 354 |
+
#
|
| 355 |
+
# where B(a) is the multivariate Beta function:
|
| 356 |
+
#
|
| 357 |
+
# B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n])
|
| 358 |
+
#
|
| 359 |
+
# The KL is
|
| 360 |
+
#
|
| 361 |
+
# KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))}
|
| 362 |
+
#
|
| 363 |
+
# so we'll need to know the log density of the Dirichlet. This is
|
| 364 |
+
#
|
| 365 |
+
# log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a)
|
| 366 |
+
#
|
| 367 |
+
# The only term that matters for the expectations is the log(x[i]). To
|
| 368 |
+
# compute the expectation of this term over the Dirichlet density, we can
|
| 369 |
+
# use the following facts about the Dirichlet in exponential family form:
|
| 370 |
+
# 1. log(x[i]) is a sufficient statistic
|
| 371 |
+
# 2. expected sufficient statistics (of any exp family distribution) are
|
| 372 |
+
# equal to derivatives of the log normalizer with respect to
|
| 373 |
+
# corresponding natural parameters: E{T[i](x)} = dA/d(eta[i])
|
| 374 |
+
#
|
| 375 |
+
# To proceed, we can rewrite the Dirichlet density in exponential family
|
| 376 |
+
# form as follows:
|
| 377 |
+
#
|
| 378 |
+
# Dir(x; a) = exp{eta(a) . T(x) - A(a)}
|
| 379 |
+
#
|
| 380 |
+
# where '.' is the dot product of vectors eta and T, and A is a scalar:
|
| 381 |
+
#
|
| 382 |
+
# eta[i](a) = a[i] - 1
|
| 383 |
+
# T[i](x) = log(x[i])
|
| 384 |
+
# A(a) = log B(a)
|
| 385 |
+
#
|
| 386 |
+
# Now, we can use fact (2) above to write
|
| 387 |
+
#
|
| 388 |
+
# E_Dir(x; a)[log(x[i])]
|
| 389 |
+
# = dA(a) / da[i]
|
| 390 |
+
# = d/da[i] log B(a)
|
| 391 |
+
# = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j])
|
| 392 |
+
# = digamma(a[i])) - digamma(sum_j a[j])
|
| 393 |
+
#
|
| 394 |
+
# Putting it all together, we have
|
| 395 |
+
#
|
| 396 |
+
# KL[Dir(x; a) || Dir(x; b)]
|
| 397 |
+
# = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)}
|
| 398 |
+
# = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b))
|
| 399 |
+
# = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b)
|
| 400 |
+
# = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))]
|
| 401 |
+
# - lbeta(a) + lbeta(b))
|
| 402 |
+
|
| 403 |
+
digamma_sum_d1 = math_ops.digamma(
|
| 404 |
+
math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True))
|
| 405 |
+
digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1
|
| 406 |
+
concentration_diff = d1.concentration - d2.concentration
|
| 407 |
+
|
| 408 |
+
return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) -
|
| 409 |
+
special_math_ops.lbeta(d1.concentration) +
|
| 410 |
+
special_math_ops.lbeta(d2.concentration))
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/dirichlet_multinomial.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The DirichletMultinomial distribution class."""
|
| 16 |
+
|
| 17 |
+
from tensorflow.python.framework import dtypes
|
| 18 |
+
from tensorflow.python.framework import ops
|
| 19 |
+
from tensorflow.python.ops import array_ops
|
| 20 |
+
from tensorflow.python.ops import check_ops
|
| 21 |
+
from tensorflow.python.ops import control_flow_ops
|
| 22 |
+
from tensorflow.python.ops import math_ops
|
| 23 |
+
from tensorflow.python.ops import random_ops
|
| 24 |
+
from tensorflow.python.ops import special_math_ops
|
| 25 |
+
from tensorflow.python.ops.distributions import distribution
|
| 26 |
+
from tensorflow.python.ops.distributions import util as distribution_util
|
| 27 |
+
from tensorflow.python.util import deprecation
|
| 28 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
__all__ = [
|
| 32 |
+
"DirichletMultinomial",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_dirichlet_multinomial_sample_note = """For each batch of counts,
|
| 37 |
+
`value = [n_0, ..., n_{K-1}]`, `P[value]` is the probability that after
|
| 38 |
+
sampling `self.total_count` draws from this Dirichlet-Multinomial distribution,
|
| 39 |
+
the number of draws falling in class `j` is `n_j`. Since this definition is
|
| 40 |
+
[exchangeable](https://en.wikipedia.org/wiki/Exchangeable_random_variables);
|
| 41 |
+
different sequences have the same counts so the probability includes a
|
| 42 |
+
combinatorial coefficient.
|
| 43 |
+
|
| 44 |
+
Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no
|
| 45 |
+
fractional components, and such that
|
| 46 |
+
`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable
|
| 47 |
+
with `self.concentration` and `self.total_count`."""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@tf_export(v1=["distributions.DirichletMultinomial"])
|
| 51 |
+
class DirichletMultinomial(distribution.Distribution):
|
| 52 |
+
"""Dirichlet-Multinomial compound distribution.
|
| 53 |
+
|
| 54 |
+
The Dirichlet-Multinomial distribution is parameterized by a (batch of)
|
| 55 |
+
length-`K` `concentration` vectors (`K > 1`) and a `total_count` number of
|
| 56 |
+
trials, i.e., the number of trials per draw from the DirichletMultinomial. It
|
| 57 |
+
is defined over a (batch of) length-`K` vector `counts` such that
|
| 58 |
+
`tf.reduce_sum(counts, -1) = total_count`. The Dirichlet-Multinomial is
|
| 59 |
+
identically the Beta-Binomial distribution when `K = 2`.
|
| 60 |
+
|
| 61 |
+
#### Mathematical Details
|
| 62 |
+
|
| 63 |
+
The Dirichlet-Multinomial is a distribution over `K`-class counts, i.e., a
|
| 64 |
+
length-`K` vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`.
|
| 65 |
+
|
| 66 |
+
The probability mass function (pmf) is,
|
| 67 |
+
|
| 68 |
+
```none
|
| 69 |
+
pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z
|
| 70 |
+
Z = Beta(alpha) / N!
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
where:
|
| 74 |
+
|
| 75 |
+
* `concentration = alpha = [alpha_0, ..., alpha_{K-1}]`, `alpha_j > 0`,
|
| 76 |
+
* `total_count = N`, `N` a positive integer,
|
| 77 |
+
* `N!` is `N` factorial, and,
|
| 78 |
+
* `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the
|
| 79 |
+
[multivariate beta function](
|
| 80 |
+
https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
|
| 81 |
+
and,
|
| 82 |
+
* `Gamma` is the [gamma function](
|
| 83 |
+
https://en.wikipedia.org/wiki/Gamma_function).
|
| 84 |
+
|
| 85 |
+
Dirichlet-Multinomial is a [compound distribution](
|
| 86 |
+
https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., its
|
| 87 |
+
samples are generated as follows.
|
| 88 |
+
|
| 89 |
+
1. Choose class probabilities:
|
| 90 |
+
`probs = [p_0,...,p_{K-1}] ~ Dir(concentration)`
|
| 91 |
+
2. Draw integers:
|
| 92 |
+
`counts = [n_0,...,n_{K-1}] ~ Multinomial(total_count, probs)`
|
| 93 |
+
|
| 94 |
+
The last `concentration` dimension parametrizes a single Dirichlet-Multinomial
|
| 95 |
+
distribution. When calling distribution functions (e.g., `dist.prob(counts)`),
|
| 96 |
+
`concentration`, `total_count` and `counts` are broadcast to the same shape.
|
| 97 |
+
The last dimension of `counts` corresponds single Dirichlet-Multinomial
|
| 98 |
+
distributions.
|
| 99 |
+
|
| 100 |
+
Distribution parameters are automatically broadcast in all functions; see
|
| 101 |
+
examples for details.
|
| 102 |
+
|
| 103 |
+
#### Pitfalls
|
| 104 |
+
|
| 105 |
+
The number of classes, `K`, must not exceed:
|
| 106 |
+
- the largest integer representable by `self.dtype`, i.e.,
|
| 107 |
+
`2**(mantissa_bits+1)` (IEE754),
|
| 108 |
+
- the maximum `Tensor` index, i.e., `2**31-1`.
|
| 109 |
+
|
| 110 |
+
In other words,
|
| 111 |
+
|
| 112 |
+
```python
|
| 113 |
+
K <= min(2**31-1, {
|
| 114 |
+
tf.float16: 2**11,
|
| 115 |
+
tf.float32: 2**24,
|
| 116 |
+
tf.float64: 2**53 }[param.dtype])
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
Note: This condition is validated only when `self.validate_args = True`.
|
| 120 |
+
|
| 121 |
+
#### Examples
|
| 122 |
+
|
| 123 |
+
```python
|
| 124 |
+
alpha = [1., 2., 3.]
|
| 125 |
+
n = 2.
|
| 126 |
+
dist = DirichletMultinomial(n, alpha)
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
Creates a 3-class distribution, with the 3rd class is most likely to be
|
| 130 |
+
drawn.
|
| 131 |
+
The distribution functions can be evaluated on counts.
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
# counts same shape as alpha.
|
| 135 |
+
counts = [0., 0., 2.]
|
| 136 |
+
dist.prob(counts) # Shape []
|
| 137 |
+
|
| 138 |
+
# alpha will be broadcast to [[1., 2., 3.], [1., 2., 3.]] to match counts.
|
| 139 |
+
counts = [[1., 1., 0.], [1., 0., 1.]]
|
| 140 |
+
dist.prob(counts) # Shape [2]
|
| 141 |
+
|
| 142 |
+
# alpha will be broadcast to shape [5, 7, 3] to match counts.
|
| 143 |
+
counts = [[...]] # Shape [5, 7, 3]
|
| 144 |
+
dist.prob(counts) # Shape [5, 7]
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
Creates a 2-batch of 3-class distributions.
|
| 148 |
+
|
| 149 |
+
```python
|
| 150 |
+
alpha = [[1., 2., 3.], [4., 5., 6.]] # Shape [2, 3]
|
| 151 |
+
n = [3., 3.]
|
| 152 |
+
dist = DirichletMultinomial(n, alpha)
|
| 153 |
+
|
| 154 |
+
# counts will be broadcast to [[2., 1., 0.], [2., 1., 0.]] to match alpha.
|
| 155 |
+
counts = [2., 1., 0.]
|
| 156 |
+
dist.prob(counts) # Shape [2]
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
# TODO(b/27419586) Change docstring for dtype of concentration once int
|
| 162 |
+
# allowed.
|
| 163 |
+
@deprecation.deprecated(
|
| 164 |
+
"2019-01-01",
|
| 165 |
+
"The TensorFlow Distributions library has moved to "
|
| 166 |
+
"TensorFlow Probability "
|
| 167 |
+
"(https://github.com/tensorflow/probability). You "
|
| 168 |
+
"should update all references to use `tfp.distributions` "
|
| 169 |
+
"instead of `tf.distributions`.",
|
| 170 |
+
warn_once=True)
|
| 171 |
+
def __init__(self,
|
| 172 |
+
total_count,
|
| 173 |
+
concentration,
|
| 174 |
+
validate_args=False,
|
| 175 |
+
allow_nan_stats=True,
|
| 176 |
+
name="DirichletMultinomial"):
|
| 177 |
+
"""Initialize a batch of DirichletMultinomial distributions.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
total_count: Non-negative floating point tensor, whose dtype is the same
|
| 181 |
+
as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with
|
| 182 |
+
`m >= 0`. Defines this as a batch of `N1 x ... x Nm` different
|
| 183 |
+
Dirichlet multinomial distributions. Its components should be equal to
|
| 184 |
+
integer values.
|
| 185 |
+
concentration: Positive floating point tensor, whose dtype is the
|
| 186 |
+
same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`.
|
| 187 |
+
Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet
|
| 188 |
+
multinomial distributions.
|
| 189 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 190 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 191 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 192 |
+
outputs.
|
| 193 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
| 194 |
+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
| 195 |
+
result is undefined. When `False`, an exception is raised if one or
|
| 196 |
+
more of the statistic's batch members are undefined.
|
| 197 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 198 |
+
"""
|
| 199 |
+
parameters = dict(locals())
|
| 200 |
+
with ops.name_scope(name, values=[total_count, concentration]) as name:
|
| 201 |
+
# Broadcasting works because:
|
| 202 |
+
# * The broadcasting convention is to prepend dimensions of size [1], and
|
| 203 |
+
# we use the last dimension for the distribution, whereas
|
| 204 |
+
# the batch dimensions are the leading dimensions, which forces the
|
| 205 |
+
# distribution dimension to be defined explicitly (i.e. it cannot be
|
| 206 |
+
# created automatically by prepending). This forces enough explicitness.
|
| 207 |
+
# * All calls involving `counts` eventually require a broadcast between
|
| 208 |
+
# `counts` and concentration.
|
| 209 |
+
self._total_count = ops.convert_to_tensor(total_count, name="total_count")
|
| 210 |
+
if validate_args:
|
| 211 |
+
self._total_count = (
|
| 212 |
+
distribution_util.embed_check_nonnegative_integer_form(
|
| 213 |
+
self._total_count))
|
| 214 |
+
self._concentration = self._maybe_assert_valid_concentration(
|
| 215 |
+
ops.convert_to_tensor(concentration,
|
| 216 |
+
name="concentration"),
|
| 217 |
+
validate_args)
|
| 218 |
+
self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
|
| 219 |
+
super(DirichletMultinomial, self).__init__(
|
| 220 |
+
dtype=self._concentration.dtype,
|
| 221 |
+
validate_args=validate_args,
|
| 222 |
+
allow_nan_stats=allow_nan_stats,
|
| 223 |
+
reparameterization_type=distribution.NOT_REPARAMETERIZED,
|
| 224 |
+
parameters=parameters,
|
| 225 |
+
graph_parents=[self._total_count,
|
| 226 |
+
self._concentration],
|
| 227 |
+
name=name)
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def total_count(self):
|
| 231 |
+
"""Number of trials used to construct a sample."""
|
| 232 |
+
return self._total_count
|
| 233 |
+
|
| 234 |
+
@property
|
| 235 |
+
def concentration(self):
|
| 236 |
+
"""Concentration parameter; expected prior counts for that coordinate."""
|
| 237 |
+
return self._concentration
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def total_concentration(self):
|
| 241 |
+
"""Sum of last dim of concentration parameter."""
|
| 242 |
+
return self._total_concentration
|
| 243 |
+
|
| 244 |
+
def _batch_shape_tensor(self):
|
| 245 |
+
return array_ops.shape(self.total_concentration)
|
| 246 |
+
|
| 247 |
+
def _batch_shape(self):
|
| 248 |
+
return self.total_concentration.get_shape()
|
| 249 |
+
|
| 250 |
+
def _event_shape_tensor(self):
|
| 251 |
+
return array_ops.shape(self.concentration)[-1:]
|
| 252 |
+
|
| 253 |
+
def _event_shape(self):
|
| 254 |
+
# Event shape depends only on total_concentration, not "n".
|
| 255 |
+
return self.concentration.get_shape().with_rank_at_least(1)[-1:]
|
| 256 |
+
|
| 257 |
+
def _sample_n(self, n, seed=None):
|
| 258 |
+
n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
|
| 259 |
+
k = self.event_shape_tensor()[0]
|
| 260 |
+
unnormalized_logits = array_ops.reshape(
|
| 261 |
+
math_ops.log(random_ops.random_gamma(
|
| 262 |
+
shape=[n],
|
| 263 |
+
alpha=self.concentration,
|
| 264 |
+
dtype=self.dtype,
|
| 265 |
+
seed=seed)),
|
| 266 |
+
shape=[-1, k])
|
| 267 |
+
draws = random_ops.multinomial(
|
| 268 |
+
logits=unnormalized_logits,
|
| 269 |
+
num_samples=n_draws,
|
| 270 |
+
seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
|
| 271 |
+
x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2)
|
| 272 |
+
final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
|
| 273 |
+
x = array_ops.reshape(x, final_shape)
|
| 274 |
+
return math_ops.cast(x, self.dtype)
|
| 275 |
+
|
| 276 |
+
@distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
|
| 277 |
+
def _log_prob(self, counts):
|
| 278 |
+
counts = self._maybe_assert_valid_sample(counts)
|
| 279 |
+
ordered_prob = (
|
| 280 |
+
special_math_ops.lbeta(self.concentration + counts)
|
| 281 |
+
- special_math_ops.lbeta(self.concentration))
|
| 282 |
+
return ordered_prob + distribution_util.log_combinations(
|
| 283 |
+
self.total_count, counts)
|
| 284 |
+
|
| 285 |
+
@distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
|
| 286 |
+
def _prob(self, counts):
|
| 287 |
+
return math_ops.exp(self._log_prob(counts))
|
| 288 |
+
|
| 289 |
+
def _mean(self):
|
| 290 |
+
return self.total_count * (self.concentration /
|
| 291 |
+
self.total_concentration[..., array_ops.newaxis])
|
| 292 |
+
|
| 293 |
+
@distribution_util.AppendDocstring(
|
| 294 |
+
"""The covariance for each batch member is defined as the following:
|
| 295 |
+
|
| 296 |
+
```none
|
| 297 |
+
Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
|
| 298 |
+
(n + alpha_0) / (1 + alpha_0)
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
where `concentration = alpha` and
|
| 302 |
+
`total_concentration = alpha_0 = sum_j alpha_j`.
|
| 303 |
+
|
| 304 |
+
The covariance between elements in a batch is defined as:
|
| 305 |
+
|
| 306 |
+
```none
|
| 307 |
+
Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
|
| 308 |
+
(n + alpha_0) / (1 + alpha_0)
|
| 309 |
+
```
|
| 310 |
+
""")
|
| 311 |
+
def _covariance(self):
|
| 312 |
+
x = self._variance_scale_term() * self._mean()
|
| 313 |
+
# pylint: disable=invalid-unary-operand-type
|
| 314 |
+
return array_ops.matrix_set_diag(
|
| 315 |
+
-math_ops.matmul(
|
| 316 |
+
x[..., array_ops.newaxis],
|
| 317 |
+
x[..., array_ops.newaxis, :]), # outer prod
|
| 318 |
+
self._variance())
|
| 319 |
+
|
| 320 |
+
def _variance(self):
|
| 321 |
+
scale = self._variance_scale_term()
|
| 322 |
+
x = scale * self._mean()
|
| 323 |
+
return x * (self.total_count * scale - x)
|
| 324 |
+
|
| 325 |
+
def _variance_scale_term(self):
|
| 326 |
+
"""Helper to `_covariance` and `_variance` which computes a shared scale."""
|
| 327 |
+
# We must take care to expand back the last dim whenever we use the
|
| 328 |
+
# total_concentration.
|
| 329 |
+
c0 = self.total_concentration[..., array_ops.newaxis]
|
| 330 |
+
return math_ops.sqrt((1. + c0 / self.total_count) / (1. + c0))
|
| 331 |
+
|
| 332 |
+
def _maybe_assert_valid_concentration(self, concentration, validate_args):
|
| 333 |
+
"""Checks the validity of the concentration parameter."""
|
| 334 |
+
if not validate_args:
|
| 335 |
+
return concentration
|
| 336 |
+
concentration = distribution_util.embed_check_categorical_event_shape(
|
| 337 |
+
concentration)
|
| 338 |
+
return control_flow_ops.with_dependencies([
|
| 339 |
+
check_ops.assert_positive(
|
| 340 |
+
concentration,
|
| 341 |
+
message="Concentration parameter must be positive."),
|
| 342 |
+
], concentration)
|
| 343 |
+
|
| 344 |
+
def _maybe_assert_valid_sample(self, counts):
|
| 345 |
+
"""Check counts for proper shape, values, then return tensor version."""
|
| 346 |
+
if not self.validate_args:
|
| 347 |
+
return counts
|
| 348 |
+
counts = distribution_util.embed_check_nonnegative_integer_form(counts)
|
| 349 |
+
return control_flow_ops.with_dependencies([
|
| 350 |
+
check_ops.assert_equal(
|
| 351 |
+
self.total_count, math_ops.reduce_sum(counts, -1),
|
| 352 |
+
message="counts last-dimension must sum to `self.total_count`"),
|
| 353 |
+
], counts)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/distribution.py
ADDED
|
@@ -0,0 +1,1316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Base classes for probability distributions."""
|
| 16 |
+
|
| 17 |
+
import abc
|
| 18 |
+
import contextlib
|
| 19 |
+
import types
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from tensorflow.python.eager import context
|
| 24 |
+
from tensorflow.python.framework import dtypes
|
| 25 |
+
from tensorflow.python.framework import ops
|
| 26 |
+
from tensorflow.python.framework import tensor_shape
|
| 27 |
+
from tensorflow.python.framework import tensor_util
|
| 28 |
+
from tensorflow.python.ops import array_ops
|
| 29 |
+
from tensorflow.python.ops import math_ops
|
| 30 |
+
from tensorflow.python.ops.distributions import kullback_leibler
|
| 31 |
+
from tensorflow.python.ops.distributions import util
|
| 32 |
+
from tensorflow.python.util import deprecation
|
| 33 |
+
from tensorflow.python.util import tf_inspect
|
| 34 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
__all__ = [
|
| 38 |
+
"ReparameterizationType",
|
| 39 |
+
"FULLY_REPARAMETERIZED",
|
| 40 |
+
"NOT_REPARAMETERIZED",
|
| 41 |
+
"Distribution",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
|
| 45 |
+
"batch_shape",
|
| 46 |
+
"batch_shape_tensor",
|
| 47 |
+
"cdf",
|
| 48 |
+
"covariance",
|
| 49 |
+
"cross_entropy",
|
| 50 |
+
"entropy",
|
| 51 |
+
"event_shape",
|
| 52 |
+
"event_shape_tensor",
|
| 53 |
+
"kl_divergence",
|
| 54 |
+
"log_cdf",
|
| 55 |
+
"log_prob",
|
| 56 |
+
"log_survival_function",
|
| 57 |
+
"mean",
|
| 58 |
+
"mode",
|
| 59 |
+
"prob",
|
| 60 |
+
"sample",
|
| 61 |
+
"stddev",
|
| 62 |
+
"survival_function",
|
| 63 |
+
"variance",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class _BaseDistribution(metaclass=abc.ABCMeta):
|
| 68 |
+
"""Abstract base class needed for resolving subclass hierarchy."""
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _copy_fn(fn):
|
| 73 |
+
"""Create a deep copy of fn.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
fn: a callable
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
A `FunctionType`: a deep copy of fn.
|
| 80 |
+
|
| 81 |
+
Raises:
|
| 82 |
+
TypeError: if `fn` is not a callable.
|
| 83 |
+
"""
|
| 84 |
+
if not callable(fn):
|
| 85 |
+
raise TypeError("fn is not callable: %s" % fn)
|
| 86 |
+
# The blessed way to copy a function. copy.deepcopy fails to create a
|
| 87 |
+
# non-reference copy. Since:
|
| 88 |
+
# types.FunctionType == type(lambda: None),
|
| 89 |
+
# and the docstring for the function type states:
|
| 90 |
+
#
|
| 91 |
+
# function(code, globals[, name[, argdefs[, closure]]])
|
| 92 |
+
#
|
| 93 |
+
# Create a function object from a code object and a dictionary.
|
| 94 |
+
# ...
|
| 95 |
+
#
|
| 96 |
+
# Here we can use this to create a new function with the old function's
|
| 97 |
+
# code, globals, closure, etc.
|
| 98 |
+
return types.FunctionType(
|
| 99 |
+
code=fn.__code__, globals=fn.__globals__,
|
| 100 |
+
name=fn.__name__, argdefs=fn.__defaults__,
|
| 101 |
+
closure=fn.__closure__)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _update_docstring(old_str, append_str):
|
| 105 |
+
"""Update old_str by inserting append_str just before the "Args:" section."""
|
| 106 |
+
old_str = old_str or ""
|
| 107 |
+
old_str_lines = old_str.split("\n")
|
| 108 |
+
|
| 109 |
+
# Step 0: Prepend spaces to all lines of append_str. This is
|
| 110 |
+
# necessary for correct markdown generation.
|
| 111 |
+
append_str = "\n".join(" %s" % line for line in append_str.split("\n"))
|
| 112 |
+
|
| 113 |
+
# Step 1: Find mention of "Args":
|
| 114 |
+
has_args_ix = [
|
| 115 |
+
ix for ix, line in enumerate(old_str_lines)
|
| 116 |
+
if line.strip().lower() == "args:"]
|
| 117 |
+
if has_args_ix:
|
| 118 |
+
final_args_ix = has_args_ix[-1]
|
| 119 |
+
return ("\n".join(old_str_lines[:final_args_ix])
|
| 120 |
+
+ "\n\n" + append_str + "\n\n"
|
| 121 |
+
+ "\n".join(old_str_lines[final_args_ix:]))
|
| 122 |
+
else:
|
| 123 |
+
return old_str + "\n\n" + append_str
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _convert_to_tensor(value, name=None, preferred_dtype=None):
|
| 127 |
+
"""Converts to tensor avoiding an eager bug that loses float precision."""
|
| 128 |
+
# TODO(b/116672045): Remove this function.
|
| 129 |
+
if (context.executing_eagerly() and preferred_dtype is not None and
|
| 130 |
+
(preferred_dtype.is_integer or preferred_dtype.is_bool)):
|
| 131 |
+
v = ops.convert_to_tensor(value, name=name)
|
| 132 |
+
if v.dtype.is_floating:
|
| 133 |
+
return v
|
| 134 |
+
return ops.convert_to_tensor(
|
| 135 |
+
value, name=name, preferred_dtype=preferred_dtype)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class _DistributionMeta(abc.ABCMeta):
|
| 139 |
+
|
| 140 |
+
def __new__(mcs, classname, baseclasses, attrs):
|
| 141 |
+
"""Control the creation of subclasses of the Distribution class.
|
| 142 |
+
|
| 143 |
+
The main purpose of this method is to properly propagate docstrings
|
| 144 |
+
from private Distribution methods, like `_log_prob`, into their
|
| 145 |
+
public wrappers as inherited by the Distribution base class
|
| 146 |
+
(e.g. `log_prob`).
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
classname: The name of the subclass being created.
|
| 150 |
+
baseclasses: A tuple of parent classes.
|
| 151 |
+
attrs: A dict mapping new attributes to their values.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
The class object.
|
| 155 |
+
|
| 156 |
+
Raises:
|
| 157 |
+
TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
|
| 158 |
+
the new class is derived via multiple inheritance and the first
|
| 159 |
+
parent class is not a subclass of `BaseDistribution`.
|
| 160 |
+
AttributeError: If `Distribution` does not implement e.g. `log_prob`.
|
| 161 |
+
ValueError: If a `Distribution` public method lacks a docstring.
|
| 162 |
+
"""
|
| 163 |
+
if not baseclasses: # Nothing to be done for Distribution
|
| 164 |
+
raise TypeError("Expected non-empty baseclass. Does Distribution "
|
| 165 |
+
"not subclass _BaseDistribution?")
|
| 166 |
+
which_base = [
|
| 167 |
+
base for base in baseclasses
|
| 168 |
+
if base == _BaseDistribution or issubclass(base, Distribution)]
|
| 169 |
+
base = which_base[0]
|
| 170 |
+
if base == _BaseDistribution: # Nothing to be done for Distribution
|
| 171 |
+
return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
|
| 172 |
+
if not issubclass(base, Distribution):
|
| 173 |
+
raise TypeError("First parent class declared for %s must be "
|
| 174 |
+
"Distribution, but saw '%s'" % (classname, base.__name__))
|
| 175 |
+
for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
|
| 176 |
+
special_attr = "_%s" % attr
|
| 177 |
+
class_attr_value = attrs.get(attr, None)
|
| 178 |
+
if attr in attrs:
|
| 179 |
+
# The method is being overridden, do not update its docstring
|
| 180 |
+
continue
|
| 181 |
+
base_attr_value = getattr(base, attr, None)
|
| 182 |
+
if not base_attr_value:
|
| 183 |
+
raise AttributeError(
|
| 184 |
+
"Internal error: expected base class '%s' to implement method '%s'"
|
| 185 |
+
% (base.__name__, attr))
|
| 186 |
+
class_special_attr_value = attrs.get(special_attr, None)
|
| 187 |
+
if class_special_attr_value is None:
|
| 188 |
+
# No _special method available, no need to update the docstring.
|
| 189 |
+
continue
|
| 190 |
+
class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
|
| 191 |
+
if not class_special_attr_docstring:
|
| 192 |
+
# No docstring to append.
|
| 193 |
+
continue
|
| 194 |
+
class_attr_value = _copy_fn(base_attr_value)
|
| 195 |
+
class_attr_docstring = tf_inspect.getdoc(base_attr_value)
|
| 196 |
+
if class_attr_docstring is None:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
"Expected base class fn to contain a docstring: %s.%s"
|
| 199 |
+
% (base.__name__, attr))
|
| 200 |
+
class_attr_value.__doc__ = _update_docstring(
|
| 201 |
+
class_attr_value.__doc__,
|
| 202 |
+
("Additional documentation from `%s`:\n\n%s"
|
| 203 |
+
% (classname, class_special_attr_docstring)))
|
| 204 |
+
attrs[attr] = class_attr_value
|
| 205 |
+
|
| 206 |
+
return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@tf_export(v1=["distributions.ReparameterizationType"])
|
| 210 |
+
class ReparameterizationType:
|
| 211 |
+
"""Instances of this class represent how sampling is reparameterized.
|
| 212 |
+
|
| 213 |
+
Two static instances exist in the distributions library, signifying
|
| 214 |
+
one of two possible properties for samples from a distribution:
|
| 215 |
+
|
| 216 |
+
`FULLY_REPARAMETERIZED`: Samples from the distribution are fully
|
| 217 |
+
reparameterized, and straight-through gradients are supported.
|
| 218 |
+
|
| 219 |
+
`NOT_REPARAMETERIZED`: Samples from the distribution are not fully
|
| 220 |
+
reparameterized, and straight-through gradients are either partially
|
| 221 |
+
unsupported or are not supported at all. In this case, for purposes of
|
| 222 |
+
e.g. RL or variational inference, it is generally safest to wrap the
|
| 223 |
+
sample results in a `stop_gradients` call and use policy
|
| 224 |
+
gradients / surrogate loss instead.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
@deprecation.deprecated(
|
| 228 |
+
"2019-01-01",
|
| 229 |
+
"The TensorFlow Distributions library has moved to "
|
| 230 |
+
"TensorFlow Probability "
|
| 231 |
+
"(https://github.com/tensorflow/probability). You "
|
| 232 |
+
"should update all references to use `tfp.distributions` "
|
| 233 |
+
"instead of `tf.distributions`.",
|
| 234 |
+
warn_once=True)
|
| 235 |
+
def __init__(self, rep_type):
|
| 236 |
+
self._rep_type = rep_type
|
| 237 |
+
|
| 238 |
+
def __repr__(self):
|
| 239 |
+
return "<Reparameterization Type: %s>" % self._rep_type
|
| 240 |
+
|
| 241 |
+
def __eq__(self, other):
|
| 242 |
+
"""Determine if this `ReparameterizationType` is equal to another.
|
| 243 |
+
|
| 244 |
+
Since ReparameterizationType instances are constant static global
|
| 245 |
+
instances, equality checks if two instances' id() values are equal.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
other: Object to compare against.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
`self is other`.
|
| 252 |
+
"""
|
| 253 |
+
return self is other
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# Fully reparameterized distribution: samples from a fully
|
| 257 |
+
# reparameterized distribution support straight-through gradients with
|
| 258 |
+
# respect to all parameters.
|
| 259 |
+
FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED")
|
| 260 |
+
tf_export(v1=["distributions.FULLY_REPARAMETERIZED"]).export_constant(
|
| 261 |
+
__name__, "FULLY_REPARAMETERIZED")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# Not reparameterized distribution: samples from a non-
|
| 265 |
+
# reparameterized distribution do not support straight-through gradients for
|
| 266 |
+
# at least some of the parameters.
|
| 267 |
+
NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED")
|
| 268 |
+
tf_export(v1=["distributions.NOT_REPARAMETERIZED"]).export_constant(
|
| 269 |
+
__name__, "NOT_REPARAMETERIZED")
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
@tf_export(v1=["distributions.Distribution"])
|
| 273 |
+
class Distribution(_BaseDistribution, metaclass=_DistributionMeta):
|
| 274 |
+
"""A generic probability distribution base class.
|
| 275 |
+
|
| 276 |
+
`Distribution` is a base class for constructing and organizing properties
|
| 277 |
+
(e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian).
|
| 278 |
+
|
| 279 |
+
#### Subclassing
|
| 280 |
+
|
| 281 |
+
Subclasses are expected to implement a leading-underscore version of the
|
| 282 |
+
same-named function. The argument signature should be identical except for
|
| 283 |
+
the omission of `name="..."`. For example, to enable `log_prob(value,
|
| 284 |
+
name="log_prob")` a subclass should implement `_log_prob(value)`.
|
| 285 |
+
|
| 286 |
+
Subclasses can append to public-level docstrings by providing
|
| 287 |
+
docstrings for their method specializations. For example:
|
| 288 |
+
|
| 289 |
+
```python
|
| 290 |
+
@util.AppendDocstring("Some other details.")
|
| 291 |
+
def _log_prob(self, value):
|
| 292 |
+
...
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
would add the string "Some other details." to the `log_prob` function
|
| 296 |
+
docstring. This is implemented as a simple decorator to avoid python
|
| 297 |
+
linter complaining about missing Args/Returns/Raises sections in the
|
| 298 |
+
partial docstrings.
|
| 299 |
+
|
| 300 |
+
#### Broadcasting, batching, and shapes
|
| 301 |
+
|
| 302 |
+
All distributions support batches of independent distributions of that type.
|
| 303 |
+
The batch shape is determined by broadcasting together the parameters.
|
| 304 |
+
|
| 305 |
+
The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and
|
| 306 |
+
`log_prob` reflect this broadcasting, as does the return value of `sample` and
|
| 307 |
+
`sample_n`.
|
| 308 |
+
|
| 309 |
+
`sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is
|
| 310 |
+
the shape of the `Tensor` returned from `sample_n`, `n` is the number of
|
| 311 |
+
samples, `batch_shape` defines how many independent distributions there are,
|
| 312 |
+
and `event_shape` defines the shape of samples from each of those independent
|
| 313 |
+
distributions. Samples are independent along the `batch_shape` dimensions, but
|
| 314 |
+
not necessarily so along the `event_shape` dimensions (depending on the
|
| 315 |
+
particulars of the underlying distribution).
|
| 316 |
+
|
| 317 |
+
Using the `Uniform` distribution as an example:
|
| 318 |
+
|
| 319 |
+
```python
|
| 320 |
+
minval = 3.0
|
| 321 |
+
maxval = [[4.0, 6.0],
|
| 322 |
+
[10.0, 12.0]]
|
| 323 |
+
|
| 324 |
+
# Broadcasting:
|
| 325 |
+
# This instance represents 4 Uniform distributions. Each has a lower bound at
|
| 326 |
+
# 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
|
| 327 |
+
u = Uniform(minval, maxval)
|
| 328 |
+
|
| 329 |
+
# `event_shape` is `TensorShape([])`.
|
| 330 |
+
event_shape = u.event_shape
|
| 331 |
+
# `event_shape_t` is a `Tensor` which will evaluate to [].
|
| 332 |
+
event_shape_t = u.event_shape_tensor()
|
| 333 |
+
|
| 334 |
+
# Sampling returns a sample per distribution. `samples` has shape
|
| 335 |
+
# [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5,
|
| 336 |
+
# batch_shape=[2, 2], and event_shape=[].
|
| 337 |
+
samples = u.sample_n(5)
|
| 338 |
+
|
| 339 |
+
# The broadcasting holds across methods. Here we use `cdf` as an example. The
|
| 340 |
+
# same holds for `log_cdf` and the likelihood functions.
|
| 341 |
+
|
| 342 |
+
# `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the
|
| 343 |
+
# shape of the `Uniform` instance.
|
| 344 |
+
cum_prob_broadcast = u.cdf(4.0)
|
| 345 |
+
|
| 346 |
+
# `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting
|
| 347 |
+
# occurred.
|
| 348 |
+
cum_prob_per_dist = u.cdf([[4.0, 5.0],
|
| 349 |
+
[6.0, 7.0]])
|
| 350 |
+
|
| 351 |
+
# INVALID as the `value` argument is not broadcastable to the distribution's
|
| 352 |
+
# shape.
|
| 353 |
+
cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
|
| 354 |
+
```
|
| 355 |
+
|
| 356 |
+
#### Shapes
|
| 357 |
+
|
| 358 |
+
There are three important concepts associated with TensorFlow Distributions
|
| 359 |
+
shapes:
|
| 360 |
+
- Event shape describes the shape of a single draw from the distribution;
|
| 361 |
+
it may be dependent across dimensions. For scalar distributions, the event
|
| 362 |
+
shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is
|
| 363 |
+
`[5]`.
|
| 364 |
+
- Batch shape describes independent, not identically distributed draws, aka a
|
| 365 |
+
"collection" or "bunch" of distributions.
|
| 366 |
+
- Sample shape describes independent, identically distributed draws of batches
|
| 367 |
+
from the distribution family.
|
| 368 |
+
|
| 369 |
+
The event shape and the batch shape are properties of a Distribution object,
|
| 370 |
+
whereas the sample shape is associated with a specific call to `sample` or
|
| 371 |
+
`log_prob`.
|
| 372 |
+
|
| 373 |
+
For detailed usage examples of TensorFlow Distributions shapes, see
|
| 374 |
+
[this tutorial](
|
| 375 |
+
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb)
|
| 376 |
+
|
| 377 |
+
#### Parameter values leading to undefined statistics or distributions.
|
| 378 |
+
|
| 379 |
+
Some distributions do not have well-defined statistics for all initialization
|
| 380 |
+
parameter values. For example, the beta distribution is parameterized by
|
| 381 |
+
positive real numbers `concentration1` and `concentration0`, and does not have
|
| 382 |
+
well-defined mode if `concentration1 < 1` or `concentration0 < 1`.
|
| 383 |
+
|
| 384 |
+
The user is given the option of raising an exception or returning `NaN`.
|
| 385 |
+
|
| 386 |
+
```python
|
| 387 |
+
a = tf.exp(tf.matmul(logits, weights_a))
|
| 388 |
+
b = tf.exp(tf.matmul(logits, weights_b))
|
| 389 |
+
|
| 390 |
+
# Will raise exception if ANY batch member has a < 1 or b < 1.
|
| 391 |
+
dist = distributions.beta(a, b, allow_nan_stats=False)
|
| 392 |
+
mode = dist.mode().eval()
|
| 393 |
+
|
| 394 |
+
# Will return NaN for batch members with either a < 1 or b < 1.
|
| 395 |
+
dist = distributions.beta(a, b, allow_nan_stats=True) # Default behavior
|
| 396 |
+
mode = dist.mode().eval()
|
| 397 |
+
```
|
| 398 |
+
|
| 399 |
+
In all cases, an exception is raised if *invalid* parameters are passed, e.g.
|
| 400 |
+
|
| 401 |
+
```python
|
| 402 |
+
# Will raise an exception if any Op is run.
|
| 403 |
+
negative_a = -1.0 * a # beta distribution by definition has a > 0.
|
| 404 |
+
dist = distributions.beta(negative_a, b, allow_nan_stats=True)
|
| 405 |
+
dist.mean().eval()
|
| 406 |
+
```
|
| 407 |
+
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
@deprecation.deprecated(
|
| 411 |
+
"2019-01-01",
|
| 412 |
+
"The TensorFlow Distributions library has moved to "
|
| 413 |
+
"TensorFlow Probability "
|
| 414 |
+
"(https://github.com/tensorflow/probability). You "
|
| 415 |
+
"should update all references to use `tfp.distributions` "
|
| 416 |
+
"instead of `tf.distributions`.",
|
| 417 |
+
warn_once=True)
|
| 418 |
+
def __init__(self,
|
| 419 |
+
dtype,
|
| 420 |
+
reparameterization_type,
|
| 421 |
+
validate_args,
|
| 422 |
+
allow_nan_stats,
|
| 423 |
+
parameters=None,
|
| 424 |
+
graph_parents=None,
|
| 425 |
+
name=None):
|
| 426 |
+
"""Constructs the `Distribution`.
|
| 427 |
+
|
| 428 |
+
**This is a private method for subclass use.**
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
dtype: The type of the event samples. `None` implies no type-enforcement.
|
| 432 |
+
reparameterization_type: Instance of `ReparameterizationType`.
|
| 433 |
+
If `distributions.FULLY_REPARAMETERIZED`, this
|
| 434 |
+
`Distribution` can be reparameterized in terms of some standard
|
| 435 |
+
distribution with a function whose Jacobian is constant for the support
|
| 436 |
+
of the standard distribution. If `distributions.NOT_REPARAMETERIZED`,
|
| 437 |
+
then no such reparameterization is available.
|
| 438 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 439 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 440 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 441 |
+
outputs.
|
| 442 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
| 443 |
+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
| 444 |
+
result is undefined. When `False`, an exception is raised if one or
|
| 445 |
+
more of the statistic's batch members are undefined.
|
| 446 |
+
parameters: Python `dict` of parameters used to instantiate this
|
| 447 |
+
`Distribution`.
|
| 448 |
+
graph_parents: Python `list` of graph prerequisites of this
|
| 449 |
+
`Distribution`.
|
| 450 |
+
name: Python `str` name prefixed to Ops created by this class. Default:
|
| 451 |
+
subclass name.
|
| 452 |
+
|
| 453 |
+
Raises:
|
| 454 |
+
ValueError: if any member of graph_parents is `None` or not a `Tensor`.
|
| 455 |
+
"""
|
| 456 |
+
graph_parents = [] if graph_parents is None else graph_parents
|
| 457 |
+
for i, t in enumerate(graph_parents):
|
| 458 |
+
if t is None or not tensor_util.is_tf_type(t):
|
| 459 |
+
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
|
| 460 |
+
if not name or name[-1] != "/": # `name` is not a name scope
|
| 461 |
+
non_unique_name = name or type(self).__name__
|
| 462 |
+
with ops.name_scope(non_unique_name) as name:
|
| 463 |
+
pass
|
| 464 |
+
self._dtype = dtype
|
| 465 |
+
self._reparameterization_type = reparameterization_type
|
| 466 |
+
self._allow_nan_stats = allow_nan_stats
|
| 467 |
+
self._validate_args = validate_args
|
| 468 |
+
self._parameters = parameters or {}
|
| 469 |
+
self._graph_parents = graph_parents
|
| 470 |
+
self._name = name
|
| 471 |
+
|
| 472 |
+
@property
|
| 473 |
+
def _parameters(self):
|
| 474 |
+
return self._parameter_dict
|
| 475 |
+
|
| 476 |
+
@_parameters.setter
|
| 477 |
+
def _parameters(self, value):
|
| 478 |
+
"""Intercept assignments to self._parameters to avoid reference cycles.
|
| 479 |
+
|
| 480 |
+
Parameters are often created using locals(), so we need to clean out any
|
| 481 |
+
references to `self` before assigning it to an attribute.
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
value: A dictionary of parameters to assign to the `_parameters` property.
|
| 485 |
+
"""
|
| 486 |
+
if "self" in value:
|
| 487 |
+
del value["self"]
|
| 488 |
+
self._parameter_dict = value
|
| 489 |
+
|
| 490 |
+
@classmethod
|
| 491 |
+
def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
|
| 492 |
+
"""Shapes of parameters given the desired shape of a call to `sample()`.
|
| 493 |
+
|
| 494 |
+
This is a class method that describes what key/value arguments are required
|
| 495 |
+
to instantiate the given `Distribution` so that a particular shape is
|
| 496 |
+
returned for that instance's call to `sample()`.
|
| 497 |
+
|
| 498 |
+
Subclasses should override class method `_param_shapes`.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
sample_shape: `Tensor` or python list/tuple. Desired shape of a call to
|
| 502 |
+
`sample()`.
|
| 503 |
+
name: name to prepend ops with.
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
`dict` of parameter name to `Tensor` shapes.
|
| 507 |
+
"""
|
| 508 |
+
with ops.name_scope(name, values=[sample_shape]):
|
| 509 |
+
return cls._param_shapes(sample_shape)
|
| 510 |
+
|
| 511 |
+
@classmethod
|
| 512 |
+
def param_static_shapes(cls, sample_shape):
|
| 513 |
+
"""param_shapes with static (i.e. `TensorShape`) shapes.
|
| 514 |
+
|
| 515 |
+
This is a class method that describes what key/value arguments are required
|
| 516 |
+
to instantiate the given `Distribution` so that a particular shape is
|
| 517 |
+
returned for that instance's call to `sample()`. Assumes that the sample's
|
| 518 |
+
shape is known statically.
|
| 519 |
+
|
| 520 |
+
Subclasses should override class method `_param_shapes` to return
|
| 521 |
+
constant-valued tensors when constant values are fed.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
sample_shape: `TensorShape` or python list/tuple. Desired shape of a call
|
| 525 |
+
to `sample()`.
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
`dict` of parameter name to `TensorShape`.
|
| 529 |
+
|
| 530 |
+
Raises:
|
| 531 |
+
ValueError: if `sample_shape` is a `TensorShape` and is not fully defined.
|
| 532 |
+
"""
|
| 533 |
+
if isinstance(sample_shape, tensor_shape.TensorShape):
|
| 534 |
+
if not sample_shape.is_fully_defined():
|
| 535 |
+
raise ValueError("TensorShape sample_shape must be fully defined")
|
| 536 |
+
sample_shape = sample_shape.as_list()
|
| 537 |
+
|
| 538 |
+
params = cls.param_shapes(sample_shape)
|
| 539 |
+
|
| 540 |
+
static_params = {}
|
| 541 |
+
for name, shape in params.items():
|
| 542 |
+
static_shape = tensor_util.constant_value(shape)
|
| 543 |
+
if static_shape is None:
|
| 544 |
+
raise ValueError(
|
| 545 |
+
"sample_shape must be a fully-defined TensorShape or list/tuple")
|
| 546 |
+
static_params[name] = tensor_shape.TensorShape(static_shape)
|
| 547 |
+
|
| 548 |
+
return static_params
|
| 549 |
+
|
| 550 |
+
@staticmethod
|
| 551 |
+
def _param_shapes(sample_shape):
|
| 552 |
+
raise NotImplementedError("_param_shapes not implemented")
|
| 553 |
+
|
| 554 |
+
@property
|
| 555 |
+
def name(self):
|
| 556 |
+
"""Name prepended to all ops created by this `Distribution`."""
|
| 557 |
+
return self._name
|
| 558 |
+
|
| 559 |
+
@property
|
| 560 |
+
def dtype(self):
|
| 561 |
+
"""The `DType` of `Tensor`s handled by this `Distribution`."""
|
| 562 |
+
return self._dtype
|
| 563 |
+
|
| 564 |
+
@property
|
| 565 |
+
def parameters(self):
|
| 566 |
+
"""Dictionary of parameters used to instantiate this `Distribution`."""
|
| 567 |
+
# Remove "self", "__class__", or other special variables. These can appear
|
| 568 |
+
# if the subclass used:
|
| 569 |
+
# `parameters = dict(locals())`.
|
| 570 |
+
return {k: v for k, v in self._parameters.items()
|
| 571 |
+
if not k.startswith("__") and k != "self"}
|
| 572 |
+
|
| 573 |
+
@property
|
| 574 |
+
def reparameterization_type(self):
|
| 575 |
+
"""Describes how samples from the distribution are reparameterized.
|
| 576 |
+
|
| 577 |
+
Currently this is one of the static instances
|
| 578 |
+
`distributions.FULLY_REPARAMETERIZED`
|
| 579 |
+
or `distributions.NOT_REPARAMETERIZED`.
|
| 580 |
+
|
| 581 |
+
Returns:
|
| 582 |
+
An instance of `ReparameterizationType`.
|
| 583 |
+
"""
|
| 584 |
+
return self._reparameterization_type
|
| 585 |
+
|
| 586 |
+
@property
|
| 587 |
+
def allow_nan_stats(self):
|
| 588 |
+
"""Python `bool` describing behavior when a stat is undefined.
|
| 589 |
+
|
| 590 |
+
Stats return +/- infinity when it makes sense. E.g., the variance of a
|
| 591 |
+
Cauchy distribution is infinity. However, sometimes the statistic is
|
| 592 |
+
undefined, e.g., if a distribution's pdf does not achieve a maximum within
|
| 593 |
+
the support of the distribution, the mode is undefined. If the mean is
|
| 594 |
+
undefined, then by definition the variance is undefined. E.g. the mean for
|
| 595 |
+
Student's T for df = 1 is undefined (no clear way to say it is either + or -
|
| 596 |
+
infinity), so the variance = E[(X - mean)**2] is also undefined.
|
| 597 |
+
|
| 598 |
+
Returns:
|
| 599 |
+
allow_nan_stats: Python `bool`.
|
| 600 |
+
"""
|
| 601 |
+
return self._allow_nan_stats
|
| 602 |
+
|
| 603 |
+
@property
|
| 604 |
+
def validate_args(self):
|
| 605 |
+
"""Python `bool` indicating possibly expensive checks are enabled."""
|
| 606 |
+
return self._validate_args
|
| 607 |
+
|
| 608 |
+
def copy(self, **override_parameters_kwargs):
|
| 609 |
+
"""Creates a deep copy of the distribution.
|
| 610 |
+
|
| 611 |
+
Note: the copy distribution may continue to depend on the original
|
| 612 |
+
initialization arguments.
|
| 613 |
+
|
| 614 |
+
Args:
|
| 615 |
+
**override_parameters_kwargs: String/value dictionary of initialization
|
| 616 |
+
arguments to override with new values.
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
distribution: A new instance of `type(self)` initialized from the union
|
| 620 |
+
of self.parameters and override_parameters_kwargs, i.e.,
|
| 621 |
+
`dict(self.parameters, **override_parameters_kwargs)`.
|
| 622 |
+
"""
|
| 623 |
+
parameters = dict(self.parameters, **override_parameters_kwargs)
|
| 624 |
+
return type(self)(**parameters)
|
| 625 |
+
|
| 626 |
+
def _batch_shape_tensor(self):
|
| 627 |
+
raise NotImplementedError(
|
| 628 |
+
"batch_shape_tensor is not implemented: {}".format(type(self).__name__))
|
| 629 |
+
|
| 630 |
+
def batch_shape_tensor(self, name="batch_shape_tensor"):
|
| 631 |
+
"""Shape of a single sample from a single event index as a 1-D `Tensor`.
|
| 632 |
+
|
| 633 |
+
The batch dimensions are indexes into independent, non-identical
|
| 634 |
+
parameterizations of this distribution.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
name: name to give to the op
|
| 638 |
+
|
| 639 |
+
Returns:
|
| 640 |
+
batch_shape: `Tensor`.
|
| 641 |
+
"""
|
| 642 |
+
with self._name_scope(name):
|
| 643 |
+
if self.batch_shape.is_fully_defined():
|
| 644 |
+
return ops.convert_to_tensor(self.batch_shape.as_list(),
|
| 645 |
+
dtype=dtypes.int32,
|
| 646 |
+
name="batch_shape")
|
| 647 |
+
return self._batch_shape_tensor()
|
| 648 |
+
|
| 649 |
+
def _batch_shape(self):
|
| 650 |
+
return tensor_shape.TensorShape(None)
|
| 651 |
+
|
| 652 |
+
@property
|
| 653 |
+
def batch_shape(self):
|
| 654 |
+
"""Shape of a single sample from a single event index as a `TensorShape`.
|
| 655 |
+
|
| 656 |
+
May be partially defined or unknown.
|
| 657 |
+
|
| 658 |
+
The batch dimensions are indexes into independent, non-identical
|
| 659 |
+
parameterizations of this distribution.
|
| 660 |
+
|
| 661 |
+
Returns:
|
| 662 |
+
batch_shape: `TensorShape`, possibly unknown.
|
| 663 |
+
"""
|
| 664 |
+
return tensor_shape.as_shape(self._batch_shape())
|
| 665 |
+
|
| 666 |
+
def _event_shape_tensor(self):
|
| 667 |
+
raise NotImplementedError(
|
| 668 |
+
"event_shape_tensor is not implemented: {}".format(type(self).__name__))
|
| 669 |
+
|
| 670 |
+
def event_shape_tensor(self, name="event_shape_tensor"):
|
| 671 |
+
"""Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
|
| 672 |
+
|
| 673 |
+
Args:
|
| 674 |
+
name: name to give to the op
|
| 675 |
+
|
| 676 |
+
Returns:
|
| 677 |
+
event_shape: `Tensor`.
|
| 678 |
+
"""
|
| 679 |
+
with self._name_scope(name):
|
| 680 |
+
if self.event_shape.is_fully_defined():
|
| 681 |
+
return ops.convert_to_tensor(self.event_shape.as_list(),
|
| 682 |
+
dtype=dtypes.int32,
|
| 683 |
+
name="event_shape")
|
| 684 |
+
return self._event_shape_tensor()
|
| 685 |
+
|
| 686 |
+
def _event_shape(self):
|
| 687 |
+
return tensor_shape.TensorShape(None)
|
| 688 |
+
|
| 689 |
+
@property
|
| 690 |
+
def event_shape(self):
|
| 691 |
+
"""Shape of a single sample from a single batch as a `TensorShape`.
|
| 692 |
+
|
| 693 |
+
May be partially defined or unknown.
|
| 694 |
+
|
| 695 |
+
Returns:
|
| 696 |
+
event_shape: `TensorShape`, possibly unknown.
|
| 697 |
+
"""
|
| 698 |
+
return tensor_shape.as_shape(self._event_shape())
|
| 699 |
+
|
| 700 |
+
def is_scalar_event(self, name="is_scalar_event"):
|
| 701 |
+
"""Indicates that `event_shape == []`.
|
| 702 |
+
|
| 703 |
+
Args:
|
| 704 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 705 |
+
|
| 706 |
+
Returns:
|
| 707 |
+
is_scalar_event: `bool` scalar `Tensor`.
|
| 708 |
+
"""
|
| 709 |
+
with self._name_scope(name):
|
| 710 |
+
return ops.convert_to_tensor(
|
| 711 |
+
self._is_scalar_helper(self.event_shape, self.event_shape_tensor),
|
| 712 |
+
name="is_scalar_event")
|
| 713 |
+
|
| 714 |
+
def is_scalar_batch(self, name="is_scalar_batch"):
|
| 715 |
+
"""Indicates that `batch_shape == []`.
|
| 716 |
+
|
| 717 |
+
Args:
|
| 718 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 719 |
+
|
| 720 |
+
Returns:
|
| 721 |
+
is_scalar_batch: `bool` scalar `Tensor`.
|
| 722 |
+
"""
|
| 723 |
+
with self._name_scope(name):
|
| 724 |
+
return ops.convert_to_tensor(
|
| 725 |
+
self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor),
|
| 726 |
+
name="is_scalar_batch")
|
| 727 |
+
|
| 728 |
+
def _sample_n(self, n, seed=None):
|
| 729 |
+
raise NotImplementedError("sample_n is not implemented: {}".format(
|
| 730 |
+
type(self).__name__))
|
| 731 |
+
|
| 732 |
+
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
|
| 733 |
+
with self._name_scope(name, values=[sample_shape]):
|
| 734 |
+
sample_shape = ops.convert_to_tensor(
|
| 735 |
+
sample_shape, dtype=dtypes.int32, name="sample_shape")
|
| 736 |
+
sample_shape, n = self._expand_sample_shape_to_vector(
|
| 737 |
+
sample_shape, "sample_shape")
|
| 738 |
+
samples = self._sample_n(n, seed, **kwargs)
|
| 739 |
+
batch_event_shape = array_ops.shape(samples)[1:]
|
| 740 |
+
final_shape = array_ops.concat([sample_shape, batch_event_shape], 0)
|
| 741 |
+
samples = array_ops.reshape(samples, final_shape)
|
| 742 |
+
samples = self._set_sample_static_shape(samples, sample_shape)
|
| 743 |
+
return samples
|
| 744 |
+
|
| 745 |
+
def sample(self, sample_shape=(), seed=None, name="sample"):
|
| 746 |
+
"""Generate samples of the specified shape.
|
| 747 |
+
|
| 748 |
+
Note that a call to `sample()` without arguments will generate a single
|
| 749 |
+
sample.
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
|
| 753 |
+
seed: Python integer seed for RNG
|
| 754 |
+
name: name to give to the op.
|
| 755 |
+
|
| 756 |
+
Returns:
|
| 757 |
+
samples: a `Tensor` with prepended dimensions `sample_shape`.
|
| 758 |
+
"""
|
| 759 |
+
return self._call_sample_n(sample_shape, seed, name)
|
| 760 |
+
|
| 761 |
+
def _log_prob(self, value):
|
| 762 |
+
raise NotImplementedError("log_prob is not implemented: {}".format(
|
| 763 |
+
type(self).__name__))
|
| 764 |
+
|
| 765 |
+
def _call_log_prob(self, value, name, **kwargs):
|
| 766 |
+
with self._name_scope(name, values=[value]):
|
| 767 |
+
value = _convert_to_tensor(
|
| 768 |
+
value, name="value", preferred_dtype=self.dtype)
|
| 769 |
+
try:
|
| 770 |
+
return self._log_prob(value, **kwargs)
|
| 771 |
+
except NotImplementedError as original_exception:
|
| 772 |
+
try:
|
| 773 |
+
return math_ops.log(self._prob(value, **kwargs))
|
| 774 |
+
except NotImplementedError:
|
| 775 |
+
raise original_exception
|
| 776 |
+
|
| 777 |
+
def log_prob(self, value, name="log_prob"):
|
| 778 |
+
"""Log probability density/mass function.
|
| 779 |
+
|
| 780 |
+
Args:
|
| 781 |
+
value: `float` or `double` `Tensor`.
|
| 782 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 783 |
+
|
| 784 |
+
Returns:
|
| 785 |
+
log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
|
| 786 |
+
values of type `self.dtype`.
|
| 787 |
+
"""
|
| 788 |
+
return self._call_log_prob(value, name)
|
| 789 |
+
|
| 790 |
+
def _prob(self, value):
|
| 791 |
+
raise NotImplementedError("prob is not implemented: {}".format(
|
| 792 |
+
type(self).__name__))
|
| 793 |
+
|
| 794 |
+
def _call_prob(self, value, name, **kwargs):
|
| 795 |
+
with self._name_scope(name, values=[value]):
|
| 796 |
+
value = _convert_to_tensor(
|
| 797 |
+
value, name="value", preferred_dtype=self.dtype)
|
| 798 |
+
try:
|
| 799 |
+
return self._prob(value, **kwargs)
|
| 800 |
+
except NotImplementedError as original_exception:
|
| 801 |
+
try:
|
| 802 |
+
return math_ops.exp(self._log_prob(value, **kwargs))
|
| 803 |
+
except NotImplementedError:
|
| 804 |
+
raise original_exception
|
| 805 |
+
|
| 806 |
+
def prob(self, value, name="prob"):
|
| 807 |
+
"""Probability density/mass function.
|
| 808 |
+
|
| 809 |
+
Args:
|
| 810 |
+
value: `float` or `double` `Tensor`.
|
| 811 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 812 |
+
|
| 813 |
+
Returns:
|
| 814 |
+
prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
|
| 815 |
+
values of type `self.dtype`.
|
| 816 |
+
"""
|
| 817 |
+
return self._call_prob(value, name)
|
| 818 |
+
|
| 819 |
+
def _log_cdf(self, value):
|
| 820 |
+
raise NotImplementedError("log_cdf is not implemented: {}".format(
|
| 821 |
+
type(self).__name__))
|
| 822 |
+
|
| 823 |
+
def _call_log_cdf(self, value, name, **kwargs):
|
| 824 |
+
with self._name_scope(name, values=[value]):
|
| 825 |
+
value = _convert_to_tensor(
|
| 826 |
+
value, name="value", preferred_dtype=self.dtype)
|
| 827 |
+
try:
|
| 828 |
+
return self._log_cdf(value, **kwargs)
|
| 829 |
+
except NotImplementedError as original_exception:
|
| 830 |
+
try:
|
| 831 |
+
return math_ops.log(self._cdf(value, **kwargs))
|
| 832 |
+
except NotImplementedError:
|
| 833 |
+
raise original_exception
|
| 834 |
+
|
| 835 |
+
def log_cdf(self, value, name="log_cdf"):
|
| 836 |
+
"""Log cumulative distribution function.
|
| 837 |
+
|
| 838 |
+
Given random variable `X`, the cumulative distribution function `cdf` is:
|
| 839 |
+
|
| 840 |
+
```none
|
| 841 |
+
log_cdf(x) := Log[ P[X <= x] ]
|
| 842 |
+
```
|
| 843 |
+
|
| 844 |
+
Often, a numerical approximation can be used for `log_cdf(x)` that yields
|
| 845 |
+
a more accurate answer than simply taking the logarithm of the `cdf` when
|
| 846 |
+
`x << -1`.
|
| 847 |
+
|
| 848 |
+
Args:
|
| 849 |
+
value: `float` or `double` `Tensor`.
|
| 850 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 851 |
+
|
| 852 |
+
Returns:
|
| 853 |
+
logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
|
| 854 |
+
values of type `self.dtype`.
|
| 855 |
+
"""
|
| 856 |
+
return self._call_log_cdf(value, name)
|
| 857 |
+
|
| 858 |
+
def _cdf(self, value):
|
| 859 |
+
raise NotImplementedError("cdf is not implemented: {}".format(
|
| 860 |
+
type(self).__name__))
|
| 861 |
+
|
| 862 |
+
def _call_cdf(self, value, name, **kwargs):
|
| 863 |
+
with self._name_scope(name, values=[value]):
|
| 864 |
+
value = _convert_to_tensor(
|
| 865 |
+
value, name="value", preferred_dtype=self.dtype)
|
| 866 |
+
try:
|
| 867 |
+
return self._cdf(value, **kwargs)
|
| 868 |
+
except NotImplementedError as original_exception:
|
| 869 |
+
try:
|
| 870 |
+
return math_ops.exp(self._log_cdf(value, **kwargs))
|
| 871 |
+
except NotImplementedError:
|
| 872 |
+
raise original_exception
|
| 873 |
+
|
| 874 |
+
def cdf(self, value, name="cdf"):
|
| 875 |
+
"""Cumulative distribution function.
|
| 876 |
+
|
| 877 |
+
Given random variable `X`, the cumulative distribution function `cdf` is:
|
| 878 |
+
|
| 879 |
+
```none
|
| 880 |
+
cdf(x) := P[X <= x]
|
| 881 |
+
```
|
| 882 |
+
|
| 883 |
+
Args:
|
| 884 |
+
value: `float` or `double` `Tensor`.
|
| 885 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 886 |
+
|
| 887 |
+
Returns:
|
| 888 |
+
cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
|
| 889 |
+
values of type `self.dtype`.
|
| 890 |
+
"""
|
| 891 |
+
return self._call_cdf(value, name)
|
| 892 |
+
|
| 893 |
+
def _log_survival_function(self, value):
|
| 894 |
+
raise NotImplementedError(
|
| 895 |
+
"log_survival_function is not implemented: {}".format(
|
| 896 |
+
type(self).__name__))
|
| 897 |
+
|
| 898 |
+
def _call_log_survival_function(self, value, name, **kwargs):
|
| 899 |
+
with self._name_scope(name, values=[value]):
|
| 900 |
+
value = _convert_to_tensor(
|
| 901 |
+
value, name="value", preferred_dtype=self.dtype)
|
| 902 |
+
try:
|
| 903 |
+
return self._log_survival_function(value, **kwargs)
|
| 904 |
+
except NotImplementedError as original_exception:
|
| 905 |
+
try:
|
| 906 |
+
return math_ops.log1p(-self.cdf(value, **kwargs))
|
| 907 |
+
except NotImplementedError:
|
| 908 |
+
raise original_exception
|
| 909 |
+
|
| 910 |
+
def log_survival_function(self, value, name="log_survival_function"):
|
| 911 |
+
"""Log survival function.
|
| 912 |
+
|
| 913 |
+
Given random variable `X`, the survival function is defined:
|
| 914 |
+
|
| 915 |
+
```none
|
| 916 |
+
log_survival_function(x) = Log[ P[X > x] ]
|
| 917 |
+
= Log[ 1 - P[X <= x] ]
|
| 918 |
+
= Log[ 1 - cdf(x) ]
|
| 919 |
+
```
|
| 920 |
+
|
| 921 |
+
Typically, different numerical approximations can be used for the log
|
| 922 |
+
survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
|
| 923 |
+
|
| 924 |
+
Args:
|
| 925 |
+
value: `float` or `double` `Tensor`.
|
| 926 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 927 |
+
|
| 928 |
+
Returns:
|
| 929 |
+
`Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
|
| 930 |
+
`self.dtype`.
|
| 931 |
+
"""
|
| 932 |
+
return self._call_log_survival_function(value, name)
|
| 933 |
+
|
| 934 |
+
def _survival_function(self, value):
|
| 935 |
+
raise NotImplementedError("survival_function is not implemented: {}".format(
|
| 936 |
+
type(self).__name__))
|
| 937 |
+
|
| 938 |
+
def _call_survival_function(self, value, name, **kwargs):
|
| 939 |
+
with self._name_scope(name, values=[value]):
|
| 940 |
+
value = _convert_to_tensor(
|
| 941 |
+
value, name="value", preferred_dtype=self.dtype)
|
| 942 |
+
try:
|
| 943 |
+
return self._survival_function(value, **kwargs)
|
| 944 |
+
except NotImplementedError as original_exception:
|
| 945 |
+
try:
|
| 946 |
+
return 1. - self.cdf(value, **kwargs)
|
| 947 |
+
except NotImplementedError:
|
| 948 |
+
raise original_exception
|
| 949 |
+
|
| 950 |
+
def survival_function(self, value, name="survival_function"):
|
| 951 |
+
"""Survival function.
|
| 952 |
+
|
| 953 |
+
Given random variable `X`, the survival function is defined:
|
| 954 |
+
|
| 955 |
+
```none
|
| 956 |
+
survival_function(x) = P[X > x]
|
| 957 |
+
= 1 - P[X <= x]
|
| 958 |
+
= 1 - cdf(x).
|
| 959 |
+
```
|
| 960 |
+
|
| 961 |
+
Args:
|
| 962 |
+
value: `float` or `double` `Tensor`.
|
| 963 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 964 |
+
|
| 965 |
+
Returns:
|
| 966 |
+
`Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
|
| 967 |
+
`self.dtype`.
|
| 968 |
+
"""
|
| 969 |
+
return self._call_survival_function(value, name)
|
| 970 |
+
|
| 971 |
+
def _entropy(self):
|
| 972 |
+
raise NotImplementedError("entropy is not implemented: {}".format(
|
| 973 |
+
type(self).__name__))
|
| 974 |
+
|
| 975 |
+
def entropy(self, name="entropy"):
|
| 976 |
+
"""Shannon entropy in nats."""
|
| 977 |
+
with self._name_scope(name):
|
| 978 |
+
return self._entropy()
|
| 979 |
+
|
| 980 |
+
def _mean(self):
|
| 981 |
+
raise NotImplementedError("mean is not implemented: {}".format(
|
| 982 |
+
type(self).__name__))
|
| 983 |
+
|
| 984 |
+
def mean(self, name="mean"):
|
| 985 |
+
"""Mean."""
|
| 986 |
+
with self._name_scope(name):
|
| 987 |
+
return self._mean()
|
| 988 |
+
|
| 989 |
+
def _quantile(self, value):
|
| 990 |
+
raise NotImplementedError("quantile is not implemented: {}".format(
|
| 991 |
+
type(self).__name__))
|
| 992 |
+
|
| 993 |
+
def _call_quantile(self, value, name, **kwargs):
|
| 994 |
+
with self._name_scope(name, values=[value]):
|
| 995 |
+
value = _convert_to_tensor(
|
| 996 |
+
value, name="value", preferred_dtype=self.dtype)
|
| 997 |
+
return self._quantile(value, **kwargs)
|
| 998 |
+
|
| 999 |
+
def quantile(self, value, name="quantile"):
|
| 1000 |
+
"""Quantile function. Aka "inverse cdf" or "percent point function".
|
| 1001 |
+
|
| 1002 |
+
Given random variable `X` and `p in [0, 1]`, the `quantile` is:
|
| 1003 |
+
|
| 1004 |
+
```none
|
| 1005 |
+
quantile(p) := x such that P[X <= x] == p
|
| 1006 |
+
```
|
| 1007 |
+
|
| 1008 |
+
Args:
|
| 1009 |
+
value: `float` or `double` `Tensor`.
|
| 1010 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 1011 |
+
|
| 1012 |
+
Returns:
|
| 1013 |
+
quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
|
| 1014 |
+
values of type `self.dtype`.
|
| 1015 |
+
"""
|
| 1016 |
+
return self._call_quantile(value, name)
|
| 1017 |
+
|
| 1018 |
+
def _variance(self):
|
| 1019 |
+
raise NotImplementedError("variance is not implemented: {}".format(
|
| 1020 |
+
type(self).__name__))
|
| 1021 |
+
|
| 1022 |
+
def variance(self, name="variance"):
|
| 1023 |
+
"""Variance.
|
| 1024 |
+
|
| 1025 |
+
Variance is defined as,
|
| 1026 |
+
|
| 1027 |
+
```none
|
| 1028 |
+
Var = E[(X - E[X])**2]
|
| 1029 |
+
```
|
| 1030 |
+
|
| 1031 |
+
where `X` is the random variable associated with this distribution, `E`
|
| 1032 |
+
denotes expectation, and `Var.shape = batch_shape + event_shape`.
|
| 1033 |
+
|
| 1034 |
+
Args:
|
| 1035 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 1036 |
+
|
| 1037 |
+
Returns:
|
| 1038 |
+
variance: Floating-point `Tensor` with shape identical to
|
| 1039 |
+
`batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
|
| 1040 |
+
"""
|
| 1041 |
+
with self._name_scope(name):
|
| 1042 |
+
try:
|
| 1043 |
+
return self._variance()
|
| 1044 |
+
except NotImplementedError as original_exception:
|
| 1045 |
+
try:
|
| 1046 |
+
return math_ops.square(self._stddev())
|
| 1047 |
+
except NotImplementedError:
|
| 1048 |
+
raise original_exception
|
| 1049 |
+
|
| 1050 |
+
def _stddev(self):
|
| 1051 |
+
raise NotImplementedError("stddev is not implemented: {}".format(
|
| 1052 |
+
type(self).__name__))
|
| 1053 |
+
|
| 1054 |
+
def stddev(self, name="stddev"):
|
| 1055 |
+
"""Standard deviation.
|
| 1056 |
+
|
| 1057 |
+
Standard deviation is defined as,
|
| 1058 |
+
|
| 1059 |
+
```none
|
| 1060 |
+
stddev = E[(X - E[X])**2]**0.5
|
| 1061 |
+
```
|
| 1062 |
+
|
| 1063 |
+
where `X` is the random variable associated with this distribution, `E`
|
| 1064 |
+
denotes expectation, and `stddev.shape = batch_shape + event_shape`.
|
| 1065 |
+
|
| 1066 |
+
Args:
|
| 1067 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 1068 |
+
|
| 1069 |
+
Returns:
|
| 1070 |
+
stddev: Floating-point `Tensor` with shape identical to
|
| 1071 |
+
`batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
|
| 1072 |
+
"""
|
| 1073 |
+
|
| 1074 |
+
with self._name_scope(name):
|
| 1075 |
+
try:
|
| 1076 |
+
return self._stddev()
|
| 1077 |
+
except NotImplementedError as original_exception:
|
| 1078 |
+
try:
|
| 1079 |
+
return math_ops.sqrt(self._variance())
|
| 1080 |
+
except NotImplementedError:
|
| 1081 |
+
raise original_exception
|
| 1082 |
+
|
| 1083 |
+
def _covariance(self):
|
| 1084 |
+
raise NotImplementedError("covariance is not implemented: {}".format(
|
| 1085 |
+
type(self).__name__))
|
| 1086 |
+
|
| 1087 |
+
def covariance(self, name="covariance"):
|
| 1088 |
+
"""Covariance.
|
| 1089 |
+
|
| 1090 |
+
Covariance is (possibly) defined only for non-scalar-event distributions.
|
| 1091 |
+
|
| 1092 |
+
For example, for a length-`k`, vector-valued distribution, it is calculated
|
| 1093 |
+
as,
|
| 1094 |
+
|
| 1095 |
+
```none
|
| 1096 |
+
Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]
|
| 1097 |
+
```
|
| 1098 |
+
|
| 1099 |
+
where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E`
|
| 1100 |
+
denotes expectation.
|
| 1101 |
+
|
| 1102 |
+
Alternatively, for non-vector, multivariate distributions (e.g.,
|
| 1103 |
+
matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices
|
| 1104 |
+
under some vectorization of the events, i.e.,
|
| 1105 |
+
|
| 1106 |
+
```none
|
| 1107 |
+
Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]
|
| 1108 |
+
```
|
| 1109 |
+
|
| 1110 |
+
where `Cov` is a (batch of) `k' x k'` matrices,
|
| 1111 |
+
`0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function
|
| 1112 |
+
mapping indices of this distribution's event dimensions to indices of a
|
| 1113 |
+
length-`k'` vector.
|
| 1114 |
+
|
| 1115 |
+
Args:
|
| 1116 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 1117 |
+
|
| 1118 |
+
Returns:
|
| 1119 |
+
covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']`
|
| 1120 |
+
where the first `n` dimensions are batch coordinates and
|
| 1121 |
+
`k' = reduce_prod(self.event_shape)`.
|
| 1122 |
+
"""
|
| 1123 |
+
with self._name_scope(name):
|
| 1124 |
+
return self._covariance()
|
| 1125 |
+
|
| 1126 |
+
def _mode(self):
|
| 1127 |
+
raise NotImplementedError("mode is not implemented: {}".format(
|
| 1128 |
+
type(self).__name__))
|
| 1129 |
+
|
| 1130 |
+
def mode(self, name="mode"):
|
| 1131 |
+
"""Mode."""
|
| 1132 |
+
with self._name_scope(name):
|
| 1133 |
+
return self._mode()
|
| 1134 |
+
|
| 1135 |
+
def _cross_entropy(self, other):
|
| 1136 |
+
return kullback_leibler.cross_entropy(
|
| 1137 |
+
self, other, allow_nan_stats=self.allow_nan_stats)
|
| 1138 |
+
|
| 1139 |
+
def cross_entropy(self, other, name="cross_entropy"):
|
| 1140 |
+
"""Computes the (Shannon) cross entropy.
|
| 1141 |
+
|
| 1142 |
+
Denote this distribution (`self`) by `P` and the `other` distribution by
|
| 1143 |
+
`Q`. Assuming `P, Q` are absolutely continuous with respect to
|
| 1144 |
+
one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon)
|
| 1145 |
+
cross entropy is defined as:
|
| 1146 |
+
|
| 1147 |
+
```none
|
| 1148 |
+
H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
|
| 1149 |
+
```
|
| 1150 |
+
|
| 1151 |
+
where `F` denotes the support of the random variable `X ~ P`.
|
| 1152 |
+
|
| 1153 |
+
Args:
|
| 1154 |
+
other: `tfp.distributions.Distribution` instance.
|
| 1155 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 1156 |
+
|
| 1157 |
+
Returns:
|
| 1158 |
+
cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
|
| 1159 |
+
representing `n` different calculations of (Shanon) cross entropy.
|
| 1160 |
+
"""
|
| 1161 |
+
with self._name_scope(name):
|
| 1162 |
+
return self._cross_entropy(other)
|
| 1163 |
+
|
| 1164 |
+
def _kl_divergence(self, other):
|
| 1165 |
+
return kullback_leibler.kl_divergence(
|
| 1166 |
+
self, other, allow_nan_stats=self.allow_nan_stats)
|
| 1167 |
+
|
| 1168 |
+
def kl_divergence(self, other, name="kl_divergence"):
|
| 1169 |
+
"""Computes the Kullback--Leibler divergence.
|
| 1170 |
+
|
| 1171 |
+
Denote this distribution (`self`) by `p` and the `other` distribution by
|
| 1172 |
+
`q`. Assuming `p, q` are absolutely continuous with respect to reference
|
| 1173 |
+
measure `r`, the KL divergence is defined as:
|
| 1174 |
+
|
| 1175 |
+
```none
|
| 1176 |
+
KL[p, q] = E_p[log(p(X)/q(X))]
|
| 1177 |
+
= -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
|
| 1178 |
+
= H[p, q] - H[p]
|
| 1179 |
+
```
|
| 1180 |
+
|
| 1181 |
+
where `F` denotes the support of the random variable `X ~ p`, `H[., .]`
|
| 1182 |
+
denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
|
| 1183 |
+
|
| 1184 |
+
Args:
|
| 1185 |
+
other: `tfp.distributions.Distribution` instance.
|
| 1186 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 1187 |
+
|
| 1188 |
+
Returns:
|
| 1189 |
+
kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
|
| 1190 |
+
representing `n` different calculations of the Kullback-Leibler
|
| 1191 |
+
divergence.
|
| 1192 |
+
"""
|
| 1193 |
+
with self._name_scope(name):
|
| 1194 |
+
return self._kl_divergence(other)
|
| 1195 |
+
|
| 1196 |
+
def __str__(self):
|
| 1197 |
+
return ("tfp.distributions.{type_name}("
|
| 1198 |
+
"\"{self_name}\""
|
| 1199 |
+
"{maybe_batch_shape}"
|
| 1200 |
+
"{maybe_event_shape}"
|
| 1201 |
+
", dtype={dtype})".format(
|
| 1202 |
+
type_name=type(self).__name__,
|
| 1203 |
+
self_name=self.name,
|
| 1204 |
+
maybe_batch_shape=(", batch_shape={}".format(self.batch_shape)
|
| 1205 |
+
if self.batch_shape.ndims is not None
|
| 1206 |
+
else ""),
|
| 1207 |
+
maybe_event_shape=(", event_shape={}".format(self.event_shape)
|
| 1208 |
+
if self.event_shape.ndims is not None
|
| 1209 |
+
else ""),
|
| 1210 |
+
dtype=self.dtype.name))
|
| 1211 |
+
|
| 1212 |
+
def __repr__(self):
|
| 1213 |
+
return ("<tfp.distributions.{type_name} "
|
| 1214 |
+
"'{self_name}'"
|
| 1215 |
+
" batch_shape={batch_shape}"
|
| 1216 |
+
" event_shape={event_shape}"
|
| 1217 |
+
" dtype={dtype}>".format(
|
| 1218 |
+
type_name=type(self).__name__,
|
| 1219 |
+
self_name=self.name,
|
| 1220 |
+
batch_shape=self.batch_shape,
|
| 1221 |
+
event_shape=self.event_shape,
|
| 1222 |
+
dtype=self.dtype.name))
|
| 1223 |
+
|
| 1224 |
+
@contextlib.contextmanager
|
| 1225 |
+
def _name_scope(self, name=None, values=None):
|
| 1226 |
+
"""Helper function to standardize op scope."""
|
| 1227 |
+
with ops.name_scope(self.name):
|
| 1228 |
+
with ops.name_scope(name, values=(
|
| 1229 |
+
([] if values is None else values) + self._graph_parents)) as scope:
|
| 1230 |
+
yield scope
|
| 1231 |
+
|
| 1232 |
+
def _expand_sample_shape_to_vector(self, x, name):
|
| 1233 |
+
"""Helper to `sample` which ensures input is 1D."""
|
| 1234 |
+
x_static_val = tensor_util.constant_value(x)
|
| 1235 |
+
if x_static_val is None:
|
| 1236 |
+
prod = math_ops.reduce_prod(x)
|
| 1237 |
+
else:
|
| 1238 |
+
prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())
|
| 1239 |
+
|
| 1240 |
+
ndims = x.get_shape().ndims # != sample_ndims
|
| 1241 |
+
if ndims is None:
|
| 1242 |
+
# Maybe expand_dims.
|
| 1243 |
+
ndims = array_ops.rank(x)
|
| 1244 |
+
expanded_shape = util.pick_vector(
|
| 1245 |
+
math_ops.equal(ndims, 0),
|
| 1246 |
+
np.array([1], dtype=np.int32), array_ops.shape(x))
|
| 1247 |
+
x = array_ops.reshape(x, expanded_shape)
|
| 1248 |
+
elif ndims == 0:
|
| 1249 |
+
# Definitely expand_dims.
|
| 1250 |
+
if x_static_val is not None:
|
| 1251 |
+
x = ops.convert_to_tensor(
|
| 1252 |
+
np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()),
|
| 1253 |
+
name=name)
|
| 1254 |
+
else:
|
| 1255 |
+
x = array_ops.reshape(x, [1])
|
| 1256 |
+
elif ndims != 1:
|
| 1257 |
+
raise ValueError("Input is neither scalar nor vector.")
|
| 1258 |
+
|
| 1259 |
+
return x, prod
|
| 1260 |
+
|
| 1261 |
+
def _set_sample_static_shape(self, x, sample_shape):
|
| 1262 |
+
"""Helper to `sample`; sets static shape info."""
|
| 1263 |
+
# Set shape hints.
|
| 1264 |
+
sample_shape = tensor_shape.TensorShape(
|
| 1265 |
+
tensor_util.constant_value(sample_shape))
|
| 1266 |
+
|
| 1267 |
+
ndims = x.get_shape().ndims
|
| 1268 |
+
sample_ndims = sample_shape.ndims
|
| 1269 |
+
batch_ndims = self.batch_shape.ndims
|
| 1270 |
+
event_ndims = self.event_shape.ndims
|
| 1271 |
+
|
| 1272 |
+
# Infer rank(x).
|
| 1273 |
+
if (ndims is None and
|
| 1274 |
+
sample_ndims is not None and
|
| 1275 |
+
batch_ndims is not None and
|
| 1276 |
+
event_ndims is not None):
|
| 1277 |
+
ndims = sample_ndims + batch_ndims + event_ndims
|
| 1278 |
+
x.set_shape([None] * ndims)
|
| 1279 |
+
|
| 1280 |
+
# Infer sample shape.
|
| 1281 |
+
if ndims is not None and sample_ndims is not None:
|
| 1282 |
+
shape = sample_shape.concatenate([None]*(ndims - sample_ndims))
|
| 1283 |
+
x.set_shape(x.get_shape().merge_with(shape))
|
| 1284 |
+
|
| 1285 |
+
# Infer event shape.
|
| 1286 |
+
if ndims is not None and event_ndims is not None:
|
| 1287 |
+
shape = tensor_shape.TensorShape(
|
| 1288 |
+
[None]*(ndims - event_ndims)).concatenate(self.event_shape)
|
| 1289 |
+
x.set_shape(x.get_shape().merge_with(shape))
|
| 1290 |
+
|
| 1291 |
+
# Infer batch shape.
|
| 1292 |
+
if batch_ndims is not None:
|
| 1293 |
+
if ndims is not None:
|
| 1294 |
+
if sample_ndims is None and event_ndims is not None:
|
| 1295 |
+
sample_ndims = ndims - batch_ndims - event_ndims
|
| 1296 |
+
elif event_ndims is None and sample_ndims is not None:
|
| 1297 |
+
event_ndims = ndims - batch_ndims - sample_ndims
|
| 1298 |
+
if sample_ndims is not None and event_ndims is not None:
|
| 1299 |
+
shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate(
|
| 1300 |
+
self.batch_shape).concatenate([None]*event_ndims)
|
| 1301 |
+
x.set_shape(x.get_shape().merge_with(shape))
|
| 1302 |
+
|
| 1303 |
+
return x
|
| 1304 |
+
|
| 1305 |
+
def _is_scalar_helper(self, static_shape, dynamic_shape_fn):
|
| 1306 |
+
"""Implementation for `is_scalar_batch` and `is_scalar_event`."""
|
| 1307 |
+
if static_shape.ndims is not None:
|
| 1308 |
+
return static_shape.ndims == 0
|
| 1309 |
+
shape = dynamic_shape_fn()
|
| 1310 |
+
if (shape.get_shape().ndims is not None and
|
| 1311 |
+
shape.get_shape().dims[0].value is not None):
|
| 1312 |
+
# If the static_shape is correctly written then we should never execute
|
| 1313 |
+
# this branch. We keep it just in case there's some unimagined corner
|
| 1314 |
+
# case.
|
| 1315 |
+
return shape.get_shape().as_list() == [0]
|
| 1316 |
+
return math_ops.equal(array_ops.shape(shape)[0], 0)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/distributions.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Core module for TensorFlow distribution objects and helpers."""
|
| 16 |
+
from tensorflow.python.util import deprecation
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
|
| 20 |
+
with deprecation.silence():
|
| 21 |
+
from tensorflow.python.ops.distributions.bernoulli import Bernoulli
|
| 22 |
+
from tensorflow.python.ops.distributions.beta import Beta
|
| 23 |
+
from tensorflow.python.ops.distributions.categorical import Categorical
|
| 24 |
+
from tensorflow.python.ops.distributions.dirichlet import Dirichlet
|
| 25 |
+
from tensorflow.python.ops.distributions.dirichlet_multinomial import DirichletMultinomial
|
| 26 |
+
from tensorflow.python.ops.distributions.distribution import *
|
| 27 |
+
from tensorflow.python.ops.distributions.exponential import Exponential
|
| 28 |
+
from tensorflow.python.ops.distributions.gamma import Gamma
|
| 29 |
+
from tensorflow.python.ops.distributions.kullback_leibler import *
|
| 30 |
+
from tensorflow.python.ops.distributions.laplace import Laplace
|
| 31 |
+
from tensorflow.python.ops.distributions.multinomial import Multinomial
|
| 32 |
+
from tensorflow.python.ops.distributions.normal import Normal
|
| 33 |
+
from tensorflow.python.ops.distributions.student_t import StudentT
|
| 34 |
+
from tensorflow.python.ops.distributions.uniform import Uniform
|
| 35 |
+
# pylint: enable=wildcard-import,unused-import
|
| 36 |
+
del deprecation
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/exponential.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Exponential distribution class."""
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from tensorflow.python.framework import dtypes
|
| 20 |
+
from tensorflow.python.framework import ops
|
| 21 |
+
from tensorflow.python.ops import array_ops
|
| 22 |
+
from tensorflow.python.ops import math_ops
|
| 23 |
+
from tensorflow.python.ops import nn
|
| 24 |
+
from tensorflow.python.ops import random_ops
|
| 25 |
+
from tensorflow.python.ops.distributions import gamma
|
| 26 |
+
from tensorflow.python.util import deprecation
|
| 27 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"Exponential",
|
| 32 |
+
"ExponentialWithSoftplusRate",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@tf_export(v1=["distributions.Exponential"])
|
| 37 |
+
class Exponential(gamma.Gamma):
|
| 38 |
+
"""Exponential distribution.
|
| 39 |
+
|
| 40 |
+
The Exponential distribution is parameterized by an event `rate` parameter.
|
| 41 |
+
|
| 42 |
+
#### Mathematical Details
|
| 43 |
+
|
| 44 |
+
The probability density function (pdf) is,
|
| 45 |
+
|
| 46 |
+
```none
|
| 47 |
+
pdf(x; lambda, x > 0) = exp(-lambda x) / Z
|
| 48 |
+
Z = 1 / lambda
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
where `rate = lambda` and `Z` is the normalizaing constant.
|
| 52 |
+
|
| 53 |
+
The Exponential distribution is a special case of the Gamma distribution,
|
| 54 |
+
i.e.,
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
Exponential(rate) = Gamma(concentration=1., rate)
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
The Exponential distribution uses a `rate` parameter, or "inverse scale",
|
| 61 |
+
which can be intuited as,
|
| 62 |
+
|
| 63 |
+
```none
|
| 64 |
+
X ~ Exponential(rate=1)
|
| 65 |
+
Y = X / rate
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
@deprecation.deprecated(
|
| 71 |
+
"2019-01-01",
|
| 72 |
+
"The TensorFlow Distributions library has moved to "
|
| 73 |
+
"TensorFlow Probability "
|
| 74 |
+
"(https://github.com/tensorflow/probability). You "
|
| 75 |
+
"should update all references to use `tfp.distributions` "
|
| 76 |
+
"instead of `tf.distributions`.",
|
| 77 |
+
warn_once=True)
|
| 78 |
+
def __init__(self,
|
| 79 |
+
rate,
|
| 80 |
+
validate_args=False,
|
| 81 |
+
allow_nan_stats=True,
|
| 82 |
+
name="Exponential"):
|
| 83 |
+
"""Construct Exponential distribution with parameter `rate`.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
rate: Floating point tensor, equivalent to `1 / mean`. Must contain only
|
| 87 |
+
positive values.
|
| 88 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 89 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 90 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 91 |
+
outputs.
|
| 92 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
| 93 |
+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
| 94 |
+
result is undefined. When `False`, an exception is raised if one or
|
| 95 |
+
more of the statistic's batch members are undefined.
|
| 96 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 97 |
+
"""
|
| 98 |
+
parameters = dict(locals())
|
| 99 |
+
# Even though all statistics of are defined for valid inputs, this is not
|
| 100 |
+
# true in the parent class "Gamma." Therefore, passing
|
| 101 |
+
# allow_nan_stats=True
|
| 102 |
+
# through to the parent class results in unnecessary asserts.
|
| 103 |
+
with ops.name_scope(name, values=[rate]) as name:
|
| 104 |
+
self._rate = ops.convert_to_tensor(rate, name="rate")
|
| 105 |
+
super(Exponential, self).__init__(
|
| 106 |
+
concentration=array_ops.ones([], dtype=self._rate.dtype),
|
| 107 |
+
rate=self._rate,
|
| 108 |
+
allow_nan_stats=allow_nan_stats,
|
| 109 |
+
validate_args=validate_args,
|
| 110 |
+
name=name)
|
| 111 |
+
self._parameters = parameters
|
| 112 |
+
self._graph_parents += [self._rate]
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def _param_shapes(sample_shape):
|
| 116 |
+
return {"rate": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def rate(self):
|
| 120 |
+
return self._rate
|
| 121 |
+
|
| 122 |
+
def _log_survival_function(self, value):
|
| 123 |
+
return self._log_prob(value) - math_ops.log(self._rate)
|
| 124 |
+
|
| 125 |
+
def _sample_n(self, n, seed=None):
|
| 126 |
+
shape = array_ops.concat([[n], array_ops.shape(self._rate)], 0)
|
| 127 |
+
# Uniform variates must be sampled from the open-interval `(0, 1)` rather
|
| 128 |
+
# than `[0, 1)`. To do so, we use `np.finfo(self.dtype.as_numpy_dtype).tiny`
|
| 129 |
+
# because it is the smallest, positive, "normal" number. A "normal" number
|
| 130 |
+
# is such that the mantissa has an implicit leading 1. Normal, positive
|
| 131 |
+
# numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In
|
| 132 |
+
# this case, a subnormal number (i.e., np.nextafter) can cause us to sample
|
| 133 |
+
# 0.
|
| 134 |
+
sampled = random_ops.random_uniform(
|
| 135 |
+
shape,
|
| 136 |
+
minval=np.finfo(self.dtype.as_numpy_dtype).tiny,
|
| 137 |
+
maxval=1.,
|
| 138 |
+
seed=seed,
|
| 139 |
+
dtype=self.dtype)
|
| 140 |
+
return -math_ops.log(sampled) / self._rate
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class ExponentialWithSoftplusRate(Exponential):
|
| 144 |
+
"""Exponential with softplus transform on `rate`."""
|
| 145 |
+
|
| 146 |
+
@deprecation.deprecated(
|
| 147 |
+
"2019-01-01",
|
| 148 |
+
"Use `tfd.Exponential(tf.nn.softplus(rate)).",
|
| 149 |
+
warn_once=True)
|
| 150 |
+
def __init__(self,
|
| 151 |
+
rate,
|
| 152 |
+
validate_args=False,
|
| 153 |
+
allow_nan_stats=True,
|
| 154 |
+
name="ExponentialWithSoftplusRate"):
|
| 155 |
+
parameters = dict(locals())
|
| 156 |
+
with ops.name_scope(name, values=[rate]) as name:
|
| 157 |
+
super(ExponentialWithSoftplusRate, self).__init__(
|
| 158 |
+
rate=nn.softplus(rate, name="softplus_rate"),
|
| 159 |
+
validate_args=validate_args,
|
| 160 |
+
allow_nan_stats=allow_nan_stats,
|
| 161 |
+
name=name)
|
| 162 |
+
self._parameters = parameters
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/gamma.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Gamma distribution class."""
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from tensorflow.python.framework import constant_op
|
| 20 |
+
from tensorflow.python.framework import dtypes
|
| 21 |
+
from tensorflow.python.framework import ops
|
| 22 |
+
from tensorflow.python.framework import tensor_shape
|
| 23 |
+
from tensorflow.python.ops import array_ops
|
| 24 |
+
from tensorflow.python.ops import check_ops
|
| 25 |
+
from tensorflow.python.ops import control_flow_ops
|
| 26 |
+
from tensorflow.python.ops import math_ops
|
| 27 |
+
from tensorflow.python.ops import nn
|
| 28 |
+
from tensorflow.python.ops import random_ops
|
| 29 |
+
from tensorflow.python.ops.distributions import distribution
|
| 30 |
+
from tensorflow.python.ops.distributions import kullback_leibler
|
| 31 |
+
from tensorflow.python.ops.distributions import util as distribution_util
|
| 32 |
+
from tensorflow.python.util import deprecation
|
| 33 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"Gamma",
|
| 38 |
+
"GammaWithSoftplusConcentrationRate",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@tf_export(v1=["distributions.Gamma"])
|
| 43 |
+
class Gamma(distribution.Distribution):
|
| 44 |
+
"""Gamma distribution.
|
| 45 |
+
|
| 46 |
+
The Gamma distribution is defined over positive real numbers using
|
| 47 |
+
parameters `concentration` (aka "alpha") and `rate` (aka "beta").
|
| 48 |
+
|
| 49 |
+
#### Mathematical Details
|
| 50 |
+
|
| 51 |
+
The probability density function (pdf) is,
|
| 52 |
+
|
| 53 |
+
```none
|
| 54 |
+
pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z
|
| 55 |
+
Z = Gamma(alpha) beta**(-alpha)
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
where:
|
| 59 |
+
|
| 60 |
+
* `concentration = alpha`, `alpha > 0`,
|
| 61 |
+
* `rate = beta`, `beta > 0`,
|
| 62 |
+
* `Z` is the normalizing constant, and,
|
| 63 |
+
* `Gamma` is the [gamma function](
|
| 64 |
+
https://en.wikipedia.org/wiki/Gamma_function).
|
| 65 |
+
|
| 66 |
+
The cumulative density function (cdf) is,
|
| 67 |
+
|
| 68 |
+
```none
|
| 69 |
+
cdf(x; alpha, beta, x > 0) = GammaInc(alpha, beta x) / Gamma(alpha)
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
where `GammaInc` is the [lower incomplete Gamma function](
|
| 73 |
+
https://en.wikipedia.org/wiki/Incomplete_gamma_function).
|
| 74 |
+
|
| 75 |
+
The parameters can be intuited via their relationship to mean and stddev,
|
| 76 |
+
|
| 77 |
+
```none
|
| 78 |
+
concentration = alpha = (mean / stddev)**2
|
| 79 |
+
rate = beta = mean / stddev**2 = concentration / mean
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Distribution parameters are automatically broadcast in all functions; see
|
| 83 |
+
examples for details.
|
| 84 |
+
|
| 85 |
+
Warning: The samples of this distribution are always non-negative. However,
|
| 86 |
+
the samples that are smaller than `np.finfo(dtype).tiny` are rounded
|
| 87 |
+
to this value, so it appears more often than it should.
|
| 88 |
+
This should only be noticeable when the `concentration` is very small, or the
|
| 89 |
+
`rate` is very large. See note in `tf.random.gamma` docstring.
|
| 90 |
+
|
| 91 |
+
Samples of this distribution are reparameterized (pathwise differentiable).
|
| 92 |
+
The derivatives are computed using the approach described in
|
| 93 |
+
(Figurnov et al., 2018).
|
| 94 |
+
|
| 95 |
+
#### Examples
|
| 96 |
+
|
| 97 |
+
```python
|
| 98 |
+
import tensorflow_probability as tfp
|
| 99 |
+
tfd = tfp.distributions
|
| 100 |
+
|
| 101 |
+
dist = tfd.Gamma(concentration=3.0, rate=2.0)
|
| 102 |
+
dist2 = tfd.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Compute the gradients of samples w.r.t. the parameters:
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
concentration = tf.constant(3.0)
|
| 109 |
+
rate = tf.constant(2.0)
|
| 110 |
+
dist = tfd.Gamma(concentration, rate)
|
| 111 |
+
samples = dist.sample(5) # Shape [5]
|
| 112 |
+
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
|
| 113 |
+
# Unbiased stochastic gradients of the loss function
|
| 114 |
+
grads = tf.gradients(loss, [concentration, rate])
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
References:
|
| 118 |
+
Implicit Reparameterization Gradients:
|
| 119 |
+
[Figurnov et al., 2018]
|
| 120 |
+
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
|
| 121 |
+
([pdf](http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
@deprecation.deprecated(
|
| 125 |
+
"2019-01-01",
|
| 126 |
+
"The TensorFlow Distributions library has moved to "
|
| 127 |
+
"TensorFlow Probability "
|
| 128 |
+
"(https://github.com/tensorflow/probability). You "
|
| 129 |
+
"should update all references to use `tfp.distributions` "
|
| 130 |
+
"instead of `tf.distributions`.",
|
| 131 |
+
warn_once=True)
|
| 132 |
+
def __init__(self,
|
| 133 |
+
concentration,
|
| 134 |
+
rate,
|
| 135 |
+
validate_args=False,
|
| 136 |
+
allow_nan_stats=True,
|
| 137 |
+
name="Gamma"):
|
| 138 |
+
"""Construct Gamma with `concentration` and `rate` parameters.
|
| 139 |
+
|
| 140 |
+
The parameters `concentration` and `rate` must be shaped in a way that
|
| 141 |
+
supports broadcasting (e.g. `concentration + rate` is a valid operation).
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
concentration: Floating point tensor, the concentration params of the
|
| 145 |
+
distribution(s). Must contain only positive values.
|
| 146 |
+
rate: Floating point tensor, the inverse scale params of the
|
| 147 |
+
distribution(s). Must contain only positive values.
|
| 148 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 149 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 150 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 151 |
+
outputs.
|
| 152 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
| 153 |
+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
| 154 |
+
result is undefined. When `False`, an exception is raised if one or
|
| 155 |
+
more of the statistic's batch members are undefined.
|
| 156 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 157 |
+
|
| 158 |
+
Raises:
|
| 159 |
+
TypeError: if `concentration` and `rate` are different dtypes.
|
| 160 |
+
"""
|
| 161 |
+
parameters = dict(locals())
|
| 162 |
+
with ops.name_scope(name, values=[concentration, rate]) as name:
|
| 163 |
+
with ops.control_dependencies([
|
| 164 |
+
check_ops.assert_positive(concentration),
|
| 165 |
+
check_ops.assert_positive(rate),
|
| 166 |
+
] if validate_args else []):
|
| 167 |
+
self._concentration = array_ops.identity(
|
| 168 |
+
concentration, name="concentration")
|
| 169 |
+
self._rate = array_ops.identity(rate, name="rate")
|
| 170 |
+
check_ops.assert_same_float_dtype(
|
| 171 |
+
[self._concentration, self._rate])
|
| 172 |
+
super(Gamma, self).__init__(
|
| 173 |
+
dtype=self._concentration.dtype,
|
| 174 |
+
validate_args=validate_args,
|
| 175 |
+
allow_nan_stats=allow_nan_stats,
|
| 176 |
+
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
|
| 177 |
+
parameters=parameters,
|
| 178 |
+
graph_parents=[self._concentration,
|
| 179 |
+
self._rate],
|
| 180 |
+
name=name)
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def _param_shapes(sample_shape):
|
| 184 |
+
return dict(
|
| 185 |
+
zip(("concentration", "rate"), ([ops.convert_to_tensor(
|
| 186 |
+
sample_shape, dtype=dtypes.int32)] * 2)))
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def concentration(self):
|
| 190 |
+
"""Concentration parameter."""
|
| 191 |
+
return self._concentration
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def rate(self):
|
| 195 |
+
"""Rate parameter."""
|
| 196 |
+
return self._rate
|
| 197 |
+
|
| 198 |
+
def _batch_shape_tensor(self):
|
| 199 |
+
return array_ops.broadcast_dynamic_shape(
|
| 200 |
+
array_ops.shape(self.concentration),
|
| 201 |
+
array_ops.shape(self.rate))
|
| 202 |
+
|
| 203 |
+
def _batch_shape(self):
|
| 204 |
+
return array_ops.broadcast_static_shape(
|
| 205 |
+
self.concentration.get_shape(),
|
| 206 |
+
self.rate.get_shape())
|
| 207 |
+
|
| 208 |
+
def _event_shape_tensor(self):
|
| 209 |
+
return constant_op.constant([], dtype=dtypes.int32)
|
| 210 |
+
|
| 211 |
+
def _event_shape(self):
|
| 212 |
+
return tensor_shape.TensorShape([])
|
| 213 |
+
|
| 214 |
+
@distribution_util.AppendDocstring(
|
| 215 |
+
"""Note: See `tf.random.gamma` docstring for sampling details and
|
| 216 |
+
caveats.""")
|
| 217 |
+
def _sample_n(self, n, seed=None):
|
| 218 |
+
return random_ops.random_gamma(
|
| 219 |
+
shape=[n],
|
| 220 |
+
alpha=self.concentration,
|
| 221 |
+
beta=self.rate,
|
| 222 |
+
dtype=self.dtype,
|
| 223 |
+
seed=seed)
|
| 224 |
+
|
| 225 |
+
def _log_prob(self, x):
|
| 226 |
+
return self._log_unnormalized_prob(x) - self._log_normalization()
|
| 227 |
+
|
| 228 |
+
def _cdf(self, x):
|
| 229 |
+
x = self._maybe_assert_valid_sample(x)
|
| 230 |
+
# Note that igamma returns the regularized incomplete gamma function,
|
| 231 |
+
# which is what we want for the CDF.
|
| 232 |
+
return math_ops.igamma(self.concentration, self.rate * x)
|
| 233 |
+
|
| 234 |
+
def _log_unnormalized_prob(self, x):
|
| 235 |
+
x = self._maybe_assert_valid_sample(x)
|
| 236 |
+
return math_ops.xlogy(self.concentration - 1., x) - self.rate * x
|
| 237 |
+
|
| 238 |
+
def _log_normalization(self):
|
| 239 |
+
return (math_ops.lgamma(self.concentration)
|
| 240 |
+
- self.concentration * math_ops.log(self.rate))
|
| 241 |
+
|
| 242 |
+
def _entropy(self):
|
| 243 |
+
return (self.concentration
|
| 244 |
+
- math_ops.log(self.rate)
|
| 245 |
+
+ math_ops.lgamma(self.concentration)
|
| 246 |
+
+ ((1. - self.concentration) *
|
| 247 |
+
math_ops.digamma(self.concentration)))
|
| 248 |
+
|
| 249 |
+
def _mean(self):
|
| 250 |
+
return self.concentration / self.rate
|
| 251 |
+
|
| 252 |
+
def _variance(self):
|
| 253 |
+
return self.concentration / math_ops.square(self.rate)
|
| 254 |
+
|
| 255 |
+
def _stddev(self):
|
| 256 |
+
return math_ops.sqrt(self.concentration) / self.rate
|
| 257 |
+
|
| 258 |
+
@distribution_util.AppendDocstring(
|
| 259 |
+
"""The mode of a gamma distribution is `(shape - 1) / rate` when
|
| 260 |
+
`shape > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is `False`,
|
| 261 |
+
an exception will be raised rather than returning `NaN`.""")
|
| 262 |
+
def _mode(self):
|
| 263 |
+
mode = (self.concentration - 1.) / self.rate
|
| 264 |
+
if self.allow_nan_stats:
|
| 265 |
+
nan = array_ops.fill(
|
| 266 |
+
self.batch_shape_tensor(),
|
| 267 |
+
np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
|
| 268 |
+
name="nan")
|
| 269 |
+
return array_ops.where_v2(self.concentration > 1., mode, nan)
|
| 270 |
+
else:
|
| 271 |
+
return control_flow_ops.with_dependencies([
|
| 272 |
+
check_ops.assert_less(
|
| 273 |
+
array_ops.ones([], self.dtype),
|
| 274 |
+
self.concentration,
|
| 275 |
+
message="mode not defined when any concentration <= 1"),
|
| 276 |
+
], mode)
|
| 277 |
+
|
| 278 |
+
def _maybe_assert_valid_sample(self, x):
|
| 279 |
+
check_ops.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
|
| 280 |
+
if not self.validate_args:
|
| 281 |
+
return x
|
| 282 |
+
return control_flow_ops.with_dependencies([
|
| 283 |
+
check_ops.assert_positive(x),
|
| 284 |
+
], x)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class GammaWithSoftplusConcentrationRate(Gamma):
|
| 288 |
+
"""`Gamma` with softplus of `concentration` and `rate`."""
|
| 289 |
+
|
| 290 |
+
@deprecation.deprecated(
|
| 291 |
+
"2019-01-01",
|
| 292 |
+
"Use `tfd.Gamma(tf.nn.softplus(concentration), "
|
| 293 |
+
"tf.nn.softplus(rate))` instead.",
|
| 294 |
+
warn_once=True)
|
| 295 |
+
def __init__(self,
|
| 296 |
+
concentration,
|
| 297 |
+
rate,
|
| 298 |
+
validate_args=False,
|
| 299 |
+
allow_nan_stats=True,
|
| 300 |
+
name="GammaWithSoftplusConcentrationRate"):
|
| 301 |
+
parameters = dict(locals())
|
| 302 |
+
with ops.name_scope(name, values=[concentration, rate]) as name:
|
| 303 |
+
super(GammaWithSoftplusConcentrationRate, self).__init__(
|
| 304 |
+
concentration=nn.softplus(concentration,
|
| 305 |
+
name="softplus_concentration"),
|
| 306 |
+
rate=nn.softplus(rate, name="softplus_rate"),
|
| 307 |
+
validate_args=validate_args,
|
| 308 |
+
allow_nan_stats=allow_nan_stats,
|
| 309 |
+
name=name)
|
| 310 |
+
self._parameters = parameters
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@kullback_leibler.RegisterKL(Gamma, Gamma)
|
| 314 |
+
def _kl_gamma_gamma(g0, g1, name=None):
|
| 315 |
+
"""Calculate the batched KL divergence KL(g0 || g1) with g0 and g1 Gamma.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
g0: instance of a Gamma distribution object.
|
| 319 |
+
g1: instance of a Gamma distribution object.
|
| 320 |
+
name: (optional) Name to use for created operations.
|
| 321 |
+
Default is "kl_gamma_gamma".
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
kl_gamma_gamma: `Tensor`. The batchwise KL(g0 || g1).
|
| 325 |
+
"""
|
| 326 |
+
with ops.name_scope(name, "kl_gamma_gamma", values=[
|
| 327 |
+
g0.concentration, g0.rate, g1.concentration, g1.rate]):
|
| 328 |
+
# Result from:
|
| 329 |
+
# http://www.fil.ion.ucl.ac.uk/~wpenny/publications/densities.ps
|
| 330 |
+
# For derivation see:
|
| 331 |
+
# http://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions pylint: disable=line-too-long
|
| 332 |
+
return (((g0.concentration - g1.concentration)
|
| 333 |
+
* math_ops.digamma(g0.concentration))
|
| 334 |
+
+ math_ops.lgamma(g1.concentration)
|
| 335 |
+
- math_ops.lgamma(g0.concentration)
|
| 336 |
+
+ g1.concentration * math_ops.log(g0.rate)
|
| 337 |
+
- g1.concentration * math_ops.log(g1.rate)
|
| 338 |
+
+ g0.concentration * (g1.rate / g0.rate - 1.))
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/identity_bijector.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Identity bijector."""
|
| 16 |
+
|
| 17 |
+
from tensorflow.python.framework import constant_op
|
| 18 |
+
from tensorflow.python.ops.distributions import bijector
|
| 19 |
+
from tensorflow.python.util import deprecation
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"Identity",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Identity(bijector.Bijector):
|
| 28 |
+
"""Compute Y = g(X) = X.
|
| 29 |
+
|
| 30 |
+
Example Use:
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
# Create the Y=g(X)=X transform which is intended for Tensors with 1 batch
|
| 34 |
+
# ndim and 1 event ndim (i.e., vector of vectors).
|
| 35 |
+
identity = Identity()
|
| 36 |
+
x = [[1., 2],
|
| 37 |
+
[3, 4]]
|
| 38 |
+
x == identity.forward(x) == identity.inverse(x)
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
@deprecation.deprecated(
|
| 44 |
+
"2019-01-01",
|
| 45 |
+
"The TensorFlow Distributions library has moved to "
|
| 46 |
+
"TensorFlow Probability "
|
| 47 |
+
"(https://github.com/tensorflow/probability). You "
|
| 48 |
+
"should update all references to use `tfp.distributions` "
|
| 49 |
+
"instead of `tf.distributions`.",
|
| 50 |
+
warn_once=True)
|
| 51 |
+
def __init__(self, validate_args=False, name="identity"):
|
| 52 |
+
super(Identity, self).__init__(
|
| 53 |
+
forward_min_event_ndims=0,
|
| 54 |
+
is_constant_jacobian=True,
|
| 55 |
+
validate_args=validate_args,
|
| 56 |
+
name=name)
|
| 57 |
+
|
| 58 |
+
def _forward(self, x):
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
def _inverse(self, y):
|
| 62 |
+
return y
|
| 63 |
+
|
| 64 |
+
def _inverse_log_det_jacobian(self, y):
|
| 65 |
+
return constant_op.constant(0., dtype=y.dtype)
|
| 66 |
+
|
| 67 |
+
def _forward_log_det_jacobian(self, x):
|
| 68 |
+
return constant_op.constant(0., dtype=x.dtype)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/kullback_leibler.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Registration and usage mechanisms for KL-divergences."""
|
| 16 |
+
|
| 17 |
+
from tensorflow.python.framework import ops
|
| 18 |
+
from tensorflow.python.ops import array_ops
|
| 19 |
+
from tensorflow.python.ops import control_flow_assert
|
| 20 |
+
from tensorflow.python.ops import math_ops
|
| 21 |
+
from tensorflow.python.util import deprecation
|
| 22 |
+
from tensorflow.python.util import tf_inspect
|
| 23 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_DIVERGENCES = {}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"RegisterKL",
|
| 31 |
+
"kl_divergence",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _registered_kl(type_a, type_b):
|
| 36 |
+
"""Get the KL function registered for classes a and b."""
|
| 37 |
+
hierarchy_a = tf_inspect.getmro(type_a)
|
| 38 |
+
hierarchy_b = tf_inspect.getmro(type_b)
|
| 39 |
+
dist_to_children = None
|
| 40 |
+
kl_fn = None
|
| 41 |
+
for mro_to_a, parent_a in enumerate(hierarchy_a):
|
| 42 |
+
for mro_to_b, parent_b in enumerate(hierarchy_b):
|
| 43 |
+
candidate_dist = mro_to_a + mro_to_b
|
| 44 |
+
candidate_kl_fn = _DIVERGENCES.get((parent_a, parent_b), None)
|
| 45 |
+
if not kl_fn or (candidate_kl_fn and candidate_dist < dist_to_children):
|
| 46 |
+
dist_to_children = candidate_dist
|
| 47 |
+
kl_fn = candidate_kl_fn
|
| 48 |
+
return kl_fn
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@deprecation.deprecated(
|
| 52 |
+
"2019-01-01",
|
| 53 |
+
"The TensorFlow Distributions library has moved to "
|
| 54 |
+
"TensorFlow Probability "
|
| 55 |
+
"(https://github.com/tensorflow/probability). You "
|
| 56 |
+
"should update all references to use `tfp.distributions` "
|
| 57 |
+
"instead of `tf.distributions`.",
|
| 58 |
+
warn_once=True)
|
| 59 |
+
@tf_export(v1=["distributions.kl_divergence"])
|
| 60 |
+
def kl_divergence(distribution_a, distribution_b,
|
| 61 |
+
allow_nan_stats=True, name=None):
|
| 62 |
+
"""Get the KL-divergence KL(distribution_a || distribution_b).
|
| 63 |
+
|
| 64 |
+
If there is no KL method registered specifically for `type(distribution_a)`
|
| 65 |
+
and `type(distribution_b)`, then the class hierarchies of these types are
|
| 66 |
+
searched.
|
| 67 |
+
|
| 68 |
+
If one KL method is registered between any pairs of classes in these two
|
| 69 |
+
parent hierarchies, it is used.
|
| 70 |
+
|
| 71 |
+
If more than one such registered method exists, the method whose registered
|
| 72 |
+
classes have the shortest sum MRO paths to the input types is used.
|
| 73 |
+
|
| 74 |
+
If more than one such shortest path exists, the first method
|
| 75 |
+
identified in the search is used (favoring a shorter MRO distance to
|
| 76 |
+
`type(distribution_a)`).
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
distribution_a: The first distribution.
|
| 80 |
+
distribution_b: The second distribution.
|
| 81 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`,
|
| 82 |
+
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
|
| 83 |
+
indicate the result is undefined. When `False`, an exception is raised
|
| 84 |
+
if one or more of the statistic's batch members are undefined.
|
| 85 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
A Tensor with the batchwise KL-divergence between `distribution_a`
|
| 89 |
+
and `distribution_b`.
|
| 90 |
+
|
| 91 |
+
Raises:
|
| 92 |
+
NotImplementedError: If no KL method is defined for distribution types
|
| 93 |
+
of `distribution_a` and `distribution_b`.
|
| 94 |
+
"""
|
| 95 |
+
kl_fn = _registered_kl(type(distribution_a), type(distribution_b))
|
| 96 |
+
if kl_fn is None:
|
| 97 |
+
raise NotImplementedError(
|
| 98 |
+
"No KL(distribution_a || distribution_b) registered for distribution_a "
|
| 99 |
+
"type %s and distribution_b type %s"
|
| 100 |
+
% (type(distribution_a).__name__, type(distribution_b).__name__))
|
| 101 |
+
|
| 102 |
+
with ops.name_scope("KullbackLeibler"):
|
| 103 |
+
kl_t = kl_fn(distribution_a, distribution_b, name=name)
|
| 104 |
+
if allow_nan_stats:
|
| 105 |
+
return kl_t
|
| 106 |
+
|
| 107 |
+
# Check KL for NaNs
|
| 108 |
+
kl_t = array_ops.identity(kl_t, name="kl")
|
| 109 |
+
|
| 110 |
+
with ops.control_dependencies([
|
| 111 |
+
control_flow_assert.Assert(
|
| 112 |
+
math_ops.logical_not(math_ops.reduce_any(math_ops.is_nan(kl_t))), [
|
| 113 |
+
"KL calculation between %s and %s returned NaN values "
|
| 114 |
+
"(and was called with allow_nan_stats=False). Values:" %
|
| 115 |
+
(distribution_a.name, distribution_b.name), kl_t
|
| 116 |
+
])
|
| 117 |
+
]):
|
| 118 |
+
return array_ops.identity(kl_t, name="checked_kl")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@deprecation.deprecated(
|
| 122 |
+
"2019-01-01",
|
| 123 |
+
"The TensorFlow Distributions library has moved to "
|
| 124 |
+
"TensorFlow Probability "
|
| 125 |
+
"(https://github.com/tensorflow/probability). You "
|
| 126 |
+
"should update all references to use `tfp.distributions` "
|
| 127 |
+
"instead of `tf.distributions`.",
|
| 128 |
+
warn_once=True)
|
| 129 |
+
def cross_entropy(ref, other,
|
| 130 |
+
allow_nan_stats=True, name=None):
|
| 131 |
+
"""Computes the (Shannon) cross entropy.
|
| 132 |
+
|
| 133 |
+
Denote two distributions by `P` (`ref`) and `Q` (`other`). Assuming `P, Q`
|
| 134 |
+
are absolutely continuous with respect to one another and permit densities
|
| 135 |
+
`p(x) dr(x)` and `q(x) dr(x)`, (Shanon) cross entropy is defined as:
|
| 136 |
+
|
| 137 |
+
```none
|
| 138 |
+
H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
where `F` denotes the support of the random variable `X ~ P`.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
ref: `tfd.Distribution` instance.
|
| 145 |
+
other: `tfd.Distribution` instance.
|
| 146 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`,
|
| 147 |
+
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
|
| 148 |
+
indicate the result is undefined. When `False`, an exception is raised
|
| 149 |
+
if one or more of the statistic's batch members are undefined.
|
| 150 |
+
name: Python `str` prepended to names of ops created by this function.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
cross_entropy: `ref.dtype` `Tensor` with shape `[B1, ..., Bn]`
|
| 154 |
+
representing `n` different calculations of (Shanon) cross entropy.
|
| 155 |
+
"""
|
| 156 |
+
with ops.name_scope(name, "cross_entropy"):
|
| 157 |
+
return ref.entropy() + kl_divergence(
|
| 158 |
+
ref, other, allow_nan_stats=allow_nan_stats)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@tf_export(v1=["distributions.RegisterKL"])
|
| 162 |
+
class RegisterKL:
|
| 163 |
+
"""Decorator to register a KL divergence implementation function.
|
| 164 |
+
|
| 165 |
+
Usage:
|
| 166 |
+
|
| 167 |
+
@distributions.RegisterKL(distributions.Normal, distributions.Normal)
|
| 168 |
+
def _kl_normal_mvn(norm_a, norm_b):
|
| 169 |
+
# Return KL(norm_a || norm_b)
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
@deprecation.deprecated(
|
| 173 |
+
"2019-01-01",
|
| 174 |
+
"The TensorFlow Distributions library has moved to "
|
| 175 |
+
"TensorFlow Probability "
|
| 176 |
+
"(https://github.com/tensorflow/probability). You "
|
| 177 |
+
"should update all references to use `tfp.distributions` "
|
| 178 |
+
"instead of `tf.distributions`.",
|
| 179 |
+
warn_once=True)
|
| 180 |
+
def __init__(self, dist_cls_a, dist_cls_b):
|
| 181 |
+
"""Initialize the KL registrar.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
dist_cls_a: the class of the first argument of the KL divergence.
|
| 185 |
+
dist_cls_b: the class of the second argument of the KL divergence.
|
| 186 |
+
"""
|
| 187 |
+
self._key = (dist_cls_a, dist_cls_b)
|
| 188 |
+
|
| 189 |
+
def __call__(self, kl_fn):
|
| 190 |
+
"""Perform the KL registration.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
kl_fn: The function to use for the KL divergence.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
kl_fn
|
| 197 |
+
|
| 198 |
+
Raises:
|
| 199 |
+
TypeError: if kl_fn is not a callable.
|
| 200 |
+
ValueError: if a KL divergence function has already been registered for
|
| 201 |
+
the given argument classes.
|
| 202 |
+
"""
|
| 203 |
+
if not callable(kl_fn):
|
| 204 |
+
raise TypeError("kl_fn must be callable, received: %s" % kl_fn)
|
| 205 |
+
if self._key in _DIVERGENCES:
|
| 206 |
+
raise ValueError("KL(%s || %s) has already been registered to: %s"
|
| 207 |
+
% (self._key[0].__name__, self._key[1].__name__,
|
| 208 |
+
_DIVERGENCES[self._key]))
|
| 209 |
+
_DIVERGENCES[self._key] = kl_fn
|
| 210 |
+
return kl_fn
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/laplace.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Laplace distribution class."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from tensorflow.python.framework import constant_op
|
| 22 |
+
from tensorflow.python.framework import dtypes
|
| 23 |
+
from tensorflow.python.framework import ops
|
| 24 |
+
from tensorflow.python.framework import tensor_shape
|
| 25 |
+
from tensorflow.python.ops import array_ops
|
| 26 |
+
from tensorflow.python.ops import check_ops
|
| 27 |
+
from tensorflow.python.ops import math_ops
|
| 28 |
+
from tensorflow.python.ops import nn
|
| 29 |
+
from tensorflow.python.ops import random_ops
|
| 30 |
+
from tensorflow.python.ops.distributions import distribution
|
| 31 |
+
from tensorflow.python.ops.distributions import special_math
|
| 32 |
+
from tensorflow.python.util import deprecation
|
| 33 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"Laplace",
|
| 38 |
+
"LaplaceWithSoftplusScale",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@tf_export(v1=["distributions.Laplace"])
|
| 43 |
+
class Laplace(distribution.Distribution):
|
| 44 |
+
"""The Laplace distribution with location `loc` and `scale` parameters.
|
| 45 |
+
|
| 46 |
+
#### Mathematical details
|
| 47 |
+
|
| 48 |
+
The probability density function (pdf) of this distribution is,
|
| 49 |
+
|
| 50 |
+
```none
|
| 51 |
+
pdf(x; mu, sigma) = exp(-|x - mu| / sigma) / Z
|
| 52 |
+
Z = 2 sigma
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
where `loc = mu`, `scale = sigma`, and `Z` is the normalization constant.
|
| 56 |
+
|
| 57 |
+
Note that the Laplace distribution can be thought of two exponential
|
| 58 |
+
distributions spliced together "back-to-back."
|
| 59 |
+
|
| 60 |
+
The Lpalce distribution is a member of the [location-scale family](
|
| 61 |
+
https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
|
| 62 |
+
constructed as,
|
| 63 |
+
|
| 64 |
+
```none
|
| 65 |
+
X ~ Laplace(loc=0, scale=1)
|
| 66 |
+
Y = loc + scale * X
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
@deprecation.deprecated(
|
| 72 |
+
"2019-01-01",
|
| 73 |
+
"The TensorFlow Distributions library has moved to "
|
| 74 |
+
"TensorFlow Probability "
|
| 75 |
+
"(https://github.com/tensorflow/probability). You "
|
| 76 |
+
"should update all references to use `tfp.distributions` "
|
| 77 |
+
"instead of `tf.distributions`.",
|
| 78 |
+
warn_once=True)
|
| 79 |
+
def __init__(self,
|
| 80 |
+
loc,
|
| 81 |
+
scale,
|
| 82 |
+
validate_args=False,
|
| 83 |
+
allow_nan_stats=True,
|
| 84 |
+
name="Laplace"):
|
| 85 |
+
"""Construct Laplace distribution with parameters `loc` and `scale`.
|
| 86 |
+
|
| 87 |
+
The parameters `loc` and `scale` must be shaped in a way that supports
|
| 88 |
+
broadcasting (e.g., `loc / scale` is a valid operation).
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
loc: Floating point tensor which characterizes the location (center)
|
| 92 |
+
of the distribution.
|
| 93 |
+
scale: Positive floating point tensor which characterizes the spread of
|
| 94 |
+
the distribution.
|
| 95 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 96 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 97 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 98 |
+
outputs.
|
| 99 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`,
|
| 100 |
+
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
|
| 101 |
+
indicate the result is undefined. When `False`, an exception is raised
|
| 102 |
+
if one or more of the statistic's batch members are undefined.
|
| 103 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 104 |
+
|
| 105 |
+
Raises:
|
| 106 |
+
TypeError: if `loc` and `scale` are of different dtype.
|
| 107 |
+
"""
|
| 108 |
+
parameters = dict(locals())
|
| 109 |
+
with ops.name_scope(name, values=[loc, scale]) as name:
|
| 110 |
+
with ops.control_dependencies([check_ops.assert_positive(scale)] if
|
| 111 |
+
validate_args else []):
|
| 112 |
+
self._loc = array_ops.identity(loc, name="loc")
|
| 113 |
+
self._scale = array_ops.identity(scale, name="scale")
|
| 114 |
+
check_ops.assert_same_float_dtype([self._loc, self._scale])
|
| 115 |
+
super(Laplace, self).__init__(
|
| 116 |
+
dtype=self._loc.dtype,
|
| 117 |
+
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
|
| 118 |
+
validate_args=validate_args,
|
| 119 |
+
allow_nan_stats=allow_nan_stats,
|
| 120 |
+
parameters=parameters,
|
| 121 |
+
graph_parents=[self._loc, self._scale],
|
| 122 |
+
name=name)
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def _param_shapes(sample_shape):
|
| 126 |
+
return dict(
|
| 127 |
+
zip(("loc", "scale"), ([ops.convert_to_tensor(
|
| 128 |
+
sample_shape, dtype=dtypes.int32)] * 2)))
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def loc(self):
|
| 132 |
+
"""Distribution parameter for the location."""
|
| 133 |
+
return self._loc
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def scale(self):
|
| 137 |
+
"""Distribution parameter for scale."""
|
| 138 |
+
return self._scale
|
| 139 |
+
|
| 140 |
+
def _batch_shape_tensor(self):
|
| 141 |
+
return array_ops.broadcast_dynamic_shape(
|
| 142 |
+
array_ops.shape(self.loc), array_ops.shape(self.scale))
|
| 143 |
+
|
| 144 |
+
def _batch_shape(self):
|
| 145 |
+
return array_ops.broadcast_static_shape(
|
| 146 |
+
self.loc.get_shape(), self.scale.get_shape())
|
| 147 |
+
|
| 148 |
+
def _event_shape_tensor(self):
|
| 149 |
+
return constant_op.constant([], dtype=dtypes.int32)
|
| 150 |
+
|
| 151 |
+
def _event_shape(self):
|
| 152 |
+
return tensor_shape.TensorShape([])
|
| 153 |
+
|
| 154 |
+
def _sample_n(self, n, seed=None):
|
| 155 |
+
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
|
| 156 |
+
# Uniform variates must be sampled from the open-interval `(-1, 1)` rather
|
| 157 |
+
# than `[-1, 1)`. In the case of `(0, 1)` we'd use
|
| 158 |
+
# `np.finfo(self.dtype.as_numpy_dtype).tiny` because it is the smallest,
|
| 159 |
+
# positive, "normal" number. However, the concept of subnormality exists
|
| 160 |
+
# only at zero; here we need the smallest usable number larger than -1,
|
| 161 |
+
# i.e., `-1 + eps/2`.
|
| 162 |
+
uniform_samples = random_ops.random_uniform(
|
| 163 |
+
shape=shape,
|
| 164 |
+
minval=np.nextafter(self.dtype.as_numpy_dtype(-1.),
|
| 165 |
+
self.dtype.as_numpy_dtype(0.)),
|
| 166 |
+
maxval=1.,
|
| 167 |
+
dtype=self.dtype,
|
| 168 |
+
seed=seed)
|
| 169 |
+
return (self.loc - self.scale * math_ops.sign(uniform_samples) *
|
| 170 |
+
math_ops.log1p(-math_ops.abs(uniform_samples)))
|
| 171 |
+
|
| 172 |
+
def _log_prob(self, x):
|
| 173 |
+
return self._log_unnormalized_prob(x) - self._log_normalization()
|
| 174 |
+
|
| 175 |
+
def _prob(self, x):
|
| 176 |
+
return math_ops.exp(self._log_prob(x))
|
| 177 |
+
|
| 178 |
+
def _log_cdf(self, x):
|
| 179 |
+
return special_math.log_cdf_laplace(self._z(x))
|
| 180 |
+
|
| 181 |
+
def _log_survival_function(self, x):
|
| 182 |
+
return special_math.log_cdf_laplace(-self._z(x))
|
| 183 |
+
|
| 184 |
+
def _cdf(self, x):
|
| 185 |
+
z = self._z(x)
|
| 186 |
+
return (0.5 + 0.5 * math_ops.sign(z) *
|
| 187 |
+
(1. - math_ops.exp(-math_ops.abs(z))))
|
| 188 |
+
|
| 189 |
+
def _log_unnormalized_prob(self, x):
|
| 190 |
+
return -math_ops.abs(self._z(x))
|
| 191 |
+
|
| 192 |
+
def _log_normalization(self):
|
| 193 |
+
return math.log(2.) + math_ops.log(self.scale)
|
| 194 |
+
|
| 195 |
+
def _entropy(self):
|
| 196 |
+
# Use broadcasting rules to calculate the full broadcast scale.
|
| 197 |
+
scale = self.scale + array_ops.zeros_like(self.loc)
|
| 198 |
+
return math.log(2.) + 1. + math_ops.log(scale)
|
| 199 |
+
|
| 200 |
+
def _mean(self):
|
| 201 |
+
return self.loc + array_ops.zeros_like(self.scale)
|
| 202 |
+
|
| 203 |
+
def _stddev(self):
|
| 204 |
+
return math.sqrt(2.) * self.scale + array_ops.zeros_like(self.loc)
|
| 205 |
+
|
| 206 |
+
def _median(self):
|
| 207 |
+
return self._mean()
|
| 208 |
+
|
| 209 |
+
def _mode(self):
|
| 210 |
+
return self._mean()
|
| 211 |
+
|
| 212 |
+
def _z(self, x):
|
| 213 |
+
return (x - self.loc) / self.scale
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class LaplaceWithSoftplusScale(Laplace):
|
| 217 |
+
"""Laplace with softplus applied to `scale`."""
|
| 218 |
+
|
| 219 |
+
@deprecation.deprecated(
|
| 220 |
+
"2019-01-01",
|
| 221 |
+
"Use `tfd.Laplace(loc, tf.nn.softplus(scale)) "
|
| 222 |
+
"instead.",
|
| 223 |
+
warn_once=True)
|
| 224 |
+
def __init__(self,
|
| 225 |
+
loc,
|
| 226 |
+
scale,
|
| 227 |
+
validate_args=False,
|
| 228 |
+
allow_nan_stats=True,
|
| 229 |
+
name="LaplaceWithSoftplusScale"):
|
| 230 |
+
parameters = dict(locals())
|
| 231 |
+
with ops.name_scope(name, values=[loc, scale]) as name:
|
| 232 |
+
super(LaplaceWithSoftplusScale, self).__init__(
|
| 233 |
+
loc=loc,
|
| 234 |
+
scale=nn.softplus(scale, name="softplus_scale"),
|
| 235 |
+
validate_args=validate_args,
|
| 236 |
+
allow_nan_stats=allow_nan_stats,
|
| 237 |
+
name=name)
|
| 238 |
+
self._parameters = parameters
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/multinomial.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Multinomial distribution class."""
|
| 16 |
+
|
| 17 |
+
from tensorflow.python.framework import dtypes
|
| 18 |
+
from tensorflow.python.framework import ops
|
| 19 |
+
from tensorflow.python.ops import array_ops
|
| 20 |
+
from tensorflow.python.ops import check_ops
|
| 21 |
+
from tensorflow.python.ops import control_flow_ops
|
| 22 |
+
from tensorflow.python.ops import map_fn
|
| 23 |
+
from tensorflow.python.ops import math_ops
|
| 24 |
+
from tensorflow.python.ops import nn_ops
|
| 25 |
+
from tensorflow.python.ops import random_ops
|
| 26 |
+
from tensorflow.python.ops.distributions import distribution
|
| 27 |
+
from tensorflow.python.ops.distributions import util as distribution_util
|
| 28 |
+
from tensorflow.python.util import deprecation
|
| 29 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
__all__ = [
|
| 33 |
+
"Multinomial",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_multinomial_sample_note = """For each batch of counts, `value = [n_0, ...
|
| 38 |
+
,n_{k-1}]`, `P[value]` is the probability that after sampling `self.total_count`
|
| 39 |
+
draws from this Multinomial distribution, the number of draws falling in class
|
| 40 |
+
`j` is `n_j`. Since this definition is [exchangeable](
|
| 41 |
+
https://en.wikipedia.org/wiki/Exchangeable_random_variables); different
|
| 42 |
+
sequences have the same counts so the probability includes a combinatorial
|
| 43 |
+
coefficient.
|
| 44 |
+
|
| 45 |
+
Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no
|
| 46 |
+
fractional components, and such that
|
| 47 |
+
`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable
|
| 48 |
+
with `self.probs` and `self.total_count`."""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@tf_export(v1=["distributions.Multinomial"])
|
| 52 |
+
class Multinomial(distribution.Distribution):
|
| 53 |
+
"""Multinomial distribution.
|
| 54 |
+
|
| 55 |
+
This Multinomial distribution is parameterized by `probs`, a (batch of)
|
| 56 |
+
length-`K` `prob` (probability) vectors (`K > 1`) such that
|
| 57 |
+
`tf.reduce_sum(probs, -1) = 1`, and a `total_count` number of trials, i.e.,
|
| 58 |
+
the number of trials per draw from the Multinomial. It is defined over a
|
| 59 |
+
(batch of) length-`K` vector `counts` such that
|
| 60 |
+
`tf.reduce_sum(counts, -1) = total_count`. The Multinomial is identically the
|
| 61 |
+
Binomial distribution when `K = 2`.
|
| 62 |
+
|
| 63 |
+
#### Mathematical Details
|
| 64 |
+
|
| 65 |
+
The Multinomial is a distribution over `K`-class counts, i.e., a length-`K`
|
| 66 |
+
vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`.
|
| 67 |
+
|
| 68 |
+
The probability mass function (pmf) is,
|
| 69 |
+
|
| 70 |
+
```none
|
| 71 |
+
pmf(n; pi, N) = prod_j (pi_j)**n_j / Z
|
| 72 |
+
Z = (prod_j n_j!) / N!
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
where:
|
| 76 |
+
* `probs = pi = [pi_0, ..., pi_{K-1}]`, `pi_j > 0`, `sum_j pi_j = 1`,
|
| 77 |
+
* `total_count = N`, `N` a positive integer,
|
| 78 |
+
* `Z` is the normalization constant, and,
|
| 79 |
+
* `N!` denotes `N` factorial.
|
| 80 |
+
|
| 81 |
+
Distribution parameters are automatically broadcast in all functions; see
|
| 82 |
+
examples for details.
|
| 83 |
+
|
| 84 |
+
#### Pitfalls
|
| 85 |
+
|
| 86 |
+
The number of classes, `K`, must not exceed:
|
| 87 |
+
- the largest integer representable by `self.dtype`, i.e.,
|
| 88 |
+
`2**(mantissa_bits+1)` (IEE754),
|
| 89 |
+
- the maximum `Tensor` index, i.e., `2**31-1`.
|
| 90 |
+
|
| 91 |
+
In other words,
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
K <= min(2**31-1, {
|
| 95 |
+
tf.float16: 2**11,
|
| 96 |
+
tf.float32: 2**24,
|
| 97 |
+
tf.float64: 2**53 }[param.dtype])
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
Note: This condition is validated only when `self.validate_args = True`.
|
| 101 |
+
|
| 102 |
+
#### Examples
|
| 103 |
+
|
| 104 |
+
Create a 3-class distribution, with the 3rd class is most likely to be drawn,
|
| 105 |
+
using logits.
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
logits = [-50., -43, 0]
|
| 109 |
+
dist = Multinomial(total_count=4., logits=logits)
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
Create a 3-class distribution, with the 3rd class is most likely to be drawn.
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
p = [.2, .3, .5]
|
| 116 |
+
dist = Multinomial(total_count=4., probs=p)
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
The distribution functions can be evaluated on counts.
|
| 120 |
+
|
| 121 |
+
```python
|
| 122 |
+
# counts same shape as p.
|
| 123 |
+
counts = [1., 0, 3]
|
| 124 |
+
dist.prob(counts) # Shape []
|
| 125 |
+
|
| 126 |
+
# p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
|
| 127 |
+
counts = [[1., 2, 1], [2, 2, 0]]
|
| 128 |
+
dist.prob(counts) # Shape [2]
|
| 129 |
+
|
| 130 |
+
# p will be broadcast to shape [5, 7, 3] to match counts.
|
| 131 |
+
counts = [[...]] # Shape [5, 7, 3]
|
| 132 |
+
dist.prob(counts) # Shape [5, 7]
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
Create a 2-batch of 3-class distributions.
|
| 136 |
+
|
| 137 |
+
```python
|
| 138 |
+
p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3]
|
| 139 |
+
dist = Multinomial(total_count=[4., 5], probs=p)
|
| 140 |
+
|
| 141 |
+
counts = [[2., 1, 1], [3, 1, 1]]
|
| 142 |
+
dist.prob(counts) # Shape [2]
|
| 143 |
+
|
| 144 |
+
dist.sample(5) # Shape [5, 2, 3]
|
| 145 |
+
```
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
@deprecation.deprecated(
|
| 149 |
+
"2019-01-01",
|
| 150 |
+
"The TensorFlow Distributions library has moved to "
|
| 151 |
+
"TensorFlow Probability "
|
| 152 |
+
"(https://github.com/tensorflow/probability). You "
|
| 153 |
+
"should update all references to use `tfp.distributions` "
|
| 154 |
+
"instead of `tf.distributions`.",
|
| 155 |
+
warn_once=True)
|
| 156 |
+
def __init__(self,
|
| 157 |
+
total_count,
|
| 158 |
+
logits=None,
|
| 159 |
+
probs=None,
|
| 160 |
+
validate_args=False,
|
| 161 |
+
allow_nan_stats=True,
|
| 162 |
+
name="Multinomial"):
|
| 163 |
+
"""Initialize a batch of Multinomial distributions.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
total_count: Non-negative floating point tensor with shape broadcastable
|
| 167 |
+
to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
|
| 168 |
+
`N1 x ... x Nm` different Multinomial distributions. Its components
|
| 169 |
+
should be equal to integer values.
|
| 170 |
+
logits: Floating point tensor representing unnormalized log-probabilities
|
| 171 |
+
of a positive event with shape broadcastable to
|
| 172 |
+
`[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines
|
| 173 |
+
this as a batch of `N1 x ... x Nm` different `K` class Multinomial
|
| 174 |
+
distributions. Only one of `logits` or `probs` should be passed in.
|
| 175 |
+
probs: Positive floating point tensor with shape broadcastable to
|
| 176 |
+
`[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines
|
| 177 |
+
this as a batch of `N1 x ... x Nm` different `K` class Multinomial
|
| 178 |
+
distributions. `probs`'s components in the last portion of its shape
|
| 179 |
+
should sum to `1`. Only one of `logits` or `probs` should be passed in.
|
| 180 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 181 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 182 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 183 |
+
outputs.
|
| 184 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
| 185 |
+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
| 186 |
+
result is undefined. When `False`, an exception is raised if one or
|
| 187 |
+
more of the statistic's batch members are undefined.
|
| 188 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 189 |
+
"""
|
| 190 |
+
parameters = dict(locals())
|
| 191 |
+
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
|
| 192 |
+
self._total_count = ops.convert_to_tensor(total_count, name="total_count")
|
| 193 |
+
if validate_args:
|
| 194 |
+
self._total_count = (
|
| 195 |
+
distribution_util.embed_check_nonnegative_integer_form(
|
| 196 |
+
self._total_count))
|
| 197 |
+
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
| 198 |
+
logits=logits,
|
| 199 |
+
probs=probs,
|
| 200 |
+
multidimensional=True,
|
| 201 |
+
validate_args=validate_args,
|
| 202 |
+
name=name)
|
| 203 |
+
self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs
|
| 204 |
+
super(Multinomial, self).__init__(
|
| 205 |
+
dtype=self._probs.dtype,
|
| 206 |
+
reparameterization_type=distribution.NOT_REPARAMETERIZED,
|
| 207 |
+
validate_args=validate_args,
|
| 208 |
+
allow_nan_stats=allow_nan_stats,
|
| 209 |
+
parameters=parameters,
|
| 210 |
+
graph_parents=[self._total_count,
|
| 211 |
+
self._logits,
|
| 212 |
+
self._probs],
|
| 213 |
+
name=name)
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def total_count(self):
|
| 217 |
+
"""Number of trials used to construct a sample."""
|
| 218 |
+
return self._total_count
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def logits(self):
|
| 222 |
+
"""Vector of coordinatewise logits."""
|
| 223 |
+
return self._logits
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def probs(self):
|
| 227 |
+
"""Probability of drawing a `1` in that coordinate."""
|
| 228 |
+
return self._probs
|
| 229 |
+
|
| 230 |
+
def _batch_shape_tensor(self):
|
| 231 |
+
return array_ops.shape(self._mean_val)[:-1]
|
| 232 |
+
|
| 233 |
+
def _batch_shape(self):
|
| 234 |
+
return self._mean_val.get_shape().with_rank_at_least(1)[:-1]
|
| 235 |
+
|
| 236 |
+
def _event_shape_tensor(self):
|
| 237 |
+
return array_ops.shape(self._mean_val)[-1:]
|
| 238 |
+
|
| 239 |
+
def _event_shape(self):
|
| 240 |
+
return self._mean_val.get_shape().with_rank_at_least(1)[-1:]
|
| 241 |
+
|
| 242 |
+
def _sample_n(self, n, seed=None):
|
| 243 |
+
n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
|
| 244 |
+
k = self.event_shape_tensor()[0]
|
| 245 |
+
|
| 246 |
+
# broadcast the total_count and logits to same shape
|
| 247 |
+
n_draws = array_ops.ones_like(
|
| 248 |
+
self.logits[..., 0], dtype=n_draws.dtype) * n_draws
|
| 249 |
+
logits = array_ops.ones_like(
|
| 250 |
+
n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits
|
| 251 |
+
|
| 252 |
+
# flatten the total_count and logits
|
| 253 |
+
flat_logits = array_ops.reshape(logits, [-1, k]) # [B1B2...Bm, k]
|
| 254 |
+
flat_ndraws = n * array_ops.reshape(n_draws, [-1]) # [B1B2...Bm]
|
| 255 |
+
|
| 256 |
+
# computes each total_count and logits situation by map_fn
|
| 257 |
+
def _sample_single(args):
|
| 258 |
+
logits, n_draw = args[0], args[1] # [K], []
|
| 259 |
+
x = random_ops.multinomial(logits[array_ops.newaxis, ...], n_draw,
|
| 260 |
+
seed) # [1, n*n_draw]
|
| 261 |
+
x = array_ops.reshape(x, shape=[n, -1]) # [n, n_draw]
|
| 262 |
+
x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2) # [n, k]
|
| 263 |
+
return x
|
| 264 |
+
|
| 265 |
+
x = map_fn.map_fn(
|
| 266 |
+
_sample_single, [flat_logits, flat_ndraws],
|
| 267 |
+
dtype=self.dtype) # [B1B2...Bm, n, k]
|
| 268 |
+
|
| 269 |
+
# reshape the results to proper shape
|
| 270 |
+
x = array_ops.transpose(x, perm=[1, 0, 2])
|
| 271 |
+
final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
|
| 272 |
+
x = array_ops.reshape(x, final_shape) # [n, B1, B2,..., Bm, k]
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
@distribution_util.AppendDocstring(_multinomial_sample_note)
|
| 276 |
+
def _log_prob(self, counts):
|
| 277 |
+
return self._log_unnormalized_prob(counts) - self._log_normalization(counts)
|
| 278 |
+
|
| 279 |
+
def _log_unnormalized_prob(self, counts):
|
| 280 |
+
counts = self._maybe_assert_valid_sample(counts)
|
| 281 |
+
return math_ops.reduce_sum(counts * nn_ops.log_softmax(self.logits), -1)
|
| 282 |
+
|
| 283 |
+
def _log_normalization(self, counts):
|
| 284 |
+
counts = self._maybe_assert_valid_sample(counts)
|
| 285 |
+
return -distribution_util.log_combinations(self.total_count, counts)
|
| 286 |
+
|
| 287 |
+
def _mean(self):
|
| 288 |
+
return array_ops.identity(self._mean_val)
|
| 289 |
+
|
| 290 |
+
def _covariance(self):
|
| 291 |
+
p = self.probs * array_ops.ones_like(
|
| 292 |
+
self.total_count)[..., array_ops.newaxis]
|
| 293 |
+
# pylint: disable=invalid-unary-operand-type
|
| 294 |
+
return array_ops.matrix_set_diag(
|
| 295 |
+
-math_ops.matmul(
|
| 296 |
+
self._mean_val[..., array_ops.newaxis],
|
| 297 |
+
p[..., array_ops.newaxis, :]), # outer product
|
| 298 |
+
self._variance())
|
| 299 |
+
|
| 300 |
+
def _variance(self):
|
| 301 |
+
p = self.probs * array_ops.ones_like(
|
| 302 |
+
self.total_count)[..., array_ops.newaxis]
|
| 303 |
+
return self._mean_val - self._mean_val * p
|
| 304 |
+
|
| 305 |
+
def _maybe_assert_valid_sample(self, counts):
|
| 306 |
+
"""Check counts for proper shape, values, then return tensor version."""
|
| 307 |
+
if not self.validate_args:
|
| 308 |
+
return counts
|
| 309 |
+
counts = distribution_util.embed_check_nonnegative_integer_form(counts)
|
| 310 |
+
return control_flow_ops.with_dependencies([
|
| 311 |
+
check_ops.assert_equal(
|
| 312 |
+
self.total_count, math_ops.reduce_sum(counts, -1),
|
| 313 |
+
message="counts must sum to `self.total_count`"),
|
| 314 |
+
], counts)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/normal.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Normal (Gaussian) distribution class."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
from tensorflow.python.framework import constant_op
|
| 20 |
+
from tensorflow.python.framework import dtypes
|
| 21 |
+
from tensorflow.python.framework import ops
|
| 22 |
+
from tensorflow.python.framework import tensor_shape
|
| 23 |
+
from tensorflow.python.ops import array_ops
|
| 24 |
+
from tensorflow.python.ops import check_ops
|
| 25 |
+
from tensorflow.python.ops import math_ops
|
| 26 |
+
from tensorflow.python.ops import nn
|
| 27 |
+
from tensorflow.python.ops import random_ops
|
| 28 |
+
from tensorflow.python.ops.distributions import distribution
|
| 29 |
+
from tensorflow.python.ops.distributions import kullback_leibler
|
| 30 |
+
from tensorflow.python.ops.distributions import special_math
|
| 31 |
+
from tensorflow.python.util import deprecation
|
| 32 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
"Normal",
|
| 37 |
+
"NormalWithSoftplusScale",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@tf_export(v1=["distributions.Normal"])
|
| 42 |
+
class Normal(distribution.Distribution):
|
| 43 |
+
"""The Normal distribution with location `loc` and `scale` parameters.
|
| 44 |
+
|
| 45 |
+
#### Mathematical details
|
| 46 |
+
|
| 47 |
+
The probability density function (pdf) is,
|
| 48 |
+
|
| 49 |
+
```none
|
| 50 |
+
pdf(x; mu, sigma) = exp(-0.5 (x - mu)**2 / sigma**2) / Z
|
| 51 |
+
Z = (2 pi sigma**2)**0.5
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
where `loc = mu` is the mean, `scale = sigma` is the std. deviation, and, `Z`
|
| 55 |
+
is the normalization constant.
|
| 56 |
+
|
| 57 |
+
The Normal distribution is a member of the [location-scale family](
|
| 58 |
+
https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
|
| 59 |
+
constructed as,
|
| 60 |
+
|
| 61 |
+
```none
|
| 62 |
+
X ~ Normal(loc=0, scale=1)
|
| 63 |
+
Y = loc + scale * X
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
#### Examples
|
| 67 |
+
|
| 68 |
+
Examples of initialization of one or a batch of distributions.
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
import tensorflow_probability as tfp
|
| 72 |
+
tfd = tfp.distributions
|
| 73 |
+
|
| 74 |
+
# Define a single scalar Normal distribution.
|
| 75 |
+
dist = tfd.Normal(loc=0., scale=3.)
|
| 76 |
+
|
| 77 |
+
# Evaluate the cdf at 1, returning a scalar.
|
| 78 |
+
dist.cdf(1.)
|
| 79 |
+
|
| 80 |
+
# Define a batch of two scalar valued Normals.
|
| 81 |
+
# The first has mean 1 and standard deviation 11, the second 2 and 22.
|
| 82 |
+
dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
|
| 83 |
+
|
| 84 |
+
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
|
| 85 |
+
# returning a length two tensor.
|
| 86 |
+
dist.prob([0, 1.5])
|
| 87 |
+
|
| 88 |
+
# Get 3 samples, returning a 3 x 2 tensor.
|
| 89 |
+
dist.sample([3])
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Arguments are broadcast when possible.
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
# Define a batch of two scalar valued Normals.
|
| 96 |
+
# Both have mean 1, but different standard deviations.
|
| 97 |
+
dist = tfd.Normal(loc=1., scale=[11, 22.])
|
| 98 |
+
|
| 99 |
+
# Evaluate the pdf of both distributions on the same point, 3.0,
|
| 100 |
+
# returning a length 2 tensor.
|
| 101 |
+
dist.prob(3.0)
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
@deprecation.deprecated(
|
| 107 |
+
"2019-01-01",
|
| 108 |
+
"The TensorFlow Distributions library has moved to "
|
| 109 |
+
"TensorFlow Probability "
|
| 110 |
+
"(https://github.com/tensorflow/probability). You "
|
| 111 |
+
"should update all references to use `tfp.distributions` "
|
| 112 |
+
"instead of `tf.distributions`.",
|
| 113 |
+
warn_once=True)
|
| 114 |
+
def __init__(self,
|
| 115 |
+
loc,
|
| 116 |
+
scale,
|
| 117 |
+
validate_args=False,
|
| 118 |
+
allow_nan_stats=True,
|
| 119 |
+
name="Normal"):
|
| 120 |
+
"""Construct Normal distributions with mean and stddev `loc` and `scale`.
|
| 121 |
+
|
| 122 |
+
The parameters `loc` and `scale` must be shaped in a way that supports
|
| 123 |
+
broadcasting (e.g. `loc + scale` is a valid operation).
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
loc: Floating point tensor; the means of the distribution(s).
|
| 127 |
+
scale: Floating point tensor; the stddevs of the distribution(s).
|
| 128 |
+
Must contain only positive values.
|
| 129 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 130 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 131 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 132 |
+
outputs.
|
| 133 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`,
|
| 134 |
+
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
|
| 135 |
+
indicate the result is undefined. When `False`, an exception is raised
|
| 136 |
+
if one or more of the statistic's batch members are undefined.
|
| 137 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 138 |
+
|
| 139 |
+
Raises:
|
| 140 |
+
TypeError: if `loc` and `scale` have different `dtype`.
|
| 141 |
+
"""
|
| 142 |
+
parameters = dict(locals())
|
| 143 |
+
with ops.name_scope(name, values=[loc, scale]) as name:
|
| 144 |
+
with ops.control_dependencies([check_ops.assert_positive(scale)] if
|
| 145 |
+
validate_args else []):
|
| 146 |
+
self._loc = array_ops.identity(loc, name="loc")
|
| 147 |
+
self._scale = array_ops.identity(scale, name="scale")
|
| 148 |
+
check_ops.assert_same_float_dtype([self._loc, self._scale])
|
| 149 |
+
super(Normal, self).__init__(
|
| 150 |
+
dtype=self._scale.dtype,
|
| 151 |
+
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
|
| 152 |
+
validate_args=validate_args,
|
| 153 |
+
allow_nan_stats=allow_nan_stats,
|
| 154 |
+
parameters=parameters,
|
| 155 |
+
graph_parents=[self._loc, self._scale],
|
| 156 |
+
name=name)
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def _param_shapes(sample_shape):
|
| 160 |
+
return dict(
|
| 161 |
+
zip(("loc", "scale"), ([ops.convert_to_tensor(
|
| 162 |
+
sample_shape, dtype=dtypes.int32)] * 2)))
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def loc(self):
|
| 166 |
+
"""Distribution parameter for the mean."""
|
| 167 |
+
return self._loc
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def scale(self):
|
| 171 |
+
"""Distribution parameter for standard deviation."""
|
| 172 |
+
return self._scale
|
| 173 |
+
|
| 174 |
+
def _batch_shape_tensor(self):
|
| 175 |
+
return array_ops.broadcast_dynamic_shape(
|
| 176 |
+
array_ops.shape(self.loc),
|
| 177 |
+
array_ops.shape(self.scale))
|
| 178 |
+
|
| 179 |
+
def _batch_shape(self):
|
| 180 |
+
return array_ops.broadcast_static_shape(
|
| 181 |
+
self.loc.get_shape(),
|
| 182 |
+
self.scale.get_shape())
|
| 183 |
+
|
| 184 |
+
def _event_shape_tensor(self):
|
| 185 |
+
return constant_op.constant([], dtype=dtypes.int32)
|
| 186 |
+
|
| 187 |
+
def _event_shape(self):
|
| 188 |
+
return tensor_shape.TensorShape([])
|
| 189 |
+
|
| 190 |
+
def _sample_n(self, n, seed=None):
|
| 191 |
+
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
|
| 192 |
+
sampled = random_ops.random_normal(
|
| 193 |
+
shape=shape, mean=0., stddev=1., dtype=self.loc.dtype, seed=seed)
|
| 194 |
+
return sampled * self.scale + self.loc
|
| 195 |
+
|
| 196 |
+
def _log_prob(self, x):
|
| 197 |
+
return self._log_unnormalized_prob(x) - self._log_normalization()
|
| 198 |
+
|
| 199 |
+
def _log_cdf(self, x):
|
| 200 |
+
return special_math.log_ndtr(self._z(x))
|
| 201 |
+
|
| 202 |
+
def _cdf(self, x):
|
| 203 |
+
return special_math.ndtr(self._z(x))
|
| 204 |
+
|
| 205 |
+
def _log_survival_function(self, x):
|
| 206 |
+
return special_math.log_ndtr(-self._z(x))
|
| 207 |
+
|
| 208 |
+
def _survival_function(self, x):
|
| 209 |
+
return special_math.ndtr(-self._z(x))
|
| 210 |
+
|
| 211 |
+
def _log_unnormalized_prob(self, x):
|
| 212 |
+
return -0.5 * math_ops.square(self._z(x))
|
| 213 |
+
|
| 214 |
+
def _log_normalization(self):
|
| 215 |
+
return 0.5 * math.log(2. * math.pi) + math_ops.log(self.scale)
|
| 216 |
+
|
| 217 |
+
def _entropy(self):
|
| 218 |
+
# Use broadcasting rules to calculate the full broadcast scale.
|
| 219 |
+
scale = self.scale * array_ops.ones_like(self.loc)
|
| 220 |
+
return 0.5 * math.log(2. * math.pi * math.e) + math_ops.log(scale)
|
| 221 |
+
|
| 222 |
+
def _mean(self):
|
| 223 |
+
return self.loc * array_ops.ones_like(self.scale)
|
| 224 |
+
|
| 225 |
+
def _quantile(self, p):
|
| 226 |
+
return self._inv_z(special_math.ndtri(p))
|
| 227 |
+
|
| 228 |
+
def _stddev(self):
|
| 229 |
+
return self.scale * array_ops.ones_like(self.loc)
|
| 230 |
+
|
| 231 |
+
def _mode(self):
|
| 232 |
+
return self._mean()
|
| 233 |
+
|
| 234 |
+
def _z(self, x):
|
| 235 |
+
"""Standardize input `x` to a unit normal."""
|
| 236 |
+
with ops.name_scope("standardize", values=[x]):
|
| 237 |
+
return (x - self.loc) / self.scale
|
| 238 |
+
|
| 239 |
+
def _inv_z(self, z):
|
| 240 |
+
"""Reconstruct input `x` from a its normalized version."""
|
| 241 |
+
with ops.name_scope("reconstruct", values=[z]):
|
| 242 |
+
return z * self.scale + self.loc
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class NormalWithSoftplusScale(Normal):
|
| 246 |
+
"""Normal with softplus applied to `scale`."""
|
| 247 |
+
|
| 248 |
+
@deprecation.deprecated(
|
| 249 |
+
"2019-01-01",
|
| 250 |
+
"Use `tfd.Normal(loc, tf.nn.softplus(scale)) "
|
| 251 |
+
"instead.",
|
| 252 |
+
warn_once=True)
|
| 253 |
+
def __init__(self,
|
| 254 |
+
loc,
|
| 255 |
+
scale,
|
| 256 |
+
validate_args=False,
|
| 257 |
+
allow_nan_stats=True,
|
| 258 |
+
name="NormalWithSoftplusScale"):
|
| 259 |
+
parameters = dict(locals())
|
| 260 |
+
with ops.name_scope(name, values=[scale]) as name:
|
| 261 |
+
super(NormalWithSoftplusScale, self).__init__(
|
| 262 |
+
loc=loc,
|
| 263 |
+
scale=nn.softplus(scale, name="softplus_scale"),
|
| 264 |
+
validate_args=validate_args,
|
| 265 |
+
allow_nan_stats=allow_nan_stats,
|
| 266 |
+
name=name)
|
| 267 |
+
self._parameters = parameters
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@kullback_leibler.RegisterKL(Normal, Normal)
|
| 271 |
+
def _kl_normal_normal(n_a, n_b, name=None):
|
| 272 |
+
"""Calculate the batched KL divergence KL(n_a || n_b) with n_a and n_b Normal.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
n_a: instance of a Normal distribution object.
|
| 276 |
+
n_b: instance of a Normal distribution object.
|
| 277 |
+
name: (optional) Name to use for created operations.
|
| 278 |
+
default is "kl_normal_normal".
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
Batchwise KL(n_a || n_b)
|
| 282 |
+
"""
|
| 283 |
+
with ops.name_scope(name, "kl_normal_normal", [n_a.loc, n_b.loc]):
|
| 284 |
+
one = constant_op.constant(1, dtype=n_a.dtype)
|
| 285 |
+
two = constant_op.constant(2, dtype=n_a.dtype)
|
| 286 |
+
half = constant_op.constant(0.5, dtype=n_a.dtype)
|
| 287 |
+
s_a_squared = math_ops.square(n_a.scale)
|
| 288 |
+
s_b_squared = math_ops.square(n_b.scale)
|
| 289 |
+
ratio = s_a_squared / s_b_squared
|
| 290 |
+
return (math_ops.squared_difference(n_a.loc, n_b.loc) / (two * s_b_squared)
|
| 291 |
+
+ half * (ratio - one - math_ops.log(ratio)))
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/special_math.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
# Functions "ndtr" and "ndtri" are derived from calculations made in:
|
| 17 |
+
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
|
| 18 |
+
# In the following email exchange, the author gives his consent to redistribute
|
| 19 |
+
# derived works under an Apache 2.0 license.
|
| 20 |
+
#
|
| 21 |
+
# From: Stephen Moshier <steve@moshier.net>
|
| 22 |
+
# Date: Sat, Jun 9, 2018 at 2:36 PM
|
| 23 |
+
# Subject: Re: Licensing cephes under Apache (BSD-like) license.
|
| 24 |
+
# To: rif <rif@google.com>
|
| 25 |
+
#
|
| 26 |
+
#
|
| 27 |
+
#
|
| 28 |
+
# Hello Rif,
|
| 29 |
+
#
|
| 30 |
+
# Yes, Google may distribute Cephes files under the Apache 2 license.
|
| 31 |
+
#
|
| 32 |
+
# If clarification is needed, I do not favor BSD over other free licenses.
|
| 33 |
+
# I would agree that Apache 2 seems to cover the concern you mentioned
|
| 34 |
+
# about sublicensees.
|
| 35 |
+
#
|
| 36 |
+
# Best wishes for good luck with your projects!
|
| 37 |
+
# Steve Moshier
|
| 38 |
+
#
|
| 39 |
+
#
|
| 40 |
+
#
|
| 41 |
+
# On Thu, 31 May 2018, rif wrote:
|
| 42 |
+
#
|
| 43 |
+
# > Hello Steve.
|
| 44 |
+
# > My name is Rif. I work on machine learning software at Google.
|
| 45 |
+
# >
|
| 46 |
+
# > Your cephes software continues to be incredibly useful and widely used. I
|
| 47 |
+
# > was wondering whether it would be permissible for us to use the Cephes code
|
| 48 |
+
# > under the Apache 2.0 license, which is extremely similar in permissions to
|
| 49 |
+
# > the BSD license (Wikipedia comparisons). This would be quite helpful to us
|
| 50 |
+
# > in terms of avoiding multiple licenses on software.
|
| 51 |
+
# >
|
| 52 |
+
# > I'm sorry to bother you with this (I can imagine you're sick of hearing
|
| 53 |
+
# > about this by now), but I want to be absolutely clear we're on the level and
|
| 54 |
+
# > not misusing your important software. In former conversation with Eugene
|
| 55 |
+
# > Brevdo (ebrevdo@google.com), you wrote "If your licensing is similar to BSD,
|
| 56 |
+
# > the formal way that has been handled is simply to add a statement to the
|
| 57 |
+
# > effect that you are incorporating the Cephes software by permission of the
|
| 58 |
+
# > author." I wanted to confirm that (a) we could use the Apache license, (b)
|
| 59 |
+
# > that we don't need to (and probably you don't want to) keep getting
|
| 60 |
+
# > contacted about individual uses, because your intent is generally to allow
|
| 61 |
+
# > this software to be reused under "BSD-like" license, and (c) you're OK
|
| 62 |
+
# > letting incorporators decide whether a license is sufficiently BSD-like?
|
| 63 |
+
# >
|
| 64 |
+
# > Best,
|
| 65 |
+
# >
|
| 66 |
+
# > rif
|
| 67 |
+
# >
|
| 68 |
+
# >
|
| 69 |
+
# >
|
| 70 |
+
|
| 71 |
+
"""Special Math Ops."""
|
| 72 |
+
|
| 73 |
+
import numpy as np
|
| 74 |
+
|
| 75 |
+
from tensorflow.python.framework import constant_op
|
| 76 |
+
from tensorflow.python.framework import ops
|
| 77 |
+
from tensorflow.python.ops import array_ops
|
| 78 |
+
from tensorflow.python.ops import math_ops
|
| 79 |
+
|
| 80 |
+
__all__ = [
|
| 81 |
+
"erfinv",
|
| 82 |
+
"ndtr",
|
| 83 |
+
"ndtri",
|
| 84 |
+
"log_ndtr",
|
| 85 |
+
"log_cdf_laplace",
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# log_ndtr uses different functions over the ranges
|
| 90 |
+
# (-infty, lower](lower, upper](upper, infty)
|
| 91 |
+
# Lower bound values were chosen by examining where the support of ndtr
|
| 92 |
+
# appears to be zero, relative to scipy's (which is always 64bit). They were
|
| 93 |
+
# then made more conservative just to be safe. (Conservative means use the
|
| 94 |
+
# expansion more than we probably need to.) See `NdtrTest` in
|
| 95 |
+
# special_math_test.py.
|
| 96 |
+
LOGNDTR_FLOAT64_LOWER = np.array(-20, np.float64)
|
| 97 |
+
LOGNDTR_FLOAT32_LOWER = np.array(-10, np.float32)
|
| 98 |
+
|
| 99 |
+
# Upper bound values were chosen by examining for which values of 'x'
|
| 100 |
+
# Log[cdf(x)] is 0, after which point we need to use the approximation
|
| 101 |
+
# Log[cdf(x)] = Log[1 - cdf(-x)] approx -cdf(-x). We chose a value slightly
|
| 102 |
+
# conservative, meaning we use the approximation earlier than needed.
|
| 103 |
+
LOGNDTR_FLOAT64_UPPER = np.array(8, np.float64)
|
| 104 |
+
LOGNDTR_FLOAT32_UPPER = np.array(5, np.float32)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def ndtr(x, name="ndtr"):
|
| 108 |
+
"""Normal distribution function.
|
| 109 |
+
|
| 110 |
+
Returns the area under the Gaussian probability density function, integrated
|
| 111 |
+
from minus infinity to x:
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
1 / x
|
| 115 |
+
ndtr(x) = ---------- | exp(-0.5 t**2) dt
|
| 116 |
+
sqrt(2 pi) /-inf
|
| 117 |
+
|
| 118 |
+
= 0.5 (1 + erf(x / sqrt(2)))
|
| 119 |
+
= 0.5 erfc(x / sqrt(2))
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
x: `Tensor` of type `float32`, `float64`.
|
| 124 |
+
name: Python string. A name for the operation (default="ndtr").
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
ndtr: `Tensor` with `dtype=x.dtype`.
|
| 128 |
+
|
| 129 |
+
Raises:
|
| 130 |
+
TypeError: if `x` is not floating-type.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
with ops.name_scope(name, values=[x]):
|
| 134 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 135 |
+
if x.dtype.as_numpy_dtype not in [np.float32, np.float64]:
|
| 136 |
+
raise TypeError(
|
| 137 |
+
"x.dtype=%s is not handled, see docstring for supported types."
|
| 138 |
+
% x.dtype)
|
| 139 |
+
return _ndtr(x)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _ndtr(x):
|
| 143 |
+
"""Implements ndtr core logic."""
|
| 144 |
+
half_sqrt_2 = constant_op.constant(
|
| 145 |
+
0.5 * np.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
|
| 146 |
+
w = x * half_sqrt_2
|
| 147 |
+
z = math_ops.abs(w)
|
| 148 |
+
y = array_ops.where_v2(
|
| 149 |
+
math_ops.less(z, half_sqrt_2), 1. + math_ops.erf(w),
|
| 150 |
+
array_ops.where_v2(
|
| 151 |
+
math_ops.greater(w, 0.), 2. - math_ops.erfc(z), math_ops.erfc(z)))
|
| 152 |
+
return 0.5 * y
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def ndtri(p, name="ndtri"):
|
| 156 |
+
"""The inverse of the CDF of the Normal distribution function.
|
| 157 |
+
|
| 158 |
+
Returns x such that the area under the pdf from minus infinity to x is equal
|
| 159 |
+
to p.
|
| 160 |
+
|
| 161 |
+
A piece-wise rational approximation is done for the function.
|
| 162 |
+
This is a port of the implementation in netlib.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
p: `Tensor` of type `float32`, `float64`.
|
| 166 |
+
name: Python string. A name for the operation (default="ndtri").
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
x: `Tensor` with `dtype=p.dtype`.
|
| 170 |
+
|
| 171 |
+
Raises:
|
| 172 |
+
TypeError: if `p` is not floating-type.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
with ops.name_scope(name, values=[p]):
|
| 176 |
+
p = ops.convert_to_tensor(p, name="p")
|
| 177 |
+
if p.dtype.as_numpy_dtype not in [np.float32, np.float64]:
|
| 178 |
+
raise TypeError(
|
| 179 |
+
"p.dtype=%s is not handled, see docstring for supported types."
|
| 180 |
+
% p.dtype)
|
| 181 |
+
return _ndtri(p)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _ndtri(p):
|
| 185 |
+
"""Implements ndtri core logic."""
|
| 186 |
+
|
| 187 |
+
# Constants used in piece-wise rational approximations. Taken from the cephes
|
| 188 |
+
# library:
|
| 189 |
+
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
|
| 190 |
+
|
| 191 |
+
p0 = [
|
| 192 |
+
-1.23916583867381258016E0, 1.39312609387279679503E1,
|
| 193 |
+
-5.66762857469070293439E1, 9.80010754185999661536E1,
|
| 194 |
+
-5.99633501014107895267E1
|
| 195 |
+
]
|
| 196 |
+
q0 = [
|
| 197 |
+
-1.18331621121330003142E0, 1.59056225126211695515E1,
|
| 198 |
+
-8.20372256168333339912E1, 2.00260212380060660359E2,
|
| 199 |
+
-2.25462687854119370527E2, 8.63602421390890590575E1,
|
| 200 |
+
4.67627912898881538453E0, 1.95448858338141759834E0, 1.0
|
| 201 |
+
]
|
| 202 |
+
p1 = [
|
| 203 |
+
-8.57456785154685413611E-4, -3.50424626827848203418E-2,
|
| 204 |
+
-1.40256079171354495875E-1, 2.18663306850790267539E0,
|
| 205 |
+
1.46849561928858024014E1, 4.40805073893200834700E1,
|
| 206 |
+
5.71628192246421288162E1, 3.15251094599893866154E1,
|
| 207 |
+
4.05544892305962419923E0
|
| 208 |
+
]
|
| 209 |
+
q1 = [
|
| 210 |
+
-9.33259480895457427372E-4, -3.80806407691578277194E-2,
|
| 211 |
+
-1.42182922854787788574E-1, 2.50464946208309415979E0,
|
| 212 |
+
1.50425385692907503408E1, 4.13172038254672030440E1,
|
| 213 |
+
4.53907635128879210584E1, 1.57799883256466749731E1, 1.0
|
| 214 |
+
]
|
| 215 |
+
p2 = [
|
| 216 |
+
6.23974539184983293730E-9, 2.65806974686737550832E-6,
|
| 217 |
+
3.01581553508235416007E-4, 1.23716634817820021358E-2,
|
| 218 |
+
2.01485389549179081538E-1, 1.33303460815807542389E0,
|
| 219 |
+
3.93881025292474443415E0, 6.91522889068984211695E0,
|
| 220 |
+
3.23774891776946035970E0
|
| 221 |
+
]
|
| 222 |
+
q2 = [
|
| 223 |
+
6.79019408009981274425E-9, 2.89247864745380683936E-6,
|
| 224 |
+
3.28014464682127739104E-4, 1.34204006088543189037E-2,
|
| 225 |
+
2.16236993594496635890E-1, 1.37702099489081330271E0,
|
| 226 |
+
3.67983563856160859403E0, 6.02427039364742014255E0, 1.0
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
def _create_polynomial(var, coeffs):
|
| 230 |
+
"""Compute n_th order polynomial via Horner's method."""
|
| 231 |
+
coeffs = np.array(coeffs, var.dtype.as_numpy_dtype)
|
| 232 |
+
if not coeffs.size:
|
| 233 |
+
return array_ops.zeros_like(var)
|
| 234 |
+
return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
|
| 235 |
+
|
| 236 |
+
maybe_complement_p = array_ops.where_v2(p > -np.expm1(-2.), 1. - p, p)
|
| 237 |
+
# Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
|
| 238 |
+
# later on. The result from the computation when p == 0 is not used so any
|
| 239 |
+
# number that doesn't result in NaNs is fine.
|
| 240 |
+
sanitized_mcp = array_ops.where_v2(
|
| 241 |
+
maybe_complement_p <= 0.,
|
| 242 |
+
array_ops.fill(array_ops.shape(p), np.array(0.5, p.dtype.as_numpy_dtype)),
|
| 243 |
+
maybe_complement_p)
|
| 244 |
+
|
| 245 |
+
# Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
|
| 246 |
+
w = sanitized_mcp - 0.5
|
| 247 |
+
ww = w ** 2
|
| 248 |
+
x_for_big_p = w + w * ww * (_create_polynomial(ww, p0)
|
| 249 |
+
/ _create_polynomial(ww, q0))
|
| 250 |
+
x_for_big_p *= -np.sqrt(2. * np.pi)
|
| 251 |
+
|
| 252 |
+
# Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
|
| 253 |
+
# where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
|
| 254 |
+
# arrays based on whether p < exp(-32).
|
| 255 |
+
z = math_ops.sqrt(-2. * math_ops.log(sanitized_mcp))
|
| 256 |
+
first_term = z - math_ops.log(z) / z
|
| 257 |
+
second_term_small_p = (
|
| 258 |
+
_create_polynomial(1. / z, p2) /
|
| 259 |
+
_create_polynomial(1. / z, q2) / z)
|
| 260 |
+
second_term_otherwise = (
|
| 261 |
+
_create_polynomial(1. / z, p1) /
|
| 262 |
+
_create_polynomial(1. / z, q1) / z)
|
| 263 |
+
x_for_small_p = first_term - second_term_small_p
|
| 264 |
+
x_otherwise = first_term - second_term_otherwise
|
| 265 |
+
|
| 266 |
+
x = array_ops.where_v2(
|
| 267 |
+
sanitized_mcp > np.exp(-2.), x_for_big_p,
|
| 268 |
+
array_ops.where_v2(z >= 8.0, x_for_small_p, x_otherwise))
|
| 269 |
+
|
| 270 |
+
x = array_ops.where_v2(p > 1. - np.exp(-2.), x, -x)
|
| 271 |
+
infinity_scalar = constant_op.constant(np.inf, dtype=p.dtype)
|
| 272 |
+
infinity = array_ops.fill(array_ops.shape(p), infinity_scalar)
|
| 273 |
+
x_nan_replaced = array_ops.where_v2(p <= 0.0, -infinity,
|
| 274 |
+
array_ops.where_v2(p >= 1.0, infinity, x))
|
| 275 |
+
return x_nan_replaced
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def log_ndtr(x, series_order=3, name="log_ndtr"):
|
| 279 |
+
"""Log Normal distribution function.
|
| 280 |
+
|
| 281 |
+
For details of the Normal distribution function see `ndtr`.
|
| 282 |
+
|
| 283 |
+
This function calculates `(log o ndtr)(x)` by either calling `log(ndtr(x))` or
|
| 284 |
+
using an asymptotic series. Specifically:
|
| 285 |
+
- For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
|
| 286 |
+
`log(1-x) ~= -x, x << 1`.
|
| 287 |
+
- For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
|
| 288 |
+
and take a log.
|
| 289 |
+
- For `x <= lower_segment`, we use the series approximation of erf to compute
|
| 290 |
+
the log CDF directly.
|
| 291 |
+
|
| 292 |
+
The `lower_segment` is set based on the precision of the input:
|
| 293 |
+
|
| 294 |
+
```
|
| 295 |
+
lower_segment = { -20, x.dtype=float64
|
| 296 |
+
{ -10, x.dtype=float32
|
| 297 |
+
upper_segment = { 8, x.dtype=float64
|
| 298 |
+
{ 5, x.dtype=float32
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
When `x < lower_segment`, the `ndtr` asymptotic series approximation is:
|
| 302 |
+
|
| 303 |
+
```
|
| 304 |
+
ndtr(x) = scale * (1 + sum) + R_N
|
| 305 |
+
scale = exp(-0.5 x**2) / (-x sqrt(2 pi))
|
| 306 |
+
sum = Sum{(-1)^n (2n-1)!! / (x**2)^n, n=1:N}
|
| 307 |
+
R_N = O(exp(-0.5 x**2) (2N+1)!! / |x|^{2N+3})
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
where `(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a
|
| 311 |
+
[double-factorial](https://en.wikipedia.org/wiki/Double_factorial).
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
x: `Tensor` of type `float32`, `float64`.
|
| 316 |
+
series_order: Positive Python `integer`. Maximum depth to
|
| 317 |
+
evaluate the asymptotic expansion. This is the `N` above.
|
| 318 |
+
name: Python string. A name for the operation (default="log_ndtr").
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
log_ndtr: `Tensor` with `dtype=x.dtype`.
|
| 322 |
+
|
| 323 |
+
Raises:
|
| 324 |
+
TypeError: if `x.dtype` is not handled.
|
| 325 |
+
TypeError: if `series_order` is a not Python `integer.`
|
| 326 |
+
ValueError: if `series_order` is not in `[0, 30]`.
|
| 327 |
+
"""
|
| 328 |
+
if not isinstance(series_order, int):
|
| 329 |
+
raise TypeError("series_order must be a Python integer.")
|
| 330 |
+
if series_order < 0:
|
| 331 |
+
raise ValueError("series_order must be non-negative.")
|
| 332 |
+
if series_order > 30:
|
| 333 |
+
raise ValueError("series_order must be <= 30.")
|
| 334 |
+
|
| 335 |
+
with ops.name_scope(name, values=[x]):
|
| 336 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 337 |
+
|
| 338 |
+
if x.dtype.as_numpy_dtype == np.float64:
|
| 339 |
+
lower_segment = LOGNDTR_FLOAT64_LOWER
|
| 340 |
+
upper_segment = LOGNDTR_FLOAT64_UPPER
|
| 341 |
+
elif x.dtype.as_numpy_dtype == np.float32:
|
| 342 |
+
lower_segment = LOGNDTR_FLOAT32_LOWER
|
| 343 |
+
upper_segment = LOGNDTR_FLOAT32_UPPER
|
| 344 |
+
else:
|
| 345 |
+
raise TypeError("x.dtype=%s is not supported." % x.dtype)
|
| 346 |
+
|
| 347 |
+
# The basic idea here was ported from:
|
| 348 |
+
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
|
| 349 |
+
# We copy the main idea, with a few changes
|
| 350 |
+
# * For x >> 1, and X ~ Normal(0, 1),
|
| 351 |
+
# Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
|
| 352 |
+
# which extends the range of validity of this function.
|
| 353 |
+
# * We use one fixed series_order for all of 'x', rather than adaptive.
|
| 354 |
+
# * Our docstring properly reflects that this is an asymptotic series, not a
|
| 355 |
+
# Taylor series. We also provided a correct bound on the remainder.
|
| 356 |
+
# * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
|
| 357 |
+
# x=0. This happens even though the branch is unchosen because when x=0
|
| 358 |
+
# the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
|
| 359 |
+
# regardless of whether dy is finite. Note that the minimum is a NOP if
|
| 360 |
+
# the branch is chosen.
|
| 361 |
+
return array_ops.where_v2(
|
| 362 |
+
math_ops.greater(x, upper_segment),
|
| 363 |
+
-_ndtr(-x), # log(1-x) ~= -x, x << 1 # pylint: disable=invalid-unary-operand-type
|
| 364 |
+
array_ops.where_v2(
|
| 365 |
+
math_ops.greater(x, lower_segment),
|
| 366 |
+
math_ops.log(_ndtr(math_ops.maximum(x, lower_segment))),
|
| 367 |
+
_log_ndtr_lower(math_ops.minimum(x, lower_segment), series_order)))
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def _log_ndtr_lower(x, series_order):
|
| 371 |
+
"""Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
|
| 372 |
+
x_2 = math_ops.square(x)
|
| 373 |
+
# Log of the term multiplying (1 + sum)
|
| 374 |
+
log_scale = -0.5 * x_2 - math_ops.log(-x) - 0.5 * np.log(2. * np.pi)
|
| 375 |
+
return log_scale + math_ops.log(_log_ndtr_asymptotic_series(x, series_order))
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def _log_ndtr_asymptotic_series(x, series_order):
|
| 379 |
+
"""Calculates the asymptotic series used in log_ndtr."""
|
| 380 |
+
dtype = x.dtype.as_numpy_dtype
|
| 381 |
+
if series_order <= 0:
|
| 382 |
+
return np.array(1, dtype)
|
| 383 |
+
x_2 = math_ops.square(x)
|
| 384 |
+
even_sum = array_ops.zeros_like(x)
|
| 385 |
+
odd_sum = array_ops.zeros_like(x)
|
| 386 |
+
x_2n = x_2 # Start with x^{2*1} = x^{2*n} with n = 1.
|
| 387 |
+
for n in range(1, series_order + 1):
|
| 388 |
+
y = np.array(_double_factorial(2 * n - 1), dtype) / x_2n
|
| 389 |
+
if n % 2:
|
| 390 |
+
odd_sum += y
|
| 391 |
+
else:
|
| 392 |
+
even_sum += y
|
| 393 |
+
x_2n *= x_2
|
| 394 |
+
return 1. + even_sum - odd_sum
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def erfinv(x, name="erfinv"):
|
| 398 |
+
"""The inverse function for erf, the error function.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
x: `Tensor` of type `float32`, `float64`.
|
| 402 |
+
name: Python string. A name for the operation (default="erfinv").
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
x: `Tensor` with `dtype=x.dtype`.
|
| 406 |
+
|
| 407 |
+
Raises:
|
| 408 |
+
TypeError: if `x` is not floating-type.
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
with ops.name_scope(name, values=[x]):
|
| 412 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 413 |
+
if x.dtype.as_numpy_dtype not in [np.float32, np.float64]:
|
| 414 |
+
raise TypeError(
|
| 415 |
+
"x.dtype=%s is not handled, see docstring for supported types."
|
| 416 |
+
% x.dtype)
|
| 417 |
+
return ndtri((x + 1.0) / 2.0) / np.sqrt(2)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def _double_factorial(n):
|
| 421 |
+
"""The double factorial function for small Python integer `n`."""
|
| 422 |
+
return np.prod(np.arange(n, 1, -2))
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def log_cdf_laplace(x, name="log_cdf_laplace"):
|
| 426 |
+
"""Log Laplace distribution function.
|
| 427 |
+
|
| 428 |
+
This function calculates `Log[L(x)]`, where `L(x)` is the cumulative
|
| 429 |
+
distribution function of the Laplace distribution, i.e.
|
| 430 |
+
|
| 431 |
+
```L(x) := 0.5 * int_{-infty}^x e^{-|t|} dt```
|
| 432 |
+
|
| 433 |
+
For numerical accuracy, `L(x)` is computed in different ways depending on `x`,
|
| 434 |
+
|
| 435 |
+
```
|
| 436 |
+
x <= 0:
|
| 437 |
+
Log[L(x)] = Log[0.5] + x, which is exact
|
| 438 |
+
|
| 439 |
+
0 < x:
|
| 440 |
+
Log[L(x)] = Log[1 - 0.5 * e^{-x}], which is exact
|
| 441 |
+
```
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
x: `Tensor` of type `float32`, `float64`.
|
| 445 |
+
name: Python string. A name for the operation (default="log_ndtr").
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
`Tensor` with `dtype=x.dtype`.
|
| 449 |
+
|
| 450 |
+
Raises:
|
| 451 |
+
TypeError: if `x.dtype` is not handled.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
with ops.name_scope(name, values=[x]):
|
| 455 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 456 |
+
|
| 457 |
+
# For x < 0, L(x) = 0.5 * exp{x} exactly, so Log[L(x)] = log(0.5) + x.
|
| 458 |
+
lower_solution = -np.log(2.) + x
|
| 459 |
+
|
| 460 |
+
# safe_exp_neg_x = exp{-x} for x > 0, but is
|
| 461 |
+
# bounded above by 1, which avoids
|
| 462 |
+
# log[1 - 1] = -inf for x = log(1/2), AND
|
| 463 |
+
# exp{-x} --> inf, for x << -1
|
| 464 |
+
safe_exp_neg_x = math_ops.exp(-math_ops.abs(x))
|
| 465 |
+
|
| 466 |
+
# log1p(z) = log(1 + z) approx z for |z| << 1. This approximation is used
|
| 467 |
+
# internally by log1p, rather than being done explicitly here.
|
| 468 |
+
upper_solution = math_ops.log1p(-0.5 * safe_exp_neg_x)
|
| 469 |
+
|
| 470 |
+
return array_ops.where_v2(x < 0., lower_solution, upper_solution)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/student_t.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Student's t distribution class."""
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from tensorflow.python.framework import constant_op
|
| 20 |
+
from tensorflow.python.framework import dtypes
|
| 21 |
+
from tensorflow.python.framework import ops
|
| 22 |
+
from tensorflow.python.framework import tensor_shape
|
| 23 |
+
from tensorflow.python.ops import array_ops
|
| 24 |
+
from tensorflow.python.ops import check_ops
|
| 25 |
+
from tensorflow.python.ops import control_flow_ops
|
| 26 |
+
from tensorflow.python.ops import math_ops
|
| 27 |
+
from tensorflow.python.ops import nn
|
| 28 |
+
from tensorflow.python.ops import random_ops
|
| 29 |
+
from tensorflow.python.ops import special_math_ops
|
| 30 |
+
from tensorflow.python.ops.distributions import distribution
|
| 31 |
+
from tensorflow.python.ops.distributions import util as distribution_util
|
| 32 |
+
from tensorflow.python.util import deprecation
|
| 33 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"StudentT",
|
| 38 |
+
"StudentTWithAbsDfSoftplusScale",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@tf_export(v1=["distributions.StudentT"])
|
| 43 |
+
class StudentT(distribution.Distribution):
|
| 44 |
+
"""Student's t-distribution.
|
| 45 |
+
|
| 46 |
+
This distribution has parameters: degree of freedom `df`, location `loc`,
|
| 47 |
+
and `scale`.
|
| 48 |
+
|
| 49 |
+
#### Mathematical details
|
| 50 |
+
|
| 51 |
+
The probability density function (pdf) is,
|
| 52 |
+
|
| 53 |
+
```none
|
| 54 |
+
pdf(x; df, mu, sigma) = (1 + y**2 / df)**(-0.5 (df + 1)) / Z
|
| 55 |
+
where,
|
| 56 |
+
y = (x - mu) / sigma
|
| 57 |
+
Z = abs(sigma) sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1))
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
where:
|
| 61 |
+
* `loc = mu`,
|
| 62 |
+
* `scale = sigma`, and,
|
| 63 |
+
* `Z` is the normalization constant, and,
|
| 64 |
+
* `Gamma` is the [gamma function](
|
| 65 |
+
https://en.wikipedia.org/wiki/Gamma_function).
|
| 66 |
+
|
| 67 |
+
The StudentT distribution is a member of the [location-scale family](
|
| 68 |
+
https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
|
| 69 |
+
constructed as,
|
| 70 |
+
|
| 71 |
+
```none
|
| 72 |
+
X ~ StudentT(df, loc=0, scale=1)
|
| 73 |
+
Y = loc + scale * X
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
Notice that `scale` has semantics more similar to standard deviation than
|
| 77 |
+
variance. However it is not actually the std. deviation; the Student's
|
| 78 |
+
t-distribution std. dev. is `scale sqrt(df / (df - 2))` when `df > 2`.
|
| 79 |
+
|
| 80 |
+
Samples of this distribution are reparameterized (pathwise differentiable).
|
| 81 |
+
The derivatives are computed using the approach described in
|
| 82 |
+
(Figurnov et al., 2018).
|
| 83 |
+
|
| 84 |
+
#### Examples
|
| 85 |
+
|
| 86 |
+
Examples of initialization of one or a batch of distributions.
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
import tensorflow_probability as tfp
|
| 90 |
+
tfd = tfp.distributions
|
| 91 |
+
|
| 92 |
+
# Define a single scalar Student t distribution.
|
| 93 |
+
single_dist = tfd.StudentT(df=3)
|
| 94 |
+
|
| 95 |
+
# Evaluate the pdf at 1, returning a scalar Tensor.
|
| 96 |
+
single_dist.prob(1.)
|
| 97 |
+
|
| 98 |
+
# Define a batch of two scalar valued Student t's.
|
| 99 |
+
# The first has degrees of freedom 2, mean 1, and scale 11.
|
| 100 |
+
# The second 3, 2 and 22.
|
| 101 |
+
multi_dist = tfd.StudentT(df=[2, 3], loc=[1, 2.], scale=[11, 22.])
|
| 102 |
+
|
| 103 |
+
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
|
| 104 |
+
# returning a length two tensor.
|
| 105 |
+
multi_dist.prob([0, 1.5])
|
| 106 |
+
|
| 107 |
+
# Get 3 samples, returning a 3 x 2 tensor.
|
| 108 |
+
multi_dist.sample(3)
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Arguments are broadcast when possible.
|
| 112 |
+
|
| 113 |
+
```python
|
| 114 |
+
# Define a batch of two Student's t distributions.
|
| 115 |
+
# Both have df 2 and mean 1, but different scales.
|
| 116 |
+
dist = tfd.StudentT(df=2, loc=1, scale=[11, 22.])
|
| 117 |
+
|
| 118 |
+
# Evaluate the pdf of both distributions on the same point, 3.0,
|
| 119 |
+
# returning a length 2 tensor.
|
| 120 |
+
dist.prob(3.0)
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
Compute the gradients of samples w.r.t. the parameters:
|
| 124 |
+
|
| 125 |
+
```python
|
| 126 |
+
df = tf.constant(2.0)
|
| 127 |
+
loc = tf.constant(2.0)
|
| 128 |
+
scale = tf.constant(11.0)
|
| 129 |
+
dist = tfd.StudentT(df=df, loc=loc, scale=scale)
|
| 130 |
+
samples = dist.sample(5) # Shape [5]
|
| 131 |
+
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
|
| 132 |
+
# Unbiased stochastic gradients of the loss function
|
| 133 |
+
grads = tf.gradients(loss, [df, loc, scale])
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
References:
|
| 137 |
+
Implicit Reparameterization Gradients:
|
| 138 |
+
[Figurnov et al., 2018]
|
| 139 |
+
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
|
| 140 |
+
([pdf](http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
@deprecation.deprecated(
|
| 144 |
+
"2019-01-01",
|
| 145 |
+
"The TensorFlow Distributions library has moved to "
|
| 146 |
+
"TensorFlow Probability "
|
| 147 |
+
"(https://github.com/tensorflow/probability). You "
|
| 148 |
+
"should update all references to use `tfp.distributions` "
|
| 149 |
+
"instead of `tf.distributions`.",
|
| 150 |
+
warn_once=True)
|
| 151 |
+
def __init__(self,
|
| 152 |
+
df,
|
| 153 |
+
loc,
|
| 154 |
+
scale,
|
| 155 |
+
validate_args=False,
|
| 156 |
+
allow_nan_stats=True,
|
| 157 |
+
name="StudentT"):
|
| 158 |
+
"""Construct Student's t distributions.
|
| 159 |
+
|
| 160 |
+
The distributions have degree of freedom `df`, mean `loc`, and scale
|
| 161 |
+
`scale`.
|
| 162 |
+
|
| 163 |
+
The parameters `df`, `loc`, and `scale` must be shaped in a way that
|
| 164 |
+
supports broadcasting (e.g. `df + loc + scale` is a valid operation).
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
df: Floating-point `Tensor`. The degrees of freedom of the
|
| 168 |
+
distribution(s). `df` must contain only positive values.
|
| 169 |
+
loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
|
| 170 |
+
scale: Floating-point `Tensor`. The scaling factor(s) for the
|
| 171 |
+
distribution(s). Note that `scale` is not technically the standard
|
| 172 |
+
deviation of this distribution but has semantics more similar to
|
| 173 |
+
standard deviation than variance.
|
| 174 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 175 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 176 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 177 |
+
outputs.
|
| 178 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`,
|
| 179 |
+
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
|
| 180 |
+
indicate the result is undefined. When `False`, an exception is raised
|
| 181 |
+
if one or more of the statistic's batch members are undefined.
|
| 182 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 183 |
+
|
| 184 |
+
Raises:
|
| 185 |
+
TypeError: if loc and scale are different dtypes.
|
| 186 |
+
"""
|
| 187 |
+
parameters = dict(locals())
|
| 188 |
+
with ops.name_scope(name, values=[df, loc, scale]) as name:
|
| 189 |
+
with ops.control_dependencies([check_ops.assert_positive(df)]
|
| 190 |
+
if validate_args else []):
|
| 191 |
+
self._df = array_ops.identity(df, name="df")
|
| 192 |
+
self._loc = array_ops.identity(loc, name="loc")
|
| 193 |
+
self._scale = array_ops.identity(scale, name="scale")
|
| 194 |
+
check_ops.assert_same_float_dtype(
|
| 195 |
+
(self._df, self._loc, self._scale))
|
| 196 |
+
super(StudentT, self).__init__(
|
| 197 |
+
dtype=self._scale.dtype,
|
| 198 |
+
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
|
| 199 |
+
validate_args=validate_args,
|
| 200 |
+
allow_nan_stats=allow_nan_stats,
|
| 201 |
+
parameters=parameters,
|
| 202 |
+
graph_parents=[self._df, self._loc, self._scale],
|
| 203 |
+
name=name)
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
def _param_shapes(sample_shape):
|
| 207 |
+
return dict(
|
| 208 |
+
zip(("df", "loc", "scale"), (
|
| 209 |
+
[ops.convert_to_tensor(
|
| 210 |
+
sample_shape, dtype=dtypes.int32)] * 3)))
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def df(self):
|
| 214 |
+
"""Degrees of freedom in these Student's t distribution(s)."""
|
| 215 |
+
return self._df
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def loc(self):
|
| 219 |
+
"""Locations of these Student's t distribution(s)."""
|
| 220 |
+
return self._loc
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def scale(self):
|
| 224 |
+
"""Scaling factors of these Student's t distribution(s)."""
|
| 225 |
+
return self._scale
|
| 226 |
+
|
| 227 |
+
def _batch_shape_tensor(self):
|
| 228 |
+
return array_ops.broadcast_dynamic_shape(
|
| 229 |
+
array_ops.shape(self.df),
|
| 230 |
+
array_ops.broadcast_dynamic_shape(
|
| 231 |
+
array_ops.shape(self.loc), array_ops.shape(self.scale)))
|
| 232 |
+
|
| 233 |
+
def _batch_shape(self):
|
| 234 |
+
return array_ops.broadcast_static_shape(
|
| 235 |
+
array_ops.broadcast_static_shape(self.df.get_shape(),
|
| 236 |
+
self.loc.get_shape()),
|
| 237 |
+
self.scale.get_shape())
|
| 238 |
+
|
| 239 |
+
def _event_shape_tensor(self):
|
| 240 |
+
return constant_op.constant([], dtype=math_ops.int32)
|
| 241 |
+
|
| 242 |
+
def _event_shape(self):
|
| 243 |
+
return tensor_shape.TensorShape([])
|
| 244 |
+
|
| 245 |
+
def _sample_n(self, n, seed=None):
|
| 246 |
+
# The sampling method comes from the fact that if:
|
| 247 |
+
# X ~ Normal(0, 1)
|
| 248 |
+
# Z ~ Chi2(df)
|
| 249 |
+
# Y = X / sqrt(Z / df)
|
| 250 |
+
# then:
|
| 251 |
+
# Y ~ StudentT(df).
|
| 252 |
+
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
|
| 253 |
+
normal_sample = random_ops.random_normal(shape, dtype=self.dtype, seed=seed)
|
| 254 |
+
df = self.df * array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
|
| 255 |
+
gamma_sample = random_ops.random_gamma(
|
| 256 |
+
[n],
|
| 257 |
+
0.5 * df,
|
| 258 |
+
beta=0.5,
|
| 259 |
+
dtype=self.dtype,
|
| 260 |
+
seed=distribution_util.gen_new_seed(seed, salt="student_t"))
|
| 261 |
+
samples = normal_sample * math_ops.rsqrt(gamma_sample / df)
|
| 262 |
+
return samples * self.scale + self.loc # Abs(scale) not wanted.
|
| 263 |
+
|
| 264 |
+
def _log_prob(self, x):
|
| 265 |
+
return self._log_unnormalized_prob(x) - self._log_normalization()
|
| 266 |
+
|
| 267 |
+
def _log_unnormalized_prob(self, x):
|
| 268 |
+
y = (x - self.loc) / self.scale # Abs(scale) superfluous.
|
| 269 |
+
return -0.5 * (self.df + 1.) * math_ops.log1p(y**2. / self.df)
|
| 270 |
+
|
| 271 |
+
def _log_normalization(self):
|
| 272 |
+
return (math_ops.log(math_ops.abs(self.scale)) +
|
| 273 |
+
0.5 * math_ops.log(self.df) +
|
| 274 |
+
0.5 * np.log(np.pi) +
|
| 275 |
+
math_ops.lgamma(0.5 * self.df) -
|
| 276 |
+
math_ops.lgamma(0.5 * (self.df + 1.)))
|
| 277 |
+
|
| 278 |
+
def _cdf(self, x):
|
| 279 |
+
# Take Abs(scale) to make subsequent where work correctly.
|
| 280 |
+
y = (x - self.loc) / math_ops.abs(self.scale)
|
| 281 |
+
x_t = self.df / (y**2. + self.df)
|
| 282 |
+
neg_cdf = 0.5 * math_ops.betainc(0.5 * self.df, 0.5, x_t)
|
| 283 |
+
return array_ops.where_v2(math_ops.less(y, 0.), neg_cdf, 1. - neg_cdf)
|
| 284 |
+
|
| 285 |
+
def _entropy(self):
|
| 286 |
+
v = array_ops.ones(self.batch_shape_tensor(),
|
| 287 |
+
dtype=self.dtype)[..., array_ops.newaxis]
|
| 288 |
+
u = v * self.df[..., array_ops.newaxis]
|
| 289 |
+
beta_arg = array_ops.concat([u, v], -1) / 2.
|
| 290 |
+
return (math_ops.log(math_ops.abs(self.scale)) +
|
| 291 |
+
0.5 * math_ops.log(self.df) +
|
| 292 |
+
special_math_ops.lbeta(beta_arg) +
|
| 293 |
+
0.5 * (self.df + 1.) *
|
| 294 |
+
(math_ops.digamma(0.5 * (self.df + 1.)) -
|
| 295 |
+
math_ops.digamma(0.5 * self.df)))
|
| 296 |
+
|
| 297 |
+
@distribution_util.AppendDocstring(
|
| 298 |
+
"""The mean of Student's T equals `loc` if `df > 1`, otherwise it is
|
| 299 |
+
`NaN`. If `self.allow_nan_stats=True`, then an exception will be raised
|
| 300 |
+
rather than returning `NaN`.""")
|
| 301 |
+
def _mean(self):
|
| 302 |
+
mean = self.loc * array_ops.ones(self.batch_shape_tensor(),
|
| 303 |
+
dtype=self.dtype)
|
| 304 |
+
if self.allow_nan_stats:
|
| 305 |
+
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
|
| 306 |
+
return array_ops.where_v2(
|
| 307 |
+
math_ops.greater(
|
| 308 |
+
self.df,
|
| 309 |
+
array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)),
|
| 310 |
+
mean, array_ops.fill(self.batch_shape_tensor(), nan, name="nan"))
|
| 311 |
+
else:
|
| 312 |
+
return control_flow_ops.with_dependencies(
|
| 313 |
+
[
|
| 314 |
+
check_ops.assert_less(
|
| 315 |
+
array_ops.ones([], dtype=self.dtype),
|
| 316 |
+
self.df,
|
| 317 |
+
message="mean not defined for components of df <= 1"),
|
| 318 |
+
],
|
| 319 |
+
mean)
|
| 320 |
+
|
| 321 |
+
@distribution_util.AppendDocstring("""
|
| 322 |
+
The variance for Student's T equals
|
| 323 |
+
|
| 324 |
+
```
|
| 325 |
+
df / (df - 2), when df > 2
|
| 326 |
+
infinity, when 1 < df <= 2
|
| 327 |
+
NaN, when df <= 1
|
| 328 |
+
```
|
| 329 |
+
""")
|
| 330 |
+
def _variance(self):
|
| 331 |
+
# We need to put the tf.where inside the outer tf.where to ensure we never
|
| 332 |
+
# hit a NaN in the gradient.
|
| 333 |
+
denom = array_ops.where_v2(
|
| 334 |
+
math_ops.greater(self.df, 2.), self.df - 2.,
|
| 335 |
+
array_ops.ones_like(self.df))
|
| 336 |
+
# Abs(scale) superfluous.
|
| 337 |
+
var = (array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) *
|
| 338 |
+
math_ops.square(self.scale) * self.df / denom)
|
| 339 |
+
# When 1 < df <= 2, variance is infinite.
|
| 340 |
+
inf = np.array(np.inf, dtype=self.dtype.as_numpy_dtype())
|
| 341 |
+
result_where_defined = array_ops.where_v2(
|
| 342 |
+
self.df > array_ops.fill(self.batch_shape_tensor(), 2.), var,
|
| 343 |
+
array_ops.fill(self.batch_shape_tensor(), inf, name="inf"))
|
| 344 |
+
|
| 345 |
+
if self.allow_nan_stats:
|
| 346 |
+
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
|
| 347 |
+
return array_ops.where_v2(
|
| 348 |
+
math_ops.greater(
|
| 349 |
+
self.df,
|
| 350 |
+
array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)),
|
| 351 |
+
result_where_defined,
|
| 352 |
+
array_ops.fill(self.batch_shape_tensor(), nan, name="nan"))
|
| 353 |
+
else:
|
| 354 |
+
return control_flow_ops.with_dependencies(
|
| 355 |
+
[
|
| 356 |
+
check_ops.assert_less(
|
| 357 |
+
array_ops.ones([], dtype=self.dtype),
|
| 358 |
+
self.df,
|
| 359 |
+
message="variance not defined for components of df <= 1"),
|
| 360 |
+
],
|
| 361 |
+
result_where_defined)
|
| 362 |
+
|
| 363 |
+
def _mode(self):
|
| 364 |
+
return array_ops.identity(self.loc)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class StudentTWithAbsDfSoftplusScale(StudentT):
|
| 368 |
+
"""StudentT with `df = floor(abs(df))` and `scale = softplus(scale)`."""
|
| 369 |
+
|
| 370 |
+
@deprecation.deprecated(
|
| 371 |
+
"2019-01-01",
|
| 372 |
+
"Use `tfd.StudentT(tf.floor(tf.abs(df)), loc, "
|
| 373 |
+
"tf.nn.softplus(scale)) instead.",
|
| 374 |
+
warn_once=True)
|
| 375 |
+
def __init__(self,
|
| 376 |
+
df,
|
| 377 |
+
loc,
|
| 378 |
+
scale,
|
| 379 |
+
validate_args=False,
|
| 380 |
+
allow_nan_stats=True,
|
| 381 |
+
name="StudentTWithAbsDfSoftplusScale"):
|
| 382 |
+
parameters = dict(locals())
|
| 383 |
+
with ops.name_scope(name, values=[df, scale]) as name:
|
| 384 |
+
super(StudentTWithAbsDfSoftplusScale, self).__init__(
|
| 385 |
+
df=math_ops.floor(math_ops.abs(df)),
|
| 386 |
+
loc=loc,
|
| 387 |
+
scale=nn.softplus(scale, name="softplus_scale"),
|
| 388 |
+
validate_args=validate_args,
|
| 389 |
+
allow_nan_stats=allow_nan_stats,
|
| 390 |
+
name=name)
|
| 391 |
+
self._parameters = parameters
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/uniform.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""The Uniform distribution class."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
from tensorflow.python.framework import constant_op
|
| 20 |
+
from tensorflow.python.framework import dtypes
|
| 21 |
+
from tensorflow.python.framework import ops
|
| 22 |
+
from tensorflow.python.framework import tensor_shape
|
| 23 |
+
from tensorflow.python.ops import array_ops
|
| 24 |
+
from tensorflow.python.ops import check_ops
|
| 25 |
+
from tensorflow.python.ops import math_ops
|
| 26 |
+
from tensorflow.python.ops import random_ops
|
| 27 |
+
from tensorflow.python.ops.distributions import distribution
|
| 28 |
+
from tensorflow.python.util import deprecation
|
| 29 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@tf_export(v1=["distributions.Uniform"])
|
| 33 |
+
class Uniform(distribution.Distribution):
|
| 34 |
+
"""Uniform distribution with `low` and `high` parameters.
|
| 35 |
+
|
| 36 |
+
#### Mathematical Details
|
| 37 |
+
|
| 38 |
+
The probability density function (pdf) is,
|
| 39 |
+
|
| 40 |
+
```none
|
| 41 |
+
pdf(x; a, b) = I[a <= x < b] / Z
|
| 42 |
+
Z = b - a
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
where
|
| 46 |
+
|
| 47 |
+
- `low = a`,
|
| 48 |
+
- `high = b`,
|
| 49 |
+
- `Z` is the normalizing constant, and
|
| 50 |
+
- `I[predicate]` is the [indicator function](
|
| 51 |
+
https://en.wikipedia.org/wiki/Indicator_function) for `predicate`.
|
| 52 |
+
|
| 53 |
+
The parameters `low` and `high` must be shaped in a way that supports
|
| 54 |
+
broadcasting (e.g., `high - low` is a valid operation).
|
| 55 |
+
|
| 56 |
+
#### Examples
|
| 57 |
+
|
| 58 |
+
```python
|
| 59 |
+
# Without broadcasting:
|
| 60 |
+
u1 = Uniform(low=3.0, high=4.0) # a single uniform distribution [3, 4]
|
| 61 |
+
u2 = Uniform(low=[1.0, 2.0],
|
| 62 |
+
high=[3.0, 4.0]) # 2 distributions [1, 3], [2, 4]
|
| 63 |
+
u3 = Uniform(low=[[1.0, 2.0],
|
| 64 |
+
[3.0, 4.0]],
|
| 65 |
+
high=[[1.5, 2.5],
|
| 66 |
+
[3.5, 4.5]]) # 4 distributions
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
# With broadcasting:
|
| 71 |
+
u1 = Uniform(low=3.0, high=[5.0, 6.0, 7.0]) # 3 distributions
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
@deprecation.deprecated(
|
| 77 |
+
"2019-01-01",
|
| 78 |
+
"The TensorFlow Distributions library has moved to "
|
| 79 |
+
"TensorFlow Probability "
|
| 80 |
+
"(https://github.com/tensorflow/probability). You "
|
| 81 |
+
"should update all references to use `tfp.distributions` "
|
| 82 |
+
"instead of `tf.distributions`.",
|
| 83 |
+
warn_once=True)
|
| 84 |
+
def __init__(self,
|
| 85 |
+
low=0.,
|
| 86 |
+
high=1.,
|
| 87 |
+
validate_args=False,
|
| 88 |
+
allow_nan_stats=True,
|
| 89 |
+
name="Uniform"):
|
| 90 |
+
"""Initialize a batch of Uniform distributions.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
low: Floating point tensor, lower boundary of the output interval. Must
|
| 94 |
+
have `low < high`.
|
| 95 |
+
high: Floating point tensor, upper boundary of the output interval. Must
|
| 96 |
+
have `low < high`.
|
| 97 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 98 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 99 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 100 |
+
outputs.
|
| 101 |
+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
| 102 |
+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
| 103 |
+
result is undefined. When `False`, an exception is raised if one or
|
| 104 |
+
more of the statistic's batch members are undefined.
|
| 105 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 106 |
+
|
| 107 |
+
Raises:
|
| 108 |
+
InvalidArgumentError: if `low >= high` and `validate_args=False`.
|
| 109 |
+
"""
|
| 110 |
+
parameters = dict(locals())
|
| 111 |
+
with ops.name_scope(name, values=[low, high]) as name:
|
| 112 |
+
with ops.control_dependencies([
|
| 113 |
+
check_ops.assert_less(
|
| 114 |
+
low, high, message="uniform not defined when low >= high.")
|
| 115 |
+
] if validate_args else []):
|
| 116 |
+
self._low = array_ops.identity(low, name="low")
|
| 117 |
+
self._high = array_ops.identity(high, name="high")
|
| 118 |
+
check_ops.assert_same_float_dtype([self._low, self._high])
|
| 119 |
+
super(Uniform, self).__init__(
|
| 120 |
+
dtype=self._low.dtype,
|
| 121 |
+
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
|
| 122 |
+
validate_args=validate_args,
|
| 123 |
+
allow_nan_stats=allow_nan_stats,
|
| 124 |
+
parameters=parameters,
|
| 125 |
+
graph_parents=[self._low,
|
| 126 |
+
self._high],
|
| 127 |
+
name=name)
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _param_shapes(sample_shape):
|
| 131 |
+
return dict(
|
| 132 |
+
zip(("low", "high"),
|
| 133 |
+
([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)))
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def low(self):
|
| 137 |
+
"""Lower boundary of the output interval."""
|
| 138 |
+
return self._low
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def high(self):
|
| 142 |
+
"""Upper boundary of the output interval."""
|
| 143 |
+
return self._high
|
| 144 |
+
|
| 145 |
+
def range(self, name="range"):
|
| 146 |
+
"""`high - low`."""
|
| 147 |
+
with self._name_scope(name):
|
| 148 |
+
return self.high - self.low
|
| 149 |
+
|
| 150 |
+
def _batch_shape_tensor(self):
|
| 151 |
+
return array_ops.broadcast_dynamic_shape(
|
| 152 |
+
array_ops.shape(self.low),
|
| 153 |
+
array_ops.shape(self.high))
|
| 154 |
+
|
| 155 |
+
def _batch_shape(self):
|
| 156 |
+
return array_ops.broadcast_static_shape(
|
| 157 |
+
self.low.get_shape(),
|
| 158 |
+
self.high.get_shape())
|
| 159 |
+
|
| 160 |
+
def _event_shape_tensor(self):
|
| 161 |
+
return constant_op.constant([], dtype=dtypes.int32)
|
| 162 |
+
|
| 163 |
+
def _event_shape(self):
|
| 164 |
+
return tensor_shape.TensorShape([])
|
| 165 |
+
|
| 166 |
+
def _sample_n(self, n, seed=None):
|
| 167 |
+
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
|
| 168 |
+
samples = random_ops.random_uniform(shape=shape,
|
| 169 |
+
dtype=self.dtype,
|
| 170 |
+
seed=seed)
|
| 171 |
+
return self.low + self.range() * samples
|
| 172 |
+
|
| 173 |
+
def _prob(self, x):
|
| 174 |
+
broadcasted_x = x * array_ops.ones(
|
| 175 |
+
self.batch_shape_tensor(), dtype=x.dtype)
|
| 176 |
+
return array_ops.where_v2(
|
| 177 |
+
math_ops.is_nan(broadcasted_x), broadcasted_x,
|
| 178 |
+
array_ops.where_v2(
|
| 179 |
+
math_ops.logical_or(broadcasted_x < self.low,
|
| 180 |
+
broadcasted_x >= self.high),
|
| 181 |
+
array_ops.zeros_like(broadcasted_x),
|
| 182 |
+
array_ops.ones_like(broadcasted_x) / self.range()))
|
| 183 |
+
|
| 184 |
+
def _cdf(self, x):
|
| 185 |
+
broadcast_shape = array_ops.broadcast_dynamic_shape(
|
| 186 |
+
array_ops.shape(x), self.batch_shape_tensor())
|
| 187 |
+
zeros = array_ops.zeros(broadcast_shape, dtype=self.dtype)
|
| 188 |
+
ones = array_ops.ones(broadcast_shape, dtype=self.dtype)
|
| 189 |
+
broadcasted_x = x * ones
|
| 190 |
+
result_if_not_big = array_ops.where_v2(
|
| 191 |
+
x < self.low, zeros, (broadcasted_x - self.low) / self.range())
|
| 192 |
+
return array_ops.where_v2(x >= self.high, ones, result_if_not_big)
|
| 193 |
+
|
| 194 |
+
def _entropy(self):
|
| 195 |
+
return math_ops.log(self.range())
|
| 196 |
+
|
| 197 |
+
def _mean(self):
|
| 198 |
+
return (self.low + self.high) / 2.
|
| 199 |
+
|
| 200 |
+
def _variance(self):
|
| 201 |
+
return math_ops.square(self.range()) / 12.
|
| 202 |
+
|
| 203 |
+
def _stddev(self):
|
| 204 |
+
return self.range() / math.sqrt(12.)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/distributions/util.py
ADDED
|
@@ -0,0 +1,1448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Utilities for probability distributions."""
|
| 16 |
+
|
| 17 |
+
import functools
|
| 18 |
+
import hashlib
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from tensorflow.python.framework import constant_op
|
| 23 |
+
from tensorflow.python.framework import dtypes
|
| 24 |
+
from tensorflow.python.framework import ops
|
| 25 |
+
from tensorflow.python.framework import tensor_shape
|
| 26 |
+
from tensorflow.python.framework import tensor_util
|
| 27 |
+
from tensorflow.python.ops import array_ops
|
| 28 |
+
from tensorflow.python.ops import array_ops_stack
|
| 29 |
+
from tensorflow.python.ops import check_ops
|
| 30 |
+
from tensorflow.python.ops import cond as tf_cond
|
| 31 |
+
from tensorflow.python.ops import control_flow_ops
|
| 32 |
+
from tensorflow.python.ops import linalg_ops
|
| 33 |
+
from tensorflow.python.ops import math_ops
|
| 34 |
+
from tensorflow.python.ops import nn
|
| 35 |
+
from tensorflow.python.util import tf_inspect
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def assert_integer_form(x,
|
| 39 |
+
data=None,
|
| 40 |
+
summarize=None,
|
| 41 |
+
message=None,
|
| 42 |
+
int_dtype=None,
|
| 43 |
+
name="assert_integer_form"):
|
| 44 |
+
"""Assert that x has integer components (or floats equal to integers).
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
x: Floating-point `Tensor`
|
| 48 |
+
data: The tensors to print out if the condition is `False`. Defaults to
|
| 49 |
+
error message and first few entries of `x` and `y`.
|
| 50 |
+
summarize: Print this many entries of each tensor.
|
| 51 |
+
message: A string to prefix to the default message.
|
| 52 |
+
int_dtype: A `tf.dtype` used to cast the float to. The default (`None`)
|
| 53 |
+
implies the smallest possible signed int will be used for casting.
|
| 54 |
+
name: A name for this operation (optional).
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`.
|
| 58 |
+
"""
|
| 59 |
+
with ops.name_scope(name, values=[x, data]):
|
| 60 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 61 |
+
if x.dtype.is_integer:
|
| 62 |
+
return control_flow_ops.no_op()
|
| 63 |
+
message = message or "{} has non-integer components".format(x)
|
| 64 |
+
if int_dtype is None:
|
| 65 |
+
try:
|
| 66 |
+
int_dtype = {
|
| 67 |
+
dtypes.float16: dtypes.int16,
|
| 68 |
+
dtypes.float32: dtypes.int32,
|
| 69 |
+
dtypes.float64: dtypes.int64,
|
| 70 |
+
}[x.dtype.base_dtype]
|
| 71 |
+
except KeyError:
|
| 72 |
+
raise TypeError("Unrecognized type {}".format(x.dtype.name))
|
| 73 |
+
return check_ops.assert_equal(
|
| 74 |
+
x,
|
| 75 |
+
math_ops.cast(math_ops.cast(x, int_dtype), x.dtype),
|
| 76 |
+
data=data,
|
| 77 |
+
summarize=summarize,
|
| 78 |
+
message=message,
|
| 79 |
+
name=name)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def assert_symmetric(matrix):
|
| 83 |
+
matrix_t = array_ops.matrix_transpose(matrix)
|
| 84 |
+
return control_flow_ops.with_dependencies(
|
| 85 |
+
[check_ops.assert_equal(matrix, matrix_t)], matrix)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def embed_check_nonnegative_integer_form(
|
| 89 |
+
x, name="embed_check_nonnegative_integer_form"):
|
| 90 |
+
"""Assert x is a non-negative tensor, and optionally of integers."""
|
| 91 |
+
with ops.name_scope(name, values=[x]):
|
| 92 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 93 |
+
assertions = [
|
| 94 |
+
check_ops.assert_non_negative(
|
| 95 |
+
x, message="'{}' must be non-negative.".format(x)),
|
| 96 |
+
]
|
| 97 |
+
if not x.dtype.is_integer:
|
| 98 |
+
assertions += [
|
| 99 |
+
assert_integer_form(
|
| 100 |
+
x,
|
| 101 |
+
message="'{}' cannot contain fractional components.".format(x)),
|
| 102 |
+
]
|
| 103 |
+
return control_flow_ops.with_dependencies(assertions, x)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def same_dynamic_shape(a, b):
|
| 107 |
+
"""Returns whether a and b have the same dynamic shape.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
a: `Tensor`
|
| 111 |
+
b: `Tensor`
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
`bool` `Tensor` representing if both tensors have the same shape.
|
| 115 |
+
"""
|
| 116 |
+
a = ops.convert_to_tensor(a, name="a")
|
| 117 |
+
b = ops.convert_to_tensor(b, name="b")
|
| 118 |
+
|
| 119 |
+
# Here we can't just do math_ops.equal(a.shape, b.shape), since
|
| 120 |
+
# static shape inference may break the equality comparison between
|
| 121 |
+
# shape(a) and shape(b) in math_ops.equal.
|
| 122 |
+
def all_shapes_equal():
|
| 123 |
+
return math_ops.reduce_all(
|
| 124 |
+
math_ops.equal(
|
| 125 |
+
array_ops.concat(
|
| 126 |
+
[array_ops.shape(a), array_ops.shape(b)], 0),
|
| 127 |
+
array_ops.concat(
|
| 128 |
+
[array_ops.shape(b), array_ops.shape(a)], 0)))
|
| 129 |
+
|
| 130 |
+
# One of the shapes isn't fully defined, so we need to use the dynamic
|
| 131 |
+
# shape.
|
| 132 |
+
return tf_cond.cond(
|
| 133 |
+
math_ops.equal(array_ops.rank(a), array_ops.rank(b)),
|
| 134 |
+
all_shapes_equal, lambda: constant_op.constant(False))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def maybe_get_static_value(x, dtype=None):
|
| 138 |
+
"""Helper which tries to return a static value.
|
| 139 |
+
|
| 140 |
+
Given `x`, extract it's value statically, optionally casting to a specific
|
| 141 |
+
dtype. If this is not possible, None is returned.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
x: `Tensor` for which to extract a value statically.
|
| 145 |
+
dtype: Optional dtype to cast to.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Statically inferred value if possible, otherwise None.
|
| 149 |
+
"""
|
| 150 |
+
if x is None:
|
| 151 |
+
return x
|
| 152 |
+
try:
|
| 153 |
+
# This returns an np.ndarray.
|
| 154 |
+
x_ = tensor_util.constant_value(x)
|
| 155 |
+
except TypeError:
|
| 156 |
+
x_ = x
|
| 157 |
+
if x_ is None or dtype is None:
|
| 158 |
+
return x_
|
| 159 |
+
return np.array(x_, dtype)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def get_logits_and_probs(logits=None,
|
| 163 |
+
probs=None,
|
| 164 |
+
multidimensional=False,
|
| 165 |
+
validate_args=False,
|
| 166 |
+
name="get_logits_and_probs",
|
| 167 |
+
dtype=None):
|
| 168 |
+
"""Converts logit to probabilities (or vice-versa), and returns both.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
logits: Floating-point `Tensor` representing log-odds.
|
| 172 |
+
probs: Floating-point `Tensor` representing probabilities.
|
| 173 |
+
multidimensional: Python `bool`, default `False`. If `True`, represents
|
| 174 |
+
whether the last dimension of `logits` or `probs`, a `[N1, N2, ... k]`
|
| 175 |
+
dimensional tensor, representing the logit or probability of `shape[-1]`
|
| 176 |
+
classes.
|
| 177 |
+
validate_args: Python `bool`, default `False`. When `True`, either assert `0
|
| 178 |
+
<= probs <= 1` (if not `multidimensional`) or that the last dimension of
|
| 179 |
+
`probs` sums to one.
|
| 180 |
+
name: A name for this operation (optional).
|
| 181 |
+
dtype: `tf.DType` to prefer when converting args to `Tensor`s.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
|
| 185 |
+
`1`, then the corresponding entry in the returned logit will be `-Inf` and
|
| 186 |
+
`Inf` respectively.
|
| 187 |
+
|
| 188 |
+
Raises:
|
| 189 |
+
ValueError: if neither `probs` nor `logits` were passed in, or both were.
|
| 190 |
+
"""
|
| 191 |
+
with ops.name_scope(name, values=[probs, logits]):
|
| 192 |
+
if (probs is None) == (logits is None):
|
| 193 |
+
raise ValueError("Must pass probs or logits, but not both.")
|
| 194 |
+
|
| 195 |
+
if probs is None:
|
| 196 |
+
logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
|
| 197 |
+
if not logits.dtype.is_floating:
|
| 198 |
+
raise TypeError("logits must having floating type.")
|
| 199 |
+
# We can early return since we constructed probs and therefore know
|
| 200 |
+
# they're valid.
|
| 201 |
+
if multidimensional:
|
| 202 |
+
if validate_args:
|
| 203 |
+
logits = embed_check_categorical_event_shape(logits)
|
| 204 |
+
return logits, nn.softmax(logits, name="probs")
|
| 205 |
+
return logits, math_ops.sigmoid(logits, name="probs")
|
| 206 |
+
|
| 207 |
+
probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
|
| 208 |
+
if not probs.dtype.is_floating:
|
| 209 |
+
raise TypeError("probs must having floating type.")
|
| 210 |
+
|
| 211 |
+
if validate_args:
|
| 212 |
+
with ops.name_scope("validate_probs"):
|
| 213 |
+
one = constant_op.constant(1., probs.dtype)
|
| 214 |
+
dependencies = [check_ops.assert_non_negative(probs)]
|
| 215 |
+
if multidimensional:
|
| 216 |
+
probs = embed_check_categorical_event_shape(probs)
|
| 217 |
+
dependencies += [
|
| 218 |
+
check_ops.assert_near(
|
| 219 |
+
math_ops.reduce_sum(probs, -1),
|
| 220 |
+
one,
|
| 221 |
+
message="probs does not sum to 1.")
|
| 222 |
+
]
|
| 223 |
+
else:
|
| 224 |
+
dependencies += [
|
| 225 |
+
check_ops.assert_less_equal(
|
| 226 |
+
probs, one, message="probs has components greater than 1.")
|
| 227 |
+
]
|
| 228 |
+
probs = control_flow_ops.with_dependencies(dependencies, probs)
|
| 229 |
+
|
| 230 |
+
with ops.name_scope("logits"):
|
| 231 |
+
if multidimensional:
|
| 232 |
+
# Here we don't compute the multidimensional case, in a manner
|
| 233 |
+
# consistent with respect to the unidimensional case. We do so
|
| 234 |
+
# following the TF convention. Typically, you might expect to see
|
| 235 |
+
# logits = log(probs) - log(probs[pivot]). A side-effect of
|
| 236 |
+
# being consistent with the TF approach is that the unidimensional case
|
| 237 |
+
# implicitly handles the second dimension but the multidimensional case
|
| 238 |
+
# explicitly keeps the pivot dimension.
|
| 239 |
+
return math_ops.log(probs), probs
|
| 240 |
+
return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _is_known_unsigned_by_dtype(dt):
|
| 244 |
+
"""Helper returning True if dtype is known to be unsigned."""
|
| 245 |
+
return {
|
| 246 |
+
dtypes.bool: True,
|
| 247 |
+
dtypes.uint8: True,
|
| 248 |
+
dtypes.uint16: True,
|
| 249 |
+
}.get(dt.base_dtype, False)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _is_known_signed_by_dtype(dt):
|
| 253 |
+
"""Helper returning True if dtype is known to be signed."""
|
| 254 |
+
return {
|
| 255 |
+
dtypes.float16: True,
|
| 256 |
+
dtypes.float32: True,
|
| 257 |
+
dtypes.float64: True,
|
| 258 |
+
dtypes.int8: True,
|
| 259 |
+
dtypes.int16: True,
|
| 260 |
+
dtypes.int32: True,
|
| 261 |
+
dtypes.int64: True,
|
| 262 |
+
}.get(dt.base_dtype, False)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _is_known_dtype(dt):
|
| 266 |
+
"""Helper returning True if dtype is known."""
|
| 267 |
+
return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _largest_integer_by_dtype(dt):
|
| 271 |
+
"""Helper returning the largest integer exactly representable by dtype."""
|
| 272 |
+
if not _is_known_dtype(dt):
|
| 273 |
+
raise TypeError("Unrecognized dtype: {}".format(dt.name))
|
| 274 |
+
if dt.is_floating:
|
| 275 |
+
return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1))
|
| 276 |
+
if dt.is_integer:
|
| 277 |
+
return np.iinfo(dt.as_numpy_dtype).max
|
| 278 |
+
if dt.base_dtype == dtypes.bool:
|
| 279 |
+
return int(1)
|
| 280 |
+
# We actually can't land here but keep the case for completeness.
|
| 281 |
+
raise TypeError("Unrecognized dtype: {}".format(dt.name))
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _smallest_integer_by_dtype(dt):
|
| 285 |
+
"""Helper returning the smallest integer exactly representable by dtype."""
|
| 286 |
+
if not _is_known_dtype(dt):
|
| 287 |
+
raise TypeError("Unrecognized dtype: {}".format(dt.name))
|
| 288 |
+
if _is_known_unsigned_by_dtype(dt):
|
| 289 |
+
return 0
|
| 290 |
+
return -1 * _largest_integer_by_dtype(dt)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _is_integer_like_by_dtype(dt):
|
| 294 |
+
"""Helper returning True if dtype.is_integer or is `bool`."""
|
| 295 |
+
if not _is_known_dtype(dt):
|
| 296 |
+
raise TypeError("Unrecognized dtype: {}".format(dt.name))
|
| 297 |
+
return dt.is_integer or dt.base_dtype == dtypes.bool
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def embed_check_categorical_event_shape(
|
| 301 |
+
categorical_param, name="embed_check_categorical_event_shape"):
|
| 302 |
+
"""Embeds checks that categorical distributions don't have too many classes.
|
| 303 |
+
|
| 304 |
+
A categorical-type distribution is one which, e.g., returns the class label
|
| 305 |
+
rather than a one-hot encoding. E.g., `Categorical(probs)`.
|
| 306 |
+
|
| 307 |
+
Since distributions output samples in the same dtype as the parameters, we
|
| 308 |
+
must ensure that casting doesn't lose precision. That is, the
|
| 309 |
+
`parameter.dtype` implies a maximum number of classes. However, since shape is
|
| 310 |
+
`int32` and categorical variables are presumed to be indexes into a `Tensor`,
|
| 311 |
+
we must also ensure that the number of classes is no larger than the largest
|
| 312 |
+
possible `int32` index, i.e., `2**31-1`.
|
| 313 |
+
|
| 314 |
+
In other words the number of classes, `K`, must satisfy the following
|
| 315 |
+
condition:
|
| 316 |
+
|
| 317 |
+
```python
|
| 318 |
+
K <= min(
|
| 319 |
+
int(2**31 - 1), # Largest float as an index.
|
| 320 |
+
{
|
| 321 |
+
dtypes.float16: int(2**11), # Largest int as a float16.
|
| 322 |
+
dtypes.float32: int(2**24),
|
| 323 |
+
dtypes.float64: int(2**53),
|
| 324 |
+
}.get(categorical_param.dtype.base_dtype, 0))
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
categorical_param: Floating-point `Tensor` representing parameters of
|
| 329 |
+
distribution over categories. The rightmost shape is presumed to be the
|
| 330 |
+
number of categories.
|
| 331 |
+
name: A name for this operation (optional).
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
categorical_param: Input `Tensor` with appropriate assertions embedded.
|
| 335 |
+
|
| 336 |
+
Raises:
|
| 337 |
+
TypeError: if `categorical_param` has an unknown `dtype`.
|
| 338 |
+
ValueError: if we can statically identify `categorical_param` as being too
|
| 339 |
+
large (for being closed under int32/float casting).
|
| 340 |
+
"""
|
| 341 |
+
with ops.name_scope(name, values=[categorical_param]):
|
| 342 |
+
x = ops.convert_to_tensor(categorical_param, name="categorical_param")
|
| 343 |
+
# The size must not exceed both of:
|
| 344 |
+
# - The largest possible int32 (since categorical values are presumed to be
|
| 345 |
+
# indexes into a Tensor).
|
| 346 |
+
# - The largest possible integer exactly representable under the given
|
| 347 |
+
# floating-point dtype (since we need to cast to/from).
|
| 348 |
+
#
|
| 349 |
+
# The chosen floating-point thresholds are 2**(1 + mantissa_bits).
|
| 350 |
+
# For more details, see:
|
| 351 |
+
# https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
|
| 352 |
+
x_dtype = x.dtype.base_dtype
|
| 353 |
+
max_event_size = (
|
| 354 |
+
_largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0)
|
| 355 |
+
if max_event_size == 0:
|
| 356 |
+
raise TypeError("Unable to validate size of unrecognized dtype "
|
| 357 |
+
"({}).".format(x_dtype.name))
|
| 358 |
+
try:
|
| 359 |
+
x_shape_static = x.get_shape().with_rank_at_least(1)
|
| 360 |
+
except ValueError:
|
| 361 |
+
raise ValueError("A categorical-distribution parameter must have "
|
| 362 |
+
"at least 1 dimension.")
|
| 363 |
+
if tensor_shape.dimension_value(x_shape_static[-1]) is not None:
|
| 364 |
+
event_size = x_shape_static.dims[-1].value
|
| 365 |
+
if event_size < 2:
|
| 366 |
+
raise ValueError("A categorical-distribution parameter must have at "
|
| 367 |
+
"least 2 events.")
|
| 368 |
+
if event_size > max_event_size:
|
| 369 |
+
raise ValueError("Number of classes exceeds `dtype` precision, i.e., "
|
| 370 |
+
"{} implies shape ({}) cannot exceed {}.".format(
|
| 371 |
+
x_dtype.name, event_size, max_event_size))
|
| 372 |
+
return x
|
| 373 |
+
else:
|
| 374 |
+
event_size = array_ops.shape(x, name="x_shape")[-1]
|
| 375 |
+
return control_flow_ops.with_dependencies([
|
| 376 |
+
check_ops.assert_rank_at_least(
|
| 377 |
+
x,
|
| 378 |
+
1,
|
| 379 |
+
message=("A categorical-distribution parameter must have "
|
| 380 |
+
"at least 1 dimension.")),
|
| 381 |
+
check_ops.assert_greater_equal(
|
| 382 |
+
array_ops.shape(x)[-1],
|
| 383 |
+
2,
|
| 384 |
+
message=("A categorical-distribution parameter must have at "
|
| 385 |
+
"least 2 events.")),
|
| 386 |
+
check_ops.assert_less_equal(
|
| 387 |
+
event_size,
|
| 388 |
+
max_event_size,
|
| 389 |
+
message="Number of classes exceeds `dtype` precision, "
|
| 390 |
+
"i.e., {} dtype cannot exceed {} shape.".format(
|
| 391 |
+
x_dtype.name, max_event_size)),
|
| 392 |
+
], x)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def embed_check_integer_casting_closed(x,
|
| 396 |
+
target_dtype,
|
| 397 |
+
assert_nonnegative=True,
|
| 398 |
+
name="embed_check_casting_closed"):
|
| 399 |
+
"""Ensures integers remain unaffected despite casting to/from int/float types.
|
| 400 |
+
|
| 401 |
+
Example integer-types: `uint8`, `int32`, `bool`.
|
| 402 |
+
Example floating-types: `float32`, `float64`.
|
| 403 |
+
|
| 404 |
+
The largest possible integer representable by an IEEE754 floating-point is
|
| 405 |
+
`2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is
|
| 406 |
+
`2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have
|
| 407 |
+
integer-form values can be cast to some other type without loss of precision.
|
| 408 |
+
|
| 409 |
+
The smallest representable integer is the negative of the largest
|
| 410 |
+
representable integer, except for types: `uint8`, `uint16`, `bool`. For these
|
| 411 |
+
types, the smallest representable integer is `0`.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
x: `Tensor` representing integer-form values.
|
| 415 |
+
target_dtype: TF `dtype` under which `x` should have identical values.
|
| 416 |
+
assert_nonnegative: `bool` indicating `x` should contain nonnegative values.
|
| 417 |
+
name: A name for this operation (optional).
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
x: Input `Tensor` with appropriate assertions embedded.
|
| 421 |
+
|
| 422 |
+
Raises:
|
| 423 |
+
TypeError: if `x` is neither integer- nor floating-type.
|
| 424 |
+
TypeError: if `target_dtype` is neither integer- nor floating-type.
|
| 425 |
+
TypeError: if neither `x` nor `target_dtype` are integer-type.
|
| 426 |
+
"""
|
| 427 |
+
|
| 428 |
+
with ops.name_scope(name, values=[x]):
|
| 429 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 430 |
+
if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating):
|
| 431 |
+
raise TypeError("{}.dtype must be floating- or "
|
| 432 |
+
"integer-type.".format(x.dtype.name))
|
| 433 |
+
if (not _is_integer_like_by_dtype(target_dtype) and
|
| 434 |
+
not target_dtype.is_floating):
|
| 435 |
+
raise TypeError("target_dtype ({}) must be floating- or "
|
| 436 |
+
"integer-type.".format(target_dtype.name))
|
| 437 |
+
if (not _is_integer_like_by_dtype(x.dtype) and
|
| 438 |
+
not _is_integer_like_by_dtype(target_dtype)):
|
| 439 |
+
raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) "
|
| 440 |
+
"must be integer-type.".format(x, x.dtype.name,
|
| 441 |
+
target_dtype.name))
|
| 442 |
+
|
| 443 |
+
assertions = []
|
| 444 |
+
if assert_nonnegative:
|
| 445 |
+
assertions += [
|
| 446 |
+
check_ops.assert_non_negative(
|
| 447 |
+
x, message="Elements must be non-negative."),
|
| 448 |
+
]
|
| 449 |
+
|
| 450 |
+
if x.dtype.is_floating:
|
| 451 |
+
# Being here means _is_integer_like_by_dtype(target_dtype) = True.
|
| 452 |
+
# Since this check implies the magnitude check below, we need only it.
|
| 453 |
+
assertions += [
|
| 454 |
+
assert_integer_form(
|
| 455 |
+
x,
|
| 456 |
+
int_dtype=target_dtype,
|
| 457 |
+
message="Elements must be {}-equivalent.".format(
|
| 458 |
+
target_dtype.name)),
|
| 459 |
+
]
|
| 460 |
+
else:
|
| 461 |
+
if (_largest_integer_by_dtype(x.dtype) >
|
| 462 |
+
_largest_integer_by_dtype(target_dtype)):
|
| 463 |
+
# Cast may lose integer precision.
|
| 464 |
+
assertions += [
|
| 465 |
+
check_ops.assert_less_equal(
|
| 466 |
+
x,
|
| 467 |
+
_largest_integer_by_dtype(target_dtype),
|
| 468 |
+
message=("Elements cannot exceed {}.".format(
|
| 469 |
+
_largest_integer_by_dtype(target_dtype)))),
|
| 470 |
+
]
|
| 471 |
+
if (not assert_nonnegative and (_smallest_integer_by_dtype(
|
| 472 |
+
x.dtype) < _smallest_integer_by_dtype(target_dtype))):
|
| 473 |
+
assertions += [
|
| 474 |
+
check_ops.assert_greater_equal(
|
| 475 |
+
x,
|
| 476 |
+
_smallest_integer_by_dtype(target_dtype),
|
| 477 |
+
message=("Elements cannot be smaller than {}.".format(
|
| 478 |
+
_smallest_integer_by_dtype(target_dtype)))),
|
| 479 |
+
]
|
| 480 |
+
|
| 481 |
+
if not assertions:
|
| 482 |
+
return x
|
| 483 |
+
return control_flow_ops.with_dependencies(assertions, x)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def log_combinations(n, counts, name="log_combinations"):
|
| 487 |
+
"""Multinomial coefficient.
|
| 488 |
+
|
| 489 |
+
Given `n` and `counts`, where `counts` has last dimension `k`, we compute
|
| 490 |
+
the multinomial coefficient as:
|
| 491 |
+
|
| 492 |
+
```n! / sum_i n_i!```
|
| 493 |
+
|
| 494 |
+
where `i` runs over all `k` classes.
|
| 495 |
+
|
| 496 |
+
Args:
|
| 497 |
+
n: Floating-point `Tensor` broadcastable with `counts`. This represents `n`
|
| 498 |
+
outcomes.
|
| 499 |
+
counts: Floating-point `Tensor` broadcastable with `n`. This represents
|
| 500 |
+
counts in `k` classes, where `k` is the last dimension of the tensor.
|
| 501 |
+
name: A name for this operation (optional).
|
| 502 |
+
|
| 503 |
+
Returns:
|
| 504 |
+
`Tensor` representing the multinomial coefficient between `n` and `counts`.
|
| 505 |
+
"""
|
| 506 |
+
# First a bit about the number of ways counts could have come in:
|
| 507 |
+
# E.g. if counts = [1, 2], then this is 3 choose 2.
|
| 508 |
+
# In general, this is (sum counts)! / sum(counts!)
|
| 509 |
+
# The sum should be along the last dimension of counts. This is the
|
| 510 |
+
# "distribution" dimension. Here n a priori represents the sum of counts.
|
| 511 |
+
with ops.name_scope(name, values=[n, counts]):
|
| 512 |
+
n = ops.convert_to_tensor(n, name="n")
|
| 513 |
+
counts = ops.convert_to_tensor(counts, name="counts")
|
| 514 |
+
total_permutations = math_ops.lgamma(n + 1)
|
| 515 |
+
counts_factorial = math_ops.lgamma(counts + 1)
|
| 516 |
+
redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1])
|
| 517 |
+
return total_permutations - redundant_permutations
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def matrix_diag_transform(matrix, transform=None, name=None):
|
| 521 |
+
"""Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.
|
| 522 |
+
|
| 523 |
+
Create a trainable covariance defined by a Cholesky factor:
|
| 524 |
+
|
| 525 |
+
```python
|
| 526 |
+
# Transform network layer into 2 x 2 array.
|
| 527 |
+
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
|
| 528 |
+
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
|
| 529 |
+
|
| 530 |
+
# Make the diagonal positive. If the upper triangle was zero, this would be a
|
| 531 |
+
# valid Cholesky factor.
|
| 532 |
+
chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
|
| 533 |
+
|
| 534 |
+
# LinearOperatorLowerTriangular ignores the upper triangle.
|
| 535 |
+
operator = LinearOperatorLowerTriangular(chol)
|
| 536 |
+
```
|
| 537 |
+
|
| 538 |
+
Example of heteroskedastic 2-D linear regression.
|
| 539 |
+
|
| 540 |
+
```python
|
| 541 |
+
tfd = tfp.distributions
|
| 542 |
+
|
| 543 |
+
# Get a trainable Cholesky factor.
|
| 544 |
+
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
|
| 545 |
+
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
|
| 546 |
+
chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
|
| 547 |
+
|
| 548 |
+
# Get a trainable mean.
|
| 549 |
+
mu = tf.contrib.layers.fully_connected(activations, 2)
|
| 550 |
+
|
| 551 |
+
# This is a fully trainable multivariate normal!
|
| 552 |
+
dist = tfd.MultivariateNormalTriL(mu, chol)
|
| 553 |
+
|
| 554 |
+
# Standard log loss. Minimizing this will "train" mu and chol, and then dist
|
| 555 |
+
# will be a distribution predicting labels as multivariate Gaussians.
|
| 556 |
+
loss = -1 * tf.reduce_mean(dist.log_prob(labels))
|
| 557 |
+
```
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
|
| 561 |
+
equal.
|
| 562 |
+
transform: Element-wise function mapping `Tensors` to `Tensors`. To be
|
| 563 |
+
applied to the diagonal of `matrix`. If `None`, `matrix` is returned
|
| 564 |
+
unchanged. Defaults to `None`.
|
| 565 |
+
name: A name to give created ops. Defaults to "matrix_diag_transform".
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
A `Tensor` with same shape and `dtype` as `matrix`.
|
| 569 |
+
"""
|
| 570 |
+
with ops.name_scope(name, "matrix_diag_transform", [matrix]):
|
| 571 |
+
matrix = ops.convert_to_tensor(matrix, name="matrix")
|
| 572 |
+
if transform is None:
|
| 573 |
+
return matrix
|
| 574 |
+
# Replace the diag with transformed diag.
|
| 575 |
+
diag = array_ops.matrix_diag_part(matrix)
|
| 576 |
+
transformed_diag = transform(diag)
|
| 577 |
+
transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag)
|
| 578 |
+
|
| 579 |
+
return transformed_mat
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def rotate_transpose(x, shift, name="rotate_transpose"):
|
| 583 |
+
"""Circularly moves dims left or right.
|
| 584 |
+
|
| 585 |
+
Effectively identical to:
|
| 586 |
+
|
| 587 |
+
```python
|
| 588 |
+
numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift))
|
| 589 |
+
```
|
| 590 |
+
|
| 591 |
+
When `validate_args=False` additional graph-runtime checks are
|
| 592 |
+
performed. These checks entail moving data from to GPU to CPU.
|
| 593 |
+
|
| 594 |
+
Example:
|
| 595 |
+
|
| 596 |
+
```python
|
| 597 |
+
x = tf.random.normal([1, 2, 3, 4]) # Tensor of shape [1, 2, 3, 4].
|
| 598 |
+
rotate_transpose(x, -1).shape == [2, 3, 4, 1]
|
| 599 |
+
rotate_transpose(x, -2).shape == [3, 4, 1, 2]
|
| 600 |
+
rotate_transpose(x, 1).shape == [4, 1, 2, 3]
|
| 601 |
+
rotate_transpose(x, 2).shape == [3, 4, 1, 2]
|
| 602 |
+
rotate_transpose(x, 7).shape == rotate_transpose(x, 3).shape # [2, 3, 4, 1]
|
| 603 |
+
rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape # [4, 1, 2, 3]
|
| 604 |
+
```
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
x: `Tensor`.
|
| 608 |
+
shift: `Tensor`. Number of dimensions to transpose left (shift<0) or
|
| 609 |
+
transpose right (shift>0).
|
| 610 |
+
name: Python `str`. The name to give this op.
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
rotated_x: Input `Tensor` with dimensions circularly rotated by shift.
|
| 614 |
+
|
| 615 |
+
Raises:
|
| 616 |
+
TypeError: if shift is not integer type.
|
| 617 |
+
"""
|
| 618 |
+
with ops.name_scope(name, values=[x, shift]):
|
| 619 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 620 |
+
shift = ops.convert_to_tensor(shift, name="shift")
|
| 621 |
+
# We do not assign back to preserve constant-ness.
|
| 622 |
+
check_ops.assert_integer(shift)
|
| 623 |
+
shift_value_static = tensor_util.constant_value(shift)
|
| 624 |
+
ndims = x.get_shape().ndims
|
| 625 |
+
if ndims is not None and shift_value_static is not None:
|
| 626 |
+
if ndims < 2:
|
| 627 |
+
return x
|
| 628 |
+
shift_value_static = np.sign(shift_value_static) * (
|
| 629 |
+
abs(shift_value_static) % ndims)
|
| 630 |
+
if shift_value_static == 0:
|
| 631 |
+
return x
|
| 632 |
+
perm = np.roll(np.arange(ndims), shift_value_static)
|
| 633 |
+
return array_ops.transpose(x, perm=perm)
|
| 634 |
+
else:
|
| 635 |
+
# Consider if we always had a positive shift, and some specified
|
| 636 |
+
# direction.
|
| 637 |
+
# When shifting left we want the new array:
|
| 638 |
+
# last(x, n-shift) + first(x, shift)
|
| 639 |
+
# and if shifting right then we want:
|
| 640 |
+
# last(x, shift) + first(x, n-shift)
|
| 641 |
+
# Observe that last(a) == slice(a, n) and first(a) == slice(0, a).
|
| 642 |
+
# Also, we can encode direction and shift as one: direction * shift.
|
| 643 |
+
# Combining these facts, we have:
|
| 644 |
+
# a = cond(shift<0, -shift, n-shift)
|
| 645 |
+
# last(x, n-a) + first(x, a) == x[a:n] + x[0:a]
|
| 646 |
+
# Finally, we transform shift by modulo length so it can be specified
|
| 647 |
+
# independently from the array upon which it operates (like python).
|
| 648 |
+
ndims = array_ops.rank(x)
|
| 649 |
+
shift = array_ops.where_v2(
|
| 650 |
+
math_ops.less(shift, 0),
|
| 651 |
+
math_ops.mod(-shift, ndims), # pylint: disable=invalid-unary-operand-type
|
| 652 |
+
ndims - math_ops.mod(shift, ndims))
|
| 653 |
+
first = math_ops.range(0, shift)
|
| 654 |
+
last = math_ops.range(shift, ndims)
|
| 655 |
+
perm = array_ops.concat([last, first], 0)
|
| 656 |
+
return array_ops.transpose(x, perm=perm)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
def pick_vector(cond, true_vector, false_vector, name="pick_vector"):
|
| 660 |
+
"""Picks possibly different length row `Tensor`s based on condition.
|
| 661 |
+
|
| 662 |
+
Value `Tensor`s should have exactly one dimension.
|
| 663 |
+
|
| 664 |
+
If `cond` is a python Boolean or `tf.constant` then either `true_vector` or
|
| 665 |
+
`false_vector` is immediately returned. I.e., no graph nodes are created and
|
| 666 |
+
no validation happens.
|
| 667 |
+
|
| 668 |
+
Args:
|
| 669 |
+
cond: `Tensor`. Must have `dtype=tf.bool` and be scalar.
|
| 670 |
+
true_vector: `Tensor` of one dimension. Returned when cond is `True`.
|
| 671 |
+
false_vector: `Tensor` of one dimension. Returned when cond is `False`.
|
| 672 |
+
name: Python `str`. The name to give this op.
|
| 673 |
+
Example: ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15,
|
| 674 |
+
18)) # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15,
|
| 675 |
+
18)) # [15, 16, 17] ```
|
| 676 |
+
|
| 677 |
+
Returns:
|
| 678 |
+
true_or_false_vector: `Tensor`.
|
| 679 |
+
|
| 680 |
+
Raises:
|
| 681 |
+
TypeError: if `cond.dtype != tf.bool`
|
| 682 |
+
TypeError: if `cond` is not a constant and
|
| 683 |
+
`true_vector.dtype != false_vector.dtype`
|
| 684 |
+
"""
|
| 685 |
+
with ops.name_scope(name, values=(cond, true_vector, false_vector)):
|
| 686 |
+
cond = ops.convert_to_tensor(cond, name="cond")
|
| 687 |
+
if cond.dtype != dtypes.bool:
|
| 688 |
+
raise TypeError("%s.dtype=%s which is not %s" %
|
| 689 |
+
(cond, cond.dtype, dtypes.bool))
|
| 690 |
+
cond_value_static = tensor_util.constant_value(cond)
|
| 691 |
+
if cond_value_static is not None:
|
| 692 |
+
return true_vector if cond_value_static else false_vector
|
| 693 |
+
true_vector = ops.convert_to_tensor(true_vector, name="true_vector")
|
| 694 |
+
false_vector = ops.convert_to_tensor(false_vector, name="false_vector")
|
| 695 |
+
if true_vector.dtype != false_vector.dtype:
|
| 696 |
+
raise TypeError(
|
| 697 |
+
"%s.dtype=%s does not match %s.dtype=%s" %
|
| 698 |
+
(true_vector, true_vector.dtype, false_vector, false_vector.dtype))
|
| 699 |
+
n = array_ops.shape(true_vector)[0]
|
| 700 |
+
return array_ops.slice(
|
| 701 |
+
array_ops.concat([true_vector, false_vector], 0),
|
| 702 |
+
[array_ops.where_v2(cond, 0, n)], [array_ops.where(cond, n, -1)])
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def prefer_static_broadcast_shape(shape1,
|
| 706 |
+
shape2,
|
| 707 |
+
name="prefer_static_broadcast_shape"):
|
| 708 |
+
"""Convenience function which statically broadcasts shape when possible.
|
| 709 |
+
|
| 710 |
+
Args:
|
| 711 |
+
shape1: `1-D` integer `Tensor`. Already converted to tensor!
|
| 712 |
+
shape2: `1-D` integer `Tensor`. Already converted to tensor!
|
| 713 |
+
name: A string name to prepend to created ops.
|
| 714 |
+
|
| 715 |
+
Returns:
|
| 716 |
+
The broadcast shape, either as `TensorShape` (if broadcast can be done
|
| 717 |
+
statically), or as a `Tensor`.
|
| 718 |
+
"""
|
| 719 |
+
with ops.name_scope(name, values=[shape1, shape2]):
|
| 720 |
+
|
| 721 |
+
def make_shape_tensor(x):
|
| 722 |
+
return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32)
|
| 723 |
+
|
| 724 |
+
def get_tensor_shape(s):
|
| 725 |
+
if isinstance(s, tensor_shape.TensorShape):
|
| 726 |
+
return s
|
| 727 |
+
s_ = tensor_util.constant_value(make_shape_tensor(s))
|
| 728 |
+
if s_ is not None:
|
| 729 |
+
return tensor_shape.TensorShape(s_)
|
| 730 |
+
return None
|
| 731 |
+
|
| 732 |
+
def get_shape_tensor(s):
|
| 733 |
+
if not isinstance(s, tensor_shape.TensorShape):
|
| 734 |
+
return make_shape_tensor(s)
|
| 735 |
+
if s.is_fully_defined():
|
| 736 |
+
return make_shape_tensor(s.as_list())
|
| 737 |
+
raise ValueError("Cannot broadcast from partially "
|
| 738 |
+
"defined `TensorShape`.")
|
| 739 |
+
|
| 740 |
+
shape1_ = get_tensor_shape(shape1)
|
| 741 |
+
shape2_ = get_tensor_shape(shape2)
|
| 742 |
+
if shape1_ is not None and shape2_ is not None:
|
| 743 |
+
return array_ops.broadcast_static_shape(shape1_, shape2_)
|
| 744 |
+
|
| 745 |
+
shape1_ = get_shape_tensor(shape1)
|
| 746 |
+
shape2_ = get_shape_tensor(shape2)
|
| 747 |
+
return array_ops.broadcast_dynamic_shape(shape1_, shape2_)
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def prefer_static_rank(x):
|
| 751 |
+
"""Return static rank of tensor `x` if available, else `tf.rank(x)`.
|
| 752 |
+
|
| 753 |
+
Args:
|
| 754 |
+
x: `Tensor` (already converted).
|
| 755 |
+
|
| 756 |
+
Returns:
|
| 757 |
+
Numpy array (if static rank is obtainable), else `Tensor`.
|
| 758 |
+
"""
|
| 759 |
+
return prefer_static_value(array_ops.rank(x))
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def prefer_static_shape(x):
|
| 763 |
+
"""Return static shape of tensor `x` if available, else `tf.shape(x)`.
|
| 764 |
+
|
| 765 |
+
Args:
|
| 766 |
+
x: `Tensor` (already converted).
|
| 767 |
+
|
| 768 |
+
Returns:
|
| 769 |
+
Numpy array (if static shape is obtainable), else `Tensor`.
|
| 770 |
+
"""
|
| 771 |
+
return prefer_static_value(array_ops.shape(x))
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def prefer_static_value(x):
|
| 775 |
+
"""Return static value of tensor `x` if available, else `x`.
|
| 776 |
+
|
| 777 |
+
Args:
|
| 778 |
+
x: `Tensor` (already converted).
|
| 779 |
+
|
| 780 |
+
Returns:
|
| 781 |
+
Numpy array (if static value is obtainable), else `Tensor`.
|
| 782 |
+
"""
|
| 783 |
+
static_x = tensor_util.constant_value(x)
|
| 784 |
+
if static_x is not None:
|
| 785 |
+
return static_x
|
| 786 |
+
return x
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
def gen_new_seed(seed, salt):
|
| 790 |
+
"""Generate a new seed, from the given seed and salt."""
|
| 791 |
+
if seed is None:
|
| 792 |
+
return None
|
| 793 |
+
string = (str(seed) + salt).encode("utf-8")
|
| 794 |
+
return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
def fill_triangular(x, upper=False, name=None):
|
| 798 |
+
"""Creates a (batch of) triangular matrix from a vector of inputs.
|
| 799 |
+
|
| 800 |
+
Created matrix can be lower- or upper-triangular. (It is more efficient to
|
| 801 |
+
create the matrix as upper or lower, rather than transpose.)
|
| 802 |
+
|
| 803 |
+
Triangular matrix elements are filled in a clockwise spiral. See example,
|
| 804 |
+
below.
|
| 805 |
+
|
| 806 |
+
If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
|
| 807 |
+
`[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
|
| 808 |
+
`n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
|
| 809 |
+
|
| 810 |
+
Example:
|
| 811 |
+
|
| 812 |
+
```python
|
| 813 |
+
fill_triangular([1, 2, 3, 4, 5, 6])
|
| 814 |
+
# ==> [[4, 0, 0],
|
| 815 |
+
# [6, 5, 0],
|
| 816 |
+
# [3, 2, 1]]
|
| 817 |
+
|
| 818 |
+
fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
|
| 819 |
+
# ==> [[1, 2, 3],
|
| 820 |
+
# [0, 5, 6],
|
| 821 |
+
# [0, 0, 4]]
|
| 822 |
+
```
|
| 823 |
+
|
| 824 |
+
For comparison, a pure numpy version of this function can be found in
|
| 825 |
+
`util_test.py`, function `_fill_triangular`.
|
| 826 |
+
|
| 827 |
+
Args:
|
| 828 |
+
x: `Tensor` representing lower (or upper) triangular elements.
|
| 829 |
+
upper: Python `bool` representing whether output matrix should be upper
|
| 830 |
+
triangular (`True`) or lower triangular (`False`, default).
|
| 831 |
+
name: Python `str`. The name to give this op.
|
| 832 |
+
|
| 833 |
+
Returns:
|
| 834 |
+
tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
|
| 835 |
+
|
| 836 |
+
Raises:
|
| 837 |
+
ValueError: if `x` cannot be mapped to a triangular matrix.
|
| 838 |
+
"""
|
| 839 |
+
|
| 840 |
+
with ops.name_scope(name, "fill_triangular", values=[x]):
|
| 841 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 842 |
+
if tensor_shape.dimension_value(
|
| 843 |
+
x.shape.with_rank_at_least(1)[-1]) is not None:
|
| 844 |
+
# Formula derived by solving for n: m = n(n+1)/2.
|
| 845 |
+
m = np.int32(x.shape.dims[-1].value)
|
| 846 |
+
n = np.sqrt(0.25 + 2. * m) - 0.5
|
| 847 |
+
if n != np.floor(n):
|
| 848 |
+
raise ValueError("Input right-most shape ({}) does not "
|
| 849 |
+
"correspond to a triangular matrix.".format(m))
|
| 850 |
+
n = np.int32(n)
|
| 851 |
+
static_final_shape = x.shape[:-1].concatenate([n, n])
|
| 852 |
+
else:
|
| 853 |
+
m = array_ops.shape(x)[-1]
|
| 854 |
+
# For derivation, see above. Casting automatically lops off the 0.5, so we
|
| 855 |
+
# omit it. We don't validate n is an integer because this has
|
| 856 |
+
# graph-execution cost; an error will be thrown from the reshape, below.
|
| 857 |
+
n = math_ops.cast(
|
| 858 |
+
math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)),
|
| 859 |
+
dtype=dtypes.int32)
|
| 860 |
+
static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate(
|
| 861 |
+
[None, None])
|
| 862 |
+
# We now concatenate the "tail" of `x` to `x` (and reverse one of them).
|
| 863 |
+
#
|
| 864 |
+
# We do this based on the insight that the input `x` provides `ceil(n/2)`
|
| 865 |
+
# rows of an `n x n` matrix, some of which will get zeroed out being on the
|
| 866 |
+
# wrong side of the diagonal. The first row will not get zeroed out at all,
|
| 867 |
+
# and we need `floor(n/2)` more rows, so the first is what we omit from
|
| 868 |
+
# `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)`
|
| 869 |
+
# rows provided by a reversed tail, it is exactly the other set of elements
|
| 870 |
+
# of the reversed tail which will be zeroed out for being on the wrong side
|
| 871 |
+
# of the diagonal further up/down the matrix. And, in doing-so, we've filled
|
| 872 |
+
# the triangular matrix in a clock-wise spiral pattern. Neat!
|
| 873 |
+
#
|
| 874 |
+
# Try it out in numpy:
|
| 875 |
+
# n = 3
|
| 876 |
+
# x = np.arange(n * (n + 1) / 2)
|
| 877 |
+
# m = x.shape[0]
|
| 878 |
+
# n = np.int32(np.sqrt(.25 + 2 * m) - .5)
|
| 879 |
+
# x_tail = x[(m - (n**2 - m)):]
|
| 880 |
+
# np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower
|
| 881 |
+
# # ==> array([[3, 4, 5],
|
| 882 |
+
# [5, 4, 3],
|
| 883 |
+
# [2, 1, 0]])
|
| 884 |
+
# np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper
|
| 885 |
+
# # ==> array([[0, 1, 2],
|
| 886 |
+
# [3, 4, 5],
|
| 887 |
+
# [5, 4, 3]])
|
| 888 |
+
#
|
| 889 |
+
# Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
|
| 890 |
+
# correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
|
| 891 |
+
# Furthermore observe that:
|
| 892 |
+
# m - (n**2 - m)
|
| 893 |
+
# = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
|
| 894 |
+
# = 2 (n**2 / 2 + n / 2) - n**2
|
| 895 |
+
# = n**2 + n - n**2
|
| 896 |
+
# = n
|
| 897 |
+
ndims = prefer_static_rank(x)
|
| 898 |
+
if upper:
|
| 899 |
+
x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
|
| 900 |
+
else:
|
| 901 |
+
x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])]
|
| 902 |
+
new_shape = (
|
| 903 |
+
static_final_shape.as_list() if static_final_shape.is_fully_defined()
|
| 904 |
+
else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0))
|
| 905 |
+
x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape)
|
| 906 |
+
x = array_ops.matrix_band_part(
|
| 907 |
+
x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0))
|
| 908 |
+
x.set_shape(static_final_shape)
|
| 909 |
+
return x
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def fill_triangular_inverse(x, upper=False, name=None):
|
| 913 |
+
"""Creates a vector from a (batch of) triangular matrix.
|
| 914 |
+
|
| 915 |
+
The vector is created from the lower-triangular or upper-triangular portion
|
| 916 |
+
depending on the value of the parameter `upper`.
|
| 917 |
+
|
| 918 |
+
If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
|
| 919 |
+
`[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
|
| 920 |
+
|
| 921 |
+
Example:
|
| 922 |
+
|
| 923 |
+
```python
|
| 924 |
+
fill_triangular_inverse(
|
| 925 |
+
[[4, 0, 0],
|
| 926 |
+
[6, 5, 0],
|
| 927 |
+
[3, 2, 1]])
|
| 928 |
+
|
| 929 |
+
# ==> [1, 2, 3, 4, 5, 6]
|
| 930 |
+
|
| 931 |
+
fill_triangular_inverse(
|
| 932 |
+
[[1, 2, 3],
|
| 933 |
+
[0, 5, 6],
|
| 934 |
+
[0, 0, 4]], upper=True)
|
| 935 |
+
|
| 936 |
+
# ==> [1, 2, 3, 4, 5, 6]
|
| 937 |
+
```
|
| 938 |
+
|
| 939 |
+
Args:
|
| 940 |
+
x: `Tensor` representing lower (or upper) triangular elements.
|
| 941 |
+
upper: Python `bool` representing whether output matrix should be upper
|
| 942 |
+
triangular (`True`) or lower triangular (`False`, default).
|
| 943 |
+
name: Python `str`. The name to give this op.
|
| 944 |
+
|
| 945 |
+
Returns:
|
| 946 |
+
flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
|
| 947 |
+
(or upper) triangular elements from `x`.
|
| 948 |
+
"""
|
| 949 |
+
|
| 950 |
+
with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
|
| 951 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 952 |
+
if tensor_shape.dimension_value(
|
| 953 |
+
x.shape.with_rank_at_least(2)[-1]) is not None:
|
| 954 |
+
n = np.int32(x.shape.dims[-1].value)
|
| 955 |
+
m = np.int32((n * (n + 1)) // 2)
|
| 956 |
+
static_final_shape = x.shape[:-2].concatenate([m])
|
| 957 |
+
else:
|
| 958 |
+
n = array_ops.shape(x)[-1]
|
| 959 |
+
m = (n * (n + 1)) // 2
|
| 960 |
+
static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
|
| 961 |
+
[None])
|
| 962 |
+
ndims = prefer_static_rank(x)
|
| 963 |
+
if upper:
|
| 964 |
+
initial_elements = x[..., 0, :]
|
| 965 |
+
triangular_portion = x[..., 1:, :]
|
| 966 |
+
else:
|
| 967 |
+
initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
|
| 968 |
+
triangular_portion = x[..., :-1, :]
|
| 969 |
+
rotated_triangular_portion = array_ops.reverse(
|
| 970 |
+
array_ops.reverse(triangular_portion, axis=[ndims - 1]),
|
| 971 |
+
axis=[ndims - 2])
|
| 972 |
+
consolidated_matrix = triangular_portion + rotated_triangular_portion
|
| 973 |
+
end_sequence = array_ops.reshape(
|
| 974 |
+
consolidated_matrix,
|
| 975 |
+
array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
|
| 976 |
+
y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
|
| 977 |
+
y.set_shape(static_final_shape)
|
| 978 |
+
return y
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
def tridiag(below=None, diag=None, above=None, name=None):
|
| 982 |
+
"""Creates a matrix with values set above, below, and on the diagonal.
|
| 983 |
+
|
| 984 |
+
Example:
|
| 985 |
+
|
| 986 |
+
```python
|
| 987 |
+
tridiag(below=[1., 2., 3.],
|
| 988 |
+
diag=[4., 5., 6., 7.],
|
| 989 |
+
above=[8., 9., 10.])
|
| 990 |
+
# ==> array([[ 4., 8., 0., 0.],
|
| 991 |
+
# [ 1., 5., 9., 0.],
|
| 992 |
+
# [ 0., 2., 6., 10.],
|
| 993 |
+
# [ 0., 0., 3., 7.]], dtype=float32)
|
| 994 |
+
```
|
| 995 |
+
|
| 996 |
+
Warning: This Op is intended for convenience, not efficiency.
|
| 997 |
+
|
| 998 |
+
Args:
|
| 999 |
+
below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below
|
| 1000 |
+
diagonal part. `None` is logically equivalent to `below = 0`.
|
| 1001 |
+
diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal
|
| 1002 |
+
part. `None` is logically equivalent to `diag = 0`.
|
| 1003 |
+
above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above
|
| 1004 |
+
diagonal part. `None` is logically equivalent to `above = 0`.
|
| 1005 |
+
name: Python `str`. The name to give this op.
|
| 1006 |
+
|
| 1007 |
+
Returns:
|
| 1008 |
+
tridiag: `Tensor` with values set above, below and on the diagonal.
|
| 1009 |
+
|
| 1010 |
+
Raises:
|
| 1011 |
+
ValueError: if all inputs are `None`.
|
| 1012 |
+
"""
|
| 1013 |
+
|
| 1014 |
+
def _pad(x):
|
| 1015 |
+
"""Prepends and appends a zero to every vector in a batch of vectors."""
|
| 1016 |
+
shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0)
|
| 1017 |
+
z = array_ops.zeros(shape, dtype=x.dtype)
|
| 1018 |
+
return array_ops.concat([z, x, z], axis=-1)
|
| 1019 |
+
|
| 1020 |
+
def _add(*x):
|
| 1021 |
+
"""Adds list of Tensors, ignoring `None`."""
|
| 1022 |
+
s = None
|
| 1023 |
+
for y in x:
|
| 1024 |
+
if y is None:
|
| 1025 |
+
continue
|
| 1026 |
+
elif s is None:
|
| 1027 |
+
s = y
|
| 1028 |
+
else:
|
| 1029 |
+
s += y
|
| 1030 |
+
if s is None:
|
| 1031 |
+
raise ValueError("Must specify at least one of `below`, `diag`, `above`.")
|
| 1032 |
+
return s
|
| 1033 |
+
|
| 1034 |
+
with ops.name_scope(name, "tridiag", [below, diag, above]):
|
| 1035 |
+
if below is not None:
|
| 1036 |
+
below = ops.convert_to_tensor(below, name="below")
|
| 1037 |
+
below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:]
|
| 1038 |
+
if diag is not None:
|
| 1039 |
+
diag = ops.convert_to_tensor(diag, name="diag")
|
| 1040 |
+
diag = array_ops.matrix_diag(diag)
|
| 1041 |
+
if above is not None:
|
| 1042 |
+
above = ops.convert_to_tensor(above, name="above")
|
| 1043 |
+
above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1]
|
| 1044 |
+
# TODO(jvdillon): Consider using scatter_nd instead of creating three full
|
| 1045 |
+
# matrices.
|
| 1046 |
+
return _add(below, diag, above)
|
| 1047 |
+
|
| 1048 |
+
|
| 1049 |
+
def reduce_weighted_logsumexp(logx,
|
| 1050 |
+
w=None,
|
| 1051 |
+
axis=None,
|
| 1052 |
+
keep_dims=False,
|
| 1053 |
+
return_sign=False,
|
| 1054 |
+
name=None):
|
| 1055 |
+
"""Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.
|
| 1056 |
+
|
| 1057 |
+
If all weights `w` are known to be positive, it is more efficient to directly
|
| 1058 |
+
use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.math.log(w))` is
|
| 1059 |
+
more
|
| 1060 |
+
efficient than `du.reduce_weighted_logsumexp(logx, w)`.
|
| 1061 |
+
|
| 1062 |
+
Reduces `input_tensor` along the dimensions given in `axis`.
|
| 1063 |
+
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
|
| 1064 |
+
entry in `axis`. If `keep_dims` is true, the reduced dimensions
|
| 1065 |
+
are retained with length 1.
|
| 1066 |
+
|
| 1067 |
+
If `axis` has no entries, all dimensions are reduced, and a
|
| 1068 |
+
tensor with a single element is returned.
|
| 1069 |
+
|
| 1070 |
+
This function is more numerically stable than log(sum(w * exp(input))). It
|
| 1071 |
+
avoids overflows caused by taking the exp of large inputs and underflows
|
| 1072 |
+
caused by taking the log of small inputs.
|
| 1073 |
+
|
| 1074 |
+
For example:
|
| 1075 |
+
|
| 1076 |
+
```python
|
| 1077 |
+
x = tf.constant([[0., 0, 0],
|
| 1078 |
+
[0, 0, 0]])
|
| 1079 |
+
|
| 1080 |
+
w = tf.constant([[-1., 1, 1],
|
| 1081 |
+
[1, 1, 1]])
|
| 1082 |
+
|
| 1083 |
+
du.reduce_weighted_logsumexp(x, w)
|
| 1084 |
+
# ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)
|
| 1085 |
+
|
| 1086 |
+
du.reduce_weighted_logsumexp(x, w, axis=0)
|
| 1087 |
+
# ==> [log(-1+1), log(1+1), log(1+1)]
|
| 1088 |
+
|
| 1089 |
+
du.reduce_weighted_logsumexp(x, w, axis=1)
|
| 1090 |
+
# ==> [log(-1+1+1), log(1+1+1)]
|
| 1091 |
+
|
| 1092 |
+
du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
|
| 1093 |
+
# ==> [[log(-1+1+1)], [log(1+1+1)]]
|
| 1094 |
+
|
| 1095 |
+
du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
|
| 1096 |
+
# ==> log(-1+5)
|
| 1097 |
+
```
|
| 1098 |
+
|
| 1099 |
+
Args:
|
| 1100 |
+
logx: The tensor to reduce. Should have numeric type.
|
| 1101 |
+
w: The weight tensor. Should have numeric type identical to `logx`.
|
| 1102 |
+
axis: The dimensions to reduce. If `None` (the default), reduces all
|
| 1103 |
+
dimensions. Must be in the range `[-rank(input_tensor),
|
| 1104 |
+
rank(input_tensor))`.
|
| 1105 |
+
keep_dims: If true, retains reduced dimensions with length 1.
|
| 1106 |
+
return_sign: If `True`, returns the sign of the result.
|
| 1107 |
+
name: A name for the operation (optional).
|
| 1108 |
+
|
| 1109 |
+
Returns:
|
| 1110 |
+
lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
|
| 1111 |
+
sign: (Optional) The sign of `sum(weight * exp(x))`.
|
| 1112 |
+
"""
|
| 1113 |
+
with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]):
|
| 1114 |
+
logx = ops.convert_to_tensor(logx, name="logx")
|
| 1115 |
+
if w is None:
|
| 1116 |
+
lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
|
| 1117 |
+
if return_sign:
|
| 1118 |
+
sgn = array_ops.ones_like(lswe)
|
| 1119 |
+
return lswe, sgn
|
| 1120 |
+
return lswe
|
| 1121 |
+
w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w")
|
| 1122 |
+
log_absw_x = logx + math_ops.log(math_ops.abs(w))
|
| 1123 |
+
max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True)
|
| 1124 |
+
# If the largest element is `-inf` or `inf` then we don't bother subtracting
|
| 1125 |
+
# off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
|
| 1126 |
+
# this is ok follows from the fact that we're actually free to subtract any
|
| 1127 |
+
# value we like, so long as we add it back after taking the `log(sum(...))`.
|
| 1128 |
+
max_log_absw_x = array_ops.where_v2(
|
| 1129 |
+
math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x),
|
| 1130 |
+
max_log_absw_x)
|
| 1131 |
+
wx_over_max_absw_x = (
|
| 1132 |
+
math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x))
|
| 1133 |
+
sum_wx_over_max_absw_x = math_ops.reduce_sum(
|
| 1134 |
+
wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
|
| 1135 |
+
if not keep_dims:
|
| 1136 |
+
max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis)
|
| 1137 |
+
sgn = math_ops.sign(sum_wx_over_max_absw_x)
|
| 1138 |
+
lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x)
|
| 1139 |
+
if return_sign:
|
| 1140 |
+
return lswe, sgn
|
| 1141 |
+
return lswe
|
| 1142 |
+
|
| 1143 |
+
|
| 1144 |
+
# TODO(jvdillon): Merge this test back into:
|
| 1145 |
+
# tensorflow/python/ops/softplus_op_test.py
|
| 1146 |
+
# once TF core is accepting new ops.
|
| 1147 |
+
def softplus_inverse(x, name=None):
|
| 1148 |
+
"""Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
|
| 1149 |
+
|
| 1150 |
+
Mathematically this op is equivalent to:
|
| 1151 |
+
|
| 1152 |
+
```none
|
| 1153 |
+
softplus_inverse = log(exp(x) - 1.)
|
| 1154 |
+
```
|
| 1155 |
+
|
| 1156 |
+
Args:
|
| 1157 |
+
x: `Tensor`. Non-negative (not enforced), floating-point.
|
| 1158 |
+
name: A name for the operation (optional).
|
| 1159 |
+
|
| 1160 |
+
Returns:
|
| 1161 |
+
`Tensor`. Has the same type/shape as input `x`.
|
| 1162 |
+
"""
|
| 1163 |
+
with ops.name_scope(name, "softplus_inverse", values=[x]):
|
| 1164 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 1165 |
+
# We begin by deriving a more numerically stable softplus_inverse:
|
| 1166 |
+
# x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
|
| 1167 |
+
# ==> exp{x} = 1 + exp{y} (1)
|
| 1168 |
+
# ==> y = Log[exp{x} - 1] (2)
|
| 1169 |
+
# = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
|
| 1170 |
+
# = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
|
| 1171 |
+
# = Log[1 - exp{-x}] + x (3)
|
| 1172 |
+
# (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
|
| 1173 |
+
# For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
|
| 1174 |
+
# be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
|
| 1175 |
+
#
|
| 1176 |
+
# In addition to the numerically stable derivation above, we clamp
|
| 1177 |
+
# small/large values to be congruent with the logic in:
|
| 1178 |
+
# tensorflow/core/kernels/softplus_op.h
|
| 1179 |
+
#
|
| 1180 |
+
# Finally, we set the input to one whenever the input is too large or too
|
| 1181 |
+
# small. This ensures that no unchosen codepath is +/- inf. This is
|
| 1182 |
+
# necessary to ensure the gradient doesn't get NaNs. Recall that the
|
| 1183 |
+
# gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
|
| 1184 |
+
# thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
|
| 1185 |
+
# to overwrite `x` with ones only when we will never actually use this
|
| 1186 |
+
# value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
|
| 1187 |
+
threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
|
| 1188 |
+
is_too_small = math_ops.less(x, np.exp(threshold))
|
| 1189 |
+
is_too_large = math_ops.greater(x, -threshold)
|
| 1190 |
+
too_small_value = math_ops.log(x)
|
| 1191 |
+
too_large_value = x
|
| 1192 |
+
# This `where` will ultimately be a NOP because we won't select this
|
| 1193 |
+
# codepath whenever we used the surrogate `ones_like`.
|
| 1194 |
+
x = array_ops.where_v2(
|
| 1195 |
+
math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x),
|
| 1196 |
+
x)
|
| 1197 |
+
y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x))
|
| 1198 |
+
return array_ops.where_v2(
|
| 1199 |
+
is_too_small, too_small_value,
|
| 1200 |
+
array_ops.where_v2(is_too_large, too_large_value, y))
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
# TODO(b/35290280): Add unit-tests.
|
| 1204 |
+
def dimension_size(x, axis):
|
| 1205 |
+
"""Returns the size of a specific dimension."""
|
| 1206 |
+
# Since tf.gather isn't "constant-in, constant-out", we must first check the
|
| 1207 |
+
# static shape or fallback to dynamic shape.
|
| 1208 |
+
s = tensor_shape.dimension_value(
|
| 1209 |
+
x.shape.with_rank_at_least(np.abs(axis))[axis])
|
| 1210 |
+
if s is not None:
|
| 1211 |
+
return s
|
| 1212 |
+
return array_ops.shape(x)[axis]
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
def process_quadrature_grid_and_probs(quadrature_grid_and_probs,
|
| 1216 |
+
dtype,
|
| 1217 |
+
validate_args,
|
| 1218 |
+
name=None):
|
| 1219 |
+
"""Validates quadrature grid, probs or computes them as necessary.
|
| 1220 |
+
|
| 1221 |
+
Args:
|
| 1222 |
+
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
| 1223 |
+
representing the sample points and the corresponding (possibly
|
| 1224 |
+
normalized) weight. When `None`, defaults to:
|
| 1225 |
+
`np.polynomial.hermite.hermgauss(deg=8)`.
|
| 1226 |
+
dtype: The expected `dtype` of `grid` and `probs`.
|
| 1227 |
+
validate_args: Python `bool`, default `False`. When `True` distribution
|
| 1228 |
+
parameters are checked for validity despite possibly degrading runtime
|
| 1229 |
+
performance. When `False` invalid inputs may silently render incorrect
|
| 1230 |
+
outputs.
|
| 1231 |
+
name: Python `str` name prefixed to Ops created by this class.
|
| 1232 |
+
|
| 1233 |
+
Returns:
|
| 1234 |
+
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
| 1235 |
+
representing the sample points and the corresponding (possibly
|
| 1236 |
+
normalized) weight.
|
| 1237 |
+
|
| 1238 |
+
Raises:
|
| 1239 |
+
ValueError: if `quadrature_grid_and_probs is not None` and
|
| 1240 |
+
`len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
|
| 1241 |
+
"""
|
| 1242 |
+
with ops.name_scope(name, "process_quadrature_grid_and_probs",
|
| 1243 |
+
[quadrature_grid_and_probs]):
|
| 1244 |
+
if quadrature_grid_and_probs is None:
|
| 1245 |
+
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
|
| 1246 |
+
grid = grid.astype(dtype.as_numpy_dtype)
|
| 1247 |
+
probs = probs.astype(dtype.as_numpy_dtype)
|
| 1248 |
+
probs /= np.linalg.norm(probs, ord=1, keepdims=True)
|
| 1249 |
+
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
|
| 1250 |
+
probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
|
| 1251 |
+
return grid, probs
|
| 1252 |
+
|
| 1253 |
+
grid, probs = tuple(quadrature_grid_and_probs)
|
| 1254 |
+
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
|
| 1255 |
+
probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype)
|
| 1256 |
+
probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs")
|
| 1257 |
+
|
| 1258 |
+
def _static_event_size(x):
|
| 1259 |
+
"""Returns the static size of a specific dimension or `None`."""
|
| 1260 |
+
return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1])
|
| 1261 |
+
|
| 1262 |
+
m, n = _static_event_size(probs), _static_event_size(grid)
|
| 1263 |
+
if m is not None and n is not None:
|
| 1264 |
+
if m != n:
|
| 1265 |
+
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
|
| 1266 |
+
"same-length zero-th-dimension `Tensor`s "
|
| 1267 |
+
"(saw lengths {}, {})".format(m, n))
|
| 1268 |
+
elif validate_args:
|
| 1269 |
+
assertions = [
|
| 1270 |
+
check_ops.assert_equal(
|
| 1271 |
+
dimension_size(probs, axis=-1),
|
| 1272 |
+
dimension_size(grid, axis=-1),
|
| 1273 |
+
message=("`quadrature_grid_and_probs` must be a `tuple` of "
|
| 1274 |
+
"same-length zero-th-dimension `Tensor`s")),
|
| 1275 |
+
]
|
| 1276 |
+
with ops.control_dependencies(assertions):
|
| 1277 |
+
grid = array_ops.identity(grid)
|
| 1278 |
+
probs = array_ops.identity(probs)
|
| 1279 |
+
return grid, probs
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
|
| 1283 |
+
"""Pads `value` to the front and/or back of a `Tensor` dim, `count` times.
|
| 1284 |
+
|
| 1285 |
+
Args:
|
| 1286 |
+
x: `Tensor` input.
|
| 1287 |
+
axis: Scalar `int`-like `Tensor` representing the single dimension to pad.
|
| 1288 |
+
(Negative indexing is supported.)
|
| 1289 |
+
front: Python `bool`; if `True` the beginning of the `axis` dimension is
|
| 1290 |
+
padded with `value`, `count` times. If `False` no front padding is made.
|
| 1291 |
+
back: Python `bool`; if `True` the end of the `axis` dimension is padded
|
| 1292 |
+
with `value`, `count` times. If `False` no end padding is made.
|
| 1293 |
+
value: Scalar `int`-like `Tensor` representing the actual value added to the
|
| 1294 |
+
front and/or back of the `axis` dimension of `x`.
|
| 1295 |
+
count: Scalar `int`-like `Tensor` representing number of elements added to
|
| 1296 |
+
the front and/or back of the `axis` dimension of `x`. E.g., if `front =
|
| 1297 |
+
back = True` then `2 * count` elements are added.
|
| 1298 |
+
name: Python `str` name prefixed to Ops created by this function.
|
| 1299 |
+
|
| 1300 |
+
Returns:
|
| 1301 |
+
pad: The padded version of input `x`.
|
| 1302 |
+
|
| 1303 |
+
Raises:
|
| 1304 |
+
ValueError: if both `front` and `back` are `False`.
|
| 1305 |
+
TypeError: if `count` is not `int`-like.
|
| 1306 |
+
"""
|
| 1307 |
+
with ops.name_scope(name, "pad", [x, value, count]):
|
| 1308 |
+
x = ops.convert_to_tensor(x, name="x")
|
| 1309 |
+
value = ops.convert_to_tensor(value, dtype=x.dtype, name="value")
|
| 1310 |
+
count = ops.convert_to_tensor(count, name="count")
|
| 1311 |
+
if not count.dtype.is_integer:
|
| 1312 |
+
raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format(
|
| 1313 |
+
count.dtype.name))
|
| 1314 |
+
if not front and not back:
|
| 1315 |
+
raise ValueError("At least one of `front`, `back` must be `True`.")
|
| 1316 |
+
ndims = (
|
| 1317 |
+
x.shape.ndims if x.shape.ndims is not None else array_ops.rank(
|
| 1318 |
+
x, name="ndims"))
|
| 1319 |
+
axis = ops.convert_to_tensor(axis, name="axis")
|
| 1320 |
+
axis_ = tensor_util.constant_value(axis)
|
| 1321 |
+
if axis_ is not None:
|
| 1322 |
+
axis = axis_
|
| 1323 |
+
if axis < 0:
|
| 1324 |
+
axis = ndims + axis
|
| 1325 |
+
count_ = tensor_util.constant_value(count)
|
| 1326 |
+
if axis_ >= 0 or x.shape.ndims is not None:
|
| 1327 |
+
head = x.shape[:axis]
|
| 1328 |
+
middle = tensor_shape.TensorShape(None if count_ is None else (
|
| 1329 |
+
tensor_shape.dimension_at_index(x.shape, axis) + count_ *
|
| 1330 |
+
(front + back)))
|
| 1331 |
+
tail = x.shape[axis + 1:]
|
| 1332 |
+
final_shape = head.concatenate(middle.concatenate(tail))
|
| 1333 |
+
else:
|
| 1334 |
+
final_shape = None
|
| 1335 |
+
else:
|
| 1336 |
+
axis = array_ops.where_v2(axis < 0, ndims + axis, axis)
|
| 1337 |
+
final_shape = None
|
| 1338 |
+
x = array_ops.pad(
|
| 1339 |
+
x,
|
| 1340 |
+
paddings=array_ops.one_hot(
|
| 1341 |
+
indices=array_ops_stack.stack(
|
| 1342 |
+
[axis if front else -1, axis if back else -1]),
|
| 1343 |
+
depth=ndims,
|
| 1344 |
+
axis=0,
|
| 1345 |
+
on_value=count,
|
| 1346 |
+
dtype=dtypes.int32),
|
| 1347 |
+
constant_values=value)
|
| 1348 |
+
if final_shape is not None:
|
| 1349 |
+
x.set_shape(final_shape)
|
| 1350 |
+
return x
|
| 1351 |
+
|
| 1352 |
+
|
| 1353 |
+
def parent_frame_arguments():
|
| 1354 |
+
"""Returns parent frame arguments.
|
| 1355 |
+
|
| 1356 |
+
When called inside a function, returns a dictionary with the caller's function
|
| 1357 |
+
arguments. These are positional arguments and keyword arguments (**kwargs),
|
| 1358 |
+
while variable arguments (*varargs) are excluded.
|
| 1359 |
+
|
| 1360 |
+
When called at global scope, this will return an empty dictionary, since there
|
| 1361 |
+
are no arguments.
|
| 1362 |
+
|
| 1363 |
+
WARNING: If caller function argument names are overloaded before invoking
|
| 1364 |
+
this method, then values will reflect the overloaded value. For this reason,
|
| 1365 |
+
we recommend calling `parent_frame_arguments` at the beginning of the
|
| 1366 |
+
function.
|
| 1367 |
+
"""
|
| 1368 |
+
# All arguments and the names used for *varargs, and **kwargs
|
| 1369 |
+
arg_names, variable_arg_name, keyword_arg_name, local_vars = (
|
| 1370 |
+
tf_inspect._inspect.getargvalues( # pylint: disable=protected-access
|
| 1371 |
+
# Get the first frame of the caller of this method.
|
| 1372 |
+
tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access
|
| 1373 |
+
|
| 1374 |
+
# Remove the *varargs, and flatten the **kwargs. Both are
|
| 1375 |
+
# nested lists.
|
| 1376 |
+
local_vars.pop(variable_arg_name, {})
|
| 1377 |
+
keyword_args = local_vars.pop(keyword_arg_name, {})
|
| 1378 |
+
|
| 1379 |
+
final_args = {}
|
| 1380 |
+
# Copy over arguments and their values. In general, local_vars
|
| 1381 |
+
# may contain more than just the arguments, since this method
|
| 1382 |
+
# can be called anywhere in a function.
|
| 1383 |
+
for arg_name in arg_names:
|
| 1384 |
+
final_args[arg_name] = local_vars.pop(arg_name)
|
| 1385 |
+
final_args.update(keyword_args)
|
| 1386 |
+
|
| 1387 |
+
return final_args
|
| 1388 |
+
|
| 1389 |
+
|
| 1390 |
+
class AppendDocstring:
|
| 1391 |
+
"""Helper class to promote private subclass docstring to public counterpart.
|
| 1392 |
+
|
| 1393 |
+
Example:
|
| 1394 |
+
|
| 1395 |
+
```python
|
| 1396 |
+
class TransformedDistribution(Distribution):
|
| 1397 |
+
@distribution_util.AppendDocstring(
|
| 1398 |
+
additional_note="A special note!",
|
| 1399 |
+
kwargs_dict={"foo": "An extra arg."})
|
| 1400 |
+
def _prob(self, y, foo=None):
|
| 1401 |
+
pass
|
| 1402 |
+
```
|
| 1403 |
+
|
| 1404 |
+
In this case, the `AppendDocstring` decorator appends the `additional_note` to
|
| 1405 |
+
the docstring of `prob` (not `_prob`) and adds a new `kwargs`
|
| 1406 |
+
section with each dictionary item as a bullet-point.
|
| 1407 |
+
|
| 1408 |
+
For a more detailed example, see `TransformedDistribution`.
|
| 1409 |
+
"""
|
| 1410 |
+
|
| 1411 |
+
def __init__(self, additional_note="", kwargs_dict=None):
|
| 1412 |
+
"""Initializes the AppendDocstring object.
|
| 1413 |
+
|
| 1414 |
+
Args:
|
| 1415 |
+
additional_note: Python string added as additional docstring to public
|
| 1416 |
+
version of function.
|
| 1417 |
+
kwargs_dict: Python string/string dictionary representing specific kwargs
|
| 1418 |
+
expanded from the **kwargs input.
|
| 1419 |
+
|
| 1420 |
+
Raises:
|
| 1421 |
+
ValueError: if kwargs_dict.key contains whitespace.
|
| 1422 |
+
ValueError: if kwargs_dict.value contains newlines.
|
| 1423 |
+
"""
|
| 1424 |
+
self._additional_note = additional_note
|
| 1425 |
+
if kwargs_dict:
|
| 1426 |
+
bullets = []
|
| 1427 |
+
for key in sorted(kwargs_dict.keys()):
|
| 1428 |
+
value = kwargs_dict[key]
|
| 1429 |
+
if any(x.isspace() for x in key):
|
| 1430 |
+
raise ValueError("Parameter name \"%s\" contains whitespace." % key)
|
| 1431 |
+
value = value.lstrip()
|
| 1432 |
+
if "\n" in value:
|
| 1433 |
+
raise ValueError(
|
| 1434 |
+
"Parameter description for \"%s\" contains newlines." % key)
|
| 1435 |
+
bullets.append("* `%s`: %s" % (key, value))
|
| 1436 |
+
self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets))
|
| 1437 |
+
|
| 1438 |
+
def __call__(self, fn):
|
| 1439 |
+
|
| 1440 |
+
@functools.wraps(fn)
|
| 1441 |
+
def _fn(*args, **kwargs):
|
| 1442 |
+
return fn(*args, **kwargs)
|
| 1443 |
+
|
| 1444 |
+
if _fn.__doc__ is None:
|
| 1445 |
+
_fn.__doc__ = self._additional_note
|
| 1446 |
+
else:
|
| 1447 |
+
_fn.__doc__ += "\n%s" % self._additional_note
|
| 1448 |
+
return _fn
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/losses/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (204 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/losses/__pycache__/losses.cpython-310.pyc
ADDED
|
Binary file (480 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/losses/__pycache__/losses_impl.cpython-310.pyc
ADDED
|
Binary file (40.4 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/losses/__pycache__/util.cpython-310.pyc
ADDED
|
Binary file (8.32 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Signal processing operations.
|
| 16 |
+
|
| 17 |
+
See the [tf.signal](https://tensorflow.org/api_guides/python/contrib.signal)
|
| 18 |
+
guide.
|
| 19 |
+
|
| 20 |
+
@@frame
|
| 21 |
+
@@hamming_window
|
| 22 |
+
@@hann_window
|
| 23 |
+
@@inverse_stft
|
| 24 |
+
@@inverse_stft_window_fn
|
| 25 |
+
@@mfccs_from_log_mel_spectrograms
|
| 26 |
+
@@linear_to_mel_weight_matrix
|
| 27 |
+
@@overlap_and_add
|
| 28 |
+
@@stft
|
| 29 |
+
|
| 30 |
+
[hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window
|
| 31 |
+
[hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window
|
| 32 |
+
[mel]: https://en.wikipedia.org/wiki/Mel_scale
|
| 33 |
+
[mfcc]: https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
|
| 34 |
+
[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
|
| 35 |
+
|
| 36 |
+
API docstring: tensorflow.signal
|
| 37 |
+
"""
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (851 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/dct_ops.cpython-310.pyc
ADDED
|
Binary file (8.56 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/fft_ops.cpython-310.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/mel_ops.cpython-310.pyc
ADDED
|
Binary file (6.88 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/mfcc_ops.cpython-310.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/signal/__pycache__/reconstruction_ops.cpython-310.pyc
ADDED
|
Binary file (3.23 kB). View file
|
|
|