| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch.nn as nn |
|
|
| from .modulate import ModLN |
|
|
|
|
| class BasicBlock(nn.Module): |
| """ |
| Transformer block that is in its simplest form. |
| Designed for PF-LRM architecture. |
| """ |
|
|
| |
| def __init__( |
| self, |
| inner_dim: int, |
| num_heads: int, |
| eps: float, |
| attn_drop: float = 0.0, |
| attn_bias: bool = False, |
| mlp_ratio: float = 4.0, |
| mlp_drop: float = 0.0, |
| ): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(inner_dim, eps=eps) |
| self.self_attn = nn.MultiheadAttention( |
| embed_dim=inner_dim, |
| num_heads=num_heads, |
| dropout=attn_drop, |
| bias=attn_bias, |
| batch_first=True, |
| ) |
| self.norm2 = nn.LayerNorm(inner_dim, eps=eps) |
| self.mlp = nn.Sequential( |
| nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), |
| nn.GELU(), |
| nn.Dropout(mlp_drop), |
| nn.Linear(int(inner_dim * mlp_ratio), inner_dim), |
| nn.Dropout(mlp_drop), |
| ) |
|
|
| def forward(self, x): |
| |
| before_sa = self.norm1(x) |
| x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class ConditionBlock(nn.Module): |
| """ |
| Transformer block that takes in a cross-attention condition. |
| Designed for SparseLRM architecture. |
| """ |
|
|
| |
| def __init__( |
| self, |
| inner_dim: int, |
| cond_dim: int, |
| num_heads: int, |
| eps: float, |
| attn_drop: float = 0.0, |
| attn_bias: bool = False, |
| mlp_ratio: float = 4.0, |
| mlp_drop: float = 0.0, |
| ): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(inner_dim, eps=eps) |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=inner_dim, |
| num_heads=num_heads, |
| kdim=cond_dim, |
| vdim=cond_dim, |
| dropout=attn_drop, |
| bias=attn_bias, |
| batch_first=True, |
| ) |
| self.norm2 = nn.LayerNorm(inner_dim, eps=eps) |
| self.self_attn = nn.MultiheadAttention( |
| embed_dim=inner_dim, |
| num_heads=num_heads, |
| dropout=attn_drop, |
| bias=attn_bias, |
| batch_first=True, |
| ) |
| self.norm3 = nn.LayerNorm(inner_dim, eps=eps) |
| self.mlp = nn.Sequential( |
| nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), |
| nn.GELU(), |
| nn.Dropout(mlp_drop), |
| nn.Linear(int(inner_dim * mlp_ratio), inner_dim), |
| nn.Dropout(mlp_drop), |
| ) |
|
|
| def forward(self, x, cond): |
| |
| |
| x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0] |
| before_sa = self.norm2(x) |
| x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] |
| x = x + self.mlp(self.norm3(x)) |
| return x |
|
|
|
|
| class ConditionModulationBlock(nn.Module): |
| """ |
| Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. |
| Designed for raw LRM architecture. |
| """ |
|
|
| |
| def __init__( |
| self, |
| inner_dim: int, |
| cond_dim: int, |
| mod_dim: int, |
| num_heads: int, |
| eps: float, |
| attn_drop: float = 0.0, |
| attn_bias: bool = False, |
| mlp_ratio: float = 4.0, |
| mlp_drop: float = 0.0, |
| ): |
| super().__init__() |
| self.norm1 = ModLN(inner_dim, mod_dim, eps) |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=inner_dim, |
| num_heads=num_heads, |
| kdim=cond_dim, |
| vdim=cond_dim, |
| dropout=attn_drop, |
| bias=attn_bias, |
| batch_first=True, |
| ) |
| self.norm2 = ModLN(inner_dim, mod_dim, eps) |
| self.self_attn = nn.MultiheadAttention( |
| embed_dim=inner_dim, |
| num_heads=num_heads, |
| dropout=attn_drop, |
| bias=attn_bias, |
| batch_first=True, |
| ) |
| self.norm3 = ModLN(inner_dim, mod_dim, eps) |
| self.mlp = nn.Sequential( |
| nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), |
| nn.GELU(), |
| nn.Dropout(mlp_drop), |
| nn.Linear(int(inner_dim * mlp_ratio), inner_dim), |
| nn.Dropout(mlp_drop), |
| ) |
|
|
| def forward(self, x, cond, mod): |
| |
| |
| |
| x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0] |
| before_sa = self.norm2(x, mod) |
| x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] |
| x = x + self.mlp(self.norm3(x, mod)) |
| return x |
|
|