Herishop commited on
Commit
baf037a
·
verified ·
1 Parent(s): 55f4413

Upload swap.py

Browse files
Files changed (1) hide show
  1. swap.py +241 -0
swap.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import insightface
4
+ from insightface.app import FaceAnalysis
5
+ from gfpgan import GFPGANer
6
+ import os
7
+ import torch
8
+ import warnings
9
+ import gradio as gr
10
+ import time
11
+ from datetime import datetime
12
+ import shutil
13
+ import traceback
14
+
15
+ # Suppress specific warnings
16
+ warnings.filterwarnings("ignore", category=UserWarning, module="gradio_client.documentation")
17
+ warnings.filterwarnings("ignore", category=FutureWarning)
18
+
19
+ # Paths (giữ nguyên như bạn cung cấp)
20
+ model_path = os.path.join("models", "inswapper_128.onnx")
21
+ gfpgan_path = os.path.join("gfpgan", "weights", "GFPGANv1.4.pth")
22
+ buffalo_l_path = os.path.join("models", "buffalo_l")
23
+ output_dir = "output"
24
+
25
+ # Initialize logging
26
+ log_messages = []
27
+
28
+ def log_message(message):
29
+ """Append message to log with timestamp."""
30
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
31
+ log_messages.append(f"[{timestamp}] {message}")
32
+ print(f"[{timestamp}] {message}") # Also print to console
33
+ return "\n".join(log_messages)
34
+
35
+ def validate_paths():
36
+ """Validate required file and directory paths."""
37
+ log_message("Validating file paths...")
38
+ for path in [model_path, gfpgan_path]:
39
+ if not os.path.isfile(path):
40
+ return False, f"Error: File not found at {path}"
41
+ if not os.path.isdir(buffalo_l_path):
42
+ return False, f"Error: buffalo_l directory not found at {buffalo_l_path}. Please download and extract buffalo_l.zip from https://github.com/deepinsight/insightface/releases/download/v0.7/buffalo_l.zip to {buffalo_l_path}"
43
+ # Kiểm tra các file cần thiết trong buffalo_l
44
+ required_files = ["1k3d68.onnx", "2d106det.onnx", "det_10g.onnx", "genderage.onnx", "w600k_r50.onnx"]
45
+ if not all(os.path.exists(os.path.join(buffalo_l_path, f)) for f in required_files):
46
+ return False, f"Error: buffalo_l directory at {buffalo_l_path} is incomplete. Please ensure it contains {', '.join(required_files)}"
47
+ return True, "All paths validated successfully"
48
+
49
+ def initialize_face_analysis():
50
+ """Initialize FaceAnalysis model."""
51
+ providers = [
52
+ ('CUDAExecutionProvider', {
53
+ 'device_id': 0,
54
+ 'gpu_mem_limit': 10 * 1024 * 1024 * 1024,
55
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
56
+ 'cudnn_conv_algo_search': 'EXHAUSTIVE',
57
+ 'do_copy_in_default_stream': True,
58
+ }),
59
+ 'CPUExecutionProvider',
60
+ ]
61
+ try:
62
+ log_message("Initializing FaceAnalysis...")
63
+ # Sử dụng root="models" để tìm đúng models\buffalo_l
64
+ app = FaceAnalysis(name="buffalo_l", root=os.path.dirname(buffalo_l_path), providers=providers)
65
+ app.prepare(ctx_id=0, det_size=(640, 640))
66
+ log_message(f"PyTorch CUDA available: {torch.cuda.is_available()}")
67
+ log_message("FaceAnalysis initialized successfully")
68
+ return app, None
69
+ except Exception as e:
70
+ error_msg = f"Error initializing FaceAnalysis: {str(e)}\n{traceback.format_exc()}"
71
+ return None, log_message(error_msg)
72
+
73
+ def load_and_detect_faces(app, source_img, target_img):
74
+ """Load images and detect faces."""
75
+ try:
76
+ log_message("Loading and detecting faces...")
77
+ if source_img is None or target_img is None:
78
+ return None, None, "Error: Source or target image is None"
79
+
80
+ source_img_np = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
81
+ target_img_np = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
82
+
83
+ source_faces = app.get(source_img_np)
84
+ target_faces = app.get(target_img_np)
85
+
86
+ log_message(f"Source image: {len(source_faces)} faces detected")
87
+ log_message(f"Target image: {len(target_faces)} faces detected")
88
+
89
+ if len(source_faces) == 0 or len(target_faces) == 0:
90
+ return None, None, "Error: No faces detected in source or target image!"
91
+
92
+ return source_faces, target_faces, None
93
+ except Exception as e:
94
+ error_msg = f"Error in load_and_detect_faces: {str(e)}\n{traceback.format_exc()}"
95
+ return None, None, log_message(error_msg)
96
+
97
+ def select_source_face(source_faces):
98
+ """Select the first source face."""
99
+ try:
100
+ log_message("Selecting source face...")
101
+ source_face = source_faces[0]
102
+ log_message("Using first detected source face")
103
+ return source_face, None
104
+ except Exception as e:
105
+ error_msg = f"Error selecting source face: {str(e)}\n{traceback.format_exc()}"
106
+ return None, log_message(error_msg)
107
+
108
+ def perform_face_swap(source_face, target_face, target_img):
109
+ """Perform face swapping with edge smoothing."""
110
+ try:
111
+ log_message("Loading inswapper model...")
112
+ swapper = insightface.model_zoo.get_model(model_path, providers=[
113
+ ('CUDAExecutionProvider', {
114
+ 'device_id': 0,
115
+ 'gpu_mem_limit': 10 * 1024 * 1024 * 1024,
116
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
117
+ 'cudnn_conv_algo_search': 'EXHAUSTIVE',
118
+ 'do_copy_in_default_stream': True,
119
+ }),
120
+ 'CPUExecutionProvider',
121
+ ])
122
+ log_message("Inswapper model loaded successfully")
123
+
124
+ target_img_np = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
125
+ result = target_img_np.copy()
126
+ result = swapper.get(result, target_face, source_face, paste_back=True)
127
+
128
+ x, y, w, h = target_face.bbox.astype(int)
129
+ mask = np.zeros(result.shape[:2], dtype=np.float32)
130
+ cv2.rectangle(mask, (x, y), (x + w, y + h), 1.0, -1)
131
+ mask = cv2.GaussianBlur(mask, (9, 9), 0)
132
+ mask = np.stack([mask]*3, axis=-1)
133
+ result = (result * mask + target_img_np * (1 - mask)).astype(np.uint8)
134
+
135
+ log_message("Face swapping completed")
136
+ return result, None
137
+ except Exception as e:
138
+ error_msg = f"Error during face swapping: {str(e)}\n{traceback.format_exc()}"
139
+ return None, log_message(error_msg)
140
+
141
+ def enhance_with_gfpgan(result):
142
+ """Enhance swapped image using GFPGAN without resizing."""
143
+ try:
144
+ log_message("Enhancing with GFPGAN...")
145
+ enhancer = GFPGANer(
146
+ model_path=gfpgan_path,
147
+ upscale=1,
148
+ arch='clean',
149
+ channel_multiplier=2,
150
+ device='cuda' if torch.cuda.is_available() else 'cpu',
151
+ bg_upsampler=None
152
+ )
153
+ _, _, enhanced_result = enhancer.enhance(result, paste_back=True)
154
+ output_path = os.path.join(output_dir, "output.jpg")
155
+ cv2.imwrite(output_path, enhanced_result)
156
+ log_message(f"Enhanced image saved to {output_path}")
157
+ return output_path, None
158
+ except Exception as e:
159
+ error_msg = f"Error during GFPGAN enhancement: {str(e)}\n{traceback.format_exc()}"
160
+ return None, log_message(error_msg)
161
+
162
+ def face_swap(source_img, target_img):
163
+ """Main face swap function for Gradio."""
164
+ global log_messages
165
+ log_messages = [] # Reset logs
166
+ start_time = time.time()
167
+
168
+ try:
169
+ log_message("Starting face swap process...")
170
+ if os.path.exists(output_dir):
171
+ shutil.rmtree(output_dir)
172
+ os.makedirs(output_dir, exist_ok=True)
173
+
174
+ valid, message = validate_paths()
175
+ log_message(message)
176
+ if not valid:
177
+ return None, log_message("Path validation failed")
178
+
179
+ app, error = initialize_face_analysis()
180
+ if error:
181
+ return None, log_message(error)
182
+
183
+ source_faces, target_faces, error = load_and_detect_faces(app, source_img, target_img)
184
+ if error:
185
+ return None, log_message(error)
186
+
187
+ source_face, error = select_source_face(source_faces)
188
+ if error:
189
+ return None, log_message(error)
190
+ target_face = target_faces[0]
191
+ log_message(f"Target face attributes: {target_face.__dict__}")
192
+
193
+ result, error = perform_face_swap(source_face, target_face, target_img)
194
+ if error:
195
+ return None, log_message(error)
196
+
197
+ output_path, error = enhance_with_gfpgan(result)
198
+ if error:
199
+ return None, log_message(error)
200
+
201
+ log_message(f"Processing completed in {time.time() - start_time:.2f} seconds")
202
+ return output_path, "\n".join(log_messages)
203
+ except Exception as e:
204
+ error_msg = f"Unexpected error in face_swap: {str(e)}\n{traceback.format_exc()}"
205
+ return None, log_message(error_msg)
206
+
207
+ # Gradio Interface
208
+ with gr.Blocks() as demo:
209
+ gr.Markdown("# Face Swap Application")
210
+ gr.Markdown("Upload source and target images to swap faces. The first detected face in the source image will be used.")
211
+
212
+ with gr.Row():
213
+ with gr.Column():
214
+ source_img = gr.Image(type="pil", label="Source Image")
215
+ target_img = gr.Image(type="pil", label="Target Image")
216
+ submit_btn = gr.Button("Swap Faces")
217
+ with gr.Column():
218
+ output = gr.Image(label="Final Output")
219
+
220
+ logs = gr.Textbox(label="Logs", interactive=False, lines=10)
221
+
222
+ submit_btn.click(
223
+ fn=face_swap,
224
+ inputs=[source_img, target_img],
225
+ outputs=[output, logs],
226
+ api_name="faceswap" # Đảm bảo endpoint /face_swap
227
+ )
228
+
229
+ if __name__ == "__main__":
230
+ try:
231
+ log_message("Launching Gradio interface...")
232
+ demo.launch(
233
+ share=True,
234
+ debug=True,
235
+ allowed_paths=["models", "gfpgan/weights", output_dir],
236
+ server_name="0.0.0.0",
237
+ server_port=7860
238
+ )
239
+ except Exception as e:
240
+ log_message(f"Error launching Gradio: {str(e)}\n{traceback.format_exc()}")
241
+ print(f"Error launching Gradio: {str(e)}\n{traceback.format_exc()}")