| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import flax.linen as nn |
| | import jax |
| | import jax.numpy as jnp |
| |
|
| |
|
| | class FlaxUpsample2D(nn.Module): |
| | out_channels: int |
| | dtype: jnp.dtype = jnp.float32 |
| |
|
| | def setup(self): |
| | self.conv = nn.Conv( |
| | self.out_channels, |
| | kernel_size=(3, 3), |
| | strides=(1, 1), |
| | padding=((1, 1), (1, 1)), |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def __call__(self, hidden_states): |
| | batch, height, width, channels = hidden_states.shape |
| | hidden_states = jax.image.resize( |
| | hidden_states, |
| | shape=(batch, height * 2, width * 2, channels), |
| | method="nearest", |
| | ) |
| | hidden_states = self.conv(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class FlaxDownsample2D(nn.Module): |
| | out_channels: int |
| | dtype: jnp.dtype = jnp.float32 |
| |
|
| | def setup(self): |
| | self.conv = nn.Conv( |
| | self.out_channels, |
| | kernel_size=(3, 3), |
| | strides=(2, 2), |
| | padding=((1, 1), (1, 1)), |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def __call__(self, hidden_states): |
| | |
| | |
| | hidden_states = self.conv(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class FlaxResnetBlock2D(nn.Module): |
| | in_channels: int |
| | out_channels: int = None |
| | dropout_prob: float = 0.0 |
| | use_nin_shortcut: bool = None |
| | dtype: jnp.dtype = jnp.float32 |
| |
|
| | def setup(self): |
| | out_channels = self.in_channels if self.out_channels is None else self.out_channels |
| |
|
| | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) |
| | self.conv1 = nn.Conv( |
| | out_channels, |
| | kernel_size=(3, 3), |
| | strides=(1, 1), |
| | padding=((1, 1), (1, 1)), |
| | dtype=self.dtype, |
| | ) |
| |
|
| | self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) |
| |
|
| | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) |
| | self.dropout = nn.Dropout(self.dropout_prob) |
| | self.conv2 = nn.Conv( |
| | out_channels, |
| | kernel_size=(3, 3), |
| | strides=(1, 1), |
| | padding=((1, 1), (1, 1)), |
| | dtype=self.dtype, |
| | ) |
| |
|
| | use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut |
| |
|
| | self.conv_shortcut = None |
| | if use_nin_shortcut: |
| | self.conv_shortcut = nn.Conv( |
| | out_channels, |
| | kernel_size=(1, 1), |
| | strides=(1, 1), |
| | padding="VALID", |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def __call__(self, hidden_states, temb, deterministic=True): |
| | residual = hidden_states |
| | hidden_states = self.norm1(hidden_states) |
| | hidden_states = nn.swish(hidden_states) |
| | hidden_states = self.conv1(hidden_states) |
| |
|
| | temb = self.time_emb_proj(nn.swish(temb)) |
| | temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) |
| | hidden_states = hidden_states + temb |
| |
|
| | hidden_states = self.norm2(hidden_states) |
| | hidden_states = nn.swish(hidden_states) |
| | hidden_states = self.dropout(hidden_states, deterministic) |
| | hidden_states = self.conv2(hidden_states) |
| |
|
| | if self.conv_shortcut is not None: |
| | residual = self.conv_shortcut(residual) |
| |
|
| | return hidden_states + residual |
| |
|