|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms |
|
|
import timm |
|
|
from PIL import Image |
|
|
|
|
|
device = 'cuda' |
|
|
processor = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
class SwinBinaryClassifier(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.backbone = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=0) |
|
|
in_f = self.backbone.num_features |
|
|
self.classifier = nn.Linear(in_f, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.backbone(x) |
|
|
return self.classifier(x) |
|
|
|
|
|
model = SwinBinaryClassifier().to(device) |
|
|
model.load_state_dict(torch.load('./breastcancer_model.pth')) |
|
|
|
|
|
image = Image.open('./tests/Benign Masses/20586908 (12)_Benign.png').convert("RGB") |
|
|
input_tensor = processor(image) |
|
|
input_batch = input_tensor.unsqueeze(0) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
input_batch = input_batch.to('cuda') |
|
|
model.to('cuda') |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(input_batch) |
|
|
|
|
|
preds = (torch.sigmoid(output) > 0.5).int() |
|
|
|
|
|
classes = ['Benign', 'Malignant'] |
|
|
|
|
|
print(f'Predicted class: {classes[preds]}') |