| from typing import Callable, Optional |
| import torch |
| import torch.nn.functional as F |
| from diffusers.models.attention_processor import * |
| from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel |
|
|
| def default_set_attn_proc_func( |
| name: str, |
| hidden_size: int, |
| cross_attention_dim: Optional[int], |
| ori_attn_proc: object, |
| ) -> object: |
| return ori_attn_proc |
|
|
| def set_flux_transformer_attn_processor( |
| transformer: FluxTransformer2DModel, |
| set_attn_proc_func: Callable = default_set_attn_proc_func, |
| set_attn_module_names: Optional[list[str]] = None, |
| ) -> None: |
| do_set_processor = lambda name, module_names: ( |
| any([name.startswith(module_name) for module_name in module_names]) |
| if module_names is not None |
| else True |
| ) |
|
|
| attn_procs = {} |
| for name, attn_processor in transformer.attn_processors.items(): |
| dim_head = transformer.config.attention_head_dim |
| num_heads = transformer.config.num_attention_heads |
| if name.endswith("attn.processor"): |
| attn_procs[name] = ( |
| set_attn_proc_func(name, dim_head, num_heads, attn_processor) |
| if do_set_processor(name, set_attn_module_names) |
| else attn_processor |
| ) |
|
|
| transformer.set_attn_processor(attn_procs) |
|
|
| class PersonalizeAnythingAttnProcessor: |
| |
| def __init__(self, name, mask, device, tau=0.98, concept_process=False, shift_mask = None, img_dims=4096): |
| if not hasattr(F, "scaled_dot_product_attention"): |
| raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
| |
| self.name = name |
| self.mask = mask.view(img_dims).bool().to(device) |
| self.device = device |
| self.tau = tau |
| self.concept_process = concept_process |
| self.img_dims = img_dims |
| |
| if shift_mask is None: |
| self.shift_mask = self.mask |
| else: |
| self.shift_mask = shift_mask.view(img_dims).bool().to(device) |
|
|
| def __call__( |
| self, |
| attn: Attention, |
| hidden_states: torch.FloatTensor, |
| encoder_hidden_states: torch.FloatTensor = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| image_rotary_emb: Optional[torch.Tensor] = None, |
| timestep = None, |
| ) -> torch.FloatTensor: |
| batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| |
| |
| if timestep is not None: |
| timestep = timestep |
| |
| concept_process = self.concept_process |
| c_q = concept_process and True |
| c_kv = concept_process and True |
| |
| t_flag = timestep > self.tau |
| r_q = True and t_flag |
| r_k = True and t_flag |
| r_v = True and t_flag |
| |
| if encoder_hidden_states is not None: |
| concept_feature_ = hidden_states[0, self.mask, :] |
| else: |
| concept_feature_ = hidden_states[0, 512:, :][self.mask, :] |
| |
| if r_k or r_q or r_v: |
| r_hidden_states = hidden_states |
| if encoder_hidden_states is not None: |
| r_hidden_states[1, self.shift_mask, :] = concept_feature_ |
| else: |
| text_hidden_states = hidden_states[1, :512, :] |
| image_hidden_states = hidden_states[1, 512:, :] |
| image_hidden_states[self.shift_mask, :] = concept_feature_ |
| |
| r_hidden_states[1] = torch.cat([text_hidden_states, image_hidden_states], dim=0) |
| |
|
|
| key = attn.to_k(hidden_states) |
| value = attn.to_v(hidden_states) |
| query = attn.to_q(hidden_states) |
| |
| |
| if r_k: |
| key = attn.to_k(r_hidden_states) |
| if r_q: |
| query = attn.to_q(r_hidden_states) |
| if r_v: |
| value = attn.to_v(r_hidden_states) |
|
|
| if concept_process: |
| if c_q: |
| c_query = attn.to_q(concept_feature_) |
| c_query = c_query.repeat(query.shape[0], 1, 1) |
| query = torch.cat([query, c_query], dim=1) |
| if c_kv: |
| c_key = attn.to_k(concept_feature_) |
| c_key = c_key.repeat(key.shape[0], 1, 1) |
| c_value = attn.to_v(concept_feature_) |
| c_value = c_value.repeat(value.shape[0], 1, 1) |
| key = torch.cat([key, c_key], dim=1) |
| value = torch.cat([value, c_value], dim=1) |
| |
|
|
| inner_dim = key.shape[-1] |
| head_dim = inner_dim // attn.heads |
|
|
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
| if attn.norm_q is not None: |
| query = attn.norm_q(query) |
| if attn.norm_k is not None: |
| key = attn.norm_k(key) |
|
|
| |
| if encoder_hidden_states is not None: |
| |
| encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
| encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
| encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
|
| encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( |
| batch_size, -1, attn.heads, head_dim |
| ).transpose(1, 2) |
| encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( |
| batch_size, -1, attn.heads, head_dim |
| ).transpose(1, 2) |
| encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( |
| batch_size, -1, attn.heads, head_dim |
| ).transpose(1, 2) |
|
|
| if attn.norm_added_q is not None: |
| encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) |
| if attn.norm_added_k is not None: |
| encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) |
|
|
| |
| query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) |
| key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) |
| value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) |
| |
| |
| if image_rotary_emb is not None: |
| from diffusers.models.embeddings import apply_rotary_emb |
| |
| |
| if not c_q: |
| query = apply_rotary_emb(query, image_rotary_emb) |
| if not c_kv: |
| key = apply_rotary_emb(key, image_rotary_emb) |
| |
| |
| |
| def get_concept_rotary_emb(ori_rotary_emb, mask): |
| enc_emb = ori_rotary_emb[:512, :] |
| hid_emb = ori_rotary_emb[512:, :] |
| concept_emb = hid_emb[mask, :] |
| |
| image_rotary_emb = torch.cat([enc_emb, hid_emb, concept_emb], dim=0) |
| return image_rotary_emb |
| |
| if concept_process: |
| |
| image_rotary_emb_0 = get_concept_rotary_emb(image_rotary_emb[0], self.shift_mask) |
| image_rotary_emb_1 = get_concept_rotary_emb(image_rotary_emb[1], self.shift_mask) |
| image_rotary_emb = (image_rotary_emb_0, image_rotary_emb_1) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if c_q: |
| query = apply_rotary_emb(query, image_rotary_emb) |
| if c_kv: |
| key = apply_rotary_emb(key, image_rotary_emb) |
| |
|
|
| hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) |
| |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| hidden_states = hidden_states.to(query.dtype) |
| |
| if encoder_hidden_states is not None: |
| encoder_hidden_states, hidden_states = ( |
| hidden_states[:, : encoder_hidden_states.shape[1]], |
| hidden_states[:, encoder_hidden_states.shape[1] :], |
| ) |
|
|
| |
| hidden_states = attn.to_out[0](hidden_states) |
| |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
| |
| |
| |
| hidden_states = hidden_states[:, :self.img_dims, :] |
| |
|
|
| return hidden_states, encoder_hidden_states |
| else: |
| |
| dims = self.img_dims + 512 |
| hidden_states = hidden_states[:, :dims, :] |
| |
| |
| return hidden_states |
| |
| class MultiPersonalizeAnythingAttnProcessor: |
| |
| def __init__(self, name, masks, device, tau=0.98, concept_process=False, shift_masks = None, img_dims=4096): |
| if not hasattr(F, "scaled_dot_product_attention"): |
| raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
| |
| self.name = name |
| self.device = device |
| self.tau = tau |
| self.concept_process = concept_process |
| self.img_dims = img_dims |
| |
| for i in range(len(masks)): |
| masks[i] = masks[i].view(img_dims).bool().to(device) |
| self.masks = masks |
| |
| if shift_masks is None: |
| self.shift_masks = self.masks |
| else: |
| for i in range(len(shift_masks)): |
| shift_masks[i] = shift_masks[i].view(img_dims).bool().to(device) |
| self.shift_masks = shift_masks |
|
|
| def __call__( |
| self, |
| attn: Attention, |
| hidden_states: torch.FloatTensor, |
| encoder_hidden_states: torch.FloatTensor = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| image_rotary_emb: Optional[torch.Tensor] = None, |
| timestep = None, |
| ) -> torch.FloatTensor: |
| batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| |
| |
| if timestep is not None: |
| timestep = timestep |
| |
| concept_process = self.concept_process |
| c_q = concept_process and True |
| c_kv = concept_process and True |
| |
| t_flag = timestep > self.tau |
| r_q = True and t_flag |
| r_k = True and t_flag |
| r_v = True and t_flag |
| |
| concept_features = [] |
| r_hidden_states = hidden_states |
| for id, mask in enumerate(self.masks): |
| if encoder_hidden_states is not None: |
| concept_feature_ = hidden_states[id, mask, :] |
| else: |
| concept_feature_ = hidden_states[id, 512:, :][mask, :] |
| |
| shift_mask = self.shift_masks[id] |
| concept_features.append(concept_feature_) |
| |
| if r_k or r_q or r_v: |
| if encoder_hidden_states is not None: |
| r_hidden_states[-1, shift_mask, :] = concept_feature_ |
| else: |
| text_hidden_states = r_hidden_states[-1, :512, :] |
| image_hidden_states = r_hidden_states[-1, 512:, :] |
| image_hidden_states[shift_mask, :] = concept_feature_ |
| r_hidden_states[-1] = torch.cat([text_hidden_states, image_hidden_states], dim=0) |
| |
|
|
| key = attn.to_k(hidden_states) |
| value = attn.to_v(hidden_states) |
| query = attn.to_q(hidden_states) |
| |
| |
| if r_k: |
| key = attn.to_k(r_hidden_states) |
| if r_q: |
| query = attn.to_q(r_hidden_states) |
| if r_v: |
| value = attn.to_v(r_hidden_states) |
|
|
| if concept_process: |
| for concept_feature_ in concept_features: |
| if c_q: |
| c_query = attn.to_q(concept_feature_) |
| c_query = c_query.repeat(query.shape[0], 1, 1) |
| query = torch.cat([query, c_query], dim=1) |
| if c_kv: |
| c_key = attn.to_k(concept_feature_) |
| c_key = c_key.repeat(key.shape[0], 1, 1) |
| |
| c_value = attn.to_v(concept_feature_) |
| c_value = c_value.repeat(value.shape[0], 1, 1) |
| |
| key = torch.cat([key, c_key], dim=1) |
| value = torch.cat([value, c_value], dim=1) |
| |
|
|
| inner_dim = key.shape[-1] |
| head_dim = inner_dim // attn.heads |
|
|
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
| if attn.norm_q is not None: |
| query = attn.norm_q(query) |
| if attn.norm_k is not None: |
| key = attn.norm_k(key) |
|
|
| |
| if encoder_hidden_states is not None: |
| |
| encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
| encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
| encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
|
| encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( |
| batch_size, -1, attn.heads, head_dim |
| ).transpose(1, 2) |
| encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( |
| batch_size, -1, attn.heads, head_dim |
| ).transpose(1, 2) |
| encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( |
| batch_size, -1, attn.heads, head_dim |
| ).transpose(1, 2) |
|
|
| if attn.norm_added_q is not None: |
| encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) |
| if attn.norm_added_k is not None: |
| encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) |
|
|
| |
| query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) |
| key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) |
| value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) |
| |
| if image_rotary_emb is not None: |
| from diffusers.models.embeddings import apply_rotary_emb |
| |
| |
| if not c_q: |
| query = apply_rotary_emb(query, image_rotary_emb) |
| if not c_kv: |
| key = apply_rotary_emb(key, image_rotary_emb) |
| |
| |
| def get_concept_rotary_emb(ori_rotary_emb, shift_masks): |
| enc_emb = ori_rotary_emb[:512, :] |
| hid_emb = ori_rotary_emb[512:, :] |
| |
| concept_embs = [] |
| for mask in shift_masks: |
| concept_embs.append(hid_emb[mask, :]) |
| concept_emb = torch.cat(concept_embs, dim=0) if len(concept_embs) > 0 else torch.zeros(0, hid_emb.shape[1], device=hid_emb.device) |
| image_rotary_emb = torch.cat([enc_emb, hid_emb, concept_emb], dim=0) |
| return image_rotary_emb |
| |
| if concept_process: |
| |
| image_rotary_emb_0 = get_concept_rotary_emb(image_rotary_emb[0], self.shift_masks) |
| image_rotary_emb_1 = get_concept_rotary_emb(image_rotary_emb[1], self.shift_masks) |
| image_rotary_emb = (image_rotary_emb_0, image_rotary_emb_1) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if c_q: |
| query = apply_rotary_emb(query, image_rotary_emb) |
| if c_kv: |
| key = apply_rotary_emb(key, image_rotary_emb) |
| |
|
|
| hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) |
| |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| hidden_states = hidden_states.to(query.dtype) |
| |
| if encoder_hidden_states is not None: |
| encoder_hidden_states, hidden_states = ( |
| hidden_states[:, : encoder_hidden_states.shape[1]], |
| hidden_states[:, encoder_hidden_states.shape[1] :], |
| ) |
|
|
| |
| hidden_states = attn.to_out[0](hidden_states) |
| |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
| |
| |
| |
| hidden_states = hidden_states[:, :self.img_dims, :] |
| |
|
|
| return hidden_states, encoder_hidden_states |
| else: |
| |
| dims = self.img_dims + 512 |
| hidden_states = hidden_states[:, :dims, :] |
| |
| |
| return hidden_states |