Spaces:
Runtime error
Runtime error
| from typing import Literal, Union, Optional | |
| import torch, gc, os | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, T5TokenizerFast | |
| from transformers import ( | |
| AutoModel, | |
| CLIPModel, | |
| CLIPProcessor, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import ( | |
| UNet2DConditionModel, | |
| SchedulerMixin, | |
| StableDiffusionPipeline, | |
| StableDiffusionXLPipeline, | |
| FluxPipeline, | |
| AutoencoderKL, | |
| FluxTransformer2DModel, | |
| ) | |
| import copy | |
| from diffusers.schedulers import ( | |
| DDIMScheduler, | |
| DDPMScheduler, | |
| LMSDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| FlowMatchEulerDiscreteScheduler, | |
| ) | |
| from diffusers import LCMScheduler, AutoencoderTiny | |
| import sys | |
| sys.path.append('.') | |
| from .flux_utils import * | |
| TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" | |
| TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" | |
| AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"] | |
| SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] | |
| DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this | |
| def load_diffusers_model( | |
| pretrained_model_name_or_path: str, | |
| v2: bool = False, | |
| clip_skip: Optional[int] = None, | |
| weight_dtype: torch.dtype = torch.float32, | |
| ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: | |
| # VAE はいらない | |
| if v2: | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| TOKENIZER_V2_MODEL_NAME, | |
| subfolder="tokenizer", | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="text_encoder", | |
| # default is clip skip 2 | |
| num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ) | |
| else: | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| TOKENIZER_V1_MODEL_NAME, | |
| subfolder="tokenizer", | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="text_encoder", | |
| num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ) | |
| unet = UNet2DConditionModel.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="unet", | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ) | |
| return tokenizer, text_encoder, unet | |
| def load_checkpoint_model( | |
| checkpoint_path: str, | |
| v2: bool = False, | |
| clip_skip: Optional[int] = None, | |
| weight_dtype: torch.dtype = torch.float32, | |
| ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: | |
| pipe = StableDiffusionPipeline.from_ckpt( | |
| checkpoint_path, | |
| upcast_attention=True if v2 else False, | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ) | |
| unet = pipe.unet | |
| tokenizer = pipe.tokenizer | |
| text_encoder = pipe.text_encoder | |
| if clip_skip is not None: | |
| if v2: | |
| text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) | |
| else: | |
| text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) | |
| del pipe | |
| return tokenizer, text_encoder, unet | |
| def load_models( | |
| pretrained_model_name_or_path: str, | |
| scheduler_name: AVAILABLE_SCHEDULERS, | |
| v2: bool = False, | |
| v_pred: bool = False, | |
| weight_dtype: torch.dtype = torch.float32, | |
| ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: | |
| if pretrained_model_name_or_path.endswith( | |
| ".ckpt" | |
| ) or pretrained_model_name_or_path.endswith(".safetensors"): | |
| tokenizer, text_encoder, unet = load_checkpoint_model( | |
| pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype | |
| ) | |
| else: # diffusers | |
| tokenizer, text_encoder, unet = load_diffusers_model( | |
| pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype | |
| ) | |
| # VAE はいらない | |
| scheduler = create_noise_scheduler( | |
| scheduler_name, | |
| prediction_type="v_prediction" if v_pred else "epsilon", | |
| ) | |
| return tokenizer, text_encoder, unet, scheduler | |
| def load_diffusers_model_xl( | |
| pretrained_model_name_or_path: str, | |
| weight_dtype: torch.dtype = torch.float32, | |
| ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: | |
| # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet | |
| tokenizers = [ | |
| CLIPTokenizer.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="tokenizer", | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ), | |
| CLIPTokenizer.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="tokenizer_2", | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| pad_token_id=0, # same as open clip | |
| ), | |
| ] | |
| text_encoders = [ | |
| CLIPTextModel.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="text_encoder", | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ), | |
| CLIPTextModelWithProjection.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="text_encoder_2", | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ), | |
| ] | |
| unet = UNet2DConditionModel.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="unet", | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ) | |
| return tokenizers, text_encoders, unet | |
| def load_checkpoint_model_xl( | |
| checkpoint_path: str, | |
| weight_dtype: torch.dtype = torch.float32, | |
| ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| checkpoint_path, | |
| torch_dtype=weight_dtype, | |
| cache_dir=DIFFUSERS_CACHE_DIR, | |
| ) | |
| unet = pipe.unet | |
| tokenizers = [pipe.tokenizer, pipe.tokenizer_2] | |
| text_encoders = [pipe.text_encoder, pipe.text_encoder_2] | |
| if len(text_encoders) == 2: | |
| text_encoders[1].pad_token_id = 0 | |
| del pipe | |
| return tokenizers, text_encoders, unet | |
| def load_models_xl_( | |
| pretrained_model_name_or_path: str, | |
| scheduler_name: AVAILABLE_SCHEDULERS, | |
| weight_dtype: torch.dtype = torch.float32, | |
| ) -> tuple[ | |
| list[CLIPTokenizer], | |
| list[SDXL_TEXT_ENCODER_TYPE], | |
| UNet2DConditionModel, | |
| SchedulerMixin, | |
| ]: | |
| if pretrained_model_name_or_path.endswith( | |
| ".ckpt" | |
| ) or pretrained_model_name_or_path.endswith(".safetensors"): | |
| ( | |
| tokenizers, | |
| text_encoders, | |
| unet, | |
| ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype) | |
| else: # diffusers | |
| ( | |
| tokenizers, | |
| text_encoders, | |
| unet, | |
| ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype) | |
| scheduler = create_noise_scheduler(scheduler_name) | |
| return tokenizers, text_encoders, unet, scheduler | |
| def create_noise_scheduler( | |
| scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", | |
| prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", | |
| ) -> SchedulerMixin: | |
| # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。 | |
| name = scheduler_name.lower().replace(" ", "_") | |
| if name == "ddim": | |
| # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim | |
| scheduler = DDIMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000, | |
| clip_sample=False, | |
| prediction_type=prediction_type, # これでいいの? | |
| ) | |
| elif name == "ddpm": | |
| # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm | |
| scheduler = DDPMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000, | |
| clip_sample=False, | |
| prediction_type=prediction_type, | |
| ) | |
| elif name == "lms": | |
| # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete | |
| scheduler = LMSDiscreteScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000, | |
| prediction_type=prediction_type, | |
| ) | |
| elif name == "euler_a": | |
| # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral | |
| scheduler = EulerAncestralDiscreteScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000, | |
| # clip_sample=False, | |
| prediction_type=prediction_type, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown scheduler name: {name}") | |
| return scheduler | |
| def load_models_xl(params): | |
| """ | |
| Load all required models for training | |
| Args: | |
| params: Dictionary containing model parameters and configurations | |
| Returns: | |
| dict: Dictionary containing all loaded models and tokenizers | |
| """ | |
| device = params['device'] | |
| weight_dtype = params['weight_dtype'] | |
| # Load SDXL components (UNet, text encoders, tokenizers) | |
| scheduler_name = 'ddim' | |
| tokenizers, text_encoders, unet, noise_scheduler = load_models_xl_( | |
| params['pretrained_model_name_or_path'], | |
| scheduler_name=scheduler_name, | |
| ) | |
| # Move text encoders to device and set to eval mode | |
| for text_encoder in text_encoders: | |
| text_encoder.to(device, dtype=weight_dtype) | |
| text_encoder.requires_grad_(False) | |
| text_encoder.eval() | |
| # Set up UNet | |
| unet.to(device, dtype=weight_dtype) | |
| unet.requires_grad_(False) | |
| unet.eval() | |
| # Load tiny VAE for efficiency | |
| vae = AutoencoderTiny.from_pretrained( | |
| "madebyollin/taesdxl", | |
| torch_dtype=weight_dtype | |
| ) | |
| vae = vae.to(device, dtype=weight_dtype) | |
| vae.requires_grad_(False) | |
| # Load appropriate encoder (CLIP or DinoV2) | |
| if params['encoder'] == 'dinov2-small': | |
| clip_model = AutoModel.from_pretrained( | |
| 'facebook/dinov2-small', | |
| torch_dtype=weight_dtype | |
| ) | |
| clip_processor= None | |
| else: | |
| clip_model = CLIPModel.from_pretrained( | |
| "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", | |
| torch_dtype=weight_dtype | |
| ) | |
| clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M") | |
| clip_model = clip_model.to(device, dtype=weight_dtype) | |
| clip_model.requires_grad_(False) | |
| # If using DMD checkpoint, load it | |
| if params['distilled'] != 'None': | |
| if '.safetensors' in params['distilled']: | |
| unet.load_state_dict(load_file(params['distilled'], device=device)) | |
| elif 'dmd2' in params['distilled']: | |
| repo_name = "tianweiy/DMD2" | |
| ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin" | |
| unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name))) | |
| else: | |
| unet.load_state_dict(torch.load(params['distilled'])) | |
| # Set up LCM scheduler for DMD | |
| noise_scheduler = LCMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000, | |
| prediction_type="epsilon", | |
| original_inference_steps=1000 | |
| ) | |
| noise_scheduler.set_timesteps(params['max_denoising_steps']) | |
| pipe = StableDiffusionXLPipeline(vae = vae, | |
| text_encoder = text_encoders[0], | |
| text_encoder_2 = text_encoders[1], | |
| tokenizer = tokenizers[0], | |
| tokenizer_2 = tokenizers[1], | |
| unet = unet, | |
| scheduler = noise_scheduler) | |
| pipe.set_progress_bar_config(disable=True) | |
| return { | |
| 'unet': unet, | |
| 'vae': vae, | |
| 'clip_model': clip_model, | |
| 'clip_processor': clip_processor, | |
| 'tokenizers': tokenizers, | |
| 'text_encoders': text_encoders, | |
| 'noise_scheduler': noise_scheduler | |
| }, pipe | |
| def load_models_flux(params): | |
| # Load the tokenizers | |
| tokenizer_one = CLIPTokenizer.from_pretrained( | |
| params['pretrained_model_name_or_path'], | |
| subfolder="tokenizer", | |
| torch_dtype=params['weight_dtype'], device_map=params['device'] | |
| ) | |
| tokenizer_two = T5TokenizerFast.from_pretrained( | |
| params['pretrained_model_name_or_path'], | |
| subfolder="tokenizer_2", | |
| torch_dtype=params['weight_dtype'], device_map=params['device'] | |
| ) | |
| # Load scheduler and models | |
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| params['pretrained_model_name_or_path'], | |
| subfolder="scheduler", | |
| torch_dtype=params['weight_dtype'], device=params['device'] | |
| ) | |
| noise_scheduler_copy = copy.deepcopy(noise_scheduler) | |
| # import correct text encoder classes | |
| text_encoder_cls_one = import_model_class_from_model_name_or_path( | |
| params['pretrained_model_name_or_path'], | |
| ) | |
| text_encoder_cls_two = import_model_class_from_model_name_or_path( | |
| params['pretrained_model_name_or_path'], subfolder="text_encoder_2" | |
| ) | |
| # Load the text encoders | |
| text_encoder_one, text_encoder_two = load_text_encoders(params['pretrained_model_name_or_path'], text_encoder_cls_one, text_encoder_cls_two, params['weight_dtype']) | |
| # Load VAE | |
| vae = AutoencoderKL.from_pretrained( | |
| params['pretrained_model_name_or_path'], | |
| subfolder="vae", | |
| torch_dtype=params['weight_dtype'], device_map='auto' | |
| ) | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| params['pretrained_model_name_or_path'], | |
| subfolder="transformer", | |
| torch_dtype=params['weight_dtype'] | |
| ) | |
| # We only train the additional adapter LoRA layers | |
| transformer.requires_grad_(False) | |
| vae.requires_grad_(False) | |
| text_encoder_one.requires_grad_(False) | |
| text_encoder_two.requires_grad_(False) | |
| vae.to(params['device']) | |
| transformer.to(params['device']) | |
| text_encoder_one.to(params['device']) | |
| text_encoder_two.to(params['device']) | |
| # Load appropriate encoder (CLIP or DinoV2) | |
| if params['encoder'] == 'dinov2-small': | |
| clip_model = AutoModel.from_pretrained( | |
| 'facebook/dinov2-small', | |
| torch_dtype=params['weight_dtype'] | |
| ) | |
| clip_processor= None | |
| else: | |
| clip_model = CLIPModel.from_pretrained( | |
| "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", | |
| torch_dtype=params['weight_dtype'] | |
| ) | |
| clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M") | |
| clip_model = clip_model.to(params['device'], dtype=params['weight_dtype']) | |
| clip_model.requires_grad_(False) | |
| pipe = FluxPipeline(noise_scheduler, | |
| vae, | |
| text_encoder_one, | |
| tokenizer_one, | |
| text_encoder_two, | |
| tokenizer_two, | |
| transformer, | |
| ) | |
| pipe.set_progress_bar_config(disable=True) | |
| return { | |
| 'transformer': transformer, | |
| 'vae': vae, | |
| 'clip_model': clip_model, | |
| 'clip_processor': clip_processor, | |
| 'tokenizers': [tokenizer_one, tokenizer_two], | |
| 'text_encoders': [text_encoder_one,text_encoder_two], | |
| 'noise_scheduler': noise_scheduler | |
| }, pipe | |
| def save_checkpoint(networks, save_path, weight_dtype): | |
| """ | |
| Save network weights and perform cleanup | |
| Args: | |
| networks: Dictionary of LoRA networks to save | |
| save_path: Path to save the checkpoints | |
| weight_dtype: Data type for the weights | |
| """ | |
| print("Saving checkpoint...") | |
| try: | |
| # Create save directory if it doesn't exist | |
| os.makedirs(save_path, exist_ok=True) | |
| # Save each network's weights | |
| for net_idx, network in networks.items(): | |
| save_name = f"{save_path}/slider_{net_idx}.pt" | |
| try: | |
| network.save_weights( | |
| save_name, | |
| dtype=weight_dtype, | |
| ) | |
| except Exception as e: | |
| print(f"Error saving network {net_idx}: {str(e)}") | |
| continue | |
| # Cleanup | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print("Checkpoint saved successfully.") | |
| except Exception as e: | |
| print(f"Error during checkpoint saving: {str(e)}") | |
| finally: | |
| # Ensure memory is cleaned up even if save fails | |
| torch.cuda.empty_cache() | |
| gc.collect() |