File size: 1,625 Bytes
2981407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""

Low-Rank Factorized Embedding. 



IMPORTANT: Uses standard nn.Linear for projection, NOT BitLinear. 

Embeddings need full precision for good token representations.

"""

import torch
import torch.nn as nn

class FactorizedEmbedding(nn.Module):
    """

    Low-Rank Factorized Embedding:  vocab → d_embed_rank → d_model

    

    Uses standard Linear (not BitLinear) for the projection.

    Embeddings are memory lookups - they benefit from full precision.

    """
    
    def __init__(self, vocab_size, d_model, d_embed_rank=96):
        super().__init__()
        self.vocab_size = vocab_size
        self. d_model = d_model
        self.d_embed_rank = d_embed_rank
        
        # Embedding table:  vocab → compressed
        self.embed = nn.Embedding(vocab_size, d_embed_rank)
        
        # Projection: compressed → full (standard Linear, NOT BitLinear)
        self.proj = nn.Linear(d_embed_rank, d_model, bias=False)
        
        # Initialize
        nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
        
        print(f"FactorizedEmbedding:  {vocab_size} × {d_embed_rank}{d_model}")
        print(f"  Params: {self.get_num_params()/1e6:.2f}M (vs {vocab_size * d_model/1e6:.2f}M dense)")
    
    def forward(self, input_ids):
        x = self.embed(input_ids)  # [B, S, d_embed_rank]
        x = self.proj(x)           # [B, S, d_model]
        return x
    
    def get_num_params(self):
        return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model