VanNguyen1214 commited on
Commit
31a2a8b
·
verified ·
1 Parent(s): 34dcdcc

Upload 41 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example_wigs/Heart/HH02.png filter=lfs diff=lfs merge=lfs -text
37
+ example_wigs/Heart/HH03.png filter=lfs diff=lfs merge=lfs -text
38
+ example_wigs/Heart/Loire.png filter=lfs diff=lfs merge=lfs -text
39
+ example_wigs/Heart/SantaRossa.png filter=lfs diff=lfs merge=lfs -text
40
+ example_wigs/Heart/Tuscany.png filter=lfs diff=lfs merge=lfs -text
41
+ example_wigs/Oblong/HH01.png filter=lfs diff=lfs merge=lfs -text
42
+ example_wigs/Oblong/HH02.png filter=lfs diff=lfs merge=lfs -text
43
+ example_wigs/Oblong/HH03.png filter=lfs diff=lfs merge=lfs -text
44
+ example_wigs/Oblong/HH07.png filter=lfs diff=lfs merge=lfs -text
45
+ example_wigs/Oblong/Loire.png filter=lfs diff=lfs merge=lfs -text
46
+ example_wigs/Oval/Alsace.png filter=lfs diff=lfs merge=lfs -text
47
+ example_wigs/Oval/Barossa.png filter=lfs diff=lfs merge=lfs -text
48
+ example_wigs/Oval/Burgundy.png filter=lfs diff=lfs merge=lfs -text
49
+ example_wigs/Oval/HH01.png filter=lfs diff=lfs merge=lfs -text
50
+ example_wigs/Oval/HH02.png filter=lfs diff=lfs merge=lfs -text
51
+ example_wigs/Oval/HH03.png filter=lfs diff=lfs merge=lfs -text
52
+ example_wigs/Oval/HH07.png filter=lfs diff=lfs merge=lfs -text
53
+ example_wigs/Oval/Loire.png filter=lfs diff=lfs merge=lfs -text
54
+ example_wigs/Oval/Napa.png filter=lfs diff=lfs merge=lfs -text
55
+ example_wigs/Oval/Piemonte.png filter=lfs diff=lfs merge=lfs -text
56
+ example_wigs/Oval/Rhone.png filter=lfs diff=lfs merge=lfs -text
57
+ example_wigs/Oval/SantaRossa.png filter=lfs diff=lfs merge=lfs -text
58
+ example_wigs/Oval/Sonoma.png filter=lfs diff=lfs merge=lfs -text
59
+ example_wigs/Oval/Tuscany.png filter=lfs diff=lfs merge=lfs -text
60
+ example_wigs/Round/Loire.png filter=lfs diff=lfs merge=lfs -text
61
+ example_wigs/Round/Piemonte.png filter=lfs diff=lfs merge=lfs -text
62
+ example_wigs/Round/Sonoma.png filter=lfs diff=lfs merge=lfs -text
63
+ example_wigs/Round/Tuscany.png filter=lfs diff=lfs merge=lfs -text
64
+ example_wigs/Square/HH03.png filter=lfs diff=lfs merge=lfs -text
65
+ example_wigs/Square/Loire.png filter=lfs diff=lfs merge=lfs -text
66
+ example_wigs/Square/Piemonte.png filter=lfs diff=lfs merge=lfs -text
67
+ example_wigs/Square/Sonoma.png filter=lfs diff=lfs merge=lfs -text
68
+ example_wigs/Square/Tuscany.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Be Rejection
3
- emoji: 🌍
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.33.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Ghep Image
3
+ emoji: 📉
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.31.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from overlay import overlay_source
4
+ from detect_face import predict, NUM_CLASSES
5
+ import os
6
+ from pathlib import Path
7
+
8
+ BASE_DIR = Path(__file__).parent # thư mục chứa app.py
9
+ FOLDER = BASE_DIR / "example_wigs"
10
+
11
+ # --- Hàm load ảnh từ folder ---
12
+ def load_images_from_folder(folder_path: str) -> list[str]:
13
+ """
14
+ Trả về list[str] chứa tất cả các hình (jpg, png, gif, bmp) trong folder_path.
15
+ """
16
+ supported = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
17
+ if not os.path.isdir(folder_path):
18
+ print(f"Cảnh báo: '{folder_path}' không phải folder hợp lệ.")
19
+ return []
20
+ files = [
21
+ os.path.join(folder_path, fn)
22
+ for fn in os.listdir(folder_path)
23
+ if os.path.splitext(fn)[1].lower() in supported
24
+ ]
25
+ if not files:
26
+ print(f"Không tìm thấy hình trong: {folder_path}")
27
+ return files
28
+
29
+ # --- Handler khi click thumbnail của Gallery ---
30
+ # def on_gallery_select(evt: gr.SelectData):
31
+ # """
32
+ # Xử lý khi click vào ảnh trong gallery - tối ưu và robust.
33
+ # """
34
+ # val = evt.value
35
+ # if isinstance(val, dict):
36
+ # img = val.get("image")
37
+ # if isinstance(img, str): return img
38
+ # if isinstance(img, dict):
39
+ # path = img.get("path") or img.get("url")
40
+ # if isinstance(path, str): return path
41
+ # for v in img.values():
42
+ # if isinstance(v, str) and os.path.isfile(v):
43
+ # return v
44
+ # for v in val.values():
45
+ # if isinstance(v, str) and os.path.isfile(v):
46
+ # return v
47
+ # raise ValueError(f"Không trích được filepath từ dict: {val}")
48
+ # if isinstance(val, str):
49
+ # return val
50
+ # raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")
51
+
52
+ def on_gallery_select(evt: gr.SelectData):
53
+ """
54
+ Khi click thumbnail: trả về
55
+ 1) filepath để nạp vào Image Source
56
+ 2) tên file (basename) để hiển thị trong Textbox
57
+ """
58
+ val = evt.value
59
+
60
+ # --- logic trích filepath y như cũ ---
61
+ if isinstance(val, dict):
62
+ img = val.get("image")
63
+ if isinstance(img, str):
64
+ filepath = img
65
+ elif isinstance(img, dict):
66
+ filepath = img.get("path") or img.get("url")
67
+ else:
68
+ filepath = next(
69
+ (v for v in val.values() if isinstance(v, str) and os.path.isfile(v)),
70
+ None
71
+ )
72
+ elif isinstance(val, str):
73
+ filepath = val
74
+ else:
75
+ raise ValueError(f"Kiểu không hỗ trợ: {type(val)}")
76
+
77
+ filename = os.path.basename(filepath) if filepath else ""
78
+ return filepath, filename
79
+
80
+ # --- Hàm xác định folder dựa trên phân lớp ---
81
+ def infer_folder(image) -> str:
82
+ cls = predict(image)["predicted_class"]
83
+ folder = str(FOLDER / cls)
84
+ return folder
85
+
86
+ # --- Hàm gộp: phân loại + load ảnh ---
87
+ def handle_bg_change(image):
88
+ """
89
+ Khi thay đổi background:
90
+ 1. Phân loại khuôn mặt
91
+ 2. Load ảnh từ folder tương ứng
92
+ """
93
+ if image is None:
94
+ return "", []
95
+
96
+ try:
97
+ folder = infer_folder(image)
98
+ images = load_images_from_folder(folder)
99
+ return folder, images
100
+ except Exception as e:
101
+ print(f"Lỗi xử lý ảnh: {e}")
102
+ return "", []
103
+
104
+ # --- Xây dựng giao diện Gradio ---
105
+ def build_demo():
106
+ with gr.Blocks(title="Xử lý hai hình ảnh", theme=gr.themes.Soft()) as demo:
107
+ gr.Markdown("Upload Background & Source, click **Run** to try on wigs.")
108
+
109
+ with gr.Row():
110
+ bg = gr.Image(type="pil", label="Background", height=500)
111
+ src = gr.Image(type="pil", label="Source", height=500, interactive=False)
112
+ out = gr.Image(label="Result", height=500, interactive=False)
113
+
114
+ folder_path_box = gr.Textbox(label="Folder path", visible=False)
115
+
116
+
117
+ with gr.Row():
118
+ src_name_box = gr.Textbox(
119
+ label="Wigs Name",
120
+ interactive=False,
121
+ show_copy_button=True , # tuỳ chọn – tiện copy đường dẫn
122
+ scale = 1
123
+ )
124
+ gallery = gr.Gallery(
125
+ label="Recommend For You",
126
+ height=300,
127
+ value=[],
128
+ type="filepath",
129
+ interactive=False,
130
+ columns=5,
131
+ object_fit="cover",
132
+ allow_preview=True,
133
+ scale = 8
134
+ )
135
+ btn = gr.Button("🔄 Run", variant="primary",scale = 1)
136
+
137
+
138
+
139
+ # Chạy ghép tóc
140
+ btn.click(fn=overlay_source, inputs=[bg, src], outputs=[out])
141
+ # Khi đổi ảnh background, tự động phân loại và load ảnh gợi ý
142
+ bg.change(
143
+ fn=handle_bg_change,
144
+ inputs=[bg],
145
+ outputs=[folder_path_box, gallery],
146
+ show_progress=True
147
+ )
148
+ # Nút tải lại ảnh thủ công (backup)
149
+ # Khi chọn ảnh trong gallery, cập nhật vào khung Source
150
+ gallery.select(
151
+ fn=on_gallery_select,
152
+ outputs=[src, src_name_box]
153
+ )
154
+
155
+ return demo
156
+
157
+ if __name__ == "__main__":
158
+ build_demo().launch()
baldhead.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # baldhead.py
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ import tensorflow as tf
8
+ import gradio as gr
9
+
10
+ # Keras imports (note: keras-contrib must be installed)
11
+ import keras.backend as K
12
+ from keras.layers import (
13
+ Input,
14
+ Conv2D,
15
+ UpSampling2D,
16
+ LeakyReLU,
17
+ GlobalAveragePooling2D,
18
+ Dense,
19
+ Reshape,
20
+ Dropout,
21
+ Concatenate,
22
+ multiply, # ← Thêm import multiply
23
+ )
24
+ from keras.models import Model
25
+ from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
26
+
27
+ # RetinaFace + skimage for face alignment
28
+ from retinaface import RetinaFace
29
+ from skimage import transform as trans
30
+
31
+ # Hugging Face Hub helper
32
+ from huggingface_hub import hf_hub_download
33
+
34
+
35
+
36
+ # --- Face‐alignment helpers (giống code gốc) ---
37
+ image_size = [256, 256]
38
+ src_landmarks = np.array([
39
+ [30.2946, 51.6963],
40
+ [65.5318, 51.5014],
41
+ [48.0252, 71.7366],
42
+ [33.5493, 92.3655],
43
+ [62.7299, 92.2041],
44
+ ], dtype=np.float32)
45
+ src_landmarks[:, 0] += 8.0
46
+ src_landmarks[:, 0] += 15.0
47
+ src_landmarks[:, 1] += 30.0
48
+ src_landmarks /= 112
49
+ src_landmarks *= 200
50
+
51
+
52
+ def list2array(values):
53
+ return np.array(list(values))
54
+
55
+
56
+ def align_face(img: np.ndarray):
57
+ """
58
+ Detect faces + landmarks in `img` via RetinaFace.
59
+ Returns lists of aligned face patches (256×256 RGB),
60
+ corresponding binary masks, and the transformation matrices.
61
+ """
62
+ faces = RetinaFace.detect_faces(img)
63
+ bboxes = np.array([list2array(faces[f]['facial_area']) for f in faces])
64
+ landmarks = np.array([list2array(faces[f]['landmarks'].values()) for f in faces])
65
+
66
+ white_canvas = np.ones(img.shape, dtype=np.uint8) * 255
67
+ aligned_faces, masks, matrices = [], [], []
68
+
69
+ if bboxes.shape[0] > 0:
70
+ for i in range(bboxes.shape[0]):
71
+ dst = landmarks[i] # detected landmarks
72
+ tform = trans.SimilarityTransform()
73
+ tform.estimate(dst, src_landmarks)
74
+ M = tform.params[0:2, :]
75
+
76
+ warped_face = cv2.warpAffine(
77
+ img, M, (image_size[1], image_size[0]), borderValue=0.0
78
+ )
79
+ warped_mask = cv2.warpAffine(
80
+ white_canvas, M, (image_size[1], image_size[0]), borderValue=0.0
81
+ )
82
+
83
+ aligned_faces.append(warped_face)
84
+ masks.append(warped_mask)
85
+ matrices.append(tform.params[0:3, :])
86
+
87
+ return aligned_faces, masks, matrices
88
+
89
+
90
+ def put_face_back(
91
+ orig_img: np.ndarray,
92
+ processed_faces: list[np.ndarray],
93
+ masks: list[np.ndarray],
94
+ matrices: list[np.ndarray],
95
+ ):
96
+ """
97
+ Warp each processed face back onto the original `orig_img`
98
+ using the inverse of the transformation matrices.
99
+ """
100
+ result = orig_img.copy()
101
+ h, w = orig_img.shape[:2]
102
+
103
+ for i in range(len(processed_faces)):
104
+ invM = np.linalg.inv(matrices[i])[0:2]
105
+ warped = cv2.warpAffine(processed_faces[i], invM, (w, h), borderValue=0.0)
106
+ mask = cv2.warpAffine(masks[i], invM, (w, h), borderValue=0.0)
107
+ binary_mask = (mask // 255).astype(np.uint8)
108
+
109
+ # Composite: result = result * (1 - mask) + warped * mask
110
+ result = result * (1 - binary_mask)
111
+ result = result.astype(np.uint8)
112
+ result = result + warped * binary_mask
113
+
114
+ return result
115
+
116
+
117
+ # ----------------------------
118
+ # 2. GENERATOR ARCHITECTURE
119
+ # ----------------------------
120
+
121
+ def squeeze_excite_block(x, ratio=4):
122
+ """
123
+ Squeeze-and-Excitation block: channel-wise attention.
124
+ """
125
+ init = x
126
+ channel_axis = 1 if K.image_data_format() == "channels_first" else -1
127
+ filters = init.shape[channel_axis]
128
+ se_shape = (1, 1, filters)
129
+
130
+ se = GlobalAveragePooling2D()(init)
131
+ se = Reshape(se_shape)(se)
132
+ se = Dense(filters // ratio, activation="relu", kernel_initializer="he_normal", use_bias=False)(se)
133
+ se = Dense(filters, activation="sigmoid", kernel_initializer="he_normal", use_bias=False)(se)
134
+ return multiply([init, se])
135
+
136
+
137
+ def conv2d(layer_input, filters, f_size=4, bn=True, se=False):
138
+ """
139
+ Downsampling block: Conv2D → LeakyReLU → (InstanceNorm) → (SE block)
140
+ """
141
+ d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same")(layer_input)
142
+ d = LeakyReLU(alpha=0.2)(d)
143
+ if bn:
144
+ d = InstanceNormalization()(d)
145
+ if se:
146
+ d = squeeze_excite_block(d)
147
+ return d
148
+
149
+
150
+ def atrous(layer_input, filters, f_size=4, bn=True):
151
+ """
152
+ Atrous (dilated) convolution block with dilation rates [2,4,8].
153
+ """
154
+ a_list = []
155
+ for rate in [2, 4, 8]:
156
+ a = Conv2D(filters, f_size, dilation_rate=rate, padding="same")(layer_input)
157
+ a_list.append(a)
158
+ a = Concatenate()(a_list)
159
+ a = LeakyReLU(alpha=0.2)(a)
160
+ if bn:
161
+ a = InstanceNormalization()(a)
162
+ return a
163
+
164
+
165
+ def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
166
+ """
167
+ Upsampling block: UpSampling2D → Conv2D → (Dropout) → InstanceNorm → Concatenate(skip)
168
+ """
169
+ u = UpSampling2D(size=2)(layer_input)
170
+ u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same", activation="relu")(u)
171
+ if dropout_rate:
172
+ u = Dropout(dropout_rate)(u)
173
+ u = InstanceNormalization()(u)
174
+ u = Concatenate()([u, skip_input])
175
+ return u
176
+
177
+
178
+ def build_generator():
179
+ """
180
+ Reconstruct the generator architecture exactly as in the notebook,
181
+ then return a Keras Model object.
182
+ """
183
+ d0 = Input(shape=(256, 256, 3))
184
+ gf = 64
185
+
186
+ # Downsampling
187
+ d1 = conv2d(d0, gf, bn=False, se=True)
188
+ d2 = conv2d(d1, gf * 2, se=True)
189
+ d3 = conv2d(d2, gf * 4, se=True)
190
+ d4 = conv2d(d3, gf * 8)
191
+ d5 = conv2d(d4, gf * 8)
192
+
193
+ # Atrous block
194
+ a1 = atrous(d5, gf * 8)
195
+
196
+ # Upsampling
197
+ u3 = deconv2d(a1, d4, gf * 8)
198
+ u4 = deconv2d(u3, d3, gf * 4)
199
+ u5 = deconv2d(u4, d2, gf * 2)
200
+ u6 = deconv2d(u5, d1, gf)
201
+
202
+ # Final upsample + conv
203
+ u7 = UpSampling2D(size=2)(u6)
204
+ output_img = Conv2D(3, kernel_size=4, strides=1, padding="same", activation="tanh")(u7)
205
+
206
+ model = Model(d0, output_img)
207
+ return model
208
+
209
+
210
+ # ----------------------------
211
+ # 3. LOAD MODEL WEIGHTS
212
+ # ----------------------------
213
+
214
+ HF_REPO_ID = "VanNguyen1214/baldhead"
215
+ HF_FILENAME = "model_G_5_170.hdf5"
216
+ HF_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
217
+
218
+ def load_generator_from_hub():
219
+ """
220
+ Download the .hdf5 weights from HF Hub into cache,
221
+ rebuild the generator, then load weights.
222
+ """
223
+ local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME,token=HF_TOKEN)
224
+ gen = build_generator()
225
+ gen.load_weights(local_path)
226
+ return gen
227
+
228
+ # Load once at startup
229
+ try:
230
+ GENERATOR = load_generator_from_hub()
231
+ print(f"[INFO] Loaded generator weights from {HF_REPO_ID}/{HF_FILENAME}")
232
+ except Exception as e:
233
+ print("[ERROR] Could not load generator:", e)
234
+ GENERATOR = None
235
+
236
+
237
+ # ----------------------------
238
+ # 4. INFERENCE FUNCTION
239
+ # ----------------------------
240
+
241
+ def inference(image: Image.Image) -> Image.Image:
242
+ """
243
+ Gradio-compatible inference function:
244
+ - Convert PIL→ numpy RGB
245
+ - Align faces
246
+ - For each face: normalize to [-1,1], run through generator, denormalize to uint8
247
+ - Put processed faces back onto original image
248
+ - Return full-image PIL
249
+ """
250
+ if GENERATOR is None:
251
+ return image
252
+
253
+ orig = np.array(image.convert("RGB"))
254
+
255
+ faces, masks, mats = align_face(orig)
256
+ if len(faces) == 0:
257
+ return image
258
+
259
+ processed_faces = []
260
+ for face in faces:
261
+ face_input = face.astype(np.float32)
262
+ face_input = (face_input / 127.5) - 1.0 # scale to [-1,1]
263
+ face_input = np.expand_dims(face_input, axis=0) # (1,256,256,3)
264
+
265
+ pred = GENERATOR.predict(face_input)[0] # (256,256,3) in [-1,1]
266
+ pred = ((pred + 1.0) * 127.5).astype(np.uint8)
267
+ processed_faces.append(pred)
268
+
269
+ output_np = put_face_back(orig, processed_faces, masks, mats)
270
+ output_pil = Image.fromarray(output_np)
271
+
272
+ return output_pil
bbox_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+ def get_bbox_from_alpha(rgba: Image.Image):
5
+ arr = np.array(rgba)
6
+ alpha = arr[...,3]
7
+ ys, xs = np.where(alpha>0)
8
+ if ys.size == 0:
9
+ return None
10
+ x1, x2 = xs.min(), xs.max()
11
+ y1, y2 = ys.min(), ys.max()
12
+ return x1, y1, x2, y2
13
+
14
+ def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int,int]) -> Image.Image:
15
+ res = bg.copy()
16
+ x, y = offset
17
+ h, w = src.shape[:2]
18
+ x1, y1 = max(x,0), max(y,0)
19
+ x2 = min(x+w, bg.shape[1])
20
+ y2 = min(y+h, bg.shape[0])
21
+ if x1>=x2 or y1>=y2:
22
+ return Image.fromarray(res)
23
+ cs = src[y1-y:y2-y, x1-x:x2-x]
24
+ cd = res[y1:y2, x1:x2]
25
+ mask = cs[...,3]>0
26
+ if cd.shape[2]==3:
27
+ cd[mask] = cs[mask][..., :3]
28
+ else:
29
+ cd[mask] = cs[mask]
30
+ res[y1:y2, x1:x2] = cd
31
+ return Image.fromarray(res)
detect_face.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision
8
+ from torchvision import transforms
9
+ from huggingface_hub import hf_hub_download
10
+ from PIL import Image
11
+ import numpy as np
12
+
13
+ # --- Cấu hình chung ---
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ HF_REPO = "VanNguyen1214/detect_faceshape" # repo của bạn trên HF Hub
16
+ HF_FILENAME = "best_model.pth" # file ở root của repo
17
+ LOCAL_CKPT = "models/best_model.pth" # sẽ lưu tại đây
18
+ CLASS_NAMES = ['Heart', 'Oblong', 'Oval', 'Round', 'Square']
19
+ NUM_CLASSES = len(CLASS_NAMES)
20
+
21
+ # --- Transform cho ảnh trước inference ---
22
+ _TRANSFORM = transforms.Compose([
23
+ transforms.Resize((224, 224)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
26
+ std =[0.229, 0.224, 0.225]),
27
+ ])
28
+
29
+ def _ensure_checkpoint() -> str:
30
+ """
31
+ Kiểm tra xem LOCAL_CKPT đã tồn tại chưa.
32
+ Nếu chưa, tải best_model.pth từ HF_REPO và lưu vào ./models/
33
+ """
34
+ if os.path.exists(LOCAL_CKPT):
35
+ return LOCAL_CKPT
36
+
37
+ try:
38
+ ckpt_path = hf_hub_download(
39
+ repo_id=HF_REPO,
40
+ filename=HF_FILENAME,
41
+ local_dir="models",
42
+ )
43
+ return ckpt_path
44
+ except Exception as e:
45
+ print(f"❌ Không tải được model từ HF Hub: {e}")
46
+ sys.exit(1)
47
+
48
+ def _load_model(ckpt_path: str) -> torch.nn.Module:
49
+ """
50
+ Tái tạo kiến trúc EfficientNet-B4, load state_dict, đưa về eval mode.
51
+ """
52
+ # 1) Khởi tạo EfficientNet-B4
53
+ model = torchvision.models.efficientnet_b4(pretrained=False)
54
+ in_features = model.classifier[1].in_features
55
+ model.classifier = nn.Sequential(
56
+ nn.Dropout(p=0.3, inplace=True),
57
+ nn.Linear(in_features, NUM_CLASSES)
58
+ )
59
+
60
+ # 2) Load trọng số
61
+ state = torch.load(ckpt_path, map_location=DEVICE)
62
+ model.load_state_dict(state)
63
+
64
+ # 3) Đưa model về chế độ evaluation
65
+ return model.to(DEVICE).eval()
66
+
67
+ # === Build model ngay khi import ===
68
+ _CKPT_PATH = _ensure_checkpoint()
69
+ _MODEL = _load_model(_CKPT_PATH)
70
+
71
+ def predict(image: Image.Image) -> dict:
72
+ """
73
+ Chức năng inference:
74
+ - image: numpy array H×W×3 RGB
75
+ - Trả về dict:
76
+ {
77
+ "predicted_class": str,
78
+ "confidence": float,
79
+ "probabilities": { class_name: prob, ... }
80
+ }
81
+ """
82
+ # Convert về PIL + transform
83
+ img = image.convert("RGB")
84
+ x = _TRANSFORM(img).unsqueeze(0).to(DEVICE)
85
+
86
+ # Inference
87
+ with torch.no_grad():
88
+ logits = _MODEL(x)
89
+ probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
90
+
91
+ idx = int(probs.argmax())
92
+ return {"predicted_class": CLASS_NAMES[idx]}
93
+
example_wigs/Heart/HH02.png ADDED

