EngReem85 commited on
Commit
f9d94c3
·
verified ·
1 Parent(s): 9ee12a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -156
app.py CHANGED
@@ -1,24 +1,46 @@
1
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- import cv2
6
- import numpy as np
7
  import gradio as gr
8
- import gdown
9
- import os
10
- from PIL import Image
11
 
12
- # =========================================================
13
- # الإعدادات العامة
14
- # =========================================================
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- IMG_SIZE = 512 # حجم الإدخال الصحيح للنموذج
17
- CLASS_NAMES = ["Ulcer", "Slough", "Necrosis"]
18
-
19
- # =========================================================
20
- # تعريف نموذج DFUTissueSegNet
21
- # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
22
  class ConvBlock(nn.Module):
23
  def __init__(self, in_ch, out_ch):
24
  super().__init__()
@@ -28,23 +50,18 @@ class ConvBlock(nn.Module):
28
  nn.ReLU(inplace=True),
29
  nn.Conv2d(out_ch, out_ch, 3, padding=1),
30
  nn.BatchNorm2d(out_ch),
31
- nn.ReLU(inplace=True)
32
  )
33
-
34
  def forward(self, x):
35
  return self.block(x)
36
 
37
  class DFUTissueSegNet(nn.Module):
38
  def __init__(self, num_classes=3):
39
  super().__init__()
40
- self.encoder1 = ConvBlock(3, 64)
41
- self.pool1 = nn.MaxPool2d(2)
42
- self.encoder2 = ConvBlock(64, 128)
43
- self.pool2 = nn.MaxPool2d(2)
44
- self.encoder3 = ConvBlock(128, 256)
45
- self.pool3 = nn.MaxPool2d(2)
46
- self.encoder4 = ConvBlock(256, 512)
47
- self.pool4 = nn.MaxPool2d(2)
48
 
49
  self.center = ConvBlock(512, 1024)
50
 
@@ -60,43 +77,29 @@ class DFUTissueSegNet(nn.Module):
60
  self.final = nn.Conv2d(64, num_classes, 1)
61
 
62
  def forward(self, x):
63
- e1 = self.encoder1(x)
64
- e2 = self.encoder2(self.pool1(e1))
65
- e3 = self.encoder3(self.pool2(e2))
66
- e4 = self.encoder4(self.pool3(e3))
67
- c = self.center(self.pool4(e4))
68
-
69
- d4 = self.up4(c)
70
- d4 = torch.cat([d4, e4], dim=1)
71
- d4 = self.dec4(d4)
72
-
73
- d3 = self.up3(d4)
74
- d3 = torch.cat([d3, e3], dim=1)
75
- d3 = self.dec3(d3)
76
-
77
- d2 = self.up2(d3)
78
- d2 = torch.cat([d2, e2], dim=1)
79
- d2 = self.dec2(d2)
80
-
81
- d1 = self.up1(d2)
82
- d1 = torch.cat([d1, e1], dim=1)
83
- d1 = self.dec1(d1)
84
-
85
  return torch.sigmoid(self.final(d1))
86
 
87
- # =========================================================
88
- # تحميل النموذج من Google Drive
89
- # =========================================================
90
  segmenter = None
91
 
 
 
 
92
  def initialize_model():
93
- """تحميل DFUTissueSegNet من Google Drive مع دعم checkpoint"""
94
  global segmenter
95
- MODEL_URL = "https://drive.google.com/uc?id=1Ovaczsjdp3E-_gYF2pbUibDjPWAC1a6c"
96
- MODEL_PATH = "best_model_5.pth"
97
-
98
  if not os.path.exists(MODEL_PATH):
99
- print("📥 تحميل DFUTissueSegNet من Google Drive...")
100
  gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
101
 
102
  try:
@@ -104,12 +107,13 @@ def initialize_model():
104
  segmenter = DFUTissueSegNet(num_classes=3)
105
  checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
106
 
107
- # إذا كان checkpoint يحتوي على state_dict
108
- if "state_dict" in checkpoint:
109
  state_dict = checkpoint["state_dict"]
110
  else:
111
  state_dict = checkpoint
112
 
 
113
  clean_state = {}
114
  for k, v in state_dict.items():
115
  nk = k.replace("module.", "").replace("model.", "")
@@ -118,111 +122,143 @@ def initialize_model():
118
  segmenter.load_state_dict(clean_state, strict=False)
