videopix commited on
Commit
580f903
·
verified ·
1 Parent(s): b3605aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -242
app.py CHANGED
@@ -11,9 +11,9 @@ from insightface.app import FaceAnalysis
11
  from tempfile import NamedTemporaryFile
12
 
13
  from fastapi import FastAPI, UploadFile, File, HTTPException
14
- from fastapi.responses import HTMLResponse, Response, StreamingResponse
15
 
16
- # Optional ONNX enhancer import (only used if small_enhancer.onnx exists)
17
  try:
18
  import onnxruntime as ort
19
  ONNX_AVAILABLE = True
@@ -25,17 +25,16 @@ except Exception:
25
  # -----------------------------------------------------------
26
  face_app = None
27
  swapper = None
28
- enhancer_session = None # optional onnx session
29
 
30
  # task store (in-memory)
31
  TASKS = {} # task_id -> {"status": "queued|processing|done|failed", "result": path, "error": text}
32
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=6)
33
 
34
  # -----------------------------------------------------------
35
- # MODEL DOWNLOAD (swapper)
36
  # -----------------------------------------------------------
37
  def download_swapper_model():
38
- """Download inswapper_128.onnx if not present (used by insightface model_zoo loader)."""
39
  url = "https://cdn.adikhanofficial.com/python/insightface/models/inswapper_128.onnx"
40
  filename = os.path.basename(url)
41
  save_path = os.path.join(os.path.dirname(__file__), filename)
@@ -49,19 +48,17 @@ def download_swapper_model():
49
  else:
50
  print("Swapper model already exists:", filename)
51
 
52
- # -----------------------------------------------------------
53
- # OPTIONAL: try to initialize a small ONNX enhancer if present
54
- # -----------------------------------------------------------
55
  def try_load_enhancer():
56
  global enhancer_session
57
  if not ONNX_AVAILABLE:
58
- print("onnxruntime not installed — enhancer disabled (fallback to OpenCV).")
59
  return
60
- path = os.path.join(os.path.dirname(__file__), "small_enhancer.onnx")
61
- if os.path.exists(path):
62
  try:
63
- enhancer_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
64
- print("Loaded ONNX enhancer from", path)
65
  except Exception as e:
66
  print("Failed to load ONNX enhancer:", e)
67
  enhancer_session = None
@@ -70,47 +67,34 @@ def try_load_enhancer():
70
 
71
 
72
  # -----------------------------------------------------------
73
- # FACE SWAP HELPER (insightface)
74
  # -----------------------------------------------------------
75
  def swap_faces(target_img, target_face, source_face):
76
- # swapper.get is synchronous
77
  return swapper.get(target_img, target_face, source_face, paste_back=True)
78
 
79
 
80
  # -----------------------------------------------------------
81
- # ENHANCEMENT: ONNX (if present) OR OpenCV fallback
82
  # -----------------------------------------------------------
83
  def enhance_face_with_onnx(face_bgr):
84
- """
85
- Attempt to use enhancer_session. This code assumes the ONNX model accepts
86
- a normalized float32 CHW input named whatever the first input is.
87
- Because small ONNX models vary, this function tries a generic call.
88
- If it fails, it will raise and caller will fallback to OpenCV enhancement.
89
- """
90
  if enhancer_session is None:
91
  raise RuntimeError("Enhancer session not available")
92
-
93
- # Generic preprocessing: resize to model input size if single static input size detected
94
  inp = enhancer_session.get_inputs()[0]
95
  name = inp.name
96
- shape = inp.shape # e.g. [1,3,256,256] or [None,3,512,512]
97
- # decide target size
98
  try:
99
- _, c, h, w = (shape if len(shape) == 4 else (1,3,512,512))
100
- h = int(h) if (isinstance(h, int) or (isinstance(h, np.integer))) else 512
101
- w = int(w) if (isinstance(w, int) or (isinstance(w, np.integer))) else 512
102
  except Exception:
103
  h, w = 512, 512
104
 
105
  img = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2RGB)
106
  img = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC)
107
  img = img.astype(np.float32) / 255.0
108
- # CHW
109
  img = np.transpose(img, (2, 0, 1))[None].astype(np.float32)
110
-
111
- # Run
112
  out = enhancer_session.run(None, {name: img})
113
- # pick first output
114
  out_img = out[0][0]
115
  out_img = np.clip(out_img * 255.0, 0, 255).astype(np.uint8)
116
  out_img = np.transpose(out_img, (1, 2, 0))
@@ -119,57 +103,108 @@ def enhance_face_with_onnx(face_bgr):
119
 
120
 
121
  def enhance_face_opencv(face_bgr):
122
- """
123
- Lightweight, fast CPU-only enhancement:
124
- - upscale x2 using bicubic
125
- - denoise with bilateral filter
126
- - sharpen via unsharp mask
127
- This works well as a fallback for small faces and is deterministic.
128
- """
129
  if face_bgr is None or face_bgr.size == 0:
130
  return face_bgr
131
-
132
- # upscale x2
133
  h, w = face_bgr.shape[:2]
134
  target_h, target_w = max(64, h * 2), max(64, w * 2)
