| | """ |
| | Wrapper class to call the stablediffusion.cpp shared library for GGUF support |
| | """ |
| |
|
| | import ctypes |
| | import platform |
| | from ctypes import ( |
| | POINTER, |
| | c_bool, |
| | c_char_p, |
| | c_float, |
| | c_int, |
| | c_int64, |
| | c_void_p, |
| | ) |
| | from dataclasses import dataclass |
| | from os import path |
| | from typing import List, Any |
| |
|
| | import numpy as np |
| | from PIL import Image |
| |
|
| | from backend.gguf.sdcpp_types import ( |
| | RngType, |
| | SampleMethod, |
| | Schedule, |
| | SDCPPLogLevel, |
| | SDImage, |
| | SdType, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class ModelConfig: |
| | model_path: str = "" |
| | clip_l_path: str = "" |
| | t5xxl_path: str = "" |
| | diffusion_model_path: str = "" |
| | vae_path: str = "" |
| | taesd_path: str = "" |
| | control_net_path: str = "" |
| | lora_model_dir: str = "" |
| | embed_dir: str = "" |
| | stacked_id_embed_dir: str = "" |
| | vae_decode_only: bool = True |
| | vae_tiling: bool = False |
| | free_params_immediately: bool = False |
| | n_threads: int = 4 |
| | wtype: SdType = SdType.SD_TYPE_Q4_0 |
| | rng_type: RngType = RngType.CUDA_RNG |
| | schedule: Schedule = Schedule.DEFAULT |
| | keep_clip_on_cpu: bool = False |
| | keep_control_net_cpu: bool = False |
| | keep_vae_on_cpu: bool = False |
| |
|
| |
|
| | @dataclass |
| | class Txt2ImgConfig: |
| | prompt: str = "a man wearing sun glasses, highly detailed" |
| | negative_prompt: str = "" |
| | clip_skip: int = -1 |
| | cfg_scale: float = 2.0 |
| | guidance: float = 3.5 |
| | width: int = 512 |
| | height: int = 512 |
| | sample_method: SampleMethod = SampleMethod.EULER_A |
| | sample_steps: int = 1 |
| | seed: int = -1 |
| | batch_count: int = 2 |
| | control_cond: Image = None |
| | control_strength: float = 0.90 |
| | style_strength: float = 0.5 |
| | normalize_input: bool = False |
| | input_id_images_path: bytes = b"" |
| |
|
| |
|
| | class GGUFDiffusion: |
| | """GGUF Diffusion |
| | To support GGUF diffusion model based on stablediffusion.cpp |
| | https://github.com/ggerganov/ggml/blob/master/docs/gguf.md |
| | Implmented based on stablediffusion.h |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | libpath: str, |
| | config: ModelConfig, |
| | logging_enabled: bool = False, |
| | ): |
| | sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath) |
| | try: |
| | self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path) |
| | except OSError as e: |
| | print(f"Failed to load library {sdcpp_shared_lib_path}") |
| | raise ValueError(f"Error: {e}") |
| |
|
| | if not config.clip_l_path or not path.exists(config.clip_l_path): |
| | raise ValueError( |
| | "CLIP model file not found,please check readme.md for GGUF model usage" |
| | ) |
| |
|
| | if not config.t5xxl_path or not path.exists(config.t5xxl_path): |
| | raise ValueError( |
| | "T5XXL model file not found,please check readme.md for GGUF model usage" |
| | ) |
| |
|
| | if not config.diffusion_model_path or not path.exists( |
| | config.diffusion_model_path |
| | ): |
| | raise ValueError( |
| | "Diffusion model file not found,please check readme.md for GGUF model usage" |
| | ) |
| |
|
| | if not config.vae_path or not path.exists(config.vae_path): |
| | raise ValueError( |
| | "VAE model file not found,please check readme.md for GGUF model usage" |
| | ) |
| |
|
| | self.model_config = config |
| |
|
| | self.libsdcpp.new_sd_ctx.argtypes = [ |
| | c_char_p, |
| | c_char_p, |
| | c_char_p, |
| | c_char_p, |
| | c_char_p, |
| | c_char_p, |
| | c_char_p, |
| | c_char_p, |
| | c_char_p, |
| | c_char_p, |
| | c_bool, |
| | c_bool, |
| | c_bool, |
| | c_int, |
| | SdType, |
| | RngType, |
| | Schedule, |
| | c_bool, |
| | c_bool, |
| | c_bool, |
| | ] |
| |
|
| | self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p) |
| |
|
| | self.sd_ctx = self.libsdcpp.new_sd_ctx( |
| | self._str_to_bytes(self.model_config.model_path), |
| | self._str_to_bytes(self.model_config.clip_l_path), |
| | self._str_to_bytes(self.model_config.t5xxl_path), |
| | self._str_to_bytes(self.model_config.diffusion_model_path), |
| | self._str_to_bytes(self.model_config.vae_path), |
| | self._str_to_bytes(self.model_config.taesd_path), |
| | self._str_to_bytes(self.model_config.control_net_path), |
| | self._str_to_bytes(self.model_config.lora_model_dir), |
| | self._str_to_bytes(self.model_config.embed_dir), |
| | self._str_to_bytes(self.model_config.stacked_id_embed_dir), |
| | self.model_config.vae_decode_only, |
| | self.model_config.vae_tiling, |
| | self.model_config.free_params_immediately, |
| | self.model_config.n_threads, |
| | self.model_config.wtype, |
| | self.model_config.rng_type, |
| | self.model_config.schedule, |
| | self.model_config.keep_clip_on_cpu, |
| | self.model_config.keep_control_net_cpu, |
| | self.model_config.keep_vae_on_cpu, |
| | ) |
| |
|
| | if logging_enabled: |
| | self._set_logcallback() |
| |
|
| | def _set_logcallback(self): |
| | print("Setting logging callback") |
| | |
| | SdLogCallbackType = ctypes.CFUNCTYPE( |
| | None, |
| | SDCPPLogLevel, |
| | ctypes.c_char_p, |
| | ctypes.c_void_p, |
| | ) |
| |
|
| | self.libsdcpp.sd_set_log_callback.argtypes = [ |
| | SdLogCallbackType, |
| | ctypes.c_void_p, |
| | ] |
| | self.libsdcpp.sd_set_log_callback.restype = None |
| | |
| | self.c_log_callback = SdLogCallbackType( |
| | self.log_callback |
| | ) |
| | self.libsdcpp.sd_set_log_callback(self.c_log_callback, None) |
| |
|
| | def _get_sdcpp_shared_lib_path( |
| | self, |
| | root_path: str, |
| | ) -> str: |
| | system_name = platform.system() |
| | print(f"GGUF Diffusion on {system_name}") |
| | lib_name = "stable-diffusion.dll" |
| | sdcpp_lib_path = "" |
| |
|
| | if system_name == "Windows": |
| | sdcpp_lib_path = path.join(root_path, lib_name) |
| | elif system_name == "Linux": |
| | lib_name = "libstable-diffusion.so" |
| | sdcpp_lib_path = path.join(root_path, lib_name) |
| | elif system_name == "Darwin": |
| | lib_name = "libstable-diffusion.dylib" |
| | sdcpp_lib_path = path.join(root_path, lib_name) |
| | else: |
| | print("Unknown platform.") |
| |
|
| | return sdcpp_lib_path |
| |
|
| | @staticmethod |
| | def log_callback( |
| | level, |
| | text, |
| | data, |
| | ): |
| | print(f"{text.decode('utf-8')}", end="") |
| |
|
| | def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes: |
| | if in_str: |
| | return in_str.encode(encoding) |
| | else: |
| | return b"" |
| |
|
| | def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]: |
| | self.libsdcpp.txt2img.restype = POINTER(SDImage) |
| | self.libsdcpp.txt2img.argtypes = [ |
| | c_void_p, |
| | c_char_p, |
| | c_char_p, |
| | c_int, |
| | c_float, |
| | c_float, |
| | c_int, |
| | c_int, |
| | SampleMethod, |
| | c_int, |
| | c_int64, |
| | c_int, |
| | POINTER(SDImage), |
| | c_float, |
| | c_float, |
| | c_bool, |
| | c_char_p, |
| | ] |
| |
|
| | image_buffer = self.libsdcpp.txt2img( |
| | self.sd_ctx, |
| | self._str_to_bytes(txt2img_cfg.prompt), |
| | self._str_to_bytes(txt2img_cfg.negative_prompt), |
| | txt2img_cfg.clip_skip, |
| | txt2img_cfg.cfg_scale, |
| | txt2img_cfg.guidance, |
| | txt2img_cfg.width, |
| | txt2img_cfg.height, |
| | txt2img_cfg.sample_method, |
| | txt2img_cfg.sample_steps, |
| | txt2img_cfg.seed, |
| | txt2img_cfg.batch_count, |
| | txt2img_cfg.control_cond, |
| | txt2img_cfg.control_strength, |
| | txt2img_cfg.style_strength, |
| | txt2img_cfg.normalize_input, |
| | txt2img_cfg.input_id_images_path, |
| | ) |
| |
|
| | images = self._get_sd_images_from_buffer( |
| | image_buffer, |
| | txt2img_cfg.batch_count, |
| | ) |
| |
|
| | return images |
| |
|
| | def _get_sd_images_from_buffer( |
| | self, |
| | image_buffer: Any, |
| | batch_count: int, |
| | ) -> List[Any]: |
| | images = [] |
| | if image_buffer: |
| | for i in range(batch_count): |
| | image = image_buffer[i] |
| | print( |
| | f"Generated image: {image.width}x{image.height} with {image.channel} channels" |
| | ) |
| |
|
| | width = image.width |
| | height = image.height |
| | channels = image.channel |
| | pixel_data = np.ctypeslib.as_array( |
| | image.data, shape=(height, width, channels) |
| | ) |
| |
|
| | if channels == 1: |
| | pil_image = Image.fromarray(pixel_data.squeeze(), mode="L") |
| | elif channels == 3: |
| | pil_image = Image.fromarray(pixel_data, mode="RGB") |
| | elif channels == 4: |
| | pil_image = Image.fromarray(pixel_data, mode="RGBA") |
| | else: |
| | raise ValueError(f"Unsupported number of channels: {channels}") |
| |
|
| | images.append(pil_image) |
| | return images |
| |
|
| | def terminate(self): |
| | if self.libsdcpp: |
| | if self.sd_ctx: |
| | self.libsdcpp.free_sd_ctx.argtypes = [c_void_p] |
| | self.libsdcpp.free_sd_ctx.restype = None |
| | self.libsdcpp.free_sd_ctx(self.sd_ctx) |
| | del self.sd_ctx |
| | self.sd_ctx = None |
| | del self.libsdcpp |
| | self.libsdcpp = None |
| |
|