tamoghna commited on
Commit
c9912a0
·
verified ·
1 Parent(s): 6cdab95

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +7 -47
modeling.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
@@ -142,29 +145,15 @@ class TranslationTransformerModel(PreTrainedModel):
142
  return_dict: Optional[bool] = None,
143
  **kwargs
144
  ) -> Union[Tuple, Seq2SeqLMOutput]:
145
- """
146
- Forward pass
147
-
148
- Args:
149
- input_ids: Source sequence tokens [batch_size, src_seq_len]
150
- attention_mask: Source attention mask [batch_size, src_seq_len]
151
- decoder_input_ids: Target sequence tokens [batch_size, tgt_seq_len]
152
- decoder_attention_mask: Target attention mask [batch_size, tgt_seq_len]
153
- labels: Labels for loss calculation [batch_size, tgt_seq_len]
154
- output_attentions: Whether to output attentions
155
- output_hidden_states: Whether to output hidden states
156
- return_dict: Whether to return ModelOutput
157
- """
158
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
159
  device = input_ids.device
160
 
161
  # If labels provided but no decoder_input_ids, shift labels to create decoder_input_ids
162
  if labels is not None and decoder_input_ids is None:
163
- # Replace -100 with pad_token_id for embedding
164
  labels_shifted = labels.clone()
165
  labels_shifted[labels_shifted == -100] = self.config.pad_token_id
166
 
