SD15-Surge-V1 / pipeline /pipeline.py
AbstractPhil's picture
Update pipeline/pipeline.py
6d94ca3 verified
# pipeline/pipeline.py
import torch
from diffusers import StableDiffusionPipeline
from diffusers.utils import BaseOutput
class OmegaDiffusionPipeline(StableDiffusionPipeline):
def __init__(
self,
vae,
text_encoder,
tokenizer,
unet,
scheduler,
bridge,
safety_checker=None,
feature_extractor=None,
**kwargs,
):
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
**kwargs,
)
# register your bridge so Diffusers knows about it
self.register_modules(bridge=bridge)
# ─── Monkey-patch the VAE.decode to insert your bridge ───
_orig_decode = self.vae.decode
in_ch = self.unet.config.in_channels
sc = self.vae.config.scaling_factor
def _decode_with_bridge(
z_scaled, *args, return_dict=False, generator=None, **decode_kwargs
):
# z_scaled = latents / scaling_factor
z = z_scaled * sc # back to raw latent
if z.shape[1] == in_ch: # 4β†’16 only when needed
z = self.bridge.dec(z)
z = z / sc # scale again
# call the real decode
out = _orig_decode(
z,
*args,
return_dict=return_dict,
generator=generator,
**decode_kwargs
)
return out
# override it in place
self.vae.decode = _decode_with_bridge
@property
def components(self):
# ensure Diffusers sees exactly the modules you expect
return {
"scheduler": self.scheduler,
"tokenizer": self.tokenizer,
"vae": self.vae,
"unet": self.unet,
"text_encoder": self.text_encoder,
"bridge": self.bridge,
"feature_extractor": self.feature_extractor,
"safety_checker": self.safety_checker,
"kwargs": {},
}
@torch.no_grad()
def _decode_latents(self, latents):
"""
The single hook that StableDiffusionPipeline.__call__
uses to turn final latents β†’ images.
"""
# calling self.vae.decode here actually invokes your _decode_with_bridge
decoded = self.vae.decode(latents, return_dict=False)
images = decoded[0] if isinstance(decoded, (tuple, list)) else decoded
# normalize to [0,1]
return (images.clamp(-1, 1) + 1) / 2
@torch.no_grad()
def __call__(self, *args, **kwargs):
# defer everything to the parent implementation, which
# will in turn call our patched _decode_latents
return super().__call__(*args, **kwargs)
@torch.no_grad()
def decode_latents(self, latents, return_dict=True):
"""
If you ever call pipe.decode_latents(...) manually,
this will route through the same bridge logic.
"""
imgs = self._decode_latents(latents)
if return_dict:
return BaseOutput(images=imgs)
return imgs