"""嵌入层:Token Embedding 词嵌入""" # 2026-01-23 import torch.nn as nn class TokenEmbedding(nn.Module): """词嵌入""" def __init__(self, vocab_size, hidden_size): """ 初始化词嵌入层 参数: vocab_size: 词汇表大小 hidden_size: 隐藏层维度 """ super().__init__() self.vocab_size = vocab_size self.hidden_size = hidden_size # 词嵌入层:将 token ID 映射为向量 self.embedding = nn.Embedding(vocab_size, hidden_size) def forward(self, input_ids): """ 前向传播 参数: input_ids: Token ID 张量,形状为 (batch_size, seq_len) 返回: 嵌入向量,形状为 (batch_size, seq_len, hidden_size) """ return self.embedding(input_ids)