VanillaVanilla000's picture
Create tester.py
16df3ca verified
import torch
import torch.nn as nn
from torchvision import transforms
import timm
from PIL import Image
device = 'cuda'
processor = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class SwinBinaryClassifier(nn.Module):
def __init__(self):
super().__init__()
self.backbone = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=0)
in_f = self.backbone.num_features
self.classifier = nn.Linear(in_f, 1)
def forward(self, x):
x = self.backbone(x)
return self.classifier(x)
model = SwinBinaryClassifier().to(device)
model.load_state_dict(torch.load('./breastcancer_model.pth'))
image = Image.open('./tests/Benign Masses/20586908 (12)_Benign.png').convert("RGB")
input_tensor = processor(image)
input_batch = input_tensor.unsqueeze(0) # Add a batch dimension
# Move the input and model to GPU if available
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
# Make a prediction
with torch.no_grad():
output = model(input_batch)
preds = (torch.sigmoid(output) > 0.5).int()
classes = ['Benign', 'Malignant']
# Print the predicted class
print(f'Predicted class: {classes[preds]}')