EngReem85 commited on
Commit
e288244
·
verified ·
1 Parent(s): f9d94c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -187
app.py CHANGED
@@ -1,10 +1,7 @@
1
  # -*- coding: utf-8 -*-
2
- # -*- coding: utf-8 -*-
3
  """
4
- واجهة موحّدة لتحليل صورة القدم باستخدام DFUTissueSegNet (من Google Drive)
5
- - ألوان متعددة: قرحة=أحمر، Slough=أصفر، نخر=أسود
6
- - حساب نسب كل نوع + مستوى خطورة إجمالي
7
- - Legend داخل Gradio
8
  """
9
 
10
  import os
@@ -12,233 +9,165 @@ import cv2
12
  import gdown
13
  import numpy as np
14
  from PIL import Image
15
-
16
  import torch
17
- import torch.nn as nn
18
- import torch.nn.functional as F
19
-
20
  import gradio as gr
 
21
 
22
- # ================================
23
- # إعدادات عامة
24
- # ================================
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
- IMG_SIZE = 512 # حجم الإدخال كما في تدريب DFUTissueSegNet
27
- THRESHOLD = 0.4 # عتبة تحويل الاحتمال إلى قناع (جرّبي 0.35..0.6 حسب بياناتك)
28
  MODEL_PATH = "best_model_5.pth"
29
-
30
- # غيّري الـ File ID أدناه لملفّك على Google Drive عند الحاجة
31
- # مثال: https://drive.google.com/file/d/FILE_ID/view => استخدمي: https://drive.google.com/uc?id=FILE_ID
32
  MODEL_URL = "https://drive.google.com/uc?id=1Ovaczsjdp3E-_gYF2pbUibDjPWAC1a6c"
33
 
34
- CLASS_NAMES = ["قرحة (Granulation)", "Slough (أنسجة ميتة جزئيًا)", "نخر (Necrotic)"]
35
  CLASS_COLORS = {
36
- "قرحة (Granulation)": (255, 0, 0), # أحمر
37
- "Slough (أنسجة ميتة جزئيًا)": (255, 255, 0), # أصفر
38
- "نخر (Necrotic)": (0, 0, 0) # أسود
39
  }
40
 
41
- # ================================
42
- # نموذج DFUTissueSegNet
43
- # ================================
44
- class ConvBlock(nn.Module):
45
- def __init__(self, in_ch, out_ch):
46
- super().__init__()
47
- self.block = nn.Sequential(
48
- nn.Conv2d(in_ch, out_ch, 3, padding=1),
49
- nn.BatchNorm2d(out_ch),
50
- nn.ReLU(inplace=True),
51
- nn.Conv2d(out_ch, out_ch, 3, padding=1),
52
- nn.BatchNorm2d(out_ch),
53
- nn.ReLU(inplace=True),
54
- )
55
- def forward(self, x):
56
- return self.block(x)
57
-
58
- class DFUTissueSegNet(nn.Module):
59
- def __init__(self, num_classes=3):
60
- super().__init__()
61
- self.enc1 = ConvBlock(3, 64); self.pool1 = nn.MaxPool2d(2)
62
- self.enc2 = ConvBlock(64, 128); self.pool2 = nn.MaxPool2d(2)
63
- self.enc3 = ConvBlock(128, 256);self.pool3 = nn.MaxPool2d(2)
64
- self.enc4 = ConvBlock(256, 512);self.pool4 = nn.MaxPool2d(2)
65
-
66
- self.center = ConvBlock(512, 1024)
67
-
68
- self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
69
- self.dec4 = ConvBlock(1024, 512)
70
- self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
71
- self.dec3 = ConvBlock(512, 256)
72
- self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
73
- self.dec2 = ConvBlock(256, 128)
74
- self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
75
- self.dec1 = ConvBlock(128, 64)
76
-
77
- self.final = nn.Conv2d(64, num_classes, 1)
78
-
79
- def forward(self, x):
80
- e1 = self.enc1(x)
81
- e2 = self.enc2(self.pool1(e1))
82
- e3 = self.enc3(self.pool2(e2))
83
- e4 = self.enc4(self.pool3(e3))
84
- c = self.center(self.pool4(e4))
85
-
86
- d4 = self.up4(c); d4 = torch.cat([d4, e4], dim=1); d4 = self.dec4(d4)
87
- d3 = self.up3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.dec3(d3)
88
- d2 = self.up2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.dec2(d2)
89
- d1 = self.up1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.dec1(d1)
90
-
91
- # مخرجات احتمالية 0..1 لكل قناة
92
- return torch.sigmoid(self.final(d1))
93
-
94
  segmenter = None
