File size: 8,630 Bytes
dd31ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43d63ec
 
 
 
 
 
 
dd31ccf
 
 
 
 
256a090
 
 
 
 
 
 
dd31ccf
50f301d
4738bab
f751834
22af2ed
 
 
 
dd31ccf
7c65853
 
 
 
 
 
 
dd31ccf
 
50f301d
7c65853
dd31ccf
 
 
 
7c65853
dd31ccf
 
 
7c65853
 
 
dd31ccf
 
 
 
 
 
 
 
 
 
a7f2e6c
dd31ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c7fbb7
 
dd31ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a9c05f
dd31ccf
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from typing import Dict, Any
import torch
from PIL import Image
import base64
from io import BytesIO
import numpy as np
from diffusers import AutoencoderKL, DDIMScheduler
from einops import repeat
from omegaconf import OmegaConf
from transformers import CLIPVisionModelWithProjection
import cv2
import os
from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.unet_3d import UNet3DConditionModel
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from src.utils.util import read_frames, get_fps, save_videos_grid
import roop.globals
from roop.core import start, decode_execution_providers, suggest_max_memory, suggest_execution_threads
from roop.utilities import normalize_output_path
from roop.processors.frame.core import get_frame_processors_modules

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type != 'cuda':
    raise ValueError("The model requires a GPU for inference.")