Git LFS Details

  • SHA256: 357555727e476770a7e53ee10711ad8f795caedfdcb90adb5083bf077439c63e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.49 MB
example_wigs/Heart/HH03.png ADDED

Git LFS Details

  • SHA256: a5ba9ef2d6fe37480923fbbd93a7bdf6fdb0590ed5c93f8e741163be31bc26eb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
example_wigs/Heart/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Heart/SantaRossa.png ADDED

Git LFS Details

  • SHA256: e70fffdbe0a0b61b267f483ea35467a0108d5b961e86df7d293459a3944c93c4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.2 MB
example_wigs/Heart/Tuscany.png ADDED

Git LFS Details

  • SHA256: 35ebf617bbbab34b05d019042f5ab8e9eb90cd6a186957a96df7c3793c142a9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.99 MB
example_wigs/Oblong/HH01.png ADDED

Git LFS Details

  • SHA256: bdf028002be35de79da4067264cce2627b5739b7f356ece65c703f1878e83537
  • Pointer size: 132 Bytes
  • Size of remote file: 2.44 MB
example_wigs/Oblong/HH02.png ADDED

Git LFS Details

  • SHA256: 357555727e476770a7e53ee10711ad8f795caedfdcb90adb5083bf077439c63e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.49 MB
example_wigs/Oblong/HH03.png ADDED

