Spaces:
Runtime error
Runtime error
File size: 7,360 Bytes
95c3b77 4ed5f6a 7ccd451 95c3b77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import gradio as gr
from overlay import overlay_source
from detect_face import predict, NUM_CLASSES
from swapface import swap_face_now
from baldhead import inference
from segmentation import get_facemesh_mask
import os
from pathlib import Path
from PIL import Image
import numpy as np
import cv2
BASE_DIR = Path(__file__).parent # thư mục chứa app.py
FOLDER = BASE_DIR / "example_wigs"
# --- Hàm load ảnh từ folder ---
def load_images_from_folder(folder_path: str) -> list[str]:
"""
Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path.
"""
supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
if not os.path.isdir(folder_path):
print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.")
return []
files = [
os.path.join(folder_path, fn)
for fn in os.listdir(folder_path)
if os.path.splitext(fn)[1].lower() in supported
]
if not files:
print(f"Không tìm thấy hình trong: {folder_path}")
return files
def on_gallery_select(evt: gr.SelectData):
"""
Khi click thumbnail: trả về
1) filepath để nạp vào Image Source
2) tên file (basename) để hiển thị trong Textbox
"""
val = evt.value
# --- logic trích filepath y như cũ ---
if isinstance(val, dict):
img = val.get("image")
if isinstance(img, str):
filepath = img
elif isinstance(img, dict):
filepath = img.get("path") or img.get("url")
else:
filepath = next(
(v for v in val.values() if isinstance(v, str) and os.path.isfile(v)),
None
)
elif isinstance(val, str):
filepath = val
else:
raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")
# Lấy tên file không có phần mở rộng
filename = os.path.splitext(os.path.basename(filepath))[0] if filepath else ""
return filepath, filename
# --- Hàm xác định folder dựa trên phân lớp ---
def infer_folder(image) -> str:
cls = predict(image)["predicted_class"]
folder = str(FOLDER / cls)
return folder
def get_face_no_forehead(image):
"""
Get face mask without forehead extension (only MediaPipe face mesh)
"""
image = image.convert("RGB")
# Face mesh mask (chỉ lấy mặt MediaPipe, không mở rộng trán)
face_mesh_mask = get_facemesh_mask(image)
# Làm mượt mask
face_mesh_mask = cv2.GaussianBlur(face_mesh_mask.astype(np.float32), (3, 3), 0)
face_mesh_mask = (face_mesh_mask > 0.5).astype(np.uint8)
np_image = np.array(image)
alpha = (face_mesh_mask * 255).astype(np.uint8)
rgba_image = np.dstack([np_image, alpha])
return Image.fromarray(rgba_image)
# --- Hàm gộp: phân loại + load ảnh ---
def handle_bg_change(image):
"""
Khi thay đổi background:
1. Phân loại khuôn mặt
2. Load ảnh từ folder tương ứng
"""
if image is None:
return "", []
try:
folder = infer_folder(image)
images = load_images_from_folder(folder)
return folder, images
except Exception as e:
print(f"Lỗi xử lý ảnh: {e}")
return "", []
# --- Pipeline đầy đủ: workflow 5 bước mới ---
def complete_pipeline(background: Image.Image, source: Image.Image):
if background is None or source is None:
return None
try:
# Bước 2: Apply baldhead to background (chuyển từ bước 4)
bg_bald = inference(background)
print("✅ Applied baldhead to background")
# Bước 1,3,4: Overlay workflow với background baldhead
overlay_result = overlay_source(bg_bald, source)
if overlay_result is None:
print("❌ Overlay workflow failed.")
return None
print("✅ Overlay workflow completed")
# Bước 5: Face Swapping - background baldhead và overlay result
final_result = swap_face_now(
background.convert("RGB"), # background baldhead
overlay_result.convert("RGB") # overlay result
)
if final_result is None:
print("❌ Face swap failed, using overlay result as fallback")
result_img = overlay_result
else:
print("✅ Face swap success")
result_img = final_result
# Clear cache after getting result
import gc
import torch
# Clear variables
overlay_result = None
bg_bald = None
final_result = None
# Force garbage collection
gc.collect()
# Clear GPU cache if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("✅ GPU cache cleared")
print("✅ Memory cache cleared")
return result_img
except Exception as e:
print(f"❌ Lỗi trong complete_pipeline: {e}")
# Clear cache on error too
import gc
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return None
# --- Xây dựng giao diện Gradio ---
def build_demo():
with gr.Blocks(title="Xử lý hai hình ảnh", theme=gr.themes.Soft()) as demo:
gr.Markdown("Upload Background & Source, click **Run** to try on wigs.")
with gr.Row():
bg = gr.Image(type="pil", label="Background", height=500)
src = gr.Image(type="pil", label="Source", height=500, interactive=False)
out = gr.Image(label="Result", height=500, interactive=False)
folder_path_box = gr.Textbox(label="Folder path", visible=False)
with gr.Row():
src_name_box = gr.Textbox(
label="Wigs Name",
interactive=False,
show_copy_button=True , # tuỳ chọn – tiện copy đường dẫn
scale = 1
)
gallery = gr.Gallery(
label="Recommend For You",
height=300,
value=[],
type="filepath",
interactive=False,
columns=5,
object_fit="cover",
allow_preview=True,
scale = 8
)
btn = gr.Button("🔄 Run", variant="primary",scale = 1)
# Chạy pipeline đầy đủ: ghép tóc + swap face cuối
btn.click(fn=complete_pipeline, inputs=[bg, src], outputs=[out])
# Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý
bg.change(
fn=handle_bg_change,
inputs=[bg],
outputs=[folder_path_box, gallery],
show_progress=True
)
# Nút tải lại ảnh thủ công (backup)
# Khi chọn ảnh trong gallery, cập nhật vào khung Source
gallery.select(
fn=on_gallery_select,
outputs=[src, src_name_box]
)
return demo
if __name__ == "__main__":
build_demo().launch()
|