File size: 4,751 Bytes
11ed29b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """
Kayra Turkish GPT Model
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_kayra import KayraConfig
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.dropout = nn.Dropout(config.hidden_dropout)
mask = torch.triu(torch.ones(config.max_position_embeddings, config.max_position_embeddings), diagonal=1).bool()
self.register_buffer("mask", mask)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
attn = attn.masked_fill(self.mask[:T, :T], float('-inf'))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = (attn @ v).transpose(1, 2).reshape(B, T, C)
return self.proj(out)
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.w3 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.norm1 = RMSNorm(config.hidden_size)
self.attn = Attention(config)
self.norm2 = RMSNorm(config.hidden_size)
self.ff = FeedForward(config)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ff(self.norm2(x))
return x
class KayraPreTrainedModel(PreTrainedModel):
config_class = KayraConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
class KayraForCausalLM(KayraPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_emb = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.drop = nn.Dropout(config.hidden_dropout)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.tok_emb.weight
self.post_init()
def get_input_embeddings(self):
return self.tok_emb
def set_input_embeddings(self, value):
self.tok_emb = value
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device)
x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
for block in self.blocks:
x = block(x)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
|