crash10155 commited on
Commit
d454ef7
·
verified ·
1 Parent(s): f91fd95

Update SwitcherAI/processors/frame/modules/frame_enhancer.py

Browse files
SwitcherAI/processors/frame/modules/frame_enhancer.py CHANGED
@@ -3,9 +3,8 @@ import cv2
3
  import threading
4
  import numpy
5
  from functools import lru_cache
6
- from basicsr.archs.rrdbnet_arch import RRDBNet
7
- from realesrgan import RealESRGANer
8
- import torch
9
  import SwitcherAI.processors.frame.core as frame_processors
10
  from SwitcherAI.typing import Frame, Face
11
  from SwitcherAI.utilities import conditional_download, resolve_relative_path
@@ -20,9 +19,13 @@ NAME = 'FACEFUSION.FRAME_PROCESSOR.FRAME_ENHANCER'
20
  @lru_cache(maxsize=None)
21
  def get_model_config() -> Dict[str, Any]:
22
  """Get model configuration with enhanced options"""
 
 
 
 
23
  return {
24
  'real_esrgan_x4': {
25
- 'model_path': resolve_relative_path('../.assets/models/RealESRGAN_x4plus.pth'),
26
  'scale': 4,
27
  'tile_size': 256,
28
  'tile_pad': 16,
@@ -38,29 +41,53 @@ def get_frame_processor() -> Any:
38
 
39
  with THREAD_LOCK:
40
  if FRAME_PROCESSOR is None:
41
- config = get_model_config()['real_esrgan_x4']
42
- model_path = config['model_path']
43
-
44
- FRAME_PROCESSOR = RealESRGANer(
45
- model_path=model_path,
46
- model=RRDBNet(
47
- num_in_ch=3,
48
- num_out_ch=3,
49
- num_feat=config['num_feat'],
50
- num_block=config['num_block'],
51
- num_grow_ch=config['num_grow_ch'],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  scale=config['scale']
53
- ),
54
- device=frame_processors.get_device(),
55
- tile=config['tile_size'],
56
- tile_pad=config['tile_pad'],
57
- pre_pad=0,
58
- scale=config['scale']
59
- )
60
-
61
- # Ensure CUDA device is set if available
62
- if torch.cuda.is_available():
63
- torch.cuda.set_device(0)
 
 
 
 
64
 
65
  return FRAME_PROCESSOR
66
 
@@ -72,27 +99,46 @@ def clear_frame_processor() -> None:
72
 
73
  def pre_check() -> bool:
74
  """Download required models for frame enhancement"""
75
- download_directory_path = resolve_relative_path('../.assets/models')
76
-
77
  try:
78
- conditional_download(download_directory_path, [
 
 
 
 
 
 
 
 
79
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
80
- ])
81
- return True
 
 
 
 
 
 
 
 
 
 
 
82
  except Exception as e:
83
- print(f"⚠️ Failed to download frame enhancement models: {e}")
84
- print("🔄 Frame enhancement will be disabled")
85
  return False
86
 
87
 
88
  def pre_process() -> bool:
89
  """Pre-process check with model validation"""
90
  try:
91
- model_path = get_model_config()['real_esrgan_x4']['model_path']
92
- if not model_path or not model_path.exists():
93
- print("⚠️ Frame enhancement model not found")
 
94
  return False
 
95
  return True
 
96
  except Exception as e:
97
  print(f"⚠️ Frame enhancement pre-process failed: {e}")
98
  return False
@@ -166,34 +212,47 @@ def enhance_frame_with_tiling(temp_frame: Frame) -> Frame:
166
  """
167
  Enhanced frame enhancement with improved tiling (inspired by FaceFusion)
168
  """
169
- config = get_model_config()['real_esrgan_x4']
170
- tile_size = (config['tile_size'], config['tile_size'])
171
- scale = config['scale']
172
-
173
- # Create tiles for processing
174
- tiles, pad_width, pad_height = create_tile_frames(temp_frame, tile_size)
175
- enhanced_tiles = []
176
-
177
- with THREAD_SEMAPHORE:
178
- frame_processor = get_frame_processor()
179
 
180
- for tile in tiles:
181
- # Process each tile individually to manage memory
182
- enhanced_tile, _ = frame_processor.enhance(tile, outscale=scale)
183
- enhanced_tiles.append(enhanced_tile)
184
-
185
- # Merge tiles back together
186
- original_height, original_width = temp_frame.shape[:2]
187
- enhanced_frame = merge_tile_frames(
188
- enhanced_tiles,
189
- original_width * scale,
190
- original_height * scale,
191
- pad_width * scale,
192
- pad_height * scale,
193
- (tile_size[0] * scale, tile_size[1] * scale)
194
- )
195
-
196
- return enhanced_frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
 
199
  def enhance_frame(temp_frame: Frame) -> Frame:
@@ -201,12 +260,24 @@ def enhance_frame(temp_frame: Frame) -> Frame:
201
  Main enhancement function with fallback to original method
202
  """
203
  try:
 
 
 
 
 
204
  # Try enhanced tiling method first
205
- return enhance_frame_with_tiling(temp_frame)
206
- except Exception:
207
- # Fallback to original method
208
- with THREAD_SEMAPHORE:
209
- temp_frame, _ = get_frame_processor().enhance(temp_frame, outscale=1)
 
 
 
 
 
 
 
210
  return temp_frame
211
 
212
 
@@ -214,61 +285,114 @@ def blend_frame(original_frame: Frame, enhanced_frame: Frame, blend_ratio: float
214
  """
215
  Blend original and enhanced frames (inspired by FaceFusion)
216
  """
217
- if original_frame.shape != enhanced_frame.shape:
218
- original_frame = cv2.resize(original_frame, (enhanced_frame.shape[1], enhanced_frame.shape[0]))
219
-
220
- # Convert blend ratio (0-1 where 1 = full enhancement)
221
- return cv2.addWeighted(original_frame, 1 - blend_ratio, enhanced_frame, blend_ratio, 0)
 
 
 
 
222
 
223
 
224
  def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame:
225
  """
226
  Main processing function (maintains your original interface)
227
  """
228
- return enhance_frame(temp_frame)
 
 
 
 
229
 
230
 
231
  def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
232
  """
233
  Process multiple frames (maintains your original interface)
234
  """
235
- for temp_frame_path in temp_frame_paths:
236
- temp_frame = cv2.imread(temp_frame_path)
237
- result_frame = process_frame(None, None, temp_frame)
238
- cv2.imwrite(temp_frame_path, result_frame)
239
- if update:
240
- update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
 
243
  def process_image(source_path: str, target_path: str, output_path: str) -> None:
244
  """
245
  Process single image (maintains your original interface)
246
  """
247
- target_frame = cv2.imread(target_path)
248
- result = process_frame(None, None, target_frame)
249
- cv2.imwrite(output_path, result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
 
252
  def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
253
  """
254
  Process video frames (maintains your original interface)
255
  """
256
- frame_processors.process_video(None, temp_frame_paths, process_frames)
 
 
 
257
 
258
 
259
  # Additional utility functions inspired by FaceFusion
260
  def get_model_scale() -> int:
261
  """Get the current model's scale factor"""
262
- return get_model_config()['real_esrgan_x4']['scale']
 
 
 
263
 
264
 
265
  def prepare_frame(frame: Frame) -> Frame:
266
  """Prepare frame for processing"""
267
- if frame.dtype != numpy.uint8:
268
- frame = frame.astype(numpy.uint8)
269
- return frame
 
 
 
270
 
271
 
272
  def normalize_frame(frame: Frame) -> Frame:
273
  """Normalize frame after processing"""
274
- return numpy.clip(frame, 0, 255).astype(numpy.uint8)
 
 
 
 
3
  import threading
4
  import numpy
5
  from functools import lru_cache
6
+ from pathlib import Path
7
+
 
8
  import SwitcherAI.processors.frame.core as frame_processors
9
  from SwitcherAI.typing import Frame, Face
10
  from SwitcherAI.utilities import conditional_download, resolve_relative_path
 
19
  @lru_cache(maxsize=None)
20
  def get_model_config() -> Dict[str, Any]:
21
  """Get model configuration with enhanced options"""
22
+ base_path = resolve_relative_path('../.assets/models')
23
+ if isinstance(base_path, str):
24
+ base_path = Path(base_path)
25
+
26
  return {
27
  'real_esrgan_x4': {
28
+ 'model_path': base_path / 'RealESRGAN_x4plus.pth',
29
  'scale': 4,
30
  'tile_size': 256,
31
  'tile_pad': 16,
 
41
 
42
  with THREAD_LOCK:
43
  if FRAME_PROCESSOR is None:
44
+ try:
45
+ # Import Real-ESRGAN components
46
+ from basicsr.archs.rrdbnet_arch import RRDBNet
47
+ from realesrgan import RealESRGANer
48
+ import torch
49
+
50
+ config = get_model_config()['real_esrgan_x4']
51
+ model_path = config['model_path']
52
+
53
+ # Check if model exists
54
+ if not model_path.exists():
55
+ print(f"⚠️ Real-ESRGAN model not found at: {model_path}")
56
+ print("🔄 Attempting to download model...")
57
+ if not pre_check():
58
+ print("❌ Failed to download Real-ESRGAN model")
59
+ return None
60
+
61
+ FRAME_PROCESSOR = RealESRGANer(
62
+ model_path=str(model_path),
63
+ model=RRDBNet(
64
+ num_in_ch=3,
65
+ num_out_ch=3,
66
+ num_feat=config['num_feat'],
67
+ num_block=config['num_block'],
68
+ num_grow_ch=config['num_grow_ch'],
69
+ scale=config['scale']
70
+ ),
71
+ device=frame_processors.get_device(),
72
+ tile=config['tile_size'],
73
+ tile_pad=config['tile_pad'],
74
+ pre_pad=0,
75
  scale=config['scale']
76
+ )
77
+
78
+ # Ensure CUDA device is set if available
79
+ if torch.cuda.is_available():
80
+ torch.cuda.set_device(0)
81
+
82
+ print("✅ Real-ESRGAN frame processor initialized")
83
+
84
+ except ImportError as e:
85
+ print(f"⚠️ Real-ESRGAN not available: {e}")
86
+ print("💡 Install with: pip install realesrgan basicsr")
87
+ FRAME_PROCESSOR = None
88
+ except Exception as e:
89
+ print(f"⚠️ Failed to initialize Real-ESRGAN: {e}")
90
+ FRAME_PROCESSOR = None
91
 
92
  return FRAME_PROCESSOR
93
 
 
99
 
100
  def pre_check() -> bool:
101
  """Download required models for frame enhancement"""
 
 
102
  try:
103
+ download_directory_path = resolve_relative_path('../.assets/models')
104
+
105
+ # Ensure download directory exists
106
+ if isinstance(download_directory_path, str):
107
+ download_directory_path = Path(download_directory_path)
108
+ download_directory_path.mkdir(parents=True, exist_ok=True)
109
+
110
+ # Download Real-ESRGAN model
111
+ model_urls = [
112
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
113
+ ]
114
+
115
+ conditional_download(str(download_directory_path), model_urls)
116
+
117
+ # Verify the model was downloaded
118
+ model_path = download_directory_path / 'RealESRGAN_x4plus.pth'
119
+ if model_path.exists() and model_path.stat().st_size > 0:
120
+ print(f"✅ Real-ESRGAN model verified: {model_path.stat().st_size / (1024*1024):.1f}MB")
121
+ return True
122
+ else:
123
+ print("❌ Real-ESRGAN model download failed or file is empty")
124
+ return False
125
+
126
  except Exception as e:
127
+ print(f" Real-ESRGAN pre-check failed: {e}")
 
128
  return False
129
 
130
 
131
  def pre_process() -> bool:
132
  """Pre-process check with model validation"""
133
  try:
134
+ # Check if processor is available
135
+ processor = get_frame_processor()
136
+ if processor is None:
137
+ print("⚠️ Real-ESRGAN not available, frame enhancement will be skipped")
138
  return False
139
+
140
  return True
141
+
142
  except Exception as e:
143
  print(f"⚠️ Frame enhancement pre-process failed: {e}")
144
  return False
 
212
  """
213
  Enhanced frame enhancement with improved tiling (inspired by FaceFusion)
214
  """
215
+ try:
216
+ processor = get_frame_processor()
217
+ if processor is None:
218
+ print("⚠️ Real-ESRGAN processor not available, returning original frame")
219
+ return temp_frame
 
 
 
 
 
220
 
221
+ config = get_model_config()['real_esrgan_x4']
222
+ tile_size = (config['tile_size'], config['tile_size'])
223
+ scale = config['scale']
224
+
225
+ # Create tiles for processing
226
+ tiles, pad_width, pad_height = create_tile_frames(temp_frame, tile_size)
227
+ enhanced_tiles = []
228
+
229
+ with THREAD_SEMAPHORE:
230
+ for tile in tiles:
231
+ try:
232
+ # Process each tile individually to manage memory
233
+ enhanced_tile, _ = processor.enhance(tile, outscale=scale)
234
+ enhanced_tiles.append(enhanced_tile)
235
+ except Exception as e:
236
+ print(f"⚠️ Tile enhancement failed: {e}")
237
+ # Use original tile if enhancement fails
238
+ enhanced_tiles.append(tile)
239
+
240
+ # Merge tiles back together
241
+ original_height, original_width = temp_frame.shape[:2]
242
+ enhanced_frame = merge_tile_frames(
243
+ enhanced_tiles,
244
+ original_width * scale,
245
+ original_height * scale,
246
+ pad_width * scale,
247
+ pad_height * scale,
248
+ (tile_size[0] * scale, tile_size[1] * scale)
249
+ )
250
+
251
+ return enhanced_frame
252
+
253
+ except Exception as e:
254
+ print(f"⚠️ Enhanced tiling failed: {e}")
255
+ return temp_frame
256
 
257
 
258
  def enhance_frame(temp_frame: Frame) -> Frame:
 
260
  Main enhancement function with fallback to original method
261
  """
262
  try:
263
+ processor = get_frame_processor()
264
+ if processor is None:
265
+ print("⚠️ Frame enhancer not available, returning original frame")
266
+ return temp_frame
267
+
268
  # Try enhanced tiling method first
269
+ try:
270
+ return enhance_frame_with_tiling(temp_frame)
271
+ except Exception as e:
272
+ print(f"⚠️ Tiling method failed: {e}, trying simple enhancement")
273
+
274
+ # Fallback to original method
275
+ with THREAD_SEMAPHORE:
276
+ enhanced_frame, _ = processor.enhance(temp_frame, outscale=1)
277
+ return enhanced_frame
278
+
279
+ except Exception as e:
280
+ print(f"⚠️ Frame enhancement failed completely: {e}")
281
  return temp_frame
282
 
283
 
 
285
  """
286
  Blend original and enhanced frames (inspired by FaceFusion)
287
  """
288
+ try:
289
+ if original_frame.shape != enhanced_frame.shape:
290
+ original_frame = cv2.resize(original_frame, (enhanced_frame.shape[1], enhanced_frame.shape[0]))
291
+
292
+ # Convert blend ratio (0-1 where 1 = full enhancement)
293
+ return cv2.addWeighted(original_frame, 1 - blend_ratio, enhanced_frame, blend_ratio, 0)
294
+ except Exception as e:
295
+ print(f"⚠️ Frame blending failed: {e}")
296
+ return enhanced_frame
297
 
298
 
299
  def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame:
300
  """
301
  Main processing function (maintains your original interface)
302
  """
303
+ try:
304
+ return enhance_frame(temp_frame)
305
+ except Exception as e:
306
+ print(f"⚠️ Error in process_frame: {e}")
307
+ return temp_frame
308
 
309
 
310
  def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
311
  """
312
  Process multiple frames (maintains your original interface)
313
  """
314
+ try:
315
+ processor = get_frame_processor()
316
+ if processor is None:
317
+ print("⚠️ Frame enhancer not available, skipping frame enhancement")
318
+ if update:
319
+ update()
320
+ return
321
+
322
+ for temp_frame_path in temp_frame_paths:
323
+ try:
324
+ temp_frame = cv2.imread(temp_frame_path)
325
+ if temp_frame is not None:
326
+ result_frame = process_frame(None, None, temp_frame)
327
+ cv2.imwrite(temp_frame_path, result_frame)
328
+ else:
329
+ print(f"⚠️ Failed to read frame: {temp_frame_path}")
330
+
331
+ except Exception as e:
332
+ print(f"⚠️ Error processing frame {temp_frame_path}: {e}")
333
+
334
+ if update:
335
+ update()
336
+
337
+ except Exception as e:
338
+ print(f"⚠️ Error in process_frames: {e}")
339
 
340
 
341
  def process_image(source_path: str, target_path: str, output_path: str) -> None:
342
  """
343
  Process single image (maintains your original interface)
344
  """
345
+ try:
346
+ processor = get_frame_processor()
347
+ if processor is None:
348
+ print("⚠️ Frame enhancer not available, copying original image")
349
+ import shutil
350
+ shutil.copy2(target_path, output_path)
351
+ return
352
+
353
+ target_frame = cv2.imread(target_path)
354
+ if target_frame is not None:
355
+ result = process_frame(None, None, target_frame)
356
+ cv2.imwrite(output_path, result)
357
+ else:
358
+ print(f"⚠️ Failed to read image: {target_path}")
359
+
360
+ except Exception as e:
361
+ print(f"⚠️ Error in process_image: {e}")
362
 
363
 
364
  def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
365
  """
366
  Process video frames (maintains your original interface)
367
  """
368
+ try:
369
+ frame_processors.process_video(None, temp_frame_paths, process_frames)
370
+ except Exception as e:
371
+ print(f"⚠️ Error in process_video: {e}")
372
 
373
 
374
  # Additional utility functions inspired by FaceFusion
375
  def get_model_scale() -> int:
376
  """Get the current model's scale factor"""
377
+ try:
378
+ return get_model_config()['real_esrgan_x4']['scale']
379
+ except:
380
+ return 1
381
 
382
 
383
  def prepare_frame(frame: Frame) -> Frame:
384
  """Prepare frame for processing"""
385
+ try:
386
+ if frame.dtype != numpy.uint8:
387
+ frame = frame.astype(numpy.uint8)
388
+ return frame
389
+ except:
390
+ return frame
391
 
392
 
393
  def normalize_frame(frame: Frame) -> Frame:
394
  """Normalize frame after processing"""
395
+ try:
396
+ return numpy.clip(frame, 0, 255).astype(numpy.uint8)
397
+ except:
398
+ return frame