cp524 commited on
Commit
f1f0a8f
·
1 Parent(s): b1beaa0

Add support for non-divisible batch sizes

Browse files
Files changed (1) hide show
  1. src/smc/pipeline.py +15 -0
src/smc/pipeline.py CHANGED
@@ -581,6 +581,21 @@ class Pipeline(
581
  else:
582
  img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[1],2*model_input.shape[2],model_input.device,model_input.dtype)
583
  txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  model_output = self.transformer(
585
  hidden_states = model_input,
586
  micro_conds=micro_conds,
 
581
  else:
582
  img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[1],2*model_input.shape[2],model_input.device,model_input.dtype)
583
  txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
584
+
585
+ if prompt_embeds.shape[0] != model_input.shape[0]:
586
+ # This can happen for the last batch (if batch_p is not divisble by total particles)
587
+ if guidance_scale > 1.0:
588
+ batch_p = prompt_embeds.shape[0] // 2
589
+ last_batch_size = model_input.shape[0] // 2
590
+ prompt_embeds = torch.cat([prompt_embeds[:last_batch_size], prompt_embeds[batch_p :batch_p + last_batch_size]])
591
+ encoder_hidden_states = torch.cat([encoder_hidden_states[:last_batch_size], encoder_hidden_states[batch_p :batch_p + last_batch_size]])
592
+ micro_conds = torch.cat([micro_conds[:last_batch_size], micro_conds[batch_p :batch_p + last_batch_size]])
593
+ else:
594
+ last_batch_size = model_input.shape[0]
595
+ prompt_embeds = prompt_embeds[:last_batch_size]
596
+ encoder_hidden_states = encoder_hidden_states[:last_batch_size]
597
+ micro_conds = micro_conds[:last_batch_size]
598
+
599
  model_output = self.transformer(
600
  hidden_states = model_input,
601
  micro_conds=micro_conds,