95
 
96
- # ================================
97
- # تحميل النموذج (من Google Drive)
98
- # ================================
99
  def initialize_model():
 
100
  global segmenter
 
101
  if not os.path.exists(MODEL_PATH):
102
  print("📥 تحميل النموذج من Google Drive...")
103
  gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
104
 
105
  try:
106
- print("🔄 تحميل DFUTissueSegNet...")
107
- segmenter = DFUTissueSegNet(num_classes=3)
108
- checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
 
 
 
 
109
 
110
- # التعامل مع checkpoints التي تحتوي state_dict
111
- if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
112
  state_dict = checkpoint["state_dict"]
113
  else:
114
  state_dict = checkpoint
115
 
116
- # تنظيف البوادئ الشائعة
117
- clean_state = {}
118
- for k, v in state_dict.items():
119
- nk = k.replace("module.", "").replace("model.", "")
120
- clean_state[nk] = v
121
-
122
- segmenter.load_state_dict(clean_state, strict=False)
123
- segmenter.to(DEVICE)
124
- segmenter.eval()
125
- print("✅ النموذج جاهز.")
126
  except Exception as e:
127
  print(f"❌ فشل تحميل النموذج: {e}")
128
  import traceback; traceback.print_exc()
129
  segmenter = None
130
 
131
- # ================================
132
  # أدوات مساعدة
133
- # ================================
134
  def ensure_rgb(np_img):
 
135
  if np_img.ndim == 2:
136
  return cv2.cvtColor(np_img, cv2.COLOR_GRAY2RGB)
137
  if np_img.shape[-1] == 4:
138
  return cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
139
  return np_img
140
 
141
- def apply_legend_markdown():
142
- return """
143
- ### 🧭 مفتاح الألوان (Legend)
144
- - 🩸 **أحمر** → نسيج قرحة (Granulation)
145
- - 🟡 **أصفر** → نسيج ميت جزئيًا (Slough)
146
- - ⚫ **أسود** → نسيج نخر (Necrotic)
147
- """
148
-
149
- # ================================
150
- # التجزئة + الحساب + التلوين
151
- # ================================
152
- def segment_and_color(pil_img: Image.Image):
153
- """
154
- يُرجع:
155
- - blended: الصورة مدموج عليها القناع اللوني
156
- - mask_rgb: القناع اللوني (RGB)
157
- - stats: نسب كل فئة + الإجمالي + مستوى الخطورة
158
- """
159
- if segmenter is None:
160
- return pil_img, pil_img, {"خطأ": "النموذج غير مهيأ"}
161
-
162
- # 1) التحضير
163
- img_np = ensure_rgb(np.array(pil_img))
164
- h, w = img_np.shape[:2]
165
-
166
- # 2) التحجيم + التطبيع كما في التدريب
167
  img_resized = cv2.resize(img_np, (IMG_SIZE, IMG_SIZE))
168
  img_norm = img_resized.astype(np.float32) / 255.0
