VanNguyen1214 commited on
Commit
8dff9a2
·
verified ·
1 Parent(s): dd1b6cd

Upload 58 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +33 -0
  2. README.md +12 -12
  3. app.py +182 -0
  4. baldhead.py +272 -0
  5. bbox_utils.py +31 -0
  6. detect_face.py +93 -0
  7. example_wigs/Heart/HH02.png +3 -0
  8. example_wigs/Heart/HH03.png +3 -0
  9. example_wigs/Heart/Loire.png +3 -0
  10. example_wigs/Heart/SantaRossa.png +3 -0
  11. example_wigs/Heart/Tuscany.png +3 -0
  12. example_wigs/Oblong/HH01.png +3 -0
  13. example_wigs/Oblong/HH02.png +3 -0
  14. example_wigs/Oblong/HH03.png +3 -0
  15. example_wigs/Oblong/HH07.png +3 -0
  16. example_wigs/Oblong/Loire.png +3 -0
  17. example_wigs/Oval/Alsace.png +3 -0
  18. example_wigs/Oval/Barossa.png +3 -0
  19. example_wigs/Oval/Burgundy.png +3 -0
  20. example_wigs/Oval/HH01.png +3 -0
  21. example_wigs/Oval/HH02.png +3 -0
  22. example_wigs/Oval/HH03.png +3 -0
  23. example_wigs/Oval/HH07.png +3 -0
  24. example_wigs/Oval/Loire.png +3 -0
  25. example_wigs/Oval/Napa.png +3 -0
  26. example_wigs/Oval/Piemonte.png +3 -0
  27. example_wigs/Oval/Rhone.png +3 -0
  28. example_wigs/Oval/SantaRossa.png +3 -0
  29. example_wigs/Oval/Sonoma.png +3 -0
  30. example_wigs/Oval/Tuscany.png +3 -0
  31. example_wigs/Round/Loire.png +3 -0
  32. example_wigs/Round/Piemonte.png +3 -0
  33. example_wigs/Round/Sonoma.png +3 -0
  34. example_wigs/Round/Tuscany.png +3 -0
  35. example_wigs/Square/HH03.png +3 -0
  36. example_wigs/Square/Loire.png +3 -0
  37. example_wigs/Square/Piemonte.png +3 -0
  38. example_wigs/Square/Sonoma.png +3 -0
  39. example_wigs/Square/Tuscany.png +3 -0
  40. overlay.py +89 -0
  41. requirements.txt +35 -0
  42. roop/__init__.py +0 -0
  43. roop/capturer.py +20 -0
  44. roop/core.py +217 -0
  45. roop/face_analyser.py +124 -0
  46. roop/globals.py +17 -0
  47. roop/metadata.py +2 -0
  48. roop/predicter.py +25 -0
  49. roop/processors/__init__.py +0 -0
  50. roop/processors/frame/__init__.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example_wigs/Heart/HH02.png filter=lfs diff=lfs merge=lfs -text
