metadata
language:
- zh
https://github.com/ximeiorg/predictive-text
本项目支持将训练好的中文预测性文本模型导出为 ONNX 格式,可在移动端、边缘设备和服务器上高效部署推理。
使用示例
Python 中使用 ONNX 推理
import numpy as np
import onnxruntime as ort
# 加载 ONNX 模型
session = ort.InferenceSession("mobile/small/model.onnx")
# 获取输入输出名称
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 准备输入 (token IDs)
input_ids = np.array([[1, 234, 567, 890]], dtype=np.int64)
# 推理
logits = session.run([output_name], {input_name: input_ids})[0]
# 获取最后一个位置的预测
last_logits = logits[0, -1, :]
# Top-5 候选词
top_5_indices = np.argsort(last_logits)[-5:][::-1]
print("Top-5 预测:", top_5_indices)
验证模型正确性
# 对比 PyTorch 和 ONNX 推理结果
uv run scripts/test_onnx_inference.py \
--checkpoint output/small/best_model.pt \
--tokenizer data/tokenizer.json \
--test-dynamic
# 完整验证
uv run scripts/verify_model.py --onnx mobile/small/model.onnx
量化支持
INT8 动态量化
# 导出量化模型
uv run src/export.py \
--checkpoint output/small/best_model.pt \
--tokenizer data/tokenizer.json \
--format quantized
校准量化 (推荐)
使用校准数据的静态量化可以更好地保持精度:
uv run scripts/quantize_with_calibration.py \
--checkpoint output/small/best_model.pt \
--train-data data/train.bin \
--vocab data/vocab.json \
--method onnx_static \
--calibration-samples 200