119
  segmenter.to(DEVICE)
120
  segmenter.eval()
121
- print("✅ تم تحميل DFUTissueSegNet من Google Drive بنجاح.")
122
  except Exception as e:
123
  print(f"❌ فشل تحميل النموذج: {e}")
124
- import traceback
125
- traceback.print_exc()
126
  segmenter = None
127
 
128
- # =========================================================
129
- # دوال التحليل
130
- # =========================================================
131
- def analyze_image(img: Image.Image):
132
- """تحليل صورة القدم وإخراج القناع"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  if segmenter is None:
134
- return img, {"خطأ": "النموذج غير مهيأ بعد."}
135
-
136
- try:
137
- print("🔍 بدء التحليل...")
138
- img_np = np.array(img)
139
-
140
- # التأكد من أن الصورة RGB
141
- if len(img_np.shape) == 2:
142
- img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
143
- elif img_np.shape[-1] == 4:
144
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB)
145
-
146
- # التحجيم والتطبيع الصحيح
147
- img_resized = cv2.resize(img_np, (IMG_SIZE, IMG_SIZE))
148
- img_norm = img_resized.astype(np.float32) / 255.0
149
- img_norm = (img_norm - 0.5) / 0.5 # كما في التدريب
150
-
151
- tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
152
-
153
- with torch.no_grad():
154
- output = segmenter(tensor).cpu().squeeze(0).numpy()
155
-
156
- # تطبيق العتبة لكل قناة
157
- th = 0.4
158
- masks = (output >= th).astype(np.uint8)
159
-
160
- # حساب النسب
161
- ulcer_ratio = np.sum(masks[0]) / (IMG_SIZE * IMG_SIZE) * 100
162
- slough_ratio = np.sum(masks[1]) / (IMG_SIZE * IMG_SIZE) * 100
163
- necrosis_ratio = np.sum(masks[2]) / (IMG_SIZE * IMG_SIZE) * 100
164
- total_ratio = ulcer_ratio + slough_ratio + necrosis_ratio
165
-
166
- # تلوين القناع
167
- color_mask = np.zeros_like(img_resized)
168
- color_mask[masks[0] == 1] = [255, 0, 0] # قرحة حمراء
169
- color_mask[masks[1] == 1] = [255, 255, 0] # Slough صفراء
170
- color_mask[masks[2] == 1] = [0, 0, 0] # نخر أسود
171
-
172
- # دمج القناع مع الصورة
173
- blended = cv2.addWeighted(img_resized, 0.7, color_mask, 0.6, 0)
174
- blended = cv2.cvtColor(blended, cv2.COLOR_RGB2BGR)
175
-
176
- # تقييم الخطورة
177
- if total_ratio == 0:
178
- risk = "No Risk 🟢"
179
- elif total_ratio < 1:
180
- risk = "Low Risk 🟡"
181
- elif total_ratio < 5:
182
- risk = "Medium Risk 🟠"
183
- else:
184
- risk = "High Risk 🔴"
185
-
186
- report = {
187
- "قرحة (%)": round(ulcer_ratio, 2),
188
- "Slough (%)": round(slough_ratio, 2),
189
- "نخر (%)": round(necrosis_ratio, 2),
190
- "إجمالي (%)": round(total_ratio, 2),
191
- "مستوى الخطورة": risk
192
- }
193
-
194
- print(f"📊 تقرير التحليل: {report}")
195
- return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), report
196
-
197
- except Exception as e:
198
- print(f" خطأ أثناء التحليل: {e}")
199
- import traceback
200
- traceback.print_exc()
201
- return img, {"خطأ": str(e)}
202
-
203
- # =========================================================
204
  # واجهة Gradio
205
- # =========================================================
206
- print("🚀 تهيئة نموذج DFUTissueSegNet...")
207
- initialize_model()
208
-
209
- with gr.Blocks(title="تحليل DFUTissueSegNet") as demo:
210
- gr.Markdown("""
211
- # 🩸 نظام تحليل DFU Tissue Segmentation
212
- **نموذج DFUTissueSegNet** لتجزئة أنسجة القدم السكري وتقدير مستوى الخطورة.
213
- """)
214
-
215
- with gr.Row():
216
- with gr.Column():
217
- input_img = gr.Image(type="pil", label="📤 ارفع صورة القدم")
218
- analyze_btn = gr.Button("🔍 تحليل الصورة", variant="primary")
219
-
220
- with gr.Column():
221
- output_img = gr.Image(type="pil", label="🩸 النتيجة", height=400)
222
- output_json = gr.JSON(label="📊 تقرير التحليل")
223
-
224
- analyze_btn.click(fn=analyze_image, inputs=[input_img], outputs=[output_img, output_json])
 
 
 
225
 
 
 
 
226
  if __name__ == "__main__":
227
- print("🌐 بدء التشغيل...")
228
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ واجهة موحّدة لتحليل صورة القدم باستخدام DFUTissueSegNet (من Google Drive)
5
+ - ألوان متعددة: قرحة=أحمر، Slough=أصفر، نخر=أسود
6
+ - حساب نسب كل نوع + مستوى خطورة إجمالي
7
+ - Legend داخل Gradio
8
+ """
9
+
10
+ import os
11
+ import cv2
12
+ import gdown
13
+ import numpy as np
14
+ from PIL import Image
15
+
16
  import torch
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
+
 