37
+ example_wigs/Heart/HH03.png filter=lfs diff=lfs merge=lfs -text
38
+ example_wigs/Heart/Loire.png filter=lfs diff=lfs merge=lfs -text
39
+ example_wigs/Heart/SantaRossa.png filter=lfs diff=lfs merge=lfs -text
40
+ example_wigs/Heart/Tuscany.png filter=lfs diff=lfs merge=lfs -text
41
+ example_wigs/Oblong/HH01.png filter=lfs diff=lfs merge=lfs -text
42
+ example_wigs/Oblong/HH02.png filter=lfs diff=lfs merge=lfs -text
43
+ example_wigs/Oblong/HH03.png filter=lfs diff=lfs merge=lfs -text
44
+ example_wigs/Oblong/HH07.png filter=lfs diff=lfs merge=lfs -text
45
+ example_wigs/Oblong/Loire.png filter=lfs diff=lfs merge=lfs -text
46
+ example_wigs/Oval/Alsace.png filter=lfs diff=lfs merge=lfs -text
47
+ example_wigs/Oval/Barossa.png filter=lfs diff=lfs merge=lfs -text
48
+ example_wigs/Oval/Burgundy.png filter=lfs diff=lfs merge=lfs -text
49
+ example_wigs/Oval/HH01.png filter=lfs diff=lfs merge=lfs -text
50
+ example_wigs/Oval/HH02.png filter=lfs diff=lfs merge=lfs -text
51
+ example_wigs/Oval/HH03.png filter=lfs diff=lfs merge=lfs -text
52
+ example_wigs/Oval/HH07.png filter=lfs diff=lfs merge=lfs -text
53
+ example_wigs/Oval/Loire.png filter=lfs diff=lfs merge=lfs -text
54
+ example_wigs/Oval/Napa.png filter=lfs diff=lfs merge=lfs -text
55
+ example_wigs/Oval/Piemonte.png filter=lfs diff=lfs merge=lfs -text
56
+ example_wigs/Oval/Rhone.png filter=lfs diff=lfs merge=lfs -text
57
+ example_wigs/Oval/SantaRossa.png filter=lfs diff=lfs merge=lfs -text
58
+ example_wigs/Oval/Sonoma.png filter=lfs diff=lfs merge=lfs -text
59
+ example_wigs/Oval/Tuscany.png filter=lfs diff=lfs merge=lfs -text
60
+ example_wigs/Round/Loire.png filter=lfs diff=lfs merge=lfs -text
61
+ example_wigs/Round/Piemonte.png filter=lfs diff=lfs merge=lfs -text
62
+ example_wigs/Round/Sonoma.png filter=lfs diff=lfs merge=lfs -text
63
+ example_wigs/Round/Tuscany.png filter=lfs diff=lfs merge=lfs -text
64
+ example_wigs/Square/HH03.png filter=lfs diff=lfs merge=lfs -text
65
+ example_wigs/Square/Loire.png filter=lfs diff=lfs merge=lfs -text
66
+ example_wigs/Square/Piemonte.png filter=lfs diff=lfs merge=lfs -text
67
+ example_wigs/Square/Sonoma.png filter=lfs diff=lfs merge=lfs -text
68
+ example_wigs/Square/Tuscany.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Real Finals
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.34.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Ghep Image
3
+ emoji: 📉
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.31.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from overlay import overlay_source
4
+ from detect_face import predict, NUM_CLASSES
5
+ from swapface import swap_face_now
6
+ import os
7
+ from pathlib import Path
8
+
9
+ BASE_DIR = Path(__file__).parent # thư mục chứa app.py
10
+ FOLDER = BASE_DIR / "example_wigs"
11
+
12
+ # --- Hàm load ảnh từ folder ---
13
+ def load_images_from_folder(folder_path: str) -> list[str]:
14
+ """
15
+ Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path.
16
+ """
17
+ supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
18
+ if not os.path.isdir(folder_path):
19
+ print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.")
20
+ return []
21
+ files = [
22
+ os.path.join(folder_path, fn)
23
+ for fn in os.listdir(folder_path)
24
+ if os.path.splitext(fn)[1].lower() in supported
25
+ ]
26
+ if not files:
27
+ print(f"Không tìm thấy hình trong: {folder_path}")
28
+ return files
29
+
30
+ def on_gallery_select(evt: gr.SelectData):
31
+ """
32
+ Khi click thumbnail: trả về
33
+ 1) filepath để nạp vào Image Source
34
+ 2) tên file (basename) để hiển thị trong Textbox
35
+ """
36
+ val = evt.value
37
+
38
+ # --- logic trích filepath y như cũ ---
39
+ if isinstance(val, dict):
40
+ img = val.get("image")
41
+ if isinstance(img, str):
42
+ filepath = img
43
+ elif isinstance(img, dict):
44
+ filepath = img.get("path") or img.get("url")
45
+ else:
46
+ filepath = next(
47
+ (v for v in val.values() if isinstance(v, str) and os.path.isfile(v)),
48
+ None
49
+ )
50
+ elif isinstance(val, str):
51
+ filepath = val
52
+ else:
53
+ raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")
54
+
55
+ filename = os.path.basename(filepath) if filepath else ""
56
+ return filepath, filename
57
+
58
+ # --- Hàm xác định folder dựa trên phân lớp ---
59
+ def infer_folder(image) -> str:
60
+ cls = predict(image)["predicted_class"]
61
+ folder = str(FOLDER / cls)
62
+ return folder
63
+
64
+ # --- Hàm gộp: phân loại + load ảnh ---
65
+ def handle_bg_change(image):
66
+ """
67
+ Khi thay đổi background:
68
+ 1. Phân loại khuôn mặt
69
+ 2. Load ảnh từ folder tương ứng
70
+ """
71
+ if image is None:
72
+ return "", []
73
+
74
+ try:
75
+ folder = infer_folder(image)
76
+ images = load_images_from_folder(folder)
77
+ return folder, images
78
+ except Exception as e:
79
+ print(f"Lỗi xử lý ảnh: {e}")
80
+ return "", []
81
+
82
+ # --- Hàm swap face ---
83
+ def swap_face_wrapper(background_img, result_img):
84
+ """
85
+ Wrapper function cho swap face giữa background và result image
86
+ """
87
+ if background_img is None or result_img is None:
88
+ return None
89
+
90
+ try:
91
+ # Swap face từ background vào result image
92
+ swapped = swap_face_now(background_img, result_img, do_enhance=True)
93
+ return swapped
94
+ except Exception as e:
95
+ print(f"Lỗi swap face: {e}")
96
+ return result_img # Trả về ảnh gốc nếu có lỗi
97
+
98
+ # --- Hàm gộp overlay + swap face ---
99
+ def combined_hair_and_face(background_img, source_img):
100
+ """
101
+ Hàm gộp: chạy overlay trước, sau đó swap face
102
+ """
103
+ if background_img is None or source_img is None:
104
+ return None
105
+
106
+ try:
107
+ # Bước 1: Chạy overlay (ghép tóc)
108
+ hair_result = overlay_source(background_img, source_img)
109
+
110
+ # Bước 2: Chạy swap face từ background lên kết quả overlay
111
+ final_result = swap_face_wrapper(background_img, hair_result)
112
+
113
+ return final_result
114
+ except Exception as e:
115
+ print(f"Lỗi trong quá trình gộp hair + face: {e}")
116
+ return None
117
+
118
+ # --- Xây dựng giao diện Gradio ---
119
+ def build_demo():
120
+ with gr.Blocks(title="Hair Try-On & Face Swap", theme=gr.themes.Soft()) as demo:
121
+ gr.Markdown("""
122
+ # 🎯 Hair Try-On & Face Swap Application
123
+ """)
124
+ with gr.Row():
125
+ bg = gr.Image(type="pil", label="Background", height=500)
126
+ src = gr.Image(type="pil", label="Source", height=500, interactive=False)
127
+ out = gr.Image(label="Result", height=500, interactive=False)
128
+
129
+ folder_path_box = gr.Textbox(label="Folder path", visible=False)
130
+
131
+
132
+ with gr.Row():
133
+ src_name_box = gr.Textbox(
134
+ label="Wigs Name",
135
+ interactive=False,
136
+ show_copy_button=True , # tuỳ chọn – tiện copy đường dẫn
137
+ scale = 1
138
+ )
139
+ gallery = gr.Gallery(
140
+ label="Recommend For You",
141
+ height=300,
142
+ value=[],
143
+ type="filepath",
144
+ interactive=False,
145
+ columns=5,
146
+ object_fit="cover",
147
+ allow_preview=True,
148
+ scale = 8
149
+ )
150
+ with gr.Column(scale=1):
151
+ combined_btn = gr.Button("🔄✨ Run Hair + Face Swap", variant="primary")
152
+ btn = gr.Button("🔄 Run Hair Only", variant="secondary")
153
+ swap_btn = gr.Button("👤 Swap Face Only", variant="secondary")
154
+
155
+
156
+
157
+ # Chạy gộp hair + face swap
158
+ combined_btn.click(fn=combined_hair_and_face, inputs=[bg, src], outputs=[out])
159
+
160
+ # Chạy ghép tóc
161
+ btn.click(fn=overlay_source, inputs=[bg, src], outputs=[out])
162
+
163
+ # Chạy swap face
164
+ swap_btn.click(fn=swap_face_wrapper, inputs=[bg, out], outputs=[out])
165
+
166
+ # Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý
167
+ bg.change(
168
+ fn=handle_bg_change,
169
+ inputs=[bg],
170
+ outputs=[folder_path_box, gallery],
171
+ show_progress=True
172
+ )
173
+ # Khi chọn ảnh trong gallery, cập nhật vào khung Source
174
+ gallery.select(
175
+ fn=on_gallery_select,
176
+ outputs=[src, src_name_box]
177
+ )
178
+
179
+ return demo
180
+
181
+ if __name__ == "__main__":
182
+ build_demo().launch()
baldhead.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # baldhead.py
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ import tensorflow as tf
8
+ import gradio as gr
9
+
10
+ # Keras imports (note: keras-contrib must be installed)
11
+ import keras.backend as K
12
+ from keras.layers import (
13
+ Input,
14
+ Conv2D,
15
+ UpSampling2D,
16
+ LeakyReLU,
17
+ GlobalAveragePooling2D,
18
+ Dense,
19
+ Reshape,
20
+ Dropout,
21
+ Concatenate,
22
+ multiply, # ← Thêm import multiply
23
+ )
24
+ from keras.models import Model
25
+ from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
26
+
27
+ # RetinaFace + skimage for face alignment
28
+ from retinaface import RetinaFace
29
+ from skimage import transform as trans
30
+
31
+ # Hugging Face Hub helper
32
+ from huggingface_hub import hf_hub_download
33
+
34
+
35
+
36
+ # --- Face‐alignment helpers (giống code gốc) ---
37
+ image_size = [256, 256]
38
+ src_landmarks = np.array([
39
+ [30.2946, 51.6963],
40
+ [65.5318, 51.5014],
41
+ [48.0252, 71.7366],
42
+ [33.5493, 92.3655],
43
+ [62.7299, 92.2041],
44
+ ], dtype=np.float32)
45
+ src_landmarks[:, 0] += 8.0
46
+ src_landmarks[:, 0] += 15.0
47
+ src_landmarks[:, 1] += 30.0
48
+ src_landmarks /= 112
49
+ src_landmarks *= 200
50
+
51
+
52
+ def list2array(values):
53
+ return np.array(list(values))
54
+
55
+
56
+ def align_face(img: np.ndarray):
57
+ """
58
+ Detect faces + landmarks in `img` via RetinaFace.
59
+ Returns lists of aligned face patches (256×256 RGB),
60
+ corresponding binary masks, and the transformation matrices.
61
+ """
62
+ faces = RetinaFace.detect_faces(img)
63
+ bboxes = np.array([list2array(faces[f]['facial_area']) for f in faces])
64
+ landmarks = np.array([list2array(faces[f]['landmarks'].values()) for f in faces])
65
+
66
+ white_canvas = np.ones(img.shape, dtype=np.uint8) * 255
67
+ aligned_faces, masks, matrices = [], [], []
68
+
69
+ if bboxes.shape[0] > 0:
70
+ for i in range(bboxes.shape[0]):
71
+ dst = landmarks[i] # detected landmarks
72
+ tform = trans.SimilarityTransform()
73
+ tform.estimate(dst, src_landmarks)
74
+ M = tform.params[0:2, :]
75
+
76
+ warped_face = cv2.warpAffine(
77
+ img, M, (image_size[1], image_size[0]), borderValue=0.0
78
+ )
79
+ warped_mask = cv2.warpAffine(
80
+ white_canvas, M, (image_size[1], image_size[0]), borderValue=0.0
81
+ )
82
+
83
+ aligned_faces.append(warped_face)
84
+ masks.append(warped_mask)
85
+ matrices.append(tform.params[0:3, :])
86
+
87
+ return aligned_faces, masks, matrices
88
+
89
+
90
+ def put_face_back(
91
+ orig_img: np.ndarray,
92
+ processed_faces: list[np.ndarray],
93
+ masks: list[np.ndarray],
94
+ matrices: list[np.ndarray],
95
+ ):
96
+ """
97
+ Warp each processed face back onto the original `orig_img`
98
+ using the inverse of the transformation matrices.
99
+ """
100
+ result = orig_img.copy()
101
+ h, w = orig_img.shape[:2]
102
+
103
+ for i in range(len(processed_faces)):
104
+ invM = np.linalg.inv(matrices[i])[0:2]
105
+ warped = cv2.warpAffine(processed_faces[i], invM, (w, h), borderValue=0.0)
106
+ mask = cv2.warpAffine(masks[i], invM, (w, h), borderValue=0.0)
107
+ binary_mask = (mask // 255).astype(np.uint8)
108
+
109
+ # Composite: result = result * (1 - mask) + warped * mask
110
+ result = result * (1 - binary_mask)
111
+ result = result.astype(np.uint8)
112
+ result = result + warped * binary_mask
113
+
114
+ return result
115
+
116
+
117
+ # ----------------------------
118
+ # 2. GENERATOR ARCHITECTURE
119
+ # ----------------------------
120
+
121
+ def squeeze_excite_block(x, ratio=4):
122
+ """
123
+ Squeeze-and-Excitation block: channel-wise attention.
124
+ """
125
+ init = x
126
+ channel_axis = 1 if K.image_data_format() == "channels_first" else -1
127
+ filters = init.shape[channel_axis]
128
+ se_shape = (1, 1, filters)
129
+
130
+ se = GlobalAveragePooling2D()(init)
131
+ se = Reshape(se_shape)(se)
132
+ se = Dense(filters // ratio, activation="relu", kernel_initializer="he_normal", use_bias=False)(se)
133
+ se = Dense(filters, activation="sigmoid", kernel_initializer="he_normal", use_bias=False)(se)
134
+ return multiply([init, se])
135
+
136
+
137
+ def conv2d(layer_input, filters, f_size=4, bn=True, se=False):
138
+ """
139
+ Downsampling block: Conv2D → LeakyReLU → (InstanceNorm) → (SE block)
140
+ """
141
+ d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same")(layer_input)
142
+ d = LeakyReLU(alpha=0.2)(d)
143
+ if bn:
144
+ d = InstanceNormalization()(d)
145
+ if se:
146
+ d = squeeze_excite_block(d)
147
+ return d
148
+
149
+
150
+ def atrous(layer_input, filters, f_size=4, bn=True):
151
+ """
152
+ Atrous (dilated) convolution block with dilation rates [2,4,8].
153
+ """
154
+ a_list = []
155
+ for rate in [2, 4, 8]:
156
+ a = Conv2D(filters, f_size, dilation_rate=rate, padding="same")(layer_input)
157
+ a_list.append(a)
158
+ a = Concatenate()(a_list)
159
+ a = LeakyReLU(alpha=0.2)(a)
160
+ if bn:
161
+ a = InstanceNormalization()(a)
162
+ return a
163
+
164
+
165
+ def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
166
+ """
167
+ Upsampling block: UpSampling2D → Conv2D → (Dropout) → InstanceNorm → Concatenate(skip)
168
+ """
169
+ u = UpSampling2D(size=2)(layer_input)
170
+ u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same", activation="relu")(u)
171
+ if dropout_rate:
172
+ u = Dropout(dropout_rate)(u)
173
+ u = InstanceNormalization()(u)
174
+ u = Concatenate()([u, skip_input])
175
+ return u
176
+
177
+
178
+ def build_generator():
179
+ """
180
+ Reconstruct the generator architecture exactly as in the notebook,
181
+ then return a Keras Model object.
182
+ """
183
+ d0 = Input(shape=(256, 256, 3))
184
+ gf = 64
185
+
186
+ # Downsampling
187
+ d1 = conv2d(d0, gf, bn=False, se=True)
188
+ d2 = conv2d(d1, gf * 2, se=True)
189
+ d3 = conv2d(d2, gf * 4, se=True)
190
+ d4 = conv2d(d3, gf * 8)
191
+ d5 = conv2d(d4, gf * 8)
192
+
193
+ # Atrous block
194
+ a1 = atrous(d5, gf * 8)
195
+
196
+ # Upsampling
197
+ u3 = deconv2d(a1, d4, gf * 8)
198
+ u4 = deconv2d(u3, d3, gf * 4)
199
+ u5 = deconv2d(u4, d2, gf * 2)
200
+ u6 = deconv2d(u5, d1, gf)
201
+
202
+ # Final upsample + conv
203
+ u7 = UpSampling2D(size=2)(u6)
204
+ output_img = Conv2D(3, kernel_size=4, strides=1, padding="same", activation="tanh")(u7)
205
+
206
+ model = Model(d0, output_img)
207
+ return model
208
+
209
+
210
+ # ----------------------------
211
+ # 3. LOAD MODEL WEIGHTS
212
+ # ----------------------------
213
+
214
+ HF_REPO_ID = "VanNguyen1214/baldhead"
215
+ HF_FILENAME = "model_G_5_170.hdf5"
216
+ HF_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
217
+
218
+ def load_generator_from_hub():
219
+ """
220
+ Download the .hdf5 weights from HF Hub into cache,
221
+ rebuild the generator, then load weights.
222
+ """
223
+ local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME,token=HF_TOKEN)
224
+ gen = build_generator()
225
+ gen.load_weights(local_path)
226
+ return gen
227
+
228
+ # Load once at startup
229
+ try:
230
+ GENERATOR = load_generator_from_hub()
231
+ print(f"[INFO] Loaded generator weights from {HF_REPO_ID}/{HF_FILENAME}")
232
+ except Exception as e:
233
+ print("[ERROR] Could not load generator:", e)
234
+ GENERATOR = None
235
+
236
+
237
+ # ----------------------------
238
+ # 4. INFERENCE FUNCTION
239
+ # ----------------------------
240
+
241
+ def inference(image: Image.Image) -> Image.Image:
242
+ """
243
+ Gradio-compatible inference function:
244
+ - Convert PIL→ numpy RGB
245
+ - Align faces
246
+ - For each face: normalize to [-1,1], run through generator, denormalize to uint8
247
+ - Put processed faces back onto original image
248
+ - Return full-image PIL
249
+ """
250
+ if GENERATOR is None:
251
+ return image
252
+
253
+ orig = np.array(image.convert("RGB"))
254
+
255
+ faces, masks, mats = align_face(orig)
256
+ if len(faces) == 0:
257
+ return image
258
+
259
+ processed_faces = []
260
+ for face in faces:
261
+ face_input = face.astype(np.float32)
262
+ face_input = (face_input / 127.5) - 1.0 # scale to [-1,1]
263
+ face_input = np.expand_dims(face_input, axis=0) # (1,256,256,3)
264
+
265
+ pred = GENERATOR.predict(face_input)[0] # (256,256,3) in [-1,1]
266
+ pred = ((pred + 1.0) * 127.5).astype(np.uint8)
267
+ processed_faces.append(pred)
268
+
269
+ output_np = put_face_back(orig, processed_faces, masks, mats)
270
+ output_pil = Image.fromarray(output_np)
271
+
272
+ return output_pil
bbox_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+ def get_bbox_from_alpha(rgba: Image.Image):
5
+ arr = np.array(rgba)
6
+ alpha = arr[...,3]
7
+ ys, xs = np.where(alpha>0)
8
+ if ys.size == 0:
9
+ return None
10
+ x1, x2 = xs.min(), xs.max()
11
+ y1, y2 = ys.min(), ys.max()
12
+ return x1, y1, x2, y2
13
+
14
+ def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int,int]) -> Image.Image:
15
+ res = bg.copy()
16
+ x, y = offset
17
+ h, w = src.shape[:2]
18
+ x1, y1 = max(x,0), max(y,0)
19
+ x2 = min(x+w, bg.shape[1])
20
+ y2 = min(y+h, bg.shape[0])
21
+ if x1>=x2 or y1>=y2:
22
+ return Image.fromarray(res)
23
+ cs = src[y1-y:y2-y, x1-x:x2-x]
24
+ cd = res[y1:y2, x1:x2]
25
+ mask = cs[...,3]>0
26
+ if cd.shape[2]==3:
27
+ cd[mask] = cs[mask][..., :3]
28
+ else:
29
+ cd[mask] = cs[mask]
30
+ res[y1:y2, x1:x2] = cd
31
+ return Image.fromarray(res)
detect_face.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision
8
+ from torchvision import transforms
9
+ from huggingface_hub import hf_hub_download
10
+ from PIL import Image
11
+ import numpy as np
12
+
13
+ # --- Cấu hình chung ---
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ HF_REPO = "VanNguyen1214/detect_faceshape" # repo của bạn trên HF Hub
16
+ HF_FILENAME = "best_model.pth" # file ở root của repo
17
+ LOCAL_CKPT = "models/best_model.pth" # sẽ lưu tại đây
18
+ CLASS_NAMES = ['Heart', 'Oblong', 'Oval', 'Round', 'Square']
19
+ NUM_CLASSES = len(CLASS_NAMES)
20
+
21
+ # --- Transform cho ảnh trước inference ---
22
+ _TRANSFORM = transforms.Compose([
23
+ transforms.Resize((224, 224)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
26
+ std =[0.229, 0.224, 0.225]),
27
+ ])
28
+
29
+ def _ensure_checkpoint() -> str:
30
+ """
31
+ Kiểm tra xem LOCAL_CKPT đã tồn tại chưa.
32
+ Nếu chưa, tải best_model.pth từ HF_REPO và lưu vào ./models/
33
+ """
34
+ if os.path.exists(LOCAL_CKPT):
35
+ return LOCAL_CKPT
36
+
37
+ try:
38
+ ckpt_path = hf_hub_download(
39
+ repo_id=HF_REPO,
40
+ filename=HF_FILENAME,
41
+ local_dir="models",
42
+ )
43
+ return ckpt_path
44
+ except Exception as e:
45
+ print(f"❌ Không tải được model từ HF Hub: {e}")
46
+ sys.exit(1)
47
+
48
+ def _load_model(ckpt_path: str) -> torch.nn.Module:
49
+ """
50
+ Tái tạo kiến trúc EfficientNet-B4, load state_dict, đưa về eval mode.
51
+ """
52
+ # 1) Khởi tạo EfficientNet-B4
53
+ model = torchvision.models.efficientnet_b4(pretrained=False)
54
+ in_features = model.classifier[1].in_features
55
+ model.classifier = nn.Sequential(
56
+ nn.Dropout(p=0.3, inplace=True),
57
+ nn.Linear(in_features, NUM_CLASSES)
58
+ )
59
+
60
+ # 2) Load trọng số
61
+ state = torch.load(ckpt_path, map_location=DEVICE)
62
+ model.load_state_dict(state)
63
+
64
+ # 3) Đưa model về chế độ evaluation
65
+ return model.to(DEVICE).eval()
66
+
67
+ # === Build model ngay khi import ===
68
+ _CKPT_PATH = _ensure_checkpoint()
69
+ _MODEL = _load_model(_CKPT_PATH)
70
+
71
+ def predict(image: Image.Image) -> dict:
72
+ """
73
+ Chức năng inference:
74
+ - image: numpy array H×W×3 RGB
75
+ - Trả về dict:
76
+ {
77
+ "predicted_class": str,
78
+ "confidence": float,
79
+ "probabilities": { class_name: prob, ... }
80
+ }
81
+ """
82
+ # Convert về PIL + transform
83
+ img = image.convert("RGB")
84
+ x = _TRANSFORM(img).unsqueeze(0).to(DEVICE)
85
+
86
+ # Inference
87
+ with torch.no_grad():
88
+ logits = _MODEL(x)
89
+ probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
90
+
91
+ idx = int(probs.argmax())
92
+ return {"predicted_class": CLASS_NAMES[idx]}
93
+
example_wigs/Heart/HH02.png ADDED

