EngReem85 commited on
Commit
9ee12a3
·
verified ·
1 Parent(s): 02e4199

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -188
app.py CHANGED
@@ -1,57 +1,50 @@
1
  # -*- coding: utf-8 -*-
2
- """
3
- نظام تحليل قرحة القدم السكري باستخدام DFUTissueSegNet
4
- - يعتمد فقط على التجزئة متعددة الفئات (قرحة / Slough / نخر)
5
- - يعرض الألوان + نسب كل نوع نسيج + مستوى الخطورة
6
- """
7
-
8
- import os
9
- import numpy as np
10
- from PIL import Image
11
- import cv2
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
- import gdown
 
16
  import gradio as gr
 
 
 
17
 
18
-
19
- # ======================================================
20
  # الإعدادات العامة
21
- # ======================================================
22
- segmenter = None
23
- IMG_SIZE = 224
24
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
25
 
26
-
27
- # ======================================================
28
- # تعريف النموذج DFUTissueSegNet
29
- # ======================================================
30
  class ConvBlock(nn.Module):
31
- def __init__(self, in_channels, out_channels):
32
- super(ConvBlock, self).__init__()
33
  self.block = nn.Sequential(
34
- nn.Conv2d(in_channels, out_channels, 3, padding=1),
35
- nn.BatchNorm2d(out_channels),
36
  nn.ReLU(inplace=True),
37
- nn.Conv2d(out_channels, out_channels, 3, padding=1),
38
- nn.BatchNorm2d(out_channels),
39
  nn.ReLU(inplace=True)
40
  )
 
41
  def forward(self, x):
42
  return self.block(x)
43
 
44
  class DFUTissueSegNet(nn.Module):
45
  def __init__(self, num_classes=3):
46
- super(DFUTissueSegNet, self).__init__()
47
  self.encoder1 = ConvBlock(3, 64)
48
- self.pool1 = nn.MaxPool2d(2, 2)
49
  self.encoder2 = ConvBlock(64, 128)
50
- self.pool2 = nn.MaxPool2d(2, 2)
51
  self.encoder3 = ConvBlock(128, 256)
52
- self.pool3 = nn.MaxPool2d(2, 2)
53
  self.encoder4 = ConvBlock(256, 512)
54
- self.pool4 = nn.MaxPool2d(2, 2)
55
 
56
  self.center = ConvBlock(512, 1024)
57
 
@@ -71,9 +64,9 @@ class DFUTissueSegNet(nn.Module):
71
  e2 = self.encoder2(self.pool1(e1))
72
  e3 = self.encoder3(self.pool2(e2))
73
  e4 = self.encoder4(self.pool3(e3))
74
- center = self.center(self.pool4(e4))
75
 
76
- d4 = self.up4(center)
77
  d4 = torch.cat([d4, e4], dim=1)
78
  d4 = self.dec4(d4)
79
 
@@ -89,16 +82,15 @@ class DFUTissueSegNet(nn.Module):
89
  d1 = torch.cat([d1, e1], dim=1)
90
  d1 = self.dec1(d1)
91
 
92
- out = self.final(d1)
93
- out = F.softmax(out, dim=1)
94
- return out
95
 
 
 
 
 
96
 
97
- # ======================================================
98
- # تحميل النموذج
99
- # ======================================================
100
  def initialize_model():
101
- """تحميل DFUTissueSegNet من Google Drive مع دعم checkpoint الكامل"""
102
  global segmenter
103
  MODEL_URL = "https://drive.google.com/uc?id=1Ovaczsjdp3E-_gYF2pbUibDjPWAC1a6c"
104
  MODEL_PATH = "best_model_5.pth"
@@ -112,13 +104,12 @@ def initialize_model():
112
  segmenter = DFUTissueSegNet(num_classes=3)
113
  checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
114
 
115
- # بعض النماذج تكون محفوظة داخل "state_dict"
116
  if "state_dict" in checkpoint:
117
  state_dict = checkpoint["state_dict"]
118
  else:
119
  state_dict = checkpoint
120
 
121
- # إزالة البادئة "model." أو "module." لو وجدت
122
  clean_state = {}
123
  for k, v in state_dict.items():
