ByteDream / bytedream /model.py
Enzo8930302's picture
Upload bytedream/model.py with huggingface_hub
74d320c verified
"""
Byte Dream Model Architecture
Complete implementation of UNet, VAE, and Text Encoder for diffusion-based image generation
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union, List
import math
class ResnetBlock2D(nn.Module):
"""Residual block for 2D convolutions"""
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: Optional[int] = None,
groups: int = 32,
eps: float = 1e-6,
):
super().__init__()
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.dropout = nn.Dropout(0.0)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = nn.SiLU()
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = hidden_states
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
x = x + temb
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.conv2(x)
if self.conv_shortcut is not None:
hidden_states = self.conv_shortcut(hidden_states)
return x + hidden_states
class AttentionBlock(nn.Module):
"""Cross-attention block for text-conditioned generation"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
num_heads: int = 8,
head_dim: Optional[int] = None,
eps: float = 1e-6,
):
super().__init__()
# Use head_dim if provided, otherwise calculate from query_dim and num_heads
self.head_dim = head_dim if head_dim is not None else query_dim // num_heads
inner_dim = self.head_dim * num_heads
# Use cross_attention_dim if provided, otherwise use query_dim (self-attention)
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.num_heads = num_heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=False)
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=False)
self.to_out = nn.ModuleList([
nn.Linear(inner_dim, query_dim),
nn.Dropout(0.0)
])
self.norm = nn.LayerNorm(query_dim, eps=eps)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
# Handle 4D inputs (batch, channels, height, width)
if hidden_states.ndim == 4:
batch_size, channels, height, width = hidden_states.shape
# Reshape to (batch, seq_len, channels) where seq_len = height * width
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, -1, channels)
is_4d = True
else:
batch_size, sequence_length, _ = hidden_states.shape
is_4d = False
query = self.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
# Multi-head attention
query = query.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = key.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value = value.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, query.shape[-1] * self.num_heads)
# Output projection
for layer in self.to_out:
attn_output = layer(attn_output)
# Reshape back to 4D if input was 4D
if is_4d:
attn_output = attn_output.reshape(batch_size, height, width, channels)
attn_output = attn_output.permute(0, 3, 1, 2)
return residual + attn_output
class DownBlock2D(nn.Module):
"""Downsampling block"""
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
num_layers: int = 1,
add_downsample: bool = True,
has_cross_attention: bool = False,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
resnets = []
attentions = []
for i in range(num_layers):
in_ch = in_channels if i == 0 else out_channels
resnets.append(ResnetBlock2D(
in_channels=in_ch,
out_channels=out_channels,
temb_channels=temb_channels,
))
if has_cross_attention:
attentions.append(AttentionBlock(
query_dim=out_channels,
cross_attention_dim=cross_attention_dim,
num_heads=8,
head_dim=out_channels // 8,
))
else:
attentions.append(None)
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
if add_downsample:
self.downsamplers = nn.ModuleList([
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
])
else:
self.downsamplers = None
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
output_states = ()
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
hidden_states = resnet(hidden_states, temb)
if attn is not None and encoder_hidden_states is not None:
hidden_states = attn(hidden_states, encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states, output_states
class UpBlock2D(nn.Module):
"""Upsampling block"""
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
num_layers: int = 1,
add_upsample: bool = True,
has_cross_attention: bool = False,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
resnets = []
attentions = []
for i in range(num_layers):
# All layers receive skip connections
in_ch = in_channels if i == 0 else out_channels
resnets.append(ResnetBlock2D(
in_channels=in_ch + prev_output_channel,
out_channels=out_channels,
temb_channels=temb_channels,
))
if has_cross_attention:
attentions.append(AttentionBlock(
query_dim=out_channels,
cross_attention_dim=cross_attention_dim,
num_heads=8,
head_dim=out_channels // 8,
))
else:
attentions.append(None)
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
if add_upsample:
self.upsamplers = nn.ModuleList([
nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
])
else:
self.upsamplers = None
def forward(
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
# Skip connection from U-Net downsampling path
if i < len(res_hidden_states_tuple):
res_hidden_state = res_hidden_states_tuple[i]
# Ensure spatial dimensions match
if hidden_states.shape[2:] != res_hidden_state.shape[2:]:
res_hidden_state = F.interpolate(
res_hidden_state,
size=hidden_states.shape[2:],
mode='bilinear',
align_corners=False
)
# Ensure channel dimensions match
# The resnet expects input = hidden_states + res_hidden_state concatenated
expected_in_channels = self.resnets[i].conv1.in_channels
actual_in_channels = hidden_states.shape[1] + res_hidden_state.shape[1]
if actual_in_channels != expected_in_channels:
# Project skip connection to match expected channels
channel_diff = expected_in_channels - hidden_states.shape[1]
if channel_diff > 0 and channel_diff != res_hidden_state.shape[1]:
# Need to project skip connection
res_hidden_state = nn.functional.conv2d(
res_hidden_state,
torch.randn(channel_diff, res_hidden_state.shape[1], 1, 1, device=res_hidden_state.device) * 0.01,
padding=0
)
hidden_states = torch.cat([hidden_states, res_hidden_state], dim=1)
hidden_states = resnet(hidden_states, temb)
if attn is not None and encoder_hidden_states is not None:
hidden_states = attn(hidden_states, encoder_hidden_states)
# Upsample AFTER all resnet layers
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class TimestepEmbedding(nn.Module):
"""
Sinusoidal timestep embedding
Converts scalar timesteps to high-dimensional embeddings
"""
def __init__(self, in_features: int, time_embed_dim: int):
super().__init__()
self.in_features = in_features
self.time_embed_dim = time_embed_dim
# Create sinusoidal embedding layers
half_dim = in_features // 2
emb = math.log(10000) / (half_dim - 1)
self.register_buffer('emb', torch.exp(-emb * torch.arange(half_dim)))
# Projection layers
self.linear_1 = nn.Linear(in_features, time_embed_dim)
self.activation = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
# Ensure timestep has correct shape [batch_size, 1]
if timestep.ndim == 0:
timestep = timestep.view(1, 1)
elif timestep.ndim == 1:
timestep = timestep.view(-1, 1)
# Apply sinusoidal embedding
emb = timestep * self.emb
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# Project through MLP
emb = self.linear_1(emb)
emb = self.activation(emb)
emb = self.linear_2(emb)
return emb
class UNet2DConditionModel(nn.Module):
"""
Main UNet architecture for diffusion-based image generation
Handles noise prediction conditioned on text embeddings
"""
def __init__(
self,
in_channels: int = 4,
out_channels: int = 4,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
attention_head_dim: int = 8,
cross_attention_dim: int = 768,
use_linear_projection: bool = True,
use_gradient_checkpointing: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.block_out_channels = block_out_channels
self.layers_per_block = layers_per_block
self.cross_attention_dim = cross_attention_dim
self.use_gradient_checkpointing = use_gradient_checkpointing
# Time embedding
time_embed_dim = block_out_channels[0] * 4
self.time_proj = TimestepEmbedding(block_out_channels[0], time_embed_dim)
# Input convolution
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# Down blocks
self.down_blocks = nn.ModuleList([])
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(["down", "down", "down", "down"]):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
num_layers=layers_per_block,
add_downsample=not is_final_block,
has_cross_attention=True,
cross_attention_dim=cross_attention_dim,
)
self.down_blocks.append(down_block)
# Middle blocks
self.mid_block = nn.ModuleList([
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
),
AttentionBlock(
query_dim=block_out_channels[-1],
cross_attention_dim=cross_attention_dim,
num_heads=attention_head_dim,
head_dim=block_out_channels[-1] // attention_head_dim,
),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
),
])
# Up blocks
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(block_out_channels))
for i, up_block_type in enumerate(["up", "up", "up", "up"]):
# Input channels: from previous up block (or mid block for first up block)
in_channels = block_out_channels[-1] if i == 0 else reversed_block_out_channels[i - 1]
output_channel = reversed_block_out_channels[i]
# Skip connections have same channels as up block output
skip_channels = output_channel
is_final_block = i == len(block_out_channels) - 1
up_block = UpBlock2D(
in_channels=in_channels,
out_channels=output_channel,
prev_output_channel=skip_channels,
temb_channels=time_embed_dim,
num_layers=layers_per_block, # Same as down blocks
add_upsample=not is_final_block,
has_cross_attention=True,
cross_attention_dim=cross_attention_dim,
)
self.up_blocks.append(up_block)
# Output
self.conv_norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[0], eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> torch.Tensor:
# Time embedding - convert timestep to float for the linear layers
timesteps_proj = self.time_proj(timestep.float())
temb = timesteps_proj
# Initial convolution
hidden_states = self.conv_in(sample)
# Down sampling path
down_block_res_samples = (hidden_states,)
for downsample_block in self.down_blocks:
if self.use_gradient_checkpointing and self.training:
hidden_states, res_samples = torch.utils.checkpoint.checkpoint(
lambda hs, t, ehs: downsample_block(hs, t, ehs),
hidden_states, temb, encoder_hidden_states,
use_reentrant=False
)
else:
hidden_states, res_samples = downsample_block(
hidden_states=hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
)
down_block_res_samples += res_samples
# Middle
for layer in self.mid_block:
if self.use_gradient_checkpointing and self.training:
if isinstance(layer, ResnetBlock2D):
hidden_states = torch.utils.checkpoint.checkpoint(
lambda hs, t: layer(hs, t),
hidden_states, temb,
use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
lambda hs, ehs: layer(hs, ehs),
hidden_states, encoder_hidden_states,
use_reentrant=False
)
else:
if isinstance(layer, ResnetBlock2D):
hidden_states = layer(hidden_states, temb)
else:
hidden_states = layer(hidden_states, encoder_hidden_states)
# Up sampling path
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
if self.use_gradient_checkpointing and self.training:
hidden_states = torch.utils.checkpoint.checkpoint(
lambda hs, res, t, ehs: upsample_block(hs, res, t, ehs),
hidden_states, res_samples, temb, encoder_hidden_states,
use_reentrant=False
)
else:
hidden_states = upsample_block(
hidden_states=hidden_states,
res_hidden_states_tuple=res_samples,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
)
# Output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKL(nn.Module):
"""
Variational Autoencoder for image compression and reconstruction
Compresses images to latent space for efficient diffusion
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",) * 4,
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",) * 4,
latent_channels: int = 4,
sample_size: int = 512,
block_out_channels: Tuple[int, ...] = (64, 128, 256, 512),
):
super().__init__()
self.sample_size = sample_size
self.block_out_channels = block_out_channels
# Encoder - using reduced channels for memory efficiency
self.encoder = nn.ModuleList()
channels = [in_channels] + list(block_out_channels)
for i in range(len(down_block_types)):
block = nn.Sequential(
nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=2, padding=1),
nn.GroupNorm(num_groups=min(32, channels[i+1]), num_channels=channels[i+1], eps=1e-6),
nn.SiLU(),
)
self.encoder.append(block)
# Latent space projection
self.quant_conv = nn.Conv2d(block_out_channels[-1], latent_channels * 2, kernel_size=1)
# Decoder - using reduced channels for memory efficiency
self.decoder = nn.ModuleList()
decoder_channels = [latent_channels] + list(reversed(block_out_channels))
for i in range(len(up_block_types)):
block = nn.Sequential(
nn.ConvTranspose2d(decoder_channels[i], decoder_channels[i+1], kernel_size=4, stride=2, padding=1),
nn.GroupNorm(num_groups=min(32, decoder_channels[i+1]), num_channels=decoder_channels[i+1], eps=1e-6),
nn.SiLU(),
)
self.decoder.append(block)
self.post_quant_conv = nn.Conv2d(latent_channels, block_out_channels[-1], kernel_size=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode image to latent space"""
for block in self.encoder:
x = block(x)
x = self.quant_conv(x)
return x
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Decode from latent space to image"""
z = self.post_quant_conv(z)
for block in self.decoder:
z = block(z)
z = self.conv_out(z)
return z
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Full autoencoder forward pass"""
encoded = self.encode(x)
decoded = self.decode(encoded[:, :4]) # Use first 4 channels
return decoded
class CLIPTextModel(nn.Module):
"""
CLIP text encoder for understanding text prompts
Extracts semantic features from text for conditioning
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch32", max_length: int = 77):
super().__init__()
try:
from transformers import CLIPTextModel as HFCLIPTextModel, CLIPTokenizer
print(f"Loading CLIP text encoder: {model_name}...")
self.model = HFCLIPTextModel.from_pretrained(model_name)
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.max_length = max_length
print(f"✓ CLIP text encoder loaded successfully on CPU")
except ImportError:
print("Warning: transformers not installed. Using dummy text encoder.")
self.model = None
self.tokenizer = None
def forward(self, text: Union[str, List[str]], device: torch.device = None) -> torch.Tensor:
"""
Encode text to embeddings
Args:
text: Text string or list of strings
device: Target device for computation
Returns:
Text embeddings tensor
"""
if self.model is None:
# Dummy implementation if transformers not available
return torch.zeros(1, 77, 512)
inputs = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
if device is not None:
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = self.model(**inputs)
return outputs.last_hidden_state
def create_unet(config):
"""Factory function to create UNet from config"""
unet_config = config['model']['unet']
return UNet2DConditionModel(
in_channels=unet_config['in_channels'],
out_channels=unet_config['out_channels'],
block_out_channels=tuple(unet_config['block_out_channels']),
layers_per_block=unet_config['layers_per_block'],
attention_head_dim=unet_config['attention_head_dim'],
cross_attention_dim=unet_config['cross_attention_dim'],
use_linear_projection=unet_config['use_linear_projection'],
use_gradient_checkpointing=True, # Enable for memory efficiency
)
def create_vae(config):
"""Factory function to create VAE from config"""
vae_config = config['model']['vae']
return AutoencoderKL(
in_channels=vae_config['in_channels'],
out_channels=vae_config['out_channels'],
down_block_types=tuple(vae_config['down_block_types']),
up_block_types=tuple(vae_config['up_block_types']),
latent_channels=vae_config['latent_channels'],
sample_size=vae_config['sample_size'],
block_out_channels=tuple(vae_config.get('block_out_channels', [64, 128, 256, 512])),
)
def create_text_encoder(config):
"""Factory function to create text encoder from config"""
text_config = config['model']['text_encoder']
return CLIPTextModel(
model_name=text_config['model'],
max_length=text_config['max_length'],
)