Spaces:
Running
on
Zero
Running
on
Zero
hatmanstack
commited on
Commit
·
28111ae
1
Parent(s):
e59f1dc
add SD35AdaLayerNormZeroX
Browse files- models_attention.py +35 -1
models_attention.py
CHANGED
|
@@ -22,12 +22,13 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
| 22 |
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
| 23 |
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
|
| 24 |
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
| 25 |
-
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
| 26 |
|
| 27 |
|
| 28 |
logger = logging.get_logger(__name__)
|
| 29 |
|
| 30 |
|
|
|
|
| 31 |
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
| 32 |
# "feed_forward_chunk_size" can be used to save memory
|
| 33 |
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
|
@@ -42,6 +43,39 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
|
|
| 42 |
)
|
| 43 |
return ff_output
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
@maybe_allow_in_graph
|
| 47 |
class GatedSelfAttentionDense(nn.Module):
|
|
|
|
| 22 |
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
| 23 |
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
|
| 24 |
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
| 25 |
+
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
| 26 |
|
| 27 |
|
| 28 |
logger = logging.get_logger(__name__)
|
| 29 |
|
| 30 |
|
| 31 |
+
|
| 32 |
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
| 33 |
# "feed_forward_chunk_size" can be used to save memory
|
| 34 |
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
|
|
|
| 43 |
)
|
| 44 |
return ff_output
|
| 45 |
|
| 46 |
+
@maybe_allow_in_graph
|
| 47 |
+
class SD35AdaLayerNormZeroX(nn.Module):
|
| 48 |
+
r"""
|
| 49 |
+
Norm layer adaptive layer norm zero (AdaLN-Zero).
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
embedding_dim (`int`): The size of each embedding vector.
|
| 53 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
self.silu = nn.SiLU()
|
| 60 |
+
self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
|
| 61 |
+
if norm_type == "layer_norm":
|
| 62 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
hidden_states: torch.Tensor,
|
| 69 |
+
emb: Optional[torch.Tensor] = None,
|
| 70 |
+
) -> Tuple[torch.Tensor, ...]:
|
| 71 |
+
emb = self.linear(self.silu(emb))
|
| 72 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
|
| 73 |
+
9, dim=1
|
| 74 |
+
)
|
| 75 |
+
norm_hidden_states = self.norm(hidden_states)
|
| 76 |
+
hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 77 |
+
norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
|
| 78 |
+
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
|
| 79 |
|
| 80 |
@maybe_allow_in_graph
|
| 81 |
class GatedSelfAttentionDense(nn.Module):
|