Upload pipeline.py
Browse files- pipeline.py +26 -26
pipeline.py
CHANGED
|
@@ -1611,18 +1611,33 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
|
| 1611 |
prompt_mask_input = prompt_mask
|
| 1612 |
latent_model_input = latents
|
| 1613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1614 |
if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
| 1615 |
# Concatenate prompt embeddings
|
| 1616 |
prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1617 |
pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 1618 |
|
| 1619 |
-
#
|
| 1620 |
-
if text_ids is not None and negative_text_ids is not None:
|
| 1621 |
-
|
| 1622 |
|
| 1623 |
# Concatenate latent image IDs if they are used
|
| 1624 |
-
if latent_image_ids is not None:
|
| 1625 |
-
|
| 1626 |
|
| 1627 |
# Concatenate prompt masks if they are used
|
| 1628 |
if prompt_mask is not None and negative_mask is not None:
|
|
@@ -1643,37 +1658,22 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
|
| 1643 |
# Prepare extra transformer arguments
|
| 1644 |
extra_transformer_args = {}
|
| 1645 |
if prompt_mask is not None:
|
| 1646 |
-
extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device)
|
| 1647 |
|
| 1648 |
# Forward pass through the transformer
|
| 1649 |
noise_pred = self.transformer(
|
| 1650 |
-
hidden_states=latent_model_input.to(device=self.transformer.device),
|
| 1651 |
timestep=timestep / 1000,
|
| 1652 |
guidance=guidance,
|
| 1653 |
-
pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
|
| 1654 |
-
encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
|
| 1655 |
-
txt_ids=text_ids_input.to(device=self.transformer.device) if text_ids is not None else None,
|
| 1656 |
-
img_ids=latent_image_ids_input.to(device=self.transformer.device) if latent_image_ids is not None else None,
|
| 1657 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1658 |
return_dict=False,
|
| 1659 |
**extra_transformer_args,
|
| 1660 |
)[0]
|
| 1661 |
|
| 1662 |
-
if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
| 1663 |
-
progress_bar.set_postfix(
|
| 1664 |
-
{
|
| 1665 |
-
'ts': timestep.detach().item() / 1000,
|
| 1666 |
-
'cfg': self._guidance_scale_real,
|
| 1667 |
-
},
|
| 1668 |
-
)
|
| 1669 |
-
else:
|
| 1670 |
-
progress_bar.set_postfix(
|
| 1671 |
-
{
|
| 1672 |
-
'ts': timestep.detach().item() / 1000,
|
| 1673 |
-
'cfg': 'N/A',
|
| 1674 |
-
},
|
| 1675 |
-
)
|
| 1676 |
-
|
| 1677 |
# Apply real CFG
|
| 1678 |
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
| 1679 |
if do_batch_cfg:
|
|
|
|
| 1611 |
prompt_mask_input = prompt_mask
|
| 1612 |
latent_model_input = latents
|
| 1613 |
|
| 1614 |
+
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
| 1615 |
+
progress_bar.set_postfix(
|
| 1616 |
+
{
|
| 1617 |
+
'ts': timestep.detach().item() / 1000,
|
| 1618 |
+
'cfg': self._guidance_scale_real,
|
| 1619 |
+
},
|
| 1620 |
+
)
|
| 1621 |
+
else:
|
| 1622 |
+
progress_bar.set_postfix(
|
| 1623 |
+
{
|
| 1624 |
+
'ts': timestep.detach().item() / 1000,
|
| 1625 |
+
'cfg': 'N/A',
|
| 1626 |
+
},
|
| 1627 |
+
)
|
| 1628 |
+
|
| 1629 |
if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
| 1630 |
# Concatenate prompt embeddings
|
| 1631 |
prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1632 |
pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 1633 |
|
| 1634 |
+
# Concatenate text IDs if they are used
|
| 1635 |
+
# if text_ids is not None and negative_text_ids is not None:
|
| 1636 |
+
# text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
|
| 1637 |
|
| 1638 |
# Concatenate latent image IDs if they are used
|
| 1639 |
+
# if latent_image_ids is not None:
|
| 1640 |
+
# latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
|
| 1641 |
|
| 1642 |
# Concatenate prompt masks if they are used
|
| 1643 |
if prompt_mask is not None and negative_mask is not None:
|
|
|
|
| 1658 |
# Prepare extra transformer arguments
|
| 1659 |
extra_transformer_args = {}
|
| 1660 |
if prompt_mask is not None:
|
| 1661 |
+
extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device).contiguous()
|
| 1662 |
|
| 1663 |
# Forward pass through the transformer
|
| 1664 |
noise_pred = self.transformer(
|
| 1665 |
+
hidden_states=latent_model_input.to(device=self.transformer.device).contiguous() ,
|
| 1666 |
timestep=timestep / 1000,
|
| 1667 |
guidance=guidance,
|
| 1668 |
+
pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device).contiguous() ,
|
| 1669 |
+
encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device).contiguous() ,
|
| 1670 |
+
txt_ids=text_ids_input.to(device=self.transformer.device).contiguous() if text_ids is not None else None,
|
| 1671 |
+
img_ids=latent_image_ids_input.to(device=self.transformer.device).contiguous() if latent_image_ids is not None else None,
|
| 1672 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1673 |
return_dict=False,
|
| 1674 |
**extra_transformer_args,
|
| 1675 |
)[0]
|
| 1676 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1677 |
# Apply real CFG
|
| 1678 |
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
| 1679 |
if do_batch_cfg:
|