danicor commited on
Commit
9dcfc8c
·
verified ·
1 Parent(s): bf34ac7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -138
app.py CHANGED
@@ -1,194 +1,185 @@
1
  import os
2
- import cv2
3
- import torch
4
  import numpy as np
5
- from PIL import Image
6
- import gradio as gr
 
7
  from huggingface_hub import hf_hub_download
 
8
  import time
9
- import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # تنظیم مسیرهای کش
12
  os.environ["HF_HOME"] = "/home/user/app/hf_cache"
13
 
14
- # کلاس‌های Face Parsing (19 کلاس)
 
 
 
 
 
 
 
15
  CELEBA_CLASSES = [
16
  'background', 'skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
17
  'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'
18
  ]
19
 
20
- class FaceParsing:
21
- def __init__(self, model_path):
22
- self.model_path = model_path
23
  self.model = None
24
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
  self.load_model()
26
 
27
  def load_model(self):
28
  """لود مدل Face Parsing"""
29
  try:
30
- # اینجا باید معماری مدل را بر اساس ریپوی اصلی تنظیم کنید
31
- # برای سادگی از یک مدل ساده استفاده می‌کنیم
32
- self.model = torch.jit.load(self.model_path, map_location=self.device)
 
 
 
 
 
 
 
 
 
 
 
33
  self.model.eval()
34
- print("[Success] Model loaded successfully")
 
 
 
35
  except Exception as e:
36
- print(f"[Error] Failed to load model: {e}")
37
- # اگر مدل قابل لود نیست، یک مدل ساده ایجاد می‌کنیم
38
- self.model = SimpleFaceParser()
39
 
40
- def preprocess_image(self, image):
41
- """پیش‌پردازش تصویر ورودی"""
42
- # تبدیل به RGB اگر لازم است
 
 
 
43
  if isinstance(image, str):
44
- image = Image.open(image).convert('RGB')
45
  elif isinstance(image, np.ndarray):
46
- image = Image.fromarray(image)
47
-
48
- # تغییر سایز به 512x512
49
- image = image.resize((512, 512))
50
 
51
- # تبدیل به tensor و نرمال‌سازی
52
- image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
53
- image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)
54
 
55
- return image_tensor.to(self.device), image
56
-
57
- def postprocess_mask(self, mask):
58
- """پس‌پردازش ماسک خروجی"""
59
- mask = mask.squeeze().cpu().numpy()
60
- return mask.astype(np.uint8)
61
-
62
- def predict(self, image):
63
- """پیش‌بینی روی تصویر ورودی"""
64
- try:
65
- image_tensor, original_image = self.preprocess_image(image)
66
-
67
- with torch.no_grad():
68
- if hasattr(self.model, 'predict'):
69
- output = self.model.predict(image_tensor)
70
- else:
71
- output = self.model(image_tensor)
72
-
73
- # گرفتن ماسک پیش‌بینی شده
74
- if isinstance(output, tuple):
75
- mask = output[0]
76
- else:
77
- mask = output
78
-
79
- parsed_mask = self.postprocess_mask(mask)
80
- return self.visualize_result(np.array(original_image), parsed_mask)
81
-
82
- except Exception as e:
83
- print(f"[Error] Prediction failed: {e}")
84
- # بازگشت تصویر اصلی در صورت خطا
85
- if isinstance(image, str):
86
- original_img = Image.open(image)
87
- else:
88
- original_img = image
89
- return original_img, original_img
90
 
91
- def visualize_result(self, original_image, mask):
92
- """ویژوالایز کردن نتایج"""
93
- # ایجاد تصویر رنگی از ماسک
94
- colored_mask = self.colorize_mask(mask)
95
-
96
  # ترکیب تصویر اصلی با ماسک
