Spaces:
Runtime error
Runtime error
| # Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Utility functions for blocks.""" | |
| from __future__ import division | |
| from __future__ import unicode_literals | |
| import math | |
| import numpy as np | |
| import six | |
| import tensorflow as tf | |
| class RsqrtInitializer(object): | |
| """Gaussian initializer with standard deviation 1/sqrt(n). | |
| Note that tf.truncated_normal is used internally. Therefore any random sample | |
| outside two-sigma will be discarded and re-sampled. | |
| """ | |
| def __init__(self, dims=(0,), **kwargs): | |
| """Creates an initializer. | |
| Args: | |
| dims: Dimension(s) index to compute standard deviation: | |
| 1.0 / sqrt(product(shape[dims])) | |
| **kwargs: Extra keyword arguments to pass to tf.truncated_normal. | |
| """ | |
| if isinstance(dims, six.integer_types): | |
| self._dims = [dims] | |
| else: | |
| self._dims = dims | |
| self._kwargs = kwargs | |
| def __call__(self, shape, dtype): | |
| stddev = 1.0 / np.sqrt(np.prod([shape[x] for x in self._dims])) | |
| return tf.truncated_normal( | |
| shape=shape, dtype=dtype, stddev=stddev, **self._kwargs) | |
| class RectifierInitializer(object): | |
| """Gaussian initializer with standard deviation sqrt(2/fan_in). | |
| Note that tf.random_normal is used internally to ensure the expected weight | |
| distribution. This is intended to be used with ReLU activations, specially | |
| in ResNets. | |
| For details please refer to: | |
| Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet | |
| Classification | |
| """ | |
| def __init__(self, dims=(0,), scale=2.0, **kwargs): | |
| """Creates an initializer. | |
| Args: | |
| dims: Dimension(s) index to compute standard deviation: | |
| sqrt(scale / product(shape[dims])) | |
| scale: A constant scaling for the initialization used as | |
| sqrt(scale / product(shape[dims])). | |
| **kwargs: Extra keyword arguments to pass to tf.truncated_normal. | |
| """ | |
| if isinstance(dims, six.integer_types): | |
| self._dims = [dims] | |
| else: | |
| self._dims = dims | |
| self._kwargs = kwargs | |
| self._scale = scale | |
| def __call__(self, shape, dtype): | |
| stddev = np.sqrt(self._scale / np.prod([shape[x] for x in self._dims])) | |
| return tf.random_normal( | |
| shape=shape, dtype=dtype, stddev=stddev, **self._kwargs) | |
| class GaussianInitializer(object): | |
| """Gaussian initializer with a given standard deviation. | |
| Note that tf.truncated_normal is used internally. Therefore any random sample | |
| outside two-sigma will be discarded and re-sampled. | |
| """ | |
| def __init__(self, stddev=1.0): | |
| self._stddev = stddev | |
| def __call__(self, shape, dtype): | |
| return tf.truncated_normal(shape=shape, dtype=dtype, stddev=self._stddev) | |