| | import math |
| | from collections import OrderedDict |
| | from functools import partial |
| | from typing import Any, Callable, List, NamedTuple, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | try: |
| | from torch.hub import load_state_dict_from_url |
| | except ImportError: |
| | from torch.utils.model_zoo import load_url as load_state_dict_from_url |
| |
|
| |
|
| | model_urls = { |
| | "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", |
| | "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", |
| | "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", |
| | "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", |
| | } |
| |
|
| |
|
| | class MLPBlock(nn.Sequential): |
| | """Transformer MLP block.""" |
| |
|
| | def __init__(self, in_dim: int, mlp_dim: int, dropout: float): |
| | super().__init__() |
| | self.linear_1 = nn.Linear(in_dim, mlp_dim) |
| | self.act = nn.GELU() |
| | self.dropout_1 = nn.Dropout(dropout) |
| | self.linear_2 = nn.Linear(mlp_dim, in_dim) |
| | self.dropout_2 = nn.Dropout(dropout) |
| |
|
| | nn.init.xavier_uniform_(self.linear_1.weight) |
| | nn.init.xavier_uniform_(self.linear_2.weight) |
| | nn.init.normal_(self.linear_1.bias, std=1e-6) |
| | nn.init.normal_(self.linear_2.bias, std=1e-6) |
| |
|
| |
|
| | class EncoderBlock(nn.Module): |
| | """Transformer encoder block.""" |
| |
|
| | def __init__( |
| | self, |
| | num_heads: int, |
| | hidden_dim: int, |
| | mlp_dim: int, |
| | dropout: float, |
| | attention_dropout: float, |
| | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), |
| | ): |
| | super().__init__() |
| | self.num_heads = num_heads |
| |
|
| | |
| | self.ln_1 = norm_layer(hidden_dim) |
| | self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | |
| | self.ln_2 = norm_layer(hidden_dim) |
| | self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) |
| |
|
| | def forward(self, input: torch.Tensor): |
| | torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") |
| | x = self.ln_1(input) |
| | x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) |
| | x = self.dropout(x) |
| | x = x + input |
| |
|
| | y = self.ln_2(x) |
| | y = self.mlp(y) |
| | return x + y |
| |
|
| |
|
| | class Encoder(nn.Module): |
| | """Transformer Model Encoder for sequence to sequence translation.""" |
| |
|
| | def __init__( |
| | self, |
| | seq_length: int, |
| | num_layers: int, |
| | num_heads: int, |
| | hidden_dim: int, |
| | mlp_dim: int, |
| | dropout: float, |
| | attention_dropout: float, |
| | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), |
| | ): |
| | super().__init__() |
| | |
| | |
| | self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) |
| | self.dropout = nn.Dropout(dropout) |
| | layers: OrderedDict[str, nn.Module] = OrderedDict() |
| | for i in range(num_layers): |
| | layers[f"encoder_layer_{i}"] = EncoderBlock( |
| | num_heads, |
| | hidden_dim, |
| | mlp_dim, |
| | dropout, |
| | attention_dropout, |
| | norm_layer, |
| | ) |
| | self.layers = nn.Sequential(layers) |
| | self.ln = norm_layer(hidden_dim) |
| |
|
| | def forward(self, input: torch.Tensor): |
| | torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") |
| | input = input + self.pos_embedding |
| | return self.ln(self.layers(self.dropout(input))) |
| |
|
| |
|
| | class FeatureTransformer(nn.Module): |
| | """ |
| | Feaure Transformer |
| | """ |
| | def __init__( |
| | self, |
| | seq_length: int = 16, |
| | num_layers: int = 2, |
| | num_heads: int = 4, |
| | hidden_dim: int = 768, |
| | mlp_dim: int = 768, |
| | dropout: float = 0.0, |
| | attention_dropout: float = 0.0, |
| | num_classes: int = 1, |
| | representation_size: Optional[int] = None, |
| | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), |
| | ) -> None: |
| | super().__init__() |
| | |
| | self.hidden_dim = hidden_dim |
| | self.mlp_dim = mlp_dim |
| | self.attention_dropout = attention_dropout |
| | self.dropout = dropout |
| | self.num_classes = num_classes |
| | self.representation_size = representation_size |
| | self.norm_layer = norm_layer |
| |
|
| | |
| | self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) |
| | seq_length += 1 |
| |
|
| | self.encoder = Encoder( |
| | seq_length, |
| | num_layers, |
| | num_heads, |
| | hidden_dim, |
| | mlp_dim, |
| | dropout, |
| | attention_dropout, |
| | norm_layer, |
| | ) |
| | self.seq_length = seq_length |
| |
|
| | heads_layers: OrderedDict[str, nn.Module] = OrderedDict() |
| | if representation_size is None: |
| | heads_layers["head"] = nn.Linear(hidden_dim, num_classes) |
| | else: |
| | heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) |
| | heads_layers["act"] = nn.Tanh() |
| | heads_layers["head"] = nn.Linear(representation_size, num_classes) |
| |
|
| | self.heads = nn.Sequential(heads_layers) |
| |
|
| | if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): |
| | fan_in = self.heads.pre_logits.in_features |
| | nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) |
| | nn.init.zeros_(self.heads.pre_logits.bias) |
| |
|
| | if isinstance(self.heads.head, nn.Linear): |
| | nn.init.zeros_(self.heads.head.weight) |
| | nn.init.zeros_(self.heads.head.bias) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | |
| | batch_class_token = self.class_token.expand(x.shape[0], -1, -1) |
| | x = torch.cat([batch_class_token, x], dim=1) |
| |
|
| | x = self.encoder(x) |
| |
|
| | |
| | x = x[:, 0] |
| | x = self.heads(x) |
| |
|
| | return x |
| |
|