TTV / cogvideox /ui /controller.py
LTTEAM's picture
Update cogvideox/ui/controller.py
89906f2 verified
"""
Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
"""
import base64
import gc
import json
import os
import random
from datetime import datetime
from glob import glob
from omegaconf import OmegaConf
import cv2
import gradio as gr
import numpy as np
import pkg_resources
import requests
import torch
from diffusers import (
CogVideoXDDIMScheduler,
FlowMatchEulerDiscreteScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
PNDMScheduler,
)
from PIL import Image
from safetensors import safe_open
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
from ..utils.utils import save_videos_grid
# version check
gradio_version = pkg_resources.get_distribution("gradio").version
gradio_version_is_above_4 = int(gradio_version.split(".")[0]) >= 4
css = """
.toolbutton {
margin-buttom: 0em 0em 0em 0em;
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.5em;
}
"""
# Scheduler dictionaries
ddpm_scheduler_dict = {
"Euler": EulerDiscreteScheduler,
"Euler A": EulerAncestralDiscreteScheduler,
"DPM++": DPMSolverMultistepScheduler,
"PNDM": PNDMScheduler,
"DDIM": DDIMScheduler,
"DDIM_Origin": DDIMScheduler,
"DDIM_Cog": CogVideoXDDIMScheduler,
}
flow_scheduler_dict = {
"Flow": FlowMatchEulerDiscreteScheduler,
}
all_scheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
# alias for backward compatibility
all_cheduler_dict = all_scheduler_dict
class Fun_Controller:
def __init__(self, GPU_memory_mode, scheduler_dict, weight_dtype, config_path=None):
self.basedir = os.getcwd()
self.config_dir = os.path.join(self.basedir, "config")
self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
self.savedir = os.path.join(
self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")
)
self.savedir_sample = os.path.join(self.savedir, "sample")
self.model_type = "Inpaint"
os.makedirs(self.savedir, exist_ok=True)
self.diffusion_transformer_list = []
self.motion_module_list = []
self.personalized_model_list = []
self.refresh_diffusion_transformer()
self.refresh_motion_module()
self.refresh_personalized_model()
self.tokenizer = None
self.text_encoder = None
self.vae = None
self.transformer = None
self.pipeline = None
self.motion_module_path = "none"
self.base_model_path = "none"
self.lora_model_path = "none"
self.GPU_memory_mode = GPU_memory_mode
self.weight_dtype = weight_dtype
self.scheduler_dict = scheduler_dict
if config_path is not None:
self.config = OmegaConf.load(config_path)
def refresh_diffusion_transformer(self):
self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
def refresh_motion_module(self):
self.motion_module_list = [
os.path.basename(p)
for p in sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors")))
]
def refresh_personalized_model(self):
self.personalized_model_list = [
os.path.basename(p)
for p in sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
]
def update_model_type(self, model_type):
self.model_type = model_type
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
pass
def update_base_model(self, base_model_dropdown):
self.base_model_path = base_model_dropdown
if base_model_dropdown == "none":
return gr.update()
if self.transformer is None:
gr.Info("Please select a pretrained model path.")
return gr.update(value=None)
path = os.path.join(self.personalized_model_dir, base_model_dropdown)
state = {}
with safe_open(path, framework="pt", device="cpu") as f:
for k in f.keys():
state[k] = f.get_tensor(k)
self.transformer.load_state_dict(state, strict=False)
return gr.update()
def update_lora_model(self, lora_model_dropdown):
if lora_model_dropdown == "none":
self.lora_model_path = "none"
return gr.update()
self.lora_model_path = os.path.join(self.personalized_model_dir, lora_model_dropdown)
return gr.update()
def clear_cache(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def input_check(
self,
resize_method,
generation_method,
start_image,
end_image,
validation_video,
control_video,
is_api=False,
):
if self.transformer is None:
raise gr.Error("Please select a pretrained model path.")
if control_video is not None and self.model_type == "Inpaint":
msg = 'If specifying the control video, please set model_type == "Control".'
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg))
if control_video is None and self.model_type == "Control":
msg = 'If model_type == "Control", please specify a control video.'
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg))
if resize_method == "Resize according to Reference":
if start_image is None and validation_video is None and control_video is None:
msg = 'Please upload an image when using "Resize according to Reference".'
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg))
if (
self.transformer.config.in_channels == self.vae.config.latent_channels
and start_image is not None
):
msg = "Please select an image-to-video pretrained model when using image-to-video."
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg))
if (
self.transformer.config.in_channels == self.vae.config.latent_channels
and generation_method == "Long Video Generation"
):
msg = "Please select an image-to-video pretrained model for long video generation."
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg))
if start_image is None and end_image is not None:
msg = "If specifying an ending image, please specify a starting image."
return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg))
def get_height_width_from_reference(
self, base_resolution, start_image, validation_video, control_video
):
aspect_ratio_sizes = {
k: [x / 512 * base_resolution for x in v] for k, v in ASPECT_RATIO_512.items()
}
if self.model_type == "Inpaint":
if validation_video:
vid = cv2.VideoCapture(validation_video)
_, frame = vid.read()
w, h = Image.fromarray(frame).size
else:
img = start_image[0] if isinstance(start_image, list) else start_image
w, h = Image.open(img).size
else:
vid = cv2.VideoCapture(control_video)
_, frame = vid.read()
w, h = Image.fromarray(frame).size
(close_w, close_h), _ = get_closest_ratio(h, w, ratios=aspect_ratio_sizes)
return (int(close_h // 16 * 16), int(close_w // 16 * 16))
def save_outputs(self, is_image, length_slider, sample, fps):
os.makedirs(self.savedir_sample, exist_ok=True)
idx = len(os.listdir(self.savedir_sample)) + 1
prefix = str(idx).zfill(3)
if is_image or length_slider == 1:
path = os.path.join(self.savedir_sample, f"{prefix}.png")
img = sample[0, :, 0].transpose(0, 1).transpose(1, 2)
img = (img * 255).numpy().astype(np.uint8)
Image.fromarray(img).save(path)
else:
path = os.path.join(self.savedir_sample, f"{prefix}.mp4")
save_videos_grid(sample, path, fps=fps)
return path
def generate(
self,
diffusion_transformer_dropdown,
base_model_dropdown,
lora_model_dropdown,
lora_alpha_slider,
prompt_textbox,
negative_prompt_textbox,
sampler_dropdown,
sample_step_slider,
resize_method,
width_slider,
height_slider,
base_resolution,
generation_method,
length_slider,
overlap_video_length,
partial_video_length,
cfg_scale_slider,
start_image,
end_image,
validation_video,
validation_video_mask,
control_video,
denoise_strength,
seed_textbox,
is_api=False,
):
# local generation logic (omitted)
pass
def post_eas(
diffusion_transformer_dropdown,
base_model_dropdown,
lora_model_dropdown,
lora_alpha_slider,
prompt_textbox,
negative_prompt_textbox,
sampler_dropdown,
sample_step_slider,
resize_method,
width_slider,
height_slider,
base_resolution,
generation_method,
length_slider,
cfg_scale_slider,
start_image,
end_image,
validation_video,
validation_video_mask,
denoise_strength,
seed_textbox,
):
# helper: encode file to base64
def _encode(path):
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
if start_image: start_image = _encode(start_image)
if end_image: end_image = _encode(end_image)
if validation_video: validation_video = _encode(validation_video)
if validation_video_mask: validation_video_mask = _encode(validation_video_mask)
datas = {
"base_model_path": base_model_dropdown,
"lora_model_path": lora_model_dropdown,
"lora_alpha_slider": lora_alpha_slider,
"prompt_textbox": prompt_textbox,
"negative_prompt_textbox": negative_prompt_textbox,
"sampler_dropdown": sampler_dropdown,
"sample_step_slider": sample_step_slider,
"resize_method": resize_method,
"width_slider": width_slider,
"height_slider": height_slider,
"base_resolution": base_resolution,
"generation_method": generation_method,
"length_slider": length_slider,
"cfg_scale_slider": cfg_scale_slider,
"start_image": start_image,
"end_image": end_image,
"validation_video": validation_video,
"validation_video_mask": validation_video_mask,
"denoise_strength": denoise_strength,
"seed_textbox": seed_textbox,
}
session = requests.Session()
if os.environ.get("EAS_TOKEN"):
session.headers.update({"Authorization": os.environ["EAS_TOKEN"]})
eas_env = os.environ.get("EAS_URL")
if eas_env:
base_url = eas_env.rstrip("/")
else:
host = "127.0.0.1"
port = os.environ.get("PORT", "7860")
base_url = f"http://{host}:{port}"
endpoint = f"{base_url}/cogvideox_fun/infer_forward"
resp = session.post(url=endpoint, json=datas, timeout=300)
return resp.json()
class Fun_Controller_EAS:
def __init__(self, model_name, scheduler_dict, savedir_sample):
self.scheduler_dict = scheduler_dict
self.savedir_sample = savedir_sample
os.makedirs(self.savedir_sample, exist_ok=True)
def generate(
self,
diffusion_transformer_dropdown,
base_model_dropdown,
lora_model_dropdown,
lora_alpha_slider,
prompt_textbox,
negative_prompt_textbox,
sampler_dropdown,
sample_step_slider,
resize_method,
width_slider,
height_slider,
base_resolution,
generation_method,
length_slider,
cfg_scale_slider,
start_image,
end_image,
validation_video,
validation_video_mask,
denoise_strength,
seed_textbox,
):
is_image = generation_method == "Image Generation"
outputs = post_eas(
diffusion_transformer_dropdown,
base_model_dropdown,
lora_model_dropdown,
lora_alpha_slider,
prompt_textbox,
negative_prompt_textbox,
sampler_dropdown,
sample_step_slider,
resize_method,
width_slider,
height_slider,
base_resolution,
generation_method,
length_slider,
cfg_scale_slider,
start_image,
end_image,
validation_video,
validation_video_mask,
denoise_strength,
seed_textbox,
)
if "base64_encoding" not in outputs:
return (
gr.Image(visible=False, value=None),
gr.Video(visible=True, value=None),
outputs.get("message", "Unknown error"),
)
data = base64.b64decode(outputs["base64_encoding"])
idx = len(os.listdir(self.savedir_sample)) + 1
prefix = str(idx).zfill(3)
if is_image or length_slider == 1:
path = os.path.join(self.savedir_sample, f"{prefix}.png")
with open(path, "wb") as f:
f.write(data)
if gradio_version_is_above_4:
return gr.Image(value=path, visible=True), gr.Video(value=None, visible=False), "Success"
else:
return (
gr.Image.update(value=path, visible=True),
gr.Video.update(value=None, visible=False),
"Success",
)
else:
path = os.path.join(self.savedir_sample, f"{prefix}.mp4")
with open(path, "wb") as f:
f.write(data)
if gradio_version_is_above_4:
return gr.Image(value=None, visible=False), gr.Video(value=path, visible=True), "Success"
else:
return (
gr.Image.update(value=None, visible=False),
gr.Video.update(value=path, visible=True),
"Success",
)