169
- img_norm = (img_norm - 0.5) / 0.5 # (x - 0.5) / 0.5
170
-
171
- tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
172
-
173
- # 3) التنبؤ
174
- with torch.no_grad():
175
- probs = segmenter(tensor).cpu().squeeze(0).numpy() # (3, H, W) احتمالات
176
-
177
- # 4) أقنعة ثنائية لكل فئة
178
- masks = (probs >= THRESHOLD).astype(np.uint8) # 0/1
179
- # إزالة الضوضاء البسيطة
180
- kernel = np.ones((5, 5), np.uint8)
181
- for i in range(masks.shape[0]):
182
- masks[i] = cv2.morphologyEx(masks[i], cv2.MORPH_OPEN, kernel)
183
- masks[i] = cv2.morphologyEx(masks[i], cv2.MORPH_CLOSE, kernel)
184
-
185
- # 5) حساب النسب على أبعاد الإدخال ثم إعادة القياس
186
- total_pixels_input = IMG_SIZE * IMG_SIZE
187
- ratios = {
188
- CLASS_NAMES[0]: np.sum(masks[0]) / total_pixels_input * 100,
189
- CLASS_NAMES[1]: np.sum(masks[1]) / total_pixels_input * 100,
190
- CLASS_NAMES[2]: np.sum(masks[2]) / total_pixels_input * 100,
191
- }
192
- total_ratio = sum(ratios.values())
193
-
194
- # 6) إنشاء قناع لوني على حجم الإدخال ثم إعادته لحجم الصورة الأصلي
195
- color_mask = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
196
- color_mask[masks[0] == 1] = CLASS_COLORS[CLASS_NAMES[0]]
197
- color_mask[masks[1] == 1] = CLASS_COLORS[CLASS_NAMES[1]]
198
- color_mask[masks[2] == 1] = CLASS_COLORS[CLASS_NAMES[2]]
199
-
200
- mask_rgb = cv2.resize(color_mask, (w, h), interpolation=cv2.INTER_NEAREST)
201
 
202
- # 7) دمج القناع مع الصورة (ألفا ~0.5)
203
- alpha = (cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2GRAY) > 0).astype(np.float32)[..., None] * 0.5
204
- blended = (alpha * mask_rgb + (1 - alpha) * img_np).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- # 8) مستوى الخطورة
207
- if total_ratio == 0:
208
- risk = "No Risk 🟢"
209
- elif total_ratio <= 1:
210
- risk = "Low Risk 🟡"
211
- elif total_ratio <= 5:
212
- risk = "Medium Risk 🟠"
213
- else:
214
- risk = "High Risk 🔴"
215
 
216
- stats = {
217
- "نِسَب_الأنسجة": {
218
- CLASS_NAMES[0]: f"{ratios[CLASS_NAMES[0]]:.2f}%",
219
- CLASS_NAMES[1]: f"{ratios[CLASS_NAMES[1]]:.2f}%",
220
- CLASS_NAMES[2]: f"{ratios[CLASS_NAMES[2]]:.2f}%",
221
- "الإجمالي": f"{total_ratio:.2f}%",
222
- },
223
- "مستوى_الخطورة": risk,
224
- "ملاحظات": "التحليل يعتمد على DFUTissueSegNet متعدد الفئات (حجم 512 وتطبيع (x-0.5)/0.5)."
225
- }
226
 
227
- return Image.fromarray(blended), Image.fromarray(mask_rgb), stats
 
 
 
228
 
229
- # ================================
230
  # واجهة Gradio
231
- # ================================
232
  def build_ui():
233
- with gr.Blocks(title="تحليل قرحة القدم - DFUTissueSegNet", theme=gr.themes.Soft()) as demo:
234
- gr.Markdown("# 🦶 نظام تحليل قرحة القدم السكري - صورة واحدة")
235
- gr.Markdown("يعتمد على **DFUTissueSegNet** لتجزئة الأنسجة وحساب نسبها، مع تلوين واضح وLegend.")
236
 
237
  with gr.Row():
238
  with gr.Column(scale=1):
239
  input_img = gr.Image(type="pil", label="📤 ارفع صورة القدم", height=320)
240
  analyze_btn = gr.Button("🔍 بدء التحليل", variant="primary")
241
- legend = gr.Markdown(apply_legend_markdown())
242
 
243
  with gr.Column(scale=1):
244
  out_blended = gr.Image(type="pil", label="🩸 الصورة مع القناع", height=320)
@@ -246,19 +175,17 @@ def build_ui():
246
  out_json = gr.JSON(label="📊 التقرير التفصيلي")
247
 
248
  analyze_btn.click(
249
- fn=segment_and_color,
250
  inputs=[input_img],
251
  outputs=[out_blended, out_mask, out_json]
252
  )
253
  return demo
254
 
255
- # ================================
256
  # تشغيل التطبيق
257
- # ================================
258
  if __name__ == "__main__":
259
  print("🚀 تهيئة النموذج...")
260
  initialize_model()
261
  app = build_ui()
262
- # ملاحظة: على Spaces لا حاجة لـ share=True
263
- app.launch(server_name="0.0.0.0", server_port=7860, share=False)
264
-
 
