File size: 8,054 Bytes
df9529d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import torch
from torch import nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim):
super().__init__()
self.rope_theta = 10000
inv_freq = 1.0 / (
self.rope_theta
** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, head_dim):
super().__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.q_proj = nn.Linear(query_dim, inner_dim, bias=False)
self.q_norm = nn.RMSNorm(head_dim, eps=1e-6)
self.k_proj = nn.Linear(context_dim, inner_dim, bias=False)
self.k_norm = nn.RMSNorm(head_dim, eps=1e-6)
self.v_proj = nn.Linear(context_dim, inner_dim, bias=False)
self.o_proj = nn.Linear(inner_dim, query_dim, bias=False)
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
context = x if context is None else context
input_shape = x.shape[:-1]
q_shape = (*input_shape, self.n_heads, self.head_dim)
context_shape = context.shape[:-1]
kv_shape = (*context_shape, self.n_heads, self.head_dim)
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
if position_embeddings is not None:
assert position_embeddings_context is not None
cos, sin = position_embeddings
query_states = apply_rotary_pos_emb(query_states, cos, sin)
cos, sin = position_embeddings_context
key_states = apply_rotary_pos_emb(key_states, cos, sin)
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
return self.o_proj(attn_output)
class TransformerBlock(nn.Module):
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=True):
super().__init__()
self.use_self_attn = use_self_attn
if self.use_self_attn:
self.norm_self_attn = nn.RMSNorm(model_dim, eps=1e-6)
self.self_attn = Attention(
query_dim=model_dim,
context_dim=model_dim,
n_heads=num_heads,
head_dim=model_dim // num_heads,
)
self.norm_cross_attn = nn.RMSNorm(model_dim, eps=1e-6)
self.cross_attn = Attention(
query_dim=model_dim,
context_dim=source_dim,
n_heads=num_heads,
head_dim=model_dim // num_heads,
)
self.norm_mlp = nn.RMSNorm(model_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(model_dim, int(model_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(model_dim * mlp_ratio), model_dim),
)
def forward(
self,
x,
context,
target_attention_mask=None,
source_attention_mask=None,
position_embeddings=None,
position_embeddings_context=None,
):
if self.use_self_attn:
normed = self.norm_self_attn(x)
attn_out = self.self_attn(
normed,
mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings,
)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(
normed,
mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
return x
class AnimaLLMAdapter(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
source_dim: int = 1024,
target_dim: int = 1024,
model_dim: int = 1024,
num_layers: int = 6,
num_heads: int = 16,
mlp_ratio: float = 4.0,
vocab_size: int = 32128,
use_self_attn: bool = True,
):
super().__init__()
self.embed = nn.Embedding(vocab_size, target_dim)
if model_dim != target_dim:
self.in_proj = nn.Linear(target_dim, model_dim)
else:
self.in_proj = nn.Identity()
self.rotary_emb = RotaryEmbedding(model_dim // num_heads)
self.blocks = nn.ModuleList(
[
TransformerBlock(
source_dim,
model_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
use_self_attn=use_self_attn,
)
for _ in range(num_layers)
]
)
self.out_proj = nn.Linear(model_dim, target_dim)
self.norm = nn.RMSNorm(target_dim, eps=1e-6)
def forward(
self,
source_hidden_states: torch.Tensor,
target_input_ids: torch.Tensor,
target_attention_mask: torch.Tensor = None,
source_attention_mask: torch.Tensor = None,
) -> torch.Tensor:
if target_attention_mask is not None:
target_attention_mask = target_attention_mask.to(torch.bool)
if target_attention_mask.ndim == 2:
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
if source_attention_mask is not None:
source_attention_mask = source_attention_mask.to(torch.bool)
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(
x,
context,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
return self.norm(self.out_proj(x))
|