|
|
|
|
|
|
|
|
from typing import Dict, Any |
|
|
import torch |
|
|
from PIL import Image |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import numpy as np |
|
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
|
from einops import repeat |
|
|
from omegaconf import OmegaConf |
|
|
from transformers import CLIPVisionModelWithProjection |
|
|
import cv2 |
|
|
import os |
|
|
import sys |
|
|
import skvideo.io |
|
|
from src.models.pose_guider import PoseGuider |
|
|
from src.models.unet_2d_condition import UNet2DConditionModel |
|
|
from src.models.unet_3d import UNet3DConditionModel |
|
|
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline |
|
|
from src.utils.util import read_frames, get_fps, save_videos_grid |
|
|
|
|
|
|
|
|
import gc |
|
|
import subprocess |
|
|
|
|
|
import requests |
|
|
import tempfile |
|
|
|
|
|
from rembg import remove |
|
|
import onnxruntime as ort |
|
|
import shutil |
|
|
|
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, storage, firestore |
|
|
import json |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if device.type != 'cuda': |
|
|
raise ValueError("The model requires a GPU for inference.") |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
config_path = os.path.join(base_dir, 'configs', 'prompts', 'animation.yaml') |
|
|
|
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"The configuration file was not found at: {config_path}") |
|
|
|
|
|
service_account_info = os.getenv("FIREBASE_ACCOUNT_INFO") |
|
|
|
|
|
if not service_account_info: |
|
|
raise ValueError("The FIREBASE_SERVICE_ACCOUNT environment variable is not set.") |
|
|
service_account_info = service_account_info.replace('/\\n/g', '\n') |
|
|
|
|
|
service_account_info_dict = json.loads(service_account_info) |
|
|
|
|
|
cred = credentials.Certificate(service_account_info_dict) |
|
|
firebase_admin.initialize_app(cred, { |
|
|
'storageBucket': 'quiz-app-edffe.appspot.com' |
|
|
}) |
|
|
|
|
|
self.config = OmegaConf.load(config_path) |
|
|
self.weight_dtype = torch.float16 |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
self.pipeline = None |
|
|
self._initialize_pipeline() |
|
|
|
|
|
def _initialize_pipeline(self): |
|
|
base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
config_path = os.path.join(base_dir, 'pretrained_weights', 'sd-vae-ft-mse') |
|
|
|
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"The sd-vae-ft-mse folder was not found at: {config_path}") |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(config_path).to(self.device, dtype=self.weight_dtype) |
|
|
|
|
|
pretrained_base_model_path_unet = os.path.join(base_dir, 'pretrained_weights', 'stable-diffusion-v1-5', 'unet') |
|
|
print("model path is " + pretrained_base_model_path_unet) |
|
|
reference_unet = UNet2DConditionModel.from_pretrained( |
|
|
pretrained_base_model_path_unet |
|
|
).to(dtype=self.weight_dtype, device=self.device) |
|
|
|
|
|
inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml') |
|
|
motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth') |
|
|
denoising_unet_path = os.path.join(base_dir, 'pretrained_weights', 'denoising_unet.pth') |
|
|
reference_unet_path = os.path.join(base_dir, 'pretrained_weights', 'reference_unet.pth') |
|
|
pose_guider_path = os.path.join(base_dir, 'pretrained_weights', 'pose_guider.pth') |
|
|
image_encoder_path = os.path.join(base_dir, 'pretrained_weights', 'image_encoder') |
|
|
|
|
|
infer_config = OmegaConf.load(inference_config_path) |
|
|
denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
|
|
pretrained_base_model_path_unet, |
|
|
motion_module_path, |
|
|
unet_additional_kwargs=infer_config.unet_additional_kwargs, |
|
|
).to(self.device, dtype=self.weight_dtype) |
|
|
|
|
|
pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(self.device, dtype=self.weight_dtype) |
|
|
image_enc = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(self.device, dtype=self.weight_dtype) |
|
|
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) |
|
|
scheduler = DDIMScheduler(**sched_kwargs) |
|
|
|
|
|
denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False) |
|
|
reference_unet.load_state_dict(torch.load(reference_unet_path, map_location="cpu")) |
|
|
pose_guider.load_state_dict(torch.load(pose_guider_path, map_location="cpu")) |
|
|
|
|
|
self.pipeline = Pose2VideoPipeline( |
|
|
vae=vae, |
|
|
image_encoder=image_enc, |
|
|
reference_unet=reference_unet, |
|
|
denoising_unet=denoising_unet, |
|
|
pose_guider=pose_guider, |
|
|
scheduler=scheduler |
|
|
).to(self.device, dtype=self.weight_dtype) |
|
|
|
|
|
def _crop_face(self, image, save_path="cropped_face.jpg", margin=0.5): |
|
|
|
|
|
cv_image = np.array(image) |
|
|
cv_image = cv_image[:, :, ::-1].copy() |
|
|
|
|
|
|
|
|
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
|
|
|
|
|
|
|
|
gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY) |
|
|
faces = face_cascade.detectMultiScale(gray, 1.1, 4) |
|
|
|
|
|
if len(faces) == 0: |
|
|
raise ValueError("No faces detected in the reference image.") |
|
|
|
|
|
|
|
|
x, y, w, h = faces[0] |
|
|
x_margin = int(margin * w) |
|
|
y_margin = int(margin * h) |
|
|
|
|
|
x1 = max(0, x - x_margin) |
|
|
y1 = max(0, y - y_margin // 2) |
|
|
x2 = min(cv_image.shape[1], x + w + x_margin) |
|
|
y2 = min(cv_image.shape[0], y + h + y_margin) |
|
|
|
|
|
cropped_face = cv_image[y1:y2, x1:x2] |
|
|
|
|
|
|
|
|
cropped_face = Image.fromarray(cropped_face[:, :, ::-1]).convert("RGB") |
|
|
|
|
|
|
|
|
cropped_face.save(save_path, format="JPEG", quality=95) |
|
|
|
|
|
return cropped_face |
|
|
|
|
|
def _swap_face(self, source_path, target_video_path, output_path): |
|
|
|
|
|
|
|
|
|
|
|
roop.globals.source_path = source_path |
|
|
roop.globals.target_path = target_video_path |
|
|
roop.globals.output_path = output_path |
|
|
roop.globals.frame_processors = ["face_swapper", "face_enhancer"] |
|
|
roop.globals.headless = True |
|
|
roop.globals.keep_fps = True |
|
|
roop.globals.keep_audio = True |
|
|
roop.globals.keep_frames = False |
|
|
roop.globals.many_faces = False |
|
|
|
|
|
roop.globals.video_quality = 50 |
|
|
roop.globals.max_memory = suggest_max_memory() |
|
|
|
|
|
|
|
|
roop.globals.execution_providers = decode_execution_providers(["CUDAExecutionProvider"]) |
|
|
roop.globals.execution_threads = suggest_execution_threads() |
|
|
|
|
|
|
|
|
ort.set_default_logger_severity(3) |
|
|
providers = ['CUDAExecutionProvider'] |
|
|
options = ort.SessionOptions() |
|
|
options.intra_op_num_threads = 1 |
|
|
|
|
|
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): |
|
|
if hasattr(frame_processor, 'onnx_session'): |
|
|
frame_processor.onnx_session.set_providers(providers, options) |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
start() |
|
|
|
|
|
|
|
|
for frame_processor in roop.globals.frame_processors: |
|
|
del frame_processor |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return os.path.join(os.getcwd(), output_path) |
|
|
|
|
|
def print_memory_stat_for_stuff(self, phase, log_file="memory_stats.log"): |
|
|
with open(log_file, "a") as f: |
|
|
f.write(f"Memory Stats - {phase}:\n") |
|
|
f.write(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB\n") |
|
|
f.write(f"Reserved memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB\n") |
|
|
f.write(f"Max allocated memory: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB\n") |
|
|
f.write(f"Max reserved memory: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB\n") |
|
|
f.write("="*30 + "\n") |
|
|
|
|
|
def convert_to_playable_format(self, input_path, output_path): |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file: |
|
|
temp_output_path = tmp_file.name |
|
|
|
|
|
command = f"ffmpeg -i {input_path} -c:v libx264 -preset fast -crf 18 -y {temp_output_path}" |
|
|
|
|
|
|
|
|
result = subprocess.run(command, shell=True, capture_output=True, text=True) |
|
|
print("Conversion STDOUT:", result.stdout) |
|
|
print("Conversion STDERR:", result.stderr) |
|
|
|
|
|
if result.returncode != 0: |
|
|
raise RuntimeError(f"FFmpeg conversion failed with exit code {result.returncode}") |
|
|
|
|
|
shutil.move(temp_output_path, output_path) |
|
|
|
|
|
def run_rife_interpolation(self, video_path, output_path, multi=2, scale=1.0): |
|
|
base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
directory = os.path.join(base_dir, "Practical-RIFE", "inference_video.py") |
|
|
model_directory = os.path.join(base_dir, "Practical-RIFE", "train_log") |
|
|
command = f"python3 {directory} --video={video_path} --output={output_path} --multi={multi} --scale={scale} --model={model_directory}" |
|
|
|
|
|
|
|
|
result = subprocess.run(command, shell=True, capture_output=True, text=True) |
|
|
print(result) |
|
|
print(result.stdout) |
|
|
print(result.stderr) |
|
|
|
|
|
if result.returncode != 0: |
|
|
raise RuntimeError(f"RIFE interpolation failed with exit code {result.returncode}") |
|
|
|
|
|
|
|
|
self.convert_to_playable_format(output_path, output_path) |
|
|
|
|
|
def speed_up_video(self, input_path, output_path, factor=4): |
|
|
command = f"ffmpeg -i {input_path} -filter:v setpts=PTS/{factor} -an {output_path}" |
|
|
|
|
|
|
|
|
result = subprocess.run(command, shell=True, capture_output=True, text=True) |
|
|
print("Speed Up Video STDOUT:", result.stdout) |
|
|
print("Speed Up Video STDERR:", result.stderr) |
|
|
|
|
|
if result.returncode != 0: |
|
|
raise RuntimeError(f"FFmpeg speed up failed with exit code {result.returncode}") |
|
|
|
|
|
def slow_down_video(self, input_path, output_path, factor=4): |
|
|
command = f"ffmpeg -i {input_path} -filter:v setpts={factor}*PTS -an {output_path}" |
|
|
|
|
|
|
|
|
result = subprocess.run(command, shell=True, capture_output=True, text=True) |
|
|
print("Slow Down Video STDOUT:", result.stdout) |
|
|
print("Slow Down Video STDERR:", result.stderr) |
|
|
|
|
|
if result.returncode != 0: |
|
|
raise RuntimeError(f"FFmpeg slow down failed with exit code {result.returncode}") |
|
|
|
|
|
def download_file(self, url: str, save_path: str): |
|
|
response = requests.get(url, stream=True) |
|
|
if response.status_code == 200: |
|
|
with open(save_path, 'wb') as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
else: |
|
|
raise ValueError(f"Failed to download file from {url}") |
|
|
|
|
|
def print_directory_contents(self, directory): |
|
|
for root, dirs, files in os.walk(directory): |
|
|
level = root.replace(directory, '').count(os.sep) |
|
|
indent = ' ' * 4 * (level) |
|
|
print(f"{indent}{os.path.basename(root)}/") |
|
|
subindent = ' ' * 4 * (level + 1) |
|
|
for f in files: |
|
|
print(f"{subindent}{f}") |
|
|
|
|
|
def print_directory_contents(self, path='.'): |
|
|
for root, dirs, files in os.walk(path): |
|
|
level = root.replace(path, '').count(os.sep) |
|
|
indent = ' ' * 4 * level |
|
|
print(f'{indent}{os.path.basename(root)}/') |
|
|
sub_indent = ' ' * 4 * (level + 1) |
|
|
for f in files: |
|
|
print(f'{sub_indent}{f}') |
|
|
|
|
|
def __call__(self, data: Any) -> Dict[str, str]: |
|
|
inputs = data.get("inputs", {}) |
|
|
ref_image_url = inputs.get("ref_image_url", "") |
|
|
video_url = inputs.get("video_url", "") |
|
|
width = inputs.get("width", 512) |
|
|
height = inputs.get("height", 768) |
|
|
length = inputs.get("length", 96) |
|
|
num_inference_steps = inputs.get("num_inference_steps", 15) |
|
|
cfg = inputs.get("cfg", 3.5) |
|
|
seed = inputs.get("seed", -1) |
|
|
firebase_doc_id = inputs.get("firebase_doc_id", "") |
|
|
|
|
|
base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
print(f"Temporary directory created at {temp_dir}") |
|
|
video_root = os.path.join(temp_dir, "dw_poses_videos") |
|
|
os.makedirs(video_root, exist_ok=True) |
|
|
downloaded_video_path = os.path.join(video_root, "downloaded_video.mp4") |
|
|
downloaded_image_path = os.path.join(video_root, "downloaded_image.jpg") |
|
|
|
|
|
self.download_file(video_url, downloaded_video_path) |
|
|
self.download_file(ref_image_url, downloaded_image_path) |
|
|
ref_image = Image.open(downloaded_image_path) |
|
|
|
|
|
original_width, original_height = ref_image.size |
|
|
max_dimension = max(original_width, original_height) |
|
|
if max_dimension > 600: |
|
|
ratio = max_dimension / 600 |
|
|
width = int(original_width / ratio) |
|
|
height = int(original_height / ratio) |
|
|
else: |
|
|
width = original_width |
|
|
height = original_height |
|
|
|
|
|
ref_image_no_bg = remove(ref_image) |
|
|
ref_image_no_bg_path = os.path.join(video_root, "ref_image_no_bg.png") |
|
|
ref_image_no_bg.save(ref_image_no_bg_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
pose_images = read_frames(downloaded_video_path) |
|
|
src_fps = get_fps(downloaded_video_path) |
|
|
|
|
|
pose_list = [] |
|
|
total_length = min(length, len(pose_images)) |
|
|
for pose_image_pil in pose_images[:total_length]: |
|
|
pose_list.append(pose_image_pil) |
|
|
|
|
|
video = self.pipeline( |
|
|
ref_image_no_bg, |
|
|
pose_list, |
|
|
width=width, |
|
|
height=height, |
|
|
video_length=total_length, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=cfg |
|
|
).videos |
|
|
|
|
|
save_dir = os.path.join(temp_dir, "output") |
|
|
if not os.path.exists(save_dir): |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
animation_path = os.path.join(save_dir, "animation_output.mp4") |
|
|
save_videos_grid(video, animation_path, n_rows=1, fps=src_fps) |
|
|
|
|
|
cropped_face_path = os.path.join(save_dir, "cropped_face.jpg") |
|
|
cropped_face = self._crop_face(ref_image_no_bg, save_path=cropped_face_path) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
swapped_face_video_path = os.path.join(save_dir, "swapped_face_output.mp4") |
|
|
facefusion_script_path = os.path.join(base_dir, 'facefusion', 'core.py') |
|
|
swap_command = f'python3 {facefusion_script_path} --source {cropped_face_path} --target {animation_path} --output {swapped_face_video_path}' |
|
|
swap_result = subprocess.run(swap_command, shell=True, capture_output=True, text=True) |
|
|
if swap_result.returncode != 0: |
|
|
raise RuntimeError(f"Error running face swap: {swap_result.stderr}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(swapped_face_video_path, "rb") as video_file: |
|
|
video_base64 = base64.b64encode(video_file.read()).decode("utf-8") |
|
|
|
|
|
|
|
|
bucket = storage.bucket() |
|
|
blob = bucket.blob(f"videos/{firebase_doc_id}/swapped_face_output.mp4") |
|
|
blob.upload_from_filename(swapped_face_video_path) |
|
|
|
|
|
|
|
|
blob.make_public() |
|
|
|
|
|
video_url = blob.public_url |
|
|
|
|
|
|
|
|
db = firestore.client() |
|
|
doc_ref = db.collection('danceResults').document(firebase_doc_id) |
|
|
doc_ref.update({"videoResultUrl": video_url}) |
|
|
|
|
|
return {"video": video_base64} |
|
|
|