Spaces:
Running
Running
File size: 15,010 Bytes
4ff9d22 7cfbe48 4ff9d22 7cfbe48 4ff9d22 7cfbe48 4ff9d22 7cfbe48 4ff9d22 7cfbe48 4ff9d22 7cfbe48 4ff9d22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 | import gc
import os, io, base64, requests, cv2
import traceback
import threading
import numpy as np
import torch
import onnxruntime as ort
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
from gfpgan import GFPGANer
import insightface
from insightface.model_zoo.inswapper import INSwapper
from functools import lru_cache
from src.config import (
INSWAPPER_MODEL_PATH, HF_TOKEN, INSIGHTFACE_MODELS_DIR,
GFPGAN_MODELS_DIR, TORCH_NUM_THREADS, ONNX_INTRA_OP_THREADS,
DOWNLOAD_TIMEOUT, DEBUG_MODE
)
# --- CONFIG & INITIALIZATION ---
torch.set_num_threads(TORCH_NUM_THREADS)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
INSWAPPER_PATH = INSWAPPER_MODEL_PATH
# ONNX Runtime session options
sess_opts = ort.SessionOptions()
sess_opts.intra_op_num_threads = ONNX_INTRA_OP_THREADS
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
try:
face_analyser = insightface.app.FaceAnalysis(name='buffalo_l')
face_analyser.prepare(ctx_id=0 if device == 'cuda' else -1, det_size=(640, 640))
except Exception as e:
print(f"CRITICAL: FaceAnalysis failed: {e}")
face_analyser = None
face_lock = threading.Lock()
model_lock = threading.Lock()
try:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == 'cuda' else ['CPUExecutionProvider']
swapper_session = ort.InferenceSession(INSWAPPER_PATH, sess_opts, providers=providers)
inswapper_model = INSwapper(model_file=INSWAPPER_PATH, session=swapper_session)
print("Inswapper model loaded successfully.")
except Exception as e:
print(f"Model Load Error: {e}")
inswapper_model = None
try:
restorer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
GFPGAN_AVAILABLE = True
except Exception as e:
print(f'GFPGAN load failed: {e}')
GFPGAN_AVAILABLE = False
# Shared thread pool for parallel work
_thread_pool = ThreadPoolExecutor(max_workers=4)
# --- UTILITIES ---
def clear_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def set_det_thresh(thresh):
global face_analyser
if face_analyser is None: return
try:
if hasattr(face_analyser, 'det_model'):
face_analyser.det_model.det_thresh = thresh
elif 'detection' in face_analyser.models:
face_analyser.models['detection'].det_thresh = thresh
except Exception as e:
print(f"Threshold Error: {e}")
def enhance_for_detection(bgr):
lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
l = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)).apply(l)
enhanced = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
return cv2.filter2D(enhanced, -1, np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]))
@lru_cache(maxsize=10)
def get_source_embedding(url):
if face_analyser is None: return None
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
'Accept': 'image/avif,image/webp,image/apng,image/*,*/*;q=0.8'
}
resp = requests.get(url, headers=headers, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True)
resp.raise_for_status()
arr = np.frombuffer(resp.content, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None or img.size == 0: return None
with face_lock:
set_det_thresh(0.20)
faces = face_analyser.get(img)
if not faces:
faces = face_analyser.get(enhance_for_detection(img))
if not faces: return None
return sorted(faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)[0]
except Exception:
print(f"Source Download/Analysis Error: {traceback.format_exc()}")
return None
def get_best_face(bgr, thresh=0.15):
if face_analyser is None: return None
with face_lock:
set_det_thresh(thresh)
faces = face_analyser.get(bgr)
if not faces:
faces = face_analyser.get(enhance_for_detection(bgr))
if not faces: return None
return sorted(faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)[0]
def _pick_largest(faces):
if not faces: return None
return sorted(faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)[0]
@lru_cache(maxsize=5)
def _get_bgr_frames_cached(url, max_frames):
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
resp = requests.get(url, headers=headers, timeout=DOWNLOAD_TIMEOUT, allow_redirects=True)
resp.raise_for_status()
gif = Image.open(io.BytesIO(resp.content))
frames, durations = [], []
try:
while True:
frame_bgr = cv2.cvtColor(np.array(gif.convert('RGB')), cv2.COLOR_RGB2BGR)
frames.append(frame_bgr)
durations.append(gif.info.get('duration', 67))
if max_frames and len(frames) >= max_frames: break
gif.seek(gif.tell() + 1)
except EOFError: pass
return tuple(frames), tuple(durations)
def get_bgr_frames_cached(url, max_frames=None):
return _get_bgr_frames_cached(url, max_frames)
def encode_gif(output_frames, durations):
buf = io.BytesIO()
pil_frames = [Image.fromarray(f) for f in output_frames]
pil_frames[0].save(
buf, format='GIF', save_all=True,
append_images=pil_frames[1:], loop=0,
duration=int(np.mean(durations)), quality=80
)
return base64.b64encode(buf.getvalue()).decode()
# --- PREPROCESS HELPERS ---
# Gamma LUT is stateless and safe to share
_gamma_lut = np.array([((i / 255.0) ** (1.0 / 1.2)) * 255 for i in range(256)], dtype=np.uint8)
# Sharpen kernels are read-only numpy arrays, safe to share
_sharpen_kernels = {
'light': np.array([[0, -0.3, 0], [-0.3, 2.2, -0.3], [0, -0.3, 0]]),
'medium': np.array([[0, -0.5, 0], [-0.5, 3.0, -0.5], [0, -0.5, 0]]),
'heavy': np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]),
}
def _make_clahe(strength):
# cv2.CLAHE objects are NOT thread-safe — always create fresh per call
limits = {'light': 1.5, 'medium': 2.5, 'heavy': 4.0}
return cv2.createCLAHE(clipLimit=limits.get(strength, 2.5), tileGridSize=(8, 8))
def preprocess_frame(bgr, strength='medium'):
if bgr is None or bgr.size == 0:
return bgr
# Stage 1: CLAHE contrast in LAB space
lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
l = _make_clahe(strength).apply(l)
enhanced = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
# Stage 2: Mild sharpening
sharpened = cv2.filter2D(enhanced, -1, _sharpen_kernels.get(strength, _sharpen_kernels['medium']))
# Stage 3: Gamma correction on heavy only
if strength == 'heavy':
sharpened = cv2.LUT(sharpened, _gamma_lut)
return sharpened
# --- CORE PROCESSING FUNCTIONS ---
def process_universal(source_url, target_url, use_gfpgan=False, max_frames=None):
source_face = get_source_embedding(source_url)
if source_face is None:
raise ValueError("Could not detect a face in the source image or source URL failed.")
bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames)
output_frames = []
locked_face = None
consecutive_no_face = 0
NO_FACE_LIMIT = 10
for i, bgr in enumerate(bgr_frames_raw):
if bgr is None: continue
current_face = get_best_face(bgr)
if current_face:
locked_face, consecutive_no_face = current_face, 0
else:
consecutive_no_face += 1
face_to_use = current_face or (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None)
if face_to_use and inswapper_model:
try:
with model_lock:
bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True)
if use_gfpgan and GFPGAN_AVAILABLE:
_, _, bgr = restorer.enhance(bgr, has_aligned=False, only_center_face=False, paste_back=True)
except Exception as e:
print(f"Swap Error Frame {i}: {e}")
output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
if i % 30 == 0: clear_memory()
return encode_gif(output_frames, durations)
def process_multiscale(source_url, target_url, use_gfpgan=False, max_frames=None):
source_face = get_source_embedding(source_url)
if source_face is None:
raise ValueError("Could not detect a face in the source image.")
bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames)
THRESHOLDS = [0.15, 0.10, 0.05]
output_frames = []
locked_face = None
consecutive_no_face = 0
NO_FACE_LIMIT = 10
for i, bgr in enumerate(bgr_frames_raw):
if bgr is None: continue
current_face = None
for thresh in THRESHOLDS:
current_face = get_best_face(bgr, thresh=thresh)
if current_face: break
if current_face:
locked_face, consecutive_no_face = current_face, 0
else:
consecutive_no_face += 1
face_to_use = current_face or (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None)
if face_to_use and inswapper_model:
try:
with model_lock:
bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True)
if use_gfpgan and GFPGAN_AVAILABLE:
_, _, bgr = restorer.enhance(bgr, has_aligned=False, only_center_face=False, paste_back=True)
except Exception as e:
print(f"Multiscale Swap Error Frame {i}: {e}")
output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
if i % 30 == 0: clear_memory()
return encode_gif(output_frames, durations)
def process_landmark(source_url, target_url, use_gfpgan=False, max_frames=None, min_landmark_confidence=0.2):
source_face = get_source_embedding(source_url)
if source_face is None:
raise ValueError("Could not detect a face in the source image.")
bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames)
def landmarks_are_valid(face, frame_shape):
if face is None:
return False
h, w = frame_shape[:2]
kps = getattr(face, 'kps', None)
if kps is None or len(kps) < 5:
return False
MARGIN = 10
for x, y in kps:
if not (-MARGIN <= x <= w + MARGIN and -MARGIN <= y <= h + MARGIN):
return False
bx1, by1, bx2, by2 = face.bbox
face_area = (bx2 - bx1) * (by2 - by1)
if face_area < 0.01 * w * h:
return False
pose = getattr(face, 'pose', None)
if pose is not None and len(pose) >= 3 and abs(pose[2]) > 60:
return False
score = getattr(face, 'det_score', 1.0)
if score < min_landmark_confidence:
return False
return True
output_frames = []
locked_face = None
consecutive_no_face = 0
NO_FACE_LIMIT = 10
THRESHOLDS = [0.15, 0.10, 0.05]
for i, bgr in enumerate(bgr_frames_raw):
if bgr is None:
continue
current_face = None
for thresh in THRESHOLDS:
current_face = get_best_face(bgr, thresh=thresh)
if current_face:
break
face_valid = landmarks_are_valid(current_face, bgr.shape)
if face_valid:
locked_face, consecutive_no_face = current_face, 0
else:
consecutive_no_face += 1
if current_face is not None:
score = getattr(current_face, 'det_score', '?')
kps = getattr(current_face, 'kps', None)
pose = getattr(current_face, 'pose', None)
print(f"Frame {i}: face found but landmark invalid — "
f"score={score:.2f}, pose={pose}, kps_count={len(kps) if kps is not None else 0}")
face_to_use = (current_face if face_valid
else (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None))
if face_to_use and inswapper_model:
try:
with model_lock:
bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True)
if use_gfpgan and GFPGAN_AVAILABLE:
_, _, bgr = restorer.enhance(
bgr, has_aligned=False, only_center_face=False, paste_back=True)
except Exception as e:
print(f"Landmark Swap Error Frame {i}: {e}")
output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
if i % 30 == 0:
clear_memory()
return encode_gif(output_frames, durations)
def process_preprocess(source_url, target_url, use_gfpgan=False, max_frames=None, preprocess_strength='medium'):
source_face = get_source_embedding(source_url)
if source_face is None:
raise ValueError("Could not detect a face in the source image.")
bgr_frames_raw, durations = get_bgr_frames_cached(target_url, max_frames)
output_frames = []
locked_face = None
consecutive_no_face = 0
NO_FACE_LIMIT = 10
for i, bgr in enumerate(bgr_frames_raw):
if bgr is None:
continue
# Preprocess only for detection — never used as swap target
preprocessed = preprocess_frame(bgr, strength=preprocess_strength)
with face_lock:
set_det_thresh(0.15)
faces = face_analyser.get(preprocessed) if face_analyser else []
current_face = _pick_largest(faces)
if not current_face:
# Fall back to raw frame detection
faces = face_analyser.get(bgr) if face_analyser else []
current_face = _pick_largest(faces)
if current_face:
locked_face, consecutive_no_face = current_face, 0
else:
consecutive_no_face += 1
face_to_use = current_face or (locked_face if consecutive_no_face <= NO_FACE_LIMIT else None)
# Always swap on original bgr — preprocessed frame is detection-only
if face_to_use and inswapper_model:
try:
with model_lock:
bgr = inswapper_model.get(bgr, face_to_use, source_face, paste_back=True)
if use_gfpgan and GFPGAN_AVAILABLE:
_, _, bgr = restorer.enhance(
bgr, has_aligned=False, only_center_face=False, paste_back=True)
except Exception as e:
print(f"Preprocess Swap Error Frame {i}: {e}")
output_frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
if i % 30 == 0:
clear_memory()
return encode_gif(output_frames, durations)
|