danicor commited on
Commit
a93dee3
·
verified ·
1 Parent(s): 7d12390

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -107
app.py CHANGED
@@ -10,7 +10,13 @@ import time
10
 
11
  # افزودن مسیر مورد نیاز برای ماژول‌های CelebAMask-HQ
12
  celebamask_path = "/home/user/app/CelebAMask-HQ"
13
- sys.path.insert(0, os.path.join(celebamask_path, "face_parsing"))
 
 
 
 
 
 
14
 
15
  # ایمپورت ماژول‌های مورد نیاز
16
  try:
@@ -21,6 +27,19 @@ try:
21
  except ImportError as e:
22
  IMPORT_SUCCESS = False
23
  print(f"❌ Failed to import CelebAMask-HQ modules: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # تنظیمات دستگاه
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -33,7 +52,7 @@ os.environ["HF_HOME"] = "/home/user/app/hf_cache"
33
  transform = T.Compose([
34
  T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
35
  T.ToTensor(),
36
- T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
37
  ])
38
 
39
  # کلاس‌های Face Parsing
@@ -63,8 +82,16 @@ class FaceParsingModel:
63
  state_dict = torch.load(model_path, map_location="cpu")
64
 
65
  # ایجاد مدل
66
- self.model = unet()
67
- self.model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
68
  self.model.eval()
69
  self.model.to(self.device)
70
 
@@ -72,6 +99,8 @@ class FaceParsingModel:
72
 
73
  except Exception as e:
74
  print(f"❌ Failed to load model: {e}")
 
 
75
  self.model = None
76
 
77
  def predict(self, image):
@@ -96,15 +125,33 @@ class FaceParsingModel:
96
  with torch.no_grad():
97
  out = self.model(data)
98
  out = generate_label(out, 512)
99
- out = out[0].cpu().numpy().transpose(1, 2, 0)
100
- out = np.clip(np.round(out * 255), 0, 255).astype(np.uint8)
101
-
 
 
102
  # ترکیب تصویر اصلی با ماسک
103
- resized_image = np.asarray(original_image.resize((512, 512))).astype(float)
104
- blended = resized_image * 0.5 + out.astype(float) * 0.5
105
- blended = np.clip(np.round(blended), 0, 255).astype(np.uint8)
106
-
107
- return out, blended
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  def initialize_app():
110
  """Initialize application"""
@@ -112,12 +159,14 @@ def initialize_app():
112
 
113
  print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
114
  print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path))
115
- print("[Info] face_parsing folder exists:", os.path.exists(os.path.join(celebamask_path, "face_parsing")))
116
  print("[Info] Module import success:", IMPORT_SUCCESS)
117
 
118
  try:
119
  face_parser = FaceParsingModel()
120
- return True, "Model loaded successfully", face_parser
 
 
121
  except Exception as e:
122
  print(f"[Error] Initialization failed: {e}")
123
  return False, f"Initialization failed: {e}", None
@@ -143,7 +192,6 @@ def process_image(input_image):
143
  img_size = original_img.size
144
  else:
145
  img_size = input_image.size if hasattr(input_image, 'size') else input_image.shape[:2][::-1]
146
- original_img = PIL.Image.fromarray(input_image) if isinstance(input_image, np.ndarray) else input_image
147
 