Git LFS Details

  • SHA256: a5ba9ef2d6fe37480923fbbd93a7bdf6fdb0590ed5c93f8e741163be31bc26eb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
example_wigs/Oblong/HH07.png ADDED

Git LFS Details

  • SHA256: 1205c879380091b4fe13bdc29b070511f745b7365be956d627dc7b94c115118e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.72 MB
example_wigs/Oblong/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Oval/Alsace.png ADDED

Git LFS Details

  • SHA256: 83767c820759344c15bed941abd94a7f5e7fe8cb462a5ae2d1e289265269d5c7
  • Pointer size: 132 Bytes
  • Size of remote file: 2.28 MB
example_wigs/Oval/Barossa.png ADDED

Git LFS Details

  • SHA256: bf9f6e9abbc352390d1826f186dd08f3536eaba60d96131b81bab49468f202e8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.44 MB
example_wigs/Oval/Burgundy.png ADDED

Git LFS Details

  • SHA256: b48e47a7e1244efe2ed472fb212c39b1f646fc2e726f1a314d7b5cff475a2755
  • Pointer size: 132 Bytes
  • Size of remote file: 2.69 MB
example_wigs/Oval/HH01.png ADDED

Git LFS Details

  • SHA256: bdf028002be35de79da4067264cce2627b5739b7f356ece65c703f1878e83537
  • Pointer size: 132 Bytes
  • Size of remote file: 2.44 MB