97
- overlay = cv2.addWeighted(original_image, 0.7, colored_mask, 0.3, 0)
98
-
99
- return overlay, colored_mask
100
-
101
- def colorize_mask(self, mask):
102
- """رنگ‌آمیزی ماسک بر اساس کلاس‌ها"""
103
- # ایجاد پالت رنگ برای کلاس‌ها
104
- cmap = plt.get_cmap('tab20', len(CELEBA_CLASSES))
105
- colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
106
-
107
- for i in range(len(CELEBA_CLASSES)):
108
- colored_mask[mask == i] = np.array(cmap(i)[:3]) * 255
109
-
110
- return colored_mask.astype(np.uint8)
111
 
112
- class SimpleFaceParser:
113
- """یک پارسر ساده برای مواقعی که مدل اصلی کار نمی‌کند"""
114
- def __init__(self):
115
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
116
-
117
- def __call__(self, x):
118
- # یک خروجی ساده شبیه ماسک تولید می‌کند
119
- batch_size, channels, height, width = x.shape
120
- return torch.randint(0, len(CELEBA_CLASSES), (batch_size, 1, height, width)).float().to(self.device)
121
 
122
  def initialize_app():
123
- """Initialize application and download models"""
124
  print("===== Application Startup at {} =====".format(time.strftime("%Y-%m-%d %H:%M:%S")))
125
 
126
- celeb_path = "/home/user/app/huggingface_models/CelebAMask-HQ"
127
- face_parsing_path = os.path.join(celeb_path, "face_parsing")
128
-
129
  print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
130
- print("[Info] CelebAMask-HQ path exists:", os.path.exists(celeb_path))
131
- print("[Info] face_parsing folder exists:", os.path.exists(face_parsing_path))
132
-
 
133
  try:
134
- model_path = hf_hub_download(
135
- repo_id="public-data/CelebAMask-HQ-Face-Parsing",
136
- filename="models/model.pth",
137
- cache_dir="/home/user/app/hf_cache"
138
- )
139
- print("[Success] Model downloaded to:", model_path)
140
-
141
- # لود مدل
142
- face_parser = FaceParsing(model_path)
143
- return True, model_path, face_parser
144
-
145
  except Exception as e:
146
- print("[Error] Failed to download model:", str(e))
147
- # استفاده از پارسر ساده
148
- face_parser = FaceParsing(None)
149
- return False, str(e), face_parser
150
 
151
  # Initialize the application
152
- success, model_info, face_parser = initialize_app()
153
 
154
  def process_image(input_image):
155
- """پردازش تصویر ورودی و بازگشت نتایج"""
156
  if input_image is None:
157
  return None, None, "لطفاً یک تصویر آپلود کنید"
158
 
 
 
 
159
  try:
160
  # پردازش تصویر
161
- overlay_result, mask_result = face_parser.predict(input_image)
162
 
163
- # اطلاعات درباره تصویر
164
  if isinstance(input_image, str):
165
- img_size = Image.open(input_image).size
 
166
  else:
167
  img_size = input_image.size if hasattr(input_image, 'size') else input_image.shape[:2][::-1]
 
168
 
169
  info_text = f"""
170
  ✅ پردازش انجام شد!
171
  - اندازه تصویر ورودی: {img_size}
172
- - مدل: {'CelebAMask-HQ' if success else 'Simple Parser'}
173
  - کلاس‌های تشخیص: {len(CELEBA_CLASSES)}
 
174
  """
175
 
176
- return overlay_result, mask_result, info_text
177
 
178
  except Exception as e:
179
- error_msg = f"خطا در پردازش تصویر: {str(e)}"
180
  print(error_msg)
181
  return None, None, error_msg
182
 
183
  def create_legend():
184
  """ایجاد لیجند برای کلاس‌ها"""
185
- legend_html = "<div style='max-height: 300px; overflow-y: auto;'><h4>Legend - کلاس‌های Face Parsing:</h4>"
 
 
 
 
 
 
186
  colors = plt.get_cmap('tab20', len(CELEBA_CLASSES))
187
 
188
  for i, class_name in enumerate(CELEBA_CLASSES):
