| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutput |
| from .configuration_thai_llm import ThaiLLMConfig |
|
|
| class ThaiLLMBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.attn = nn.MultiheadAttention( |
| config.hidden_size, |
| config.num_attention_heads, |
| dropout=config.attention_probs_dropout_prob, |
| batch_first=True |
| ) |
| self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.mlp = nn.Sequential( |
| nn.Linear(config.hidden_size, config.intermediate_size), |
| nn.GELU(), |
| nn.Linear(config.intermediate_size, config.hidden_size), |
| nn.Dropout(config.hidden_dropout_prob) |
| ) |
| self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| def forward(self, x, key_padding_mask=None): |
| residual = x |
| x, _ = self.attn(x, x, x, key_padding_mask=key_padding_mask) |
| x = self.norm1(x + residual) |
|
|
| residual = x |
| x = self.mlp(x) |
| x = self.norm2(x + residual) |
|
|
| return x |
|
|
| class ThaiLLMForCausalLM(PreTrainedModel): |
| config_class = ThaiLLMConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.embed_positions = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| self.layers = nn.ModuleList([ThaiLLMBlock(config) for _ in range(config.num_hidden_layers)]) |
| self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for name, module in self.named_modules(): |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None): |
| bsz, seq_len = input_ids.shape |
| position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0).expand(bsz, -1) |
|
|
| hidden_states = self.embed_tokens(input_ids) + self.embed_positions(position_ids) |
|
|
| key_padding_mask = attention_mask == 0 if attention_mask is not None else None |
|
|
| for layer in self.layers: |
| hidden_states = layer(hidden_states, key_padding_mask=key_padding_mask) |
|
|
| hidden_states = self.norm(hidden_states) |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
| return CausalLMOutput(logits=logits, loss=loss) |