Spaces:
Runtime error
Runtime error
fix: sinkformer
Browse files
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -211,7 +211,7 @@ def dot_product_attention_weights(
|
|
| 211 |
dtype: Any = jnp.float32,
|
| 212 |
precision: PrecisionLike = None,
|
| 213 |
sinkhorn_iters: int = 1,
|
| 214 |
-
|
| 215 |
):
|
| 216 |
"""
|
| 217 |
Computes dot-product attention weights given query and key.
|
|
@@ -239,7 +239,7 @@ def dot_product_attention_weights(
|
|
| 239 |
attn_weights = attn_weights + embed_pos
|
| 240 |
|
| 241 |
# normalize the attention weights
|
| 242 |
-
if
|
| 243 |
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
| 244 |
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
| 245 |
else:
|
|
@@ -461,7 +461,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 461 |
dtype=self.dtype,
|
| 462 |
precision=None,
|
| 463 |
sinkhorn_iters=self.config.sinkhorn_iters,
|
| 464 |
-
|
| 465 |
)
|
| 466 |
if self.config.use_cosine_attention:
|
| 467 |
# divide by tau
|
|
|
|
| 211 |
dtype: Any = jnp.float32,
|
| 212 |
precision: PrecisionLike = None,
|
| 213 |
sinkhorn_iters: int = 1,
|
| 214 |
+
is_encoder: bool = False,
|
| 215 |
):
|
| 216 |
"""
|
| 217 |
Computes dot-product attention weights given query and key.
|
|
|
|
| 239 |
attn_weights = attn_weights + embed_pos
|
| 240 |
|
| 241 |
# normalize the attention weights
|
| 242 |
+
if not is_encoder or sinkhorn_iters == 1:
|
| 243 |
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
| 244 |
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
| 245 |
else:
|
|
|
|
| 461 |
dtype=self.dtype,
|
| 462 |
precision=None,
|
| 463 |
sinkhorn_iters=self.config.sinkhorn_iters,
|
| 464 |
+
is_encoder=self.is_encoder,
|
| 465 |
)
|
| 466 |
if self.config.use_cosine_attention:
|
| 467 |
# divide by tau
|