Git LFS Details

  • SHA256: 357555727e476770a7e53ee10711ad8f795caedfdcb90adb5083bf077439c63e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.49 MB
example_wigs/Heart/HH03.png ADDED

Git LFS Details

  • SHA256: a5ba9ef2d6fe37480923fbbd93a7bdf6fdb0590ed5c93f8e741163be31bc26eb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
example_wigs/Heart/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Heart/SantaRossa.png ADDED

Git LFS Details

  • SHA256: e70fffdbe0a0b61b267f483ea35467a0108d5b961e86df7d293459a3944c93c4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.2 MB
example_wigs/Heart/Tuscany.png ADDED

Git LFS Details

  • SHA256: 35ebf617bbbab34b05d019042f5ab8e9eb90cd6a186957a96df7c3793c142a9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.99 MB
example_wigs/Oblong/HH01.png ADDED

Git LFS Details

  • SHA256: bdf028002be35de79da4067264cce2627b5739b7f356ece65c703f1878e83537
  • Pointer size: 132 Bytes
  • Size of remote file: 2.44 MB
example_wigs/Oblong/HH02.png ADDED

Git LFS Details

  • SHA256: 357555727e476770a7e53ee10711ad8f795caedfdcb90adb5083bf077439c63e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.49 MB
example_wigs/Oblong/HH03.png ADDED

