danicor commited on
Commit
7e96e19
·
verified ·
1 Parent(s): a93dee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -42
app.py CHANGED
@@ -3,10 +3,12 @@ import sys
3
  import numpy as np
4
  import PIL.Image
5
  import torch
 
6
  import torchvision.transforms as T
7
  from huggingface_hub import hf_hub_download
8
  import gradio as gr
9
  import time
 
10
 
11
  # افزودن مسیر مورد نیاز برای ماژول‌های CelebAMask-HQ
12
  celebamask_path = "/home/user/app/CelebAMask-HQ"
@@ -18,28 +20,95 @@ print("Python path:", sys.path)
18
  print("CelebAMask path exists:", os.path.exists(celebamask_path))
19
  print("Face parsing path exists:", os.path.exists(face_parsing_path))
20
 
21
- # ایمپورت ماژول‌های مورد نیاز
22
- try:
23
- from unet import unet
24
- from utils import generate_label
25
- IMPORT_SUCCESS = True
26
- print("✅ Successfully imported CelebAMask-HQ modules")
27
- except ImportError as e:
28
- IMPORT_SUCCESS = False
29
- print(f"❌ Failed to import CelebAMask-HQ modules: {e}")
30
- # تعریف توابع جایگزین در صورت نیاز
31
- def unet(**kwargs):
32
- from unet import UNet
33
- return UNet(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- def generate_label(inputs, imsize=512):
36
- pred_batch = []
37
- for input in inputs:
38
- pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
39
- pred_batch.append(pred)
40
- pred_batch = np.array(pred_batch)
41
- pred_batch = torch.from_numpy(pred_batch)
42
- return pred_batch
43
 
44
  # تنظیمات دستگاه
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -50,9 +119,9 @@ os.environ["HF_HOME"] = "/home/user/app/hf_cache"
50
 
51
  # تعریف transform
52
  transform = T.Compose([
53
- T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
54
  T.ToTensor(),
55
- T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
56
  ])
57
 
58
  # کلاس‌های Face Parsing
@@ -78,20 +147,20 @@ class FaceParsingModel:
78
  )
79
  print(f"✅ Model downloaded to: {model_path}")
80
 
 
 
 
81
  # لود state dict
82
  state_dict = torch.load(model_path, map_location="cpu")
83
 
84
- # ایجاد مدل
85
- self.model = unet(n_channels=3, n_classes=19)
86
-
87
- # تطبیق state dict در صورت نیاز
88
  new_state_dict = {}
89
  for k, v in state_dict.items():
90
  if k.startswith('module.'):
91
- k = k[7:] # حذف 'module.' اگر وجود دارد
92
  new_state_dict[k] = v
93
 
94
- self.model.load_state_dict(new_state_dict, strict=False)
95
  self.model.eval()
96
  self.model.to(self.device)
97
 
@@ -124,27 +193,41 @@ class FaceParsingModel:
124
  # پیش‌بینی
125
  with torch.no_grad():
126
  out = self.model(data)
127
- out = generate_label(out, 512)
128
- out = out[0].cpu().numpy()
129
 
130
- # تبدیل به تصویر رنگی
131
- colored_mask = self.colorize_mask(out)
132
 
133
  # ترکیب تصویر اصلی با ماسک
134
  resized_image = np.asarray(original_image.resize((512, 512)))
135
- blended = resized_image * 0.6 + colored_mask * 0.4
136
- blended = np.clip(blended, 0, 255).astype(np.uint8)
137
 
138
  return colored_mask, blended
139
 
140
  def colorize_mask(self, mask):
141
  """رنگ‌آمیزی ماسک بر اساس کلاس‌ها"""
142
- # پالت رنگ برای 19 کلاس
143
  palette = [
144
- [0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0],
145
- [255, 0, 255], [0, 255, 255], [128, 0, 0], [0, 128, 0], [0, 0, 128],
146
- [128, 128, 0], [128, 0, 128], [0, 128, 128], [128, 128, 128], [255, 128, 0],
147
- [255, 0, 128], [128, 255, 0], [0, 255, 128], [255, 128, 128]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  ]
149
 
150
  colored = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
@@ -160,7 +243,6 @@ def initialize_app():
160
  print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
161
  print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path))
162
  print("[Info] face_parsing folder exists:", os.path.exists(face_parsing_path))
163
- print("[Info] Module import success:", IMPORT_SUCCESS)
164
 
165
  try:
166
  face_parser = FaceParsingModel()
@@ -210,4 +292,73 @@ def process_image(input_image):
210
  traceback.print_exc()
211
  return None, None, error_msg
212
 
213
- # ادامه کد مشابه قبل برای Gradio interface...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  import PIL.Image
5
  import torch
6
+ import torch.nn as nn
7
  import torchvision.transforms as T
8
  from huggingface_hub import hf_hub_download
9
  import gradio as gr
10
  import time
11
+ import cv2
12
 
13
  # افزودن مسیر مورد نیاز برای ماژول‌های CelebAMask-HQ
14
  celebamask_path = "/home/user/app/CelebAMask-HQ"
 
