File size: 6,885 Bytes
e63dd1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import math
import torch
import torch.nn as nn
from torch.nn import functional
from transformers import PreTrainedModel, PretrainedConfig
class Heads(nn.Module):
def __init__(self, feature_embed, head_size, block_size):
super().__init__()
self.q = nn.Linear(feature_embed, head_size, bias=False)
self.k = nn.Linear(feature_embed, head_size, bias=False)
self.v = nn.Linear(feature_embed, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size)))
self.dropout = nn.Dropout(0.15)
def forward(self, x):
B, T, C = x.shape
k = self.k(x)
q = self.q(x)
v = self.v(x)
weighted = q @ k.transpose(-2,-1) * (k.shape[-1] ** -0.5)
weighted = weighted.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
weighted = functional.softmax(weighted, dim=-1)
weighted = self.dropout(weighted)
return weighted @ v
class MultiHeadAttention(nn.Module):
def __init__(self, head_size, n_heads, feature_embed, block_size):
super().__init__()
self.multiple_heads = nn.ModuleList(Heads(feature_embed, head_size, block_size) for _ in range(n_heads))
self.linear = nn.Linear(head_size*n_heads, feature_embed)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
out = torch.cat([head(x) for head in self.multiple_heads], dim=-1)
out = self.linear(out)
return self.dropout(out)
class Decoder(nn.Module):
def __init__(self, feature_embed, n_heads, block_size):
super().__init__()
head_size = feature_embed // n_heads
self.multihead = MultiHeadAttention(head_size, n_heads, feature_embed, block_size=block_size)
self.layerNorm = nn.LayerNorm(feature_embed)
def forward(self, x):
y = self.multihead(x)
return self.layerNorm(x+y)
class NOVA(nn.Module):
def __init__(self, vocab_size, block_size=256, feature_embed=640, n_layers=4, n_heads=8):
super().__init__()
self.vocab_size = vocab_size
self.block_size = block_size
self.feature_embed = feature_embed
self.n_layers = n_layers
self.n_heads = n_heads
self.vector_embedding = nn.Embedding(vocab_size, feature_embed)
self.learnable_position = nn.Embedding(block_size, feature_embed) # learnable positional encoding
# Sinusoidal Positional encoding
sinusoid = torch.zeros(block_size, feature_embed)
position = torch.arange(0, block_size, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, feature_embed, 2).float() * (-math.log(10000.0) / feature_embed))
sinusoid[:, 0::2] = torch.sin(position * div_term)
sinusoid[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('sinusoidal_encoding', sinusoid) # not trainable
# initialising Decoder Model
self.decoder_block = nn.Sequential(*[
Decoder(feature_embed, n_heads=n_heads, block_size=self.block_size) for _ in range(n_layers)
])
self.linear_head = nn.Linear(feature_embed, vocab_size)
self.layer_norm = nn.LayerNorm(feature_embed)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
if isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
def forward(self, indx, target=None):
B, T = indx.shape
token_embedding = self.vector_embedding(indx) # [B, T, C]
# Positional encoding (hybrid: learned + sinusoidal)
learned = self.learnable_position(torch.arange(T, device=indx.device)) # [T, C]
sinusoidal = self.sinusoidal_encoding[:T] # [T, C]
positional_encoding = learned + sinusoidal # [T, C]
positional_encoding = positional_encoding.unsqueeze(0).expand(B, -1, -1) # [B, T, C]
x = token_embedding + positional_encoding # [B, T, C]
x = self.decoder_block(x) # [B, T, C]
x = self.layer_norm(x) # [B, T, C]
logits = self.linear_head(x) # [B, T, vocab_size]
if target is None:
return logits, None
# Shift logits and targets for causal language modeling
logits = logits[:, :-1, :] # [B, T-1, vocab_size]
target = target[:, 1:] # [B, T-1]
# Flatten for loss
logits = logits.contiguous().view(-1, logits.size(-1)) # [B*(T-1), vocab_size]
target = target.contiguous().view(-1) # [B*(T-1)]
loss = functional.cross_entropy(logits, target, ignore_index=-100)
return logits, loss
@torch.no_grad()
def generate(self, index, max_tokens=512):
for _ in range(max_tokens):
index_cond = index[:,-self.block_size:]
logits, loss = self.forward(index_cond)
logits = logits[:,-1,:]
probs = torch.softmax(logits, dim=-1)
next_index = torch.multinomial(probs, num_samples=1)
# if next_index == self.eos_id:
# break
index = torch.cat((index,next_index), dim=1)
return index
class NovaConfig(PretrainedConfig):
model_type = "nova"
def __init__(self, vocab_size=6000, block_size=256, feature_embed=640, n_layers=4, n_heads=8, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.block_size = block_size
self.n_embd = feature_embed
self.n_layer = n_layers
self.n_head = n_heads
class NovaForCausalLM(PreTrainedModel):
config_class = NovaConfig
def __init__(self, config: NovaConfig):
super().__init__(config)
# your original model init logic here
self.vocab_size = config.vocab_size
self.block_size = config.block_size
self.model = NOVA(vocab_size=self.vocab_size, block_size=self.block_size,
feature_embed=config.n_embd, n_layers=config.n_layer, n_heads=config.n_head)
self.post_init() # important for HF compatibility
def forward(self, input_ids, labels=None):
return self.model(input_ids, labels)
def generate(self, input_ids, max_length=256):
return self.model.generate(input_ids, max_length)
|