| """ | |
| Created By: ishwor subedi | |
| Date: 2024-08-13 | |
| """ | |
| import random | |
| import numpy as np | |
| import torch | |
| from typing import Tuple, List | |
| from diffusers import StableDiffusionXLPipeline | |
| from PIL import Image | |
| from src.utils.imutils import yaml_read | |
| class ImageGenerator: | |
| def __init__(self, model_name: str = "RunDiffusion/Juggernaut-X-v10", device: str = "cuda"): | |
| self.pipe = StableDiffusionXLPipeline.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| ) | |
| self.pipe.to(device) | |
| self.MAX_SEED = np.iinfo(np.int32).max | |
| self.styles = self._initialize_styles() | |
| self.DEFAULT_STYLE_NAME = "(No style)" | |
| def _initialize_styles(self): | |
| style_list = yaml_read("params.yaml")['style_list'] | |
| return {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} | |
| def randomize_seed_fn(self, seed: int, randomize_seed: bool) -> int: | |
| if randomize_seed: | |
| seed = random.randint(0, self.MAX_SEED) | |
| return seed | |
| def apply_style(self, style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: | |
| p, n = self.styles.get(style_name, self.styles[self.DEFAULT_STYLE_NAME]) | |
| if not negative: | |
| negative = "" | |
| return p.replace("{prompt}", positive), n + negative | |
| def generate_image(self, prompt: str, | |
| negative_prompt: str = "", | |
| style: str = None, | |
| use_negative_prompt: bool = False, | |
| num_inference_steps: int = 30, | |
| num_images_per_prompt: int = 1, | |
| seed: int = 0, | |
| width: int = 1024, | |
| height: int = 1024, | |
| guidance_scale: float = 3, | |
| randomize_seed: bool = False, | |
| ) -> Tuple[List[Image.Image], int]: | |
| if style is None: | |
| style = self.DEFAULT_STYLE_NAME | |
| seed = self.randomize_seed_fn(seed, randomize_seed) | |
| if not use_negative_prompt: | |
| negative_prompt = "" | |
| prompt, negative_prompt = self.apply_style(style, prompt, negative_prompt) | |
| images = self.pipe(prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| num_images_per_prompt=num_images_per_prompt, | |
| cross_attention_kwargs={"scale": 0.65}, | |
| output_type="pil").images | |
| return images, seed | |