Spaces:
Running
Running
Update gradio_utils.py
Browse files- gradio_utils.py +0 -5
gradio_utils.py
CHANGED
|
@@ -42,7 +42,6 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
|
|
| 42 |
temb=None):
|
| 43 |
# un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2)
|
| 44 |
# un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb)
|
| 45 |
-
# 生成一个0到1之间的随机数
|
| 46 |
global total_count,attn_count,cur_step,mask256,mask1024,mask4096
|
| 47 |
global sa16, sa32, sa64
|
| 48 |
global write
|
|
@@ -50,7 +49,6 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
|
|
| 50 |
self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
|
| 51 |
else:
|
| 52 |
encoder_hidden_states = torch.cat(self.id_bank[cur_step][0],hidden_states[:1],self.id_bank[cur_step][1],hidden_states[1:])
|
| 53 |
-
# 判断随机数是否大于0.5
|
| 54 |
if cur_step <5:
|
| 55 |
hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
|
| 56 |
else: # 256 1024 4096
|
|
@@ -260,7 +258,6 @@ def cal_attn_indice_xl_effcient_memory(total_length,id_length,sa32,sa64,height,w
|
|
| 260 |
nums_4096 = (height // 16) * (width // 16)
|
| 261 |
bool_matrix1024 = torch.rand((total_length,nums_1024),device = device,dtype = dtype) < sa32
|
| 262 |
bool_matrix4096 = torch.rand((total_length,nums_4096),device = device,dtype = dtype) < sa64
|
| 263 |
-
# 用nonzero()函数获取所有为True的值的索引
|
| 264 |
indices1024 = [torch.nonzero(bool_matrix1024[i], as_tuple=True)[0] for i in range(total_length)]
|
| 265 |
indices4096 = [torch.nonzero(bool_matrix4096[i], as_tuple=True)[0] for i in range(total_length)]
|
| 266 |
|
|
@@ -431,7 +428,6 @@ def is_torch2_available():
|
|
| 431 |
return hasattr(F, "scaled_dot_product_attention")
|
| 432 |
|
| 433 |
|
| 434 |
-
# 将列表转换为字典的函数
|
| 435 |
def character_to_dict(general_prompt):
|
| 436 |
character_dict = {}
|
| 437 |
generate_prompt_arr = general_prompt.splitlines()
|
|
@@ -439,7 +435,6 @@ def character_to_dict(general_prompt):
|
|
| 439 |
invert_character_index_dict = {}
|
| 440 |
character_list = []
|
| 441 |
for ind,string in enumerate(generate_prompt_arr):
|
| 442 |
-
# 分割字符串寻找key和value
|
| 443 |
start = string.find('[')
|
| 444 |
end = string.find(']')
|
| 445 |
if start != -1 and end != -1:
|
|
|
|
| 42 |
temb=None):
|
| 43 |
# un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2)
|
| 44 |
# un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb)
|
|
|
|
| 45 |
global total_count,attn_count,cur_step,mask256,mask1024,mask4096
|
| 46 |
global sa16, sa32, sa64
|
| 47 |
global write
|
|
|
|
| 49 |
self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
|
| 50 |
else:
|
| 51 |
encoder_hidden_states = torch.cat(self.id_bank[cur_step][0],hidden_states[:1],self.id_bank[cur_step][1],hidden_states[1:])
|
|
|
|
| 52 |
if cur_step <5:
|
| 53 |
hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
|
| 54 |
else: # 256 1024 4096
|
|
|
|
| 258 |
nums_4096 = (height // 16) * (width // 16)
|
| 259 |
bool_matrix1024 = torch.rand((total_length,nums_1024),device = device,dtype = dtype) < sa32
|
| 260 |
bool_matrix4096 = torch.rand((total_length,nums_4096),device = device,dtype = dtype) < sa64
|
|
|
|
| 261 |
indices1024 = [torch.nonzero(bool_matrix1024[i], as_tuple=True)[0] for i in range(total_length)]
|
| 262 |
indices4096 = [torch.nonzero(bool_matrix4096[i], as_tuple=True)[0] for i in range(total_length)]
|
| 263 |
|
|
|
|
| 428 |
return hasattr(F, "scaled_dot_product_attention")
|
| 429 |
|
| 430 |
|
|
|
|
| 431 |
def character_to_dict(general_prompt):
|
| 432 |
character_dict = {}
|
| 433 |
generate_prompt_arr = general_prompt.splitlines()
|
|
|
|
| 435 |
invert_character_index_dict = {}
|
| 436 |
character_list = []
|
| 437 |
for ind,string in enumerate(generate_prompt_arr):
|
|
|
|
| 438 |
start = string.find('[')
|
| 439 |
end = string.find(']')
|
| 440 |
if start != -1 and end != -1:
|