Spaces:
Runtime error
Runtime error
remove bias and minor fixes
Browse files- dalle_mini/modeling_bart_flax.py +30 -55
dalle_mini/modeling_bart_flax.py
CHANGED
|
@@ -44,7 +44,7 @@ from transformers.modeling_flax_utils import (
|
|
| 44 |
from transformers.utils import logging
|
| 45 |
|
| 46 |
|
| 47 |
-
from configuration_bart import BartConfig
|
| 48 |
|
| 49 |
|
| 50 |
logger = logging.get_logger(__name__)
|
|
@@ -80,7 +80,7 @@ class FlaxBartAttention(nn.Module):
|
|
| 80 |
dense = partial(
|
| 81 |
nn.Dense,
|
| 82 |
self.embed_dim,
|
| 83 |
-
use_bias=
|
| 84 |
dtype=self.dtype,
|
| 85 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 86 |
)
|
|
@@ -242,10 +242,14 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 242 |
self.fc1 = nn.Dense(
|
| 243 |
self.config.encoder_ffn_dim,
|
| 244 |
dtype=self.dtype,
|
|
|
|
| 245 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 246 |
)
|
| 247 |
self.fc2 = nn.Dense(
|
| 248 |
-
self.embed_dim,
|
|
|
|
|
|
|
|
|
|
| 249 |
)
|
| 250 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 251 |
|
|
@@ -325,14 +329,18 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 325 |
dropout=self.config.attention_dropout,
|
| 326 |
dtype=self.dtype,
|
| 327 |
)
|
| 328 |
-
self.encoder_attn_layer_norm = nn
|
| 329 |
self.fc1 = nn.Dense(
|
| 330 |
self.config.encoder_ffn_dim,
|
| 331 |
dtype=self.dtype,
|
|
|
|
| 332 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 333 |
)
|
| 334 |
self.fc2 = nn.Dense(
|
| 335 |
-
self.embed_dim,
|
|
|
|
|
|
|
|
|
|
| 336 |
)
|
| 337 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 338 |
|
|
@@ -414,7 +422,6 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
| 414 |
class FlaxBartEncoder(nn.Module):
|
| 415 |
config: BartConfig
|
| 416 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 417 |
-
embed_tokens: Optional[nn.Embed] = None
|
| 418 |
|
| 419 |
def setup(self):
|
| 420 |
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
|
@@ -424,16 +431,15 @@ class FlaxBartEncoder(nn.Module):
|
|
| 424 |
self.max_source_positions = self.config.max_position_embeddings
|
| 425 |
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
|
| 426 |
|
| 427 |
-
|
| 428 |
-
self.
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
)
|
| 433 |
|
| 434 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 435 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 436 |
-
self.offset =
|
| 437 |
self.embed_positions = nn.Embed(
|
| 438 |
self.config.max_position_embeddings + self.offset,
|
| 439 |
embed_dim,
|
|
@@ -472,7 +478,6 @@ class FlaxBartEncoder(nn.Module):
|
|
| 472 |
class FlaxBartDecoder(nn.Module):
|
| 473 |
config: BartConfig
|
| 474 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 475 |
-
embed_tokens: Optional[nn.Embed] = None
|
| 476 |
|
| 477 |
def setup(self):
|
| 478 |
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
|
@@ -482,18 +487,17 @@ class FlaxBartDecoder(nn.Module):
|
|
| 482 |
self.max_target_positions = self.config.max_position_embeddings
|
| 483 |
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
|
| 484 |
|
| 485 |
-
|
| 486 |
-
self.
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
)
|
| 491 |
|
| 492 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 493 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 494 |
-
self.offset =
|
| 495 |
self.embed_positions = nn.Embed(
|
| 496 |
-
self.config.
|
| 497 |
embed_dim,
|
| 498 |
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 499 |
)
|
|
@@ -546,20 +550,8 @@ class FlaxBartModule(nn.Module):
|
|
| 546 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 547 |
|
| 548 |
def setup(self):
|
| 549 |
-
self.
|
| 550 |
-
|
| 551 |
-
self.config.d_model,
|
| 552 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 553 |
-
)
|
| 554 |
-
# a separate embedding is used for the decoder
|
| 555 |
-
self.decoder_embed = nn.Embed(
|
| 556 |
-
self.config.decoder_vocab_size,
|
| 557 |
-
self.config.d_model,
|
| 558 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 559 |
-
)
|
| 560 |
-
|
| 561 |
-
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
| 562 |
-
self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
| 563 |
|
| 564 |
def _get_encoder_module(self):
|
| 565 |
return self.encoder
|
|
@@ -575,8 +567,6 @@ class FlaxBartModule(nn.Module):
|
|
| 575 |
decoder_attention_mask,
|
| 576 |
position_ids,
|
| 577 |
decoder_position_ids,
|
| 578 |
-
output_attentions: bool = False,
|
| 579 |
-
output_hidden_states: bool = False,
|
| 580 |
return_dict: bool = True,
|
| 581 |
deterministic: bool = True,
|
| 582 |
):
|
|
@@ -584,9 +574,6 @@ class FlaxBartModule(nn.Module):
|
|
| 584 |
input_ids=input_ids,
|
| 585 |
attention_mask=attention_mask,
|
| 586 |
position_ids=position_ids,
|
| 587 |
-
output_attentions=output_attentions,
|
| 588 |
-
output_hidden_states=output_hidden_states,
|
| 589 |
-
return_dict=return_dict,
|
| 590 |
deterministic=deterministic,
|
| 591 |
)
|
| 592 |
|
|
@@ -596,9 +583,6 @@ class FlaxBartModule(nn.Module):
|
|
| 596 |
position_ids=decoder_position_ids,
|
| 597 |
encoder_hidden_states=encoder_outputs[0],
|
| 598 |
encoder_attention_mask=attention_mask,
|
| 599 |
-
output_attentions=output_attentions,
|
| 600 |
-
output_hidden_states=output_hidden_states,
|
| 601 |
-
return_dict=return_dict,
|
| 602 |
deterministic=deterministic,
|
| 603 |
)
|
| 604 |
|
|
@@ -629,8 +613,8 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|
| 629 |
dtype: jnp.dtype = jnp.float32,
|
| 630 |
**kwargs,
|
| 631 |
):
|
| 632 |
-
module = self.module_class(config=config, dtype=dtype
|
| 633 |
-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
| 634 |
|
| 635 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
| 636 |
# init input tensors
|
|
@@ -755,17 +739,11 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|
| 755 |
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
| 756 |
position_ids: Optional[jnp.ndarray] = None,
|
| 757 |
decoder_position_ids: Optional[jnp.ndarray] = None,
|
| 758 |
-
output_attentions: Optional[bool] = None,
|
| 759 |
-
output_hidden_states: Optional[bool] = None,
|
| 760 |
return_dict: Optional[bool] = None,
|
| 761 |
train: bool = False,
|
| 762 |
params: dict = None,
|
| 763 |
dropout_rng: PRNGKey = None,
|
| 764 |
):
|
| 765 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 766 |
-
output_hidden_states = (
|
| 767 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 768 |
-
)
|
| 769 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 770 |
|
| 771 |
# prepare encoder inputs
|
|
@@ -817,7 +795,6 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
|
|
| 817 |
dtype=self.dtype,
|
| 818 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 819 |
)
|
| 820 |
-
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.decoder_vocab_size))
|
| 821 |
|
| 822 |
def _get_encoder_module(self):
|
| 823 |
return self.model.encoder
|
|
@@ -853,8 +830,6 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
|
|
| 853 |
else:
|
| 854 |
lm_logits = self.lm_head(hidden_states)
|
| 855 |
|
| 856 |
-
lm_logits += self.final_logits_bias
|
| 857 |
-
|
| 858 |
return FlaxSeq2SeqLMOutput(
|
| 859 |
logits=lm_logits,
|
| 860 |
decoder_hidden_states=outputs.decoder_hidden_states,
|
|
|
|
| 44 |
from transformers.utils import logging
|
| 45 |
|
| 46 |
|
| 47 |
+
from .configuration_bart import BartConfig
|
| 48 |
|
| 49 |
|
| 50 |
logger = logging.get_logger(__name__)
|
|
|
|
| 80 |
dense = partial(
|
| 81 |
nn.Dense,
|
| 82 |
self.embed_dim,
|
| 83 |
+
use_bias=False,
|
| 84 |
dtype=self.dtype,
|
| 85 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 86 |
)
|
|
|
|
| 242 |
self.fc1 = nn.Dense(
|
| 243 |
self.config.encoder_ffn_dim,
|
| 244 |
dtype=self.dtype,
|
| 245 |
+
use_bias=False,
|
| 246 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 247 |
)
|
| 248 |
self.fc2 = nn.Dense(
|
| 249 |
+
self.embed_dim,
|
| 250 |
+
dtype=self.dtype,
|
| 251 |
+
use_bias=False,
|
| 252 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 253 |
)
|
| 254 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 255 |
|
|
|
|
| 329 |
dropout=self.config.attention_dropout,
|
| 330 |
dtype=self.dtype,
|
| 331 |
)
|
| 332 |
+
self.encoder_attn_layer_norm = nn
|
| 333 |
self.fc1 = nn.Dense(
|
| 334 |
self.config.encoder_ffn_dim,
|
| 335 |
dtype=self.dtype,
|
| 336 |
+
use_bias=False,
|
| 337 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 338 |
)
|
| 339 |
self.fc2 = nn.Dense(
|
| 340 |
+
self.embed_dim,
|
| 341 |
+
dtype=self.dtype,
|
| 342 |
+
use_bias=False,
|
| 343 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 344 |
)
|
| 345 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 346 |
|
|
|
|
| 422 |
class FlaxBartEncoder(nn.Module):
|
| 423 |
config: BartConfig
|
| 424 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
| 425 |
|
| 426 |
def setup(self):
|
| 427 |
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
|
|
|
| 431 |
self.max_source_positions = self.config.max_position_embeddings
|
| 432 |
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
|
| 433 |
|
| 434 |
+
self.embed_tokens = nn.Embed(
|
| 435 |
+
self.config.vocab_size,
|
| 436 |
+
embed_dim,
|
| 437 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 438 |
+
)
|
|
|
|
| 439 |
|
| 440 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 441 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 442 |
+
self.offset = 0
|
| 443 |
self.embed_positions = nn.Embed(
|
| 444 |
self.config.max_position_embeddings + self.offset,
|
| 445 |
embed_dim,
|
|
|
|
| 478 |
class FlaxBartDecoder(nn.Module):
|
| 479 |
config: BartConfig
|
| 480 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
| 481 |
|
| 482 |
def setup(self):
|
| 483 |
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
|
|
|
| 487 |
self.max_target_positions = self.config.max_position_embeddings
|
| 488 |
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
|
| 489 |
|
| 490 |
+
self.embed_tokens = nn.Embed(
|
| 491 |
+
self.config.decoder_vocab_size,
|
| 492 |
+
embed_dim,
|
| 493 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 494 |
+
)
|
|
|
|
| 495 |
|
| 496 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 497 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 498 |
+
self.offset = 0
|
| 499 |
self.embed_positions = nn.Embed(
|
| 500 |
+
self.config.decoder_max_position_embeddings + self.offset,
|
| 501 |
embed_dim,
|
| 502 |
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
| 503 |
)
|
|
|
|
| 550 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 551 |
|
| 552 |
def setup(self):
|
| 553 |
+
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype)
|
| 554 |
+
self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
|
| 556 |
def _get_encoder_module(self):
|
| 557 |
return self.encoder
|
|
|
|
| 567 |
decoder_attention_mask,
|
| 568 |
position_ids,
|
| 569 |
decoder_position_ids,
|
|
|
|
|
|
|
| 570 |
return_dict: bool = True,
|
| 571 |
deterministic: bool = True,
|
| 572 |
):
|
|
|
|
| 574 |
input_ids=input_ids,
|
| 575 |
attention_mask=attention_mask,
|
| 576 |
position_ids=position_ids,
|
|
|
|
|
|
|
|
|
|
| 577 |
deterministic=deterministic,
|
| 578 |
)
|
| 579 |
|
|
|
|
| 583 |
position_ids=decoder_position_ids,
|
| 584 |
encoder_hidden_states=encoder_outputs[0],
|
| 585 |
encoder_attention_mask=attention_mask,
|
|
|
|
|
|
|
|
|
|
| 586 |
deterministic=deterministic,
|
| 587 |
)
|
| 588 |
|
|
|
|
| 613 |
dtype: jnp.dtype = jnp.float32,
|
| 614 |
**kwargs,
|
| 615 |
):
|
| 616 |
+
module = self.module_class(config=config, dtype=dtype)
|
| 617 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs)
|
| 618 |
|
| 619 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
| 620 |
# init input tensors
|
|
|
|
| 739 |
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
| 740 |
position_ids: Optional[jnp.ndarray] = None,
|
| 741 |
decoder_position_ids: Optional[jnp.ndarray] = None,
|
|
|
|
|
|
|
| 742 |
return_dict: Optional[bool] = None,
|
| 743 |
train: bool = False,
|
| 744 |
params: dict = None,
|
| 745 |
dropout_rng: PRNGKey = None,
|
| 746 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 748 |
|
| 749 |
# prepare encoder inputs
|
|
|
|
| 795 |
dtype=self.dtype,
|
| 796 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 797 |
)
|
|
|
|
| 798 |
|
| 799 |
def _get_encoder_module(self):
|
| 800 |
return self.model.encoder
|
|
|
|
| 830 |
else:
|
| 831 |
lm_logits = self.lm_head(hidden_states)
|
| 832 |
|
|
|
|
|
|
|
| 833 |
return FlaxSeq2SeqLMOutput(
|
| 834 |
logits=lm_logits,
|
| 835 |
decoder_hidden_states=outputs.decoder_hidden_states,
|