https://github.com/ximeiorg/predictive-text

本项目支持将训练好的中文预测性文本模型导出为 ONNX 格式,可在移动端、边缘设备和服务器上高效部署推理。

使用示例

Python 中使用 ONNX 推理

import numpy as np
import onnxruntime as ort

# 加载 ONNX 模型
session = ort.InferenceSession("mobile/small/model.onnx")

# 获取输入输出名称
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# 准备输入 (token IDs)
input_ids = np.array([[1, 234, 567, 890]], dtype=np.int64)

# 推理
logits = session.run([output_name], {input_name: input_ids})[0]

# 获取最后一个位置的预测
last_logits = logits[0, -1, :]

# Top-5 候选词
top_5_indices = np.argsort(last_logits)[-5:][::-1]
print("Top-5 预测:", top_5_indices)

验证模型正确性

# 对比 PyTorch 和 ONNX 推理结果
uv run scripts/test_onnx_inference.py \
    --checkpoint output/small/best_model.pt \
    --tokenizer data/tokenizer.json \
    --test-dynamic

# 完整验证
uv run scripts/verify_model.py --onnx mobile/small/model.onnx

量化支持

INT8 动态量化

# 导出量化模型
uv run src/export.py \
    --checkpoint output/small/best_model.pt \
    --tokenizer data/tokenizer.json \
    --format quantized

校准量化 (推荐)

使用校准数据的静态量化可以更好地保持精度:

uv run scripts/quantize_with_calibration.py \
    --checkpoint output/small/best_model.pt \
    --train-data data/train.bin \
    --vocab data/vocab.json \
    --method onnx_static \
    --calibration-samples 200
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support