Spaces:
Configuration error
Configuration error
arthur-qiu
commited on
Commit
·
9330af0
1
Parent(s):
f833804
fix filter
Browse files- scale_attention.py +7 -10
- scale_attention_turbo.py +7 -10
scale_attention.py
CHANGED
|
@@ -9,15 +9,15 @@ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
|
|
| 9 |
x_coord = torch.arange(kernel_size)
|
| 10 |
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
|
| 11 |
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
| 12 |
-
|
| 13 |
-
kernel =
|
| 14 |
|
| 15 |
return kernel
|
| 16 |
|
| 17 |
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
| 18 |
-
channels = latents.shape[
|
| 19 |
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
|
| 20 |
-
blurred_latents = F.
|
| 21 |
|
| 22 |
return blurred_latents
|
| 23 |
|
|
@@ -159,7 +159,6 @@ def scale_forward(
|
|
| 159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 160 |
attn_output = torch.where(count>0, value/count, value)
|
| 161 |
|
| 162 |
-
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
| 163 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 164 |
|
| 165 |
attn_output_global = self.attn1(
|
|
@@ -168,12 +167,12 @@ def scale_forward(
|
|
| 168 |
attention_mask=attention_mask,
|
| 169 |
**cross_attention_kwargs,
|
| 170 |
)
|
| 171 |
-
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh
|
| 172 |
|
| 173 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 174 |
|
| 175 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 176 |
-
attn_output = rearrange(attn_output, 'bh
|
| 177 |
|
| 178 |
elif fourg_window:
|
| 179 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
@@ -199,7 +198,6 @@ def scale_forward(
|
|
| 199 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 200 |
attn_output = torch.where(count>0, value/count, value)
|
| 201 |
|
| 202 |
-
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
| 203 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 204 |
|
| 205 |
value = torch.zeros_like(norm_hidden_states)
|
|
@@ -221,11 +219,10 @@ def scale_forward(
|
|
| 221 |
|
| 222 |
attn_output_global = torch.where(count>0, value/count, value)
|
| 223 |
|
| 224 |
-
attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
|
| 225 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 226 |
|
| 227 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 228 |
-
attn_output = rearrange(attn_output, 'bh
|
| 229 |
|
| 230 |
else:
|
| 231 |
attn_output = self.attn1(
|
|
|
|
| 9 |
x_coord = torch.arange(kernel_size)
|
| 10 |
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
|
| 11 |
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
| 12 |
+
gaussian_3d = gaussian_1d[:, None, None] * gaussian_1d[None, :, None] * gaussian_1d[None, None, :]
|
| 13 |
+
kernel = gaussian_3d[None, None, :, :, :].repeat(channels, 1, 1, 1, 1)
|
| 14 |
|
| 15 |
return kernel
|
| 16 |
|
| 17 |
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
| 18 |
+
channels = latents.shape[0]
|
| 19 |
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
|
| 20 |
+
blurred_latents = F.conv3d(latents.unsqueeze(0), kernel, padding=kernel_size//2, groups=channels)[0]
|
| 21 |
|
| 22 |
return blurred_latents
|
| 23 |
|
|
|
|
| 159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 160 |
attn_output = torch.where(count>0, value/count, value)
|
| 161 |
|
|
|
|
| 162 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 163 |
|
| 164 |
attn_output_global = self.attn1(
|
|
|
|
| 167 |
attention_mask=attention_mask,
|
| 168 |
**cross_attention_kwargs,
|
| 169 |
)
|
| 170 |
+
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
|
| 171 |
|
| 172 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 173 |
|
| 174 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 175 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
| 176 |
|
| 177 |
elif fourg_window:
|
| 178 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
|
|
| 198 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 199 |
attn_output = torch.where(count>0, value/count, value)
|
| 200 |
|
|
|
|
| 201 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 202 |
|
| 203 |
value = torch.zeros_like(norm_hidden_states)
|
|
|
|
| 219 |
|
| 220 |
attn_output_global = torch.where(count>0, value/count, value)
|
| 221 |
|
|
|
|
| 222 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 223 |
|
| 224 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 225 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
| 226 |
|
| 227 |
else:
|
| 228 |
attn_output = self.attn1(
|
scale_attention_turbo.py
CHANGED
|
@@ -9,15 +9,15 @@ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
|
|
| 9 |
x_coord = torch.arange(kernel_size)
|
| 10 |
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
|
| 11 |
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
| 12 |
-
|
| 13 |
-
kernel =
|
| 14 |
|
| 15 |
return kernel
|
| 16 |
|
| 17 |
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
| 18 |
-
channels = latents.shape[
|
| 19 |
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
|
| 20 |
-
blurred_latents = F.
|
| 21 |
|
| 22 |
return blurred_latents
|
| 23 |
|
|
@@ -159,7 +159,6 @@ def scale_forward(
|
|
| 159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 160 |
attn_output = torch.where(count>0, value/count, value)
|
| 161 |
|
| 162 |
-
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
| 163 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 164 |
|
| 165 |
attn_output_global = self.attn1(
|
|
@@ -168,12 +167,12 @@ def scale_forward(
|
|
| 168 |
attention_mask=attention_mask,
|
| 169 |
**cross_attention_kwargs,
|
| 170 |
)
|
| 171 |
-
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh
|
| 172 |
|
| 173 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 174 |
|
| 175 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 176 |
-
attn_output = rearrange(attn_output, 'bh
|
| 177 |
|
| 178 |
elif fourg_window:
|
| 179 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
@@ -199,7 +198,6 @@ def scale_forward(
|
|
| 199 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 200 |
attn_output = torch.where(count>0, value/count, value)
|
| 201 |
|
| 202 |
-
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
| 203 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 204 |
|
| 205 |
value = torch.zeros_like(norm_hidden_states)
|
|
@@ -221,11 +219,10 @@ def scale_forward(
|
|
| 221 |
|
| 222 |
attn_output_global = torch.where(count>0, value/count, value)
|
| 223 |
|
| 224 |
-
attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
|
| 225 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 226 |
|
| 227 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 228 |
-
attn_output = rearrange(attn_output, 'bh
|
| 229 |
|
| 230 |
else:
|
| 231 |
attn_output = self.attn1(
|
|
|
|
| 9 |
x_coord = torch.arange(kernel_size)
|
| 10 |
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
|
| 11 |
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
| 12 |
+
gaussian_3d = gaussian_1d[:, None, None] * gaussian_1d[None, :, None] * gaussian_1d[None, None, :]
|
| 13 |
+
kernel = gaussian_3d[None, None, :, :, :].repeat(channels, 1, 1, 1, 1)
|
| 14 |
|
| 15 |
return kernel
|
| 16 |
|
| 17 |
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
| 18 |
+
channels = latents.shape[0]
|
| 19 |
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
|
| 20 |
+
blurred_latents = F.conv3d(latents.unsqueeze(0), kernel, padding=kernel_size//2, groups=channels)[0]
|
| 21 |
|
| 22 |
return blurred_latents
|
| 23 |
|
|
|
|
| 159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 160 |
attn_output = torch.where(count>0, value/count, value)
|
| 161 |
|
|
|
|
| 162 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 163 |
|
| 164 |
attn_output_global = self.attn1(
|
|
|
|
| 167 |
attention_mask=attention_mask,
|
| 168 |
**cross_attention_kwargs,
|
| 169 |
)
|
| 170 |
+
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
|
| 171 |
|
| 172 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 173 |
|
| 174 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 175 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
| 176 |
|
| 177 |
elif fourg_window:
|
| 178 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
|
|
| 198 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
| 199 |
attn_output = torch.where(count>0, value/count, value)
|
| 200 |
|
|
|
|
| 201 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 202 |
|
| 203 |
value = torch.zeros_like(norm_hidden_states)
|
|
|
|
| 219 |
|
| 220 |
attn_output_global = torch.where(count>0, value/count, value)
|
| 221 |
|
|
|
|
| 222 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
| 223 |
|
| 224 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
| 225 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
| 226 |
|
| 227 |
else:
|
| 228 |
attn_output = self.attn1(
|