tamoghna commited on
Commit
133eda2
·
verified ·
1 Parent(s): 83b0712

Create modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +128 -0
modeling.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PretrainedConfig, PreTrainedModel
5
+ import warnings
6
+
7
+ # Use the Hugging Face base configuration class for compatibility
8
+ class TransformerConfig(PretrainedConfig):
9
+ # Model type must match the one found in your config.json (small_transformer)
10
+ model_type = "small_transformer"
11
+
12
+ def __init__(self,
13
+ vocab_size=80000,
14
+ d_model=256,
15
+ nhead=8,
16
+ num_encoder_layers=3,
17
+ num_decoder_layers=3,
18
+ dim_feedforward=512,
19
+ dropout=0.1,
20
+ pad_token_id=0,
21
+ bos_token_id=1, # Assuming <s> is 1
22
+ eos_token_id=2, # Assuming </s> is 2
23
+ max_position_embeddings=512,
24
+ **kwargs):
25
+ super().__init__(pad_token_id=pad_token_id,
26
+ bos_token_id=bos_token_id,
27
+ eos_token_id=eos_token_id,
28
+ **kwargs)
29
+ self.vocab_size = vocab_size
30
+ self.d_model = d_model
31
+ self.nhead = nhead
32
+ self.num_encoder_layers = num_encoder_layers
33
+ self.num_decoder_layers = num_decoder_layers
34
+ self.dim_feedforward = dim_feedforward
35
+ self.dropout = dropout
36
+ self.max_position_embeddings = max_position_embeddings
37
+
38
+ # Add a placeholder for decoder_start_token_id, which is needed for generation
39
+ if not hasattr(self, "decoder_start_token_id"):
40
+ # For a multilingual model, this is often the target language token ID
41
+ # You will set this explicitly during generation in your Gradio app (as shown previously)
42
+ self.decoder_start_token_id = None
43
+
44
+
45
+ # Use the Hugging Face base model class for compatibility
46
+ class SmallTransformer(PreTrainedModel):
47
+ # Link the model to its configuration class
48
+ config_class = TransformerConfig
49
+
50
+ def __init__(self, config):
51
+ super().__init__(config)
52
+ self.config = config
53
+
54
+ # --- Model Components (from your training code) ---
55
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
56
+ self.pos_encoder = nn.Embedding(config.max_position_embeddings, config.d_model)
57
+ self.pos_decoder = nn.Embedding(config.max_position_embeddings, config.d_model)
58
+ self.embed_scale = math.sqrt(config.d_model)
59
+
60
+ enc_layer = nn.TransformerEncoderLayer(d_model=config.d_model, nhead=config.nhead,
61
+ dim_feedforward=config.dim_feedforward,
62
+ dropout=config.dropout, batch_first=True)
63
+ dec_layer = nn.TransformerDecoderLayer(d_model=config.d_model, nhead=config.nhead,
64
+ dim_feedforward=config.dim_feedforward,
65
+ dropout=config.dropout, batch_first=True)
66
+
67
+ self.encoder = nn.TransformerEncoder(enc_layer, num_layers=config.num_encoder_layers)
68
+ self.decoder = nn.TransformerDecoder(dec_layer, num_layers=config.num_decoder_layers)
69
+ self.output_layer = nn.Linear(config.d_model, config.vocab_size)
70
+
71
+ # Initialize weights
72
+ self.post_init()
73
+
74
+ # Implement the forward pass exactly as you had it
75
+ def forward(self, input_ids=None, decoder_input_ids=None, **kwargs):
76
+ src = input_ids
77
+ tgt = decoder_input_ids
78
+
79
+ assert src.dim() == 2 and tgt.dim() == 2
80
+
81
+ # Your custom max_token check (omitting for brevity but keep if you need it)
82
+
83
+ src_mask = (src == self.config.pad_token_id)
84
+ tgt_mask_pad = (tgt == self.config.pad_token_id)
85
+
86
+ T = tgt.size(1)
87
+ # Create Causal Mask
88
+ causal_mask = torch.triu(torch.ones((T, T), device=tgt.device), diagonal=1).bool()
89
+
90
+ # Positional Encoding
91
+ src_pos = torch.arange(0, src.size(1), device=src.device).unsqueeze(0).expand(src.size(0), -1).clamp(max=self.config.max_position_embeddings - 1)
92
+ tgt_pos = torch.arange(0, tgt.size(1), device=tgt.device).unsqueeze(0).expand(tgt.size(0), -1).clamp(max=self.config.max_position_embeddings - 1)
93
+
94
+ src_emb = self.embedding(src) * self.embed_scale + self.pos_encoder(src_pos)
95
+ tgt_emb = self.embedding(tgt) * self.embed_scale + self.pos_decoder(tgt_pos)
96
+
97
+ memory = self.encoder(src_emb, src_key_padding_mask=src_mask)
98
+ output = self.decoder(tgt_emb, memory, tgt_mask=causal_mask,
99
+ tgt_key_padding_mask=tgt_mask_pad,
100
+ memory_key_padding_mask=src_mask)
101
+
102
+ # The output must be the logits before the final softmax/loss
103
+ logits = self.output_layer(output)
104
+
105
+ # Return a dictionary/tuple of outputs compatible with PreTrainedModel
106
+ return (logits,) # Return logits in a tuple for compatibility
107
+
108
+ # Implement the mandatory generate method (minimal implementation)
109
+ def prepare_inputs_for_generation(self, decoder_input_ids, **kwargs):
110
+ # This method is required by the .generate() function
111
+ return {"input_ids": kwargs.get("input_ids"), "decoder_input_ids": decoder_input_ids}
112
+
113
+ def _prepare_decoder_input_ids_for_generation(self, decoder_input_ids, **kwargs):
114
+ # A simple method to ensure the decoder input starts with the language token
115
+ # This is typically handled by generation_config, but we include a check here
116
+ if decoder_input_ids is None and self.config.decoder_start_token_id is not None:
117
+ warnings.warn("Using decoder_start_token_id from config. This should be manually set during generation.")
118
+ decoder_input_ids = torch.ones((kwargs["input_ids"].shape[0], 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id
119
+ return decoder_input_ids
120
+
121
+
122
+ # Register the custom model type so AutoModel can find it.
123
+ # This ensures that when AutoModelForSeq2SeqLM sees 'model_type': 'small_transformer'
124
+ # in your config.json, it knows to use the SmallTransformer class from this file.
125
+ if SmallTransformer.config_class.model_type in PreTrainedModel._model_mapping.keys():
126
+ pass
127
+ else:
128
+ PreTrainedModel._model_mapping.register(SmallTransformer.config_class, SmallTransformer)