20
  print("CelebAMask path exists:", os.path.exists(celebamask_path))
21
  print("Face parsing path exists:", os.path.exists(face_parsing_path))
22
 
23
+ # تعریف معماری مدل مطابق با state dict دانلود شده
24
+ class SimpleFaceParser(nn.Module):
25
+ def __init__(self, n_channels=3, n_classes=19):
26
+ super(SimpleFaceParser, self).__init__()
27
+
28
+ def conv_block(in_channels, out_channels):
29
+ return nn.Sequential(
30
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
31
+ nn.BatchNorm2d(out_channels),
32
+ nn.ReLU(inplace=True),
33
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
34
+ nn.BatchNorm2d(out_channels),
35
+ nn.ReLU(inplace=True)
36
+ )
37
+
38
+ # Encoder
39
+ self.enc1 = conv_block(n_channels, 16)
40
+ self.enc2 = conv_block(16, 32)
41
+ self.enc3 = conv_block(32, 64)
42
+ self.enc4 = conv_block(64, 128)
43
+ self.enc5 = conv_block(128, 256)
44
+
45
+ # Decoder
46
+ self.dec4 = conv_block(256 + 128, 128)
47
+ self.dec3 = conv_block(128 + 64, 64)
48
+ self.dec2 = conv_block(64 + 32, 32)
49
+ self.dec1 = conv_block(32 + 16, 16)
50
+
51
+ # Pooling and upsample
52
+ self.pool = nn.MaxPool2d(2)
53
+ self.upsample4 = nn.ConvTranspose2d(256, 128, 2, 2)
54
+ self.upsample3 = nn.ConvTranspose2d(128, 64, 2, 2)
55
+ self.upsample2 = nn.ConvTranspose2d(64, 32, 2, 2)
56
+ self.upsample1 = nn.ConvTranspose2d(32, 16, 2, 2)
57
+
58
+ # Final layer
59
+ self.final = nn.Conv2d(16, n_classes, 1)
60
+
61
+ def forward(self, x):
62
+ # Encoder
63
+ e1 = self.enc1(x)
64
+ e2 = self.enc2(self.pool(e1))
65
+ e3 = self.enc3(self.pool(e2))
66
+ e4 = self.enc4(self.pool(e3))
67
+ e5 = self.enc5(self.pool(e4))
68
+
69
+ # Decoder with skip connections
70
+ d4 = self.upsample4(e5)
71
+ d4 = torch.cat([d4, e4], dim=1)
72
+ d4 = self.dec4(d4)
73
+
74
+ d3 = self.upsample3(d4)
75
+ d3 = torch.cat([d3, e3], dim=1)
76
+ d3 = self.dec3(d3)
77
+
78
+ d2 = self.upsample2(d3)
79
+ d2 = torch.cat([d2, e2], dim=1)
80
+ d2 = self.dec2(d2)
81
+
82
+ d1 = self.upsample1(d2)
83
+ d1 = torch.cat([d1, e1], dim=1)
84
+ d1 = self.dec1(d1)
85
+
86
+ return self.final(d1)
87
+
88
+ def unet(**kwargs):
89
+ return SimpleFaceParser(**kwargs)
90
+
91
+ # تابع generate_label
92
+ def generate_label(inputs, imsize=512):
93
+ """Generate label maps from model outputs"""
94
+ pred_batch = []
95
+ for input in inputs:
96
+ input = input.unsqueeze(0)
97
+ pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
98
+ pred_batch.append(pred)
99
+
100
+ pred_batch = np.array(pred_batch)
101
+ pred_batch = torch.from_numpy(pred_batch)
102
+
103
+ label_batch = []
104
+ for p in pred_batch:
105
+ p = p.view(1, imsize, imsize)
106
+ label_batch.append(p.data.cpu())
107
+
108
+ label_batch = torch.cat(label_batch, 0)
109
+ label_batch = label_batch.type(torch.LongTensor)
110
 
111
+ return label_batch
 
 
 
 
 
 
 
112
 
113
  # تنظیمات دستگاه
114
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
119
 
120
  # تعریف transform
121
  transform = T.Compose([
122
+ T.Resize((512, 512)),
123
  T.ToTensor(),
124
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
125
  ])
126
 
127
  # کلاس‌های Face Parsing
 
147
  )
148
  print(f"✅ Model downloaded to: {model_path}")
149
 
150
+ # ایجاد مدل با معماری صحیح
151
+ self.model = unet(n_channels=3, n_classes=19)
152
+
153
  # لود state dict
154
  state_dict = torch.load(model_path, map_location="cpu")
155
 
156
+ # اگر state dict از DataParallel باشد، module. را حذف می‌کنیم
 
 
 
157
  new_state_dict = {}
158
  for k, v in state_dict.items():
159
  if k.startswith('module.'):
160
+ k = k[7:]
161
  new_state_dict[k] = v
162
 
163
+ self.model.load_state_dict(new_state_dict)
164
  self.model.eval()
165
  self.model.to(self.device)
166
 
 
193
  # پیش‌بینی
194
  with torch.no_grad():
