| import gradio as gr |
| import torch |
| import timm |
| from torchvision import transforms |
| import numpy as np |
| import os |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
| import cv2 |
|
|
| |
| from pytorch_grad_cam import GradCAM |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
| from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
| print("Library berhasil diimpor.") |
|
|
| |
| NUM_CLASSES = 4 |
| IMAGE_SIZE = (224, 224) |
| MODEL_NAME_TIMM = 'swin_tiny_patch4_window7_224' |
| transform = transforms.Compose([ |
| transforms.Grayscale(num_output_channels=3), |
| transforms.Resize(IMAGE_SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
|
|
| CLASS_NAMES = ['Mild Impairment', 'Moderate Impairment', 'Non Impairment', 'Very Mild Impairment'] |
|
|
| |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| def load_model(token): |
| if not token: |
| raise gr.Error("Hugging Face Token tidak ditemukan di pengaturan secrets repository!") |
| try: |
| print("Mengunduh model dari Hub menggunakan token...") |
| model_path = hf_hub_download( |
| repo_id="dielz/alzheimer-classification", |
| filename="21_juni/swin_v1_99%/swin_simple_best_model.pth", |
| token=token |
| ) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = timm.create_model(MODEL_NAME_TIMM, pretrained=False, num_classes=NUM_CLASSES) |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model.to(device) |
| model.eval() |
| print(f"Model berhasil dimuat ke device: {device}") |
| return model, device |
| except Exception as e: |
| raise gr.Error(f"Gagal memuat model: {str(e)}") |
|
|
| MODEL, DEVICE = load_model(HF_TOKEN) |
|
|
| |
| def reshape_transform(tensor): |
| |
| try: |
| B, H, W, C = tensor.shape; result = tensor.permute(0, 3, 1, 2) |
| except: |
| B, N, C = tensor.shape; H = W = int(N**0.5); result = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2) |
| return result |
|
|
| def unnormalize_image(tensor_img): |
| |
| img = tensor_img.clone().permute(1, 2, 0) |
| mean = torch.tensor([0.485, 0.456, 0.406], device=img.device) |
| std = torch.tensor([0.229, 0.224, 0.225], device=img.device) |
| img = img * std + mean |
| return torch.clamp(img, 0, 1) |
|
|
| TARGET_LAYER = MODEL.norm |
| CAM_GENERATOR = GradCAM(model=MODEL, target_layers=[TARGET_LAYER], reshape_transform=reshape_transform) |
|
|
| |
| def predict_and_visualize(input_image): |
| if input_image is None: |
| raise gr.Error("Mohon unggah gambar terlebih dahulu.") |
| |
| image_tensor = transform(input_image).unsqueeze(0).to(DEVICE) |
| outputs = MODEL(image_tensor) |
| probabilities = torch.softmax(outputs, dim=1)[0] |
| confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(NUM_CLASSES)} |
| pred_label_idx = torch.argmax(probabilities).item() |
|
|
| print("Membuat visualisasi Grad-CAM...") |
| targets = [ClassifierOutputTarget(pred_label_idx)] |
| grayscale_cam = CAM_GENERATOR(input_tensor=image_tensor, targets=targets)[0, :] |
| |
| |
| rgb_img = np.array(input_image) / 255.0 |
| |
| |
| |
| |
| |
| |
| target_size = (rgb_img.shape[1], rgb_img.shape[0]) |
| grayscale_cam_resized = cv2.resize(grayscale_cam, target_size) |
| |
|
|
| |
| visualization = show_cam_on_image(rgb_img, grayscale_cam_resized, use_rgb=True) |
| print("Visualisasi selesai.") |
| |
| return visualization, confidences |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| "<h1 style='text-align: center; margin-bottom: 1rem'>🧠 Alzheimer's Disease Classification with XAI</h1>" |
| ) |
| |
| gr.Markdown( |
| "<p style='text-align: center; color: #FFA500; font-size: 14px;'>" |
| "⚠ This tool is for educational purposes only and should not be used as a substitute for a professional medical diagnosis of Alzheimer's disease." |
| "</p>" |
| ) |
|
|
| gr.Markdown( |
| "<p style='text-align: center;'>Unggah gambar MRI otak untuk diklasifikasikan. Model akan menampilkan prediksi beserta visualisasi Grad-CAM.</p>" |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image = gr.Image(type="pil", label="Unggah Gambar MRI") |
| submit_btn = gr.Button("Klasifikasikan", variant="primary") |
| gr.Markdown("---") |
| gr.Markdown("### Contoh Gambar") |
| gr.Examples( |
| examples=[ |
| "inference_swin/Mild Impairment/MildImpairment (828).jpg", |
| |
| |
| "inference_swin/Moderate Impairment/ModerateImpairment (133).jpg", |
| |
| "inference_swin/No Impairment/NoImpairment (1689).jpg", |
| |
| "inference_swin/Very Mild Impairment/VeryMildImpairment (903).jpg" |
| ], |
| inputs=input_image |
| ) |
| with gr.Column(scale=1): |
| output_gradcam = gr.Image(label="Visualisasi Grad-CAM") |
| output_confidence = gr.Label(num_top_classes=4, label="Confidence Scores") |
|
|
| submit_btn.click(fn=predict_and_visualize, inputs=[input_image], outputs=[output_gradcam, output_confidence]) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch(debug=True) |