| import torch | |
| import torch.nn as nn | |
| from torch.nn import LayerNorm, Linear, Dropout | |
| from torch.nn.functional import gelu | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from .SwinCXRConfig import SwinCXRConfig | |
| class SwinSelfAttention(nn.Module): | |
| def __init__(self, embed_dim, num_heads, dropout): | |
| super(SwinSelfAttention, self).__init__() | |
| self.query = Linear(embed_dim, embed_dim) | |
| self.key = Linear(embed_dim, embed_dim) | |
| self.value = Linear(embed_dim, embed_dim) | |
| self.dropout = Dropout(p=dropout) | |
| def forward(self, x): | |
| query = self.query(x) | |
| key = self.key(x) | |
| value = self.value(x) | |
| attention_weights = torch.matmul(query, key.transpose(-2, -1)) / query.size(-1)**0.5 | |
| attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1) | |
| attention_output = torch.matmul(attention_weights, value) | |
| return self.dropout(attention_output) | |
| class SwinLayer(nn.Module): | |
| def __init__(self, embed_dim, num_heads, dropout=0.1): | |
| super(SwinLayer, self).__init__() | |
| self.layernorm_before = LayerNorm(embed_dim) | |
| self.attention = SwinSelfAttention(embed_dim, num_heads, dropout) | |
| self.drop_path = Dropout(p=dropout) | |
| self.layernorm_after = LayerNorm(embed_dim) | |
| self.fc1 = Linear(embed_dim, 4 * embed_dim) | |
| self.fc2 = Linear(4 * embed_dim, embed_dim) | |
| self.intermediate_act_fn = gelu | |
| def forward(self, x): | |
| normed = self.layernorm_before(x) | |
| attention_output = self.attention(normed) | |
| attention_output = self.drop_path(attention_output) | |
| x = x + attention_output | |
| normed = self.layernorm_after(x) | |
| intermediate = self.fc1(normed) | |
| intermediate = self.intermediate_act_fn(intermediate) | |
| output = self.fc2(intermediate) | |
| return x + output | |
| class SwinPatchEmbedding(nn.Module): | |
| def __init__(self, in_channels=3, patch_size=4, embed_dim=128): | |
| super(SwinPatchEmbedding, self).__init__() | |
| self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| self.norm = LayerNorm(embed_dim) | |
| def forward(self, x): | |
| x = self.projection(x) | |
| x = x.flatten(2).transpose(1, 2) | |
| x = self.norm(x) | |
| return x | |
| class SwinEncoder(nn.Module): | |
| def __init__(self, num_layers, embed_dim, num_heads, dropout=0.1): | |
| super(SwinEncoder, self).__init__() | |
| self.layers = nn.ModuleList([ | |
| SwinLayer(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| class SwinModelForCXRClassification(PreTrainedModel): | |
| config_class = SwinCXRConfig | |
| def __init__(self, config): | |
| super(SwinModelForCXRClassification, self).__init__(config) | |
| self.embeddings = nn.Module() | |
| self.embeddings.patch_embeddings = SwinPatchEmbedding( | |
| in_channels=3, | |
| patch_size=4, | |
| embed_dim=128 | |
| ) | |
| self.embeddings.norm = LayerNorm(128) | |
| self.embeddings.dropout = Dropout(p=0.0) | |
| self.encoder = SwinEncoder( | |
| num_layers=4, | |
| embed_dim=128, | |
| num_heads=4, | |
| dropout=0.1 | |
| ) | |
| self.layernorm = LayerNorm(128) | |
| self.pooler = nn.AdaptiveAvgPool1d(output_size=1) | |
| self.classifier = Linear(in_features=128, out_features=3, bias=True) | |
| def forward(self, pixel_values, labels=None): | |
| x = self.embeddings.patch_embeddings(pixel_values) | |
| x = self.encoder(x) | |
| x = self.layernorm(x) | |
| x = self.pooler(x.transpose(1, 2)).squeeze(-1) | |
| logits = self.classifier(x) | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1)) | |
| return loss, logits | |
| return logits |