Git LFS Details

  • SHA256: a5ba9ef2d6fe37480923fbbd93a7bdf6fdb0590ed5c93f8e741163be31bc26eb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
example_wigs/Oblong/HH07.png ADDED

Git LFS Details

  • SHA256: 1205c879380091b4fe13bdc29b070511f745b7365be956d627dc7b94c115118e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.72 MB
example_wigs/Oblong/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Oval/Alsace.png ADDED

Git LFS Details

  • SHA256: 83767c820759344c15bed941abd94a7f5e7fe8cb462a5ae2d1e289265269d5c7
  • Pointer size: 132 Bytes
  • Size of remote file: 2.28 MB
example_wigs/Oval/Barossa.png ADDED

Git LFS Details

  • SHA256: bf9f6e9abbc352390d1826f186dd08f3536eaba60d96131b81bab49468f202e8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.44 MB
example_wigs/Oval/Burgundy.png ADDED

Git LFS Details

  • SHA256: b48e47a7e1244efe2ed472fb212c39b1f646fc2e726f1a314d7b5cff475a2755
  • Pointer size: 132 Bytes
  • Size of remote file: 2.69 MB
example_wigs/Oval/HH01.png ADDED

Git LFS Details

  • SHA256: bdf028002be35de79da4067264cce2627b5739b7f356ece65c703f1878e83537
  • Pointer size: 132 Bytes
  • Size of remote file: 2.44 MB
