from typing import Dict, Any import torch from PIL import Image import base64 from io import BytesIO import numpy as np from diffusers import AutoencoderKL, DDIMScheduler from einops import repeat from omegaconf import OmegaConf from transformers import CLIPVisionModelWithProjection import cv2 import os from src.models.pose_guider import PoseGuider from src.models.unet_2d_condition import UNet2DConditionModel from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from src.utils.util import read_frames, get_fps, save_videos_grid import roop.globals from roop.core import start, decode_execution_providers, suggest_max_memory, suggest_execution_threads from roop.utilities import normalize_output_path from roop.processors.frame.core import get_frame_processors_modules device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("The model requires a GPU for inference.") class EndpointHandler(): def __init__(self, path=""): base_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(base_dir, 'configs', 'prompts', 'animation.yaml') if not os.path.exists(config_path): raise FileNotFoundError(f"The configuration file was not found at: {config_path}") self.config = OmegaConf.load(config_path) self.weight_dtype = torch.float16 self.pipeline = None self._initialize_pipeline() def _initialize_pipeline(self): base_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(base_dir, 'pretrained_weights', 'sd-vae-ft-mse') if not os.path.exists(config_path): raise FileNotFoundError(f"The sd-vae-ft-mse folder was not found at: {config_path}") vae = AutoencoderKL.from_pretrained(config_path).to(device, dtype=self.weight_dtype) pretrained_base_model_path_unet = os.path.join(base_dir, 'pretrained_weights', 'stable-diffusion-v1-5', 'unet') reference_unet = UNet2DConditionModel.from_pretrained( 'pretrained_weights/stable-diffusion-v1-5/unet', local_files_only=True ).to(device, dtype=self.weight_dtype) inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml') motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth') denoising_unet_path = os.path.join(base_dir, 'pretrained_weights', 'denoising_unet.pth') reference_unet_path = os.path.join(base_dir, 'pretrained_weights', 'reference_unet.pth') pose_guider_path = os.path.join(base_dir, 'pretrained_weights', 'pose_guider.pth') image_encoder_path = os.path.join(base_dir, 'pretrained_weights', 'image_encoder') infer_config = OmegaConf.load(inference_config_path) denoising_unet = UNet3DConditionModel.from_pretrained_2d( pretrained_base_model_path_unet, motion_module_path, unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(device, dtype=self.weight_dtype) pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(device, dtype=self.weight_dtype) image_enc = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device, dtype=self.weight_dtype) sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False) reference_unet.load_state_dict(torch.load(reference_unet_path, map_location="cpu")) pose_guider.load_state_dict(torch.load(pose_guider_path, map_location="cpu")) self.pipeline = Pose2VideoPipeline( vae=vae, image_encoder=image_enc, reference_unet=reference_unet, denoising_unet=denoising_unet, pose_guider=pose_guider, scheduler=scheduler ).to(device, dtype=self.weight_dtype) def _crop_face(self, image, save_path="cropped_face.jpg", margin=0.5): # Convert image to OpenCV format cv_image = np.array(image) cv_image = cv_image[:, :, ::-1].copy() # Load OpenCV face detector face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') # Detect faces gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, 1.1, 4) if len(faces) == 0: raise ValueError("No faces detected in the reference image.") # Crop the first face found with a margin x, y, w, h = faces[0] x_margin = int(margin * w) y_margin = int(margin * h) x1 = max(0, x - x_margin) y1 = max(0, y - y_margin) x2 = min(cv_image.shape[1], x + w + x_margin) y2 = min(cv_image.shape[0], y + h + y_margin) cropped_face = cv_image[y1:y2, x1:x2] # Convert back to PIL format cropped_face = Image.fromarray(cropped_face[:, :, ::-1]).convert("RGB") # Save the cropped face cropped_face.save(save_path, format="JPEG", quality=95) return cropped_face def _swap_face(self, source_image, target_video_path): source_path = "input.jpg" source_image.save(source_path, format="JPEG", quality=95) output_path = "output.mp4" roop.globals.source_path = source_path roop.globals.target_path = target_video_path roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, output_path) roop.globals.frame_processors = ["face_swapper", "face_enhancer"] roop.globals.headless = True roop.globals.keep_fps = True roop.globals.keep_audio = True roop.globals.keep_frames = False roop.globals.many_faces = False roop.globals.video_encoder = "libx264" roop.globals.video_quality = 50 roop.globals.max_memory = suggest_max_memory() roop.globals.execution_providers = decode_execution_providers(["cpu"]) roop.globals.execution_threads = suggest_execution_threads() for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): if not frame_processor.pre_check(): raise ValueError("Frame processor pre-check failed.") start() return os.path.join(os.getcwd(), output_path) def __call__(self, data: Any) -> Dict[str, str]: inputs = data.get("inputs", {}) ref_image_base64 = inputs.get("ref_image", "") pose_video_path = inputs.get("pose_video_path", "") width = inputs.get("width", 512) height = inputs.get("height", 768) length = inputs.get("length", 24) num_inference_steps = inputs.get("num_inference_steps", 25) cfg = inputs.get("cfg", 3.5) seed = inputs.get("seed", 123) ref_image = Image.open(BytesIO(base64.b64decode(ref_image_base64))) torch.manual_seed(seed) pose_images = read_frames(pose_video_path) src_fps = get_fps(pose_video_path) pose_list = [] total_length = min(length, len(pose_images)) for pose_image_pil in pose_images[:total_length]: pose_list.append(pose_image_pil) video = self.pipeline( ref_image, pose_list, width=width, height=height, video_length=total_length, num_inference_steps=num_inference_steps, guidance_scale=cfg ).videos save_dir = f"./output/gradio" if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) animation_path = os.path.join(save_dir, "animation_output.mp4") save_videos_grid(video, animation_path, n_rows=1, fps=src_fps) # Crop the face from the reference image and save it cropped_face_path = os.path.join(save_dir, "cropped_face.jpg") cropped_face = self._crop_face(ref_image, save_path=cropped_face_path) # Perform face swapping final_video_path = self._swap_face(cropped_face, animation_path) # Encode the final video in base64 with open(final_video_path, "rb") as video_file: video_base64 = base64.b64encode(video_file.read()).decode("utf-8") torch.cuda.empty_cache() return {"video": video_base64}