import gc import copy import cv2 import datetime import os import time from contextlib import contextmanager import numpy as np import torch import torchvision from einops import repeat from PIL import Image, ImageFilter LOG_PREFIX = "[DiffuEraser]" REQUEST_LOG_FILES = {} def set_request_log_file(request_id, log_path): if request_id and log_path: REQUEST_LOG_FILES[request_id] = log_path def clear_request_log_file(request_id): if request_id: REQUEST_LOG_FILES.pop(request_id, None) def _append_request_log(request_id, text): log_path = REQUEST_LOG_FILES.get(request_id) if not log_path: return try: with open(log_path, "a", encoding="utf-8") as f: f.write(text + "\n") except Exception: pass def _format_log_value(value): if value is None: return "none" value = str(value) if not value: return "empty" if any(ch.isspace() for ch in value) or len(value) > 80: value = value.replace("\n", "\\n") return repr(value) return value def log_event(stage, message="", request_id=None, **fields): timestamp = datetime.datetime.now().isoformat(timespec="seconds") request_part = f" request_id={request_id}" if request_id else "" field_part = " ".join(f"{key}={_format_log_value(value)}" for key, value in fields.items() if value is not None) text = f"{LOG_PREFIX} {timestamp}{request_part} stage={stage}" if message: text += f" {message}" if field_part: text += f" {field_part}" print(text, flush=True) _append_request_log(request_id, text) @contextmanager def timed_stage(stage, request_id=None, **fields): start = time.perf_counter() log_event(stage, "start", request_id=request_id, **fields) try: yield except Exception as exc: log_event(stage, "error", request_id=request_id, error_type=type(exc).__name__, error=str(exc)) raise finally: elapsed = time.perf_counter() - start log_event(stage, "end", request_id=request_id, elapsed_sec=f"{elapsed:.2f}") def log_cuda_memory(label, request_id=None): try: if not torch.cuda.is_available(): log_event("memory.cuda", label, request_id=request_id, cuda_available=False) return device_index = torch.cuda.current_device() props = torch.cuda.get_device_properties(device_index) log_event( "memory.cuda", label, request_id=request_id, cuda_available=True, device_index=device_index, device_name=props.name, total_gb=f"{props.total_memory / (1024 ** 3):.2f}", allocated_gb=f"{torch.cuda.memory_allocated(device_index) / (1024 ** 3):.2f}", reserved_gb=f"{torch.cuda.memory_reserved(device_index) / (1024 ** 3):.2f}", max_allocated_gb=f"{torch.cuda.max_memory_allocated(device_index) / (1024 ** 3):.2f}", ) except Exception as exc: log_event("memory.cuda", "unavailable", request_id=request_id, error_type=type(exc).__name__, error=str(exc)) from diffusers import ( AutoencoderKL, DDPMScheduler, DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, UniPCMultistepScheduler, LCMScheduler, ) from diffusers.schedulers import TCDScheduler from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.utils.torch_utils import randn_tensor from transformers import AutoTokenizer, PretrainedConfig from libs.unet_motion_model import MotionAdapter, UNetMotionModel from libs.brushnet_CA import BrushNetModel from libs.unet_2d_condition import UNet2DConditionModel from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline checkpoints = { "2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0], "4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0], "8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0], "16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0], "Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5], "Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5], "Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5], "LCM-Like LoRA": [ "pcm_{}_lcmlike_lora_converted.safetensors", 4, 0.0, ], } def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation else: raise ValueError(f"{model_class} is not supported.") def resize_frames(frames, size=None): if size is not None: out_size = size process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) frames = [f.resize(process_size) for f in frames] else: out_size = frames[0].size process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) if not out_size == process_size: frames = [f.resize(process_size) for f in frames] return frames def _odd_kernel_size(value): value = max(0, int(value)) if value <= 0: return 0 return value * 2 + 1 def refine_mask_array(mask, mask_refine_mode="Keep", mask_refine_iterations=0, mask_feather_px=0, mask_dilation_iter=0): mode = str(mask_refine_mode) refine_iterations = max(0, int(mask_refine_iterations)) dilation_iterations = max(0, int(mask_dilation_iter)) feather_px = max(0, int(mask_feather_px)) m = (np.asarray(mask) > 0).astype(np.uint8) * 255 kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) if refine_iterations > 0 and mode == "Erode": m = cv2.erode(m, kernel, iterations=refine_iterations) elif refine_iterations > 0 and mode == "Dilate": m = cv2.dilate(m, kernel, iterations=refine_iterations) if dilation_iterations > 0: m = cv2.dilate(m, kernel, iterations=dilation_iterations) kernel_size = _odd_kernel_size(feather_px) if kernel_size > 0: m = cv2.GaussianBlur(m, (kernel_size, kernel_size), 0) return m def read_mask( validation_mask, fps, n_total_frames, img_size, mask_dilation_iter, frames, mask_refine_mode="Keep", mask_refine_iterations=0, mask_feather_px=0, request_id=None): cap = cv2.VideoCapture(validation_mask) if not cap.isOpened(): print("Error: Could not open mask video.") exit() mask_fps = cap.get(cv2.CAP_PROP_FPS) if mask_fps != fps: cap.release() raise ValueError("The frame rate of all input videos needs to be consistent.") masks = [] masked_images = [] idx = 0 while True: ret, frame = cap.read() if not ret: break if(idx >= n_total_frames): break mask = Image.fromarray(frame[...,::-1]).convert('L') if mask.size != img_size: mask = mask.resize(img_size, Image.NEAREST) mask = np.asarray(mask) m = refine_mask_array( mask, mask_refine_mode=mask_refine_mode, mask_refine_iterations=mask_refine_iterations, mask_feather_px=mask_feather_px, mask_dilation_iter=mask_dilation_iter, ) mask = Image.fromarray(m) masks.append(mask) masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255)) masked_image = Image.fromarray(masked_image.astype(np.uint8)) masked_images.append(masked_image) idx += 1 cap.release() log_event( "diffueraser.read_mask", "mask refinement applied", request_id=request_id, mask_refine_mode=mask_refine_mode, mask_refine_iterations=mask_refine_iterations, mask_feather_px=mask_feather_px, mask_dilation_iter=mask_dilation_iter, masks=len(masks), ) return masks, masked_images def read_priori(priori, fps, n_total_frames, img_size): cap = cv2.VideoCapture(priori) if not cap.isOpened(): print("Error: Could not open video.") exit() priori_fps = cap.get(cv2.CAP_PROP_FPS) if priori_fps != fps: cap.release() raise ValueError("The frame rate of all input videos needs to be consistent.") prioris=[] idx = 0 while True: ret, frame = cap.read() if not ret: break if(idx >= n_total_frames): break img = Image.fromarray(frame[...,::-1]) if img.size != img_size: img = img.resize(img_size) prioris.append(img) idx += 1 cap.release() return prioris def read_video(validation_image, video_length, nframes, max_img_size): vframes, aframes, info = torchvision.io.read_video(filename=validation_image, pts_unit='sec', end_pts=video_length) # RGB fps = info['video_fps'] n_total_frames = int(video_length * fps) n_clip = int(np.ceil(n_total_frames/nframes)) frames = list(vframes.numpy())[:n_total_frames] frames = [Image.fromarray(f) for f in frames] max_size = max(frames[0].size) if(max_size<256): raise ValueError("The resolution of the uploaded video must be larger than 256x256.") if(max_size>4096): raise ValueError("The resolution of the uploaded video must be smaller than 4096x4096.") if max_size>max_img_size: ratio = max_size/max_img_size ratio_size = (int(frames[0].size[0]/ratio),int(frames[0].size[1]/ratio)) img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8) resize_flag=True elif (frames[0].size[0]%8==0) and (frames[0].size[1]%8==0): img_size = frames[0].size resize_flag=False else: ratio_size = frames[0].size img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8) resize_flag=True if resize_flag: frames = resize_frames(frames, img_size) img_size = frames[0].size return frames, fps, img_size, n_clip, n_total_frames class DiffuEraser: def __init__( self, device, base_model_path, vae_path, diffueraser_path, revision=None, ckpt="Normal CFG 4-Step", mode="sd15", loaded=None): self.device = device self.mode = mode self.current_ckpt = None self.current_scheduler = None ## load model self.vae = AutoencoderKL.from_pretrained(vae_path) self.noise_scheduler = DDPMScheduler.from_pretrained(base_model_path, subfolder="scheduler", prediction_type="v_prediction", timestep_spacing="trailing", rescale_betas_zero_snr=True ) self.tokenizer = AutoTokenizer.from_pretrained( base_model_path, subfolder="tokenizer", use_fast=False, ) text_encoder_cls = import_model_class_from_model_name_or_path(base_model_path,revision) self.text_encoder = text_encoder_cls.from_pretrained( base_model_path, subfolder="text_encoder" ) self.brushnet = BrushNetModel.from_pretrained(diffueraser_path, subfolder="brushnet") self.unet_main = UNetMotionModel.from_pretrained( diffueraser_path, subfolder="unet_main", ) ## set pipeline self.pipeline = StableDiffusionDiffuEraserPipeline.from_pretrained( base_model_path, vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet_main, brushnet=self.brushnet, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ).to(self.device, torch.float16) self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) self.scheduler_config = copy.deepcopy(self.pipeline.scheduler.config) self.pipeline.set_progress_bar_config(disable=True) self.noise_scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) ## use PCM self.set_checkpoint(ckpt) def _resolve_scheduler_name(self, scheduler_name, ckpt=None): ckpt = ckpt or self.ckpt if scheduler_name == "Auto": return "LCM" if ckpt == "LCM-Like LoRA" else "TCD" return scheduler_name def _build_scheduler(self, scheduler_name, ckpt=None): resolved_scheduler = self._resolve_scheduler_name(scheduler_name, ckpt) if resolved_scheduler == "LCM": return LCMScheduler() if resolved_scheduler == "TCD": return TCDScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", timestep_spacing="trailing", ) if resolved_scheduler == "UniPC": return UniPCMultistepScheduler.from_config(self.scheduler_config) if resolved_scheduler == "DDIM": return DDIMScheduler.from_config(self.scheduler_config) if resolved_scheduler == "Euler": return EulerDiscreteScheduler.from_config(self.scheduler_config) if resolved_scheduler == "DPM++ 2M": return DPMSolverMultistepScheduler.from_config(self.scheduler_config) raise ValueError(f"Unsupported scheduler: {scheduler_name}") def set_scheduler(self, scheduler_name="Auto", request_id=None): resolved_scheduler = self._resolve_scheduler_name(scheduler_name, self.ckpt) scheduler_key = f"{scheduler_name}->{resolved_scheduler}" if scheduler_key == self.current_scheduler: log_event("diffueraser.scheduler", "unchanged", request_id=request_id, scheduler=scheduler_name, resolved_scheduler=resolved_scheduler) return with timed_stage("diffueraser.scheduler", request_id=request_id, scheduler=scheduler_name, resolved_scheduler=resolved_scheduler): self.pipeline.scheduler = self._build_scheduler(scheduler_name, self.ckpt) self.current_scheduler = scheduler_key log_event("diffueraser.scheduler", "set", request_id=request_id, scheduler=scheduler_name, resolved_scheduler=resolved_scheduler) def set_checkpoint(self, ckpt, scheduler_name="Auto", request_id=None): if ckpt != self.current_ckpt: PCM_ckpts = checkpoints[ckpt][0].format(self.mode) with timed_stage("diffueraser.lora", request_id=request_id, ckpt=ckpt, weight=PCM_ckpts): if self.current_ckpt is not None: log_event("diffueraser.lora", "unload previous", request_id=request_id, previous_ckpt=self.current_ckpt) self.pipeline.unload_lora_weights() self.pipeline.load_lora_weights( "weights/PCM_Weights", weight_name=PCM_ckpts, subfolder=self.mode ) self.ckpt = ckpt self.current_ckpt = ckpt self.num_inference_steps = checkpoints[ckpt][1] self.guidance_scale = checkpoints[ckpt][2] self.current_scheduler = None log_event( "diffueraser.lora", "loaded", request_id=request_id, ckpt=ckpt, num_inference_steps=self.num_inference_steps, checkpoint_guidance_scale=self.guidance_scale, ) else: log_event("diffueraser.lora", "unchanged", request_id=request_id, ckpt=ckpt) self.set_scheduler(scheduler_name, request_id=request_id) def forward(self, validation_image, validation_mask, priori, output_path, max_img_size = 1280, video_length=2, mask_dilation_iter=4, mask_refine_mode="Keep", mask_refine_iterations=0, mask_feather_px=0, nframes=22, seed=None, revision = None, guidance_scale=None, blended=True, prompt="", negative_prompt="", request_id=None, output_fps=None, progress_callback=None): validation_prompt = prompt or "" negative_prompt = negative_prompt or None guidance_scale_final = self.guidance_scale if guidance_scale==None else guidance_scale def _progress(local_value, desc): if progress_callback is None: return try: progress_callback(local_value, desc) except Exception: pass def _pipeline_progress(start, end, desc): def callback(_pipe, step, _timestep, callback_kwargs): total = max(1, int(self.num_inference_steps)) ratio = max(0.0, min(1.0, float(step + 1) / total)) _progress(start + (end - start) * ratio, f"{desc} {step + 1}/{total}") return callback_kwargs return callback _progress(0.01, "DiffuEraser: reading inputs") log_event( "diffueraser.forward", "start", request_id=request_id, ckpt=self.ckpt, scheduler=self.current_scheduler, num_inference_steps=self.num_inference_steps, video_length=video_length, max_img_size=max_img_size, nframes=nframes, mask_dilation_iter=mask_dilation_iter, mask_refine_mode=mask_refine_mode, mask_refine_iterations=mask_refine_iterations, mask_feather_px=mask_feather_px, guidance_scale=guidance_scale_final, seed="random" if seed is None else seed, prompt_chars=len(validation_prompt), negative_prompt_chars=0 if negative_prompt is None else len(negative_prompt), output_fps="same_as_processed" if output_fps is None else output_fps, ) log_cuda_memory("diffueraser.forward.start", request_id=request_id) if (max_img_size<256 or max_img_size>1920): raise ValueError("The max_img_size must be larger than 256, smaller than 1920.") ################ read input video ################ with timed_stage("diffueraser.read_video", request_id=request_id, input=validation_image): frames, fps, img_size, n_clip, n_total_frames = read_video(validation_image, video_length, nframes, max_img_size) video_len = len(frames) log_event( "diffueraser.read_video", "loaded frames", request_id=request_id, fps=f"{fps:.2f}", image_size=f"{img_size[0]}x{img_size[1]}", n_clip=n_clip, n_total_frames=n_total_frames, frames=len(frames), ) _progress(0.04, "DiffuEraser: reading mask") ################ read mask ################ with timed_stage("diffueraser.read_mask", request_id=request_id, input=validation_mask): validation_masks_input, validation_images_input = read_mask( validation_mask, fps, video_len, img_size, mask_dilation_iter, frames, mask_refine_mode=mask_refine_mode, mask_refine_iterations=mask_refine_iterations, mask_feather_px=mask_feather_px, request_id=request_id, ) log_event("diffueraser.read_mask", "loaded masks", request_id=request_id, masks=len(validation_masks_input), masked_images=len(validation_images_input)) _progress(0.07, "DiffuEraser: reading ProPainter priori") ################ read priori ################ with timed_stage("diffueraser.read_priori", request_id=request_id, input=priori): prioris = read_priori(priori, fps, n_total_frames, img_size) log_event("diffueraser.read_priori", "loaded priori frames", request_id=request_id, prioris=len(prioris)) ## recheck n_total_frames = min(min(len(frames), len(validation_masks_input)), len(prioris)) if(n_total_frames<22): raise ValueError("The effective video duration is too short. Please make sure that the number of frames of video, mask, and priori is at least greater than 22 frames.") validation_masks_input = validation_masks_input[:n_total_frames] validation_images_input = validation_images_input[:n_total_frames] frames = frames[:n_total_frames] prioris = prioris[:n_total_frames] log_event( "diffueraser.recheck", "aligned frame counts", request_id=request_id, n_total_frames=n_total_frames, frames=len(frames), masks=len(validation_masks_input), prioris=len(prioris), ) _progress(0.10, "DiffuEraser: resizing inputs") with timed_stage("diffueraser.resize_inputs", request_id=request_id): prioris = resize_frames(prioris) validation_masks_input = resize_frames(validation_masks_input) validation_images_input = resize_frames(validation_images_input) resized_frames = resize_frames(frames) log_event( "diffueraser.resize_inputs", "resized input lists", request_id=request_id, prioris=len(prioris), masks=len(validation_masks_input), masked_images=len(validation_images_input), frames=len(resized_frames), ) ############################################## # DiffuEraser inference ############################################## _progress(0.14, "DiffuEraser: preparing inference") log_event("diffueraser.inference", "begin core inference", request_id=request_id) if seed is None: generator = None else: generator = torch.Generator(device=self.device).manual_seed(seed) ## random noise real_video_length = len(validation_images_input) tar_width, tar_height = validation_images_input[0].size shape = ( nframes, 4, tar_height//8, tar_width//8 ) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet_main is not None: prompt_embeds_dtype = self.unet_main.dtype else: prompt_embeds_dtype = torch.float16 log_event( "diffueraser.latents", "preparing noise", request_id=request_id, real_video_length=real_video_length, n_clip=n_clip, latent_shape="x".join(str(x) for x in shape), target_size=f"{tar_width}x{tar_height}", dtype=prompt_embeds_dtype, ) _progress(0.16, "DiffuEraser: preparing noise") with timed_stage("diffueraser.prepare_noise", request_id=request_id): noise_pre = randn_tensor(shape, device=torch.device(self.device), dtype=prompt_embeds_dtype, generator=generator) noise = repeat(noise_pre, "t c h w->(repeat t) c h w", repeat=n_clip)[:real_video_length,...] _progress(0.18, "DiffuEraser: encoding priori latents") ################ prepare priori ################ with timed_stage("diffueraser.prepare_priori_latents", request_id=request_id, prioris=len(prioris)): images_preprocessed = [] for image in prioris: image = self.image_processor.preprocess(image, height=tar_height, width=tar_width).to(dtype=torch.float32) image = image.to(device=torch.device(self.device), dtype=torch.float16) images_preprocessed.append(image) pixel_values = torch.cat(images_preprocessed) with torch.no_grad(): pixel_values = pixel_values.to(dtype=torch.float16) latents = [] num=4 for i in range(0, pixel_values.shape[0], num): latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * self.vae.config.scaling_factor #[(b f), c1=4, h, w] log_event("diffueraser.prepare_priori_latents", "created latents", request_id=request_id, latents_shape="x".join(str(x) for x in latents.shape)) torch.cuda.empty_cache() log_cuda_memory("diffueraser.after_prepare_latents", request_id=request_id) timesteps = torch.tensor([0], device=self.device) timesteps = timesteps.long() validation_masks_input_ori = copy.deepcopy(validation_masks_input) resized_frames_ori = copy.deepcopy(resized_frames) ################ Pre-inference ################ if n_total_frames > nframes*2: ## do pre-inference only when number of input frames is larger than nframes*2 with timed_stage("diffueraser.pre_inference", request_id=request_id, n_total_frames=n_total_frames, nframes=nframes): ## sample step = n_total_frames / nframes sample_index = [int(i * step) for i in range(nframes)] sample_index = sample_index[:22] log_event("diffueraser.pre_inference", "sampled frames", request_id=request_id, sample_count=len(sample_index)) validation_masks_input_pre = [validation_masks_input[i] for i in sample_index] validation_images_input_pre = [validation_images_input[i] for i in sample_index] latents_pre = torch.stack([latents[i] for i in sample_index]) ## add proiri noisy_latents_pre = self.noise_scheduler.add_noise(latents_pre, noise_pre, timesteps) latents_pre = noisy_latents_pre with torch.no_grad(): latents_pre_out = self.pipeline( num_frames=nframes, prompt=validation_prompt, images=validation_images_input_pre, masks=validation_masks_input_pre, num_inference_steps=self.num_inference_steps, generator=generator, guidance_scale=guidance_scale_final, negative_prompt=negative_prompt, latents=latents_pre, callback_on_step_end=_pipeline_progress(0.30, 0.50, "DiffuEraser: pre-inference"), callback_on_step_end_tensor_inputs=[], ).latents torch.cuda.empty_cache() log_cuda_memory("diffueraser.after_pre_inference_pipeline", request_id=request_id) def decode_latents(latents, weight_dtype): latents = 1 / self.vae.config.scaling_factor * latents video = [] for t in range(latents.shape[0]): video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample) video = torch.concat(video, dim=0) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video _progress(0.52, "DiffuEraser: decoding pre-inference frames") with timed_stage("diffueraser.pre_inference.decode", request_id=request_id, latents_shape="x".join(str(x) for x in latents_pre_out.shape)): with torch.no_grad(): video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16) images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil") torch.cuda.empty_cache() ## replace input frames with updated frames black_image = Image.new('L', validation_masks_input[0].size, color=0) for i,index in enumerate(sample_index): latents[index] = latents_pre_out[i] validation_masks_input[index] = black_image validation_images_input[index] = images_pre_out[i] resized_frames[index] = images_pre_out[i] else: _progress(0.55, "DiffuEraser: pre-inference skipped") log_event("diffueraser.pre_inference", "skipped", request_id=request_id, reason="not_enough_frames", n_total_frames=n_total_frames, threshold=nframes*2) latents_pre_out=None sample_index=None gc.collect() torch.cuda.empty_cache() log_cuda_memory("diffueraser.after_pre_inference", request_id=request_id) _progress(0.58, "DiffuEraser: frame inference") ################ Frame-by-frame inference ################ with timed_stage("diffueraser.frame_inference", request_id=request_id, frames=len(validation_images_input), steps=self.num_inference_steps): ## add priori noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) latents = noisy_latents with torch.no_grad(): images = self.pipeline( num_frames=nframes, prompt=validation_prompt, images=validation_images_input, masks=validation_masks_input, num_inference_steps=self.num_inference_steps, generator=generator, guidance_scale=guidance_scale_final, negative_prompt=negative_prompt, latents=latents, callback_on_step_end=_pipeline_progress(0.60, 0.86, "DiffuEraser: frame inference"), callback_on_step_end_tensor_inputs=[], ).frames images = images[:real_video_length] log_event("diffueraser.frame_inference", "generated frames", request_id=request_id, frames=len(images)) gc.collect() torch.cuda.empty_cache() log_cuda_memory("diffueraser.after_frame_inference", request_id=request_id) _progress(0.88, "DiffuEraser: composing output") ################ Compose ################ with timed_stage("diffueraser.compose_write", request_id=request_id, output=output_path, frames=real_video_length): binary_masks = validation_masks_input_ori mask_blurreds = [] if blended: # blur, you can adjust the parameters for better performance for i in range(len(binary_masks)): mask_blurred = cv2.GaussianBlur(np.array(binary_masks[i]), (21, 21), 0)/255. binary_mask = 1-(1-np.array(binary_masks[i])/255.) * (1-mask_blurred) mask_blurreds.append(Image.fromarray((binary_mask*255).astype(np.uint8))) binary_masks = mask_blurreds comp_frames = [] for i in range(len(images)): mask = np.expand_dims(np.array(binary_masks[i]),2).repeat(3, axis=2).astype(np.float32)/255. img = (np.array(images[i]).astype(np.uint8) * mask \ + np.array(resized_frames_ori[i]).astype(np.uint8) * (1 - mask)).astype(np.uint8) comp_frames.append(Image.fromarray(img)) default_fps = fps output_frames = comp_frames writer_fps = default_fps if output_fps is not None: output_fps = float(output_fps) if output_fps > 0 and output_fps < default_fps: target_count = max(1, int(round(len(comp_frames) * output_fps / default_fps))) frame_indices = np.linspace(0, len(comp_frames) - 1, target_count).round().astype(int) output_frames = [comp_frames[i] for i in frame_indices] writer_fps = output_fps log_event( "diffueraser.output_fps", "downsampled output frames", request_id=request_id, input_fps=f"{default_fps:.2f}", output_fps=f"{writer_fps:.2f}", input_frames=len(comp_frames), output_frames=len(output_frames), ) else: log_event( "diffueraser.output_fps", "kept processed fps", request_id=request_id, input_fps=f"{default_fps:.2f}", requested_output_fps=f"{output_fps:.2f}", ) writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), writer_fps, output_frames[0].size) for f in range(len(output_frames)): img = np.array(output_frames[f]).astype(np.uint8) writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) writer.release() log_event( "diffueraser.compose_write", "wrote output", request_id=request_id, output=output_path, fps=f"{writer_fps:.2f}", source_fps=f"{default_fps:.2f}", frame_size=f"{output_frames[0].size[0]}x{output_frames[0].size[1]}", frames=len(output_frames), source_frames=len(comp_frames), ) ################################ _progress(1.0, "DiffuEraser: done") log_cuda_memory("diffueraser.forward.end", request_id=request_id) return output_path