135
  up = cv2.resize(face_bgr, (target_w, target_h), interpolation=cv2.INTER_CUBIC)
136
-
137
- # bilateral filter to smooth artifacts
138
  denoised = cv2.bilateralFilter(up, d=5, sigmaColor=75, sigmaSpace=75)
139
-
140
- # unsharp mask (sharpen)
141
- blurred = cv2.GaussianBlur(denoised, (0,0), sigmaX=2, sigmaY=2)
142
  sharpened = cv2.addWeighted(denoised, 1.4, blurred, -0.4, 0)
143
-
144
- # optional mild contrast boost
145
  lab = cv2.cvtColor(sharpened, cv2.COLOR_BGR2LAB)
146
  l, a, b = cv2.split(lab)
147
  l = cv2.equalizeHist(l)
148
  lab = cv2.merge((l, a, b))
149
  result = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
150
-
151
- # resize back to original face box size will be handled by caller
152
  return result
153
 
154
 
155
  def enhance_face(face_bgr):
156
- # Try ONNX first if available
157
  if enhancer_session is not None:
158
  try:
159
  return enhance_face_with_onnx(face_bgr)
160
  except Exception as e:
161
- print("ONNX enhancer failed, falling back to OpenCV enhancer:", e)
162
- # fall through to OpenCV
163
  return enhance_face_opencv(face_bgr)
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  # -----------------------------------------------------------
167
  # BACKGROUND TASK HANDLER
168
  # -----------------------------------------------------------
169
  def run_task(task_id, src_bytes, tgt_bytes):
170
  TASKS[task_id]["status"] = "processing"
171
  try:
172
- # decode
173
  src = cv2.imdecode(np.frombuffer(src_bytes, np.uint8), cv2.IMREAD_COLOR)
174
  tgt = cv2.imdecode(np.frombuffer(tgt_bytes, np.uint8), cv2.IMREAD_COLOR)
175
 
@@ -178,7 +213,6 @@ def run_task(task_id, src_bytes, tgt_bytes):
178
  if tgt is None:
179
  raise ValueError("Invalid target image")
180
 
181
- # detect faces
182
  src_faces = face_app.get(src)
183
  tgt_faces = face_app.get(tgt)
184
 
@@ -187,38 +221,44 @@ def run_task(task_id, src_bytes, tgt_bytes):
187
  if not tgt_faces:
188
  raise ValueError("No face detected in target image")
189
 
190
- # choose first faces (you can enhance selection logic)
191
  s_face = src_faces[0]
192
  t_face = tgt_faces[0]
193
 
194
- # perform swap (returns full image with pasted face)
195
  swapped = swap_faces(tgt, t_face, s_face)
196
 
197
- # attempt to crop the face region (using bbox from detection)
198
  try:
199
- # bbox could be float, ensure ints and clamp
200
  x1, y1, x2, y2 = map(int, map(round, t_face.bbox))
201
  h, w = swapped.shape[:2]
202
- x1, x2 = max(0, min(x1, w-1)), max(0, min(x2, w))
203
- y1, y2 = max(0, min(y1, h-1)), max(0, min(y2, h))
 
204
  if x2 - x1 > 10 and y2 - y1 > 10:
205
  face_crop = swapped[y1:y2, x1:x2].copy()
206
 
207
- # enhance the face crop
208
  enhanced = enhance_face(face_crop)
209
 
210
- # resize enhanced back to bbox size (in case upscaled)
211
- enhanced_resized = cv2.resize(enhanced, (x2 - x1, y2 - y1), interpolation=cv2.INTER_CUBIC)
212
- swapped[y1:y2, x1:x2] = enhanced_resized
 
 
 
 
 
 
 
 
 
 
213
  else:
214
- # bbox too small; skip enhancement
215
  pass
216
  except Exception as e:
217
- # If any error occurs during enhancement, continue with swapped image
218
- print("Face enhancement step failed:", e)
219
  print(traceback.format_exc())
220
 
221
- # write output
222
  out_path = f"/tmp/{task_id}.jpg"
223
  ok = cv2.imwrite(out_path, swapped)
224
  if not ok:
@@ -237,11 +277,10 @@ def run_task(task_id, src_bytes, tgt_bytes):
237
  # -----------------------------------------------------------
238
  # FASTAPI app & UI
239
  # -----------------------------------------------------------
240
- app = FastAPI(title="FaceSwap Async API (image-only)")
241
 
242
  @app.get("/", response_class=HTMLResponse)
243
  def home():
