| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import timm, json | |
| labels = [ | |
| 'crevice_corrosion', | |
| 'erosion_corrosion', | |
| 'galvanic_corrosion', | |
| 'mic_corrosion', | |
| 'no_corrosion', | |
| 'pitting_corrosion', | |
| 'stress_corrosion', | |
| 'under_insulation_corrosion', | |
| 'uniform_corrosion' | |
| ] | |
| model = timm.create_model('resnet50', pretrained=False, num_classes=len(labels)) | |
| state = torch.load('resnet50-corrosion-classifier-v1.pth', map_location='cpu') | |
| model.load_state_dict(state, strict=False) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def predict(path): | |
| img = Image.open(path).convert('RGB') | |
| x = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| probs = model(x).softmax(dim=1).squeeze().tolist() | |
| idx = int(torch.tensor(probs).argmax()) | |
| return labels[idx], probs[idx] | |
| if __name__ == "__main__": | |
| import sys | |
| print(predict(sys.argv[1] if len(sys.argv)>1 else "test.jpg")) | |