ibm / app.py
pranaya20's picture
ibm.py
35336dd verified
# ====================================================================
# 0. INSTALL DEPENDENCIES
# ====================================================================
# We ensure the latest versions are installed to avoid compatibility issues.
!pip install --upgrade -q gradio grad-cam timm
# ====================================================================
# 1. IMPORTS & ENHANCED SETUP
# ====================================================================
import gradio as gr
import torch
import timm
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from glob import glob
# --- NEW: Rich information about each lesion type ---
lesion_info = {
'akiec': {
"name": "Actinic Keratoses (akiec)",
"description": "A pre-cancerous patch of thick, scaly, or crusty skin. It is more common in fair-skinned individuals and is often associated with long-term sun exposure. While not cancer, it can develop into squamous cell carcinoma.",
"risk": "Moderate",
"color": "orange"
},
'bcc': {
"name": "Basal Cell Carcinoma (bcc)",
"description": "The most common form of skin cancer. It often appears as a slightly transparent bump on the sun-exposed skin. It grows slowly and rarely spreads to other parts of the body.",
"risk": "High",
"color": "red"
},
'bkl': {
"name": "Benign Keratosis-like Lesions (bkl)",
"description": "A group of common, non-cancerous skin growths. This category includes seborrheic keratoses, solar lentigo, and lichen planus-like keratoses. They are harmless but can sometimes resemble skin cancer.",
"risk": "Low / Benign",
"color": "green"
},
'df': {
"name": "Dermatofibroma (df)",
"description": "A common, benign skin nodule. It is a harmless growth within the skin, usually firm and brown to tan. It often feels like a hard lump under the skin.",
"risk": "Low / Benign",
"color": "green"
},
'mel': {
"name": "Melanoma (mel)",
"description": "The most serious type of skin cancer. It develops in the cells (melanocytes) that produce melanin. Melanoma can be more aggressive than other skin cancers and has a higher chance of spreading if not treated early.",
"risk": "High",
"color": "red"
},
'nv': {
"name": "Melanocytic Nevi (nv)",
"description": "Commonly known as moles. These are benign growths of melanocytes. While most moles are harmless, some types can be at higher risk of developing into melanoma.",
"risk": "Low / Benign",
"color": "green"
},
'vasc': {
"name": "Vascular Lesions (vasc)",
"description": "Skin conditions and tumors resulting from a proliferation of blood vessels. This category includes cherry angiomas, angiokeratomas, and pyogenic granulomas. They are typically benign.",
"risk": "Low / Benign",
"color": "green"
}
}
# Map short codes to full lesion info
label_to_lesion = list(lesion_info.keys())
class_names = [lesion_info[k]['name'] for k in label_to_lesion]
# Find example image paths automatically
try:
example_paths = glob('/content/ham10000_images_part_1/*.jpg')[:4]
except Exception as e:
print(f"Could not find example images, examples will be empty: {e}")
example_paths = []
# --- Standard Model and Transform Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=7)
model.to(device)
model_path = 'best_vit_model.pth' # Make sure this model file exists!
try:
model.load_state_dict(torch.load(model_path, map_location=device))
except FileNotFoundError:
print(f"ERROR: Model file '{model_path}' not found. The app will not work.")
# You might want to add code here to download the model if it's hosted somewhere
model.eval()
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
val_test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
def reshape_transform(tensor, height=14, width=14):
result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
result = result.permute(0, 3, 1, 2)
return result
target_layer = [model.blocks[-1].norm1]
cam = GradCAM(model=model, target_layers=target_layer, reshape_transform=reshape_transform)
# ====================================================================
# 2. ENHANCED CORE FUNCTION
# ====================================================================
def generate_cam(image_tensor, class_name):
if not class_name or class_name == "None": return None, "Select a class to explain."
try:
target_idx = class_names.index(class_name)
targets = [ClassifierOutputTarget(target_idx)]
grayscale_cam = cam(input_tensor=image_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
# De-normalize image for display
img_np = image_tensor.squeeze().cpu().numpy().transpose(1, 2, 0)
img_np = std * img_np + mean
img_np = np.clip(img_np, 0, 1)
visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
return visualization, f"Explanation for: **{class_name}**"
except Exception as e:
return None, f"Could not generate CAM: {e}"
def predict_and_analyze(input_image, confidence_threshold, explain_class_1, explain_class_2):
if input_image is None:
return None, None, None, None, None, None, None, "Please upload an image to begin."
# Initial Prediction
image_tensor = val_test_transform(input_image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Top prediction details
confidences = {name: float(prob) for name, prob in zip(class_names, probabilities)}
predicted_class_idx = probabilities.argmax().item()
predicted_class_name = class_names[predicted_class_idx]
confidence = probabilities[predicted_class_idx].item()
# Get lesion info and risk
lesion_key = label_to_lesion[predicted_class_idx]
info = lesion_info[lesion_key]
risk_html = f"<p style='color:{info['color']}; font-weight:bold; font-size:1.1em;'>Risk Level: {info['risk']}</p>"
lesion_details_md = f"### {info['name']}\n\n{info['description']}"
# Confidence warning
warning_text = ""
if confidence < confidence_threshold:
warning_text = f"⚠️ **Warning:** Model confidence ({confidence:.1%}) is below your threshold of {confidence_threshold:.0%}. The prediction may be unreliable."
primary_result_md = f"**Top Prediction:** `{predicted_class_name}`\n\n**Confidence:** `{confidence:.2%}`\n\n{warning_text}"
# Generate CAM for the first class (defaults to top prediction)
if not explain_class_1 or explain_class_1 == "None": explain_class_1 = predicted_class_name
cam_1, title_1 = generate_cam(image_tensor, explain_class_1)
# Generate CAM for the second (optional) comparison class
cam_2, title_2 = generate_cam(image_tensor, explain_class_2)
# Generate processed image for viewing
processed_img_np = image_tensor.squeeze().cpu().numpy().transpose(1, 2, 0)
processed_img_np = std * processed_img_np + mean
processed_img_np = np.clip(processed_img_np, 0, 1)
return (
primary_result_md,
risk_html,
confidences,
cam_1,
title_1,
cam_2,
title_2,
processed_img_np,
lesion_details_md,
)
# ====================================================================
# 3. THE POLISHED & ENHANCED GUI
# ====================================================================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="teal", secondary_hue="orange")) as demo:
gr.Markdown("# 🩺 Advanced XAI Skin Lesion Analyzer\nUpload a dermatoscopic image to classify it with a Vision Transformer and visually interpret the model's reasoning.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam"], height=300)
with gr.Accordion("⚙️ Analysis Controls", open=True):
confidence_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Warning Threshold")
gr.Markdown("Select classes below to generate comparative explanations.")
with gr.Row():
explain_class_1 = gr.Dropdown(choices=["None"] + class_names, value="None", label="Explain Class 1", info="Defaults to top prediction.")
explain_class_2 = gr.Dropdown(choices=["None"] + class_names, value="None", label="Explain Class 2 (Optional)")
with gr.Row():
clear_btn = gr.ClearButton(value="Clear")
submit_btn = gr.Button("Analyze Image", variant="primary")
gr.Examples(examples=example_paths, inputs=input_image, label="Example Images")
with gr.Column(scale=2):
with gr.Row():
primary_result = gr.Markdown(label="Primary Result")
risk_assessment = gr.HTML(label="Risk Assessment")
with gr.Tabs():
with gr.TabItem("🔬 Comparative XAI Heatmaps"):
with gr.Row():
with gr.Column():
cam_title_1 = gr.Markdown()
output_cam_1 = gr.Image(label="Grad-CAM 1")
with gr.Column():
cam_title_2 = gr.Markdown()
output_cam_2 = gr.Image(label="Grad-CAM 2")
with gr.TabItem("📊 Detailed Probabilities"):
output_probs = gr.BarPlot(x="label", y="confidence", title="Class Probabilities", y_lim=[0,1], min_width=300)
with gr.TabItem("🖼️ Model Input"):
processed_image = gr.Image(label="Processed Image (224x224 Normalized)", height=300)
gr.Markdown("This is the image after resizing and normalization, as seen by the model.")
with gr.TabItem("ℹ️ Lesion Information"):
lesion_details = gr.Markdown(label="About the Predicted Lesion")
gr.Markdown("---")
with gr.Accordion("About this Tool", open=False):
gr.Markdown(
"**Model Performance:** This tool uses a Vision Transformer (ViT) model fine-tuned on the HAM10000 dataset, achieving **~72-75%** validation accuracy. "
"Performance can vary based on image quality and lesion type.\n\n"
"**Explainability (XAI):** The heatmaps are generated using Grad-CAM, a technique that highlights the regions of the image the model focused on to make its prediction."
)
gr.Markdown("### ⚠️ **Medical Disclaimer**\nThis tool is a research demonstration and **is not a substitute for professional medical advice**. The predictions are for informational purposes only. Please consult a qualified dermatologist for any medical concerns.")
# Define components to be cleared
components_to_clear = [
input_image, primary_result, risk_assessment, output_probs,
output_cam_1, cam_title_1, output_cam_2, cam_title_2,
processed_image, lesion_details, explain_class_1, explain_class_2
]
# Define outputs for the submit button
outputs_list = [
primary_result, risk_assessment, output_probs,
output_cam_1, cam_title_1, output_cam_2, cam_title_2,
processed_image, lesion_details
]
clear_btn.click(lambda: [None] * len(components_to_clear), outputs=components_to_clear)
submit_btn.click(
fn=predict_and_analyze,
inputs=[input_image, confidence_threshold, explain_class_1, explain_class_2],
outputs=outputs_list
)
demo.launch(share=True, debug=True)