Spaces:
Runtime error
Runtime error
| from typing import Tuple | |
| import jax | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| from .flax_attention_pseudo3d import TransformerPseudo3DModel | |
| from .flax_resnet_pseudo3d import ResnetBlockPseudo3D, DownsamplePseudo3D, UpsamplePseudo3D | |
| class UNetMidBlockPseudo3DCrossAttn(nn.Module): | |
| in_channels: int | |
| num_layers: int = 1 | |
| attn_num_head_channels: int = 1 | |
| use_memory_efficient_attention: bool = False | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self) -> None: | |
| resnets = [ | |
| ResnetBlockPseudo3D( | |
| in_channels = self.in_channels, | |
| out_channels = self.in_channels, | |
| dtype = self.dtype | |
| ) | |
| ] | |
| attentions = [] | |
| for _ in range(self.num_layers): | |
| attn_block = TransformerPseudo3DModel( | |
| in_channels = self.in_channels, | |
| num_attention_heads = self.attn_num_head_channels, | |
| attention_head_dim = self.in_channels // self.attn_num_head_channels, | |
| num_layers = 1, | |
| use_memory_efficient_attention = self.use_memory_efficient_attention, | |
| dtype = self.dtype | |
| ) | |
| attentions.append(attn_block) | |
| res_block = ResnetBlockPseudo3D( | |
| in_channels = self.in_channels, | |
| out_channels = self.in_channels, | |
| dtype = self.dtype | |
| ) | |
| resnets.append(res_block) | |
| self.attentions = attentions | |
| self.resnets = resnets | |
| def __call__(self, | |
| hidden_states: jax.Array, | |
| temb: jax.Array, | |
| encoder_hidden_states = jax.Array | |
| ) -> jax.Array: | |
| hidden_states = self.resnets[0](hidden_states, temb) | |
| for attn, resnet in zip(self.attentions, self.resnets[1:]): | |
| hidden_states = attn(hidden_states, encoder_hidden_states) | |
| hidden_states = resnet(hidden_states, temb) | |
| return hidden_states | |
| class CrossAttnDownBlockPseudo3D(nn.Module): | |
| in_channels: int | |
| out_channels: int | |
| num_layers: int = 1 | |
| attn_num_head_channels: int = 1 | |
| add_downsample: bool = True | |
| use_memory_efficient_attention: bool = False | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self) -> None: | |
| attentions = [] | |
| resnets = [] | |
| for i in range(self.num_layers): | |
| in_channels = self.in_channels if i == 0 else self.out_channels | |
| res_block = ResnetBlockPseudo3D( | |
| in_channels = in_channels, | |
| out_channels = self.out_channels, | |
| dtype = self.dtype | |
| ) | |
| resnets.append(res_block) | |
| attn_block = TransformerPseudo3DModel( | |
| in_channels = self.out_channels, | |
| num_attention_heads = self.attn_num_head_channels, | |
| attention_head_dim = self.out_channels // self.attn_num_head_channels, | |
| num_layers = 1, | |
| use_memory_efficient_attention = self.use_memory_efficient_attention, | |
| dtype = self.dtype | |
| ) | |
| attentions.append(attn_block) | |
| self.resnets = resnets | |
| self.attentions = attentions | |
| if self.add_downsample: | |
| self.downsamplers_0 = DownsamplePseudo3D( | |
| out_channels = self.out_channels, | |
| dtype = self.dtype | |
| ) | |
| else: | |
| self.downsamplers_0 = None | |
| def __call__(self, | |
| hidden_states: jax.Array, | |
| temb: jax.Array, | |
| encoder_hidden_states: jax.Array | |
| ) -> Tuple[jax.Array, jax.Array]: | |
| output_states = () | |
| for resnet, attn in zip(self.resnets, self.attentions): | |
| hidden_states = resnet(hidden_states, temb) | |
| hidden_states = attn(hidden_states, encoder_hidden_states) | |
| output_states += (hidden_states, ) | |
| if self.add_downsample: | |
| hidden_states = self.downsamplers_0(hidden_states) | |
| output_states += (hidden_states, ) | |
| return hidden_states, output_states | |
| class DownBlockPseudo3D(nn.Module): | |
| in_channels: int | |
| out_channels: int | |
| num_layers: int = 1 | |
| add_downsample: bool = True | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self) -> None: | |
| resnets = [] | |
| for i in range(self.num_layers): | |
| in_channels = self.in_channels if i == 0 else self.out_channels | |
| res_block = ResnetBlockPseudo3D( | |
| in_channels = in_channels, | |
| out_channels = self.out_channels, | |
| dtype = self.dtype | |
| ) | |
| resnets.append(res_block) | |
| self.resnets = resnets | |
| if self.add_downsample: | |
| self.downsamplers_0 = DownsamplePseudo3D( | |
| out_channels = self.out_channels, | |
| dtype = self.dtype | |
| ) | |
| else: | |
| self.downsamplers_0 = None | |
| def __call__(self, | |
| hidden_states: jax.Array, | |
| temb: jax.Array | |
| ) -> Tuple[jax.Array, jax.Array]: | |
| output_states = () | |
| for resnet in self.resnets: | |
| hidden_states = resnet(hidden_states, temb) | |
| output_states += (hidden_states, ) | |
| if self.add_downsample: | |
| hidden_states = self.downsamplers_0(hidden_states) | |
| output_states += (hidden_states, ) | |
| return hidden_states, output_states | |
| class CrossAttnUpBlockPseudo3D(nn.Module): | |
| in_channels: int | |
| out_channels: int | |
| prev_output_channels: int | |
| num_layers: int = 1 | |
| attn_num_head_channels: int = 1 | |
| add_upsample: bool = True | |
| use_memory_efficient_attention: bool = False | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self) -> None: | |
| resnets = [] | |
| attentions = [] | |
| for i in range(self.num_layers): | |
| res_skip_channels = self.in_channels if (i == self.num_layers -1) else self.out_channels | |
| resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels | |
| res_block = ResnetBlockPseudo3D( | |
| in_channels = resnet_in_channels + res_skip_channels, | |
| out_channels = self.out_channels, | |
| dtype = self.dtype | |
| ) | |
| resnets.append(res_block) | |
| attn_block = TransformerPseudo3DModel( | |
| in_channels = self.out_channels, | |
| num_attention_heads = self.attn_num_head_channels, | |
| attention_head_dim = self.out_channels // self.attn_num_head_channels, | |
| num_layers = 1, | |
| use_memory_efficient_attention = self.use_memory_efficient_attention, | |
| dtype = self.dtype | |
| ) | |
| attentions.append(attn_block) | |
| self.resnets = resnets | |
| self.attentions = attentions | |
| if self.add_upsample: | |
| self.upsamplers_0 = UpsamplePseudo3D( | |
| out_channels = self.out_channels, | |
| dtype = self.dtype | |
| ) | |
| else: | |
| self.upsamplers_0 = None | |
| def __call__(self, | |
| hidden_states: jax.Array, | |
| res_hidden_states_tuple: Tuple[jax.Array, ...], | |
| temb: jax.Array, | |
| encoder_hidden_states: jax.Array | |
| ) -> jax.Array: | |
| for resnet, attn in zip(self.resnets, self.attentions): | |
| res_hidden_states = res_hidden_states_tuple[-1] | |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
| hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis = -1) | |
| hidden_states = resnet(hidden_states, temb) | |
| hidden_states = attn(hidden_states, encoder_hidden_states) | |
| if self.add_upsample: | |
| hidden_states = self.upsamplers_0(hidden_states) | |
| return hidden_states | |
| class UpBlockPseudo3D(nn.Module): | |
| in_channels: int | |
| out_channels: int | |
| prev_output_channels: int | |
| num_layers: int = 1 | |
| add_upsample: bool = True | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self) -> None: | |
| resnets = [] | |
| for i in range(self.num_layers): | |
| res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels | |
| resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels | |
| res_block = ResnetBlockPseudo3D( | |
| in_channels = resnet_in_channels + res_skip_channels, | |
| out_channels = self.out_channels, | |
| dtype = self.dtype | |
| ) | |
| resnets.append(res_block) | |
| self.resnets = resnets | |
| if self.add_upsample: | |
| self.upsamplers_0 = UpsamplePseudo3D( | |
| out_channels = self.out_channels, | |
| dtype = self.dtype | |
| ) | |
| else: | |
| self.upsamplers_0 = None | |
| def __call__(self, | |
| hidden_states: jax.Array, | |
| res_hidden_states_tuple: Tuple[jax.Array, ...], | |
| temb: jax.Array | |
| ) -> jax.Array: | |
| for resnet in self.resnets: | |
| res_hidden_states = res_hidden_states_tuple[-1] | |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
| hidden_states = jnp.concatenate([hidden_states, res_hidden_states], axis = -1) | |
| hidden_states = resnet(hidden_states, temb) | |
| if self.add_upsample: | |
| hidden_states = self.upsamplers_0(hidden_states) | |
| return hidden_states | |