Conn Finnegan commited on
Commit
d3bf433
·
verified ·
1 Parent(s): 1e11e36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -14
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import gradio as gr
2
  import torch
3
  from torchvision import models, transforms
 
4
  from PIL import Image
 
 
5
 
6
  # Load model
7
  model = models.resnet18()
@@ -9,37 +12,93 @@ model.fc = torch.nn.Linear(model.fc.in_features, 2)
9
  model.load_state_dict(torch.load("skin_cancer_resnet18_version1.pt", map_location="cpu"))
10
  model.eval()
11
 
12
- # Class labels
13
  classes = ['benign', 'malignant']
14
 
15
- # Image preprocessing
16
  transform = transforms.Compose([
17
  transforms.Resize((224, 224)),
18
  transforms.ToTensor()
19
  ])
20
 
21
- # Inference function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def predict(img):
23
- img = img.convert("RGB")
24
- input_tensor = transform(img).unsqueeze(0)
 
 
25
  with torch.no_grad():
26
  output = model(input_tensor)
27
- probs = torch.nn.functional.softmax(output[0], dim=0)
28
- return {classes[i]: float(probs[i]) for i in range(2)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # UI text
31
- title = "🧠 Lumen: Skin Cancer Classifier"
32
  description = """
33
- Upload a dermoscopic image of a mole or skin lesion.<br>
34
- The model will classify it as <b>benign</b> or <b>malignant</b> based on its appearance.<br><br>
35
- <b>Disclaimer:</b> This tool is for research and educational use only. It is not a diagnostic device.
36
  """
37
 
38
- # Gradio Interface
39
  demo = gr.Interface(
40
  fn=predict,
41
  inputs=gr.Image(type="pil", label="Upload Lesion Image"),
42
- outputs=gr.Label(num_top_classes=2, label="Prediction"),
 
 
 
43
  title=title,
44
  description=description
45
  )
 
1
  import gradio as gr
2
  import torch
3
  from torchvision import models, transforms
4
+ from torch.nn import functional as F
5
  from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
 
9
  # Load model
10
  model = models.resnet18()
 
12
  model.load_state_dict(torch.load("skin_cancer_resnet18_version1.pt", map_location="cpu"))
13
  model.eval()
14
 
 
15
  classes = ['benign', 'malignant']
16
 
17
+ # Preprocessing
18
  transform = transforms.Compose([
19
  transforms.Resize((224, 224)),
20
  transforms.ToTensor()
21
  ])
22
 
23
+ # Grad-CAM setup
24
+ final_conv_layer = model.layer4[1].conv2 # Adjust if using a different architecture
25
+ gradients = []
26
+
27
+ def save_gradient(module, grad_input, grad_output):
28
+ gradients.append(grad_output[0])
29
+
30
+ final_conv_layer.register_forward_hook(save_gradient)
31
+
32
+ def generate_gradcam(input_tensor, pred_class):
33
+ model.zero_grad()
34
+ output = model(input_tensor)
35
+ class_score = output[0, pred_class]
36
+ class_score.backward()
37
+
38
+ grads_val = gradients[-1].detach().numpy()[0]
39
+ activations = final_conv_layer_output.detach().numpy()[0]
40
+
41
+ weights = np.mean(grads_val, axis=(1, 2))
42
+ cam = np.zeros(activations.shape[1:], dtype=np.float32)
43
+
44
+ for i, w in enumerate(weights):
45
+ cam += w * activations[i]
46
+
47
+ cam = np.maximum(cam, 0)
48
+ cam = cv2.resize(cam, (224, 224))
49
+ cam = cam - np.min(cam)
50
+ cam = cam / np.max(cam)
51
+
52
+ return cam
53
+
54
+ # Hook to get activations
55
+ def get_activations(module, input, output):
56
+ global final_conv_layer_output
57
+ final_conv_layer_output = output
58
+
59
+ final_conv_layer.register_forward_hook(get_activations)
60
+
61
+ # Main prediction and Grad-CAM overlay function
62
  def predict(img):
63
+ gradients.clear()
64
+ img_rgb = img.convert("RGB")
65
+ input_tensor = transform(img_rgb).unsqueeze(0)
66
+
67
  with torch.no_grad():
68
  output = model(input_tensor)
69
+ probs = F.softmax(output[0], dim=0)
70
+ pred_class = torch.argmax(probs).item()
71
+
72
+ # Grad-CAM
73
+ cam = generate_gradcam(input_tensor, pred_class)
74
+
75
+ # Convert CAM to heatmap
76
+ heatmap = (cam * 255).astype(np.uint8)
77
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
78
+
79
+ # Overlay heatmap on image
80
+ img_np = np.array(img_rgb.resize((224, 224)))
81
+ overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)
82
+
83
+ # Convert back to PIL
84
+ overlay_img = Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
85
+
86
+ return {classes[i]: float(probs[i]) for i in range(2)}, overlay_img
87
 
88
+ # UI
89
+ title = "🧠 Lumen: Skin Cancer Classifier with Grad-CAM"
90
  description = """
91
+ Upload a dermoscopic image of a mole or lesion. The model will classify it as <b>benign</b> or <b>malignant</b> and show a heatmap of what it focused on.<br><br>
92
+ <b>Disclaimer:</b> This tool is for educational use only.
 
93
  """
94
 
 
95
  demo = gr.Interface(
96
  fn=predict,
97
  inputs=gr.Image(type="pil", label="Upload Lesion Image"),
98
+ outputs=[
99
+ gr.Label(num_top_classes=2, label="Prediction"),
100
+ gr.Image(type="pil", label="Grad-CAM Visualisation")
101
+ ],
102
  title=title,
103
  description=description
104
  )