124
  nk = k.replace("module.", "").replace("model.", "")
@@ -134,182 +125,104 @@ def initialize_model():
134
  traceback.print_exc()
135
  segmenter = None
136
 
137
-
138
-
139
- # ======================================================
140
- # دالة التجزئة
141
- # ======================================================
142
- def segment_ulcer(pil_img: Image.Image):
143
- """تجزئة متعددة الفئات + حساب نسب كل نوع نسيج"""
144
  if segmenter is None:
145
- np_img = np.array(pil_img)
146
- return np.zeros((np_img.shape[0], np_img.shape[1], 3), dtype=np.uint8), {}
147
 
148
  try:
149
- img_np = np.array(pil_img)
150
- if img_np.ndim == 2:
 
 
 
151
  img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
152
  elif img_np.shape[-1] == 4:
153
  img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB)
154
 
 
155
  img_resized = cv2.resize(img_np, (IMG_SIZE, IMG_SIZE))
156
  img_norm = img_resized.astype(np.float32) / 255.0
 
 
157
  tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
158
 
159
  with torch.no_grad():
160
- output = segmenter(tensor)
161
- pred = output.squeeze().cpu().numpy() # (3, H, W)
162
-
163
- # تأكيد الأبعاد
164
- if pred.ndim != 3 or pred.shape[0] < 3:
165
- print("⚠️ النموذج لم يُرجع 3 قنوات. سيتم استخدام القناة الأولى فقط.")
166
- pred = np.stack([pred, np.zeros_like(pred), np.zeros_like(pred)], axis=0)
167
-
168
- pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
169
-
170
- gran = pred[0, :, :] # أحمر
171
- slough = pred[1, :, :] # أصفر
172
- nec = pred[2, :, :] # أسود
173
-
174
- th = 0.55
175
- gran_mask = (gran >= th).astype(np.uint8)
176
- slough_mask = (slough >= th).astype(np.uint8)
177
- nec_mask = (nec >= th).astype(np.uint8)
178
-
179
- kernel = np.ones((5, 5), np.uint8)
180
- for m in [gran_mask, slough_mask, nec_mask]:
181
- cv2.morphologyEx(m, cv2.MORPH_OPEN, kernel)
182
- cv2.morphologyEx(m, cv2.MORPH_CLOSE, kernel)
183
-
184
- mask_color = np.zeros((*gran_mask.shape, 3), dtype=np.uint8)
185
- mask_color[gran_mask == 1] = (255, 0, 0)
186
- mask_color[slough_mask == 1] = (255, 255, 0)
187
- mask_color[nec_mask == 1] = (0, 0, 0)
188
-
189
- mask_resized = cv2.resize(mask_color, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST)
190
-
191
- total_pixels = mask_resized.shape[0] * mask_resized.shape[1]
192
- gran_ratio = np.sum(gran_mask) / total_pixels * 100
193
- slough_ratio = np.sum(slough_mask) / total_pixels * 100
194
- nec_ratio = np.sum(nec_mask) / total_pixels * 100
195
- total_ratio = gran_ratio + slough_ratio + nec_ratio
196
-
197
- tissue_stats = {
198
- "قرحة (Granulation)": f"{gran_ratio:.2f}%",
199
- "Slough (أنسجة ميتة جزئيًا)": f"{slough_ratio:.2f}%",
200
- "نخر (Necrotic)": f"{nec_ratio:.2f}%",
201
- "الإجمالي": f"{total_ratio:.2f}%"
202
- }
203
 
204
- return mask_resized, tissue_stats
205
-
206
- except Exception as e:
207
- print(f"❌ خطأ في التجزئة: {e}")
208
- import traceback
209
- traceback.print_exc()
210
- np_img = np.array(pil_img)
211
- return np.zeros((np_img.shape[0], np_img.shape[1], 3), dtype=np.uint8), {}
212
 
 
 
 
 
 
213
 
