Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as nnf | |
| import abc | |
| import math | |
| from torchvision.utils import save_image | |
| LOW_RESOURCE = False | |
| MAX_NUM_WORDS = 77 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| class AttentionControl(abc.ABC): | |
| def step_callback(self, x_t): | |
| return x_t | |
| def between_steps(self): | |
| return | |
| def start_att_layers(self): | |
| return self.start_ac_layer #if LOW_RESOURCE else 0 | |
| def end_att_layers(self): | |
| return self.end_ac_layer | |
| def forward(self, q, k, v, num_heads,attn): | |
| raise NotImplementedError | |
| def attn_forward(self, q, k, v, num_heads,attention_probs,attn): | |
| if q.shape[0]//num_heads == 3: | |
| h_s_re = self.forward(q, k, v, num_heads,attention_probs, attn) | |
| else: | |
| uq,cq = q.chunk(2) | |
| uk,ck = k.chunk(2) | |
| uv,cv = v.chunk(2) | |
| u_attn, c_attn = attention_probs.chunk(2) | |
| u_h_s_re = self.forward(uq, uk, uv, num_heads,u_attn, attn) | |
| c_h_s_re = self.forward(cq, ck, cv, num_heads,c_attn, attn) | |
| h_s_re = (u_h_s_re, c_h_s_re) | |
| return h_s_re | |
| def __call__(self, q, k, v, num_heads,attention_probs,attn): | |
| if self.cur_att_layer >= self.start_att_layers and self.cur_att_layer < self.end_att_layers: | |
| h_s_re = self.attn_forward(q, k, v, num_heads,attention_probs,attn) | |
| else: | |
| h_s_re=None | |
| self.cur_att_layer += 1 | |
| if self.cur_att_layer == self.num_att_layers // 2: #+ self.num_uncond_att_layers: | |
| self.cur_att_layer = 0 #self.num_uncond_att_layers | |
| self.cur_step += 1 | |
| self.between_steps() | |
| return h_s_re | |
| def reset(self): | |
| self.cur_step = 0 | |
| self.cur_att_layer = 0 | |
| def __init__(self): | |
| self.cur_step = 0 | |
| self.num_att_layers = -1 | |
| self.cur_att_layer = 0 | |
| def enhance_tensor(tensor: torch.Tensor, contrast_factor: float = 1.67) -> torch.Tensor: | |
| """ Compute the attention map contrasting. """ | |
| mean_feat = tensor.mean(dim=-1, keepdims=True) | |
| adjusted_tensor = (tensor - mean_feat) * contrast_factor + mean_feat | |
| return adjusted_tensor | |
| class AttentionStyle(AttentionControl): | |
| def __init__(self, | |
| num_steps, | |
| start_ac_layer, end_ac_layer, | |
| style_guidance=0.3, | |
| mix_q_scale=1.0, | |
| de_bug=False, | |
| ): | |
| super(AttentionStyle, self).__init__() | |
| self.start_ac_layer = start_ac_layer | |
| self.end_ac_layer = end_ac_layer | |
| self.num_steps=num_steps | |
| self.de_bug = de_bug | |
| self.style_guidance = style_guidance | |
| self.coef = None | |
| self.mix_q_scale = mix_q_scale | |
| def forward(self, q, k, v, num_heads, attention_probs, attn): | |
| if self.de_bug: | |
| import pdb; pdb.set_trace() | |
| if self.mix_q_scale < 1.0: | |
| q[num_heads*2:] = q[num_heads*2:] * self.mix_q_scale + (1 - self.mix_q_scale) * q[num_heads*1:num_heads*2] | |
| b,n,d = k.shape | |
| re_q = q[num_heads*2:] # b,n,d, | |
| re_k = torch.cat([k[num_heads*1:num_heads*2],k[num_heads*0:num_heads*1]],dim=1) #b,2n,d | |
| v_re = torch.cat([v[num_heads*1:num_heads*2],v[num_heads*0:num_heads*1]],dim=1) #b,2n,d | |
| re_sim = torch.bmm(re_q, re_k.transpose(-1, -2)) * attn.scale | |
| re_sim[:,:,n:] = re_sim[:,:,n:] * self.style_guidance | |
| re_attention_map = re_sim.softmax(-1) | |
| h_s_re = torch.bmm(re_attention_map, v_re) | |
| return h_s_re | |