toc / ex /dynamic-prompts /sd_dynamic_prompts /generator_builder.py
Toc's picture
Upload 1820 files
8da9c80
from __future__ import annotations
import logging
from dynamicprompts.generators import (
BatchedCombinatorialPromptGenerator,
CombinatorialPromptGenerator,
DummyGenerator,
FeelingLuckyGenerator,
JinjaGenerator,
PromptGenerator,
RandomPromptGenerator,
)
from dynamicprompts.parser.parse import default_parser_config
from sd_dynamic_prompts.frozenprompt_generator import FrozenPromptGenerator
logger = logging.getLogger(__name__)
class GeneratorBuilder:
def __init__(
self,
wildcard_manager,
parser_config=default_parser_config,
ignore_whitespace=False,
):
self._wildcard_manager = wildcard_manager
self._is_dummy = False
self._should_freeze_prompt = False
self._is_feeling_lucky = False
self._is_jinja_template = False
self._is_combinatorial = False
self._is_magic_prompt = False
self._is_attention_grabber = False
self._combinatorial_batches = 1
self._magic_model = None
self._magic_prompt_length = 100
self._magic_temp_value = 0.7
self._magic_blocklist_regex = None
self._min_attention = 1.1
self._max_attention = 1.5
self._device = 0
self._ignore_whitespace = ignore_whitespace
self._unlink_seed_from_prompt = False
self._seed = -1
self._context = None
self._parser_config = parser_config
def log_configuration(self):
logger.debug(
f"""
Creating generator:
is_dummy: {self._is_dummy}
is_feeling_lucky: {self._is_feeling_lucky}
enable_jinja_templates: {self._is_jinja_template}
is_combinatorial: {self._is_combinatorial}
is_magic_prompt: {self._is_magic_prompt}
combinatorial_batches: {self._combinatorial_batches}
magic_prompt_length: {self._magic_prompt_length}
magic_temp_value: {self._magic_temp_value}
magic_blocklist_regex: {self._magic_blocklist_regex}
is_attention_grabber: {self._is_attention_grabber}
min_attention: {self._min_attention}
max_attention: {self._max_attention}
""",
)
def set_is_dummy(self, is_dummy=True):
self._is_dummy = is_dummy
return self
def set_is_feeling_lucky(self, is_feeling_lucky=True):
self._is_feeling_lucky = is_feeling_lucky
return self
def set_is_attention_grabber(
self,
is_attention_grabber=True,
min_attention=1.1,
max_attention=1.5,
):
self._is_attention_grabber = is_attention_grabber
self._min_attention = min_attention
self._max_attention = max_attention
return self
def set_is_jinja_template(self, is_jinja_template=True, limit_prompts=False):
self._is_jinja_template = is_jinja_template
self._limit_jinja_prompts = limit_prompts
return self
def set_is_combinatorial(self, is_combinatorial=True, combinatorial_batches=1):
self._is_combinatorial = is_combinatorial
self._combinatorial_batches = combinatorial_batches
return self
def set_is_magic_prompt(
self,
is_magic_prompt=True,
magic_model=None,
magic_prompt_length=100,
magic_temp_value=0.7,
device=0,
magic_blocklist_regex: str | None = None,
batch_size=1,
):
if not magic_model:
self._is_magic_prompt = False
return self
self._magic_model = magic_model
self._magic_prompt_length = magic_prompt_length
self._magic_temp_value = magic_temp_value
self._magic_blocklist_regex = magic_blocklist_regex
self._is_magic_prompt = is_magic_prompt
self._magic_batch_size = batch_size
self._device = device
return self
def set_unlink_seed_from_prompt(self, unlink_seed_from_prompt=True):
self._unlink_seed_from_prompt = unlink_seed_from_prompt
return self
def set_seed(self, seed):
self._seed = seed
return self
def set_freeze_prompt(self, should_freeze: bool):
self._should_freeze_prompt = should_freeze
return self
def set_context(self, context):
self._context = context
return self
def disable_prompt_magic(self):
self.set_is_attention_grabber(False)
self.set_is_magic_prompt(False)
self.set_is_feeling_lucky(False)
return self
def create_generator(self):
if self._is_dummy:
return DummyGenerator()
elif self._is_feeling_lucky:
generator = FeelingLuckyGenerator()
elif self._is_jinja_template:
generator = self.create_jinja_generator(self._context)
else:
generator = self.create_basic_generator()
if self._is_magic_prompt:
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
generator = MagicPromptGenerator(
generator,
model_name=self._magic_model,
device=self._device,
max_prompt_length=self._magic_prompt_length,
temperature=self._magic_temp_value,
seed=self._seed,
blocklist_regex=self._magic_blocklist_regex,
batch_size=self._magic_batch_size,
)
if self._is_attention_grabber:
try:
from sd_dynamic_prompts.attention_generator import (
SpecialSyntaxAwareAttentionGenerator,
)
generator = SpecialSyntaxAwareAttentionGenerator(
generator,
min_attention=self._min_attention,
max_attention=self._max_attention,
)
except ImportError as ie:
logger.error(f"Not using AttentionGenerator: {ie}")
if self._should_freeze_prompt:
generator = FrozenPromptGenerator(generator)
return generator
def create_basic_generator(
self,
) -> PromptGenerator:
if self._is_combinatorial:
prompt_generator = CombinatorialPromptGenerator(
self._wildcard_manager,
parser_config=self._parser_config,
ignore_whitespace=self._ignore_whitespace,
)
prompt_generator = BatchedCombinatorialPromptGenerator(
prompt_generator,
batches=self._combinatorial_batches,
)
else:
prompt_generator = RandomPromptGenerator(
self._wildcard_manager,
seed=self._seed,
parser_config=self._parser_config,
unlink_seed_from_prompt=self._unlink_seed_from_prompt,
ignore_whitespace=self._ignore_whitespace,
)
return prompt_generator
def create_jinja_generator(self, p) -> PromptGenerator:
original_prompt = p.all_prompts[0] if len(p.all_prompts) > 0 else p.prompt
original_negative_prompt = (
p.all_negative_prompts[0]
if len(p.all_negative_prompts) > 0
else p.negative_prompt
)
context = {
"model": {
"filename": p.sd_model.sd_checkpoint_info.filename,
"title": p.sd_model.sd_checkpoint_info.title,
"hash": p.sd_model.sd_checkpoint_info.hash,
"model_name": p.sd_model.sd_checkpoint_info.model_name,
},
"image": {
"width": p.width,
"height": p.height,
},
"parameters": {
"steps": p.steps,
"batch_size": p.batch_size,
"num_batches": p.n_iter,
"width": p.width,
"height": p.height,
"cfg_scale": p.cfg_scale,
"sampler_name": p.sampler_name,
"seed": p.seed,
},
"prompt": {
"prompt": original_prompt,
"negative_prompt": original_negative_prompt,
},
}
generator = JinjaGenerator(
self._wildcard_manager,
context,
limit_prompts=self._limit_jinja_prompts,
ignore_whitespace=self._ignore_whitespace,
)
return generator