Spaces:
Running on Zero
Running on Zero
| """Residual block components for diffusion models.""" | |
| from abc import abstractmethod | |
| from typing import Optional, Any, Dict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.NeuralNetwork import transformer | |
| from src.Attention import Attention | |
| from src.cond import cast | |
| from src.sample import sampling_util | |
| ops = cast.disable_weight_init | |
| class TimestepBlock1(nn.Module): | |
| """Abstract timestep block.""" | |
| def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
| pass | |
| def forward_timestep_embed1(ts, x, emb, context=None, transformer_options={}, output_shape=None, | |
| time_context=None, num_video_frames=None, image_only_indicator=None): | |
| """Forward pass for timestep embedding.""" | |
| for layer in ts: | |
| if isinstance(layer, TimestepBlock1): | |
| x = layer(x, emb) | |
| elif isinstance(layer, transformer.SpatialTransformer): | |
| x = layer(x, context, transformer_options) | |
| if "transformer_index" in transformer_options: | |
| transformer_options["transformer_index"] += 1 | |
| elif isinstance(layer, Upsample1): | |
| x = layer(x, output_shape=output_shape) | |
| else: | |
| x = layer(x) | |
| return x | |
| class Upsample1(nn.Module): | |
| """Upsample layer with optional conv.""" | |
| def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, | |
| dtype=None, device=None, operations=ops): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| if use_conv: | |
| self.conv = operations.conv_nd(dims, channels, self.out_channels, 3, | |
| padding=padding, dtype=dtype, device=device) | |
| def forward(self, x, output_shape=None): | |
| shape = [x.shape[2] * 2, x.shape[3] * 2] if output_shape is None else [output_shape[2], output_shape[3]] | |
| x = F.interpolate(x, size=shape, mode="nearest") | |
| return self.conv(x) if self.use_conv else x | |
| class Downsample1(nn.Module): | |
| """Downsample layer.""" | |
| def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, | |
| dtype=None, device=None, operations=ops): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| stride = 2 if dims != 3 else (1, 2, 2) | |
| self.op = operations.conv_nd(dims, channels, self.out_channels, 3, stride=stride, | |
| padding=padding, dtype=dtype, device=device) | |
| def forward(self, x): | |
| return self.op(x) | |
| class ResBlock1(TimestepBlock1): | |
| """Residual block with timestep embedding.""" | |
| def __init__(self, channels, emb_channels, dropout, out_channels=None, use_conv=False, | |
| use_scale_shift_norm=False, dims=2, use_checkpoint=False, up=False, down=False, | |
| kernel_size=3, exchange_temb_dims=False, skip_t_emb=False, | |
| dtype=None, device=None, operations=ops): | |
| super().__init__() | |
| self.out_channels = out_channels or channels | |
| self.use_checkpoint = use_checkpoint | |
| self.skip_t_emb = skip_t_emb | |
| padding = kernel_size // 2 | |
| self.in_layers = nn.Sequential( | |
| operations.GroupNorm(32, channels, dtype=dtype, device=device), | |
| nn.SiLU(), | |
| operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)) | |
| self.emb_layers = nn.Sequential( | |
| nn.SiLU(), | |
| operations.Linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, | |
| dtype=dtype, device=device)) | |
| self.out_layers = nn.Sequential( | |
| operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device), | |
| nn.SiLU(), | |
| nn.Dropout(p=dropout), | |
| operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)) | |
| self.skip_connection = (nn.Identity() if self.out_channels == channels | |
| else operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)) | |
| def forward(self, x, emb): | |
| return sampling_util.checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) | |
| def _forward(self, x, emb): | |
| h = self.in_layers(x) | |
| if not self.skip_t_emb: | |
| emb_out = self.emb_layers(emb).type(h.dtype) | |
| while len(emb_out.shape) < len(h.shape): | |
| emb_out = emb_out[..., None] | |
| h = h + emb_out | |
| return self.skip_connection(x) + self.out_layers(h) | |
| class ResnetBlock(nn.Module): | |
| """VAE-style ResNet block.""" | |
| def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): | |
| super().__init__() | |
| out_channels = out_channels or in_channels | |
| self.in_channels, self.out_channels = in_channels, out_channels | |
| self.norm1 = Attention.Normalize(in_channels) | |
| self.conv1 = ops.Conv2d(in_channels, out_channels, 3, 1, 1) | |
| self.norm2 = Attention.Normalize(out_channels) | |
| self.dropout = nn.Dropout(dropout, inplace=True) | |
| self.conv2 = ops.Conv2d(out_channels, out_channels, 3, 1, 1) | |
| self.nin_shortcut = ops.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else None | |
| self.swish = nn.SiLU(inplace=True) | |
| def forward(self, x, temb): | |
| h = self.swish(self.norm1(x)) | |
| h = self.conv1(h) | |
| h = self.dropout(self.swish(self.norm2(h))) | |
| h = self.conv2(h) | |
| return (self.nin_shortcut(x) if self.nin_shortcut else x) + h | |