video-animator / handler.py
root
trying
96c092e
raw
history blame
8.55 kB
from typing import Dict, Any
import torch
from PIL import Image
import base64
from io import BytesIO
import numpy as np
from diffusers import AutoencoderKL, DDIMScheduler
from einops import repeat
from omegaconf import OmegaConf
from transformers import CLIPVisionModelWithProjection
import cv2
import os
from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.unet_3d import UNet3DConditionModel
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from src.utils.util import read_frames, get_fps, save_videos_grid
import roop.globals
from roop.core import start, decode_execution_providers, suggest_max_memory, suggest_execution_threads
from roop.utilities import normalize_output_path
from roop.processors.frame.core import get_frame_processors_modules
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("The model requires a GPU for inference.")
class EndpointHandler():
def __init__(self, path=""):
base_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(base_dir, 'configs', 'prompts', 'animation.yaml')
if not os.path.exists(config_path):
raise FileNotFoundError(f"The configuration file was not found at: {config_path}")
self.config = OmegaConf.load(config_path)
self.weight_dtype = torch.float16
self.pipeline = None
self._initialize_pipeline()
def _initialize_pipeline(self):
base_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(base_dir, 'pretrained_weights', 'sd-vae-ft-mse')
if not os.path.exists(config_path):
raise FileNotFoundError(f"The sd-vae-ft-mse folder was not found at: {config_path}")
vae = AutoencoderKL.from_pretrained(config_path).to(device, dtype=self.weight_dtype)
pretrained_base_model_path_unet = os.path.join(base_dir, 'pretrained_weights', 'stable-diffusion-v1-5', 'unet')
reference_unet = UNet2DConditionModel.from_pretrained(
'pretrained_weights/stable-diffusion-v1-5/unet',
local_files_only=True
).to(device, dtype=self.weight_dtype)
inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml')
motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth')
denoising_unet_path = os.path.join(base_dir, 'pretrained_weights', 'denoising_unet.pth')
reference_unet_path = os.path.join(base_dir, 'pretrained_weights', 'reference_unet.pth')
pose_guider_path = os.path.join(base_dir, 'pretrained_weights', 'pose_guider.pth')
image_encoder_path = os.path.join(base_dir, 'pretrained_weights', 'image_encoder')
infer_config = OmegaConf.load(inference_config_path)
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
pretrained_base_model_path_unet,
motion_module_path,
unet_additional_kwargs=infer_config.unet_additional_kwargs,
).to(device, dtype=self.weight_dtype)
pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(device, dtype=self.weight_dtype)
image_enc = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device, dtype=self.weight_dtype)
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
scheduler = DDIMScheduler(**sched_kwargs)
denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False)
reference_unet.load_state_dict(torch.load(reference_unet_path, map_location="cpu"))
pose_guider.load_state_dict(torch.load(pose_guider_path, map_location="cpu"))
self.pipeline = Pose2VideoPipeline(
vae=vae,
image_encoder=image_enc,
reference_unet=reference_unet,
denoising_unet=denoising_unet,
pose_guider=pose_guider,
scheduler=scheduler
).to(device, dtype=self.weight_dtype)
def _crop_face(self, image, save_path="cropped_face.jpg", margin=0.5):
# Convert image to OpenCV format
cv_image = np.array(image)
cv_image = cv_image[:, :, ::-1].copy()
# Load OpenCV face detector
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
# Detect faces
gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
if len(faces) == 0:
raise ValueError("No faces detected in the reference image.")
# Crop the first face found with a margin
x, y, w, h = faces[0]
x_margin = int(margin * w)
y_margin = int(margin * h)
x1 = max(0, x - x_margin)
y1 = max(0, y - y_margin)
x2 = min(cv_image.shape[1], x + w + x_margin)
y2 = min(cv_image.shape[0], y + h + y_margin)
cropped_face = cv_image[y1:y2, x1:x2]
# Convert back to PIL format
cropped_face = Image.fromarray(cropped_face[:, :, ::-1]).convert("RGB")
# Save the cropped face
cropped_face.save(save_path, format="JPEG", quality=95)
return cropped_face
def _swap_face(self, source_image, target_video_path):
source_path = "input.jpg"
source_image.save(source_path, format="JPEG", quality=95)
output_path = "output.mp4"
roop.globals.source_path = source_path
roop.globals.target_path = target_video_path
roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, output_path)
roop.globals.frame_processors = ["face_swapper", "face_enhancer"]
roop.globals.headless = True
roop.globals.keep_fps = True
roop.globals.keep_audio = True
roop.globals.keep_frames = False
roop.globals.many_faces = False
roop.globals.video_encoder = "libx264"
roop.globals.video_quality = 50
roop.globals.max_memory = suggest_max_memory()
roop.globals.execution_providers = decode_execution_providers(["cpu"])
roop.globals.execution_threads = suggest_execution_threads()
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
if not frame_processor.pre_check():
raise ValueError("Frame processor pre-check failed.")
start()
return os.path.join(os.getcwd(), output_path)
def __call__(self, data: Any) -> Dict[str, str]:
inputs = data.get("inputs", {})
ref_image_base64 = inputs.get("ref_image", "")
pose_video_path = inputs.get("pose_video_path", "")
width = inputs.get("width", 512)
height = inputs.get("height", 768)
length = inputs.get("length", 24)
num_inference_steps = inputs.get("num_inference_steps", 25)
cfg = inputs.get("cfg", 3.5)
seed = inputs.get("seed", 123)
ref_image = Image.open(BytesIO(base64.b64decode(ref_image_base64)))
torch.manual_seed(seed)
pose_images = read_frames(pose_video_path)
src_fps = get_fps(pose_video_path)
pose_list = []
total_length = min(length, len(pose_images))
for pose_image_pil in pose_images[:total_length]:
pose_list.append(pose_image_pil)
video = self.pipeline(
ref_image,
pose_list,
width=width,
height=height,
video_length=total_length,
num_inference_steps=num_inference_steps,
guidance_scale=cfg
).videos
save_dir = f"./output/gradio"
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
animation_path = os.path.join(save_dir, "animation_output.mp4")
save_videos_grid(video, animation_path, n_rows=1, fps=src_fps)
# Crop the face from the reference image and save it
cropped_face_path = os.path.join(save_dir, "cropped_face.jpg")
cropped_face = self._crop_face(ref_image, save_path=cropped_face_path)
# Perform face swapping
final_video_path = self._swap_face(cropped_face, animation_path)
# Encode the final video in base64
with open(final_video_path, "rb") as video_file:
video_base64 = base64.b64encode(video_file.read()).decode("utf-8")
torch.cuda.empty_cache()
return {"video": video_base64}