Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from overlay import overlay_source | |
| from detect_face import predict, NUM_CLASSES | |
| from swapface import swap_face_now | |
| from baldhead import inference | |
| from segmentation import get_facemesh_mask | |
| import os | |
| from pathlib import Path | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| BASE_DIR = Path(__file__).parent # thư mục chứa app.py | |
| FOLDER = BASE_DIR / "example_wigs" | |
| # --- Hàm load ảnh từ folder --- | |
| def load_images_from_folder(folder_path: str) -> list[str]: | |
| """ | |
| Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path. | |
| """ | |
| supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'} | |
| if not os.path.isdir(folder_path): | |
| print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.") | |
| return [] | |
| files = [ | |
| os.path.join(folder_path, fn) | |
| for fn in os.listdir(folder_path) | |
| if os.path.splitext(fn)[1].lower() in supported | |
| ] | |
| if not files: | |
| print(f"Không tìm thấy hình trong: {folder_path}") | |
| return files | |
| def on_gallery_select(evt: gr.SelectData): | |
| """ | |
| Khi click thumbnail: trả về | |
| 1) filepath để nạp vào Image Source | |
| 2) tên file (basename) để hiển thị trong Textbox | |
| """ | |
| val = evt.value | |
| # --- logic trích filepath y như cũ --- | |
| if isinstance(val, dict): | |
| img = val.get("image") | |
| if isinstance(img, str): | |
| filepath = img | |
| elif isinstance(img, dict): | |
| filepath = img.get("path") or img.get("url") | |
| else: | |
| filepath = next( | |
| (v for v in val.values() if isinstance(v, str) and os.path.isfile(v)), | |
| None | |
| ) | |
| elif isinstance(val, str): | |
| filepath = val | |
| else: | |
| raise ValueError(f"Kiểu không hỗ trợ: {type(val)}") | |
| # Lấy tên file không có phần mở rộng | |
| filename = os.path.splitext(os.path.basename(filepath))[0] if filepath else "" | |
| return filepath, filename | |
| # --- Hàm xác định folder dựa trên phân lớp --- | |
| def infer_folder(image) -> str: | |
| cls = predict(image)["predicted_class"] | |
| folder = str(FOLDER / cls) | |
| return folder | |
| def get_face_no_forehead(image): | |
| """ | |
| Get face mask without forehead extension (only MediaPipe face mesh) | |
| """ | |
| image = image.convert("RGB") | |
| # Face mesh mask (chỉ lấy mặt MediaPipe, không mở rộng trán) | |
| face_mesh_mask = get_facemesh_mask(image) | |
| # Làm mượt mask | |
| face_mesh_mask = cv2.GaussianBlur(face_mesh_mask.astype(np.float32), (3, 3), 0) | |
| face_mesh_mask = (face_mesh_mask > 0.5).astype(np.uint8) | |
| np_image = np.array(image) | |
| alpha = (face_mesh_mask * 255).astype(np.uint8) | |
| rgba_image = np.dstack([np_image, alpha]) | |
| return Image.fromarray(rgba_image) | |
| # --- Hàm gộp: phân loại + load ảnh --- | |
| def handle_bg_change(image): | |
| """ | |
| Khi thay đổi background: | |
| 1. Phân loại khuôn mặt | |
| 2. Load ảnh từ folder tương ứng | |
| """ | |
| if image is None: | |
| return "", [] | |
| try: | |
| folder = infer_folder(image) | |
| images = load_images_from_folder(folder) | |
| return folder, images | |
| except Exception as e: | |
| print(f"Lỗi xử lý ảnh: {e}") | |
| return "", [] | |
| # --- Pipeline đầy đủ: workflow 5 bước mới --- | |
| def complete_pipeline(background: Image.Image, source: Image.Image): | |
| if background is None or source is None: | |
| return None | |
| try: | |
| # Bước 2: Apply baldhead to background (chuyển từ bước 4) | |
| bg_bald = inference(background) | |
| print("✅ Applied baldhead to background") | |
| # Bước 1,3,4: Overlay workflow với background baldhead | |
| overlay_result = overlay_source(bg_bald, source) | |
| if overlay_result is None: | |
| print("❌ Overlay workflow failed.") | |
| return None | |
| print("✅ Overlay workflow completed") | |
| # Bước 5: Face Swapping - background baldhead và overlay result | |
| final_result = swap_face_now( | |
| background.convert("RGB"), # background baldhead | |
| overlay_result.convert("RGB") # overlay result | |
| ) | |
| if final_result is None: | |
| print("❌ Face swap failed, using overlay result as fallback") | |
| result_img = overlay_result | |
| else: | |
| print("✅ Face swap success") | |
| result_img = final_result | |
| # Clear cache after getting result | |
| import gc | |
| import torch | |
| # Clear variables | |
| overlay_result = None | |
| bg_bald = None | |
| final_result = None | |
| # Force garbage collection | |
| gc.collect() | |
| # Clear GPU cache if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print("✅ GPU cache cleared") | |
| print("✅ Memory cache cleared") | |
| return result_img | |
| except Exception as e: | |
| print(f"❌ Lỗi trong complete_pipeline: {e}") | |
| # Clear cache on error too | |
| import gc | |
| import torch | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return None | |
| # --- Xây dựng giao diện Gradio --- | |
| def build_demo(): | |
| with gr.Blocks(title="Xử lý hai hình ảnh", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("Upload Background & Source, click **Run** to try on wigs.") | |
| with gr.Row(): | |
| bg = gr.Image(type="pil", label="Background", height=500) | |
| src = gr.Image(type="pil", label="Source", height=500, interactive=False) | |
| out = gr.Image(label="Result", height=500, interactive=False) | |
| folder_path_box = gr.Textbox(label="Folder path", visible=False) | |
| with gr.Row(): | |
| src_name_box = gr.Textbox( | |
| label="Wigs Name", | |
| interactive=False, | |
| show_copy_button=True , # tuỳ chọn – tiện copy đường dẫn | |
| scale = 1 | |
| ) | |
| gallery = gr.Gallery( | |
| label="Recommend For You", | |
| height=300, | |
| value=[], | |
| type="filepath", | |
| interactive=False, | |
| columns=5, | |
| object_fit="cover", | |
| allow_preview=True, | |
| scale = 8 | |
| ) | |
| btn = gr.Button("🔄 Run", variant="primary",scale = 1) | |
| # Chạy pipeline đầy đủ: ghép tóc + swap face cuối | |
| btn.click(fn=complete_pipeline, inputs=[bg, src], outputs=[out]) | |
| # Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý | |
| bg.change( | |
| fn=handle_bg_change, | |
| inputs=[bg], | |
| outputs=[folder_path_box, gallery], | |
| show_progress=True | |
| ) | |
| # Nút tải lại ảnh thủ công (backup) | |
| # Khi chọn ảnh trong gallery, cập nhật vào khung Source | |
| gallery.select( | |
| fn=on_gallery_select, | |
| outputs=[src, src_name_box] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| build_demo().launch() | |