File size: 3,614 Bytes
05a3220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import annotations

import imghdr
import tempfile
import threading
from pathlib import Path

import cv2
import onnxruntime

import modules.globals
from modules.face_analyser import get_face_analyser
from modules.processors.frame import face_swapper

_RUNTIME_LOCK = threading.Lock()
_INFERENCE_LOCK = threading.Lock()
_RUNTIME_READY = False


def _select_execution_providers() -> list[str]:
    available = onnxruntime.get_available_providers()
    preferred = [
        "CUDAExecutionProvider",
        "CoreMLExecutionProvider",
        "DmlExecutionProvider",
        "CPUExecutionProvider",
    ]
    selected = [provider for provider in preferred if provider in available]
    return selected or ["CPUExecutionProvider"]


def _configure_globals() -> None:
    modules.globals.frame_processors = ["face_swapper"]
    modules.globals.execution_providers = _select_execution_providers()
    modules.globals.execution_threads = 1
    modules.globals.keep_fps = True
    modules.globals.keep_audio = False
    modules.globals.keep_frames = False
    modules.globals.many_faces = False
    modules.globals.map_faces = False
    modules.globals.mouth_mask = False
    modules.globals.nsfw_filter = False
    modules.globals.headless = True
    modules.globals.opacity = 1.0
    modules.globals.sharpness = 0.0
    modules.globals.enable_interpolation = False
    modules.globals.log_level = "error"


def initialize_runtime() -> None:
    global _RUNTIME_READY

    if _RUNTIME_READY:
        return

    with _RUNTIME_LOCK:
        if _RUNTIME_READY:
            return

        _configure_globals()
        if not face_swapper.pre_check():
            raise RuntimeError("Face swapper model setup failed.")
        if not face_swapper.pre_start():
            raise RuntimeError("Face swapper model failed to initialize.")
        get_face_analyser()
        _RUNTIME_READY = True


def _infer_extension(file_path: Path) -> str:
    detected = imghdr.what(file_path)
    if detected == "jpeg":
        return ".jpg"
    if detected == "png":
        return ".png"
    if detected == "webp":
        return ".webp"
    raise ValueError("Downloaded file is not a supported image.")


def _write_uploaded_image(content: bytes, directory: Path, stem: str) -> Path:
    provisional_path = directory / f"{stem}.bin"
    provisional_path.write_bytes(content)
    final_path = provisional_path.with_suffix(_infer_extension(provisional_path))
    provisional_path.replace(final_path)
    return final_path


def swap_face_from_uploads(source_image_bytes: bytes, target_image_bytes: bytes) -> tuple[bytes, str]:
    initialize_runtime()

    with _INFERENCE_LOCK:
        with tempfile.TemporaryDirectory(prefix="face-swap-") as temp_dir:
            temp_path = Path(temp_dir)
            source_path = _write_uploaded_image(source_image_bytes, temp_path, "source")
            target_path = _write_uploaded_image(target_image_bytes, temp_path, "target")
            output_path = temp_path / "result.png"

            face_swapper.process_image(str(source_path), str(target_path), str(output_path))

            if not output_path.exists():
                raise RuntimeError("Face swap failed and no output image was produced.")

            result = cv2.imread(str(output_path))
            if result is None:
                raise RuntimeError("Generated output image could not be read.")

            ok, encoded = cv2.imencode(".png", result)
            if not ok:
                raise RuntimeError("Generated output image could not be encoded.")

            return encoded.tobytes(), "image/png"