Sohan2004 commited on
Commit
b16259e
·
verified ·
1 Parent(s): 2f320d8

Upload IndicTransForConditionalGeneration

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "architectures": [
5
+ "IndicTransForConditionalGeneration"
6
+ ],
7
+ "attention_dropout": 0.0,
8
+ "attn_implementation": null,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_indictrans.IndicTransConfig",
11
+ "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
12
+ },
13
+ "bos_token_id": 0,
14
+ "decoder_attention_heads": 16,
15
+ "decoder_embed_dim": 1024,
16
+ "decoder_ffn_dim": 8192,
17
+ "decoder_layerdrop": 0,
18
+ "decoder_layers": 18,
19
+ "decoder_normalize_before": true,
20
+ "decoder_start_token_id": 2,
21
+ "decoder_vocab_size": 122672,
22
+ "dropout": 0.2,
23
+ "encoder_attention_heads": 16,
24
+ "encoder_embed_dim": 1024,
25
+ "encoder_ffn_dim": 8192,
26
+ "encoder_layerdrop": 0,
27
+ "encoder_layers": 18,
28
+ "encoder_normalize_before": true,
29
+ "encoder_vocab_size": 32322,
30
+ "eos_token_id": 2,
31
+ "init_std": 0.02,
32
+ "is_encoder_decoder": true,
33
+ "layernorm_embedding": false,
34
+ "max_source_positions": 256,
35
+ "max_target_positions": 256,
36
+ "model_type": "IndicTrans",
37
+ "num_hidden_layers": 18,
38
+ "pad_token_id": 1,
39
+ "scale_embedding": true,
40
+ "share_decoder_input_output_embed": false,
41
+ "tokenizer_class": "IndicTransTokenizer",
42
+ "torch_dtype": "float32",
43
+ "transformers_version": "4.53.0",
44
+ "use_cache": true,
45
+ "vocab_size": 122672
46
+ }
configuration_indictrans.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans config."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, Mapping, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
24
+ from transformers.onnx.utils import compute_effective_axis_dimension
25
+ from transformers.utils import TensorType, is_torch_available
26
+
27
+
28
+ # Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
29
+ class IndicTransConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
32
+ IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the IT2
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50265):
41
+ Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`IT2Model`] or
43
+ d_model (`int`, *optional*, defaults to 1024):
44
+ Dimensionality of the layers and the pooler layer.
45
+ encoder_layers (`int`, *optional*, defaults to 12):
46
+ Number of encoder layers.
47
+ decoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of decoder layers.
49
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
55
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for classifier.
68
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
69
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
70
+ just in case (e.g., 512 or 1024 or 2048).
71
+ init_std (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
74
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
75
+ for more details.
76
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
77
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
78
+ for more details.
79
+ use_cache (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the model should return the last key/values attentions (not used by all models).
81
+ ```"""
82
+ model_type = "IndicTrans"
83
+ keys_to_ignore_at_inference = ["past_key_values"]
84
+ attribute_map = {
85
+ "num_attention_heads": "encoder_attention_heads",
86
+ "hidden_size": "d_model",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ encoder_vocab_size=None,
92
+ decoder_vocab_size=None,
93
+ encoder_embed_dim=512,
94
+ decoder_embed_dim=512,
95
+ max_source_positions=210,
96
+ max_target_positions=210,
97
+ encoder_layers=6,
98
+ encoder_ffn_dim=2048,
99
+ encoder_attention_heads=8,
100
+ decoder_layers=6,
101
+ decoder_ffn_dim=2048,
102
+ decoder_attention_heads=8,
103
+ encoder_layerdrop=0.00,
104
+ decoder_layerdrop=0.00,
105
+ use_cache=True,
106
+ is_encoder_decoder=True,
107
+ activation_function="relu",
108
+ encoder_normalize_before=False,
109
+ decoder_normalize_before=False,
110
+ layernorm_embedding=False,
111
+ share_decoder_input_output_embed=False,
112
+ dropout=0.1,
113
+ attention_dropout=0.0,
114
+ activation_dropout=0.0,
115
+ init_std=0.02,
116
+ scale_embedding=True,
117
+ decoder_start_token_id=2,
118
+ pad_token_id=1,
119
+ bos_token_id=0,
120
+ eos_token_id=2,
121
+ attn_implementation="eager",
122
+ **kwargs,
123
+ ):
124
+ self.encoder_vocab_size = encoder_vocab_size
125
+ self.decoder_vocab_size = decoder_vocab_size
126
+ self.encoder_normalize_before = encoder_normalize_before
127
+ self.decoder_normalize_before = decoder_normalize_before
128
+ self.layernorm_embedding = layernorm_embedding
129
+ self.max_source_positions = max_source_positions
130
+ self.max_target_positions = max_target_positions
131
+ self.encoder_embed_dim = encoder_embed_dim
132
+ self.decoder_embed_dim = decoder_embed_dim
133
+ self.encoder_ffn_dim = encoder_ffn_dim
134
+ self.encoder_layers = encoder_layers
135
+ self.encoder_attention_heads = encoder_attention_heads
136
+ self.decoder_ffn_dim = decoder_ffn_dim
137
+ self.decoder_layers = decoder_layers
138
+ self.decoder_attention_heads = decoder_attention_heads
139
+ self.dropout = dropout
140
+ self.attention_dropout = attention_dropout
141
+ self.activation_dropout = activation_dropout
142
+ self.activation_function = activation_function
143
+ self.init_std = init_std
144
+ self.encoder_layerdrop = encoder_layerdrop
145
+ self.decoder_layerdrop = decoder_layerdrop
146
+ self.use_cache = use_cache
147
+ self.num_hidden_layers = encoder_layers
148
+ self.scale_embedding = scale_embedding
149
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
150
+ self.attn_implementation = attn_implementation
151
+
152
+ super().__init__(
153
+ pad_token_id=pad_token_id,
154
+ bos_token_id=bos_token_id,
155
+ eos_token_id=eos_token_id,
156
+ is_encoder_decoder=is_encoder_decoder,
157
+ decoder_start_token_id=decoder_start_token_id,
158
+ **kwargs,
159
+ )
160
+
161
+
162
+ class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
163
+ @property
164
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
165
+ common_inputs = OrderedDict(
166
+ [
167
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
168
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
169
+ ]
170
+ )
171
+
172
+ if self.use_past:
173
+ common_inputs["decoder_input_ids"] = {0: "batch"}
174
+ common_inputs["decoder_attention_mask"] = {
175
+ 0: "batch",
176
+ 1: "past_decoder_sequence + sequence",
177
+ }
178
+ else:
179
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
180
+ common_inputs["decoder_attention_mask"] = {
181
+ 0: "batch",
182
+ 1: "decoder_sequence",
183
+ }
184
+
185
+ if self.use_past:
186
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
187
+ return common_inputs
188
+
189
+ # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
190
+ # A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
191
+ # answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
192
+ # was done for BART so that it can be updated if need be.
193
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
194
+ self,
195
+ tokenizer: PreTrainedTokenizer,
196
+ batch_size: int = -1,
197
+ seq_length: int = -1,
198
+ is_pair: bool = False,
199
+ framework: Optional[TensorType] = None,
200
+ ) -> Mapping[str, Any]:
201
+ # Copied from OnnxConfig.generate_dummy_inputs
202
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
203
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
204
+ batch_size = compute_effective_axis_dimension(
205
+ batch_size,
206
+ fixed_dimension=OnnxConfig.default_fixed_batch,
207
+ num_token_to_add=0,
208
+ )
209
+
210
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
211
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
212
+ seq_length = compute_effective_axis_dimension(
213
+ seq_length,
214
+ fixed_dimension=OnnxConfig.default_fixed_sequence,
215
+ num_token_to_add=token_to_add,
216
+ )
217
+
218
+ # Generate dummy inputs according to compute batch and sequence
219
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
220
+ common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
221
+ return common_inputs
222
+
223
+ # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
224
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
225
+ self,
226
+ tokenizer: PreTrainedTokenizer,
227
+ batch_size: int = -1,
228
+ seq_length: int = -1,
229
+ is_pair: bool = False,
230
+ framework: Optional[TensorType] = None,
231
+ ) -> Mapping[str, Any]:
232
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
233
+ tokenizer, batch_size, seq_length, is_pair, framework
234
+ )
235
+
236
+ # Generate decoder inputs
237
+ decoder_seq_length = seq_length if not self.use_past else 1
238
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
239
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
240
+ )
241
+ decoder_inputs = {
242
+ f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
243
+ }
244
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
245
+
246
+ if self.use_past:
247
+ if not is_torch_available():
248
+ raise ValueError(
249
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
250
+ )
251
+ else:
252
+ import torch
253
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
254
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
255
+ (
256
+ num_encoder_attention_heads,
257
+ num_decoder_attention_heads,
258
+ ) = self.num_attention_heads
259
+ encoder_shape = (
260
+ batch,
261
+ num_encoder_attention_heads,
262
+ encoder_seq_length,
263
+ self._config.hidden_size // num_encoder_attention_heads,
264
+ )
265
+ decoder_past_length = decoder_seq_length + 3
266
+ decoder_shape = (
267
+ batch,
268
+ num_decoder_attention_heads,
269
+ decoder_past_length,
270
+ self._config.hidden_size // num_decoder_attention_heads,
271
+ )
272
+
273
+ common_inputs["decoder_attention_mask"] = torch.cat(
274
+ [
275
+ common_inputs["decoder_attention_mask"],
276
+ torch.ones(batch, decoder_past_length),
277
+ ],
278
+ dim=1,
279
+ )
280
+
281
+ common_inputs["past_key_values"] = []
282
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
283
+ num_encoder_layers, num_decoder_layers = self.num_layers
284
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
285
+ max_num_layers = (
286
+ max(num_encoder_layers, num_decoder_layers) - min_num_layers
287
+ )
288
+ remaining_side_name = (
289
+ "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
290
+ )
291
+
292
+ for _ in range(min_num_layers):
293
+ common_inputs["past_key_values"].append(
294
+ (
295
+ torch.zeros(decoder_shape),
296
+ torch.zeros(decoder_shape),
297
+ torch.zeros(encoder_shape),
298
+ torch.zeros(encoder_shape),
299
+ )
300
+ )
301
+ # TODO: test this.
302
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
303
+ for _ in range(min_num_layers, max_num_layers):
304
+ common_inputs["past_key_values"].append(
305
+ (torch.zeros(shape), torch.zeros(shape))
306
+ )
307
+ return common_inputs
308
+
309
+ generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "decoder_start_token_id": 2,
5
+ "eos_token_id": 2,
6
+ "pad_token_id": 1,
7
+ "transformers_version": "4.53.0"
8
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35d28fe035cd6ac026536b555558b07762425c8b930670219063e4fc3666c96d
3
+ size 4462265272
modeling_indictrans.py ADDED
@@ -0,0 +1,1802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans model."""
16
+
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+
27
+ from transformers.modeling_attn_mask_utils import (
28
+ _prepare_4d_attention_mask,
29
+ _prepare_4d_attention_mask_for_sdpa,
30
+ _prepare_4d_causal_attention_mask,
31
+ _prepare_4d_causal_attention_mask_for_sdpa,
32
+ )
33
+
34
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutput,
37
+ BaseModelOutputWithPastAndCrossAttentions,
38
+ Seq2SeqLMOutput,
39
+ Seq2SeqModelOutput
40
+ )
41
+
42
+ from transformers.utils import (
43
+ logging,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10,
46
+ )
47
+
48
+ from transformers.modeling_utils import PreTrainedModel
49
+ from transformers.generation.utils import GenerationMixin
50
+
51
+ from .configuration_indictrans import IndicTransConfig
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
57
+
58
+ try:
59
+ if is_flash_attn_2_available():
60
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
61
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
62
+ except:
63
+ pass
64
+
65
+
66
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
67
+ def _get_unpad_data(attention_mask):
68
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
69
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
70
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
71
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
72
+ return (
73
+ indices,
74
+ cu_seqlens,
75
+ max_seqlen_in_batch,
76
+ )
77
+
78
+
79
+ # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
80
+ def shift_tokens_right(
81
+ input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
82
+ ):
83
+ """
84
+ Shift input ids one token to the right.
85
+ """
86
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
87
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
88
+ shifted_input_ids[:, 0] = decoder_start_token_id
89
+
90
+ if pad_token_id is None:
91
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
92
+ # replace possible -100 values in labels by `pad_token_id`
93
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
94
+
95
+ return shifted_input_ids
96
+
97
+
98
+ def create_position_ids_from_input_ids(
99
+ input_ids, padding_idx, past_key_values_length=0
100
+ ):
101
+ """
102
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
103
+ are ignored. This is modified from fairseq's `utils.make_positions`.
104
+ """
105
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
106
+ mask = input_ids.ne(padding_idx).int()
107
+ incremental_indices = (
108
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
109
+ ) * mask
110
+ return incremental_indices.long() + padding_idx
111
+
112
+
113
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
114
+ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
115
+ """This module produces sinusoidal positional embeddings of any length."""
116
+
117
+ def __init__(
118
+ self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
119
+ ):
120
+ super().__init__()
121
+ self.offset = 2
122
+ self.embedding_dim = embedding_dim
123
+ self.padding_idx = padding_idx
124
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
125
+
126
+ def make_weights(
127
+ self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
128
+ ):
129
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
130
+ if hasattr(self, "weights"):
131
+ # in forward put the weights on the correct dtype and device of the param
132
+ emb_weights = emb_weights.to(
133
+ dtype=self.weights.dtype, device=self.weights.device
134
+ )
135
+
136
+ self.register_buffer("weights", emb_weights, persistent=False)
137
+
138
+ @staticmethod
139
+ def get_embedding(
140
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
141
+ ):
142
+ """
143
+ Build sinusoidal embeddings.
144
+
145
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
146
+ "Attention Is All You Need".
147
+ """
148
+ half_dim = embedding_dim // 2
149
+ emb = math.log(10000) / (half_dim - 1)
150
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
151
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
152
+ 1
153
+ ) * emb.unsqueeze(0)
154
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
155
+ num_embeddings, -1
156
+ )
157
+ if embedding_dim % 2 == 1:
158
+ # zero pad
159
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
160
+ if padding_idx is not None:
161
+ emb[padding_idx, :] = 0
162
+
163
+ return emb.to(torch.get_default_dtype())
164
+
165
+ @torch.no_grad()
166
+ def forward(
167
+ self,
168
+ input_ids: torch.Tensor = None,
169
+ inputs_embeds: torch.Tensor = None,
170
+ past_key_values_length: int = 0,
171
+ ):
172
+ if input_ids is not None:
173
+ bsz, seq_len = input_ids.size()
174
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
175
+ position_ids = create_position_ids_from_input_ids(
176
+ input_ids, self.padding_idx, past_key_values_length
177
+ ).to(input_ids.device)
178
+ else:
179
+ bsz, seq_len = inputs_embeds.size()[:-1]
180
+ position_ids = self.create_position_ids_from_inputs_embeds(
181
+ inputs_embeds, past_key_values_length
182
+ )
183
+
184
+ # expand embeddings if needed
185
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
186
+ if max_pos > self.weights.size(0):
187
+ self.make_weights(
188
+ max_pos + self.offset, self.embedding_dim, self.padding_idx
189
+ )
190
+
191
+ return (
192
+ self.weights.index_select(0, position_ids.view(-1))
193
+ .view(bsz, seq_len, self.weights.shape[-1])
194
+ .detach()
195
+ )
196
+
197
+ def create_position_ids_from_inputs_embeds(
198
+ self, inputs_embeds, past_key_values_length
199
+ ):
200
+ """
201
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
202
+
203
+ Args:
204
+ inputs_embeds: torch.Tensor
205
+
206
+ Returns: torch.Tensor
207
+ """
208
+ input_shape = inputs_embeds.size()[:-1]
209
+ sequence_length = input_shape[1]
210
+
211
+ position_ids = torch.arange(
212
+ self.padding_idx + 1,
213
+ sequence_length + self.padding_idx + 1,
214
+ dtype=torch.long,
215
+ device=inputs_embeds.device,
216
+ )
217
+ return (
218
+ position_ids.unsqueeze(0).expand(input_shape).contiguous()
219
+ + past_key_values_length
220
+ )
221
+
222
+
223
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
224
+ class IndicTransAttention(nn.Module):
225
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
226
+
227
+ def __init__(
228
+ self,
229
+ embed_dim: int,
230
+ num_heads: int,
231
+ dropout: float = 0.0,
232
+ is_decoder: bool = False,
233
+ bias: bool = True,
234
+ is_causal: bool = False,
235
+ config: Optional[IndicTransConfig] = None,
236
+ ):
237
+ super().__init__()
238
+ self.embed_dim = embed_dim
239
+ self.num_heads = num_heads
240
+ self.dropout = dropout
241
+ self.head_dim = embed_dim // num_heads
242
+ self.config = config
243
+
244
+ if (self.head_dim * num_heads) != self.embed_dim:
245
+ raise ValueError(
246
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
247
+ f" and `num_heads`: {num_heads})."
248
+ )
249
+ self.scaling = self.head_dim**-0.5
250
+ self.is_decoder = is_decoder
251
+ self.is_causal = is_causal
252
+
253
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
254
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
255
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
256
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
257
+
258
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
259
+ return (
260
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
261
+ .transpose(1, 2)
262
+ .contiguous()
263
+ )
264
+
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ key_value_states: Optional[torch.Tensor] = None,
269
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
270
+ attention_mask: Optional[torch.Tensor] = None,
271
+ layer_head_mask: Optional[torch.Tensor] = None,
272
+ output_attentions: bool = False,
273
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
274
+ """Input shape: Batch x Time x Channel"""
275
+
276
+ # if key_value_states are provided this layer is used as a cross-attention layer
277
+ # for the decoder
278
+ is_cross_attention = key_value_states is not None
279
+
280
+ bsz, tgt_len, _ = hidden_states.size()
281
+
282
+ # get query proj
283
+ query_states = self.q_proj(hidden_states) * self.scaling
284
+ # get key, value proj
285
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
286
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
287
+ # the provided `key_value_states` to support prefix tuning
288
+ if (
289
+ is_cross_attention
290
+ and past_key_value is not None
291
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
292
+ ):
293
+ # reuse k,v, cross_attentions
294
+ key_states = past_key_value[0]
295
+ value_states = past_key_value[1]
296
+ elif is_cross_attention:
297
+ # cross_attentions
298
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
299
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
300
+ elif past_key_value is not None:
301
+ # reuse k, v, self_attention
302
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
303
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
304
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
305
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
306
+ else:
307
+ # self_attention
308
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
309
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
310
+
311
+ if self.is_decoder:
312
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
313
+ # Further calls to cross_attention layer can then reuse all cross-attention
314
+ # key/value_states (first "if" case)
315
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
316
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
317
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
318
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
319
+ past_key_value = (key_states, value_states)
320
+
321
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
322
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
323
+ key_states = key_states.reshape(*proj_shape)
324
+ value_states = value_states.reshape(*proj_shape)
325
+
326
+ src_len = key_states.size(1)
327
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
328
+
329
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
330
+ raise ValueError(
331
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
332
+ f" {attn_weights.size()}"
333
+ )
334
+
335
+ if attention_mask is not None:
336
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
337
+ raise ValueError(
338
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
339
+ )
340
+ attn_weights = (
341
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
342
+ + attention_mask
343
+ )
344
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
345
+
346
+ attn_weights = F.softmax(attn_weights, dim=-1)
347
+
348
+ if layer_head_mask is not None:
349
+ if layer_head_mask.size() != (self.num_heads,):
350
+ raise ValueError(
351
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
352
+ f" {layer_head_mask.size()}"
353
+ )
354
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
355
+ bsz, self.num_heads, tgt_len, src_len
356
+ )
357
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
358
+
359
+ if output_attentions:
360
+ # this operation is a bit awkward, but it's required to
361
+ # make sure that attn_weights keeps its gradient.
362
+ # In order to do so, attn_weights have to be reshaped
363
+ # twice and have to be reused in the following
364
+ attn_weights_reshaped = attn_weights.view(
365
+ bsz, self.num_heads, tgt_len, src_len
366
+ )
367
+ attn_weights = attn_weights_reshaped.view(
368
+ bsz * self.num_heads, tgt_len, src_len
369
+ )
370
+ else:
371
+ attn_weights_reshaped = None
372
+
373
+ attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
374
+
375
+ attn_output = torch.bmm(attn_probs, value_states)
376
+
377
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
378
+ raise ValueError(
379
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
380
+ f" {attn_output.size()}"
381
+ )
382
+
383
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
384
+ attn_output = attn_output.transpose(1, 2)
385
+
386
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
387
+ # partitioned across GPUs when using tensor-parallelism.
388
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
389
+
390
+ attn_output = self.out_proj(attn_output)
391
+
392
+ return attn_output, attn_weights_reshaped, past_key_value
393
+
394
+
395
+ class IndicTransFlashAttention2(IndicTransAttention):
396
+ """
397
+ IndicTrans flash attention module. This module inherits from `IndicTransAttention` as the weights of the module stays
398
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
399
+ flash attention and deal with padding tokens in case the input contains any of them.
400
+ """
401
+
402
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
403
+ def __init__(self, *args, **kwargs):
404
+ super().__init__(*args, **kwargs)
405
+
406
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
407
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
408
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
409
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
410
+
411
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
412
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
413
+
414
+ def forward(
415
+ self,
416
+ hidden_states: torch.Tensor,
417
+ key_value_states: Optional[torch.Tensor] = None,
418
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
419
+ attention_mask: Optional[torch.Tensor] = None,
420
+ layer_head_mask: Optional[torch.Tensor] = None,
421
+ output_attentions: bool = False,
422
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
423
+ # IndicTransFlashAttention2 attention does not support output_attentions
424
+ if output_attentions:
425
+ raise ValueError("IndicTransFlashAttention2 attention does not support output_attentions")
426
+
427
+ # if key_value_states are provided this layer is used as a cross-attention layer
428
+ # for the decoder
429
+ is_cross_attention = key_value_states is not None
430
+
431
+ bsz, q_len, _ = hidden_states.size()
432
+
433
+ # get query proj
434
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
435
+ # get key, value proj
436
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
437
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
438
+ # the provided `key_value_states` to support prefix tuning
439
+ if (
440
+ is_cross_attention
441
+ and past_key_value is not None
442
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
443
+ ):
444
+ # reuse k,v, cross_attentions
445
+ key_states = past_key_value[0].transpose(1, 2)
446
+ value_states = past_key_value[1].transpose(1, 2)
447
+ elif is_cross_attention:
448
+ # cross_attentions
449
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
450
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
451
+ elif past_key_value is not None:
452
+ # reuse k, v, self_attention
453
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
454
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
455
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
456
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
457
+ else:
458
+ # self_attention
459
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
460
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
461
+
462
+ if self.is_decoder:
463
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
464
+ # Further calls to cross_attention layer can then reuse all cross-attention
465
+ # key/value_states (first "if" case)
466
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
467
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
468
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
469
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
470
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
471
+
472
+ kv_seq_len = key_states.shape[-2]
473
+ if past_key_value is not None:
474
+ kv_seq_len += past_key_value[0].shape[-2]
475
+
476
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
477
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
478
+ # cast them back in the correct dtype just to be sure everything works as expected.
479
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
480
+ # in fp32. (LlamaRMSNorm handles it correctly)
481
+
482
+ input_dtype = query_states.dtype
483
+ if input_dtype == torch.float32:
484
+ if torch.is_autocast_enabled():
485
+ target_dtype = torch.get_autocast_gpu_dtype()
486
+ # Handle the case where the model is quantized
487
+ elif hasattr(self.config, "_pre_quantization_dtype"):
488
+ target_dtype = self.config._pre_quantization_dtype
489
+ else:
490
+ target_dtype = self.q_proj.weight.dtype
491
+
492
+ logger.warning_once(
493
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
494
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
495
+ f" {target_dtype}."
496
+ )
497
+
498
+ query_states = query_states.to(target_dtype)
499
+ key_states = key_states.to(target_dtype)
500
+ value_states = value_states.to(target_dtype)
501
+
502
+ attn_output = self._flash_attention_forward(
503
+ query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
504
+ )
505
+
506
+ attn_output = attn_output.reshape(bsz, q_len, -1)
507
+ attn_output = self.out_proj(attn_output)
508
+
509
+ if not output_attentions:
510
+ attn_weights = None
511
+
512
+ return attn_output, attn_weights, past_key_value
513
+
514
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
515
+ def _flash_attention_forward(
516
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
517
+ ):
518
+ """
519
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
520
+ first unpad the input, then computes the attention scores and pad the final attention scores.
521
+
522
+ Args:
523
+ query_states (`torch.Tensor`):
524
+ Input query states to be passed to Flash Attention API
525
+ key_states (`torch.Tensor`):
526
+ Input key states to be passed to Flash Attention API
527
+ value_states (`torch.Tensor`):
528
+ Input value states to be passed to Flash Attention API
529
+ attention_mask (`torch.Tensor`):
530
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
531
+ position of padding tokens and 1 for the position of non-padding tokens.
532
+ dropout (`float`):
533
+ Attention dropout
534
+ softmax_scale (`float`, *optional*):
535
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
536
+ """
537
+ if not self._flash_attn_uses_top_left_mask:
538
+ causal = self.is_causal
539
+ else:
540
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
541
+ causal = self.is_causal and query_length != 1
542
+
543
+ # Contains at least one padding token in the sequence
544
+ if attention_mask is not None:
545
+ batch_size = query_states.shape[0]
546
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
547
+ query_states, key_states, value_states, attention_mask, query_length
548
+ )
549
+
550
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
551
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
552
+
553
+ attn_output_unpad = flash_attn_varlen_func(
554
+ query_states,
555
+ key_states,
556
+ value_states,
557
+ cu_seqlens_q=cu_seqlens_q,
558
+ cu_seqlens_k=cu_seqlens_k,
559
+ max_seqlen_q=max_seqlen_in_batch_q,
560
+ max_seqlen_k=max_seqlen_in_batch_k,
561
+ dropout_p=dropout,
562
+ softmax_scale=softmax_scale,
563
+ causal=causal,
564
+ )
565
+
566
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
567
+ else:
568
+ attn_output = flash_attn_func(
569
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
570
+ )
571
+
572
+ return attn_output
573
+
574
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
575
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
576
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
577
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
578
+
579
+ key_layer = index_first_axis(
580
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
581
+ )
582
+ value_layer = index_first_axis(
583
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
584
+ )
585
+ if query_length == kv_seq_len:
586
+ query_layer = index_first_axis(
587
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
588
+ )
589
+ cu_seqlens_q = cu_seqlens_k
590
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
591
+ indices_q = indices_k
592
+ elif query_length == 1:
593
+ max_seqlen_in_batch_q = 1
594
+ cu_seqlens_q = torch.arange(
595
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
596
+ ) # There is a memcpy here, that is very bad.
597
+ indices_q = cu_seqlens_q[:-1]
598
+ query_layer = query_layer.squeeze(1)
599
+ else:
600
+ # The -q_len: slice assumes left padding.
601
+ attention_mask = attention_mask[:, -query_length:]
602
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
603
+
604
+ return (
605
+ query_layer,
606
+ key_layer,
607
+ value_layer,
608
+ indices_q,
609
+ (cu_seqlens_q, cu_seqlens_k),
610
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
611
+ )
612
+
613
+
614
+ class IndicTransSdpaAttention(IndicTransAttention):
615
+ def forward(
616
+ self,
617
+ hidden_states: torch.Tensor,
618
+ key_value_states: Optional[torch.Tensor] = None,
619
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
620
+ attention_mask: Optional[torch.Tensor] = None,
621
+ layer_head_mask: Optional[torch.Tensor] = None,
622
+ output_attentions: bool = False,
623
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
624
+ """Input shape: Batch x Time x Channel"""
625
+ if output_attentions or layer_head_mask is not None:
626
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
627
+ logger.warning_once(
628
+ "IndicTransModel is using IndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
629
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
630
+ )
631
+ return super().forward(
632
+ hidden_states,
633
+ key_value_states=key_value_states,
634
+ past_key_value=past_key_value,
635
+ attention_mask=attention_mask,
636
+ layer_head_mask=layer_head_mask,
637
+ output_attentions=output_attentions,
638
+ )
639
+
640
+ # if key_value_states are provided this layer is used as a cross-attention layer
641
+ # for the decoder
642
+ is_cross_attention = key_value_states is not None
643
+
644
+ bsz, tgt_len, _ = hidden_states.size()
645
+
646
+ # get query proj
647
+ query_states = self.q_proj(hidden_states)
648
+ # get key, value proj
649
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
650
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
651
+ # the provided `key_value_states` to support prefix tuning
652
+ if (
653
+ is_cross_attention
654
+ and past_key_value is not None
655
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
656
+ ):
657
+ # reuse k,v, cross_attentions
658
+ key_states = past_key_value[0]
659
+ value_states = past_key_value[1]
660
+ elif is_cross_attention:
661
+ # cross_attentions
662
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
663
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
664
+ elif past_key_value is not None:
665
+ # reuse k, v, self_attention
666
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
667
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
668
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
669
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
670
+ else:
671
+ # self_attention
672
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
673
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
674
+
675
+ if self.is_decoder:
676
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
677
+ # Further calls to cross_attention layer can then reuse all cross-attention
678
+ # key/value_states (first "if" case)
679
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
680
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
681
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
682
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
683
+ past_key_value = (key_states, value_states)
684
+
685
+ query_states = self._shape(query_states, tgt_len, bsz)
686
+
687
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
688
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
689
+ attn_output = F.scaled_dot_product_attention(
690
+ query_states,
691
+ key_states,
692
+ value_states,
693
+ attn_mask=attention_mask,
694
+ dropout_p=self.dropout if self.training else 0.0,
695
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
696
+ is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
697
+ )
698
+
699
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
700
+ raise ValueError(
701
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
702
+ f" {attn_output.size()}"
703
+ )
704
+
705
+ attn_output = attn_output.transpose(1, 2)
706
+
707
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
708
+ # partitioned across GPUs when using tensor-parallelism.
709
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
710
+
711
+ attn_output = self.out_proj(attn_output)
712
+
713
+ return attn_output, None, past_key_value
714
+
715
+
716
+ INDICTRANS_ATTENTION_CLASSES = {
717
+ "eager": IndicTransAttention,
718
+ "sdpa": IndicTransSdpaAttention,
719
+ "flash_attention_2": IndicTransFlashAttention2,
720
+ }
721
+
722
+ # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
723
+ class IndicTransEncoderLayer(nn.Module):
724
+ def __init__(self, config: IndicTransConfig):
725
+ super().__init__()
726
+ self.embed_dim = config.encoder_embed_dim
727
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
728
+ embed_dim=self.embed_dim,
729
+ num_heads=config.encoder_attention_heads,
730
+ dropout=config.attention_dropout,
731
+ config=config,
732
+ )
733
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
734
+ self.dropout = config.dropout
735
+ self.activation_fn = ACT2FN[config.activation_function]
736
+ self.activation_dropout = config.activation_dropout
737
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
738
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
739
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
740
+ self.normalize_before = config.encoder_normalize_before
741
+
742
+ def forward(
743
+ self,
744
+ hidden_states: torch.Tensor,
745
+ attention_mask: torch.Tensor,
746
+ layer_head_mask: torch.Tensor,
747
+ output_attentions: bool = False,
748
+ ) -> torch.Tensor:
749
+ """
750
+ Args:
751
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
752
+ attention_mask (`torch.FloatTensor`): attention mask of size
753
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
754
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
755
+ `(encoder_attention_heads,)`.
756
+ output_attentions (`bool`, *optional*):
757
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
758
+ returned tensors for more detail.
759
+ """
760
+ residual = hidden_states
761
+ if self.normalize_before:
762
+ hidden_states = self.self_attn_layer_norm(hidden_states)
763
+ hidden_states, attn_weights, _ = self.self_attn(
764
+ hidden_states=hidden_states,
765
+ attention_mask=attention_mask,
766
+ layer_head_mask=layer_head_mask,
767
+ output_attentions=output_attentions,
768
+ )
769
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
770
+ hidden_states = residual + hidden_states
771
+ if not self.normalize_before:
772
+ hidden_states = self.self_attn_layer_norm(hidden_states)
773
+
774
+ residual = hidden_states
775
+ if self.normalize_before:
776
+ hidden_states = self.final_layer_norm(hidden_states)
777
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
778
+ hidden_states = F.dropout(
779
+ hidden_states, p=self.activation_dropout, training=self.training
780
+ )
781
+ hidden_states = self.fc2(hidden_states)
782
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
783
+ hidden_states = residual + hidden_states
784
+ if not self.normalize_before:
785
+ hidden_states = self.final_layer_norm(hidden_states)
786
+
787
+ if hidden_states.dtype == torch.float16 and (
788
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
789
+ ):
790
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
791
+ hidden_states = torch.clamp(
792
+ hidden_states, min=-clamp_value, max=clamp_value
793
+ )
794
+
795
+ outputs = (hidden_states,)
796
+
797
+ if output_attentions:
798
+ outputs += (attn_weights,)
799
+
800
+ return outputs
801
+
802
+
803
+ # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
804
+ class IndicTransDecoderLayer(nn.Module):
805
+ def __init__(self, config: IndicTransConfig):
806
+ super().__init__()
807
+ self.embed_dim = config.decoder_embed_dim
808
+
809
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
810
+ embed_dim=self.embed_dim,
811
+ num_heads=config.decoder_attention_heads,
812
+ dropout=config.attention_dropout,
813
+ is_decoder=True,
814
+ is_causal=True,
815
+ config=config,
816
+ )
817
+ self.dropout = config.dropout
818
+ self.activation_fn = ACT2FN[config.activation_function]
819
+ self.activation_dropout = config.activation_dropout
820
+
821
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
822
+ self.encoder_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
823
+ self.embed_dim,
824
+ config.decoder_attention_heads,
825
+ dropout=config.attention_dropout,
826
+ is_decoder=True,
827
+ config=config,
828
+ )
829
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
830
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
831
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
832
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
833
+ self.normalize_before = config.decoder_normalize_before
834
+
835
+ def forward(
836
+ self,
837
+ hidden_states: torch.Tensor,
838
+ attention_mask: Optional[torch.Tensor] = None,
839
+ encoder_hidden_states: Optional[torch.Tensor] = None,
840
+ encoder_attention_mask: Optional[torch.Tensor] = None,
841
+ layer_head_mask: Optional[torch.Tensor] = None,
842
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
843
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
844
+ output_attentions: Optional[bool] = False,
845
+ use_cache: Optional[bool] = True,
846
+ ) -> torch.Tensor:
847
+ """
848
+ Args:
849
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
850
+ attention_mask (`torch.FloatTensor`): attention mask of size
851
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
852
+ encoder_hidden_states (`torch.FloatTensor`):
853
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
854
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
855
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
856
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
857
+ `(encoder_attention_heads,)`.
858
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
859
+ size `(decoder_attention_heads,)`.
860
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
861
+ output_attentions (`bool`, *optional*):
862
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
863
+ returned tensors for more detail.
864
+ """
865
+ residual = hidden_states
866
+ if self.normalize_before:
867
+ hidden_states = self.self_attn_layer_norm(hidden_states)
868
+
869
+ # Self Attention
870
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
871
+ self_attn_past_key_value = (
872
+ past_key_value[:2] if past_key_value is not None else None
873
+ )
874
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
875
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
876
+ hidden_states=hidden_states,
877
+ past_key_value=self_attn_past_key_value,
878
+ attention_mask=attention_mask,
879
+ layer_head_mask=layer_head_mask,
880
+ output_attentions=output_attentions,
881
+ )
882
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
883
+ hidden_states = residual + hidden_states
884
+ if not self.normalize_before:
885
+ hidden_states = self.self_attn_layer_norm(hidden_states)
886
+
887
+ # Cross-Attention Block
888
+ cross_attn_present_key_value = None
889
+ cross_attn_weights = None
890
+ if encoder_hidden_states is not None:
891
+ residual = hidden_states
892
+ if self.normalize_before:
893
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
894
+
895
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
896
+ cross_attn_past_key_value = (
897
+ past_key_value[-2:] if past_key_value is not None else None
898
+ )
899
+ (
900
+ hidden_states,
901
+ cross_attn_weights,
902
+ cross_attn_present_key_value,
903
+ ) = self.encoder_attn(
904
+ hidden_states=hidden_states,
905
+ key_value_states=encoder_hidden_states,
906
+ attention_mask=encoder_attention_mask,
907
+ layer_head_mask=cross_attn_layer_head_mask,
908
+ past_key_value=cross_attn_past_key_value,
909
+ output_attentions=output_attentions,
910
+ )
911
+ hidden_states = F.dropout(
912
+ hidden_states, p=self.dropout, training=self.training
913
+ )
914
+ hidden_states = residual + hidden_states
915
+ if not self.normalize_before:
916
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
917
+
918
+ # add cross-attn to positions 3,4 of present_key_value tuple
919
+ present_key_value = present_key_value + cross_attn_present_key_value
920
+
921
+ # Fully Connected
922
+ residual = hidden_states
923
+ if self.normalize_before:
924
+ hidden_states = self.final_layer_norm(hidden_states)
925
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
926
+ hidden_states = F.dropout(
927
+ hidden_states, p=self.activation_dropout, training=self.training
928
+ )
929
+ hidden_states = self.fc2(hidden_states)
930
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
931
+ hidden_states = residual + hidden_states
932
+ if not self.normalize_before:
933
+ hidden_states = self.final_layer_norm(hidden_states)
934
+
935
+ outputs = (hidden_states,)
936
+
937
+ if output_attentions:
938
+ outputs += (self_attn_weights, cross_attn_weights)
939
+
940
+ if use_cache:
941
+ outputs += (present_key_value,)
942
+
943
+ return outputs
944
+
945
+
946
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
947
+ class IndicTransPreTrainedModel(PreTrainedModel):
948
+ config_class = IndicTransConfig
949
+ base_model_prefix = "model"
950
+ supports_gradient_checkpointing = True
951
+ _no_split_modules = ["IndicTransAttention"]
952
+
953
+ def _init_weights(self, module):
954
+ std = self.config.init_std
955
+ if isinstance(module, nn.Linear):
956
+ module.weight.data.normal_(mean=0.0, std=std)
957
+ if module.bias is not None:
958
+ module.bias.data.zero_()
959
+ elif isinstance(module, nn.Embedding):
960
+ module.weight.data.normal_(mean=0.0, std=std)
961
+ if module.padding_idx is not None:
962
+ module.weight.data[module.padding_idx].zero_()
963
+
964
+ def _set_gradient_checkpointing(self, module, value=False):
965
+ if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
966
+ module.gradient_checkpointing = value
967
+
968
+
969
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
970
+ class IndicTransEncoder(IndicTransPreTrainedModel):
971
+ """
972
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
973
+ [`IndicTransEncoderLayer`].
974
+
975
+ Args:
976
+ config: IndicTransConfig
977
+ embed_tokens (nn.Embedding): output embedding
978
+ """
979
+
980
+ def __init__(
981
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
982
+ ):
983
+ super().__init__(config)
984
+
985
+ self.dropout = config.dropout
986
+ self.layerdrop = config.encoder_layerdrop
987
+
988
+ embed_dim = config.encoder_embed_dim
989
+ self.padding_idx = config.pad_token_id
990
+ self.max_source_positions = config.max_source_positions
991
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
992
+
993
+ self.embed_tokens = nn.Embedding(
994
+ config.encoder_vocab_size, embed_dim, self.padding_idx
995
+ )
996
+
997
+ if embed_tokens is not None:
998
+ self.embed_tokens.weight = embed_tokens.weight
999
+
1000
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1001
+ config.max_source_positions,
1002
+ embed_dim,
1003
+ self.padding_idx,
1004
+ )
1005
+ self.layers = nn.ModuleList(
1006
+ [IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
1007
+ )
1008
+ self.layer_norm = (
1009
+ nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
1010
+ )
1011
+ self.layernorm_embedding = (
1012
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1013
+ )
1014
+
1015
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1016
+ self._use_sdpa = config._attn_implementation == "sdpa"
1017
+
1018
+ self.gradient_checkpointing = False
1019
+ # Initialize weights and apply final processing
1020
+ self.post_init()
1021
+
1022
+ def forward(
1023
+ self,
1024
+ input_ids: Optional[torch.Tensor] = None,
1025
+ attention_mask: Optional[torch.Tensor] = None,
1026
+ head_mask: Optional[torch.Tensor] = None,
1027
+ inputs_embeds: Optional[torch.Tensor] = None,
1028
+ output_attentions: Optional[bool] = None,
1029
+ output_hidden_states: Optional[bool] = None,
1030
+ return_dict: Optional[bool] = None,
1031
+ ):
1032
+ r"""
1033
+ Args:
1034
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1035
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1036
+ provide it.
1037
+
1038
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1039
+ [`PreTrainedTokenizer.__call__`] for details.
1040
+
1041
+ [What are input IDs?](../glossary#input-ids)
1042
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1043
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1044
+
1045
+ - 1 for tokens that are **not masked**,
1046
+ - 0 for tokens that are **masked**.
1047
+
1048
+ [What are attention masks?](../glossary#attention-mask)
1049
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
1050
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1051
+
1052
+ - 1 indicates the head is **not masked**,
1053
+ - 0 indicates the head is **masked**.
1054
+
1055
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1056
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1057
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1058
+ than the model's internal embedding lookup matrix.
1059
+ output_attentions (`bool`, *optional*):
1060
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1061
+ returned tensors for more detail.
1062
+ output_hidden_states (`bool`, *optional*):
1063
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1064
+ for more detail.
1065
+ return_dict (`bool`, *optional*):
1066
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1067
+ """
1068
+ output_attentions = (
1069
+ output_attentions
1070
+ if output_attentions is not None
1071
+ else self.config.output_attentions
1072
+ )
1073
+ output_hidden_states = (
1074
+ output_hidden_states
1075
+ if output_hidden_states is not None
1076
+ else self.config.output_hidden_states
1077
+ )
1078
+ return_dict = (
1079
+ return_dict if return_dict is not None else self.config.use_return_dict
1080
+ )
1081
+
1082
+ # retrieve input_ids and inputs_embeds
1083
+ if input_ids is not None and inputs_embeds is not None:
1084
+ raise ValueError(
1085
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1086
+ )
1087
+ elif input_ids is not None:
1088
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1089
+ input_shape = input_ids.size()
1090
+ input_ids = input_ids.view(-1, input_shape[-1])
1091
+ elif inputs_embeds is not None:
1092
+ input_shape = inputs_embeds.size()[:-1]
1093
+ else:
1094
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1095
+
1096
+ if inputs_embeds is None:
1097
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1098
+
1099
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
1100
+ embed_pos = embed_pos.to(inputs_embeds.device)
1101
+
1102
+ hidden_states = inputs_embeds + embed_pos
1103
+ if self.layernorm_embedding is not None:
1104
+ hidden_states = self.layernorm_embedding(hidden_states)
1105
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1106
+
1107
+ if attention_mask is not None:
1108
+ if self._use_flash_attention_2:
1109
+ attention_mask = attention_mask if 0 in attention_mask else None
1110
+ elif self._use_sdpa and head_mask is None and not output_attentions:
1111
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1112
+ # the manual implementation that requires a 4D causal mask in all cases.
1113
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1114
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
1115
+ else:
1116
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1117
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1118
+
1119
+
1120
+ encoder_states = () if output_hidden_states else None
1121
+ all_attentions = () if output_attentions else None
1122
+
1123
+ # check if head_mask has a correct number of layers specified if desired
1124
+ if head_mask is not None:
1125
+ if head_mask.size()[0] != len(self.layers):
1126
+ raise ValueError(
1127
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
1128
+ f" {head_mask.size()[0]}."
1129
+ )
1130
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1131
+
1132
+ for idx, encoder_layer in enumerate(self.layers):
1133
+ if output_hidden_states:
1134
+ encoder_states = encoder_states + (hidden_states,)
1135
+
1136
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1137
+ dropout_probability = torch.rand([])
1138
+
1139
+ skip_the_layer = (
1140
+ True
1141
+ if self.training and (dropout_probability < self.layerdrop)
1142
+ else False
1143
+ )
1144
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1145
+ # under deepspeed zero3 all gpus must run in sync
1146
+
1147
+ if self.gradient_checkpointing and self.training:
1148
+ # create gradient checkpointing function
1149
+ def create_custom_forward(module):
1150
+ def custom_forward(*inputs):
1151
+ return module(*inputs, output_attentions)
1152
+
1153
+ return custom_forward
1154
+
1155
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1156
+ create_custom_forward(encoder_layer),
1157
+ hidden_states,
1158
+ attention_mask,
1159
+ (head_mask[idx] if head_mask is not None else None),
1160
+ )
1161
+ else:
1162
+ layer_outputs = encoder_layer(
1163
+ hidden_states,
1164
+ attention_mask,
1165
+ layer_head_mask=(
1166
+ head_mask[idx] if head_mask is not None else None
1167
+ ),
1168
+ output_attentions=output_attentions,
1169
+ )
1170
+
1171
+ hidden_states = layer_outputs[0]
1172
+
1173
+ if skip_the_layer:
1174
+ layer_outputs = (None, None)
1175
+
1176
+ if output_attentions:
1177
+ all_attentions = all_attentions + (layer_outputs[1],)
1178
+
1179
+ if self.layer_norm is not None:
1180
+ hidden_states = self.layer_norm(hidden_states)
1181
+
1182
+ if output_hidden_states:
1183
+ encoder_states = encoder_states + (hidden_states,)
1184
+
1185
+ if not return_dict:
1186
+ return tuple(
1187
+ v
1188
+ for v in [hidden_states, encoder_states, all_attentions]
1189
+ if v is not None
1190
+ )
1191
+ return BaseModelOutput(
1192
+ last_hidden_state=hidden_states,
1193
+ hidden_states=encoder_states,
1194
+ attentions=all_attentions,
1195
+ )
1196
+
1197
+
1198
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
1199
+ class IndicTransDecoder(IndicTransPreTrainedModel):
1200
+ """
1201
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
1202
+
1203
+ Args:
1204
+ config: IndicTransConfig
1205
+ embed_tokens (nn.Embedding): output embedding
1206
+ """
1207
+
1208
+ def __init__(
1209
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
1210
+ ):
1211
+ super().__init__(config)
1212
+ self.dropout = config.dropout
1213
+ self.layerdrop = config.decoder_layerdrop
1214
+
1215
+ embed_dim = config.encoder_embed_dim
1216
+ self.padding_idx = config.pad_token_id
1217
+ self.max_target_positions = config.max_target_positions
1218
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
1219
+
1220
+ self.embed_tokens = nn.Embedding(
1221
+ config.decoder_vocab_size, embed_dim, self.padding_idx
1222
+ )
1223
+
1224
+ if embed_tokens is not None:
1225
+ self.embed_tokens.weight = embed_tokens.weight
1226
+
1227
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1228
+ config.max_target_positions,
1229
+ embed_dim,
1230
+ self.padding_idx,
1231
+ )
1232
+ self.layers = nn.ModuleList(
1233
+ [IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
1234
+ )
1235
+ self.layer_norm = (
1236
+ nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
1237
+ )
1238
+ self.layernorm_embedding = (
1239
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1240
+ )
1241
+
1242
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1243
+ self._use_sdpa = config._attn_implementation == "sdpa"
1244
+
1245
+ self.gradient_checkpointing = False
1246
+ # Initialize weights and apply final processing
1247
+ self.post_init()
1248
+
1249
+ def forward(
1250
+ self,
1251
+ input_ids: Optional[torch.Tensor] = None,
1252
+ attention_mask: Optional[torch.Tensor] = None,
1253
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1254
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1255
+ head_mask: Optional[torch.Tensor] = None,
1256
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1257
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1258
+ inputs_embeds: Optional[torch.Tensor] = None,
1259
+ use_cache: Optional[bool] = None,
1260
+ output_attentions: Optional[bool] = None,
1261
+ output_hidden_states: Optional[bool] = None,
1262
+ return_dict: Optional[bool] = None,
1263
+ ):
1264
+ r"""
1265
+ Args:
1266
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1267
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1268
+ provide it.
1269
+
1270
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1271
+ [`PreTrainedTokenizer.__call__`] for details.
1272
+
1273
+ [What are input IDs?](../glossary#input-ids)
1274
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1275
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1276
+
1277
+ - 1 for tokens that are **not masked**,
1278
+ - 0 for tokens that are **masked**.
1279
+
1280
+ [What are attention masks?](../glossary#attention-mask)
1281
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
1282
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1283
+ of the decoder.
1284
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
1285
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
1286
+ selected in `[0, 1]`:
1287
+
1288
+ - 1 for tokens that are **not masked**,
1289
+ - 0 for tokens that are **masked**.
1290
+
1291
+ [What are attention masks?](../glossary#attention-mask)
1292
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1293
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1294
+
1295
+ - 1 indicates the head is **not masked**,
1296
+ - 0 indicates the head is **masked**.
1297
+
1298
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1299
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
1300
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
1301
+
1302
+ - 1 indicates the head is **not masked**,
1303
+ - 0 indicates the head is **masked**.
1304
+
1305
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1306
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1307
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1308
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1309
+
1310
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1311
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1312
+
1313
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1314
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1315
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
1316
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
1317
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
1318
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
1319
+ embedding lookup matrix.
1320
+ output_attentions (`bool`, *optional*):
1321
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1322
+ returned tensors for more detail.
1323
+ output_hidden_states (`bool`, *optional*):
1324
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1325
+ for more detail.
1326
+ return_dict (`bool`, *optional*):
1327
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1328
+ """
1329
+ output_attentions = (
1330
+ output_attentions
1331
+ if output_attentions is not None
1332
+ else self.config.output_attentions
1333
+ )
1334
+ output_hidden_states = (
1335
+ output_hidden_states
1336
+ if output_hidden_states is not None
1337
+ else self.config.output_hidden_states
1338
+ )
1339
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1340
+ return_dict = (
1341
+ return_dict if return_dict is not None else self.config.use_return_dict
1342
+ )
1343
+
1344
+ # retrieve input_ids and inputs_embeds
1345
+ if input_ids is not None and inputs_embeds is not None:
1346
+ raise ValueError(
1347
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1348
+ )
1349
+ elif input_ids is not None:
1350
+ input_shape = input_ids.size()
1351
+ input_ids = input_ids.view(-1, input_shape[-1])
1352
+ elif inputs_embeds is not None:
1353
+ input_shape = inputs_embeds.size()[:-1]
1354
+ else:
1355
+ raise ValueError(
1356
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1357
+ )
1358
+
1359
+ # past_key_values_length
1360
+ past_key_values_length = (
1361
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1362
+ )
1363
+
1364
+ if inputs_embeds is None:
1365
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1366
+
1367
+
1368
+ if self._use_flash_attention_2:
1369
+ # 2d mask is passed through the layers
1370
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1371
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1372
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1373
+ # the manual implementation that requires a 4D causal mask in all cases.
1374
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1375
+ attention_mask,
1376
+ input_shape,
1377
+ inputs_embeds,
1378
+ past_key_values_length,
1379
+ )
1380
+ else:
1381
+ # 4d mask is passed through the layers
1382
+ attention_mask = _prepare_4d_causal_attention_mask(
1383
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1384
+ )
1385
+
1386
+ # expand encoder attention mask
1387
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1388
+ if self._use_flash_attention_2:
1389
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1390
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
1391
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1392
+ # the manual implementation that requires a 4D causal mask in all cases.
1393
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1394
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1395
+ encoder_attention_mask,
1396
+ inputs_embeds.dtype,
1397
+ tgt_len=input_shape[-1],
1398
+ )
1399
+ else:
1400
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1401
+ encoder_attention_mask = _prepare_4d_attention_mask(
1402
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1403
+ )
1404
+
1405
+ # embed positions
1406
+ positions = self.embed_positions(
1407
+ input_ids, inputs_embeds, past_key_values_length
1408
+ )
1409
+ positions = positions.to(inputs_embeds.device)
1410
+
1411
+ hidden_states = inputs_embeds + positions
1412
+ if self.layernorm_embedding is not None:
1413
+ hidden_states = self.layernorm_embedding(hidden_states)
1414
+
1415
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1416
+
1417
+ if self.gradient_checkpointing and self.training:
1418
+ if use_cache:
1419
+ logger.warning_once(
1420
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting"
1421
+ " `use_cache=False`..."
1422
+ )
1423
+ use_cache = False
1424
+
1425
+ # decoder layers
1426
+ all_hidden_states = () if output_hidden_states else None
1427
+ all_self_attns = () if output_attentions else None
1428
+ all_cross_attentions = () if output_attentions else None
1429
+ next_decoder_cache = () if use_cache else None
1430
+
1431
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1432
+ for attn_mask, mask_name in zip(
1433
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1434
+ ):
1435
+ if attn_mask is not None:
1436
+ if attn_mask.size()[0] != len(self.layers):
1437
+ raise ValueError(
1438
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1439
+ f" {head_mask.size()[0]}."
1440
+ )
1441
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1442
+
1443
+ for idx, decoder_layer in enumerate(self.layers):
1444
+ if output_hidden_states:
1445
+ all_hidden_states += (hidden_states,)
1446
+
1447
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1448
+ dropout_probability = torch.rand([])
1449
+
1450
+ skip_the_layer = (
1451
+ True
1452
+ if self.training and (dropout_probability < self.layerdrop)
1453
+ else False
1454
+ )
1455
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1456
+ # under deepspeed zero3 all gpus must run in sync
1457
+
1458
+ past_key_value = (
1459
+ past_key_values[idx] if past_key_values is not None else None
1460
+ )
1461
+
1462
+ if self.gradient_checkpointing and self.training:
1463
+
1464
+ def create_custom_forward(module):
1465
+ def custom_forward(*inputs):
1466
+ # None for past_key_value
1467
+ return module(*inputs, output_attentions, use_cache)
1468
+
1469
+ return custom_forward
1470
+
1471
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1472
+ create_custom_forward(decoder_layer),
1473
+ hidden_states,
1474
+ attention_mask,
1475
+ encoder_hidden_states,
1476
+ encoder_attention_mask,
1477
+ head_mask[idx] if head_mask is not None else None,
1478
+ cross_attn_head_mask[idx]
1479
+ if cross_attn_head_mask is not None
1480
+ else None,
1481
+ None,
1482
+ )
1483
+ else:
1484
+ layer_outputs = decoder_layer(
1485
+ hidden_states,
1486
+ attention_mask=attention_mask,
1487
+ encoder_hidden_states=encoder_hidden_states,
1488
+ encoder_attention_mask=encoder_attention_mask,
1489
+ layer_head_mask=(
1490
+ head_mask[idx] if head_mask is not None else None
1491
+ ),
1492
+ cross_attn_layer_head_mask=(
1493
+ cross_attn_head_mask[idx]
1494
+ if cross_attn_head_mask is not None
1495
+ else None
1496
+ ),
1497
+ past_key_value=past_key_value,
1498
+ output_attentions=output_attentions,
1499
+ use_cache=use_cache,
1500
+ )
1501
+
1502
+ hidden_states = layer_outputs[0]
1503
+
1504
+ if skip_the_layer:
1505
+ continue
1506
+
1507
+ if use_cache:
1508
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1509
+
1510
+ if output_attentions:
1511
+ all_self_attns += (layer_outputs[1],)
1512
+ all_cross_attentions += (layer_outputs[2],)
1513
+
1514
+ if self.layer_norm is not None:
1515
+ hidden_states = self.layer_norm(hidden_states)
1516
+
1517
+ # add hidden states from the last decoder layer
1518
+ if output_hidden_states:
1519
+ all_hidden_states += (hidden_states,)
1520
+
1521
+ next_cache = next_decoder_cache if use_cache else None
1522
+ if not return_dict:
1523
+ return tuple(
1524
+ v
1525
+ for v in [
1526
+ hidden_states,
1527
+ next_cache,
1528
+ all_hidden_states,
1529
+ all_self_attns,
1530
+ all_cross_attentions,
1531
+ ]
1532
+ if v is not None
1533
+ )
1534
+ return BaseModelOutputWithPastAndCrossAttentions(
1535
+ last_hidden_state=hidden_states,
1536
+ past_key_values=next_cache,
1537
+ hidden_states=all_hidden_states,
1538
+ attentions=all_self_attns,
1539
+ cross_attentions=all_cross_attentions,
1540
+ )
1541
+
1542
+
1543
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
1544
+ class IndicTransModel(IndicTransPreTrainedModel):
1545
+ _tied_weights_keys = None
1546
+
1547
+ def __init__(self, config: IndicTransConfig):
1548
+ super().__init__(config)
1549
+
1550
+ self.encoder = IndicTransEncoder(config)
1551
+ self.decoder = IndicTransDecoder(config)
1552
+
1553
+ # Initialize weights and apply final processing
1554
+ self.post_init()
1555
+
1556
+ def get_encoder(self):
1557
+ return self.encoder
1558
+
1559
+ def get_decoder(self):
1560
+ return self.decoder
1561
+
1562
+ def forward(
1563
+ self,
1564
+ input_ids: Optional[torch.LongTensor] = None,
1565
+ attention_mask: Optional[torch.Tensor] = None,
1566
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1567
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1568
+ head_mask: Optional[torch.Tensor] = None,
1569
+ decoder_head_mask: Optional[torch.Tensor] = None,
1570
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1571
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1572
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1573
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1574
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1575
+ use_cache: Optional[bool] = None,
1576
+ output_attentions: Optional[bool] = None,
1577
+ output_hidden_states: Optional[bool] = None,
1578
+ return_dict: Optional[bool] = None,
1579
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1580
+ output_attentions = (
1581
+ output_attentions
1582
+ if output_attentions is not None
1583
+ else self.config.output_attentions
1584
+ )
1585
+ output_hidden_states = (
1586
+ output_hidden_states
1587
+ if output_hidden_states is not None
1588
+ else self.config.output_hidden_states
1589
+ )
1590
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1591
+ return_dict = (
1592
+ return_dict if return_dict is not None else self.config.use_return_dict
1593
+ )
1594
+
1595
+ if encoder_outputs is None:
1596
+ encoder_outputs = self.encoder(
1597
+ input_ids=input_ids,
1598
+ attention_mask=attention_mask,
1599
+ head_mask=head_mask,
1600
+ inputs_embeds=inputs_embeds,
1601
+ output_attentions=output_attentions,
1602
+ output_hidden_states=output_hidden_states,
1603
+ return_dict=return_dict,
1604
+ )
1605
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1606
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1607
+ encoder_outputs = BaseModelOutput(
1608
+ last_hidden_state=encoder_outputs[0],
1609
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1610
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1611
+ )
1612
+
1613
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1614
+ decoder_outputs = self.decoder(
1615
+ input_ids=decoder_input_ids,
1616
+ attention_mask=decoder_attention_mask,
1617
+ encoder_hidden_states=encoder_outputs[0],
1618
+ encoder_attention_mask=attention_mask,
1619
+ head_mask=decoder_head_mask,
1620
+ cross_attn_head_mask=cross_attn_head_mask,
1621
+ past_key_values=past_key_values,
1622
+ inputs_embeds=decoder_inputs_embeds,
1623
+ use_cache=use_cache,
1624
+ output_attentions=output_attentions,
1625
+ output_hidden_states=output_hidden_states,
1626
+ return_dict=return_dict,
1627
+ )
1628
+
1629
+ if not return_dict:
1630
+ return decoder_outputs + encoder_outputs
1631
+
1632
+ return Seq2SeqModelOutput(
1633
+ last_hidden_state=decoder_outputs.last_hidden_state,
1634
+ past_key_values=decoder_outputs.past_key_values,
1635
+ decoder_hidden_states=decoder_outputs.hidden_states,
1636
+ decoder_attentions=decoder_outputs.attentions,
1637
+ cross_attentions=decoder_outputs.cross_attentions,
1638
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1639
+ encoder_hidden_states=encoder_outputs.hidden_states,
1640
+ encoder_attentions=encoder_outputs.attentions,
1641
+ )
1642
+
1643
+
1644
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1645
+ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
1646
+ base_model_prefix = "model"
1647
+ _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
1648
+ _label_smoothing = 0.0
1649
+
1650
+ def __init__(self, config: IndicTransConfig):
1651
+ super().__init__(config)
1652
+ self.model = IndicTransModel(config)
1653
+ self.lm_head = nn.Linear(
1654
+ config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1655
+ )
1656
+
1657
+ self.post_init()
1658
+
1659
+ def tie_weights(self):
1660
+ if self.config.share_decoder_input_output_embed:
1661
+ self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.lm_head)
1662
+
1663
+ def get_encoder(self):
1664
+ return self.model.encoder
1665
+
1666
+ def get_decoder(self):
1667
+ return self.model.decoder
1668
+
1669
+ def get_input_embeddings(self):
1670
+ return self.model.encoder.embed_tokens
1671
+
1672
+ def get_output_embeddings(self):
1673
+ return self.lm_head
1674
+
1675
+ def set_output_embeddings(self, new_embeddings):
1676
+ self.lm_head = new_embeddings
1677
+
1678
+ def set_label_smoothing(self, label_smoothing):
1679
+ self._label_smoothing = label_smoothing
1680
+ def forward(
1681
+ self,
1682
+ input_ids: Optional[torch.LongTensor] = None,
1683
+ attention_mask: Optional[torch.Tensor] = None,
1684
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1685
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1686
+ head_mask: Optional[torch.Tensor] = None,
1687
+ decoder_head_mask: Optional[torch.Tensor] = None,
1688
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1689
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1690
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1691
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1692
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1693
+ labels: Optional[torch.LongTensor] = None,
1694
+ use_cache: Optional[bool] = None,
1695
+ output_attentions: Optional[bool] = None,
1696
+ output_hidden_states: Optional[bool] = None,
1697
+ return_dict: Optional[bool] = None,
1698
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1699
+ r"""
1700
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1701
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1702
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1703
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1704
+
1705
+ Returns:
1706
+ """
1707
+ return_dict = (
1708
+ return_dict if return_dict is not None else self.config.use_return_dict
1709
+ )
1710
+
1711
+ if labels is not None:
1712
+ if decoder_input_ids is None:
1713
+ decoder_input_ids = shift_tokens_right(
1714
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1715
+ )
1716
+
1717
+ outputs = self.model(
1718
+ input_ids,
1719
+ attention_mask=attention_mask,
1720
+ decoder_input_ids=decoder_input_ids,
1721
+ encoder_outputs=encoder_outputs,
1722
+ decoder_attention_mask=decoder_attention_mask,
1723
+ head_mask=head_mask,
1724
+ decoder_head_mask=decoder_head_mask,
1725
+ cross_attn_head_mask=cross_attn_head_mask,
1726
+ past_key_values=past_key_values,
1727
+ inputs_embeds=inputs_embeds,
1728
+ decoder_inputs_embeds=decoder_inputs_embeds,
1729
+ use_cache=use_cache,
1730
+ output_attentions=output_attentions,
1731
+ output_hidden_states=output_hidden_states,
1732
+ return_dict=return_dict,
1733
+ )
1734
+ lm_logits = self.lm_head(outputs[0])
1735
+
1736
+ masked_lm_loss = None
1737
+ if labels is not None:
1738
+ # move labels to the correct device to enable PP
1739
+ labels = labels.to(lm_logits.device)
1740
+ masked_lm_loss = F.cross_entropy(
1741
+ input=lm_logits.view(-1, self.config.decoder_vocab_size),
1742
+ target=labels.view(-1),
1743
+ ignore_index=-100,
1744
+ label_smoothing=self._label_smoothing,
1745
+ )
1746
+
1747
+ if not return_dict:
1748
+ output = (lm_logits,) + outputs[1:]
1749
+ return (
1750
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1751
+ )
1752
+
1753
+ return Seq2SeqLMOutput(
1754
+ loss=masked_lm_loss,
1755
+ logits=lm_logits,
1756
+ past_key_values=outputs.past_key_values,
1757
+ decoder_hidden_states=outputs.decoder_hidden_states,
1758
+ decoder_attentions=outputs.decoder_attentions,
1759
+ cross_attentions=outputs.cross_attentions,
1760
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1761
+ encoder_hidden_states=outputs.encoder_hidden_states,
1762
+ encoder_attentions=outputs.encoder_attentions,
1763
+ )
1764
+
1765
+ def prepare_inputs_for_generation(
1766
+ self,
1767
+ decoder_input_ids,
1768
+ past_key_values=None,
1769
+ attention_mask=None,
1770
+ head_mask=None,
1771
+ decoder_head_mask=None,
1772
+ cross_attn_head_mask=None,
1773
+ use_cache=None,
1774
+ encoder_outputs=None,
1775
+ **kwargs,
1776
+ ):
1777
+ # cut decoder_input_ids if past is used
1778
+ if past_key_values is not None:
1779
+ decoder_input_ids = decoder_input_ids[:, -1:]
1780
+
1781
+ return {
1782
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1783
+ "encoder_outputs": encoder_outputs,
1784
+ "past_key_values": past_key_values,
1785
+ "decoder_input_ids": decoder_input_ids,
1786
+ "attention_mask": attention_mask,
1787
+ "head_mask": head_mask,
1788
+ "decoder_head_mask": decoder_head_mask,
1789
+ "cross_attn_head_mask": cross_attn_head_mask,
1790
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1791
+ }
1792
+
1793
+ @staticmethod
1794
+ def _reorder_cache(past_key_values, beam_idx):
1795
+ reordered_past = ()
1796
+ for layer_past in past_key_values:
1797
+ reordered_past += (
1798
+ tuple(
1799
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1800
+ ),
1801
+ )
1802
+ return reordered_past