Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from translation import translate, load_model, load_vocab | |
| MAX_SEQ_LEN = 60 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_PATH = "./translation/model.pth" # 模型权重路径 | |
| SRC_VOCAB_PATH = "./translation/word2int_en.json" # 英文词汇表路径 | |
| TGT_VOCAB_PATH = "./translation/word2int_cn.json" # 中文词汇表路径 | |
| # 加载词汇表 | |
| src_vocab = load_vocab(SRC_VOCAB_PATH) | |
| tgt_vocab = load_vocab(TGT_VOCAB_PATH) | |
| # 加载模型 | |
| model = load_model(MODEL_PATH, len(src_vocab), len(tgt_vocab)) | |
| # 翻译函数包装为 Gradio 接口 | |
| def translate_sentence(input_sentence): | |
| return translate(model, input_sentence, src_vocab, tgt_vocab, MAX_SEQ_LEN) | |
| # 创建 Gradio 接口 | |
| iface = gr.Interface( | |
| fn=translate_sentence, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter English sentence here..."), | |
| outputs=gr.Textbox(), | |
| title="NLP作业:基于Tranformer的机器翻译系统", | |
| description="输入英文输出中文喵", | |
| ) | |
| # 启动 Gradio 应用 | |
| iface.launch() |