import gradio as gr from overlay import overlay_source from detect_face import predict, NUM_CLASSES from swapface import swap_face_now from baldhead import inference from segmentation import get_facemesh_mask import os from pathlib import Path from PIL import Image import numpy as np import cv2 BASE_DIR = Path(__file__).parent # thư mục chứa app.py FOLDER = BASE_DIR / "example_wigs" # --- Hàm load ảnh từ folder --- def load_images_from_folder(folder_path: str) -> list[str]: """ Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path. """ supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'} if not os.path.isdir(folder_path): print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.") return [] files = [ os.path.join(folder_path, fn) for fn in os.listdir(folder_path) if os.path.splitext(fn)[1].lower() in supported ] if not files: print(f"Không tìm thấy hình trong: {folder_path}") return files def on_gallery_select(evt: gr.SelectData): """ Khi click thumbnail: trả về 1) filepath để nạp vào Image Source 2) tên file (basename) để hiển thị trong Textbox """ val = evt.value # --- logic trích filepath y như cũ --- if isinstance(val, dict): img = val.get("image") if isinstance(img, str): filepath = img elif isinstance(img, dict): filepath = img.get("path") or img.get("url") else: filepath = next( (v for v in val.values() if isinstance(v, str) and os.path.isfile(v)), None ) elif isinstance(val, str): filepath = val else: raise ValueError(f"Kiểu không hỗ trợ: {type(val)}") # Lấy tên file không có phần mở rộng filename = os.path.splitext(os.path.basename(filepath))[0] if filepath else "" return filepath, filename # --- Hàm xác định folder dựa trên phân lớp --- def infer_folder(image) -> str: cls = predict(image)["predicted_class"] folder = str(FOLDER / cls) return folder def get_face_no_forehead(image): """ Get face mask without forehead extension (only MediaPipe face mesh) """ image = image.convert("RGB") # Face mesh mask (chỉ lấy mặt MediaPipe, không mở rộng trán) face_mesh_mask = get_facemesh_mask(image) # Làm mượt mask face_mesh_mask = cv2.GaussianBlur(face_mesh_mask.astype(np.float32), (3, 3), 0) face_mesh_mask = (face_mesh_mask > 0.5).astype(np.uint8) np_image = np.array(image) alpha = (face_mesh_mask * 255).astype(np.uint8) rgba_image = np.dstack([np_image, alpha]) return Image.fromarray(rgba_image) # --- Hàm gộp: phân loại + load ảnh --- def handle_bg_change(image): """ Khi thay đổi background: 1. Phân loại khuôn mặt 2. Load ảnh từ folder tương ứng """ if image is None: return "", [] try: folder = infer_folder(image) images = load_images_from_folder(folder) return folder, images except Exception as e: print(f"Lỗi xử lý ảnh: {e}") return "", [] # --- Pipeline đầy đủ: workflow 5 bước mới --- def complete_pipeline(background: Image.Image, source: Image.Image): if background is None or source is None: return None try: # Bước 2: Apply baldhead to background (chuyển từ bước 4) bg_bald = inference(background) print("✅ Applied baldhead to background") # Bước 1,3,4: Overlay workflow với background baldhead overlay_result = overlay_source(bg_bald, source) if overlay_result is None: print("❌ Overlay workflow failed.") return None print("✅ Overlay workflow completed") # Bước 5: Face Swapping - background baldhead và overlay result final_result = swap_face_now( background.convert("RGB"), # background baldhead overlay_result.convert("RGB") # overlay result ) if final_result is None: print("❌ Face swap failed, using overlay result as fallback") result_img = overlay_result else: print("✅ Face swap success") result_img = final_result # Clear cache after getting result import gc import torch # Clear variables overlay_result = None bg_bald = None final_result = None # Force garbage collection gc.collect() # Clear GPU cache if available if torch.cuda.is_available(): torch.cuda.empty_cache() print("✅ GPU cache cleared") print("✅ Memory cache cleared") return result_img except Exception as e: print(f"❌ Lỗi trong complete_pipeline: {e}") # Clear cache on error too import gc import torch gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return None # --- Xây dựng giao diện Gradio --- def build_demo(): with gr.Blocks(title="Xử lý hai hình ảnh", theme=gr.themes.Soft()) as demo: gr.Markdown("Upload Background & Source, click **Run** to try on wigs.") with gr.Row(): bg = gr.Image(type="pil", label="Background", height=500) src = gr.Image(type="pil", label="Source", height=500, interactive=False) out = gr.Image(label="Result", height=500, interactive=False) folder_path_box = gr.Textbox(label="Folder path", visible=False) with gr.Row(): src_name_box = gr.Textbox( label="Wigs Name", interactive=False, show_copy_button=True , # tuỳ chọn – tiện copy đường dẫn scale = 1 ) gallery = gr.Gallery( label="Recommend For You", height=300, value=[], type="filepath", interactive=False, columns=5, object_fit="cover", allow_preview=True, scale = 8 ) btn = gr.Button("🔄 Run", variant="primary",scale = 1) # Chạy pipeline đầy đủ: ghép tóc + swap face cuối btn.click(fn=complete_pipeline, inputs=[bg, src], outputs=[out]) # Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý bg.change( fn=handle_bg_change, inputs=[bg], outputs=[folder_path_box, gallery], show_progress=True ) # Nút tải lại ảnh thủ công (backup) # Khi chọn ảnh trong gallery, cập nhật vào khung Source gallery.select( fn=on_gallery_select, outputs=[src, src_name_box] ) return demo if __name__ == "__main__": build_demo().launch()