214
- # ======================================================
215
- # تطبيق القناع اللوني
216
- # ======================================================
217
- def apply_ulcer_mask(pil_img: Image.Image, mask_color):
218
- try:
219
- base = np.array(pil_img)
220
- if base.ndim == 2:
221
- base = cv2.cvtColor(base, cv2.COLOR_GRAY2RGB)
222
- elif base.shape[-1] == 4:
223
- base = cv2.cvtColor(base, cv2.COLOR_RGBA2RGB)
224
-
225
- if mask_color.shape[:2] != base.shape[:2]:
226
- mask_color = cv2.resize(mask_color, (base.shape[1], base.shape[0]))
227
-
228
- mask_gray = cv2.cvtColor(mask_color, cv2.COLOR_RGB2GRAY)
229
- alpha = (mask_gray > 0).astype(np.float32)[..., None] * 0.5
230
- blended = (alpha * mask_color + (1 - alpha) * base).astype(np.uint8)
231
- return Image.fromarray(blended)
232
- except Exception as e:
233
- print(f"❌ خطأ في تطبيق القناع: {e}")
234
- return pil_img
235
 
 
 
 
236
 
237
- # ======================================================
238
- # تحليل الصورة
239
- # ======================================================
240
- def analyze_single_image(pil_img: Image.Image):
241
- if pil_img is None:
242
- return None, None, {"خطأ": "لم يتم رفع صورة"}
243
-
244
- try:
245
- mask_color, tissue_stats = segment_ulcer(pil_img)
246
- blended = apply_ulcer_mask(pil_img, mask_color)
247
-
248
- total_ratio = float(tissue_stats.get("الإجمالي", "0").replace("%", ""))
249
  if total_ratio == 0:
250
- level, emoji = "No Risk", "🟢"
251
- elif total_ratio <= 1:
252
- level, emoji = "Low Risk", "🟡"
253
- elif total_ratio <= 5:
254
- level, emoji = "Medium Risk", "🟠"
255
  else:
256
- level, emoji = "High Risk", "🔴"
257
 
258
  report = {
259
- "مستوى_الخطورة": f"{level} {emoji}",
260
- "نسب_الأنسجة": tissue_stats,
261
- "ملاحظات": "تم التحليل اعتمادًا على DFUTissueSegNet متعدد الفئات."
 
 
262
  }
263
 
264
- return blended, Image.fromarray(mask_color), report
 
265
 
266
  except Exception as e:
267
  print(f"❌ خطأ أثناء التحليل: {e}")
268
  import traceback
269
  traceback.print_exc()
270
- return pil_img, None, {"خطأ": str(e)}
271
-
272
 
273
- # ======================================================
274
  # واجهة Gradio
275
- # ======================================================
276
- def build_ui():
277
- with gr.Blocks(title="تحليل قرحة القدم السكري", theme=gr.themes.Soft()) as demo:
278
- gr.Markdown("# 🦶 نظام تحليل قرحة القدم السكري - DFUTissueSegNet")
279
- gr.Markdown("### يعتمد التحليل على التجزئة لتحديد أنواع الأنسجة المتضررة ونسبة كل نوع")
280
-
281
- with gr.Row():
282
- with gr.Column():
283
- input_img = gr.Image(type="pil", label="📤 ارفع صورة القدم")
284
- analyze_btn = gr.Button("🔍 بدء التحليل", variant="primary")
285
-
286
- with gr.Column():
287
- output_img = gr.Image(type="pil", label="🩸 الصورة مع القناع", height=320)
288
- mask_img = gr.Image(type="pil", label="🧩 القناع اللوني", height=320)
289
- json_out = gr.JSON(label="📊 التقرير التفصيلي")
290
-
291
- gr.Markdown("""
292
- ---
293
- ### 🧭 مفتاح الألوان (Legend)
294
- - 🩸 **أحمر** → نسيج قرحة (Granulation)
295
- - 🟡 **أصفر** → نسيج ميت جزئيًا (Slough)
296
- - ⚫ **أسود** → نسيج نخر (Necrotic)
297
- ---
298
- """)
299
-
300
- analyze_btn.click(
301
- fn=analyze_single_image,
302
- inputs=[input_img],
303
- outputs=[output_img, mask_img, json_out]
304
- )
305
- return demo
306
 
 
307
 
308
- # ======================================================
309
- # ��شغيل النظام
310
- # ======================================================
311
  if __name__ == "__main__":
