tamoghna commited on
Commit
1452010
·
verified ·
1 Parent(s): 3d42845

Create modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +335 -0
modeling.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from transformers.modeling_outputs import Seq2SeqLMOutput
5
+ from typing import Optional, Tuple, Union
6
+ import math
7
+
8
+
9
+ class PositionalEncoding(nn.Module):
10
+ """Positional encoding for transformer"""
11
+ def __init__(self, d_model, max_length=5000):
12
+ super().__init__()
13
+ pe = torch.zeros(max_length, d_model)
14
+ position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ pe = pe.unsqueeze(0)
19
+ self.register_buffer('pe', pe)
20
+
21
+ def forward(self, x):
22
+ return x + self.pe[:, :x.size(1)]
23
+
24
+
25
+ class TranslationTransformerConfig(PretrainedConfig):
26
+ """Configuration class for TranslationTransformer"""
27
+ model_type = "translation_transformer"
28
+
29
+ def __init__(
30
+ self,
31
+ vocab_size=32000,
32
+ d_model=512,
33
+ nhead=8,
34
+ num_encoder_layers=6,
35
+ num_decoder_layers=6,
36
+ dim_feedforward=2048,
37
+ dropout=0.1,
38
+ pad_token_id=0,
39
+ bos_token_id=2,
40
+ eos_token_id=3,
41
+ max_length=512,
42
+ **kwargs
43
+ ):
44
+ super().__init__(
45
+ pad_token_id=pad_token_id,
46
+ bos_token_id=bos_token_id,
47
+ eos_token_id=eos_token_id,
48
+ **kwargs
49
+ )
50
+
51
+ self.vocab_size = vocab_size
52
+ self.d_model = d_model
53
+ self.nhead = nhead
54
+ self.num_encoder_layers = num_encoder_layers
55
+ self.num_decoder_layers = num_decoder_layers
56
+ self.dim_feedforward = dim_feedforward
57
+ self.dropout = dropout
58
+ self.max_length = max_length
59
+
60
+ # Required for HuggingFace compatibility
61
+ self.is_encoder_decoder = True
62
+ self.decoder_start_token_id = bos_token_id
63
+
64
+
65
+ class TranslationTransformerModel(PreTrainedModel):
66
+ """
67
+ Encoder-Decoder Transformer for Translation
68
+ Compatible with HuggingFace Hub
69
+ """
70
+ config_class = TranslationTransformerConfig
71
+ base_model_prefix = "translation_transformer"
72
+ supports_gradient_checkpointing = True
73
+
74
+ def __init__(self, config):
75
+ super().__init__(config)
76
+ self.config = config
77
+
78
+ # Embeddings
79
+ self.embedding = nn.Embedding(
80
+ config.vocab_size,
81
+ config.d_model,
82
+ padding_idx=config.pad_token_id
83
+ )
84
+ self.pos_encoder = PositionalEncoding(config.d_model, config.max_length)
85
+ self.pos_decoder = PositionalEncoding(config.d_model, config.max_length)
86
+
87
+ # Transformer
88
+ self.transformer = nn.Transformer(
89
+ d_model=config.d_model,
90
+ nhead=config.nhead,
91
+ num_encoder_layers=config.num_encoder_layers,
92
+ num_decoder_layers=config.num_decoder_layers,
93
+ dim_feedforward=config.dim_feedforward,
94
+ dropout=config.dropout,
95
+ batch_first=True
96
+ )
97
+
98
+ # Output layer
99
+ self.fc_out = nn.Linear(config.d_model, config.vocab_size)
100
+
101
+ # Initialize weights
102
+ self.post_init()
103
+
104
+ def _init_weights(self, module):
105
+ """Initialize weights"""
106
+ if isinstance(module, nn.Linear):
107
+ module.weight.data.normal_(mean=0.0, std=0.02)
108
+ if module.bias is not None:
109
+ module.bias.data.zero_()
110
+ elif isinstance(module, nn.Embedding):
111
+ module.weight.data.normal_(mean=0.0, std=0.02)
112
+ if module.padding_idx is not None:
113
+ module.weight.data[module.padding_idx].zero_()
114
+
115
+ def get_encoder(self):
116
+ """Return encoder for compatibility"""
117
+ return self.transformer.encoder
118
+
119
+ def get_decoder(self):
120
+ """Return decoder for compatibility"""
121
+ return self.transformer.decoder
122
+
123
+ def generate_square_subsequent_mask(self, sz, device):
124
+ """Generate causal mask for decoder"""
125
+ mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1)
126
+ mask = mask.masked_fill(mask == 1, float('-inf'))
127
+ return mask
128
+
129
+ def create_padding_mask(self, seq, pad_token_id):
130
+ """Create padding mask"""
131
+ return (seq == pad_token_id)
132
+
133
+ def forward(
134
+ self,
135
+ input_ids: Optional[torch.LongTensor] = None,
136
+ attention_mask: Optional[torch.FloatTensor] = None,
137
+ decoder_input_ids: Optional[torch.LongTensor] = None,
138
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
139
+ labels: Optional[torch.LongTensor] = None,
140
+ output_attentions: Optional[bool] = None,
141
+ output_hidden_states: Optional[bool] = None,
142
+ return_dict: Optional[bool] = None,
143
+ **kwargs
144
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
145
+ """
146
+ Forward pass
147
+
148
+ Args:
149
+ input_ids: Source sequence tokens [batch_size, src_seq_len]
150
+ attention_mask: Source attention mask [batch_size, src_seq_len]
151
+ decoder_input_ids: Target sequence tokens [batch_size, tgt_seq_len]
152
+ decoder_attention_mask: Target attention mask [batch_size, tgt_seq_len]
153
+ labels: Labels for loss calculation [batch_size, tgt_seq_len]
154
+ output_attentions: Whether to output attentions
155
+ output_hidden_states: Whether to output hidden states
156
+ return_dict: Whether to return ModelOutput
157
+ """
158
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
159
+ device = input_ids.device
160
+
161
+ # If labels provided but no decoder_input_ids, shift labels to create decoder_input_ids
162
+ if labels is not None and decoder_input_ids is None:
163
+ # Replace -100 with pad_token_id for embedding
164
+ labels_shifted = labels.clone()
165
+ labels_shifted[labels_shifted == -100] = self.config.pad_token_id
166
+
167
+ # Shift right: [BOS, token1, token2, ...] from [token1, token2, ..., EOS]
168
+ decoder_input_ids = torch.cat([
169
+ torch.full((labels.shape[0], 1), self.config.bos_token_id, dtype=torch.long, device=device),
170
+ labels_shifted[:, :-1]
171
+ ], dim=1)
172
+
173
+ # Embeddings with scaling
174
+ src_emb = self.embedding(input_ids) * math.sqrt(self.config.d_model)
175
+ src_emb = self.pos_encoder(src_emb)
176
+
177
+ tgt_emb = self.embedding(decoder_input_ids) * math.sqrt(self.config.d_model)
178
+ tgt_emb = self.pos_decoder(tgt_emb)
179
+
180
+ # Create masks
181
+ tgt_seq_len = decoder_input_ids.size(1)
182
+ tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len, device)
183
+
184
+ src_key_padding_mask = self.create_padding_mask(input_ids, self.config.pad_token_id)
185
+ tgt_key_padding_mask = self.create_padding_mask(decoder_input_ids, self.config.pad_token_id)
186
+
187
+ # Transformer forward pass
188
+ output = self.transformer(
189
+ src_emb,
190
+ tgt_emb,
191
+ tgt_mask=tgt_mask,
192
+ src_key_padding_mask=src_key_padding_mask,
193
+ tgt_key_padding_mask=tgt_key_padding_mask,
194
+ memory_key_padding_mask=src_key_padding_mask
195
+ )
196
+
197
+ # Output projection
198
+ logits = self.fc_out(output)
199
+
200
+ # Calculate loss if labels provided
201
+ loss = None
202
+ if labels is not None:
203
+ # Use -100 as ignore_index (standard for HuggingFace)
204
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
205
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
206
+
207
+ if not return_dict:
208
+ output = (logits,)
209
+ return ((loss,) + output) if loss is not None else output
210
+
211
+ return Seq2SeqLMOutput(
212
+ loss=loss,
213
+ logits=logits,
214
+ )
215
+
216
+ def prepare_inputs_for_generation(
217
+ self,
218
+ decoder_input_ids,
219
+ past_key_values=None,
220
+ attention_mask=None,
221
+ use_cache=None,
222
+ encoder_outputs=None,
223
+ **kwargs
224
+ ):
225
+ """Prepare inputs for generation (required for HuggingFace generate)"""
226
+ return {
227
+ "input_ids": kwargs.get("input_ids"),
228
+ "decoder_input_ids": decoder_input_ids,
229
+ "attention_mask": attention_mask,
230
+ }
231
+
232
+ @staticmethod
233
+ def _reorder_cache(past_key_values, beam_idx):
234
+ """Reorder cache for beam search (placeholder)"""
235
+ return past_key_values
236
+
237
+ def generate(
238
+ self,
239
+ input_ids: torch.LongTensor,
240
+ attention_mask: Optional[torch.FloatTensor] = None,
241
+ max_length: int = 128,
242
+ num_beams: int = 1,
243
+ temperature: float = 1.0,
244
+ do_sample: bool = False,
245
+ top_k: int = 50,
246
+ top_p: float = 1.0,
247
+ **kwargs
248
+ ) -> torch.LongTensor:
249
+ """
250
+ Generate translations
251
+
252
+ Args:
253
+ input_ids: Source sequence [batch_size, src_seq_len]
254
+ attention_mask: Source attention mask
255
+ max_length: Maximum generation length
256
+ num_beams: Number of beams for beam search
257
+ temperature: Sampling temperature
258
+ do_sample: Whether to use sampling
259
+ top_k: Top-k sampling parameter
260
+ top_p: Nucleus sampling parameter
261
+
262
+ Returns:
263
+ Generated sequences [batch_size, tgt_seq_len]
264
+ """
265
+ device = input_ids.device
266
+ batch_size = input_ids.size(0)
267
+
268
+ # Start with BOS token
269
+ decoder_input_ids = torch.full(
270
+ (batch_size, 1),
271
+ self.config.bos_token_id,
272
+ dtype=torch.long,
273
+ device=device
274
+ )
275
+
276
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
277
+
278
+ # Generate tokens one by one
279
+ for _ in range(max_length - 1):
280
+ # Forward pass
281
+ outputs = self.forward(
282
+ input_ids=input_ids,
283
+ attention_mask=attention_mask,
284
+ decoder_input_ids=decoder_input_ids,
285
+ return_dict=True
286
+ )
287
+
288
+ # Get next token logits
289
+ next_token_logits = outputs.logits[:, -1, :] / temperature
290
+
291
+ if do_sample:
292
+ # Apply top-k filtering
293
+ if top_k > 0:
294
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
295
+ next_token_logits[indices_to_remove] = float('-inf')
296
+
297
+ # Apply top-p (nucleus) filtering
298
+ if top_p < 1.0:
299
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
300
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
301
+ sorted_indices_to_remove = cumulative_probs > top_p
302
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
303
+ sorted_indices_to_remove[..., 0] = 0
304
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
305
+ next_token_logits[indices_to_remove] = float('-inf')
306
+
307
+ # Sample
308
+ probs = torch.softmax(next_token_logits, dim=-1)
309
+ next_token = torch.multinomial(probs, num_samples=1)
310
+ else:
311
+ # Greedy selection
312
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
313
+
314
+ # Mark finished sequences (those that generated EOS)
315
+ finished = finished | (next_token.squeeze(-1) == self.config.eos_token_id)
316
+
317
+ # Replace tokens in finished sequences with PAD
318
+ next_token[finished] = self.config.pad_token_id
319
+
320
+ # Append to decoder input
321
+ decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
322
+
323
+ # Stop if all sequences are finished
324
+ if finished.all():
325
+ break
326
+
327
+ return decoder_input_ids
328
+
329
+
330
+ # Register the model in the AutoModel registry
331
+ from transformers import AutoConfig, AutoModel, AutoModelForSeq2SeqLM
332
+
333
+ AutoConfig.register("translation_transformer", TranslationTransformerConfig)
334
+ AutoModel.register(TranslationTransformerConfig, TranslationTransformerModel)
335
+ AutoModelForSeq2SeqLM.register(TranslationTransformerConfig, TranslationTransformerModel)