test / modules /control /units /reference.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame
2.71 kB
from typing import Union
import time
import diffusers.utils
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from modules.shared import log, opts
from modules.control.units import detect
from modules import sd_models
what = 'Reference'
def list_models():
return ['Reference']
class ReferencePipeline():
def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
t0 = time.time()
self.orig_pipeline = pipeline
self.pipeline = None
if pipeline is None:
log.error(f'Control {what} model pipeline: model not loaded')
return
if opts.diffusers_fuse_projections and hasattr(pipeline, 'unfuse_qkv_projections'):
pipeline.unfuse_qkv_projections()
if detect.is_sdxl(pipeline):
cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_xl_reference', module_file='pipeline.py')
self.pipeline = cls(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
text_encoder_2=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer,
tokenizer_2=pipeline.tokenizer_2,
unet=pipeline.unet,
scheduler=pipeline.scheduler,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
)
sd_models.move_model(self.pipeline, pipeline.device)
elif detect.is_sd15(pipeline):
cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_reference', module_file='pipeline.py')
self.pipeline = cls(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
unet=pipeline.unet,
scheduler=pipeline.scheduler,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
requires_safety_checker=False,
safety_checker=None,
)
sd_models.move_model(self.pipeline, pipeline.device)
else:
log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type')
return
if dtype is not None and self.pipeline is not None:
self.pipeline = self.pipeline.to(dtype)
t1 = time.time()
if self.pipeline is not None:
log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
else:
log.error(f'Control {what} pipeline: not initialized')
def restore(self):
self.pipeline = None
return self.orig_pipeline