189
  color = colors(i)
190
  color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
191
- legend_html += f"<div style='margin: 2px; padding: 5px; background-color: {color_hex}; color: white;'>{i}: {class_name}</div>"
 
 
 
 
 
192
 
193
  legend_html += "</div>"
194
  return legend_html
@@ -198,6 +189,8 @@ with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as de
198
  gr.Markdown("""
199
  # 🎭 CelebAMask-HQ Face Parsing Demo
200
  **آپلود یک تصویر صورت و دریافت خروجی Face Parsing**
 
 
201
  """)
202
 
203
  with gr.Row():
@@ -205,26 +198,27 @@ with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as de
205
  input_image = gr.Image(
206
  label="📷 تصویر ورودی",
207
  type="filepath",
208
- sources=["upload", "webcam"],
209
  height=300
210
  )
211
- process_btn = gr.Button("🚀 پردازش تصویر", variant="primary")
212
 
213
- with gr.Accordion("ℹ️ اطلاعات برنامه", open=False):
214
- status_text = gr.Markdown(f"""
215
- **وضعیت برنامه:**
216
- - مدل: {'✅ موفق' if success else '⚠️ ساده'}
217
- - مسیر مدل: `{model_info if success else 'مدل پیش‌فرض'}`
218
- - کلاس‌های تشخیص: {len(CELEBA_CLASSES)}
 
219
  """)
220
 
221
  with gr.Column():
222
- output_overlay = gr.Image(
223
- label="🎨 نتیجه ترکیبی (Overlay)",
224
  height=300
225
  )
226
  output_mask = gr.Image(
227
- label="🎭 ماسک segmentation",
228
  height=300
229
  )
230
 
@@ -242,18 +236,17 @@ with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as de
242
  process_btn.click(
243
  fn=process_image,
244
  inputs=[input_image],
245
- outputs=[output_overlay, output_mask, info_output]
246
  )
247
 
248
- # پردازش خودکار هنگام آپلود تصویر
249
  input_image.upload(
250
  fn=process_image,
251
  inputs=[input_image],
252
- outputs=[output_overlay, output_mask, info_output]
253
  )
254
 
255
  if __name__ == "__main__":
