| | from constants import LCM_DEFAULT_MODEL |
| | from diffusers import ( |
| | DiffusionPipeline, |
| | AutoencoderTiny, |
| | UNet2DConditionModel, |
| | LCMScheduler, |
| | StableDiffusionPipeline, |
| | ) |
| | import torch |
| | from backend.tiny_autoencoder import get_tiny_autoencoder_repo_id |
| | from typing import Any |
| | from diffusers import ( |
| | LCMScheduler, |
| | StableDiffusionImg2ImgPipeline, |
| | StableDiffusionXLImg2ImgPipeline, |
| | AutoPipelineForText2Image, |
| | AutoPipelineForImage2Image, |
| | StableDiffusionControlNetPipeline, |
| | ) |
| | import pathlib |
| |
|
| |
|
| | def _get_lcm_pipeline_from_base_model( |
| | lcm_model_id: str, |
| | base_model_id: str, |
| | use_local_model: bool, |
| | ): |
| | pipeline = None |
| | unet = UNet2DConditionModel.from_pretrained( |
| | lcm_model_id, |
| | torch_dtype=torch.float32, |
| | local_files_only=use_local_model, |
| | resume_download=True, |
| | ) |
| | pipeline = DiffusionPipeline.from_pretrained( |
| | base_model_id, |
| | unet=unet, |
| | torch_dtype=torch.float32, |
| | local_files_only=use_local_model, |
| | resume_download=True, |
| | ) |
| | pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config) |
| | return pipeline |
| |
|
| |
|
| | def load_taesd( |
| | pipeline: Any, |
| | use_local_model: bool = False, |
| | torch_data_type: torch.dtype = torch.float32, |
| | ): |
| | tiny_vae = get_tiny_autoencoder_repo_id(pipeline.__class__.__name__) |
| | pipeline.vae = AutoencoderTiny.from_pretrained( |
| | tiny_vae, |
| | torch_dtype=torch_data_type, |
| | local_files_only=use_local_model, |
| | ) |
| |
|
| |
|
| | def get_lcm_model_pipeline( |
| | model_id: str = LCM_DEFAULT_MODEL, |
| | use_local_model: bool = False, |
| | pipeline_args={}, |
| | ): |
| | pipeline = None |
| | if model_id == "latent-consistency/lcm-sdxl": |
| | pipeline = _get_lcm_pipeline_from_base_model( |
| | model_id, |
| | "stabilityai/stable-diffusion-xl-base-1.0", |
| | use_local_model, |
| | ) |
| |
|
| | elif model_id == "latent-consistency/lcm-ssd-1b": |
| | pipeline = _get_lcm_pipeline_from_base_model( |
| | model_id, |
| | "segmind/SSD-1B", |
| | use_local_model, |
| | ) |
| | elif pathlib.Path(model_id).suffix == ".safetensors": |
| | |
| | |
| | |
| | dummy_pipeline = StableDiffusionPipeline.from_single_file( |
| | model_id, |
| | safety_checker=None, |
| | run_safety_checker=False, |
| | load_safety_checker=False, |
| | local_files_only=use_local_model, |
| | use_safetensors=True, |
| | ) |
| | if "lcm" in model_id.lower(): |
| | dummy_pipeline.scheduler = LCMScheduler.from_config( |
| | dummy_pipeline.scheduler.config |
| | ) |
| |
|
| | pipeline = AutoPipelineForText2Image.from_pipe( |
| | dummy_pipeline, |
| | **pipeline_args, |
| | ) |
| | del dummy_pipeline |
| | else: |
| | |
| | pipeline = AutoPipelineForText2Image.from_pretrained( |
| | model_id, |
| | local_files_only=use_local_model, |
| | **pipeline_args, |
| | ) |
| |
|
| | return pipeline |
| |
|
| |
|
| | def get_image_to_image_pipeline(pipeline: Any) -> Any: |
| | components = pipeline.components |
| | pipeline_class = pipeline.__class__.__name__ |
| | if ( |
| | pipeline_class == "LatentConsistencyModelPipeline" |
| | or pipeline_class == "StableDiffusionPipeline" |
| | ): |
| | return StableDiffusionImg2ImgPipeline(**components) |
| | elif pipeline_class == "StableDiffusionControlNetPipeline": |
| | return AutoPipelineForImage2Image.from_pipe(pipeline) |
| | elif pipeline_class == "StableDiffusionXLPipeline": |
| | return StableDiffusionXLImg2ImgPipeline(**components) |
| | else: |
| | raise Exception(f"Unknown pipeline {pipeline_class}") |
| |
|