LLM-fastAPI / models /transformer /module /embeddings.py
Songyou's picture
add files
f3b11f9
raw
history blame contribute delete
358 Bytes
import torch.nn as nn
import math
class Embeddings(nn.Module):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
# weight matrix, each row present one word
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, x):
return self.lut(x) * math.sqrt(self.d_model)