import gradio as gr import torch from PIL import Image from torchvision.transforms import ToTensor import torchvision.transforms as transforms import torch.nn.functional as F import numpy as np from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image import matplotlib.pyplot as plt # Load the pre-trained model model = torch.load('model.pth', map_location=torch.device('cpu')) model.eval() #define the target layer to pull for gradcam target_layers = [model.layer4[-1]] # Define the class labels class_labels = ['Crazing', 'Inclusion', 'Patches', 'Pitted', 'Rolled', 'Scratches'] # Transformations for input images preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.4562, 0.4562, 0.4562], std=[0.2502, 0.2502, 0.2502]), ]) inv_normalize = transforms.Normalize( mean=[0.4562, 0.4562, 0.4562], std=[0.2502, 0.2502, 0.2502] ) # Gradio app interface def classify_image(inp, transperancy=0.8): model.to("cpu") input_tensor = preprocess(inp) input_batch = input_tensor.unsqueeze(0).to('cpu') # Create a batch cam = GradCAM(model=model,use_cuda=False, target_layers=target_layers) grayscale_cam = cam(input_tensor=input_batch, targets=None) grayscale_cam = grayscale_cam[0, :] img = input_tensor.squeeze(0) img = inv_normalize(img) rgb_img = np.transpose(img, (1, 2, 0)) rgb_img = rgb_img.numpy() rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min()) visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transperancy) with torch.no_grad(): output = model(input_batch) probabilities = F.softmax(output[0], dim=0) pred_class_idx = torch.argmax(probabilities).item() class_probabilities = {class_labels[i]: float(probabilities[i]) for i in range(len(class_labels))} #prob_string = "\n".join([f"{label}: {prob:.2f}" for label, prob in class_probabilities.items()]) return inp, class_probabilities, visualization iface = gr.Interface( fn=classify_image, inputs=[gr.Image(shape=(200, 200),type="pil", label="Input Image"), gr.Slider(0, 1, value = 0.8, label="Opacity of GradCAM")], outputs=[ gr.Image(shape=(200,200),type="numpy", label="Input Image").style(width=300, height=300), gr.Label(label="Probability of Defect", num_top_classes=3), gr.Image(shape=(200,200), type="numpy", label="GradCam").style(width=300, height=300) ], title="Metal Defects Image Classification", description="The classification depends on the microscopic scale of the image being uploaded :)" ) iface.launch()