File size: 6,885 Bytes
e63dd1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import math
import torch
import torch.nn as nn
from torch.nn import functional
from transformers import PreTrainedModel, PretrainedConfig

class Heads(nn.Module):
    def __init__(self, feature_embed, head_size, block_size):
        super().__init__()

        self.q = nn.Linear(feature_embed, head_size, bias=False)
        self.k = nn.Linear(feature_embed, head_size, bias=False)
        self.v = nn.Linear(feature_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size)))
        self.dropout = nn.Dropout(0.15)

    def forward(self, x):
        B, T, C = x.shape
        k = self.k(x)
        q = self.q(x)
        v = self.v(x)

        weighted = q @ k.transpose(-2,-1) * (k.shape[-1] ** -0.5)
        weighted = weighted.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
        weighted = functional.softmax(weighted, dim=-1)
        weighted = self.dropout(weighted)
        return weighted @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, head_size, n_heads, feature_embed, block_size):
        super().__init__()

        self.multiple_heads = nn.ModuleList(Heads(feature_embed, head_size, block_size) for _ in range(n_heads))
        self.linear = nn.Linear(head_size*n_heads, feature_embed)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.multiple_heads], dim=-1)
        out = self.linear(out)
        return self.dropout(out)

class Decoder(nn.Module):
    def __init__(self, feature_embed, n_heads, block_size):
        super().__init__()

        head_size = feature_embed // n_heads
        self.multihead = MultiHeadAttention(head_size, n_heads, feature_embed, block_size=block_size)
        self.layerNorm = nn.LayerNorm(feature_embed)

    def forward(self, x):
        y = self.multihead(x)
        return self.layerNorm(x+y)

class NOVA(nn.Module):
    def __init__(self, vocab_size, block_size=256, feature_embed=640, n_layers=4, n_heads=8):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.feature_embed = feature_embed
        self.n_layers = n_layers
        self.n_heads = n_heads

        self.vector_embedding = nn.Embedding(vocab_size, feature_embed)
        self.learnable_position = nn.Embedding(block_size, feature_embed)   # learnable positional encoding

        # Sinusoidal Positional encoding
        sinusoid = torch.zeros(block_size, feature_embed)
        position = torch.arange(0, block_size, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, feature_embed, 2).float() * (-math.log(10000.0) / feature_embed))
        sinusoid[:, 0::2] = torch.sin(position * div_term)
        sinusoid[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('sinusoidal_encoding', sinusoid)  # not trainable

        # initialising Decoder Model
        self.decoder_block = nn.Sequential(*[
            Decoder(feature_embed, n_heads=n_heads, block_size=self.block_size) for _ in range(n_layers)
        ])
        self.linear_head = nn.Linear(feature_embed, vocab_size)
        self.layer_norm = nn.LayerNorm(feature_embed)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        if isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)

    def forward(self, indx, target=None):
        B, T = indx.shape

        token_embedding = self.vector_embedding(indx)  # [B, T, C]

        # Positional encoding (hybrid: learned + sinusoidal)
        learned = self.learnable_position(torch.arange(T, device=indx.device))      # [T, C]
        sinusoidal = self.sinusoidal_encoding[:T]                                   # [T, C]
        positional_encoding = learned + sinusoidal                                  # [T, C]
        positional_encoding = positional_encoding.unsqueeze(0).expand(B, -1, -1)    # [B, T, C]

        x = token_embedding + positional_encoding                                   # [B, T, C]
        x = self.decoder_block(x)                                                   # [B, T, C]
        x = self.layer_norm(x)                                                      # [B, T, C]
        logits = self.linear_head(x)                                                # [B, T, vocab_size]

        if target is None:
            return logits, None

        # Shift logits and targets for causal language modeling
        logits = logits[:, :-1, :]      # [B, T-1, vocab_size]
        target = target[:, 1:]          # [B, T-1]

        # Flatten for loss
        logits = logits.contiguous().view(-1, logits.size(-1))  # [B*(T-1), vocab_size]
        target = target.contiguous().view(-1)                   # [B*(T-1)]

        loss = functional.cross_entropy(logits, target, ignore_index=-100)

        return logits, loss


    @torch.no_grad()
    def generate(self, index, max_tokens=512):
        for _ in range(max_tokens):
            index_cond = index[:,-self.block_size:]
            logits, loss = self.forward(index_cond)
            logits = logits[:,-1,:]
            probs = torch.softmax(logits, dim=-1)

            next_index = torch.multinomial(probs, num_samples=1)
            # if next_index == self.eos_id:
            #     break
            index = torch.cat((index,next_index), dim=1)
        return index

class NovaConfig(PretrainedConfig):
    model_type = "nova"

    def __init__(self, vocab_size=6000, block_size=256, feature_embed=640, n_layers=4, n_heads=8, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_embd = feature_embed
        self.n_layer = n_layers
        self.n_head = n_heads

class NovaForCausalLM(PreTrainedModel):
    config_class = NovaConfig

    def __init__(self, config: NovaConfig):
        super().__init__(config)
        # your original model init logic here
        self.vocab_size = config.vocab_size
        self.block_size = config.block_size
        self.model = NOVA(vocab_size=self.vocab_size, block_size=self.block_size,
                          feature_embed=config.n_embd, n_layers=config.n_layer, n_heads=config.n_head)
        self.post_init()  # important for HF compatibility

    def forward(self, input_ids, labels=None):
        return self.model(input_ids, labels)
    
    def generate(self, input_ids, max_length=256):
        return self.model.generate(input_ids, max_length)