NOVA-Verse / nova_modelling.py
harshit36's picture
Upload 6 files
e63dd1f verified
import math
import torch
import torch.nn as nn
from torch.nn import functional
from transformers import PreTrainedModel, PretrainedConfig
class Heads(nn.Module):
def __init__(self, feature_embed, head_size, block_size):
super().__init__()
self.q = nn.Linear(feature_embed, head_size, bias=False)
self.k = nn.Linear(feature_embed, head_size, bias=False)
self.v = nn.Linear(feature_embed, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size)))
self.dropout = nn.Dropout(0.15)
def forward(self, x):
B, T, C = x.shape
k = self.k(x)
q = self.q(x)
v = self.v(x)
weighted = q @ k.transpose(-2,-1) * (k.shape[-1] ** -0.5)
weighted = weighted.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
weighted = functional.softmax(weighted, dim=-1)
weighted = self.dropout(weighted)
return weighted @ v
class MultiHeadAttention(nn.Module):
def __init__(self, head_size, n_heads, feature_embed, block_size):
super().__init__()
self.multiple_heads = nn.ModuleList(Heads(feature_embed, head_size, block_size) for _ in range(n_heads))
self.linear = nn.Linear(head_size*n_heads, feature_embed)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
out = torch.cat([head(x) for head in self.multiple_heads], dim=-1)
out = self.linear(out)
return self.dropout(out)
class Decoder(nn.Module):
def __init__(self, feature_embed, n_heads, block_size):
super().__init__()
head_size = feature_embed // n_heads
self.multihead = MultiHeadAttention(head_size, n_heads, feature_embed, block_size=block_size)
self.layerNorm = nn.LayerNorm(feature_embed)
def forward(self, x):
y = self.multihead(x)
return self.layerNorm(x+y)
class NOVA(nn.Module):
def __init__(self, vocab_size, block_size=256, feature_embed=640, n_layers=4, n_heads=8):
super().__init__()
self.vocab_size = vocab_size
self.block_size = block_size
self.feature_embed = feature_embed
self.n_layers = n_layers
self.n_heads = n_heads
self.vector_embedding = nn.Embedding(vocab_size, feature_embed)
self.learnable_position = nn.Embedding(block_size, feature_embed) # learnable positional encoding
# Sinusoidal Positional encoding
sinusoid = torch.zeros(block_size, feature_embed)
position = torch.arange(0, block_size, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, feature_embed, 2).float() * (-math.log(10000.0) / feature_embed))
sinusoid[:, 0::2] = torch.sin(position * div_term)
sinusoid[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('sinusoidal_encoding', sinusoid) # not trainable
# initialising Decoder Model
self.decoder_block = nn.Sequential(*[
Decoder(feature_embed, n_heads=n_heads, block_size=self.block_size) for _ in range(n_layers)
])
self.linear_head = nn.Linear(feature_embed, vocab_size)
self.layer_norm = nn.LayerNorm(feature_embed)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
if isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
def forward(self, indx, target=None):
B, T = indx.shape
token_embedding = self.vector_embedding(indx) # [B, T, C]
# Positional encoding (hybrid: learned + sinusoidal)
learned = self.learnable_position(torch.arange(T, device=indx.device)) # [T, C]
sinusoidal = self.sinusoidal_encoding[:T] # [T, C]
positional_encoding = learned + sinusoidal # [T, C]
positional_encoding = positional_encoding.unsqueeze(0).expand(B, -1, -1) # [B, T, C]
x = token_embedding + positional_encoding # [B, T, C]
x = self.decoder_block(x) # [B, T, C]
x = self.layer_norm(x) # [B, T, C]
logits = self.linear_head(x) # [B, T, vocab_size]
if target is None:
return logits, None
# Shift logits and targets for causal language modeling
logits = logits[:, :-1, :] # [B, T-1, vocab_size]
target = target[:, 1:] # [B, T-1]
# Flatten for loss
logits = logits.contiguous().view(-1, logits.size(-1)) # [B*(T-1), vocab_size]
target = target.contiguous().view(-1) # [B*(T-1)]
loss = functional.cross_entropy(logits, target, ignore_index=-100)
return logits, loss
@torch.no_grad()
def generate(self, index, max_tokens=512):
for _ in range(max_tokens):
index_cond = index[:,-self.block_size:]
logits, loss = self.forward(index_cond)
logits = logits[:,-1,:]
probs = torch.softmax(logits, dim=-1)
next_index = torch.multinomial(probs, num_samples=1)
# if next_index == self.eos_id:
# break
index = torch.cat((index,next_index), dim=1)
return index
class NovaConfig(PretrainedConfig):
model_type = "nova"
def __init__(self, vocab_size=6000, block_size=256, feature_embed=640, n_layers=4, n_heads=8, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.block_size = block_size
self.n_embd = feature_embed
self.n_layer = n_layers
self.n_head = n_heads
class NovaForCausalLM(PreTrainedModel):
config_class = NovaConfig
def __init__(self, config: NovaConfig):
super().__init__(config)
# your original model init logic here
self.vocab_size = config.vocab_size
self.block_size = config.block_size
self.model = NOVA(vocab_size=self.vocab_size, block_size=self.block_size,
feature_embed=config.n_embd, n_layers=config.n_layer, n_heads=config.n_head)
self.post_init() # important for HF compatibility
def forward(self, input_ids, labels=None):
return self.model(input_ids, labels)
def generate(self, input_ids, max_length=256):
return self.model.generate(input_ids, max_length)