dielz's picture
adjust teks position
3a8bc14 verified
import gradio as gr
import torch
import timm
from torchvision import transforms
import numpy as np
import os
from huggingface_hub import hf_hub_download
from PIL import Image
import cv2 # <-- TAMBAHKAN IMPORT INI
# Impor komponen yang diperlukan dari library grad-cam
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
print("Library berhasil diimpor.")
# --- 1. KONFIGURASI MODEL & TRANSFORMASI ---
NUM_CLASSES = 4
IMAGE_SIZE = (224, 224)
MODEL_NAME_TIMM = 'swin_tiny_patch4_window7_224'
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3),
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
CLASS_NAMES = ['Mild Impairment', 'Moderate Impairment', 'Non Impairment', 'Very Mild Impairment']
# --- 2. MEMUAT TOKEN DAN MODEL SECARA GLOBAL (HANYA SEKALI) ---
HF_TOKEN = os.environ.get("HF_TOKEN")
def load_model(token):
if not token:
raise gr.Error("Hugging Face Token tidak ditemukan di pengaturan secrets repository!")
try:
print("Mengunduh model dari Hub menggunakan token...")
model_path = hf_hub_download(
repo_id="dielz/alzheimer-classification",
filename="21_juni/swin_v1_99%/swin_simple_best_model.pth",
token=token
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(MODEL_NAME_TIMM, pretrained=False, num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
print(f"Model berhasil dimuat ke device: {device}")
return model, device
except Exception as e:
raise gr.Error(f"Gagal memuat model: {str(e)}")
MODEL, DEVICE = load_model(HF_TOKEN)
# --- 3. KONFIGURASI GRAD-CAM ---
def reshape_transform(tensor):
# ... (Fungsi ini sudah benar dari sebelumnya)
try:
B, H, W, C = tensor.shape; result = tensor.permute(0, 3, 1, 2)
except:
B, N, C = tensor.shape; H = W = int(N**0.5); result = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
return result
def unnormalize_image(tensor_img):
# ... (Fungsi ini sudah benar dari sebelumnya)
img = tensor_img.clone().permute(1, 2, 0)
mean = torch.tensor([0.485, 0.456, 0.406], device=img.device)
std = torch.tensor([0.229, 0.224, 0.225], device=img.device)
img = img * std + mean
return torch.clamp(img, 0, 1)
TARGET_LAYER = MODEL.norm
CAM_GENERATOR = GradCAM(model=MODEL, target_layers=[TARGET_LAYER], reshape_transform=reshape_transform)
# --- 4. FUNGSI UTAMA PREDIKSI & VISUALISASI ---
def predict_and_visualize(input_image):
if input_image is None:
raise gr.Error("Mohon unggah gambar terlebih dahulu.")
image_tensor = transform(input_image).unsqueeze(0).to(DEVICE)
outputs = MODEL(image_tensor)
probabilities = torch.softmax(outputs, dim=1)[0]
confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(NUM_CLASSES)}
pred_label_idx = torch.argmax(probabilities).item()
print("Membuat visualisasi Grad-CAM...")
targets = [ClassifierOutputTarget(pred_label_idx)]
grayscale_cam = CAM_GENERATOR(input_tensor=image_tensor, targets=targets)[0, :]
# Siapkan gambar asli untuk digabung
rgb_img = np.array(input_image) / 255.0
# ==============================================================================
# --- LANGKAH PERBAIKAN: RESIZE HEATMAP ---
# ==============================================================================
# Ubah ukuran heatmap (grayscale_cam) agar sama dengan ukuran gambar asli (rgb_img)
# cv2.resize mengharapkan (width, height)
target_size = (rgb_img.shape[1], rgb_img.shape[0])
grayscale_cam_resized = cv2.resize(grayscale_cam, target_size)
# ==============================================================================
# Gabungkan gambar asli dengan heatmap yang SUDAH DI-RESIZE
visualization = show_cam_on_image(rgb_img, grayscale_cam_resized, use_rgb=True)
print("Visualisasi selesai.")
return visualization, confidences
# --- 5. MEMBUAT ANTARMUKA GRADIO ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"<h1 style='text-align: center; margin-bottom: 1rem'>🧠 Alzheimer's Disease Classification with XAI</h1>"
)
gr.Markdown(
"<p style='text-align: center; color: #FFA500; font-size: 14px;'>"
"⚠ This tool is for educational purposes only and should not be used as a substitute for a professional medical diagnosis of Alzheimer's disease."
"</p>"
)
gr.Markdown(
"<p style='text-align: center;'>Unggah gambar MRI otak untuk diklasifikasikan. Model akan menampilkan prediksi beserta visualisasi Grad-CAM.</p>"
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Unggah Gambar MRI")
submit_btn = gr.Button("Klasifikasikan", variant="primary")
gr.Markdown("---")
gr.Markdown("### Contoh Gambar")
gr.Examples(
examples=[
"inference_swin/Mild Impairment/MildImpairment (828).jpg",
#"inference_swin/Mild Impairment/MildImpairment (298).jpg",
#"inference_swin/Moderate Impairment/ModerateImpairment (1086).jpg",
"inference_swin/Moderate Impairment/ModerateImpairment (133).jpg",
#"inference_swin/No Impairment/NoImpairment (1373).jpg",
"inference_swin/No Impairment/NoImpairment (1689).jpg",
#"inference__swin/Very Mild Impairment/VeryMildImpairment (1446).jpg",
"inference_swin/Very Mild Impairment/VeryMildImpairment (903).jpg"
],
inputs=input_image
)
with gr.Column(scale=1):
output_gradcam = gr.Image(label="Visualisasi Grad-CAM")
output_confidence = gr.Label(num_top_classes=4, label="Confidence Scores")
submit_btn.click(fn=predict_and_visualize, inputs=[input_image], outputs=[output_gradcam, output_confidence])
# Luncurkan aplikasi
if __name__ == "__main__":
demo.launch(debug=True)