File size: 7,360 Bytes
95c3b77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed5f6a
 
 
7ccd451
95c3b77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

import gradio as gr
from overlay import overlay_source
from detect_face import predict, NUM_CLASSES
from swapface import swap_face_now
from baldhead import inference
from segmentation import get_facemesh_mask

import os
from pathlib import Path
from PIL import Image
import numpy as np
import cv2


BASE_DIR = Path(__file__).parent  # thư mục chứa app.py
FOLDER = BASE_DIR / "example_wigs"

# --- Hàm load ảnh từ folder ---
def load_images_from_folder(folder_path: str) -> list[str]:
    """
    Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path.
    """
    supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
    if not os.path.isdir(folder_path):
        print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.")
        return []
    files = [
        os.path.join(folder_path, fn)
        for fn in os.listdir(folder_path)
        if os.path.splitext(fn)[1].lower() in supported
    ]
    if not files:
        print(f"Không tìm thấy hình trong: {folder_path}")
    return files


def on_gallery_select(evt: gr.SelectData):
    """
    Khi click thumbnail: trả về
      1) filepath để nạp vào Image Source
      2) tên file (basename) để hiển thị trong Textbox
    """
    val = evt.value

    # --- logic trích filepath y như cũ ---
    if isinstance(val, dict):
        img = val.get("image")
        if isinstance(img, str):
            filepath = img
        elif isinstance(img, dict):
            filepath = img.get("path") or img.get("url")
        else:
            filepath = next(
                (v for v in val.values() if isinstance(v, str) and os.path.isfile(v)),
                None
            )
    elif isinstance(val, str):
        filepath = val
    else:
        raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")

 # Lấy tên file không có phần mở rộng
    filename = os.path.splitext(os.path.basename(filepath))[0] if filepath else ""
    return filepath, filename

# --- Hàm xác định folder dựa trên phân lớp ---
def infer_folder(image) -> str:
    cls = predict(image)["predicted_class"]
    folder = str(FOLDER / cls)
    return folder

def get_face_no_forehead(image):
    """
    Get face mask without forehead extension (only MediaPipe face mesh)
    """
    image = image.convert("RGB")
    
    # Face mesh mask (chỉ lấy mặt MediaPipe, không mở rộng trán)
    face_mesh_mask = get_facemesh_mask(image)
    
    # Làm mượt mask
    face_mesh_mask = cv2.GaussianBlur(face_mesh_mask.astype(np.float32), (3, 3), 0)
    face_mesh_mask = (face_mesh_mask > 0.5).astype(np.uint8)

    np_image = np.array(image)
    alpha = (face_mesh_mask * 255).astype(np.uint8)
    rgba_image = np.dstack([np_image, alpha])
    return Image.fromarray(rgba_image)

# --- Hàm gộp: phân loại + load ảnh ---
def handle_bg_change(image):
    """
    Khi thay đổi background:
    1. Phân loại khuôn mặt
    2. Load ảnh từ folder tương ứng
    """
    if image is None:
        return "", []
    
    try:
        folder = infer_folder(image)
        images = load_images_from_folder(folder)
        return folder, images
    except Exception as e:
        print(f"Lỗi xử lý ảnh: {e}")
        return "", []

# --- Pipeline đầy đủ: workflow 5 bước mới ---
def complete_pipeline(background: Image.Image, source: Image.Image):
    if background is None or source is None:
        return None

    try:
        # Bước 2: Apply baldhead to background (chuyển từ bước 4)
        bg_bald = inference(background)
        print("✅ Applied baldhead to background")
        
        # Bước 1,3,4: Overlay workflow với background baldhead
        overlay_result = overlay_source(bg_bald, source)
        if overlay_result is None:
            print("❌ Overlay workflow failed.")
            return None
        
        print("✅ Overlay workflow completed")

        # Bước 5: Face Swapping - background baldhead và overlay result
        final_result = swap_face_now(
            background.convert("RGB"),       # background baldhead
            overlay_result.convert("RGB")    # overlay result


        )

        if final_result is None:
            print("❌ Face swap failed, using overlay result as fallback")
            result_img = overlay_result
        else:
            print("✅ Face swap success")
            result_img = final_result

        # Clear cache after getting result
        import gc
        import torch
        
        # Clear variables
        overlay_result = None
        bg_bald = None
        final_result = None
        
        # Force garbage collection
        gc.collect()
        
        # Clear GPU cache if available
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print("✅ GPU cache cleared")
        
        print("✅ Memory cache cleared")
        return result_img

    except Exception as e:
        print(f"❌ Lỗi trong complete_pipeline: {e}")
        # Clear cache on error too
        import gc
        import torch
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return None
        
# --- Xây dựng giao diện Gradio ---
def build_demo():
    with gr.Blocks(title="Xử lý hai hình ảnh", theme=gr.themes.Soft()) as demo:
        gr.Markdown("Upload Background & Source, click **Run** to try on wigs.")

        with gr.Row():
            bg = gr.Image(type="pil", label="Background", height=500)
            src = gr.Image(type="pil", label="Source", height=500, interactive=False)
            out = gr.Image(label="Result", height=500, interactive=False)

        folder_path_box = gr.Textbox(label="Folder path", visible=False)

        
        with gr.Row():
            src_name_box = gr.Textbox(
                        label="Wigs Name",
                        interactive=False,
                        show_copy_button=True ,     # tuỳ chọn – tiện copy đường dẫn
                        scale = 1
                        )
            gallery = gr.Gallery(
                            label="Recommend For You",
                            height=300,
                            value=[],
                            type="filepath",
                            interactive=False,
                            columns=5,
                            object_fit="cover",
                            allow_preview=True,
                            scale = 8
                            )
            btn = gr.Button("🔄 Run", variant="primary",scale = 1)  
 


        # Chạy pipeline đầy đủ: ghép tóc + swap face cuối
        btn.click(fn=complete_pipeline, inputs=[bg, src], outputs=[out])
        # Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý
        bg.change(
            fn=handle_bg_change,
            inputs=[bg],
            outputs=[folder_path_box, gallery],
            show_progress=True
        )
        # Nút tải lại ảnh thủ công (backup)
        # Khi chọn ảnh trong gallery, cập nhật vào khung Source
        gallery.select(
            fn=on_gallery_select,
            outputs=[src, src_name_box]
        )

    return demo

if __name__ == "__main__":
    build_demo().launch()