EO-VAE / _eo_vae /layers.py
BiliSakura's picture
Update all files for EO-VAE
774dec2 verified
# Apache-2.0 - Based on Flux2 / diffusers
# ResnetBlock, AttnBlock, Downsample, Upsample
import torch
import torch.nn as nn
from torch import Tensor
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor) -> Tensor:
x = nn.functional.pad(x, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(x)
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
return self.conv(x)
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, cond_dim: int = None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels if out_channels is not None else in_channels
self.cond_dim = cond_dim
self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, self.out_channels, 3, stride=1, padding=1)
if self.cond_dim is not None:
self.emb_proj = nn.Linear(cond_dim, self.out_channels * 2)
nn.init.zeros_(self.emb_proj.bias)
self.emb_proj.weight.data.zero_()
self.emb_proj.bias.data[: self.out_channels] = 1.0
self.norm2 = nn.GroupNorm(32, self.out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, 3, stride=1, padding=1)
self.nin_shortcut = (
nn.Conv2d(in_channels, self.out_channels, 1, stride=1, padding=0)
if in_channels != self.out_channels
else nn.Identity()
)
def forward(self, x: Tensor, emb: Tensor = None) -> Tensor:
h = self.norm1(x)
h = swish(h)
h = self.conv1(h)
if self.cond_dim is not None and emb is not None:
style = self.emb_proj(emb).unsqueeze(-1).unsqueeze(-1)
scale, shift = style.chunk(2, dim=1)
h = self.norm2(h)
h = h * scale + shift
else:
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
return h + self.nin_shortcut(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.norm = nn.GroupNorm(32, in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, 1)
self.k = nn.Conv2d(in_channels, in_channels, 1)
self.v = nn.Conv2d(in_channels, in_channels, 1)
self.proj_out = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x: Tensor) -> Tensor:
h_ = self.norm(x)
q, k, v = self.q(h_), self.k(h_), self.v(h_)
b, c, h, w = q.shape
q = q.flatten(2).transpose(1, 2).unsqueeze(1) # b 1 (hw) c
k = k.flatten(2).transpose(1, 2).unsqueeze(1)
v = v.flatten(2).transpose(1, 2).unsqueeze(1)
h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v)
h_ = h_.squeeze(1).transpose(1, 2).view(b, c, h, w)
return x + self.proj_out(h_)