example_wigs/Oval/HH02.png ADDED

Git LFS Details

  • SHA256: 357555727e476770a7e53ee10711ad8f795caedfdcb90adb5083bf077439c63e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.49 MB
example_wigs/Oval/HH03.png ADDED

Git LFS Details

  • SHA256: a5ba9ef2d6fe37480923fbbd93a7bdf6fdb0590ed5c93f8e741163be31bc26eb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
example_wigs/Oval/HH07.png ADDED

Git LFS Details

  • SHA256: 1205c879380091b4fe13bdc29b070511f745b7365be956d627dc7b94c115118e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.72 MB
example_wigs/Oval/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Oval/Napa.png ADDED

Git LFS Details

  • SHA256: 1a9a929040f0bb2d4d527f811b35a6f7d92135aca380afa72e729cc74db6c5a2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
example_wigs/Oval/Piemonte.png ADDED

Git LFS Details

  • SHA256: 43b0d004d0565425c442b5c75d1dfd0ac8efa239f600fe07c85524fa0eb09e83
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
example_wigs/Oval/Rhone.png ADDED

Git LFS Details

  • SHA256: 928ece7bd6fa34d6b0d4e98f9457199f8247b21d0cc5929aaa3d1edc6332722b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.06 MB
example_wigs/Oval/SantaRossa.png ADDED

Git LFS Details

  • SHA256: e70fffdbe0a0b61b267f483ea35467a0108d5b961e86df7d293459a3944c93c4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.2 MB
example_wigs/Oval/Sonoma.png ADDED

Git LFS Details

  • SHA256: a9d70d9b95a40319beeff562149c708a6525fccbb8245caf484cb8b2cb74edc6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.18 MB
example_wigs/Oval/Tuscany.png ADDED

Git LFS Details

  • SHA256: 35ebf617bbbab34b05d019042f5ab8e9eb90cd6a186957a96df7c3793c142a9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.99 MB
example_wigs/Round/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Round/Piemonte.png ADDED

Git LFS Details

  • SHA256: 43b0d004d0565425c442b5c75d1dfd0ac8efa239f600fe07c85524fa0eb09e83
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
example_wigs/Round/Sonoma.png ADDED

Git LFS Details

  • SHA256: a9d70d9b95a40319beeff562149c708a6525fccbb8245caf484cb8b2cb74edc6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.18 MB
example_wigs/Round/Tuscany.png ADDED

Git LFS Details

  • SHA256: 35ebf617bbbab34b05d019042f5ab8e9eb90cd6a186957a96df7c3793c142a9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.99 MB
example_wigs/Square/HH03.png ADDED

Git LFS Details

  • SHA256: a5ba9ef2d6fe37480923fbbd93a7bdf6fdb0590ed5c93f8e741163be31bc26eb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
example_wigs/Square/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Square/Piemonte.png ADDED

Git LFS Details

  • SHA256: 43b0d004d0565425c442b5c75d1dfd0ac8efa239f600fe07c85524fa0eb09e83
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
example_wigs/Square/Sonoma.png ADDED

Git LFS Details

  • SHA256: a9d70d9b95a40319beeff562149c708a6525fccbb8245caf484cb8b2cb74edc6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.18 MB
example_wigs/Square/Tuscany.png ADDED

Git LFS Details

  • SHA256: 35ebf617bbbab34b05d019042f5ab8e9eb90cd6a186957a96df7c3793c142a9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.99 MB
