Spaces:
Running
Running
| import gc | |
| import copy | |
| import cv2 | |
| import datetime | |
| import os | |
| import time | |
| from contextlib import contextmanager | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from einops import repeat | |
| from PIL import Image, ImageFilter | |
| LOG_PREFIX = "[DiffuEraser]" | |
| REQUEST_LOG_FILES = {} | |
| def set_request_log_file(request_id, log_path): | |
| if request_id and log_path: | |
| REQUEST_LOG_FILES[request_id] = log_path | |
| def clear_request_log_file(request_id): | |
| if request_id: | |
| REQUEST_LOG_FILES.pop(request_id, None) | |
| def _append_request_log(request_id, text): | |
| log_path = REQUEST_LOG_FILES.get(request_id) | |
| if not log_path: | |
| return | |
| try: | |
| with open(log_path, "a", encoding="utf-8") as f: | |
| f.write(text + "\n") | |
| except Exception: | |
| pass | |
| def _format_log_value(value): | |
| if value is None: | |
| return "none" | |
| value = str(value) | |
| if not value: | |
| return "empty" | |
| if any(ch.isspace() for ch in value) or len(value) > 80: | |
| value = value.replace("\n", "\\n") | |
| return repr(value) | |
| return value | |
| def log_event(stage, message="", request_id=None, **fields): | |
| timestamp = datetime.datetime.now().isoformat(timespec="seconds") | |
| request_part = f" request_id={request_id}" if request_id else "" | |
| field_part = " ".join(f"{key}={_format_log_value(value)}" for key, value in fields.items() if value is not None) | |
| text = f"{LOG_PREFIX} {timestamp}{request_part} stage={stage}" | |
| if message: | |
| text += f" {message}" | |
| if field_part: | |
| text += f" {field_part}" | |
| print(text, flush=True) | |
| _append_request_log(request_id, text) | |
| def timed_stage(stage, request_id=None, **fields): | |
| start = time.perf_counter() | |
| log_event(stage, "start", request_id=request_id, **fields) | |
| try: | |
| yield | |
| except Exception as exc: | |
| log_event(stage, "error", request_id=request_id, error_type=type(exc).__name__, error=str(exc)) | |
| raise | |
| finally: | |
| elapsed = time.perf_counter() - start | |
| log_event(stage, "end", request_id=request_id, elapsed_sec=f"{elapsed:.2f}") | |
| def log_cuda_memory(label, request_id=None): | |
| try: | |
| if not torch.cuda.is_available(): | |
| log_event("memory.cuda", label, request_id=request_id, cuda_available=False) | |
| return | |
| device_index = torch.cuda.current_device() | |
| props = torch.cuda.get_device_properties(device_index) | |
| log_event( | |
| "memory.cuda", | |
| label, | |
| request_id=request_id, | |
| cuda_available=True, | |
| device_index=device_index, | |
| device_name=props.name, | |
| total_gb=f"{props.total_memory / (1024 ** 3):.2f}", | |
| allocated_gb=f"{torch.cuda.memory_allocated(device_index) / (1024 ** 3):.2f}", | |
| reserved_gb=f"{torch.cuda.memory_reserved(device_index) / (1024 ** 3):.2f}", | |
| max_allocated_gb=f"{torch.cuda.max_memory_allocated(device_index) / (1024 ** 3):.2f}", | |
| ) | |
| except Exception as exc: | |
| log_event("memory.cuda", "unavailable", request_id=request_id, error_type=type(exc).__name__, error=str(exc)) | |
| from diffusers import ( | |
| AutoencoderKL, | |
| DDPMScheduler, | |
| DDIMScheduler, | |
| DPMSolverMultistepScheduler, | |
| EulerDiscreteScheduler, | |
| UniPCMultistepScheduler, | |
| LCMScheduler, | |
| ) | |
| from diffusers.schedulers import TCDScheduler | |
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from transformers import AutoTokenizer, PretrainedConfig | |
| from libs.unet_motion_model import MotionAdapter, UNetMotionModel | |
| from libs.brushnet_CA import BrushNetModel | |
| from libs.unet_2d_condition import UNet2DConditionModel | |
| from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline | |
| checkpoints = { | |
| "2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0], | |
| "4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0], | |
| "8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0], | |
| "16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0], | |
| "Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5], | |
| "Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5], | |
| "Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5], | |
| "LCM-Like LoRA": [ | |
| "pcm_{}_lcmlike_lora_converted.safetensors", | |
| 4, | |
| 0.0, | |
| ], | |
| } | |
| def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): | |
| text_encoder_config = PretrainedConfig.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="text_encoder", | |
| revision=revision, | |
| ) | |
| model_class = text_encoder_config.architectures[0] | |
| if model_class == "CLIPTextModel": | |
| from transformers import CLIPTextModel | |
| return CLIPTextModel | |
| elif model_class == "RobertaSeriesModelWithTransformation": | |
| from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation | |
| return RobertaSeriesModelWithTransformation | |
| else: | |
| raise ValueError(f"{model_class} is not supported.") | |
| def resize_frames(frames, size=None): | |
| if size is not None: | |
| out_size = size | |
| process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) | |
| frames = [f.resize(process_size) for f in frames] | |
| else: | |
| out_size = frames[0].size | |
| process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) | |
| if not out_size == process_size: | |
| frames = [f.resize(process_size) for f in frames] | |
| return frames | |
| def _odd_kernel_size(value): | |
| value = max(0, int(value)) | |
| if value <= 0: | |
| return 0 | |
| return value * 2 + 1 | |
| def refine_mask_array(mask, mask_refine_mode="Keep", mask_refine_iterations=0, mask_feather_px=0, mask_dilation_iter=0): | |
| mode = str(mask_refine_mode) | |
| refine_iterations = max(0, int(mask_refine_iterations)) | |
| dilation_iterations = max(0, int(mask_dilation_iter)) | |
| feather_px = max(0, int(mask_feather_px)) | |
| m = (np.asarray(mask) > 0).astype(np.uint8) * 255 | |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) | |
| if refine_iterations > 0 and mode == "Erode": | |
| m = cv2.erode(m, kernel, iterations=refine_iterations) | |
| elif refine_iterations > 0 and mode == "Dilate": | |
| m = cv2.dilate(m, kernel, iterations=refine_iterations) | |
| if dilation_iterations > 0: | |
| m = cv2.dilate(m, kernel, iterations=dilation_iterations) | |
| kernel_size = _odd_kernel_size(feather_px) | |
| if kernel_size > 0: | |
| m = cv2.GaussianBlur(m, (kernel_size, kernel_size), 0) | |
| return m | |
| def read_mask( | |
| validation_mask, fps, n_total_frames, img_size, mask_dilation_iter, frames, | |
| mask_refine_mode="Keep", mask_refine_iterations=0, mask_feather_px=0, request_id=None): | |
| cap = cv2.VideoCapture(validation_mask) | |
| if not cap.isOpened(): | |
| print("Error: Could not open mask video.") | |
| exit() | |
| mask_fps = cap.get(cv2.CAP_PROP_FPS) | |
| if mask_fps != fps: | |
| cap.release() | |
| raise ValueError("The frame rate of all input videos needs to be consistent.") | |
| masks = [] | |
| masked_images = [] | |
| idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if(idx >= n_total_frames): | |
| break | |
| mask = Image.fromarray(frame[...,::-1]).convert('L') | |
| if mask.size != img_size: | |
| mask = mask.resize(img_size, Image.NEAREST) | |
| mask = np.asarray(mask) | |
| m = refine_mask_array( | |
| mask, | |
| mask_refine_mode=mask_refine_mode, | |
| mask_refine_iterations=mask_refine_iterations, | |
| mask_feather_px=mask_feather_px, | |
| mask_dilation_iter=mask_dilation_iter, | |
| ) | |
| mask = Image.fromarray(m) | |
| masks.append(mask) | |
| masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255)) | |
| masked_image = Image.fromarray(masked_image.astype(np.uint8)) | |
| masked_images.append(masked_image) | |
| idx += 1 | |
| cap.release() | |
| log_event( | |
| "diffueraser.read_mask", | |
| "mask refinement applied", | |
| request_id=request_id, | |
| mask_refine_mode=mask_refine_mode, | |
| mask_refine_iterations=mask_refine_iterations, | |
| mask_feather_px=mask_feather_px, | |
| mask_dilation_iter=mask_dilation_iter, | |
| masks=len(masks), | |
| ) | |
| return masks, masked_images | |
| def read_priori(priori, fps, n_total_frames, img_size): | |
| cap = cv2.VideoCapture(priori) | |
| if not cap.isOpened(): | |
| print("Error: Could not open video.") | |
| exit() | |
| priori_fps = cap.get(cv2.CAP_PROP_FPS) | |
| if priori_fps != fps: | |
| cap.release() | |
| raise ValueError("The frame rate of all input videos needs to be consistent.") | |
| prioris=[] | |
| idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if(idx >= n_total_frames): | |
| break | |
| img = Image.fromarray(frame[...,::-1]) | |
| if img.size != img_size: | |
| img = img.resize(img_size) | |
| prioris.append(img) | |
| idx += 1 | |
| cap.release() | |
| return prioris | |
| def read_video(validation_image, video_length, nframes, max_img_size): | |
| vframes, aframes, info = torchvision.io.read_video(filename=validation_image, pts_unit='sec', end_pts=video_length) # RGB | |
| fps = info['video_fps'] | |
| n_total_frames = int(video_length * fps) | |
| n_clip = int(np.ceil(n_total_frames/nframes)) | |
| frames = list(vframes.numpy())[:n_total_frames] | |
| frames = [Image.fromarray(f) for f in frames] | |
| max_size = max(frames[0].size) | |
| if(max_size<256): | |
| raise ValueError("The resolution of the uploaded video must be larger than 256x256.") | |
| if(max_size>4096): | |
| raise ValueError("The resolution of the uploaded video must be smaller than 4096x4096.") | |
| if max_size>max_img_size: | |
| ratio = max_size/max_img_size | |
| ratio_size = (int(frames[0].size[0]/ratio),int(frames[0].size[1]/ratio)) | |
| img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8) | |
| resize_flag=True | |
| elif (frames[0].size[0]%8==0) and (frames[0].size[1]%8==0): | |
| img_size = frames[0].size | |
| resize_flag=False | |
| else: | |
| ratio_size = frames[0].size | |
| img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8) | |
| resize_flag=True | |
| if resize_flag: | |
| frames = resize_frames(frames, img_size) | |
| img_size = frames[0].size | |
| return frames, fps, img_size, n_clip, n_total_frames | |
| class DiffuEraser: | |
| def __init__( | |
| self, device, base_model_path, vae_path, diffueraser_path, revision=None, | |
| ckpt="Normal CFG 4-Step", mode="sd15", loaded=None): | |
| self.device = device | |
| self.mode = mode | |
| self.current_ckpt = None | |
| self.current_scheduler = None | |
| ## load model | |
| self.vae = AutoencoderKL.from_pretrained(vae_path) | |
| self.noise_scheduler = DDPMScheduler.from_pretrained(base_model_path, | |
| subfolder="scheduler", | |
| prediction_type="v_prediction", | |
| timestep_spacing="trailing", | |
| rescale_betas_zero_snr=True | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_path, | |
| subfolder="tokenizer", | |
| use_fast=False, | |
| ) | |
| text_encoder_cls = import_model_class_from_model_name_or_path(base_model_path,revision) | |
| self.text_encoder = text_encoder_cls.from_pretrained( | |
| base_model_path, subfolder="text_encoder" | |
| ) | |
| self.brushnet = BrushNetModel.from_pretrained(diffueraser_path, subfolder="brushnet") | |
| self.unet_main = UNetMotionModel.from_pretrained( | |
| diffueraser_path, subfolder="unet_main", | |
| ) | |
| ## set pipeline | |
| self.pipeline = StableDiffusionDiffuEraserPipeline.from_pretrained( | |
| base_model_path, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| unet=self.unet_main, | |
| brushnet=self.brushnet, | |
| safety_checker=None, | |
| feature_extractor=None, | |
| requires_safety_checker=False, | |
| ).to(self.device, torch.float16) | |
| self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) | |
| self.scheduler_config = copy.deepcopy(self.pipeline.scheduler.config) | |
| self.pipeline.set_progress_bar_config(disable=True) | |
| self.noise_scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) | |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) | |
| ## use PCM | |
| self.set_checkpoint(ckpt) | |
| def _resolve_scheduler_name(self, scheduler_name, ckpt=None): | |
| ckpt = ckpt or self.ckpt | |
| if scheduler_name == "Auto": | |
| return "LCM" if ckpt == "LCM-Like LoRA" else "TCD" | |
| return scheduler_name | |
| def _build_scheduler(self, scheduler_name, ckpt=None): | |
| resolved_scheduler = self._resolve_scheduler_name(scheduler_name, ckpt) | |
| if resolved_scheduler == "LCM": | |
| return LCMScheduler() | |
| if resolved_scheduler == "TCD": | |
| return TCDScheduler( | |
| num_train_timesteps=1000, | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| timestep_spacing="trailing", | |
| ) | |
| if resolved_scheduler == "UniPC": | |
| return UniPCMultistepScheduler.from_config(self.scheduler_config) | |
| if resolved_scheduler == "DDIM": | |
| return DDIMScheduler.from_config(self.scheduler_config) | |
| if resolved_scheduler == "Euler": | |
| return EulerDiscreteScheduler.from_config(self.scheduler_config) | |
| if resolved_scheduler == "DPM++ 2M": | |
| return DPMSolverMultistepScheduler.from_config(self.scheduler_config) | |
| raise ValueError(f"Unsupported scheduler: {scheduler_name}") | |
| def set_scheduler(self, scheduler_name="Auto", request_id=None): | |
| resolved_scheduler = self._resolve_scheduler_name(scheduler_name, self.ckpt) | |
| scheduler_key = f"{scheduler_name}->{resolved_scheduler}" | |
| if scheduler_key == self.current_scheduler: | |
| log_event("diffueraser.scheduler", "unchanged", request_id=request_id, scheduler=scheduler_name, resolved_scheduler=resolved_scheduler) | |
| return | |
| with timed_stage("diffueraser.scheduler", request_id=request_id, scheduler=scheduler_name, resolved_scheduler=resolved_scheduler): | |
| self.pipeline.scheduler = self._build_scheduler(scheduler_name, self.ckpt) | |
| self.current_scheduler = scheduler_key | |
| log_event("diffueraser.scheduler", "set", request_id=request_id, scheduler=scheduler_name, resolved_scheduler=resolved_scheduler) | |
| def set_checkpoint(self, ckpt, scheduler_name="Auto", request_id=None): | |
| if ckpt != self.current_ckpt: | |
| PCM_ckpts = checkpoints[ckpt][0].format(self.mode) | |
| with timed_stage("diffueraser.lora", request_id=request_id, ckpt=ckpt, weight=PCM_ckpts): | |
| if self.current_ckpt is not None: | |
| log_event("diffueraser.lora", "unload previous", request_id=request_id, previous_ckpt=self.current_ckpt) | |
| self.pipeline.unload_lora_weights() | |
| self.pipeline.load_lora_weights( | |
| "weights/PCM_Weights", weight_name=PCM_ckpts, subfolder=self.mode | |
| ) | |
| self.ckpt = ckpt | |
| self.current_ckpt = ckpt | |
| self.num_inference_steps = checkpoints[ckpt][1] | |
| self.guidance_scale = checkpoints[ckpt][2] | |
| self.current_scheduler = None | |
| log_event( | |
| "diffueraser.lora", | |
| "loaded", | |
| request_id=request_id, | |
| ckpt=ckpt, | |
| num_inference_steps=self.num_inference_steps, | |
| checkpoint_guidance_scale=self.guidance_scale, | |
| ) | |
| else: | |
| log_event("diffueraser.lora", "unchanged", request_id=request_id, ckpt=ckpt) | |
| self.set_scheduler(scheduler_name, request_id=request_id) | |
| def forward(self, validation_image, validation_mask, priori, output_path, | |
| max_img_size = 1280, video_length=2, mask_dilation_iter=4, | |
| mask_refine_mode="Keep", mask_refine_iterations=0, mask_feather_px=0, | |
| nframes=22, seed=None, revision = None, guidance_scale=None, blended=True, | |
| prompt="", negative_prompt="", request_id=None, output_fps=None, progress_callback=None): | |
| validation_prompt = prompt or "" | |
| negative_prompt = negative_prompt or None | |
| guidance_scale_final = self.guidance_scale if guidance_scale==None else guidance_scale | |
| def _progress(local_value, desc): | |
| if progress_callback is None: | |
| return | |
| try: | |
| progress_callback(local_value, desc) | |
| except Exception: | |
| pass | |
| def _pipeline_progress(start, end, desc): | |
| def callback(_pipe, step, _timestep, callback_kwargs): | |
| total = max(1, int(self.num_inference_steps)) | |
| ratio = max(0.0, min(1.0, float(step + 1) / total)) | |
| _progress(start + (end - start) * ratio, f"{desc} {step + 1}/{total}") | |
| return callback_kwargs | |
| return callback | |
| _progress(0.01, "DiffuEraser: reading inputs") | |
| log_event( | |
| "diffueraser.forward", | |
| "start", | |
| request_id=request_id, | |
| ckpt=self.ckpt, | |
| scheduler=self.current_scheduler, | |
| num_inference_steps=self.num_inference_steps, | |
| video_length=video_length, | |
| max_img_size=max_img_size, | |
| nframes=nframes, | |
| mask_dilation_iter=mask_dilation_iter, | |
| mask_refine_mode=mask_refine_mode, | |
| mask_refine_iterations=mask_refine_iterations, | |
| mask_feather_px=mask_feather_px, | |
| guidance_scale=guidance_scale_final, | |
| seed="random" if seed is None else seed, | |
| prompt_chars=len(validation_prompt), | |
| negative_prompt_chars=0 if negative_prompt is None else len(negative_prompt), | |
| output_fps="same_as_processed" if output_fps is None else output_fps, | |
| ) | |
| log_cuda_memory("diffueraser.forward.start", request_id=request_id) | |
| if (max_img_size<256 or max_img_size>1920): | |
| raise ValueError("The max_img_size must be larger than 256, smaller than 1920.") | |
| ################ read input video ################ | |
| with timed_stage("diffueraser.read_video", request_id=request_id, input=validation_image): | |
| frames, fps, img_size, n_clip, n_total_frames = read_video(validation_image, video_length, nframes, max_img_size) | |
| video_len = len(frames) | |
| log_event( | |
| "diffueraser.read_video", | |
| "loaded frames", | |
| request_id=request_id, | |
| fps=f"{fps:.2f}", | |
| image_size=f"{img_size[0]}x{img_size[1]}", | |
| n_clip=n_clip, | |
| n_total_frames=n_total_frames, | |
| frames=len(frames), | |
| ) | |
| _progress(0.04, "DiffuEraser: reading mask") | |
| ################ read mask ################ | |
| with timed_stage("diffueraser.read_mask", request_id=request_id, input=validation_mask): | |
| validation_masks_input, validation_images_input = read_mask( | |
| validation_mask, fps, video_len, img_size, mask_dilation_iter, frames, | |
| mask_refine_mode=mask_refine_mode, | |
| mask_refine_iterations=mask_refine_iterations, | |
| mask_feather_px=mask_feather_px, | |
| request_id=request_id, | |
| ) | |
| log_event("diffueraser.read_mask", "loaded masks", request_id=request_id, masks=len(validation_masks_input), masked_images=len(validation_images_input)) | |
| _progress(0.07, "DiffuEraser: reading ProPainter priori") | |
| ################ read priori ################ | |
| with timed_stage("diffueraser.read_priori", request_id=request_id, input=priori): | |
| prioris = read_priori(priori, fps, n_total_frames, img_size) | |
| log_event("diffueraser.read_priori", "loaded priori frames", request_id=request_id, prioris=len(prioris)) | |
| ## recheck | |
| n_total_frames = min(min(len(frames), len(validation_masks_input)), len(prioris)) | |
| if(n_total_frames<22): | |
| raise ValueError("The effective video duration is too short. Please make sure that the number of frames of video, mask, and priori is at least greater than 22 frames.") | |
| validation_masks_input = validation_masks_input[:n_total_frames] | |
| validation_images_input = validation_images_input[:n_total_frames] | |
| frames = frames[:n_total_frames] | |
| prioris = prioris[:n_total_frames] | |
| log_event( | |
| "diffueraser.recheck", | |
| "aligned frame counts", | |
| request_id=request_id, | |
| n_total_frames=n_total_frames, | |
| frames=len(frames), | |
| masks=len(validation_masks_input), | |
| prioris=len(prioris), | |
| ) | |
| _progress(0.10, "DiffuEraser: resizing inputs") | |
| with timed_stage("diffueraser.resize_inputs", request_id=request_id): | |
| prioris = resize_frames(prioris) | |
| validation_masks_input = resize_frames(validation_masks_input) | |
| validation_images_input = resize_frames(validation_images_input) | |
| resized_frames = resize_frames(frames) | |
| log_event( | |
| "diffueraser.resize_inputs", | |
| "resized input lists", | |
| request_id=request_id, | |
| prioris=len(prioris), | |
| masks=len(validation_masks_input), | |
| masked_images=len(validation_images_input), | |
| frames=len(resized_frames), | |
| ) | |
| ############################################## | |
| # DiffuEraser inference | |
| ############################################## | |
| _progress(0.14, "DiffuEraser: preparing inference") | |
| log_event("diffueraser.inference", "begin core inference", request_id=request_id) | |
| if seed is None: | |
| generator = None | |
| else: | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| ## random noise | |
| real_video_length = len(validation_images_input) | |
| tar_width, tar_height = validation_images_input[0].size | |
| shape = ( | |
| nframes, | |
| 4, | |
| tar_height//8, | |
| tar_width//8 | |
| ) | |
| if self.text_encoder is not None: | |
| prompt_embeds_dtype = self.text_encoder.dtype | |
| elif self.unet_main is not None: | |
| prompt_embeds_dtype = self.unet_main.dtype | |
| else: | |
| prompt_embeds_dtype = torch.float16 | |
| log_event( | |
| "diffueraser.latents", | |
| "preparing noise", | |
| request_id=request_id, | |
| real_video_length=real_video_length, | |
| n_clip=n_clip, | |
| latent_shape="x".join(str(x) for x in shape), | |
| target_size=f"{tar_width}x{tar_height}", | |
| dtype=prompt_embeds_dtype, | |
| ) | |
| _progress(0.16, "DiffuEraser: preparing noise") | |
| with timed_stage("diffueraser.prepare_noise", request_id=request_id): | |
| noise_pre = randn_tensor(shape, device=torch.device(self.device), dtype=prompt_embeds_dtype, generator=generator) | |
| noise = repeat(noise_pre, "t c h w->(repeat t) c h w", repeat=n_clip)[:real_video_length,...] | |
| _progress(0.18, "DiffuEraser: encoding priori latents") | |
| ################ prepare priori ################ | |
| with timed_stage("diffueraser.prepare_priori_latents", request_id=request_id, prioris=len(prioris)): | |
| images_preprocessed = [] | |
| for image in prioris: | |
| image = self.image_processor.preprocess(image, height=tar_height, width=tar_width).to(dtype=torch.float32) | |
| image = image.to(device=torch.device(self.device), dtype=torch.float16) | |
| images_preprocessed.append(image) | |
| pixel_values = torch.cat(images_preprocessed) | |
| with torch.no_grad(): | |
| pixel_values = pixel_values.to(dtype=torch.float16) | |
| latents = [] | |
| num=4 | |
| for i in range(0, pixel_values.shape[0], num): | |
| latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample()) | |
| latents = torch.cat(latents, dim=0) | |
| latents = latents * self.vae.config.scaling_factor #[(b f), c1=4, h, w] | |
| log_event("diffueraser.prepare_priori_latents", "created latents", request_id=request_id, latents_shape="x".join(str(x) for x in latents.shape)) | |
| torch.cuda.empty_cache() | |
| log_cuda_memory("diffueraser.after_prepare_latents", request_id=request_id) | |
| timesteps = torch.tensor([0], device=self.device) | |
| timesteps = timesteps.long() | |
| validation_masks_input_ori = copy.deepcopy(validation_masks_input) | |
| resized_frames_ori = copy.deepcopy(resized_frames) | |
| ################ Pre-inference ################ | |
| if n_total_frames > nframes*2: ## do pre-inference only when number of input frames is larger than nframes*2 | |
| with timed_stage("diffueraser.pre_inference", request_id=request_id, n_total_frames=n_total_frames, nframes=nframes): | |
| ## sample | |
| step = n_total_frames / nframes | |
| sample_index = [int(i * step) for i in range(nframes)] | |
| sample_index = sample_index[:22] | |
| log_event("diffueraser.pre_inference", "sampled frames", request_id=request_id, sample_count=len(sample_index)) | |
| validation_masks_input_pre = [validation_masks_input[i] for i in sample_index] | |
| validation_images_input_pre = [validation_images_input[i] for i in sample_index] | |
| latents_pre = torch.stack([latents[i] for i in sample_index]) | |
| ## add proiri | |
| noisy_latents_pre = self.noise_scheduler.add_noise(latents_pre, noise_pre, timesteps) | |
| latents_pre = noisy_latents_pre | |
| with torch.no_grad(): | |
| latents_pre_out = self.pipeline( | |
| num_frames=nframes, | |
| prompt=validation_prompt, | |
| images=validation_images_input_pre, | |
| masks=validation_masks_input_pre, | |
| num_inference_steps=self.num_inference_steps, | |
| generator=generator, | |
| guidance_scale=guidance_scale_final, | |
| negative_prompt=negative_prompt, | |
| latents=latents_pre, | |
| callback_on_step_end=_pipeline_progress(0.30, 0.50, "DiffuEraser: pre-inference"), | |
| callback_on_step_end_tensor_inputs=[], | |
| ).latents | |
| torch.cuda.empty_cache() | |
| log_cuda_memory("diffueraser.after_pre_inference_pipeline", request_id=request_id) | |
| def decode_latents(latents, weight_dtype): | |
| latents = 1 / self.vae.config.scaling_factor * latents | |
| video = [] | |
| for t in range(latents.shape[0]): | |
| video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample) | |
| video = torch.concat(video, dim=0) | |
| # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | |
| video = video.float() | |
| return video | |
| _progress(0.52, "DiffuEraser: decoding pre-inference frames") | |
| with timed_stage("diffueraser.pre_inference.decode", request_id=request_id, latents_shape="x".join(str(x) for x in latents_pre_out.shape)): | |
| with torch.no_grad(): | |
| video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16) | |
| images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil") | |
| torch.cuda.empty_cache() | |
| ## replace input frames with updated frames | |
| black_image = Image.new('L', validation_masks_input[0].size, color=0) | |
| for i,index in enumerate(sample_index): | |
| latents[index] = latents_pre_out[i] | |
| validation_masks_input[index] = black_image | |
| validation_images_input[index] = images_pre_out[i] | |
| resized_frames[index] = images_pre_out[i] | |
| else: | |
| _progress(0.55, "DiffuEraser: pre-inference skipped") | |
| log_event("diffueraser.pre_inference", "skipped", request_id=request_id, reason="not_enough_frames", n_total_frames=n_total_frames, threshold=nframes*2) | |
| latents_pre_out=None | |
| sample_index=None | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| log_cuda_memory("diffueraser.after_pre_inference", request_id=request_id) | |
| _progress(0.58, "DiffuEraser: frame inference") | |
| ################ Frame-by-frame inference ################ | |
| with timed_stage("diffueraser.frame_inference", request_id=request_id, frames=len(validation_images_input), steps=self.num_inference_steps): | |
| ## add priori | |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
| latents = noisy_latents | |
| with torch.no_grad(): | |
| images = self.pipeline( | |
| num_frames=nframes, | |
| prompt=validation_prompt, | |
| images=validation_images_input, | |
| masks=validation_masks_input, | |
| num_inference_steps=self.num_inference_steps, | |
| generator=generator, | |
| guidance_scale=guidance_scale_final, | |
| negative_prompt=negative_prompt, | |
| latents=latents, | |
| callback_on_step_end=_pipeline_progress(0.60, 0.86, "DiffuEraser: frame inference"), | |
| callback_on_step_end_tensor_inputs=[], | |
| ).frames | |
| images = images[:real_video_length] | |
| log_event("diffueraser.frame_inference", "generated frames", request_id=request_id, frames=len(images)) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| log_cuda_memory("diffueraser.after_frame_inference", request_id=request_id) | |
| _progress(0.88, "DiffuEraser: composing output") | |
| ################ Compose ################ | |
| with timed_stage("diffueraser.compose_write", request_id=request_id, output=output_path, frames=real_video_length): | |
| binary_masks = validation_masks_input_ori | |
| mask_blurreds = [] | |
| if blended: | |
| # blur, you can adjust the parameters for better performance | |
| for i in range(len(binary_masks)): | |
| mask_blurred = cv2.GaussianBlur(np.array(binary_masks[i]), (21, 21), 0)/255. | |
| binary_mask = 1-(1-np.array(binary_masks[i])/255.) * (1-mask_blurred) | |
| mask_blurreds.append(Image.fromarray((binary_mask*255).astype(np.uint8))) | |
| binary_masks = mask_blurreds | |
| comp_frames = [] | |
| for i in range(len(images)): | |
| mask = np.expand_dims(np.array(binary_masks[i]),2).repeat(3, axis=2).astype(np.float32)/255. | |
| img = (np.array(images[i]).astype(np.uint8) * mask \ | |
| + np.array(resized_frames_ori[i]).astype(np.uint8) * (1 - mask)).astype(np.uint8) | |
| comp_frames.append(Image.fromarray(img)) | |
| default_fps = fps | |
| output_frames = comp_frames | |
| writer_fps = default_fps | |
| if output_fps is not None: | |
| output_fps = float(output_fps) | |
| if output_fps > 0 and output_fps < default_fps: | |
| target_count = max(1, int(round(len(comp_frames) * output_fps / default_fps))) | |
| frame_indices = np.linspace(0, len(comp_frames) - 1, target_count).round().astype(int) | |
| output_frames = [comp_frames[i] for i in frame_indices] | |
| writer_fps = output_fps | |
| log_event( | |
| "diffueraser.output_fps", | |
| "downsampled output frames", | |
| request_id=request_id, | |
| input_fps=f"{default_fps:.2f}", | |
| output_fps=f"{writer_fps:.2f}", | |
| input_frames=len(comp_frames), | |
| output_frames=len(output_frames), | |
| ) | |
| else: | |
| log_event( | |
| "diffueraser.output_fps", | |
| "kept processed fps", | |
| request_id=request_id, | |
| input_fps=f"{default_fps:.2f}", | |
| requested_output_fps=f"{output_fps:.2f}", | |
| ) | |
| writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), | |
| writer_fps, output_frames[0].size) | |
| for f in range(len(output_frames)): | |
| img = np.array(output_frames[f]).astype(np.uint8) | |
| writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
| writer.release() | |
| log_event( | |
| "diffueraser.compose_write", | |
| "wrote output", | |
| request_id=request_id, | |
| output=output_path, | |
| fps=f"{writer_fps:.2f}", | |
| source_fps=f"{default_fps:.2f}", | |
| frame_size=f"{output_frames[0].size[0]}x{output_frames[0].size[1]}", | |
| frames=len(output_frames), | |
| source_frames=len(comp_frames), | |
| ) | |
| ################################ | |
| _progress(1.0, "DiffuEraser: done") | |
| log_cuda_memory("diffueraser.forward.end", request_id=request_id) | |
| return output_path | |