148
  info_text = f"""
149
  ✅ پردازش انجام شد!
@@ -158,97 +206,8 @@ def process_image(input_image):
158
  except Exception as e:
159
  error_msg = f"❌ خطا در پردازش تصویر: {str(e)}"
160
  print(error_msg)
 
 
161
  return None, None, error_msg
162
 
163
- def create_legend():
164
- """ایجاد لیجند برای کلاس‌ها"""
165
- import matplotlib.pyplot as plt
166
-
167
- legend_html = """
168
- <div style='max-height: 300px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; border-radius: 5px;'>
169
- <h4>🎨 Legend - کلاس‌های Face Parsing:</h4>
170
- """
171
-
172
- colors = plt.get_cmap('tab20', len(CELEBA_CLASSES))
173
-
174
- for i, class_name in enumerate(CELEBA_CLASSES):
175
- color = colors(i)
176
- color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
177
- text_color = 'white' if color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114 < 0.5 else 'black'
178
- legend_html += f"""
179
- <div style='margin: 2px; padding: 5px; background-color: {color_hex}; color: {text_color}; border-radius: 3px;'>
180
- <strong>{i}:</strong> {class_name}
181
- </div>
182
- """
183
-
184
- legend_html += "</div>"
185
- return legend_html
186
-
187
- # ایجاد اینترفیس Gradio
188
- with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as demo:
189
- gr.Markdown("""
190
- # 🎭 CelebAMask-HQ Face Parsing Demo
191
- **آپلود یک تصویر صورت و دریافت خروجی Face Parsing**
192
-
193
- این مدل صورت را به 19 بخش مختلف تقسیم می‌کند (پوست، چشم، ابرو، بینی، دهان، مو و ...)
194
- """)
195
-
196
- with gr.Row():
197
- with gr.Column():
198
- input_image = gr.Image(
199
- label="📷 تصویر ورودی",
200
- type="filepath",
201
- sources=["upload"],
202
- height=300
203
- )
204
- process_btn = gr.Button("🚀 پردازش تصویر", variant="primary", size="lg")
205
-
206
- with gr.Accordion("ℹ️ وضعیت برنامه", open=False):
207
- status_display = gr.Markdown(f"""
208
- **وضعیت:**
209
- - 🎯 مدل: {'✅ لود شده' if success else '❌ خطا در لود'}
210
- - 💻 دستگاه: `{device}`
211
- - 📦 ماژول‌ها: {'✅ ایمپورت شده' if IMPORT_SUCCESS else '❌ خطا در ایمپورت'}
212
- - 🗂️ کلاس‌ها: {len(CELEBA_CLASSES)}
213
- """)
214
-
215
- with gr.Column():
216
- output_blended = gr.Image(
217
- label="🎨 نتیجه ترکیبی (تصویر + ماسک)",
218
- height=300
219
- )
220
- output_mask = gr.Image(
221
- label="🎭 ماسک سگمنتیشن",
222
- height=300
223
- )
224
-
225
- with gr.Row():
226
- info_output = gr.Textbox(
227
- label="📊 اطلاعات پردازش",
228
- lines=3,
229
- max_lines=6
230
- )
231
-
232
- with gr.Row():
233
- gr.HTML(create_legend())
234
-
235
- # اتصال رویدادها
236
- process_btn.click(
237
- fn=process_image,
238
- inputs=[input_image],
239
- outputs=[output_blended, output_mask, info_output]
240
- )
241
-
242
- input_image.upload(
243
- fn=process_image,
244
- inputs=[input_image],
245
- outputs=[output_blended, output_mask, info_output]
246
- )
247
-
248
- if __name__ == "__main__":
249
- print("🚀 Starting Face Parsing Application...")
250
- demo.launch(
251
- server_name="0.0.0.0",
252
- server_port=7860,
253
- share=False
254
- )
 
10
 
11
  # افزودن مسیر مورد نیاز برای ماژول‌های CelebAMask-HQ
12
  celebamask_path = "/home/user/app/CelebAMask-HQ"
13
+ face_parsing_path = os.path.join(celebamask_path, "face_parsing")
14
+ sys.path.insert(0, celebamask_path)
15
+ sys.path.insert(0, face_parsing_path)
16
+
17
+ print("Python path:", sys.path)
18
+ print("CelebAMask path exists:", os.path.exists(celebamask_path))
19
+ print("Face parsing path exists:", os.path.exists(face_parsing_path))
20
 
21
  # ایمپورت ماژول‌های مورد نیاز
22
  try:
 
27
  except ImportError as e:
28
  IMPORT_SUCCESS = False
29
  print(f"❌ Failed to import CelebAMask-HQ modules: {e}")
30
+ # تعریف توابع جایگزین در صورت نیاز
31
+ def unet(**kwargs):
32
+ from unet import UNet
33
+ return UNet(**kwargs)
34
+
35
+ def generate_label(inputs, imsize=512):
36
+ pred_batch = []
37
+ for input in inputs:
38
+ pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
39
+ pred_batch.append(pred)
40
+ pred_batch = np.array(pred_batch)
41
+ pred_batch = torch.from_numpy(pred_batch)
42
+ return pred_batch
43
 
44
  # تنظیمات دستگاه
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
52
  transform = T.Compose([
53
  T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
54
  T.ToTensor(),
55
+ T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
56
  ])
57
 
58
  # کلاس‌های Face Parsing
 
82
  state_dict = torch.load(model_path, map_location="cpu")
83
 
84
  # ایجاد مدل
85
+ self.model = unet(n_channels=3, n_classes=19)
86
+
87
+ # تطبیق state dict در صورت نیاز
88
+ new_state_dict = {}
89
+ for k, v in state_dict.items():
90
+ if k.startswith('module.'):
91
+ k = k[7:] # حذف 'module.' اگر وجود دارد
92
+ new_state_dict[k] = v
93
+
94
+ self.model.load_state_dict(new_state_dict, strict=False)
95
  self.model.eval()
96
  self.model.to(self.device)
97
 
 
99
 
100
  except Exception as e:
101
  print(f"❌ Failed to load model: {e}")
102
+ import traceback
103
+ traceback.print_exc()
104
  self.model = None
105
 
106
  def predict(self, image):
 
125
  with torch.no_grad():
126
  out = self.model(data)
127
  out = generate_label(out, 512)
128
+ out = out[0].cpu().numpy()
129
+
130
+ # تبدیل به تصویر رنگی
131
+ colored_mask = self.colorize_mask(out)
132
+
133
  # ترکیب تصویر اصلی با ماسک
134
+ resized_image = np.asarray(original_image.resize((512, 512)))
135
+ blended = resized_image * 0.6 + colored_mask * 0.4
136
+ blended = np.clip(blended, 0, 255).astype(np.uint8)
137
+
138
+ return colored_mask, blended
139
+
140
+ def colorize_mask(self, mask):
141
+ """رنگ‌آمیزی ماسک بر اساس کلاس‌ها"""
142
+ # پالت رنگ برای 19 کلاس
143
+ palette = [
144
+ [0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0],
145
+ [255, 0, 255], [0, 255, 255], [128, 0, 0], [0, 128, 0], [0, 0, 128],
146
+ [128, 128, 0], [128, 0, 128], [0, 128, 128], [128, 128, 128], [255, 128, 0],
147
+ [255, 0, 128], [128, 255, 0], [0, 255, 128], [255, 128, 128]
148
+ ]
149
+
150
+ colored = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
151
+ for i in range(len(palette)):
152
+ colored[mask == i] = palette[i]
153
+
154
+ return colored
155
 
156
  def initialize_app():
157
  """Initialize application"""
 
159
 
160
  print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
161
  print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path))
162
+ print("[Info] face_parsing folder exists:", os.path.exists(face_parsing_path))
163
  print("[Info] Module import success:", IMPORT_SUCCESS)
164
 
165
  try:
166
  face_parser = FaceParsingModel()
167
+ success = face_parser.model is not None
168
+ status_msg = "Model loaded successfully" if success else "Model failed to load"
169
+ return success, status_msg, face_parser
170
  except Exception as e:
171
  print(f"[Error] Initialization failed: {e}")
172
  return False, f"Initialization failed: {e}", None
 
192
  img_size = original_img.size
193
  else:
194
  img_size = input_image.size if hasattr(input_image, 'size') else input_image.shape[:2][::-1]
 
195
 
196
  info_text = f"""
197
  ✅ پردازش انجام شد!
 
206
  except Exception as e:
207
  error_msg = f"❌ خطا در پردازش تصویر: {str(e)}"
208
  print(error_msg)
209
+ import traceback
210
+ traceback.print_exc()
211
  return None, None, error_msg
212
 
213
+ # ادامه کد مشابه قبل برای Gradio interface...