Update pipeline.py
Browse files- pipeline.py +15 -22
pipeline.py
CHANGED
|
@@ -33,31 +33,24 @@ class SuperDiffPipeline(DiffusionPipeline, ConfigMixin):
|
|
| 33 |
|
| 34 |
"""
|
| 35 |
super().__init__()
|
| 36 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
self.vae = vae
|
| 38 |
-
self.text_encoder = text_encoder
|
| 39 |
-
self.tokenizer = tokenizer
|
| 40 |
self.scheduler = scheduler
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
self.vae.to(device)
|
| 45 |
-
self.unet.to(device)
|
| 46 |
-
self.text_encoder.to(device)
|
| 47 |
-
|
| 48 |
-
self.register_to_config(
|
| 49 |
-
vae=vae.__class__.__name__,
|
| 50 |
-
scheduler=scheduler.__class__.__name__,
|
| 51 |
-
tokenizer=tokenizer.__class__.__name__,
|
| 52 |
-
unet=unet.__class__.__name__,
|
| 53 |
-
text_encoder=text_encoder.__class__.__name__,
|
| 54 |
-
device=device,
|
| 55 |
-
batch_size=None,
|
| 56 |
-
num_inference_steps=None,
|
| 57 |
-
guidance_scale=None,
|
| 58 |
-
lift=None,
|
| 59 |
-
seed=None,
|
| 60 |
-
)
|
| 61 |
|
| 62 |
@torch.no_grad
|
| 63 |
def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
|
|
|
|
| 33 |
|
| 34 |
"""
|
| 35 |
super().__init__()
|
| 36 |
+
self.register_to_config(
|
| 37 |
+
batch_size=kwargs.get("batch_size", 1),
|
| 38 |
+
device=kwargs.get("device", "cuda"),
|
| 39 |
+
guidance_scale=kwargs.get("guidance_scale", 7.5),
|
| 40 |
+
lift=kwargs.get("lift", 0.0),
|
| 41 |
+
num_inference_steps=kwargs.get("num_inference_steps", 50),
|
| 42 |
+
seed=kwargs.get("seed", 42)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Assign model components
|
| 46 |
self.vae = vae
|
|
|
|
|
|
|
| 47 |
self.scheduler = scheduler
|
| 48 |
+
self.tokenizer = tokenizer
|
| 49 |
+
self.unet = unet
|
| 50 |
+
self.text_encoder = text_encoder
|
| 51 |
|
| 52 |
+
# Move components to device
|
| 53 |
+
self.to(torch.device(self.config.device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
@torch.no_grad
|
| 56 |
def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
|