1
  # -*- coding: utf-8 -*-
 
2
  """
3
+ تحليل قرحة القدم باستخدام Unet + EfficientNet-b0
4
+ النموذج من Google Drive (best_model_5.pth)
 
 
5
  """
6
 
7
  import os
 
9
  import gdown
10
  import numpy as np
11
  from PIL import Image
 
12
  import torch
 
 
 
13
  import gradio as gr
14
+ import segmentation_models_pytorch as smp
15
 
16
+ # =========================================================
17
+ # الإعدادات العامة
18
+ # =========================================================
19
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ IMG_SIZE = 512
21
+ THRESHOLD = 0.35
22
  MODEL_PATH = "best_model_5.pth"
 
 
 
23
  MODEL_URL = "https://drive.google.com/uc?id=1Ovaczsjdp3E-_gYF2pbUibDjPWAC1a6c"
24
 
25
+ CLASS_NAMES = ["قرحة (Granulation)", "Slough", "نخر (Necrosis)"]
26
  CLASS_COLORS = {
27
+ "قرحة (Granulation)": (255, 0, 0), # أحمر
28
+ "Slough": (255, 255, 0), # أصفر
29
+ "نخر (Necrosis)": (0, 0, 0) # أسود
30
  }
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  segmenter = None
33
 
34
+ # =========================================================
35
+ # تحميل النموذج
36
+ # =========================================================
37
  def initialize_model():
38
+ """تحميل نموذج Unet EfficientNet من Google Drive"""
39
  global segmenter
40
+
41
  if not os.path.exists(MODEL_PATH):
42
  print("📥 تحميل النموذج من Google Drive...")
43
  gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
44
 
45
  try:
46
+ print("🔄 تحميل Unet EfficientNet...")
47
+ model = smp.Unet(
48
+ encoder_name="efficientnet-b0",
49
+ encoder_weights=None,
50
+ classes=len(CLASS_NAMES),
51
+ activation="sigmoid"
52
+ )
53
 
54
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
55
+ if "state_dict" in checkpoint:
56
  state_dict = checkpoint["state_dict"]
57
  else:
58
  state_dict = checkpoint
59
 
60
+ clean_state = {k.replace("module.", "").replace("model.", ""): v for k, v in state_dict.items()}
61
+ model.load_state_dict(clean_state, strict=False)
62
+ model.to(DEVICE)
63
+ model.eval()
64
+ segmenter = model
65
+ print("✅ تم تحميل النموذج بنجاح.")
 
 
 
 
66
  except Exception as e:
67
  print(f"❌ فشل تحميل النموذج: {e}")
68
  import traceback; traceback.print_exc()
69
  segmenter = None
70
 
71
+ # =========================================================
72
  # أدوات مساعدة
73
+ # =========================================================
74
  def ensure_rgb(np_img):
75
+ """تحويل الصورة إلى RGB إذا لزم"""
76
  if np_img.ndim == 2:
77
  return cv2.cvtColor(np_img, cv2.COLOR_GRAY2RGB)
78
  if np_img.shape[-1] == 4:
79
  return cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
80
  return np_img
81
 
82
+ def preprocess_image(img: Image.Image):
83
+ """تجهيز الصورة للنموذج"""
84
+ img_np = ensure_rgb(np.array(img))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  img_resized = cv2.resize(img_np, (IMG_SIZE, IMG_SIZE))
86
  img_norm = img_resized.astype(np.float32) / 255.0
