Update pipeline.py
Browse files- pipeline.py +9 -4
pipeline.py
CHANGED
|
@@ -219,10 +219,15 @@ class AAS_XL(AttentionBase):
|
|
| 219 |
self.mask = mask # mask with shape (1, 1 ,h, w)
|
| 220 |
self.ss_steps = ss_steps
|
| 221 |
self.ss_scale = ss_scale
|
| 222 |
-
self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
|
| 223 |
-
self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
|
| 224 |
-
self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
|
| 225 |
-
self.mask_128 = F.max_pool2d(mask, (1024 // 128, 1024 // 128)).round().squeeze().squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, is_mask_attn, mask, **kwargs):
|
| 228 |
B = q.shape[0] // num_heads
|
|
|
|
| 219 |
self.mask = mask # mask with shape (1, 1 ,h, w)
|
| 220 |
self.ss_steps = ss_steps
|
| 221 |
self.ss_scale = ss_scale
|
| 222 |
+
# self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
|
| 223 |
+
# self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
|
| 224 |
+
# self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
|
| 225 |
+
# self.mask_128 = F.max_pool2d(mask, (1024 // 128, 1024 // 128)).round().squeeze().squeeze()
|
| 226 |
+
# target_size: 입력된 mask의 크기 (height, width)
|
| 227 |
+
self.mask_16 = F.max_pool2d(mask, (mask.shape[-2] // 16, mask.shape[-1] // 16)).round().squeeze().squeeze()
|
| 228 |
+
self.mask_32 = F.max_pool2d(mask, (mask.shape[-2] // 32, mask.shape[-1] // 32)).round().squeeze().squeeze()
|
| 229 |
+
self.mask_64 = F.max_pool2d(mask, (mask.shape[-2] // 64, mask.shape[-1] // 64)).round().squeeze().squeeze()
|
| 230 |
+
self.mask_128 = F.max_pool2d(mask, (mask.shape[-2] // 128, mask.shape[-1] // 128)).round().squeeze().squeeze()
|
| 231 |
|
| 232 |
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, is_mask_attn, mask, **kwargs):
|
| 233 |
B = q.shape[0] // num_heads
|