danicor commited on
Commit
3755310
·
verified ·
1 Parent(s): 0e96e5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -94
app.py CHANGED
@@ -3,7 +3,6 @@ import sys
3
  import numpy as np
4
  import PIL.Image
5
  import torch
6
- import torch.nn as nn
7
  import torchvision.transforms as T
8
  from huggingface_hub import hf_hub_download
9
  import gradio as gr
@@ -20,95 +19,15 @@ print("Python path:", sys.path)
20
  print("CelebAMask path exists:", os.path.exists(celebamask_path))
21
  print("Face parsing path exists:", os.path.exists(face_parsing_path))
22
 
23
- # تعریف معماری مدل مطابق با state dict دانلود شده
24
- class SimpleFaceParser(nn.Module):
25
- def __init__(self, n_channels=3, n_classes=19):
26
- super(SimpleFaceParser, self).__init__()
27
-
28
- def conv_block(in_channels, out_channels):
29
- return nn.Sequential(
30
- nn.Conv2d(in_channels, out_channels, 3, padding=1),
31
- nn.BatchNorm2d(out_channels),
32
- nn.ReLU(inplace=True),
33
- nn.Conv2d(out_channels, out_channels, 3, padding=1),
34
- nn.BatchNorm2d(out_channels),
35
- nn.ReLU(inplace=True)
36
- )
37
-
38
- # Encoder
39
- self.enc1 = conv_block(n_channels, 16)
40
- self.enc2 = conv_block(16, 32)
41
- self.enc3 = conv_block(32, 64)
42
- self.enc4 = conv_block(64, 128)
43
- self.enc5 = conv_block(128, 256)
44
-
45
- # Decoder
46
- self.dec4 = conv_block(256 + 128, 128)
47
- self.dec3 = conv_block(128 + 64, 64)
48
- self.dec2 = conv_block(64 + 32, 32)
49
- self.dec1 = conv_block(32 + 16, 16)
50
-
51
- # Pooling and upsample
52
- self.pool = nn.MaxPool2d(2)
53
- self.upsample4 = nn.ConvTranspose2d(256, 128, 2, 2)
54
- self.upsample3 = nn.ConvTranspose2d(128, 64, 2, 2)
55
- self.upsample2 = nn.ConvTranspose2d(64, 32, 2, 2)
56
- self.upsample1 = nn.ConvTranspose2d(32, 16, 2, 2)
57
-
58
- # Final layer
59
- self.final = nn.Conv2d(16, n_classes, 1)
60
-
61
- def forward(self, x):
62
- # Encoder
63
- e1 = self.enc1(x)
64
- e2 = self.enc2(self.pool(e1))
65
- e3 = self.enc3(self.pool(e2))
66
- e4 = self.enc4(self.pool(e3))
67
- e5 = self.enc5(self.pool(e4))
68
-
69
- # Decoder with skip connections
70
- d4 = self.upsample4(e5)
71
- d4 = torch.cat([d4, e4], dim=1)
72
- d4 = self.dec4(d4)
73
-
74
- d3 = self.upsample3(d4)
75
- d3 = torch.cat([d3, e3], dim=1)
76
- d3 = self.dec3(d3)
77
-
78
- d2 = self.upsample2(d3)
79
- d2 = torch.cat([d2, e2], dim=1)
80
- d2 = self.dec2(d2)
81
-
82
- d1 = self.upsample1(d2)
83
- d1 = torch.cat([d1, e1], dim=1)
84
- d1 = self.dec1(d1)
85
-
86
- return self.final(d1)
87
-
88
- def unet(**kwargs):
89
- return SimpleFaceParser(**kwargs)
90
-
91
- # تابع generate_label
92
- def generate_label(inputs, imsize=512):
93
- """Generate label maps from model outputs"""
94
- pred_batch = []
95
- for input in inputs:
96
- input = input.unsqueeze(0)
97
- pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
98
- pred_batch.append(pred)
99
-
100
- pred_batch = np.array(pred_batch)
101
- pred_batch = torch.from_numpy(pred_batch)
102
-
103
- label_batch = []
104
- for p in pred_batch:
105
- p = p.view(1, imsize, imsize)
106
- label_batch.append(p.data.cpu())
107
-
108
- label_batch = torch.cat(label_batch, 0)
109
- label_batch = label_batch.type(torch.LongTensor)
110
-
111
- return label_batch
112
 
