Spaces:
Configuration error
Configuration error
arthur-qiu commited on
Commit ·
f833804
1
Parent(s): cca304f
fix filter
Browse files- scale_attention.py +6 -3
- scale_attention_turbo.py +6 -3
scale_attention.py
CHANGED
|
@@ -159,6 +159,7 @@ def scale_forward(
|
|
| 159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 160 |
attn_output = torch.where(count>0, value/count, value)
|
| 161 |
|
|
|
|
| 162 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 163 |
|
| 164 |
attn_output_global = self.attn1(
|
|
@@ -167,12 +168,12 @@ def scale_forward(
|
|
| 167 |
attention_mask=attention_mask,
|
| 168 |
**cross_attention_kwargs,
|
| 169 |
)
|
| 170 |
-
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w
|
| 171 |
|
| 172 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 173 |
|
| 174 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 175 |
-
attn_output = rearrange(attn_output, 'bh h w
|
| 176 |
|
| 177 |
elif fourg_window:
|
| 178 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
@@ -198,6 +199,7 @@ def scale_forward(
|
|
| 198 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 199 |
attn_output = torch.where(count>0, value/count, value)
|
| 200 |
|
|
|
|
| 201 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 202 |
|
| 203 |
value = torch.zeros_like(norm_hidden_states)
|
|
@@ -219,10 +221,11 @@ def scale_forward(
|
|
| 219 |
|
| 220 |
attn_output_global = torch.where(count>0, value/count, value)
|
| 221 |
|
|
|
|
| 222 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 223 |
|
| 224 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 225 |
-
attn_output = rearrange(attn_output, 'bh h w
|
| 226 |
|
| 227 |
else:
|
| 228 |
attn_output = self.attn1(
|
|
|
|
| 159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 160 |
attn_output = torch.where(count>0, value/count, value)
|
| 161 |
|
| 162 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
| 163 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 164 |
|
| 165 |
attn_output_global = self.attn1(
|
|
|
|
| 168 |
attention_mask=attention_mask,
|
| 169 |
**cross_attention_kwargs,
|
| 170 |
)
|
| 171 |
+
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh d h w', h = latent_h)
|
| 172 |
|
| 173 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 174 |
|
| 175 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 176 |
+
attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
|
| 177 |
|
| 178 |
elif fourg_window:
|
| 179 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
|
|
| 199 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 200 |
attn_output = torch.where(count>0, value/count, value)
|
| 201 |
|
| 202 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
| 203 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 204 |
|
| 205 |
value = torch.zeros_like(norm_hidden_states)
|
|
|
|
| 221 |
|
| 222 |
attn_output_global = torch.where(count>0, value/count, value)
|
| 223 |
|
| 224 |
+
attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
|
| 225 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 226 |
|
| 227 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 228 |
+
attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
|
| 229 |
|
| 230 |
else:
|
| 231 |
attn_output = self.attn1(
|
scale_attention_turbo.py
CHANGED
|
@@ -159,6 +159,7 @@ def scale_forward(
|
|
| 159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 160 |
attn_output = torch.where(count>0, value/count, value)
|
| 161 |
|
|
|
|
| 162 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 163 |
|
| 164 |
attn_output_global = self.attn1(
|
|
@@ -167,12 +168,12 @@ def scale_forward(
|
|
| 167 |
attention_mask=attention_mask,
|
| 168 |
**cross_attention_kwargs,
|
| 169 |
)
|
| 170 |
-
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w
|
| 171 |
|
| 172 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 173 |
|
| 174 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 175 |
-
attn_output = rearrange(attn_output, 'bh h w
|
| 176 |
|
| 177 |
elif fourg_window:
|
| 178 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
@@ -198,6 +199,7 @@ def scale_forward(
|
|
| 198 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 199 |
attn_output = torch.where(count>0, value/count, value)
|
| 200 |
|
|
|
|
| 201 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 202 |
|
| 203 |
value = torch.zeros_like(norm_hidden_states)
|
|
@@ -219,10 +221,11 @@ def scale_forward(
|
|
| 219 |
|
| 220 |
attn_output_global = torch.where(count>0, value/count, value)
|
| 221 |
|
|
|
|
| 222 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 223 |
|
| 224 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 225 |
-
attn_output = rearrange(attn_output, 'bh h w
|
| 226 |
|
| 227 |
else:
|
| 228 |
attn_output = self.attn1(
|
|
|
|
| 159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 160 |
attn_output = torch.where(count>0, value/count, value)
|
| 161 |
|
| 162 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
| 163 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 164 |
|
| 165 |
attn_output_global = self.attn1(
|
|
|
|
| 168 |
attention_mask=attention_mask,
|
| 169 |
**cross_attention_kwargs,
|
| 170 |
)
|
| 171 |
+
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh d h w', h = latent_h)
|
| 172 |
|
| 173 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 174 |
|
| 175 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 176 |
+
attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
|
| 177 |
|
| 178 |
elif fourg_window:
|
| 179 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
|
|
| 199 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 200 |
attn_output = torch.where(count>0, value/count, value)
|
| 201 |
|
| 202 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
| 203 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 204 |
|
| 205 |
value = torch.zeros_like(norm_hidden_states)
|
|
|
|
| 221 |
|
| 222 |
attn_output_global = torch.where(count>0, value/count, value)
|
| 223 |
|
| 224 |
+
attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
|
| 225 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 226 |
|
| 227 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 228 |
+
attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
|
| 229 |
|
| 230 |
else:
|
| 231 |
attn_output = self.attn1(
|