195
  out = self.model(data)
196
+ label_out = generate_label(out, 512)
197
+ mask = label_out[0].cpu().numpy()
198
 
199
+ # رنگ‌آمیزی ماسک
200
+ colored_mask = self.colorize_mask(mask)
201
 
202
  # ترکیب تصویر اصلی با ماسک
203
  resized_image = np.asarray(original_image.resize((512, 512)))
204
+ blended = cv2.addWeighted(resized_image, 0.7, colored_mask, 0.3, 0)
 
205
 
206
  return colored_mask, blended
207
 
208
  def colorize_mask(self, mask):
209
  """رنگ‌آمیزی ماسک بر اساس کلاس‌ها"""
210
+ # پالت رنگ برای 19 کلاس (متفاوت برای تشخیص بهتر)
211
  palette = [
212
+ [0, 0, 0], # background - سیاه
213
+ [255, 200, 200], # skin - پوست
214
+ [0, 255, 0], # l_brow - سبز
215
+ [0, 200, 0], # r_brow - سبز تیره
216
+ [255, 0, 0], # l_eye - قرمز
217
+ [200, 0, 0], # r_eye - قرمز تیره
218
+ [255, 255, 0], # eye_g - زرد
219
+ [0, 0, 255], # l_ear - آبی
220
+ [0, 0, 200], # r_ear - آبی تیره
221
+ [128, 0, 128], # ear_r - بنفش
222
+ [255, 165, 0], # nose - نارنجی
223
+ [255, 0, 255], # mouth - صورتی
224
+ [200, 0, 200], # u_lip - صورتی تیره
225
+ [165, 42, 42], # l_lip - قهوه‌ای
226
+ [0, 255, 255], # neck - فیروزه‌ای
227
+ [0, 200, 200], # neck_l - فیروزه‌ای تیره
228
+ [128, 128, 128], # cloth - خاکستری
229
+ [255, 255, 255], # hair - سفید
230
+ [255, 215, 0] # hat - طلایی
231
  ]
232
 
233
  colored = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
 
243
  print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
244
  print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path))
245
  print("[Info] face_parsing folder exists:", os.path.exists(face_parsing_path))
 
246
 
247
  try:
248
  face_parser = FaceParsingModel()
 
292
  traceback.print_exc()
293
  return None, None, error_msg
294
 
295
+ # ادامه کد Gradio (مشابه قبل)
296
+
297
+ # ایجاد اینترفیس Gradio
298
+ with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as demo:
299
+ gr.Markdown("""
300
+ # 🎭 CelebAMask-HQ Face Parsing Demo
301
+ **آپلود یک تصویر صورت و دریافت خروجی Face Parsing**
302
+
303
+ این مدل صورت را به 19 بخش مختلف تقسیم می‌کند (پوست، چشم، ابرو، بینی، دهان، مو و ...)
304
+ """)
305
+
306
+ with gr.Row():
307
+ with gr.Column():
308
+ input_image = gr.Image(
309
+ label="📷 تصویر ورودی",
310
+ type="filepath",
311
+ sources=["upload"],
312
+ height=300
313
+ )
314
+ process_btn = gr.Button("🚀 پردازش تصویر", variant="primary", size="lg")
315
+
316
+ with gr.Accordion("ℹ️ وضعیت برنامه", open=False):
317
+ status_display = gr.Markdown(f"""
318
+ **وضعیت:**
319
+ - 🎯 مدل: {'✅ لود شده' if success else '❌ خطا در لود'}
320
+ - 💻 دستگاه: `{device}`
321
+ - 📦 ماژول‌ها: {'✅ ایمپورت شده' if IMPORT_SUCCESS else '❌ خطا در ایمپورت'}
322
+ - 🗂️ کلاس‌ها: {len(CELEBA_CLASSES)}
323
+ """)
324
+
325
+ with gr.Column():
326
+ output_blended = gr.Image(
327
+ label="🎨 نتیجه ترکیبی (تصویر + ماسک)",
328
+ height=300
329
+ )
330
+ output_mask = gr.Image(
331
+ label="🎭 ماسک سگمنتیشن",
332
+ height=300
333
+ )
334
+
335
+ with gr.Row():
336
+ info_output = gr.Textbox(
337
+ label="📊 اطلاعات پردازش",
338
+ lines=3,
339
+ max_lines=6
340
+ )
341
+
342
+ with gr.Row():
343
+ gr.HTML(create_legend())
344
+
345
+ # اتصال رویدادها
346
+ process_btn.click(
347
+ fn=process_image,
348
+ inputs=[input_image],
349
+ outputs=[output_blended, output_mask, info_output]
350
+ )
351
+
352
+ input_image.upload(
353
+ fn=process_image,
354
+ inputs=[input_image],
355
+ outputs=[output_blended, output_mask, info_output]
356
+ )
357
+
358
+ if __name__ == "__main__":
359
+ print("🚀 Starting Face Parsing Application...")
360
+ demo.launch(
361
+ server_name="0.0.0.0",
362
+ server_port=7860,
363
+ share=False
364
+ )