SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import torch
import torch.nn as nn
from einops import rearrange
import warnings
import torch.nn.functional as F
USE_FLASH_ATTENTION3 = True
try:
from flash_attn_interface import flash_attn_func
FA3_AVAILABLE = True
warnings.warn('flash attention 3 is available (LVSM)')
except ImportError:
FA3_AVAILABLE = False
warnings.warn('flash attention 3 is not available (LVSM)')
try:
import xformers.ops as xops
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
warnings.warn('xformers is not available (LVSM)')
# raise ImportError("Please install xformers to use flashatt v2")
def init_weights(module, std=0.02):
"""Initialize weights for linear and embedding layers.
Args:
module: Module to initialize
std: Standard deviation for normal initialization
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
# src: https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/llama/model.py#L28
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight.type_as(x)
class MLP(nn.Module):
"""
Multi-Layer Perceptron block.
Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L49-L65
"""
def __init__(
self,
dim,
mlp_ratio=4,
bias=False,
dropout=0.0,
activation=nn.GELU,
mlp_dim=None,
):
"""
Args:
dim: Input dimension
mlp_ratio: Multiplier for hidden dimension
bias: Whether to use bias in linear layers
dropout: Dropout probability
activation: Activation function
mlp_dim: Optional explicit hidden dimension (overrides mlp_ratio)
"""
super().__init__()
hidden_dim = mlp_dim if mlp_dim is not None else int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim, bias=bias),
activation(),
nn.Linear(hidden_dim, dim, bias=bias),
nn.Dropout(dropout),
)
def forward(self, x):
return self.mlp(x)
class QK_Norm_SelfAttention(nn.Module):
"""
Self-attention with optional Q-K normalization.
Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L68-L92
"""
def __init__(
self,
dim,
head_dim,
qkv_bias=False,
fc_bias=True,
attn_dropout=0.0,
fc_dropout=0.0,
use_qk_norm=True,
):
"""
Args:
dim: Input dimension
head_dim: Dimension of each attention head
qkv_bias: Whether to use bias in QKV projection
fc_bias: Whether to use bias in output projection
attn_dropout: Dropout probability for attention weights
fc_dropout: Dropout probability for output projection
use_qk_norm: Whether to use Q-K normalization
We use flash attention V2 for efficiency.
"""
super().__init__()
assert dim % head_dim == 0, f"Token dimension {dim} should be divisible by head dimension {head_dim}"
self.dim = dim
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.attn_dropout = attn_dropout
self.use_qk_norm = use_qk_norm
self.to_qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.fc = nn.Linear(dim, dim, bias=fc_bias)
self.attn_fc_dropout = nn.Dropout(fc_dropout)
# Optional Q-K normalization
if self.use_qk_norm:
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
def forward(self, x, attn_bias=None):
"""
Args:
x: Input tensor of shape (batch, seq_len, dim)
attn_bias: Optional attention bias mask
Returns:
Output tensor of shape (batch, seq_len, dim)
"""
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = (rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.head_dim) for t in (q, k, v))
# Apply qk normalization if enabled
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
if USE_FLASH_ATTENTION3 and FA3_AVAILABLE:
x = flash_attn_func(q, k, v)[0]
elif XFORMERS_AVAILABLE:
x = xops.memory_efficient_attention(
q, k, v,
attn_bias=attn_bias,
p=self.attn_dropout if self.training else 0.0,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
)
else:
# use pytorch's built-in attention
q = q.permute(0, 2, 1, 3).contiguous() # [B, H, L, C]
k = k.permute(0, 2, 1, 3).contiguous()
v = v.permute(0, 2, 1, 3).contiguous()
x = F.scaled_dot_product_attention(q, k, v)
x = x.permute(0, 2, 1, 3).contiguous() # [B, L, H, C]
x = rearrange(x, "b l nh dh -> b l (nh dh)")
x = self.attn_fc_dropout(self.fc(x))
return x
class SubsetAttention(nn.Module):
"""Attention that can attend to subsets of queries or keys/values."""
def __init__(
self,
dim,
head_dim,
qkv_bias=False,
attn_dropout=0.0,
fc_bias=False,
fc_dropout=0.0,
use_qk_norm=False
):
"""
Args:
dim: Input dimension
head_dim: Dimension of each attention head
qkv_bias: Whether to use bias in QKV projection
attn_dropout: Dropout probability for attention weights
fc_bias: Whether to use bias in output projection
fc_dropout: Dropout probability for output projection
use_qk_norm: Whether to use Q-K normalization
We use flash attention V2 for efficiency.
"""
super().__init__()
assert dim % head_dim == 0, f"Token dimension {dim} should be divisible by head dimension {head_dim}"
self.dim = dim
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.attn_dropout = attn_dropout
self.use_qk_norm = use_qk_norm
# Projections
self.to_qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.fc = nn.Linear(dim, dim, bias=fc_bias)
self.attn_fc_dropout = nn.Dropout(fc_dropout)
# Optional Q-K normalization
if self.use_qk_norm:
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
def forward(self, x, subset_kv_size=None, subset_q_size=None):
"""
Args:
x: Input tensor of shape (batch, seq_len, dim)
subset_kv_size: If provided, only attend to tokens after this index in KV
subset_q_size: If provided, only compute attention for queries up to this index
Returns:
Output tensor of shape (batch, seq_len, dim)
"""
# Only one subset parameter can be provided
assert not (subset_kv_size is not None and subset_q_size is not None), \
"Only one of subset_kv_size or subset_q_size can be provided"
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = (rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.head_dim) for t in (q, k, v))
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
# Handle subset attention cases
if subset_kv_size is not None and subset_kv_size < k.shape[1]:
# Attend to subset of key/value tokens
k_subset = k[:, subset_kv_size:, :, :].contiguous()
v_subset = v[:, subset_kv_size:, :, :].contiguous()
x = xops.memory_efficient_attention(
q, k_subset, v_subset,
attn_bias=None,
p=self.attn_dropout if self.training else 0.0,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
)
elif subset_q_size is not None and subset_q_size < q.shape[1]:
# Only compute attention for subset of query tokens
q_subset = q[:, :subset_q_size, :, :].contiguous()
x = xops.memory_efficient_attention(
q_subset, k, v,
attn_bias=None,
p=self.attn_dropout if self.training else 0.0,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
)
else:
# Regular attention for all tokens
x = xops.memory_efficient_attention(
q, k, v,
attn_bias=None,
p=self.attn_dropout if self.training else 0.0,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
)
x = rearrange(x, "b l nh dh -> b l (nh dh)")
# Final projection
x = self.attn_fc_dropout(self.fc(x))
return x
class QK_Norm_TransformerBlock(nn.Module):
"""
Standard transformer block with pre-normalization architecture.
Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L95-L113
"""
def __init__(
self,
dim,
head_dim,
ln_bias=False,
attn_qkv_bias=False,
attn_dropout=0.0,
attn_fc_bias=False,
attn_fc_dropout=0.0,
mlp_ratio=4,
mlp_bias=False,
mlp_dropout=0.0,
use_qk_norm=True,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim, bias=ln_bias)
self.attn = QK_Norm_SelfAttention(
dim=dim,
head_dim=head_dim,
qkv_bias=attn_qkv_bias,
fc_bias=attn_fc_bias,
attn_dropout=attn_dropout,
fc_dropout=attn_fc_dropout,
use_qk_norm=use_qk_norm,
)
self.norm2 = nn.LayerNorm(dim, bias=ln_bias)
self.mlp = MLP(
dim=dim,
mlp_ratio=mlp_ratio,
bias=mlp_bias,
dropout=mlp_dropout,
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x