20
  import gradio as gr
 
 
 
21
 
22
+ # ================================
23
+ # إعدادات عامة
24
+ # ================================
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ IMG_SIZE = 512 # حجم الإدخال كما في تدريب DFUTissueSegNet
27
+ THRESHOLD = 0.4 # عتبة تحويل الاحتمال إلى قناع (جرّبي 0.35..0.6 حسب بياناتك)
28
+ MODEL_PATH = "best_model_5.pth"
29
+
30
+ # غيّري الـ File ID أدناه لملفّك على Google Drive عند الحاجة
31
+ # مثال: https://drive.google.com/file/d/FILE_ID/view => استخدمي: https://drive.google.com/uc?id=FILE_ID
32
+ MODEL_URL = "https://drive.google.com/uc?id=1Ovaczsjdp3E-_gYF2pbUibDjPWAC1a6c"
33
+
34
+ CLASS_NAMES = ["قرحة (Granulation)", "Slough (أنسجة ميتة جزئيًا)", "نخر (Necrotic)"]
35
+ CLASS_COLORS = {
36
+ "قرحة (Granulation)": (255, 0, 0), # أحمر
37
+ "Slough (أنسجة ميتة جزئيًا)": (255, 255, 0), # أصفر
38
+ "نخر (Necrotic)": (0, 0, 0) # أسود
39
+ }
40
+
41
+ # ================================
42
+ # نموذج DFUTissueSegNet
43
+ # ================================
44
  class ConvBlock(nn.Module):
45
  def __init__(self, in_ch, out_ch):
46
  super().__init__()
 
50
  nn.ReLU(inplace=True),
51
  nn.Conv2d(out_ch, out_ch, 3, padding=1),
52
  nn.BatchNorm2d(out_ch),
53
+ nn.ReLU(inplace=True),
54
  )
 
55
  def forward(self, x):
56
  return self.block(x)
57
 
58
  class DFUTissueSegNet(nn.Module):
59
  def __init__(self, num_classes=3):
60
  super().__init__()
61
+ self.enc1 = ConvBlock(3, 64); self.pool1 = nn.MaxPool2d(2)
62
+ self.enc2 = ConvBlock(64, 128); self.pool2 = nn.MaxPool2d(2)
63
+ self.enc3 = ConvBlock(128, 256);self.pool3 = nn.MaxPool2d(2)
64
+ self.enc4 = ConvBlock(256, 512);self.pool4 = nn.MaxPool2d(2)
 
 
 
 
65
 
66
  self.center = ConvBlock(512, 1024)
67
 
 
77
  self.final = nn.Conv2d(64, num_classes, 1)
78
 
79
  def forward(self, x):
80
+ e1 = self.enc1(x)
81
+ e2 = self.enc2(self.pool1(e1))
82
+ e3 = self.enc3(self.pool2(e2))
83
+ e4 = self.enc4(self.pool3(e3))
84
+ c = self.center(self.pool4(e4))
85
+
86
+ d4 = self.up4(c); d4 = torch.cat([d4, e4], dim=1); d4 = self.dec4(d4)
87
+ d3 = self.up3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.dec3(d3)
88
+ d2 = self.up2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.dec2(d2)
89
+ d1 = self.up1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.dec1(d1)
90
+
91
+ # مخرجات احتمالية 0..1 لكل قناة
 
 
 
 
 
 
 
 
 
 
92
  return torch.sigmoid(self.final(d1))
