| |
|
|
| import torch |
| from diffusers import StableDiffusionPipeline |
| from diffusers.utils import BaseOutput |
|
|
| class OmegaDiffusionPipeline(StableDiffusionPipeline): |
| def __init__( |
| self, |
| vae, |
| text_encoder, |
| tokenizer, |
| unet, |
| scheduler, |
| bridge, |
| safety_checker=None, |
| feature_extractor=None, |
| **kwargs, |
| ): |
| super().__init__( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| unet=unet, |
| scheduler=scheduler, |
| safety_checker=safety_checker, |
| feature_extractor=feature_extractor, |
| **kwargs, |
| ) |
| |
| self.register_modules(bridge=bridge) |
|
|
| |
| _orig_decode = self.vae.decode |
| in_ch = self.unet.config.in_channels |
| sc = self.vae.config.scaling_factor |
|
|
| def _decode_with_bridge( |
| z_scaled, *args, return_dict=False, generator=None, **decode_kwargs |
| ): |
| |
| z = z_scaled * sc |
| if z.shape[1] == in_ch: |
| z = self.bridge.dec(z) |
| z = z / sc |
| |
| out = _orig_decode( |
| z, |
| *args, |
| return_dict=return_dict, |
| generator=generator, |
| **decode_kwargs |
| ) |
| return out |
|
|
| |
| self.vae.decode = _decode_with_bridge |
|
|
| @property |
| def components(self): |
| |
| return { |
| "scheduler": self.scheduler, |
| "tokenizer": self.tokenizer, |
| "vae": self.vae, |
| "unet": self.unet, |
| "text_encoder": self.text_encoder, |
| "bridge": self.bridge, |
| "feature_extractor": self.feature_extractor, |
| "safety_checker": self.safety_checker, |
| "kwargs": {}, |
| } |
|
|
| @torch.no_grad() |
| def _decode_latents(self, latents): |
| """ |
| The single hook that StableDiffusionPipeline.__call__ |
| uses to turn final latents β images. |
| """ |
| |
| decoded = self.vae.decode(latents, return_dict=False) |
| images = decoded[0] if isinstance(decoded, (tuple, list)) else decoded |
| |
| return (images.clamp(-1, 1) + 1) / 2 |
|
|
| @torch.no_grad() |
| def __call__(self, *args, **kwargs): |
| |
| |
| return super().__call__(*args, **kwargs) |
|
|
| @torch.no_grad() |
| def decode_latents(self, latents, return_dict=True): |
| """ |
| If you ever call pipe.decode_latents(...) manually, |
| this will route through the same bridge logic. |
| """ |
| imgs = self._decode_latents(latents) |
| if return_dict: |
| return BaseOutput(images=imgs) |
| return imgs |
|
|