Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from einops import rearrange
from torch import Tensor, nn
try:
from apex.contrib.group_norm import GroupNorm
OPT_GROUP_NORM = True
except Exception:
print('Fused optimized group norm has not been installed.')
OPT_GROUP_NORM = False
# pylint: disable=C0116
def Normalize(in_channels, num_groups=32, act=""):
"""Creates a group normalization layer with specified activation.
Args:
in_channels (int): Number of channels in the input.
num_groups (int, optional): Number of groups for GroupNorm. Defaults to 32.
act (str, optional): Activation function name. Defaults to "".
Returns:
GroupNorm: A normalization layer with optional activation.
"""
return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, act=act)
def nonlinearity(x):
"""Nonlinearity function used in temporal embedding projection.
Currently implemented as a SiLU (Swish) function.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Output after applying SiLU activation.
"""
return x * torch.sigmoid(x)
class ResnetBlock(nn.Module):
"""A ResNet-style block that can optionally apply a temporal embedding and shortcut projections.
This block consists of two convolutional layers, normalization, and optional temporal embedding.
It can adjust channel dimensions between input and output via shortcuts.
"""
def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=0):
"""
Args:
in_channels (int): Number of input channels.
out_channels (int, optional): Number of output channels. Defaults to in_channels.
conv_shortcut (bool, optional): Whether to use a convolutional shortcut. Defaults to False.
dropout (float, optional): Dropout probability. Defaults to 0.0.
temb_channels (int, optional): Number of channels in temporal embedding. Defaults to 0.
"""
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, act="silu")
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, act="silu")
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
"""Forward pass of the ResnetBlock.
Args:
x (Tensor): Input feature map of shape (B, C, H, W).
temb (Tensor): Temporal embedding tensor of shape (B, temb_channels).
Returns:
Tensor: Output feature map of shape (B, out_channels, H, W).
"""
h = x
h = self.norm1(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class Upsample(nn.Module):
"""Upsampling block that increases spatial resolution by a factor of 2.
Can optionally include a convolution after upsampling.
"""
def __init__(self, in_channels, with_conv):
"""
Args:
in_channels (int): Number of input channels.
with_conv (bool): If True, apply a convolution after upsampling.
"""
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
"""Forward pass of the Upsample block.
Args:
x (Tensor): Input feature map (B, C, H, W).
Returns:
Tensor: Upsampled feature map (B, C, 2H, 2W).
"""
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = x.dtype
if dtype == torch.bfloat16:
x = x.to(torch.float32)
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if dtype == torch.bfloat16:
x = x.to(dtype)
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""Downsampling block that reduces spatial resolution by a factor of 2.
Can optionally include a convolution before downsampling.
"""
def __init__(self, in_channels, with_conv):
"""
Args:
in_channels (int): Number of input channels.
with_conv (bool): If True, apply a convolution before downsampling.
"""
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
"""Forward pass of the Downsample block.
Args:
x (Tensor): Input feature map (B, C, H, W).
Returns:
Tensor: Downsampled feature map (B, C, H/2, W/2).
"""
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class AttnBlock(nn.Module):
"""Self-attention block that applies scaled dot-product attention to feature maps.
Normalizes input, computes queries, keys, and values, then applies attention and a projection.
"""
def __init__(self, in_channels: int):
"""
Args:
in_channels (int): Number of input/output channels.
"""
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels, act="silu")
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
"""Compute the attention over the input feature maps.
Args:
h_ (Tensor): Normalized input feature map (B, C, H, W).
Returns:
Tensor: Output after applying scaled dot-product attention (B, C, H, W).
"""
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the AttnBlock.
Args:
x (Tensor): Input feature map (B, C, H, W).
Returns:
Tensor: Output feature map after self-attention (B, C, H, W).
"""
return x + self.proj_out(self.attention(x))
class LinearAttention(nn.Module):
"""Linear Attention block for efficient attention computations.
Uses linear attention mechanisms to reduce complexity and memory usage.
"""
def __init__(self, dim, heads=4, dim_head=32):
"""
Args:
dim (int): Input channel dimension.
heads (int, optional): Number of attention heads. Defaults to 4.
dim_head (int, optional): Dimension per attention head. Defaults to 32.
"""
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
"""Forward pass of the LinearAttention block.
Args:
x (Tensor): Input feature map (B, C, H, W).
Returns:
Tensor: Output feature map after linear attention (B, C, H, W).
"""
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class LinAttnBlock(LinearAttention):
"""Wrapper class to provide a linear attention block in a form compatible with other attention blocks."""
def __init__(self, in_channels):
"""
Args:
in_channels (int): Number of input/output channels.
"""
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
def make_attn(in_channels, attn_type="vanilla"):
"""Factory function to create an attention block.
Args:
in_channels (int): Number of input/output channels.
attn_type (str, optional): Type of attention block to create. Options: "vanilla", "linear", "none".
Defaults to "vanilla".
Returns:
nn.Module: An instance of the requested attention block.
"""
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
# pylint: disable=C0116