Spaces:
Running
Running
| import tensorflow as tf | |
| import numpy as np | |
| import os | |
| import warnings | |
| import io | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| warnings.filterwarnings("ignore") | |
| # ============================================================ | |
| # 1. LOAD MODEL (with Hugging Face compatibility) | |
| # ============================================================ | |
| print("=" * 60) | |
| print("π LOADING MODEL FOR HUGGING FACE SPACES") | |
| print("=" * 60) | |
| MODEL_PATHS = ["model.keras", "./model.keras", "/tmp/model.keras"] | |
| best_model = None | |
| for model_path in MODEL_PATHS: | |
| if os.path.exists(model_path): | |
| try: | |
| print(f"π Loading model from: {model_path}") | |
| best_model = tf.keras.models.load_model(model_path, compile=False, safe_mode=False) | |
| print(f"β Model loaded successfully") | |
| break | |
| except Exception as e: | |
| print(f"β Failed to load: {e}") | |
| if best_model is None: | |
| print("β οΈ Creating dummy model for demo...") | |
| from tensorflow.keras import layers, Model | |
| inputs = layers.Input(shape=(224, 224, 3)) | |
| x = layers.GlobalAveragePooling2D()(inputs) | |
| dr_output = layers.Dense(5, name="dr_head")(x) | |
| dme_output = layers.Dense(3, name="dme_head")(x) | |
| best_model = Model(inputs, {"dr_head": dr_output, "dme_head": dme_output}) | |
| print("β Dummy model created") | |
| # ============================================================ | |
| # 2. CONFIG | |
| # ============================================================ | |
| IMG_SIZE = 224 | |
| DR_CLASSES = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"] | |
| DME_CLASSES = ["No DME", "Low Risk", "High Risk"] | |
| # Color mapping for each class | |
| COLOR_MAP = { | |
| "No DR": "#10b981", # Green | |
| "Mild": "#f59e0b", # Yellow | |
| "Moderate": "#f97316", # Orange | |
| "Severe": "#ef4444", # Red | |
| "Proliferative DR": "#8b5cf6", # Purple | |
| "No DME": "#10b981", # Green | |
| "Low Risk": "#f59e0b", # Yellow | |
| "High Risk": "#ef4444" # Red | |
| } | |
| # ============================================================ | |
| # 3. PREDICTION FUNCTIONS | |
| # ============================================================ | |
| def preprocess_pil_image(img): | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| img = img.resize((IMG_SIZE, IMG_SIZE)) | |
| arr = np.array(img, dtype=np.float32) / 255.0 | |
| return np.expand_dims(arr, 0) | |
| def ensure_probability(x): | |
| x = np.asarray(x, dtype=np.float32) | |
| if x.min() < 0 or x.max() > 1.0 or abs(x.sum() - 1.0) > 1e-3: | |
| x = tf.nn.softmax(x).numpy() | |
| return x | |
| def predict_image(image): | |
| try: | |
| img_tensor = preprocess_pil_image(image) | |
| preds = best_model.predict(img_tensor, verbose=0) | |
| dr_pred = None | |
| dme_pred = None | |
| if isinstance(preds, dict): | |
| dr_keys = [k for k in preds.keys() if 'dr' in k.lower()] | |
| dme_keys = [k for k in preds.keys() if 'dme' in k.lower()] | |
| if dr_keys: | |
| dr_pred = preds[dr_keys[0]] | |
| if dme_keys: | |
| dme_pred = preds[dme_keys[0]] | |
| if dr_pred is None and len(preds) >= 2: | |
| keys = list(preds.keys()) | |
| dr_pred = preds[keys[0]] | |
| dme_pred = preds[keys[1]] | |
| elif isinstance(preds, (list, tuple)): | |
| if len(preds) >= 2: | |
| dr_pred = preds[0] | |
| dme_pred = preds[1] | |
| else: | |
| dr_pred = preds[0][:, :5] if len(preds[0].shape) > 1 else preds[0][:5] | |
| dme_pred = preds[0][:, 5:8] if len(preds[0].shape) > 1 else preds[0][5:8] | |
| elif isinstance(preds, np.ndarray): | |
| if len(preds.shape) == 2: | |
| dr_pred = preds[:, :5] | |
| dme_pred = preds[:, 5:8] | |
| else: | |
| dr_pred = preds[:5] | |
| dme_pred = preds[5:8] | |
| if dr_pred is not None and len(dr_pred.shape) > 1: | |
| dr_pred = dr_pred[0] | |
| if dme_pred is not None and len(dme_pred.shape) > 1: | |
| dme_pred = dme_pred[0] | |
| dr_pred = dr_pred if dr_pred is not None else np.zeros(5) | |
| dme_pred = dme_pred if dme_pred is not None else np.zeros(3) | |
| dr_probs = ensure_probability(dr_pred) | |
| dme_probs = ensure_probability(dme_pred) | |
| dr_idx = int(np.argmax(dr_probs)) | |
| dme_idx = int(np.argmax(dme_probs)) | |
| dr_name = DR_CLASSES[dr_idx] | |
| dme_name = DME_CLASSES[dme_idx] | |
| dr_conf = float(dr_probs[dr_idx] * 100) | |
| dme_conf = float(dme_probs[dme_idx] * 100) | |
| # Recommendations | |
| recommendations = { | |
| "No DR": "Lanjutkan pola hidup sehat dan lakukan pemeriksaan mata rutin minimal 1 tahun sekali.", | |
| "Mild": "Disarankan kontrol gula darah secara ketat dan pemeriksaan mata berkala setiap 6 bulan.", | |
| "Moderate": "Disarankan kontrol gula darah secara ketat dan pemeriksaan mata berkala setiap 6 bulan.", | |
| "Severe": "Disarankan segera konsultasi ke dokter spesialis mata untuk evaluasi dan penanganan lebih lanjut.", | |
| "Proliferative DR": "Disarankan segera konsultasi ke dokter spesialis mata untuk evaluasi dan penanganan lebih lanjut.", | |
| "No DME": "Belum ditemukan tanda edema makula diabetik, lanjutkan pemantauan rutin.", | |
| "Low Risk": "Perlu observasi ketat dan pemeriksaan lanjutan untuk mencegah progresivitas.", | |
| "High Risk": "Disarankan segera mendapatkan evaluasi klinis dan terapi oleh dokter spesialis mata." | |
| } | |
| return { | |
| "success": True, | |
| "dr": { | |
| "name": dr_name, | |
| "confidence": dr_conf, | |
| "color": COLOR_MAP.get(dr_name, "#6b7280"), | |
| "recommendation": recommendations.get(dr_name, "") | |
| }, | |
| "dme": { | |
| "name": dme_name, | |
| "confidence": dme_conf, | |
| "color": COLOR_MAP.get(dme_name, "#6b7280"), | |
| "recommendation": recommendations.get(dme_name, "") | |
| } | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # ============================================================ | |
| # 4. CREATE FASTAPI APP (API ONLY) | |
| # ============================================================ | |
| app = FastAPI( | |
| title="DR & DME Detection API", | |
| description="API untuk deteksi Diabetic Retinopathy (DR) dan Diabetic Macular Edema (DME) dari gambar retina", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| return { | |
| "message": "DR & DME Detection API", | |
| "endpoints": { | |
| "POST /api/predict": "Upload image for prediction", | |
| "GET /health": "Check API health status" | |
| }, | |
| "version": "1.0.0" | |
| } | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": best_model is not None, | |
| "timestamp": np.datetime64('now').astype(str) | |
| } | |
| async def api_predict(file: UploadFile = File(...)): | |
| """ | |
| Predict DR and DME from retinal image | |
| - **file**: Image file (JPEG, PNG, etc.) | |
| """ | |
| try: | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| contents = await file.read() | |
| img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| result = predict_image(img) | |
| if not result["success"]: | |
| raise HTTPException(status_code=500, detail=result["error"]) | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ============================================================ | |
| # 5. MAIN ENTRY POINT | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| print("\n" + "="*60) | |
| print("π DR & DME Detection API Starting...") | |
| print("="*60) | |
| print(f"π± Health Check: https://kodetr-idrid.hf.space/health") | |
| print(f"π± API Docs: https://kodetr-idrid.hf.space/docs") | |
| print(f"π± Predict: POST https://kodetr-idrid.hf.space/api/predict") | |
| print("="*60) | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, # Changed from 7860 to 8000 | |
| log_level="info" | |
| ) | |