| |
| |
| |
| |
| !pip install --upgrade -q gradio grad-cam timm |
|
|
| |
| |
| |
| import gradio as gr |
| import torch |
| import timm |
| from PIL import Image |
| import numpy as np |
| import torchvision.transforms as transforms |
| from pytorch_grad_cam import GradCAM |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
| from pytorch_grad_cam.utils.image import show_cam_on_image |
| from glob import glob |
|
|
| |
| lesion_info = { |
| 'akiec': { |
| "name": "Actinic Keratoses (akiec)", |
| "description": "A pre-cancerous patch of thick, scaly, or crusty skin. It is more common in fair-skinned individuals and is often associated with long-term sun exposure. While not cancer, it can develop into squamous cell carcinoma.", |
| "risk": "Moderate", |
| "color": "orange" |
| }, |
| 'bcc': { |
| "name": "Basal Cell Carcinoma (bcc)", |
| "description": "The most common form of skin cancer. It often appears as a slightly transparent bump on the sun-exposed skin. It grows slowly and rarely spreads to other parts of the body.", |
| "risk": "High", |
| "color": "red" |
| }, |
| 'bkl': { |
| "name": "Benign Keratosis-like Lesions (bkl)", |
| "description": "A group of common, non-cancerous skin growths. This category includes seborrheic keratoses, solar lentigo, and lichen planus-like keratoses. They are harmless but can sometimes resemble skin cancer.", |
| "risk": "Low / Benign", |
| "color": "green" |
| }, |
| 'df': { |
| "name": "Dermatofibroma (df)", |
| "description": "A common, benign skin nodule. It is a harmless growth within the skin, usually firm and brown to tan. It often feels like a hard lump under the skin.", |
| "risk": "Low / Benign", |
| "color": "green" |
| }, |
| 'mel': { |
| "name": "Melanoma (mel)", |
| "description": "The most serious type of skin cancer. It develops in the cells (melanocytes) that produce melanin. Melanoma can be more aggressive than other skin cancers and has a higher chance of spreading if not treated early.", |
| "risk": "High", |
| "color": "red" |
| }, |
| 'nv': { |
| "name": "Melanocytic Nevi (nv)", |
| "description": "Commonly known as moles. These are benign growths of melanocytes. While most moles are harmless, some types can be at higher risk of developing into melanoma.", |
| "risk": "Low / Benign", |
| "color": "green" |
| }, |
| 'vasc': { |
| "name": "Vascular Lesions (vasc)", |
| "description": "Skin conditions and tumors resulting from a proliferation of blood vessels. This category includes cherry angiomas, angiokeratomas, and pyogenic granulomas. They are typically benign.", |
| "risk": "Low / Benign", |
| "color": "green" |
| } |
| } |
| |
| label_to_lesion = list(lesion_info.keys()) |
| class_names = [lesion_info[k]['name'] for k in label_to_lesion] |
|
|
| |
| try: |
| example_paths = glob('/content/ham10000_images_part_1/*.jpg')[:4] |
| except Exception as e: |
| print(f"Could not find example images, examples will be empty: {e}") |
| example_paths = [] |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=7) |
| model.to(device) |
| model_path = 'best_vit_model.pth' |
| try: |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| except FileNotFoundError: |
| print(f"ERROR: Model file '{model_path}' not found. The app will not work.") |
| |
| model.eval() |
|
|
| mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
| val_test_transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean, std) |
| ]) |
|
|
| def reshape_transform(tensor, height=14, width=14): |
| result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2)) |
| result = result.permute(0, 3, 1, 2) |
| return result |
| target_layer = [model.blocks[-1].norm1] |
| cam = GradCAM(model=model, target_layers=target_layer, reshape_transform=reshape_transform) |
|
|
| |
| |
| |
| def generate_cam(image_tensor, class_name): |
| if not class_name or class_name == "None": return None, "Select a class to explain." |
| try: |
| target_idx = class_names.index(class_name) |
| targets = [ClassifierOutputTarget(target_idx)] |
| grayscale_cam = cam(input_tensor=image_tensor, targets=targets) |
| grayscale_cam = grayscale_cam[0, :] |
| |
| |
| img_np = image_tensor.squeeze().cpu().numpy().transpose(1, 2, 0) |
| img_np = std * img_np + mean |
| img_np = np.clip(img_np, 0, 1) |
| |
| visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True) |
| return visualization, f"Explanation for: **{class_name}**" |
| except Exception as e: |
| return None, f"Could not generate CAM: {e}" |
|
|
| def predict_and_analyze(input_image, confidence_threshold, explain_class_1, explain_class_2): |
| if input_image is None: |
| return None, None, None, None, None, None, None, "Please upload an image to begin." |
|
|
| |
| image_tensor = val_test_transform(input_image).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| output = model(image_tensor) |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
|
| |
| confidences = {name: float(prob) for name, prob in zip(class_names, probabilities)} |
| predicted_class_idx = probabilities.argmax().item() |
| predicted_class_name = class_names[predicted_class_idx] |
| confidence = probabilities[predicted_class_idx].item() |
| |
| |
| lesion_key = label_to_lesion[predicted_class_idx] |
| info = lesion_info[lesion_key] |
| risk_html = f"<p style='color:{info['color']}; font-weight:bold; font-size:1.1em;'>Risk Level: {info['risk']}</p>" |
| lesion_details_md = f"### {info['name']}\n\n{info['description']}" |
|
|
| |
| warning_text = "" |
| if confidence < confidence_threshold: |
| warning_text = f"⚠️ **Warning:** Model confidence ({confidence:.1%}) is below your threshold of {confidence_threshold:.0%}. The prediction may be unreliable." |
| |
| primary_result_md = f"**Top Prediction:** `{predicted_class_name}`\n\n**Confidence:** `{confidence:.2%}`\n\n{warning_text}" |
| |
| |
| if not explain_class_1 or explain_class_1 == "None": explain_class_1 = predicted_class_name |
| cam_1, title_1 = generate_cam(image_tensor, explain_class_1) |
| |
| |
| cam_2, title_2 = generate_cam(image_tensor, explain_class_2) |
| |
| |
| processed_img_np = image_tensor.squeeze().cpu().numpy().transpose(1, 2, 0) |
| processed_img_np = std * processed_img_np + mean |
| processed_img_np = np.clip(processed_img_np, 0, 1) |
|
|
| return ( |
| primary_result_md, |
| risk_html, |
| confidences, |
| cam_1, |
| title_1, |
| cam_2, |
| title_2, |
| processed_img_np, |
| lesion_details_md, |
| ) |
|
|
| |
| |
| |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="teal", secondary_hue="orange")) as demo: |
| gr.Markdown("# 🩺 Advanced XAI Skin Lesion Analyzer\nUpload a dermatoscopic image to classify it with a Vision Transformer and visually interpret the model's reasoning.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image = gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam"], height=300) |
| |
| with gr.Accordion("⚙️ Analysis Controls", open=True): |
| confidence_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Warning Threshold") |
| gr.Markdown("Select classes below to generate comparative explanations.") |
| with gr.Row(): |
| explain_class_1 = gr.Dropdown(choices=["None"] + class_names, value="None", label="Explain Class 1", info="Defaults to top prediction.") |
| explain_class_2 = gr.Dropdown(choices=["None"] + class_names, value="None", label="Explain Class 2 (Optional)") |
|
|
| with gr.Row(): |
| clear_btn = gr.ClearButton(value="Clear") |
| submit_btn = gr.Button("Analyze Image", variant="primary") |
| |
| gr.Examples(examples=example_paths, inputs=input_image, label="Example Images") |
|
|
| with gr.Column(scale=2): |
| with gr.Row(): |
| primary_result = gr.Markdown(label="Primary Result") |
| risk_assessment = gr.HTML(label="Risk Assessment") |
|
|
| with gr.Tabs(): |
| with gr.TabItem("🔬 Comparative XAI Heatmaps"): |
| with gr.Row(): |
| with gr.Column(): |
| cam_title_1 = gr.Markdown() |
| output_cam_1 = gr.Image(label="Grad-CAM 1") |
| with gr.Column(): |
| cam_title_2 = gr.Markdown() |
| output_cam_2 = gr.Image(label="Grad-CAM 2") |
| |
| with gr.TabItem("📊 Detailed Probabilities"): |
| output_probs = gr.BarPlot(x="label", y="confidence", title="Class Probabilities", y_lim=[0,1], min_width=300) |
|
|
| with gr.TabItem("🖼️ Model Input"): |
| processed_image = gr.Image(label="Processed Image (224x224 Normalized)", height=300) |
| gr.Markdown("This is the image after resizing and normalization, as seen by the model.") |
|
|
| with gr.TabItem("ℹ️ Lesion Information"): |
| lesion_details = gr.Markdown(label="About the Predicted Lesion") |
|
|
| gr.Markdown("---") |
| with gr.Accordion("About this Tool", open=False): |
| gr.Markdown( |
| "**Model Performance:** This tool uses a Vision Transformer (ViT) model fine-tuned on the HAM10000 dataset, achieving **~72-75%** validation accuracy. " |
| "Performance can vary based on image quality and lesion type.\n\n" |
| "**Explainability (XAI):** The heatmaps are generated using Grad-CAM, a technique that highlights the regions of the image the model focused on to make its prediction." |
| ) |
| gr.Markdown("### ⚠️ **Medical Disclaimer**\nThis tool is a research demonstration and **is not a substitute for professional medical advice**. The predictions are for informational purposes only. Please consult a qualified dermatologist for any medical concerns.") |
|
|
| |
| components_to_clear = [ |
| input_image, primary_result, risk_assessment, output_probs, |
| output_cam_1, cam_title_1, output_cam_2, cam_title_2, |
| processed_image, lesion_details, explain_class_1, explain_class_2 |
| ] |
| |
| |
| outputs_list = [ |
| primary_result, risk_assessment, output_probs, |
| output_cam_1, cam_title_1, output_cam_2, cam_title_2, |
| processed_image, lesion_details |
| ] |
|
|
| clear_btn.click(lambda: [None] * len(components_to_clear), outputs=components_to_clear) |
| submit_btn.click( |
| fn=predict_and_analyze, |
| inputs=[input_image, confidence_threshold, explain_class_1, explain_class_2], |
| outputs=outputs_list |
| ) |
|
|
| demo.launch(share=True, debug=True) |