Spaces:
Paused
Paused
| from typing import Any, List, Callable | |
| import cv2 | |
| import threading | |
| import numpy as np | |
| import os | |
| # Environment fixes | |
| os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' | |
| os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' | |
| import SwitcherAI.globals | |
| import SwitcherAI.processors.frame.core as frame_processors | |
| from SwitcherAI import wording | |
| from SwitcherAI.core import update_status | |
| from SwitcherAI.face_analyser import get_many_faces, get_one_face | |
| from SwitcherAI.typing import Frame, Face | |
| from SwitcherAI.utilities import conditional_download, resolve_relative_path, is_image, is_video | |
| # Global variables matching the pattern | |
| FRAME_PROCESSOR = None | |
| THREAD_SEMAPHORE = threading.Semaphore() | |
| THREAD_LOCK = threading.Lock() | |
| NAME = 'FACEFUSION.FRAME_PROCESSOR.LIP_SYNCER' | |
| def get_frame_processor() -> Any: | |
| """Get the lip sync processor - using ONNX Runtime like FaceFusion""" | |
| global FRAME_PROCESSOR | |
| with THREAD_LOCK: | |
| if FRAME_PROCESSOR is None: | |
| try: | |
| # Get the model name from globals | |
| model_name = getattr(SwitcherAI.globals, 'lip_syncer_model', 'wav2lip_gan_96') | |
| model_path = resolve_relative_path(f'../.assets/models/{model_name}.onnx') | |
| print(f"[{NAME}] Loading model: {model_path}") | |
| if os.path.exists(model_path): | |
| # Load ONNX model like FaceFusion does | |
| import onnxruntime | |
| providers = getattr(SwitcherAI.globals, 'execution_providers', ['CPUExecutionProvider']) | |
| FRAME_PROCESSOR = onnxruntime.InferenceSession(model_path, providers=providers) | |
| print(f"[{NAME}] ONNX model loaded successfully") | |
| else: | |
| print(f"[{NAME}] Model file not found: {model_path}") | |
| FRAME_PROCESSOR = None | |
| except ImportError: | |
| print(f"[{NAME}] onnxruntime not available, using passthrough mode") | |
| FRAME_PROCESSOR = None | |
| except Exception as e: | |
| print(f"[{NAME}] Error loading ONNX model: {e}") | |
| FRAME_PROCESSOR = None | |
| return FRAME_PROCESSOR | |
| def clear_frame_processor() -> None: | |
| """Clear the frame processor""" | |
| global FRAME_PROCESSOR | |
| FRAME_PROCESSOR = None | |
| def pre_check() -> bool: | |
| """Pre-check for lip syncer requirements""" | |
| print(f"[{NAME}] Pre-check starting...") | |
| try: | |
| # Check if we need to download models | |
| download_directory_path = resolve_relative_path('../.assets/models') | |
| # Get model name from globals | |
| model_name = getattr(SwitcherAI.globals, 'lip_syncer_model', 'wav2lip_gan_96') | |
| model_path = os.path.join(download_directory_path, f'{model_name}.onnx') | |
| if not os.path.exists(model_path): | |
| print(f"[{NAME}] Model not found: {model_path}") | |
| # Model download URLs | |
| model_urls = { | |
| 'wav2lip_96': ['Awwfuck.com'], | |
| 'wav2lip_gan_96': ['Awwfuck.com'] | |
| } | |
| if model_name in model_urls: | |
| print(f"[{NAME}] Attempting to download {model_name}") | |
| conditional_download(download_directory_path, model_urls[model_name]) | |
| print(f"[{NAME}] Pre-check passed") | |
| return True | |
| except Exception as e: | |
| print(f"[{NAME}] Pre-check error: {e}") | |
| return True | |
| def pre_process() -> bool: | |
| """Pre-process initialization""" | |
| print(f"[{NAME}] Pre-processing...") | |
| # Check target type like FaceFusion does | |
| if not is_image(SwitcherAI.globals.target_path) and not is_video(SwitcherAI.globals.target_path): | |
| update_status(wording.get('select_image_or_video_target') + wording.get('exclamation_mark'), NAME) | |
| return False | |
| print(f"[{NAME}] Pre-processing completed") | |
| return True | |
| def post_process() -> None: | |
| """Post-process cleanup""" | |
| clear_frame_processor() | |
| print(f"[{NAME}] Post-processing completed") | |
| def prepare_audio_frame(audio_frame: np.ndarray) -> np.ndarray: | |
| """Prepare audio frame like FaceFusion - convert mel spectrogram properly""" | |
| # FaceFusion audio preprocessing | |
| audio_frame = np.maximum(np.exp(-5 * np.log(10)), audio_frame) | |
| audio_frame = np.log10(audio_frame) * 1.6 + 3.2 | |
| audio_frame = audio_frame.clip(-4, 4).astype(np.float32) | |
| audio_frame = np.expand_dims(audio_frame, axis=(0, 1)) | |
| return audio_frame | |
| def prepare_crop_frame(crop_vision_frame: np.ndarray) -> np.ndarray: | |
| """Prepare crop frame like FaceFusion""" | |
| crop_vision_frame = np.expand_dims(crop_vision_frame, axis=0) | |
| prepare_vision_frame = crop_vision_frame.copy() | |
| prepare_vision_frame[:, 48:] = 0 # Mask bottom half | |
| crop_vision_frame = np.concatenate((prepare_vision_frame, crop_vision_frame), axis=3) | |
| crop_vision_frame = crop_vision_frame.transpose(0, 3, 1, 2).astype('float32') / 255.0 | |
| return crop_vision_frame | |
| def normalize_close_frame(crop_vision_frame: np.ndarray) -> np.ndarray: | |
| """Normalize frame like FaceFusion""" | |
| crop_vision_frame = crop_vision_frame[0].transpose(1, 2, 0) | |
| crop_vision_frame = crop_vision_frame.clip(0, 1) * 255 | |
| crop_vision_frame = crop_vision_frame.astype(np.uint8) | |
| return crop_vision_frame | |
| def forward(temp_audio_frame: np.ndarray, close_vision_frame: np.ndarray) -> np.ndarray: | |
| """Forward pass through model like FaceFusion""" | |
| lip_syncer = get_frame_processor() | |
| if lip_syncer is None: | |
| return close_vision_frame | |
| try: | |
| with THREAD_SEMAPHORE: | |
| # Get input names from the model | |
| input_names = [inp.name for inp in lip_syncer.get_inputs()] | |
| # Create input dictionary - FaceFusion uses 'source' and 'target' | |
| inputs = {} | |
| for name in input_names: | |
| if 'source' in name.lower() or 'audio' in name.lower() or 'mel' in name.lower(): | |
| inputs[name] = temp_audio_frame | |
| elif 'target' in name.lower() or 'video' in name.lower() or 'frame' in name.lower(): | |
| inputs[name] = close_vision_frame | |
| # Run inference | |
| close_vision_frame = lip_syncer.run(None, inputs)[0] | |
| return close_vision_frame | |
| except Exception as e: | |
| print(f"[{NAME}] Forward pass error: {e}") | |
| return close_vision_frame | |
| def sync_lip(target_face: Face, temp_audio_frame: np.ndarray, temp_vision_frame: Frame) -> Frame: | |
| """Main lip sync function following FaceFusion's approach""" | |
| try: | |
| # For now, create dummy audio frame if none provided | |
| if temp_audio_frame is None: | |
| # Create empty mel spectrogram (80 features x 16 frames) | |
| temp_audio_frame = np.zeros((80, 16), dtype=np.float32) | |
| # Prepare audio frame | |
| temp_audio_frame = prepare_audio_frame(temp_audio_frame) | |
| # Extract face region using face landmarks | |
| if hasattr(target_face, 'bbox'): | |
| bbox = target_face.bbox | |
| x1, y1, x2, y2 = map(int, bbox) | |
| # Ensure coordinates are within frame bounds | |
| h, w = temp_vision_frame.shape[:2] | |
| x1 = max(0, min(x1, w-1)) | |
| y1 = max(0, min(y1, h-1)) | |
| x2 = max(0, min(x2, w-1)) | |
| y2 = max(0, min(y2, h-1)) | |
| if x2 <= x1 or y2 <= y1: | |
| return temp_vision_frame | |
| # Extract and resize face region to 96x96 | |
| face_region = temp_vision_frame[y1:y2, x1:x2] | |
| close_vision_frame = cv2.resize(face_region, (96, 96)) | |
| # Prepare crop frame | |
| close_vision_frame = prepare_crop_frame(close_vision_frame) | |
| # Forward pass | |
| close_vision_frame = forward(temp_audio_frame, close_vision_frame) | |
| # Normalize output | |
| close_vision_frame = normalize_close_frame(close_vision_frame) | |
| # Resize back and paste | |
| close_vision_frame = cv2.resize(close_vision_frame, (x2-x1, y2-y1)) | |
| # Simple paste back | |
| result_frame = temp_vision_frame.copy() | |
| result_frame[y1:y2, x1:x2] = close_vision_frame | |
| return result_frame | |
| return temp_vision_frame | |
| except Exception as e: | |
| print(f"[{NAME}] Lip sync error: {e}") | |
| return temp_vision_frame | |
| def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame: | |
| """Process a single frame""" | |
| try: | |
| # Get all faces in the frame | |
| many_faces = get_many_faces(temp_frame) | |
| if not many_faces: | |
| return temp_frame | |
| # Process each face with lip sync | |
| result_frame = temp_frame | |
| for target_face in many_faces: | |
| # Create dummy audio frame for now | |
| temp_audio_frame = np.zeros((80, 16), dtype=np.float32) | |
| result_frame = sync_lip(target_face, temp_audio_frame, result_frame) | |
| return result_frame | |
| except Exception as e: | |
| print(f"[{NAME}] Error processing frame: {e}") | |
| return temp_frame | |
| def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None: | |
| """Process multiple frames""" | |
| total_frames = len(temp_frame_paths) | |
| print(f"[{NAME}] Processing {total_frames} frames") | |
| for i, temp_frame_path in enumerate(temp_frame_paths): | |
| try: | |
| # Read frame | |
| temp_frame = cv2.imread(temp_frame_path) | |
| if temp_frame is None: | |
| continue | |
| # Process frame | |
| result_frame = process_frame(None, None, temp_frame) | |
| # Save processed frame | |
| cv2.imwrite(temp_frame_path, result_frame) | |
| # Update progress | |
| if update: | |
| update() | |
| # Progress feedback | |
| if i % 100 == 0: | |
| print(f"[{NAME}] Progress: {i}/{total_frames} frames") | |
| except Exception as e: | |
| print(f"[{NAME}] Error processing frame {i}: {e}") | |
| continue | |
| print(f"[{NAME}] Frame processing completed") | |
| def process_image(source_path: str, target_path: str, output_path: str) -> None: | |
| """Process a single image""" | |
| try: | |
| print(f"[{NAME}] Processing image: {os.path.basename(target_path)}") | |
| # Read image | |
| target_frame = cv2.imread(target_path) | |
| if target_frame is None: | |
| import shutil | |
| shutil.copy2(target_path, output_path) | |
| return | |
| # Process frame | |
| result_frame = process_frame(None, None, target_frame) | |
| # Save result | |
| cv2.imwrite(output_path, result_frame) | |
| print(f"[{NAME}] Image processing completed") | |
| except Exception as e: | |
| print(f"[{NAME}] Error processing image: {e}") | |
| # Fallback: copy original | |
| import shutil | |
| shutil.copy2(target_path, output_path) | |
| def process_video(source_path: str, temp_frame_paths: List[str]) -> None: | |
| """Process video using the frame processor core""" | |
| frame_processors.process_video(source_path, temp_frame_paths, process_frames) |