example_wigs/Oval/HH02.png ADDED

Git LFS Details

  • SHA256: 357555727e476770a7e53ee10711ad8f795caedfdcb90adb5083bf077439c63e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.49 MB
example_wigs/Oval/HH03.png ADDED

Git LFS Details

  • SHA256: a5ba9ef2d6fe37480923fbbd93a7bdf6fdb0590ed5c93f8e741163be31bc26eb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
example_wigs/Oval/HH07.png ADDED

Git LFS Details

  • SHA256: 1205c879380091b4fe13bdc29b070511f745b7365be956d627dc7b94c115118e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.72 MB
example_wigs/Oval/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Oval/Napa.png ADDED

Git LFS Details

  • SHA256: 1a9a929040f0bb2d4d527f811b35a6f7d92135aca380afa72e729cc74db6c5a2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
example_wigs/Oval/Piemonte.png ADDED

Git LFS Details

  • SHA256: 43b0d004d0565425c442b5c75d1dfd0ac8efa239f600fe07c85524fa0eb09e83
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
example_wigs/Oval/Rhone.png ADDED

Git LFS Details

  • SHA256: 928ece7bd6fa34d6b0d4e98f9457199f8247b21d0cc5929aaa3d1edc6332722b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.06 MB
example_wigs/Oval/SantaRossa.png ADDED

