test / app.py
danicor's picture
Update app.py
eea2dea verified
import os
import sys
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torchvision.transforms as T
from huggingface_hub import hf_hub_download
import gradio as gr
import time
import cv2
# تنظیم مسیرها
celebamask_path = "/home/user/app/CelebAMask-HQ"
face_parsing_path = os.path.join(celebamask_path, "face_parsing")
sys.path.insert(0, celebamask_path)
sys.path.insert(0, face_parsing_path)
# ایمپورت ماژول‌های اصلی (به عنوان fallback)
try:
from unet import unet as celebamask_unet
from utils import generate_label
HAS_CELEBAMASK = True
except ImportError:
HAS_CELEBAMASK = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["HF_HOME"] = "/home/user/app/hf_cache"
# تعریف BiSeNet (مدل دقیق‌تر)
class BiSeNet(nn.Module):
def __init__(self, n_classes=19):
super(BiSeNet, self).__init__()
# پیاده‌سازی ساده‌شده BiSeNet
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, 3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
# ... (پیاده‌سازی کامل BiSeNet)
self.final = nn.Conv2d(64, n_classes, 1)
def forward(self, x):
x = self.conv1(x)
x = self.final(x)
return x
# کلاس‌های Face Parsing
CELEBA_CLASSES = [
'background', 'skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'
]
class AdvancedFaceParsing:
def __init__(self):
self.model = None
self.device = device
self.model_type = "unknown"
self.load_best_model()
def load_best_model(self):
"""سعی می‌کند بهترین مدل موجود را لود کند"""
models_to_try = [
# مدل‌های دقیق‌تر
{
"name": "BiSeNet-Face-Parsing",
"repo_id": "yangyuke001/bisenet-face-parsing",
"filename": "model.pth",
"constructor": self.create_bisenet
},
{
"name": "CelebAMask-HQ-Improved",
"repo_id": "public-data/CelebAMask-HQ-Face-Parsing",
"filename": "models/model.pth",
"constructor": self.create_celebamask_unet
},
# fallback به مدل اصلی
{
"name": "CelebAMask-HQ-Original",
"repo_id": "public-data/CelebAMask-HQ-Face-Parsing",
"filename": "model.pth",
"constructor": self.create_celebamask_unet
}
]
for model_info in models_to_try:
try:
print(f"🔄 Trying {model_info['name']}...")
model_path = hf_hub_download(
repo_id=model_info["repo_id"],
filename=model_info["filename"],
cache_dir="/home/user/app/hf_cache"
)
self.model = model_info["constructor"]()
state_dict = torch.load(model_path, map_location="cpu")
# تطبیق state dict
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
new_state_dict[k] = v
self.model.load_state_dict(new_state_dict, strict=False)
self.model.eval()
self.model.to(self.device)
self.model_type = model_info["name"]
print(f"✅ Successfully loaded {model_info['name']}")
return
except Exception as e:
print(f"❌ Failed to load {model_info['name']}: {e}")
continue
print("⚠️ Could not load any model, using simple fallback")
self.model = self.create_simple_model()
self.model_type = "Simple-Fallback"
def create_bisenet(self):
"""ایجاد مدل BiSeNet"""
return BiSeNet(n_classes=19)
def create_celebamask_unet(self):
"""ایجاد مدل CelebAMask-HQ U-Net"""
if HAS_CELEBAMASK:
return celebamask_unet(
feature_scale=4,
n_classes=19,
is_deconv=True,
in_channels=3,
is_batchnorm=True
)
else:
return self.create_simple_model()
def create_simple_model(self):
"""مدل ساده fallback"""
return nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 19, 1)
)
def predict(self, image):
"""پردازش تصویر"""
if self.model is None:
raise ValueError("Model not loaded")
# تبدیل تصویر
if isinstance(image, str):
image = PIL.Image.open(image).convert('RGB')
elif isinstance(image, np.ndarray):
image = PIL.Image.fromarray(image)
original_image = image.copy()
# transform
transform = T.Compose([
T.Resize((512, 512)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
data = transform(image).unsqueeze(0).to(self.device)
# پیش‌بینی
with torch.no_grad():
out = self.model(data)
# تولید ماسک
if hasattr(self, 'generate_label') and HAS_CELEBAMASK:
mask = generate_label(out, 512)[0].cpu().numpy()
else:
# روش ساده‌تر
mask = torch.argmax(out, dim=1)[0].cpu().numpy()
colored_mask = self.colorize_mask(mask)
# ترکیب نتایج
resized_image = np.asarray(original_image.resize((512, 512)))
blended = cv2.addWeighted(resized_image, 0.7, colored_mask, 0.3, 0)
return colored_mask, blended, self.model_type
def colorize_mask(self, mask):
"""رنگ‌آمیزی ماسک"""
palette = [
[0, 0, 0], [255, 200, 200], [0, 255, 0], [0, 200, 0],
[255, 0, 0], [200, 0, 0], [255, 255, 0], [0, 0, 255],
[0, 0, 200], [128, 0, 128], [255, 165, 0], [255, 0, 255],
[200, 0, 200], [165, 42, 42], [0, 255, 255], [0, 200, 200],
[128, 128, 128], [255, 255, 255], [255, 215, 0]
]
colored = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
for i in range(len(palette)):
colored[mask == i] = palette[i]
return colored
# استفاده از مدل پیشرفته
face_parser = AdvancedFaceParsing()
print(f"🎯 Loaded model: {face_parser.model_type}")
def process_image(input_image):
if input_image is None:
return None, None, "لطفاً یک تصویر آپلود کنید"
try:
mask, blended, model_type = face_parser.predict(input_image)
info_text = f"""
✅ پردازش انجام شد با {model_type}!
- مدل: {model_type}
- کلاس‌های تشخیص: {len(CELEBA_CLASSES)}
- دستگاه: {device}
"""
return blended, mask, info_text
except Exception as e:
return None, None, f"❌ خطا: {str(e)}"
# ادامه کد Gradio مشابه قبل...
# ایجاد اینترفیس Gradio
with gr.Blocks(title="CelebAMask-HQ Face Parsing", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎭 CelebAMask-HQ Face Parsing Demo
**آپلود یک تصویر صورت و دریافت خروجی Face Parsing**
این مدل صورت را به 19 بخش مختلف تقسیم می‌کند (پوست، چشم، ابرو، بینی، دهان، مو و ...)
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="📷 تصویر ورودی",
type="filepath",
sources=["upload"],
height=300
)
process_btn = gr.Button("🚀 پردازش تصویر", variant="primary", size="lg")
with gr.Accordion("ℹ️ وضعیت برنامه", open=False):
status_display = gr.Markdown(f"""
**وضعیت:**
- 🎯 مدل: {'✅ لود شده' if success else '❌ خطا در لود'}
- 💻 دستگاه: `{device}`
- 📦 ماژول‌ها: {'✅ ایمپورت شده' if IMPORT_SUCCESS else '❌ خطا در ایمپورت'}
- 🗂️ کلاس‌ها: {len(CELEBA_CLASSES)}
""")
with gr.Column():
output_blended = gr.Image(
label="🎨 نتیجه ترکیبی (تصویر + ماسک)",
height=300
)
output_mask = gr.Image(
label="🎭 ماسک سگمنتیشن",
height=300
)
with gr.Row():
info_output = gr.Textbox(
label="📊 اطلاعات پردازش",
lines=3,
max_lines=6
)
with gr.Row():
gr.HTML(create_legend())
# اتصال رویدادها
process_btn.click(
fn=process_image,
inputs=[input_image],
outputs=[output_blended, output_mask, info_output]
)
input_image.upload(
fn=process_image,
inputs=[input_image],
outputs=[output_blended, output_mask, info_output]
)
if __name__ == "__main__":
print("🚀 Starting Face Parsing Application...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)