| |
|
|
| import mlx.core as mx |
| from mlx.nn.layers.base import Module |
|
|
|
|
| class Dropout(Module): |
| r"""Randomly zero a portion of the elements during training. |
| |
| The remaining elements are multiplied with :math:`\frac{1}{1-p}` where |
| :math:`p` is the probability of zeroing an element. This is done so the |
| expected value of a given element will remain the same. |
| |
| Args: |
| p (float): The probability to zero an element |
| """ |
|
|
| def __init__(self, p: float = 0.5): |
| super().__init__() |
|
|
| if p < 0 or p >= 1: |
| raise ValueError(f"The dropout probability {p} is not in [0, 1)") |
|
|
| self._p_1 = 1 - p |
|
|
| def _extra_repr(self): |
| return f"p={1-self._p_1}" |
|
|
| def __call__(self, x): |
| if self._p_1 == 1 or not self.training: |
| return x |
|
|
| mask = mx.random.bernoulli(self._p_1, x.shape) |
|
|
| return (1 / self._p_1) * mask * x |
|
|
|
|
| class Dropout2d(Module): |
| r"""Apply 2D channel-wise dropout during training. |
| |
| Randomly zero out entire channels independently with probability :math:`p`. |
| This layer expects the channels to be last, i.e. the input shape should be |
| ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input |
| image height,``W`` is the input image width, and``C`` is the number of |
| input channels |
| |
| The remaining channels are scaled by :math:`\frac{1}{1-p}` to |
| maintain the expected value of each element. Unlike traditional dropout, |
| which zeros individual entries, this layer zeros entire channels. This is |
| beneficial for early convolution layers where adjacent pixels are |
| correlated. In such case, traditional dropout may not effectively |
| regularize activations. For more details, see [1]. |
| |
| [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015. |
| Efficient Object Localization Using Convolutional Networks. CVPR 2015. |
| |
| Args: |
| p (float): Probability of zeroing a channel during training. |
| """ |
|
|
| def __init__(self, p: float = 0.5): |
| super().__init__() |
|
|
| if p < 0 or p >= 1: |
| raise ValueError(f"The dropout probability {p} is not in [0, 1)") |
|
|
| self._p_1 = 1 - p |
|
|
| def _extra_repr(self): |
| return f"p={1-self._p_1}" |
|
|
| def __call__(self, x): |
| if x.ndim not in (3, 4): |
| raise ValueError( |
| f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions." |
| ) |
|
|
| if self._p_1 == 1 or not self.training: |
| return x |
|
|
| |
| |
| |
| mask_shape = x.shape |
| mask_shape[-2] = mask_shape[-3] = 1 |
|
|
| mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) |
| return (1 / self._p_1) * mask * x |
|
|
|
|
| class Dropout3d(Module): |
| r"""Apply 3D channel-wise dropout during training. |
| |
| Randomly zero out entire channels independently with probability :math:`p`. |
| This layer expects the channels to be last, i.e., the input shape should be |
| `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth, |
| `H` is the input image height, `W` is the input image width, and `C` is |
| the number of input channels. |
| |
| The remaining channels are scaled by :math:`\frac{1}{1-p}` to |
| maintain the expected value of each element. Unlike traditional dropout, |
| which zeros individual entries, this layer zeros entire channels. This is |
| often beneficial for convolutional layers processing 3D data, like in |
| medical imaging or video processing. |
| |
| Args: |
| p (float): Probability of zeroing a channel during training. |
| """ |
|
|
| def __init__(self, p: float = 0.5): |
| super().__init__() |
|
|
| if p < 0 or p >= 1: |
| raise ValueError(f"The dropout probability {p} is not in [0, 1)") |
|
|
| self._p_1 = 1 - p |
|
|
| def _extra_repr(self): |
| return f"p={1-self._p_1}" |
|
|
| def __call__(self, x): |
| if x.ndim not in (4, 5): |
| raise ValueError( |
| f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions." |
| ) |
|
|
| if self._p_1 == 1 or not self.training: |
| return x |
|
|
| |
| |
| |
| mask_shape = list(x.shape) |
| mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1 |
|
|
| mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) |
| return (1 / self._p_1) * mask * x |
|
|