256
- print("Starting Face Parsing Application...")
257
  demo.launch(
258
  server_name="0.0.0.0",
259
  server_port=7860,
 
1
  import os
2
+ import sys
 
3
  import numpy as np
4
+ import PIL.Image
5
+ import torch
6
+ import torchvision.transforms as T
7
  from huggingface_hub import hf_hub_download
8
+ import gradio as gr
9
  import time
10
+
11
+ # افزودن مسیر مورد نیاز برای ماژول‌های CelebAMask-HQ
12
+ celebamask_path = "/home/user/app/CelebAMask-HQ"
13
+ sys.path.insert(0, os.path.join(celebamask_path, "face_parsing"))
14
+
15
+ # ایمپورت ماژول‌های مورد نیاز
16
+ try:
17
+ from unet import unet
18
+ from utils import generate_label
19
+ IMPORT_SUCCESS = True
20
+ print("✅ Successfully imported CelebAMask-HQ modules")
21
+ except ImportError as e:
22
+ IMPORT_SUCCESS = False
23
+ print(f"❌ Failed to import CelebAMask-HQ modules: {e}")
24
+
25
+ # تنظیمات دستگاه
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ print(f"Using device: {device}")
28
 
29
  # تنظیم مسیرهای کش
30
  os.environ["HF_HOME"] = "/home/user/app/hf_cache"
31
 
32
+ # تعریف transform
33
+ transform = T.Compose([
34
+ T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
35
+ T.ToTensor(),
36
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
37
+ ])
38
+
39
+ # کلاس‌های Face Parsing
40
  CELEBA_CLASSES = [
41
  'background', 'skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
42
  'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'
43
  ]
44
 
45
+ class FaceParsingModel:
46
+ def __init__(self):
 
47
  self.model = None
48
+ self.device = device
49
  self.load_model()
50
 
51
  def load_model(self):
52
  """لود مدل Face Parsing"""
53
  try:
54
+ print("📥 Downloading model...")
55
+ model_path = hf_hub_download(
56
+ repo_id="public-data/CelebAMask-HQ-Face-Parsing",
57
+ filename="models/model.pth",
58
+ cache_dir="/home/user/app/hf_cache"
59
+ )
60
+ print(f"✅ Model downloaded to: {model_path}")
61
+
62
+ # لود state dict
63
+ state_dict = torch.load(model_path, map_location="cpu")
64
+
65
+ # ایجاد مدل
66
+ self.model = unet()
67
+ self.model.load_state_dict(state_dict)
68
  self.model.eval()
69
+ self.model.to(self.device)
70
+
71
+ print("✅ Model loaded successfully")
72
+
73
  except Exception as e:
74
+ print(f" Failed to load model: {e}")
75
+ self.model = None
 
76
 
77
+ def predict(self, image):
78
+ """پردازش تصویر و تولید ماسک"""
79
+ if self.model is None:
80
+ raise ValueError("Model not loaded properly")
81
+
82
+ # تبدیل به PIL Image اگر لازم است
83
  if isinstance(image, str):
84
+ image = PIL.Image.open(image).convert('RGB')
85
  elif isinstance(image, np.ndarray):
86
+ image = PIL.Image.fromarray(image)
 
 
 
87
 
88
+ # ذخیره تصویر اصلی
89
+ original_image = image.copy()
 
90
 
91
+ # پیش‌پردازش
92
+ data = transform(image)
93
+ data = data.unsqueeze(0).to(self.device)
94
+
95
+ # پیش‌بینی
96
+ with torch.no_grad():
97
+ out = self.model(data)
98
+ out = generate_label(out, 512)
99
+ out = out[0].cpu().numpy().transpose(1, 2, 0)
100
+ out = np.clip(np.round(out * 255), 0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
102
  # ترکیب تصویر اصلی با ماسک
103
+ resized_image = np.asarray(original_image.resize((512, 512))).astype(float)
104
+ blended = resized_image * 0.5 + out.astype(float) * 0.5
105
+ blended = np.clip(np.round(blended), 0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ return out, blended
 
 
 
 
 
 
 
 
108
 
109
  def initialize_app():
110
+ """Initialize application"""
111
  print("===== Application Startup at {} =====".format(time.strftime("%Y-%m-%d %H:%M:%S")))
112
 
 
 
 
113
  print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
114
+ print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path))
115
+ print("[Info] face_parsing folder exists:", os.path.exists(os.path.join(celebamask_path, "face_parsing")))
116
+ print("[Info] Module import success:", IMPORT_SUCCESS)
117
+
118
  try:
119
+ face_parser = FaceParsingModel()
120
+ return True, "Model loaded successfully", face_parser
 
 
 
 
 
 
 
 
 
121
  except Exception as e:
122
+ print(f"[Error] Initialization failed: {e}")
123
+ return False, f"Initialization failed: {e}", None
 
 
124
 
125
  # Initialize the application
126
+ success, status_msg, face_parser = initialize_app()
127
 
128
  def process_image(input_image):
129
+ """پردازش تصویر ورودی"""
130
  if input_image is None:
131
  return None, None, "لطفاً یک تصویر آپلود کنید"
132
 
133
+ if not success or face_parser is None:
134
+ return None, None, "❌ مدل لود نشده است. لطفاً دوباره تلاش کنید."
135
+
136
  try:
137
  # پردازش تصویر
138
+ mask, blended = face_parser.predict(input_image)
139
 
140
+ # اطلاعات پردازش
141
  if isinstance(input_image, str):
142
+ original_img = PIL.Image.open(input_image)
143
+ img_size = original_img.size
144
  else:
145
  img_size = input_image.size if hasattr(input_image, 'size') else input_image.shape[:2][::-1]
146
+ original_img = PIL.Image.fromarray(input_image) if isinstance(input_image, np.ndarray) else input_image
147
 
148
  info_text = f"""
149
  ✅ پردازش انجام شد!
150
  - اندازه تصویر ورودی: {img_size}
151
+ - اندازه خروجی: 512x512
152
  - کلاس‌های تشخیص: {len(CELEBA_CLASSES)}
153
+ - دستگاه پردازش: {device}
154
  """
155
 
156
+ return blended, mask, info_text
157
 
158
  except Exception as e:
159
+ error_msg = f"خطا در پردازش تصویر: {str(e)}"
160
  print(error_msg)
161
  return None, None, error_msg
162
 
163
  def create_legend():
164
  """ایجاد لیجند برای کلاس‌ها"""
165
+ import matplotlib.pyplot as plt
166
+
167
+ legend_html = """
168
+ <div style='max-height: 300px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; border-radius: 5px;'>
169
+ <h4>🎨 Legend - کلاس‌های Face Parsing:</h4>
170
+ """
171
+
172
  colors = plt.get_cmap('tab20', len(CELEBA_CLASSES))
173
 
174
  for i, class_name in enumerate(CELEBA_CLASSES):
175
  color = colors(i)
176
  color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
177
+ text_color = 'white' if color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114 < 0.5 else 'black'
178
+ legend_html += f"""
179
+ <div style='margin: 2px; padding: 5px; background-color: {color_hex}; color: {text_color}; border-radius: 3px;'>
180
+ <strong>{i}:</strong> {class_name}
181
+ </div>
182
+ """
183
 
184
  legend_html += "</div>"
185
  return legend_html
 
189
  gr.Markdown("""
190
  # 🎭 CelebAMask-HQ Face Parsing Demo
191
  **آپلود یک تصویر صورت و دریافت خروجی Face Parsing**
192
+
193
+ این مدل صورت را به 19 بخش مختلف تقسیم می‌کند (پوست، چشم، ابرو، بینی، دهان، مو و ...)
194
  """)
195
 
196
  with gr.Row():
 
198
  input_image = gr.Image(
199
  label="📷 تصویر ورودی",
200
  type="filepath",
201
+ sources=["upload"],
202
  height=300
203
  )
204
+ process_btn = gr.Button("🚀 پردازش تصویر", variant="primary", size="lg")
205
 
206
+ with gr.Accordion("ℹ️ وضعیت برنامه", open=False):
207
+ status_display = gr.Markdown(f"""
208
+ **وضعیت:**
209
+ - 🎯 مدل: {'✅ لود شده' if success else ' خطا در لود'}
210
+ - 💻 دستگاه: `{device}`
211
+ - 📦 ماژول‌ها: {'✅ ایمپورت شده' if IMPORT_SUCCESS else '❌ خطا در ایمپورت'}
212
+ - 🗂️ کلاس‌ها: {len(CELEBA_CLASSES)}
213
  """)
214
 
215
  with gr.Column():
216
+ output_blended = gr.Image(
217
+ label="🎨 نتیجه ترکیبی (تصویر + ماسک)",
218
  height=300
219
  )
220
  output_mask = gr.Image(
221
+ label="🎭 ماسک سگمنتیشن",
222
  height=300
223
  )
224
 
 
236
  process_btn.click(
237
  fn=process_image,
238
  inputs=[input_image],
239
+ outputs=[output_blended, output_mask, info_output]
240
  )
241
 
 
242
  input_image.upload(
243
  fn=process_image,
244
  inputs=[input_image],
245
+ outputs=[output_blended, output_mask, info_output]
246
  )
247
 
248
  if __name__ == "__main__":
249
+ print("🚀 Starting Face Parsing Application...")
250
  demo.launch(
251
  server_name="0.0.0.0",
252
  server_port=7860,