| from __future__ import annotations |
|
|
| from functools import cached_property |
|
|
| from diffusers import ( |
| StableDiffusionControlNetInpaintPipeline, |
| StableDiffusionControlNetPipeline, |
| StableDiffusionInpaintPipeline, |
| StableDiffusionPipeline, |
| ) |
|
|
| from asdff.base import AdPipelineBase |
|
|
|
|
| class AdPipeline(AdPipelineBase, StableDiffusionPipeline): |
| @cached_property |
| def inpaint_pipeline(self): |
| return StableDiffusionInpaintPipeline( |
| vae=self.vae, |
| text_encoder=self.text_encoder, |
| tokenizer=self.tokenizer, |
| unet=self.unet, |
| scheduler=self.scheduler, |
| safety_checker=self.safety_checker, |
| feature_extractor=self.feature_extractor, |
| requires_safety_checker=self.config.requires_safety_checker, |
| ) |
|
|
| @property |
| def txt2img_class(self): |
| return StableDiffusionPipeline |
|
|
|
|
| class AdCnPipeline(AdPipelineBase, StableDiffusionControlNetPipeline): |
| @cached_property |
| def inpaint_pipeline(self): |
| return StableDiffusionControlNetInpaintPipeline( |
| vae=self.vae, |
| text_encoder=self.text_encoder, |
| tokenizer=self.tokenizer, |
| unet=self.unet, |
| controlnet=self.controlnet, |
| scheduler=self.scheduler, |
| safety_checker=self.safety_checker, |
| feature_extractor=self.feature_extractor, |
| requires_safety_checker=self.config.requires_safety_checker, |
| ) |
|
|
| @property |
| def txt2img_class(self): |
| return StableDiffusionControlNetPipeline |
|
|