| --- |
| language: |
| - zh |
| --- |
| |
| [https://github.com/ximeiorg/predictive-text](https://github.com/ximeiorg/predictive-text) |
|
|
| 本项目支持将训练好的中文预测性文本模型导出为 ONNX 格式,可在移动端、边缘设备和服务器上高效部署推理。 |
|
|
| ## 使用示例 |
|
|
| ### Python 中使用 ONNX 推理 |
|
|
| ```python |
| import numpy as np |
| import onnxruntime as ort |
| |
| # 加载 ONNX 模型 |
| session = ort.InferenceSession("mobile/small/model.onnx") |
| |
| # 获取输入输出名称 |
| input_name = session.get_inputs()[0].name |
| output_name = session.get_outputs()[0].name |
| |
| # 准备输入 (token IDs) |
| input_ids = np.array([[1, 234, 567, 890]], dtype=np.int64) |
| |
| # 推理 |
| logits = session.run([output_name], {input_name: input_ids})[0] |
| |
| # 获取最后一个位置的预测 |
| last_logits = logits[0, -1, :] |
| |
| # Top-5 候选词 |
| top_5_indices = np.argsort(last_logits)[-5:][::-1] |
| print("Top-5 预测:", top_5_indices) |
| ``` |
|
|
| ### 验证模型正确性 |
|
|
| ```bash |
| # 对比 PyTorch 和 ONNX 推理结果 |
| uv run scripts/test_onnx_inference.py \ |
| --checkpoint output/small/best_model.pt \ |
| --tokenizer data/tokenizer.json \ |
| --test-dynamic |
| |
| # 完整验证 |
| uv run scripts/verify_model.py --onnx mobile/small/model.onnx |
| ``` |
|
|
| --- |
|
|
| ## 量化支持 |
|
|
| ### INT8 动态量化 |
|
|
| ```bash |
| # 导出量化模型 |
| uv run src/export.py \ |
| --checkpoint output/small/best_model.pt \ |
| --tokenizer data/tokenizer.json \ |
| --format quantized |
| ``` |
|
|
| ### 校准量化 (推荐) |
|
|
| 使用校准数据的静态量化可以更好地保持精度: |
|
|
| ```bash |
| uv run scripts/quantize_with_calibration.py \ |
| --checkpoint output/small/best_model.pt \ |
| --train-data data/train.bin \ |
| --vocab data/vocab.json \ |
| --method onnx_static \ |
| --calibration-samples 200 |
| ``` |
|
|