IDRiD / app-gradio.py
kodetr's picture
Update app-gradio.py
3f5df17 verified
import tensorflow as tf
import gradio as gr
import numpy as np
import os
import warnings
import io
import json
import base64
from PIL import Image
import tempfile
warnings.filterwarnings("ignore")
# ============================================================
# 1. LOAD MODEL (with Hugging Face compatibility)
# ============================================================
print("=" * 60)
print("πŸš€ LOADING MODEL FOR HUGGING FACE SPACES")
print("=" * 60)
# Cek apakah model ada di root atau folder
MODEL_PATHS = [
"model.keras",
"./model.keras",
"/tmp/model.keras"
]
best_model = None
for model_path in MODEL_PATHS:
if os.path.exists(model_path):
try:
print(f"πŸ“‚ Trying to load model from: {model_path}")
best_model = tf.keras.models.load_model(
model_path,
compile=False,
safe_mode=False # Important for compatibility
)
print(f"βœ… Model loaded successfully from {model_path}")
break
except Exception as e:
print(f"❌ Failed to load from {model_path}: {e}")
# Jika model tidak ditemukan, buat dummy model
if best_model is None:
print("⚠️ No model file found. Creating dummy model for demo...")
from tensorflow.keras import layers, Model
inputs = layers.Input(shape=(224, 224, 3))
x = layers.GlobalAveragePooling2D()(inputs)
dr_output = layers.Dense(5, name="dr_head")(x)
dme_output = layers.Dense(3, name="dme_head")(x)
best_model = Model(inputs, {"dr_head": dr_output, "dme_head": dme_output})
best_model.compile(optimizer="adam", loss="categorical_crossentropy")
print("βœ… Dummy model created")
# Summary model (debug info)
try:
best_model.summary()
except:
print("ℹ️ Model loaded, summary not available")
# ============================================================
# 2. CONFIG
# ============================================================
IMG_SIZE = 224
DR_CLASSES = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
DME_CLASSES = ["No DME", "Low Risk", "High Risk"]
# ============================================================
# 3. PREPROCESSING FUNCTIONS
# ============================================================
def preprocess_pil_image(img):
"""Preprocess PIL Image for prediction"""
# Convert to RGB if needed
if img.mode != 'RGB':
img = img.convert('RGB')
# Resize
img = img.resize((IMG_SIZE, IMG_SIZE))
# Convert to numpy and normalize
arr = np.array(img, dtype=np.float32) / 255.0
# Add batch dimension
return np.expand_dims(arr, 0)
# ============================================================
# 4. SOFTMAX SAFETY
# ============================================================
def ensure_probability(x):
x = np.asarray(x, dtype=np.float32)
# If values don't look like probabilities, apply softmax
if x.min() < 0 or x.max() > 1.0 or abs(x.sum() - 1.0) > 1e-3:
x = tf.nn.softmax(x).numpy()
return x
# ============================================================
# 5. CORE PREDICTION FUNCTION
# ============================================================
def predict_image(image):
"""Core prediction function that returns structured data"""
try:
# Preprocess
img_tensor = preprocess_pil_image(image)
# Predict (disable verbose for cleaner output)
preds = best_model.predict(img_tensor, verbose=0)
# ---- Handle different model output formats ----
dr_pred = None
dme_pred = None
if isinstance(preds, dict):
# Cari key untuk DR dan DME
dr_keys = [k for k in preds.keys() if 'dr' in k.lower()]
dme_keys = [k for k in preds.keys() if 'dme' in k.lower()]
if dr_keys:
dr_pred = preds[dr_keys[0]]
if dme_keys:
dme_pred = preds[dme_keys[0]]
# Jika tidak ketemu, ambil 2 output pertama
if dr_pred is None and len(preds) >= 2:
keys = list(preds.keys())
dr_pred = preds[keys[0]]
dme_pred = preds[keys[1]]
elif isinstance(preds, (list, tuple)):
if len(preds) >= 2:
dr_pred = preds[0]
dme_pred = preds[1]
else:
dr_pred = preds[0][:, :5] if len(preds[0].shape) > 1 else preds[0][:5]
dme_pred = preds[0][:, 5:8] if len(preds[0].shape) > 1 else preds[0][5:8]
elif isinstance(preds, np.ndarray):
if len(preds.shape) == 2:
dr_pred = preds[:, :5]
dme_pred = preds[:, 5:8]
else:
dr_pred = preds[:5]
dme_pred = preds[5:8]
# Ambil batch pertama jika ada batch dimension
if dr_pred is not None and len(dr_pred.shape) > 1:
dr_pred = dr_pred[0]
if dme_pred is not None and len(dme_pred.shape) > 1:
dme_pred = dme_pred[0]
# Jika masih None, beri nilai default
if dr_pred is None:
dr_pred = np.zeros(5)
if dme_pred is None:
dme_pred = np.zeros(3)
# ---- Apply softmax ----
dr_probs = ensure_probability(dr_pred)
dme_probs = ensure_probability(dme_pred)
# ---- Get results ----
dr_idx = int(np.argmax(dr_probs))
dme_idx = int(np.argmax(dme_probs))
dr_name = DR_CLASSES[dr_idx]
dme_name = DME_CLASSES[dme_idx]
dr_conf = float(dr_probs[dr_idx] * 100)
dme_conf = float(dme_probs[dme_idx] * 100)
# ---- Generate recommendations ----
if dr_name in ["No DR"]:
rec_dr = "Lanjutkan pola hidup sehat dan lakukan pemeriksaan mata rutin minimal 1 tahun sekali."
elif dr_name in ["Mild", "Moderate"]:
rec_dr = "Disarankan kontrol gula darah secara ketat dan pemeriksaan mata berkala setiap 6 bulan."
else: # Severe / Proliferative
rec_dr = "Disarankan segera konsultasi ke dokter spesialis mata untuk evaluasi dan penanganan lebih lanjut."
if dme_name == "No DME":
rec_dme = "Belum ditemukan tanda edema makula diabetik, lanjutkan pemantauan rutin."
elif dme_name == "Low Risk":
rec_dme = "Perlu observasi ketat dan pemeriksaan lanjutan untuk mencegah progresivitas."
else: # High Risk
rec_dme = "Disarankan segera mendapatkan evaluasi klinis dan terapi oleh dokter spesialis mata."
# Return both structured data and HTML
return {
"success": True,
"predictions": {
"diabetic_retinopathy": {
"classification": dr_name,
"confidence": dr_conf,
"index": dr_idx,
"probabilities": dr_probs.tolist(),
"recommendation": rec_dr
},
"diabetic_macular_edema": {
"classification": dme_name,
"confidence": dme_conf,
"index": dme_idx,
"probabilities": dme_probs.tolist(),
"recommendation": rec_dme
}
}
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
# ============================================================
# 6. API PREDICT FUNCTION UNTUK /run/predict
# ============================================================
def api_predict(image):
"""
Function untuk API endpoint /run/predict
Gradio akan secara otomatis mengkonversi input ke format yang tepat
"""
try:
if image is None:
return {"error": "No image provided"}
# Jika input adalah dictionary (dari JSON API call)
if isinstance(image, dict):
if "data" in image:
# Handle base64 dari JSON API
return handle_api_json_input(image)
# Jika input adalah file/bytes (dari form-data)
# Gradio sudah otomatis konversi ke PIL Image
result = predict_image(image)
return result
except Exception as e:
return {
"success": False,
"error": f"API processing error: {str(e)}"
}
def handle_api_json_input(image_data):
"""Handle JSON input dengan base64"""
try:
img_data = image_data["data"]
if isinstance(img_data, list):
img_data = img_data[0]
# Decode base64
if img_data.startswith("data:image"):
img_data = img_data.split(",")[1]
# Convert base64 to PIL Image
img_bytes = base64.b64decode(img_data)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# Get prediction
return predict_image(img)
except Exception as e:
return {
"success": False,
"error": f"Base64 processing error: {str(e)}"
}
# ============================================================
# 7. FORMAT OUTPUT FOR GRADIO UI
# ============================================================
def format_prediction_html(result):
"""Format prediction result as HTML for Gradio"""
if not result["success"]:
return f"""
<div style="color: red; padding: 20px; border: 2px solid red; border-radius: 10px;">
<h3>❌ Error</h3>
<p>{result['error']}</p>
</div>
"""
preds = result["predictions"]
dr = preds["diabetic_retinopathy"]
dme = preds["diabetic_macular_edema"]
# Warna berdasarkan severity
dr_color = {
"No DR": "#28a745",
"Mild": "#ffc107",
"Moderate": "#fd7e14",
"Severe": "#dc3545",
"Proliferative DR": "#6f42c1"
}.get(dr["classification"], "#000000")
dme_color = {
"No DME": "#28a745",
"Low Risk": "#ffc107",
"High Risk": "#dc3545"
}.get(dme["classification"], "#000000")
html = f"""
<div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto;">
<!-- Header -->
<div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white; padding: 25px; border-radius: 15px 15px 0 0; margin-bottom: 20px;">
<h1 style="margin: 0; font-size: 32px;">πŸ”¬ HASIL DETEKSI</h1>
<p style="margin: 5px 0 0 0; font-size: 16px; opacity: 0.9;">AI-Powered Retina Analysis</p>
</div>
<!-- Results Table -->
<div style="background: white; border-radius: 10px; box-shadow: 0 4px 12px rgba(0,0,0,0.1); overflow: hidden;">
<table style="width: 100%; border-collapse: collapse;">
<thead>
<tr style="background-color: #f8f9fa;">
<th style="padding: 16px; text-align: left; border-bottom: 2px solid #dee2e6; font-size: 18px;">Kondisi</th>
<th style="padding: 16px; text-align: left; border-bottom: 2px solid #dee2e6; font-size: 18px;">Klasifikasi</th>
<th style="padding: 16px; text-align: left; border-bottom: 2px solid #dee2e6; font-size: 18px;">Tingkat Kepercayaan</th>
</tr>
</thead>
<tbody>
<tr>
<td style="padding: 16px; border-bottom: 1px solid #dee2e6; font-weight: bold;">Diabetic Retinopathy (DR)</td>
<td style="padding: 16px; border-bottom: 1px solid #dee2e6;">
<span style="color: {dr_color}; font-weight: bold; font-size: 18px;">{dr['classification']}</span>
</td>
<td style="padding: 16px; border-bottom: 1px solid #dee2e6;">
<div style="display: flex; align-items: center; gap: 10px;">
<div style="flex-grow: 1; background: #e9ecef; height: 20px; border-radius: 10px; overflow: hidden;">
<div style="width: {dr['confidence']}%; background: {dr_color}; height: 100%;"></div>
</div>
<span style="font-weight: bold; min-width: 60px;">{dr['confidence']:.1f}%</span>
</div>
</td>
</tr>
<tr>
<td style="padding: 16px; border-bottom: 1px solid #dee2e6; font-weight: bold;">Diabetic Macular Edema (DME)</td>
<td style="padding: 16px; border-bottom: 1px solid #dee2e6;">
<span style="color: {dme_color}; font-weight: bold; font-size: 18px;">{dme['classification']}</span>
</td>
<td style="padding: 16px; border-bottom: 1px solid #dee2e6;">
<div style="display: flex; align-items: center; gap: 10px;">
<div style="flex-grow: 1; background: #e9ecef; height: 20px; border-radius: 10px; overflow: hidden;">
<div style="width: {dme['confidence']}%; background: {dme_color}; height: 100%;"></div>
</div>
<span style="font-weight: bold; min-width: 60px;">{dme['confidence']:.1f}%</span>
</div>
</td>
</tr>
</tbody>
</table>
</div>
<!-- Recommendations -->
<div style="margin-top: 25px; background: white; border-radius: 10px; box-shadow: 0 4px 12px rgba(0,0,0,0.1); overflow: hidden;">
<div style="background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); color: white; padding: 15px;">
<h3 style="margin: 0; font-size: 22px;">🩺 REKOMENDASI KLINIS</h3>
</div>
<div style="padding: 20px;">
<div style="margin-bottom: 15px;">
<h4 style="color: #333; margin-bottom: 8px;">β€’ Diabetic Retinopathy (DR):</h4>
<p style="margin: 0; color: #555; line-height: 1.6;">{dr['recommendation']}</p>
</div>
<div>
<h4 style="color: #333; margin-bottom: 8px;">β€’ Diabetic Macular Edema (DME):</h4>
<p style="margin: 0; color: #555; line-height: 1.6;">{dme['recommendation']}</p>
</div>
</div>
</div>
<!-- Disclaimer -->
<div style="margin-top: 20px; padding: 15px; background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; font-size: 14px;">
<strong>⚠️ Disclaimer:</strong> Hasil ini merupakan prediksi AI dan bukan diagnosis medis. Konsultasikan dengan dokter spesialis mata untuk diagnosis yang akurat.
</div>
</div>
"""
return html
# ============================================================
# 8. GRADIO UI FUNCTION
# ============================================================
def gradio_predict(image):
"""Main function for Gradio UI"""
if image is None:
return "❌ Silakan unggah gambar fundus retina"
# Get prediction
result = predict_image(image)
# Format as HTML
return format_prediction_html(result)
# ============================================================
# 9. CREATE GRADIO INTERFACES
# ============================================================
# Interface untuk Web UI
web_interface = gr.Interface(
fn=gradio_predict,
inputs=gr.Image(type="pil", label="πŸ“€ Upload Gambar Retina"),
outputs=gr.HTML(label="πŸ“Š Hasil Analisis"),
title="🩺 DETEKSI DIABETIC RETINOPATHY & DME",
description="Sistem AI untuk Analisis Citra Fundus Retina",
allow_flagging="never"
)
# Interface untuk API (akan digunakan oleh /run/predict)
api_interface = gr.Interface(
fn=api_predict,
inputs=gr.Image(type="pil"),
outputs=gr.JSON(),
title="API Endpoint",
description="Use this endpoint for API calls",
allow_flagging="never"
)
# ============================================================
# 10. MULTI TEST IMAGES
# ============================================================
TEST_IMAGES = [
"IDRiD_001test.jpg",
"IDRiD_004test.jpg",
"IDRiD_005test.jpg",
"IDRiD_006test.jpg",
"IDRiD_007test.jpg",
"IDRiD_008test.jpg",
"IDRiD_009test.jpg",
"IDRiD_010test.jpg",
"IDRiD_011test.jpg",
"IDRiD_012test.jpg",
]
TEST_IMAGES = [[p] for p in TEST_IMAGES if os.path.exists(p)]
# ============================================================
# 11. CREATE GRADIO APP WITH BLOCKS
# ============================================================
with gr.Blocks(
title="DR & DME Detection",
# css=CUSTOM_CSS,
theme=gr.themes.Soft()
) as demo:
# Header
gr.Markdown("""
# 🩺 DETEKSI DIABETIC RETINOPATHY & DME
### Sistem AI untuk Analisis Citra Fundus Retina
Upload gambar fundus retina untuk mendeteksi:
- **Diabetic Retinopathy (DR)**: Kerusakan retina akibat diabetes
- **Diabetic Macular Edema (DME)**: Pembengkakan di makula
""")
# Create tabs for Web UI and API
with gr.Tabs():
# Tab 1: Web UI
with gr.TabItem("🌐 Web Interface"):
with gr.Row():
with gr.Column(scale=1):
# Upload section
image_input = gr.Image(
type="pil",
label="πŸ“€ Upload Gambar Retina",
height=300
)
upload_btn = gr.Button(
"πŸ” Analisis Gambar",
variant="primary",
size="lg"
)
gr.Markdown("""
**Format yang didukung:** JPG, PNG, JPEG
**Ukuran rekomendasi:** 224Γ—224 piksel
**Warna:** RGB (akan dikonversi otomatis)
""")
with gr.Column(scale=2):
# Results section
output_html = gr.HTML(
label="πŸ“Š Hasil Analisis",
value="<div style='text-align: center; padding: 50px; color: #666;'>Hasil analisis akan muncul di sini setelah mengupload gambar.</div>"
)
gr.Markdown("### πŸ§ͺ Data Testing")
gr.Examples(
examples=TEST_IMAGES,
inputs=image_input
)
# Connect button to function
upload_btn.click(
fn=gradio_predict,
inputs=image_input,
outputs=output_html
)
# Also trigger on image upload
image_input.change(
fn=gradio_predict,
inputs=image_input,
outputs=output_html
)
# Tab 2: API Interface
with gr.TabItem("πŸ”§ API Endpoint"):
gr.Markdown("""
### API Endpoint untuk Mobile App
**URL:** `/run/predict`
**Method:** POST
**Content-Type:** `multipart/form-data` atau `application/json`
""")
with gr.Row():
with gr.Column():
api_image_input = gr.Image(
type="pil",
label="Test API dengan gambar"
)
api_test_btn = gr.Button("Test API", variant="secondary")
with gr.Column():
api_output = gr.JSON(
label="API Response",
value={"info": "API response akan muncul di sini"}
)
# Connect API test button
api_test_btn.click(
fn=api_predict,
inputs=api_image_input,
outputs=api_output
)
gr.Markdown("""
### πŸ“‹ Contoh Penggunaan API
**cURL dengan file:**
```bash
curl -X POST "https://[your-space].hf.space/run/predict" \\
-F "data=@retina_image.jpg"
```
**Python:**
```python
import requests
with open("retina_image.jpg", "rb") as f:
response = requests.post(
"https://[your-space].hf.space/run/predict",
files={"data": f}
)
print(response.json())
```
""")
# ============================================================
# 11. LAUNCH FOR HUGGING FACE
# ============================================================
if __name__ == "__main__":
# Launch untuk Hugging Face Spaces
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False
)