93
 
 
 
 
94
  segmenter = None
95
 
96
+ # ================================
97
+ # تحميل النموذج (من Google Drive)
98
+ # ================================
99
  def initialize_model():
 
100
  global segmenter
 
 
 
101
  if not os.path.exists(MODEL_PATH):
102
+ print("📥 تحميل النموذج من Google Drive...")
103
  gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
104
 
105
  try:
 
107
  segmenter = DFUTissueSegNet(num_classes=3)
108
  checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
109
 
110
+ # التعامل مع checkpoints التي تحتوي state_dict
111
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
112
  state_dict = checkpoint["state_dict"]
113
  else:
114
  state_dict = checkpoint
115
 
116
+ # تنظيف البوادئ الشائعة
117
  clean_state = {}
118
  for k, v in state_dict.items():
119
  nk = k.replace("module.", "").replace("model.", "")
 
122
  segmenter.load_state_dict(clean_state, strict=False)
123
  segmenter.to(DEVICE)
124
  segmenter.eval()
125
+ print("✅ النموذج جاهز.")
126
  except Exception as e:
127
  print(f"❌ فشل تحميل النموذج: {e}")
128
+ import traceback; traceback.print_exc()
 
129
  segmenter = None
130
 
131
+ # ================================
132
+ # أدوات مساعدة
133
+ # ================================
134
+ def ensure_rgb(np_img):
135
+ if np_img.ndim == 2:
136
+ return cv2.cvtColor(np_img, cv2.COLOR_GRAY2RGB)
137
+ if np_img.shape[-1] == 4:
138
+ return cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
139
+ return np_img
140
+
141
+ def apply_legend_markdown():
142
+ return """
143
+ ### 🧭 مفتاح الألوان (Legend)
144
+ - 🩸 **أحمر** → نسيج قرحة (Granulation)
145
+ - 🟡 **أصفر** → نسيج ميت جزئيًا (Slough)
146
+ - ⚫ **أسود** → نسيج نخر (Necrotic)
147
+ """
148
+
149
+ # ================================
150
+ # التجزئة + الحساب + التلوين
151
+ # ================================
152
+ def segment_and_color(pil_img: Image.Image):
153
+ """
154
+ يُرجع:
155
+ - blended: الصورة مدموج عليها القناع اللوني
156
+ - mask_rgb: القناع اللوني (RGB)
157
+ - stats: نسب كل فئة + الإجمالي + مستوى الخطورة
158
+ """
159
  if segmenter is None:
160
+ return pil_img, pil_img, {"خطأ": "النموذج غير مهيأ"}
161
+
162
+ # 1) التحضير
163
+ img_np = ensure_rgb(np.array(pil_img))
164
+ h, w = img_np.shape[:2]
165
+
166
+ # 2) التحجيم + التطبيع كما في التدريب
167
+ img_resized = cv2.resize(img_np, (IMG_SIZE, IMG_SIZE))
168
+ img_norm = img_resized.astype(np.float32) / 255.0
169
+ img_norm = (img_norm - 0.5) / 0.5 # (x - 0.5) / 0.5
170
+
171
+ tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
172
+
173
+ # 3) التنبؤ
174
+ with torch.no_grad():
175
+ probs = segmenter(tensor).cpu().squeeze(0).numpy() # (3, H, W) احتمالات
176
+
177
+ # 4) أقنعة ثنائية لكل فئة
178
+ masks = (probs >= THRESHOLD).astype(np.uint8) # 0/1
179
+ # إزالة الضوضاء البسيطة
180
+ kernel = np.ones((5, 5), np.uint8)
181
+ for i in range(masks.shape[0]):
182
+ masks[i] = cv2.morphologyEx(masks[i], cv2.MORPH_OPEN, kernel)
183
+ masks[i] = cv2.morphologyEx(masks[i], cv2.MORPH_CLOSE, kernel)
184
+
185
+ # 5) حساب النسب على أبعاد الإدخال ثم إعادة القياس
186
+ total_pixels_input = IMG_SIZE * IMG_SIZE
187
+ ratios = {
188
+ CLASS_NAMES[0]: np.sum(masks[0]) / total_pixels_input * 100,
189
+ CLASS_NAMES[1]: np.sum(masks[1]) / total_pixels_input * 100,
190
+ CLASS_NAMES[2]: np.sum(masks[2]) / total_pixels_input * 100,
191
+ }
192
+ total_ratio = sum(ratios.values())
193
+
194
+ # 6) إنشاء قناع لوني على حجم الإدخال ثم إعادته لحجم الصورة الأصلي
195
+ color_mask = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
196
+ color_mask[masks[0] == 1] = CLASS_COLORS[CLASS_NAMES[0]]
197
+ color_mask[masks[1] == 1] = CLASS_COLORS[CLASS_NAMES[1]]
198
+ color_mask[masks[2] == 1] = CLASS_COLORS[CLASS_NAMES[2]]
199
+
200
+ mask_rgb = cv2.resize(color_mask, (w, h), interpolation=cv2.INTER_NEAREST)
201
+
202
+ # 7) دمج القناع مع الصورة (ألفا ~0.5)
203
+ alpha = (cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2GRAY) > 0).astype(np.float32)[..., None] * 0.5
204
+ blended = (alpha * mask_rgb + (1 - alpha) * img_np).astype(np.uint8)
205
+
206
+ # 8) مستوى الخطورة
207
+ if total_ratio == 0:
208
+ risk = "No Risk 🟢"
209
+ elif total_ratio <= 1:
210
+ risk = "Low Risk 🟡"
211
+ elif total_ratio <= 5:
212
+ risk = "Medium Risk 🟠"
213
+ else:
214
+ risk = "High Risk 🔴"
215
+
216
+ stats = {
217
+ "نِسَب_الأنسجة": {
218
+ CLASS_NAMES[0]: f"{ratios[CLASS_NAMES[0]]:.2f}%",
219
+ CLASS_NAMES[1]: f"{ratios[CLASS_NAMES[1]]:.2f}%",
220
+ CLASS_NAMES[2]: f"{ratios[CLASS_NAMES[2]]:.2f}%",
221
+ "الإجمالي": f"{total_ratio:.2f}%",
222
+ },
223
+ "مستوى_الخطورة": risk,
224
+ "ملاحظات": "التحليل يعتمد على DFUTissueSegNet متعدد الفئات (حجم 512 وتطبيع (x-0.5)/0.5)."
225
+ }
226
+
227
+ return Image.fromarray(blended), Image.fromarray(mask_rgb), stats
228
+
229
+ # ================================
230
  # واجهة Gradio
231
+ # ================================
232
+ def build_ui():
233
+ with gr.Blocks(title="تحليل قرحة القدم - DFUTissueSegNet", theme=gr.themes.Soft()) as demo:
234
+ gr.Markdown("# 🦶 نظام تحليل قرحة القدم السكري - صورة واحدة")
235
+ gr.Markdown("يعتمد على **DFUTissueSegNet** لتجزئة الأنسجة وحساب نسبها، مع تلوين واضح وLegend.")
236
+
237
+ with gr.Row():
238
+ with gr.Column(scale=1):
239
+ input_img = gr.Image(type="pil", label="📤 ارفع صورة القدم", height=320)
240
+ analyze_btn = gr.Button("🔍 بدء التحليل", variant="primary")
241
+ legend = gr.Markdown(apply_legend_markdown())
242
+
243
+ with gr.Column(scale=1):
244
+ out_blended = gr.Image(type="pil", label="🩸 الصورة مع القناع", height=320)
245
+ out_mask = gr.Image(type="pil", label="🧩 القناع اللوني", height=320)
246
+ out_json = gr.JSON(label="📊 التقرير التفصيلي")
247
+
248
+ analyze_btn.click(
249
+ fn=segment_and_color,
250
+ inputs=[input_img],
251
+ outputs=[out_blended, out_mask, out_json]
252
+ )
253
+ return demo
254
 
255
+ # ================================
256
+ # تشغيل التطبيق
257
+ # ================================
258
  if __name__ == "__main__":
259
+ print("🚀 تهيئة النموذج...")
260
+ initialize_model()
261
+ app = build_ui()
262
+ # ملاحظة: على Spaces لا حاجة لـ share=True
263
+ app.launch(server_name="0.0.0.0", server_port=7860, share=False)
264
+