244
- # stylish async UI with fixed drag & drop (same as previous UI, improved)
245
  return """
246
  <!doctype html>
247
  <html lang="en">
@@ -250,191 +289,72 @@ def home():
250
  <title>FaceSwap Async API</title>
251
  <meta name="viewport" content="width=device-width,initial-scale=1">
252
  <style>
253
- :root { --bg:#0e1525; --card:rgba(255,255,255,0.06); --accent:#6c5ce7; --accent2:#00cec9; --text:#e8ecf2; --muted:#9ba3b4; --danger:#ff6b6b; --success:#2ecc71; --radius:14px; }
254
- body { background: linear-gradient(160deg,#111827,#0e1525); font-family: Inter, sans-serif; padding:25px; color:var(--text); margin:0; }
255
- .card { background:var(--card); padding:22px; border-radius:var(--radius); box-shadow:0 8px 25px rgba(0,0,0,0.4); max-width:900px; margin:auto; }
256
- h1{ font-size:24px; margin:0 0 6px 0; } .subtitle{ color:var(--muted); font-size:14px; margin-bottom:14px; }
257
- .row{ display:flex; gap:12px; align-items:center; flex-wrap:wrap; }
258
- .upload-zone { flex:1; min-width:220px; border:2px dashed rgba(255,255,255,0.12); padding:18px; border-radius:12px; text-align:center; cursor:pointer; transition:0.12s; }
259
- .upload-zone.dragover{ border-color: var(--accent); background: rgba(255,255,255,0.02); transform: translateY(-4px); }
260
- input[type=file]{ display:none; }
261
- button{ background: linear-gradient(90deg,var(--accent),var(--accent2)); border:none; padding:12px 18px; border-radius:12px; color:white; font-weight:600; cursor:pointer; }
262
- .status{ margin-top:12px; color:var(--muted); }
263
- .task-id{ margin-top:8px; color:var(--accent2); font-weight:700; }
264
- .preview{ margin-top:16px; text-align:center; }
265
- .preview img{ max-width:100%; border-radius:10px; border:1px solid rgba(255,255,255,0.06); background:#021017; }
266
- .loader { width:20px;height:20px;border:3px solid rgba(255,255,255,0.12); border-top-color:var(--accent2); border-radius:50%; animation:spin .9s linear infinite; display:inline-block; }
267
- @keyframes spin{ to{ transform:rotate(360deg) } }
268
  </style>
269
  </head>
270
  <body>
271
  <div class="card">
272
  <h1>Async FaceSwap</h1>
273
  <div class="subtitle">Upload a source face image and a target image. You get a task id immediately; result will appear when ready.</div>
274
-
275
  <div class="row">
276
- <label class="upload-zone" id="zoneSource">
277
- <div><strong id="labelSource">Click or drop source image</strong></div>
278
- <input id="inputSource" type="file" accept="image/*">
279
- </label>
280
-
281
- <label class="upload-zone" id="zoneTarget">
282
- <div><strong id="labelTarget">Click or drop target image</strong></div>
283
- <input id="inputTarget" type="file" accept="image/*">
284
- </label>
285
  </div>
286
-
287
- <div class="row" style="margin-top:12px;">
288
  <button id="startBtn">Start FaceSwap</button>
289
- <button id="resetBtn" style="background:transparent;border:1px solid rgba(255,255,255,0.06);color:var(--muted);">Reset</button>
290
  <div style="flex:1"></div>
291
  <div id="status" class="status"></div>
292
  </div>
293
-
294
  <div class="task-id" id="taskId"></div>
295
-
296
- <div class="preview">
297
- <img id="resultImg" style="display:none;">
298
- </div>
299
-
300
  </div>
301
-
302
  <script>
303
  (function(){
304
- const zoneSource = document.getElementById('zoneSource');
305
- const zoneTarget = document.getElementById('zoneTarget');
306
- const inputSource = document.getElementById('inputSource');
307
- const inputTarget = document.getElementById('inputTarget');
308
- const labelSource = document.getElementById('labelSource');
309
- const labelTarget = document.getElementById('labelTarget');
310
- const startBtn = document.getElementById('startBtn');
311
- const resetBtn = document.getElementById('resetBtn');
312
- const status = document.getElementById('status');
313
- const taskIdEl = document.getElementById('taskId');
314
- const resultImg = document.getElementById('resultImg');
315
-
316
- let sourceFile = null;
317
- let targetFile = null;
318
- let pollHandle = null;
319
-
320
- function prevent(e){ e.preventDefault(); e.stopPropagation(); }
321
-
322
- function makeDrop(zone, input, label){
323
- ['dragenter','dragover','dragleave','drop'].forEach(ev => {
324
- zone.addEventListener(ev, prevent);
325
- });
326
- zone.addEventListener('dragover', ()=> zone.classList.add('dragover'));
327
- zone.addEventListener('dragleave', ()=> zone.classList.remove('dragover'));
328
- zone.addEventListener('drop', (ev)=>{
329
- zone.classList.remove('dragover');
330
- const f = ev.dataTransfer.files && ev.dataTransfer.files[0];
331
- if (!f) return;
332
- input.files = ev.dataTransfer.files;
333
- label.innerText = f.name;
334
- if (input === inputSource) sourceFile = f;
335
- if (input === inputTarget) targetFile = f;
336
- });
337
- zone.addEventListener('click', ()=> input.click());
338
- input.addEventListener('change', ()=> {
339
- const f = input.files && input.files[0];
340
- if (!f) return;
341
- label.innerText = f.name;
342
- if (input === inputSource) sourceFile = f;
343
- if (input === inputTarget) targetFile = f;
344
- });
345
  }
346
-
347
- makeDrop(zoneSource, inputSource, labelSource);
348
- makeDrop(zoneTarget, inputTarget, labelTarget);
349
-
350
- startBtn.addEventListener('click', async ()=>{
351
- if (!sourceFile || !targetFile) { alert('Select source and target images'); return; }
352
- status.innerHTML = 'Uploading <span class="loader"></span>';
353
- startBtn.disabled = true;
354
- try {
355
- const fd = new FormData();
356
- fd.append('source', sourceFile);
357
- fd.append('target', targetFile);
358
- const res = await fetch('/swap-image', { method:'POST', body: fd });
359
- if (!res.ok) {
360
- const txt = await res.text().catch(()=>res.statusText);
361
- status.innerHTML = '<span style="color:#ff6b6b">Upload failed: ' + txt + '</span>';
362
- startBtn.disabled = false;
363
- return;
364
- }
365
- const data = await res.json();
366
- taskIdEl.innerText = 'Task ID: ' + data.task_id;
367
- pollStatus(data.task_id);
368
- } catch (err) {
369
- status.innerText = 'Network error';
370
- console.error(err);
371
- } finally {
372
- // keep button disabled until task completes or reset
373
- }
374
  });
375
-
376
- resetBtn.addEventListener('click', ()=>{
377
- sourceFile = targetFile = null;
378
- inputSource.value = '';
379
- inputTarget.value = '';
380
- labelSource.innerText = 'Click or drop source image';
381
- labelTarget.innerText = 'Click or drop target image';
382
- status.innerText = '';
383
- taskIdEl.innerText = '';
384
- resultImg.style.display = 'none';
385
- if (pollHandle) { clearInterval(pollHandle); pollHandle = null; }
386
- startBtn.disabled = false;
387
- });
388
-
389
- function pollStatus(taskId){
390
- // poll every 1s
391
- pollHandle = setInterval(async ()=>{
392
- try {
393
- const res = await fetch('/task-status/' + taskId);
394
- const data = await res.json();
395
- if (data.status === 'processing') {
396
- status.innerHTML = 'Processing <span class="loader"></span>';
397
- } else if (data.status === 'failed') {
398
- status.innerHTML = '<span style="color:#ff6b6b">Failed: ' + (data.error || '') + '</span>';
399
- clearInterval(pollHandle);
400
- pollHandle = null;
401
- startBtn.disabled = false;
402
- } else if (data.status === 'done') {
403
- status.innerHTML = '<span style="color:#2ecc71">Completed</span>';
404
- clearInterval(pollHandle);
405
- pollHandle = null;
406
- fetchResult(taskId);
407
- } else {
408
- status.innerText = data.status;
409
- }
410
- } catch (err) {
411
- console.error(err);
412
- status.innerText = 'Status error';
413
- }
414
- }, 1000);
415
- }
416
-
417
  async function fetchResult(taskId){
418
- try {
419
- const res = await fetch('/task-result/' + taskId);
420
- if (!res.ok) {
421
- const t = await res.text().catch(()=>res.statusText);
422
- status.innerHTML = '<span style="color:#ff6b6b">Result fetch failed: ' + t + '</span>';
423
- startBtn.disabled = false;
424
- return;
425
- }
426
- const blob = await res.blob();
427
- const url = URL.createObjectURL(blob);
428
- resultImg.src = url;
429
- resultImg.style.display = 'block';
430
- startBtn.disabled = false;
431
- } catch (err) {
432
- console.error(err);
433
- status.innerText = 'Fetch error';
434
- startBtn.disabled = false;
435
- }
436
  }
437
-
438
  })();
439
  </script>
440
  </body>
@@ -452,7 +372,6 @@ async def swap_image(source: UploadFile = File(...), target: UploadFile = File(.
452
  task_id = str(uuid.uuid4())
453
  TASKS[task_id] = {"status": "queued"}
454
 
455
- # submit background job
456
  executor.submit(run_task, task_id, src_bytes, tgt_bytes)
457
 
458
  return {"task_id": task_id, "status": "queued"}
@@ -478,7 +397,6 @@ def task_result(task_id: str):
478
  task = TASKS[task_id]
479
  if task["status"] != "done":
480
  return {"status": task["status"], "error": task.get("error")}
481
- # stream result
482
  return StreamingResponse(open(task["result"], "rb"), media_type="image/jpeg")
483
 
484
 
@@ -486,15 +404,15 @@ def task_result(task_id: str):
486
  # INIT MODELS
487
  # -----------------------------------------------------------
488
  print("Initializing models...")
489
- # prepare detection model (CPU)
490
  face_app = FaceAnalysis(name="buffalo_l")
491
  face_app.prepare(ctx_id=-1, det_size=(640, 640))
492
 
493
- # download swapper model if needed and load swapper
494
  download_swapper_model()
495
  swapper = insightface.model_zoo.get_model("inswapper_128.onnx", root=os.path.dirname(__file__))
496
 
497
- # try load optional small ONNX enhancer if you provided it
498
  if ONNX_AVAILABLE:
499
  try_load_enhancer()
500
 
 
11
  from tempfile import NamedTemporaryFile
12
 
13
  from fastapi import FastAPI, UploadFile, File, HTTPException
14
+ from fastapi.responses import HTMLResponse, StreamingResponse
15
 
16
+ # Optional ONNX enhancer import
17
  try:
18
  import onnxruntime as ort
19
  ONNX_AVAILABLE = True
 
25
  # -----------------------------------------------------------
26
  face_app = None
27
  swapper = None
28
+ enhancer_session = None
29
 
30
  # task store (in-memory)
31
  TASKS = {} # task_id -> {"status": "queued|processing|done|failed", "result": path, "error": text}
32
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=6)
33
 
34
  # -----------------------------------------------------------
35
+ # MODELS - download swapper if needed
36
  # -----------------------------------------------------------
37
  def download_swapper_model():
 
38
  url = "https://cdn.adikhanofficial.com/python/insightface/models/inswapper_128.onnx"
39
  filename = os.path.basename(url)
40
  save_path = os.path.join(os.path.dirname(__file__), filename)
 
48
  else:
49
  print("Swapper model already exists:", filename)
50
 
51
+
 
 
52
  def try_load_enhancer():
53
  global enhancer_session
54
  if not ONNX_AVAILABLE:
55
+ print("onnxruntime not installed — ONNX enhancer disabled (using OpenCV fallback).")
56
  return
57
+ p = os.path.join(os.path.dirname(__file__), "small_enhancer.onnx")
58
+ if os.path.exists(p):
59
  try:
60
+ enhancer_session = ort.InferenceSession(p, providers=["CPUExecutionProvider"])
61
+ print("Loaded ONNX enhancer:", p)
62
  except Exception as e:
63
  print("Failed to load ONNX enhancer:", e)
64
  enhancer_session = None
 
67
 
68
 
69
  # -----------------------------------------------------------
70
+ # SWAP HELPER (insightface swapper)
71
  # -----------------------------------------------------------
72
  def swap_faces(target_img, target_face, source_face):
 
73
  return swapper.get(target_img, target_face, source_face, paste_back=True)
74
 
75
 
76
  # -----------------------------------------------------------
77
+ # ENHANCERS
78
  # -----------------------------------------------------------
79
  def enhance_face_with_onnx(face_bgr):
 
 
 
 
 
 
80
  if enhancer_session is None:
81
  raise RuntimeError("Enhancer session not available")
 
 
82
  inp = enhancer_session.get_inputs()[0]
83
  name = inp.name
84
+ shape = inp.shape
85
+ # determine target size (safe fallback)
86
  try:
87
+ _, c, h, w = (shape if len(shape) == 4 else (1, 3, 512, 512))
88
+ h = int(h) if isinstance(h, (int, np.integer)) else 512
89
+ w = int(w) if isinstance(w, (int, np.integer)) else 512
90
  except Exception:
91
  h, w = 512, 512
92
 
93
  img = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2RGB)
94
  img = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC)
95
  img = img.astype(np.float32) / 255.0
 
96
  img = np.transpose(img, (2, 0, 1))[None].astype(np.float32)
 
 
97
  out = enhancer_session.run(None, {name: img})
 
98
  out_img = out[0][0]
99
  out_img = np.clip(out_img * 255.0, 0, 255).astype(np.uint8)
100
  out_img = np.transpose(out_img, (1, 2, 0))
 
103
 
104
 
105
  def enhance_face_opencv(face_bgr):
 
 
 
 
 
 
 
106
  if face_bgr is None or face_bgr.size == 0:
107
  return face_bgr
 
 
108
  h, w = face_bgr.shape[:2]
109
  target_h, target_w = max(64, h * 2), max(64, w * 2)
110
  up = cv2.resize(face_bgr, (target_w, target_h), interpolation=cv2.INTER_CUBIC)
 
 
111
  denoised = cv2.bilateralFilter(up, d=5, sigmaColor=75, sigmaSpace=75)
112
+ blurred = cv2.GaussianBlur(denoised, (0, 0), sigmaX=2, sigmaY=2)
 
 
113
  sharpened = cv2.addWeighted(denoised, 1.4, blurred, -0.4, 0)
 
 
114
  lab = cv2.cvtColor(sharpened, cv2.COLOR_BGR2LAB)
115
  l, a, b = cv2.split(lab)
116
  l = cv2.equalizeHist(l)
117
  lab = cv2.merge((l, a, b))
118
  result = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
 
 
119
  return result
120
 
121
 
122
  def enhance_face(face_bgr):
 
123
  if enhancer_session is not None:
124
  try:
125
  return enhance_face_with_onnx(face_bgr)
126
  except Exception as e:
127
+ print("ONNX enhancement failed:", e)
 
128
  return enhance_face_opencv(face_bgr)
129
 
130
 
131
+ # -----------------------------------------------------------
132
+ # COLOR HARMONIZATION (Reinhard-like in LAB)
133
+ # -----------------------------------------------------------
134
+ def color_transfer_reinhard(src_bgr, ref_bgr):
135
+ """
136
+ Map src colors to match ref using mean/std in LAB color space (fast).
137
+ Returns adjusted src (same size).
138
+ """
139
+ try:
140
+ src_lab = cv2.cvtColor(src_bgr, cv2.COLOR_BGR2LAB).astype(np.float32)
141
+ ref_lab = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2LAB).astype(np.float32)
142
+
143
+ src_means, src_stds = cv2.meanStdDev(src_lab)
144
+ ref_means, ref_stds = cv2.meanStdDev(ref_lab)
145
+
146
+ src_means = src_means.flatten()
147
+ src_stds = src_stds.flatten()
148
+ ref_means = ref_means.flatten()
149
+ ref_stds = ref_stds.flatten()
150
+
151
+ # avoid division by zero
152
+ src_stds[src_stds < 1e-6] = 1.0
153
+
154
+ # transfer
155
+ lab = (src_lab - src_means) * (ref_stds / src_stds) + ref_means
156
+ lab = np.clip(lab, 0, 255).astype(np.uint8)
157
+ out = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
158
+ return out
159
+ except Exception:
160
+ return src_bgr
161
+
162
+
163
+ # -----------------------------------------------------------
164
+ # SOFT FEATHER BLEND
165
+ # -----------------------------------------------------------
166
+ def feather_blend(target_img, new_face, bbox, feather_amount=0.3):
167
+ """
168
+ Soft blend new_face (same size as bbox) into target_img at bbox using a
169
+ Gaussian feathered alpha mask.
170
+ - bbox: (x1,y1,x2,y2)
171
+ - feather_amount: fraction of bbox size used for feather radius
172
+ """
173
+ x1, y1, x2, y2 = bbox
174
+ h = y2 - y1
175
+ w = x2 - x1
176
+ if h <= 0 or w <= 0:
177
+ return target_img
178
+
179
+ # build mask
180
+ mask = np.ones((h, w), dtype=np.uint8) * 255
181
+ # feather sigma proportional to size
182
+ k = int(max(1, min(h, w) * feather_amount))
183
+ k = k if k % 2 == 1 else k + 1
184
+ mask_blur = cv2.GaussianBlur(mask, (k, k), 0)
185
+
186
+ # normalize alpha to [0,1]
187
+ alpha = (mask_blur.astype(np.float32) / 255.0)[:, :, None]
188
+
189
+ # ensure sizes match
190
+ face_resized = cv2.resize(new_face, (w, h), interpolation=cv2.INTER_CUBIC)
191
+
192
+ # region in target
193
+ target_region = target_img[y1:y2, x1:x2].astype(np.float32)
194
+ src_region = face_resized.astype(np.float32)
195
+
196
+ comp = (alpha * src_region + (1 - alpha) * target_region).astype(np.uint8)
197
+ out_img = target_img.copy()
198
+ out_img[y1:y2, x1:x2] = comp
199
+ return out_img
200
+
201
+
202
  # -----------------------------------------------------------
203
  # BACKGROUND TASK HANDLER
204
  # -----------------------------------------------------------
205
  def run_task(task_id, src_bytes, tgt_bytes):
206
  TASKS[task_id]["status"] = "processing"
207
  try:
 
208
  src = cv2.imdecode(np.frombuffer(src_bytes, np.uint8), cv2.IMREAD_COLOR)
209
  tgt = cv2.imdecode(np.frombuffer(tgt_bytes, np.uint8), cv2.IMREAD_COLOR)
210
 
 
213
  if tgt is None:
214
  raise ValueError("Invalid target image")
215
 
 
216
  src_faces = face_app.get(src)
217
  tgt_faces = face_app.get(tgt)
218
 
 
221
  if not tgt_faces:
222
  raise ValueError("No face detected in target image")
223
 
 
224
  s_face = src_faces[0]
225
  t_face = tgt_faces[0]
226
 
 
227
  swapped = swap_faces(tgt, t_face, s_face)
228
 
229
+ # Attempt to extract bbox and do enhancement + color transfer + blend
230
  try:
 
231
  x1, y1, x2, y2 = map(int, map(round, t_face.bbox))
232
  h, w = swapped.shape[:2]
233
+ x1, x2 = max(0, min(x1, w - 1)), max(0, min(x2, w))
234
+ y1, y2 = max(0, min(y1, h - 1)), max(0, min(y2, h))
235
+
236
  if x2 - x1 > 10 and y2 - y1 > 10:
237
  face_crop = swapped[y1:y2, x1:x2].copy()
238
 
239
+ # enhancement (ONNX or OpenCV)
240
  enhanced = enhance_face(face_crop)
241
 
242
+ # color transfer: match enhanced face colors to the target's surrounding region
243
+ # sample a slightly expanded patch from the original target to get ambient color
244
+ pad = int(0.15 * max(y2 - y1, x2 - x1))
245
+ rx1, ry1 = max(0, x1 - pad), max(0, y1 - pad)
246
+ rx2, ry2 = min(w, x2 + pad), min(h, y2 + pad)
247
+ ref_region = tgt[ry1:ry2, rx1:rx2]
248
+ if ref_region is None or ref_region.size == 0:
249
+ ref_region = tgt[y1:y2, x1:x2]
250
+ # perform Reinhard-style mean/std transfer in LAB on enhanced crop
251
+ transferred = color_transfer_reinhard(enhanced, ref_region)
252
+
253
+ # blend transferred face back into swapped image softly
254
+ swapped = feather_blend(swapped, transferred, (x1, y1, x2, y2), feather_amount=0.35)
255
  else:
256
+ # bbox too small, keep swapped
257
  pass
258
  except Exception as e:
259
+ print("Enhancement/color-blend failed:", e)
 
260
  print(traceback.format_exc())
261
 
 
262
  out_path = f"/tmp/{task_id}.jpg"
263
  ok = cv2.imwrite(out_path, swapped)
264
  if not ok:
 
277
  # -----------------------------------------------------------
278
  # FASTAPI app & UI
279
  # -----------------------------------------------------------
280
+ app = FastAPI(title="FaceSwap Async API (image-only, color harmonized)")
281
 
282
  @app.get("/", response_class=HTMLResponse)
283
  def home():
 
284
  return """
