VanillaVanilla000 commited on
Commit
b9b6d04
·
verified ·
1 Parent(s): 4df39eb

Create tester.py

Browse files
Files changed (1) hide show
  1. tester.py +45 -0
tester.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ import timm
6
+ from PIL import Image
7
+
8
+ device = 'cuda'
9
+ processor = transforms.Compose([
10
+ transforms.Resize((224, 224)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
13
+ ])
14
+ class SwinBinaryClassifier(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.backbone = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=0)
18
+ in_f = self.backbone.num_features
19
+ self.classifier = nn.Linear(in_f, 1)
20
+
21
+ def forward(self, x):
22
+ x = self.backbone(x)
23
+ return self.classifier(x)
24
+
25
+ model = SwinBinaryClassifier().to(device)
26
+ model.load_state_dict(torch.load('./breastcancer_model.pth'))
27
+
28
+ image = Image.open('./tests/Benign Masses/20586908 (12)_Benign.png').convert("RGB")
29
+ input_tensor = processor(image)
30
+ input_batch = input_tensor.unsqueeze(0) # Add a batch dimension
31
+
32
+ # Move the input and model to GPU if available
33
+ if torch.cuda.is_available():
34
+ input_batch = input_batch.to('cuda')
35
+ model.to('cuda')
36
+
37
+ # Make a prediction
38
+ with torch.no_grad():
39
+ output = model(input_batch)
40
+
41
+ preds = (torch.sigmoid(output) > 0.5).int()
42
+
43
+ classes = ['Benign', 'Malignant']
44
+ # Print the predicted class
45
+ print(f'Predicted class: {classes[preds]}')