VanillaVanilla000 commited on
Commit
16df3ca
·
verified ·
1 Parent(s): 22785e8

Create tester.py

Browse files
Files changed (1) hide show
  1. tester.py +44 -0
tester.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ import timm
5
+ from PIL import Image
6
+
7
+ device = 'cuda'
8
+ processor = transforms.Compose([
9
+ transforms.Resize((224, 224)),
10
+ transforms.ToTensor(),
11
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
12
+ ])
13
+ class SwinBinaryClassifier(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.backbone = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=0)
17
+ in_f = self.backbone.num_features
18
+ self.classifier = nn.Linear(in_f, 1)
19
+
20
+ def forward(self, x):
21
+ x = self.backbone(x)
22
+ return self.classifier(x)
23
+
24
+ model = SwinBinaryClassifier().to(device)
25
+ model.load_state_dict(torch.load('./breastcancer_model.pth'))
26
+
27
+ image = Image.open('./tests/Benign Masses/20586908 (12)_Benign.png').convert("RGB")
28
+ input_tensor = processor(image)
29
+ input_batch = input_tensor.unsqueeze(0) # Add a batch dimension
30
+
31
+ # Move the input and model to GPU if available
32
+ if torch.cuda.is_available():
33
+ input_batch = input_batch.to('cuda')
34
+ model.to('cuda')
35
+
36
+ # Make a prediction
37
+ with torch.no_grad():
38
+ output = model(input_batch)
39
+
40
+ preds = (torch.sigmoid(output) > 0.5).int()
41
+
42
+ classes = ['Benign', 'Malignant']
43
+ # Print the predicted class
44
+ print(f'Predicted class: {classes[preds]}')