tamoghna commited on
Commit
a0c9612
·
verified ·
1 Parent(s): ed109c7

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +8 -124
modeling.py CHANGED
@@ -1,126 +1,3 @@
1
- # import math
2
- # import torch
3
- # import torch.nn as nn
4
- # from transformers import PretrainedConfig, PreTrainedModel
5
- # import warnings
6
-
7
- # # Use the Hugging Face base configuration class for compatibility
8
- # class TransformerConfig(PretrainedConfig):
9
- # # Model type must match the one found in your config.json (small_transformer)
10
- # model_type = "small_transformer"
11
-
12
- # def __init__(self,
13
- # vocab_size=80000,
14
- # d_model=256,
15
- # nhead=8,
16
- # num_encoder_layers=3,
17
- # num_decoder_layers=3,
18
- # dim_feedforward=512,
19
- # dropout=0.1,
20
- # pad_token_id=0,
21
- # bos_token_id=1, # Assuming <s> is 1
22
- # eos_token_id=2, # Assuming </s> is 2
23
- # max_position_embeddings=512,
24
- # **kwargs):
25
- # super().__init__(pad_token_id=pad_token_id,
26
- # bos_token_id=bos_token_id,
27
- # eos_token_id=eos_token_id,
28
- # **kwargs)
29
- # self.vocab_size = vocab_size
30
- # self.d_model = d_model
31
- # self.nhead = nhead
32
- # self.num_encoder_layers = num_encoder_layers
33
- # self.num_decoder_layers = num_decoder_layers
34
- # self.dim_feedforward = dim_feedforward
35
- # self.dropout = dropout
36
- # self.max_position_embeddings = max_position_embeddings
37
-
38
- # # Add a placeholder for decoder_start_token_id, which is needed for generation
39
- # if not hasattr(self, "decoder_start_token_id"):
40
- # # For a multilingual model, this is often the target language token ID
41
- # # You will set this explicitly during generation in your Gradio app (as shown previously)
42
- # self.decoder_start_token_id = None
43
-
44
-
45
- # # Use the Hugging Face base model class for compatibility
46
- # class SmallTransformer(PreTrainedModel):
47
- # # Link the model to its configuration class
48
- # config_class = TransformerConfig
49
-
50
- # def __init__(self, config):
51
- # super().__init__(config)
52
- # self.config = config
53
-
54
- # # --- Model Components (from your training code) ---
55
- # self.embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
56
- # self.pos_encoder = nn.Embedding(config.max_position_embeddings, config.d_model)
57
- # self.pos_decoder = nn.Embedding(config.max_position_embeddings, config.d_model)
58
- # self.embed_scale = math.sqrt(config.d_model)
59
-
60
- # enc_layer = nn.TransformerEncoderLayer(d_model=config.d_model, nhead=config.nhead,
61
- # dim_feedforward=config.dim_feedforward,
62
- # dropout=config.dropout, batch_first=True)
63
- # dec_layer = nn.TransformerDecoderLayer(d_model=config.d_model, nhead=config.nhead,
64
- # dim_feedforward=config.dim_feedforward,
65
- # dropout=config.dropout, batch_first=True)
66
-
67
- # self.encoder = nn.TransformerEncoder(enc_layer, num_layers=config.num_encoder_layers)
68
- # self.decoder = nn.TransformerDecoder(dec_layer, num_layers=config.num_decoder_layers)
69
- # self.output_layer = nn.Linear(config.d_model, config.vocab_size)
70
-
71
- # # Initialize weights
72
- # self.post_init()
73
-
74
- # # Implement the forward pass exactly as you had it
75
- # def forward(self, input_ids=None, decoder_input_ids=None, **kwargs):
76
- # src = input_ids
77
- # tgt = decoder_input_ids
78
-
79
- # assert src.dim() == 2 and tgt.dim() == 2
80
-
81
- # # Your custom max_token check (omitting for brevity but keep if you need it)
82
-
83
- # src_mask = (src == self.config.pad_token_id)
84
- # tgt_mask_pad = (tgt == self.config.pad_token_id)
85
-
86
- # T = tgt.size(1)
87
- # # Create Causal Mask
88
- # causal_mask = torch.triu(torch.ones((T, T), device=tgt.device), diagonal=1).bool()
89
-
90
- # # Positional Encoding
91
- # src_pos = torch.arange(0, src.size(1), device=src.device).unsqueeze(0).expand(src.size(0), -1).clamp(max=self.config.max_position_embeddings - 1)
92
- # tgt_pos = torch.arange(0, tgt.size(1), device=tgt.device).unsqueeze(0).expand(tgt.size(0), -1).clamp(max=self.config.max_position_embeddings - 1)
93
-
94
- # src_emb = self.embedding(src) * self.embed_scale + self.pos_encoder(src_pos)
95
- # tgt_emb = self.embedding(tgt) * self.embed_scale + self.pos_decoder(tgt_pos)
96
-
97
- # memory = self.encoder(src_emb, src_key_padding_mask=src_mask)
98
- # output = self.decoder(tgt_emb, memory, tgt_mask=causal_mask,
99
- # tgt_key_padding_mask=tgt_mask_pad,
100
- # memory_key_padding_mask=src_mask)
101
-
102
- # # The output must be the logits before the final softmax/loss
103
- # logits = self.output_layer(output)
104
-
105
- # # Return a dictionary/tuple of outputs compatible with PreTrainedModel
106
- # return (logits,) # Return logits in a tuple for compatibility
107
-
108
- # # Implement the mandatory generate method (minimal implementation)
109
- # def prepare_inputs_for_generation(self, decoder_input_ids, **kwargs):
110
- # # This method is required by the .generate() function
111
- # return {"input_ids": kwargs.get("input_ids"), "decoder_input_ids": decoder_input_ids}
112
-
113
- # def _prepare_decoder_input_ids_for_generation(self, decoder_input_ids, **kwargs):
114
- # # A simple method to ensure the decoder input starts with the language token
115
- # # This is typically handled by generation_config, but we include a check here
116
- # if decoder_input_ids is None and self.config.decoder_start_token_id is not None:
117
- # warnings.warn("Using decoder_start_token_id from config. This should be manually set during generation.")
118
- # decoder_input_ids = torch.ones((kwargs["input_ids"].shape[0], 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id
119
- # return decoder_input_ids
120
-
121
-
122
- # # No registration needed - auto_map in config.json handles this
123
-
124
  """PyTorch Small Transformer model for English to Hindi/Bengali translation."""