167
- # Shift right: [BOS, token1, token2, ...] from [token1, token2, ..., EOS]
168
  decoder_input_ids = torch.cat([
169
  torch.full((labels.shape[0], 1), self.config.bos_token_id, dtype=torch.long, device=device),
170
  labels_shifted[:, :-1]
@@ -200,7 +189,6 @@ class TranslationTransformerModel(PreTrainedModel):
200
  # Calculate loss if labels provided
201
  loss = None
202
  if labels is not None:
203
- # Use -100 as ignore_index (standard for HuggingFace)
204
  loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
205
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
206
 
@@ -222,7 +210,7 @@ class TranslationTransformerModel(PreTrainedModel):
222
  encoder_outputs=None,
223
  **kwargs
224
  ):
225
- """Prepare inputs for generation (required for HuggingFace generate)"""
226
  return {
227
  "input_ids": kwargs.get("input_ids"),
228
  "decoder_input_ids": decoder_input_ids,
@@ -231,7 +219,7 @@ class TranslationTransformerModel(PreTrainedModel):
231
 
232
  @staticmethod
233
  def _reorder_cache(past_key_values, beam_idx):
234
- """Reorder cache for beam search (placeholder)"""
235
  return past_key_values
236
 
237
  def generate(
@@ -246,22 +234,7 @@ class TranslationTransformerModel(PreTrainedModel):
246
  top_p: float = 1.0,
247
  **kwargs
248
  ) -> torch.LongTensor:
249
- """
250
- Generate translations
251
-
252
- Args:
253
- input_ids: Source sequence [batch_size, src_seq_len]
254
- attention_mask: Source attention mask
255
- max_length: Maximum generation length
256
- num_beams: Number of beams for beam search
257
- temperature: Sampling temperature
258
- do_sample: Whether to use sampling
259
- top_k: Top-k sampling parameter
260
- top_p: Nucleus sampling parameter
261
-
262
- Returns:
263
- Generated sequences [batch_size, tgt_seq_len]
264
- """
265
  device = input_ids.device
266
  batch_size = input_ids.size(0)
267
 
@@ -277,7 +250,6 @@ class TranslationTransformerModel(PreTrainedModel):
277
 
278
  # Generate tokens one by one
279
  for _ in range(max_length - 1):
280
- # Forward pass
281
  outputs = self.forward(
282
  input_ids=input_ids,
283
  attention_mask=attention_mask,
@@ -285,16 +257,13 @@ class TranslationTransformerModel(PreTrainedModel):
285
  return_dict=True
286
  )
287
 
288
- # Get next token logits
289
  next_token_logits = outputs.logits[:, -1, :] / temperature
290
 
291
  if do_sample:
292
- # Apply top-k filtering
293
  if top_k > 0:
294
  indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
295
  next_token_logits[indices_to_remove] = float('-inf')
296
 
297
- # Apply top-p (nucleus) filtering
298
  if top_p < 1.0:
299
  sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
300
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
@@ -304,23 +273,15 @@ class TranslationTransformerModel(PreTrainedModel):
304
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
305
  next_token_logits[indices_to_remove] = float('-inf')
306
 
307
- # Sample
308
  probs = torch.softmax(next_token_logits, dim=-1)
309
  next_token = torch.multinomial(probs, num_samples=1)
310
  else:
311
- # Greedy selection
312
  next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
313
 
314
- # Mark finished sequences (those that generated EOS)
315
  finished = finished | (next_token.squeeze(-1) == self.config.eos_token_id)
316
-
317
- # Replace tokens in finished sequences with PAD
318
  next_token[finished] = self.config.pad_token_id
319
-
320
- # Append to decoder input
321
  decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
322
 
323
- # Stop if all sequences are finished
324
  if finished.all():
325
  break
326
 
@@ -329,7 +290,6 @@ class TranslationTransformerModel(PreTrainedModel):
329
 
330
  # Register the model in the AutoModel registry
331
  from transformers import AutoConfig, AutoModel, AutoModelForSeq2SeqLM
332
- from .configuration_translation_transformer import TranslationTransformerConfig
333
 
334
  AutoConfig.register("translation_transformer", TranslationTransformerConfig)
335
  AutoModel.register(TranslationTransformerConfig, TranslationTransformerModel)
 
1
+ """
2
+ Translation Transformer Model for HuggingFace Hub
3
+ """
4
  import torch
5
  import torch.nn as nn
6
  from transformers import PreTrainedModel, PretrainedConfig
 
145
  return_dict: Optional[bool] = None,
146
  **kwargs
147
  ) -> Union[Tuple, Seq2SeqLMOutput]:
148
+ """Forward pass"""
 
 
 
 
 
 
 
 
 
 
 
 
149
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
150
  device = input_ids.device
151
 
152
  # If labels provided but no decoder_input_ids, shift labels to create decoder_input_ids
153
  if labels is not None and decoder_input_ids is None:
 
154
  labels_shifted = labels.clone()
155
  labels_shifted[labels_shifted == -100] = self.config.pad_token_id
156
 
 
157
  decoder_input_ids = torch.cat([
158
  torch.full((labels.shape[0], 1), self.config.bos_token_id, dtype=torch.long, device=device),
159
  labels_shifted[:, :-1]
 
189
  # Calculate loss if labels provided
190
  loss = None
191
  if labels is not None:
 
192
  loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
193
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
194
 
 
210
  encoder_outputs=None,
211
  **kwargs
212
  ):
213
+ """Prepare inputs for generation"""
214
  return {
215
  "input_ids": kwargs.get("input_ids"),
216
  "decoder_input_ids": decoder_input_ids,
 
219
 
220
  @staticmethod
221
  def _reorder_cache(past_key_values, beam_idx):
222
+ """Reorder cache for beam search"""
223
  return past_key_values
224
 
225
  def generate(
 
234
  top_p: float = 1.0,
235
  **kwargs
236
  ) -> torch.LongTensor:
237
+ """Generate translations"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  device = input_ids.device
239
  batch_size = input_ids.size(0)
240
 
 
250
 
251
  # Generate tokens one by one
252
  for _ in range(max_length - 1):
 
253
  outputs = self.forward(
254
  input_ids=input_ids,
255
  attention_mask=attention_mask,
 
257
  return_dict=True
258
  )
259
 
 
260
  next_token_logits = outputs.logits[:, -1, :] / temperature
261
 
262
  if do_sample:
 
263
  if top_k > 0:
264
  indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
265
  next_token_logits[indices_to_remove] = float('-inf')
266
 
 
267
  if top_p < 1.0:
268
  sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
269
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
 
273
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
274
  next_token_logits[indices_to_remove] = float('-inf')
275
 
 
276
  probs = torch.softmax(next_token_logits, dim=-1)
277
  next_token = torch.multinomial(probs, num_samples=1)
278
  else:
 
279
  next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
280
 
 
281
  finished = finished | (next_token.squeeze(-1) == self.config.eos_token_id)
 
 
282
  next_token[finished] = self.config.pad_token_id
 
 
283
  decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
284
 
 
285
  if finished.all():
286
  break
287
 
 
290
 
291
  # Register the model in the AutoModel registry
292
  from transformers import AutoConfig, AutoModel, AutoModelForSeq2SeqLM
 
293
 
294
  AutoConfig.register("translation_transformer", TranslationTransformerConfig)
295
  AutoModel.register(TranslationTransformerConfig, TranslationTransformerModel)