from PIL import Image from modules import shared from modules import devices from modules import images from sd_bmab import util from sd_bmab import constants from sd_bmab.base import filter from sd_bmab.util import debug_print from sd_bmab.base import process_txt2img, process_img2img_with_controlnet, Context, ProcessorBase class ResamplePreprocessor(ProcessorBase): def __init__(self) -> None: super().__init__() self.resample_opt = {} self.enabled = False self.hiresfix_enabled = False self.save_image = False self.method = 'txt2img-1pass' self.checkpoint = constants.checkpoint_default self.vae = constants.vae_default self.filter = 'None' self.prompt = None self.negative_prompt = None self.sampler = None self.scheduler = None self.upscaler = None self.steps = 20 self.cfg_scale = 0.7 self.denoising_strength = 0.75 self.strength = 0.5 self.begin = 0.0 self.end = 1.0 self.base_sd_model = None def use_controlnet(self, context: Context): return self.preprocess(context, None) def preprocess(self, context: Context, image: Image): self.enabled = context.args['resample_enabled'] self.resample_opt = context.args.get('module_config', {}).get('resample_opt', {}) self.hiresfix_enabled = self.resample_opt.get('hiresfix_enabled', self.hiresfix_enabled) self.save_image = self.resample_opt.get('save_image', self.save_image) self.method = self.resample_opt.get('method', self.method) self.checkpoint = self.resample_opt.get('checkpoint', self.checkpoint) self.vae = self.resample_opt.get('vae', self.vae) self.filter = self.resample_opt.get('filter', self.filter) self.prompt = self.resample_opt.get('prompt', self.prompt) self.negative_prompt = self.resample_opt.get('negative_prompt', self.negative_prompt) self.sampler = self.resample_opt.get('sampler', self.sampler) self.scheduler = self.resample_opt.get('scheduler', self.scheduler) self.upscaler = self.resample_opt.get('upscaler', self.upscaler) self.steps = self.resample_opt.get('steps', self.steps) self.cfg_scale = self.resample_opt.get('cfg_scale', self.cfg_scale) self.denoising_strength = self.resample_opt.get('denoising_strength', self.denoising_strength) self.strength = self.resample_opt.get('scale', self.strength) self.begin = self.resample_opt.get('width', self.begin) self.end = self.resample_opt.get('height', self.end) if self.enabled and self.hiresfix_enabled: return False return self.enabled @staticmethod def get_resample_args(image, weight, begin, end): cn_args = { 'enabled': True, 'image': util.b64_encoding(image.convert('RGB')), 'module': 'tile_resample', 'model': shared.opts.bmab_cn_tile_resample, 'weight': weight, "guidance_start": begin, "guidance_end": end, 'resize_mode': 'Just Resize', 'pixel_perfect': False, 'control_mode': 'ControlNet is more important', 'processor_res': 512, 'threshold_a': 1, 'threshold_b': 1, } return cn_args def process(self, context: Context, image: Image): if self.prompt == '': self.prompt = context.get_prompt_by_index() debug_print('prompt', self.prompt) elif self.prompt.find('#!org!#') >= 0: current_prompt = context.get_prompt_by_index() self.prompt = self.prompt.replace('#!org!#', current_prompt) debug_print('Prompt', self.prompt) if self.negative_prompt == '': self.negative_prompt = context.sdprocessing.negative_prompt if self.checkpoint == constants.checkpoint_default: self.checkpoint = context.sdprocessing.sd_model if self.sampler == constants.sampler_default: self.sampler = context.sdprocessing.sampler_name if self.scheduler == constants.scheduler_default: self.scheduler = util.get_scheduler(context.sdprocessing) bmab_filter = filter.get_filter(self.filter) seed, subseed = context.get_seeds() options = dict( seed=seed, subseed=subseed, denoising_strength=self.denoising_strength, prompt=self.prompt, negative_prompt=self.negative_prompt, sampler_name=self.sampler, scheduler=self.scheduler, steps=self.steps, cfg_scale=self.cfg_scale, ) if self.checkpoint != constants.checkpoint_default: override_settings = options.get('override_settings', {}) override_settings['sd_model_checkpoint'] = self.checkpoint options['override_settings'] = override_settings if self.vae != constants.vae_default: override_settings = options.get('override_settings', {}) override_settings['sd_vae'] = self.vae options['override_settings'] = override_settings filter.preprocess_filter(bmab_filter, context, image, options) context.add_job() if self.save_image: saved = image.copy() images.save_image( saved, context.sdprocessing.outpath_samples, '', context.sdprocessing.all_seeds[context.index], context.sdprocessing.all_prompts[context.index], shared.opts.samples_format, p=context.sdprocessing, suffix='-before-resample') context.add_extra_image(saved) cn_op_arg = self.get_resample_args(image, self.strength, self.begin, self.end) processed = image.copy() if self.hiresfix_enabled: if self.method == 'txt2img-1pass' or self.method == 'txt2img-2pass': options['width'] = processed.width options['height'] = processed.height processed = process_txt2img(context, options=options, controlnet=[cn_op_arg]) elif self.method == 'img2img-1pass': del cn_op_arg['input_image'] options['width'] = processed.width options['height'] = processed.height processed = process_img2img_with_controlnet(context, image, options=options, controlnet=[cn_op_arg]) else: if self.method == 'txt2img-1pass': if context.is_hires_fix(): if context.sdprocessing.hr_resize_x != 0 or context.sdprocessing.hr_resize_y != 0: options['width'] = context.sdprocessing.hr_resize_x options['height'] = context.sdprocessing.hr_resize_y else: options['width'] = int(context.sdprocessing.width * context.sdprocessing.hr_scale) options['height'] = int(context.sdprocessing.height * context.sdprocessing.hr_scale) processed = process_txt2img(context, options=options, controlnet=[cn_op_arg]) elif self.method == 'txt2img-2pass': if context.is_txtimg() and context.is_hires_fix(): options.update(dict( enable_hr=context.sdprocessing.enable_hr, hr_scale=context.sdprocessing.hr_scale, hr_resize_x=context.sdprocessing.hr_resize_x, hr_resize_y=context.sdprocessing.hr_resize_y, )) processed = process_txt2img(context, options=options, controlnet=[cn_op_arg]) elif self.method == 'img2img-1pass': del cn_op_arg['input_image'] processed = process_img2img_with_controlnet(context, image, options=options, controlnet=[cn_op_arg]) image = filter.process_filter(bmab_filter, context, image, processed) filter.postprocess_filter(bmab_filter, context) return image def postprocess(self, context: Context, image: Image): devices.torch_gc() class ResamplePreprocessorBeforeUpscale(ResamplePreprocessor): def preprocess(self, context: Context, image: Image): super().preprocess(context, image) return self.enabled and self.hiresfix_enabled and (context.is_hires_fix() or context.is_img2img())