File size: 1,849 Bytes
af59080 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | import torch
import torch.nn as nn
from torchvision.models import swin_t
class SwinTransformerMultiLabel(nn.Module):
def __init__(self, num_classes):
super(SwinTransformerMultiLabel, self).__init__()
self.model = swin_t(weights="IMAGENET1K_V1")
# Adjust final classification layer
in_features = self.model.head.in_features # Should be 768
self.model.head = nn.Linear(in_features, num_classes)
def forward(self, x):
x = self.model.features(x) # Extract features
print(f"🔹 Feature map shape before flattening: {x.shape}") # Debugging output
# ✅ Correctly apply GAP over height & width
x = x.mean(dim=[1, 2]) # Now shape is (batch_size, 768)
print(f"🔹 Feature shape after GAP: {x.shape}")
x = self.model.head(x) # Classification layer
return x
def main():
# Define number of classes
num_classes = 2
# Create the model
model = SwinTransformerMultiLabel(num_classes)
# Set the model to evaluation mode
model.eval()
# Generate a dummy input tensor (batch_size=5, channels=3, height=224, width=224)
dummy_input = torch.randn(5, 3, 224, 224)
# Forward pass through the model
output = model(dummy_input)
# Print output shape
print(f"✅ Model output shape: {output.shape}") # Expected: (5, 2)
# Check model parameters (classification head)
print(f"✅ Model classification head: {model.model.head}")
# Check with different batch sizes
for batch_size in [1, 8, 16]:
dummy_input = torch.randn(batch_size, 3, 224, 224)
output = model(dummy_input)
print(f"✅ Batch Size {batch_size} -> Output Shape: {output.shape}")
if __name__ == "__main__":
main()
|