Anima-sdnext-diffusers / llm_adapter /modeling_llm_adapter.py
CalamitousFelicitousness's picture
Upload folder using huggingface_hub
df9529d verified
import torch
from torch import nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim):
super().__init__()
self.rope_theta = 10000
inv_freq = 1.0 / (
self.rope_theta
** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, head_dim):
super().__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.q_proj = nn.Linear(query_dim, inner_dim, bias=False)
self.q_norm = nn.RMSNorm(head_dim, eps=1e-6)
self.k_proj = nn.Linear(context_dim, inner_dim, bias=False)
self.k_norm = nn.RMSNorm(head_dim, eps=1e-6)
self.v_proj = nn.Linear(context_dim, inner_dim, bias=False)
self.o_proj = nn.Linear(inner_dim, query_dim, bias=False)
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
context = x if context is None else context
input_shape = x.shape[:-1]
q_shape = (*input_shape, self.n_heads, self.head_dim)
context_shape = context.shape[:-1]
kv_shape = (*context_shape, self.n_heads, self.head_dim)
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
if position_embeddings is not None:
assert position_embeddings_context is not None
cos, sin = position_embeddings
query_states = apply_rotary_pos_emb(query_states, cos, sin)
cos, sin = position_embeddings_context
key_states = apply_rotary_pos_emb(key_states, cos, sin)
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
return self.o_proj(attn_output)
class TransformerBlock(nn.Module):
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=True):
super().__init__()
self.use_self_attn = use_self_attn
if self.use_self_attn:
self.norm_self_attn = nn.RMSNorm(model_dim, eps=1e-6)
self.self_attn = Attention(
query_dim=model_dim,
context_dim=model_dim,
n_heads=num_heads,
head_dim=model_dim // num_heads,
)
self.norm_cross_attn = nn.RMSNorm(model_dim, eps=1e-6)
self.cross_attn = Attention(
query_dim=model_dim,
context_dim=source_dim,
n_heads=num_heads,
head_dim=model_dim // num_heads,
)
self.norm_mlp = nn.RMSNorm(model_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(model_dim, int(model_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(model_dim * mlp_ratio), model_dim),
)
def forward(
self,
x,
context,
target_attention_mask=None,
source_attention_mask=None,
position_embeddings=None,
position_embeddings_context=None,
):
if self.use_self_attn:
normed = self.norm_self_attn(x)
attn_out = self.self_attn(
normed,
mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings,
)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(
normed,
mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
return x
class AnimaLLMAdapter(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
source_dim: int = 1024,
target_dim: int = 1024,
model_dim: int = 1024,
num_layers: int = 6,
num_heads: int = 16,
mlp_ratio: float = 4.0,
vocab_size: int = 32128,
use_self_attn: bool = True,
):
super().__init__()
self.embed = nn.Embedding(vocab_size, target_dim)
if model_dim != target_dim:
self.in_proj = nn.Linear(target_dim, model_dim)
else:
self.in_proj = nn.Identity()
self.rotary_emb = RotaryEmbedding(model_dim // num_heads)
self.blocks = nn.ModuleList(
[
TransformerBlock(
source_dim,
model_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
use_self_attn=use_self_attn,
)
for _ in range(num_layers)
]
)
self.out_proj = nn.Linear(model_dim, target_dim)
self.norm = nn.RMSNorm(target_dim, eps=1e-6)
def forward(
self,
source_hidden_states: torch.Tensor,
target_input_ids: torch.Tensor,
target_attention_mask: torch.Tensor = None,
source_attention_mask: torch.Tensor = None,
) -> torch.Tensor:
if target_attention_mask is not None:
target_attention_mask = target_attention_mask.to(torch.bool)
if target_attention_mask.ndim == 2:
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
if source_attention_mask is not None:
source_attention_mask = source_attention_mask.to(torch.bool)
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(
x,
context,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
return self.norm(self.out_proj(x))