Spaces:
Build error
Build error
| import torch | |
| import json | |
| import torch.nn as nn | |
| from .model import TransformerModel # 确保与训练代码的模型定义一致 | |
| # 配置参数 | |
| MAX_SEQ_LEN = 60 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_PATH = "./results/model/model5.pth" # 模型权重路径 | |
| SRC_VOCAB_PATH = "word2int_en.json" # 英文词汇表路径 | |
| TGT_VOCAB_PATH = "word2int_cn.json" # 中文词汇表路径 | |
| # 加载词汇表 | |
| def load_vocab(file_path): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| # 编码输入句子 | |
| def encode_sentence(sentence, vocab, max_len): | |
| tokens = sentence.split() | |
| ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens] | |
| ids = [vocab["<BOS>"]] + ids[:max_len - 2] + [vocab["<EOS>"]] | |
| ids += [vocab["<PAD>"]] * (max_len - len(ids)) | |
| return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(DEVICE) # 添加 batch 维度 | |
| # 解码输出句子 | |
| def decode_sentence(ids, vocab): | |
| reverse_vocab = {idx: word for word, idx in vocab.items()} | |
| tokens = [reverse_vocab[id] for id in ids if id not in {vocab["<PAD>"], vocab["<BOS>"], vocab["<EOS>"]}] | |
| return "".join(tokens) # 中文不需要空格 | |
| # 加载模型 | |
| def load_model(model_path, src_vocab_size, tgt_vocab_size): | |
| model = TransformerModel(src_vocab_size, tgt_vocab_size).to(DEVICE) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.eval() | |
| return model | |
| # 翻译函数 | |
| def translate(model, sentence, src_vocab, tgt_vocab, max_len): | |
| # 编码输入句子 | |
| src_tensor = encode_sentence(sentence, src_vocab, max_len) | |
| # 初始化目标序列为 <BOS> | |
| tgt_tensor = torch.tensor([tgt_vocab["<BOS>"]], dtype=torch.long).unsqueeze(0).to(DEVICE) | |
| for _ in range(max_len): | |
| # 生成目标序列的 mask | |
| tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) | |
| # 推理得到输出 | |
| output = model(src_tensor, tgt_tensor, tgt_mask=tgt_mask) | |
| # 取最后一个时间步的预测结果 | |
| next_token = output[:, -1, :].argmax(dim=-1).item() | |
| # 将预测的 token 添加到目标序列中 | |
| tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1) | |
| # 如果预测到 <EOS>,停止生成 | |
| if next_token == tgt_vocab["<EOS>"]: | |
| break | |
| # 解码目标序列为句子 | |
| return decode_sentence(tgt_tensor.squeeze(0).tolist(), tgt_vocab) |