ulcer_detection / app.py
EngReem85's picture
Update app.py
d5181cc verified
# -*- coding: utf-8 -*-
"""
تحليل قرحة القدم باستخدام Unet + EfficientNet-b0
النموذج من Google Drive (best_model_5.pth)
"""
import os
import cv2
import gdown
import numpy as np
from PIL import Image
import torch
import gradio as gr
import segmentation_models_pytorch as smp
# =========================================================
# الإعدادات العامة
# =========================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 512
THRESHOLD = 0.35
MODEL_PATH = "best_model_5.pth"
MODEL_URL = "https://drive.google.com/uc?id=1Ovaczsjdp3E-_gYF2pbUibDjPWAC1a6c"
CLASS_NAMES = ["قرحة (Granulation)", "Slough", "نخر (Necrosis)"]
CLASS_COLORS = {
"قرحة (Granulation)": (255, 0, 0), # أحمر
"Slough": (255, 255, 0), # أصفر
"نخر (Necrosis)": (0, 0, 0) # أسود
}
segmenter = None
# =========================================================
# تحميل النموذج
# =========================================================
def initialize_model():
"""تحميل نموذج Unet EfficientNet من Google Drive"""
global segmenter
if not os.path.exists(MODEL_PATH):
print("📥 تحميل النموذج من Google Drive...")
gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
try:
print("🔄 تحميل Unet EfficientNet...")
model = smp.Unet(
encoder_name="efficientnet-b0",
encoder_weights=None,
classes=len(CLASS_NAMES),
activation="sigmoid"
)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
clean_state = {k.replace("module.", "").replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(clean_state, strict=False)
model.to(DEVICE)
model.eval()
segmenter = model
print("✅ تم تحميل النموذج بنجاح.")
except Exception as e:
print(f"❌ فشل تحميل النموذج: {e}")
import traceback; traceback.print_exc()
segmenter = None
# =========================================================
# أدوات مساعدة
# =========================================================
def ensure_rgb(np_img):
"""تحويل الصورة إلى RGB إذا لزم"""
if np_img.ndim == 2:
return cv2.cvtColor(np_img, cv2.COLOR_GRAY2RGB)
if np_img.shape[-1] == 4:
return cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
return np_img
def preprocess_image(img: Image.Image):
img_np = ensure_rgb(np.array(img))
img_resized = cv2.resize(img_np, (IMG_SIZE, IMG_SIZE))
img_norm = img_resized.astype(np.float32) / 255.0
# ✅ تطبيع ImageNet الصحيح
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_norm = (img_norm - mean) / std
# ✅ تحويل إلى double لأن النموذج يستخدم float64
tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).double()
return tensor.to(DEVICE), img_np
# =========================================================
# التجزئة والتحليل
# =========================================================
def analyze_image(img: Image.Image):
"""تحليل صورة القدم وعرض النسب"""
if segmenter is None:
return img, img, {"خطأ": "النموذج غير مهيأ بعد."}
try:
print("🔍 بدء التحليل...")
tensor, img_np = preprocess_image(img)
with torch.no_grad():
output = segmenter(tensor).cpu().squeeze(0).numpy() # (3,H,W)
masks = (output >= THRESHOLD).astype(np.uint8)
# تنظيف الأقنعة
kernel = np.ones((5,5), np.uint8)
for i in range(masks.shape[0]):
masks[i] = cv2.morphologyEx(masks[i], cv2.MORPH_OPEN, kernel)
masks[i] = cv2.morphologyEx(masks[i], cv2.MORPH_CLOSE, kernel)
# حساب النسب
total_pixels = masks.shape[1] * masks.shape[2]
ratios = {
CLASS_NAMES[0]: np.sum(masks[0]) / total_pixels * 100,
CLASS_NAMES[1]: np.sum(masks[1]) / total_pixels * 100,
CLASS_NAMES[2]: np.sum(masks[2]) / total_pixels * 100
}
total_ratio = sum(ratios.values())
# إنشاء قناع لوني
color_mask = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8)
color_mask[masks[0] == 1] = CLASS_COLORS[CLASS_NAMES[0]]
color_mask[masks[1] == 1] = CLASS_COLORS[CLASS_NAMES[1]]
color_mask[masks[2] == 1] = CLASS_COLORS[CLASS_NAMES[2]]
color_mask = cv2.resize(color_mask, (img_np.shape[1], img_np.shape[0]))
# دمج القناع مع الصورة
alpha = 0.5
blended = cv2.addWeighted(img_np, 1 - alpha, color_mask, alpha, 0)
# تقييم الخطورة
if total_ratio == 0:
risk = "No Risk 🟢"
elif total_ratio < 1:
risk = "Low Risk 🟡"
elif total_ratio < 5:
risk = "Medium Risk 🟠"
else:
risk = "High Risk 🔴"
report = {
"نسب الأنسجة (%)": {k: f"{v:.2f}" for k, v in ratios.items()},
"إجمالي (%)": f"{total_ratio:.2f}",
"مستوى الخطورة": risk
}
print(f"📊 النتائج: {report}")
return Image.fromarray(blended), Image.fromarray(color_mask), report
except Exception as e:
print(f"❌ خطأ أثناء التحليل: {e}")
import traceback; traceback.print_exc()
return img, img, {"خطأ": str(e)}
# =========================================================
# واجهة Gradio
# =========================================================
def build_ui():
with gr.Blocks(title="تحليل قرحة القدم - EfficientNet Unet", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🦶 تحليل صورة القدم السكري (Unet + EfficientNet)")
gr.Markdown("الكشف عن أنواع الأنسجة المصابة (قرحة / Slough / نخر) وتقدير مستوى الخطورة.")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="📤 ارفع صورة القدم", height=320)
analyze_btn = gr.Button("🔍 بدء التحليل", variant="primary")
with gr.Column(scale=1):
out_blended = gr.Image(type="pil", label="🩸 الصورة مع القناع", height=320)
out_mask = gr.Image(type="pil", label="🧩 القناع اللوني", height=320)
out_json = gr.JSON(label="📊 التقرير التفصيلي")
analyze_btn.click(
fn=analyze_image,
inputs=[input_img],
outputs=[out_blended, out_mask, out_json]
)
return demo
# =========================================================
# تشغيل التطبيق
# =========================================================
if __name__ == "__main__":
print("🚀 تهيئة النموذج...")
initialize_model()
app = build_ui()
app.launch(server_name="0.0.0.0", server_port=7860)