Git LFS Details

  • SHA256: e70fffdbe0a0b61b267f483ea35467a0108d5b961e86df7d293459a3944c93c4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.2 MB
example_wigs/Oval/Sonoma.png ADDED

Git LFS Details

  • SHA256: a9d70d9b95a40319beeff562149c708a6525fccbb8245caf484cb8b2cb74edc6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.18 MB
example_wigs/Oval/Tuscany.png ADDED

Git LFS Details

  • SHA256: 35ebf617bbbab34b05d019042f5ab8e9eb90cd6a186957a96df7c3793c142a9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.99 MB
example_wigs/Round/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Round/Piemonte.png ADDED

Git LFS Details

  • SHA256: 43b0d004d0565425c442b5c75d1dfd0ac8efa239f600fe07c85524fa0eb09e83
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
example_wigs/Round/Sonoma.png ADDED

Git LFS Details

  • SHA256: a9d70d9b95a40319beeff562149c708a6525fccbb8245caf484cb8b2cb74edc6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.18 MB
example_wigs/Round/Tuscany.png ADDED

Git LFS Details

  • SHA256: 35ebf617bbbab34b05d019042f5ab8e9eb90cd6a186957a96df7c3793c142a9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.99 MB
example_wigs/Square/HH03.png ADDED

Git LFS Details

  • SHA256: a5ba9ef2d6fe37480923fbbd93a7bdf6fdb0590ed5c93f8e741163be31bc26eb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