113
  # تنظیمات دستگاه
114
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -148,7 +67,13 @@ class FaceParsingModel:
148
  print(f"✅ Model downloaded to: {model_path}")
149
 
150
  # ایجاد مدل با معماری صحیح
151
- self.model = unet(n_channels=3, n_classes=19)
 
 
 
 
 
 
152
 
153
  # لود state dict
154
  state_dict = torch.load(model_path, map_location="cpu")
@@ -243,6 +168,7 @@ def initialize_app():
243
  print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
244
  print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path))
245
  print("[Info] face_parsing folder exists:", os.path.exists(face_parsing_path))
 
246
 
247
  try:
248
  face_parser = FaceParsingModel()
@@ -292,9 +218,29 @@ def process_image(input_image):
292
  traceback.print_exc()
293
  return None, None, error_msg
294
 
295
- # ادامه کد Gradio (مشابه قبل)
296
-
297
- # ادامه کد Gradio (مشابه قبل)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  # ایجاد اینترفیس Gradio
300
  with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as demo:
 
3
  import numpy as np
4
  import PIL.Image
5
  import torch
 
6
  import torchvision.transforms as T
7
  from huggingface_hub import hf_hub_download
8
  import gradio as gr
 
19
  print("CelebAMask path exists:", os.path.exists(celebamask_path))
20
  print("Face parsing path exists:", os.path.exists(face_parsing_path))
21
 
22
+ # ایمپورت ماژول‌های مورد نیاز
23
+ try:
24
+ from unet import unet
25
+ from utils import generate_label
26
+ IMPORT_SUCCESS = True
27
+ print("✅ Successfully imported CelebAMask-HQ modules")
28
+ except ImportError as e:
29
+ IMPORT_SUCCESS = False
30
+ print(f"❌ Failed to import CelebAMask-HQ modules: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # تنظیمات دستگاه
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
67
  print(f"✅ Model downloaded to: {model_path}")
68
 
69
  # ایجاد مدل با معماری صحیح
70
+ self.model = unet(
71
+ feature_scale=4,
72
+ n_classes=19,
73
+ is_deconv=True,
74
+ in_channels=3,
75
+ is_batchnorm=True
76
+ )
77
 
78
  # لود state dict
79
  state_dict = torch.load(model_path, map_location="cpu")
 
168
  print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
169
  print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path))
170
  print("[Info] face_parsing folder exists:", os.path.exists(face_parsing_path))
171
+ print("[Info] Module import success:", IMPORT_SUCCESS)
172
 
173
  try:
174
  face_parser = FaceParsingModel()
 
218
  traceback.print_exc()
219
  return None, None, error_msg
220
 
221
+ def create_legend():
222
+ """ایجاد لیجند برای کلاس‌ها"""
223
+ import matplotlib.pyplot as plt
224
+
225
+ legend_html = """
226
+ <div style='max-height: 300px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; border-radius: 5px;'>
227
+ <h4>🎨 Legend - کلاس‌های Face Parsing:</h4>
228
+ """
229
+
230
+ colors = plt.get_cmap('tab20', len(CELEBA_CLASSES))
231
+
232
+ for i, class_name in enumerate(CELEBA_CLASSES):
233
+ color = colors(i)
234
+ color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
235
+ text_color = 'white' if color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114 < 0.5 else 'black'
236
+ legend_html += f"""
237
+ <div style='margin: 2px; padding: 5px; background-color: {color_hex}; color: {text_color}; border-radius: 3px;'>
238
+ <strong>{i}:</strong> {class_name}
239
+ </div>
240
+ """
241
+
242
+ legend_html += "</div>"
243
+ return legend_html
244
 
245
  # ایجاد اینترفیس Gradio
246
  with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as demo: