mesko-tts / bio_llm /model /embedding.py
mesklintech's picture
Publish BioVoice-TTS sparse energy checkpoint and model card
424c56c verified
raw
history blame contribute delete
543 Bytes
import torch
from torch import nn
class TokenEmbedding(nn.Module):
"""Maps token ids to continuous vectors x_t = W_embed[token]."""
def __init__(self, vocab_size: int, d_model: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
@property
def weight(self) -> torch.Tensor:
return self.embedding.weight
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(token_ids)