import math import torch import torch.nn as nn from torch.nn import functional from transformers import PreTrainedModel, PretrainedConfig class Heads(nn.Module): def __init__(self, feature_embed, head_size, block_size): super().__init__() self.q = nn.Linear(feature_embed, head_size, bias=False) self.k = nn.Linear(feature_embed, head_size, bias=False) self.v = nn.Linear(feature_embed, head_size, bias=False) self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size))) self.dropout = nn.Dropout(0.15) def forward(self, x): B, T, C = x.shape k = self.k(x) q = self.q(x) v = self.v(x) weighted = q @ k.transpose(-2,-1) * (k.shape[-1] ** -0.5) weighted = weighted.masked_fill(self.tril[:T,:T] == 0, float('-inf')) weighted = functional.softmax(weighted, dim=-1) weighted = self.dropout(weighted) return weighted @ v class MultiHeadAttention(nn.Module): def __init__(self, head_size, n_heads, feature_embed, block_size): super().__init__() self.multiple_heads = nn.ModuleList(Heads(feature_embed, head_size, block_size) for _ in range(n_heads)) self.linear = nn.Linear(head_size*n_heads, feature_embed) self.dropout = nn.Dropout(0.1) def forward(self, x): out = torch.cat([head(x) for head in self.multiple_heads], dim=-1) out = self.linear(out) return self.dropout(out) class Decoder(nn.Module): def __init__(self, feature_embed, n_heads, block_size): super().__init__() head_size = feature_embed // n_heads self.multihead = MultiHeadAttention(head_size, n_heads, feature_embed, block_size=block_size) self.layerNorm = nn.LayerNorm(feature_embed) def forward(self, x): y = self.multihead(x) return self.layerNorm(x+y) class NOVA(nn.Module): def __init__(self, vocab_size, block_size=256, feature_embed=640, n_layers=4, n_heads=8): super().__init__() self.vocab_size = vocab_size self.block_size = block_size self.feature_embed = feature_embed self.n_layers = n_layers self.n_heads = n_heads self.vector_embedding = nn.Embedding(vocab_size, feature_embed) self.learnable_position = nn.Embedding(block_size, feature_embed) # learnable positional encoding # Sinusoidal Positional encoding sinusoid = torch.zeros(block_size, feature_embed) position = torch.arange(0, block_size, dtype=torch.float32).unsqueeze(1) div_term = torch.exp(torch.arange(0, feature_embed, 2).float() * (-math.log(10000.0) / feature_embed)) sinusoid[:, 0::2] = torch.sin(position * div_term) sinusoid[:, 1::2] = torch.cos(position * div_term) self.register_buffer('sinusoidal_encoding', sinusoid) # not trainable # initialising Decoder Model self.decoder_block = nn.Sequential(*[ Decoder(feature_embed, n_heads=n_heads, block_size=self.block_size) for _ in range(n_layers) ]) self.linear_head = nn.Linear(feature_embed, vocab_size) self.layer_norm = nn.LayerNorm(feature_embed) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.01) if module.bias is not None: torch.nn.init.zeros_(module.bias) if isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.01) def forward(self, indx, target=None): B, T = indx.shape token_embedding = self.vector_embedding(indx) # [B, T, C] # Positional encoding (hybrid: learned + sinusoidal) learned = self.learnable_position(torch.arange(T, device=indx.device)) # [T, C] sinusoidal = self.sinusoidal_encoding[:T] # [T, C] positional_encoding = learned + sinusoidal # [T, C] positional_encoding = positional_encoding.unsqueeze(0).expand(B, -1, -1) # [B, T, C] x = token_embedding + positional_encoding # [B, T, C] x = self.decoder_block(x) # [B, T, C] x = self.layer_norm(x) # [B, T, C] logits = self.linear_head(x) # [B, T, vocab_size] if target is None: return logits, None # Shift logits and targets for causal language modeling logits = logits[:, :-1, :] # [B, T-1, vocab_size] target = target[:, 1:] # [B, T-1] # Flatten for loss logits = logits.contiguous().view(-1, logits.size(-1)) # [B*(T-1), vocab_size] target = target.contiguous().view(-1) # [B*(T-1)] loss = functional.cross_entropy(logits, target, ignore_index=-100) return logits, loss @torch.no_grad() def generate(self, index, max_tokens=512): for _ in range(max_tokens): index_cond = index[:,-self.block_size:] logits, loss = self.forward(index_cond) logits = logits[:,-1,:] probs = torch.softmax(logits, dim=-1) next_index = torch.multinomial(probs, num_samples=1) # if next_index == self.eos_id: # break index = torch.cat((index,next_index), dim=1) return index class NovaConfig(PretrainedConfig): model_type = "nova" def __init__(self, vocab_size=6000, block_size=256, feature_embed=640, n_layers=4, n_heads=8, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.block_size = block_size self.n_embd = feature_embed self.n_layer = n_layers self.n_head = n_heads class NovaForCausalLM(PreTrainedModel): config_class = NovaConfig def __init__(self, config: NovaConfig): super().__init__(config) # your original model init logic here self.vocab_size = config.vocab_size self.block_size = config.block_size self.model = NOVA(vocab_size=self.vocab_size, block_size=self.block_size, feature_embed=config.n_embd, n_layers=config.n_layer, n_heads=config.n_head) self.post_init() # important for HF compatibility def forward(self, input_ids, labels=None): return self.model(input_ids, labels) def generate(self, input_ids, max_length=256): return self.model.generate(input_ids, max_length)