TaoNet-pico-T1 / factorized_embedding.py
Lobakkang's picture
Upload TaoNet model to HuggingFace Hub
2981407 verified
"""
Low-Rank Factorized Embedding.
IMPORTANT: Uses standard nn.Linear for projection, NOT BitLinear.
Embeddings need full precision for good token representations.
"""
import torch
import torch.nn as nn
class FactorizedEmbedding(nn.Module):
"""
Low-Rank Factorized Embedding: vocab β†’ d_embed_rank β†’ d_model
Uses standard Linear (not BitLinear) for the projection.
Embeddings are memory lookups - they benefit from full precision.
"""
def __init__(self, vocab_size, d_model, d_embed_rank=96):
super().__init__()
self.vocab_size = vocab_size
self. d_model = d_model
self.d_embed_rank = d_embed_rank
# Embedding table: vocab β†’ compressed
self.embed = nn.Embedding(vocab_size, d_embed_rank)
# Projection: compressed β†’ full (standard Linear, NOT BitLinear)
self.proj = nn.Linear(d_embed_rank, d_model, bias=False)
# Initialize
nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
print(f"FactorizedEmbedding: {vocab_size} Γ— {d_embed_rank} β†’ {d_model}")
print(f" Params: {self.get_num_params()/1e6:.2f}M (vs {vocab_size * d_model/1e6:.2f}M dense)")
def forward(self, input_ids):
x = self.embed(input_ids) # [B, S, d_embed_rank]
x = self.proj(x) # [B, S, d_model]
return x
def get_num_params(self):
return self.vocab_size * self.d_embed_rank + self.d_embed_rank * self.d_model