|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.nn import functional
|
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
|
class Heads(nn.Module):
|
|
|
def __init__(self, feature_embed, head_size, block_size):
|
|
|
super().__init__()
|
|
|
|
|
|
self.q = nn.Linear(feature_embed, head_size, bias=False)
|
|
|
self.k = nn.Linear(feature_embed, head_size, bias=False)
|
|
|
self.v = nn.Linear(feature_embed, head_size, bias=False)
|
|
|
self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size)))
|
|
|
self.dropout = nn.Dropout(0.15)
|
|
|
|
|
|
def forward(self, x):
|
|
|
B, T, C = x.shape
|
|
|
k = self.k(x)
|
|
|
q = self.q(x)
|
|
|
v = self.v(x)
|
|
|
|
|
|
weighted = q @ k.transpose(-2,-1) * (k.shape[-1] ** -0.5)
|
|
|
weighted = weighted.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
|
|
|
weighted = functional.softmax(weighted, dim=-1)
|
|
|
weighted = self.dropout(weighted)
|
|
|
return weighted @ v
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
|
def __init__(self, head_size, n_heads, feature_embed, block_size):
|
|
|
super().__init__()
|
|
|
|
|
|
self.multiple_heads = nn.ModuleList(Heads(feature_embed, head_size, block_size) for _ in range(n_heads))
|
|
|
self.linear = nn.Linear(head_size*n_heads, feature_embed)
|
|
|
self.dropout = nn.Dropout(0.1)
|
|
|
|
|
|
def forward(self, x):
|
|
|
out = torch.cat([head(x) for head in self.multiple_heads], dim=-1)
|
|
|
out = self.linear(out)
|
|
|
return self.dropout(out)
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
def __init__(self, feature_embed, n_heads, block_size):
|
|
|
super().__init__()
|
|
|
|
|
|
head_size = feature_embed // n_heads
|
|
|
self.multihead = MultiHeadAttention(head_size, n_heads, feature_embed, block_size=block_size)
|
|
|
self.layerNorm = nn.LayerNorm(feature_embed)
|
|
|
|
|
|
def forward(self, x):
|
|
|
y = self.multihead(x)
|
|
|
return self.layerNorm(x+y)
|
|
|
|
|
|
class NOVA(nn.Module):
|
|
|
def __init__(self, vocab_size, block_size=256, feature_embed=640, n_layers=4, n_heads=8):
|
|
|
super().__init__()
|
|
|
|
|
|
self.vocab_size = vocab_size
|
|
|
self.block_size = block_size
|
|
|
self.feature_embed = feature_embed
|
|
|
self.n_layers = n_layers
|
|
|
self.n_heads = n_heads
|
|
|
|
|
|
self.vector_embedding = nn.Embedding(vocab_size, feature_embed)
|
|
|
self.learnable_position = nn.Embedding(block_size, feature_embed)
|
|
|
|
|
|
|
|
|
sinusoid = torch.zeros(block_size, feature_embed)
|
|
|
position = torch.arange(0, block_size, dtype=torch.float32).unsqueeze(1)
|
|
|
div_term = torch.exp(torch.arange(0, feature_embed, 2).float() * (-math.log(10000.0) / feature_embed))
|
|
|
sinusoid[:, 0::2] = torch.sin(position * div_term)
|
|
|
sinusoid[:, 1::2] = torch.cos(position * div_term)
|
|
|
self.register_buffer('sinusoidal_encoding', sinusoid)
|
|
|
|
|
|
|
|
|
self.decoder_block = nn.Sequential(*[
|
|
|
Decoder(feature_embed, n_heads=n_heads, block_size=self.block_size) for _ in range(n_layers)
|
|
|
])
|
|
|
self.linear_head = nn.Linear(feature_embed, vocab_size)
|
|
|
self.layer_norm = nn.LayerNorm(feature_embed)
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
if isinstance(module, nn.Linear):
|
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
|
|
if module.bias is not None:
|
|
|
torch.nn.init.zeros_(module.bias)
|
|
|
if isinstance(module, nn.Embedding):
|
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
|
|
|
|
|
def forward(self, indx, target=None):
|
|
|
B, T = indx.shape
|
|
|
|
|
|
token_embedding = self.vector_embedding(indx)
|
|
|
|
|
|
|
|
|
learned = self.learnable_position(torch.arange(T, device=indx.device))
|
|
|
sinusoidal = self.sinusoidal_encoding[:T]
|
|
|
positional_encoding = learned + sinusoidal
|
|
|
positional_encoding = positional_encoding.unsqueeze(0).expand(B, -1, -1)
|
|
|
|
|
|
x = token_embedding + positional_encoding
|
|
|
x = self.decoder_block(x)
|
|
|
x = self.layer_norm(x)
|
|
|
logits = self.linear_head(x)
|
|
|
|
|
|
if target is None:
|
|
|
return logits, None
|
|
|
|
|
|
|
|
|
logits = logits[:, :-1, :]
|
|
|
target = target[:, 1:]
|
|
|
|
|
|
|
|
|
logits = logits.contiguous().view(-1, logits.size(-1))
|
|
|
target = target.contiguous().view(-1)
|
|
|
|
|
|
loss = functional.cross_entropy(logits, target, ignore_index=-100)
|
|
|
|
|
|
return logits, loss
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def generate(self, index, max_tokens=512):
|
|
|
for _ in range(max_tokens):
|
|
|
index_cond = index[:,-self.block_size:]
|
|
|
logits, loss = self.forward(index_cond)
|
|
|
logits = logits[:,-1,:]
|
|
|
probs = torch.softmax(logits, dim=-1)
|
|
|
|
|
|
next_index = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
index = torch.cat((index,next_index), dim=1)
|
|
|
return index
|
|
|
|
|
|
class NovaConfig(PretrainedConfig):
|
|
|
model_type = "nova"
|
|
|
|
|
|
def __init__(self, vocab_size=6000, block_size=256, feature_embed=640, n_layers=4, n_heads=8, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.vocab_size = vocab_size
|
|
|
self.block_size = block_size
|
|
|
self.n_embd = feature_embed
|
|
|
self.n_layer = n_layers
|
|
|
self.n_head = n_heads
|
|
|
|
|
|
class NovaForCausalLM(PreTrainedModel):
|
|
|
config_class = NovaConfig
|
|
|
|
|
|
def __init__(self, config: NovaConfig):
|
|
|
super().__init__(config)
|
|
|
|
|
|
self.vocab_size = config.vocab_size
|
|
|
self.block_size = config.block_size
|
|
|
self.model = NOVA(vocab_size=self.vocab_size, block_size=self.block_size,
|
|
|
feature_embed=config.n_embd, n_layers=config.n_layer, n_heads=config.n_head)
|
|
|
self.post_init()
|
|
|
|
|
|
def forward(self, input_ids, labels=None):
|
|
|
return self.model(input_ids, labels)
|
|
|
|
|
|
def generate(self, input_ids, max_length=256):
|
|
|
return self.model.generate(input_ids, max_length)
|
|
|
|