|
|
import os |
|
|
import sys |
|
|
import numpy as np |
|
|
import PIL.Image |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.transforms as T |
|
|
from huggingface_hub import hf_hub_download |
|
|
import gradio as gr |
|
|
import time |
|
|
import cv2 |
|
|
|
|
|
|
|
|
celebamask_path = "/home/user/app/CelebAMask-HQ" |
|
|
face_parsing_path = os.path.join(celebamask_path, "face_parsing") |
|
|
sys.path.insert(0, celebamask_path) |
|
|
sys.path.insert(0, face_parsing_path) |
|
|
|
|
|
|
|
|
try: |
|
|
from unet import unet as celebamask_unet |
|
|
from utils import generate_label |
|
|
HAS_CELEBAMASK = True |
|
|
except ImportError: |
|
|
HAS_CELEBAMASK = False |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
os.environ["HF_HOME"] = "/home/user/app/hf_cache" |
|
|
|
|
|
|
|
|
class BiSeNet(nn.Module): |
|
|
def __init__(self, n_classes=19): |
|
|
super(BiSeNet, self).__init__() |
|
|
|
|
|
self.conv1 = nn.Sequential( |
|
|
nn.Conv2d(3, 64, 3, stride=2, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 64, 3, stride=2, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
self.final = nn.Conv2d(64, n_classes, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv1(x) |
|
|
x = self.final(x) |
|
|
return x |
|
|
|
|
|
|
|
|
CELEBA_CLASSES = [ |
|
|
'background', 'skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', |
|
|
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat' |
|
|
] |
|
|
|
|
|
class AdvancedFaceParsing: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.device = device |
|
|
self.model_type = "unknown" |
|
|
self.load_best_model() |
|
|
|
|
|
def load_best_model(self): |
|
|
"""سعی میکند بهترین مدل موجود را لود کند""" |
|
|
models_to_try = [ |
|
|
|
|
|
{ |
|
|
"name": "BiSeNet-Face-Parsing", |
|
|
"repo_id": "yangyuke001/bisenet-face-parsing", |
|
|
"filename": "model.pth", |
|
|
"constructor": self.create_bisenet |
|
|
}, |
|
|
{ |
|
|
"name": "CelebAMask-HQ-Improved", |
|
|
"repo_id": "public-data/CelebAMask-HQ-Face-Parsing", |
|
|
"filename": "models/model.pth", |
|
|
"constructor": self.create_celebamask_unet |
|
|
}, |
|
|
|
|
|
{ |
|
|
"name": "CelebAMask-HQ-Original", |
|
|
"repo_id": "public-data/CelebAMask-HQ-Face-Parsing", |
|
|
"filename": "model.pth", |
|
|
"constructor": self.create_celebamask_unet |
|
|
} |
|
|
] |
|
|
|
|
|
for model_info in models_to_try: |
|
|
try: |
|
|
print(f"🔄 Trying {model_info['name']}...") |
|
|
model_path = hf_hub_download( |
|
|
repo_id=model_info["repo_id"], |
|
|
filename=model_info["filename"], |
|
|
cache_dir="/home/user/app/hf_cache" |
|
|
) |
|
|
|
|
|
self.model = model_info["constructor"]() |
|
|
state_dict = torch.load(model_path, map_location="cpu") |
|
|
|
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith('module.'): |
|
|
k = k[7:] |
|
|
new_state_dict[k] = v |
|
|
|
|
|
self.model.load_state_dict(new_state_dict, strict=False) |
|
|
self.model.eval() |
|
|
self.model.to(self.device) |
|
|
self.model_type = model_info["name"] |
|
|
|
|
|
print(f"✅ Successfully loaded {model_info['name']}") |
|
|
return |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load {model_info['name']}: {e}") |
|
|
continue |
|
|
|
|
|
print("⚠️ Could not load any model, using simple fallback") |
|
|
self.model = self.create_simple_model() |
|
|
self.model_type = "Simple-Fallback" |
|
|
|
|
|
def create_bisenet(self): |
|
|
"""ایجاد مدل BiSeNet""" |
|
|
return BiSeNet(n_classes=19) |
|
|
|
|
|
def create_celebamask_unet(self): |
|
|
"""ایجاد مدل CelebAMask-HQ U-Net""" |
|
|
if HAS_CELEBAMASK: |
|
|
return celebamask_unet( |
|
|
feature_scale=4, |
|
|
n_classes=19, |
|
|
is_deconv=True, |
|
|
in_channels=3, |
|
|
is_batchnorm=True |
|
|
) |
|
|
else: |
|
|
return self.create_simple_model() |
|
|
|
|
|
def create_simple_model(self): |
|
|
"""مدل ساده fallback""" |
|
|
return nn.Sequential( |
|
|
nn.Conv2d(3, 64, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 19, 1) |
|
|
) |
|
|
|
|
|
def predict(self, image): |
|
|
"""پردازش تصویر""" |
|
|
if self.model is None: |
|
|
raise ValueError("Model not loaded") |
|
|
|
|
|
|
|
|
if isinstance(image, str): |
|
|
image = PIL.Image.open(image).convert('RGB') |
|
|
elif isinstance(image, np.ndarray): |
|
|
image = PIL.Image.fromarray(image) |
|
|
|
|
|
original_image = image.copy() |
|
|
|
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize((512, 512)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
data = transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.model(data) |
|
|
|
|
|
|
|
|
if hasattr(self, 'generate_label') and HAS_CELEBAMASK: |
|
|
mask = generate_label(out, 512)[0].cpu().numpy() |
|
|
else: |
|
|
|
|
|
mask = torch.argmax(out, dim=1)[0].cpu().numpy() |
|
|
|
|
|
colored_mask = self.colorize_mask(mask) |
|
|
|
|
|
|
|
|
resized_image = np.asarray(original_image.resize((512, 512))) |
|
|
blended = cv2.addWeighted(resized_image, 0.7, colored_mask, 0.3, 0) |
|
|
|
|
|
return colored_mask, blended, self.model_type |
|
|
|
|
|
def colorize_mask(self, mask): |
|
|
"""رنگآمیزی ماسک""" |
|
|
palette = [ |
|
|
[0, 0, 0], [255, 200, 200], [0, 255, 0], [0, 200, 0], |
|
|
[255, 0, 0], [200, 0, 0], [255, 255, 0], [0, 0, 255], |
|
|
[0, 0, 200], [128, 0, 128], [255, 165, 0], [255, 0, 255], |
|
|
[200, 0, 200], [165, 42, 42], [0, 255, 255], [0, 200, 200], |
|
|
[128, 128, 128], [255, 255, 255], [255, 215, 0] |
|
|
] |
|
|
|
|
|
colored = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) |
|
|
for i in range(len(palette)): |
|
|
colored[mask == i] = palette[i] |
|
|
|
|
|
return colored |
|
|
|
|
|
|
|
|
face_parser = AdvancedFaceParsing() |
|
|
print(f"🎯 Loaded model: {face_parser.model_type}") |
|
|
|
|
|
def process_image(input_image): |
|
|
if input_image is None: |
|
|
return None, None, "لطفاً یک تصویر آپلود کنید" |
|
|
|
|
|
try: |
|
|
mask, blended, model_type = face_parser.predict(input_image) |
|
|
|
|
|
info_text = f""" |
|
|
✅ پردازش انجام شد با {model_type}! |
|
|
- مدل: {model_type} |
|
|
- کلاسهای تشخیص: {len(CELEBA_CLASSES)} |
|
|
- دستگاه: {device} |
|
|
""" |
|
|
|
|
|
return blended, mask, info_text |
|
|
|
|
|
except Exception as e: |
|
|
return None, None, f"❌ خطا: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# 🎭 CelebAMask-HQ Face Parsing Demo |
|
|
**آپلود یک تصویر صورت و دریافت خروجی Face Parsing** |
|
|
|
|
|
این مدل صورت را به 19 بخش مختلف تقسیم میکند (پوست، چشم، ابرو، بینی، دهان، مو و ...) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image( |
|
|
label="📷 تصویر ورودی", |
|
|
type="filepath", |
|
|
sources=["upload"], |
|
|
height=300 |
|
|
) |
|
|
process_btn = gr.Button("🚀 پردازش تصویر", variant="primary", size="lg") |
|
|
|
|
|
with gr.Accordion("ℹ️ وضعیت برنامه", open=False): |
|
|
status_display = gr.Markdown(f""" |
|
|
**وضعیت:** |
|
|
- 🎯 مدل: {'✅ لود شده' if success else '❌ خطا در لود'} |
|
|
- 💻 دستگاه: `{device}` |
|
|
- 📦 ماژولها: {'✅ ایمپورت شده' if IMPORT_SUCCESS else '❌ خطا در ایمپورت'} |
|
|
- 🗂️ کلاسها: {len(CELEBA_CLASSES)} |
|
|
""") |
|
|
|
|
|
with gr.Column(): |
|
|
output_blended = gr.Image( |
|
|
label="🎨 نتیجه ترکیبی (تصویر + ماسک)", |
|
|
height=300 |
|
|
) |
|
|
output_mask = gr.Image( |
|
|
label="🎭 ماسک سگمنتیشن", |
|
|
height=300 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
info_output = gr.Textbox( |
|
|
label="📊 اطلاعات پردازش", |
|
|
lines=3, |
|
|
max_lines=6 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.HTML(create_legend()) |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=process_image, |
|
|
inputs=[input_image], |
|
|
outputs=[output_blended, output_mask, info_output] |
|
|
) |
|
|
|
|
|
input_image.upload( |
|
|
fn=process_image, |
|
|
inputs=[input_image], |
|
|
outputs=[output_blended, output_mask, info_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("🚀 Starting Face Parsing Application...") |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |