Spaces:
Paused
Paused
File size: 7,280 Bytes
8822914 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import gc
import os
from collections import OrderedDict
from typing import ForwardRef, List, Optional, Union
import torch
from safetensors.torch import save_file, load_file
from jobs.process.BaseProcess import BaseProcess
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
add_base_model_info_to_meta
from toolkit.sampler import get_sampler
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
import random
from toolkit.util.get_model import get_model_class
class GenerateConfig:
def __init__(self, **kwargs):
self.prompts: List[str]
self.sampler = kwargs.get('sampler', 'ddpm')
self.width = kwargs.get('width', 512)
self.height = kwargs.get('height', 512)
self.size_list: Union[List[int], None] = kwargs.get('size_list', None)
self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.prompt_2 = kwargs.get('prompt_2', None)
self.neg_2 = kwargs.get('neg_2', None)
self.prompts = kwargs.get('prompts', None)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.compile = kwargs.get('compile', False)
self.ext = kwargs.get('ext', 'png')
self.prompt_file = kwargs.get('prompt_file', False)
self.num_repeats = kwargs.get('num_repeats', 1)
self.prompts_in_file = self.prompts
if self.prompts is None:
raise ValueError("Prompts must be set")
if isinstance(self.prompts, str):
if os.path.exists(self.prompts):
with open(self.prompts, 'r', encoding='utf-8') as f:
self.prompts_in_file = f.read().splitlines()
self.prompts_in_file = [p.strip() for p in self.prompts_in_file if len(p.strip()) > 0]
else:
raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
self.random_prompts = kwargs.get('random_prompts', False)
self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1)
self.max_images = kwargs.get('max_images', 10000)
if self.random_prompts:
self.prompts = []
for i in range(self.max_images):
num_prompts = random.randint(1, self.max_random_per_prompt)
prompt_list = [random.choice(self.prompts_in_file) for _ in range(num_prompts)]
self.prompts.append(", ".join(prompt_list))
else:
self.prompts = self.prompts_in_file
if kwargs.get('shuffle', False):
# shuffle the prompts
random.shuffle(self.prompts)
class GenerateProcess(BaseProcess):
process_id: int
config: OrderedDict
progress_bar: ForwardRef('tqdm') = None
sd: StableDiffusion
def __init__(
self,
process_id: int,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.device = self.get_conf('device', self.job.device)
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16'))
self.progress_bar = None
ModelClass = get_model_class(self.model_config)
# if the model class has get_train_scheduler static method
if hasattr(ModelClass, 'get_train_scheduler'):
sampler = ModelClass.get_train_scheduler()
else:
# get the noise scheduler
arch = 'sd'
if self.model_config.is_pixart:
arch = 'pixart'
if self.model_config.is_flux:
arch = 'flux'
if self.model_config.is_lumina2:
arch = 'lumina2'
sampler = get_sampler(
self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
},
arch=arch,
)
self.sd = ModelClass(
device=self.device,
model_config=self.model_config,
dtype=self.model_config.dtype,
noise_scheduler=sampler,
)
print(f"Using device {self.device}")
def clean_prompt(self, prompt: str):
# remove any non alpha numeric characters or ,'" from prompt
return ''.join(e for e in prompt if e.isalnum() or e in ", '\"")
def run(self):
with torch.no_grad():
super().run()
print("Loading model...")
self.sd.load_model()
self.sd.pipeline.to(self.device, self.torch_dtype)
print("Compiling model...")
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
if self.generate_config.compile:
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
print(f"Generating {len(self.generate_config.prompts)} images")
# build prompt image configs
prompt_image_configs = []
for _ in range(self.generate_config.num_repeats):
for prompt in self.generate_config.prompts:
# remove --
prompt = prompt.replace('--', '').strip()
width = self.generate_config.width
height = self.generate_config.height
# prompt = self.clean_prompt(prompt)
if self.generate_config.size_list is not None:
# randomly select a size
width, height = random.choice(self.generate_config.size_list)
prompt_image_configs.append(GenerateImageConfig(
prompt=prompt,
prompt_2=self.generate_config.prompt_2,
width=width,
height=height,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg,
negative_prompt_2=self.generate_config.neg_2,
seed=self.generate_config.seed,
guidance_rescale=self.generate_config.guidance_rescale,
output_ext=self.generate_config.ext,
output_folder=self.output_folder,
add_prompt_file=self.generate_config.prompt_file
))
# generate images
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()
|