312
- print("🚀 تهيئة نموذج DFUTissueSegNet...")
313
- initialize_model()
314
- demo = build_ui()
315
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ import cv2
6
+ import numpy as np
7
  import gradio as gr
8
+ import gdown
9
+ import os
10
+ from PIL import Image
11
 
12
+ # =========================================================
 
13
  # الإعدادات العامة
14
+ # =========================================================
 
 
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ IMG_SIZE = 512 # حجم الإدخال الصحيح للنموذج
17
+ CLASS_NAMES = ["Ulcer", "Slough", "Necrosis"]
18
 
19
+ # =========================================================
20
+ # تعريف نموذج DFUTissueSegNet
21
+ # =========================================================
 
22
  class ConvBlock(nn.Module):
23
+ def __init__(self, in_ch, out_ch):
24
+ super().__init__()
25
  self.block = nn.Sequential(
26
+ nn.Conv2d(in_ch, out_ch, 3, padding=1),
27
+ nn.BatchNorm2d(out_ch),
28
  nn.ReLU(inplace=True),
29
+ nn.Conv2d(out_ch, out_ch, 3, padding=1),
30
+ nn.BatchNorm2d(out_ch),
31
  nn.ReLU(inplace=True)
32
  )
33
+
34
  def forward(self, x):
35
  return self.block(x)
36
 
37
  class DFUTissueSegNet(nn.Module):
38
  def __init__(self, num_classes=3):
39
+ super().__init__()
40
  self.encoder1 = ConvBlock(3, 64)
41
+ self.pool1 = nn.MaxPool2d(2)
42
  self.encoder2 = ConvBlock(64, 128)
43
+ self.pool2 = nn.MaxPool2d(2)
44
  self.encoder3 = ConvBlock(128, 256)
45
+ self.pool3 = nn.MaxPool2d(2)
46
  self.encoder4 = ConvBlock(256, 512)
47
+ self.pool4 = nn.MaxPool2d(2)
48
 
49
  self.center = ConvBlock(512, 1024)
50
 
 
64
  e2 = self.encoder2(self.pool1(e1))
65
  e3 = self.encoder3(self.pool2(e2))
66
  e4 = self.encoder4(self.pool3(e3))
67
+ c = self.center(self.pool4(e4))
68
 
69
+ d4 = self.up4(c)
70
  d4 = torch.cat([d4, e4], dim=1)
71
  d4 = self.dec4(d4)
72
 
 
82
  d1 = torch.cat([d1, e1], dim=1)
83
  d1 = self.dec1(d1)
84
 
85
+ return torch.sigmoid(self.final(d1))
 
 
86
 
87
+ # =========================================================
88
+ # تحميل النموذج من Google Drive
89
+ # =========================================================
90
+ segmenter = None
91
 
 
 
 
92
  def initialize_model():
93
+ """تحميل DFUTissueSegNet من Google Drive مع دعم checkpoint"""
94
  global segmenter
95
  MODEL_URL = "https://drive.google.com/uc?id=1Ovaczsjdp3E-_gYF2pbUibDjPWAC1a6c"
96
  MODEL_PATH = "best_model_5.pth"
 
104
  segmenter = DFUTissueSegNet(num_classes=3)
105
  checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
106
 
107
+ # إذا كان checkpoint يحتوي على state_dict
108
  if "state_dict" in checkpoint:
109
  state_dict = checkpoint["state_dict"]
110
  else:
111
  state_dict = checkpoint
112
 
 
113
  clean_state = {}
114
  for k, v in state_dict.items():
115
  nk = k.replace("module.", "").replace("model.", "")
 
125
  traceback.print_exc()
126
  segmenter = None
127
 
128
+ # =========================================================
129
+ # دوال التحليل
130
+ # =========================================================
131
+ def analyze_image(img: Image.Image):
132
+ """تحليل صورة القدم وإخراج القناع"""
 
 
133
  if segmenter is None:
134
+ return img, {"خطأ": "النموذج غير مهيأ بعد."}
 
135
 
136
  try:
137
+ print("🔍 بدء التحليل...")
138
+ img_np = np.array(img)
139
+
140
+ # التأكد من أن الصورة RGB
141
+ if len(img_np.shape) == 2:
142
  img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
