Spaces:
Runtime error
Runtime error
feat: add sinkformer + custom final ln + pre-ln (#151)
Browse files- README.md +11 -1
- src/dalle_mini/model/configuration.py +22 -12
- src/dalle_mini/model/modeling.py +75 -17
README.md
CHANGED
|
@@ -124,8 +124,9 @@ Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-t
|
|
| 124 |
- "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
|
| 125 |
- "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
|
| 126 |
- "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
|
| 127 |
-
- "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)
|
| 128 |
- "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
|
|
|
|
| 129 |
|
| 130 |
Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
|
| 131 |
|
|
@@ -247,3 +248,12 @@ Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization f
|
|
| 247 |
primaryClass = {cs.LG}
|
| 248 |
}
|
| 249 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
- "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
|
| 125 |
- "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
|
| 126 |
- "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
|
| 127 |
+
- "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)"
|
| 128 |
- "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
|
| 129 |
+
- "[Sinkformers: Transformers with Doubly Stochastic Attention](https://arxiv.org/abs/2110.11773)"
|
| 130 |
|
| 131 |
Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
|
| 132 |
|
|
|
|
| 248 |
primaryClass = {cs.LG}
|
| 249 |
}
|
| 250 |
```
|
| 251 |
+
|
| 252 |
+
```text
|
| 253 |
+
@misc{title = {Sinkformers: Transformers with Doubly Stochastic Attention},
|
| 254 |
+
url = {https://arxiv.org/abs/2110.11773},
|
| 255 |
+
author = {Sander, Michael E. and Ablin, Pierre and Blondel, Mathieu and Peyré, Gabriel},
|
| 256 |
+
publisher = {arXiv},
|
| 257 |
+
year = {2021},
|
| 258 |
+
}
|
| 259 |
+
```
|
src/dalle_mini/model/configuration.py
CHANGED
|
@@ -59,37 +59,39 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 59 |
do_sample=True,
|
| 60 |
# transformer variants
|
| 61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
| 62 |
-
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
|
| 63 |
-
|
| 64 |
use_cosine_attention=False, # used in Swin v2
|
| 65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
| 66 |
use_deepnet_scaling=False, # used in Deepnet
|
| 67 |
use_glu=False, # "GLU Variants Improve Transformer"
|
| 68 |
use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
# parameters that should not be necessary but could affect results
|
| 71 |
-
force_ln_scale=
|
| 72 |
-
force_final_ln_encoder=False, # force layer normalization in encoder final layer even when followed by dense layers
|
| 73 |
**kwargs,
|
| 74 |
):
|
| 75 |
# text normalizer
|
| 76 |
self.normalize_text = normalize_text
|
| 77 |
|
| 78 |
# transformer variants
|
| 79 |
-
self.
|
| 80 |
assert ln_type in [
|
| 81 |
"rmsnorm",
|
| 82 |
"layernorm",
|
| 83 |
], "ln_type must be 'rmsnorm' or 'layernorm'"
|
| 84 |
self.ln_type = ln_type
|
|
|
|
|
|
|
| 85 |
assert ln_positions in [
|
| 86 |
"normformer",
|
| 87 |
"swinv2",
|
| 88 |
"cogview",
|
| 89 |
-
"
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
ln_positions = "postln"
|
| 93 |
assert use_alibi is False, "use_alibi is not supported yet"
|
| 94 |
self.ln_positions = ln_positions
|
| 95 |
self.use_cosine_attention = use_cosine_attention
|
|
@@ -97,9 +99,17 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 97 |
self.use_deepnet_scaling = use_deepnet_scaling
|
| 98 |
self.use_glu = use_glu
|
| 99 |
self.use_alibi = use_alibi
|
| 100 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
self.force_ln_scale = force_ln_scale
|
| 102 |
-
self.force_final_ln_encoder = force_final_ln_encoder
|
| 103 |
|
| 104 |
# common parameters
|
| 105 |
self.encoder_vocab_size = encoder_vocab_size
|
|
|
|
| 59 |
do_sample=True,
|
| 60 |
# transformer variants
|
| 61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
| 62 |
+
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
|
| 63 |
+
use_head_scale=False, # used in NormFormer
|
| 64 |
use_cosine_attention=False, # used in Swin v2
|
| 65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
| 66 |
use_deepnet_scaling=False, # used in Deepnet
|
| 67 |
use_glu=False, # "GLU Variants Improve Transformer"
|
| 68 |
use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
| 69 |
+
sinkhorn_iters=1, # used in SinkFormers
|
| 70 |
+
use_final_ln_encoder=False, # final layer normalization in encoder
|
| 71 |
+
use_final_ln_decoder=False, # final layer normalization in decoder
|
| 72 |
# parameters that should not be necessary but could affect results
|
| 73 |
+
force_ln_scale=False, # force scale in layernorm even when followed by dense layers
|
|
|
|
| 74 |
**kwargs,
|
| 75 |
):
|
| 76 |
# text normalizer
|
| 77 |
self.normalize_text = normalize_text
|
| 78 |
|
| 79 |
# transformer variants
|
| 80 |
+
self.use_head_scale = use_head_scale # per Normformer
|
| 81 |
assert ln_type in [
|
| 82 |
"rmsnorm",
|
| 83 |
"layernorm",
|
| 84 |
], "ln_type must be 'rmsnorm' or 'layernorm'"
|
| 85 |
self.ln_type = ln_type
|
| 86 |
+
if ln_positions == "deepnet":
|
| 87 |
+
ln_positions = "postln"
|
| 88 |
assert ln_positions in [
|
| 89 |
"normformer",
|
| 90 |
"swinv2",
|
| 91 |
"cogview",
|
| 92 |
+
"postln",
|
| 93 |
+
"preln",
|
| 94 |
+
], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
|
|
|
|
| 95 |
assert use_alibi is False, "use_alibi is not supported yet"
|
| 96 |
self.ln_positions = ln_positions
|
| 97 |
self.use_cosine_attention = use_cosine_attention
|
|
|
|
| 99 |
self.use_deepnet_scaling = use_deepnet_scaling
|
| 100 |
self.use_glu = use_glu
|
| 101 |
self.use_alibi = use_alibi
|
| 102 |
+
self.sinkhorn_iters = sinkhorn_iters
|
| 103 |
+
if ln_positions == "postln":
|
| 104 |
+
assert (
|
| 105 |
+
use_final_ln_encoder
|
| 106 |
+
), "use_final_ln_encoder must be True when ln_positions is 'postln'"
|
| 107 |
+
assert (
|
| 108 |
+
use_final_ln_decoder
|
| 109 |
+
), "use_final_ln_decoder must be True when ln_positions is 'postln'"
|
| 110 |
+
self.use_final_ln_encoder = use_final_ln_encoder
|
| 111 |
+
self.use_final_ln_decoder = use_final_ln_decoder
|
| 112 |
self.force_ln_scale = force_ln_scale
|
|
|
|
| 113 |
|
| 114 |
# common parameters
|
| 115 |
self.encoder_vocab_size = encoder_vocab_size
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -28,7 +28,7 @@ import msgpack.exceptions
|
|
| 28 |
from flax.core.frozen_dict import unfreeze
|
| 29 |
from flax.linen import combine_masks, make_causal_mask
|
| 30 |
from flax.linen import partitioning as nn_partitioning
|
| 31 |
-
from flax.linen.
|
| 32 |
from flax.serialization import from_bytes
|
| 33 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 34 |
from jax import lax
|
|
@@ -175,6 +175,66 @@ def norm(type, *args, **kwargs):
|
|
| 175 |
raise ValueError(f"Unknown norm type {type}")
|
| 176 |
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
class FlaxBartAttention(FlaxBartAttention):
|
| 179 |
"""
|
| 180 |
Edits:
|
|
@@ -225,7 +285,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 225 |
)
|
| 226 |
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
| 227 |
|
| 228 |
-
if self.config.
|
| 229 |
self.head_scale = self.param(
|
| 230 |
"head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
|
| 231 |
)
|
|
@@ -342,13 +402,14 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 342 |
deterministic=deterministic,
|
| 343 |
dtype=self.dtype,
|
| 344 |
precision=None,
|
|
|
|
| 345 |
)
|
| 346 |
if self.config.use_cosine_attention:
|
| 347 |
# divide by tau
|
| 348 |
attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
|
| 349 |
|
| 350 |
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
| 351 |
-
if self.config.
|
| 352 |
# per Normformer
|
| 353 |
attn_output = attn_output * self.head_scale
|
| 354 |
attn_output = self._merge_heads(attn_output)
|
|
@@ -373,7 +434,7 @@ class GLU(nn.Module):
|
|
| 373 |
self.config
|
| 374 |
)
|
| 375 |
|
| 376 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 377 |
x = norm(
|
| 378 |
self.config.ln_type,
|
| 379 |
dtype=self.dtype,
|
|
@@ -438,7 +499,7 @@ class FFN(nn.Module):
|
|
| 438 |
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
| 439 |
self.config
|
| 440 |
)
|
| 441 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 442 |
x = norm(
|
| 443 |
self.config.ln_type,
|
| 444 |
dtype=self.dtype,
|
|
@@ -507,7 +568,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 507 |
|
| 508 |
embed_dim = self.config.d_model
|
| 509 |
residual = hidden_states
|
| 510 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 511 |
hidden_states = norm(
|
| 512 |
self.config.ln_type,
|
| 513 |
dtype=self.dtype,
|
|
@@ -612,7 +673,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 612 |
residual = hidden_states
|
| 613 |
|
| 614 |
# Self Attention
|
| 615 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 616 |
hidden_states = norm(
|
| 617 |
self.config.ln_type,
|
| 618 |
dtype=self.dtype,
|
|
@@ -651,7 +712,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 651 |
cross_attn_weights = None
|
| 652 |
if encoder_hidden_states is not None:
|
| 653 |
residual = hidden_states
|
| 654 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
| 655 |
hidden_states = norm(
|
| 656 |
self.config.ln_type,
|
| 657 |
dtype=self.dtype,
|
|
@@ -759,12 +820,9 @@ class FlaxBartEncoderLayerCollection(nn.Module):
|
|
| 759 |
all_hidden_states += (hidden_states,)
|
| 760 |
# final layernorm on the output of the last layer
|
| 761 |
# or every 6 layers for Swin v2
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
self.config.ln_positions == "swinv2"
|
| 766 |
-
and ((i == n_layers - 1) or ((i + 1) % 6 == 0))
|
| 767 |
-
)
|
| 768 |
# we don't need to scale the norm for the last layer
|
| 769 |
use_scale = i != n_layers - 1
|
| 770 |
layer_outputs = layer(
|
|
@@ -839,9 +897,9 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
| 839 |
all_hidden_states += (hidden_states,)
|
| 840 |
# final layernorm on the output of the last layer
|
| 841 |
# or every 6 layers for Swin v2
|
| 842 |
-
add_norm = (
|
| 843 |
-
|
| 844 |
-
)
|
| 845 |
# we don't need to scale the norm for the last layer
|
| 846 |
use_scale = i != n_layers - 1
|
| 847 |
layer_outputs = layer(
|
|
|
|
| 28 |
from flax.core.frozen_dict import unfreeze
|
| 29 |
from flax.linen import combine_masks, make_causal_mask
|
| 30 |
from flax.linen import partitioning as nn_partitioning
|
| 31 |
+
from flax.linen.linear import PrecisionLike
|
| 32 |
from flax.serialization import from_bytes
|
| 33 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 34 |
from jax import lax
|
|
|
|
| 175 |
raise ValueError(f"Unknown norm type {type}")
|
| 176 |
|
| 177 |
|
| 178 |
+
def dot_product_attention_weights(
|
| 179 |
+
query: Any,
|
| 180 |
+
key: Any,
|
| 181 |
+
bias: Optional[Any] = None,
|
| 182 |
+
mask: Optional[Any] = None,
|
| 183 |
+
broadcast_dropout: bool = True,
|
| 184 |
+
dropout_rng: Optional[PRNGKey] = None,
|
| 185 |
+
dropout_rate: float = 0.0,
|
| 186 |
+
deterministic: bool = False,
|
| 187 |
+
dtype: Any = jnp.float32,
|
| 188 |
+
precision: PrecisionLike = None,
|
| 189 |
+
sinkhorn_iters: int = 1,
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Computes dot-product attention weights given query and key.
|
| 193 |
+
|
| 194 |
+
Adapted from flax.linen.attention.dot_product_attention_weights"
|
| 195 |
+
"""
|
| 196 |
+
assert query.ndim == key.ndim, "q, k must have same rank."
|
| 197 |
+
assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
|
| 198 |
+
assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
|
| 199 |
+
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
|
| 200 |
+
|
| 201 |
+
# calculate attention matrix
|
| 202 |
+
depth = query.shape[-1]
|
| 203 |
+
query = query / jnp.sqrt(depth).astype(dtype)
|
| 204 |
+
# attn weight shape is (batch..., num_heads, q_length, kv_length)
|
| 205 |
+
attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
|
| 206 |
+
|
| 207 |
+
# apply attention bias: masking, dropout, proximity bias, etc.
|
| 208 |
+
if bias is not None:
|
| 209 |
+
attn_weights = attn_weights + bias
|
| 210 |
+
# apply attention mask
|
| 211 |
+
if mask is not None:
|
| 212 |
+
big_neg = jnp.finfo(dtype).min
|
| 213 |
+
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
| 214 |
+
|
| 215 |
+
# normalize the attention weights
|
| 216 |
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
| 217 |
+
for i in range(sinkhorn_iters - 1):
|
| 218 |
+
axis = -2 if i % 2 == 0 else -1
|
| 219 |
+
attn_weights /= 1e-8 + jnp.sum(attn_weights, axis=axis, keepdims=True)
|
| 220 |
+
|
| 221 |
+
# apply attention dropout
|
| 222 |
+
if not deterministic and dropout_rate > 0.0:
|
| 223 |
+
keep_prob = 1.0 - dropout_rate
|
| 224 |
+
if broadcast_dropout:
|
| 225 |
+
# dropout is broadcast across the batch + head dimensions
|
| 226 |
+
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
|
| 227 |
+
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
| 228 |
+
else:
|
| 229 |
+
keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
|
| 230 |
+
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
|
| 231 |
+
keep_prob, dtype=dtype
|
| 232 |
+
)
|
| 233 |
+
attn_weights = attn_weights * multiplier
|
| 234 |
+
|
| 235 |
+
return attn_weights
|
| 236 |
+
|
| 237 |
+
|
| 238 |
class FlaxBartAttention(FlaxBartAttention):
|
| 239 |
"""
|
| 240 |
Edits:
|
|
|
|
| 285 |
)
|
| 286 |
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
| 287 |
|
| 288 |
+
if self.config.use_head_scale:
|
| 289 |
self.head_scale = self.param(
|
| 290 |
"head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
|
| 291 |
)
|
|
|
|
| 402 |
deterministic=deterministic,
|
| 403 |
dtype=self.dtype,
|
| 404 |
precision=None,
|
| 405 |
+
sinkhorn_iters=self.config.sinkhorn_iters,
|
| 406 |
)
|
| 407 |
if self.config.use_cosine_attention:
|
| 408 |
# divide by tau
|
| 409 |
attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
|
| 410 |
|
| 411 |
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
| 412 |
+
if self.config.use_head_scale:
|
| 413 |
# per Normformer
|
| 414 |
attn_output = attn_output * self.head_scale
|
| 415 |
attn_output = self._merge_heads(attn_output)
|
|
|
|
| 434 |
self.config
|
| 435 |
)
|
| 436 |
|
| 437 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
| 438 |
x = norm(
|
| 439 |
self.config.ln_type,
|
| 440 |
dtype=self.dtype,
|
|
|
|
| 499 |
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
| 500 |
self.config
|
| 501 |
)
|
| 502 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
| 503 |
x = norm(
|
| 504 |
self.config.ln_type,
|
| 505 |
dtype=self.dtype,
|
|
|
|
| 568 |
|
| 569 |
embed_dim = self.config.d_model
|
| 570 |
residual = hidden_states
|
| 571 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
| 572 |
hidden_states = norm(
|
| 573 |
self.config.ln_type,
|
| 574 |
dtype=self.dtype,
|
|
|
|
| 673 |
residual = hidden_states
|
| 674 |
|
| 675 |
# Self Attention
|
| 676 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
| 677 |
hidden_states = norm(
|
| 678 |
self.config.ln_type,
|
| 679 |
dtype=self.dtype,
|
|
|
|
| 712 |
cross_attn_weights = None
|
| 713 |
if encoder_hidden_states is not None:
|
| 714 |
residual = hidden_states
|
| 715 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
| 716 |
hidden_states = norm(
|
| 717 |
self.config.ln_type,
|
| 718 |
dtype=self.dtype,
|
|
|
|
| 820 |
all_hidden_states += (hidden_states,)
|
| 821 |
# final layernorm on the output of the last layer
|
| 822 |
# or every 6 layers for Swin v2
|
| 823 |
+
add_norm = (
|
| 824 |
+
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
|
| 825 |
+
) or (self.config.use_final_ln_encoder and (i == n_layers - 1))
|
|
|
|
|
|
|
|
|
|
| 826 |
# we don't need to scale the norm for the last layer
|
| 827 |
use_scale = i != n_layers - 1
|
| 828 |
layer_outputs = layer(
|
|
|
|
| 897 |
all_hidden_states += (hidden_states,)
|
| 898 |
# final layernorm on the output of the last layer
|
| 899 |
# or every 6 layers for Swin v2
|
| 900 |
+
add_norm = (
|
| 901 |
+
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
|
| 902 |
+
) or (self.config.use_final_ln_decoder and (i == n_layers - 1))
|
| 903 |
# we don't need to scale the norm for the last layer
|
| 904 |
use_scale = i != n_layers - 1
|
| 905 |
layer_outputs = layer(
|