Conn-Finnegan's picture
Disclaimer Update
ee88fce verified
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2
# Load model
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("skin_cancer_resnet18_version1.pt", map_location="cpu"))
model.eval()
classes = ['benign', 'malignant']
# Transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Target layer for Grad-CAM
target_layer = model.layer3[1].conv2
# Store activations & gradients
activations = None
gradients = None
def forward_hook(module, input, output):
global activations
activations = output.detach()
def backward_hook(module, grad_input, grad_output):
global gradients
gradients = grad_output[0].detach()
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)
# Grad-CAM function
def generate_gradcam(input_tensor, class_idx):
model.zero_grad()
output = model(input_tensor)
class_score = output[0, class_idx]
class_score.backward()
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) # [C]
weighted_activations = activations[0] * pooled_gradients[:, None, None] # [C, H, W]
cam = torch.sum(weighted_activations, dim=0).cpu().numpy()
# Normalize and resize
cam = np.maximum(cam, 0)
cam = cv2.resize(cam, (224, 224))
cam -= cam.min()
cam /= cam.max()
return cam
# Full pipeline
def predict(img):
global activations, gradients
activations = None
gradients = None
img = img.convert("RGB")
input_tensor = transform(img).unsqueeze(0)
output = model(input_tensor)
probs = F.softmax(output[0], dim=0)
pred_class = torch.argmax(probs).item()
cam = generate_gradcam(input_tensor, pred_class)
# Convert to heatmap
heatmap = np.uint8(255 * cam)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Overlay
img_np = np.array(img.resize((224, 224)))
overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)
overlay_img = Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
return {classes[i]: float(probs[i]) for i in range(2)}, overlay_img
# Gradio interface
title = "🧠 Soma: Skin Cancer Classifier + Grad-CAM"
description = """
Disclaimer: This AI model was trained solely on the HAM10000 dataset, which contains images of seven pigmented lesion types (actinic keratoses/Bowen’s disease, basal cell carcinoma, benign keratosis-like lesions, dermatofibroma, melanocytic nevi, melanoma, and vascular lesions).
It is not trained to detect other skin cancers or conditions, including Merkel cell carcinoma, invasive squamous cell carcinoma, Kaposi’s sarcoma, adnexal tumours, metastatic skin lesions, amelanotic melanoma, non-pigmented basal cell carcinoma, or rare vascular malignancies.
Results should not be considered a medical diagnosis and must be confirmed by a qualified healthcare professional.
"""
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Lesion Image"),
outputs=[
gr.Label(num_top_classes=2, label="Prediction"),
gr.Image(type="pil", label="Grad-CAM Visualisation")
],
title=title,
description=description
)
demo.launch()