285
  <!doctype html>
286
  <html lang="en">
 
289
  <title>FaceSwap Async API</title>
290
  <meta name="viewport" content="width=device-width,initial-scale=1">
291
  <style>
292
+ :root{--bg:#0e1525;--card:rgba(255,255,255,0.06);--accent:#6c5ce7;--accent2:#00cec9;--text:#e8ecf2;--muted:#9ba3b4;--danger:#ff6b6b;--success:#2ecc71;--radius:14px}
293
+ body{background:linear-gradient(160deg,#111827,#0e1525);font-family:Inter, sans-serif;padding:25px;color:var(--text);margin:0}
294
+ .card{background:var(--card);padding:22px;border-radius:var(--radius);box-shadow:0 8px 25px rgba(0,0,0,0.4);max-width:900px;margin:auto}
295
+ h1{font-size:24px;margin:0 0 6px 0}.subtitle{color:var(--muted);font-size:14px;margin-bottom:14px}.row{display:flex;gap:12px;align-items:center;flex-wrap:wrap}.upload-zone{flex:1;min-width:220px;border:2px dashed rgba(255,255,255,0.12);padding:18px;border-radius:12px;text-align:center;cursor:pointer;transition:0.12s}.upload-zone.dragover{border-color:var(--accent);background:rgba(255,255,255,0.02);transform:translateY(-4px)}input[type=file]{display:none}button{background:linear-gradient(90deg,var(--accent),var(--accent2));border:none;padding:12px 18px;border-radius:12px;color:white;font-weight:600;cursor:pointer}.status{margin-top:12px;color:var(--muted)}.task-id{margin-top:8px;color:var(--accent2);font-weight:700}.preview{margin-top:16px;text-align:center}.preview img{max-width:100%;border-radius:10px;border:1px solid rgba(255,255,255,0.06);background:#021017}.loader{width:20px;height:20px;border:3px solid rgba(255,255,255,0.12);border-top-color:var(--accent2);border-radius:50%;animation:spin .9s linear infinite;display:inline-block}@keyframes spin{to{transform:rotate(360deg)}}
 
 
 
 
 
 
 
 
 
 
 
296
  </style>
297
  </head>
298
  <body>
299
  <div class="card">
300
  <h1>Async FaceSwap</h1>
301
  <div class="subtitle">Upload a source face image and a target image. You get a task id immediately; result will appear when ready.</div>
 
302
  <div class="row">
303
+ <label class="upload-zone" id="zoneSource"><div><strong id="labelSource">Click or drop source image</strong></div><input id="inputSource" type="file" accept="image/*"></label>
304
+ <label class="upload-zone" id="zoneTarget"><div><strong id="labelTarget">Click or drop target image</strong></div><input id="inputTarget" type="file" accept="image/*"></label>
 
 
 
 
 
 
 
305
  </div>
306
+ <div class="row" style="margin-top:12px">
 
307
  <button id="startBtn">Start FaceSwap</button>
308
+ <button id="resetBtn" style="background:transparent;border:1px solid rgba(255,255,255,0.06);color:var(--muted)">Reset</button>
309
  <div style="flex:1"></div>
310
  <div id="status" class="status"></div>
311
  </div>
 
312
  <div class="task-id" id="taskId"></div>
313
+ <div class="preview"><img id="resultImg" style="display:none"></div>
 
 
 
 
314
  </div>
 
315
  <script>
316
  (function(){
317
+ const zoneSource=document.getElementById('zoneSource'), zoneTarget=document.getElementById('zoneTarget');
318
+ const inputSource=document.getElementById('inputSource'), inputTarget=document.getElementById('inputTarget');
319
+ const labelSource=document.getElementById('labelSource'), labelTarget=document.getElementById('labelTarget');
320
+ const startBtn=document.getElementById('startBtn'), resetBtn=document.getElementById('resetBtn');
321
+ const status=document.getElementById('status'), taskIdEl=document.getElementById('taskId'), resultImg=document.getElementById('resultImg');
322
+ let sourceFile=null, targetFile=null, pollHandle=null;
323
+ function prevent(e){e.preventDefault();e.stopPropagation();}
324
+ function makeDrop(zone,input,label){
325
+ ['dragenter','dragover','dragleave','drop'].forEach(ev=>zone.addEventListener(ev,prevent));
326
+ zone.addEventListener('dragover',()=>zone.classList.add('dragover'));
327
+ zone.addEventListener('dragleave',()=>zone.classList.remove('dragover'));
328
+ zone.addEventListener('drop',ev=>{zone.classList.remove('dragover');const f=ev.dataTransfer.files&&ev.dataTransfer.files[0];if(!f) return;input.files=ev.dataTransfer.files;label.innerText=f.name;if(input===inputSource) sourceFile=f; if(input===inputTarget) targetFile=f;});
329
+ zone.addEventListener('click',()=>input.click());
330
+ input.addEventListener('change',()=>{const f=input.files&&input.files[0]; if(!f) return; label.innerText=f.name; if(input===inputSource) sourceFile=f; if(input===inputTarget) targetFile=f; });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  }
332
+ makeDrop(zoneSource,inputSource,labelSource); makeDrop(zoneTarget,inputTarget,labelTarget);
333
+ startBtn.addEventListener('click',async()=>{
334
+ if(!sourceFile||!targetFile){alert('Select source and target images');return;}
335
+ status.innerHTML='Uploading <span class="loader"></span>'; startBtn.disabled=true;
336
+ try{
337
+ const fd=new FormData(); fd.append('source', sourceFile); fd.append('target', targetFile);
338
+ const res=await fetch('/swap-image',{method:'POST',body:fd});
339
+ if(!res.ok){ const t=await res.text().catch(()=>res.statusText); status.innerHTML='<span style="color:#ff6b6b">Upload failed: '+t+'</span>'; startBtn.disabled=false; return; }
340
+ const data=await res.json(); taskIdEl.innerText='Task ID: '+data.task_id; pollStatus(data.task_id);
341
+ }catch(err){ console.error(err); status.innerText='Network error'; startBtn.disabled=false; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  });
343
+ resetBtn.addEventListener('click',()=>{ sourceFile=targetFile=null; inputSource.value=''; inputTarget.value=''; labelSource.innerText='Click or drop source image'; labelTarget.innerText='Click or drop target image'; status.innerText=''; taskIdEl.innerText=''; resultImg.style.display='none'; if(pollHandle){clearInterval(pollHandle); pollHandle=null;} startBtn.disabled=false; });
344
+ function pollStatus(taskId){ pollHandle=setInterval(async()=>{
345
+ try{
346
+ const res=await fetch('/task-status/'+taskId); const data=await res.json();
347
+ if(data.status==='processing'){ status.innerHTML='Processing <span class="loader"></span>'; }
348
+ else if(data.status==='failed'){ status.innerHTML='<span style="color:#ff6b6b">Failed: '+(data.error||'')+'</span>'; clearInterval(pollHandle); pollHandle=null; startBtn.disabled=false; }
349
+ else if(data.status==='done'){ status.innerHTML='<span style="color:#2ecc71">Completed</span>'; clearInterval(pollHandle); pollHandle=null; fetchResult(taskId); } else { status.innerText=data.status; }
350
+ }catch(err){ console.error(err); status.innerText='Status error'; }
351
+ },1000); }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  async function fetchResult(taskId){
353
+ try{
354
+ const res=await fetch('/task-result/'+taskId); if(!res.ok){ const t=await res.text().catch(()=>res.statusText); status.innerHTML='<span style="color:#ff6b6b">Result fetch failed: '+t+'</span>'; startBtn.disabled=false; return; }
355
+ const blob=await res.blob(); const url=URL.createObjectURL(blob); resultImg.src=url; resultImg.style.display='block'; startBtn.disabled=false;
356
+ }catch(err){ console.error(err); status.innerText='Fetch error'; startBtn.disabled=false; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  }
 
358
  })();
359
  </script>
360
  </body>
 
372
  task_id = str(uuid.uuid4())
373
  TASKS[task_id] = {"status": "queued"}
374
 
 
375
  executor.submit(run_task, task_id, src_bytes, tgt_bytes)
376
 
377
  return {"task_id": task_id, "status": "queued"}
 
397
  task = TASKS[task_id]
398
  if task["status"] != "done":
399
  return {"status": task["status"], "error": task.get("error")}
 
400
  return StreamingResponse(open(task["result"], "rb"), media_type="image/jpeg")
401
 
402
 
 
404
  # INIT MODELS
405
  # -----------------------------------------------------------
406
  print("Initializing models...")
407
+ # detection model
408
  face_app = FaceAnalysis(name="buffalo_l")
409
  face_app.prepare(ctx_id=-1, det_size=(640, 640))
410
 
411
+ # swapper model
412
  download_swapper_model()
413
  swapper = insightface.model_zoo.get_model("inswapper_128.onnx", root=os.path.dirname(__file__))
414
 
415
+ # optional enhancer
416
  if ONNX_AVAILABLE:
417
  try_load_enhancer()
418