| | import torch, warnings, glob, os |
| | import numpy as np |
| | from PIL import Image |
| | from einops import repeat, reduce |
| | from typing import Optional, Union |
| | from dataclasses import dataclass |
| | from modelscope import snapshot_download |
| | import numpy as np |
| | from PIL import Image |
| | from typing import Optional |
| |
|
| |
|
| | class BasePipeline(torch.nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | device="cuda", torch_dtype=torch.float16, |
| | height_division_factor=64, width_division_factor=64, |
| | time_division_factor=None, time_division_remainder=None, |
| | ): |
| | super().__init__() |
| | |
| | self.device = device |
| | self.torch_dtype = torch_dtype |
| | |
| | self.height_division_factor = height_division_factor |
| | self.width_division_factor = width_division_factor |
| | self.time_division_factor = time_division_factor |
| | self.time_division_remainder = time_division_remainder |
| | self.vram_management_enabled = False |
| | |
| | |
| | def to(self, *args, **kwargs): |
| | device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) |
| | if device is not None: |
| | self.device = device |
| | if dtype is not None: |
| | self.torch_dtype = dtype |
| | super().to(*args, **kwargs) |
| | return self |
| |
|
| |
|
| | def check_resize_height_width(self, height, width, num_frames=None): |
| | |
| | if height % self.height_division_factor != 0: |
| | height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor |
| | print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") |
| | if width % self.width_division_factor != 0: |
| | width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor |
| | print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") |
| | if num_frames is None: |
| | return height, width |
| | else: |
| | if num_frames % self.time_division_factor != self.time_division_remainder: |
| | num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder |
| | print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") |
| | return height, width, num_frames |
| |
|
| |
|
| | def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): |
| | |
| | image = torch.Tensor(np.array(image, dtype=np.float32)) |
| | image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) |
| | image = image * ((max_value - min_value) / 255) + min_value |
| | image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) |
| | return image |
| |
|
| |
|
| | def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): |
| | |
| | video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] |
| | video = torch.stack(video, dim=pattern.index("T") // 2) |
| | return video |
| |
|
| |
|
| | def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): |
| | |
| | if pattern != "H W C": |
| | vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") |
| | image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) |
| | image = image.to(device="cpu", dtype=torch.uint8) |
| | image = Image.fromarray(image.numpy()) |
| | return image |
| |
|
| |
|
| | def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): |
| | |
| | if pattern != "T H W C": |
| | vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") |
| | video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] |
| | return video |
| |
|
| |
|
| | def load_models_to_device(self, model_names=[]): |
| | if self.vram_management_enabled: |
| | |
| | for name, model in self.named_children(): |
| | if name not in model_names: |
| | if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: |
| | for module in model.modules(): |
| | if hasattr(module, "offload"): |
| | module.offload() |
| | else: |
| | model.cpu() |
| | torch.cuda.empty_cache() |
| | |
| | for name, model in self.named_children(): |
| | if name in model_names: |
| | if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: |
| | for module in model.modules(): |
| | if hasattr(module, "onload"): |
| | module.onload() |
| | else: |
| | model.to(self.device) |
| |
|
| |
|
| | def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): |
| | |
| | generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) |
| | noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) |
| | noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) |
| | return noise |
| |
|
| |
|
| | def enable_cpu_offload(self): |
| | warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.") |
| | self.vram_management_enabled = True |
| | |
| | |
| | def get_vram(self): |
| | return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3) |
| | |
| | |
| | def freeze_except(self, model_names): |
| | for name, model in self.named_children(): |
| | if name in model_names: |
| | model.train() |
| | model.requires_grad_(True) |
| | else: |
| | model.eval() |
| | model.requires_grad_(False) |
| |
|
| |
|
| | @dataclass |
| | class ModelConfig: |
| | path: Union[str, list[str]] = None |
| | model_id: str = None |
| | origin_file_pattern: Union[str, list[str]] = None |
| | download_resource: str = "ModelScope" |
| | offload_device: Optional[Union[str, torch.device]] = None |
| | offload_dtype: Optional[torch.dtype] = None |
| | local_model_path: str = None |
| | skip_download: bool = False |
| |
|
| | def download_if_necessary(self, use_usp=False): |
| | if self.path is None: |
| | |
| | if self.model_id is None: |
| | raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") |
| | |
| | |
| | if use_usp: |
| | import torch.distributed as dist |
| | skip_download = self.skip_download or dist.get_rank() != 0 |
| | else: |
| | skip_download = self.skip_download |
| | |
| | |
| | if self.origin_file_pattern is None or self.origin_file_pattern == "": |
| | self.origin_file_pattern = "" |
| | allow_file_pattern = None |
| | is_folder = True |
| | elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"): |
| | allow_file_pattern = self.origin_file_pattern + "*" |
| | is_folder = True |
| | else: |
| | allow_file_pattern = self.origin_file_pattern |
| | is_folder = False |
| | |
| | |
| | if self.local_model_path is None: |
| | self.local_model_path = "./models" |
| | if not skip_download: |
| | downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) |
| | snapshot_download( |
| | self.model_id, |
| | local_dir=os.path.join(self.local_model_path, self.model_id), |
| | allow_file_pattern=allow_file_pattern, |
| | ignore_file_pattern=downloaded_files, |
| | local_files_only=False |
| | ) |
| | |
| | |
| | if use_usp: |
| | import torch.distributed as dist |
| | dist.barrier(device_ids=[dist.get_rank()]) |
| | |
| | |
| | if is_folder: |
| | self.path = os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern) |
| | else: |
| | self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) |
| | if isinstance(self.path, list) and len(self.path) == 1: |
| | self.path = self.path[0] |
| |
|
| |
|
| |
|
| | class PipelineUnit: |
| | def __init__( |
| | self, |
| | seperate_cfg: bool = False, |
| | take_over: bool = False, |
| | input_params: tuple[str] = None, |
| | input_params_posi: dict[str, str] = None, |
| | input_params_nega: dict[str, str] = None, |
| | onload_model_names: tuple[str] = None |
| | ): |
| | self.seperate_cfg = seperate_cfg |
| | self.take_over = take_over |
| | self.input_params = input_params |
| | self.input_params_posi = input_params_posi |
| | self.input_params_nega = input_params_nega |
| | self.onload_model_names = onload_model_names |
| |
|
| |
|
| | def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict: |
| | raise NotImplementedError("`process` is not implemented.") |
| |
|
| |
|
| |
|
| | class PipelineUnitRunner: |
| | def __init__(self): |
| | pass |
| |
|
| | def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: |
| | if unit.take_over: |
| | |
| | inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) |
| | elif unit.seperate_cfg: |
| | |
| | processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} |
| | if unit.input_params is not None: |
| | for name in unit.input_params: |
| | processor_inputs[name] = inputs_shared.get(name) |
| | processor_outputs = unit.process(pipe, **processor_inputs) |
| | inputs_posi.update(processor_outputs) |
| | |
| | if inputs_shared["cfg_scale"] != 1: |
| | processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} |
| | if unit.input_params is not None: |
| | for name in unit.input_params: |
| | processor_inputs[name] = inputs_shared.get(name) |
| | processor_outputs = unit.process(pipe, **processor_inputs) |
| | inputs_nega.update(processor_outputs) |
| | else: |
| | inputs_nega.update(processor_outputs) |
| | else: |
| | processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} |
| | processor_outputs = unit.process(pipe, **processor_inputs) |
| | inputs_shared.update(processor_outputs) |
| | return inputs_shared, inputs_posi, inputs_nega |
| |
|