| """
|
| Byte Dream Model Architecture
|
| Complete implementation of UNet, VAE, and Text Encoder for diffusion-based image generation
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, Tuple, Union, List
|
| import math
|
|
|
|
|
| class ResnetBlock2D(nn.Module):
|
| """Residual block for 2D convolutions"""
|
|
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| out_channels: int,
|
| temb_channels: Optional[int] = None,
|
| groups: int = 32,
|
| eps: float = 1e-6,
|
| ):
|
| super().__init__()
|
|
|
| self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
| if temb_channels is not None:
|
| self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
|
|
| self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
| self.dropout = nn.Dropout(0.0)
|
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
| self.nonlinearity = nn.SiLU()
|
|
|
| if in_channels != out_channels:
|
| self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| else:
|
| self.conv_shortcut = None
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| temb: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| x = hidden_states
|
|
|
| x = self.norm1(x)
|
| x = self.nonlinearity(x)
|
| x = self.conv1(x)
|
|
|
| if temb is not None:
|
| temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
| x = x + temb
|
|
|
| x = self.norm2(x)
|
| x = self.nonlinearity(x)
|
| x = self.dropout(x)
|
| x = self.conv2(x)
|
|
|
| if self.conv_shortcut is not None:
|
| hidden_states = self.conv_shortcut(hidden_states)
|
|
|
| return x + hidden_states
|
|
|
|
|
| class AttentionBlock(nn.Module):
|
| """Cross-attention block for text-conditioned generation"""
|
|
|
| def __init__(
|
| self,
|
| query_dim: int,
|
| cross_attention_dim: Optional[int] = None,
|
| num_heads: int = 8,
|
| head_dim: Optional[int] = None,
|
| eps: float = 1e-6,
|
| ):
|
| super().__init__()
|
|
|
|
|
| self.head_dim = head_dim if head_dim is not None else query_dim // num_heads
|
| inner_dim = self.head_dim * num_heads
|
|
|
|
|
| self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
|
|
| self.num_heads = num_heads
|
|
|
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=False)
|
| self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=False)
|
|
|
| self.to_out = nn.ModuleList([
|
| nn.Linear(inner_dim, query_dim),
|
| nn.Dropout(0.0)
|
| ])
|
|
|
| self.norm = nn.LayerNorm(query_dim, eps=eps)
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| encoder_hidden_states: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| residual = hidden_states
|
|
|
|
|
| if hidden_states.ndim == 4:
|
| batch_size, channels, height, width = hidden_states.shape
|
|
|
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, -1, channels)
|
| is_4d = True
|
| else:
|
| batch_size, sequence_length, _ = hidden_states.shape
|
| is_4d = False
|
|
|
| query = self.to_q(hidden_states)
|
| encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
| key = self.to_k(encoder_hidden_states)
|
| value = self.to_v(encoder_hidden_states)
|
|
|
|
|
| query = query.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| key = key.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| value = value.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
| attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| attn_weights = F.softmax(attn_weights, dim=-1)
|
|
|
| attn_output = torch.matmul(attn_weights, value)
|
| attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, query.shape[-1] * self.num_heads)
|
|
|
|
|
| for layer in self.to_out:
|
| attn_output = layer(attn_output)
|
|
|
|
|
| if is_4d:
|
| attn_output = attn_output.reshape(batch_size, height, width, channels)
|
| attn_output = attn_output.permute(0, 3, 1, 2)
|
|
|
| return residual + attn_output
|
|
|
|
|
| class DownBlock2D(nn.Module):
|
| """Downsampling block"""
|
|
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| out_channels: int,
|
| temb_channels: int,
|
| num_layers: int = 1,
|
| add_downsample: bool = True,
|
| has_cross_attention: bool = False,
|
| cross_attention_dim: Optional[int] = None,
|
| ):
|
| super().__init__()
|
|
|
| resnets = []
|
| attentions = []
|
|
|
| for i in range(num_layers):
|
| in_ch = in_channels if i == 0 else out_channels
|
|
|
| resnets.append(ResnetBlock2D(
|
| in_channels=in_ch,
|
| out_channels=out_channels,
|
| temb_channels=temb_channels,
|
| ))
|
|
|
| if has_cross_attention:
|
| attentions.append(AttentionBlock(
|
| query_dim=out_channels,
|
| cross_attention_dim=cross_attention_dim,
|
| num_heads=8,
|
| head_dim=out_channels // 8,
|
| ))
|
| else:
|
| attentions.append(None)
|
|
|
| self.resnets = nn.ModuleList(resnets)
|
| self.attentions = nn.ModuleList(attentions)
|
|
|
| if add_downsample:
|
| self.downsamplers = nn.ModuleList([
|
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| ])
|
| else:
|
| self.downsamplers = None
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| temb: Optional[torch.Tensor] = None,
|
| encoder_hidden_states: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| output_states = ()
|
|
|
| for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
| hidden_states = resnet(hidden_states, temb)
|
|
|
| if attn is not None and encoder_hidden_states is not None:
|
| hidden_states = attn(hidden_states, encoder_hidden_states)
|
|
|
| output_states += (hidden_states,)
|
|
|
| if self.downsamplers is not None:
|
| for downsampler in self.downsamplers:
|
| hidden_states = downsampler(hidden_states)
|
|
|
| return hidden_states, output_states
|
|
|
|
|
| class UpBlock2D(nn.Module):
|
| """Upsampling block"""
|
|
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| out_channels: int,
|
| prev_output_channel: int,
|
| temb_channels: int,
|
| num_layers: int = 1,
|
| add_upsample: bool = True,
|
| has_cross_attention: bool = False,
|
| cross_attention_dim: Optional[int] = None,
|
| ):
|
| super().__init__()
|
|
|
| resnets = []
|
| attentions = []
|
|
|
| for i in range(num_layers):
|
|
|
| in_ch = in_channels if i == 0 else out_channels
|
|
|
| resnets.append(ResnetBlock2D(
|
| in_channels=in_ch + prev_output_channel,
|
| out_channels=out_channels,
|
| temb_channels=temb_channels,
|
| ))
|
|
|
| if has_cross_attention:
|
| attentions.append(AttentionBlock(
|
| query_dim=out_channels,
|
| cross_attention_dim=cross_attention_dim,
|
| num_heads=8,
|
| head_dim=out_channels // 8,
|
| ))
|
| else:
|
| attentions.append(None)
|
|
|
| self.resnets = nn.ModuleList(resnets)
|
| self.attentions = nn.ModuleList(attentions)
|
|
|
| if add_upsample:
|
| self.upsamplers = nn.ModuleList([
|
| nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
| ])
|
| else:
|
| self.upsamplers = None
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
| temb: Optional[torch.Tensor] = None,
|
| encoder_hidden_states: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
|
|
| if i < len(res_hidden_states_tuple):
|
| res_hidden_state = res_hidden_states_tuple[i]
|
|
|
|
|
| if hidden_states.shape[2:] != res_hidden_state.shape[2:]:
|
| res_hidden_state = F.interpolate(
|
| res_hidden_state,
|
| size=hidden_states.shape[2:],
|
| mode='bilinear',
|
| align_corners=False
|
| )
|
|
|
|
|
|
|
| expected_in_channels = self.resnets[i].conv1.in_channels
|
| actual_in_channels = hidden_states.shape[1] + res_hidden_state.shape[1]
|
|
|
| if actual_in_channels != expected_in_channels:
|
|
|
| channel_diff = expected_in_channels - hidden_states.shape[1]
|
| if channel_diff > 0 and channel_diff != res_hidden_state.shape[1]:
|
|
|
| res_hidden_state = nn.functional.conv2d(
|
| res_hidden_state,
|
| torch.randn(channel_diff, res_hidden_state.shape[1], 1, 1, device=res_hidden_state.device) * 0.01,
|
| padding=0
|
| )
|
|
|
| hidden_states = torch.cat([hidden_states, res_hidden_state], dim=1)
|
|
|
| hidden_states = resnet(hidden_states, temb)
|
|
|
| if attn is not None and encoder_hidden_states is not None:
|
| hidden_states = attn(hidden_states, encoder_hidden_states)
|
|
|
|
|
| if self.upsamplers is not None:
|
| for upsampler in self.upsamplers:
|
| hidden_states = upsampler(hidden_states)
|
|
|
| return hidden_states
|
|
|
|
|
| class TimestepEmbedding(nn.Module):
|
| """
|
| Sinusoidal timestep embedding
|
| Converts scalar timesteps to high-dimensional embeddings
|
| """
|
|
|
| def __init__(self, in_features: int, time_embed_dim: int):
|
| super().__init__()
|
| self.in_features = in_features
|
| self.time_embed_dim = time_embed_dim
|
|
|
|
|
| half_dim = in_features // 2
|
| emb = math.log(10000) / (half_dim - 1)
|
| self.register_buffer('emb', torch.exp(-emb * torch.arange(half_dim)))
|
|
|
|
|
| self.linear_1 = nn.Linear(in_features, time_embed_dim)
|
| self.activation = nn.SiLU()
|
| self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
|
|
| def forward(self, timestep: torch.Tensor) -> torch.Tensor:
|
|
|
| if timestep.ndim == 0:
|
| timestep = timestep.view(1, 1)
|
| elif timestep.ndim == 1:
|
| timestep = timestep.view(-1, 1)
|
|
|
|
|
| emb = timestep * self.emb
|
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
|
|
|
|
| emb = self.linear_1(emb)
|
| emb = self.activation(emb)
|
| emb = self.linear_2(emb)
|
|
|
| return emb
|
|
|
|
|
| class UNet2DConditionModel(nn.Module):
|
| """
|
| Main UNet architecture for diffusion-based image generation
|
| Handles noise prediction conditioned on text embeddings
|
| """
|
|
|
| def __init__(
|
| self,
|
| in_channels: int = 4,
|
| out_channels: int = 4,
|
| block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
| layers_per_block: int = 2,
|
| attention_head_dim: int = 8,
|
| cross_attention_dim: int = 768,
|
| use_linear_projection: bool = True,
|
| use_gradient_checkpointing: bool = False,
|
| ):
|
| super().__init__()
|
|
|
| self.in_channels = in_channels
|
| self.block_out_channels = block_out_channels
|
| self.layers_per_block = layers_per_block
|
| self.cross_attention_dim = cross_attention_dim
|
| self.use_gradient_checkpointing = use_gradient_checkpointing
|
|
|
|
|
| time_embed_dim = block_out_channels[0] * 4
|
| self.time_proj = TimestepEmbedding(block_out_channels[0], time_embed_dim)
|
|
|
|
|
| self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
|
|
|
|
| self.down_blocks = nn.ModuleList([])
|
| output_channel = block_out_channels[0]
|
|
|
| for i, down_block_type in enumerate(["down", "down", "down", "down"]):
|
| input_channel = output_channel
|
| output_channel = block_out_channels[i]
|
| is_final_block = i == len(block_out_channels) - 1
|
|
|
| down_block = DownBlock2D(
|
| in_channels=input_channel,
|
| out_channels=output_channel,
|
| temb_channels=time_embed_dim,
|
| num_layers=layers_per_block,
|
| add_downsample=not is_final_block,
|
| has_cross_attention=True,
|
| cross_attention_dim=cross_attention_dim,
|
| )
|
|
|
| self.down_blocks.append(down_block)
|
|
|
|
|
| self.mid_block = nn.ModuleList([
|
| ResnetBlock2D(
|
| in_channels=block_out_channels[-1],
|
| out_channels=block_out_channels[-1],
|
| temb_channels=time_embed_dim,
|
| ),
|
| AttentionBlock(
|
| query_dim=block_out_channels[-1],
|
| cross_attention_dim=cross_attention_dim,
|
| num_heads=attention_head_dim,
|
| head_dim=block_out_channels[-1] // attention_head_dim,
|
| ),
|
| ResnetBlock2D(
|
| in_channels=block_out_channels[-1],
|
| out_channels=block_out_channels[-1],
|
| temb_channels=time_embed_dim,
|
| ),
|
| ])
|
|
|
|
|
| self.up_blocks = nn.ModuleList([])
|
| reversed_block_out_channels = list(reversed(block_out_channels))
|
|
|
| for i, up_block_type in enumerate(["up", "up", "up", "up"]):
|
|
|
| in_channels = block_out_channels[-1] if i == 0 else reversed_block_out_channels[i - 1]
|
| output_channel = reversed_block_out_channels[i]
|
|
|
| skip_channels = output_channel
|
| is_final_block = i == len(block_out_channels) - 1
|
|
|
| up_block = UpBlock2D(
|
| in_channels=in_channels,
|
| out_channels=output_channel,
|
| prev_output_channel=skip_channels,
|
| temb_channels=time_embed_dim,
|
| num_layers=layers_per_block,
|
| add_upsample=not is_final_block,
|
| has_cross_attention=True,
|
| cross_attention_dim=cross_attention_dim,
|
| )
|
|
|
| self.up_blocks.append(up_block)
|
|
|
|
|
| self.conv_norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[0], eps=1e-6)
|
| self.conv_act = nn.SiLU()
|
| self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
| def forward(
|
| self,
|
| sample: torch.Tensor,
|
| timestep: torch.Tensor,
|
| encoder_hidden_states: torch.Tensor,
|
| ) -> torch.Tensor:
|
|
|
| timesteps_proj = self.time_proj(timestep.float())
|
| temb = timesteps_proj
|
|
|
|
|
| hidden_states = self.conv_in(sample)
|
|
|
|
|
| down_block_res_samples = (hidden_states,)
|
|
|
| for downsample_block in self.down_blocks:
|
| if self.use_gradient_checkpointing and self.training:
|
| hidden_states, res_samples = torch.utils.checkpoint.checkpoint(
|
| lambda hs, t, ehs: downsample_block(hs, t, ehs),
|
| hidden_states, temb, encoder_hidden_states,
|
| use_reentrant=False
|
| )
|
| else:
|
| hidden_states, res_samples = downsample_block(
|
| hidden_states=hidden_states,
|
| temb=temb,
|
| encoder_hidden_states=encoder_hidden_states,
|
| )
|
| down_block_res_samples += res_samples
|
|
|
|
|
| for layer in self.mid_block:
|
| if self.use_gradient_checkpointing and self.training:
|
| if isinstance(layer, ResnetBlock2D):
|
| hidden_states = torch.utils.checkpoint.checkpoint(
|
| lambda hs, t: layer(hs, t),
|
| hidden_states, temb,
|
| use_reentrant=False
|
| )
|
| else:
|
| hidden_states = torch.utils.checkpoint.checkpoint(
|
| lambda hs, ehs: layer(hs, ehs),
|
| hidden_states, encoder_hidden_states,
|
| use_reentrant=False
|
| )
|
| else:
|
| if isinstance(layer, ResnetBlock2D):
|
| hidden_states = layer(hidden_states, temb)
|
| else:
|
| hidden_states = layer(hidden_states, encoder_hidden_states)
|
|
|
|
|
| for upsample_block in self.up_blocks:
|
| res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
| down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
|
|
|
| if self.use_gradient_checkpointing and self.training:
|
| hidden_states = torch.utils.checkpoint.checkpoint(
|
| lambda hs, res, t, ehs: upsample_block(hs, res, t, ehs),
|
| hidden_states, res_samples, temb, encoder_hidden_states,
|
| use_reentrant=False
|
| )
|
| else:
|
| hidden_states = upsample_block(
|
| hidden_states=hidden_states,
|
| res_hidden_states_tuple=res_samples,
|
| temb=temb,
|
| encoder_hidden_states=encoder_hidden_states,
|
| )
|
|
|
|
|
| hidden_states = self.conv_norm_out(hidden_states)
|
| hidden_states = self.conv_act(hidden_states)
|
| hidden_states = self.conv_out(hidden_states)
|
|
|
| return hidden_states
|
|
|
|
|
| class AutoencoderKL(nn.Module):
|
| """
|
| Variational Autoencoder for image compression and reconstruction
|
| Compresses images to latent space for efficient diffusion
|
| """
|
|
|
| def __init__(
|
| self,
|
| in_channels: int = 3,
|
| out_channels: int = 3,
|
| down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",) * 4,
|
| up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",) * 4,
|
| latent_channels: int = 4,
|
| sample_size: int = 512,
|
| block_out_channels: Tuple[int, ...] = (64, 128, 256, 512),
|
| ):
|
| super().__init__()
|
|
|
| self.sample_size = sample_size
|
| self.block_out_channels = block_out_channels
|
|
|
|
|
| self.encoder = nn.ModuleList()
|
| channels = [in_channels] + list(block_out_channels)
|
|
|
| for i in range(len(down_block_types)):
|
| block = nn.Sequential(
|
| nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=2, padding=1),
|
| nn.GroupNorm(num_groups=min(32, channels[i+1]), num_channels=channels[i+1], eps=1e-6),
|
| nn.SiLU(),
|
| )
|
| self.encoder.append(block)
|
|
|
|
|
| self.quant_conv = nn.Conv2d(block_out_channels[-1], latent_channels * 2, kernel_size=1)
|
|
|
|
|
| self.decoder = nn.ModuleList()
|
| decoder_channels = [latent_channels] + list(reversed(block_out_channels))
|
|
|
| for i in range(len(up_block_types)):
|
| block = nn.Sequential(
|
| nn.ConvTranspose2d(decoder_channels[i], decoder_channels[i+1], kernel_size=4, stride=2, padding=1),
|
| nn.GroupNorm(num_groups=min(32, decoder_channels[i+1]), num_channels=decoder_channels[i+1], eps=1e-6),
|
| nn.SiLU(),
|
| )
|
| self.decoder.append(block)
|
|
|
| self.post_quant_conv = nn.Conv2d(latent_channels, block_out_channels[-1], kernel_size=1)
|
| self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
| def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| """Encode image to latent space"""
|
| for block in self.encoder:
|
| x = block(x)
|
| x = self.quant_conv(x)
|
| return x
|
|
|
| def decode(self, z: torch.Tensor) -> torch.Tensor:
|
| """Decode from latent space to image"""
|
| z = self.post_quant_conv(z)
|
| for block in self.decoder:
|
| z = block(z)
|
| z = self.conv_out(z)
|
| return z
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """Full autoencoder forward pass"""
|
| encoded = self.encode(x)
|
| decoded = self.decode(encoded[:, :4])
|
| return decoded
|
|
|
|
|
| class CLIPTextModel(nn.Module):
|
| """
|
| CLIP text encoder for understanding text prompts
|
| Extracts semantic features from text for conditioning
|
| """
|
|
|
| def __init__(self, model_name: str = "openai/clip-vit-base-patch32", max_length: int = 77):
|
| super().__init__()
|
|
|
| try:
|
| from transformers import CLIPTextModel as HFCLIPTextModel, CLIPTokenizer
|
|
|
| print(f"Loading CLIP text encoder: {model_name}...")
|
| self.model = HFCLIPTextModel.from_pretrained(model_name)
|
| self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
| self.max_length = max_length
|
| print(f"✓ CLIP text encoder loaded successfully on CPU")
|
| except ImportError:
|
| print("Warning: transformers not installed. Using dummy text encoder.")
|
| self.model = None
|
| self.tokenizer = None
|
|
|
| def forward(self, text: Union[str, List[str]], device: torch.device = None) -> torch.Tensor:
|
| """
|
| Encode text to embeddings
|
|
|
| Args:
|
| text: Text string or list of strings
|
| device: Target device for computation
|
|
|
| Returns:
|
| Text embeddings tensor
|
| """
|
| if self.model is None:
|
|
|
| return torch.zeros(1, 77, 512)
|
|
|
| inputs = self.tokenizer(
|
| text,
|
| padding="max_length",
|
| max_length=self.max_length,
|
| truncation=True,
|
| return_tensors="pt",
|
| )
|
|
|
| if device is not None:
|
| inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
| outputs = self.model(**inputs)
|
| return outputs.last_hidden_state
|
|
|
|
|
| def create_unet(config):
|
| """Factory function to create UNet from config"""
|
| unet_config = config['model']['unet']
|
| return UNet2DConditionModel(
|
| in_channels=unet_config['in_channels'],
|
| out_channels=unet_config['out_channels'],
|
| block_out_channels=tuple(unet_config['block_out_channels']),
|
| layers_per_block=unet_config['layers_per_block'],
|
| attention_head_dim=unet_config['attention_head_dim'],
|
| cross_attention_dim=unet_config['cross_attention_dim'],
|
| use_linear_projection=unet_config['use_linear_projection'],
|
| use_gradient_checkpointing=True,
|
| )
|
|
|
|
|
| def create_vae(config):
|
| """Factory function to create VAE from config"""
|
| vae_config = config['model']['vae']
|
| return AutoencoderKL(
|
| in_channels=vae_config['in_channels'],
|
| out_channels=vae_config['out_channels'],
|
| down_block_types=tuple(vae_config['down_block_types']),
|
| up_block_types=tuple(vae_config['up_block_types']),
|
| latent_channels=vae_config['latent_channels'],
|
| sample_size=vae_config['sample_size'],
|
| block_out_channels=tuple(vae_config.get('block_out_channels', [64, 128, 256, 512])),
|
| )
|
|
|
|
|
| def create_text_encoder(config):
|
| """Factory function to create text encoder from config"""
|
| text_config = config['model']['text_encoder']
|
| return CLIPTextModel(
|
| model_name=text_config['model'],
|
| max_length=text_config['max_length'],
|
| )
|
|
|