Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Optional, Union | |
| import torch | |
| from omegaconf import OmegaConf | |
| from .model.dit import get_dit, parallelize | |
| from .model.text_embedders import get_text_embedder | |
| from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler | |
| from omegaconf.dictconfig import DictConfig | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from .t2v_pipeline import Kandinsky4T2VPipeline | |
| from torch.distributed.device_mesh import DeviceMesh, init_device_mesh | |
| def get_T2V_pipeline( | |
| device_map: Union[str, torch.device, dict], | |
| resolution: int = 512, | |
| cache_dir: str = './weights/', | |
| dit_path: str = None, | |
| text_encoder_path: str = None, | |
| tokenizer_path: str = None, | |
| vae_path: str = None, | |
| scheduler_path: str = None, | |
| conf_path: str = None, | |
| ) -> Kandinsky4T2VPipeline: | |
| assert resolution in [512] | |
| if not isinstance(device_map, dict): | |
| device_map = { | |
| 'dit': device_map, | |
| 'vae': device_map, | |
| 'text_embedder': device_map | |
| } | |
| try: | |
| local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) | |
| except: | |
| local_rank, world_size = 0, 1 | |
| if world_size > 1: | |
| device_mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("tensor_parallel",)) | |
| device_map["dit"] = torch.device(f'cuda:{local_rank}') | |
| os.makedirs(cache_dir, exist_ok=True) | |
| if dit_path is None: | |
| dit_path = hf_hub_download( | |
| repo_id="ai-forever/kandinsky4", filename=f"kandinsky4_distil_{resolution}.pt", local_dir=cache_dir | |
| ) | |
| if vae_path is None: | |
| vae_path = snapshot_download( | |
| repo_id="THUDM/CogVideoX-5b", allow_patterns='vae/*', local_dir=cache_dir | |
| ) | |
| vae_path = os.path.join(cache_dir, f"vae/") | |
| if scheduler_path is None: | |
| scheduler_path = snapshot_download( | |
| repo_id="THUDM/CogVideoX-5b", allow_patterns='scheduler/*', local_dir=cache_dir | |
| ) | |
| scheduler_path = os.path.join(cache_dir, f"scheduler/") | |
| if text_encoder_path is None: | |
| text_encoder_path = snapshot_download( | |
| repo_id="THUDM/CogVideoX-5b", allow_patterns='text_encoder/*', local_dir=cache_dir | |
| ) | |
| text_encoder_path = os.path.join(cache_dir, f"text_encoder/") | |
| if tokenizer_path is None: | |
| tokenizer_path = snapshot_download( | |
| repo_id="THUDM/CogVideoX-5b", allow_patterns='tokenizer/*', local_dir=cache_dir | |
| ) | |
| tokenizer_path = os.path.join(cache_dir, f"tokenizer/") | |
| if conf_path is None: | |
| conf = get_default_conf(vae_path, text_encoder_path, tokenizer_path, scheduler_path, dit_path) | |
| else: | |
| conf = OmegaConf.load(conf_path) | |
| dit = get_dit(conf.dit) | |
| dit = dit.to(dtype=torch.bfloat16, device=device_map["dit"]) | |
| noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(conf.dit.scheduler) | |
| if world_size > 1: | |
| dit = parallelize(dit, device_mesh["tensor_parallel"]) | |
| text_embedder = get_text_embedder(conf) | |
| text_embedder = text_embedder.freeze() | |
| if local_rank == 0: | |
| text_embedder = text_embedder.to(device=device_map["text_embedder"], dtype=torch.bfloat16) | |
| vae = AutoencoderKLCogVideoX.from_pretrained(conf.vae.checkpoint_path) | |
| vae = vae.eval() | |
| if local_rank == 0: | |
| vae = vae.to(device_map["vae"], dtype=torch.bfloat16) | |
| return Kandinsky4T2VPipeline( | |
| device_map=device_map, | |
| dit=dit, | |
| text_embedder=text_embedder, | |
| vae=vae, | |
| noise_scheduler=noise_scheduler, | |
| resolution=resolution, | |
| local_dit_rank=local_rank, | |
| world_size=world_size, | |
| ) | |
| def get_default_conf( | |
| vae_path, | |
| text_encoder_path, | |
| tokenizer_path, | |
| scheduler_path, | |
| dit_path, | |
| ) -> DictConfig: | |
| dit_params = { | |
| 'in_visual_dim': 16, | |
| 'in_text_dim': 4096, | |
| 'out_visual_dim': 16, | |
| 'time_dim': 512, | |
| 'patch_size': [1, 2, 2], | |
| 'model_dim': 3072, | |
| 'ff_dim': 12288, | |
| 'num_blocks': 21, | |
| 'axes_dims': [16, 24, 24] | |
| } | |
| conf = { | |
| 'vae': | |
| { | |
| 'checkpoint_path': vae_path | |
| }, | |
| 'text_embedder': | |
| { | |
| 'emb_size': 4096, | |
| 'tokens_lenght': 224, | |
| 'params': | |
| { | |
| 'checkpoint_path': text_encoder_path, | |
| 'tokenizer_path': tokenizer_path | |
| } | |
| }, | |
| 'dit': | |
| { | |
| 'scheduler': scheduler_path, | |
| 'checkpoint_path': dit_path, | |
| 'params': dit_params | |
| }, | |
| 'resolution': 512, | |
| } | |
| return DictConfig(conf) |