DiffuEraser-demo / diffueraser /diffueraser.py
jing96963's picture
Fix CLIPFeatureExtractor missing in transformers 5.x: feature_extractor=None
a1478f9
import gc
import copy
import cv2
import os
import numpy as np
import torch
import torchvision
from einops import repeat
from PIL import Image, ImageFilter
from diffusers import (
AutoencoderKL,
DDPMScheduler,
UniPCMultistepScheduler,
LCMScheduler,
)
from diffusers.schedulers import TCDScheduler
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.utils.torch_utils import randn_tensor
from transformers import AutoTokenizer, PretrainedConfig
from libs.unet_motion_model import MotionAdapter, UNetMotionModel
from libs.brushnet_CA import BrushNetModel
from libs.unet_2d_condition import UNet2DConditionModel
from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline
checkpoints = {
"2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0],
"4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0],
"8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0],
"16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0],
"Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5],
"Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5],
"Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5],
"LCM-Like LoRA": [
"pcm_{}_lcmlike_lora_converted.safetensors",
4,
0.0,
],
}
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
def resize_frames(frames, size=None):
if size is not None:
out_size = size
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
frames = [f.resize(process_size) for f in frames]
else:
out_size = frames[0].size
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
if not out_size == process_size:
frames = [f.resize(process_size) for f in frames]
return frames
def read_mask(validation_mask, fps, n_total_frames, img_size, mask_dilation_iter, frames):
cap = cv2.VideoCapture(validation_mask)
if not cap.isOpened():
print("Error: Could not open mask video.")
exit()
mask_fps = cap.get(cv2.CAP_PROP_FPS)
if mask_fps != fps:
cap.release()
raise ValueError("The frame rate of all input videos needs to be consistent.")
masks = []
masked_images = []
idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if(idx >= n_total_frames):
break
mask = Image.fromarray(frame[...,::-1]).convert('L')
if mask.size != img_size:
mask = mask.resize(img_size, Image.NEAREST)
mask = np.asarray(mask)
m = np.array(mask > 0).astype(np.uint8)
m = cv2.erode(m,
cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
iterations=1)
m = cv2.dilate(m,
cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
iterations=mask_dilation_iter)
mask = Image.fromarray(m * 255)
masks.append(mask)
masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255))
masked_image = Image.fromarray(masked_image.astype(np.uint8))
masked_images.append(masked_image)
idx += 1
cap.release()
return masks, masked_images
def read_priori(priori, fps, n_total_frames, img_size):
cap = cv2.VideoCapture(priori)
if not cap.isOpened():
print("Error: Could not open video.")
exit()
priori_fps = cap.get(cv2.CAP_PROP_FPS)
if priori_fps != fps:
cap.release()
raise ValueError("The frame rate of all input videos needs to be consistent.")
prioris=[]
idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if(idx >= n_total_frames):
break
img = Image.fromarray(frame[...,::-1])
if img.size != img_size:
img = img.resize(img_size)
prioris.append(img)
idx += 1
cap.release()
os.remove(priori) # remove priori
return prioris
def read_video(validation_image, video_length, nframes, max_img_size):
vframes, aframes, info = torchvision.io.read_video(filename=validation_image, pts_unit='sec', end_pts=video_length) # RGB
fps = info['video_fps']
n_total_frames = int(video_length * fps)
n_clip = int(np.ceil(n_total_frames/nframes))
frames = list(vframes.numpy())[:n_total_frames]
frames = [Image.fromarray(f) for f in frames]
max_size = max(frames[0].size)
if(max_size<256):
raise ValueError("The resolution of the uploaded video must be larger than 256x256.")
if(max_size>4096):
raise ValueError("The resolution of the uploaded video must be smaller than 4096x4096.")
if max_size>max_img_size:
ratio = max_size/max_img_size
ratio_size = (int(frames[0].size[0]/ratio),int(frames[0].size[1]/ratio))
img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
resize_flag=True
elif (frames[0].size[0]%8==0) and (frames[0].size[1]%8==0):
img_size = frames[0].size
resize_flag=False
else:
ratio_size = frames[0].size
img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
resize_flag=True
if resize_flag:
frames = resize_frames(frames, img_size)
img_size = frames[0].size
return frames, fps, img_size, n_clip, n_total_frames
class DiffuEraser:
def __init__(
self, device, base_model_path, vae_path, diffueraser_path, revision=None,
ckpt="Normal CFG 4-Step", mode="sd15", loaded=None):
self.device = device
## load model
self.vae = AutoencoderKL.from_pretrained(vae_path)
self.noise_scheduler = DDPMScheduler.from_pretrained(base_model_path,
subfolder="scheduler",
prediction_type="v_prediction",
timestep_spacing="trailing",
rescale_betas_zero_snr=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
base_model_path,
subfolder="tokenizer",
use_fast=False,
)
text_encoder_cls = import_model_class_from_model_name_or_path(base_model_path,revision)
self.text_encoder = text_encoder_cls.from_pretrained(
base_model_path, subfolder="text_encoder"
)
self.brushnet = BrushNetModel.from_pretrained(diffueraser_path, subfolder="brushnet")
self.unet_main = UNetMotionModel.from_pretrained(
diffueraser_path, subfolder="unet_main",
)
## set pipeline
self.pipeline = StableDiffusionDiffuEraserPipeline.from_pretrained(
base_model_path,
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.unet_main,
brushnet=self.brushnet,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
).to(self.device, torch.float16)
self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
self.pipeline.set_progress_bar_config(disable=True)
self.noise_scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
## use PCM
self.ckpt = ckpt
PCM_ckpts = checkpoints[ckpt][0].format(mode)
self.guidance_scale = checkpoints[ckpt][2]
if loaded != (ckpt + mode):
self.pipeline.load_lora_weights(
"weights/PCM_Weights", weight_name=PCM_ckpts, subfolder=mode
)
loaded = ckpt + mode
if ckpt == "LCM-Like LoRA":
self.pipeline.scheduler = LCMScheduler()
else:
self.pipeline.scheduler = TCDScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
timestep_spacing="trailing",
)
self.num_inference_steps = checkpoints[ckpt][1]
self.guidance_scale = 0
def forward(self, validation_image, validation_mask, priori, output_path,
max_img_size = 1280, video_length=2, mask_dilation_iter=4,
nframes=22, seed=None, revision = None, guidance_scale=None, blended=True):
validation_prompt = "" #
guidance_scale_final = self.guidance_scale if guidance_scale==None else guidance_scale
if (max_img_size<256 or max_img_size>1920):
raise ValueError("The max_img_size must be larger than 256, smaller than 1920.")
################ read input video ################
frames, fps, img_size, n_clip, n_total_frames = read_video(validation_image, video_length, nframes, max_img_size)
video_len = len(frames)
################ read mask ################
validation_masks_input, validation_images_input = read_mask(validation_mask, fps, video_len, img_size, mask_dilation_iter, frames)
################ read priori ################
prioris = read_priori(priori, fps, n_total_frames, img_size)
## recheck
n_total_frames = min(min(len(frames), len(validation_masks_input)), len(prioris))
if(n_total_frames<22):
raise ValueError("The effective video duration is too short. Please make sure that the number of frames of video, mask, and priori is at least greater than 22 frames.")
validation_masks_input = validation_masks_input[:n_total_frames]
validation_images_input = validation_images_input[:n_total_frames]
frames = frames[:n_total_frames]
prioris = prioris[:n_total_frames]
prioris = resize_frames(prioris)
validation_masks_input = resize_frames(validation_masks_input)
validation_images_input = resize_frames(validation_images_input)
resized_frames = resize_frames(frames)
##############################################
# DiffuEraser inference
##############################################
print("DiffuEraser inference...")
if seed is None:
generator = None
else:
generator = torch.Generator(device=self.device).manual_seed(seed)
## random noise
real_video_length = len(validation_images_input)
tar_width, tar_height = validation_images_input[0].size
shape = (
nframes,
4,
tar_height//8,
tar_width//8
)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet_main is not None:
prompt_embeds_dtype = self.unet_main.dtype
else:
prompt_embeds_dtype = torch.float16
noise_pre = randn_tensor(shape, device=torch.device(self.device), dtype=prompt_embeds_dtype, generator=generator)
noise = repeat(noise_pre, "t c h w->(repeat t) c h w", repeat=n_clip)[:real_video_length,...]
################ prepare priori ################
images_preprocessed = []
for image in prioris:
image = self.image_processor.preprocess(image, height=tar_height, width=tar_width).to(dtype=torch.float32)
image = image.to(device=torch.device(self.device), dtype=torch.float16)
images_preprocessed.append(image)
pixel_values = torch.cat(images_preprocessed)
with torch.no_grad():
pixel_values = pixel_values.to(dtype=torch.float16)
latents = []
num=4
for i in range(0, pixel_values.shape[0], num):
latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample())
latents = torch.cat(latents, dim=0)
latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
torch.cuda.empty_cache()
timesteps = torch.tensor([0], device=self.device)
timesteps = timesteps.long()
validation_masks_input_ori = copy.deepcopy(validation_masks_input)
resized_frames_ori = copy.deepcopy(resized_frames)
################ Pre-inference ################
if n_total_frames > nframes*2: ## do pre-inference only when number of input frames is larger than nframes*2
## sample
step = n_total_frames / nframes
sample_index = [int(i * step) for i in range(nframes)]
sample_index = sample_index[:22]
validation_masks_input_pre = [validation_masks_input[i] for i in sample_index]
validation_images_input_pre = [validation_images_input[i] for i in sample_index]
latents_pre = torch.stack([latents[i] for i in sample_index])
## add proiri
noisy_latents_pre = self.noise_scheduler.add_noise(latents_pre, noise_pre, timesteps)
latents_pre = noisy_latents_pre
with torch.no_grad():
latents_pre_out = self.pipeline(
num_frames=nframes,
prompt=validation_prompt,
images=validation_images_input_pre,
masks=validation_masks_input_pre,
num_inference_steps=self.num_inference_steps,
generator=generator,
guidance_scale=guidance_scale_final,
latents=latents_pre,
).latents
torch.cuda.empty_cache()
def decode_latents(latents, weight_dtype):
latents = 1 / self.vae.config.scaling_factor * latents
video = []
for t in range(latents.shape[0]):
video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
video = torch.concat(video, dim=0)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
with torch.no_grad():
video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
torch.cuda.empty_cache()
## replace input frames with updated frames
black_image = Image.new('L', validation_masks_input[0].size, color=0)
for i,index in enumerate(sample_index):
latents[index] = latents_pre_out[i]
validation_masks_input[index] = black_image
validation_images_input[index] = images_pre_out[i]
resized_frames[index] = images_pre_out[i]
else:
latents_pre_out=None
sample_index=None
gc.collect()
torch.cuda.empty_cache()
################ Frame-by-frame inference ################
## add priori
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
latents = noisy_latents
with torch.no_grad():
images = self.pipeline(
num_frames=nframes,
prompt=validation_prompt,
images=validation_images_input,
masks=validation_masks_input,
num_inference_steps=self.num_inference_steps,
generator=generator,
guidance_scale=guidance_scale_final,
latents=latents,
).frames
images = images[:real_video_length]
gc.collect()
torch.cuda.empty_cache()
################ Compose ################
binary_masks = validation_masks_input_ori
mask_blurreds = []
if blended:
# blur, you can adjust the parameters for better performance
for i in range(len(binary_masks)):
mask_blurred = cv2.GaussianBlur(np.array(binary_masks[i]), (21, 21), 0)/255.
binary_mask = 1-(1-np.array(binary_masks[i])/255.) * (1-mask_blurred)
mask_blurreds.append(Image.fromarray((binary_mask*255).astype(np.uint8)))
binary_masks = mask_blurreds
comp_frames = []
for i in range(len(images)):
mask = np.expand_dims(np.array(binary_masks[i]),2).repeat(3, axis=2).astype(np.float32)/255.
img = (np.array(images[i]).astype(np.uint8) * mask \
+ np.array(resized_frames_ori[i]).astype(np.uint8) * (1 - mask)).astype(np.uint8)
comp_frames.append(Image.fromarray(img))
default_fps = fps
writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"),
default_fps, comp_frames[0].size)
for f in range(real_video_length):
img = np.array(comp_frames[f]).astype(np.uint8)
writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
writer.release()
################################
return output_path