ishworrsubedii's picture
Added new features and improved code formatting:
32a0eda
"""
Created By: ishwor subedi
Date: 2024-08-13
"""
import random
import numpy as np
import torch
from typing import Tuple, List
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from src.utils.imutils import yaml_read
class ImageGenerator:
def __init__(self, model_name: str = "RunDiffusion/Juggernaut-X-v10", device: str = "cuda"):
self.pipe = StableDiffusionXLPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
)
self.pipe.to(device)
self.MAX_SEED = np.iinfo(np.int32).max
self.styles = self._initialize_styles()
self.DEFAULT_STYLE_NAME = "(No style)"
def _initialize_styles(self):
style_list = yaml_read("params.yaml")['style_list']
return {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
def randomize_seed_fn(self, seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, self.MAX_SEED)
return seed
def apply_style(self, style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
p, n = self.styles.get(style_name, self.styles[self.DEFAULT_STYLE_NAME])
if not negative:
negative = ""
return p.replace("{prompt}", positive), n + negative
def generate_image(self, prompt: str,
negative_prompt: str = "",
style: str = None,
use_negative_prompt: bool = False,
num_inference_steps: int = 30,
num_images_per_prompt: int = 1,
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 3,
randomize_seed: bool = False,
) -> Tuple[List[Image.Image], int]:
if style is None:
style = self.DEFAULT_STYLE_NAME
seed = self.randomize_seed_fn(seed, randomize_seed)
if not use_negative_prompt:
negative_prompt = ""
prompt, negative_prompt = self.apply_style(style, prompt, negative_prompt)
images = self.pipe(prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images_per_prompt,
cross_attention_kwargs={"scale": 0.65},
output_type="pil").images
return images, seed