Spaces:
Running
on
Zero
Running
on
Zero
| from .custom_cross_attention_processor import DecoupledCrossAttnProcessor2_0 | |
| import torch | |
| from diffusers.models.attention_processor import IPAdapterAttnProcessor2_0, Attention | |
| def load_custom_ip_adapter( | |
| unet, | |
| path=None, | |
| blocks="full", | |
| Custom_Attn_Type=DecoupledCrossAttnProcessor2_0, | |
| cross_attention_dim=2048, | |
| Image_Proj_Type=None, | |
| ): | |
| if path is None: | |
| state_dict = None | |
| else: | |
| state_dict = torch.load(path, map_location="cpu") | |
| # unet.config.encoder_hid_dim_type = "ip_image_proj" | |
| # if Image_Proj_Type is None: | |
| # unet.encoder_hid_proj = torch.nn.Identity() | |
| # unet.encoder_hid_proj.image_projection_layers = torch.nn.ModuleList( | |
| # [torch.nn.Identity()] | |
| # ) | |
| for name, module in unet.named_modules(): | |
| if "attn2" in name and isinstance(module, Attention): | |
| if blocks == "midup" and "mid" not in name and "up" not in name: | |
| continue | |
| if not isinstance(module.processor, torch.nn.Module): | |
| module.set_processor( | |
| Custom_Attn_Type( | |
| hidden_size=module.query_dim, | |
| cross_attention_dim=cross_attention_dim, | |
| ).to(unet.device, unet.dtype) | |
| ) | |
| if state_dict is not None: | |
| module.processor.load_state_dict(state_dict[f"{name}.processor"]) | |
| else: | |
| if hasattr(module.processor, "to_q_ip"): | |
| torch.nn.init.kaiming_normal_(module.processor.to_q_ip.weight) | |
| torch.nn.init.kaiming_normal_(module.processor.to_k_ip.weight) | |
| torch.nn.init.kaiming_normal_(module.processor.to_v_ip.weight) | |
| def save_custom_ip_adapter(unet, path): | |
| state_dict = {} | |
| for name, module in unet.attn_processors.items(): | |
| if isinstance(module, torch.nn.Module): | |
| state_dict[name] = module.state_dict() | |
| torch.save(state_dict, path) | |
| def set_scale(unet, scale): | |
| for name, module in unet.attn_processors.items(): | |
| if isinstance(module, torch.nn.Module): | |
| module.scale = scale | |