sdas / 3-bmab /sd_bmab /controlnet.py
dikdimon's picture
Upload 3-bmab using SD-Hub
c10aebf verified
from copy import copy
from modules import shared
from modules import processing
from modules import img2img
from modules.processing import Processed
from sd_bmab.util import debug_print, get_cn_args
controlnet_args = (0, 0)
class FakeControlNet:
def __init__(self, ctx, cn_enabled=False) -> None:
super().__init__()
self.context = ctx
self.process = ctx.sdprocessing
self.all_prompts = None
self.all_negative_prompts = None
self.extra_image = None
self.enabled = self.is_controlnet_enabled() if cn_enabled else False
self.control_index = []
debug_print('FakeControlNet', self.enabled, cn_enabled)
def __enter__(self):
if self.enabled:
dummy = Processed(self.process, [], self.process.seed, "")
self.all_prompts = copy(self.process.all_prompts)
self.all_negative_prompts = copy(self.process.all_negative_prompts)
self.extra_image = copy(self.context.script.extra_image)
self.process.scripts.postprocess(copy(self.process), dummy)
for idx, obj in enumerate(self.process.script_args):
if 'controlnet' in obj.__class__.__name__.lower():
if hasattr(obj, 'enabled') and obj.enabled:
obj.enabled = False
self.control_index.append(idx)
elif isinstance(obj, dict) and 'model' in obj and obj['enabled']:
obj['enabled'] = False
self.control_index.append(idx)
def __exit__(self, *args, **kwargs):
if self.enabled:
copy_p = copy(self.process)
self.process.all_prompts = self.all_prompts
self.process.all_negative_prompts = self.all_negative_prompts
self.extra_image.extend(self.context.script.extra_image)
for idx in self.control_index:
obj = self.process.script_args[idx]
if 'controlnet' in obj.__class__.__name__.lower():
if hasattr(obj, 'enabled'):
obj.enabled = True
elif isinstance(obj, dict) and 'model' in obj:
obj['enabled'] = True
self.process.scripts.before_process(copy_p)
self.process.scripts.process(copy_p)
self.context.script.extra_image.extend(self.extra_image)
def is_controlnet_enabled(self):
global controlnet_args
for idx in range(controlnet_args[0], controlnet_args[1]):
obj = self.process.script_args[idx]
if isinstance(obj, dict):
return True
if 'controlnet' in obj.__class__.__name__.lower():
if hasattr(obj, 'enabled'):
return True
return False
class PreventControlNet(FakeControlNet):
process_images_inner = processing.process_images_inner
process_batch = img2img.process_batch
def __init__(self, ctx, cn_enabled=False) -> None:
super().__init__(ctx, cn_enabled)
self._process_images_inner = processing.process_images_inner
self._process_batch = img2img.process_batch
self.allow_script_control = None
self.p = ctx.sdprocessing
self.all_prompts = copy(ctx.sdprocessing.all_prompts)
self.all_negative_prompts = copy(ctx.sdprocessing.all_negative_prompts)
def is_controlnet_used(self):
if not self.p.script_args:
return False
for idx, obj in enumerate(self.p.script_args):
if 'controlnet' in obj.__class__.__name__.lower():
if hasattr(obj, 'enabled') and obj.enabled:
debug_print('Use controlnet True')
return True
elif isinstance(obj, dict) and 'module' in obj and obj['enabled']:
debug_print('Use controlnet True')
return True
debug_print('Use controlnet False')
return False
def __enter__(self):
model = self.p.sd_model.model.diffusion_model
if hasattr(model, '_original_forward'):
model._old_forward = self.p.sd_model.model.diffusion_model.forward
model.forward = getattr(model, '_original_forward')
processing.process_images_inner = PreventControlNet.process_images_inner
img2img.process_batch = PreventControlNet.process_batch
if 'control_net_allow_script_control' in shared.opts.data:
self.allow_script_control = shared.opts.data['control_net_allow_script_control']
shared.opts.data['control_net_allow_script_control'] = True
self.multiple_tqdm = shared.opts.data.get('multiple_tqdm', True)
shared.opts.data['multiple_tqdm'] = False
super().__enter__()
def __exit__(self, *args, **kwargs):
processing.process_images_inner = self._process_images_inner
img2img.process_batch = self._process_batch
if 'control_net_allow_script_control' in shared.opts.data:
shared.opts.data['control_net_allow_script_control'] = self.allow_script_control
shared.opts.data['multiple_tqdm'] = self.multiple_tqdm
model = self.p.sd_model.model.diffusion_model
if hasattr(model, '_original_forward') and hasattr(model, '_old_forward'):
self.p.sd_model.model.diffusion_model.forward = model._old_forward
super().__exit__(*args, **kwargs)
def update_controlnet_args(p):
cn_arg_index = []
for idx, obj in enumerate(p.script_args):
if 'controlnet' in obj.__class__.__name__.lower():
cn_arg_index.append(idx)
global controlnet_args
controlnet_args = (cn_arg_index[0], cn_arg_index[-1])
def get_controlnet_index(p):
cn_args = get_cn_args(p)
controlnet_count = 0
for num in range(*cn_args):
obj = p.script_args[num]
if hasattr(obj, 'enabled') and obj.enabled:
controlnet_count += 1
elif isinstance(obj, dict) and 'model' in obj and obj['enabled']:
controlnet_count += 1
else:
break
return cn_args[0] + controlnet_count