overlay.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import mediapipe as mp
4
+
5
+ from baldhead import inference # cạo tóc background
6
+ from segmentation import extract_hair
7
+
8
+ # MediaPipe Face Detection
9
+ mp_fd = mp.solutions.face_detection.FaceDetection(model_selection=1,
10
+ min_detection_confidence=0.5)
11
+
12
+ def get_face_bbox(img: Image.Image) -> tuple[int,int,int,int] | None:
13
+ arr = np.array(img.convert("RGB"))
14
+ res = mp_fd.process(arr)
15
+ if not res.detections:
16
+ return None
17
+ d = res.detections[0].location_data.relative_bounding_box
18
+ h, w = arr.shape[:2]
19
+ x1 = int(d.xmin * w)
20
+ y1 = int(d.ymin * h)
21
+ x2 = x1 + int(d.width * w)
22
+ y2 = y1 + int(d.height * h)
23
+ return x1, y1, x2, y2
24
+
25
+ def compute_scale(w_bg, h_bg, w_src, h_src) -> float:
26
+ return ((w_bg / w_src) + (h_bg / h_src)) / 2
27
+
28
+ def compute_offset(bbox_bg, bbox_src, scale) -> tuple[int,int]:
29
+ x1, y1, x2, y2 = bbox_bg
30
+ bg_cx = x1 + (x2 - x1)//2
31
+ bg_cy = y1 + (y2 - y1)//2
32
+ sx1, sy1, sx2, sy2 = bbox_src
33
+ src_cx = int((sx1 + (sx2 - sx1)//2) * scale)
34
+ src_cy = int((sy1 + (sy2 - sy1)//2) * scale)
35
+ return bg_cx - src_cx, bg_cy - src_cy
36
+
37
+ def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int,int]) -> Image.Image:
38
+ res = bg.copy()
39
+ x, y = offset
40
+ h, w = src.shape[:2]
41
+ x1, y1 = max(x,0), max(y,0)
42
+ x2 = min(x+w, bg.shape[1])
43
+ y2 = min(y+h, bg.shape[0])
44
+ if x1>=x2 or y1>=y2:
45
+ return Image.fromarray(res)
46
+ cs = src[y1-y:y2-y, x1-x:x2-x]
47
+ cd = res[y1:y2, x1:x2]
48
+ mask = cs[...,3] > 0
49
+ if cd.shape[2] == 3:
50
+ cd[mask] = cs[mask][...,:3]
51
+ else:
52
+ cd[mask] = cs[mask]
53
+ res[y1:y2, x1:x2] = cd
54
+ return Image.fromarray(res)
55
+
56
+ def overlay_source(background: Image.Image, source: Image.Image):
57
+ # 1) detect bboxes
58
+ bbox_bg = get_face_bbox(background)
59
+ bbox_src = get_face_bbox(source)
60
+ if bbox_bg is None:
61
+ return None, "❌ No face in background."
62
+ if bbox_src is None:
63
+ return None, "❌ No face in source."
64
+
65
+ # 2) compute scale & resize source
66
+ w_bg, h_bg = bbox_bg[2]-bbox_bg[0], bbox_bg[3]-bbox_bg[1]
67
+ w_src, h_src = bbox_src[2]-bbox_src[0], bbox_src[3]-bbox_src[1]
68
+ scale = compute_scale(w_bg, h_bg, w_src, h_src)
69
+ src_scaled = source.resize(
70
+ (int(source.width*scale), int(source.height*scale)),
71
+ Image.Resampling.LANCZOS
72
+ )
73
+
74
+ # 3) compute offset
75
+ offset = compute_offset(bbox_bg, bbox_src, scale)
76
+
77
+ # 4) baldhead background
78
+ bg_bald = inference(background)
79
+
80
+ # 5) extract hair-only from source
81
+ hair_only = extract_hair(src_scaled)
82
+
83
+ # 6) paste onto bald background
84
+ result = paste_with_alpha(
85
+ np.array(bg_bald.convert("RGBA")),
86
+ np.array(hair_only),
87
+ offset
88
+ )
89
+ return result
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118 # Dòng này có vẻ là comment hoặc cấu hình cho pip, không phải là một gói
2
+ # spaces # Dòng này không rõ ràng là một gói, có thể là ghi chú. Nếu không phải gói, hãy xóa đi.
3
+ huggingface_hub>=0.20.3
4
+ numpy==1.23.5
5
+ transformers==4.30.0
6
+ opencv-python-headless==4.7.0.72
7
+ onnx==1.14.0
8
+ insightface==0.7.3
9
+ psutil==5.9.5
10
+ tk==0.1.0 # Lưu ý: tk thường được bao gồm trong bản cài đặt Python chuẩn, không phải lúc nào cũng cần cài qua pip.
11
+ customtkinter==5.1.3
12
+ pillow==9.5.0
13
+ torch==2.0.1+cu118; sys_platform != 'darwin'
14
+ torch==2.0.1; sys_platform == 'darwin'
15
+ torchvision==0.15.2+cu118; sys_platform != 'darwin'
16
+ torchvision==0.15.2; sys_platform == 'darwin'
17
+ # onnxruntime==1.15.0; # Bỏ comment cho dòng này nếu bạn muốn cố định phiên bản cho mọi OS
18
+ # sys_platform == 'darwin' and platform_machine != 'arm64' # Comment
19
+ onnxruntime-silicon==1.13.1; sys_platform == 'darwin' and platform_machine == 'arm64'
20
+ onnxruntime-gpu==1.15.0; sys_platform != 'darwin' # Nên giữ lại dòng này cho non-darwin GPU
21
+ onnxruntime==1.15.0; sys_platform == 'darwin' and platform_machine != 'arm64' # Thêm lại dòng onnxruntime cho Mac Intel
22
+ tensorflow==2.12.0
23
+ # sys_platform != 'darwin' # Comment
24
+ opennsfw2==0.10.2
25
+ # protobuf==4.23.2 # Thay thế dòng này
26
+ protobuf==4.25.3 # *** THAY ĐỔI QUAN TRỌNG ***
27
+ tqdm==4.65.0
28
+ gfpgan==1.3.8
29
+ # torch # Dòng này không cần thiết vì torch đã được định nghĩa ở trên với phiên bản cụ thể.
30
+
31
+ # Thêm các thư viện mới cần thiết cho app.py đã cập nhật
32
+ scikit-image>=0.19 # Hoặc một phiên bản cụ thể hơn nếu bạn muốn, ví dụ: scikit-image==0.19.3
33
+ mediapipe==0.10.14 # *** THÊM MỚI HOẶC CẬP NHẬT *** (Phiên bản này yêu cầu protobuf >=4.25.3)
34
+ git+https://github.com/keras-team/keras-contrib.git
35
+ retina-face==0.0.13
roop/__init__.py ADDED
File without changes
roop/capturer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import cv2
3
+
4
+
5
+ def get_video_frame(video_path: str, frame_number: int = 0) -> Any:
6
+ capture = cv2.VideoCapture(video_path)
7
+ frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT)
8
+ capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1))
9
+ has_frame, frame = capture.read()
10
+ capture.release()
11
+ if has_frame:
12
+ return frame
13
+ return None
14
+
15
+
16
+ def get_video_frame_total(video_path: str) -> int:
17
+ capture = cv2.VideoCapture(video_path)
18
+ video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
19
+ capture.release()
20
+ return video_frame_total
roop/core.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import sys
5
+ # single thread doubles cuda performance - needs to be set before torch import
6
+ if any(arg.startswith('--execution-provider') for arg in sys.argv):
7
+ os.environ['OMP_NUM_THREADS'] = '1'
8
+ # reduce tensorflow log level
9
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
10
+ import warnings
11
+ from typing import List
12
+ import platform
13
+ import signal
14
+ import shutil
15
+ import argparse
16
+ import torch
17
+ import onnxruntime
18
+ import tensorflow
19
+
20
+ import roop.globals
21
+ import roop.metadata
22
+ import roop.ui as ui
23
+ from roop.predicter import predict_image, predict_video
24
+ from roop.processors.frame.core import get_frame_processors_modules
25
+ from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
26
+
27
+ if 'ROCMExecutionProvider' in roop.globals.execution_providers:
28
+ del torch
29
+
30
+ warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
31
+ warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
32
+
33
+
34
+ def parse_args() -> None:
35
+ signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
36
+ program = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100))
37
+ program.add_argument('-s', '--source', help='select an source image', dest='source_path')
38
+ program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
39
+ program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
40
+ program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
41
+ program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False)
42
+ program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True)
43
+ program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False)
44
+ program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False)
45
+ program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
46
+ program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
47
+ program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory())
48
+ program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
49
+ program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
50
+ program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
51
+
52
+ args = program.parse_args()
53
+
54
+ roop.globals.source_path = args.source_path
55
+ roop.globals.target_path = args.target_path
56
+ roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
57
+ roop.globals.frame_processors = args.frame_processor
58
+ roop.globals.headless = args.source_path or args.target_path or args.output_path
59
+ roop.globals.keep_fps = args.keep_fps
60
+ roop.globals.keep_audio = args.keep_audio
61
+ roop.globals.keep_frames = args.keep_frames
62
+ roop.globals.many_faces = args.many_faces
63
+ roop.globals.video_encoder = args.video_encoder
64
+ roop.globals.video_quality = args.video_quality
65
+ roop.globals.max_memory = args.max_memory
66
+ roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
67
+ roop.globals.execution_threads = args.execution_threads
68
+
69
+
70
+ def encode_execution_providers(execution_providers: List[str]) -> List[str]:
71
+ return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
72
+
73
+
74
+ def decode_execution_providers(execution_providers: List[str]) -> List[str]:
75
+ return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
76
+ if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
77
+
78
+
79
+ def suggest_max_memory() -> int:
80
+ if platform.system().lower() == 'darwin':
81
+ return 4
82
+ return 16
83
+
84
+
85
+ def suggest_execution_providers() -> List[str]:
86
+ return encode_execution_providers(onnxruntime.get_available_providers())
87
+
88
+
89
+ def suggest_execution_threads() -> int:
90
+ if 'DmlExecutionProvider' in roop.globals.execution_providers:
91
+ return 1
92
+ if 'ROCMExecutionProvider' in roop.globals.execution_providers:
93
+ return 1
94
+ return 8
95
+
96
+
97
+ def limit_resources() -> None:
98
+ # prevent tensorflow memory leak
99
+ gpus = tensorflow.config.experimental.list_physical_devices('GPU')
100
+ for gpu in gpus:
101
+ tensorflow.config.experimental.set_virtual_device_configuration(gpu, [
102
+ tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
103
+ ])
104
+ # limit memory usage
105
+ if roop.globals.max_memory:
106
+ memory = roop.globals.max_memory * 1024 ** 3
107
+ if platform.system().lower() == 'darwin':
108
+ memory = roop.globals.max_memory * 1024 ** 6
109
+ if platform.system().lower() == 'windows':
110
+ import ctypes
111
+ kernel32 = ctypes.windll.kernel32
112
+ kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
113
+ else:
114
+ import resource
115
+ resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
116
+
117
+
118
+ def release_resources() -> None:
119
+ if 'CUDAExecutionProvider' in roop.globals.execution_providers:
120
+ torch.cuda.empty_cache()
121
+
122
+
123
+ def pre_check() -> bool:
124
+ if sys.version_info < (3, 9):
125
+ update_status('Python version is not supported - please upgrade to 3.9 or higher.')
126
+ return False
127
+ if not shutil.which('ffmpeg'):
128
+ update_status('ffmpeg is not installed.')
129
+ return False
130
+ return True
131
+
132
+
133
+ def update_status(message: str, scope: str = 'ROOP.CORE') -> None:
134
+ print(f'[{scope}] {message}')
135
+ if not roop.globals.headless:
136
+ ui.update_status(message)
137
+
138
+
139
+ def start() -> None:
140
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
141
+ if not frame_processor.pre_start():
142
+ return
143
+ # process image to image
144
+ if has_image_extension(roop.globals.target_path):
145
+ if predict_image(roop.globals.target_path):
146
+ destroy()
147
+ shutil.copy2(roop.globals.target_path, roop.globals.output_path)
148
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
149
+ for frame_processor_name in roop.globals.frame_processors:
150
+ if frame_processor_name == frame_processor.frame_name:
151
+ update_status('Progressing...', frame_processor.NAME)
152
+ frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
153
+ frame_processor.post_process()
154
+ release_resources()
155
+ if is_image(roop.globals.target_path):
156
+ update_status('Processing to image succeed!')
157
+ else:
158
+ update_status('Processing to image failed!')
159
+ return
160
+ # process image to videos
161
+ if predict_video(roop.globals.target_path):
162
+ destroy()
163
+ update_status('Creating temp resources...')
164
+ create_temp(roop.globals.target_path)
165
+ update_status('Extracting frames...')
166
+ extract_frames(roop.globals.target_path)
167
+ temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
168
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
169
+ update_status('Progressing...', frame_processor.NAME)
170
+ frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
171
+ frame_processor.post_process()
172
+ release_resources()
173
+ # handles fps
174
+ if roop.globals.keep_fps:
175
+ update_status('Detecting fps...')
176
+ fps = detect_fps(roop.globals.target_path)
177
+ update_status(f'Creating video with {fps} fps...')
178
+ create_video(roop.globals.target_path, fps)
179
+ else:
180
+ update_status('Creating video with 30.0 fps...')
181
+ create_video(roop.globals.target_path)
182
+ # handle audio
183
+ if roop.globals.keep_audio:
184
+ if roop.globals.keep_fps:
185
+ update_status('Restoring audio...')
186
+ else:
187
+ update_status('Restoring audio might cause issues as fps are not kept...')
188
+ restore_audio(roop.globals.target_path, roop.globals.output_path)
189
+ else:
190
+ move_temp(roop.globals.target_path, roop.globals.output_path)
191
+ # clean and validate
192
+ clean_temp(roop.globals.target_path)
193
+ if is_video(roop.globals.target_path):
194
+ update_status('Processing to video succeed!')
195
+ else:
196
+ update_status('Processing to video failed!')
197
+
198
+
199
+ def destroy() -> None:
200
+ if roop.globals.target_path:
201
+ clean_temp(roop.globals.target_path)
202
+ quit()
203
+
204
+
205
+ def run() -> None:
206
+ parse_args()
207
+ if not pre_check():
208
+ return
209
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
210
+ if not frame_processor.pre_check():
211
+ return
212
+ limit_resources()
213
+ if roop.globals.headless:
214
+ start()
215
+ else:
216
+ window = ui.init(start, destroy)
217
+ window.mainloop()
roop/face_analyser.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import Any
3
+ import insightface
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import roop.globals
8
+ from roop.typing import Frame
9
+
10
+ FACE_ANALYSER = None
11
+ THREAD_LOCK = threading.Lock()
12
+
13
+
14
+ def get_face_analyser() -> Any:
15
+ global FACE_ANALYSER
16
+
17
+ with THREAD_LOCK:
18
+ if FACE_ANALYSER is None:
19
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers)
20
+ FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640))
21
+ return FACE_ANALYSER
22
+
23
+
24
+ def get_precise_face_mask(frame: Frame) -> Any:
25
+ """
26
+ Get precise face mask using advanced segmentation (same as detect_face_and_forehead_no_hair).
27
+ Returns both InsightFace detection and precise mask.
28
+ """
29
+ try:
30
+ # Import the precise detection function
31
+ import sys
32
+ import os
33
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
34
+ from segmentation import detect_face_and_forehead_no_hair
35
+
36
+ # Convert frame to PIL Image
37
+ if isinstance(frame, np.ndarray):
38
+ pil_image = Image.fromarray(frame)
39
+ else:
40
+ pil_image = frame
41
+
42
+ # Get precise face mask (clean skin only)
43
+ precise_mask = detect_face_and_forehead_no_hair(pil_image)
44
+
45
+ # Also get InsightFace detection for face swapping compatibility
46
+ insightface_faces = get_face_analyser().get(frame)
47
+
48
+ return {
49
+ 'precise_mask': precise_mask,
50
+ 'insightface_faces': insightface_faces,
51
+ 'has_face': precise_mask.sum() > 0 and len(insightface_faces) > 0
52
+ }
53
+
54
+ except Exception as e:
55
+ print(f"Precise face detection failed: {e}")
56
+ # Fallback to regular InsightFace
57
+ insightface_faces = get_face_analyser().get(frame)
58
+ return {
59
+ 'precise_mask': None,
60
+ 'insightface_faces': insightface_faces,
61
+ 'has_face': len(insightface_faces) > 0
62
+ }
63
+
64
+
65
+ def get_one_face(frame: Frame) -> Any:
66
+ """
67
+ Get one face with enhanced precision detection.
68
+ """
69
+ # Get precise detection info
70
+ face_info = get_precise_face_mask(frame)
71
+
72
+ if face_info['has_face'] and face_info['insightface_faces']:
73
+ try:
74
+ # Select face (leftmost) for compatibility
75
+ selected_face = min(face_info['insightface_faces'], key=lambda x: x.bbox[0])
76
+
77
+ # Add precise mask info to face object
78
+ if face_info['precise_mask'] is not None:
79
+ selected_face.precise_mask = face_info['precise_mask']
80
+ print(f"✅ Enhanced face detection: {face_info['precise_mask'].sum()} precise pixels")
81
+
82
+ return selected_face
83
+ except (ValueError, IndexError):
84
+ return None
85
+
86
+ # Fallback to original method
87
+ face = get_face_analyser().get(frame)
88
+ try:
89
+ selected_face = min(face, key=lambda x: x.bbox[0])
90
+ return selected_face
91
+ except ValueError:
92
+ return None
93
+
94
+
95
+ def get_many_faces(frame: Frame) -> Any:
96
+ """
97
+ Get many faces with enhanced precision detection.
98
+ """
99
+ # Get precise detection info
100
+ face_info = get_precise_face_mask(frame)
101
+
102
+ if face_info['has_face'] and face_info['insightface_faces']:
103
+ faces = face_info['insightface_faces']
104
+
105
+ # Add precise mask info to all face objects
106
+ if face_info['precise_mask'] is not None:
107
+ for face in faces:
108
+ face.precise_mask = face_info['precise_mask']
109
+
110
+ print(f"✅ Enhanced multi-face detection: {len(faces)} faces with precise masks")
111
+ return faces
112
+
113
+ # Fallback to original method
114
+ try:
115
+ return get_face_analyser().get(frame)
116
+ except IndexError:
117
+ return None
118
+
119
+
120
+ def has_precise_face_mask(face_obj) -> bool:
121
+ """
122
+ Check if face object has precise mask attached.
123
+ """
124
+ return hasattr(face_obj, 'precise_mask') and face_obj.precise_mask is not None
roop/globals.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ source_path = None
4
+ target_path = None
5
+ output_path = None
6
+ frame_processors: List[str] = []
7
+ keep_fps = None
8
+ keep_audio = None
9
+ keep_frames = None
10
+ many_faces = None
11
+ video_encoder = None
12
+ video_quality = None
13
+ max_memory = None
14
+ execution_providers: List[str] = []
15
+ execution_threads = None
16
+ headless = None
17
+ log_level = 'error'
roop/metadata.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ name = 'roop'
2
+ version = '1.1.0'
roop/predicter.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import opennsfw2
3
+ from PIL import Image
4
+
5
+ from roop.typing import Frame
6
+
7
+ MAX_PROBABILITY = 0.85
8
+
9
+
10
+ def predict_frame(target_frame: Frame) -> bool:
11
+ image = Image.fromarray(target_frame)
12
+ image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
13
+ model = opennsfw2.make_open_nsfw_model()
14
+ views = numpy.expand_dims(image, axis=0)
15
+ _, probability = model.predict(views)[0]
16
+ return probability > MAX_PROBABILITY
17
+
18
+
19
+ def predict_image(target_path: str) -> bool:
20
+ return opennsfw2.predict_image(target_path) > MAX_PROBABILITY
21
+
22
+
23
+ def predict_video(target_path: str) -> bool:
24
+ _, probabilities = opennsfw2.predict_video_frames(video_path=target_path, frame_interval=100)
25
+ return any(probability > MAX_PROBABILITY for probability in probabilities)
roop/processors/__init__.py ADDED
File without changes
roop/processors/frame/__init__.py ADDED
File without changes