| import pytest |
| import torch |
|
|
| from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType |
| from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher |
| from invokeai.backend.util.test_utils import install_and_load_model |
|
|
|
|
| def build_dummy_sd15_unet_input(torch_device): |
| batch_size = 1 |
| num_channels = 4 |
| sizes = (32, 32) |
|
|
| noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device) |
| time_step = torch.tensor([10]).to(torch_device) |
| encoder_hidden_states = torch.randn((batch_size, 77, 768)).to(torch_device) |
|
|
| return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} |
|
|
|
|
| @pytest.mark.parametrize( |
| "model_params", |
| [ |
| |
| { |
| "ip_adapter_model_id": "InvokeAI/ip_adapter_sd15", |
| "ip_adapter_model_name": "ip_adapter_sd15", |
| "base_model": BaseModelType.StableDiffusion1, |
| "unet_model_id": "runwayml/stable-diffusion-v1-5", |
| "unet_model_name": "stable-diffusion-v1-5", |
| }, |
| |
| { |
| "ip_adapter_model_id": "InvokeAI/ip_adapter_plus_sd15", |
| "ip_adapter_model_name": "ip_adapter_plus_sd15", |
| "base_model": BaseModelType.StableDiffusion1, |
| "unet_model_id": "runwayml/stable-diffusion-v1-5", |
| "unet_model_name": "stable-diffusion-v1-5", |
| }, |
| |
| { |
| "ip_adapter_model_id": "InvokeAI/ip-adapter-full-face_sd15", |
| "ip_adapter_model_name": "ip-adapter-full-face_sd15", |
| "base_model": BaseModelType.StableDiffusion1, |
| "unet_model_id": "runwayml/stable-diffusion-v1-5", |
| "unet_model_name": "stable-diffusion-v1-5", |
| }, |
| ], |
| ) |
| @pytest.mark.slow |
| def test_ip_adapter_unet_patch(model_params, model_installer, torch_device): |
| """Smoke test that IP-Adapter weights can be loaded and used to patch a UNet.""" |
| ip_adapter_info = install_and_load_model( |
| model_installer=model_installer, |
| model_path_id_or_url=model_params["ip_adapter_model_id"], |
| model_name=model_params["ip_adapter_model_name"], |
| base_model=model_params["base_model"], |
| model_type=ModelType.IPAdapter, |
| ) |
|
|
| unet_info = install_and_load_model( |
| model_installer=model_installer, |
| model_path_id_or_url=model_params["unet_model_id"], |
| model_name=model_params["unet_model_name"], |
| base_model=model_params["base_model"], |
| model_type=ModelType.Main, |
| submodel_type=SubModelType.UNet, |
| ) |
|
|
| dummy_unet_input = build_dummy_sd15_unet_input(torch_device) |
|
|
| with torch.no_grad(), ip_adapter_info as ip_adapter, unet_info as unet: |
| ip_adapter.to(torch_device, dtype=torch.float32) |
| unet.to(torch_device, dtype=torch.float32) |
|
|
| |
| ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device) |
|
|
| cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]} |
| ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter]) |
| with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet): |
| output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample |
|
|
| assert output.shape == dummy_unet_input["sample"].shape |
|
|