example_wigs/Square/Loire.png ADDED

Git LFS Details

  • SHA256: dc8864c7d5dd20de52ac6f5c8e1ddf236f4fda8278d63dae347306b0f33fb02a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
example_wigs/Square/Piemonte.png ADDED

Git LFS Details

  • SHA256: 43b0d004d0565425c442b5c75d1dfd0ac8efa239f600fe07c85524fa0eb09e83
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
example_wigs/Square/Sonoma.png ADDED

Git LFS Details

  • SHA256: a9d70d9b95a40319beeff562149c708a6525fccbb8245caf484cb8b2cb74edc6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.18 MB
example_wigs/Square/Tuscany.png ADDED

Git LFS Details

  • SHA256: 35ebf617bbbab34b05d019042f5ab8e9eb90cd6a186957a96df7c3793c142a9e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.99 MB
overlay.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import mediapipe as mp
4
+
5
+ from baldhead import inference # cạo tóc background
6
+ from segmentation import extract_hair
7
+
8
+ # MediaPipe Face Detection
9
+ mp_fd = mp.solutions.face_detection.FaceDetection(model_selection=1,
10
+ min_detection_confidence=0.5)
11
+
12
+ def get_face_bbox(img: Image.Image) -> tuple[int,int,int,int] | None:
13
+ arr = np.array(img.convert("RGB"))
14
+ res = mp_fd.process(arr)
15
+ if not res.detections:
16
+ return None
17
+ d = res.detections[0].location_data.relative_bounding_box
18
+ h, w = arr.shape[:2]
19
+ x1 = int(d.xmin * w)
20
+ y1 = int(d.ymin * h)
21
+ x2 = x1 + int(d.width * w)
22
+ y2 = y1 + int(d.height * h)
23
+ return x1, y1, x2, y2
24
+
25
+ def compute_scale(w_bg, h_bg, w_src, h_src) -> float:
26
+ return ((w_bg / w_src) + (h_bg / h_src)) / 2
27
+
28
+ def compute_offset(bbox_bg, bbox_src, scale) -> tuple[int,int]:
29
+ x1, y1, x2, y2 = bbox_bg
30
+ bg_cx = x1 + (x2 - x1)//2
31
+ bg_cy = y1 + (y2 - y1)//2
32
+ sx1, sy1, sx2, sy2 = bbox_src
33
+ src_cx = int((sx1 + (sx2 - sx1)//2) * scale)
34
+ src_cy = int((sy1 + (sy2 - sy1)//2) * scale)
35
+ return bg_cx - src_cx, bg_cy - src_cy
36
+
37
+ def paste_with_alpha(bg: np.ndarray, src: np.ndarray, offset: tuple[int,int]) -> Image.Image:
38
+ res = bg.copy()
39
+ x, y = offset
40
+ h, w = src.shape[:2]
41
+ x1, y1 = max(x,0), max(y,0)
42
+ x2 = min(x+w, bg.shape[1])
43
+ y2 = min(y+h, bg.shape[0])
44
+ if x1>=x2 or y1>=y2:
45
+ return Image.fromarray(res)
46
+ cs = src[y1-y:y2-y, x1-x:x2-x]
47
+ cd = res[y1:y2, x1:x2]
48
+ mask = cs[...,3] > 0
49
+ if cd.shape[2] == 3:
50
+ cd[mask] = cs[mask][...,:3]
51
+ else:
52
+ cd[mask] = cs[mask]
53
+ res[y1:y2, x1:x2] = cd
54
+ return Image.fromarray(res)
55
+
56
+ def overlay_source(background: Image.Image, source: Image.Image):
57
+ # 1) detect bboxes
58
+ bbox_bg = get_face_bbox(background)
59
+ bbox_src = get_face_bbox(source)
60
+ if bbox_bg is None:
61
+ return None, "❌ No face in background."
62
+ if bbox_src is None:
63
+ return None, "❌ No face in source."
64
+
65
+ # 2) compute scale & resize source
66
+ w_bg, h_bg = bbox_bg[2]-bbox_bg[0], bbox_bg[3]-bbox_bg[1]
67
+ w_src, h_src = bbox_src[2]-bbox_src[0], bbox_src[3]-bbox_src[1]
68
+ scale = compute_scale(w_bg, h_bg, w_src, h_src)
69
+ src_scaled = source.resize(
70
+ (int(source.width*scale), int(source.height*scale)),
71
+ Image.Resampling.LANCZOS
72
+ )
73
+
74
+ # 3) compute offset
75
+ offset = compute_offset(bbox_bg, bbox_src, scale)
76
+
77
+ # 4) baldhead background
78
+ bg_bald = inference(background)
79
+
80
+ # 5) extract hair-only from source
81
+ hair_only = extract_hair(src_scaled)
82
+
83
+ # 6) paste onto bald background
84
+ result = paste_with_alpha(
85
+ np.array(bg_bald.convert("RGBA")),
86
+ np.array(hair_only),
87
+ offset
88
+ )
89
+ return result
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ transformers==4.36.0
3
+ torch==2.1.0
4
+ torchvision==0.16.0
5
+ huggingface-hub==0.19.4
6
+ Pillow==9.2.0
7
+ opencv-python-headless==4.8.1.78
8
+ numpy==1.24.3
9
+ mediapipe==0.10.8
10
+ tensorflow==2.11.0
11
+ keras==2.11.0
12
+ scikit-image==0.20.0
13
+ git+https://github.com/keras-team/keras-contrib.git
14
+ retina-face==0.0.13
segmentation.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ # Load SegFormer for hair segmentation
8
+ processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
9
+ model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")
10
+
11
+ def extract_hair(image: Image.Image) -> Image.Image:
12
+ """
13
+ Return an RGBA image where hair pixels have alpha=255 and
14
+ all other pixels have alpha=0.
15
+ """
16
+ rgb = image.convert("RGB")
17
+ arr = np.array(rgb)
18
+ h, w = arr.shape[:2]
19
+
20
+ # Segment hair
21
+ inputs = processor(images=rgb, return_tensors="pt")
22
+ with torch.no_grad():
23
+ logits = model(**inputs).logits.cpu()
24
+ up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
25
+ seg = up.argmax(dim=1)[0].numpy()
26
+ hair_mask = (seg == 2).astype(np.uint8)
27
+
28
+ # Build RGBA
29
+ alpha = (hair_mask * 255).astype(np.uint8)
30
+ rgba = np.dstack([arr, alpha])
31
+ return Image.fromarray(rgba)