Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,6 @@ import sys
|
|
| 3 |
import numpy as np
|
| 4 |
import PIL.Image
|
| 5 |
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
import torchvision.transforms as T
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
import gradio as gr
|
|
@@ -20,95 +19,15 @@ print("Python path:", sys.path)
|
|
| 20 |
print("CelebAMask path exists:", os.path.exists(celebamask_path))
|
| 21 |
print("Face parsing path exists:", os.path.exists(face_parsing_path))
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
nn.ReLU(inplace=True),
|
| 33 |
-
nn.Conv2d(out_channels, out_channels, 3, padding=1),
|
| 34 |
-
nn.BatchNorm2d(out_channels),
|
| 35 |
-
nn.ReLU(inplace=True)
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
# Encoder
|
| 39 |
-
self.enc1 = conv_block(n_channels, 16)
|
| 40 |
-
self.enc2 = conv_block(16, 32)
|
| 41 |
-
self.enc3 = conv_block(32, 64)
|
| 42 |
-
self.enc4 = conv_block(64, 128)
|
| 43 |
-
self.enc5 = conv_block(128, 256)
|
| 44 |
-
|
| 45 |
-
# Decoder
|
| 46 |
-
self.dec4 = conv_block(256 + 128, 128)
|
| 47 |
-
self.dec3 = conv_block(128 + 64, 64)
|
| 48 |
-
self.dec2 = conv_block(64 + 32, 32)
|
| 49 |
-
self.dec1 = conv_block(32 + 16, 16)
|
| 50 |
-
|
| 51 |
-
# Pooling and upsample
|
| 52 |
-
self.pool = nn.MaxPool2d(2)
|
| 53 |
-
self.upsample4 = nn.ConvTranspose2d(256, 128, 2, 2)
|
| 54 |
-
self.upsample3 = nn.ConvTranspose2d(128, 64, 2, 2)
|
| 55 |
-
self.upsample2 = nn.ConvTranspose2d(64, 32, 2, 2)
|
| 56 |
-
self.upsample1 = nn.ConvTranspose2d(32, 16, 2, 2)
|
| 57 |
-
|
| 58 |
-
# Final layer
|
| 59 |
-
self.final = nn.Conv2d(16, n_classes, 1)
|
| 60 |
-
|
| 61 |
-
def forward(self, x):
|
| 62 |
-
# Encoder
|
| 63 |
-
e1 = self.enc1(x)
|
| 64 |
-
e2 = self.enc2(self.pool(e1))
|
| 65 |
-
e3 = self.enc3(self.pool(e2))
|
| 66 |
-
e4 = self.enc4(self.pool(e3))
|
| 67 |
-
e5 = self.enc5(self.pool(e4))
|
| 68 |
-
|
| 69 |
-
# Decoder with skip connections
|
| 70 |
-
d4 = self.upsample4(e5)
|
| 71 |
-
d4 = torch.cat([d4, e4], dim=1)
|
| 72 |
-
d4 = self.dec4(d4)
|
| 73 |
-
|
| 74 |
-
d3 = self.upsample3(d4)
|
| 75 |
-
d3 = torch.cat([d3, e3], dim=1)
|
| 76 |
-
d3 = self.dec3(d3)
|
| 77 |
-
|
| 78 |
-
d2 = self.upsample2(d3)
|
| 79 |
-
d2 = torch.cat([d2, e2], dim=1)
|
| 80 |
-
d2 = self.dec2(d2)
|
| 81 |
-
|
| 82 |
-
d1 = self.upsample1(d2)
|
| 83 |
-
d1 = torch.cat([d1, e1], dim=1)
|
| 84 |
-
d1 = self.dec1(d1)
|
| 85 |
-
|
| 86 |
-
return self.final(d1)
|
| 87 |
-
|
| 88 |
-
def unet(**kwargs):
|
| 89 |
-
return SimpleFaceParser(**kwargs)
|
| 90 |
-
|
| 91 |
-
# تابع generate_label
|
| 92 |
-
def generate_label(inputs, imsize=512):
|
| 93 |
-
"""Generate label maps from model outputs"""
|
| 94 |
-
pred_batch = []
|
| 95 |
-
for input in inputs:
|
| 96 |
-
input = input.unsqueeze(0)
|
| 97 |
-
pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
|
| 98 |
-
pred_batch.append(pred)
|
| 99 |
-
|
| 100 |
-
pred_batch = np.array(pred_batch)
|
| 101 |
-
pred_batch = torch.from_numpy(pred_batch)
|
| 102 |
-
|
| 103 |
-
label_batch = []
|
| 104 |
-
for p in pred_batch:
|
| 105 |
-
p = p.view(1, imsize, imsize)
|
| 106 |
-
label_batch.append(p.data.cpu())
|
| 107 |
-
|
| 108 |
-
label_batch = torch.cat(label_batch, 0)
|
| 109 |
-
label_batch = label_batch.type(torch.LongTensor)
|
| 110 |
-
|
| 111 |
-
return label_batch
|
| 112 |
|
| 113 |
# تنظیمات دستگاه
|
| 114 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -148,7 +67,13 @@ class FaceParsingModel:
|
|
| 148 |
print(f"✅ Model downloaded to: {model_path}")
|
| 149 |
|
| 150 |
# ایجاد مدل با معماری صحیح
|
| 151 |
-
self.model = unet(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# لود state dict
|
| 154 |
state_dict = torch.load(model_path, map_location="cpu")
|
|
@@ -243,6 +168,7 @@ def initialize_app():
|
|
| 243 |
print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
|
| 244 |
print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path))
|
| 245 |
print("[Info] face_parsing folder exists:", os.path.exists(face_parsing_path))
|
|
|
|
| 246 |
|
| 247 |
try:
|
| 248 |
face_parser = FaceParsingModel()
|
|
@@ -292,9 +218,29 @@ def process_image(input_image):
|
|
| 292 |
traceback.print_exc()
|
| 293 |
return None, None, error_msg
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
# ایجاد اینترفیس Gradio
|
| 300 |
with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as demo:
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import PIL.Image
|
| 5 |
import torch
|
|
|
|
| 6 |
import torchvision.transforms as T
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
import gradio as gr
|
|
|
|
| 19 |
print("CelebAMask path exists:", os.path.exists(celebamask_path))
|
| 20 |
print("Face parsing path exists:", os.path.exists(face_parsing_path))
|
| 21 |
|
| 22 |
+
# ایمپورت ماژولهای مورد نیاز
|
| 23 |
+
try:
|
| 24 |
+
from unet import unet
|
| 25 |
+
from utils import generate_label
|
| 26 |
+
IMPORT_SUCCESS = True
|
| 27 |
+
print("✅ Successfully imported CelebAMask-HQ modules")
|
| 28 |
+
except ImportError as e:
|
| 29 |
+
IMPORT_SUCCESS = False
|
| 30 |
+
print(f"❌ Failed to import CelebAMask-HQ modules: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# تنظیمات دستگاه
|
| 33 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 67 |
print(f"✅ Model downloaded to: {model_path}")
|
| 68 |
|
| 69 |
# ایجاد مدل با معماری صحیح
|
| 70 |
+
self.model = unet(
|
| 71 |
+
feature_scale=4,
|
| 72 |
+
n_classes=19,
|
| 73 |
+
is_deconv=True,
|
| 74 |
+
in_channels=3,
|
| 75 |
+
is_batchnorm=True
|
| 76 |
+
)
|
| 77 |
|
| 78 |
# لود state dict
|
| 79 |
state_dict = torch.load(model_path, map_location="cpu")
|
|
|
|
| 168 |
print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH"))
|
| 169 |
print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path))
|
| 170 |
print("[Info] face_parsing folder exists:", os.path.exists(face_parsing_path))
|
| 171 |
+
print("[Info] Module import success:", IMPORT_SUCCESS)
|
| 172 |
|
| 173 |
try:
|
| 174 |
face_parser = FaceParsingModel()
|
|
|
|
| 218 |
traceback.print_exc()
|
| 219 |
return None, None, error_msg
|
| 220 |
|
| 221 |
+
def create_legend():
|
| 222 |
+
"""ایجاد لیجند برای کلاسها"""
|
| 223 |
+
import matplotlib.pyplot as plt
|
| 224 |
+
|
| 225 |
+
legend_html = """
|
| 226 |
+
<div style='max-height: 300px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; border-radius: 5px;'>
|
| 227 |
+
<h4>🎨 Legend - کلاسهای Face Parsing:</h4>
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
colors = plt.get_cmap('tab20', len(CELEBA_CLASSES))
|
| 231 |
+
|
| 232 |
+
for i, class_name in enumerate(CELEBA_CLASSES):
|
| 233 |
+
color = colors(i)
|
| 234 |
+
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
|
| 235 |
+
text_color = 'white' if color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114 < 0.5 else 'black'
|
| 236 |
+
legend_html += f"""
|
| 237 |
+
<div style='margin: 2px; padding: 5px; background-color: {color_hex}; color: {text_color}; border-radius: 3px;'>
|
| 238 |
+
<strong>{i}:</strong> {class_name}
|
| 239 |
+
</div>
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
legend_html += "</div>"
|
| 243 |
+
return legend_html
|
| 244 |
|
| 245 |
# ایجاد اینترفیس Gradio
|
| 246 |
with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as demo:
|