manbeast3b commited on
Commit
4f23ede
·
verified ·
1 Parent(s): a283739

Update src/utils.py

Browse files
Files changed (1) hide show
  1. src/utils.py +8 -4
src/utils.py CHANGED
@@ -193,7 +193,6 @@ def warpped_feature(sample, step):
193
  def warpped_skip_feature(block_samples, step):
194
  down_block_res_samples = []
195
  # print(block_samples.shape, step)
196
- print(step)
197
  for sample in block_samples:
198
  sample_expand = warpped_feature(sample, step)
199
  down_block_res_samples.append(sample_expand)
@@ -205,9 +204,14 @@ def warpped_text_emb(text_emb, step):
205
  step: timestep span
206
  """
207
  bs, token_len, dim = text_emb.shape
208
- uncond_fea, cond_fea = text_emb.chunk(2)
209
- uncond_fea = uncond_fea.repeat(step,1,1) # (step * bs//2) * 77 *768
210
- cond_fea = cond_fea.repeat(step,1,1) # (step * bs//2) * 77 * 768
 
 
 
 
 
211
  return torch.cat([uncond_fea, cond_fea]) # (step*bs) * 77 *768
212
 
213
  def warpped_timestep(timesteps, bs):
 
193
  def warpped_skip_feature(block_samples, step):
194
  down_block_res_samples = []
195
  # print(block_samples.shape, step)
 
196
  for sample in block_samples:
197
  sample_expand = warpped_feature(sample, step)
198
  down_block_res_samples.append(sample_expand)
 
204
  step: timestep span
205
  """
206
  bs, token_len, dim = text_emb.shape
207
+ if text_emb.shape[0] >= 2:
208
+ uncond_fea, cond_fea = text_emb.chunk(2)
209
+ uncond_fea = uncond_fea.repeat(step,1,1) # (step * bs//2) * 77 *768
210
+ cond_fea = cond_fea.repeat(step,1,1) # (step * bs//2) * 77 * 768
211
+ else:
212
+ uncond_fea, cond_fea = text_emb, text_emb
213
+ uncond_fea = uncond_fea.repeat(step,1,1) # (step * bs//2) * 77 *768
214
+ cond_fea = cond_fea.repeat(step,1,1) # (step * bs//2) * 77 * 768
215
  return torch.cat([uncond_fea, cond_fea]) # (step*bs) * 77 *768
216
 
217
  def warpped_timestep(timesteps, bs):