Spaces:
Starting on Zero
Starting on Zero
| from typing import Callable, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers.models.attention_processor import * | |
| from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | |
| def default_set_attn_proc_func( | |
| name: str, | |
| hidden_size: int, | |
| cross_attention_dim: Optional[int], | |
| ori_attn_proc: object, | |
| ) -> object: | |
| return ori_attn_proc | |
| def set_flux_transformer_attn_processor( | |
| transformer: FluxTransformer2DModel, | |
| set_attn_proc_func: Callable = default_set_attn_proc_func, | |
| set_attn_module_names: Optional[list[str]] = None, | |
| ) -> None: | |
| do_set_processor = lambda name, module_names: ( | |
| any([name.startswith(module_name) for module_name in module_names]) | |
| if module_names is not None | |
| else True | |
| ) # prefix match | |
| attn_procs = {} | |
| for name, attn_processor in transformer.attn_processors.items(): | |
| dim_head = transformer.config.attention_head_dim | |
| num_heads = transformer.config.num_attention_heads | |
| if name.endswith("attn.processor"): | |
| attn_procs[name] = ( | |
| set_attn_proc_func(name, dim_head, num_heads, attn_processor) | |
| if do_set_processor(name, set_attn_module_names) | |
| else attn_processor | |
| ) | |
| transformer.set_attn_processor(attn_procs) | |
| class PersonalizeAnythingAttnProcessor: | |
| def __init__(self, name, mask, device, tau=0.98, concept_process=False, shift_mask = None, img_dims=4096): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
| self.name = name | |
| self.mask = mask.view(img_dims).bool().to(device) | |
| self.device = device | |
| self.tau = tau | |
| self.concept_process = concept_process | |
| self.img_dims = img_dims | |
| if shift_mask is None: | |
| self.shift_mask = self.mask | |
| else: | |
| self.shift_mask = shift_mask.view(img_dims).bool().to(device) | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: torch.FloatTensor = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| timestep = None, | |
| ) -> torch.FloatTensor: | |
| batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ################################################################################### | |
| if timestep is not None: | |
| timestep = timestep | |
| concept_process = self.concept_process # token concatenation | |
| c_q = concept_process and True # if token concatenation is applied to q | |
| c_kv = concept_process and True # if token concatenation is applied to kv | |
| t_flag = timestep > self.tau # token replacement | |
| r_q = True and t_flag # if token concatenation is applied to q | |
| r_k = True and t_flag # if token concatenation is applied to k | |
| r_v = True and t_flag # if token concatenation is applied to v | |
| if encoder_hidden_states is not None: | |
| concept_feature_ = hidden_states[0, self.mask, :] | |
| else: | |
| concept_feature_ = hidden_states[0, 512:, :][self.mask, :] | |
| if r_k or r_q or r_v: | |
| r_hidden_states = hidden_states | |
| if encoder_hidden_states is not None: | |
| r_hidden_states[1, self.shift_mask, :] = concept_feature_ | |
| else: | |
| text_hidden_states = hidden_states[1, :512, :] | |
| image_hidden_states = hidden_states[1, 512:, :] | |
| image_hidden_states[self.shift_mask, :] = concept_feature_ | |
| r_hidden_states[1] = torch.cat([text_hidden_states, image_hidden_states], dim=0) | |
| ################################################################################### | |
| key = attn.to_k(hidden_states) | |
| value = attn.to_v(hidden_states) | |
| query = attn.to_q(hidden_states) | |
| ################################################################################### | |
| if r_k: | |
| key = attn.to_k(r_hidden_states) | |
| if r_q: | |
| query = attn.to_q(r_hidden_states) | |
| if r_v: | |
| value = attn.to_v(r_hidden_states) | |
| if concept_process: | |
| if c_q: | |
| c_query = attn.to_q(concept_feature_) | |
| c_query = c_query.repeat(query.shape[0], 1, 1) | |
| query = torch.cat([query, c_query], dim=1) | |
| if c_kv: | |
| c_key = attn.to_k(concept_feature_) | |
| c_key = c_key.repeat(key.shape[0], 1, 1) | |
| c_value = attn.to_v(concept_feature_) | |
| c_value = c_value.repeat(value.shape[0], 1, 1) | |
| key = torch.cat([key, c_key], dim=1) | |
| value = torch.cat([value, c_value], dim=1) | |
| ################################################################################### | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` | |
| if encoder_hidden_states is not None: | |
| # `context` projections. | |
| encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) | |
| encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | |
| encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | |
| encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( | |
| batch_size, -1, attn.heads, head_dim | |
| ).transpose(1, 2) | |
| encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( | |
| batch_size, -1, attn.heads, head_dim | |
| ).transpose(1, 2) | |
| encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( | |
| batch_size, -1, attn.heads, head_dim | |
| ).transpose(1, 2) | |
| if attn.norm_added_q is not None: | |
| encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) | |
| if attn.norm_added_k is not None: | |
| encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) | |
| # attention | |
| query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) | |
| key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) | |
| value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) | |
| if image_rotary_emb is not None: | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| # use original position emb or text emb | |
| if not c_q: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| if not c_kv: | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| ################################################################################### | |
| # get original position emb | |
| def get_concept_rotary_emb(ori_rotary_emb, mask): | |
| enc_emb = ori_rotary_emb[:512, :] | |
| hid_emb = ori_rotary_emb[512:, :] | |
| concept_emb = hid_emb[mask, :] | |
| image_rotary_emb = torch.cat([enc_emb, hid_emb, concept_emb], dim=0) | |
| return image_rotary_emb | |
| if concept_process: | |
| # 1. use original position emb | |
| image_rotary_emb_0 = get_concept_rotary_emb(image_rotary_emb[0], self.shift_mask) | |
| image_rotary_emb_1 = get_concept_rotary_emb(image_rotary_emb[1], self.shift_mask) | |
| image_rotary_emb = (image_rotary_emb_0, image_rotary_emb_1) | |
| # 2. use text emb | |
| # dims = (self.mask == 1).sum().item() | |
| # concept_rotary_emb_0 = torch.ones((dims, 128)).to(self.device) | |
| # concept_rotary_emb_1 = torch.zeros((dims, 128)).to(self.device) | |
| # image_rotary_emb = ( | |
| # torch.cat([image_rotary_emb[0], concept_rotary_emb_0], dim=0), | |
| # torch.cat([image_rotary_emb[1], concept_rotary_emb_1], dim=0)) | |
| if c_q: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| if c_kv: | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| ################################################################################### | |
| hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) | |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
| hidden_states = hidden_states.to(query.dtype) | |
| if encoder_hidden_states is not None: | |
| encoder_hidden_states, hidden_states = ( | |
| hidden_states[:, : encoder_hidden_states.shape[1]], | |
| hidden_states[:, encoder_hidden_states.shape[1] :], | |
| ) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | |
| ################################################################ | |
| # restore after token concatenation | |
| hidden_states = hidden_states[:, :self.img_dims, :] | |
| ################################################################ | |
| return hidden_states, encoder_hidden_states | |
| else: | |
| ################################################################ | |
| dims = self.img_dims + 512 | |
| hidden_states = hidden_states[:, :dims, :] | |
| ################################################################ | |
| return hidden_states | |
| class MultiPersonalizeAnythingAttnProcessor: | |
| def __init__(self, name, masks, device, tau=0.98, concept_process=False, shift_masks = None, img_dims=4096): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
| self.name = name | |
| self.device = device | |
| self.tau = tau | |
| self.concept_process = concept_process | |
| self.img_dims = img_dims | |
| for i in range(len(masks)): | |
| masks[i] = masks[i].view(img_dims).bool().to(device) | |
| self.masks = masks | |
| if shift_masks is None: | |
| self.shift_masks = self.masks | |
| else: | |
| for i in range(len(shift_masks)): | |
| shift_masks[i] = shift_masks[i].view(img_dims).bool().to(device) | |
| self.shift_masks = shift_masks | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: torch.FloatTensor = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| timestep = None, | |
| ) -> torch.FloatTensor: | |
| batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ################################################################################### | |
| if timestep is not None: | |
| timestep = timestep | |
| concept_process = self.concept_process # token concatenation | |
| c_q = concept_process and True # if token concatenation is applied to q | |
| c_kv = concept_process and True # if token concatenation is applied to kv | |
| t_flag = timestep > self.tau # token replacement | |
| r_q = True and t_flag # if token concatenation is applied to q | |
| r_k = True and t_flag # if token concatenation is applied to k | |
| r_v = True and t_flag # if token concatenation is applied to v | |
| concept_features = [] | |
| r_hidden_states = hidden_states | |
| for id, mask in enumerate(self.masks): | |
| if encoder_hidden_states is not None: | |
| concept_feature_ = hidden_states[id, mask, :] | |
| else: | |
| concept_feature_ = hidden_states[id, 512:, :][mask, :] | |
| shift_mask = self.shift_masks[id] | |
| concept_features.append(concept_feature_) | |
| if r_k or r_q or r_v: | |
| if encoder_hidden_states is not None: | |
| r_hidden_states[-1, shift_mask, :] = concept_feature_ | |
| else: | |
| text_hidden_states = r_hidden_states[-1, :512, :] | |
| image_hidden_states = r_hidden_states[-1, 512:, :] | |
| image_hidden_states[shift_mask, :] = concept_feature_ | |
| r_hidden_states[-1] = torch.cat([text_hidden_states, image_hidden_states], dim=0) | |
| ################################################################################### | |
| key = attn.to_k(hidden_states) | |
| value = attn.to_v(hidden_states) | |
| query = attn.to_q(hidden_states) | |
| ################################################################################### | |
| if r_k: | |
| key = attn.to_k(r_hidden_states) | |
| if r_q: | |
| query = attn.to_q(r_hidden_states) | |
| if r_v: | |
| value = attn.to_v(r_hidden_states) | |
| if concept_process: | |
| for concept_feature_ in concept_features: | |
| if c_q: | |
| c_query = attn.to_q(concept_feature_) | |
| c_query = c_query.repeat(query.shape[0], 1, 1) | |
| query = torch.cat([query, c_query], dim=1) | |
| if c_kv: | |
| c_key = attn.to_k(concept_feature_) | |
| c_key = c_key.repeat(key.shape[0], 1, 1) | |
| c_value = attn.to_v(concept_feature_) | |
| c_value = c_value.repeat(value.shape[0], 1, 1) | |
| key = torch.cat([key, c_key], dim=1) | |
| value = torch.cat([value, c_value], dim=1) | |
| ################################################################################### | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` | |
| if encoder_hidden_states is not None: | |
| # `context` projections. | |
| encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) | |
| encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | |
| encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | |
| encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( | |
| batch_size, -1, attn.heads, head_dim | |
| ).transpose(1, 2) | |
| encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( | |
| batch_size, -1, attn.heads, head_dim | |
| ).transpose(1, 2) | |
| encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( | |
| batch_size, -1, attn.heads, head_dim | |
| ).transpose(1, 2) | |
| if attn.norm_added_q is not None: | |
| encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) | |
| if attn.norm_added_k is not None: | |
| encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) | |
| # attention | |
| query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) | |
| key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) | |
| value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) | |
| if image_rotary_emb is not None: | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| # use original position emb or text emb | |
| if not c_q: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| if not c_kv: | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| ################################################################################### | |
| def get_concept_rotary_emb(ori_rotary_emb, shift_masks): | |
| enc_emb = ori_rotary_emb[:512, :] | |
| hid_emb = ori_rotary_emb[512:, :] | |
| concept_embs = [] | |
| for mask in shift_masks: | |
| concept_embs.append(hid_emb[mask, :]) | |
| concept_emb = torch.cat(concept_embs, dim=0) if len(concept_embs) > 0 else torch.zeros(0, hid_emb.shape[1], device=hid_emb.device) | |
| image_rotary_emb = torch.cat([enc_emb, hid_emb, concept_emb], dim=0) | |
| return image_rotary_emb | |
| if concept_process: | |
| # 1. use original position emb with plural masks | |
| image_rotary_emb_0 = get_concept_rotary_emb(image_rotary_emb[0], self.shift_masks) | |
| image_rotary_emb_1 = get_concept_rotary_emb(image_rotary_emb[1], self.shift_masks) | |
| image_rotary_emb = (image_rotary_emb_0, image_rotary_emb_1) | |
| # 2. use text emb with plural masks | |
| # total_dims = sum((mask == 1).sum().item() for mask in self.masks) | |
| # concept_rotary_emb_0 = torch.ones((total_dims, 128)).to(self.device) | |
| # concept_rotary_emb_1 = torch.zeros((total_dims, 128)).to(self.device) | |
| # image_rotary_emb = ( | |
| # torch.cat([image_rotary_emb[0], concept_rotary_emb_0], dim=0), | |
| # torch.cat([image_rotary_emb[1], concept_rotary_emb_1], dim=0) | |
| # ) | |
| if c_q: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| if c_kv: | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| ################################################################################### | |
| hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) | |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
| hidden_states = hidden_states.to(query.dtype) | |
| if encoder_hidden_states is not None: | |
| encoder_hidden_states, hidden_states = ( | |
| hidden_states[:, : encoder_hidden_states.shape[1]], | |
| hidden_states[:, encoder_hidden_states.shape[1] :], | |
| ) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | |
| ################################################################ | |
| # restore after token concatenation | |
| hidden_states = hidden_states[:, :self.img_dims, :] | |
| ################################################################ | |
| return hidden_states, encoder_hidden_states | |
| else: | |
| ################################################################ | |
| dims = self.img_dims + 512 | |
| hidden_states = hidden_states[:, :dims, :] | |
| ################################################################ | |
| return hidden_states |