be_rejection / app.py
VanNguyen1214's picture
Update app.py
4ed5f6a verified
import gradio as gr
from overlay import overlay_source
from detect_face import predict, NUM_CLASSES
from swapface import swap_face_now
from baldhead import inference
from segmentation import get_facemesh_mask
import os
from pathlib import Path
from PIL import Image
import numpy as np
import cv2
BASE_DIR = Path(__file__).parent # thư mục chứa app.py
FOLDER = BASE_DIR / "example_wigs"
# --- Hàm load ảnh từ folder ---
def load_images_from_folder(folder_path: str) -> list[str]:
"""
Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path.
"""
supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
if not os.path.isdir(folder_path):
print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.")
return []
files = [
os.path.join(folder_path, fn)
for fn in os.listdir(folder_path)
if os.path.splitext(fn)[1].lower() in supported
]
if not files:
print(f"Không tìm thấy hình trong: {folder_path}")
return files
def on_gallery_select(evt: gr.SelectData):
"""
Khi click thumbnail: trả về
1) filepath để nạp vào Image Source
2) tên file (basename) để hiển thị trong Textbox
"""
val = evt.value
# --- logic trích filepath y như cũ ---
if isinstance(val, dict):
img = val.get("image")
if isinstance(img, str):
filepath = img
elif isinstance(img, dict):
filepath = img.get("path") or img.get("url")
else:
filepath = next(
(v for v in val.values() if isinstance(v, str) and os.path.isfile(v)),
None
)
elif isinstance(val, str):
filepath = val
else:
raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")
# Lấy tên file không có phần mở rộng
filename = os.path.splitext(os.path.basename(filepath))[0] if filepath else ""
return filepath, filename
# --- Hàm xác định folder dựa trên phân lớp ---
def infer_folder(image) -> str:
cls = predict(image)["predicted_class"]
folder = str(FOLDER / cls)
return folder
def get_face_no_forehead(image):
"""
Get face mask without forehead extension (only MediaPipe face mesh)
"""
image = image.convert("RGB")
# Face mesh mask (chỉ lấy mặt MediaPipe, không mở rộng trán)
face_mesh_mask = get_facemesh_mask(image)
# Làm mượt mask
face_mesh_mask = cv2.GaussianBlur(face_mesh_mask.astype(np.float32), (3, 3), 0)
face_mesh_mask = (face_mesh_mask > 0.5).astype(np.uint8)
np_image = np.array(image)
alpha = (face_mesh_mask * 255).astype(np.uint8)
rgba_image = np.dstack([np_image, alpha])
return Image.fromarray(rgba_image)
# --- Hàm gộp: phân loại + load ảnh ---
def handle_bg_change(image):
"""
Khi thay đổi background:
1. Phân loại khuôn mặt
2. Load ảnh từ folder tương ứng
"""
if image is None:
return "", []
try:
folder = infer_folder(image)
images = load_images_from_folder(folder)
return folder, images
except Exception as e:
print(f"Lỗi xử lý ảnh: {e}")
return "", []
# --- Pipeline đầy đủ: workflow 5 bước mới ---
def complete_pipeline(background: Image.Image, source: Image.Image):
if background is None or source is None:
return None
try:
# Bước 2: Apply baldhead to background (chuyển từ bước 4)
bg_bald = inference(background)
print("✅ Applied baldhead to background")
# Bước 1,3,4: Overlay workflow với background baldhead
overlay_result = overlay_source(bg_bald, source)
if overlay_result is None:
print("❌ Overlay workflow failed.")
return None
print("✅ Overlay workflow completed")
# Bước 5: Face Swapping - background baldhead và overlay result
final_result = swap_face_now(
background.convert("RGB"), # background baldhead
overlay_result.convert("RGB") # overlay result
)
if final_result is None:
print("❌ Face swap failed, using overlay result as fallback")
result_img = overlay_result
else:
print("✅ Face swap success")
result_img = final_result
# Clear cache after getting result
import gc
import torch
# Clear variables
overlay_result = None
bg_bald = None
final_result = None
# Force garbage collection
gc.collect()
# Clear GPU cache if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("✅ GPU cache cleared")
print("✅ Memory cache cleared")
return result_img
except Exception as e:
print(f"❌ Lỗi trong complete_pipeline: {e}")
# Clear cache on error too
import gc
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return None
# --- Xây dựng giao diện Gradio ---
def build_demo():
with gr.Blocks(title="Xử lý hai hình ảnh", theme=gr.themes.Soft()) as demo:
gr.Markdown("Upload Background & Source, click **Run** to try on wigs.")
with gr.Row():
bg = gr.Image(type="pil", label="Background", height=500)
src = gr.Image(type="pil", label="Source", height=500, interactive=False)
out = gr.Image(label="Result", height=500, interactive=False)
folder_path_box = gr.Textbox(label="Folder path", visible=False)
with gr.Row():
src_name_box = gr.Textbox(
label="Wigs Name",
interactive=False,
show_copy_button=True , # tuỳ chọn – tiện copy đường dẫn
scale = 1
)
gallery = gr.Gallery(
label="Recommend For You",
height=300,
value=[],
type="filepath",
interactive=False,
columns=5,
object_fit="cover",
allow_preview=True,
scale = 8
)
btn = gr.Button("🔄 Run", variant="primary",scale = 1)
# Chạy pipeline đầy đủ: ghép tóc + swap face cuối
btn.click(fn=complete_pipeline, inputs=[bg, src], outputs=[out])
# Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý
bg.change(
fn=handle_bg_change,
inputs=[bg],
outputs=[folder_path_box, gallery],
show_progress=True
)
# Nút tải lại ảnh thủ công (backup)
# Khi chọn ảnh trong gallery, cập nhật vào khung Source
gallery.select(
fn=on_gallery_select,
outputs=[src, src_name_box]
)
return demo
if __name__ == "__main__":
build_demo().launch()