Spaces:
Runtime error
Runtime error
| import hashlib | |
| import io | |
| import re | |
| import os | |
| import imageio | |
| import numpy as np | |
| from typing import Union | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| import random | |
| import torch | |
| import PIL.Image | |
| import PIL.ImageOps | |
| from PIL import Image | |
| from typing import Callable, Union | |
| import torch | |
| import torchvision | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| import decord | |
| decord.bridge.set_bridge('torch') | |
| from PIL import Image, ImageOps | |
| from safetensors import safe_open | |
| # from tqdm import tqdm | |
| from einops import rearrange | |
| from motionclone.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint,convert_ldm_clip_checkpoint_concise | |
| from motionclone.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora | |
| from huggingface_hub import snapshot_download | |
| # from transformers import ( | |
| # AutoFeatureExtractor, | |
| # BertTokenizerFast, | |
| # CLIPImageProcessor, | |
| # CLIPTextConfig, | |
| # CLIPTextModel, | |
| # CLIPTextModelWithProjection, | |
| # CLIPTokenizer, | |
| # CLIPVisionConfig, | |
| # CLIPVisionModelWithProjection, | |
| # ) | |
| MOTION_MODULES = [ | |
| "mm_sd_v14.ckpt", | |
| "mm_sd_v15.ckpt", | |
| "mm_sd_v15_v2.ckpt", | |
| "v3_sd15_mm.ckpt", | |
| ] | |
| ADAPTERS = [ | |
| # "mm_sd_v14.ckpt", | |
| # "mm_sd_v15.ckpt", | |
| # "mm_sd_v15_v2.ckpt", | |
| # "mm_sdxl_v10_beta.ckpt", | |
| "v2_lora_PanLeft.ckpt", | |
| "v2_lora_PanRight.ckpt", | |
| "v2_lora_RollingAnticlockwise.ckpt", | |
| "v2_lora_RollingClockwise.ckpt", | |
| "v2_lora_TiltDown.ckpt", | |
| "v2_lora_TiltUp.ckpt", | |
| "v2_lora_ZoomIn.ckpt", | |
| "v2_lora_ZoomOut.ckpt", | |
| "v3_sd15_adapter.ckpt", | |
| # "v3_sd15_mm.ckpt", | |
| "v3_sd15_sparsectrl_rgb.ckpt", | |
| "v3_sd15_sparsectrl_scribble.ckpt", | |
| ] | |
| BACKUP_DREAMBOOTH_MODELS = [ | |
| "realisticVisionV60B1_v51VAE.safetensors", | |
| "majicmixRealistic_v4.safetensors", | |
| "leosamsFilmgirlUltra_velvia20Lora.safetensors", | |
| "toonyou_beta3.safetensors", | |
| "majicmixRealistic_v5Preview.safetensors", | |
| "rcnzCartoon3d_v10.safetensors", | |
| "lyriel_v16.safetensors", | |
| "leosamsHelloworldXL_filmGrain20.safetensors", | |
| "TUSUN.safetensors", | |
| ] | |
| def zero_rank_print(s): | |
| if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) | |
| def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): | |
| videos = rearrange(videos, "b c t h w -> t b c h w") | |
| outputs = [] | |
| for x in videos: | |
| x = torchvision.utils.make_grid(x, nrow=n_rows) | |
| x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
| if rescale: | |
| x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
| x = (x * 255).numpy().astype(np.uint8) | |
| outputs.append(x) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| imageio.mimsave(path, outputs, fps=fps) | |
| def auto_download(local_path, is_dreambooth_lora=False): | |
| hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff" | |
| folder, filename = os.path.split(local_path) | |
| if not os.path.exists(local_path): | |
| print(f"local file {local_path} does not exist. trying to download from {hf_repo}") | |
| if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}" | |
| else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}" | |
| folder = "." if folder == "" else folder | |
| os.makedirs(folder, exist_ok=True) | |
| snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename]) | |
| def load_weights( | |
| animation_pipeline, | |
| # motion module | |
| motion_module_path = "", | |
| motion_module_lora_configs = [], | |
| # domain adapter | |
| adapter_lora_path = "", | |
| adapter_lora_scale = 1.0, | |
| # image layers | |
| dreambooth_model_path = "", | |
| lora_model_path = "", | |
| lora_alpha = 0.8, | |
| ): | |
| # motion module | |
| unet_state_dict = {} | |
| if motion_module_path != "": | |
| print(f"load motion module from {motion_module_path}") | |
| motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") | |
| motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict | |
| unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) | |
| unet_state_dict.pop("animatediff_config", "") | |
| missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) | |
| # assert len(unexpected) == 0 | |
| del unet_state_dict | |
| # base model | |
| if dreambooth_model_path != "": | |
| print(f"load dreambooth model from {dreambooth_model_path}") | |
| if dreambooth_model_path.endswith(".safetensors"): | |
| # import pdb; pdb.set_trace() | |
| dreambooth_state_dict = {} | |
| # import safetensors | |
| # dreambooth_state_dict = safetensors.torch.load_file(dreambooth_model_path) | |
| # import pdb; pdb.set_trace() | |
| with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| dreambooth_state_dict[key] = f.get_tensor(key) | |
| # import pdb; pdb.set_trace() | |
| elif dreambooth_model_path.endswith(".ckpt"): | |
| dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") | |
| # 1. vae | |
| converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) | |
| animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) | |
| # 2. unet | |
| converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) | |
| animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) | |
| # 3. text_model | |
| # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) | |
| converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_concise(dreambooth_state_dict) | |
| animation_pipeline.text_encoder.load_state_dict(converted_text_encoder_checkpoint, strict=True) | |
| del dreambooth_state_dict, converted_vae_checkpoint, converted_unet_checkpoint, converted_text_encoder_checkpoint | |
| # clip_config_name = "models/clip-vit-large-patch14" | |
| # clip_config = CLIPTextConfig.from_pretrained(clip_config_name, local_files_only=True) | |
| # text_model = CLIPTextModel(clip_config) | |
| # keys = list(dreambooth_state_dict.keys()) | |
| # text_model_dict = {} | |
| # for key in keys: | |
| # if key.startswith("cond_stage_model.transformer"): | |
| # text_model_dict[key[len("cond_stage_model.transformer.") :]] = dreambooth_state_dict[key] | |
| # text_model.load_state_dict(text_model_dict) | |
| # animation_pipeline.text_encoder = text_model.to(dtype=animation_pipeline.unet.dtype) | |
| # # import pdb; pdb.set_trace() | |
| # # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) | |
| # del dreambooth_state_dict | |
| # lora layers | |
| if lora_model_path != "": | |
| print(f"load lora model from {lora_model_path}") | |
| assert lora_model_path.endswith(".safetensors") | |
| lora_state_dict = {} | |
| with safe_open(lora_model_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| lora_state_dict[key] = f.get_tensor(key) | |
| animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) | |
| del lora_state_dict | |
| # domain adapter lora | |
| if adapter_lora_path != "": | |
| print(f"load domain lora from {adapter_lora_path}") | |
| domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") | |
| domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict | |
| domain_lora_state_dict.pop("animatediff_config", "") | |
| animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) | |
| # motion module lora | |
| for motion_module_lora_config in motion_module_lora_configs: | |
| path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] | |
| print(f"load motion LoRA from {path}") | |
| motion_lora_state_dict = torch.load(path, map_location="cpu") | |
| motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict | |
| motion_lora_state_dict.pop("animatediff_config", "") | |
| animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) | |
| return animation_pipeline | |
| def video_preprocess(video_path, height, width, video_length, duration=None, sample_start_idx=0,): | |
| #print("in video_preprocess:") | |
| #print(video_path, height, width, video_length, duration, sample_start_idx) | |
| video_name = video_path.split('/')[-1].split('.')[0] | |
| vr = decord.VideoReader(video_path) | |
| fps = vr.get_avg_fps() | |
| if duration is None: | |
| # 读取整个视频 | |
| total_frames = len(vr) | |
| else: | |
| # 根据给定的时长(秒)计算帧数 | |
| total_frames = int(fps * duration) | |
| total_frames = min(total_frames, len(vr)) # 确保不超过视频总长度 | |
| sample_index = np.linspace(0, total_frames - 1, video_length, dtype=int) | |
| print(total_frames,sample_index) | |
| video = vr.get_batch(sample_index) | |
| ''' | |
| print("after betch :") | |
| print(video) | |
| ''' | |
| import torch | |
| if hasattr(video, "asnumpy"): | |
| video = video.asnumpy() | |
| video = torch.from_numpy(video) | |
| video = rearrange(video, "f h w c -> f c h w") | |
| video = F.interpolate(video, size=(height, width), mode="bilinear", align_corners=True) | |
| # video_sample = rearrange(video, "(b f) c h w -> b f h w c", f=video_length) | |
| # imageio.mimwrite(f"processed_videos/sample_{video_name}.mp4", video_sample[0], fps=8, quality=9) | |
| video = video / 127.5 - 1.0 | |
| return video | |
| def set_nested_item(dataDict, mapList, value): | |
| """Set item in nested dictionary""" | |
| """ | |
| Example: the mapList contains the name of each key ['injection','self-attn'] | |
| this method will change the content in dataDict['injection']['self-attn'] with value | |
| """ | |
| for k in mapList[:-1]: | |
| dataDict = dataDict[k] | |
| dataDict[mapList[-1]] = value | |
| def merge_sweep_config(base_config, update): | |
| """Merge the updated parameters into the base config""" | |
| if base_config is None: | |
| raise ValueError("Base config is None") | |
| if update is None: | |
| raise ValueError("Update config is None") | |
| for key in update.keys(): | |
| map_list = key.split("--") | |
| set_nested_item(base_config, map_list, update[key]) | |
| return base_config | |
| # Adapt from https://github.com/castorini/daam | |
| def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0): | |
| merge_idxs = [] | |
| tokens = tokenizer.tokenize(prompt.lower()) | |
| if word_idx is None: | |
| word = word.lower() | |
| search_tokens = tokenizer.tokenize(word) | |
| start_indices = [x + offset_idx for x in range(len(tokens)) if | |
| tokens[x:x + len(search_tokens)] == search_tokens] | |
| for indice in start_indices: | |
| merge_idxs += [i + indice for i in range(0, len(search_tokens))] | |
| if not merge_idxs: | |
| raise Exception(f'Search word {word} not found in prompt!') | |
| else: | |
| merge_idxs.append(word_idx) | |
| return [x + 1 for x in merge_idxs], word_idx # Offset by 1. | |
| def extract_data(input_string: str) -> list: | |
| print("input_string:", input_string) | |
| """ | |
| Extract data from a string pattern where contents in () are separated by ; | |
| The first item in each () is considered as 'ref' and the rest as 'gen'. | |
| Args: | |
| - input_string (str): The input string pattern. | |
| Returns: | |
| - list: A list of dictionaries containing 'ref' and 'gen'. | |
| """ | |
| pattern = r'\(([^)]+)\)' | |
| matches = re.findall(pattern, input_string) | |
| data = [] | |
| for match in matches: | |
| parts = [x.strip() for x in match.split(';')] | |
| ref = parts[0].strip() | |
| gen = parts[1].strip() | |
| data.append({'ref': ref, 'gen': gen}) | |
| return data | |
| def generate_hash_key(image, prompt=""): | |
| """ | |
| Generate a hash key for the given image and prompt. | |
| """ | |
| byte_array = io.BytesIO() | |
| image.save(byte_array, format='JPEG') | |
| # Get byte data | |
| image_byte_data = byte_array.getvalue() | |
| # Combine image byte data and prompt byte data | |
| combined_data = image_byte_data + prompt.encode('utf-8') | |
| sha256 = hashlib.sha256() | |
| sha256.update(combined_data) | |
| return sha256.hexdigest() | |
| def save_data(data, folder_path, key): | |
| """ | |
| Save data to a file, using key as the file name | |
| """ | |
| if not os.path.exists(folder_path): | |
| os.makedirs(folder_path) | |
| file_path = os.path.join(folder_path, f"{key}.pt") | |
| torch.save(data, file_path) | |
| def get_data(folder_path, key): | |
| """ | |
| Get data from a file, using key as the file name | |
| :param folder_path: | |
| :param key: | |
| :return: | |
| """ | |
| file_path = os.path.join(folder_path, f"{key}.pt") | |
| if os.path.exists(file_path): | |
| return torch.load(file_path) | |
| else: | |
| return None | |
| def PILtoTensor(data: Image.Image) -> torch.Tensor: | |
| return torch.tensor(np.array(data)).permute(2, 0, 1).unsqueeze(0).float() | |
| def TensorToPIL(data: torch.Tensor) -> Image.Image: | |
| return Image.fromarray(data.squeeze().permute(1, 2, 0).numpy().astype(np.uint8)) | |
| # Adapt from https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/utils/loading_utils.py#L9 | |
| def load_image( | |
| image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None | |
| ) -> PIL.Image.Image: | |
| """ | |
| Loads `image` to a PIL Image. | |
| Args: | |
| image (`str` or `PIL.Image.Image`): | |
| The image to convert to the PIL Image format. | |
| convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional): | |
| A conversion method to apply to the image after loading it. | |
| When set to `None` the image will be converted "RGB". | |
| Returns: | |
| `PIL.Image.Image`: | |
| A PIL Image. | |
| """ | |
| if isinstance(image, str): | |
| if image.startswith("http://") or image.startswith("https://"): | |
| image = PIL.Image.open(requests.get(image, stream=True).raw) | |
| elif os.path.isfile(image): | |
| image = PIL.Image.open(image) | |
| else: | |
| raise ValueError( | |
| f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." | |
| ) | |
| elif isinstance(image, PIL.Image.Image): | |
| image = image | |
| else: | |
| raise ValueError( | |
| "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." | |
| ) | |
| image = PIL.ImageOps.exif_transpose(image) | |
| if convert_method is not None: | |
| image = convert_method(image) | |
| else: | |
| image = image.convert("RGB") | |
| return image | |
| # Take from huggingface/diffusers | |
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
| """ | |
| Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
| """ | |
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
| # rescale the results from guidance (fixes overexposure) | |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
| # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
| noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
| return noise_cfg | |
| def _in_step(config, step): | |
| in_step = False | |
| try: | |
| start_step = config.start_step | |
| end_step = config.end_step | |
| if start_step <= step < end_step: | |
| in_step = True | |
| except: | |
| in_step = False | |
| return in_step | |
| def classify_blocks(block_list, name): | |
| is_correct_block = False | |
| for block in block_list: | |
| if block in name: | |
| is_correct_block = True | |
| break | |
| return is_correct_block | |
| def set_all_seed(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.backends.cudnn.deterministic = True |