125
 
126
  import math
@@ -299,7 +176,8 @@ class SmallTransformer(SmallTransformerPreTrainedModel):
299
  def generate(
300
  self,
301
  input_ids: torch.LongTensor,
302
- max_length: int = 64,
 
303
  lang_token_id: int = None,
304
  eos_token_id: int = None,
305
  **kwargs
@@ -308,6 +186,12 @@ class SmallTransformer(SmallTransformerPreTrainedModel):
308
  if eos_token_id is None:
309
  eos_token_id = self.config.eos_token_id
310
 
 
 
 
 
 
 
311
  batch_size = input_ids.size(0)
312
  device = input_ids.device
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """PyTorch Small Transformer model for English to Hindi/Bengali translation."""
2
 
3
  import math
 
176
  def generate(
177
  self,
178
  input_ids: torch.LongTensor,
179
+ max_length: int = None,
180
+ max_new_tokens: int = None,
181
  lang_token_id: int = None,
182
  eos_token_id: int = None,
183
  **kwargs
 
186
  if eos_token_id is None:
187
  eos_token_id = self.config.eos_token_id
188
 
189
+ # Handle max_new_tokens parameter
190
+ if max_new_tokens is not None:
191
+ max_length = max_new_tokens
192
+ elif max_length is None:
193
+ max_length = 64
194
+
195
  batch_size = input_ids.size(0)
196
  device = input_ids.device
197