kayra-1 / modeling_kayra.py
sixfingerdev's picture
Kayra-Stable: Fine-tuned with 21K Turkish QA dataset
11ed29b verified
"""
Kayra Turkish GPT Model
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_kayra import KayraConfig
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.dropout = nn.Dropout(config.hidden_dropout)
mask = torch.triu(torch.ones(config.max_position_embeddings, config.max_position_embeddings), diagonal=1).bool()
self.register_buffer("mask", mask)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
attn = attn.masked_fill(self.mask[:T, :T], float('-inf'))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = (attn @ v).transpose(1, 2).reshape(B, T, C)
return self.proj(out)
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.w3 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.norm1 = RMSNorm(config.hidden_size)
self.attn = Attention(config)
self.norm2 = RMSNorm(config.hidden_size)
self.ff = FeedForward(config)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ff(self.norm2(x))
return x
class KayraPreTrainedModel(PreTrainedModel):
config_class = KayraConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
class KayraForCausalLM(KayraPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_emb = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.drop = nn.Dropout(config.hidden_dropout)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.tok_emb.weight
self.post_init()
def get_input_embeddings(self):
return self.tok_emb
def set_input_embeddings(self, value):
self.tok_emb = value
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device)
x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
for block in self.blocks:
x = block(x)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}