87
+ img_norm = (img_norm - 0.5) / 0.5
88
+ tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0)
89
+ return tensor.to(DEVICE), img_np
90
+
91
+ # =========================================================
92
+ # التجزئة والتحليل
93
+ # =========================================================
94
+ def analyze_image(img: Image.Image):
95
+ """تحليل صورة القدم وعرض النسب"""
96
+ if segmenter is None:
97
+ return img, img, {"خطأ": "النموذج غير مهيأ بعد."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ try:
100
+ print("🔍 بدء التحليل...")
101
+ tensor, img_np = preprocess_image(img)
102
+
103
+ with torch.no_grad():
104
+ output = segmenter(tensor).cpu().squeeze(0).numpy() # (3,H,W)
105
+
106
+ masks = (output >= THRESHOLD).astype(np.uint8)
107
+
108
+ # تنظيف الأقنعة
109
+ kernel = np.ones((5,5), np.uint8)
110
+ for i in range(masks.shape[0]):
111
+ masks[i] = cv2.morphologyEx(masks[i], cv2.MORPH_OPEN, kernel)
112
+ masks[i] = cv2.morphologyEx(masks[i], cv2.MORPH_CLOSE, kernel)
113
+
114
+ # حساب النسب
115
+ total_pixels = masks.shape[1] * masks.shape[2]
116
+ ratios = {
117
+ CLASS_NAMES[0]: np.sum(masks[0]) / total_pixels * 100,
118
+ CLASS_NAMES[1]: np.sum(masks[1]) / total_pixels * 100,
119
+ CLASS_NAMES[2]: np.sum(masks[2]) / total_pixels * 100
120
+ }
121
+ total_ratio = sum(ratios.values())
122
+
123
+ # إنشاء قناع لوني
124
+ color_mask = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8)
125
+ color_mask[masks[0] == 1] = CLASS_COLORS[CLASS_NAMES[0]]
126
+ color_mask[masks[1] == 1] = CLASS_COLORS[CLASS_NAMES[1]]
127
+ color_mask[masks[2] == 1] = CLASS_COLORS[CLASS_NAMES[2]]
128
+
129
+ color_mask = cv2.resize(color_mask, (img_np.shape[1], img_np.shape[0]))
130
+
131
+ # دمج القناع مع الصورة
132
+ alpha = 0.5
133
+ blended = cv2.addWeighted(img_np, 1 - alpha, color_mask, alpha, 0)
134
+
135
+ # تقييم الخطورة
136
+ if total_ratio == 0:
137
+ risk = "No Risk 🟢"
138
+ elif total_ratio < 1:
139
+ risk = "Low Risk 🟡"
140
+ elif total_ratio < 5:
141
+ risk = "Medium Risk 🟠"
142
+ else:
143
+ risk = "High Risk 🔴"
144
 
145
+ report = {
146
+ "نسب الأنسجة (%)": {k: f"{v:.2f}" for k, v in ratios.items()},
147
+ "إجمالي (%)": f"{total_ratio:.2f}",
148
+ "مستوى الخطورة": risk
149
+ }
 
 
 
 
150
 
151
+ print(f"📊 النتائج: {report}")
152
+ return Image.fromarray(blended), Image.fromarray(color_mask), report
 
 
 
 
 
 
 
 
153
 
154
+ except Exception as e:
155
+ print(f"❌ خطأ أثناء التحليل: {e}")
156
+ import traceback; traceback.print_exc()
157
+ return img, img, {"خطأ": str(e)}
158
 
159
+ # =========================================================
160
  # واجهة Gradio
161
+ # =========================================================
162
  def build_ui():
163
+ with gr.Blocks(title="تحليل قرحة القدم - EfficientNet Unet", theme=gr.themes.Soft()) as demo:
164
+ gr.Markdown("# 🦶 تحليل صورة القدم السكري (Unet + EfficientNet)")
165
+ gr.Markdown("الكشف عن أنواع الأنسجة المصابة (قرحة / Slough / نخر) وتقدير مستوى الخطورة.")
166
 
167
  with gr.Row():
168
  with gr.Column(scale=1):
169
  input_img = gr.Image(type="pil", label="📤 ارفع صورة القدم", height=320)
170
  analyze_btn = gr.Button("🔍 بدء التحليل", variant="primary")
 
171
 
172
  with gr.Column(scale=1):
173
  out_blended = gr.Image(type="pil", label="🩸 الصورة مع القناع", height=320)
 
175
  out_json = gr.JSON(label="📊 التقرير التفصيلي")
176
 
177
  analyze_btn.click(
178
+ fn=analyze_image,
179
  inputs=[input_img],
180
  outputs=[out_blended, out_mask, out_json]
181
  )
182
  return demo
183
 
184
+ # =========================================================
185
  # تشغيل التطبيق
186
+ # =========================================================
187
  if __name__ == "__main__":
188
  print("🚀 تهيئة النموذج...")
189
  initialize_model()
190
  app = build_ui()
191
+ app.launch(server_name="0.0.0.0", server_port=7860)