|
|
import os |
|
|
import sys |
|
|
import numpy as np |
|
|
import PIL.Image |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
from huggingface_hub import hf_hub_download |
|
|
import gradio as gr |
|
|
import time |
|
|
import cv2 |
|
|
|
|
|
|
|
|
celebamask_path = "/home/user/app/CelebAMask-HQ" |
|
|
face_parsing_path = os.path.join(celebamask_path, "face_parsing") |
|
|
sys.path.insert(0, celebamask_path) |
|
|
sys.path.insert(0, face_parsing_path) |
|
|
|
|
|
print("Python path:", sys.path) |
|
|
print("CelebAMask path exists:", os.path.exists(celebamask_path)) |
|
|
print("Face parsing path exists:", os.path.exists(face_parsing_path)) |
|
|
|
|
|
|
|
|
try: |
|
|
from unet import unet |
|
|
from utils import generate_label |
|
|
IMPORT_SUCCESS = True |
|
|
print("✅ Successfully imported CelebAMask-HQ modules") |
|
|
except ImportError as e: |
|
|
IMPORT_SUCCESS = False |
|
|
print(f"❌ Failed to import CelebAMask-HQ modules: {e}") |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/home/user/app/hf_cache" |
|
|
|
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize((512, 512)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
CELEBA_CLASSES = [ |
|
|
'background', 'skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', |
|
|
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat' |
|
|
] |
|
|
|
|
|
class FaceParsingModel: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.device = device |
|
|
self.load_model() |
|
|
|
|
|
def load_model(self): |
|
|
"""لود مدل Face Parsing""" |
|
|
try: |
|
|
print("📥 Downloading model...") |
|
|
model_path = hf_hub_download( |
|
|
repo_id="public-data/CelebAMask-HQ-Face-Parsing", |
|
|
filename="models/model.pth", |
|
|
cache_dir="/home/user/app/hf_cache" |
|
|
) |
|
|
print(f"✅ Model downloaded to: {model_path}") |
|
|
|
|
|
|
|
|
self.model = unet( |
|
|
feature_scale=4, |
|
|
n_classes=19, |
|
|
is_deconv=True, |
|
|
in_channels=3, |
|
|
is_batchnorm=True |
|
|
) |
|
|
|
|
|
|
|
|
state_dict = torch.load(model_path, map_location="cpu") |
|
|
|
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith('module.'): |
|
|
k = k[7:] |
|
|
new_state_dict[k] = v |
|
|
|
|
|
self.model.load_state_dict(new_state_dict) |
|
|
self.model.eval() |
|
|
self.model.to(self.device) |
|
|
|
|
|
print("✅ Model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load model: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
self.model = None |
|
|
|
|
|
def predict(self, image): |
|
|
"""پردازش تصویر و تولید ماسک""" |
|
|
if self.model is None: |
|
|
raise ValueError("Model not loaded properly") |
|
|
|
|
|
|
|
|
if isinstance(image, str): |
|
|
image = PIL.Image.open(image).convert('RGB') |
|
|
elif isinstance(image, np.ndarray): |
|
|
image = PIL.Image.fromarray(image) |
|
|
|
|
|
|
|
|
original_image = image.copy() |
|
|
|
|
|
|
|
|
data = transform(image) |
|
|
data = data.unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.model(data) |
|
|
label_out = generate_label(out, 512) |
|
|
mask = label_out[0].cpu().numpy() |
|
|
|
|
|
|
|
|
colored_mask = self.colorize_mask(mask) |
|
|
|
|
|
|
|
|
resized_image = np.asarray(original_image.resize((512, 512))) |
|
|
blended = cv2.addWeighted(resized_image, 0.7, colored_mask, 0.3, 0) |
|
|
|
|
|
return colored_mask, blended |
|
|
|
|
|
def colorize_mask(self, mask): |
|
|
"""رنگآمیزی ماسک بر اساس کلاسها""" |
|
|
|
|
|
palette = [ |
|
|
[0, 0, 0], |
|
|
[255, 200, 200], |
|
|
[0, 255, 0], |
|
|
[0, 200, 0], |
|
|
[255, 0, 0], |
|
|
[200, 0, 0], |
|
|
[255, 255, 0], |
|
|
[0, 0, 255], |
|
|
[0, 0, 200], |
|
|
[128, 0, 128], |
|
|
[255, 165, 0], |
|
|
[255, 0, 255], |
|
|
[200, 0, 200], |
|
|
[165, 42, 42], |
|
|
[0, 255, 255], |
|
|
[0, 200, 200], |
|
|
[128, 128, 128], |
|
|
[255, 255, 255], |
|
|
[255, 215, 0] |
|
|
] |
|
|
|
|
|
colored = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) |
|
|
for i in range(len(palette)): |
|
|
colored[mask == i] = palette[i] |
|
|
|
|
|
return colored |
|
|
|
|
|
def initialize_app(): |
|
|
"""Initialize application""" |
|
|
print("===== Application Startup at {} =====".format(time.strftime("%Y-%m-%d %H:%M:%S"))) |
|
|
|
|
|
print("[Info] PYTHONPATH:", os.environ.get("PYTHONPATH")) |
|
|
print("[Info] CelebAMask-HQ path exists:", os.path.exists(celebamask_path)) |
|
|
print("[Info] face_parsing folder exists:", os.path.exists(face_parsing_path)) |
|
|
print("[Info] Module import success:", IMPORT_SUCCESS) |
|
|
|
|
|
try: |
|
|
face_parser = FaceParsingModel() |
|
|
success = face_parser.model is not None |
|
|
status_msg = "Model loaded successfully" if success else "Model failed to load" |
|
|
return success, status_msg, face_parser |
|
|
except Exception as e: |
|
|
print(f"[Error] Initialization failed: {e}") |
|
|
return False, f"Initialization failed: {e}", None |
|
|
|
|
|
|
|
|
success, status_msg, face_parser = initialize_app() |
|
|
|
|
|
def process_image(input_image): |
|
|
"""پردازش تصویر ورودی""" |
|
|
if input_image is None: |
|
|
return None, None, "لطفاً یک تصویر آپلود کنید" |
|
|
|
|
|
if not success or face_parser is None: |
|
|
return None, None, "❌ مدل لود نشده است. لطفاً دوباره تلاش کنید." |
|
|
|
|
|
try: |
|
|
|
|
|
mask, blended = face_parser.predict(input_image) |
|
|
|
|
|
|
|
|
if isinstance(input_image, str): |
|
|
original_img = PIL.Image.open(input_image) |
|
|
img_size = original_img.size |
|
|
else: |
|
|
img_size = input_image.size if hasattr(input_image, 'size') else input_image.shape[:2][::-1] |
|
|
|
|
|
info_text = f""" |
|
|
✅ پردازش انجام شد! |
|
|
- اندازه تصویر ورودی: {img_size} |
|
|
- اندازه خروجی: 512x512 |
|
|
- کلاسهای تشخیص: {len(CELEBA_CLASSES)} |
|
|
- دستگاه پردازش: {device} |
|
|
""" |
|
|
|
|
|
return blended, mask, info_text |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ خطا در پردازش تصویر: {str(e)}" |
|
|
print(error_msg) |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, None, error_msg |
|
|
|
|
|
def create_legend(): |
|
|
"""ایجاد لیجند برای کلاسها""" |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
legend_html = """ |
|
|
<div style='max-height: 300px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; border-radius: 5px;'> |
|
|
<h4>🎨 Legend - کلاسهای Face Parsing:</h4> |
|
|
""" |
|
|
|
|
|
colors = plt.get_cmap('tab20', len(CELEBA_CLASSES)) |
|
|
|
|
|
for i, class_name in enumerate(CELEBA_CLASSES): |
|
|
color = colors(i) |
|
|
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255)) |
|
|
text_color = 'white' if color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114 < 0.5 else 'black' |
|
|
legend_html += f""" |
|
|
<div style='margin: 2px; padding: 5px; background-color: {color_hex}; color: {text_color}; border-radius: 3px;'> |
|
|
<strong>{i}:</strong> {class_name} |
|
|
</div> |
|
|
""" |
|
|
|
|
|
legend_html += "</div>" |
|
|
return legend_html |
|
|
|
|
|
|
|
|
with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# 🎭 CelebAMask-HQ Face Parsing Demo |
|
|
**آپلود یک تصویر صورت و دریافت خروجی Face Parsing** |
|
|
|
|
|
این مدل صورت را به 19 بخش مختلف تقسیم میکند (پوست، چشم، ابرو، بینی، دهان، مو و ...) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image( |
|
|
label="📷 تصویر ورودی", |
|
|
type="filepath", |
|
|
sources=["upload"], |
|
|
height=300 |
|
|
) |
|
|
process_btn = gr.Button("🚀 پردازش تصویر", variant="primary", size="lg") |
|
|
|
|
|
with gr.Accordion("ℹ️ وضعیت برنامه", open=False): |
|
|
status_display = gr.Markdown(f""" |
|
|
**وضعیت:** |
|
|
- 🎯 مدل: {'✅ لود شده' if success else '❌ خطا در لود'} |
|
|
- 💻 دستگاه: `{device}` |
|
|
- 📦 ماژولها: {'✅ ایمپورت شده' if IMPORT_SUCCESS else '❌ خطا در ایمپورت'} |
|
|
- 🗂️ کلاسها: {len(CELEBA_CLASSES)} |
|
|
""") |
|
|
|
|
|
with gr.Column(): |
|
|
output_blended = gr.Image( |
|
|
label="🎨 نتیجه ترکیبی (تصویر + ماسک)", |
|
|
height=300 |
|
|
) |
|
|
output_mask = gr.Image( |
|
|
label="🎭 ماسک سگمنتیشن", |
|
|
height=300 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
info_output = gr.Textbox( |
|
|
label="📊 اطلاعات پردازش", |
|
|
lines=3, |
|
|
max_lines=6 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.HTML(create_legend()) |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=process_image, |
|
|
inputs=[input_image], |
|
|
outputs=[output_blended, output_mask, info_output] |
|
|
) |
|
|
|
|
|
input_image.upload( |
|
|
fn=process_image, |
|
|
inputs=[input_image], |
|
|
outputs=[output_blended, output_mask, info_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("🚀 Starting Face Parsing Application...") |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |