|
|
"""
|
|
|
Low-Rank Factorized Embedding.
|
|
|
|
|
|
IMPORTANT: Uses standard nn.Linear for projection, NOT BitLinear.
|
|
|
Embeddings need full precision for good token representations.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class FactorizedEmbedding(nn.Module):
|
|
|
"""
|
|
|
Low-Rank Factorized Embedding: vocab β d_embed_rank β d_model
|
|
|
|
|
|
Uses standard Linear (not BitLinear) for the projection.
|
|
|
Embeddings are memory lookups - they benefit from full precision.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, vocab_size, d_model, d_embed_rank=96):
|
|
|
super().__init__()
|
|
|
self.vocab_size = vocab_size
|
|
|
self. d_model = d_model
|
|
|
self.d_embed_rank = d_embed_rank
|
|
|
|
|
|
|
|
|
self.embed = nn.Embedding(vocab_size, d_embed_rank)
|
|
|
|
|
|
|
|
|
self.proj = nn.Linear(d_embed_rank, d_model, bias=False)
|
|
|
|
|
|
|
|
|
nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
|
|
|
nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
|
|
|
|
|
|
print(f"FactorizedEmbedding: {vocab_size} Γ {d_embed_rank} β {d_model}")
|
|
|
print(f" Params: {self.get_num_params()/1e6:.2f}M (vs {vocab_size * d_model/1e6:.2f}M dense)")
|
|
|
|
|
|
def forward(self, input_ids):
|
|
|
x = self.embed(input_ids)
|
|
|
x = self.proj(x)
|
|
|
return x
|
|
|
|
|
|
def get_num_params(self):
|
|
|
return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model
|
|
|
|