""" Byte Dream Model Architecture Complete implementation of UNet, VAE, and Text Encoder for diffusion-based image generation """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union, List import math class ResnetBlock2D(nn.Module): """Residual block for 2D convolutions""" def __init__( self, in_channels: int, out_channels: int, temb_channels: Optional[int] = None, groups: int = 32, eps: float = 1e-6, ): super().__init__() self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: self.time_emb_proj = nn.Linear(temb_channels, out_channels) self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) self.dropout = nn.Dropout(0.0) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = nn.SiLU() if in_channels != out_channels: self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) else: self.conv_shortcut = None def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = hidden_states x = self.norm1(x) x = self.nonlinearity(x) x = self.conv1(x) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] x = x + temb x = self.norm2(x) x = self.nonlinearity(x) x = self.dropout(x) x = self.conv2(x) if self.conv_shortcut is not None: hidden_states = self.conv_shortcut(hidden_states) return x + hidden_states class AttentionBlock(nn.Module): """Cross-attention block for text-conditioned generation""" def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, num_heads: int = 8, head_dim: Optional[int] = None, eps: float = 1e-6, ): super().__init__() # Use head_dim if provided, otherwise calculate from query_dim and num_heads self.head_dim = head_dim if head_dim is not None else query_dim // num_heads inner_dim = self.head_dim * num_heads # Use cross_attention_dim if provided, otherwise use query_dim (self-attention) self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.num_heads = num_heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=False) self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=False) self.to_out = nn.ModuleList([ nn.Linear(inner_dim, query_dim), nn.Dropout(0.0) ]) self.norm = nn.LayerNorm(query_dim, eps=eps) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states # Handle 4D inputs (batch, channels, height, width) if hidden_states.ndim == 4: batch_size, channels, height, width = hidden_states.shape # Reshape to (batch, seq_len, channels) where seq_len = height * width hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, -1, channels) is_4d = True else: batch_size, sequence_length, _ = hidden_states.shape is_4d = False query = self.to_q(hidden_states) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) # Multi-head attention query = query.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) key = key.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) value = value.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # Scaled dot-product attention attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) attn_weights = F.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, query.shape[-1] * self.num_heads) # Output projection for layer in self.to_out: attn_output = layer(attn_output) # Reshape back to 4D if input was 4D if is_4d: attn_output = attn_output.reshape(batch_size, height, width, channels) attn_output = attn_output.permute(0, 3, 1, 2) return residual + attn_output class DownBlock2D(nn.Module): """Downsampling block""" def __init__( self, in_channels: int, out_channels: int, temb_channels: int, num_layers: int = 1, add_downsample: bool = True, has_cross_attention: bool = False, cross_attention_dim: Optional[int] = None, ): super().__init__() resnets = [] attentions = [] for i in range(num_layers): in_ch = in_channels if i == 0 else out_channels resnets.append(ResnetBlock2D( in_channels=in_ch, out_channels=out_channels, temb_channels=temb_channels, )) if has_cross_attention: attentions.append(AttentionBlock( query_dim=out_channels, cross_attention_dim=cross_attention_dim, num_heads=8, head_dim=out_channels // 8, )) else: attentions.append(None) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_downsample: self.downsamplers = nn.ModuleList([ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) ]) else: self.downsamplers = None def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: output_states = () for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): hidden_states = resnet(hidden_states, temb) if attn is not None and encoder_hidden_states is not None: hidden_states = attn(hidden_states, encoder_hidden_states) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states class UpBlock2D(nn.Module): """Upsampling block""" def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, num_layers: int = 1, add_upsample: bool = True, has_cross_attention: bool = False, cross_attention_dim: Optional[int] = None, ): super().__init__() resnets = [] attentions = [] for i in range(num_layers): # All layers receive skip connections in_ch = in_channels if i == 0 else out_channels resnets.append(ResnetBlock2D( in_channels=in_ch + prev_output_channel, out_channels=out_channels, temb_channels=temb_channels, )) if has_cross_attention: attentions.append(AttentionBlock( query_dim=out_channels, cross_attention_dim=cross_attention_dim, num_heads=8, head_dim=out_channels // 8, )) else: attentions.append(None) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_upsample: self.upsamplers = nn.ModuleList([ nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1) ]) else: self.upsamplers = None def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): # Skip connection from U-Net downsampling path if i < len(res_hidden_states_tuple): res_hidden_state = res_hidden_states_tuple[i] # Ensure spatial dimensions match if hidden_states.shape[2:] != res_hidden_state.shape[2:]: res_hidden_state = F.interpolate( res_hidden_state, size=hidden_states.shape[2:], mode='bilinear', align_corners=False ) # Ensure channel dimensions match # The resnet expects input = hidden_states + res_hidden_state concatenated expected_in_channels = self.resnets[i].conv1.in_channels actual_in_channels = hidden_states.shape[1] + res_hidden_state.shape[1] if actual_in_channels != expected_in_channels: # Project skip connection to match expected channels channel_diff = expected_in_channels - hidden_states.shape[1] if channel_diff > 0 and channel_diff != res_hidden_state.shape[1]: # Need to project skip connection res_hidden_state = nn.functional.conv2d( res_hidden_state, torch.randn(channel_diff, res_hidden_state.shape[1], 1, 1, device=res_hidden_state.device) * 0.01, padding=0 ) hidden_states = torch.cat([hidden_states, res_hidden_state], dim=1) hidden_states = resnet(hidden_states, temb) if attn is not None and encoder_hidden_states is not None: hidden_states = attn(hidden_states, encoder_hidden_states) # Upsample AFTER all resnet layers if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class TimestepEmbedding(nn.Module): """ Sinusoidal timestep embedding Converts scalar timesteps to high-dimensional embeddings """ def __init__(self, in_features: int, time_embed_dim: int): super().__init__() self.in_features = in_features self.time_embed_dim = time_embed_dim # Create sinusoidal embedding layers half_dim = in_features // 2 emb = math.log(10000) / (half_dim - 1) self.register_buffer('emb', torch.exp(-emb * torch.arange(half_dim))) # Projection layers self.linear_1 = nn.Linear(in_features, time_embed_dim) self.activation = nn.SiLU() self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) def forward(self, timestep: torch.Tensor) -> torch.Tensor: # Ensure timestep has correct shape [batch_size, 1] if timestep.ndim == 0: timestep = timestep.view(1, 1) elif timestep.ndim == 1: timestep = timestep.view(-1, 1) # Apply sinusoidal embedding emb = timestep * self.emb emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # Project through MLP emb = self.linear_1(emb) emb = self.activation(emb) emb = self.linear_2(emb) return emb class UNet2DConditionModel(nn.Module): """ Main UNet architecture for diffusion-based image generation Handles noise prediction conditioned on text embeddings """ def __init__( self, in_channels: int = 4, out_channels: int = 4, block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, attention_head_dim: int = 8, cross_attention_dim: int = 768, use_linear_projection: bool = True, use_gradient_checkpointing: bool = False, ): super().__init__() self.in_channels = in_channels self.block_out_channels = block_out_channels self.layers_per_block = layers_per_block self.cross_attention_dim = cross_attention_dim self.use_gradient_checkpointing = use_gradient_checkpointing # Time embedding time_embed_dim = block_out_channels[0] * 4 self.time_proj = TimestepEmbedding(block_out_channels[0], time_embed_dim) # Input convolution self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) # Down blocks self.down_blocks = nn.ModuleList([]) output_channel = block_out_channels[0] for i, down_block_type in enumerate(["down", "down", "down", "down"]): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = DownBlock2D( in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, num_layers=layers_per_block, add_downsample=not is_final_block, has_cross_attention=True, cross_attention_dim=cross_attention_dim, ) self.down_blocks.append(down_block) # Middle blocks self.mid_block = nn.ModuleList([ ResnetBlock2D( in_channels=block_out_channels[-1], out_channels=block_out_channels[-1], temb_channels=time_embed_dim, ), AttentionBlock( query_dim=block_out_channels[-1], cross_attention_dim=cross_attention_dim, num_heads=attention_head_dim, head_dim=block_out_channels[-1] // attention_head_dim, ), ResnetBlock2D( in_channels=block_out_channels[-1], out_channels=block_out_channels[-1], temb_channels=time_embed_dim, ), ]) # Up blocks self.up_blocks = nn.ModuleList([]) reversed_block_out_channels = list(reversed(block_out_channels)) for i, up_block_type in enumerate(["up", "up", "up", "up"]): # Input channels: from previous up block (or mid block for first up block) in_channels = block_out_channels[-1] if i == 0 else reversed_block_out_channels[i - 1] output_channel = reversed_block_out_channels[i] # Skip connections have same channels as up block output skip_channels = output_channel is_final_block = i == len(block_out_channels) - 1 up_block = UpBlock2D( in_channels=in_channels, out_channels=output_channel, prev_output_channel=skip_channels, temb_channels=time_embed_dim, num_layers=layers_per_block, # Same as down blocks add_upsample=not is_final_block, has_cross_attention=True, cross_attention_dim=cross_attention_dim, ) self.up_blocks.append(up_block) # Output self.conv_norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[0], eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1) def forward( self, sample: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, ) -> torch.Tensor: # Time embedding - convert timestep to float for the linear layers timesteps_proj = self.time_proj(timestep.float()) temb = timesteps_proj # Initial convolution hidden_states = self.conv_in(sample) # Down sampling path down_block_res_samples = (hidden_states,) for downsample_block in self.down_blocks: if self.use_gradient_checkpointing and self.training: hidden_states, res_samples = torch.utils.checkpoint.checkpoint( lambda hs, t, ehs: downsample_block(hs, t, ehs), hidden_states, temb, encoder_hidden_states, use_reentrant=False ) else: hidden_states, res_samples = downsample_block( hidden_states=hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states, ) down_block_res_samples += res_samples # Middle for layer in self.mid_block: if self.use_gradient_checkpointing and self.training: if isinstance(layer, ResnetBlock2D): hidden_states = torch.utils.checkpoint.checkpoint( lambda hs, t: layer(hs, t), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( lambda hs, ehs: layer(hs, ehs), hidden_states, encoder_hidden_states, use_reentrant=False ) else: if isinstance(layer, ResnetBlock2D): hidden_states = layer(hidden_states, temb) else: hidden_states = layer(hidden_states, encoder_hidden_states) # Up sampling path for upsample_block in self.up_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets):] down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)] if self.use_gradient_checkpointing and self.training: hidden_states = torch.utils.checkpoint.checkpoint( lambda hs, res, t, ehs: upsample_block(hs, res, t, ehs), hidden_states, res_samples, temb, encoder_hidden_states, use_reentrant=False ) else: hidden_states = upsample_block( hidden_states=hidden_states, res_hidden_states_tuple=res_samples, temb=temb, encoder_hidden_states=encoder_hidden_states, ) # Output hidden_states = self.conv_norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states class AutoencoderKL(nn.Module): """ Variational Autoencoder for image compression and reconstruction Compresses images to latent space for efficient diffusion """ def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",) * 4, up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",) * 4, latent_channels: int = 4, sample_size: int = 512, block_out_channels: Tuple[int, ...] = (64, 128, 256, 512), ): super().__init__() self.sample_size = sample_size self.block_out_channels = block_out_channels # Encoder - using reduced channels for memory efficiency self.encoder = nn.ModuleList() channels = [in_channels] + list(block_out_channels) for i in range(len(down_block_types)): block = nn.Sequential( nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=2, padding=1), nn.GroupNorm(num_groups=min(32, channels[i+1]), num_channels=channels[i+1], eps=1e-6), nn.SiLU(), ) self.encoder.append(block) # Latent space projection self.quant_conv = nn.Conv2d(block_out_channels[-1], latent_channels * 2, kernel_size=1) # Decoder - using reduced channels for memory efficiency self.decoder = nn.ModuleList() decoder_channels = [latent_channels] + list(reversed(block_out_channels)) for i in range(len(up_block_types)): block = nn.Sequential( nn.ConvTranspose2d(decoder_channels[i], decoder_channels[i+1], kernel_size=4, stride=2, padding=1), nn.GroupNorm(num_groups=min(32, decoder_channels[i+1]), num_channels=decoder_channels[i+1], eps=1e-6), nn.SiLU(), ) self.decoder.append(block) self.post_quant_conv = nn.Conv2d(latent_channels, block_out_channels[-1], kernel_size=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode image to latent space""" for block in self.encoder: x = block(x) x = self.quant_conv(x) return x def decode(self, z: torch.Tensor) -> torch.Tensor: """Decode from latent space to image""" z = self.post_quant_conv(z) for block in self.decoder: z = block(z) z = self.conv_out(z) return z def forward(self, x: torch.Tensor) -> torch.Tensor: """Full autoencoder forward pass""" encoded = self.encode(x) decoded = self.decode(encoded[:, :4]) # Use first 4 channels return decoded class CLIPTextModel(nn.Module): """ CLIP text encoder for understanding text prompts Extracts semantic features from text for conditioning """ def __init__(self, model_name: str = "openai/clip-vit-base-patch32", max_length: int = 77): super().__init__() try: from transformers import CLIPTextModel as HFCLIPTextModel, CLIPTokenizer print(f"Loading CLIP text encoder: {model_name}...") self.model = HFCLIPTextModel.from_pretrained(model_name) self.tokenizer = CLIPTokenizer.from_pretrained(model_name) self.max_length = max_length print(f"✓ CLIP text encoder loaded successfully on CPU") except ImportError: print("Warning: transformers not installed. Using dummy text encoder.") self.model = None self.tokenizer = None def forward(self, text: Union[str, List[str]], device: torch.device = None) -> torch.Tensor: """ Encode text to embeddings Args: text: Text string or list of strings device: Target device for computation Returns: Text embeddings tensor """ if self.model is None: # Dummy implementation if transformers not available return torch.zeros(1, 77, 512) inputs = self.tokenizer( text, padding="max_length", max_length=self.max_length, truncation=True, return_tensors="pt", ) if device is not None: inputs = {k: v.to(device) for k, v in inputs.items()} outputs = self.model(**inputs) return outputs.last_hidden_state def create_unet(config): """Factory function to create UNet from config""" unet_config = config['model']['unet'] return UNet2DConditionModel( in_channels=unet_config['in_channels'], out_channels=unet_config['out_channels'], block_out_channels=tuple(unet_config['block_out_channels']), layers_per_block=unet_config['layers_per_block'], attention_head_dim=unet_config['attention_head_dim'], cross_attention_dim=unet_config['cross_attention_dim'], use_linear_projection=unet_config['use_linear_projection'], use_gradient_checkpointing=True, # Enable for memory efficiency ) def create_vae(config): """Factory function to create VAE from config""" vae_config = config['model']['vae'] return AutoencoderKL( in_channels=vae_config['in_channels'], out_channels=vae_config['out_channels'], down_block_types=tuple(vae_config['down_block_types']), up_block_types=tuple(vae_config['up_block_types']), latent_channels=vae_config['latent_channels'], sample_size=vae_config['sample_size'], block_out_channels=tuple(vae_config.get('block_out_channels', [64, 128, 256, 512])), ) def create_text_encoder(config): """Factory function to create text encoder from config""" text_config = config['model']['text_encoder'] return CLIPTextModel( model_name=text_config['model'], max_length=text_config['max_length'], )