import torch import torch.nn as nn from functools import partial import numpy as np from .model_core import ( PatchEmbed, AltBlock, trunc_normal_ ) class EAT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.mode = config.model_variant # "pretrain" or "finetune" # === Embedding / Encoder === self.local_encoder = PatchEmbed( patch_size=config.patch_size, in_chans=config.in_chans, embed_dim=config.embed_dim, stride=config.stride, use_sincos_pos=config.fixed_positions ) self.extra_tokens = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) self.pos_drop = nn.Dropout(p=config.drop_rate, inplace=True) trunc_normal_(self.extra_tokens, std=.02) norm_layer = partial(nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine) dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth) self.blocks = nn.ModuleList([ AltBlock(config.embed_dim, config.num_heads, config.mlp_ratio, qkv_bias=config.qkv_bias, drop=config.drop_rate, attn_drop=config.attn_drop_rate, mlp_drop=config.activation_dropout, post_mlp_drop=config.post_mlp_drop, drop_path=dpr[i], norm_layer=norm_layer, layer_norm_first=config.layer_norm_first, ffn_targets=True) for i in range(config.depth) ]) self.pre_norm = norm_layer(config.embed_dim) # === Head (for finetune) === if self.mode == "finetune": self.fc_norm = nn.LayerNorm(config.embed_dim) self.head = nn.Linear(config.embed_dim, config.num_classes, bias=True) else: self.head = nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def encode(self, x): B = x.shape[0] x = self.local_encoder(x) x = torch.cat((self.extra_tokens.expand(B, -1, -1), x), dim=1) x = self.pre_norm(x) x = self.pos_drop(x) for blk in self.blocks: x, _ = blk(x) return x def forward(self, x): x = self.encode(x) if self.mode == "finetune": x = x[:, 0] # use cls token x = self.fc_norm(x) x = self.head(x) return x def extract_features(self, x): x = self.encode(x) return x