Spaces:
Runtime error
Runtime error
Upload 58 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +33 -0
- README.md +12 -12
- app.py +182 -0
- baldhead.py +272 -0
- bbox_utils.py +31 -0
- detect_face.py +93 -0
- example_wigs/Heart/HH02.png +3 -0
- example_wigs/Heart/HH03.png +3 -0
- example_wigs/Heart/Loire.png +3 -0
- example_wigs/Heart/SantaRossa.png +3 -0
- example_wigs/Heart/Tuscany.png +3 -0
- example_wigs/Oblong/HH01.png +3 -0
- example_wigs/Oblong/HH02.png +3 -0
- example_wigs/Oblong/HH03.png +3 -0
- example_wigs/Oblong/HH07.png +3 -0
- example_wigs/Oblong/Loire.png +3 -0
- example_wigs/Oval/Alsace.png +3 -0
- example_wigs/Oval/Barossa.png +3 -0
- example_wigs/Oval/Burgundy.png +3 -0
- example_wigs/Oval/HH01.png +3 -0
- example_wigs/Oval/HH02.png +3 -0
- example_wigs/Oval/HH03.png +3 -0
- example_wigs/Oval/HH07.png +3 -0
- example_wigs/Oval/Loire.png +3 -0
- example_wigs/Oval/Napa.png +3 -0
- example_wigs/Oval/Piemonte.png +3 -0
- example_wigs/Oval/Rhone.png +3 -0
- example_wigs/Oval/SantaRossa.png +3 -0
- example_wigs/Oval/Sonoma.png +3 -0
- example_wigs/Oval/Tuscany.png +3 -0
- example_wigs/Round/Loire.png +3 -0
- example_wigs/Round/Piemonte.png +3 -0
- example_wigs/Round/Sonoma.png +3 -0
- example_wigs/Round/Tuscany.png +3 -0
- example_wigs/Square/HH03.png +3 -0
- example_wigs/Square/Loire.png +3 -0
- example_wigs/Square/Piemonte.png +3 -0
- example_wigs/Square/Sonoma.png +3 -0
- example_wigs/Square/Tuscany.png +3 -0
- overlay.py +89 -0
- requirements.txt +35 -0
- roop/__init__.py +0 -0
- roop/capturer.py +20 -0
- roop/core.py +217 -0
- roop/face_analyser.py +124 -0
- roop/globals.py +17 -0
- roop/metadata.py +2 -0
- roop/predicter.py +25 -0
- roop/processors/__init__.py +0 -0
- roop/processors/frame/__init__.py +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
example_wigs/Heart/HH02.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
example_wigs/Heart/HH03.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
example_wigs/Heart/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
example_wigs/Heart/SantaRossa.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
example_wigs/Heart/Tuscany.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
example_wigs/Oblong/HH01.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
example_wigs/Oblong/HH02.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
example_wigs/Oblong/HH03.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
example_wigs/Oblong/HH07.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
example_wigs/Oblong/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
example_wigs/Oval/Alsace.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
example_wigs/Oval/Barossa.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
example_wigs/Oval/Burgundy.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
example_wigs/Oval/HH01.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
example_wigs/Oval/HH02.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
example_wigs/Oval/HH03.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
example_wigs/Oval/HH07.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
example_wigs/Oval/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
example_wigs/Oval/Napa.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
example_wigs/Oval/Piemonte.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
example_wigs/Oval/Rhone.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
example_wigs/Oval/SantaRossa.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
example_wigs/Oval/Sonoma.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
example_wigs/Oval/Tuscany.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
example_wigs/Round/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
example_wigs/Round/Piemonte.png filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
example_wigs/Round/Sonoma.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
example_wigs/Round/Tuscany.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
example_wigs/Square/HH03.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
example_wigs/Square/Loire.png filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
example_wigs/Square/Piemonte.png filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
example_wigs/Square/Sonoma.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
example_wigs/Square/Tuscany.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Ghep Image
|
| 3 |
+
emoji: 📉
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.31.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from overlay import overlay_source
|
| 4 |
+
from detect_face import predict, NUM_CLASSES
|
| 5 |
+
from swapface import swap_face_now
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
BASE_DIR = Path(__file__).parent # thư mục chứa app.py
|
| 10 |
+
FOLDER = BASE_DIR / "example_wigs"
|
| 11 |
+
|
| 12 |
+
# --- Hàm load ảnh từ folder ---
|
| 13 |
+
def load_images_from_folder(folder_path: str) -> list[str]:
|
| 14 |
+
"""
|
| 15 |
+
Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path.
|
| 16 |
+
"""
|
| 17 |
+
supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
|
| 18 |
+
if not os.path.isdir(folder_path):
|
| 19 |
+
print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.")
|
| 20 |
+
return []
|
| 21 |
+
files = [
|
| 22 |
+
os.path.join(folder_path, fn)
|
| 23 |
+
for fn in os.listdir(folder_path)
|
| 24 |
+
if os.path.splitext(fn)[1].lower() in supported
|
| 25 |
+
]
|
| 26 |
+
if not files:
|
| 27 |
+
print(f"Không tìm thấy hình trong: {folder_path}")
|
| 28 |
+
return files
|
| 29 |
+
|
| 30 |
+
def on_gallery_select(evt: gr.SelectData):
|
| 31 |
+
"""
|
| 32 |
+
Khi click thumbnail: trả về
|
| 33 |
+
1) filepath để nạp vào Image Source
|
| 34 |
+
2) tên file (basename) để hiển thị trong Textbox
|
| 35 |
+
"""
|
| 36 |
+
val = evt.value
|
| 37 |
+
|
| 38 |
+
# --- logic trích filepath y như cũ ---
|
| 39 |
+
if isinstance(val, dict):
|
| 40 |
+
img = val.get("image")
|
| 41 |
+
if isinstance(img, str):
|
| 42 |
+
filepath = img
|
| 43 |
+
elif isinstance(img, dict):
|
| 44 |
+
filepath = img.get("path") or img.get("url")
|
| 45 |
+
else:
|
| 46 |
+
filepath = next(
|
| 47 |
+
(v for v in val.values() if isinstance(v, str) and os.path.isfile(v)),
|
| 48 |
+
None
|
| 49 |
+
)
|
| 50 |
+
elif isinstance(val, str):
|
| 51 |
+
filepath = val
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")
|
| 54 |
+
|
| 55 |
+
filename = os.path.basename(filepath) if filepath else ""
|
| 56 |
+
return filepath, filename
|
| 57 |
+
|
| 58 |
+
# --- Hàm xác định folder dựa trên phân lớp ---
|
| 59 |
+
def infer_folder(image) -> str:
|
| 60 |
+
cls = predict(image)["predicted_class"]
|
| 61 |
+
folder = str(FOLDER / cls)
|
| 62 |
+
return folder
|
| 63 |
+
|
| 64 |
+
# --- Hàm gộp: phân loại + load ảnh ---
|
| 65 |
+
def handle_bg_change(image):
|
| 66 |
+
"""
|
| 67 |
+
Khi thay đổi background:
|
| 68 |
+
1. Phân loại khuôn mặt
|
| 69 |
+
2. Load ảnh từ folder tương ứng
|
| 70 |
+
"""
|
| 71 |
+
if image is None:
|
| 72 |
+
return "", []
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
folder = infer_folder(image)
|
| 76 |
+
images = load_images_from_folder(folder)
|
| 77 |
+
return folder, images
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"Lỗi xử lý ảnh: {e}")
|
| 80 |
+
return "", []
|
| 81 |
+
|
| 82 |
+
# --- Hàm swap face ---
|
| 83 |
+
def swap_face_wrapper(background_img, result_img):
|
| 84 |
+
"""
|
| 85 |
+
Wrapper function cho swap face giữa background và result image
|
| 86 |
+
"""
|
| 87 |
+
if background_img is None or result_img is None:
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
# Swap face từ background vào result image
|
| 92 |
+
swapped = swap_face_now(background_img, result_img, do_enhance=True)
|
| 93 |
+
return swapped
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"Lỗi swap face: {e}")
|
| 96 |
+
return result_img # Trả về ảnh gốc nếu có lỗi
|
| 97 |
+
|
| 98 |
+
# --- Hàm gộp overlay + swap face ---
|
| 99 |
+
def combined_hair_and_face(background_img, source_img):
|
| 100 |
+
"""
|
| 101 |
+
Hàm gộp: chạy overlay trước, sau đó swap face
|
| 102 |
+
"""
|
| 103 |
+
if background_img is None or source_img is None:
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
# Bước 1: Chạy overlay (ghép tóc)
|
| 108 |
+
hair_result = overlay_source(background_img, source_img)
|
| 109 |
+
|
| 110 |
+
# Bước 2: Chạy swap face từ background lên kết quả overlay
|
| 111 |
+
final_result = swap_face_wrapper(background_img, hair_result)
|
| 112 |
+
|
| 113 |
+
return final_result
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"Lỗi trong quá trình gộp hair + face: {e}")
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
# --- Xây dựng giao diện Gradio ---
|
| 119 |
+
def build_demo():
|
| 120 |
+
with gr.Blocks(title="Hair Try-On & Face Swap", theme=gr.themes.Soft()) as demo:
|
| 121 |
+
gr.Markdown("""
|
| 122 |
+
# 🎯 Hair Try-On & Face Swap Application
|
| 123 |
+
""")
|
| 124 |
+
with gr.Row():
|
| 125 |
+
bg = gr.Image(type="pil", label="Background", height=500)
|
| 126 |
+
src = gr.Image(type="pil", label="Source", height=500, interactive=False)
|
| 127 |
+
out = gr.Image(label="Result", height=500, interactive=False)
|
| 128 |
+
|
| 129 |
+
folder_path_box = gr.Textbox(label="Folder path", visible=False)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
with gr.Row():
|
| 133 |
+
src_name_box = gr.Textbox(
|
| 134 |
+
label="Wigs Name",
|
| 135 |
+
interactive=False,
|
| 136 |
+
show_copy_button=True , # tuỳ chọn – tiện copy đường dẫn
|
| 137 |
+
scale = 1
|
| 138 |
+
)
|
| 139 |
+
gallery = gr.Gallery(
|
| 140 |
+
label="Recommend For You",
|
| 141 |
+
height=300,
|
| 142 |
+
value=[],
|
| 143 |
+
type="filepath",
|
| 144 |
+
interactive=False,
|
| 145 |
+
columns=5,
|
| 146 |
+
object_fit="cover",
|
| 147 |
+
allow_preview=True,
|
| 148 |
+
scale = 8
|
| 149 |
+
)
|
| 150 |
+
with gr.Column(scale=1):
|
| 151 |
+
combined_btn = gr.Button("🔄✨ Run Hair + Face Swap", variant="primary")
|
| 152 |
+
btn = gr.Button("🔄 Run Hair Only", variant="secondary")
|
| 153 |
+
swap_btn = gr.Button("👤 Swap Face Only", variant="secondary")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Chạy gộp hair + face swap
|
| 158 |
+
combined_btn.click(fn=combined_hair_and_face, inputs=[bg, src], outputs=[out])
|
| 159 |
+
|
| 160 |
+
# Chạy ghép tóc
|
| 161 |
+
btn.click(fn=overlay_source, inputs=[bg, src], outputs=[out])
|
| 162 |
+
|
| 163 |
+
# Chạy swap face
|
| 164 |
+
swap_btn.click(fn=swap_face_wrapper, inputs=[bg, out], outputs=[out])
|
| 165 |
+
|
| 166 |
+
# Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý
|
| 167 |
+
bg.change(
|
| 168 |
+
fn=handle_bg_change,
|
| 169 |
+
inputs=[bg],
|
| 170 |
+
outputs=[folder_path_box, gallery],
|
| 171 |
+
show_progress=True
|
| 172 |
+
)
|
| 173 |
+
# Khi chọn ảnh trong gallery, cập nhật vào khung Source
|
| 174 |
+
gallery.select(
|
| 175 |
+
fn=on_gallery_select,
|
| 176 |
+
outputs=[src, src_name_box]
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
return demo
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
build_demo().launch()
|
baldhead.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# baldhead.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
# Keras imports (note: keras-contrib must be installed)
|
| 11 |
+
import keras.backend as K
|
| 12 |
+
from keras.layers import (
|
| 13 |
+
Input,
|
| 14 |
+
Conv2D,
|
| 15 |
+
UpSampling2D,
|
| 16 |
+
LeakyReLU,
|
| 17 |
+
GlobalAveragePooling2D,
|
| 18 |
+
Dense,
|
| 19 |
+
Reshape,
|
| 20 |
+
Dropout,
|
| 21 |
+
Concatenate,
|
| 22 |
+
multiply, # ← Thêm import multiply
|
| 23 |
+
)
|
| 24 |
+
from keras.models import Model
|
| 25 |
+
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
|
| 26 |
+
|
| 27 |
+
# RetinaFace + skimage for face alignment
|
| 28 |
+
from retinaface import RetinaFace
|
| 29 |
+
from skimage import transform as trans
|
| 30 |
+
|
| 31 |
+
# Hugging Face Hub helper
|
| 32 |
+
from huggingface_hub import hf_hub_download
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# --- Face‐alignment helpers (giống code gốc) ---
|
| 37 |
+
image_size = [256, 256]
|
| 38 |
+
src_landmarks = np.array([
|
| 39 |
+
[30.2946, 51.6963],
|
| 40 |
+
[65.5318, 51.5014],
|
| 41 |
+
[48.0252, 71.7366],
|
| 42 |
+
[33.5493, 92.3655],
|
| 43 |
+
[62.7299, 92.2041],
|
| 44 |
+
], dtype=np.float32)
|
| 45 |
+
src_landmarks[:, 0] += 8.0
|
| 46 |
+
src_landmarks[:, 0] += 15.0
|
| 47 |
+
src_landmarks[:, 1] += 30.0
|
| 48 |
+
src_landmarks /= 112
|
| 49 |
+
src_landmarks *= 200
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def list2array(values):
|
| 53 |
+
return np.array(list(values))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def align_face(img: np.ndarray):
|
| 57 |
+
"""
|
| 58 |
+
Detect faces + landmarks in `img` via RetinaFace.
|
| 59 |
+
Returns lists of aligned face patches (256×256 RGB),
|
| 60 |
+
corresponding binary masks, and the transformation matrices.
|
| 61 |
+
"""
|
| 62 |
+
faces = RetinaFace.detect_faces(img)
|
| 63 |
+
bboxes = np.array([list2array(faces[f]['facial_area']) for f in faces])
|
| 64 |
+
landmarks = np.array([list2array(faces[f]['landmarks'].values()) for f in faces])
|
| 65 |
+
|
| 66 |
+
white_canvas = np.ones(img.shape, dtype=np.uint8) * 255
|
| 67 |
+
aligned_faces, masks, matrices = [], [], []
|
| 68 |
+
|
| 69 |
+
if bboxes.shape[0] > 0:
|
| 70 |
+
for i in range(bboxes.shape[0]):
|
| 71 |
+
dst = landmarks[i] # detected landmarks
|
| 72 |
+
tform = trans.SimilarityTransform()
|
| 73 |
+
tform.estimate(dst, src_landmarks)
|
| 74 |
+
M = tform.params[0:2, :]
|
| 75 |
+
|
| 76 |
+
warped_face = cv2.warpAffine(
|
| 77 |
+
img, M, (image_size[1], image_size[0]), borderValue=0.0
|
| 78 |
+
)
|
| 79 |
+
warped_mask = cv2.warpAffine(
|
| 80 |
+
white_canvas, M, (image_size[1], image_size[0]), borderValue=0.0
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
aligned_faces.append(warped_face)
|
| 84 |
+
masks.append(warped_mask)
|
| 85 |
+
matrices.append(tform.params[0:3, :])
|
| 86 |
+
|
| 87 |
+
return aligned_faces, masks, matrices
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def put_face_back(
|
| 91 |
+
orig_img: np.ndarray,
|
| 92 |
+
processed_faces: list[np.ndarray],
|
| 93 |
+
masks: list[np.ndarray],
|
| 94 |
+
matrices: list[np.ndarray],
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Warp each processed face back onto the original `orig_img`
|
| 98 |
+
using the inverse of the transformation matrices.
|
| 99 |
+
"""
|
| 100 |
+
result = orig_img.copy()
|
| 101 |
+
h, w = orig_img.shape[:2]
|
| 102 |
+
|
| 103 |
+
for i in range(len(processed_faces)):
|
| 104 |
+
invM = np.linalg.inv(matrices[i])[0:2]
|
| 105 |
+
warped = cv2.warpAffine(processed_faces[i], invM, (w, h), borderValue=0.0)
|
| 106 |
+
mask = cv2.warpAffine(masks[i], invM, (w, h), borderValue=0.0)
|
| 107 |
+
binary_mask = (mask // 255).astype(np.uint8)
|
| 108 |
+
|
| 109 |
+
# Composite: result = result * (1 - mask) + warped * mask
|
| 110 |
+
result = result * (1 - binary_mask)
|
| 111 |
+
result = result.astype(np.uint8)
|
| 112 |
+
result = result + warped * binary_mask
|
| 113 |
+
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ----------------------------
|
| 118 |
+
# 2. GENERATOR ARCHITECTURE
|
| 119 |
+
# ----------------------------
|
| 120 |
+
|
| 121 |
+
def squeeze_excite_block(x, ratio=4):
|
| 122 |
+
"""
|
| 123 |
+
Squeeze-and-Excitation block: channel-wise attention.
|
| 124 |
+
"""
|
| 125 |
+
init = x
|
| 126 |
+
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
| 127 |
+
filters = init.shape[channel_axis]
|
| 128 |
+
se_shape = (1, 1, filters)
|
| 129 |
+
|
| 130 |
+
se = GlobalAveragePooling2D()(init)
|
| 131 |
+
se = Reshape(se_shape)(se)
|
| 132 |
+
se = Dense(filters // ratio, activation="relu", kernel_initializer="he_normal", use_bias=False)(se)
|
| 133 |
+
se = Dense(filters, activation="sigmoid", kernel_initializer="he_normal", use_bias=False)(se)
|
| 134 |
+
return multiply([init, se])
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def conv2d(layer_input, filters, f_size=4, bn=True, se=False):
|
| 138 |
+
"""
|
| 139 |
+
Downsampling block: Conv2D → LeakyReLU → (InstanceNorm) → (SE block)
|
| 140 |
+
"""
|
| 141 |
+
d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same")(layer_input)
|
| 142 |
+
d = LeakyReLU(alpha=0.2)(d)
|
| 143 |
+
if bn:
|
| 144 |
+
d = InstanceNormalization()(d)
|
| 145 |
+
if se:
|
| 146 |
+
d = squeeze_excite_block(d)
|
| 147 |
+
return d
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def atrous(layer_input, filters, f_size=4, bn=True):
|
| 151 |
+
"""
|
| 152 |
+
Atrous (dilated) convolution block with dilation rates [2,4,8].
|
| 153 |
+
"""
|
| 154 |
+
a_list = []
|
| 155 |
+
for rate in [2, 4, 8]:
|
| 156 |
+
a = Conv2D(filters, f_size, dilation_rate=rate, padding="same")(layer_input)
|
| 157 |
+
a_list.append(a)
|
| 158 |
+
a = Concatenate()(a_list)
|
| 159 |
+
a = LeakyReLU(alpha=0.2)(a)
|
| 160 |
+
if bn:
|
| 161 |
+
a = InstanceNormalization()(a)
|
| 162 |
+
return a
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
|
| 166 |
+
"""
|
| 167 |
+
Upsampling block: UpSampling2D → Conv2D → (Dropout) → InstanceNorm → Concatenate(skip)
|
| 168 |
+
"""
|
| 169 |
+
u = UpSampling2D(size=2)(layer_input)
|
| 170 |
+
u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same", activation="relu")(u)
|
| 171 |
+
if dropout_rate:
|
| 172 |
+
u = Dropout(dropout_rate)(u)
|
| 173 |
+
u = InstanceNormalization()(u)
|
| 174 |
+
u = Concatenate()([u, skip_input])
|
| 175 |
+
return u
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def build_generator():
|
| 179 |
+
"""
|
| 180 |
+
Reconstruct the generator architecture exactly as in the notebook,
|
| 181 |
+
then return a Keras Model object.
|
| 182 |
+
"""
|
| 183 |
+
d0 = Input(shape=(256, 256, 3))
|
| 184 |
+
gf = 64
|
| 185 |
+
|
| 186 |
+
# Downsampling
|
| 187 |
+
d1 = conv2d(d0, gf, bn=False, se=True)
|
| 188 |
+
d2 = conv2d(d1, gf * 2, se=True)
|
| 189 |
+
d3 = conv2d(d2, gf * 4, se=True)
|
| 190 |
+
d4 = conv2d(d3, gf * 8)
|
| 191 |
+
d5 = conv2d(d4, gf * 8)
|
| 192 |
+
|
| 193 |
+
# Atrous block
|
| 194 |
+
a1 = atrous(d5, gf * 8)
|
| 195 |
+
|
| 196 |
+
# Upsampling
|
| 197 |
+
u3 = deconv2d(a1, d4, gf * 8)
|
| 198 |
+
u4 = deconv2d(u3, d3, gf * 4)
|
| 199 |
+
u5 = deconv2d(u4, d2, gf * 2)
|
| 200 |
+
u6 = deconv2d(u5, d1, gf)
|
| 201 |
+
|
| 202 |
+
# Final upsample + conv
|
| 203 |
+
u7 = UpSampling2D(size=2)(u6)
|
| 204 |
+
output_img = Conv2D(3, kernel_size=4, strides=1, padding="same", activation="tanh")(u7)
|
| 205 |
+
|
| 206 |
+
model = Model(d0, output_img)
|
| 207 |
+
return model
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ----------------------------
|
| 211 |
+
# 3. LOAD MODEL WEIGHTS
|
| 212 |
+
# ----------------------------
|
| 213 |
+
|
| 214 |
+
HF_REPO_ID = "VanNguyen1214/baldhead"
|
| 215 |
+
HF_FILENAME = "model_G_5_170.hdf5"
|
| 216 |
+
HF_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
| 217 |
+
|
| 218 |
+
def load_generator_from_hub():
|
| 219 |
+
"""
|
| 220 |
+
Download the .hdf5 weights from HF Hub into cache,
|
| 221 |
+
rebuild the generator, then load weights.
|
| 222 |
+
"""
|
| 223 |
+
local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME,token=HF_TOKEN)
|
| 224 |
+
gen = build_generator()
|
| 225 |
+
gen.load_weights(local_path)
|
| 226 |
+
return gen
|
| 227 |
+
|
| 228 |
+
# Load once at startup
|
| 229 |
+
try:
|
| 230 |
+
GENERATOR = load_generator_from_hub()
|
| 231 |
+
print(f"[INFO] Loaded generator weights from {HF_REPO_ID}/{HF_FILENAME}")
|
| 232 |
+
except Exception as e:
|
| 233 |
+
print("[ERROR] Could not load generator:", e)
|
| 234 |
+
GENERATOR = None
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# ----------------------------
|
| 238 |
+
# 4. INFERENCE FUNCTION
|
| 239 |
+
# ----------------------------
|
| 240 |
+
|
| 241 |
+
def inference(image: Image.Image) -> Image.Image:
|
| 242 |
+
"""
|
| 243 |
+
Gradio-compatible inference function:
|
| 244 |
+
- Convert PIL→ numpy RGB
|
| 245 |
+
- Align faces
|
| 246 |
+
- For each face: normalize to [-1,1], run through generator, denormalize to uint8
|
| 247 |
+
- Put processed faces back onto original image
|
| 248 |
+
- Return full-image PIL
|
| 249 |
+
"""
|
| 250 |
+
if GENERATOR is None:
|
| 251 |
+
return image
|
| 252 |
+
|
| 253 |
+
orig = np.array(image.convert("RGB"))
|
| 254 |
+
|
| 255 |
+
faces, masks, mats = align_face(orig)
|
| 256 |
+
if len(faces) == 0:
|
| 257 |
+
return image
|
| 258 |
+
|
| 259 |
+
processed_faces = []
|
| 260 |
+
for face in faces:
|
| 261 |
+
face_input = face.astype(np.float32)
|
| 262 |
+
face_input = (face_input / 127.5) - 1.0 # scale to [-1,1]
|
| 263 |
+
face_input = np.expand_dims(face_input, axis=0) # (1,256,256,3)
|
| 264 |
+
|
| 265 |
+
pred = GENERATOR.predict(face_input)[0] # (256,256,3) in [-1,1]
|
| 266 |
+
pred = ((pred + 1.0) * 127.5).astype(np.uint8)
|
| 267 |
+
processed_faces.append(pred)
|
| 268 |
+
|
| 269 |
+
output_np = put_face_back(orig, processed_faces, masks, mats)
|
| 270 |
+
output_pil = Image.fromarray(output_np)
|
| 271 |
+
|
| 272 |
+
return output_pil
|
bbox_utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
def get_bbox_from_alpha(rgba: Image.Image):
|
| 5 |
+
arr = np.array(rgba)
|
| 6 |
+
alpha = arr[...,3]
|
| 7 |
+
ys, xs = np.where(alpha>0)
|
| 8 |
+
if ys.size == 0:
|
| 9 |
+
return None
|
| 10 |
+
x1, x2 = xs.min(), xs.max()
|
| 11 |
+
y1, y2 = ys.min(), ys.max()
|
| 12 |
+
return x1, y1, x2, y2
|
| 13 |
+
|
| 14 |
+
def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int,int]) -> Image.Image:
|
| 15 |
+
res = bg.copy()
|
| 16 |
+
x, y = offset
|
| 17 |
+
h, w = src.shape[:2]
|
| 18 |
+
x1, y1 = max(x,0), max(y,0)
|
| 19 |
+
x2 = min(x+w, bg.shape[1])
|
| 20 |
+
y2 = min(y+h, bg.shape[0])
|
| 21 |
+
if x1>=x2 or y1>=y2:
|
| 22 |
+
return Image.fromarray(res)
|
| 23 |
+
cs = src[y1-y:y2-y, x1-x:x2-x]
|
| 24 |
+
cd = res[y1:y2, x1:x2]
|
| 25 |
+
mask = cs[...,3]>0
|
| 26 |
+
if cd.shape[2]==3:
|
| 27 |
+
cd[mask] = cs[mask][..., :3]
|
| 28 |
+
else:
|
| 29 |
+
cd[mask] = cs[mask]
|
| 30 |
+
res[y1:y2, x1:x2] = cd
|
| 31 |
+
return Image.fromarray(res)
|
detect_face.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchvision
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
# --- Cấu hình chung ---
|
| 14 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
+
HF_REPO = "VanNguyen1214/detect_faceshape" # repo của bạn trên HF Hub
|
| 16 |
+
HF_FILENAME = "best_model.pth" # file ở root của repo
|
| 17 |
+
LOCAL_CKPT = "models/best_model.pth" # sẽ lưu tại đây
|
| 18 |
+
CLASS_NAMES = ['Heart', 'Oblong', 'Oval', 'Round', 'Square']
|
| 19 |
+
NUM_CLASSES = len(CLASS_NAMES)
|
| 20 |
+
|
| 21 |
+
# --- Transform cho ảnh trước inference ---
|
| 22 |
+
_TRANSFORM = transforms.Compose([
|
| 23 |
+
transforms.Resize((224, 224)),
|
| 24 |
+
transforms.ToTensor(),
|
| 25 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 26 |
+
std =[0.229, 0.224, 0.225]),
|
| 27 |
+
])
|
| 28 |
+
|
| 29 |
+
def _ensure_checkpoint() -> str:
|
| 30 |
+
"""
|
| 31 |
+
Kiểm tra xem LOCAL_CKPT đã tồn tại chưa.
|
| 32 |
+
Nếu chưa, tải best_model.pth từ HF_REPO và lưu vào ./models/
|
| 33 |
+
"""
|
| 34 |
+
if os.path.exists(LOCAL_CKPT):
|
| 35 |
+
return LOCAL_CKPT
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
ckpt_path = hf_hub_download(
|
| 39 |
+
repo_id=HF_REPO,
|
| 40 |
+
filename=HF_FILENAME,
|
| 41 |
+
local_dir="models",
|
| 42 |
+
)
|
| 43 |
+
return ckpt_path
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"❌ Không tải được model từ HF Hub: {e}")
|
| 46 |
+
sys.exit(1)
|
| 47 |
+
|
| 48 |
+
def _load_model(ckpt_path: str) -> torch.nn.Module:
|
| 49 |
+
"""
|
| 50 |
+
Tái tạo kiến trúc EfficientNet-B4, load state_dict, đưa về eval mode.
|
| 51 |
+
"""
|
| 52 |
+
# 1) Khởi tạo EfficientNet-B4
|
| 53 |
+
model = torchvision.models.efficientnet_b4(pretrained=False)
|
| 54 |
+
in_features = model.classifier[1].in_features
|
| 55 |
+
model.classifier = nn.Sequential(
|
| 56 |
+
nn.Dropout(p=0.3, inplace=True),
|
| 57 |
+
nn.Linear(in_features, NUM_CLASSES)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# 2) Load trọng số
|
| 61 |
+
state = torch.load(ckpt_path, map_location=DEVICE)
|
| 62 |
+
model.load_state_dict(state)
|
| 63 |
+
|
| 64 |
+
# 3) Đưa model về chế độ evaluation
|
| 65 |
+
return model.to(DEVICE).eval()
|
| 66 |
+
|
| 67 |
+
# === Build model ngay khi import ===
|
| 68 |
+
_CKPT_PATH = _ensure_checkpoint()
|
| 69 |
+
_MODEL = _load_model(_CKPT_PATH)
|
| 70 |
+
|
| 71 |
+
def predict(image: Image.Image) -> dict:
|
| 72 |
+
"""
|
| 73 |
+
Chức năng inference:
|
| 74 |
+
- image: numpy array H×W×3 RGB
|
| 75 |
+
- Trả về dict:
|
| 76 |
+
{
|
| 77 |
+
"predicted_class": str,
|
| 78 |
+
"confidence": float,
|
| 79 |
+
"probabilities": { class_name: prob, ... }
|
| 80 |
+
}
|
| 81 |
+
"""
|
| 82 |
+
# Convert về PIL + transform
|
| 83 |
+
img = image.convert("RGB")
|
| 84 |
+
x = _TRANSFORM(img).unsqueeze(0).to(DEVICE)
|
| 85 |
+
|
| 86 |
+
# Inference
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
logits = _MODEL(x)
|
| 89 |
+
probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
|
| 90 |
+
|
| 91 |
+
idx = int(probs.argmax())
|
| 92 |
+
return {"predicted_class": CLASS_NAMES[idx]}
|
| 93 |
+
|
example_wigs/Heart/HH02.png
ADDED
|
Git LFS Details
|
example_wigs/Heart/HH03.png
ADDED
|
Git LFS Details
|
example_wigs/Heart/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Heart/SantaRossa.png
ADDED
|
Git LFS Details
|
example_wigs/Heart/Tuscany.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/HH01.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/HH02.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/HH03.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/HH07.png
ADDED
|
Git LFS Details
|
example_wigs/Oblong/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Alsace.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Barossa.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Burgundy.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/HH01.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/HH02.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/HH03.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/HH07.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Napa.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Piemonte.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Rhone.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/SantaRossa.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Sonoma.png
ADDED
|
Git LFS Details
|
example_wigs/Oval/Tuscany.png
ADDED
|
Git LFS Details
|
example_wigs/Round/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Round/Piemonte.png
ADDED
|
Git LFS Details
|
example_wigs/Round/Sonoma.png
ADDED
|
Git LFS Details
|
example_wigs/Round/Tuscany.png
ADDED
|
Git LFS Details
|
example_wigs/Square/HH03.png
ADDED
|
Git LFS Details
|
example_wigs/Square/Loire.png
ADDED
|
Git LFS Details
|
example_wigs/Square/Piemonte.png
ADDED
|
Git LFS Details
|
example_wigs/Square/Sonoma.png
ADDED
|
Git LFS Details
|
example_wigs/Square/Tuscany.png
ADDED
|
Git LFS Details
|
overlay.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import mediapipe as mp
|
| 4 |
+
|
| 5 |
+
from baldhead import inference # cạo tóc background
|
| 6 |
+
from segmentation import extract_hair
|
| 7 |
+
|
| 8 |
+
# MediaPipe Face Detection
|
| 9 |
+
mp_fd = mp.solutions.face_detection.FaceDetection(model_selection=1,
|
| 10 |
+
min_detection_confidence=0.5)
|
| 11 |
+
|
| 12 |
+
def get_face_bbox(img: Image.Image) -> tuple[int,int,int,int] | None:
|
| 13 |
+
arr = np.array(img.convert("RGB"))
|
| 14 |
+
res = mp_fd.process(arr)
|
| 15 |
+
if not res.detections:
|
| 16 |
+
return None
|
| 17 |
+
d = res.detections[0].location_data.relative_bounding_box
|
| 18 |
+
h, w = arr.shape[:2]
|
| 19 |
+
x1 = int(d.xmin * w)
|
| 20 |
+
y1 = int(d.ymin * h)
|
| 21 |
+
x2 = x1 + int(d.width * w)
|
| 22 |
+
y2 = y1 + int(d.height * h)
|
| 23 |
+
return x1, y1, x2, y2
|
| 24 |
+
|
| 25 |
+
def compute_scale(w_bg, h_bg, w_src, h_src) -> float:
|
| 26 |
+
return ((w_bg / w_src) + (h_bg / h_src)) / 2
|
| 27 |
+
|
| 28 |
+
def compute_offset(bbox_bg, bbox_src, scale) -> tuple[int,int]:
|
| 29 |
+
x1, y1, x2, y2 = bbox_bg
|
| 30 |
+
bg_cx = x1 + (x2 - x1)//2
|
| 31 |
+
bg_cy = y1 + (y2 - y1)//2
|
| 32 |
+
sx1, sy1, sx2, sy2 = bbox_src
|
| 33 |
+
src_cx = int((sx1 + (sx2 - sx1)//2) * scale)
|
| 34 |
+
src_cy = int((sy1 + (sy2 - sy1)//2) * scale)
|
| 35 |
+
return bg_cx - src_cx, bg_cy - src_cy
|
| 36 |
+
|
| 37 |
+
def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int,int]) -> Image.Image:
|
| 38 |
+
res = bg.copy()
|
| 39 |
+
x, y = offset
|
| 40 |
+
h, w = src.shape[:2]
|
| 41 |
+
x1, y1 = max(x,0), max(y,0)
|
| 42 |
+
x2 = min(x+w, bg.shape[1])
|
| 43 |
+
y2 = min(y+h, bg.shape[0])
|
| 44 |
+
if x1>=x2 or y1>=y2:
|
| 45 |
+
return Image.fromarray(res)
|
| 46 |
+
cs = src[y1-y:y2-y, x1-x:x2-x]
|
| 47 |
+
cd = res[y1:y2, x1:x2]
|
| 48 |
+
mask = cs[...,3] > 0
|
| 49 |
+
if cd.shape[2] == 3:
|
| 50 |
+
cd[mask] = cs[mask][...,:3]
|
| 51 |
+
else:
|
| 52 |
+
cd[mask] = cs[mask]
|
| 53 |
+
res[y1:y2, x1:x2] = cd
|
| 54 |
+
return Image.fromarray(res)
|
| 55 |
+
|
| 56 |
+
def overlay_source(background: Image.Image, source: Image.Image):
|
| 57 |
+
# 1) detect bboxes
|
| 58 |
+
bbox_bg = get_face_bbox(background)
|
| 59 |
+
bbox_src = get_face_bbox(source)
|
| 60 |
+
if bbox_bg is None:
|
| 61 |
+
return None, "❌ No face in background."
|
| 62 |
+
if bbox_src is None:
|
| 63 |
+
return None, "❌ No face in source."
|
| 64 |
+
|
| 65 |
+
# 2) compute scale & resize source
|
| 66 |
+
w_bg, h_bg = bbox_bg[2]-bbox_bg[0], bbox_bg[3]-bbox_bg[1]
|
| 67 |
+
w_src, h_src = bbox_src[2]-bbox_src[0], bbox_src[3]-bbox_src[1]
|
| 68 |
+
scale = compute_scale(w_bg, h_bg, w_src, h_src)
|
| 69 |
+
src_scaled = source.resize(
|
| 70 |
+
(int(source.width*scale), int(source.height*scale)),
|
| 71 |
+
Image.Resampling.LANCZOS
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# 3) compute offset
|
| 75 |
+
offset = compute_offset(bbox_bg, bbox_src, scale)
|
| 76 |
+
|
| 77 |
+
# 4) baldhead background
|
| 78 |
+
bg_bald = inference(background)
|
| 79 |
+
|
| 80 |
+
# 5) extract hair-only from source
|
| 81 |
+
hair_only = extract_hair(src_scaled)
|
| 82 |
+
|
| 83 |
+
# 6) paste onto bald background
|
| 84 |
+
result = paste_with_alpha(
|
| 85 |
+
np.array(bg_bald.convert("RGBA")),
|
| 86 |
+
np.array(hair_only),
|
| 87 |
+
offset
|
| 88 |
+
)
|
| 89 |
+
return result
|
requirements.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu118 # Dòng này có vẻ là comment hoặc cấu hình cho pip, không phải là một gói
|
| 2 |
+
# spaces # Dòng này không rõ ràng là một gói, có thể là ghi chú. Nếu không phải gói, hãy xóa đi.
|
| 3 |
+
huggingface_hub>=0.20.3
|
| 4 |
+
numpy==1.23.5
|
| 5 |
+
transformers==4.30.0
|
| 6 |
+
opencv-python-headless==4.7.0.72
|
| 7 |
+
onnx==1.14.0
|
| 8 |
+
insightface==0.7.3
|
| 9 |
+
psutil==5.9.5
|
| 10 |
+
tk==0.1.0 # Lưu ý: tk thường được bao gồm trong bản cài đặt Python chuẩn, không phải lúc nào cũng cần cài qua pip.
|
| 11 |
+
customtkinter==5.1.3
|
| 12 |
+
pillow==9.5.0
|
| 13 |
+
torch==2.0.1+cu118; sys_platform != 'darwin'
|
| 14 |
+
torch==2.0.1; sys_platform == 'darwin'
|
| 15 |
+
torchvision==0.15.2+cu118; sys_platform != 'darwin'
|
| 16 |
+
torchvision==0.15.2; sys_platform == 'darwin'
|
| 17 |
+
# onnxruntime==1.15.0; # Bỏ comment cho dòng này nếu bạn muốn cố định phiên bản cho mọi OS
|
| 18 |
+
# sys_platform == 'darwin' and platform_machine != 'arm64' # Comment
|
| 19 |
+
onnxruntime-silicon==1.13.1; sys_platform == 'darwin' and platform_machine == 'arm64'
|
| 20 |
+
onnxruntime-gpu==1.15.0; sys_platform != 'darwin' # Nên giữ lại dòng này cho non-darwin GPU
|
| 21 |
+
onnxruntime==1.15.0; sys_platform == 'darwin' and platform_machine != 'arm64' # Thêm lại dòng onnxruntime cho Mac Intel
|
| 22 |
+
tensorflow==2.12.0
|
| 23 |
+
# sys_platform != 'darwin' # Comment
|
| 24 |
+
opennsfw2==0.10.2
|
| 25 |
+
# protobuf==4.23.2 # Thay thế dòng này
|
| 26 |
+
protobuf==4.25.3 # *** THAY ĐỔI QUAN TRỌNG ***
|
| 27 |
+
tqdm==4.65.0
|
| 28 |
+
gfpgan==1.3.8
|
| 29 |
+
# torch # Dòng này không cần thiết vì torch đã được định nghĩa ở trên với phiên bản cụ thể.
|
| 30 |
+
|
| 31 |
+
# Thêm các thư viện mới cần thiết cho app.py đã cập nhật
|
| 32 |
+
scikit-image>=0.19 # Hoặc một phiên bản cụ thể hơn nếu bạn muốn, ví dụ: scikit-image==0.19.3
|
| 33 |
+
mediapipe==0.10.14 # *** THÊM MỚI HOẶC CẬP NHẬT *** (Phiên bản này yêu cầu protobuf >=4.25.3)
|
| 34 |
+
git+https://github.com/keras-team/keras-contrib.git
|
| 35 |
+
retina-face==0.0.13
|
roop/__init__.py
ADDED
|
File without changes
|
roop/capturer.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
import cv2
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_video_frame(video_path: str, frame_number: int = 0) -> Any:
|
| 6 |
+
capture = cv2.VideoCapture(video_path)
|
| 7 |
+
frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
| 8 |
+
capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1))
|
| 9 |
+
has_frame, frame = capture.read()
|
| 10 |
+
capture.release()
|
| 11 |
+
if has_frame:
|
| 12 |
+
return frame
|
| 13 |
+
return None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_video_frame_total(video_path: str) -> int:
|
| 17 |
+
capture = cv2.VideoCapture(video_path)
|
| 18 |
+
video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 19 |
+
capture.release()
|
| 20 |
+
return video_frame_total
|
roop/core.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
# single thread doubles cuda performance - needs to be set before torch import
|
| 6 |
+
if any(arg.startswith('--execution-provider') for arg in sys.argv):
|
| 7 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
| 8 |
+
# reduce tensorflow log level
|
| 9 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 10 |
+
import warnings
|
| 11 |
+
from typing import List
|
| 12 |
+
import platform
|
| 13 |
+
import signal
|
| 14 |
+
import shutil
|
| 15 |
+
import argparse
|
| 16 |
+
import torch
|
| 17 |
+
import onnxruntime
|
| 18 |
+
import tensorflow
|
| 19 |
+
|
| 20 |
+
import roop.globals
|
| 21 |
+
import roop.metadata
|
| 22 |
+
import roop.ui as ui
|
| 23 |
+
from roop.predicter import predict_image, predict_video
|
| 24 |
+
from roop.processors.frame.core import get_frame_processors_modules
|
| 25 |
+
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
|
| 26 |
+
|
| 27 |
+
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
|
| 28 |
+
del torch
|
| 29 |
+
|
| 30 |
+
warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
|
| 31 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def parse_args() -> None:
|
| 35 |
+
signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
|
| 36 |
+
program = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100))
|
| 37 |
+
program.add_argument('-s', '--source', help='select an source image', dest='source_path')
|
| 38 |
+
program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
|
| 39 |
+
program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
|
| 40 |
+
program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
|
| 41 |
+
program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False)
|
| 42 |
+
program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True)
|
| 43 |
+
program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False)
|
| 44 |
+
program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False)
|
| 45 |
+
program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
|
| 46 |
+
program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
|
| 47 |
+
program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory())
|
| 48 |
+
program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
|
| 49 |
+
program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
|
| 50 |
+
program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
|
| 51 |
+
|
| 52 |
+
args = program.parse_args()
|
| 53 |
+
|
| 54 |
+
roop.globals.source_path = args.source_path
|
| 55 |
+
roop.globals.target_path = args.target_path
|
| 56 |
+
roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
|
| 57 |
+
roop.globals.frame_processors = args.frame_processor
|
| 58 |
+
roop.globals.headless = args.source_path or args.target_path or args.output_path
|
| 59 |
+
roop.globals.keep_fps = args.keep_fps
|
| 60 |
+
roop.globals.keep_audio = args.keep_audio
|
| 61 |
+
roop.globals.keep_frames = args.keep_frames
|
| 62 |
+
roop.globals.many_faces = args.many_faces
|
| 63 |
+
roop.globals.video_encoder = args.video_encoder
|
| 64 |
+
roop.globals.video_quality = args.video_quality
|
| 65 |
+
roop.globals.max_memory = args.max_memory
|
| 66 |
+
roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
|
| 67 |
+
roop.globals.execution_threads = args.execution_threads
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
|
| 71 |
+
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def decode_execution_providers(execution_providers: List[str]) -> List[str]:
|
| 75 |
+
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
|
| 76 |
+
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def suggest_max_memory() -> int:
|
| 80 |
+
if platform.system().lower() == 'darwin':
|
| 81 |
+
return 4
|
| 82 |
+
return 16
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def suggest_execution_providers() -> List[str]:
|
| 86 |
+
return encode_execution_providers(onnxruntime.get_available_providers())
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def suggest_execution_threads() -> int:
|
| 90 |
+
if 'DmlExecutionProvider' in roop.globals.execution_providers:
|
| 91 |
+
return 1
|
| 92 |
+
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
|
| 93 |
+
return 1
|
| 94 |
+
return 8
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def limit_resources() -> None:
|
| 98 |
+
# prevent tensorflow memory leak
|
| 99 |
+
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
|
| 100 |
+
for gpu in gpus:
|
| 101 |
+
tensorflow.config.experimental.set_virtual_device_configuration(gpu, [
|
| 102 |
+
tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
|
| 103 |
+
])
|
| 104 |
+
# limit memory usage
|
| 105 |
+
if roop.globals.max_memory:
|
| 106 |
+
memory = roop.globals.max_memory * 1024 ** 3
|
| 107 |
+
if platform.system().lower() == 'darwin':
|
| 108 |
+
memory = roop.globals.max_memory * 1024 ** 6
|
| 109 |
+
if platform.system().lower() == 'windows':
|
| 110 |
+
import ctypes
|
| 111 |
+
kernel32 = ctypes.windll.kernel32
|
| 112 |
+
kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
|
| 113 |
+
else:
|
| 114 |
+
import resource
|
| 115 |
+
resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def release_resources() -> None:
|
| 119 |
+
if 'CUDAExecutionProvider' in roop.globals.execution_providers:
|
| 120 |
+
torch.cuda.empty_cache()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def pre_check() -> bool:
|
| 124 |
+
if sys.version_info < (3, 9):
|
| 125 |
+
update_status('Python version is not supported - please upgrade to 3.9 or higher.')
|
| 126 |
+
return False
|
| 127 |
+
if not shutil.which('ffmpeg'):
|
| 128 |
+
update_status('ffmpeg is not installed.')
|
| 129 |
+
return False
|
| 130 |
+
return True
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def update_status(message: str, scope: str = 'ROOP.CORE') -> None:
|
| 134 |
+
print(f'[{scope}] {message}')
|
| 135 |
+
if not roop.globals.headless:
|
| 136 |
+
ui.update_status(message)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def start() -> None:
|
| 140 |
+
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
| 141 |
+
if not frame_processor.pre_start():
|
| 142 |
+
return
|
| 143 |
+
# process image to image
|
| 144 |
+
if has_image_extension(roop.globals.target_path):
|
| 145 |
+
if predict_image(roop.globals.target_path):
|
| 146 |
+
destroy()
|
| 147 |
+
shutil.copy2(roop.globals.target_path, roop.globals.output_path)
|
| 148 |
+
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
| 149 |
+
for frame_processor_name in roop.globals.frame_processors:
|
| 150 |
+
if frame_processor_name == frame_processor.frame_name:
|
| 151 |
+
update_status('Progressing...', frame_processor.NAME)
|
| 152 |
+
frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
|
| 153 |
+
frame_processor.post_process()
|
| 154 |
+
release_resources()
|
| 155 |
+
if is_image(roop.globals.target_path):
|
| 156 |
+
update_status('Processing to image succeed!')
|
| 157 |
+
else:
|
| 158 |
+
update_status('Processing to image failed!')
|
| 159 |
+
return
|
| 160 |
+
# process image to videos
|
| 161 |
+
if predict_video(roop.globals.target_path):
|
| 162 |
+
destroy()
|
| 163 |
+
update_status('Creating temp resources...')
|
| 164 |
+
create_temp(roop.globals.target_path)
|
| 165 |
+
update_status('Extracting frames...')
|
| 166 |
+
extract_frames(roop.globals.target_path)
|
| 167 |
+
temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
|
| 168 |
+
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
| 169 |
+
update_status('Progressing...', frame_processor.NAME)
|
| 170 |
+
frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
|
| 171 |
+
frame_processor.post_process()
|
| 172 |
+
release_resources()
|
| 173 |
+
# handles fps
|
| 174 |
+
if roop.globals.keep_fps:
|
| 175 |
+
update_status('Detecting fps...')
|
| 176 |
+
fps = detect_fps(roop.globals.target_path)
|
| 177 |
+
update_status(f'Creating video with {fps} fps...')
|
| 178 |
+
create_video(roop.globals.target_path, fps)
|
| 179 |
+
else:
|
| 180 |
+
update_status('Creating video with 30.0 fps...')
|
| 181 |
+
create_video(roop.globals.target_path)
|
| 182 |
+
# handle audio
|
| 183 |
+
if roop.globals.keep_audio:
|
| 184 |
+
if roop.globals.keep_fps:
|
| 185 |
+
update_status('Restoring audio...')
|
| 186 |
+
else:
|
| 187 |
+
update_status('Restoring audio might cause issues as fps are not kept...')
|
| 188 |
+
restore_audio(roop.globals.target_path, roop.globals.output_path)
|
| 189 |
+
else:
|
| 190 |
+
move_temp(roop.globals.target_path, roop.globals.output_path)
|
| 191 |
+
# clean and validate
|
| 192 |
+
clean_temp(roop.globals.target_path)
|
| 193 |
+
if is_video(roop.globals.target_path):
|
| 194 |
+
update_status('Processing to video succeed!')
|
| 195 |
+
else:
|
| 196 |
+
update_status('Processing to video failed!')
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def destroy() -> None:
|
| 200 |
+
if roop.globals.target_path:
|
| 201 |
+
clean_temp(roop.globals.target_path)
|
| 202 |
+
quit()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def run() -> None:
|
| 206 |
+
parse_args()
|
| 207 |
+
if not pre_check():
|
| 208 |
+
return
|
| 209 |
+
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
| 210 |
+
if not frame_processor.pre_check():
|
| 211 |
+
return
|
| 212 |
+
limit_resources()
|
| 213 |
+
if roop.globals.headless:
|
| 214 |
+
start()
|
| 215 |
+
else:
|
| 216 |
+
window = ui.init(start, destroy)
|
| 217 |
+
window.mainloop()
|
roop/face_analyser.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from typing import Any
|
| 3 |
+
import insightface
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import roop.globals
|
| 8 |
+
from roop.typing import Frame
|
| 9 |
+
|
| 10 |
+
FACE_ANALYSER = None
|
| 11 |
+
THREAD_LOCK = threading.Lock()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_face_analyser() -> Any:
|
| 15 |
+
global FACE_ANALYSER
|
| 16 |
+
|
| 17 |
+
with THREAD_LOCK:
|
| 18 |
+
if FACE_ANALYSER is None:
|
| 19 |
+
FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers)
|
| 20 |
+
FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640))
|
| 21 |
+
return FACE_ANALYSER
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_precise_face_mask(frame: Frame) -> Any:
|
| 25 |
+
"""
|
| 26 |
+
Get precise face mask using advanced segmentation (same as detect_face_and_forehead_no_hair).
|
| 27 |
+
Returns both InsightFace detection and precise mask.
|
| 28 |
+
"""
|
| 29 |
+
try:
|
| 30 |
+
# Import the precise detection function
|
| 31 |
+
import sys
|
| 32 |
+
import os
|
| 33 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
| 34 |
+
from segmentation import detect_face_and_forehead_no_hair
|
| 35 |
+
|
| 36 |
+
# Convert frame to PIL Image
|
| 37 |
+
if isinstance(frame, np.ndarray):
|
| 38 |
+
pil_image = Image.fromarray(frame)
|
| 39 |
+
else:
|
| 40 |
+
pil_image = frame
|
| 41 |
+
|
| 42 |
+
# Get precise face mask (clean skin only)
|
| 43 |
+
precise_mask = detect_face_and_forehead_no_hair(pil_image)
|
| 44 |
+
|
| 45 |
+
# Also get InsightFace detection for face swapping compatibility
|
| 46 |
+
insightface_faces = get_face_analyser().get(frame)
|
| 47 |
+
|
| 48 |
+
return {
|
| 49 |
+
'precise_mask': precise_mask,
|
| 50 |
+
'insightface_faces': insightface_faces,
|
| 51 |
+
'has_face': precise_mask.sum() > 0 and len(insightface_faces) > 0
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f"Precise face detection failed: {e}")
|
| 56 |
+
# Fallback to regular InsightFace
|
| 57 |
+
insightface_faces = get_face_analyser().get(frame)
|
| 58 |
+
return {
|
| 59 |
+
'precise_mask': None,
|
| 60 |
+
'insightface_faces': insightface_faces,
|
| 61 |
+
'has_face': len(insightface_faces) > 0
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_one_face(frame: Frame) -> Any:
|
| 66 |
+
"""
|
| 67 |
+
Get one face with enhanced precision detection.
|
| 68 |
+
"""
|
| 69 |
+
# Get precise detection info
|
| 70 |
+
face_info = get_precise_face_mask(frame)
|
| 71 |
+
|
| 72 |
+
if face_info['has_face'] and face_info['insightface_faces']:
|
| 73 |
+
try:
|
| 74 |
+
# Select face (leftmost) for compatibility
|
| 75 |
+
selected_face = min(face_info['insightface_faces'], key=lambda x: x.bbox[0])
|
| 76 |
+
|
| 77 |
+
# Add precise mask info to face object
|
| 78 |
+
if face_info['precise_mask'] is not None:
|
| 79 |
+
selected_face.precise_mask = face_info['precise_mask']
|
| 80 |
+
print(f"✅ Enhanced face detection: {face_info['precise_mask'].sum()} precise pixels")
|
| 81 |
+
|
| 82 |
+
return selected_face
|
| 83 |
+
except (ValueError, IndexError):
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
# Fallback to original method
|
| 87 |
+
face = get_face_analyser().get(frame)
|
| 88 |
+
try:
|
| 89 |
+
selected_face = min(face, key=lambda x: x.bbox[0])
|
| 90 |
+
return selected_face
|
| 91 |
+
except ValueError:
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_many_faces(frame: Frame) -> Any:
|
| 96 |
+
"""
|
| 97 |
+
Get many faces with enhanced precision detection.
|
| 98 |
+
"""
|
| 99 |
+
# Get precise detection info
|
| 100 |
+
face_info = get_precise_face_mask(frame)
|
| 101 |
+
|
| 102 |
+
if face_info['has_face'] and face_info['insightface_faces']:
|
| 103 |
+
faces = face_info['insightface_faces']
|
| 104 |
+
|
| 105 |
+
# Add precise mask info to all face objects
|
| 106 |
+
if face_info['precise_mask'] is not None:
|
| 107 |
+
for face in faces:
|
| 108 |
+
face.precise_mask = face_info['precise_mask']
|
| 109 |
+
|
| 110 |
+
print(f"✅ Enhanced multi-face detection: {len(faces)} faces with precise masks")
|
| 111 |
+
return faces
|
| 112 |
+
|
| 113 |
+
# Fallback to original method
|
| 114 |
+
try:
|
| 115 |
+
return get_face_analyser().get(frame)
|
| 116 |
+
except IndexError:
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def has_precise_face_mask(face_obj) -> bool:
|
| 121 |
+
"""
|
| 122 |
+
Check if face object has precise mask attached.
|
| 123 |
+
"""
|
| 124 |
+
return hasattr(face_obj, 'precise_mask') and face_obj.precise_mask is not None
|
roop/globals.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
source_path = None
|
| 4 |
+
target_path = None
|
| 5 |
+
output_path = None
|
| 6 |
+
frame_processors: List[str] = []
|
| 7 |
+
keep_fps = None
|
| 8 |
+
keep_audio = None
|
| 9 |
+
keep_frames = None
|
| 10 |
+
many_faces = None
|
| 11 |
+
video_encoder = None
|
| 12 |
+
video_quality = None
|
| 13 |
+
max_memory = None
|
| 14 |
+
execution_providers: List[str] = []
|
| 15 |
+
execution_threads = None
|
| 16 |
+
headless = None
|
| 17 |
+
log_level = 'error'
|
roop/metadata.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name = 'roop'
|
| 2 |
+
version = '1.1.0'
|
roop/predicter.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy
|
| 2 |
+
import opennsfw2
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
from roop.typing import Frame
|
| 6 |
+
|
| 7 |
+
MAX_PROBABILITY = 0.85
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def predict_frame(target_frame: Frame) -> bool:
|
| 11 |
+
image = Image.fromarray(target_frame)
|
| 12 |
+
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
|
| 13 |
+
model = opennsfw2.make_open_nsfw_model()
|
| 14 |
+
views = numpy.expand_dims(image, axis=0)
|
| 15 |
+
_, probability = model.predict(views)[0]
|
| 16 |
+
return probability > MAX_PROBABILITY
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def predict_image(target_path: str) -> bool:
|
| 20 |
+
return opennsfw2.predict_image(target_path) > MAX_PROBABILITY
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def predict_video(target_path: str) -> bool:
|
| 24 |
+
_, probabilities = opennsfw2.predict_video_frames(video_path=target_path, frame_interval=100)
|
| 25 |
+
return any(probability > MAX_PROBABILITY for probability in probabilities)
|
roop/processors/__init__.py
ADDED
|
File without changes
|
roop/processors/frame/__init__.py
ADDED
|
File without changes
|