yurista's picture
Upload 2 files
749940e verified
import gradio as gr
import numpy as np
import cv2
import tensorflow as tf
from PIL import Image
from tensorflow.keras.layers import Average, Layer
# --- Registrasi Custom Layer ---
@tf.keras.utils.register_keras_serializable(package="Custom")
class Cast(Layer):
def call(self, inputs):
return tf.cast(inputs, tf.float32)
# --- Konfigurasi Awal ---
MODEL_PATH = "ensemble_efficientresnet_model.h5" # Pastikan file ini diunggah ke root repo
IMG_SIZE = (224, 224)
CLASS_NAMES = ['Astrocitoma', 'Carcinoma', 'Ependimoma', 'Ganglioglioma', 'Germinoma', 'Glioblastoma',
'Glioma', 'Granuloma', 'Meduloblastoma', 'Meningioma', 'Neurocitoma', 'Notumor',
'Oligodendroglioma', 'Papiloma', 'Pituitary', 'Schwannoma', 'Tuberculoma']
# --- Load model ---
try:
print("πŸ”„ Memuat model ensemble...")
model = tf.keras.models.load_model(
MODEL_PATH,
custom_objects={'Cast': Cast, 'Average': Average}
)
print("βœ… Model berhasil dimuat.")
except Exception as e:
print(f"❌ Error saat memuat model: {e}")
model = None
# --- Fungsi Preprocessing ---
def preprocess_for_model(pil_image):
try:
img = np.array(pil_image)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
original_resized_img = cv2.resize(img, IMG_SIZE)
gray = cv2.cvtColor(original_resized_img, cv2.COLOR_RGB2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced_gray = clahe.apply(gray)
_, thresh = cv2.threshold(enhanced_gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
c = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(c)
processed_img = cv2.resize(original_resized_img[y:y+h, x:x+w], IMG_SIZE)
else:
processed_img = original_resized_img
final_img_for_model = processed_img.astype(np.float32) / 255.0
return final_img_for_model, original_resized_img, processed_img
except Exception as e:
print(f"❌ Error saat preprocessing: {e}")
return None, None, None
# --- Fungsi Saliency Mapping ---
def graph_based_saliency(image):
try:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
gray = cv2.resize(gray, IMG_SIZE)
hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
hist_norm = hist / hist.sum()
p_map = -np.log(hist_norm + 1e-8).flatten()
saliency = np.array([p_map[p] for p in gray.flatten()]).reshape(gray.shape)
h, w = gray.shape
Y, X = np.mgrid[0:h, 0:w]
center_x, center_y = w // 2, h // 2
sigma = min(h, w) / 2.0
spatial_weight = np.exp(-((X - center_x)**2 + (Y - center_y)**2) / (2 * sigma**2))
saliency_weighted = saliency * spatial_weight
saliency_norm = cv2.normalize(saliency_weighted, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
_, mask = cv2.threshold(saliency_norm, 100, 255, cv2.THRESH_BINARY)
saliency_masked = cv2.bitwise_and(image, image, mask=mask)
saliency_color = cv2.applyColorMap(saliency_norm, cv2.COLORMAP_JET)
combined = cv2.addWeighted(saliency_color, 0.6, image, 0.4, 0)
final_overlay = cv2.bitwise_and(combined, combined, mask=mask)
return final_overlay
except Exception as e:
print(f"⚠ Gagal membuat saliency map: {e}")
return image
# --- Fungsi Inti Prediksi dan Segmentasi ---
def predict_and_visualize(input_image):
if model is None:
raise gr.Error("❌ Model belum berhasil dimuat.")
if input_image is None:
return None, "Silakan unggah gambar MRI terlebih dahulu.", None
processed_img, original_img, processed_display = preprocess_for_model(input_image)
if processed_img is None:
return None, "❌ Gagal memproses gambar.", None
img_array = np.expand_dims(processed_img, axis=0)
predictions = model.predict(img_array)[0]
pred_index = np.argmax(predictions)
confidences = {CLASS_NAMES[i]: float(predictions[i]) for i in range(len(CLASS_NAMES))}
saliency_image = graph_based_saliency(original_img)
return saliency_image, confidences, processed_display
# --- Antarmuka Gradio ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="teal", secondary_hue="cyan")) as demo:
gr.Markdown("# 🧠 Klasifikasi & Segmentasi Tumor Otak (Saliency Mapping)")
gr.Markdown("Model ensemble (ResNet50 + EfficientNetB4) dengan visualisasi ROI berbasis Saliency Mapping akurat.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Unggah Gambar MRI")
submit_btn = gr.Button("Prediksi & Segmentasi", variant="primary")
with gr.Column():
output_gradcam = gr.Image(label="Peta Saliency (ROI Tumor)")
output_label = gr.Label(num_top_classes=3, label="Klasifikasi Tumor")
output_preprocessed = gr.Image(label="Gambar Setelah Preprocessing")
submit_btn.click(
fn=predict_and_visualize,
inputs=[input_image],
outputs=[output_gradcam, output_label, output_preprocessed]
)
# --- Menjalankan Aplikasi ---
demo.launch()