Spaces:
Running
Running
| import gc | |
| import os, io, base64, requests, cv2 | |
| import traceback | |
| import threading | |
| import numpy as np | |
| import torch | |
| import onnxruntime as ort | |
| from concurrent.futures import ThreadPoolExecutor | |
| from PIL import Image | |
| from gfpgan import GFPGANer | |
| import insightface | |
| from insightface.model_zoo.inswapper import INSwapper | |
| from functools import lru_cache | |
| from src.config import ( | |
| INSWAPPER_MODEL_PATH, HF_TOKEN, INSIGHTFACE_MODELS_DIR, | |
| GFPGAN_MODELS_DIR, TORCH_NUM_THREADS, ONNX_INTRA_OP_THREADS, | |
| DOWNLOAD_TIMEOUT, DEBUG_MODE | |
| ) | |
| # --- CONFIG & INITIALIZATION --- | |
| torch.set_num_threads(TORCH_NUM_THREADS) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| INSWAPPER_PATH = INSWAPPER_MODEL_PATH | |
| # ONNX Runtime session options | |
| sess_opts = ort.SessionOptions() | |
| sess_opts.intra_op_num_threads = ONNX_INTRA_OP_THREADS | |
| sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| try: | |
| face_analyser = insightface.app.FaceAnalysis(name='buffalo_l') | |
| face_analyser.prepare(ctx_id=0 if device == 'cuda' else -1, det_size=(640, 640)) | |
| except Exception as e: | |
| print(f"CRITICAL: FaceAnalysis failed: {e}") | |
| face_analyser = None | |
| face_lock = threading.Lock() | |
| model_lock = threading.Lock() | |
| try: | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == 'cuda' else ['CPUExecutionProvider'] | |
| swapper_session = ort.InferenceSession(INSWAPPER_PATH, sess_opts, providers=providers) | |
| inswapper_model = INSwapper(model_file=INSWAPPER_PATH, session=swapper_session) | |
| print("Inswapper model loaded successfully.") | |
| except Exception as e: | |
| print(f"Model Load Error: {e}") | |
| inswapper_model = None | |
| try: | |
| restorer = GFPGANer( | |
| model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', | |
| upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) | |
| GFPGAN_AVAILABLE = True | |
| except Exception as e: | |
| print(f'GFPGAN load failed: {e}') | |
| GFPGAN_AVAILABLE = False | |
| # Shared thread pool for parallel work | |
| _thread_pool = ThreadPoolExecutor(max_workers=4) | |
| # --- UTILITIES --- | |
| def clear_memory(): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def set_det_thresh(thresh): | |
| global face_analyser | |
| if face_analyser is None: return | |
| try: | |
| if hasattr(face_analyser, 'det_model'): | |
| face_analyser.det_model.det_thresh = thresh | |
| elif 'detection' in face_analyser.models: | |
| face_analyser.models['detection'].det_thresh = thresh | |
| except Exception as e: | |
| print(f"Threshold Error: {e}") | |
| def enhance_for_detection(bgr): | |
| lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB) | |
| l, a, b = cv2.split(lab) | |
| l = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)).apply(l) | |
| enhanced = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR) | |
| return cv2.filter2D(enhanced, -1, np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])) | |
| def get_source_embedding(url): | |
| if face_analyser is None: return None | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36', | |
| 'Accept': 'image/avif,image/webp,image/apng,image/*,*/*;q=0.8' | |
| } | |
| resp = requests.get(url, headers=headers, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True) | |
| resp.raise_for_status() | |
| arr = np.frombuffer(resp.content, np.uint8) | |
| img = cv2.imdecode(arr, cv2.IMREAD_COLOR) | |
| if img is None or img.size == 0: return None | |
| with face_lock: | |
| set_det_thresh(0.20) | |
| faces = face_analyser.get(img) | |
| if not faces: | |
| faces = face_analyser.get(enhance_for_detection(img)) | |
| if not faces: return None | |
| return sorted(faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)[0] | |
| except Exception: | |
| print(f"Source Download/Analysis Error: {traceback.format_exc()}") | |
| return None | |
| def get_best_face(bgr, thresh=0.15): | |
| if face_analyser is None: return None | |
| with face_lock: | |
| set_det_thresh(thresh) | |
| faces = face_analyser.get(bgr) | |
| if not faces: | |
| faces = face_analyser.get(enhance_for_detection(bgr)) | |
| if not faces: return None | |
| return sorted(faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)[0] | |
| def _pick_largest(faces): | |
| if not faces: return None | |
| return sorted(faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)[0] | |
| def _get_bgr_frames_cached(url, max_frames): | |
| headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'} | |
| resp = requests.get(url, headers=headers, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True) | |
| resp.raise_for_status() | |
| gif = Image.open(io.BytesIO(resp.content)) | |
| frames, durations = [], [] | |
| try: | |
| while True: | |
| frame_bgr = cv2.cvtColor(np.array(gif.convert('RGB')), cv2.COLOR_RGB2BGR) | |
| frames.append(frame_bgr) | |
| durations.append(gif.info.get('duration', 67)) | |
| if max_frames and len(frames) >= max_frames: break | |
| gif.seek(gif.tell() + 1) | |
| except EOFError: pass | |
| return tuple(frames), tuple(durations) | |
| def get_bgr_frames_cached(url, max_frames=None): | |
| return _get_bgr_frames_cached(url, max_frames) | |
| def encode_gif(output_frames, durations): | |
| buf = io.BytesIO() | |
| pil_frames = [Image.fromarray(f) for f in output_frames] | |
| pil_frames[0].save( | |
| buf, format='GIF', save_all=True, | |
| append_images=pil_frames[1:], loop=0, | |
| duration=int(np.mean(durations)), quality=80 | |
| ) | |
| return base64.b64encode(buf.getvalue()).decode() | |
| # --- PREPROCESS HELPERS --- | |
| # Gamma LUT is stateless and safe to share | |
| _gamma_lut = np.array([((i / 255.0) ** (1.0 / 1.2)) * 255 for i in range(256)], dtype=np.uint8) | |
| # Sharpen kernels are read-only numpy arrays, safe to share | |
| _sharpen_kernels = { | |
| 'light': np.array([[0, -0.3, 0], [-0.3, 2.2, -0.3], [0, -0.3, 0]]), | |
| 'medium': np.array([[0, -0.5, 0], [-0.5, 3.0, -0.5], [0, -0.5, 0]]), | |
| 'heavy': np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]), | |
| } | |
| def _make_clahe(strength): | |
| # cv2.CLAHE objects are NOT thread-safe — always create fresh per call | |
| limits = {'light': 1.5, 'medium': 2.5, 'heavy': 4.0} | |
| return cv2.createCLAHE(clipLimit=limits.get(strength, 2.5), tileGridSize=(8, 8)) | |
| def preprocess_frame(bgr, strength='medium'): | |
| if bgr is None or bgr.size == 0: | |
| return bgr | |
| # Stage 1: CLAHE contrast in LAB space | |
| lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB) | |
| l, a, b = cv2.split(lab) | |
| l = _make_clahe(strength).apply(l) | |
| enhanced = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR) | |
| # Stage 2: Mild sharpening | |
| sharpened = cv2.filter2D(enhanced, -1, _sharpen_kernels.get(strength, _sharpen_kernels['medium'])) | |
| # Stage 3: Gamma correction on heavy only | |
| if strength == 'heavy': | |
| sharpened = cv2.LUT(sharpened, _gamma_lut) | |
| return sharpened | |
| # --- CORE PROCESSING FUNCTIONS --- | |
| def process_universal(source_url, target_url, use_gfpgan=False, max_frames=None): | |
| source_face = get_source_embedding(source_url) | |
| if source_face is None: | |
| raise ValueError("Could not detect a face in the source image or source URL failed.") | |
| bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames) | |
| output_frames = [] | |
| locked_face = None | |
| consecutive_no_face = 0 | |
| NO_FACE_LIMIT = 10 | |
| for i, bgr in enumerate(bgr_frames_raw): | |
| if bgr is None: continue | |
| current_face = get_best_face(bgr) | |
| if current_face: | |
| locked_face, consecutive_no_face = current_face, 0 | |
| else: | |
| consecutive_no_face += 1 | |
| face_to_use = current_face or (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None) | |
| if face_to_use and inswapper_model: | |
| try: | |
| with model_lock: | |
| bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True) | |
| if use_gfpgan and GFPGAN_AVAILABLE: | |
| _, _, bgr = restorer.enhance(bgr, has_aligned=False, only_center_face=False, paste_back=True) | |
| except Exception as e: | |
| print(f"Swap Error Frame {i}: {e}") | |
| output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| if i % 30 == 0: clear_memory() | |
| return encode_gif(output_frames, durations) | |
| def process_multiscale(source_url, target_url, use_gfpgan=False, max_frames=None): | |
| source_face = get_source_embedding(source_url) | |
| if source_face is None: | |
| raise ValueError("Could not detect a face in the source image.") | |
| bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames) | |
| THRESHOLDS = [0.15, 0.10, 0.05] | |
| output_frames = [] | |
| locked_face = None | |
| consecutive_no_face = 0 | |
| NO_FACE_LIMIT = 10 | |
| for i, bgr in enumerate(bgr_frames_raw): | |
| if bgr is None: continue | |
| current_face = None | |
| for thresh in THRESHOLDS: | |
| current_face = get_best_face(bgr, thresh=thresh) | |
| if current_face: break | |
| if current_face: | |
| locked_face, consecutive_no_face = current_face, 0 | |
| else: | |
| consecutive_no_face += 1 | |
| face_to_use = current_face or (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None) | |
| if face_to_use and inswapper_model: | |
| try: | |
| with model_lock: | |
| bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True) | |
| if use_gfpgan and GFPGAN_AVAILABLE: | |
| _, _, bgr = restorer.enhance(bgr, has_aligned=False, only_center_face=False, paste_back=True) | |
| except Exception as e: | |
| print(f"Multiscale Swap Error Frame {i}: {e}") | |
| output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| if i % 30 == 0: clear_memory() | |
| return encode_gif(output_frames, durations) | |
| def process_landmark(source_url, target_url, use_gfpgan=False, max_frames=None, min_landmark_confidence=0.2): | |
| source_face = get_source_embedding(source_url) | |
| if source_face is None: | |
| raise ValueError("Could not detect a face in the source image.") | |
| bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames) | |
| def landmarks_are_valid(face, frame_shape): | |
| if face is None: | |
| return False | |
| h, w = frame_shape[:2] | |
| kps = getattr(face, 'kps', None) | |
| if kps is None or len(kps) < 5: | |
| return False | |
| MARGIN = 10 | |
| for x, y in kps: | |
| if not (-MARGIN <= x <= w + MARGIN and -MARGIN <= y <= h + MARGIN): | |
| return False | |
| bx1, by1, bx2, by2 = face.bbox | |
| face_area = (bx2 - bx1) * (by2 - by1) | |
| if face_area < 0.01 * w * h: | |
| return False | |
| pose = getattr(face, 'pose', None) | |
| if pose is not None and len(pose) >= 3 and abs(pose[2]) > 60: | |
| return False | |
| score = getattr(face, 'det_score', 1.0) | |
| if score < min_landmark_confidence: | |
| return False | |
| return True | |
| output_frames = [] | |
| locked_face = None | |
| consecutive_no_face = 0 | |
| NO_FACE_LIMIT = 10 | |
| THRESHOLDS = [0.15, 0.10, 0.05] | |
| for i, bgr in enumerate(bgr_frames_raw): | |
| if bgr is None: | |
| continue | |
| current_face = None | |
| for thresh in THRESHOLDS: | |
| current_face = get_best_face(bgr, thresh=thresh) | |
| if current_face: | |
| break | |
| face_valid = landmarks_are_valid(current_face, bgr.shape) | |
| if face_valid: | |
| locked_face, consecutive_no_face = current_face, 0 | |
| else: | |
| consecutive_no_face += 1 | |
| if current_face is not None: | |
| score = getattr(current_face, 'det_score', '?') | |
| kps = getattr(current_face, 'kps', None) | |
| pose = getattr(current_face, 'pose', None) | |
| print(f"Frame {i}: face found but landmark invalid — " | |
| f"score={score:.2f}, pose={pose}, kps_count={len(kps) if kps is not None else 0}") | |
| face_to_use = (current_face if face_valid | |
| else (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None)) | |
| if face_to_use and inswapper_model: | |
| try: | |
| with model_lock: | |
| bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True) | |
| if use_gfpgan and GFPGAN_AVAILABLE: | |
| _, _, bgr = restorer.enhance( | |
| bgr, has_aligned=False, only_center_face=False, paste_back=True) | |
| except Exception as e: | |
| print(f"Landmark Swap Error Frame {i}: {e}") | |
| output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| if i % 30 == 0: | |
| clear_memory() | |
| return encode_gif(output_frames, durations) | |
| def process_preprocess(source_url, target_url, use_gfpgan=False, max_frames=None, preprocess_strength='medium'): | |
| source_face = get_source_embedding(source_url) | |
| if source_face is None: | |
| raise ValueError("Could not detect a face in the source image.") | |
| bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames) | |
| output_frames = [] | |
| locked_face = None | |
| consecutive_no_face = 0 | |
| NO_FACE_LIMIT = 10 | |
| for i, bgr in enumerate(bgr_frames_raw): | |
| if bgr is None: | |
| continue | |
| # Preprocess only for detection — never used as swap target | |
| preprocessed = preprocess_frame(bgr, strength=preprocess_strength) | |
| with face_lock: | |
| set_det_thresh(0.15) | |
| faces = face_analyser.get(preprocessed) if face_analyser else [] | |
| current_face = _pick_largest(faces) | |
| if not current_face: | |
| # Fall back to raw frame detection | |
| faces = face_analyser.get(bgr) if face_analyser else [] | |
| current_face = _pick_largest(faces) | |
| if current_face: | |
| locked_face, consecutive_no_face = current_face, 0 | |
| else: | |
| consecutive_no_face += 1 | |
| face_to_use = current_face or (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None) | |
| # Always swap on original bgr — preprocessed frame is detection-only | |
| if face_to_use and inswapper_model: | |
| try: | |
| with model_lock: | |
| bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True) | |
| if use_gfpgan and GFPGAN_AVAILABLE: | |
| _, _, bgr = restorer.enhance( | |
| bgr, has_aligned=False, only_center_face=False, paste_back=True) | |
| except Exception as e: | |
| print(f"Preprocess Swap Error Frame {i}: {e}") | |
| output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| if i % 30 == 0: | |
| clear_memory() | |
| return encode_gif(output_frames, durations) | |