|
|
from typing import Dict, Any |
|
|
import torch |
|
|
from PIL import Image |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import numpy as np |
|
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
|
from einops import repeat |
|
|
from omegaconf import OmegaConf |
|
|
from transformers import CLIPVisionModelWithProjection |
|
|
import cv2 |
|
|
import os |
|
|
from src.models.pose_guider import PoseGuider |
|
|
from src.models.unet_2d_condition import UNet2DConditionModel |
|
|
from src.models.unet_3d import UNet3DConditionModel |
|
|
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline |
|
|
from src.utils.util import read_frames, get_fps, save_videos_grid |
|
|
import roop.globals |
|
|
from roop.core import start, decode_execution_providers, suggest_max_memory, suggest_execution_threads |
|
|
from roop.utilities import normalize_output_path |
|
|
from roop.processors.frame.core import get_frame_processors_modules |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
if device.type != 'cuda': |
|
|
raise ValueError("The model requires a GPU for inference.") |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
config_path = os.path.join(base_dir, 'configs', 'prompts', 'animation.yaml') |
|
|
|
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"The configuration file was not found at: {config_path}") |
|
|
|
|
|
self.config = OmegaConf.load(config_path) |
|
|
self.weight_dtype = torch.float16 |
|
|
self.pipeline = None |
|
|
self._initialize_pipeline() |
|
|
|
|
|
def _initialize_pipeline(self): |
|
|
base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
config_path = os.path.join(base_dir, 'pretrained_weights', 'sd-vae-ft-mse') |
|
|
|
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"The sd-vae-ft-mse folder was not found at: {config_path}") |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(config_path).to(device, dtype=self.weight_dtype) |
|
|
|
|
|
pretrained_base_model_path_unet = os.path.join(base_dir, 'pretrained_weights', 'stable-diffusion-v1-5', 'unet') |
|
|
|
|
|
reference_unet = UNet2DConditionModel.from_pretrained( |
|
|
'pretrained_weights/stable-diffusion-v1-5/unet', |
|
|
local_files_only=True |
|
|
).to(device, dtype=self.weight_dtype) |
|
|
|
|
|
inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml') |
|
|
motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth') |
|
|
denoising_unet_path = os.path.join(base_dir, 'pretrained_weights', 'denoising_unet.pth') |
|
|
reference_unet_path = os.path.join(base_dir, 'pretrained_weights', 'reference_unet.pth') |
|
|
pose_guider_path = os.path.join(base_dir, 'pretrained_weights', 'pose_guider.pth') |
|
|
image_encoder_path = os.path.join(base_dir, 'pretrained_weights', 'image_encoder') |
|
|
|
|
|
infer_config = OmegaConf.load(inference_config_path) |
|
|
denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
|
|
pretrained_base_model_path_unet, |
|
|
motion_module_path, |
|
|
unet_additional_kwargs=infer_config.unet_additional_kwargs, |
|
|
).to(device, dtype=self.weight_dtype) |
|
|
|
|
|
pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(device, dtype=self.weight_dtype) |
|
|
image_enc = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device, dtype=self.weight_dtype) |
|
|
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) |
|
|
scheduler = DDIMScheduler(**sched_kwargs) |
|
|
|
|
|
denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False) |
|
|
reference_unet.load_state_dict(torch.load(reference_unet_path, map_location="cpu")) |
|
|
pose_guider.load_state_dict(torch.load(pose_guider_path, map_location="cpu")) |
|
|
|
|
|
self.pipeline = Pose2VideoPipeline( |
|
|
vae=vae, |
|
|
image_encoder=image_enc, |
|
|
reference_unet=reference_unet, |
|
|
denoising_unet=denoising_unet, |
|
|
pose_guider=pose_guider, |
|
|
scheduler=scheduler |
|
|
).to(device, dtype=self.weight_dtype) |
|
|
|
|
|
def _crop_face(self, image, save_path="cropped_face.jpg", margin=0.5): |
|
|
|
|
|
cv_image = np.array(image) |
|
|
cv_image = cv_image[:, :, ::-1].copy() |
|
|
|
|
|
|
|
|
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
|
|
|
|
|
|
|
|
gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY) |
|
|
faces = face_cascade.detectMultiScale(gray, 1.1, 4) |
|
|
|
|
|
if len(faces) == 0: |
|
|
raise ValueError("No faces detected in the reference image.") |
|
|
|
|
|
|
|
|
x, y, w, h = faces[0] |
|
|
x_margin = int(margin * w) |
|
|
y_margin = int(margin * h) |
|
|
|
|
|
x1 = max(0, x - x_margin) |
|
|
y1 = max(0, y - y_margin) |
|
|
x2 = min(cv_image.shape[1], x + w + x_margin) |
|
|
y2 = min(cv_image.shape[0], y + h + y_margin) |
|
|
|
|
|
cropped_face = cv_image[y1:y2, x1:x2] |
|
|
|
|
|
|
|
|
cropped_face = Image.fromarray(cropped_face[:, :, ::-1]).convert("RGB") |
|
|
|
|
|
|
|
|
cropped_face.save(save_path, format="JPEG", quality=95) |
|
|
|
|
|
return cropped_face |
|
|
|
|
|
def _swap_face(self, source_image, target_video_path): |
|
|
source_path = "input.jpg" |
|
|
source_image.save(source_path, format="JPEG", quality=95) |
|
|
output_path = "output.mp4" |
|
|
|
|
|
roop.globals.source_path = source_path |
|
|
roop.globals.target_path = target_video_path |
|
|
roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, output_path) |
|
|
roop.globals.frame_processors = ["face_swapper", "face_enhancer"] |
|
|
roop.globals.headless = True |
|
|
roop.globals.keep_fps = True |
|
|
roop.globals.keep_audio = True |
|
|
roop.globals.keep_frames = False |
|
|
roop.globals.many_faces = False |
|
|
roop.globals.video_encoder = "libx264" |
|
|
roop.globals.video_quality = 50 |
|
|
roop.globals.max_memory = suggest_max_memory() |
|
|
roop.globals.execution_providers = decode_execution_providers(["cpu"]) |
|
|
roop.globals.execution_threads = suggest_execution_threads() |
|
|
|
|
|
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): |
|
|
if not frame_processor.pre_check(): |
|
|
raise ValueError("Frame processor pre-check failed.") |
|
|
|
|
|
start() |
|
|
|
|
|
return os.path.join(os.getcwd(), output_path) |
|
|
|
|
|
def __call__(self, data: Any) -> Dict[str, str]: |
|
|
inputs = data.get("inputs", {}) |
|
|
ref_image_base64 = inputs.get("ref_image", "") |
|
|
pose_video_path = inputs.get("pose_video_path", "") |
|
|
width = inputs.get("width", 512) |
|
|
height = inputs.get("height", 768) |
|
|
length = inputs.get("length", 24) |
|
|
num_inference_steps = inputs.get("num_inference_steps", 25) |
|
|
cfg = inputs.get("cfg", 3.5) |
|
|
seed = inputs.get("seed", 123) |
|
|
|
|
|
ref_image = Image.open(BytesIO(base64.b64decode(ref_image_base64))) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
pose_images = read_frames(pose_video_path) |
|
|
src_fps = get_fps(pose_video_path) |
|
|
|
|
|
pose_list = [] |
|
|
total_length = min(length, len(pose_images)) |
|
|
for pose_image_pil in pose_images[:total_length]: |
|
|
pose_list.append(pose_image_pil) |
|
|
|
|
|
video = self.pipeline( |
|
|
ref_image, |
|
|
pose_list, |
|
|
width=width, |
|
|
height=height, |
|
|
video_length=total_length, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=cfg |
|
|
).videos |
|
|
|
|
|
save_dir = f"./output/gradio" |
|
|
if not os.path.exists(save_dir): |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
animation_path = os.path.join(save_dir, "animation_output.mp4") |
|
|
save_videos_grid(video, animation_path, n_rows=1, fps=src_fps) |
|
|
|
|
|
|
|
|
cropped_face_path = os.path.join(save_dir, "cropped_face.jpg") |
|
|
cropped_face = self._crop_face(ref_image, save_path=cropped_face_path) |
|
|
|
|
|
|
|
|
final_video_path = self._swap_face(cropped_face, animation_path) |
|
|
|
|
|
|
|
|
with open(final_video_path, "rb") as video_file: |
|
|
video_base64 = base64.b64encode(video_file.read()).decode("utf-8") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return {"video": video_base64} |
|
|
|