| | import tempfile |
| | import unittest |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from diffusers import DiffusionPipeline |
| | from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor |
| |
|
| |
|
| | class AttnAddedKVProcessorTests(unittest.TestCase): |
| | def get_constructor_arguments(self, only_cross_attention: bool = False): |
| | query_dim = 10 |
| |
|
| | if only_cross_attention: |
| | cross_attention_dim = 12 |
| | else: |
| | |
| | cross_attention_dim = query_dim |
| |
|
| | return { |
| | "query_dim": query_dim, |
| | "cross_attention_dim": cross_attention_dim, |
| | "heads": 2, |
| | "dim_head": 4, |
| | "added_kv_proj_dim": 6, |
| | "norm_num_groups": 1, |
| | "only_cross_attention": only_cross_attention, |
| | "processor": AttnAddedKVProcessor(), |
| | } |
| |
|
| | def get_forward_arguments(self, query_dim, added_kv_proj_dim): |
| | batch_size = 2 |
| |
|
| | hidden_states = torch.rand(batch_size, query_dim, 3, 2) |
| | encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim) |
| | attention_mask = None |
| |
|
| | return { |
| | "hidden_states": hidden_states, |
| | "encoder_hidden_states": encoder_hidden_states, |
| | "attention_mask": attention_mask, |
| | } |
| |
|
| | def test_only_cross_attention(self): |
| | |
| |
|
| | torch.manual_seed(0) |
| |
|
| | constructor_args = self.get_constructor_arguments(only_cross_attention=False) |
| | attn = Attention(**constructor_args) |
| |
|
| | self.assertTrue(attn.to_k is not None) |
| | self.assertTrue(attn.to_v is not None) |
| |
|
| | forward_args = self.get_forward_arguments( |
| | query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] |
| | ) |
| |
|
| | self_and_cross_attn_out = attn(**forward_args) |
| |
|
| | |
| |
|
| | torch.manual_seed(0) |
| |
|
| | constructor_args = self.get_constructor_arguments(only_cross_attention=True) |
| | attn = Attention(**constructor_args) |
| |
|
| | self.assertTrue(attn.to_k is None) |
| | self.assertTrue(attn.to_v is None) |
| |
|
| | forward_args = self.get_forward_arguments( |
| | query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] |
| | ) |
| |
|
| | only_cross_attn_out = attn(**forward_args) |
| |
|
| | self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) |
| |
|
| |
|
| | class DeprecatedAttentionBlockTests(unittest.TestCase): |
| | def test_conversion_when_using_device_map(self): |
| | pipe = DiffusionPipeline.from_pretrained( |
| | "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None |
| | ) |
| |
|
| | pre_conversion = pipe( |
| | "foo", |
| | num_inference_steps=2, |
| | generator=torch.Generator("cpu").manual_seed(0), |
| | output_type="np", |
| | ).images |
| |
|
| | |
| | pipe = DiffusionPipeline.from_pretrained( |
| | "hf-internal-testing/tiny-stable-diffusion-torch", device_map="balanced", safety_checker=None |
| | ) |
| |
|
| | conversion = pipe( |
| | "foo", |
| | num_inference_steps=2, |
| | generator=torch.Generator("cpu").manual_seed(0), |
| | output_type="np", |
| | ).images |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | |
| | pipe.save_pretrained(tmpdir) |
| |
|
| | |
| | pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="balanced", safety_checker=None) |
| | after_conversion = pipe( |
| | "foo", |
| | num_inference_steps=2, |
| | generator=torch.Generator("cpu").manual_seed(0), |
| | output_type="np", |
| | ).images |
| |
|
| | self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3)) |
| | self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3)) |
| |
|