crash10155 commited on
Commit
a3d5e88
·
verified ·
1 Parent(s): d454ef7

Update SwitcherAI/processors/frame/modules/face_enhancer.py

Browse files
SwitcherAI/processors/frame/modules/face_enhancer.py CHANGED
@@ -1,7 +1,7 @@
1
  from typing import Any, List, Callable
2
  import cv2
3
  import threading
4
- from gfpgan.utils import GFPGANer
5
 
6
  import SwitcherAI.globals
7
  import SwitcherAI.processors.frame.core as frame_processors
@@ -22,79 +22,216 @@ def get_frame_processor() -> Any:
22
 
23
  with THREAD_LOCK:
24
  if FRAME_PROCESSOR is None:
25
- model_path = resolve_relative_path('../.assets/models/GFPGANv1.4.pth')
26
- FRAME_PROCESSOR = GFPGANer(
27
- model_path = model_path,
28
- upscale = 1,
29
- device = frame_processors.get_device()
30
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return FRAME_PROCESSOR
32
 
33
 
34
  def clear_frame_processor() -> None:
35
  global FRAME_PROCESSOR
36
-
37
  FRAME_PROCESSOR = None
38
 
39
 
40
  def pre_check() -> bool:
41
- download_directory_path = resolve_relative_path('../.assets/models')
42
- conditional_download(download_directory_path, ['https://github.com/SwitcherAI/SwitcherAI-assets/releases/download/models/GFPGANv1.4.pth'])
43
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  def pre_process() -> bool:
47
- if not is_image(SwitcherAI.globals.target_path) and not is_video(SwitcherAI.globals.target_path):
48
- update_status(wording.get('select_image_or_video_target') + wording.get('exclamation_mark'), NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return False
50
- return True
51
 
52
 
53
  def post_process() -> None:
54
  clear_frame_processor()
55
 
56
 
57
- def enhance_face(target_face : Face, temp_frame : Frame) -> Frame:
58
- start_x, start_y, end_x, end_y = map(int, target_face['bbox'])
59
- padding_x = int((end_x - start_x) * 0.5)
60
- padding_y = int((end_y - start_y) * 0.5)
61
- start_x = max(0, start_x - padding_x)
62
- start_y = max(0, start_y - padding_y)
63
- end_x = max(0, end_x + padding_x)
64
- end_y = max(0, end_y + padding_y)
65
- crop_frame = temp_frame[start_y:end_y, start_x:end_x]
66
- if crop_frame.size:
67
- with THREAD_SEMAPHORE:
68
- _, _, crop_frame = get_frame_processor().enhance(
69
- crop_frame,
70
- paste_back = True
71
- )
72
- temp_frame[start_y:end_y, start_x:end_x] = crop_frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  return temp_frame
74
 
75
 
76
- def process_frame(source_face : Face, reference_face : Face, temp_frame : Frame) -> Frame:
77
- many_faces = get_many_faces(temp_frame)
78
- if many_faces:
79
- for target_face in many_faces:
80
- temp_frame = enhance_face(target_face, temp_frame)
 
 
 
 
 
 
 
 
 
 
 
 
81
  return temp_frame
82
 
83
 
84
- def process_frames(source_path : str, temp_frame_paths : List[str], update: Callable[[], None]) -> None:
85
- for temp_frame_path in temp_frame_paths:
86
- temp_frame = cv2.imread(temp_frame_path)
87
- result_frame = process_frame(None, None, temp_frame)
88
- cv2.imwrite(temp_frame_path, result_frame)
89
- if update:
90
- update()
91
-
92
-
93
- def process_image(source_path : str, target_path : str, output_path : str) -> None:
94
- target_frame = cv2.imread(target_path)
95
- result_frame = process_frame(None, None, target_frame)
96
- cv2.imwrite(output_path, result_frame)
97
-
98
-
99
- def process_video(source_path : str, temp_frame_paths : List[str]) -> None:
100
- SwitcherAI.processors.frame.core.process_video(None, temp_frame_paths, process_frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Any, List, Callable
2
  import cv2
3
  import threading
4
+ from pathlib import Path
5
 
6
  import SwitcherAI.globals
7
  import SwitcherAI.processors.frame.core as frame_processors
 
22
 
23
  with THREAD_LOCK:
24
  if FRAME_PROCESSOR is None:
25
+ try:
26
+ # Import GFPGAN here to handle import errors gracefully
27
+ from gfpgan.utils import GFPGANer
28
+
29
+ model_path = resolve_relative_path('../.assets/models/GFPGANv1.4.pth')
30
+
31
+ # Convert to Path object if it's a string
32
+ if isinstance(model_path, str):
33
+ model_path = Path(model_path)
34
+
35
+ # Check if model exists
36
+ if not model_path.exists():
37
+ print(f"⚠️ GFPGAN model not found at: {model_path}")
38
+ print("🔄 Attempting to download model...")
39
+ if not pre_check():
40
+ print("❌ Failed to download GFPGAN model")
41
+ return None
42
+
43
+ FRAME_PROCESSOR = GFPGANer(
44
+ model_path = str(model_path),
45
+ upscale = 1,
46
+ device = frame_processors.get_device()
47
+ )
48
+ print("✅ GFPGAN frame processor initialized")
49
+
50
+ except ImportError as e:
51
+ print(f"⚠️ GFPGAN not available: {e}")
52
+ print("💡 Install with: pip install gfpgan")
53
+ FRAME_PROCESSOR = None
54
+ except Exception as e:
55
+ print(f"⚠️ Failed to initialize GFPGAN: {e}")
56
+ FRAME_PROCESSOR = None
57
+
58
  return FRAME_PROCESSOR
59
 
60
 
61
  def clear_frame_processor() -> None:
62
  global FRAME_PROCESSOR
 
63
  FRAME_PROCESSOR = None
64
 
65
 
66
  def pre_check() -> bool:
67
+ try:
68
+ download_directory_path = resolve_relative_path('../.assets/models')
69
+
70
+ # Ensure download directory exists
71
+ if isinstance(download_directory_path, str):
72
+ download_directory_path = Path(download_directory_path)
73
+ download_directory_path.mkdir(parents=True, exist_ok=True)
74
+
75
+ # Download GFPGAN model
76
+ model_urls = [
77
+ 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
78
+ ]
79
+
80
+ conditional_download(str(download_directory_path), model_urls)
81
+
82
+ # Verify the model was downloaded
83
+ model_path = download_directory_path / 'GFPGANv1.4.pth'
84
+ if model_path.exists() and model_path.stat().st_size > 0:
85
+ print(f"✅ GFPGAN model verified: {model_path.stat().st_size / (1024*1024):.1f}MB")
86
+ return True
87
+ else:
88
+ print("❌ GFPGAN model download failed or file is empty")
89
+ return False
90
+
91
+ except Exception as e:
92
+ print(f"❌ GFPGAN pre-check failed: {e}")
93
+ return False
94
 
95
 
96
  def pre_process() -> bool:
97
+ try:
98
+ # Check if we have valid input
99
+ if not is_image(SwitcherAI.globals.target_path) and not is_video(SwitcherAI.globals.target_path):
100
+ update_status(wording.get('select_image_or_video_target') + wording.get('exclamation_mark'), NAME)
101
+ return False
102
+
103
+ # Check if GFPGAN is available
104
+ processor = get_frame_processor()
105
+ if processor is None:
106
+ print("⚠️ GFPGAN not available, face enhancement will be skipped")
107
+ return False
108
+
109
+ return True
110
+
111
+ except Exception as e:
112
+ print(f"⚠️ Face enhancer pre-process failed: {e}")
113
  return False
 
114
 
115
 
116
  def post_process() -> None:
117
  clear_frame_processor()
118
 
119
 
120
+ def enhance_face(target_face: Face, temp_frame: Frame) -> Frame:
121
+ """Enhanced face enhancement with error handling"""
122
+ try:
123
+ processor = get_frame_processor()
124
+ if processor is None:
125
+ print("⚠️ GFPGAN processor not available, returning original frame")
126
+ return temp_frame
127
+
128
+ start_x, start_y, end_x, end_y = map(int, target_face['bbox'])
129
+ padding_x = int((end_x - start_x) * 0.5)
130
+ padding_y = int((end_y - start_y) * 0.5)
131
+ start_x = max(0, start_x - padding_x)
132
+ start_y = max(0, start_y - padding_y)
133
+ end_x = max(0, end_x + padding_x)
134
+ end_y = max(0, end_y + padding_y)
135
+
136
+ # Ensure coordinates are within frame bounds
137
+ height, width = temp_frame.shape[:2]
138
+ end_x = min(end_x, width)
139
+ end_y = min(end_y, height)
140
+
141
+ crop_frame = temp_frame[start_y:end_y, start_x:end_x]
142
+
143
+ if crop_frame.size > 0:
144
+ with THREAD_SEMAPHORE:
145
+ try:
146
+ _, _, enhanced_crop = processor.enhance(
147
+ crop_frame,
148
+ paste_back = True
149
+ )
150
+ temp_frame[start_y:end_y, start_x:end_x] = enhanced_crop
151
+ except Exception as e:
152
+ print(f"⚠️ Face enhancement failed: {e}")
153
+ # Return original frame if enhancement fails
154
+ pass
155
+
156
+ except Exception as e:
157
+ print(f"⚠️ Error in enhance_face: {e}")
158
+
159
  return temp_frame
160
 
161
 
162
+ def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame:
163
+ """Process frame with enhanced error handling"""
164
+ try:
165
+ # Check if processor is available
166
+ processor = get_frame_processor()
167
+ if processor is None:
168
+ print("⚠️ Face enhancer not available, skipping enhancement")
169
+ return temp_frame
170
+
171
+ many_faces = get_many_faces(temp_frame)
172
+ if many_faces:
173
+ for target_face in many_faces:
174
+ temp_frame = enhance_face(target_face, temp_frame)
175
+
176
+ except Exception as e:
177
+ print(f"⚠️ Error in process_frame: {e}")
178
+
179
  return temp_frame
180
 
181
 
182
+ def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
183
+ """Process multiple frames with progress updates"""
184
+ try:
185
+ processor = get_frame_processor()
186
+ if processor is None:
187
+ print("⚠️ Face enhancer not available, skipping frame enhancement")
188
+ if update:
189
+ update()
190
+ return
191
+
192
+ for temp_frame_path in temp_frame_paths:
193
+ try:
194
+ temp_frame = cv2.imread(temp_frame_path)
195
+ if temp_frame is not None:
196
+ result_frame = process_frame(None, None, temp_frame)
197
+ cv2.imwrite(temp_frame_path, result_frame)
198
+ else:
199
+ print(f"⚠️ Failed to read frame: {temp_frame_path}")
200
+
201
+ except Exception as e:
202
+ print(f"⚠️ Error processing frame {temp_frame_path}: {e}")
203
+
204
+ if update:
205
+ update()
206
+
207
+ except Exception as e:
208
+ print(f"⚠️ Error in process_frames: {e}")
209
+
210
+
211
+ def process_image(source_path: str, target_path: str, output_path: str) -> None:
212
+ """Process single image with error handling"""
213
+ try:
214
+ processor = get_frame_processor()
215
+ if processor is None:
216
+ print("⚠️ Face enhancer not available, copying original image")
217
+ import shutil
218
+ shutil.copy2(target_path, output_path)
219
+ return
220
+
221
+ target_frame = cv2.imread(target_path)
222
+ if target_frame is not None:
223
+ result_frame = process_frame(None, None, target_frame)
224
+ cv2.imwrite(output_path, result_frame)
225
+ else:
226
+ print(f"⚠️ Failed to read image: {target_path}")
227
+
228
+ except Exception as e:
229
+ print(f"⚠️ Error in process_image: {e}")
230
+
231
+
232
+ def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
233
+ """Process video frames"""
234
+ try:
235
+ SwitcherAI.processors.frame.core.process_video(None, temp_frame_paths, process_frames)
236
+ except Exception as e:
237
+ print(f"⚠️ Error in process_video: {e}")