Update src/utils.py
Browse files- src/utils.py +8 -4
src/utils.py
CHANGED
|
@@ -193,7 +193,6 @@ def warpped_feature(sample, step):
|
|
| 193 |
def warpped_skip_feature(block_samples, step):
|
| 194 |
down_block_res_samples = []
|
| 195 |
# print(block_samples.shape, step)
|
| 196 |
-
print(step)
|
| 197 |
for sample in block_samples:
|
| 198 |
sample_expand = warpped_feature(sample, step)
|
| 199 |
down_block_res_samples.append(sample_expand)
|
|
@@ -205,9 +204,14 @@ def warpped_text_emb(text_emb, step):
|
|
| 205 |
step: timestep span
|
| 206 |
"""
|
| 207 |
bs, token_len, dim = text_emb.shape
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
return torch.cat([uncond_fea, cond_fea]) # (step*bs) * 77 *768
|
| 212 |
|
| 213 |
def warpped_timestep(timesteps, bs):
|
|
|
|
| 193 |
def warpped_skip_feature(block_samples, step):
|
| 194 |
down_block_res_samples = []
|
| 195 |
# print(block_samples.shape, step)
|
|
|
|
| 196 |
for sample in block_samples:
|
| 197 |
sample_expand = warpped_feature(sample, step)
|
| 198 |
down_block_res_samples.append(sample_expand)
|
|
|
|
| 204 |
step: timestep span
|
| 205 |
"""
|
| 206 |
bs, token_len, dim = text_emb.shape
|
| 207 |
+
if text_emb.shape[0] >= 2:
|
| 208 |
+
uncond_fea, cond_fea = text_emb.chunk(2)
|
| 209 |
+
uncond_fea = uncond_fea.repeat(step,1,1) # (step * bs//2) * 77 *768
|
| 210 |
+
cond_fea = cond_fea.repeat(step,1,1) # (step * bs//2) * 77 * 768
|
| 211 |
+
else:
|
| 212 |
+
uncond_fea, cond_fea = text_emb, text_emb
|
| 213 |
+
uncond_fea = uncond_fea.repeat(step,1,1) # (step * bs//2) * 77 *768
|
| 214 |
+
cond_fea = cond_fea.repeat(step,1,1) # (step * bs//2) * 77 * 768
|
| 215 |
return torch.cat([uncond_fea, cond_fea]) # (step*bs) * 77 *768
|
| 216 |
|
| 217 |
def warpped_timestep(timesteps, bs):
|