File size: 2,712 Bytes
32a0eda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Created By: ishwor subedi
Date: 2024-08-13
"""
import random
import numpy as np
import torch
from typing import Tuple, List
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from src.utils.imutils import yaml_read


class ImageGenerator:
    def __init__(self, model_name: str = "RunDiffusion/Juggernaut-X-v10", device: str = "cuda"):
        self.pipe = StableDiffusionXLPipeline.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
        )
        self.pipe.to(device)
        self.MAX_SEED = np.iinfo(np.int32).max
        self.styles = self._initialize_styles()
        self.DEFAULT_STYLE_NAME = "(No style)"

    def _initialize_styles(self):
        style_list = yaml_read("params.yaml")['style_list']

        return {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}

    def randomize_seed_fn(self, seed: int, randomize_seed: bool) -> int:
        if randomize_seed:
            seed = random.randint(0, self.MAX_SEED)
        return seed

    def apply_style(self, style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
        p, n = self.styles.get(style_name, self.styles[self.DEFAULT_STYLE_NAME])
        if not negative:
            negative = ""
        return p.replace("{prompt}", positive), n + negative

    def generate_image(self, prompt: str,
                       negative_prompt: str = "",
                       style: str = None,
                       use_negative_prompt: bool = False,
                       num_inference_steps: int = 30,
                       num_images_per_prompt: int = 1,
                       seed: int = 0,
                       width: int = 1024,
                       height: int = 1024,
                       guidance_scale: float = 3,
                       randomize_seed: bool = False,
                       ) -> Tuple[List[Image.Image], int]:
        if style is None:
            style = self.DEFAULT_STYLE_NAME
        seed = self.randomize_seed_fn(seed, randomize_seed)
        if not use_negative_prompt:
            negative_prompt = ""
        prompt, negative_prompt = self.apply_style(style, prompt, negative_prompt)

        images = self.pipe(prompt=prompt,
                           negative_prompt=negative_prompt,
                           width=width,
                           height=height,
                           guidance_scale=guidance_scale,
                           num_inference_steps=num_inference_steps,
                           num_images_per_prompt=num_images_per_prompt,
                           cross_attention_kwargs={"scale": 0.65},
                           output_type="pil").images

        return images, seed