fitfreakdvlp commited on
Commit
581e339
·
verified ·
1 Parent(s): b557978

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +9 -4
pipeline.py CHANGED
@@ -219,10 +219,15 @@ class AAS_XL(AttentionBase):
219
  self.mask = mask # mask with shape (1, 1 ,h, w)
220
  self.ss_steps = ss_steps
221
  self.ss_scale = ss_scale
222
- self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
223
- self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
224
- self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
225
- self.mask_128 = F.max_pool2d(mask, (1024 // 128, 1024 // 128)).round().squeeze().squeeze()
 
 
 
 
 
226
 
227
  def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, is_mask_attn, mask, **kwargs):
228
  B = q.shape[0] // num_heads
 
219
  self.mask = mask # mask with shape (1, 1 ,h, w)
220
  self.ss_steps = ss_steps
221
  self.ss_scale = ss_scale
222
+ # self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
223
+ # self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
224
+ # self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
225
+ # self.mask_128 = F.max_pool2d(mask, (1024 // 128, 1024 // 128)).round().squeeze().squeeze()
226
+ # target_size: 입력된 mask의 크기 (height, width)
227
+ self.mask_16 = F.max_pool2d(mask, (mask.shape[-2] // 16, mask.shape[-1] // 16)).round().squeeze().squeeze()
228
+ self.mask_32 = F.max_pool2d(mask, (mask.shape[-2] // 32, mask.shape[-1] // 32)).round().squeeze().squeeze()
229
+ self.mask_64 = F.max_pool2d(mask, (mask.shape[-2] // 64, mask.shape[-1] // 64)).round().squeeze().squeeze()
230
+ self.mask_128 = F.max_pool2d(mask, (mask.shape[-2] // 128, mask.shape[-1] // 128)).round().squeeze().squeeze()
231
 
232
  def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, is_mask_attn, mask, **kwargs):
233
  B = q.shape[0] // num_heads