File size: 3,339 Bytes
4f19e7d
 
17cba64
4f19e7d
 
d3bf433
 
4f19e7d
 
 
 
b32662e
4f19e7d
 
 
 
17cba64
4f19e7d
 
 
 
 
17cba64
 
d3bf433
17cba64
 
 
d3bf433
17cba64
 
 
d3bf433
17cba64
 
 
 
 
 
 
 
 
d3bf433
 
17cba64
d3bf433
 
17cba64
 
 
 
 
d3bf433
 
17cba64
 
d3bf433
 
17cba64
 
 
 
 
d3bf433
17cba64
 
d3bf433
17cba64
 
 
d3bf433
 
 
17cba64
 
d3bf433
 
17cba64
 
d3bf433
 
 
 
4f19e7d
17cba64
3b6b171
add70db
3b6b171
ee88fce
 
 
add70db
 
4f19e7d
 
add70db
d3bf433
 
 
 
add70db
 
4f19e7d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2

# Load model
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("skin_cancer_resnet18_version1.pt", map_location="cpu"))
model.eval()

classes = ['benign', 'malignant']

# Transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Target layer for Grad-CAM
target_layer = model.layer3[1].conv2

# Store activations & gradients
activations = None
gradients = None

def forward_hook(module, input, output):
    global activations
    activations = output.detach()

def backward_hook(module, grad_input, grad_output):
    global gradients
    gradients = grad_output[0].detach()

target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)

# Grad-CAM function
def generate_gradcam(input_tensor, class_idx):
    model.zero_grad()
    output = model(input_tensor)
    class_score = output[0, class_idx]
    class_score.backward()

    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])  # [C]
    weighted_activations = activations[0] * pooled_gradients[:, None, None]  # [C, H, W]
    cam = torch.sum(weighted_activations, dim=0).cpu().numpy()
    
    # Normalize and resize
    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (224, 224))
    cam -= cam.min()
    cam /= cam.max()
    return cam

# Full pipeline
def predict(img):
    global activations, gradients
    activations = None
    gradients = None

    img = img.convert("RGB")
    input_tensor = transform(img).unsqueeze(0)

    output = model(input_tensor)
    probs = F.softmax(output[0], dim=0)
    pred_class = torch.argmax(probs).item()

    cam = generate_gradcam(input_tensor, pred_class)

    # Convert to heatmap
    heatmap = np.uint8(255 * cam)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    # Overlay
    img_np = np.array(img.resize((224, 224)))
    overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)
    overlay_img = Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))

    return {classes[i]: float(probs[i]) for i in range(2)}, overlay_img

# Gradio interface
title = "🧠 Soma: Skin Cancer Classifier + Grad-CAM"
description = """

Disclaimer: This AI model was trained solely on the HAM10000 dataset, which contains images of seven pigmented lesion types (actinic keratoses/Bowen’s disease, basal cell carcinoma, benign keratosis-like lesions, dermatofibroma, melanocytic nevi, melanoma, and vascular lesions).
It is not trained to detect other skin cancers or conditions, including Merkel cell carcinoma, invasive squamous cell carcinoma, Kaposi’s sarcoma, adnexal tumours, metastatic skin lesions, amelanotic melanoma, non-pigmented basal cell carcinoma, or rare vascular malignancies.
Results should not be considered a medical diagnosis and must be confirmed by a qualified healthcare professional.
"""

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Lesion Image"),
    outputs=[
        gr.Label(num_top_classes=2, label="Prediction"),
        gr.Image(type="pil", label="Grad-CAM Visualisation")
    ],
    title=title,
    description=description
)

demo.launch()