class EndpointHandler():
    def __init__(self, path=""):
        base_dir = os.path.dirname(os.path.abspath(__file__))
        config_path = os.path.join(base_dir, 'configs', 'prompts', 'animation.yaml')

        if not os.path.exists(config_path):
            raise FileNotFoundError(f"The configuration file was not found at: {config_path}")

        self.config = OmegaConf.load(config_path)
        self.weight_dtype = torch.float16
        self.pipeline = None
        self._initialize_pipeline()

    def _initialize_pipeline(self):
        base_dir = os.path.dirname(os.path.abspath(__file__))
        config_path = os.path.join(base_dir, 'pretrained_weights', 'sd-vae-ft-mse')

        if not os.path.exists(config_path):
            raise FileNotFoundError(f"The sd-vae-ft-mse folder was not found at: {config_path}")

        vae = AutoencoderKL.from_pretrained(config_path).to(device, dtype=self.weight_dtype)

        pretrained_base_model_path_unet = os.path.join(base_dir, 'pretrained_weights', 'stable-diffusion-v1-5', 'unet')

        print("model path is " + pretrained_base_model_path_unet)
            reference_unet = UNet2DConditionModel.from_pretrained(
                self.config.pretrained_base_model_path,
                subfolder="unet",
            ).to(dtype=self.weight_dtype, device="cuda")

        inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml')
        motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth')
        denoising_unet_path = os.path.join(base_dir, 'pretrained_weights', 'denoising_unet.pth')
        reference_unet_path = os.path.join(base_dir, 'pretrained_weights', 'reference_unet.pth')
        pose_guider_path = os.path.join(base_dir, 'pretrained_weights', 'pose_guider.pth')
        image_encoder_path = os.path.join(base_dir, 'pretrained_weights', 'image_encoder')

        infer_config = OmegaConf.load(inference_config_path)
        denoising_unet = UNet3DConditionModel.from_pretrained_2d(
            pretrained_base_model_path_unet,
            motion_module_path,
            unet_additional_kwargs=infer_config.unet_additional_kwargs,
        ).to(device, dtype=self.weight_dtype)

        pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(device, dtype=self.weight_dtype)
        image_enc = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device, dtype=self.weight_dtype)
        sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
        scheduler = DDIMScheduler(**sched_kwargs)

        denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False)
        reference_unet.load_state_dict(torch.load(reference_unet_path, map_location="cpu"))
        pose_guider.load_state_dict(torch.load(pose_guider_path, map_location="cpu"))

        self.pipeline = Pose2VideoPipeline(
            vae=vae,
            image_encoder=image_enc,
            reference_unet=reference_unet,
            denoising_unet=denoising_unet,
            pose_guider=pose_guider,
            scheduler=scheduler
        ).to(device, dtype=self.weight_dtype)

    def _crop_face(self, image, save_path="cropped_face.jpg", margin=0.5):
        # Convert image to OpenCV format
        cv_image = np.array(image)
        cv_image = cv_image[:, :, ::-1].copy()

        # Load OpenCV face detector
        face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

        # Detect faces
        gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, 1.1, 4)

        if len(faces) == 0:
            raise ValueError("No faces detected in the reference image.")

        # Crop the first face found with a margin
        x, y, w, h = faces[0]
        x_margin = int(margin * w)
        y_margin = int(margin * h)
        
        x1 = max(0, x - x_margin)
        y1 = max(0, y - y_margin)
        x2 = min(cv_image.shape[1], x + w + x_margin)
        y2 = min(cv_image.shape[0], y + h + y_margin)

        cropped_face = cv_image[y1:y2, x1:x2]

        # Convert back to PIL format
        cropped_face = Image.fromarray(cropped_face[:, :, ::-1]).convert("RGB")

        # Save the cropped face
        cropped_face.save(save_path, format="JPEG", quality=95)

        return cropped_face

    def _swap_face(self, source_image, target_video_path):
        source_path = "input.jpg"
        source_image.save(source_path, format="JPEG", quality=95)
        output_path = "output.mp4"

        roop.globals.source_path = source_path
        roop.globals.target_path = target_video_path
        roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, output_path)
        roop.globals.frame_processors = ["face_swapper", "face_enhancer"]
        roop.globals.headless = True
        roop.globals.keep_fps = True
        roop.globals.keep_audio = True
        roop.globals.keep_frames = False
        roop.globals.many_faces = False
        roop.globals.video_encoder = "libx264"
        roop.globals.video_quality = 50
        roop.globals.max_memory = suggest_max_memory()
        roop.globals.execution_providers = decode_execution_providers(["cpu"])
        roop.globals.execution_threads = suggest_execution_threads()

        for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
            if not frame_processor.pre_check():
                raise ValueError("Frame processor pre-check failed.")

        start()

        return os.path.join(os.getcwd(), output_path)

    def __call__(self, data: Any) -> Dict[str, str]:
        inputs = data.get("inputs", {})
        ref_image_base64 = inputs.get("ref_image", "")
        pose_video_path = inputs.get("pose_video_path", "")
        width = inputs.get("width", 512)
        height = inputs.get("height", 768)
        length = inputs.get("length", 24)
        num_inference_steps = inputs.get("num_inference_steps", 25)
        cfg = inputs.get("cfg", 3.5)
        seed = inputs.get("seed", 123)

        ref_image = Image.open(BytesIO(base64.b64decode(ref_image_base64)))

        torch.manual_seed(seed)
        pose_images = read_frames(pose_video_path)
        src_fps = get_fps(pose_video_path)

        pose_list = []
        total_length = min(length, len(pose_images))
        for pose_image_pil in pose_images[:total_length]:
            pose_list.append(pose_image_pil)

        video = self.pipeline(
            ref_image,
            pose_list,
            width=width,
            height=height,
            video_length=total_length,
            num_inference_steps=num_inference_steps,
            guidance_scale=cfg
        ).videos

        save_dir = f"./output/gradio"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        animation_path = os.path.join(save_dir, "animation_output.mp4")
        save_videos_grid(video, animation_path, n_rows=1, fps=src_fps)

        # Crop the face from the reference image and save it
        cropped_face_path = os.path.join(save_dir, "cropped_face.jpg")
        cropped_face = self._crop_face(ref_image, save_path=cropped_face_path)

        # Perform face swapping
        final_video_path = self._swap_face(cropped_face, animation_path)

        # Encode the final video in base64
        with open(final_video_path, "rb") as video_file:
            video_base64 = base64.b64encode(video_file.read()).decode("utf-8")

        torch.cuda.empty_cache()

        return {"video": video_base64}