Shreeraj commited on
Commit
5aad50d
·
1 Parent(s): 60fd70c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision.transforms import ToTensor
5
+ import torchvision.transforms as transforms
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from pytorch_grad_cam import GradCAM
9
+ from pytorch_grad_cam.utils.image import show_cam_on_image
10
+ import matplotlib.pyplot as plt
11
+
12
+
13
+ # Load the pre-trained model
14
+ model = torch.load('model.pth', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
15
+ model.eval()
16
+
17
+ #define the target layer to pull for gradcam
18
+ target_layers = [model.layer4[-1]]
19
+
20
+ # Define the class labels
21
+ class_labels = ['Crazing', 'Inclusion', 'Patches', 'Pitted', 'Rolled', 'Scratches']
22
+
23
+ # Transformations for input images
24
+ preprocess = transforms.Compose([
25
+ transforms.Resize((224, 224)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.4562, 0.4562, 0.4562], std=[0.2502, 0.2502, 0.2502]),
28
+ ])
29
+
30
+ inv_normalize = transforms.Normalize(
31
+ mean=[0.4562, 0.4562, 0.4562],
32
+ std=[0.2502, 0.2502, 0.2502]
33
+ )
34
+
35
+ # Gradio app interface
36
+ def classify_image(inp, transperancy=0.8):
37
+ #image = Image.fromarray((inp * 255).astype(np.uint8)) # Convert NumPy array to PIL Image
38
+ #input_tensor = preprocess(image)
39
+ input_tensor = preprocess(inp)
40
+ input_batch = input_tensor.unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu') # Create a batch
41
+
42
+ cam = GradCAM(model=model,use_cuda=True, target_layers=target_layers)
43
+
44
+ grayscale_cam = cam(input_tensor=input_batch, targets=None)
45
+ grayscale_cam = grayscale_cam[0, :]
46
+ img = input_tensor.squeeze(0)
47
+ img = inv_normalize(img)
48
+ rgb_img = np.transpose(img, (1, 2, 0))
49
+ rgb_img = rgb_img.numpy()
50
+ rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())
51
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transperancy)
52
+
53
+ with torch.no_grad():
54
+ output = model(input_batch)
55
+
56
+ probabilities = F.softmax(output[0], dim=0)
57
+ pred_class_idx = torch.argmax(probabilities).item()
58
+
59
+ class_probabilities = {class_labels[i]: float(probabilities[i]) for i in range(len(class_labels))}
60
+ #prob_string = "\n".join([f"{label}: {prob:.2f}" for label, prob in class_probabilities.items()])
61
+
62
+ return inp, class_probabilities, visualization
63
+
64
+ iface = gr.Interface(
65
+ fn=classify_image,
66
+ inputs=[gr.Image(shape=(200, 200),type="pil", label="Input Image"),
67
+ gr.Slider(0, 1, value = 0.8, label="Opacity of GradCAM")],
68
+
69
+ outputs=[
70
+ gr.Image(shape=(200,200),type="numpy", label="Input Image").style(width=300, height=300),
71
+ gr.Label(label="Probability of Defect", num_top_classes=3),
72
+ gr.Image(shape=(200,200), type="numpy", label="GradCam").style(width=300, height=300)
73
+ ],
74
+ title="Metal Defects Image Classification",
75
+ description="The classification depends on the microscopic scale of the image being uploaded :)"
76
+ )
77
+ iface.launch()