| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import flax.linen as nn |
| | import jax.numpy as jnp |
| |
|
| | from ..attention_flax import FlaxTransformer2DModel |
| | from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D |
| |
|
| |
|
| | class FlaxCrossAttnDownBlock2D(nn.Module): |
| | r""" |
| | Cross Attention 2D Downsizing block - original architecture from Unet transformers: |
| | https://arxiv.org/abs/2103.06104 |
| | |
| | Parameters: |
| | in_channels (:obj:`int`): |
| | Input channels |
| | out_channels (:obj:`int`): |
| | Output channels |
| | dropout (:obj:`float`, *optional*, defaults to 0.0): |
| | Dropout rate |
| | num_layers (:obj:`int`, *optional*, defaults to 1): |
| | Number of attention blocks layers |
| | num_attention_heads (:obj:`int`, *optional*, defaults to 1): |
| | Number of attention heads of each spatial transformer block |
| | add_downsample (:obj:`bool`, *optional*, defaults to `True`): |
| | Whether to add downsampling layer before each final output |
| | use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): |
| | enable memory efficient attention https://arxiv.org/abs/2112.05682 |
| | split_head_dim (`bool`, *optional*, defaults to `False`): |
| | Whether to split the head dimension into a new axis for the self-attention computation. In most cases, |
| | enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. |
| | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| | Parameters `dtype` |
| | """ |
| |
|
| | in_channels: int |
| | out_channels: int |
| | dropout: float = 0.0 |
| | num_layers: int = 1 |
| | num_attention_heads: int = 1 |
| | add_downsample: bool = True |
| | use_linear_projection: bool = False |
| | only_cross_attention: bool = False |
| | use_memory_efficient_attention: bool = False |
| | split_head_dim: bool = False |
| | dtype: jnp.dtype = jnp.float32 |
| | transformer_layers_per_block: int = 1 |
| |
|
| | def setup(self): |
| | resnets = [] |
| | attentions = [] |
| |
|
| | for i in range(self.num_layers): |
| | in_channels = self.in_channels if i == 0 else self.out_channels |
| |
|
| | res_block = FlaxResnetBlock2D( |
| | in_channels=in_channels, |
| | out_channels=self.out_channels, |
| | dropout_prob=self.dropout, |
| | dtype=self.dtype, |
| | ) |
| | resnets.append(res_block) |
| |
|
| | attn_block = FlaxTransformer2DModel( |
| | in_channels=self.out_channels, |
| | n_heads=self.num_attention_heads, |
| | d_head=self.out_channels // self.num_attention_heads, |
| | depth=self.transformer_layers_per_block, |
| | use_linear_projection=self.use_linear_projection, |
| | only_cross_attention=self.only_cross_attention, |
| | use_memory_efficient_attention=self.use_memory_efficient_attention, |
| | split_head_dim=self.split_head_dim, |
| | dtype=self.dtype, |
| | ) |
| | attentions.append(attn_block) |
| |
|
| | self.resnets = resnets |
| | self.attentions = attentions |
| |
|
| | if self.add_downsample: |
| | self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) |
| |
|
| | def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): |
| | output_states = () |
| |
|
| | for resnet, attn in zip(self.resnets, self.attentions): |
| | hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
| | hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) |
| | output_states += (hidden_states,) |
| |
|
| | if self.add_downsample: |
| | hidden_states = self.downsamplers_0(hidden_states) |
| | output_states += (hidden_states,) |
| |
|
| | return hidden_states, output_states |
| |
|
| |
|
| | class FlaxDownBlock2D(nn.Module): |
| | r""" |
| | Flax 2D downsizing block |
| | |
| | Parameters: |
| | in_channels (:obj:`int`): |
| | Input channels |
| | out_channels (:obj:`int`): |
| | Output channels |
| | dropout (:obj:`float`, *optional*, defaults to 0.0): |
| | Dropout rate |
| | num_layers (:obj:`int`, *optional*, defaults to 1): |
| | Number of attention blocks layers |
| | add_downsample (:obj:`bool`, *optional*, defaults to `True`): |
| | Whether to add downsampling layer before each final output |
| | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| | Parameters `dtype` |
| | """ |
| |
|
| | in_channels: int |
| | out_channels: int |
| | dropout: float = 0.0 |
| | num_layers: int = 1 |
| | add_downsample: bool = True |
| | dtype: jnp.dtype = jnp.float32 |
| |
|
| | def setup(self): |
| | resnets = [] |
| |
|
| | for i in range(self.num_layers): |
| | in_channels = self.in_channels if i == 0 else self.out_channels |
| |
|
| | res_block = FlaxResnetBlock2D( |
| | in_channels=in_channels, |
| | out_channels=self.out_channels, |
| | dropout_prob=self.dropout, |
| | dtype=self.dtype, |
| | ) |
| | resnets.append(res_block) |
| | self.resnets = resnets |
| |
|
| | if self.add_downsample: |
| | self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) |
| |
|
| | def __call__(self, hidden_states, temb, deterministic=True): |
| | output_states = () |
| |
|
| | for resnet in self.resnets: |
| | hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
| | output_states += (hidden_states,) |
| |
|
| | if self.add_downsample: |
| | hidden_states = self.downsamplers_0(hidden_states) |
| | output_states += (hidden_states,) |
| |
|
| | return hidden_states, output_states |
| |
|
| |
|
| | class FlaxCrossAttnUpBlock2D(nn.Module): |
| | r""" |
| | Cross Attention 2D Upsampling block - original architecture from Unet transformers: |
| | https://arxiv.org/abs/2103.06104 |
| | |
| | Parameters: |
| | in_channels (:obj:`int`): |
| | Input channels |
| | out_channels (:obj:`int`): |
| | Output channels |
| | dropout (:obj:`float`, *optional*, defaults to 0.0): |
| | Dropout rate |
| | num_layers (:obj:`int`, *optional*, defaults to 1): |
| | Number of attention blocks layers |
| | num_attention_heads (:obj:`int`, *optional*, defaults to 1): |
| | Number of attention heads of each spatial transformer block |
| | add_upsample (:obj:`bool`, *optional*, defaults to `True`): |
| | Whether to add upsampling layer before each final output |
| | use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): |
| | enable memory efficient attention https://arxiv.org/abs/2112.05682 |
| | split_head_dim (`bool`, *optional*, defaults to `False`): |
| | Whether to split the head dimension into a new axis for the self-attention computation. In most cases, |
| | enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. |
| | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| | Parameters `dtype` |
| | """ |
| |
|
| | in_channels: int |
| | out_channels: int |
| | prev_output_channel: int |
| | dropout: float = 0.0 |
| | num_layers: int = 1 |
| | num_attention_heads: int = 1 |
| | add_upsample: bool = True |
| | use_linear_projection: bool = False |
| | only_cross_attention: bool = False |
| | use_memory_efficient_attention: bool = False |
| | split_head_dim: bool = False |
| | dtype: jnp.dtype = jnp.float32 |
| | transformer_layers_per_block: int = 1 |
| |
|
| | def setup(self): |
| | resnets = [] |
| | attentions = [] |
| |
|
| | for i in range(self.num_layers): |
| | res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels |
| | resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels |
| |
|
| | res_block = FlaxResnetBlock2D( |
| | in_channels=resnet_in_channels + res_skip_channels, |
| | out_channels=self.out_channels, |
| | dropout_prob=self.dropout, |
| | dtype=self.dtype, |
| | ) |
| | resnets.append(res_block) |
| |
|
| | attn_block = FlaxTransformer2DModel( |
| | in_channels=self.out_channels, |
| | n_heads=self.num_attention_heads, |
| | d_head=self.out_channels // self.num_attention_heads, |
| | depth=self.transformer_layers_per_block, |
| | use_linear_projection=self.use_linear_projection, |
| | only_cross_attention=self.only_cross_attention, |
| | use_memory_efficient_attention=self.use_memory_efficient_attention, |
| | split_head_dim=self.split_head_dim, |
| | dtype=self.dtype, |
| | ) |
| | attentions.append(attn_block) |
| |
|
| | self.resnets = resnets |
| | self.attentions = attentions |
| |
|
| | if self.add_upsample: |
| | self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) |
| |
|
| | def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): |
| | for resnet, attn in zip(self.resnets, self.attentions): |
| | |
| | res_hidden_states = res_hidden_states_tuple[-1] |
| | res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
| | hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) |
| |
|
| | hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
| | hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) |
| |
|
| | if self.add_upsample: |
| | hidden_states = self.upsamplers_0(hidden_states) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class FlaxUpBlock2D(nn.Module): |
| | r""" |
| | Flax 2D upsampling block |
| | |
| | Parameters: |
| | in_channels (:obj:`int`): |
| | Input channels |
| | out_channels (:obj:`int`): |
| | Output channels |
| | prev_output_channel (:obj:`int`): |
| | Output channels from the previous block |
| | dropout (:obj:`float`, *optional*, defaults to 0.0): |
| | Dropout rate |
| | num_layers (:obj:`int`, *optional*, defaults to 1): |
| | Number of attention blocks layers |
| | add_downsample (:obj:`bool`, *optional*, defaults to `True`): |
| | Whether to add downsampling layer before each final output |
| | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| | Parameters `dtype` |
| | """ |
| |
|
| | in_channels: int |
| | out_channels: int |
| | prev_output_channel: int |
| | dropout: float = 0.0 |
| | num_layers: int = 1 |
| | add_upsample: bool = True |
| | dtype: jnp.dtype = jnp.float32 |
| |
|
| | def setup(self): |
| | resnets = [] |
| |
|
| | for i in range(self.num_layers): |
| | res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels |
| | resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels |
| |
|
| | res_block = FlaxResnetBlock2D( |
| | in_channels=resnet_in_channels + res_skip_channels, |
| | out_channels=self.out_channels, |
| | dropout_prob=self.dropout, |
| | dtype=self.dtype, |
| | ) |
| | resnets.append(res_block) |
| |
|
| | self.resnets = resnets |
| |
|
| | if self.add_upsample: |
| | self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) |
| |
|
| | def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): |
| | for resnet in self.resnets: |
| | |
| | res_hidden_states = res_hidden_states_tuple[-1] |
| | res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
| | hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) |
| |
|
| | hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
| |
|
| | if self.add_upsample: |
| | hidden_states = self.upsamplers_0(hidden_states) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class FlaxUNetMidBlock2DCrossAttn(nn.Module): |
| | r""" |
| | Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 |
| | |
| | Parameters: |
| | in_channels (:obj:`int`): |
| | Input channels |
| | dropout (:obj:`float`, *optional*, defaults to 0.0): |
| | Dropout rate |
| | num_layers (:obj:`int`, *optional*, defaults to 1): |
| | Number of attention blocks layers |
| | num_attention_heads (:obj:`int`, *optional*, defaults to 1): |
| | Number of attention heads of each spatial transformer block |
| | use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): |
| | enable memory efficient attention https://arxiv.org/abs/2112.05682 |
| | split_head_dim (`bool`, *optional*, defaults to `False`): |
| | Whether to split the head dimension into a new axis for the self-attention computation. In most cases, |
| | enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. |
| | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| | Parameters `dtype` |
| | """ |
| |
|
| | in_channels: int |
| | dropout: float = 0.0 |
| | num_layers: int = 1 |
| | num_attention_heads: int = 1 |
| | use_linear_projection: bool = False |
| | use_memory_efficient_attention: bool = False |
| | split_head_dim: bool = False |
| | dtype: jnp.dtype = jnp.float32 |
| | transformer_layers_per_block: int = 1 |
| |
|
| | def setup(self): |
| | |
| | resnets = [ |
| | FlaxResnetBlock2D( |
| | in_channels=self.in_channels, |
| | out_channels=self.in_channels, |
| | dropout_prob=self.dropout, |
| | dtype=self.dtype, |
| | ) |
| | ] |
| |
|
| | attentions = [] |
| |
|
| | for _ in range(self.num_layers): |
| | attn_block = FlaxTransformer2DModel( |
| | in_channels=self.in_channels, |
| | n_heads=self.num_attention_heads, |
| | d_head=self.in_channels // self.num_attention_heads, |
| | depth=self.transformer_layers_per_block, |
| | use_linear_projection=self.use_linear_projection, |
| | use_memory_efficient_attention=self.use_memory_efficient_attention, |
| | split_head_dim=self.split_head_dim, |
| | dtype=self.dtype, |
| | ) |
| | attentions.append(attn_block) |
| |
|
| | res_block = FlaxResnetBlock2D( |
| | in_channels=self.in_channels, |
| | out_channels=self.in_channels, |
| | dropout_prob=self.dropout, |
| | dtype=self.dtype, |
| | ) |
| | resnets.append(res_block) |
| |
|
| | self.resnets = resnets |
| | self.attentions = attentions |
| |
|
| | def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): |
| | hidden_states = self.resnets[0](hidden_states, temb) |
| | for attn, resnet in zip(self.attentions, self.resnets[1:]): |
| | hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) |
| | hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
| |
|
| | return hidden_states |
| |
|