File size: 1,478 Bytes
5582ceb
 
 
 
 
 
5df5821
5582ceb
5df5821
5582ceb
5df5821
5582ceb
 
 
5df5821
5582ceb
 
 
 
 
5df5821
5582ceb
5df5821
5582ceb
 
 
5df5821
 
 
 
 
5582ceb
5df5821
 
 
5582ceb
5df5821
 
5582ceb
 
 
 
5df5821
5582ceb
5df5821
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torchvision import models, transforms
from PIL import Image
from huggingface_hub import hf_hub_download

ckpt = hf_hub_download(repo_id="Sarth001/LungCanver", filename="best_model_v2_fixed.pth")
state = torch.load(ckpt, map_location='cpu')["model_state_dict"]

if any(k.startswith('fc.1.weight') for k in state):
    use_dropout = True
    w = next(state[k] for k in state if k.startswith('fc.1.weight'))
    num_classes = w.shape[0]
else:
    use_dropout = False
    w = next(state[k] for k in state if k.startswith('fc.weight'))
    num_classes = w.shape[0]

import torch.nn as nn
model = models.resnet50(pretrained=False)
if use_dropout:
    model.fc = nn.Sequential(nn.Dropout(p=0.5), nn.Linear(model.fc.in_features, num_classes))
else:
    model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(state, strict=False)
model.eval()

tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

def predict(path):
    img = Image.open(path).convert('RGB')
    x = tf(img).unsqueeze(0)
    with torch.no_grad():
        p = torch.softmax(model(x), dim=1)[0]
    return int(p.argmax()), float(p.max())

if __name__ == '__main__':
    import sys
    if len(sys.argv) < 2:
        print('Usage: python inference_example.py image.jpg')
    else:
        idx, prob = predict(sys.argv[1])
        print('Prediction:', idx, 'Prob:', prob)