|
|
from PIL import Image |
|
|
|
|
|
from modules import shared |
|
|
from modules import devices |
|
|
from modules import images |
|
|
|
|
|
from sd_bmab import util |
|
|
from sd_bmab import constants |
|
|
from sd_bmab.base import filter |
|
|
from sd_bmab.util import debug_print |
|
|
from sd_bmab.base import process_txt2img, process_img2img_with_controlnet, Context, ProcessorBase |
|
|
|
|
|
|
|
|
class ResamplePreprocessor(ProcessorBase): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.resample_opt = {} |
|
|
self.enabled = False |
|
|
self.hiresfix_enabled = False |
|
|
self.save_image = False |
|
|
self.method = 'txt2img-1pass' |
|
|
self.checkpoint = constants.checkpoint_default |
|
|
self.vae = constants.vae_default |
|
|
self.filter = 'None' |
|
|
self.prompt = None |
|
|
self.negative_prompt = None |
|
|
self.sampler = None |
|
|
self.scheduler = None |
|
|
self.upscaler = None |
|
|
self.steps = 20 |
|
|
self.cfg_scale = 0.7 |
|
|
self.denoising_strength = 0.75 |
|
|
self.strength = 0.5 |
|
|
self.begin = 0.0 |
|
|
self.end = 1.0 |
|
|
|
|
|
self.base_sd_model = None |
|
|
|
|
|
def use_controlnet(self, context: Context): |
|
|
return self.preprocess(context, None) |
|
|
|
|
|
def preprocess(self, context: Context, image: Image): |
|
|
self.enabled = context.args['resample_enabled'] |
|
|
self.resample_opt = context.args.get('module_config', {}).get('resample_opt', {}) |
|
|
|
|
|
self.hiresfix_enabled = self.resample_opt.get('hiresfix_enabled', self.hiresfix_enabled) |
|
|
self.save_image = self.resample_opt.get('save_image', self.save_image) |
|
|
self.method = self.resample_opt.get('method', self.method) |
|
|
self.checkpoint = self.resample_opt.get('checkpoint', self.checkpoint) |
|
|
self.vae = self.resample_opt.get('vae', self.vae) |
|
|
self.filter = self.resample_opt.get('filter', self.filter) |
|
|
self.prompt = self.resample_opt.get('prompt', self.prompt) |
|
|
self.negative_prompt = self.resample_opt.get('negative_prompt', self.negative_prompt) |
|
|
self.sampler = self.resample_opt.get('sampler', self.sampler) |
|
|
self.scheduler = self.resample_opt.get('scheduler', self.scheduler) |
|
|
self.upscaler = self.resample_opt.get('upscaler', self.upscaler) |
|
|
self.steps = self.resample_opt.get('steps', self.steps) |
|
|
self.cfg_scale = self.resample_opt.get('cfg_scale', self.cfg_scale) |
|
|
self.denoising_strength = self.resample_opt.get('denoising_strength', self.denoising_strength) |
|
|
self.strength = self.resample_opt.get('scale', self.strength) |
|
|
self.begin = self.resample_opt.get('width', self.begin) |
|
|
self.end = self.resample_opt.get('height', self.end) |
|
|
|
|
|
if self.enabled and self.hiresfix_enabled: |
|
|
return False |
|
|
return self.enabled |
|
|
|
|
|
@staticmethod |
|
|
def get_resample_args(image, weight, begin, end): |
|
|
cn_args = { |
|
|
'enabled': True, |
|
|
'image': util.b64_encoding(image.convert('RGB')), |
|
|
'module': 'tile_resample', |
|
|
'model': shared.opts.bmab_cn_tile_resample, |
|
|
'weight': weight, |
|
|
"guidance_start": begin, |
|
|
"guidance_end": end, |
|
|
'resize_mode': 'Just Resize', |
|
|
'pixel_perfect': False, |
|
|
'control_mode': 'ControlNet is more important', |
|
|
'processor_res': 512, |
|
|
'threshold_a': 1, |
|
|
'threshold_b': 1, |
|
|
} |
|
|
return cn_args |
|
|
|
|
|
def process(self, context: Context, image: Image): |
|
|
if self.prompt == '': |
|
|
self.prompt = context.get_prompt_by_index() |
|
|
debug_print('prompt', self.prompt) |
|
|
elif self.prompt.find('#!org!#') >= 0: |
|
|
current_prompt = context.get_prompt_by_index() |
|
|
self.prompt = self.prompt.replace('#!org!#', current_prompt) |
|
|
debug_print('Prompt', self.prompt) |
|
|
if self.negative_prompt == '': |
|
|
self.negative_prompt = context.sdprocessing.negative_prompt |
|
|
if self.checkpoint == constants.checkpoint_default: |
|
|
self.checkpoint = context.sdprocessing.sd_model |
|
|
if self.sampler == constants.sampler_default: |
|
|
self.sampler = context.sdprocessing.sampler_name |
|
|
if self.scheduler == constants.scheduler_default: |
|
|
self.scheduler = util.get_scheduler(context.sdprocessing) |
|
|
|
|
|
bmab_filter = filter.get_filter(self.filter) |
|
|
|
|
|
seed, subseed = context.get_seeds() |
|
|
options = dict( |
|
|
seed=seed, subseed=subseed, |
|
|
denoising_strength=self.denoising_strength, |
|
|
prompt=self.prompt, |
|
|
negative_prompt=self.negative_prompt, |
|
|
sampler_name=self.sampler, |
|
|
scheduler=self.scheduler, |
|
|
steps=self.steps, |
|
|
cfg_scale=self.cfg_scale, |
|
|
) |
|
|
|
|
|
if self.checkpoint != constants.checkpoint_default: |
|
|
override_settings = options.get('override_settings', {}) |
|
|
override_settings['sd_model_checkpoint'] = self.checkpoint |
|
|
options['override_settings'] = override_settings |
|
|
if self.vae != constants.vae_default: |
|
|
override_settings = options.get('override_settings', {}) |
|
|
override_settings['sd_vae'] = self.vae |
|
|
options['override_settings'] = override_settings |
|
|
|
|
|
filter.preprocess_filter(bmab_filter, context, image, options) |
|
|
|
|
|
context.add_job() |
|
|
if self.save_image: |
|
|
saved = image.copy() |
|
|
images.save_image( |
|
|
saved, context.sdprocessing.outpath_samples, '', |
|
|
context.sdprocessing.all_seeds[context.index], context.sdprocessing.all_prompts[context.index], |
|
|
shared.opts.samples_format, p=context.sdprocessing, suffix='-before-resample') |
|
|
context.add_extra_image(saved) |
|
|
cn_op_arg = self.get_resample_args(image, self.strength, self.begin, self.end) |
|
|
|
|
|
processed = image.copy() |
|
|
if self.hiresfix_enabled: |
|
|
if self.method == 'txt2img-1pass' or self.method == 'txt2img-2pass': |
|
|
options['width'] = processed.width |
|
|
options['height'] = processed.height |
|
|
processed = process_txt2img(context, options=options, controlnet=[cn_op_arg]) |
|
|
elif self.method == 'img2img-1pass': |
|
|
del cn_op_arg['input_image'] |
|
|
options['width'] = processed.width |
|
|
options['height'] = processed.height |
|
|
processed = process_img2img_with_controlnet(context, image, options=options, controlnet=[cn_op_arg]) |
|
|
else: |
|
|
if self.method == 'txt2img-1pass': |
|
|
if context.is_hires_fix(): |
|
|
if context.sdprocessing.hr_resize_x != 0 or context.sdprocessing.hr_resize_y != 0: |
|
|
options['width'] = context.sdprocessing.hr_resize_x |
|
|
options['height'] = context.sdprocessing.hr_resize_y |
|
|
else: |
|
|
options['width'] = int(context.sdprocessing.width * context.sdprocessing.hr_scale) |
|
|
options['height'] = int(context.sdprocessing.height * context.sdprocessing.hr_scale) |
|
|
processed = process_txt2img(context, options=options, controlnet=[cn_op_arg]) |
|
|
elif self.method == 'txt2img-2pass': |
|
|
if context.is_txtimg() and context.is_hires_fix(): |
|
|
options.update(dict( |
|
|
enable_hr=context.sdprocessing.enable_hr, |
|
|
hr_scale=context.sdprocessing.hr_scale, |
|
|
hr_resize_x=context.sdprocessing.hr_resize_x, |
|
|
hr_resize_y=context.sdprocessing.hr_resize_y, |
|
|
)) |
|
|
processed = process_txt2img(context, options=options, controlnet=[cn_op_arg]) |
|
|
elif self.method == 'img2img-1pass': |
|
|
del cn_op_arg['input_image'] |
|
|
processed = process_img2img_with_controlnet(context, image, options=options, controlnet=[cn_op_arg]) |
|
|
|
|
|
image = filter.process_filter(bmab_filter, context, image, processed) |
|
|
filter.postprocess_filter(bmab_filter, context) |
|
|
|
|
|
return image |
|
|
|
|
|
def postprocess(self, context: Context, image: Image): |
|
|
devices.torch_gc() |
|
|
|
|
|
|
|
|
class ResamplePreprocessorBeforeUpscale(ResamplePreprocessor): |
|
|
|
|
|
def preprocess(self, context: Context, image: Image): |
|
|
super().preprocess(context, image) |
|
|
return self.enabled and self.hiresfix_enabled and (context.is_hires_fix() or context.is_img2img()) |
|
|
|
|
|
|