File size: 1,765 Bytes
3270dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""

Low-Rank Factorized Embedding. 



Uses standard nn.Linear for projection (NOT ternary quantization). 

Embeddings should use full precision for good token representations.

"""

import torch
import torch.nn as nn


class FactorizedEmbedding(nn.Module):
    """

    Low-Rank Factorized Embedding: vocab → d_embed_rank → d_model

    

    Uses standard Linear layers (no quantization) for full precision.

    Reduces embedding parameters from vocab_size × d_model to:

        vocab_size × d_embed_rank + d_embed_rank × d_model

    """
    
    def __init__(self, vocab_size, d_model, d_embed_rank=96):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.d_embed_rank = d_embed_rank
        
        # Embedding table: vocab → compressed rank
        self.embed = nn.Embedding(vocab_size, d_embed_rank)
        
        # Projection: compressed → full (standard Linear)
        self.proj = nn.Linear(d_embed_rank, d_model, bias=False)
        
        # Initialize with small weights for stable training
        nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
    
    def forward(self, input_ids):
        """

        Args:

            input_ids: [batch_size, seq_len] tensor of token IDs

        

        Returns:

            embeddings: [batch_size, seq_len, d_model]

        """
        x = self.embed(input_ids)  # [B, S, d_embed_rank]
        x = self.proj(x)           # [B, S, d_model]
        return x
    
    def get_num_params(self):
        """Return total number of parameters."""
        return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model