""" Low-Rank Factorized Embedding. IMPORTANT: Uses standard nn.Linear for projection, NOT BitLinear. Embeddings need full precision for good token representations. """ import torch import torch.nn as nn class FactorizedEmbedding(nn.Module): """ Low-Rank Factorized Embedding: vocab → d_embed_rank → d_model Uses standard Linear (not BitLinear) for the projection. Embeddings are memory lookups - they benefit from full precision. """ def __init__(self, vocab_size, d_model, d_embed_rank=96): super().__init__() self.vocab_size = vocab_size self. d_model = d_model self.d_embed_rank = d_embed_rank # Embedding table: vocab → compressed self.embed = nn.Embedding(vocab_size, d_embed_rank) # Projection: compressed → full (standard Linear, NOT BitLinear) self.proj = nn.Linear(d_embed_rank, d_model, bias=False) # Initialize nn.init.normal_(self.embed.weight, mean=0.0, std=0.02) nn.init.normal_(self.proj.weight, mean=0.0, std=0.02) print(f"FactorizedEmbedding: {vocab_size} × {d_embed_rank} → {d_model}") print(f" Params: {self.get_num_params()/1e6:.2f}M (vs {vocab_size * d_model/1e6:.2f}M dense)") def forward(self, input_ids): x = self.embed(input_ids) # [B, S, d_embed_rank] x = self.proj(x) # [B, S, d_model] return x def get_num_params(self): return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model