File size: 1,625 Bytes
2981407 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
"""
Low-Rank Factorized Embedding.
IMPORTANT: Uses standard nn.Linear for projection, NOT BitLinear.
Embeddings need full precision for good token representations.
"""
import torch
import torch.nn as nn
class FactorizedEmbedding(nn.Module):
"""
Low-Rank Factorized Embedding: vocab → d_embed_rank → d_model
Uses standard Linear (not BitLinear) for the projection.
Embeddings are memory lookups - they benefit from full precision.
"""
def __init__(self, vocab_size, d_model, d_embed_rank=96):
super().__init__()
self.vocab_size = vocab_size
self. d_model = d_model
self.d_embed_rank = d_embed_rank
# Embedding table: vocab → compressed
self.embed = nn.Embedding(vocab_size, d_embed_rank)
# Projection: compressed → full (standard Linear, NOT BitLinear)
self.proj = nn.Linear(d_embed_rank, d_model, bias=False)
# Initialize
nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
print(f"FactorizedEmbedding: {vocab_size} × {d_embed_rank} → {d_model}")
print(f" Params: {self.get_num_params()/1e6:.2f}M (vs {vocab_size * d_model/1e6:.2f}M dense)")
def forward(self, input_ids):
x = self.embed(input_ids) # [B, S, d_embed_rank]
x = self.proj(x) # [B, S, d_model]
return x
def get_num_params(self):
return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model
|