Spaces:
Sleeping
Sleeping
Add support for non-divisible batch sizes
Browse files- src/smc/pipeline.py +15 -0
src/smc/pipeline.py
CHANGED
|
@@ -581,6 +581,21 @@ class Pipeline(
|
|
| 581 |
else:
|
| 582 |
img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[1],2*model_input.shape[2],model_input.device,model_input.dtype)
|
| 583 |
txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
model_output = self.transformer(
|
| 585 |
hidden_states = model_input,
|
| 586 |
micro_conds=micro_conds,
|
|
|
|
| 581 |
else:
|
| 582 |
img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[1],2*model_input.shape[2],model_input.device,model_input.dtype)
|
| 583 |
txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
|
| 584 |
+
|
| 585 |
+
if prompt_embeds.shape[0] != model_input.shape[0]:
|
| 586 |
+
# This can happen for the last batch (if batch_p is not divisble by total particles)
|
| 587 |
+
if guidance_scale > 1.0:
|
| 588 |
+
batch_p = prompt_embeds.shape[0] // 2
|
| 589 |
+
last_batch_size = model_input.shape[0] // 2
|
| 590 |
+
prompt_embeds = torch.cat([prompt_embeds[:last_batch_size], prompt_embeds[batch_p :batch_p + last_batch_size]])
|
| 591 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states[:last_batch_size], encoder_hidden_states[batch_p :batch_p + last_batch_size]])
|
| 592 |
+
micro_conds = torch.cat([micro_conds[:last_batch_size], micro_conds[batch_p :batch_p + last_batch_size]])
|
| 593 |
+
else:
|
| 594 |
+
last_batch_size = model_input.shape[0]
|
| 595 |
+
prompt_embeds = prompt_embeds[:last_batch_size]
|
| 596 |
+
encoder_hidden_states = encoder_hidden_states[:last_batch_size]
|
| 597 |
+
micro_conds = micro_conds[:last_batch_size]
|
| 598 |
+
|
| 599 |
model_output = self.transformer(
|
| 600 |
hidden_states = model_input,
|
| 601 |
micro_conds=micro_conds,
|