ydshieh commited on
Commit ·
64afcd5
1
Parent(s): ec3ceb6
Fix style
Browse files- vit_gpt2/modeling_flax_gpt2.py +27 -11
vit_gpt2/modeling_flax_gpt2.py
CHANGED
|
@@ -24,7 +24,10 @@ from flax.linen.attention import dot_product_attention_weights
|
|
| 24 |
from jax import lax
|
| 25 |
|
| 26 |
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
| 27 |
-
from ...modeling_flax_outputs import
|
|
|
|
|
|
|
|
|
|
| 28 |
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
| 29 |
from ...utils import logging
|
| 30 |
from .configuration_gpt2 import GPT2Config
|
|
@@ -301,7 +304,9 @@ class FlaxGPT2Block(nn.Module):
|
|
| 301 |
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
| 302 |
|
| 303 |
if self.config.add_cross_attention:
|
| 304 |
-
self.crossattention = FlaxGPT2Attention(
|
|
|
|
|
|
|
| 305 |
self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
| 306 |
|
| 307 |
project_encoder = getattr(self.config, "project_encoder", None)
|
|
@@ -337,7 +342,6 @@ class FlaxGPT2Block(nn.Module):
|
|
| 337 |
hidden_states = attn_output + residual
|
| 338 |
|
| 339 |
# Cross-Attention Block
|
| 340 |
-
cross_attn_weights = None
|
| 341 |
if encoder_hidden_states is not None:
|
| 342 |
# add one self-attention block for cross-attention
|
| 343 |
if not hasattr(self, "crossattention"):
|
|
@@ -413,13 +417,16 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
|
| 413 |
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
| 414 |
encoder_attention_mask = attention_mask
|
| 415 |
module_init_outputs = self.module.init(
|
| 416 |
-
rngs,
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
)
|
| 419 |
else:
|
| 420 |
-
module_init_outputs = self.module.init(
|
| 421 |
-
rngs, input_ids, attention_mask, position_ids, return_dict=False
|
| 422 |
-
)
|
| 423 |
|
| 424 |
return module_init_outputs["params"]
|
| 425 |
|
|
@@ -660,7 +667,11 @@ class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
|
|
| 660 |
|
| 661 |
|
| 662 |
append_call_sample_docstring(
|
| 663 |
-
FlaxGPT2Model,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
)
|
| 665 |
|
| 666 |
|
|
@@ -718,9 +729,10 @@ class FlaxGPT2LMHeadModule(nn.Module):
|
|
| 718 |
logits=lm_logits,
|
| 719 |
hidden_states=outputs.hidden_states,
|
| 720 |
attentions=outputs.attentions,
|
| 721 |
-
cross_attentions=outputs.cross_attentions
|
| 722 |
)
|
| 723 |
|
|
|
|
| 724 |
@add_start_docstrings(
|
| 725 |
"""
|
| 726 |
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
|
@@ -759,5 +771,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
|
| 759 |
|
| 760 |
|
| 761 |
append_call_sample_docstring(
|
| 762 |
-
FlaxGPT2LMHeadModel,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
)
|
|
|
|
| 24 |
from jax import lax
|
| 25 |
|
| 26 |
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
| 27 |
+
from ...modeling_flax_outputs import (
|
| 28 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
| 29 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
| 30 |
+
)
|
| 31 |
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
| 32 |
from ...utils import logging
|
| 33 |
from .configuration_gpt2 import GPT2Config
|
|
|
|
| 304 |
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
| 305 |
|
| 306 |
if self.config.add_cross_attention:
|
| 307 |
+
self.crossattention = FlaxGPT2Attention(
|
| 308 |
+
config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True
|
| 309 |
+
)
|
| 310 |
self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
| 311 |
|
| 312 |
project_encoder = getattr(self.config, "project_encoder", None)
|
|
|
|
| 342 |
hidden_states = attn_output + residual
|
| 343 |
|
| 344 |
# Cross-Attention Block
|
|
|
|
| 345 |
if encoder_hidden_states is not None:
|
| 346 |
# add one self-attention block for cross-attention
|
| 347 |
if not hasattr(self, "crossattention"):
|
|
|
|
| 417 |
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
| 418 |
encoder_attention_mask = attention_mask
|
| 419 |
module_init_outputs = self.module.init(
|
| 420 |
+
rngs,
|
| 421 |
+
input_ids,
|
| 422 |
+
attention_mask,
|
| 423 |
+
position_ids,
|
| 424 |
+
encoder_hidden_states,
|
| 425 |
+
encoder_attention_mask,
|
| 426 |
+
return_dict=False,
|
| 427 |
)
|
| 428 |
else:
|
| 429 |
+
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
|
|
|
|
|
|
| 430 |
|
| 431 |
return module_init_outputs["params"]
|
| 432 |
|
|
|
|
| 667 |
|
| 668 |
|
| 669 |
append_call_sample_docstring(
|
| 670 |
+
FlaxGPT2Model,
|
| 671 |
+
_TOKENIZER_FOR_DOC,
|
| 672 |
+
_CHECKPOINT_FOR_DOC,
|
| 673 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
| 674 |
+
_CONFIG_FOR_DOC,
|
| 675 |
)
|
| 676 |
|
| 677 |
|
|
|
|
| 729 |
logits=lm_logits,
|
| 730 |
hidden_states=outputs.hidden_states,
|
| 731 |
attentions=outputs.attentions,
|
| 732 |
+
cross_attentions=outputs.cross_attentions,
|
| 733 |
)
|
| 734 |
|
| 735 |
+
|
| 736 |
@add_start_docstrings(
|
| 737 |
"""
|
| 738 |
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
|
|
|
| 771 |
|
| 772 |
|
| 773 |
append_call_sample_docstring(
|
| 774 |
+
FlaxGPT2LMHeadModel,
|
| 775 |
+
_TOKENIZER_FOR_DOC,
|
| 776 |
+
_CHECKPOINT_FOR_DOC,
|
| 777 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
| 778 |
+
_CONFIG_FOR_DOC,
|
| 779 |
)
|