Update pipeline.py
Browse files- pipeline.py +70 -32
pipeline.py
CHANGED
|
@@ -18,12 +18,7 @@ import einops
|
|
| 18 |
import PIL.Image
|
| 19 |
import numpy as np
|
| 20 |
import torch as th
|
| 21 |
-
import torch.nn as nn
|
| 22 |
-
from torchvision import transforms
|
| 23 |
|
| 24 |
-
from diffusers import ModelMixin
|
| 25 |
-
from transformers import AutoModel, AutoConfig, SiglipVisionConfig, Dinov2Config, Dinov2Model
|
| 26 |
-
from transformers import SiglipVisionModel
|
| 27 |
from diffusers import DiffusionPipeline
|
| 28 |
from diffusers.image_processor import VaeImageProcessor
|
| 29 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|
@@ -31,8 +26,6 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
|
| 31 |
from diffusers.utils.torch_utils import randn_tensor
|
| 32 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 33 |
|
| 34 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 35 |
-
# REf: https://github.com/tatp22/multidim-positional-encoding/tree/master
|
| 36 |
from analogy_encoder import AnalogyEncoder
|
| 37 |
from analogy_projector import AnalogyProjector
|
| 38 |
from analogy_input_processor import AnalogyInputProcessor
|
|
@@ -259,8 +252,8 @@ class PatternAnalogyTrifuser(DiffusionPipeline):
|
|
| 259 |
The call function to the pipeline for generation.
|
| 260 |
|
| 261 |
Args:
|
| 262 |
-
|
| 263 |
-
The
|
| 264 |
height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
|
| 265 |
The height in pixels of the generated image.
|
| 266 |
width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
|
|
@@ -301,32 +294,77 @@ class PatternAnalogyTrifuser(DiffusionPipeline):
|
|
| 301 |
Examples:
|
| 302 |
|
| 303 |
```py
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
```
|
| 325 |
|
| 326 |
Returns:
|
| 327 |
-
[`~
|
| 328 |
-
|
| 329 |
-
otherwise a `tuple` is returned where the first element is a list with the generated images.
|
| 330 |
"""
|
| 331 |
|
| 332 |
# 1. Check inputs. Raise error if not correct
|
|
|
|
| 18 |
import PIL.Image
|
| 19 |
import numpy as np
|
| 20 |
import torch as th
|
|
|
|
|
|
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
from diffusers import DiffusionPipeline
|
| 23 |
from diffusers.image_processor import VaeImageProcessor
|
| 24 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|
|
|
| 26 |
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 28 |
|
|
|
|
|
|
|
| 29 |
from analogy_encoder import AnalogyEncoder
|
| 30 |
from analogy_projector import AnalogyProjector
|
| 31 |
from analogy_input_processor import AnalogyInputProcessor
|
|
|
|
| 252 |
The call function to the pipeline for generation.
|
| 253 |
|
| 254 |
Args:
|
| 255 |
+
analogy_prompt (`List[Tuple[PIL.Image.Image]]'):
|
| 256 |
+
The analogy sequence A, A*, B which is our model's prompt for generating B* the analogical pattern satisfying A:A*::B:B*.
|
| 257 |
height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
|
| 258 |
The height in pixels of the generated image.
|
| 259 |
width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`):
|
|
|
|
| 294 |
Examples:
|
| 295 |
|
| 296 |
```py
|
| 297 |
+
import requests
|
| 298 |
+
import torch as th
|
| 299 |
+
from PIL import Image
|
| 300 |
+
from io import BytesIO
|
| 301 |
+
import matplotlib.pyplot as plt
|
| 302 |
+
from PIL import Image, ImageOps
|
| 303 |
+
from diffusers import DiffusionPipeline
|
| 304 |
+
|
| 305 |
+
SEED = 1729
|
| 306 |
+
DEVICE = th.device("cuda")
|
| 307 |
+
DTYPE = th.float16
|
| 308 |
+
FIG_K = 3
|
| 309 |
+
EXAMPLE_ID = 0
|
| 310 |
+
|
| 311 |
+
# Now we need to do the trick
|
| 312 |
+
pretrained_path = "bardofcodes/pattern_analogies"
|
| 313 |
+
new_pipe = DiffusionPipeline.from_pretrained(
|
| 314 |
+
pretrained_path,
|
| 315 |
+
custom_pipeline=pretrained_path,
|
| 316 |
+
trust_remote_code=True
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
img_urls = [
|
| 320 |
+
f"https://huggingface.co/bardofcodes/pattern_analogies/resolve/main/examples/{EXAMPLE_ID}_a.png",
|
| 321 |
+
f"https://huggingface.co/bardofcodes/pattern_analogies/resolve/main/examples/{EXAMPLE_ID}_a_star.png",
|
| 322 |
+
f"https://huggingface.co/bardofcodes/pattern_analogies/resolve/main/examples/{EXAMPLE_ID}_b.png",
|
| 323 |
+
]
|
| 324 |
+
images = []
|
| 325 |
+
for url in img_urls:
|
| 326 |
+
response = requests.get(url)
|
| 327 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 328 |
+
images.append(image)
|
| 329 |
+
|
| 330 |
+
pipe_input = [tuple(images)]
|
| 331 |
+
|
| 332 |
+
pipe = new_pipe.to(DEVICE, DTYPE)
|
| 333 |
+
var_images = pipe(pipe_input, num_inference_steps=50, num_images_per_prompt=3,).images
|
| 334 |
+
|
| 335 |
+
plt.figure(figsize=(3*FIG_K, 2*FIG_K))
|
| 336 |
+
plt.axis('off')
|
| 337 |
+
plt.legend(framealpha=1)
|
| 338 |
+
plt.rcParams['legend.fontsize'] = 'large'
|
| 339 |
+
for i in range(6):
|
| 340 |
+
if i == 0:
|
| 341 |
+
plt.subplot(2, 3, i+1)
|
| 342 |
+
val_image = img1
|
| 343 |
+
label_str = "A"
|
| 344 |
+
elif i == 1:
|
| 345 |
+
plt.subplot(2, 3, i+1)
|
| 346 |
+
val_image = alt_img1
|
| 347 |
+
label_str = "A*"
|
| 348 |
+
elif i == 2:
|
| 349 |
+
plt.subplot(2, 3, i+1)
|
| 350 |
+
val_image = img2
|
| 351 |
+
label_str = "Target"
|
| 352 |
+
else:
|
| 353 |
+
plt.subplot(2, 3,i + 1)
|
| 354 |
+
val_image = var_images[i-3]
|
| 355 |
+
label_str = f"Variation {i-2}"
|
| 356 |
+
|
| 357 |
+
val_image = ImageOps.expand(val_image,border=2,fill='black')
|
| 358 |
+
plt.imshow(val_image)
|
| 359 |
+
plt.scatter([], [], c="r", label=label_str)
|
| 360 |
+
plt.legend(loc="lower right")
|
| 361 |
+
plt.axis('off')
|
| 362 |
+
plt.subplots_adjust(wspace=0.01, hspace=0.01)
|
| 363 |
```
|
| 364 |
|
| 365 |
Returns:
|
| 366 |
+
[`~ImagePipelineOutput`] or `tuple`
|
| 367 |
+
The generated image(s) as a [`~ImagePipelineOutput`] or a tuple of images.
|
|
|
|
| 368 |
"""
|
| 369 |
|
| 370 |
# 1. Check inputs. Raise error if not correct
|