File size: 6,265 Bytes
b6ff324 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | # ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Vision Transformer."""
from typing import Tuple
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint as apply_ckpt
from diffnext.models.embeddings import PatchEmbed, RotaryEmbed3D
from diffnext.models.flex_attention import FlexAttentionCausal2D
class MLP(nn.Module):
"""Two layers MLP."""
def __init__(self, dim, mlp_ratio=4):
super(MLP, self).__init__()
self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
self.fc2 = nn.Linear(int(dim * mlp_ratio), dim)
self.activation = nn.GELU()
def forward(self, x) -> torch.Tensor:
return self.fc2(self.activation(self.fc1(x)))
class Attention(nn.Module):
"""Multihead attention."""
def __init__(self, dim, num_heads, qkv_bias=True):
super(Attention, self).__init__()
self.num_heads, self.head_dim = num_heads, dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.attn_mask, self.cache_kv, self.pe_func, self.flex_attn = None, None, None, None
def forward(self, x) -> torch.Tensor:
qkv_shape = [-1, x.size(1), 3, self.num_heads, self.head_dim]
q, k, v = self.qkv(x).view(qkv_shape).permute(2, 0, 3, 1, 4).unbind(dim=0)
q, k = (self.pe_func(q), self.pe_func(k)) if self.pe_func else (q, k)
if self.cache_kv is not None and self.cache_kv:
if isinstance(self.cache_kv, list):
k = self.cache_kv[0] = torch.cat([self.cache_kv[0], k], dim=2)
v = self.cache_kv[1] = torch.cat([self.cache_kv[1], v], dim=2)
else:
self.cache_kv = [k, v]
if self.flex_attn and self.flex_attn.offsets:
return self.proj(self.flex_attn(q, k, v).transpose(1, 2).flatten(2))
o = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask)
return self.proj(o.transpose(1, 2).flatten(2))
class Block(nn.Module):
"""Transformer block."""
def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True):
super(Block, self).__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
self.attn_checkpointing, self.mlp_checkpointing = False, False
def forward_attn(self, x) -> torch.Tensor:
return self.norm1(self.attn(x))
def forward_mlp(self, x) -> torch.Tensor:
return self.norm2(self.mlp(x))
def forward_ckpt(self, x, name) -> torch.Tensor:
if getattr(self, f"{name}_checkpointing", False) and x.requires_grad:
return apply_ckpt(getattr(self, f"forward_{name}"), x, use_reentrant=False)
return getattr(self, f"forward_{name}")(x)
def forward(self, x, pe_func: callable = None) -> torch.Tensor:
self.attn.pe_func = pe_func
x = self.forward_ckpt(x, "attn").add_(x)
return self.forward_ckpt(x, "mlp").add_(x)
class VisionTransformer(nn.Module):
"""Vision transformer."""
def __init__(
self,
depth,
embed_dim,
num_heads,
mlp_ratio=4,
patch_size=2,
image_size=32,
image_dim=4,
encoder_depth=None,
):
super(VisionTransformer, self).__init__()
self.embed_dim, self.image_size, self.image_dim = embed_dim, image_size, image_dim
self.patch_embed = PatchEmbed(image_dim, embed_dim, patch_size)
self.pos_embed, self.rope = nn.Identity(), RotaryEmbed3D(embed_dim // num_heads)
self.blocks = nn.ModuleList(Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth))
self.norm, self.mixer = nn.LayerNorm(embed_dim), nn.Identity()
self.encoder_depth = len(self.blocks) // 2 if encoder_depth is None else encoder_depth
self.flex_attn = FlexAttentionCausal2D()
[setattr(blk.attn, "flex_attn", self.flex_attn) for blk in self.blocks]
def prepare_pe(self, c=None, ids=None, pos=None) -> Tuple[callable, callable]:
pad = 0 if c is None else c.size(1)
pe1 = pe2 = self.rope.get_func(pos, pad)
pe1 = self.rope.get_func(pos, pad, ids.expand(-1, -1, 3)) if ids is not None else pe1
return pe1, pe2
def enable_kvcache(self, mode=True):
[setattr(blk.attn, "cache_kv", mode) for blk in self.blocks]
def forward(self, x, c=None, prev_ids=None, pos=None) -> torch.Tensor:
x, prev_ids = x if isinstance(x, (tuple, list)) else (x, prev_ids)
prev_ids = prev_ids if self.encoder_depth else None
x = x_masked = self.pos_embed(self.patch_embed(x))
pe1, pe2 = self.prepare_pe(c, prev_ids, pos) if pos is not None else [None] * 2
if prev_ids is not None: # Split mask from x.
prev_ids = prev_ids.expand(-1, -1, x.size(-1))
x = x.gather(1, prev_ids)
x = x if c is None else torch.cat([c, x], dim=1)
for blk in self.blocks[: self.encoder_depth]:
x = blk(x, pe1)
if prev_ids is not None and c is not None: # Split c from x.
c, x = x.split((c.size(1), x.size(1) - c.size(1)), dim=1)
if prev_ids is not None: # Merge mask with x.
x = x_masked.to(dtype=x.dtype).scatter(1, prev_ids, x)
x = x if c is None else torch.cat([c, x], dim=1)
for blk in self.blocks[self.encoder_depth :]:
x = blk(x, pe2)
return self.norm(x if c is None else x[:, c.size(1) :])
|