tinygpt-ptbr-v1 / modeling_tinygpt.py
Madras1's picture
Update modeling_tinygpt.py
2f74d7e verified
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from .configuration_tinygpt import TinyGPTConfig
class TinyGPTRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
variance = x.float().pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return x * self.weight
class TinyGPTAttention(nn.Module):
def __init__(self, config: TinyGPTConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError("hidden_size must be divisible by num_attention_heads")
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.dropout = nn.Dropout(config.dropout)
def _shape(self, x: torch.Tensor) -> torch.Tensor:
batch, seq_len, _ = x.size()
return x.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q = self._shape(self.q_proj(hidden_states))
k = self._shape(self.k_proj(hidden_states))
v = self._shape(self.v_proj(hidden_states))
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
seq_len = hidden_states.size(1)
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool),
diagonal=1,
)
attn_scores = attn_scores.masked_fill(causal_mask, torch.finfo(attn_scores.dtype).min)
if attention_mask is not None:
key_mask = attention_mask[:, None, None, :].to(torch.bool)
attn_scores = attn_scores.masked_fill(~key_mask, torch.finfo(attn_scores.dtype).min)
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(hidden_states.dtype)
attn_probs = self.dropout(attn_probs)
attn_output = torch.matmul(attn_probs, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(
hidden_states.size(0), seq_len, self.hidden_size
)
return self.out_proj(attn_output)
class TinyGPTMLP(nn.Module):
def __init__(self, config: TinyGPTConfig):
super().__init__()
self.fc_in = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
self.fc_out = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
self.dropout = nn.Dropout(config.dropout)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc_in(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.fc_out(hidden_states)
return self.dropout(hidden_states)
class TinyGPTBlock(nn.Module):
def __init__(self, config: TinyGPTConfig):
super().__init__()
self.attn_norm = TinyGPTRMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = TinyGPTAttention(config)
self.mlp_norm = TinyGPTRMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = TinyGPTMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(self.attn_norm(hidden_states), attention_mask)
hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states))
return hidden_states
class TinyGPTPreTrainedModel(PreTrainedModel):
config_class = TinyGPTConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["TinyGPTBlock"]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
class TinyGPTModel(TinyGPTPreTrainedModel):
def __init__(self, config: TinyGPTConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Parameter(
torch.zeros(config.max_position_embeddings, config.hidden_size)
)
self.dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList(
[TinyGPTBlock(config) for _ in range(config.num_hidden_layers)]
)
self.final_norm = TinyGPTRMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
seq_len = input_ids.size(1)
hidden_states = self.embed_tokens(input_ids) + self.position_embeddings[:seq_len]
hidden_states = self.dropout(hidden_states)
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask=attention_mask)
hidden_states = self.final_norm(hidden_states)
return hidden_states
class TinyGPTForCausalLM(TinyGPTPreTrainedModel):
_tied_weights_keys = []
def __init__(self, config: TinyGPTConfig):
super().__init__(config)
self.model = TinyGPTModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
return {"input_ids": input_ids, "attention_mask": attention_mask}
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
hidden_states = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=self.config.pad_token_id,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)