|
|
from transformers import PreTrainedModel |
|
|
from .configuration_vitmix import ViTMixConfig |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from einops import rearrange |
|
|
from einops.layers.torch import Rearrange |
|
|
|
|
|
from st_moe_pytorch import SparseMoEBlock, MoE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pair(t): |
|
|
return t if isinstance(t, tuple) else (t, t) |
|
|
|
|
|
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): |
|
|
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") |
|
|
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" |
|
|
omega = torch.arange(dim // 4) / (dim // 4 - 1) |
|
|
omega = 1.0 / (temperature ** omega) |
|
|
|
|
|
y = y.flatten()[:, None] * omega[None, :] |
|
|
x = x.flatten()[:, None] * omega[None, :] |
|
|
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) |
|
|
return pe.type(dtype) |
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, hidden_dim): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.LayerNorm(dim), |
|
|
nn.Linear(dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim, dim), |
|
|
) |
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, dim, heads = 8, dim_head = 64): |
|
|
super().__init__() |
|
|
inner_dim = dim_head * heads |
|
|
self.heads = heads |
|
|
self.scale = dim_head ** -0.5 |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
|
|
self.attend = nn.Softmax(dim = -1) |
|
|
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
|
|
self.to_out = nn.Linear(inner_dim, dim, bias = False) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.norm(x) |
|
|
|
|
|
qkv = self.to_qkv(x).chunk(3, dim = -1) |
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) |
|
|
|
|
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
|
|
|
|
attn = self.attend(dots) |
|
|
|
|
|
out = torch.matmul(attn, v) |
|
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
|
return self.to_out(out) |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
|
|
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_experts): |
|
|
super().__init__() |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(depth): |
|
|
if _ % 2 == 0: |
|
|
self.layers.append(nn.ModuleList([ |
|
|
Attention(dim, heads = heads, dim_head = dim_head), |
|
|
FeedForward(dim, mlp_dim) |
|
|
])) |
|
|
else: |
|
|
self.layers.append(nn.ModuleList([ |
|
|
Attention(dim, heads = heads, dim_head = dim_head), |
|
|
SparseMoEBlock( |
|
|
MoE(dim = dim, |
|
|
num_experts = num_experts, |
|
|
gating_top_n = 2, |
|
|
threshold_train = 0.2, |
|
|
threshold_eval = 0.2, |
|
|
capacity_factor_train = 1.25, |
|
|
capacity_factor_eval = 2., |
|
|
balance_loss_coef = 1e-2, |
|
|
router_z_loss_coef = 1e-3, |
|
|
), |
|
|
add_ff_before = True, |
|
|
add_ff_after = True |
|
|
) |
|
|
])) |
|
|
def forward(self, x): |
|
|
for attne, ff in self.layers: |
|
|
x = attne(x) + x |
|
|
try: |
|
|
x = ff(x) + x |
|
|
except: |
|
|
x = ff(x)[0]+x |
|
|
return self.norm(x) |
|
|
|
|
|
class SimpleViTMIX(nn.Module): |
|
|
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, num_experts = 12): |
|
|
super().__init__() |
|
|
image_height, image_width = pair(image_size) |
|
|
patch_height, patch_width = pair(patch_size) |
|
|
|
|
|
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' |
|
|
|
|
|
patch_dim = channels * patch_height * patch_width |
|
|
|
|
|
self.to_patch_embedding = nn.Sequential( |
|
|
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), |
|
|
nn.LayerNorm(patch_dim), |
|
|
nn.Linear(patch_dim, dim), |
|
|
nn.LayerNorm(dim), |
|
|
) |
|
|
|
|
|
self.pos_embedding = posemb_sincos_2d( |
|
|
h = image_height // patch_height, |
|
|
w = image_width // patch_width, |
|
|
dim = dim, |
|
|
) |
|
|
|
|
|
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_experts) |
|
|
|
|
|
self.pool = "mean" |
|
|
self.to_latent = nn.Identity() |
|
|
|
|
|
self.linear_head = nn.Linear(dim, num_classes) |
|
|
|
|
|
def forward(self, img): |
|
|
device = img.device |
|
|
|
|
|
x = self.to_patch_embedding(img) |
|
|
x += self.pos_embedding.to(device, dtype=x.dtype) |
|
|
|
|
|
x = self.transformer(x) |
|
|
x = x.mean(dim = 1) |
|
|
|
|
|
x = self.to_latent(x) |
|
|
return self.linear_head(x) |
|
|
|
|
|
|
|
|
|
|
|
class ViTMixModel(PreTrainedModel): |
|
|
config_class = ViTMixConfig |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = SimpleViTMIX( |
|
|
image_size = config.image_size, |
|
|
patch_size = config.patch_size, |
|
|
num_classes = config.num_classes, |
|
|
dim = config.dim, |
|
|
depth = config.depth, |
|
|
heads = config.heads, |
|
|
mlp_dim = config.mlp_dim, |
|
|
num_experts = config.num_experts |
|
|
) |
|
|
def forward(self,tensor, labels = None): |
|
|
logits = self.model(tensor) |
|
|
if labels is not None: |
|
|
loss = torch.nn.cross_entropy(logits, labels) |
|
|
return {"loss": loss, "logits": logits} |
|
|
return {"logits": logits} |