SwapMe / src /processor.py
Help
make models download at buildtime not runtime
7cfbe48
import gc
import os, io, base64, requests, cv2
import traceback
import threading
import numpy as np
import torch
import onnxruntime as ort
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
from gfpgan import GFPGANer
import insightface
from insightface.model_zoo.inswapper import INSwapper
from functools import lru_cache
from src.config import (
INSWAPPER_MODEL_PATH, HF_TOKEN, INSIGHTFACE_MODELS_DIR,
GFPGAN_MODELS_DIR, TORCH_NUM_THREADS, ONNX_INTRA_OP_THREADS,
DOWNLOAD_TIMEOUT, DEBUG_MODE
)
# --- CONFIG & INITIALIZATION ---
torch.set_num_threads(TORCH_NUM_THREADS)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
INSWAPPER_PATH = INSWAPPER_MODEL_PATH
# ONNX Runtime session options
sess_opts = ort.SessionOptions()
sess_opts.intra_op_num_threads = ONNX_INTRA_OP_THREADS
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
try:
face_analyser = insightface.app.FaceAnalysis(name='buffalo_l')
face_analyser.prepare(ctx_id=0 if device == 'cuda' else -1, det_size=(640, 640))
except Exception as e:
print(f"CRITICAL: FaceAnalysis failed: {e}")
face_analyser = None
face_lock = threading.Lock()
model_lock = threading.Lock()
try:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == 'cuda' else ['CPUExecutionProvider']
swapper_session = ort.InferenceSession(INSWAPPER_PATH, sess_opts, providers=providers)
inswapper_model = INSwapper(model_file=INSWAPPER_PATH, session=swapper_session)
print("Inswapper model loaded successfully.")
except Exception as e:
print(f"Model Load Error: {e}")
inswapper_model = None
try:
restorer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
GFPGAN_AVAILABLE = True
except Exception as e:
print(f'GFPGAN load failed: {e}')
GFPGAN_AVAILABLE = False
# Shared thread pool for parallel work
_thread_pool = ThreadPoolExecutor(max_workers=4)
# --- UTILITIES ---
def clear_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def set_det_thresh(thresh):
global face_analyser
if face_analyser is None: return
try:
if hasattr(face_analyser, 'det_model'):
face_analyser.det_model.det_thresh = thresh
elif 'detection' in face_analyser.models:
face_analyser.models['detection'].det_thresh = thresh
except Exception as e:
print(f"Threshold Error: {e}")
def enhance_for_detection(bgr):
lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
l = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)).apply(l)
enhanced = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
return cv2.filter2D(enhanced, -1, np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]))
@lru_cache(maxsize=10)
def get_source_embedding(url):
if face_analyser is None: return None
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
'Accept': 'image/avif,image/webp,image/apng,image/*,*/*;q=0.8'
}
resp = requests.get(url, headers=headers, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True)
resp.raise_for_status()
arr = np.frombuffer(resp.content, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None or img.size == 0: return None
with face_lock:
set_det_thresh(0.20)
faces = face_analyser.get(img)
if not faces:
faces = face_analyser.get(enhance_for_detection(img))
if not faces: return None
return sorted(faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)[0]
except Exception:
print(f"Source Download/Analysis Error: {traceback.format_exc()}")
return None
def get_best_face(bgr, thresh=0.15):
if face_analyser is None: return None
with face_lock:
set_det_thresh(thresh)
faces = face_analyser.get(bgr)
if not faces:
faces = face_analyser.get(enhance_for_detection(bgr))
if not faces: return None
return sorted(faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)[0]
def _pick_largest(faces):
if not faces: return None
return sorted(faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)[0]
@lru_cache(maxsize=5)
def _get_bgr_frames_cached(url, max_frames):
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
resp = requests.get(url, headers=headers, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True)
resp.raise_for_status()
gif = Image.open(io.BytesIO(resp.content))
frames, durations = [], []
try:
while True:
frame_bgr = cv2.cvtColor(np.array(gif.convert('RGB')), cv2.COLOR_RGB2BGR)
frames.append(frame_bgr)
durations.append(gif.info.get('duration', 67))
if max_frames and len(frames) >= max_frames: break
gif.seek(gif.tell() + 1)
except EOFError: pass
return tuple(frames), tuple(durations)
def get_bgr_frames_cached(url, max_frames=None):
return _get_bgr_frames_cached(url, max_frames)
def encode_gif(output_frames, durations):
buf = io.BytesIO()
pil_frames = [Image.fromarray(f) for f in output_frames]
pil_frames[0].save(
buf, format='GIF', save_all=True,
append_images=pil_frames[1:], loop=0,
duration=int(np.mean(durations)), quality=80
)
return base64.b64encode(buf.getvalue()).decode()
# --- PREPROCESS HELPERS ---
# Gamma LUT is stateless and safe to share
_gamma_lut = np.array([((i / 255.0) ** (1.0 / 1.2)) * 255 for i in range(256)], dtype=np.uint8)
# Sharpen kernels are read-only numpy arrays, safe to share
_sharpen_kernels = {
'light': np.array([[0, -0.3, 0], [-0.3, 2.2, -0.3], [0, -0.3, 0]]),
'medium': np.array([[0, -0.5, 0], [-0.5, 3.0, -0.5], [0, -0.5, 0]]),
'heavy': np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]),
}
def _make_clahe(strength):
# cv2.CLAHE objects are NOT thread-safe — always create fresh per call
limits = {'light': 1.5, 'medium': 2.5, 'heavy': 4.0}
return cv2.createCLAHE(clipLimit=limits.get(strength, 2.5), tileGridSize=(8, 8))
def preprocess_frame(bgr, strength='medium'):
if bgr is None or bgr.size == 0:
return bgr
# Stage 1: CLAHE contrast in LAB space
lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
l = _make_clahe(strength).apply(l)
enhanced = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
# Stage 2: Mild sharpening
sharpened = cv2.filter2D(enhanced, -1, _sharpen_kernels.get(strength, _sharpen_kernels['medium']))
# Stage 3: Gamma correction on heavy only
if strength == 'heavy':
sharpened = cv2.LUT(sharpened, _gamma_lut)
return sharpened
# --- CORE PROCESSING FUNCTIONS ---
def process_universal(source_url, target_url, use_gfpgan=False, max_frames=None):
source_face = get_source_embedding(source_url)
if source_face is None:
raise ValueError("Could not detect a face in the source image or source URL failed.")
bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames)
output_frames = []
locked_face = None
consecutive_no_face = 0
NO_FACE_LIMIT = 10
for i, bgr in enumerate(bgr_frames_raw):
if bgr is None: continue
current_face = get_best_face(bgr)
if current_face:
locked_face, consecutive_no_face = current_face, 0
else:
consecutive_no_face += 1
face_to_use = current_face or (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None)
if face_to_use and inswapper_model:
try:
with model_lock:
bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True)
if use_gfpgan and GFPGAN_AVAILABLE:
_, _, bgr = restorer.enhance(bgr, has_aligned=False, only_center_face=False, paste_back=True)
except Exception as e:
print(f"Swap Error Frame {i}: {e}")
output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
if i % 30 == 0: clear_memory()
return encode_gif(output_frames, durations)
def process_multiscale(source_url, target_url, use_gfpgan=False, max_frames=None):
source_face = get_source_embedding(source_url)
if source_face is None:
raise ValueError("Could not detect a face in the source image.")
bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames)
THRESHOLDS = [0.15, 0.10, 0.05]
output_frames = []
locked_face = None
consecutive_no_face = 0
NO_FACE_LIMIT = 10
for i, bgr in enumerate(bgr_frames_raw):
if bgr is None: continue
current_face = None
for thresh in THRESHOLDS:
current_face = get_best_face(bgr, thresh=thresh)
if current_face: break
if current_face:
locked_face, consecutive_no_face = current_face, 0
else:
consecutive_no_face += 1
face_to_use = current_face or (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None)
if face_to_use and inswapper_model:
try:
with model_lock:
bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True)
if use_gfpgan and GFPGAN_AVAILABLE:
_, _, bgr = restorer.enhance(bgr, has_aligned=False, only_center_face=False, paste_back=True)
except Exception as e:
print(f"Multiscale Swap Error Frame {i}: {e}")
output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
if i % 30 == 0: clear_memory()
return encode_gif(output_frames, durations)
def process_landmark(source_url, target_url, use_gfpgan=False, max_frames=None, min_landmark_confidence=0.2):
source_face = get_source_embedding(source_url)
if source_face is None:
raise ValueError("Could not detect a face in the source image.")
bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames)
def landmarks_are_valid(face, frame_shape):
if face is None:
return False
h, w = frame_shape[:2]
kps = getattr(face, 'kps', None)
if kps is None or len(kps) < 5:
return False
MARGIN = 10
for x, y in kps:
if not (-MARGIN <= x <= w + MARGIN and -MARGIN <= y <= h + MARGIN):
return False
bx1, by1, bx2, by2 = face.bbox
face_area = (bx2 - bx1) * (by2 - by1)
if face_area < 0.01 * w * h:
return False
pose = getattr(face, 'pose', None)
if pose is not None and len(pose) >= 3 and abs(pose[2]) > 60:
return False
score = getattr(face, 'det_score', 1.0)
if score < min_landmark_confidence:
return False
return True
output_frames = []
locked_face = None
consecutive_no_face = 0
NO_FACE_LIMIT = 10
THRESHOLDS = [0.15, 0.10, 0.05]
for i, bgr in enumerate(bgr_frames_raw):
if bgr is None:
continue
current_face = None
for thresh in THRESHOLDS:
current_face = get_best_face(bgr, thresh=thresh)
if current_face:
break
face_valid = landmarks_are_valid(current_face, bgr.shape)
if face_valid:
locked_face, consecutive_no_face = current_face, 0
else:
consecutive_no_face += 1
if current_face is not None:
score = getattr(current_face, 'det_score', '?')
kps = getattr(current_face, 'kps', None)
pose = getattr(current_face, 'pose', None)
print(f"Frame {i}: face found but landmark invalid — "
f"score={score:.2f}, pose={pose}, kps_count={len(kps) if kps is not None else 0}")
face_to_use = (current_face if face_valid
else (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None))
if face_to_use and inswapper_model:
try:
with model_lock:
bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True)
if use_gfpgan and GFPGAN_AVAILABLE:
_, _, bgr = restorer.enhance(
bgr, has_aligned=False, only_center_face=False, paste_back=True)
except Exception as e:
print(f"Landmark Swap Error Frame {i}: {e}")
output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
if i % 30 == 0:
clear_memory()
return encode_gif(output_frames, durations)
def process_preprocess(source_url, target_url, use_gfpgan=False, max_frames=None, preprocess_strength='medium'):
source_face = get_source_embedding(source_url)
if source_face is None:
raise ValueError("Could not detect a face in the source image.")
bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames)
output_frames = []
locked_face = None
consecutive_no_face = 0
NO_FACE_LIMIT = 10
for i, bgr in enumerate(bgr_frames_raw):
if bgr is None:
continue
# Preprocess only for detection — never used as swap target
preprocessed = preprocess_frame(bgr, strength=preprocess_strength)
with face_lock:
set_det_thresh(0.15)
faces = face_analyser.get(preprocessed) if face_analyser else []
current_face = _pick_largest(faces)
if not current_face:
# Fall back to raw frame detection
faces = face_analyser.get(bgr) if face_analyser else []
current_face = _pick_largest(faces)
if current_face:
locked_face, consecutive_no_face = current_face, 0
else:
consecutive_no_face += 1
face_to_use = current_face or (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None)
# Always swap on original bgr — preprocessed frame is detection-only
if face_to_use and inswapper_model:
try:
with model_lock:
bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True)
if use_gfpgan and GFPGAN_AVAILABLE:
_, _, bgr = restorer.enhance(
bgr, has_aligned=False, only_center_face=False, paste_back=True)
except Exception as e:
print(f"Preprocess Swap Error Frame {i}: {e}")
output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
if i % 30 == 0:
clear_memory()
return encode_gif(output_frames, durations)