""" Low-Rank Factorized Embedding. Uses standard nn.Linear for projection (NOT ternary quantization). Embeddings should use full precision for good token representations. """ import torch import torch.nn as nn class FactorizedEmbedding(nn.Module): """ Low-Rank Factorized Embedding: vocab → d_embed_rank → d_model Uses standard Linear layers (no quantization) for full precision. Reduces embedding parameters from vocab_size × d_model to: vocab_size × d_embed_rank + d_embed_rank × d_model """ def __init__(self, vocab_size, d_model, d_embed_rank=96): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.d_embed_rank = d_embed_rank # Embedding table: vocab → compressed rank self.embed = nn.Embedding(vocab_size, d_embed_rank) # Projection: compressed → full (standard Linear) self.proj = nn.Linear(d_embed_rank, d_model, bias=False) # Initialize with small weights for stable training nn.init.normal_(self.embed.weight, mean=0.0, std=0.02) nn.init.normal_(self.proj.weight, mean=0.0, std=0.02) def forward(self, input_ids): """ Args: input_ids: [batch_size, seq_len] tensor of token IDs Returns: embeddings: [batch_size, seq_len, d_model] """ x = self.embed(input_ids) # [B, S, d_embed_rank] x = self.proj(x) # [B, S, d_model] return x def get_num_params(self): """Return total number of parameters.""" return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model