Spaces:
Runtime error
Runtime error
Upload 41 files
Browse files- .gitattributes +33 -0
- README.md +12 -12
- app.py +158 -0
- baldhead.py +272 -0
- bbox_utils.py +31 -0
- detect_face.py +93 -0
- example_wigs/Heart/HH02.png +3 -0
- example_wigs/Heart/HH03.png +3 -0
- example_wigs/Heart/Loire.png +3 -0
- example_wigs/Heart/SantaRossa.png +3 -0
- example_wigs/Heart/Tuscany.png +3 -0
- example_wigs/Oblong/HH01.png +3 -0
- example_wigs/Oblong/HH02.png +3 -0
- example_wigs/Oblong/HH03.png +3 -0
- example_wigs/Oblong/HH07.png +3 -0
- example_wigs/Oblong/Loire.png +3 -0
- example_wigs/Oval/Alsace.png +3 -0
- example_wigs/Oval/Barossa.png +3 -0
- example_wigs/Oval/Burgundy.png +3 -0
- example_wigs/Oval/HH01.png +3 -0
- example_wigs/Oval/HH02.png +3 -0
- example_wigs/Oval/HH03.png +3 -0
- example_wigs/Oval/HH07.png +3 -0
- example_wigs/Oval/Loire.png +3 -0
- example_wigs/Oval/Napa.png +3 -0
- example_wigs/Oval/Piemonte.png +3 -0
- example_wigs/Oval/Rhone.png +3 -0
- example_wigs/Oval/SantaRossa.png +3 -0
- example_wigs/Oval/Sonoma.png +3 -0
- example_wigs/Oval/Tuscany.png +3 -0
- example_wigs/Round/Loire.png +3 -0
- example_wigs/Round/Piemonte.png +3 -0
- example_wigs/Round/Sonoma.png +3 -0
- example_wigs/Round/Tuscany.png +3 -0
- example_wigs/Square/HH03.png +3 -0
- example_wigs/Square/Loire.png +3 -0
- example_wigs/Square/Piemonte.png +3 -0
- example_wigs/Square/Sonoma.png +3 -0
- example_wigs/Square/Tuscany.png +3 -0
- overlay.py +89 -0
- requirements.txt +14 -0
- segmentation.py +31 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
example_wigs/Heart/HH02.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
example_wigs/Heart/HH03.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
example_wigs/Heart/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
example_wigs/Heart/SantaRossa.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
example_wigs/Heart/Tuscany.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
example_wigs/Oblong/HH01.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
example_wigs/Oblong/HH02.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
example_wigs/Oblong/HH03.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
example_wigs/Oblong/HH07.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
example_wigs/Oblong/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
example_wigs/Oval/Alsace.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
example_wigs/Oval/Barossa.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
example_wigs/Oval/Burgundy.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
example_wigs/Oval/HH01.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
example_wigs/Oval/HH02.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
example_wigs/Oval/HH03.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
example_wigs/Oval/HH07.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
example_wigs/Oval/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
example_wigs/Oval/Napa.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
example_wigs/Oval/Piemonte.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
example_wigs/Oval/Rhone.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
example_wigs/Oval/SantaRossa.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
example_wigs/Oval/Sonoma.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
example_wigs/Oval/Tuscany.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
example_wigs/Round/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
example_wigs/Round/Piemonte.png filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
example_wigs/Round/Sonoma.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
example_wigs/Round/Tuscany.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
example_wigs/Square/HH03.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
example_wigs/Square/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
example_wigs/Square/Piemonte.png filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
example_wigs/Square/Sonoma.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
example_wigs/Square/Tuscany.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Ghep Image
|
| 3 |
+
emoji: 📉
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.31.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from overlay import overlay_source
|
| 4 |
+
from detect_face import predict, NUM_CLASSES
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
BASE_DIR = Path(__file__).parent # thư mục chứa app.py
|
| 9 |
+
FOLDER = BASE_DIR / "example_wigs"
|
| 10 |
+
|
| 11 |
+
# --- Hàm load ảnh từ folder ---
|
| 12 |
+
def load_images_from_folder(folder_path: str) -> list[str]:
|
| 13 |
+
"""
|
| 14 |
+
Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path.
|
| 15 |
+
"""
|
| 16 |
+
supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
|
| 17 |
+
if not os.path.isdir(folder_path):
|
| 18 |
+
print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.")
|
| 19 |
+
return []
|
| 20 |
+
files = [
|
| 21 |
+
os.path.join(folder_path, fn)
|
| 22 |
+
for fn in os.listdir(folder_path)
|
| 23 |
+
if os.path.splitext(fn)[1].lower() in supported
|
| 24 |
+
]
|
| 25 |
+
if not files:
|
| 26 |
+
print(f"Không tìm thấy hình trong: {folder_path}")
|
| 27 |
+
return files
|
| 28 |
+
|
| 29 |
+
# --- Handler khi click thumbnail của Gallery ---
|
| 30 |
+
# def on_gallery_select(evt: gr.SelectData):
|
| 31 |
+
# """
|
| 32 |
+
# Xử lý khi click vào ảnh trong gallery - tối ưu và robust.
|
| 33 |
+
# """
|
| 34 |
+
# val = evt.value
|
| 35 |
+
# if isinstance(val, dict):
|
| 36 |
+
# img = val.get("image")
|
| 37 |
+
# if isinstance(img, str): return img
|
| 38 |
+
# if isinstance(img, dict):
|
| 39 |
+
# path = img.get("path") or img.get("url")
|
| 40 |
+
# if isinstance(path, str): return path
|
| 41 |
+
# for v in img.values():
|
| 42 |
+
# if isinstance(v, str) and os.path.isfile(v):
|
| 43 |
+
# return v
|
| 44 |
+
# for v in val.values():
|
| 45 |
+
# if isinstance(v, str) and os.path.isfile(v):
|
| 46 |
+
# return v
|
| 47 |
+
# raise ValueError(f"Không trích được filepath từ dict: {val}")
|
| 48 |
+
# if isinstance(val, str):
|
| 49 |
+
# return val
|
| 50 |
+
# raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")
|
| 51 |
+
|
| 52 |
+
def on_gallery_select(evt: gr.SelectData):
|
| 53 |
+
"""
|
| 54 |
+
Khi click thumbnail: trả về
|
| 55 |
+
1) filepath để nạp vào Image Source
|
| 56 |
+
2) tên file (basename) để hiển thị trong Textbox
|
| 57 |
+
"""
|
| 58 |
+
val = evt.value
|
| 59 |
+
|
| 60 |
+
# --- logic trích filepath y như cũ ---
|
| 61 |
+
if isinstance(val, dict):
|
| 62 |
+
img = val.get("image")
|
| 63 |
+
if isinstance(img, str):
|
| 64 |
+
filepath = img
|
| 65 |
+
elif isinstance(img, dict):
|
| 66 |
+
filepath = img.get("path") or img.get("url")
|
| 67 |
+
else:
|
| 68 |
+
filepath = next(
|
| 69 |
+
(v for v in val.values() if isinstance(v, str) and os.path.isfile(v)),
|
| 70 |
+
None
|
| 71 |
+
)
|
| 72 |
+
elif isinstance(val, str):
|
| 73 |
+
filepath = val
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")
|
| 76 |
+
|
| 77 |
+
filename = os.path.basename(filepath) if filepath else ""
|
| 78 |
+
return filepath, filename
|
| 79 |
+
|
| 80 |
+
# --- Hàm xác định folder dựa trên phân lớp ---
|
| 81 |
+
def infer_folder(image) -> str:
|
| 82 |
+
cls = predict(image)["predicted_class"]
|
| 83 |
+
folder = str(FOLDER / cls)
|
| 84 |
+
return folder
|
| 85 |
+
|
| 86 |
+
# --- Hàm gộp: phân loại + load ảnh ---
|
| 87 |
+
def handle_bg_change(image):
|
| 88 |
+
"""
|
| 89 |
+
Khi thay đổi background:
|
| 90 |
+
1. Phân loại khuôn mặt
|
| 91 |
+
2. Load ảnh từ folder tương ứng
|
| 92 |
+
"""
|
| 93 |
+
if image is None:
|
| 94 |
+
return "", []
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
folder = infer_folder(image)
|
| 98 |
+
images = load_images_from_folder(folder)
|
| 99 |
+
return folder, images
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(f"Lỗi xử lý ảnh: {e}")
|
| 102 |
+
return "", []
|
| 103 |
+
|
| 104 |
+
# --- Xây dựng giao diện Gradio ---
|
| 105 |
+
def build_demo():
|
| 106 |
+
with gr.Blocks(title="Xử lý hai hình ảnh", theme=gr.themes.Soft()) as demo:
|
| 107 |
+
gr.Markdown("Upload Background & Source, click **Run** to try on wigs.")
|
| 108 |
+
|
| 109 |
+
with gr.Row():
|
| 110 |
+
bg = gr.Image(type="pil", label="Background", height=500)
|
| 111 |
+
src = gr.Image(type="pil", label="Source", height=500, interactive=False)
|
| 112 |
+
out = gr.Image(label="Result", height=500, interactive=False)
|
| 113 |
+
|
| 114 |
+
folder_path_box = gr.Textbox(label="Folder path", visible=False)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
with gr.Row():
|
| 118 |
+
src_name_box = gr.Textbox(
|
| 119 |
+
label="Wigs Name",
|
| 120 |
+
interactive=False,
|
| 121 |
+
show_copy_button=True , # tuỳ chọn – tiện copy đường dẫn
|
| 122 |
+
scale = 1
|
| 123 |
+
)
|
| 124 |
+
gallery = gr.Gallery(
|
| 125 |
+
label="Recommend For You",
|
| 126 |
+
height=300,
|
| 127 |
+
value=[],
|
| 128 |
+
type="filepath",
|
| 129 |
+
interactive=False,
|
| 130 |
+
columns=5,
|
| 131 |
+
object_fit="cover",
|
| 132 |
+
allow_preview=True,
|
| 133 |
+
scale = 8
|
| 134 |
+
)
|
| 135 |
+
btn = gr.Button("🔄 Run", variant="primary",scale = 1)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Chạy ghép tóc
|
| 140 |
+
btn.click(fn=overlay_source, inputs=[bg, src], outputs=[out])
|
| 141 |
+
# Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý
|
| 142 |
+
bg.change(
|
| 143 |
+
fn=handle_bg_change,
|
| 144 |
+
inputs=[bg],
|
| 145 |
+
outputs=[folder_path_box, gallery],
|
| 146 |
+
show_progress=True
|
| 147 |
+
)
|
| 148 |
+
# Nút tải lại ảnh thủ công (backup)
|
| 149 |
+
# Khi chọn ảnh trong gallery, cập nhật vào khung Source
|
| 150 |
+
gallery.select(
|
| 151 |
+
fn=on_gallery_select,
|
| 152 |
+
outputs=[src, src_name_box]
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return demo
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
build_demo().launch()
|
baldhead.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# baldhead.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
# Keras imports (note: keras-contrib must be installed)
|
| 11 |
+
import keras.backend as K
|
| 12 |
+
from keras.layers import (
|
| 13 |
+
Input,
|
| 14 |
+
Conv2D,
|
| 15 |
+
UpSampling2D,
|
| 16 |
+
LeakyReLU,
|
| 17 |
+
GlobalAveragePooling2D,
|
| 18 |
+
Dense,
|
| 19 |
+
Reshape,
|
| 20 |
+
Dropout,
|
| 21 |
+
Concatenate,
|
| 22 |
+
multiply, # ← Thêm import multiply
|
| 23 |
+
)
|
| 24 |
+
from keras.models import Model
|
| 25 |
+
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
|
| 26 |
+
|
| 27 |
+
# RetinaFace + skimage for face alignment
|
| 28 |
+
from retinaface import RetinaFace
|
| 29 |
+
from skimage import transform as trans
|
| 30 |
+
|
| 31 |
+
# Hugging Face Hub helper
|
| 32 |
+
from huggingface_hub import hf_hub_download
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# --- Face‐alignment helpers (giống code gốc) ---
|
| 37 |
+
image_size = [256, 256]
|
| 38 |
+
src_landmarks = np.array([
|
| 39 |
+
[30.2946, 51.6963],
|
| 40 |
+
[65.5318, 51.5014],
|
| 41 |
+
[48.0252, 71.7366],
|
| 42 |
+
[33.5493, 92.3655],
|
| 43 |
+
[62.7299, 92.2041],
|
| 44 |
+
], dtype=np.float32)
|
| 45 |
+
src_landmarks[:, 0] += 8.0
|
| 46 |
+
src_landmarks[:, 0] += 15.0
|
| 47 |
+
src_landmarks[:, 1] += 30.0
|
| 48 |
+
src_landmarks /= 112
|
| 49 |
+
src_landmarks *= 200
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def list2array(values):
|
| 53 |
+
return np.array(list(values))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def align_face(img: np.ndarray):
|
| 57 |
+
"""
|
| 58 |
+
Detect faces + landmarks in `img` via RetinaFace.
|
| 59 |
+
Returns lists of aligned face patches (256×256 RGB),
|
| 60 |
+
corresponding binary masks, and the transformation matrices.
|
| 61 |
+
"""
|
| 62 |
+
faces = RetinaFace.detect_faces(img)
|
| 63 |
+
bboxes = np.array([list2array(faces[f]['facial_area']) for f in faces])
|
| 64 |
+
landmarks = np.array([list2array(faces[f]['landmarks'].values()) for f in faces])
|
| 65 |
+
|
| 66 |
+
white_canvas = np.ones(img.shape, dtype=np.uint8) * 255
|
| 67 |
+
aligned_faces, masks, matrices = [], [], []
|
| 68 |
+
|
| 69 |
+
if bboxes.shape[0] > 0:
|
| 70 |
+
for i in range(bboxes.shape[0]):
|
| 71 |
+
dst = landmarks[i] # detected landmarks
|
| 72 |
+
tform = trans.SimilarityTransform()
|
| 73 |
+
tform.estimate(dst, src_landmarks)
|
| 74 |
+
M = tform.params[0:2, :]
|
| 75 |
+
|
| 76 |
+
warped_face = cv2.warpAffine(
|
| 77 |
+
img, M, (image_size[1], image_size[0]), borderValue=0.0
|
| 78 |
+
)
|
| 79 |
+
warped_mask = cv2.warpAffine(
|
| 80 |
+
white_canvas, M, (image_size[1], image_size[0]), borderValue=0.0
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
aligned_faces.append(warped_face)
|
| 84 |
+
masks.append(warped_mask)
|
| 85 |
+
matrices.append(tform.params[0:3, :])
|
| 86 |
+
|
| 87 |
+
return aligned_faces, masks, matrices
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def put_face_back(
|
| 91 |
+
orig_img: np.ndarray,
|
| 92 |
+
processed_faces: list[np.ndarray],
|
| 93 |
+
masks: list[np.ndarray],
|
| 94 |
+
matrices: list[np.ndarray],
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Warp each processed face back onto the original `orig_img`
|
| 98 |
+
using the inverse of the transformation matrices.
|
| 99 |
+
"""
|
| 100 |
+
result = orig_img.copy()
|
| 101 |
+
h, w = orig_img.shape[:2]
|
| 102 |
+
|
| 103 |
+
for i in range(len(processed_faces)):
|
| 104 |
+
invM = np.linalg.inv(matrices[i])[0:2]
|
| 105 |
+
warped = cv2.warpAffine(processed_faces[i], invM, (w, h), borderValue=0.0)
|
| 106 |
+
mask = cv2.warpAffine(masks[i], invM, (w, h), borderValue=0.0)
|
| 107 |
+
binary_mask = (mask // 255).astype(np.uint8)
|
| 108 |
+
|
| 109 |
+
# Composite: result = result * (1 - mask) + warped * mask
|
| 110 |
+
result = result * (1 - binary_mask)
|
| 111 |
+
result = result.astype(np.uint8)
|
| 112 |
+
result = result + warped * binary_mask
|
| 113 |
+
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ----------------------------
|
| 118 |
+
# 2. GENERATOR ARCHITECTURE
|
| 119 |
+
# ----------------------------
|
| 120 |
+
|
| 121 |
+
def squeeze_excite_block(x, ratio=4):
|
| 122 |
+
"""
|
| 123 |
+
Squeeze-and-Excitation block: channel-wise attention.
|
| 124 |
+
"""
|
| 125 |
+
init = x
|
| 126 |
+
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 127 |
+
filters = init.shape[channel_axis]
|
| 128 |
+
se_shape = (1, 1, filters)
|
| 129 |
+
|
| 130 |
+
se = GlobalAveragePooling2D()(init)
|
| 131 |
+
se = Reshape(se_shape)(se)
|
| 132 |
+
se = Dense(filters // ratio, activation="relu", kernel_initializer="he_normal", use_bias=False)(se)
|
| 133 |
+
se = Dense(filters, activation="sigmoid", kernel_initializer="he_normal", use_bias=False)(se)
|
| 134 |
+
return multiply([init, se])
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def conv2d(layer_input, filters, f_size=4, bn=True, se=False):
|
| 138 |
+
"""
|
| 139 |
+
Downsampling block: Conv2D → LeakyReLU → (InstanceNorm) → (SE block)
|
| 140 |
+
"""
|
| 141 |
+
d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same")(layer_input)
|
| 142 |
+
d = LeakyReLU(alpha=0.2)(d)
|
| 143 |
+
if bn:
|
| 144 |
+
d = InstanceNormalization()(d)
|
| 145 |
+
if se:
|
| 146 |
+
d = squeeze_excite_block(d)
|
| 147 |
+
return d
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def atrous(layer_input, filters, f_size=4, bn=True):
|
| 151 |
+
"""
|
| 152 |
+
Atrous (dilated) convolution block with dilation rates [2,4,8].
|
| 153 |
+
"""
|
| 154 |
+
a_list = []
|
| 155 |
+
for rate in [2, 4, 8]:
|
| 156 |
+
a = Conv2D(filters, f_size, dilation_rate=rate, padding="same")(layer_input)
|
| 157 |
+
a_list.append(a)
|
| 158 |
+
a = Concatenate()(a_list)
|
| 159 |
+
a = LeakyReLU(alpha=0.2)(a)
|
| 160 |
+
if bn:
|
| 161 |
+
a = InstanceNormalization()(a)
|
| 162 |
+
return a
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
|
| 166 |
+
"""
|
| 167 |
+
Upsampling block: UpSampling2D → Conv2D → (Dropout) → InstanceNorm → Concatenate(skip)
|
| 168 |
+
"""
|
| 169 |
+
u = UpSampling2D(size=2)(layer_input)
|
| 170 |
+
u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same", activation="relu")(u)
|
| 171 |
+
if dropout_rate:
|
| 172 |
+
u = Dropout(dropout_rate)(u)
|
| 173 |
+
u = InstanceNormalization()(u)
|
| 174 |
+
u = Concatenate()([u, skip_input])
|
| 175 |
+
return u
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def build_generator():
|
| 179 |
+
"""
|
| 180 |
+
Reconstruct the generator architecture exactly as in the notebook,
|
| 181 |
+
then return a Keras Model object.
|
| 182 |
+
"""
|
| 183 |
+
d0 = Input(shape=(256, 256, 3))
|
| 184 |
+
gf = 64
|
| 185 |
+
|
| 186 |
+
# Downsampling
|
| 187 |
+
d1 = conv2d(d0, gf, bn=False, se=True)
|
| 188 |
+
d2 = conv2d(d1, gf * 2, se=True)
|
| 189 |
+
d3 = conv2d(d2, gf * 4, se=True)
|
| 190 |
+
d4 = conv2d(d3, gf * 8)
|
| 191 |
+
d5 = conv2d(d4, gf * 8)
|
| 192 |
+
|
| 193 |
+
# Atrous block
|
| 194 |
+
a1 = atrous(d5, gf * 8)
|
| 195 |
+
|
| 196 |
+
# Upsampling
|
| 197 |
+
u3 = deconv2d(a1, d4, gf * 8)
|
| 198 |
+
u4 = deconv2d(u3, d3, gf * 4)
|
| 199 |
+
u5 = deconv2d(u4, d2, gf * 2)
|
| 200 |
+
u6 = deconv2d(u5, d1, gf)
|
| 201 |
+
|
| 202 |
+
# Final upsample + conv
|
| 203 |
+
u7 = UpSampling2D(size=2)(u6)
|
| 204 |
+
output_img = Conv2D(3, kernel_size=4, strides=1, padding="same", activation="tanh")(u7)
|
| 205 |
+
|
| 206 |
+
model = Model(d0, output_img)
|
| 207 |
+
return model
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ----------------------------
|
| 211 |
+
# 3. LOAD MODEL WEIGHTS
|
| 212 |
+
# ----------------------------
|
| 213 |
+
|
| 214 |
+
HF_REPO_ID = "VanNguyen1214/baldhead"
|
| 215 |
+
HF_FILENAME = "model_G_5_170.hdf5"
|
| 216 |
+
HF_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
| 217 |
+
|
| 218 |
+
def load_generator_from_hub():
|
| 219 |
+
"""
|
| 220 |
+
Download the .hdf5 weights from HF Hub into cache,
|
| 221 |
+
rebuild the generator, then load weights.
|
| 222 |
+
"""
|
| 223 |
+
local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME,token=HF_TOKEN)
|
| 224 |
+
gen = build_generator()
|
| 225 |
+
gen.load_weights(local_path)
|
| 226 |
+
return gen
|
| 227 |
+
|
| 228 |
+
# Load once at startup
|
| 229 |
+
try:
|
| 230 |
+
GENERATOR = load_generator_from_hub()
|
| 231 |
+
print(f"[INFO] Loaded generator weights from {HF_REPO_ID}/{HF_FILENAME}")
|
| 232 |
+
except Exception as e:
|
| 233 |
+
print("[ERROR] Could not load generator:", e)
|
| 234 |
+
GENERATOR = None
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# ----------------------------
|
| 238 |
+
# 4. INFERENCE FUNCTION
|
| 239 |
+
# ----------------------------
|
| 240 |
+
|
| 241 |
+
def inference(image: Image.Image) -> Image.Image:
|
| 242 |
+
"""
|
| 243 |
+
Gradio-compatible inference function:
|
| 244 |
+
- Convert PIL→ numpy RGB
|
| 245 |
+
- Align faces
|
| 246 |
+
- For each face: normalize to [-1,1], run through generator, denormalize to uint8
|
| 247 |
+
- Put processed faces back onto original image
|
| 248 |
+
- Return full-image PIL
|
| 249 |
+
"""
|
| 250 |
+
if GENERATOR is None:
|
| 251 |
+
return image
|
| 252 |
+
|
| 253 |
+
orig = np.array(image.convert("RGB"))
|
| 254 |
+
|
| 255 |
+
faces, masks, mats = align_face(orig)
|
| 256 |
+
if len(faces) == 0:
|
| 257 |
+
return image
|
| 258 |
+
|
| 259 |
+
processed_faces = []
|
| 260 |
+
for face in faces:
|
| 261 |
+
face_input = face.astype(np.float32)
|
| 262 |
+
face_input = (face_input / 127.5) - 1.0 # scale to [-1,1]
|
| 263 |
+
face_input = np.expand_dims(face_input, axis=0) # (1,256,256,3)
|
| 264 |
+
|
| 265 |
+
pred = GENERATOR.predict(face_input)[0] # (256,256,3) in [-1,1]
|
| 266 |
+
pred = ((pred + 1.0) * 127.5).astype(np.uint8)
|
| 267 |
+
processed_faces.append(pred)
|
| 268 |
+
|
| 269 |
+
output_np = put_face_back(orig, processed_faces, masks, mats)
|
| 270 |
+
output_pil = Image.fromarray(output_np)
|
| 271 |
+
|
| 272 |
+
return output_pil
|
bbox_utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
def get_bbox_from_alpha(rgba: Image.Image):
|
| 5 |
+
arr = np.array(rgba)
|
| 6 |
+
alpha = arr[...,3]
|
| 7 |
+
ys, xs = np.where(alpha>0)
|
| 8 |
+
if ys.size == 0:
|
| 9 |
+
return None
|
| 10 |
+
x1, x2 = xs.min(), xs.max()
|
| 11 |
+
y1, y2 = ys.min(), ys.max()
|
| 12 |
+
return x1, y1, x2, y2
|
| 13 |
+
|
| 14 |
+
def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int,int]) -> Image.Image:
|
| 15 |
+
res = bg.copy()
|
| 16 |
+
x, y = offset
|
| 17 |
+
h, w = src.shape[:2]
|
| 18 |
+
x1, y1 = max(x,0), max(y,0)
|
| 19 |
+
x2 = min(x+w, bg.shape[1])
|
| 20 |
+
y2 = min(y+h, bg.shape[0])
|
| 21 |
+
if x1>=x2 or y1>=y2:
|
| 22 |
+
return Image.fromarray(res)
|
| 23 |
+
cs = src[y1-y:y2-y, x1-x:x2-x]
|
| 24 |
+
cd = res[y1:y2, x1:x2]
|
| 25 |
+
mask = cs[...,3]>0
|
| 26 |
+
if cd.shape[2]==3:
|
| 27 |
+
cd[mask] = cs[mask][..., :3]
|
| 28 |
+
else:
|
| 29 |
+
cd[mask] = cs[mask]
|
| 30 |
+
res[y1:y2, x1:x2] = cd
|
| 31 |
+
return Image.fromarray(res)
|
detect_face.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchvision
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
# --- Cấu hình chung ---
|
| 14 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
+
HF_REPO = "VanNguyen1214/detect_faceshape" # repo của bạn trên HF Hub
|
| 16 |
+
HF_FILENAME = "best_model.pth" # file ở root của repo
|
| 17 |
+
LOCAL_CKPT = "models/best_model.pth" # sẽ lưu tại đây
|
| 18 |
+
CLASS_NAMES = ['Heart', 'Oblong', 'Oval', 'Round', 'Square']
|
| 19 |
+
NUM_CLASSES = len(CLASS_NAMES)
|
| 20 |
+
|
| 21 |
+
# --- Transform cho ảnh trước inference ---
|
| 22 |
+
_TRANSFORM = transforms.Compose([
|
| 23 |
+
transforms.Resize((224, 224)),
|
| 24 |
+
transforms.ToTensor(),
|
| 25 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 26 |
+
std =[0.229, 0.224, 0.225]),
|
| 27 |
+
])
|
| 28 |
+
|
| 29 |
+
def _ensure_checkpoint() -> str:
|
| 30 |
+
"""
|
| 31 |
+
Kiểm tra xem LOCAL_CKPT đã tồn tại chưa.
|
| 32 |
+
Nếu chưa, tải best_model.pth từ HF_REPO và lưu vào ./models/
|
| 33 |
+
"""
|
| 34 |
+
if os.path.exists(LOCAL_CKPT):
|
| 35 |
+
return LOCAL_CKPT
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
ckpt_path = hf_hub_download(
|
| 39 |
+
repo_id=HF_REPO,
|
| 40 |
+
filename=HF_FILENAME,
|
| 41 |
+
local_dir="models",
|
| 42 |
+
)
|
| 43 |
+
return ckpt_path
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"❌ Không tải được model từ HF Hub: {e}")
|
| 46 |
+
sys.exit(1)
|
| 47 |
+
|
| 48 |
+
def _load_model(ckpt_path: str) -> torch.nn.Module:
|
| 49 |
+
"""
|
| 50 |
+
Tái tạo kiến trúc EfficientNet-B4, load state_dict, đưa về eval mode.
|
| 51 |
+
"""
|
| 52 |
+
# 1) Khởi tạo EfficientNet-B4
|
| 53 |
+
model = torchvision.models.efficientnet_b4(pretrained=False)
|
| 54 |
+
in_features = model.classifier[1].in_features
|
| 55 |
+
model.classifier = nn.Sequential(
|
| 56 |
+
nn.Dropout(p=0.3, inplace=True),
|
| 57 |
+
nn.Linear(in_features, NUM_CLASSES)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# 2) Load trọng số
|
| 61 |
+
state = torch.load(ckpt_path, map_location=DEVICE)
|
| 62 |
+
model.load_state_dict(state)
|
| 63 |
+
|
| 64 |
+
# 3) Đưa model về chế độ evaluation
|
| 65 |
+
return model.to(DEVICE).eval()
|
| 66 |
+
|
| 67 |
+
# === Build model ngay khi import ===
|
| 68 |
+
_CKPT_PATH = _ensure_checkpoint()
|
| 69 |
+
_MODEL = _load_model(_CKPT_PATH)
|
| 70 |
+
|
| 71 |
+
def predict(image: Image.Image) -> dict:
|
| 72 |
+
"""
|
| 73 |
+
Chức năng inference:
|
| 74 |
+
- image: numpy array H×W×3 RGB
|
| 75 |
+
- Trả về dict:
|
| 76 |
+
{
|
| 77 |
+
"predicted_class": str,
|
| 78 |
+
"confidence": float,
|
| 79 |
+
"probabilities": { class_name: prob, ... }
|
| 80 |
+
}
|
| 81 |
+
"""
|
| 82 |
+
# Convert về PIL + transform
|
| 83 |
+
img = image.convert("RGB")
|
| 84 |
+
x = _TRANSFORM(img).unsqueeze(0).to(DEVICE)
|
| 85 |
+
|
| 86 |
+
# Inference
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
logits = _MODEL(x)
|
| 89 |
+
probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
|
| 90 |
+
|
| 91 |
+
idx = int(probs.argmax())
|
| 92 |
+
return {"predicted_class": CLASS_NAMES[idx]}
|
| 93 |
+
|
example_wigs/Heart/HH02.png
ADDED
|
Git LFS Details
|
example_wigs/Heart/HH03.png
ADDED
|
Git LFS Details
|
example_wigs/Heart/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Heart/SantaRossa.png
ADDED
|
Git LFS Details
|
example_wigs/Heart/Tuscany.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/HH01.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/HH02.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/HH03.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/HH07.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Alsace.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Barossa.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Burgundy.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/HH01.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/HH02.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/HH03.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/HH07.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Napa.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Piemonte.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Rhone.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/SantaRossa.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Sonoma.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Tuscany.png
ADDED
|
Git LFS Details
|
example_wigs/Round/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Round/Piemonte.png
ADDED
|
Git LFS Details
|
example_wigs/Round/Sonoma.png
ADDED
|
Git LFS Details
|
example_wigs/Round/Tuscany.png
ADDED
|
Git LFS Details
|
example_wigs/Square/HH03.png
ADDED
|
Git LFS Details
|
example_wigs/Square/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Square/Piemonte.png
ADDED
|
Git LFS Details
|
example_wigs/Square/Sonoma.png
ADDED
|
Git LFS Details
|
example_wigs/Square/Tuscany.png
ADDED
|
Git LFS Details
|
overlay.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import mediapipe as mp
|
| 4 |
+
|
| 5 |
+
from baldhead import inference # cạo tóc background
|
| 6 |
+
from segmentation import extract_hair
|
| 7 |
+
|
| 8 |
+
# MediaPipe Face Detection
|
| 9 |
+
mp_fd = mp.solutions.face_detection.FaceDetection(model_selection=1,
|
| 10 |
+
min_detection_confidence=0.5)
|
| 11 |
+
|
| 12 |
+
def get_face_bbox(img: Image.Image) -> tuple[int,int,int,int] | None:
|
| 13 |
+
arr = np.array(img.convert("RGB"))
|
| 14 |
+
res = mp_fd.process(arr)
|
| 15 |
+
if not res.detections:
|
| 16 |
+
return None
|
| 17 |
+
d = res.detections[0].location_data.relative_bounding_box
|
| 18 |
+
h, w = arr.shape[:2]
|
| 19 |
+
x1 = int(d.xmin * w)
|
| 20 |
+
y1 = int(d.ymin * h)
|
| 21 |
+
x2 = x1 + int(d.width * w)
|
| 22 |
+
y2 = y1 + int(d.height * h)
|
| 23 |
+
return x1, y1, x2, y2
|
| 24 |
+
|
| 25 |
+
def compute_scale(w_bg, h_bg, w_src, h_src) -> float:
|
| 26 |
+
return ((w_bg / w_src) + (h_bg / h_src)) / 2
|
| 27 |
+
|
| 28 |
+
def compute_offset(bbox_bg, bbox_src, scale) -> tuple[int,int]:
|
| 29 |
+
x1, y1, x2, y2 = bbox_bg
|
| 30 |
+
bg_cx = x1 + (x2 - x1)//2
|
| 31 |
+
bg_cy = y1 + (y2 - y1)//2
|
| 32 |
+
sx1, sy1, sx2, sy2 = bbox_src
|
| 33 |
+
src_cx = int((sx1 + (sx2 - sx1)//2) * scale)
|
| 34 |
+
src_cy = int((sy1 + (sy2 - sy1)//2) * scale)
|
| 35 |
+
return bg_cx - src_cx, bg_cy - src_cy
|
| 36 |
+
|
| 37 |
+
def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int,int]) -> Image.Image:
|
| 38 |
+
res = bg.copy()
|
| 39 |
+
x, y = offset
|
| 40 |
+
h, w = src.shape[:2]
|
| 41 |
+
x1, y1 = max(x,0), max(y,0)
|
| 42 |
+
x2 = min(x+w, bg.shape[1])
|
| 43 |
+
y2 = min(y+h, bg.shape[0])
|
| 44 |
+
if x1>=x2 or y1>=y2:
|
| 45 |
+
return Image.fromarray(res)
|
| 46 |
+
cs = src[y1-y:y2-y, x1-x:x2-x]
|
| 47 |
+
cd = res[y1:y2, x1:x2]
|
| 48 |
+
mask = cs[...,3] > 0
|
| 49 |
+
if cd.shape[2] == 3:
|
| 50 |
+
cd[mask] = cs[mask][...,:3]
|
| 51 |
+
else:
|
| 52 |
+
cd[mask] = cs[mask]
|
| 53 |
+
res[y1:y2, x1:x2] = cd
|
| 54 |
+
return Image.fromarray(res)
|
| 55 |
+
|
| 56 |
+
def overlay_source(background: Image.Image, source: Image.Image):
|
| 57 |
+
# 1) detect bboxes
|
| 58 |
+
bbox_bg = get_face_bbox(background)
|
| 59 |
+
bbox_src = get_face_bbox(source)
|
| 60 |
+
if bbox_bg is None:
|
| 61 |
+
return None, "❌ No face in background."
|
| 62 |
+
if bbox_src is None:
|
| 63 |
+
return None, "❌ No face in source."
|
| 64 |
+
|
| 65 |
+
# 2) compute scale & resize source
|
| 66 |
+
w_bg, h_bg = bbox_bg[2]-bbox_bg[0], bbox_bg[3]-bbox_bg[1]
|
| 67 |
+
w_src, h_src = bbox_src[2]-bbox_src[0], bbox_src[3]-bbox_src[1]
|
| 68 |
+
scale = compute_scale(w_bg, h_bg, w_src, h_src)
|
| 69 |
+
src_scaled = source.resize(
|
| 70 |
+
(int(source.width*scale), int(source.height*scale)),
|
| 71 |
+
Image.Resampling.LANCZOS
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# 3) compute offset
|
| 75 |
+
offset = compute_offset(bbox_bg, bbox_src, scale)
|
| 76 |
+
|
| 77 |
+
# 4) baldhead background
|
| 78 |
+
bg_bald = inference(background)
|
| 79 |
+
|
| 80 |
+
# 5) extract hair-only from source
|
| 81 |
+
hair_only = extract_hair(src_scaled)
|
| 82 |
+
|
| 83 |
+
# 6) paste onto bald background
|
| 84 |
+
result = paste_with_alpha(
|
| 85 |
+
np.array(bg_bald.convert("RGBA")),
|
| 86 |
+
np.array(hair_only),
|
| 87 |
+
offset
|
| 88 |
+
)
|
| 89 |
+
return result
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
transformers==4.36.0
|
| 3 |
+
torch==2.1.0
|
| 4 |
+
torchvision==0.16.0
|
| 5 |
+
huggingface-hub==0.19.4
|
| 6 |
+
Pillow==9.2.0
|
| 7 |
+
opencv-python-headless==4.8.1.78
|
| 8 |
+
numpy==1.24.3
|
| 9 |
+
mediapipe==0.10.8
|
| 10 |
+
tensorflow==2.11.0
|
| 11 |
+
keras==2.11.0
|
| 12 |
+
scikit-image==0.20.0
|
| 13 |
+
git+https://github.com/keras-team/keras-contrib.git
|
| 14 |
+
retina-face==0.0.13
|
segmentation.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
# Load SegFormer for hair segmentation
|
| 8 |
+
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
|
| 9 |
+
model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")
|
| 10 |
+
|
| 11 |
+
def extract_hair(image: Image.Image) -> Image.Image:
|
| 12 |
+
"""
|
| 13 |
+
Return an RGBA image where hair pixels have alpha=255 and
|
| 14 |
+
all other pixels have alpha=0.
|
| 15 |
+
"""
|
| 16 |
+
rgb = image.convert("RGB")
|
| 17 |
+
arr = np.array(rgb)
|
| 18 |
+
h, w = arr.shape[:2]
|
| 19 |
+
|
| 20 |
+
# Segment hair
|
| 21 |
+
inputs = processor(images=rgb, return_tensors="pt")
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
logits = model(**inputs).logits.cpu()
|
| 24 |
+
up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
|
| 25 |
+
seg = up.argmax(dim=1)[0].numpy()
|
| 26 |
+
hair_mask = (seg == 2).astype(np.uint8)
|
| 27 |
+
|
| 28 |
+
# Build RGBA
|
| 29 |
+
alpha = (hair_mask * 255).astype(np.uint8)
|
| 30 |
+
rgba = np.dstack([arr, alpha])
|
| 31 |
+
return Image.fromarray(rgba)
|