llm-o1 / modeling_thai_llm.py
JonusNattapong's picture
Create modeling_thai_llm.py
65c3db0 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_thai_llm import ThaiLLMConfig
class ThaiLLMBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = nn.MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
dropout=config.attention_probs_dropout_prob,
batch_first=True
)
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.GELU(),
nn.Linear(config.intermediate_size, config.hidden_size),
nn.Dropout(config.hidden_dropout_prob)
)
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, x, key_padding_mask=None):
residual = x
x, _ = self.attn(x, x, x, key_padding_mask=key_padding_mask)
x = self.norm1(x + residual)
residual = x
x = self.mlp(x)
x = self.norm2(x + residual)
return x
class ThaiLLMForCausalLM(PreTrainedModel):
config_class = ThaiLLMConfig
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_positions = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.layers = nn.ModuleList([ThaiLLMBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self._init_weights()
def _init_weights(self):
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Embedding)):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, input_ids, attention_mask=None, labels=None):
bsz, seq_len = input_ids.shape
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0).expand(bsz, -1)
hidden_states = self.embed_tokens(input_ids) + self.embed_positions(position_ids)
key_padding_mask = attention_mask == 0 if attention_mask is not None else None
for layer in self.layers:
hidden_states = layer(hidden_states, key_padding_mask=key_padding_mask)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return CausalLMOutput(logits=logits, loss=loss)