import torch from torchvision import models, transforms from PIL import Image from huggingface_hub import hf_hub_download ckpt = hf_hub_download(repo_id="Sarth001/LungCanver", filename="best_model_v2_fixed.pth") state = torch.load(ckpt, map_location='cpu')["model_state_dict"] if any(k.startswith('fc.1.weight') for k in state): use_dropout = True w = next(state[k] for k in state if k.startswith('fc.1.weight')) num_classes = w.shape[0] else: use_dropout = False w = next(state[k] for k in state if k.startswith('fc.weight')) num_classes = w.shape[0] import torch.nn as nn model = models.resnet50(pretrained=False) if use_dropout: model.fc = nn.Sequential(nn.Dropout(p=0.5), nn.Linear(model.fc.in_features, num_classes)) else: model.fc = nn.Linear(model.fc.in_features, num_classes) model.load_state_dict(state, strict=False) model.eval() tf = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) def predict(path): img = Image.open(path).convert('RGB') x = tf(img).unsqueeze(0) with torch.no_grad(): p = torch.softmax(model(x), dim=1)[0] return int(p.argmax()), float(p.max()) if __name__ == '__main__': import sys if len(sys.argv) < 2: print('Usage: python inference_example.py image.jpg') else: idx, prob = predict(sys.argv[1]) print('Prediction:', idx, 'Prob:', prob)