143
  elif img_np.shape[-1] == 4:
144
  img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB)
145
 
146
+ # التحجيم والتطبيع الصحيح
147
  img_resized = cv2.resize(img_np, (IMG_SIZE, IMG_SIZE))
148
  img_norm = img_resized.astype(np.float32) / 255.0
149
+ img_norm = (img_norm - 0.5) / 0.5 # كما في التدريب
150
+
151
  tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
152
 
153
  with torch.no_grad():
154
+ output = segmenter(tensor).cpu().squeeze(0).numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ # تطبيق العتبة لكل قناة
157
+ th = 0.4
158
+ masks = (output >= th).astype(np.uint8)
 
 
 
 
 
159
 
160
+ # حساب النسب
161
+ ulcer_ratio = np.sum(masks[0]) / (IMG_SIZE * IMG_SIZE) * 100
162
+ slough_ratio = np.sum(masks[1]) / (IMG_SIZE * IMG_SIZE) * 100
163
+ necrosis_ratio = np.sum(masks[2]) / (IMG_SIZE * IMG_SIZE) * 100
164
+ total_ratio = ulcer_ratio + slough_ratio + necrosis_ratio
165
 
166
+ # تلوين القناع
167
+ color_mask = np.zeros_like(img_resized)
168
+ color_mask[masks[0] == 1] = [255, 0, 0] # قرحة حمراء
169
+ color_mask[masks[1] == 1] = [255, 255, 0] # Slough صفراء
170
+ color_mask[masks[2] == 1] = [0, 0, 0] # نخر أسود
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ # دمج القناع مع الصورة
173
+ blended = cv2.addWeighted(img_resized, 0.7, color_mask, 0.6, 0)
174
+ blended = cv2.cvtColor(blended, cv2.COLOR_RGB2BGR)
175
 
176
+ # تقييم الخطورة
 
 
 
 
 
 
 
 
 
 
 
177
  if total_ratio == 0:
178
+ risk = "No Risk 🟢"
179
+ elif total_ratio < 1:
180
+ risk = "Low Risk 🟡"
181
+ elif total_ratio < 5:
182
+ risk = "Medium Risk 🟠"
183
  else:
184
+ risk = "High Risk 🔴"
185
 
186
  report = {
187
+ "قرحة (%)": round(ulcer_ratio, 2),
188
+ "Slough (%)": round(slough_ratio, 2),
189
+ "نخر (%)": round(necrosis_ratio, 2),
190
+ "إجمالي (%)": round(total_ratio, 2),
191
+ "مستوى الخطورة": risk
192
  }
193
 
194
+ print(f"📊 تقرير التحليل: {report}")
195
+ return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), report
196
 
197
  except Exception as e:
198
  print(f"❌ خطأ أثناء التحليل: {e}")
199
  import traceback
200
  traceback.print_exc()
201
+ return img, {"خطأ": str(e)}
 
202
 
203
+ # =========================================================
204
  # واجهة Gradio
205
+ # =========================================================
206
+ print("🚀 تهيئة نموذج DFUTissueSegNet...")
207
+ initialize_model()
208
+
209
+ with gr.Blocks(title="تحليل DFUTissueSegNet") as demo:
210
+ gr.Markdown("""
211
+ # 🩸 نظام تحليل DFU Tissue Segmentation
212
+ **نموذج DFUTissueSegNet** لتجزئة أنسجة القدم السكري وتقدير مستوى الخطورة.
213
+ """)
214
+
215
+ with gr.Row():
216
+ with gr.Column():
217
+ input_img = gr.Image(type="pil", label="📤 ارفع صورة القدم")
218
+ analyze_btn = gr.Button("🔍 تحليل الصورة", variant="primary")
219
+
220
+ with gr.Column():
221
+ output_img = gr.Image(type="pil", label="🩸 النتيجة", height=400)
222
+ output_json = gr.JSON(label="📊 تقرير التحليل")
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ analyze_btn.click(fn=analyze_image, inputs=[input_img], outputs=[output_img, output_json])
225
 
 
 
 
226
  if __name__ == "__main__":
227
+ print("🌐 بدء التشغيل...")
228
+ demo.launch(server_name="0.0.0.0", server_port=7860)