aarath97 commited on
Commit
6bcbd0d
·
verified ·
1 Parent(s): 91053ac

Upload 16 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.SRC filter=lfs diff=lfs merge=lfs -text
37
+ model.TGT filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) AI4Bharat.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
README.md ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - as
4
+ - bn
5
+ - brx
6
+ - doi
7
+ - en
8
+ - gom
9
+ - gu
10
+ - hi
11
+ - kn
12
+ - ks
13
+ - kas
14
+ - mai
15
+ - ml
16
+ - mr
17
+ - mni
18
+ - mnb
19
+ - ne
20
+ - or
21
+ - pa
22
+ - sa
23
+ - sat
24
+ - sd
25
+ - snd
26
+ - ta
27
+ - te
28
+ - ur
29
+ language_details: >-
30
+ asm_Beng, ben_Beng, brx_Deva, doi_Deva, eng_Latn, gom_Deva, guj_Gujr,
31
+ hin_Deva, kan_Knda, kas_Arab, kas_Deva, mai_Deva, mal_Mlym, mar_Deva,
32
+ mni_Beng, mni_Mtei, npi_Deva, ory_Orya, pan_Guru, san_Deva, sat_Olck,
33
+ snd_Arab, snd_Deva, tam_Taml, tel_Telu, urd_Arab
34
+ tags:
35
+ - indictrans2
36
+ - translation
37
+ - ai4bharat
38
+ - multilingual
39
+ license: mit
40
+ datasets:
41
+ - flores-200
42
+ - IN22-Gen
43
+ - IN22-Conv
44
+ metrics:
45
+ - bleu
46
+ - chrf
47
+ - chrf++
48
+ - comet
49
+ inference: false
50
+ ---
51
+
52
+ # IndicTrans2
53
+
54
+ This is the model card of IndicTrans2 Indic-En 1.1B variant.
55
+
56
+ Here are the [metrics](https://drive.google.com/drive/folders/1lOOdaU0VdRSBgJEsNav5zC7wwLBis9NI?usp=sharing) for the particular checkpoint.
57
+
58
+ Please refer to `Appendix D: Model Card` of the [preprint](https://arxiv.org/abs/2305.16307) for further details on model training, intended use, data, metrics, limitations and recommendations.
59
+
60
+
61
+ ### Usage Instructions
62
+
63
+ Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_interface) for a detail description on how to use HF compatible IndicTrans2 models for inference.
64
+
65
+ ```python
66
+ import torch
67
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
68
+ from IndicTransToolkit.processor import IndicProcessor
69
+ # recommended to run this on a gpu with flash_attn installed
70
+ # don't set attn_implemetation if you don't have flash_attn
71
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
72
+
73
+ src_lang, tgt_lang = "hin_Deva", "eng_Latn"
74
+ model_name = "ai4bharat/indictrans2-indic-en-1B"
75
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
76
+
77
+ model = AutoModelForSeq2SeqLM.from_pretrained(
78
+ model_name,
79
+ trust_remote_code=True,
80
+ torch_dtype=torch.float16, # performance might slightly vary for bfloat16
81
+ attn_implementation="flash_attention_2"
82
+ ).to(DEVICE)
83
+
84
+ ip = IndicProcessor(inference=True)
85
+
86
+ input_sentences = [
87
+ "जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।",
88
+ "हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।",
89
+ "अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।",
90
+ "मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
91
+ ]
92
+
93
+ batch = ip.preprocess_batch(
94
+ input_sentences,
95
+ src_lang=src_lang,
96
+ tgt_lang=tgt_lang,
97
+ )
98
+
99
+ # Tokenize the sentences and generate input encodings
100
+ inputs = tokenizer(
101
+ batch,
102
+ truncation=True,
103
+ padding="longest",
104
+ return_tensors="pt",
105
+ return_attention_mask=True,
106
+ ).to(DEVICE)
107
+
108
+ # Generate translations using the model
109
+ with torch.no_grad():
110
+ generated_tokens = model.generate(
111
+ **inputs,
112
+ use_cache=True,
113
+ min_length=0,
114
+ max_length=256,
115
+ num_beams=5,
116
+ num_return_sequences=1,
117
+ )
118
+
119
+ # Decode the generated tokens into text
120
+ generated_tokens = tokenizer.batch_decode(
121
+ generated_tokens,
122
+ skip_special_tokens=True,
123
+ clean_up_tokenization_spaces=True,
124
+ )
125
+
126
+ # Postprocess the translations, including entity replacement
127
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
128
+
129
+ for input_sentence, translation in zip(input_sentences, translations):
130
+ print(f"{src_lang}: {input_sentence}")
131
+ print(f"{tgt_lang}: {translation}")
132
+ ```
133
+
134
+ ### 📢 Long Context IT2 Models
135
+
136
+ - New RoPE based IndicTrans2 models which are capable of handling sequence lengths **upto 2048 tokens** are available [here](https://huggingface.co/collections/prajdabre/indictrans2-rope-6742ddac669a05db0804db35).
137
+ - These models can be used by just changing the `model_name` parameter. Please read the model card of the RoPE-IT2 models for more information about the generation.
138
+ - It is recommended to run these models with `flash_attention_2` for efficient generation.
139
+
140
+ ### Citation
141
+
142
+ If you consider using our work then please cite using:
143
+
144
+ ```
145
+ @article{gala2023indictrans,
146
+ title={IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
147
+ author={Jay Gala and Pranjal A Chitale and A K Raghavan and Varun Gumma and Sumanth Doddapaneni and Aswanth Kumar M and Janki Atul Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M Khapra and Raj Dabre and Anoop Kunchukuttan},
148
+ journal={Transactions on Machine Learning Research},
149
+ issn={2835-8856},
150
+ year={2023},
151
+ url={https://openreview.net/forum?id=vfT4YuzAYA},
152
+ note={}
153
+ }
154
+ ```
config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "ai4bharat/indictrans2-indic-en-1B",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "architectures": [
6
+ "IndicTransForConditionalGeneration"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_indictrans.IndicTransConfig",
10
+ "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
11
+ },
12
+ "tokenizer_class": "IndicTransTokenizer",
13
+ "attention_dropout": 0.0,
14
+ "bos_token_id": 0,
15
+ "decoder_attention_heads": 16,
16
+ "decoder_embed_dim": 1024,
17
+ "decoder_ffn_dim": 8192,
18
+ "decoder_layerdrop": 0,
19
+ "decoder_layers": 18,
20
+ "decoder_normalize_before": true,
21
+ "decoder_start_token_id": 2,
22
+ "decoder_vocab_size": 32296,
23
+ "vocab_size": 32296,
24
+ "dropout": 0.2,
25
+ "encoder_attention_heads": 16,
26
+ "encoder_embed_dim": 1024,
27
+ "encoder_ffn_dim": 8192,
28
+ "encoder_layerdrop": 0,
29
+ "encoder_layers": 18,
30
+ "encoder_normalize_before": true,
31
+ "encoder_vocab_size": 122706,
32
+ "eos_token_id": 2,
33
+ "init_std": 0.02,
34
+ "is_encoder_decoder": true,
35
+ "layernorm_embedding": false,
36
+ "max_source_positions": 256,
37
+ "max_target_positions": 256,
38
+ "model_type": "IndicTrans",
39
+ "num_hidden_layers": 18,
40
+ "pad_token_id": 1,
41
+ "scale_embedding": true,
42
+ "share_decoder_input_output_embed": false,
43
+ "torch_dtype": "float32",
44
+ "transformers_version": "4.32.1",
45
+ "use_cache": true,
46
+ "attn_implementation": "eager"
47
+ }
configuration_indictrans.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans config."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, Mapping, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
24
+ from transformers.onnx.utils import compute_effective_axis_dimension
25
+ from transformers.utils import TensorType, is_torch_available
26
+
27
+
28
+ # Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
29
+ class IndicTransConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
32
+ IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the IT2
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50265):
41
+ Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`IT2Model`] or
43
+ d_model (`int`, *optional*, defaults to 1024):
44
+ Dimensionality of the layers and the pooler layer.
45
+ encoder_layers (`int`, *optional*, defaults to 12):
46
+ Number of encoder layers.
47
+ decoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of decoder layers.
49
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
55
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for classifier.
68
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
69
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
70
+ just in case (e.g., 512 or 1024 or 2048).
71
+ init_std (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
74
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
75
+ for more details.
76
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
77
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
78
+ for more details.
79
+ use_cache (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the model should return the last key/values attentions (not used by all models).
81
+ ```"""
82
+ model_type = "IndicTrans"
83
+ keys_to_ignore_at_inference = ["past_key_values"]
84
+ attribute_map = {
85
+ "num_attention_heads": "encoder_attention_heads",
86
+ "hidden_size": "d_model",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ encoder_vocab_size=None,
92
+ decoder_vocab_size=None,
93
+ encoder_embed_dim=512,
94
+ decoder_embed_dim=512,
95
+ max_source_positions=210,
96
+ max_target_positions=210,
97
+ encoder_layers=6,
98
+ encoder_ffn_dim=2048,
99
+ encoder_attention_heads=8,
100
+ decoder_layers=6,
101
+ decoder_ffn_dim=2048,
102
+ decoder_attention_heads=8,
103
+ encoder_layerdrop=0.00,
104
+ decoder_layerdrop=0.00,
105
+ use_cache=True,
106
+ is_encoder_decoder=True,
107
+ activation_function="relu",
108
+ encoder_normalize_before=False,
109
+ decoder_normalize_before=False,
110
+ layernorm_embedding=False,
111
+ share_decoder_input_output_embed=False,
112
+ dropout=0.1,
113
+ attention_dropout=0.0,
114
+ activation_dropout=0.0,
115
+ init_std=0.02,
116
+ scale_embedding=True,
117
+ decoder_start_token_id=2,
118
+ pad_token_id=1,
119
+ bos_token_id=0,
120
+ eos_token_id=2,
121
+ attn_implementation="eager",
122
+ **kwargs,
123
+ ):
124
+ self.encoder_vocab_size = encoder_vocab_size
125
+ self.decoder_vocab_size = decoder_vocab_size
126
+ self.encoder_normalize_before = encoder_normalize_before
127
+ self.decoder_normalize_before = decoder_normalize_before
128
+ self.layernorm_embedding = layernorm_embedding
129
+ self.max_source_positions = max_source_positions
130
+ self.max_target_positions = max_target_positions
131
+ self.encoder_embed_dim = encoder_embed_dim
132
+ self.decoder_embed_dim = decoder_embed_dim
133
+ self.encoder_ffn_dim = encoder_ffn_dim
134
+ self.encoder_layers = encoder_layers
135
+ self.encoder_attention_heads = encoder_attention_heads
136
+ self.decoder_ffn_dim = decoder_ffn_dim
137
+ self.decoder_layers = decoder_layers
138
+ self.decoder_attention_heads = decoder_attention_heads
139
+ self.dropout = dropout
140
+ self.attention_dropout = attention_dropout
141
+ self.activation_dropout = activation_dropout
142
+ self.activation_function = activation_function
143
+ self.init_std = init_std
144
+ self.encoder_layerdrop = encoder_layerdrop
145
+ self.decoder_layerdrop = decoder_layerdrop
146
+ self.use_cache = use_cache
147
+ self.num_hidden_layers = encoder_layers
148
+ self.scale_embedding = scale_embedding
149
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
150
+ self.attn_implementation = attn_implementation
151
+
152
+ super().__init__(
153
+ pad_token_id=pad_token_id,
154
+ bos_token_id=bos_token_id,
155
+ eos_token_id=eos_token_id,
156
+ is_encoder_decoder=is_encoder_decoder,
157
+ decoder_start_token_id=decoder_start_token_id,
158
+ **kwargs,
159
+ )
160
+
161
+
162
+ class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
163
+ @property
164
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
165
+ common_inputs = OrderedDict(
166
+ [
167
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
168
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
169
+ ]
170
+ )
171
+
172
+ if self.use_past:
173
+ common_inputs["decoder_input_ids"] = {0: "batch"}
174
+ common_inputs["decoder_attention_mask"] = {
175
+ 0: "batch",
176
+ 1: "past_decoder_sequence + sequence",
177
+ }
178
+ else:
179
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
180
+ common_inputs["decoder_attention_mask"] = {
181
+ 0: "batch",
182
+ 1: "decoder_sequence",
183
+ }
184
+
185
+ if self.use_past:
186
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
187
+ return common_inputs
188
+
189
+ # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
190
+ # A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
191
+ # answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
192
+ # was done for BART so that it can be updated if need be.
193
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
194
+ self,
195
+ tokenizer: PreTrainedTokenizer,
196
+ batch_size: int = -1,
197
+ seq_length: int = -1,
198
+ is_pair: bool = False,
199
+ framework: Optional[TensorType] = None,
200
+ ) -> Mapping[str, Any]:
201
+ # Copied from OnnxConfig.generate_dummy_inputs
202
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
203
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
204
+ batch_size = compute_effective_axis_dimension(
205
+ batch_size,
206
+ fixed_dimension=OnnxConfig.default_fixed_batch,
207
+ num_token_to_add=0,
208
+ )
209
+
210
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
211
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
212
+ seq_length = compute_effective_axis_dimension(
213
+ seq_length,
214
+ fixed_dimension=OnnxConfig.default_fixed_sequence,
215
+ num_token_to_add=token_to_add,
216
+ )
217
+
218
+ # Generate dummy inputs according to compute batch and sequence
219
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
220
+ common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
221
+ return common_inputs
222
+
223
+ # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
224
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
225
+ self,
226
+ tokenizer: PreTrainedTokenizer,
227
+ batch_size: int = -1,
228
+ seq_length: int = -1,
229
+ is_pair: bool = False,
230
+ framework: Optional[TensorType] = None,
231
+ ) -> Mapping[str, Any]:
232
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
233
+ tokenizer, batch_size, seq_length, is_pair, framework
234
+ )
235
+
236
+ # Generate decoder inputs
237
+ decoder_seq_length = seq_length if not self.use_past else 1
238
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
239
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
240
+ )
241
+ decoder_inputs = {
242
+ f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
243
+ }
244
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
245
+
246
+ if self.use_past:
247
+ if not is_torch_available():
248
+ raise ValueError(
249
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
250
+ )
251
+ else:
252
+ import torch
253
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
254
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
255
+ (
256
+ num_encoder_attention_heads,
257
+ num_decoder_attention_heads,
258
+ ) = self.num_attention_heads
259
+ encoder_shape = (
260
+ batch,
261
+ num_encoder_attention_heads,
262
+ encoder_seq_length,
263
+ self._config.hidden_size // num_encoder_attention_heads,
264
+ )
265
+ decoder_past_length = decoder_seq_length + 3
266
+ decoder_shape = (
267
+ batch,
268
+ num_decoder_attention_heads,
269
+ decoder_past_length,
270
+ self._config.hidden_size // num_decoder_attention_heads,
271
+ )
272
+
273
+ common_inputs["decoder_attention_mask"] = torch.cat(
274
+ [
275
+ common_inputs["decoder_attention_mask"],
276
+ torch.ones(batch, decoder_past_length),
277
+ ],
278
+ dim=1,
279
+ )
280
+
281
+ common_inputs["past_key_values"] = []
282
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
283
+ num_encoder_layers, num_decoder_layers = self.num_layers
284
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
285
+ max_num_layers = (
286
+ max(num_encoder_layers, num_decoder_layers) - min_num_layers
287
+ )
288
+ remaining_side_name = (
289
+ "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
290
+ )
291
+
292
+ for _ in range(min_num_layers):
293
+ common_inputs["past_key_values"].append(
294
+ (
295
+ torch.zeros(decoder_shape),
296
+ torch.zeros(decoder_shape),
297
+ torch.zeros(encoder_shape),
298
+ torch.zeros(encoder_shape),
299
+ )
300
+ )
301
+ # TODO: test this.
302
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
303
+ for _ in range(min_num_layers, max_num_layers):
304
+ common_inputs["past_key_values"].append(
305
+ (torch.zeros(shape), torch.zeros(shape))
306
+ )
307
+ return common_inputs
308
+
309
+ generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
dict.SRC.json ADDED
The diff for this file is too large to render. See raw diff
 
dict.TGT.json ADDED
The diff for this file is too large to render. See raw diff
 
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "decoder_start_token_id": 2,
5
+ "eos_token_id": 2,
6
+ "pad_token_id": 1,
7
+ "transformers_version": "4.32.1"
8
+ }
gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.SRC filter=lfs diff=lfs merge=lfs -text
model.SRC ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
+ size 3256903
model.TGT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cedc5cbcc740369b76201942a0f096fec7287fee039b55bdb956f301235b914
3
+ size 759425
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b030cdd001e669a4043ea70ebb39a59630f5af58e0d0329ac9edc663ce98448
3
+ size 4092117552
modeling_indictrans.py ADDED
@@ -0,0 +1,1803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans model."""
16
+
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+
27
+ from transformers.modeling_attn_mask_utils import (
28
+ _prepare_4d_attention_mask,
29
+ _prepare_4d_attention_mask_for_sdpa,
30
+ _prepare_4d_causal_attention_mask,
31
+ _prepare_4d_causal_attention_mask_for_sdpa,
32
+ )
33
+
34
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutput,
37
+ BaseModelOutputWithPastAndCrossAttentions,
38
+ Seq2SeqLMOutput,
39
+ Seq2SeqModelOutput
40
+ )
41
+
42
+ from transformers.utils import (
43
+ logging,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10,
46
+ )
47
+
48
+ from transformers.modeling_utils import PreTrainedModel
49
+ from transformers.generation.utils import GenerationMixin
50
+
51
+ from .configuration_indictrans import IndicTransConfig
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
57
+
58
+ try:
59
+ if is_flash_attn_2_available():
60
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
61
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
62
+ except:
63
+ pass
64
+
65
+
66
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
67
+ def _get_unpad_data(attention_mask):
68
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
69
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
70
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
71
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
72
+ return (
73
+ indices,
74
+ cu_seqlens,
75
+ max_seqlen_in_batch,
76
+ )
77
+
78
+
79
+ # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
80
+ def shift_tokens_right(
81
+ input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
82
+ ):
83
+ """
84
+ Shift input ids one token to the right.
85
+ """
86
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
87
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
88
+ shifted_input_ids[:, 0] = decoder_start_token_id
89
+
90
+ if pad_token_id is None:
91
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
92
+ # replace possible -100 values in labels by `pad_token_id`
93
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
94
+
95
+ return shifted_input_ids
96
+
97
+
98
+ def create_position_ids_from_input_ids(
99
+ input_ids, padding_idx, past_key_values_length=0
100
+ ):
101
+ """
102
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
103
+ are ignored. This is modified from fairseq's `utils.make_positions`.
104
+ """
105
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
106
+ mask = input_ids.ne(padding_idx).int()
107
+ incremental_indices = (
108
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
109
+ ) * mask
110
+ return incremental_indices.long() + padding_idx
111
+
112
+
113
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
114
+ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
115
+ """This module produces sinusoidal positional embeddings of any length."""
116
+
117
+ def __init__(
118
+ self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
119
+ ):
120
+ super().__init__()
121
+ self.offset = 2
122
+ self.embedding_dim = embedding_dim
123
+ self.padding_idx = padding_idx
124
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
125
+
126
+ def make_weights(
127
+ self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
128
+ ):
129
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
130
+ if hasattr(self, "weights"):
131
+ # in forward put the weights on the correct dtype and device of the param
132
+ emb_weights = emb_weights.to(
133
+ dtype=self.weights.dtype, device=self.weights.device
134
+ )
135
+
136
+ self.register_buffer("weights", emb_weights, persistent=False)
137
+
138
+ @staticmethod
139
+ def get_embedding(
140
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
141
+ ):
142
+ """
143
+ Build sinusoidal embeddings.
144
+
145
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
146
+ "Attention Is All You Need".
147
+ """
148
+ half_dim = embedding_dim // 2
149
+ emb = math.log(10000) / (half_dim - 1)
150
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
151
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
152
+ 1
153
+ ) * emb.unsqueeze(0)
154
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
155
+ num_embeddings, -1
156
+ )
157
+ if embedding_dim % 2 == 1:
158
+ # zero pad
159
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
160
+ if padding_idx is not None:
161
+ emb[padding_idx, :] = 0
162
+
163
+ return emb.to(torch.get_default_dtype())
164
+
165
+ @torch.no_grad()
166
+ def forward(
167
+ self,
168
+ input_ids: torch.Tensor = None,
169
+ inputs_embeds: torch.Tensor = None,
170
+ past_key_values_length: int = 0,
171
+ ):
172
+ if input_ids is not None:
173
+ bsz, seq_len = input_ids.size()
174
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
175
+ position_ids = create_position_ids_from_input_ids(
176
+ input_ids, self.padding_idx, past_key_values_length
177
+ ).to(input_ids.device)
178
+ else:
179
+ bsz, seq_len = inputs_embeds.size()[:-1]
180
+ position_ids = self.create_position_ids_from_inputs_embeds(
181
+ inputs_embeds, past_key_values_length
182
+ )
183
+
184
+ # expand embeddings if needed
185
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
186
+ if max_pos > self.weights.size(0):
187
+ self.make_weights(
188
+ max_pos + self.offset, self.embedding_dim, self.padding_idx
189
+ )
190
+
191
+ return (
192
+ self.weights.index_select(0, position_ids.view(-1))
193
+ .view(bsz, seq_len, self.weights.shape[-1])
194
+ .detach()
195
+ )
196
+
197
+ def create_position_ids_from_inputs_embeds(
198
+ self, inputs_embeds, past_key_values_length
199
+ ):
200
+ """
201
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
202
+
203
+ Args:
204
+ inputs_embeds: torch.Tensor
205
+
206
+ Returns: torch.Tensor
207
+ """
208
+ input_shape = inputs_embeds.size()[:-1]
209
+ sequence_length = input_shape[1]
210
+
211
+ position_ids = torch.arange(
212
+ self.padding_idx + 1,
213
+ sequence_length + self.padding_idx + 1,
214
+ dtype=torch.long,
215
+ device=inputs_embeds.device,
216
+ )
217
+ return (
218
+ position_ids.unsqueeze(0).expand(input_shape).contiguous()
219
+ + past_key_values_length
220
+ )
221
+
222
+
223
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
224
+ class IndicTransAttention(nn.Module):
225
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
226
+
227
+ def __init__(
228
+ self,
229
+ embed_dim: int,
230
+ num_heads: int,
231
+ dropout: float = 0.0,
232
+ is_decoder: bool = False,
233
+ bias: bool = True,
234
+ is_causal: bool = False,
235
+ config: Optional[IndicTransConfig] = None,
236
+ ):
237
+ super().__init__()
238
+ self.embed_dim = embed_dim
239
+ self.num_heads = num_heads
240
+ self.dropout = dropout
241
+ self.head_dim = embed_dim // num_heads
242
+ self.config = config
243
+
244
+ if (self.head_dim * num_heads) != self.embed_dim:
245
+ raise ValueError(
246
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
247
+ f" and `num_heads`: {num_heads})."
248
+ )
249
+ self.scaling = self.head_dim**-0.5
250
+ self.is_decoder = is_decoder
251
+ self.is_causal = is_causal
252
+
253
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
254
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
255
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
256
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
257
+
258
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
259
+ return (
260
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
261
+ .transpose(1, 2)
262
+ .contiguous()
263
+ )
264
+
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ key_value_states: Optional[torch.Tensor] = None,
269
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
270
+ attention_mask: Optional[torch.Tensor] = None,
271
+ layer_head_mask: Optional[torch.Tensor] = None,
272
+ output_attentions: bool = False,
273
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
274
+ """Input shape: Batch x Time x Channel"""
275
+
276
+ # if key_value_states are provided this layer is used as a cross-attention layer
277
+ # for the decoder
278
+ is_cross_attention = key_value_states is not None
279
+
280
+ bsz, tgt_len, _ = hidden_states.size()
281
+
282
+ # get query proj
283
+ query_states = self.q_proj(hidden_states) * self.scaling
284
+ # get key, value proj
285
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
286
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
287
+ # the provided `key_value_states` to support prefix tuning
288
+ if (
289
+ is_cross_attention
290
+ and past_key_value is not None
291
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
292
+ ):
293
+ # reuse k,v, cross_attentions
294
+ key_states = past_key_value[0]
295
+ value_states = past_key_value[1]
296
+ elif is_cross_attention:
297
+ # cross_attentions
298
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
299
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
300
+ elif past_key_value is not None:
301
+ # reuse k, v, self_attention
302
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
303
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
304
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
305
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
306
+ else:
307
+ # self_attention
308
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
309
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
310
+
311
+ if self.is_decoder:
312
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
313
+ # Further calls to cross_attention layer can then reuse all cross-attention
314
+ # key/value_states (first "if" case)
315
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
316
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
317
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
318
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
319
+ past_key_value = (key_states, value_states)
320
+
321
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
322
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
323
+ key_states = key_states.reshape(*proj_shape)
324
+ value_states = value_states.reshape(*proj_shape)
325
+
326
+ src_len = key_states.size(1)
327
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
328
+
329
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
330
+ raise ValueError(
331
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
332
+ f" {attn_weights.size()}"
333
+ )
334
+
335
+ if attention_mask is not None:
336
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
337
+ raise ValueError(
338
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
339
+ )
340
+ attn_weights = (
341
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
342
+ + attention_mask
343
+ )
344
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
345
+
346
+ attn_weights = F.softmax(attn_weights, dim=-1)
347
+
348
+ if layer_head_mask is not None:
349
+ if layer_head_mask.size() != (self.num_heads,):
350
+ raise ValueError(
351
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
352
+ f" {layer_head_mask.size()}"
353
+ )
354
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
355
+ bsz, self.num_heads, tgt_len, src_len
356
+ )
357
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
358
+
359
+ if output_attentions:
360
+ # this operation is a bit awkward, but it's required to
361
+ # make sure that attn_weights keeps its gradient.
362
+ # In order to do so, attn_weights have to be reshaped
363
+ # twice and have to be reused in the following
364
+ attn_weights_reshaped = attn_weights.view(
365
+ bsz, self.num_heads, tgt_len, src_len
366
+ )
367
+ attn_weights = attn_weights_reshaped.view(
368
+ bsz * self.num_heads, tgt_len, src_len
369
+ )
370
+ else:
371
+ attn_weights_reshaped = None
372
+
373
+ attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
374
+
375
+ attn_output = torch.bmm(attn_probs, value_states)
376
+
377
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
378
+ raise ValueError(
379
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
380
+ f" {attn_output.size()}"
381
+ )
382
+
383
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
384
+ attn_output = attn_output.transpose(1, 2)
385
+
386
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
387
+ # partitioned across GPUs when using tensor-parallelism.
388
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
389
+
390
+ attn_output = self.out_proj(attn_output)
391
+
392
+ return attn_output, attn_weights_reshaped, past_key_value
393
+
394
+
395
+ class IndicTransFlashAttention2(IndicTransAttention):
396
+ """
397
+ IndicTrans flash attention module. This module inherits from `IndicTransAttention` as the weights of the module stays
398
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
399
+ flash attention and deal with padding tokens in case the input contains any of them.
400
+ """
401
+
402
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
403
+ def __init__(self, *args, **kwargs):
404
+ super().__init__(*args, **kwargs)
405
+
406
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
407
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
408
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
409
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
410
+
411
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
412
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
413
+
414
+ def forward(
415
+ self,
416
+ hidden_states: torch.Tensor,
417
+ key_value_states: Optional[torch.Tensor] = None,
418
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
419
+ attention_mask: Optional[torch.Tensor] = None,
420
+ layer_head_mask: Optional[torch.Tensor] = None,
421
+ output_attentions: bool = False,
422
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
423
+ # IndicTransFlashAttention2 attention does not support output_attentions
424
+ if output_attentions:
425
+ raise ValueError("IndicTransFlashAttention2 attention does not support output_attentions")
426
+
427
+ # if key_value_states are provided this layer is used as a cross-attention layer
428
+ # for the decoder
429
+ is_cross_attention = key_value_states is not None
430
+
431
+ bsz, q_len, _ = hidden_states.size()
432
+
433
+ # get query proj
434
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
435
+ # get key, value proj
436
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
437
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
438
+ # the provided `key_value_states` to support prefix tuning
439
+ if (
440
+ is_cross_attention
441
+ and past_key_value is not None
442
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
443
+ ):
444
+ # reuse k,v, cross_attentions
445
+ key_states = past_key_value[0].transpose(1, 2)
446
+ value_states = past_key_value[1].transpose(1, 2)
447
+ elif is_cross_attention:
448
+ # cross_attentions
449
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
450
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
451
+ elif past_key_value is not None:
452
+ # reuse k, v, self_attention
453
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
454
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
455
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
456
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
457
+ else:
458
+ # self_attention
459
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
460
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
461
+
462
+ if self.is_decoder:
463
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
464
+ # Further calls to cross_attention layer can then reuse all cross-attention
465
+ # key/value_states (first "if" case)
466
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
467
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
468
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
469
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
470
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
471
+
472
+ kv_seq_len = key_states.shape[-2]
473
+ if past_key_value is not None:
474
+ kv_seq_len += past_key_value[0].shape[-2]
475
+
476
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
477
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
478
+ # cast them back in the correct dtype just to be sure everything works as expected.
479
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
480
+ # in fp32. (LlamaRMSNorm handles it correctly)
481
+
482
+ input_dtype = query_states.dtype
483
+ if input_dtype == torch.float32:
484
+ if torch.is_autocast_enabled():
485
+ target_dtype = torch.get_autocast_gpu_dtype()
486
+ # Handle the case where the model is quantized
487
+ elif hasattr(self.config, "_pre_quantization_dtype"):
488
+ target_dtype = self.config._pre_quantization_dtype
489
+ else:
490
+ target_dtype = self.q_proj.weight.dtype
491
+
492
+ logger.warning_once(
493
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
494
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
495
+ f" {target_dtype}."
496
+ )
497
+
498
+ query_states = query_states.to(target_dtype)
499
+ key_states = key_states.to(target_dtype)
500
+ value_states = value_states.to(target_dtype)
501
+
502
+ attn_output = self._flash_attention_forward(
503
+ query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
504
+ )
505
+
506
+ attn_output = attn_output.reshape(bsz, q_len, -1)
507
+ attn_output = self.out_proj(attn_output)
508
+
509
+ if not output_attentions:
510
+ attn_weights = None
511
+
512
+ return attn_output, attn_weights, past_key_value
513
+
514
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
515
+ def _flash_attention_forward(
516
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
517
+ ):
518
+ """
519
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
520
+ first unpad the input, then computes the attention scores and pad the final attention scores.
521
+
522
+ Args:
523
+ query_states (`torch.Tensor`):
524
+ Input query states to be passed to Flash Attention API
525
+ key_states (`torch.Tensor`):
526
+ Input key states to be passed to Flash Attention API
527
+ value_states (`torch.Tensor`):
528
+ Input value states to be passed to Flash Attention API
529
+ attention_mask (`torch.Tensor`):
530
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
531
+ position of padding tokens and 1 for the position of non-padding tokens.
532
+ dropout (`float`):
533
+ Attention dropout
534
+ softmax_scale (`float`, *optional*):
535
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
536
+ """
537
+ if not self._flash_attn_uses_top_left_mask:
538
+ causal = self.is_causal
539
+ else:
540
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
541
+ causal = self.is_causal and query_length != 1
542
+
543
+ # Contains at least one padding token in the sequence
544
+ if attention_mask is not None:
545
+ batch_size = query_states.shape[0]
546
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
547
+ query_states, key_states, value_states, attention_mask, query_length
548
+ )
549
+
550
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
551
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
552
+
553
+ attn_output_unpad = flash_attn_varlen_func(
554
+ query_states,
555
+ key_states,
556
+ value_states,
557
+ cu_seqlens_q=cu_seqlens_q,
558
+ cu_seqlens_k=cu_seqlens_k,
559
+ max_seqlen_q=max_seqlen_in_batch_q,
560
+ max_seqlen_k=max_seqlen_in_batch_k,
561
+ dropout_p=dropout,
562
+ softmax_scale=softmax_scale,
563
+ causal=causal,
564
+ )
565
+
566
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
567
+ else:
568
+ attn_output = flash_attn_func(
569
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
570
+ )
571
+
572
+ return attn_output
573
+
574
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
575
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
576
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
577
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
578
+
579
+ key_layer = index_first_axis(
580
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
581
+ )
582
+ value_layer = index_first_axis(
583
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
584
+ )
585
+ if query_length == kv_seq_len:
586
+ query_layer = index_first_axis(
587
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
588
+ )
589
+ cu_seqlens_q = cu_seqlens_k
590
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
591
+ indices_q = indices_k
592
+ elif query_length == 1:
593
+ max_seqlen_in_batch_q = 1
594
+ cu_seqlens_q = torch.arange(
595
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
596
+ ) # There is a memcpy here, that is very bad.
597
+ indices_q = cu_seqlens_q[:-1]
598
+ query_layer = query_layer.squeeze(1)
599
+ else:
600
+ # The -q_len: slice assumes left padding.
601
+ attention_mask = attention_mask[:, -query_length:]
602
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
603
+
604
+ return (
605
+ query_layer,
606
+ key_layer,
607
+ value_layer,
608
+ indices_q,
609
+ (cu_seqlens_q, cu_seqlens_k),
610
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
611
+ )
612
+
613
+
614
+ class IndicTransSdpaAttention(IndicTransAttention):
615
+ def forward(
616
+ self,
617
+ hidden_states: torch.Tensor,
618
+ key_value_states: Optional[torch.Tensor] = None,
619
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
620
+ attention_mask: Optional[torch.Tensor] = None,
621
+ layer_head_mask: Optional[torch.Tensor] = None,
622
+ output_attentions: bool = False,
623
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
624
+ """Input shape: Batch x Time x Channel"""
625
+ if output_attentions or layer_head_mask is not None:
626
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
627
+ logger.warning_once(
628
+ "IndicTransModel is using IndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
629
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
630
+ )
631
+ return super().forward(
632
+ hidden_states,
633
+ key_value_states=key_value_states,
634
+ past_key_value=past_key_value,
635
+ attention_mask=attention_mask,
636
+ layer_head_mask=layer_head_mask,
637
+ output_attentions=output_attentions,
638
+ )
639
+
640
+ # if key_value_states are provided this layer is used as a cross-attention layer
641
+ # for the decoder
642
+ is_cross_attention = key_value_states is not None
643
+
644
+ bsz, tgt_len, _ = hidden_states.size()
645
+
646
+ # get query proj
647
+ query_states = self.q_proj(hidden_states)
648
+ # get key, value proj
649
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
650
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
651
+ # the provided `key_value_states` to support prefix tuning
652
+ if (
653
+ is_cross_attention
654
+ and past_key_value is not None
655
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
656
+ ):
657
+ # reuse k,v, cross_attentions
658
+ key_states = past_key_value[0]
659
+ value_states = past_key_value[1]
660
+ elif is_cross_attention:
661
+ # cross_attentions
662
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
663
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
664
+ elif past_key_value is not None:
665
+ # reuse k, v, self_attention
666
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
667
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
668
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
669
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
670
+ else:
671
+ # self_attention
672
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
673
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
674
+
675
+ if self.is_decoder:
676
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
677
+ # Further calls to cross_attention layer can then reuse all cross-attention
678
+ # key/value_states (first "if" case)
679
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
680
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
681
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
682
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
683
+ past_key_value = (key_states, value_states)
684
+
685
+ query_states = self._shape(query_states, tgt_len, bsz)
686
+
687
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
688
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
689
+ attn_output = F.scaled_dot_product_attention(
690
+ query_states,
691
+ key_states,
692
+ value_states,
693
+ attn_mask=attention_mask,
694
+ dropout_p=self.dropout if self.training else 0.0,
695
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
696
+ is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
697
+ )
698
+
699
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
700
+ raise ValueError(
701
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
702
+ f" {attn_output.size()}"
703
+ )
704
+
705
+ attn_output = attn_output.transpose(1, 2)
706
+
707
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
708
+ # partitioned across GPUs when using tensor-parallelism.
709
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
710
+
711
+ attn_output = self.out_proj(attn_output)
712
+
713
+ return attn_output, None, past_key_value
714
+
715
+
716
+ INDICTRANS_ATTENTION_CLASSES = {
717
+ "eager": IndicTransAttention,
718
+ "sdpa": IndicTransSdpaAttention,
719
+ "flash_attention_2": IndicTransFlashAttention2,
720
+ }
721
+
722
+ # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
723
+ class IndicTransEncoderLayer(nn.Module):
724
+ def __init__(self, config: IndicTransConfig):
725
+ super().__init__()
726
+ self.embed_dim = config.encoder_embed_dim
727
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
728
+ embed_dim=self.embed_dim,
729
+ num_heads=config.encoder_attention_heads,
730
+ dropout=config.attention_dropout,
731
+ config=config,
732
+ )
733
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
734
+ self.dropout = config.dropout
735
+ self.activation_fn = ACT2FN[config.activation_function]
736
+ self.activation_dropout = config.activation_dropout
737
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
738
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
739
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
740
+ self.normalize_before = config.encoder_normalize_before
741
+
742
+ def forward(
743
+ self,
744
+ hidden_states: torch.Tensor,
745
+ attention_mask: torch.Tensor,
746
+ layer_head_mask: torch.Tensor,
747
+ output_attentions: bool = False,
748
+ ) -> torch.Tensor:
749
+ """
750
+ Args:
751
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
752
+ attention_mask (`torch.FloatTensor`): attention mask of size
753
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
754
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
755
+ `(encoder_attention_heads,)`.
756
+ output_attentions (`bool`, *optional*):
757
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
758
+ returned tensors for more detail.
759
+ """
760
+ residual = hidden_states
761
+ if self.normalize_before:
762
+ hidden_states = self.self_attn_layer_norm(hidden_states)
763
+ hidden_states, attn_weights, _ = self.self_attn(
764
+ hidden_states=hidden_states,
765
+ attention_mask=attention_mask,
766
+ layer_head_mask=layer_head_mask,
767
+ output_attentions=output_attentions,
768
+ )
769
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
770
+ hidden_states = residual + hidden_states
771
+ if not self.normalize_before:
772
+ hidden_states = self.self_attn_layer_norm(hidden_states)
773
+
774
+ residual = hidden_states
775
+ if self.normalize_before:
776
+ hidden_states = self.final_layer_norm(hidden_states)
777
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
778
+ hidden_states = F.dropout(
779
+ hidden_states, p=self.activation_dropout, training=self.training
780
+ )
781
+ hidden_states = self.fc2(hidden_states)
782
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
783
+ hidden_states = residual + hidden_states
784
+ if not self.normalize_before:
785
+ hidden_states = self.final_layer_norm(hidden_states)
786
+
787
+ if hidden_states.dtype == torch.float16 and (
788
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
789
+ ):
790
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
791
+ hidden_states = torch.clamp(
792
+ hidden_states, min=-clamp_value, max=clamp_value
793
+ )
794
+
795
+ outputs = (hidden_states,)
796
+
797
+ if output_attentions:
798
+ outputs += (attn_weights,)
799
+
800
+ return outputs
801
+
802
+
803
+ # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
804
+ class IndicTransDecoderLayer(nn.Module):
805
+ def __init__(self, config: IndicTransConfig):
806
+ super().__init__()
807
+ self.embed_dim = config.decoder_embed_dim
808
+
809
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
810
+ embed_dim=self.embed_dim,
811
+ num_heads=config.decoder_attention_heads,
812
+ dropout=config.attention_dropout,
813
+ is_decoder=True,
814
+ is_causal=True,
815
+ config=config,
816
+ )
817
+ self.dropout = config.dropout
818
+ self.activation_fn = ACT2FN[config.activation_function]
819
+ self.activation_dropout = config.activation_dropout
820
+
821
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
822
+ self.encoder_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
823
+ self.embed_dim,
824
+ config.decoder_attention_heads,
825
+ dropout=config.attention_dropout,
826
+ is_decoder=True,
827
+ config=config,
828
+ )
829
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
830
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
831
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
832
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
833
+ self.normalize_before = config.decoder_normalize_before
834
+
835
+ def forward(
836
+ self,
837
+ hidden_states: torch.Tensor,
838
+ attention_mask: Optional[torch.Tensor] = None,
839
+ encoder_hidden_states: Optional[torch.Tensor] = None,
840
+ encoder_attention_mask: Optional[torch.Tensor] = None,
841
+ layer_head_mask: Optional[torch.Tensor] = None,
842
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
843
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
844
+ output_attentions: Optional[bool] = False,
845
+ use_cache: Optional[bool] = True,
846
+ ) -> torch.Tensor:
847
+ """
848
+ Args:
849
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
850
+ attention_mask (`torch.FloatTensor`): attention mask of size
851
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
852
+ encoder_hidden_states (`torch.FloatTensor`):
853
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
854
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
855
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
856
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
857
+ `(encoder_attention_heads,)`.
858
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
859
+ size `(decoder_attention_heads,)`.
860
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
861
+ output_attentions (`bool`, *optional*):
862
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
863
+ returned tensors for more detail.
864
+ """
865
+ residual = hidden_states
866
+ if self.normalize_before:
867
+ hidden_states = self.self_attn_layer_norm(hidden_states)
868
+
869
+ # Self Attention
870
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
871
+ self_attn_past_key_value = (
872
+ past_key_value[:2] if past_key_value is not None else None
873
+ )
874
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
875
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
876
+ hidden_states=hidden_states,
877
+ past_key_value=self_attn_past_key_value,
878
+ attention_mask=attention_mask,
879
+ layer_head_mask=layer_head_mask,
880
+ output_attentions=output_attentions,
881
+ )
882
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
883
+ hidden_states = residual + hidden_states
884
+ if not self.normalize_before:
885
+ hidden_states = self.self_attn_layer_norm(hidden_states)
886
+
887
+ # Cross-Attention Block
888
+ cross_attn_present_key_value = None
889
+ cross_attn_weights = None
890
+ if encoder_hidden_states is not None:
891
+ residual = hidden_states
892
+ if self.normalize_before:
893
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
894
+
895
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
896
+ cross_attn_past_key_value = (
897
+ past_key_value[-2:] if past_key_value is not None else None
898
+ )
899
+ (
900
+ hidden_states,
901
+ cross_attn_weights,
902
+ cross_attn_present_key_value,
903
+ ) = self.encoder_attn(
904
+ hidden_states=hidden_states,
905
+ key_value_states=encoder_hidden_states,
906
+ attention_mask=encoder_attention_mask,
907
+ layer_head_mask=cross_attn_layer_head_mask,
908
+ past_key_value=cross_attn_past_key_value,
909
+ output_attentions=output_attentions,
910
+ )
911
+ hidden_states = F.dropout(
912
+ hidden_states, p=self.dropout, training=self.training
913
+ )
914
+ hidden_states = residual + hidden_states
915
+ if not self.normalize_before:
916
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
917
+
918
+ # add cross-attn to positions 3,4 of present_key_value tuple
919
+ present_key_value = present_key_value + cross_attn_present_key_value
920
+
921
+ # Fully Connected
922
+ residual = hidden_states
923
+ if self.normalize_before:
924
+ hidden_states = self.final_layer_norm(hidden_states)
925
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
926
+ hidden_states = F.dropout(
927
+ hidden_states, p=self.activation_dropout, training=self.training
928
+ )
929
+ hidden_states = self.fc2(hidden_states)
930
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
931
+ hidden_states = residual + hidden_states
932
+ if not self.normalize_before:
933
+ hidden_states = self.final_layer_norm(hidden_states)
934
+
935
+ outputs = (hidden_states,)
936
+
937
+ if output_attentions:
938
+ outputs += (self_attn_weights, cross_attn_weights)
939
+
940
+ if use_cache:
941
+ outputs += (present_key_value,)
942
+
943
+ return outputs
944
+
945
+
946
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
947
+ class IndicTransPreTrainedModel(PreTrainedModel):
948
+ config_class = IndicTransConfig
949
+ base_model_prefix = "model"
950
+ supports_gradient_checkpointing = True
951
+ _no_split_modules = ["IndicTransAttention"]
952
+
953
+ def _init_weights(self, module):
954
+ std = self.config.init_std
955
+ if isinstance(module, nn.Linear):
956
+ module.weight.data.normal_(mean=0.0, std=std)
957
+ if module.bias is not None:
958
+ module.bias.data.zero_()
959
+ elif isinstance(module, nn.Embedding):
960
+ module.weight.data.normal_(mean=0.0, std=std)
961
+ if module.padding_idx is not None:
962
+ module.weight.data[module.padding_idx].zero_()
963
+
964
+ def _set_gradient_checkpointing(self, module, value=False):
965
+ if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
966
+ module.gradient_checkpointing = value
967
+
968
+
969
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
970
+ class IndicTransEncoder(IndicTransPreTrainedModel):
971
+ """
972
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
973
+ [`IndicTransEncoderLayer`].
974
+
975
+ Args:
976
+ config: IndicTransConfig
977
+ embed_tokens (nn.Embedding): output embedding
978
+ """
979
+
980
+ def __init__(
981
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
982
+ ):
983
+ super().__init__(config)
984
+
985
+ self.dropout = config.dropout
986
+ self.layerdrop = config.encoder_layerdrop
987
+
988
+ embed_dim = config.encoder_embed_dim
989
+ self.padding_idx = config.pad_token_id
990
+ self.max_source_positions = config.max_source_positions
991
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
992
+
993
+ self.embed_tokens = nn.Embedding(
994
+ config.encoder_vocab_size, embed_dim, self.padding_idx
995
+ )
996
+
997
+ if embed_tokens is not None:
998
+ self.embed_tokens.weight = embed_tokens.weight
999
+
1000
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1001
+ config.max_source_positions,
1002
+ embed_dim,
1003
+ self.padding_idx,
1004
+ )
1005
+ self.layers = nn.ModuleList(
1006
+ [IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
1007
+ )
1008
+ self.layer_norm = (
1009
+ nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
1010
+ )
1011
+ self.layernorm_embedding = (
1012
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1013
+ )
1014
+
1015
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1016
+ self._use_sdpa = config._attn_implementation == "sdpa"
1017
+
1018
+ self.gradient_checkpointing = False
1019
+ # Initialize weights and apply final processing
1020
+ self.post_init()
1021
+
1022
+ def forward(
1023
+ self,
1024
+ input_ids: Optional[torch.Tensor] = None,
1025
+ attention_mask: Optional[torch.Tensor] = None,
1026
+ head_mask: Optional[torch.Tensor] = None,
1027
+ inputs_embeds: Optional[torch.Tensor] = None,
1028
+ output_attentions: Optional[bool] = None,
1029
+ output_hidden_states: Optional[bool] = None,
1030
+ return_dict: Optional[bool] = None,
1031
+ ):
1032
+ r"""
1033
+ Args:
1034
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1035
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1036
+ provide it.
1037
+
1038
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1039
+ [`PreTrainedTokenizer.__call__`] for details.
1040
+
1041
+ [What are input IDs?](../glossary#input-ids)
1042
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1043
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1044
+
1045
+ - 1 for tokens that are **not masked**,
1046
+ - 0 for tokens that are **masked**.
1047
+
1048
+ [What are attention masks?](../glossary#attention-mask)
1049
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
1050
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1051
+
1052
+ - 1 indicates the head is **not masked**,
1053
+ - 0 indicates the head is **masked**.
1054
+
1055
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1056
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1057
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1058
+ than the model's internal embedding lookup matrix.
1059
+ output_attentions (`bool`, *optional*):
1060
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1061
+ returned tensors for more detail.
1062
+ output_hidden_states (`bool`, *optional*):
1063
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1064
+ for more detail.
1065
+ return_dict (`bool`, *optional*):
1066
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1067
+ """
1068
+ output_attentions = (
1069
+ output_attentions
1070
+ if output_attentions is not None
1071
+ else self.config.output_attentions
1072
+ )
1073
+ output_hidden_states = (
1074
+ output_hidden_states
1075
+ if output_hidden_states is not None
1076
+ else self.config.output_hidden_states
1077
+ )
1078
+ return_dict = (
1079
+ return_dict if return_dict is not None else self.config.use_return_dict
1080
+ )
1081
+
1082
+ # retrieve input_ids and inputs_embeds
1083
+ if input_ids is not None and inputs_embeds is not None:
1084
+ raise ValueError(
1085
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1086
+ )
1087
+ elif input_ids is not None:
1088
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1089
+ input_shape = input_ids.size()
1090
+ input_ids = input_ids.view(-1, input_shape[-1])
1091
+ elif inputs_embeds is not None:
1092
+ input_shape = inputs_embeds.size()[:-1]
1093
+ else:
1094
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1095
+
1096
+ if inputs_embeds is None:
1097
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1098
+
1099
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
1100
+ embed_pos = embed_pos.to(inputs_embeds.device)
1101
+
1102
+ hidden_states = inputs_embeds + embed_pos
1103
+ if self.layernorm_embedding is not None:
1104
+ hidden_states = self.layernorm_embedding(hidden_states)
1105
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1106
+
1107
+ if attention_mask is not None:
1108
+ if self._use_flash_attention_2:
1109
+ attention_mask = attention_mask if 0 in attention_mask else None
1110
+ elif self._use_sdpa and head_mask is None and not output_attentions:
1111
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1112
+ # the manual implementation that requires a 4D causal mask in all cases.
1113
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1114
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
1115
+ else:
1116
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1117
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1118
+
1119
+
1120
+ encoder_states = () if output_hidden_states else None
1121
+ all_attentions = () if output_attentions else None
1122
+
1123
+ # check if head_mask has a correct number of layers specified if desired
1124
+ if head_mask is not None:
1125
+ if head_mask.size()[0] != len(self.layers):
1126
+ raise ValueError(
1127
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
1128
+ f" {head_mask.size()[0]}."
1129
+ )
1130
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1131
+
1132
+ for idx, encoder_layer in enumerate(self.layers):
1133
+ if output_hidden_states:
1134
+ encoder_states = encoder_states + (hidden_states,)
1135
+
1136
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1137
+ dropout_probability = torch.rand([])
1138
+
1139
+ skip_the_layer = (
1140
+ True
1141
+ if self.training and (dropout_probability < self.layerdrop)
1142
+ else False
1143
+ )
1144
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1145
+ # under deepspeed zero3 all gpus must run in sync
1146
+
1147
+ if self.gradient_checkpointing and self.training:
1148
+ # create gradient checkpointing function
1149
+ def create_custom_forward(module):
1150
+ def custom_forward(*inputs):
1151
+ return module(*inputs, output_attentions)
1152
+
1153
+ return custom_forward
1154
+
1155
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1156
+ create_custom_forward(encoder_layer),
1157
+ hidden_states,
1158
+ attention_mask,
1159
+ (head_mask[idx] if head_mask is not None else None),
1160
+ )
1161
+ else:
1162
+ layer_outputs = encoder_layer(
1163
+ hidden_states,
1164
+ attention_mask,
1165
+ layer_head_mask=(
1166
+ head_mask[idx] if head_mask is not None else None
1167
+ ),
1168
+ output_attentions=output_attentions,
1169
+ )
1170
+
1171
+ hidden_states = layer_outputs[0]
1172
+
1173
+ if skip_the_layer:
1174
+ layer_outputs = (None, None)
1175
+
1176
+ if output_attentions:
1177
+ all_attentions = all_attentions + (layer_outputs[1],)
1178
+
1179
+ if self.layer_norm is not None:
1180
+ hidden_states = self.layer_norm(hidden_states)
1181
+
1182
+ if output_hidden_states:
1183
+ encoder_states = encoder_states + (hidden_states,)
1184
+
1185
+ if not return_dict:
1186
+ return tuple(
1187
+ v
1188
+ for v in [hidden_states, encoder_states, all_attentions]
1189
+ if v is not None
1190
+ )
1191
+ return BaseModelOutput(
1192
+ last_hidden_state=hidden_states,
1193
+ hidden_states=encoder_states,
1194
+ attentions=all_attentions,
1195
+ )
1196
+
1197
+
1198
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
1199
+ class IndicTransDecoder(IndicTransPreTrainedModel):
1200
+ """
1201
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
1202
+
1203
+ Args:
1204
+ config: IndicTransConfig
1205
+ embed_tokens (nn.Embedding): output embedding
1206
+ """
1207
+
1208
+ def __init__(
1209
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
1210
+ ):
1211
+ super().__init__(config)
1212
+ self.dropout = config.dropout
1213
+ self.layerdrop = config.decoder_layerdrop
1214
+
1215
+ embed_dim = config.encoder_embed_dim
1216
+ self.padding_idx = config.pad_token_id
1217
+ self.max_target_positions = config.max_target_positions
1218
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
1219
+
1220
+ self.embed_tokens = nn.Embedding(
1221
+ config.decoder_vocab_size, embed_dim, self.padding_idx
1222
+ )
1223
+
1224
+ if embed_tokens is not None:
1225
+ self.embed_tokens.weight = embed_tokens.weight
1226
+
1227
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1228
+ config.max_target_positions,
1229
+ embed_dim,
1230
+ self.padding_idx,
1231
+ )
1232
+ self.layers = nn.ModuleList(
1233
+ [IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
1234
+ )
1235
+ self.layer_norm = (
1236
+ nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
1237
+ )
1238
+ self.layernorm_embedding = (
1239
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1240
+ )
1241
+
1242
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1243
+ self._use_sdpa = config._attn_implementation == "sdpa"
1244
+
1245
+ self.gradient_checkpointing = False
1246
+ # Initialize weights and apply final processing
1247
+ self.post_init()
1248
+
1249
+ def forward(
1250
+ self,
1251
+ input_ids: Optional[torch.Tensor] = None,
1252
+ attention_mask: Optional[torch.Tensor] = None,
1253
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1254
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1255
+ head_mask: Optional[torch.Tensor] = None,
1256
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1257
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1258
+ inputs_embeds: Optional[torch.Tensor] = None,
1259
+ use_cache: Optional[bool] = None,
1260
+ output_attentions: Optional[bool] = None,
1261
+ output_hidden_states: Optional[bool] = None,
1262
+ return_dict: Optional[bool] = None,
1263
+ ):
1264
+ r"""
1265
+ Args:
1266
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1267
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1268
+ provide it.
1269
+
1270
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1271
+ [`PreTrainedTokenizer.__call__`] for details.
1272
+
1273
+ [What are input IDs?](../glossary#input-ids)
1274
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1275
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1276
+
1277
+ - 1 for tokens that are **not masked**,
1278
+ - 0 for tokens that are **masked**.
1279
+
1280
+ [What are attention masks?](../glossary#attention-mask)
1281
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
1282
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1283
+ of the decoder.
1284
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
1285
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
1286
+ selected in `[0, 1]`:
1287
+
1288
+ - 1 for tokens that are **not masked**,
1289
+ - 0 for tokens that are **masked**.
1290
+
1291
+ [What are attention masks?](../glossary#attention-mask)
1292
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1293
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1294
+
1295
+ - 1 indicates the head is **not masked**,
1296
+ - 0 indicates the head is **masked**.
1297
+
1298
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1299
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
1300
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
1301
+
1302
+ - 1 indicates the head is **not masked**,
1303
+ - 0 indicates the head is **masked**.
1304
+
1305
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1306
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1307
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1308
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1309
+
1310
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1311
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1312
+
1313
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1314
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1315
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
1316
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
1317
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
1318
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
1319
+ embedding lookup matrix.
1320
+ output_attentions (`bool`, *optional*):
1321
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1322
+ returned tensors for more detail.
1323
+ output_hidden_states (`bool`, *optional*):
1324
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1325
+ for more detail.
1326
+ return_dict (`bool`, *optional*):
1327
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1328
+ """
1329
+ output_attentions = (
1330
+ output_attentions
1331
+ if output_attentions is not None
1332
+ else self.config.output_attentions
1333
+ )
1334
+ output_hidden_states = (
1335
+ output_hidden_states
1336
+ if output_hidden_states is not None
1337
+ else self.config.output_hidden_states
1338
+ )
1339
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1340
+ return_dict = (
1341
+ return_dict if return_dict is not None else self.config.use_return_dict
1342
+ )
1343
+
1344
+ # retrieve input_ids and inputs_embeds
1345
+ if input_ids is not None and inputs_embeds is not None:
1346
+ raise ValueError(
1347
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1348
+ )
1349
+ elif input_ids is not None:
1350
+ input_shape = input_ids.size()
1351
+ input_ids = input_ids.view(-1, input_shape[-1])
1352
+ elif inputs_embeds is not None:
1353
+ input_shape = inputs_embeds.size()[:-1]
1354
+ else:
1355
+ raise ValueError(
1356
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1357
+ )
1358
+
1359
+ # past_key_values_length
1360
+ past_key_values_length = (
1361
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1362
+ )
1363
+
1364
+ if inputs_embeds is None:
1365
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1366
+
1367
+
1368
+ if self._use_flash_attention_2:
1369
+ # 2d mask is passed through the layers
1370
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1371
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1372
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1373
+ # the manual implementation that requires a 4D causal mask in all cases.
1374
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1375
+ attention_mask,
1376
+ input_shape,
1377
+ inputs_embeds,
1378
+ past_key_values_length,
1379
+ )
1380
+ else:
1381
+ # 4d mask is passed through the layers
1382
+ attention_mask = _prepare_4d_causal_attention_mask(
1383
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1384
+ )
1385
+
1386
+ # expand encoder attention mask
1387
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1388
+ if self._use_flash_attention_2:
1389
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1390
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
1391
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1392
+ # the manual implementation that requires a 4D causal mask in all cases.
1393
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1394
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1395
+ encoder_attention_mask,
1396
+ inputs_embeds.dtype,
1397
+ tgt_len=input_shape[-1],
1398
+ )
1399
+ else:
1400
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1401
+ encoder_attention_mask = _prepare_4d_attention_mask(
1402
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1403
+ )
1404
+
1405
+ # embed positions
1406
+ positions = self.embed_positions(
1407
+ input_ids, inputs_embeds, past_key_values_length
1408
+ )
1409
+ positions = positions.to(inputs_embeds.device)
1410
+
1411
+ hidden_states = inputs_embeds + positions
1412
+ if self.layernorm_embedding is not None:
1413
+ hidden_states = self.layernorm_embedding(hidden_states)
1414
+
1415
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1416
+
1417
+ if self.gradient_checkpointing and self.training:
1418
+ if use_cache:
1419
+ logger.warning_once(
1420
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting"
1421
+ " `use_cache=False`..."
1422
+ )
1423
+ use_cache = False
1424
+
1425
+ # decoder layers
1426
+ all_hidden_states = () if output_hidden_states else None
1427
+ all_self_attns = () if output_attentions else None
1428
+ all_cross_attentions = () if output_attentions else None
1429
+ next_decoder_cache = () if use_cache else None
1430
+
1431
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1432
+ for attn_mask, mask_name in zip(
1433
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1434
+ ):
1435
+ if attn_mask is not None:
1436
+ if attn_mask.size()[0] != len(self.layers):
1437
+ raise ValueError(
1438
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1439
+ f" {head_mask.size()[0]}."
1440
+ )
1441
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1442
+
1443
+ for idx, decoder_layer in enumerate(self.layers):
1444
+ if output_hidden_states:
1445
+ all_hidden_states += (hidden_states,)
1446
+
1447
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1448
+ dropout_probability = torch.rand([])
1449
+
1450
+ skip_the_layer = (
1451
+ True
1452
+ if self.training and (dropout_probability < self.layerdrop)
1453
+ else False
1454
+ )
1455
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1456
+ # under deepspeed zero3 all gpus must run in sync
1457
+
1458
+ past_key_value = (
1459
+ past_key_values[idx] if past_key_values is not None else None
1460
+ )
1461
+
1462
+ if self.gradient_checkpointing and self.training:
1463
+
1464
+ def create_custom_forward(module):
1465
+ def custom_forward(*inputs):
1466
+ # None for past_key_value
1467
+ return module(*inputs, output_attentions, use_cache)
1468
+
1469
+ return custom_forward
1470
+
1471
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1472
+ create_custom_forward(decoder_layer),
1473
+ hidden_states,
1474
+ attention_mask,
1475
+ encoder_hidden_states,
1476
+ encoder_attention_mask,
1477
+ head_mask[idx] if head_mask is not None else None,
1478
+ cross_attn_head_mask[idx]
1479
+ if cross_attn_head_mask is not None
1480
+ else None,
1481
+ None,
1482
+ )
1483
+ else:
1484
+ layer_outputs = decoder_layer(
1485
+ hidden_states,
1486
+ attention_mask=attention_mask,
1487
+ encoder_hidden_states=encoder_hidden_states,
1488
+ encoder_attention_mask=encoder_attention_mask,
1489
+ layer_head_mask=(
1490
+ head_mask[idx] if head_mask is not None else None
1491
+ ),
1492
+ cross_attn_layer_head_mask=(
1493
+ cross_attn_head_mask[idx]
1494
+ if cross_attn_head_mask is not None
1495
+ else None
1496
+ ),
1497
+ past_key_value=past_key_value,
1498
+ output_attentions=output_attentions,
1499
+ use_cache=use_cache,
1500
+ )
1501
+
1502
+ hidden_states = layer_outputs[0]
1503
+
1504
+ if skip_the_layer:
1505
+ continue
1506
+
1507
+ if use_cache:
1508
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1509
+
1510
+ if output_attentions:
1511
+ all_self_attns += (layer_outputs[1],)
1512
+ all_cross_attentions += (layer_outputs[2],)
1513
+
1514
+ if self.layer_norm is not None:
1515
+ hidden_states = self.layer_norm(hidden_states)
1516
+
1517
+ # add hidden states from the last decoder layer
1518
+ if output_hidden_states:
1519
+ all_hidden_states += (hidden_states,)
1520
+
1521
+ next_cache = next_decoder_cache if use_cache else None
1522
+ if not return_dict:
1523
+ return tuple(
1524
+ v
1525
+ for v in [
1526
+ hidden_states,
1527
+ next_cache,
1528
+ all_hidden_states,
1529
+ all_self_attns,
1530
+ all_cross_attentions,
1531
+ ]
1532
+ if v is not None
1533
+ )
1534
+ return BaseModelOutputWithPastAndCrossAttentions(
1535
+ last_hidden_state=hidden_states,
1536
+ past_key_values=next_cache,
1537
+ hidden_states=all_hidden_states,
1538
+ attentions=all_self_attns,
1539
+ cross_attentions=all_cross_attentions,
1540
+ )
1541
+
1542
+
1543
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
1544
+ class IndicTransModel(IndicTransPreTrainedModel):
1545
+ _tied_weights_keys = None
1546
+
1547
+ def __init__(self, config: IndicTransConfig):
1548
+ super().__init__(config)
1549
+
1550
+ self.encoder = IndicTransEncoder(config)
1551
+ self.decoder = IndicTransDecoder(config)
1552
+
1553
+ # Initialize weights and apply final processing
1554
+ self.post_init()
1555
+
1556
+ def get_encoder(self):
1557
+ return self.encoder
1558
+
1559
+ def get_decoder(self):
1560
+ return self.decoder
1561
+
1562
+ def forward(
1563
+ self,
1564
+ input_ids: Optional[torch.LongTensor] = None,
1565
+ attention_mask: Optional[torch.Tensor] = None,
1566
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1567
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1568
+ head_mask: Optional[torch.Tensor] = None,
1569
+ decoder_head_mask: Optional[torch.Tensor] = None,
1570
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1571
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1572
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1573
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1574
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1575
+ use_cache: Optional[bool] = None,
1576
+ output_attentions: Optional[bool] = None,
1577
+ output_hidden_states: Optional[bool] = None,
1578
+ return_dict: Optional[bool] = None,
1579
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1580
+ output_attentions = (
1581
+ output_attentions
1582
+ if output_attentions is not None
1583
+ else self.config.output_attentions
1584
+ )
1585
+ output_hidden_states = (
1586
+ output_hidden_states
1587
+ if output_hidden_states is not None
1588
+ else self.config.output_hidden_states
1589
+ )
1590
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1591
+ return_dict = (
1592
+ return_dict if return_dict is not None else self.config.use_return_dict
1593
+ )
1594
+
1595
+ if encoder_outputs is None:
1596
+ encoder_outputs = self.encoder(
1597
+ input_ids=input_ids,
1598
+ attention_mask=attention_mask,
1599
+ head_mask=head_mask,
1600
+ inputs_embeds=inputs_embeds,
1601
+ output_attentions=output_attentions,
1602
+ output_hidden_states=output_hidden_states,
1603
+ return_dict=return_dict,
1604
+ )
1605
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1606
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1607
+ encoder_outputs = BaseModelOutput(
1608
+ last_hidden_state=encoder_outputs[0],
1609
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1610
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1611
+ )
1612
+
1613
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1614
+ decoder_outputs = self.decoder(
1615
+ input_ids=decoder_input_ids,
1616
+ attention_mask=decoder_attention_mask,
1617
+ encoder_hidden_states=encoder_outputs[0],
1618
+ encoder_attention_mask=attention_mask,
1619
+ head_mask=decoder_head_mask,
1620
+ cross_attn_head_mask=cross_attn_head_mask,
1621
+ past_key_values=past_key_values,
1622
+ inputs_embeds=decoder_inputs_embeds,
1623
+ use_cache=use_cache,
1624
+ output_attentions=output_attentions,
1625
+ output_hidden_states=output_hidden_states,
1626
+ return_dict=return_dict,
1627
+ )
1628
+
1629
+ if not return_dict:
1630
+ return decoder_outputs + encoder_outputs
1631
+
1632
+ return Seq2SeqModelOutput(
1633
+ last_hidden_state=decoder_outputs.last_hidden_state,
1634
+ past_key_values=decoder_outputs.past_key_values,
1635
+ decoder_hidden_states=decoder_outputs.hidden_states,
1636
+ decoder_attentions=decoder_outputs.attentions,
1637
+ cross_attentions=decoder_outputs.cross_attentions,
1638
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1639
+ encoder_hidden_states=encoder_outputs.hidden_states,
1640
+ encoder_attentions=encoder_outputs.attentions,
1641
+ )
1642
+
1643
+
1644
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1645
+ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
1646
+ base_model_prefix = "model"
1647
+ _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
1648
+ _label_smoothing = 0.0
1649
+
1650
+ def __init__(self, config: IndicTransConfig):
1651
+ super().__init__(config)
1652
+ self.model = IndicTransModel(config)
1653
+ self.lm_head = nn.Linear(
1654
+ config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1655
+ )
1656
+
1657
+ self.post_init()
1658
+
1659
+ def tie_weights(self):
1660
+ if self.config.share_decoder_input_output_embed:
1661
+ self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.lm_head)
1662
+
1663
+ def get_encoder(self):
1664
+ return self.model.encoder
1665
+
1666
+ def get_decoder(self):
1667
+ return self.model.decoder
1668
+
1669
+ def get_input_embeddings(self):
1670
+ return self.model.encoder.embed_tokens
1671
+
1672
+ def get_output_embeddings(self):
1673
+ return self.lm_head
1674
+
1675
+ def set_output_embeddings(self, new_embeddings):
1676
+ self.lm_head = new_embeddings
1677
+
1678
+ def set_label_smoothing(self, label_smoothing):
1679
+ self._label_smoothing = label_smoothing
1680
+
1681
+ def forward(
1682
+ self,
1683
+ input_ids: Optional[torch.LongTensor] = None,
1684
+ attention_mask: Optional[torch.Tensor] = None,
1685
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1686
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1687
+ head_mask: Optional[torch.Tensor] = None,
1688
+ decoder_head_mask: Optional[torch.Tensor] = None,
1689
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1690
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1691
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1692
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1693
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1694
+ labels: Optional[torch.LongTensor] = None,
1695
+ use_cache: Optional[bool] = None,
1696
+ output_attentions: Optional[bool] = None,
1697
+ output_hidden_states: Optional[bool] = None,
1698
+ return_dict: Optional[bool] = None,
1699
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1700
+ r"""
1701
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1702
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1703
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1704
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1705
+
1706
+ Returns:
1707
+ """
1708
+ return_dict = (
1709
+ return_dict if return_dict is not None else self.config.use_return_dict
1710
+ )
1711
+
1712
+ if labels is not None:
1713
+ if decoder_input_ids is None:
1714
+ decoder_input_ids = shift_tokens_right(
1715
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1716
+ )
1717
+
1718
+ outputs = self.model(
1719
+ input_ids,
1720
+ attention_mask=attention_mask,
1721
+ decoder_input_ids=decoder_input_ids,
1722
+ encoder_outputs=encoder_outputs,
1723
+ decoder_attention_mask=decoder_attention_mask,
1724
+ head_mask=head_mask,
1725
+ decoder_head_mask=decoder_head_mask,
1726
+ cross_attn_head_mask=cross_attn_head_mask,
1727
+ past_key_values=past_key_values,
1728
+ inputs_embeds=inputs_embeds,
1729
+ decoder_inputs_embeds=decoder_inputs_embeds,
1730
+ use_cache=use_cache,
1731
+ output_attentions=output_attentions,
1732
+ output_hidden_states=output_hidden_states,
1733
+ return_dict=return_dict,
1734
+ )
1735
+ lm_logits = self.lm_head(outputs[0])
1736
+
1737
+ masked_lm_loss = None
1738
+ if labels is not None:
1739
+ # move labels to the correct device to enable PP
1740
+ labels = labels.to(lm_logits.device)
1741
+ masked_lm_loss = F.cross_entropy(
1742
+ input=lm_logits.view(-1, self.config.decoder_vocab_size),
1743
+ target=labels.view(-1),
1744
+ ignore_index=-100,
1745
+ label_smoothing=self._label_smoothing,
1746
+ )
1747
+
1748
+ if not return_dict:
1749
+ output = (lm_logits,) + outputs[1:]
1750
+ return (
1751
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1752
+ )
1753
+
1754
+ return Seq2SeqLMOutput(
1755
+ loss=masked_lm_loss,
1756
+ logits=lm_logits,
1757
+ past_key_values=outputs.past_key_values,
1758
+ decoder_hidden_states=outputs.decoder_hidden_states,
1759
+ decoder_attentions=outputs.decoder_attentions,
1760
+ cross_attentions=outputs.cross_attentions,
1761
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1762
+ encoder_hidden_states=outputs.encoder_hidden_states,
1763
+ encoder_attentions=outputs.encoder_attentions,
1764
+ )
1765
+
1766
+ def prepare_inputs_for_generation(
1767
+ self,
1768
+ decoder_input_ids,
1769
+ past_key_values=None,
1770
+ attention_mask=None,
1771
+ head_mask=None,
1772
+ decoder_head_mask=None,
1773
+ cross_attn_head_mask=None,
1774
+ use_cache=None,
1775
+ encoder_outputs=None,
1776
+ **kwargs,
1777
+ ):
1778
+ # cut decoder_input_ids if past is used
1779
+ if past_key_values is not None:
1780
+ decoder_input_ids = decoder_input_ids[:, -1:]
1781
+
1782
+ return {
1783
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1784
+ "encoder_outputs": encoder_outputs,
1785
+ "past_key_values": past_key_values,
1786
+ "decoder_input_ids": decoder_input_ids,
1787
+ "attention_mask": attention_mask,
1788
+ "head_mask": head_mask,
1789
+ "decoder_head_mask": decoder_head_mask,
1790
+ "cross_attn_head_mask": cross_attn_head_mask,
1791
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1792
+ }
1793
+
1794
+ @staticmethod
1795
+ def _reorder_cache(past_key_values, beam_idx):
1796
+ reordered_past = ()
1797
+ for layer_past in past_key_values:
1798
+ reordered_past += (
1799
+ tuple(
1800
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1801
+ ),
1802
+ )
1803
+ return reordered_past
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c95ae59f5148766cef4801353d7d3166f623078ea078ad49eba84c202381e5d3
3
+ size 4092276281
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "<pad>",
5
+ "unk_token": "<unk>"
6
+ }
tokenization_indictrans.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from transformers.utils import logging
5
+ from typing import Dict, List, Optional, Union, Tuple
6
+
7
+ from sentencepiece import SentencePieceProcessor
8
+ from transformers.tokenization_utils import PreTrainedTokenizer
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ # Convert LANGUAGE_TAGS to a frozen set for faster lookups
14
+ LANGUAGE_TAGS = frozenset(
15
+ {
16
+ "asm_Beng",
17
+ "awa_Deva",
18
+ "ben_Beng",
19
+ "bho_Deva",
20
+ "brx_Deva",
21
+ "doi_Deva",
22
+ "eng_Latn",
23
+ "gom_Deva",
24
+ "gon_Deva",
25
+ "guj_Gujr",
26
+ "hin_Deva",
27
+ "hne_Deva",
28
+ "kan_Knda",
29
+ "kas_Arab",
30
+ "kas_Deva",
31
+ "kha_Latn",
32
+ "lus_Latn",
33
+ "mag_Deva",
34
+ "mai_Deva",
35
+ "mal_Mlym",
36
+ "mar_Deva",
37
+ "mni_Beng",
38
+ "mni_Mtei",
39
+ "npi_Deva",
40
+ "ory_Orya",
41
+ "pan_Guru",
42
+ "san_Deva",
43
+ "sat_Olck",
44
+ "snd_Arab",
45
+ "snd_Deva",
46
+ "tam_Taml",
47
+ "tel_Telu",
48
+ "urd_Arab",
49
+ "unr_Deva",
50
+ }
51
+ )
52
+
53
+ VOCAB_FILES_NAMES = {
54
+ "src_vocab_fp": "dict.SRC.json",
55
+ "tgt_vocab_fp": "dict.TGT.json",
56
+ "src_spm_fp": "model.SRC",
57
+ "tgt_spm_fp": "model.TGT",
58
+ }
59
+
60
+
61
+ class IndicTransTokenizer(PreTrainedTokenizer):
62
+ _added_tokens_encoder: Dict[str, int] = {}
63
+ _added_tokens_decoder: Dict[str, int] = {}
64
+ vocab_files_names = VOCAB_FILES_NAMES
65
+ model_input_names = ["input_ids", "attention_mask"]
66
+
67
+ def __init__(
68
+ self,
69
+ src_vocab_fp=None,
70
+ tgt_vocab_fp=None,
71
+ src_spm_fp=None,
72
+ tgt_spm_fp=None,
73
+ unk_token="<unk>",
74
+ bos_token="<s>",
75
+ eos_token="</s>",
76
+ pad_token="<pad>",
77
+ do_lower_case=False,
78
+ **kwargs,
79
+ ):
80
+ self.src_vocab_fp = src_vocab_fp
81
+ self.tgt_vocab_fp = tgt_vocab_fp
82
+ self.src_spm_fp = src_spm_fp
83
+ self.tgt_spm_fp = tgt_spm_fp
84
+
85
+ # Store token content directly instead of accessing .content
86
+ self.unk_token = (
87
+ hasattr(unk_token, "content") and unk_token.content or unk_token
88
+ )
89
+ self.pad_token = (
90
+ hasattr(pad_token, "content") and pad_token.content or pad_token
91
+ )
92
+ self.eos_token = (
93
+ hasattr(eos_token, "content") and eos_token.content or eos_token
94
+ )
95
+ self.bos_token = (
96
+ hasattr(bos_token, "content") and bos_token.content or bos_token
97
+ )
98
+
99
+ # Load vocabularies
100
+ self.src_encoder = self._load_json(self.src_vocab_fp)
101
+ self.tgt_encoder = self._load_json(self.tgt_vocab_fp)
102
+
103
+ # Validate tokens
104
+ if self.unk_token not in self.src_encoder:
105
+ raise KeyError("<unk> token must be in vocab")
106
+ if self.pad_token not in self.src_encoder:
107
+ raise KeyError("<pad> token must be in vocab")
108
+
109
+ # Pre-compute reverse mappings
110
+ self.src_decoder = {v: k for k, v in self.src_encoder.items()}
111
+ self.tgt_decoder = {v: k for k, v in self.tgt_encoder.items()}
112
+
113
+ # Load SPM models
114
+ self.src_spm = self._load_spm(self.src_spm_fp)
115
+ self.tgt_spm = self._load_spm(self.tgt_spm_fp)
116
+
117
+ # Initialize current settings
118
+ self._switch_to_input_mode()
119
+
120
+ # Cache token IDs
121
+ self.unk_token_id = self.src_encoder[self.unk_token]
122
+ self.pad_token_id = self.src_encoder[self.pad_token]
123
+ self.eos_token_id = self.src_encoder[self.eos_token]
124
+ self.bos_token_id = self.src_encoder[self.bos_token]
125
+
126
+ super().__init__(
127
+ src_vocab_file=self.src_vocab_fp,
128
+ tgt_vocab_file=self.tgt_vocab_fp,
129
+ do_lower_case=do_lower_case,
130
+ unk_token=unk_token,
131
+ bos_token=bos_token,
132
+ eos_token=eos_token,
133
+ pad_token=pad_token,
134
+ **kwargs,
135
+ )
136
+
137
+ def add_new_language_tags(self, new_tags: List[str]) -> None:
138
+ global LANGUAGE_TAGS
139
+ LANGUAGE_TAGS = frozenset(LANGUAGE_TAGS | set(new_tags))
140
+
141
+ def _switch_to_input_mode(self) -> None:
142
+ self.spm = self.src_spm
143
+ self.padding_side = "left"
144
+ self.encoder = self.src_encoder
145
+ self.decoder = self.src_decoder
146
+ self._tokenize = self._src_tokenize
147
+
148
+ def _switch_to_target_mode(self) -> None:
149
+ self.spm = self.tgt_spm
150
+ self.padding_side = "right"
151
+ self.encoder = self.tgt_encoder
152
+ self.decoder = self.tgt_decoder
153
+ self._tokenize = self._tgt_tokenize
154
+
155
+ @staticmethod
156
+ def _load_spm(path: str) -> SentencePieceProcessor:
157
+ return SentencePieceProcessor(model_file=path)
158
+
159
+ @staticmethod
160
+ def _save_json(data: Union[Dict, List], path: str) -> None:
161
+ with open(path, "w", encoding="utf-8") as f:
162
+ json.dump(data, f, indent=2)
163
+
164
+ @staticmethod
165
+ def _load_json(path: str) -> Union[Dict, List]:
166
+ with open(path, "r", encoding="utf-8") as f:
167
+ return json.load(f)
168
+
169
+ @property
170
+ def src_vocab_size(self) -> int:
171
+ return len(self.src_encoder)
172
+
173
+ @property
174
+ def tgt_vocab_size(self) -> int:
175
+ return len(self.tgt_encoder)
176
+
177
+ def get_src_vocab(self) -> Dict[str, int]:
178
+ return dict(self.src_encoder, **self.added_tokens_encoder)
179
+
180
+ def get_tgt_vocab(self) -> Dict[str, int]:
181
+ return dict(self.tgt_encoder, **self.added_tokens_decoder)
182
+
183
+ def get_vocab(self) -> Dict[str, int]:
184
+ return self.get_src_vocab()
185
+
186
+ @property
187
+ def vocab_size(self) -> int:
188
+ return self.src_vocab_size
189
+
190
+ def _convert_token_to_id(self, token: str) -> int:
191
+ return self.encoder.get(token, self.unk_token_id)
192
+
193
+ def _convert_id_to_token(self, index: int) -> str:
194
+ return self.decoder.get(index, self.unk_token)
195
+
196
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
197
+ return "".join(tokens).replace("▁", " ").strip()
198
+
199
+ def _src_tokenize(self, text: str) -> List[str]:
200
+ src_lang, tgt_lang, text = text.split(" ", 2)
201
+ assert src_lang in LANGUAGE_TAGS, f"Invalid source language tag: {src_lang}"
202
+ assert tgt_lang in LANGUAGE_TAGS, f"Invalid target language tag: {tgt_lang}"
203
+ return [src_lang, tgt_lang] + self.spm.EncodeAsPieces(text)
204
+
205
+ def _tgt_tokenize(self, text: str) -> List[str]:
206
+ return self.spm.EncodeAsPieces(text)
207
+
208
+ def _decode(
209
+ self,
210
+ token_ids: Union[int, List[int]],
211
+ skip_special_tokens: bool = False,
212
+ clean_up_tokenization_spaces: bool = None,
213
+ spaces_between_special_tokens: bool = True,
214
+ **kwargs,
215
+ ) -> str:
216
+ self._switch_to_target_mode()
217
+ decoded_token_ids = super()._decode(
218
+ token_ids=token_ids,
219
+ skip_special_tokens=skip_special_tokens,
220
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
221
+ spaces_between_special_tokens=spaces_between_special_tokens,
222
+ **kwargs,
223
+ )
224
+ self._switch_to_input_mode()
225
+ return decoded_token_ids
226
+
227
+ def build_inputs_with_special_tokens(
228
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
229
+ ) -> List[int]:
230
+ return token_ids_0 + [self.eos_token_id]
231
+
232
+ def save_vocabulary(
233
+ self, save_directory: str, filename_prefix: Optional[str] = None
234
+ ) -> Tuple[str, ...]:
235
+ if not os.path.isdir(save_directory):
236
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
237
+ return ()
238
+
239
+ src_spm_fp = os.path.join(save_directory, "model.SRC")
240
+ tgt_spm_fp = os.path.join(save_directory, "model.TGT")
241
+ src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
242
+ tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
243
+
244
+ self._save_json(self.src_encoder, src_vocab_fp)
245
+ self._save_json(self.tgt_encoder, tgt_vocab_fp)
246
+
247
+ for fp, spm in [(src_spm_fp, self.src_spm), (tgt_spm_fp, self.tgt_spm)]:
248
+ with open(fp, "wb") as f:
249
+ f.write(spm.serialized_model_proto())
250
+
251
+ return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp
tokenizer_config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "bos_token": "<s>",
37
+ "clean_up_tokenization_spaces": true,
38
+ "do_lower_case": false,
39
+ "eos_token": "</s>",
40
+ "model_max_length": 256,
41
+ "pad_token": "<pad>",
42
+ "name_or_path": "ai4bharat/indictrans2-indic-en-1B",
43
+ "tokenizer_class": "IndicTransTokenizer",
44
+ "auto_map": {
45
+ "AutoTokenizer": [
46
+ "tokenization_indictrans.IndicTransTokenizer",
47
+ null
48
+ ]
49
+ },
50
+ "unk_token": "<unk>"
51
+ }