""" Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py """ import base64 import gc import json import os import random from datetime import datetime from glob import glob from omegaconf import OmegaConf import cv2 import gradio as gr import numpy as np import pkg_resources import requests import torch from diffusers import ( CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, ) from PIL import Image from safetensors import safe_open from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio from ..utils.utils import save_videos_grid # version check gradio_version = pkg_resources.get_distribution("gradio").version gradio_version_is_above_4 = int(gradio_version.split(".")[0]) >= 4 css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 2.5em; min-width: 2.5em !important; height: 2.5em; } """ # Scheduler dictionaries ddpm_scheduler_dict = { "Euler": EulerDiscreteScheduler, "Euler A": EulerAncestralDiscreteScheduler, "DPM++": DPMSolverMultistepScheduler, "PNDM": PNDMScheduler, "DDIM": DDIMScheduler, "DDIM_Origin": DDIMScheduler, "DDIM_Cog": CogVideoXDDIMScheduler, } flow_scheduler_dict = { "Flow": FlowMatchEulerDiscreteScheduler, } all_scheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict} # alias for backward compatibility all_cheduler_dict = all_scheduler_dict class Fun_Controller: def __init__(self, GPU_memory_mode, scheduler_dict, weight_dtype, config_path=None): self.basedir = os.getcwd() self.config_dir = os.path.join(self.basedir, "config") self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer") self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model") self.savedir = os.path.join( self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S") ) self.savedir_sample = os.path.join(self.savedir, "sample") self.model_type = "Inpaint" os.makedirs(self.savedir, exist_ok=True) self.diffusion_transformer_list = [] self.motion_module_list = [] self.personalized_model_list = [] self.refresh_diffusion_transformer() self.refresh_motion_module() self.refresh_personalized_model() self.tokenizer = None self.text_encoder = None self.vae = None self.transformer = None self.pipeline = None self.motion_module_path = "none" self.base_model_path = "none" self.lora_model_path = "none" self.GPU_memory_mode = GPU_memory_mode self.weight_dtype = weight_dtype self.scheduler_dict = scheduler_dict if config_path is not None: self.config = OmegaConf.load(config_path) def refresh_diffusion_transformer(self): self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/"))) def refresh_motion_module(self): self.motion_module_list = [ os.path.basename(p) for p in sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors"))) ] def refresh_personalized_model(self): self.personalized_model_list = [ os.path.basename(p) for p in sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors"))) ] def update_model_type(self, model_type): self.model_type = model_type def update_diffusion_transformer(self, diffusion_transformer_dropdown): pass def update_base_model(self, base_model_dropdown): self.base_model_path = base_model_dropdown if base_model_dropdown == "none": return gr.update() if self.transformer is None: gr.Info("Please select a pretrained model path.") return gr.update(value=None) path = os.path.join(self.personalized_model_dir, base_model_dropdown) state = {} with safe_open(path, framework="pt", device="cpu") as f: for k in f.keys(): state[k] = f.get_tensor(k) self.transformer.load_state_dict(state, strict=False) return gr.update() def update_lora_model(self, lora_model_dropdown): if lora_model_dropdown == "none": self.lora_model_path = "none" return gr.update() self.lora_model_path = os.path.join(self.personalized_model_dir, lora_model_dropdown) return gr.update() def clear_cache(self): gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() def input_check( self, resize_method, generation_method, start_image, end_image, validation_video, control_video, is_api=False, ): if self.transformer is None: raise gr.Error("Please select a pretrained model path.") if control_video is not None and self.model_type == "Inpaint": msg = 'If specifying the control video, please set model_type == "Control".' return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) if control_video is None and self.model_type == "Control": msg = 'If model_type == "Control", please specify a control video.' return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) if resize_method == "Resize according to Reference": if start_image is None and validation_video is None and control_video is None: msg = 'Please upload an image when using "Resize according to Reference".' return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) if ( self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None ): msg = "Please select an image-to-video pretrained model when using image-to-video." return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) if ( self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation" ): msg = "Please select an image-to-video pretrained model for long video generation." return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) if start_image is None and end_image is not None: msg = "If specifying an ending image, please specify a starting image." return ("", msg) if is_api else (_ for _ in ()).throw(gr.Error(msg)) def get_height_width_from_reference( self, base_resolution, start_image, validation_video, control_video ): aspect_ratio_sizes = { k: [x / 512 * base_resolution for x in v] for k, v in ASPECT_RATIO_512.items() } if self.model_type == "Inpaint": if validation_video: vid = cv2.VideoCapture(validation_video) _, frame = vid.read() w, h = Image.fromarray(frame).size else: img = start_image[0] if isinstance(start_image, list) else start_image w, h = Image.open(img).size else: vid = cv2.VideoCapture(control_video) _, frame = vid.read() w, h = Image.fromarray(frame).size (close_w, close_h), _ = get_closest_ratio(h, w, ratios=aspect_ratio_sizes) return (int(close_h // 16 * 16), int(close_w // 16 * 16)) def save_outputs(self, is_image, length_slider, sample, fps): os.makedirs(self.savedir_sample, exist_ok=True) idx = len(os.listdir(self.savedir_sample)) + 1 prefix = str(idx).zfill(3) if is_image or length_slider == 1: path = os.path.join(self.savedir_sample, f"{prefix}.png") img = sample[0, :, 0].transpose(0, 1).transpose(1, 2) img = (img * 255).numpy().astype(np.uint8) Image.fromarray(img).save(path) else: path = os.path.join(self.savedir_sample, f"{prefix}.mp4") save_videos_grid(sample, path, fps=fps) return path def generate( self, diffusion_transformer_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, base_resolution, generation_method, length_slider, overlap_video_length, partial_video_length, cfg_scale_slider, start_image, end_image, validation_video, validation_video_mask, control_video, denoise_strength, seed_textbox, is_api=False, ): # local generation logic (omitted) pass def post_eas( diffusion_transformer_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, base_resolution, generation_method, length_slider, cfg_scale_slider, start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox, ): # helper: encode file to base64 def _encode(path): with open(path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") if start_image: start_image = _encode(start_image) if end_image: end_image = _encode(end_image) if validation_video: validation_video = _encode(validation_video) if validation_video_mask: validation_video_mask = _encode(validation_video_mask) datas = { "base_model_path": base_model_dropdown, "lora_model_path": lora_model_dropdown, "lora_alpha_slider": lora_alpha_slider, "prompt_textbox": prompt_textbox, "negative_prompt_textbox": negative_prompt_textbox, "sampler_dropdown": sampler_dropdown, "sample_step_slider": sample_step_slider, "resize_method": resize_method, "width_slider": width_slider, "height_slider": height_slider, "base_resolution": base_resolution, "generation_method": generation_method, "length_slider": length_slider, "cfg_scale_slider": cfg_scale_slider, "start_image": start_image, "end_image": end_image, "validation_video": validation_video, "validation_video_mask": validation_video_mask, "denoise_strength": denoise_strength, "seed_textbox": seed_textbox, } session = requests.Session() if os.environ.get("EAS_TOKEN"): session.headers.update({"Authorization": os.environ["EAS_TOKEN"]}) eas_env = os.environ.get("EAS_URL") if eas_env: base_url = eas_env.rstrip("/") else: host = "127.0.0.1" port = os.environ.get("PORT", "7860") base_url = f"http://{host}:{port}" endpoint = f"{base_url}/cogvideox_fun/infer_forward" resp = session.post(url=endpoint, json=datas, timeout=300) return resp.json() class Fun_Controller_EAS: def __init__(self, model_name, scheduler_dict, savedir_sample): self.scheduler_dict = scheduler_dict self.savedir_sample = savedir_sample os.makedirs(self.savedir_sample, exist_ok=True) def generate( self, diffusion_transformer_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, base_resolution, generation_method, length_slider, cfg_scale_slider, start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox, ): is_image = generation_method == "Image Generation" outputs = post_eas( diffusion_transformer_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, base_resolution, generation_method, length_slider, cfg_scale_slider, start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox, ) if "base64_encoding" not in outputs: return ( gr.Image(visible=False, value=None), gr.Video(visible=True, value=None), outputs.get("message", "Unknown error"), ) data = base64.b64decode(outputs["base64_encoding"]) idx = len(os.listdir(self.savedir_sample)) + 1 prefix = str(idx).zfill(3) if is_image or length_slider == 1: path = os.path.join(self.savedir_sample, f"{prefix}.png") with open(path, "wb") as f: f.write(data) if gradio_version_is_above_4: return gr.Image(value=path, visible=True), gr.Video(value=None, visible=False), "Success" else: return ( gr.Image.update(value=path, visible=True), gr.Video.update(value=None, visible=False), "Success", ) else: path = os.path.join(self.savedir_sample, f"{prefix}.mp4") with open(path, "wb") as f: f.write(data) if gradio_version_is_above_4: return gr.Image(value=None, visible=False), gr.Video(value=path, visible=True), "Success" else: return ( gr.Image.update(value=None, visible=False), gr.Video.update(value=path, visible=True), "Success", )