|
|
""" |
|
|
Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py |
|
|
""" |
|
|
import base64 |
|
|
import gc |
|
|
import json |
|
|
import os |
|
|
import random |
|
|
from datetime import datetime |
|
|
from glob import glob |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pkg_resources |
|
|
import requests |
|
|
import torch |
|
|
from diffusers import ( |
|
|
CogVideoXDDIMScheduler, |
|
|
FlowMatchEulerDiscreteScheduler, |
|
|
DDIMScheduler, |
|
|
DPMSolverMultistepScheduler, |
|
|
EulerAncestralDiscreteScheduler, |
|
|
EulerDiscreteScheduler, |
|
|
PNDMScheduler, |
|
|
) |
|
|
from PIL import Image |
|
|
from safetensors import safe_open |
|
|
|
|
|
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio |
|
|
from ..utils.utils import save_videos_grid |
|
|
|
|
|
|
|
|
gradio_version = pkg_resources.get_distribution("gradio").version |
|
|
gradio_version_is_above_4 = int(gradio_version.split(".")[0]) >= 4 |
|
|
|
|
|
css = """ |
|
|
.toolbutton { |
|
|
margin-buttom: 0em 0em 0em 0em; |
|
|
max-width: 2.5em; |
|
|
min-width: 2.5em !important; |
|
|
height: 2.5em; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
ddpm_scheduler_dict = { |
|
|
"Euler": EulerDiscreteScheduler, |
|
|
"Euler A": EulerAncestralDiscreteScheduler, |
|
|
"DPM++": DPMSolverMultistepScheduler, |
|
|
"PNDM": PNDMScheduler, |
|
|
"DDIM": DDIMScheduler, |
|
|
"DDIM_Origin": DDIMScheduler, |
|
|
"DDIM_Cog": CogVideoXDDIMScheduler, |
|
|
} |
|
|
flow_scheduler_dict = { |
|
|
"Flow": FlowMatchEulerDiscreteScheduler, |
|
|
} |
|
|
all_scheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict} |
|
|
|
|
|
|
|
|
all_cheduler_dict = all_scheduler_dict |
|
|
|
|
|
|
|
|
class Fun_Controller: |
|
|
def __init__(self, GPU_memory_mode, scheduler_dict, weight_dtype, config_path=None): |
|
|
self.basedir = os.getcwd() |
|
|
self.config_dir = os.path.join(self.basedir, "config") |
|
|
self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer") |
|
|
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") |
|
|
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model") |
|
|
self.savedir = os.path.join( |
|
|
self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S") |
|
|
) |
|
|
self.savedir_sample = os.path.join(self.savedir, "sample") |
|
|
self.model_type = "Inpaint" |
|
|
os.makedirs(self.savedir, exist_ok=True) |
|
|
|
|
|
self.diffusion_transformer_list = [] |
|
|
self.motion_module_list = [] |
|
|
self.personalized_model_list = [] |
|
|
|
|
|
self.refresh_diffusion_transformer() |
|
|
self.refresh_motion_module() |
|
|
self.refresh_personalized_model() |
|
|
|
|
|
self.tokenizer = None |
|
|
self.text_encoder = None |
|
|
self.vae = None |
|
|
self.transformer = None |
|
|
self.pipeline = None |
|
|
self.motion_module_path = "none" |
|
|
self.base_model_path = "none" |
|
|
self.lora_model_path = "none" |
|
|
self.GPU_memory_mode = GPU_memory_mode |
|
|
self.weight_dtype = weight_dtype |
|
|
self.scheduler_dict = scheduler_dict |
|
|
|
|
|
if config_path is not None: |
|
|
self.config = OmegaConf.load(config_path) |
|
|
|
|
|
def refresh_diffusion_transformer(self): |
|
|
self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/"))) |
|
|
|
|
|
def refresh_motion_module(self): |
|
|
self.motion_module_list = [ |
|
|
os.path.basename(p) |
|
|
for p in sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors"))) |
|
|
] |
|
|
|
|
|
def refresh_personalized_model(self): |
|
|
self.personalized_model_list = [ |
|
|
os.path.basename(p) |
|
|
for p in sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors"))) |
|
|
] |
|
|
|
|
|
def update_model_type(self, model_type): |
|
|
self.model_type = model_type |
|
|
|
|
|
def update_diffusion_transformer(self, diffusion_transformer_dropdown): |
|
|
pass |
|
|
|
|
|
def update_base_model(self, base_model_dropdown): |
|
|
self.base_model_path = base_model_dropdown |
|
|
if base_model_dropdown == "none": |
|
|
return gr.update() |
|
|
if self.transformer is None: |
|
|
gr.Info("Please select a pretrained model path.") |
|
|
return gr.update(value=None) |
|
|
path = os.path.join(self.personalized_model_dir, base_model_dropdown) |
|
|
state = {} |
|
|
with safe_open(path, framework="pt", device="cpu") as f: |
|
|
for k in f.keys(): |
|
|
state[k] = f.get_tensor(k) |
|
|
self.transformer.load_state_dict(state, strict=False) |
|
|
return gr.update() |
|
|
|
|
|
def update_lora_model(self, lora_model_dropdown): |
|
|
if lora_model_dropdown == "none": |
|
|
self.lora_model_path = "none" |
|
|
return gr.update() |
|
|
self.lora_model_path = os.path.join(self.personalized_model_dir, lora_model_dropdown) |
|
|
return gr.update() |
|
|
|
|
|
def clear_cache(self): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.ipc_collect() |
|
|
|
|
|
def input_check( |
|
|
self, |
|
|
resize_method, |
|
|
generation_method, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
control_video, |
|
|
is_api=False, |
|
|
): |
|
|
if self.transformer is None: |
|
|
raise gr.Error("Please select a pretrained model path.") |
|
|
|
|
|
if control_video is not None and self.model_type == "Inpaint": |
|
|
msg = 'If specifying the control video, please set model_type == "Control".' |
|
|
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) |
|
|
|
|
|
if control_video is None and self.model_type == "Control": |
|
|
msg = 'If model_type == "Control", please specify a control video.' |
|
|
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) |
|
|
|
|
|
if resize_method == "Resize according to Reference": |
|
|
if start_image is None and validation_video is None and control_video is None: |
|
|
msg = 'Please upload an image when using "Resize according to Reference".' |
|
|
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) |
|
|
|
|
|
if ( |
|
|
self.transformer.config.in_channels == self.vae.config.latent_channels |
|
|
and start_image is not None |
|
|
): |
|
|
msg = "Please select an image-to-video pretrained model when using image-to-video." |
|
|
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) |
|
|
|
|
|
if ( |
|
|
self.transformer.config.in_channels == self.vae.config.latent_channels |
|
|
and generation_method == "Long Video Generation" |
|
|
): |
|
|
msg = "Please select an image-to-video pretrained model for long video generation." |
|
|
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) |
|
|
|
|
|
if start_image is None and end_image is not None: |
|
|
msg = "If specifying an ending image, please specify a starting image." |
|
|
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) |
|
|
|
|
|
def get_height_width_from_reference( |
|
|
self, base_resolution, start_image, validation_video, control_video |
|
|
): |
|
|
aspect_ratio_sizes = { |
|
|
k: [x / 512 * base_resolution for x in v] for k, v in ASPECT_RATIO_512.items() |
|
|
} |
|
|
if self.model_type == "Inpaint": |
|
|
if validation_video: |
|
|
vid = cv2.VideoCapture(validation_video) |
|
|
_, frame = vid.read() |
|
|
w, h = Image.fromarray(frame).size |
|
|
else: |
|
|
img = start_image[0] if isinstance(start_image, list) else start_image |
|
|
w, h = Image.open(img).size |
|
|
else: |
|
|
vid = cv2.VideoCapture(control_video) |
|
|
_, frame = vid.read() |
|
|
w, h = Image.fromarray(frame).size |
|
|
|
|
|
(close_w, close_h), _ = get_closest_ratio(h, w, ratios=aspect_ratio_sizes) |
|
|
return (int(close_h // 16 * 16), int(close_w // 16 * 16)) |
|
|
|
|
|
def save_outputs(self, is_image, length_slider, sample, fps): |
|
|
os.makedirs(self.savedir_sample, exist_ok=True) |
|
|
idx = len(os.listdir(self.savedir_sample)) + 1 |
|
|
prefix = str(idx).zfill(3) |
|
|
|
|
|
if is_image or length_slider == 1: |
|
|
path = os.path.join(self.savedir_sample, f"{prefix}.png") |
|
|
img = sample[0, :, 0].transpose(0, 1).transpose(1, 2) |
|
|
img = (img * 255).numpy().astype(np.uint8) |
|
|
Image.fromarray(img).save(path) |
|
|
else: |
|
|
path = os.path.join(self.savedir_sample, f"{prefix}.mp4") |
|
|
save_videos_grid(sample, path, fps=fps) |
|
|
|
|
|
return path |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
overlap_video_length, |
|
|
partial_video_length, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
control_video, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
is_api=False, |
|
|
): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def post_eas( |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
): |
|
|
|
|
|
def _encode(path): |
|
|
with open(path, "rb") as f: |
|
|
return base64.b64encode(f.read()).decode("utf-8") |
|
|
|
|
|
if start_image: start_image = _encode(start_image) |
|
|
if end_image: end_image = _encode(end_image) |
|
|
if validation_video: validation_video = _encode(validation_video) |
|
|
if validation_video_mask: validation_video_mask = _encode(validation_video_mask) |
|
|
|
|
|
datas = { |
|
|
"base_model_path": base_model_dropdown, |
|
|
"lora_model_path": lora_model_dropdown, |
|
|
"lora_alpha_slider": lora_alpha_slider, |
|
|
"prompt_textbox": prompt_textbox, |
|
|
"negative_prompt_textbox": negative_prompt_textbox, |
|
|
"sampler_dropdown": sampler_dropdown, |
|
|
"sample_step_slider": sample_step_slider, |
|
|
"resize_method": resize_method, |
|
|
"width_slider": width_slider, |
|
|
"height_slider": height_slider, |
|
|
"base_resolution": base_resolution, |
|
|
"generation_method": generation_method, |
|
|
"length_slider": length_slider, |
|
|
"cfg_scale_slider": cfg_scale_slider, |
|
|
"start_image": start_image, |
|
|
"end_image": end_image, |
|
|
"validation_video": validation_video, |
|
|
"validation_video_mask": validation_video_mask, |
|
|
"denoise_strength": denoise_strength, |
|
|
"seed_textbox": seed_textbox, |
|
|
} |
|
|
|
|
|
session = requests.Session() |
|
|
if os.environ.get("EAS_TOKEN"): |
|
|
session.headers.update({"Authorization": os.environ["EAS_TOKEN"]}) |
|
|
|
|
|
eas_env = os.environ.get("EAS_URL") |
|
|
if eas_env: |
|
|
base_url = eas_env.rstrip("/") |
|
|
else: |
|
|
host = "127.0.0.1" |
|
|
port = os.environ.get("PORT", "7860") |
|
|
base_url = f"http://{host}:{port}" |
|
|
endpoint = f"{base_url}/cogvideox_fun/infer_forward" |
|
|
|
|
|
resp = session.post(url=endpoint, json=datas, timeout=300) |
|
|
return resp.json() |
|
|
|
|
|
|
|
|
class Fun_Controller_EAS: |
|
|
def __init__(self, model_name, scheduler_dict, savedir_sample): |
|
|
self.scheduler_dict = scheduler_dict |
|
|
self.savedir_sample = savedir_sample |
|
|
os.makedirs(self.savedir_sample, exist_ok=True) |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
): |
|
|
is_image = generation_method == "Image Generation" |
|
|
outputs = post_eas( |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
) |
|
|
|
|
|
if "base64_encoding" not in outputs: |
|
|
return ( |
|
|
gr.Image(visible=False, value=None), |
|
|
gr.Video(visible=True, value=None), |
|
|
outputs.get("message", "Unknown error"), |
|
|
) |
|
|
|
|
|
data = base64.b64decode(outputs["base64_encoding"]) |
|
|
idx = len(os.listdir(self.savedir_sample)) + 1 |
|
|
prefix = str(idx).zfill(3) |
|
|
|
|
|
if is_image or length_slider == 1: |
|
|
path = os.path.join(self.savedir_sample, f"{prefix}.png") |
|
|
with open(path, "wb") as f: |
|
|
f.write(data) |
|
|
if gradio_version_is_above_4: |
|
|
return gr.Image(value=path, visible=True), gr.Video(value=None, visible=False), "Success" |
|
|
else: |
|
|
return ( |
|
|
gr.Image.update(value=path, visible=True), |
|
|
gr.Video.update(value=None, visible=False), |
|
|
"Success", |
|
|
) |
|
|
else: |
|
|
path = os.path.join(self.savedir_sample, f"{prefix}.mp4") |
|
|
with open(path, "wb") as f: |
|
|
f.write(data) |
|
|
if gradio_version_is_above_4: |
|
|
return gr.Image(value=None, visible=False), gr.Video(value=path, visible=True), "Success" |
|
|
else: |
|
|
return ( |
|
|
gr.Image.update(value=None, visible=False), |
|
|
gr.Video.update(value=path, visible=True), |
|
|
"Success", |
|
|
) |
|
|
|