| """嵌入层:Token Embedding 词嵌入""" | |
| # 2026-01-23 | |
| import torch.nn as nn | |
| class TokenEmbedding(nn.Module): | |
| """词嵌入""" | |
| def __init__(self, vocab_size, hidden_size): | |
| """ | |
| 初始化词嵌入层 | |
| 参数: | |
| vocab_size: 词汇表大小 | |
| hidden_size: 隐藏层维度 | |
| """ | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| # 词嵌入层:将 token ID 映射为向量 | |
| self.embedding = nn.Embedding(vocab_size, hidden_size) | |
| def forward(self, input_ids): | |
| """ | |
| 前向传播 | |
| 参数: | |
| input_ids: Token ID 张量,形状为 (batch_size, seq_len) | |
| 返回: | |
| 嵌入向量,形状为 (batch_size, seq_len, hidden_size) | |
| """ | |
| return self.embedding(input_ids) | |