Initial release: Qwen3-Reranker-4B CoreML ANE-optimized bundle + service
Browse files- .gitattributes +1 -0
- README.md +133 -0
- bundles/qwen3_reranker_ane_bundle_4b/manifest.json +29 -0
- bundles/qwen3_reranker_ane_bundle_4b/packages/b1_s128.mlpackage/Data/com.apple.CoreML/model.mlmodel +3 -0
- bundles/qwen3_reranker_ane_bundle_4b/packages/b1_s128.mlpackage/Data/com.apple.CoreML/weights/weight.bin +3 -0
- bundles/qwen3_reranker_ane_bundle_4b/packages/b1_s128.mlpackage/Manifest.json +18 -0
- bundles/qwen3_reranker_ane_bundle_4b/tokenizer/chat_template.jinja +85 -0
- bundles/qwen3_reranker_ane_bundle_4b/tokenizer/tokenizer.json +3 -0
- bundles/qwen3_reranker_ane_bundle_4b/tokenizer/tokenizer_config.json +14 -0
- qwen3_ane_rerank/__init__.py +5 -0
- qwen3_ane_rerank/__main__.py +5 -0
- qwen3_ane_rerank/api.py +84 -0
- qwen3_ane_rerank/cli.py +174 -0
- qwen3_ane_rerank/converter.py +267 -0
- qwen3_ane_rerank/manifest.py +109 -0
- qwen3_ane_rerank/profiles.py +47 -0
- qwen3_ane_rerank/runtime.py +278 -0
- requirements-service.txt +5 -0
- run_server.sh +74 -0
- setup_venv.sh +33 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
bundles/qwen3_reranker_ane_bundle_4b/tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- zh
|
| 6 |
+
tags:
|
| 7 |
+
- qwen3
|
| 8 |
+
- reranker
|
| 9 |
+
- coreml
|
| 10 |
+
- apple-silicon
|
| 11 |
+
- ane
|
| 12 |
+
pipeline_tag: text-ranking
|
| 13 |
+
library_name: coremltools
|
| 14 |
+
base_model: Qwen/Qwen3-Reranker-4B
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# Qwen3-Reranker-4B-CoreML (ANE-Optimized)
|
| 18 |
+
|
| 19 |
+
## English
|
| 20 |
+
|
| 21 |
+
This repository provides a pre-converted CoreML bundle derived from `Qwen3-Reranker-4B` and an OpenAI-style rerank API service for Apple Silicon.
|
| 22 |
+
|
| 23 |
+
### Bundle Specs
|
| 24 |
+
|
| 25 |
+
| Item | Value |
|
| 26 |
+
| --- | --- |
|
| 27 |
+
| Base model | `Qwen/Qwen3-Reranker-4B` |
|
| 28 |
+
| Task | Text reranking |
|
| 29 |
+
| Profiles | `b1_s128` |
|
| 30 |
+
| Bundle path | `bundles/qwen3_reranker_ane_bundle_4b` |
|
| 31 |
+
| Default model id | `qwen3-reranker-4b-ane` |
|
| 32 |
+
| Package size (approx.) | `7.5G` |
|
| 33 |
+
|
| 34 |
+
### Scope
|
| 35 |
+
|
| 36 |
+
- This release is **text-only reranking**.
|
| 37 |
+
- Endpoint: `POST /rerank` and `POST /v1/rerank`.
|
| 38 |
+
|
| 39 |
+
### Quick Start
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
./setup_venv.sh
|
| 43 |
+
./run_server.sh
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Health check:
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
curl -s http://127.0.0.1:8000/health
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
Rerank request:
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
curl -s http://127.0.0.1:8000/v1/rerank \
|
| 56 |
+
-H 'Content-Type: application/json' \
|
| 57 |
+
-d '{
|
| 58 |
+
"query": "capital of China",
|
| 59 |
+
"documents": [
|
| 60 |
+
"The capital of China is Beijing.",
|
| 61 |
+
"Gravity is a force."
|
| 62 |
+
],
|
| 63 |
+
"top_n": 2,
|
| 64 |
+
"return_documents": true
|
| 65 |
+
}'
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### Notes
|
| 69 |
+
|
| 70 |
+
- Fixed shape profile (`s128`) for low-power deployment.
|
| 71 |
+
- Inputs longer than profile capacity return an explicit error.
|
| 72 |
+
- First request has warm-up latency.
|
| 73 |
+
- Default compute setting is `cpu_and_ne` (ANE-preferred, not ANE-guaranteed).
|
| 74 |
+
|
| 75 |
+
## 中文
|
| 76 |
+
|
| 77 |
+
这个仓库提供基于 `Qwen3-Reranker-4B` 的预转换 CoreML bundle,以及可直接运行的文本重排服务(`/v1/rerank`)。
|
| 78 |
+
|
| 79 |
+
### Bundle 规格
|
| 80 |
+
|
| 81 |
+
| 项目 | 值 |
|
| 82 |
+
| --- | --- |
|
| 83 |
+
| 基础模型 | `Qwen/Qwen3-Reranker-4B` |
|
| 84 |
+
| 任务类型 | 文本重排 |
|
| 85 |
+
| Profile | `b1_s128` |
|
| 86 |
+
| Bundle 路径 | `bundles/qwen3_reranker_ane_bundle_4b` |
|
| 87 |
+
| 默认模型名 | `qwen3-reranker-4b-ane` |
|
| 88 |
+
| 包体积(约) | `7.5G` |
|
| 89 |
+
|
| 90 |
+
### 范围说明
|
| 91 |
+
|
| 92 |
+
- 本版本仅支持**纯文本重排**。
|
| 93 |
+
- 接口为 `POST /rerank` 与 `POST /v1/rerank`。
|
| 94 |
+
|
| 95 |
+
### 快速开始
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
./setup_venv.sh
|
| 99 |
+
./run_server.sh
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
健康检查:
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
curl -s http://127.0.0.1:8000/health
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
重排请求:
|
| 109 |
+
|
| 110 |
+
```bash
|
| 111 |
+
curl -s http://127.0.0.1:8000/v1/rerank \
|
| 112 |
+
-H 'Content-Type: application/json' \
|
| 113 |
+
-d '{
|
| 114 |
+
"query": "capital of China",
|
| 115 |
+
"documents": [
|
| 116 |
+
"The capital of China is Beijing.",
|
| 117 |
+
"Gravity is a force."
|
| 118 |
+
],
|
| 119 |
+
"top_n": 2,
|
| 120 |
+
"return_documents": true
|
| 121 |
+
}'
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### 说明
|
| 125 |
+
|
| 126 |
+
- 固定 shape profile(`s128`),偏向低功耗部署。
|
| 127 |
+
- 输入超过 profile 上限会明确报错。
|
| 128 |
+
- 首次请求会有预热延迟。
|
| 129 |
+
- 默认 `cpu_and_ne`,是偏向 ANE 调度,不等于 100% 仅 ANE 执行。
|
| 130 |
+
|
| 131 |
+
## License
|
| 132 |
+
|
| 133 |
+
Apache-2.0. Please also follow the license and usage terms of the base Qwen model.
|
bundles/qwen3_reranker_ane_bundle_4b/manifest.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"format_version": 1,
|
| 3 |
+
"task": "rerank",
|
| 4 |
+
"model_name": "Qwen3-Reranker-4B",
|
| 5 |
+
"source_model_dir": "/Volumes/256G/Applications/ANE/Qwen3-Reranker-4B",
|
| 6 |
+
"tokenizer_dir": "tokenizer",
|
| 7 |
+
"hidden_size": 2560,
|
| 8 |
+
"yes_token_id": 9693,
|
| 9 |
+
"no_token_id": 2152,
|
| 10 |
+
"system_prompt": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".",
|
| 11 |
+
"pair_template": "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}",
|
| 12 |
+
"prefix_text": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n",
|
| 13 |
+
"suffix_text": "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
|
| 14 |
+
"created_at_utc": "2026-03-02T15:48:04.694204+00:00",
|
| 15 |
+
"profiles": [
|
| 16 |
+
{
|
| 17 |
+
"profile_id": "b1_s128",
|
| 18 |
+
"batch_size": 1,
|
| 19 |
+
"seq_len": 128,
|
| 20 |
+
"package_path": "packages/b1_s128.mlpackage",
|
| 21 |
+
"compiled_path": null,
|
| 22 |
+
"input_names": [
|
| 23 |
+
"input_ids",
|
| 24 |
+
"attention_mask"
|
| 25 |
+
],
|
| 26 |
+
"output_name": "score"
|
| 27 |
+
}
|
| 28 |
+
]
|
| 29 |
+
}
|
bundles/qwen3_reranker_ane_bundle_4b/packages/b1_s128.mlpackage/Data/com.apple.CoreML/model.mlmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:954e30d62948183fd3323ac8db3b1b2be3526ef788a19028807ce39642c4450b
|
| 3 |
+
size 791454
|
bundles/qwen3_reranker_ane_bundle_4b/packages/b1_s128.mlpackage/Data/com.apple.CoreML/weights/weight.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a76471b81324fce92ce1fe1edfbc9fe964abd77f3f97d84db70a1773f605de92
|
| 3 |
+
size 8043733440
|
bundles/qwen3_reranker_ane_bundle_4b/packages/b1_s128.mlpackage/Manifest.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fileFormatVersion": "1.0.0",
|
| 3 |
+
"itemInfoEntries": {
|
| 4 |
+
"82518FB0-9416-470C-98CD-55225EF89B96": {
|
| 5 |
+
"author": "com.apple.CoreML",
|
| 6 |
+
"description": "CoreML Model Specification",
|
| 7 |
+
"name": "model.mlmodel",
|
| 8 |
+
"path": "com.apple.CoreML/model.mlmodel"
|
| 9 |
+
},
|
| 10 |
+
"FB126581-AA8F-423B-B9F0-2EBA683BA825": {
|
| 11 |
+
"author": "com.apple.CoreML",
|
| 12 |
+
"description": "CoreML Model Weights",
|
| 13 |
+
"name": "weights",
|
| 14 |
+
"path": "com.apple.CoreML/weights"
|
| 15 |
+
}
|
| 16 |
+
},
|
| 17 |
+
"rootModelIdentifier": "82518FB0-9416-470C-98CD-55225EF89B96"
|
| 18 |
+
}
|
bundles/qwen3_reranker_ane_bundle_4b/tokenizer/chat_template.jinja
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 27 |
+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
| 28 |
+
{%- elif message.role == "assistant" %}
|
| 29 |
+
{%- set content = message.content %}
|
| 30 |
+
{%- set reasoning_content = '' %}
|
| 31 |
+
{%- if message.reasoning_content is defined and message.reasoning_content is not none %}
|
| 32 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 33 |
+
{%- else %}
|
| 34 |
+
{%- if '</think>' in message.content %}
|
| 35 |
+
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
| 36 |
+
{%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 37 |
+
{%- endif %}
|
| 38 |
+
{%- endif %}
|
| 39 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 40 |
+
{%- if loop.last or (not loop.last and reasoning_content) %}
|
| 41 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 42 |
+
{%- else %}
|
| 43 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 44 |
+
{%- endif %}
|
| 45 |
+
{%- else %}
|
| 46 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 47 |
+
{%- endif %}
|
| 48 |
+
{%- if message.tool_calls %}
|
| 49 |
+
{%- for tool_call in message.tool_calls %}
|
| 50 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 51 |
+
{{- '\n' }}
|
| 52 |
+
{%- endif %}
|
| 53 |
+
{%- if tool_call.function %}
|
| 54 |
+
{%- set tool_call = tool_call.function %}
|
| 55 |
+
{%- endif %}
|
| 56 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 57 |
+
{{- tool_call.name }}
|
| 58 |
+
{{- '", "arguments": ' }}
|
| 59 |
+
{%- if tool_call.arguments is string %}
|
| 60 |
+
{{- tool_call.arguments }}
|
| 61 |
+
{%- else %}
|
| 62 |
+
{{- tool_call.arguments | tojson }}
|
| 63 |
+
{%- endif %}
|
| 64 |
+
{{- '}\n</tool_call>' }}
|
| 65 |
+
{%- endfor %}
|
| 66 |
+
{%- endif %}
|
| 67 |
+
{{- '<|im_end|>\n' }}
|
| 68 |
+
{%- elif message.role == "tool" %}
|
| 69 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 70 |
+
{{- '<|im_start|>user' }}
|
| 71 |
+
{%- endif %}
|
| 72 |
+
{{- '\n<tool_response>\n' }}
|
| 73 |
+
{{- message.content }}
|
| 74 |
+
{{- '\n</tool_response>' }}
|
| 75 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 76 |
+
{{- '<|im_end|>\n' }}
|
| 77 |
+
{%- endif %}
|
| 78 |
+
{%- endif %}
|
| 79 |
+
{%- endfor %}
|
| 80 |
+
{%- if add_generation_prompt %}
|
| 81 |
+
{{- '<|im_start|>assistant\n' }}
|
| 82 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 83 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 84 |
+
{%- endif %}
|
| 85 |
+
{%- endif %}
|
bundles/qwen3_reranker_ane_bundle_4b/tokenizer/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
|
| 3 |
+
size 11422650
|
bundles/qwen3_reranker_ane_bundle_4b/tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": null,
|
| 5 |
+
"clean_up_tokenization_spaces": false,
|
| 6 |
+
"eos_token": "<|im_end|>",
|
| 7 |
+
"errors": "replace",
|
| 8 |
+
"is_local": true,
|
| 9 |
+
"model_max_length": 131072,
|
| 10 |
+
"pad_token": "<|endoftext|>",
|
| 11 |
+
"split_special_tokens": false,
|
| 12 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 13 |
+
"unk_token": null
|
| 14 |
+
}
|
qwen3_ane_rerank/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Qwen3 reranker conversion + ANE serving toolkit."""
|
| 2 |
+
|
| 3 |
+
from .manifest import BundleManifest, ProfileEntry
|
| 4 |
+
|
| 5 |
+
__all__ = ["BundleManifest", "ProfileEntry"]
|
qwen3_ane_rerank/__main__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cli import main
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
main()
|
qwen3_ane_rerank/api.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI, HTTPException
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
from .runtime import Qwen3AneRerankRuntime
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RerankRequest(BaseModel):
|
| 12 |
+
query: str
|
| 13 |
+
documents: list[str]
|
| 14 |
+
model: str | None = None
|
| 15 |
+
top_n: int | None = Field(default=None, ge=1)
|
| 16 |
+
return_documents: bool = False
|
| 17 |
+
instruction: str | None = None
|
| 18 |
+
user: str | None = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_app(runtime: Qwen3AneRerankRuntime, default_model_id: str | None = None) -> FastAPI:
|
| 22 |
+
app = FastAPI(title="Qwen3 ANE Reranker Service", version="0.1.0")
|
| 23 |
+
|
| 24 |
+
@app.get("/health")
|
| 25 |
+
def health() -> dict[str, Any]:
|
| 26 |
+
return {
|
| 27 |
+
"ok": True,
|
| 28 |
+
"task": "rerank",
|
| 29 |
+
"model": default_model_id or runtime.manifest.model_name,
|
| 30 |
+
"profiles": [
|
| 31 |
+
{
|
| 32 |
+
"id": p.entry.profile_id,
|
| 33 |
+
"batch_size": p.entry.batch_size,
|
| 34 |
+
"seq_len": p.entry.seq_len,
|
| 35 |
+
}
|
| 36 |
+
for p in runtime.profiles
|
| 37 |
+
],
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
@app.post("/rerank")
|
| 41 |
+
@app.post("/v1/rerank")
|
| 42 |
+
def rerank(req: RerankRequest) -> dict[str, Any]:
|
| 43 |
+
try:
|
| 44 |
+
if req.query == "":
|
| 45 |
+
raise ValueError("query must not be empty")
|
| 46 |
+
if not req.documents:
|
| 47 |
+
raise ValueError("documents must not be empty")
|
| 48 |
+
if any(doc == "" for doc in req.documents):
|
| 49 |
+
raise ValueError("documents must not contain empty strings")
|
| 50 |
+
|
| 51 |
+
results, prompt_tokens = runtime.rerank(
|
| 52 |
+
query=req.query,
|
| 53 |
+
documents=req.documents,
|
| 54 |
+
top_n=req.top_n,
|
| 55 |
+
instruction=req.instruction,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
data = []
|
| 59 |
+
for row in results:
|
| 60 |
+
item = {
|
| 61 |
+
"object": "rerank_result",
|
| 62 |
+
"index": row["index"],
|
| 63 |
+
"relevance_score": row["relevance_score"],
|
| 64 |
+
}
|
| 65 |
+
if req.return_documents:
|
| 66 |
+
item["document"] = req.documents[row["index"]]
|
| 67 |
+
data.append(item)
|
| 68 |
+
|
| 69 |
+
model_name = req.model or default_model_id or runtime.manifest.model_name
|
| 70 |
+
return {
|
| 71 |
+
"object": "list",
|
| 72 |
+
"data": data,
|
| 73 |
+
"model": model_name,
|
| 74 |
+
"usage": {
|
| 75 |
+
"prompt_tokens": int(prompt_tokens),
|
| 76 |
+
"total_tokens": int(prompt_tokens),
|
| 77 |
+
},
|
| 78 |
+
}
|
| 79 |
+
except ValueError as exc:
|
| 80 |
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 81 |
+
except RuntimeError as exc:
|
| 82 |
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
| 83 |
+
|
| 84 |
+
return app
|
qwen3_ane_rerank/cli.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from .converter import BuildOptions, build_bundle
|
| 7 |
+
from .profiles import parse_profiles
|
| 8 |
+
from .runtime import Qwen3AneRerankRuntime
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _add_common_build_args(parser: argparse.ArgumentParser) -> None:
|
| 12 |
+
parser.add_argument(
|
| 13 |
+
"--profiles",
|
| 14 |
+
type=str,
|
| 15 |
+
default=None,
|
| 16 |
+
help="Shape profiles as comma list BxS (e.g. 1x128,4x128)",
|
| 17 |
+
)
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--target",
|
| 20 |
+
type=str,
|
| 21 |
+
default="macOS14",
|
| 22 |
+
choices=["macOS14", "macOS15", "iOS17", "iOS18"],
|
| 23 |
+
help="Core ML minimum deployment target",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--compile-mlmodelc",
|
| 27 |
+
action=argparse.BooleanOptionalAction,
|
| 28 |
+
default=True,
|
| 29 |
+
help="Compile .mlpackage into .mlmodelc with coremlcompiler",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--system-prompt",
|
| 33 |
+
default=(
|
| 34 |
+
"Judge whether the Document meets the requirements based on the Query and the "
|
| 35 |
+
'Instruct provided. Note that the answer can only be "yes" or "no".'
|
| 36 |
+
),
|
| 37 |
+
help="System prompt used in reranker prompt template",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def cmd_convert(args: argparse.Namespace) -> None:
|
| 42 |
+
profiles = parse_profiles(args.profiles)
|
| 43 |
+
options = BuildOptions(
|
| 44 |
+
model_dir=Path(args.model_dir),
|
| 45 |
+
bundle_dir=Path(args.bundle_dir),
|
| 46 |
+
profiles=profiles,
|
| 47 |
+
compile_mlmodelc=bool(args.compile_mlmodelc),
|
| 48 |
+
minimum_deployment_target=args.target,
|
| 49 |
+
system_prompt=args.system_prompt,
|
| 50 |
+
)
|
| 51 |
+
manifest = build_bundle(options)
|
| 52 |
+
print(f"Built bundle at: {Path(args.bundle_dir).resolve()}")
|
| 53 |
+
print(f"Model: {manifest.model_name}")
|
| 54 |
+
print(f"Task: {manifest.task}")
|
| 55 |
+
print(f"Hidden size: {manifest.hidden_size}")
|
| 56 |
+
print(f"Token ids yes/no: {manifest.yes_token_id}/{manifest.no_token_id}")
|
| 57 |
+
print("Profiles:")
|
| 58 |
+
for entry in manifest.profiles:
|
| 59 |
+
print(
|
| 60 |
+
f" - {entry.profile_id}: batch={entry.batch_size}, seq={entry.seq_len}, "
|
| 61 |
+
f"model={entry.compiled_path or entry.package_path}"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def cmd_serve(args: argparse.Namespace) -> None:
|
| 66 |
+
bundle_dir = Path(args.bundle_dir)
|
| 67 |
+
manifest_path = bundle_dir / "manifest.json"
|
| 68 |
+
|
| 69 |
+
if not manifest_path.exists():
|
| 70 |
+
if not args.auto_build:
|
| 71 |
+
raise SystemExit(
|
| 72 |
+
f"Bundle not found at {bundle_dir}. Run convert first or pass --auto-build --model-dir."
|
| 73 |
+
)
|
| 74 |
+
if not args.model_dir:
|
| 75 |
+
raise SystemExit("--model-dir is required when --auto-build is enabled")
|
| 76 |
+
|
| 77 |
+
profiles = parse_profiles(args.profiles)
|
| 78 |
+
options = BuildOptions(
|
| 79 |
+
model_dir=Path(args.model_dir),
|
| 80 |
+
bundle_dir=bundle_dir,
|
| 81 |
+
profiles=profiles,
|
| 82 |
+
compile_mlmodelc=bool(args.compile_mlmodelc),
|
| 83 |
+
minimum_deployment_target=args.target,
|
| 84 |
+
system_prompt=args.system_prompt,
|
| 85 |
+
)
|
| 86 |
+
print("Bundle not found; building from source model...")
|
| 87 |
+
build_bundle(options)
|
| 88 |
+
|
| 89 |
+
runtime = Qwen3AneRerankRuntime(bundle_dir=bundle_dir, compute_units=args.compute_units)
|
| 90 |
+
|
| 91 |
+
from .api import create_app
|
| 92 |
+
import uvicorn
|
| 93 |
+
|
| 94 |
+
app = create_app(runtime=runtime, default_model_id=args.model_id)
|
| 95 |
+
uvicorn.run(
|
| 96 |
+
app,
|
| 97 |
+
host=args.host,
|
| 98 |
+
port=args.port,
|
| 99 |
+
log_level=args.log_level,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 104 |
+
parser = argparse.ArgumentParser(
|
| 105 |
+
prog="qwen3-ane-rerank",
|
| 106 |
+
description="Convert Qwen3-Reranker model to Core ML ANE bundle and serve /v1/rerank endpoint.",
|
| 107 |
+
)
|
| 108 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
| 109 |
+
|
| 110 |
+
convert_parser = subparsers.add_parser(
|
| 111 |
+
"convert",
|
| 112 |
+
help="Convert local HF Qwen3-Reranker model into ANE-ready Core ML profile bundle",
|
| 113 |
+
)
|
| 114 |
+
convert_parser.add_argument("--model-dir", required=True, help="Path to source HF model directory")
|
| 115 |
+
convert_parser.add_argument(
|
| 116 |
+
"--bundle-dir",
|
| 117 |
+
required=True,
|
| 118 |
+
help="Output bundle directory (manifest + packages + tokenizer)",
|
| 119 |
+
)
|
| 120 |
+
_add_common_build_args(convert_parser)
|
| 121 |
+
convert_parser.set_defaults(func=cmd_convert)
|
| 122 |
+
|
| 123 |
+
serve_parser = subparsers.add_parser(
|
| 124 |
+
"serve",
|
| 125 |
+
help="Run /v1/rerank endpoint backed by Core ML ANE profiles",
|
| 126 |
+
)
|
| 127 |
+
serve_parser.add_argument(
|
| 128 |
+
"--bundle-dir",
|
| 129 |
+
required=True,
|
| 130 |
+
help="Bundle directory created by convert",
|
| 131 |
+
)
|
| 132 |
+
serve_parser.add_argument(
|
| 133 |
+
"--model-dir",
|
| 134 |
+
default=None,
|
| 135 |
+
help="Source HF model directory (required if --auto-build and bundle missing)",
|
| 136 |
+
)
|
| 137 |
+
serve_parser.add_argument(
|
| 138 |
+
"--auto-build",
|
| 139 |
+
action=argparse.BooleanOptionalAction,
|
| 140 |
+
default=True,
|
| 141 |
+
help="Auto-build bundle from --model-dir when manifest is missing",
|
| 142 |
+
)
|
| 143 |
+
_add_common_build_args(serve_parser)
|
| 144 |
+
serve_parser.add_argument("--host", default="127.0.0.1")
|
| 145 |
+
serve_parser.add_argument("--port", type=int, default=8000)
|
| 146 |
+
serve_parser.add_argument(
|
| 147 |
+
"--compute-units",
|
| 148 |
+
default="cpu_and_ne",
|
| 149 |
+
choices=["cpu_and_ne", "all", "cpu_only", "cpu_and_gpu"],
|
| 150 |
+
help="Core ML compute units preference",
|
| 151 |
+
)
|
| 152 |
+
serve_parser.add_argument(
|
| 153 |
+
"--model-id",
|
| 154 |
+
default="qwen3-reranker-0.6b-ane",
|
| 155 |
+
help="Model id returned in API responses",
|
| 156 |
+
)
|
| 157 |
+
serve_parser.add_argument(
|
| 158 |
+
"--log-level",
|
| 159 |
+
default="info",
|
| 160 |
+
choices=["critical", "error", "warning", "info", "debug", "trace"],
|
| 161 |
+
)
|
| 162 |
+
serve_parser.set_defaults(func=cmd_serve)
|
| 163 |
+
|
| 164 |
+
return parser
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def main() -> None:
|
| 168 |
+
parser = build_parser()
|
| 169 |
+
args = parser.parse_args()
|
| 170 |
+
args.func(args)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
main()
|
qwen3_ane_rerank/converter.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import shutil
|
| 6 |
+
import subprocess
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from .manifest import BundleManifest, ProfileEntry
|
| 10 |
+
from .profiles import ShapeProfile
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
DEFAULT_SYSTEM_PROMPT = (
|
| 14 |
+
"Judge whether the Document meets the requirements based on the Query and the "
|
| 15 |
+
'Instruct provided. Note that the answer can only be "yes" or "no".'
|
| 16 |
+
)
|
| 17 |
+
DEFAULT_PAIR_TEMPLATE = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}"
|
| 18 |
+
DEFAULT_PREFIX_TEMPLATE = "<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n"
|
| 19 |
+
DEFAULT_SUFFIX_TEXT = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(slots=True)
|
| 23 |
+
class BuildOptions:
|
| 24 |
+
model_dir: Path
|
| 25 |
+
bundle_dir: Path
|
| 26 |
+
profiles: list[ShapeProfile]
|
| 27 |
+
compile_mlmodelc: bool = True
|
| 28 |
+
minimum_deployment_target: str = "macOS14"
|
| 29 |
+
system_prompt: str = DEFAULT_SYSTEM_PROMPT
|
| 30 |
+
pair_template: str = DEFAULT_PAIR_TEMPLATE
|
| 31 |
+
suffix_text: str = DEFAULT_SUFFIX_TEXT
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _import_conversion_deps() -> tuple[Any, Any, Any, Any, Any]:
|
| 35 |
+
try:
|
| 36 |
+
import numpy as np
|
| 37 |
+
import torch
|
| 38 |
+
import coremltools as ct
|
| 39 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 40 |
+
except Exception as exc: # pragma: no cover - runtime dependency check
|
| 41 |
+
raise RuntimeError(
|
| 42 |
+
"Missing conversion dependencies. Install torch, transformers, coremltools, numpy."
|
| 43 |
+
) from exc
|
| 44 |
+
return np, torch, ct, AutoModelForCausalLM, AutoTokenizer
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _resolve_target(ct: Any, raw: str) -> Any:
|
| 48 |
+
if raw == "macOS14":
|
| 49 |
+
return ct.target.macOS14
|
| 50 |
+
if raw == "macOS15":
|
| 51 |
+
return ct.target.macOS15
|
| 52 |
+
if raw == "iOS17":
|
| 53 |
+
return ct.target.iOS17
|
| 54 |
+
if raw == "iOS18":
|
| 55 |
+
return ct.target.iOS18
|
| 56 |
+
raise ValueError(f"Unsupported minimum deployment target: {raw}")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _compile_mlpackage(package_path: Path, compiled_root: Path) -> Path:
|
| 60 |
+
compiled_root.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
cmd = ["xcrun", "coremlcompiler", "compile", str(package_path), str(compiled_root)]
|
| 62 |
+
proc = subprocess.run(cmd, capture_output=True, text=True)
|
| 63 |
+
if proc.returncode != 0:
|
| 64 |
+
raise RuntimeError(
|
| 65 |
+
"coremlcompiler compile failed for "
|
| 66 |
+
f"{package_path.name}:\nSTDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
expected = compiled_root / f"{package_path.stem}.mlmodelc"
|
| 70 |
+
if expected.exists():
|
| 71 |
+
return expected
|
| 72 |
+
|
| 73 |
+
matches = sorted(compiled_root.glob("*.mlmodelc"), key=lambda p: p.stat().st_mtime)
|
| 74 |
+
if not matches:
|
| 75 |
+
raise RuntimeError(f"coremlcompiler succeeded but no .mlmodelc found under {compiled_root}")
|
| 76 |
+
return matches[-1]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _ensure_tokenizer_has_pad_token(tokenizer: Any) -> None:
|
| 80 |
+
if tokenizer.pad_token_id is not None:
|
| 81 |
+
return
|
| 82 |
+
if tokenizer.eos_token is not None:
|
| 83 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 84 |
+
return
|
| 85 |
+
if tokenizer.unk_token is not None:
|
| 86 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 87 |
+
return
|
| 88 |
+
raise RuntimeError("Tokenizer has no pad/eos/unk token; cannot build fixed-shape ANE pipeline")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _resolve_token_id(tokenizer: Any, token: str) -> int:
|
| 92 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
| 93 |
+
if token_id is None:
|
| 94 |
+
raise RuntimeError(f"Unable to resolve token id for '{token}'")
|
| 95 |
+
unk_id = tokenizer.unk_token_id
|
| 96 |
+
if token_id < 0 or (unk_id is not None and token_id == unk_id):
|
| 97 |
+
raise RuntimeError(f"Token '{token}' is missing from tokenizer vocab")
|
| 98 |
+
return int(token_id)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def build_bundle(options: BuildOptions) -> BundleManifest:
|
| 102 |
+
np, torch, ct, AutoModelForCausalLM, AutoTokenizer = _import_conversion_deps()
|
| 103 |
+
|
| 104 |
+
model_dir = options.model_dir.resolve()
|
| 105 |
+
bundle_dir = options.bundle_dir.resolve()
|
| 106 |
+
packages_dir = bundle_dir / "packages"
|
| 107 |
+
compiled_dir = bundle_dir / "compiled"
|
| 108 |
+
tokenizer_dir = bundle_dir / "tokenizer"
|
| 109 |
+
|
| 110 |
+
bundle_dir.mkdir(parents=True, exist_ok=True)
|
| 111 |
+
packages_dir.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
if options.compile_mlmodelc:
|
| 113 |
+
compiled_dir.mkdir(parents=True, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 116 |
+
str(model_dir),
|
| 117 |
+
local_files_only=True,
|
| 118 |
+
trust_remote_code=False,
|
| 119 |
+
use_fast=True,
|
| 120 |
+
)
|
| 121 |
+
tokenizer.padding_side = "left"
|
| 122 |
+
_ensure_tokenizer_has_pad_token(tokenizer)
|
| 123 |
+
tokenizer.save_pretrained(str(tokenizer_dir))
|
| 124 |
+
|
| 125 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 126 |
+
str(model_dir),
|
| 127 |
+
local_files_only=True,
|
| 128 |
+
trust_remote_code=False,
|
| 129 |
+
dtype=torch.float32,
|
| 130 |
+
)
|
| 131 |
+
model = model.float().eval()
|
| 132 |
+
if hasattr(model, "config") and hasattr(model.config, "use_cache"):
|
| 133 |
+
model.config.use_cache = False
|
| 134 |
+
if hasattr(model, "config") and hasattr(model.config, "_attn_implementation"):
|
| 135 |
+
model.config._attn_implementation = "eager"
|
| 136 |
+
|
| 137 |
+
if not hasattr(model, "model") or not hasattr(model, "lm_head"):
|
| 138 |
+
raise RuntimeError("Unsupported model structure: expected .model backbone and .lm_head")
|
| 139 |
+
|
| 140 |
+
hidden_size = int(getattr(model.config, "hidden_size", 0))
|
| 141 |
+
if hidden_size <= 0:
|
| 142 |
+
raise RuntimeError("Unable to infer hidden size from model config")
|
| 143 |
+
|
| 144 |
+
yes_token_id = _resolve_token_id(tokenizer, "yes")
|
| 145 |
+
no_token_id = _resolve_token_id(tokenizer, "no")
|
| 146 |
+
score_weight = (
|
| 147 |
+
model.lm_head.weight[yes_token_id].detach().to(torch.float32)
|
| 148 |
+
- model.lm_head.weight[no_token_id].detach().to(torch.float32)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
backbone = model.model
|
| 152 |
+
|
| 153 |
+
class Qwen3RerankWrapper(torch.nn.Module):
|
| 154 |
+
def __init__(self, language_backbone: Any, score_weight_vec: Any, seq_len: int):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.backbone = language_backbone
|
| 157 |
+
self.register_buffer("score_weight", score_weight_vec.view(1, -1), persistent=False)
|
| 158 |
+
causal = torch.tril(torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32))
|
| 159 |
+
self.register_buffer("causal_template", causal, persistent=False)
|
| 160 |
+
self.neg_inf = -1e4
|
| 161 |
+
|
| 162 |
+
def _build_attention_bias(self, attention_mask: Any) -> Any:
|
| 163 |
+
mask = attention_mask.to(torch.float32)
|
| 164 |
+
key_valid = mask.unsqueeze(1).unsqueeze(1)
|
| 165 |
+
query_valid = mask.unsqueeze(1).unsqueeze(3)
|
| 166 |
+
allowed = self.causal_template * key_valid * query_valid
|
| 167 |
+
return (1.0 - allowed) * self.neg_inf
|
| 168 |
+
|
| 169 |
+
def forward(self, input_ids: Any, attention_mask: Any) -> Any:
|
| 170 |
+
input_ids = input_ids.to(torch.int64)
|
| 171 |
+
attention_bias = self._build_attention_bias(attention_mask)
|
| 172 |
+
outputs = self.backbone(
|
| 173 |
+
input_ids=input_ids,
|
| 174 |
+
attention_mask=attention_bias,
|
| 175 |
+
return_dict=False,
|
| 176 |
+
)
|
| 177 |
+
last_hidden = outputs[0][:, -1, :]
|
| 178 |
+
logit_delta = (last_hidden * self.score_weight).sum(dim=1, keepdim=True)
|
| 179 |
+
return torch.sigmoid(logit_delta)
|
| 180 |
+
|
| 181 |
+
target = _resolve_target(ct, options.minimum_deployment_target)
|
| 182 |
+
prefix_text = DEFAULT_PREFIX_TEMPLATE.format(system_prompt=options.system_prompt)
|
| 183 |
+
|
| 184 |
+
profile_entries: list[ProfileEntry] = []
|
| 185 |
+
for profile in options.profiles:
|
| 186 |
+
profile_id = profile.profile_id
|
| 187 |
+
package_path = packages_dir / f"{profile_id}.mlpackage"
|
| 188 |
+
if package_path.exists():
|
| 189 |
+
shutil.rmtree(package_path)
|
| 190 |
+
wrapper = Qwen3RerankWrapper(backbone, score_weight, seq_len=profile.seq_len).eval()
|
| 191 |
+
|
| 192 |
+
input_ids = torch.full(
|
| 193 |
+
(profile.batch_size, profile.seq_len),
|
| 194 |
+
fill_value=int(tokenizer.pad_token_id),
|
| 195 |
+
dtype=torch.int32,
|
| 196 |
+
)
|
| 197 |
+
attention_mask = torch.ones(
|
| 198 |
+
(profile.batch_size, profile.seq_len),
|
| 199 |
+
dtype=torch.int32,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
exported = torch.export.export(wrapper, (input_ids, attention_mask), strict=False)
|
| 203 |
+
exported = exported.run_decompositions({})
|
| 204 |
+
|
| 205 |
+
ct.convert(
|
| 206 |
+
exported,
|
| 207 |
+
convert_to="mlprogram",
|
| 208 |
+
minimum_deployment_target=target,
|
| 209 |
+
compute_precision=ct.precision.FLOAT16,
|
| 210 |
+
skip_model_load=True,
|
| 211 |
+
package_dir=str(package_path),
|
| 212 |
+
inputs=[
|
| 213 |
+
ct.TensorType(
|
| 214 |
+
name="input_ids",
|
| 215 |
+
shape=input_ids.shape,
|
| 216 |
+
dtype=np.int32,
|
| 217 |
+
),
|
| 218 |
+
ct.TensorType(
|
| 219 |
+
name="attention_mask",
|
| 220 |
+
shape=attention_mask.shape,
|
| 221 |
+
dtype=np.int32,
|
| 222 |
+
),
|
| 223 |
+
],
|
| 224 |
+
outputs=[ct.TensorType(name="score")],
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
compiled_path: Path | None = None
|
| 228 |
+
if options.compile_mlmodelc:
|
| 229 |
+
compiled_path = _compile_mlpackage(package_path, compiled_dir)
|
| 230 |
+
|
| 231 |
+
profile_entries.append(
|
| 232 |
+
ProfileEntry(
|
| 233 |
+
profile_id=profile_id,
|
| 234 |
+
batch_size=profile.batch_size,
|
| 235 |
+
seq_len=profile.seq_len,
|
| 236 |
+
package_path=str(package_path.relative_to(bundle_dir)),
|
| 237 |
+
compiled_path=(
|
| 238 |
+
str(compiled_path.relative_to(bundle_dir)) if compiled_path is not None else None
|
| 239 |
+
),
|
| 240 |
+
input_names=["input_ids", "attention_mask"],
|
| 241 |
+
output_name="score",
|
| 242 |
+
)
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
manifest = BundleManifest.create(
|
| 246 |
+
model_name=Path(model_dir).name,
|
| 247 |
+
source_model_dir=str(model_dir),
|
| 248 |
+
tokenizer_dir=str(tokenizer_dir.relative_to(bundle_dir)),
|
| 249 |
+
hidden_size=hidden_size,
|
| 250 |
+
yes_token_id=yes_token_id,
|
| 251 |
+
no_token_id=no_token_id,
|
| 252 |
+
system_prompt=options.system_prompt,
|
| 253 |
+
pair_template=options.pair_template,
|
| 254 |
+
prefix_text=prefix_text,
|
| 255 |
+
suffix_text=options.suffix_text,
|
| 256 |
+
profiles=profile_entries,
|
| 257 |
+
)
|
| 258 |
+
manifest.save(bundle_dir)
|
| 259 |
+
return manifest
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def build_bundle_if_missing(options: BuildOptions) -> BundleManifest:
|
| 263 |
+
bundle_dir = options.bundle_dir.resolve()
|
| 264 |
+
manifest_path = bundle_dir / "manifest.json"
|
| 265 |
+
if manifest_path.exists():
|
| 266 |
+
return BundleManifest.load(bundle_dir)
|
| 267 |
+
return build_bundle(options)
|
qwen3_ane_rerank/manifest.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import asdict, dataclass
|
| 4 |
+
from datetime import datetime, timezone
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
MANIFEST_FILENAME = "manifest.json"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(slots=True)
|
| 14 |
+
class ProfileEntry:
|
| 15 |
+
profile_id: str
|
| 16 |
+
batch_size: int
|
| 17 |
+
seq_len: int
|
| 18 |
+
package_path: str
|
| 19 |
+
compiled_path: str | None
|
| 20 |
+
input_names: list[str]
|
| 21 |
+
output_name: str
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(slots=True)
|
| 25 |
+
class BundleManifest:
|
| 26 |
+
format_version: int
|
| 27 |
+
task: str
|
| 28 |
+
model_name: str
|
| 29 |
+
source_model_dir: str
|
| 30 |
+
tokenizer_dir: str
|
| 31 |
+
hidden_size: int
|
| 32 |
+
yes_token_id: int
|
| 33 |
+
no_token_id: int
|
| 34 |
+
system_prompt: str
|
| 35 |
+
pair_template: str
|
| 36 |
+
prefix_text: str
|
| 37 |
+
suffix_text: str
|
| 38 |
+
created_at_utc: str
|
| 39 |
+
profiles: list[ProfileEntry]
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def create(
|
| 43 |
+
cls,
|
| 44 |
+
*,
|
| 45 |
+
model_name: str,
|
| 46 |
+
source_model_dir: str,
|
| 47 |
+
tokenizer_dir: str,
|
| 48 |
+
hidden_size: int,
|
| 49 |
+
yes_token_id: int,
|
| 50 |
+
no_token_id: int,
|
| 51 |
+
system_prompt: str,
|
| 52 |
+
pair_template: str,
|
| 53 |
+
prefix_text: str,
|
| 54 |
+
suffix_text: str,
|
| 55 |
+
profiles: list[ProfileEntry],
|
| 56 |
+
) -> "BundleManifest":
|
| 57 |
+
return cls(
|
| 58 |
+
format_version=1,
|
| 59 |
+
task="rerank",
|
| 60 |
+
model_name=model_name,
|
| 61 |
+
source_model_dir=source_model_dir,
|
| 62 |
+
tokenizer_dir=tokenizer_dir,
|
| 63 |
+
hidden_size=hidden_size,
|
| 64 |
+
yes_token_id=yes_token_id,
|
| 65 |
+
no_token_id=no_token_id,
|
| 66 |
+
system_prompt=system_prompt,
|
| 67 |
+
pair_template=pair_template,
|
| 68 |
+
prefix_text=prefix_text,
|
| 69 |
+
suffix_text=suffix_text,
|
| 70 |
+
created_at_utc=datetime.now(timezone.utc).isoformat(),
|
| 71 |
+
profiles=profiles,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def from_dict(cls, payload: dict[str, Any]) -> "BundleManifest":
|
| 76 |
+
profiles = [ProfileEntry(**entry) for entry in payload["profiles"]]
|
| 77 |
+
return cls(
|
| 78 |
+
format_version=payload["format_version"],
|
| 79 |
+
task=payload.get("task", "rerank"),
|
| 80 |
+
model_name=payload["model_name"],
|
| 81 |
+
source_model_dir=payload["source_model_dir"],
|
| 82 |
+
tokenizer_dir=payload["tokenizer_dir"],
|
| 83 |
+
hidden_size=payload["hidden_size"],
|
| 84 |
+
yes_token_id=payload["yes_token_id"],
|
| 85 |
+
no_token_id=payload["no_token_id"],
|
| 86 |
+
system_prompt=payload["system_prompt"],
|
| 87 |
+
pair_template=payload["pair_template"],
|
| 88 |
+
prefix_text=payload["prefix_text"],
|
| 89 |
+
suffix_text=payload["suffix_text"],
|
| 90 |
+
created_at_utc=payload["created_at_utc"],
|
| 91 |
+
profiles=profiles,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def to_dict(self) -> dict[str, Any]:
|
| 95 |
+
payload = asdict(self)
|
| 96 |
+
payload["profiles"] = [asdict(entry) for entry in self.profiles]
|
| 97 |
+
return payload
|
| 98 |
+
|
| 99 |
+
def save(self, bundle_dir: Path) -> Path:
|
| 100 |
+
bundle_dir.mkdir(parents=True, exist_ok=True)
|
| 101 |
+
path = bundle_dir / MANIFEST_FILENAME
|
| 102 |
+
path.write_text(json.dumps(self.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8")
|
| 103 |
+
return path
|
| 104 |
+
|
| 105 |
+
@classmethod
|
| 106 |
+
def load(cls, bundle_dir: Path) -> "BundleManifest":
|
| 107 |
+
path = bundle_dir / MANIFEST_FILENAME
|
| 108 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 109 |
+
return cls.from_dict(payload)
|
qwen3_ane_rerank/profiles.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
DEFAULT_PROFILES = [
|
| 7 |
+
(1, 128),
|
| 8 |
+
(4, 128),
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True, slots=True)
|
| 13 |
+
class ShapeProfile:
|
| 14 |
+
batch_size: int
|
| 15 |
+
seq_len: int
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
def profile_id(self) -> str:
|
| 19 |
+
return f"b{self.batch_size}_s{self.seq_len}"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def parse_profiles(raw: str | None) -> list[ShapeProfile]:
|
| 23 |
+
if not raw:
|
| 24 |
+
return [ShapeProfile(batch_size=b, seq_len=s) for b, s in DEFAULT_PROFILES]
|
| 25 |
+
|
| 26 |
+
parsed: list[ShapeProfile] = []
|
| 27 |
+
for item in raw.split(","):
|
| 28 |
+
item = item.strip().lower()
|
| 29 |
+
if not item:
|
| 30 |
+
continue
|
| 31 |
+
if "x" not in item:
|
| 32 |
+
raise ValueError(f"Invalid profile '{item}'. Expected format BxS, e.g. 4x512")
|
| 33 |
+
left, right = item.split("x", 1)
|
| 34 |
+
batch_size = int(left)
|
| 35 |
+
seq_len = int(right)
|
| 36 |
+
if batch_size <= 0 or seq_len <= 0:
|
| 37 |
+
raise ValueError(f"Invalid profile '{item}'. B and S must be positive")
|
| 38 |
+
parsed.append(ShapeProfile(batch_size=batch_size, seq_len=seq_len))
|
| 39 |
+
|
| 40 |
+
if not parsed:
|
| 41 |
+
raise ValueError("No valid profiles parsed")
|
| 42 |
+
|
| 43 |
+
unique = {(p.batch_size, p.seq_len): p for p in parsed}
|
| 44 |
+
return [
|
| 45 |
+
unique[key]
|
| 46 |
+
for key in sorted(unique.keys(), key=lambda x: (x[1], x[0]))
|
| 47 |
+
]
|
qwen3_ane_rerank/runtime.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Iterable
|
| 6 |
+
|
| 7 |
+
from .manifest import BundleManifest, ProfileEntry
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
MAX_INPUT_TOKENS_PER_ITEM = 8192
|
| 11 |
+
MAX_TOTAL_TOKENS_PER_REQUEST = 300000
|
| 12 |
+
DEFAULT_INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query."
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _import_runtime_deps() -> tuple[Any, Any, Any]:
|
| 16 |
+
try:
|
| 17 |
+
import numpy as np
|
| 18 |
+
import coremltools as ct
|
| 19 |
+
from transformers import AutoTokenizer
|
| 20 |
+
except Exception as exc: # pragma: no cover - runtime dependency check
|
| 21 |
+
raise RuntimeError(
|
| 22 |
+
"Missing runtime dependencies. Install numpy, coremltools, transformers."
|
| 23 |
+
) from exc
|
| 24 |
+
return np, ct, AutoTokenizer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(slots=True)
|
| 28 |
+
class LoadedProfile:
|
| 29 |
+
entry: ProfileEntry
|
| 30 |
+
model_path: Path
|
| 31 |
+
model: Any | None = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Qwen3AneRerankRuntime:
|
| 35 |
+
def __init__(self, bundle_dir: str | Path, compute_units: str = "cpu_and_ne") -> None:
|
| 36 |
+
np, ct, AutoTokenizer = _import_runtime_deps()
|
| 37 |
+
self.np = np
|
| 38 |
+
self.ct = ct
|
| 39 |
+
|
| 40 |
+
self.bundle_dir = Path(bundle_dir).resolve()
|
| 41 |
+
self.manifest = BundleManifest.load(self.bundle_dir)
|
| 42 |
+
|
| 43 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 44 |
+
str(self.bundle_dir / self.manifest.tokenizer_dir),
|
| 45 |
+
local_files_only=True,
|
| 46 |
+
trust_remote_code=False,
|
| 47 |
+
use_fast=True,
|
| 48 |
+
)
|
| 49 |
+
self.tokenizer.padding_side = "left"
|
| 50 |
+
if self.tokenizer.pad_token_id is None:
|
| 51 |
+
if self.tokenizer.eos_token is not None:
|
| 52 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 53 |
+
elif self.tokenizer.unk_token is not None:
|
| 54 |
+
self.tokenizer.pad_token = self.tokenizer.unk_token
|
| 55 |
+
else:
|
| 56 |
+
raise RuntimeError("Tokenizer has no pad/eos/unk token")
|
| 57 |
+
|
| 58 |
+
self.prefix_tokens = self.tokenizer.encode(
|
| 59 |
+
self.manifest.prefix_text,
|
| 60 |
+
add_special_tokens=False,
|
| 61 |
+
)
|
| 62 |
+
self.suffix_tokens = self.tokenizer.encode(
|
| 63 |
+
self.manifest.suffix_text,
|
| 64 |
+
add_special_tokens=False,
|
| 65 |
+
)
|
| 66 |
+
self.static_token_cost = len(self.prefix_tokens) + len(self.suffix_tokens)
|
| 67 |
+
|
| 68 |
+
self.compute_units = self._resolve_compute_units(compute_units)
|
| 69 |
+
|
| 70 |
+
self.profiles: list[LoadedProfile] = []
|
| 71 |
+
for entry in self.manifest.profiles:
|
| 72 |
+
package_path = self.bundle_dir / entry.package_path
|
| 73 |
+
model_path = package_path
|
| 74 |
+
if entry.compiled_path is not None:
|
| 75 |
+
compiled_path = self.bundle_dir / entry.compiled_path
|
| 76 |
+
if (compiled_path / "Manifest.json").exists():
|
| 77 |
+
model_path = compiled_path
|
| 78 |
+
self.profiles.append(LoadedProfile(entry=entry, model_path=model_path))
|
| 79 |
+
|
| 80 |
+
if not self.profiles:
|
| 81 |
+
raise RuntimeError("No profiles found in manifest")
|
| 82 |
+
|
| 83 |
+
self.max_profile_batch = max(p.entry.batch_size for p in self.profiles)
|
| 84 |
+
self.max_profile_seq = max(p.entry.seq_len for p in self.profiles)
|
| 85 |
+
|
| 86 |
+
if self.static_token_cost >= self.max_profile_seq:
|
| 87 |
+
raise RuntimeError(
|
| 88 |
+
"Profile seq_len is too small for reranker prompt template. "
|
| 89 |
+
f"Need > {self.static_token_cost}, got {self.max_profile_seq}."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def _resolve_compute_units(self, raw: str) -> Any:
|
| 93 |
+
mode = raw.strip().lower()
|
| 94 |
+
cu = self.ct.ComputeUnit
|
| 95 |
+
if mode == "cpu_and_ne" and hasattr(cu, "CPU_AND_NE"):
|
| 96 |
+
return cu.CPU_AND_NE
|
| 97 |
+
if mode == "all":
|
| 98 |
+
return cu.ALL
|
| 99 |
+
if mode == "cpu_only":
|
| 100 |
+
return cu.CPU_ONLY
|
| 101 |
+
if mode == "cpu_and_gpu":
|
| 102 |
+
return cu.CPU_AND_GPU
|
| 103 |
+
if mode == "cpu_and_ne" and not hasattr(cu, "CPU_AND_NE"):
|
| 104 |
+
return cu.ALL
|
| 105 |
+
raise ValueError(f"Unsupported compute unit mode: {raw}")
|
| 106 |
+
|
| 107 |
+
def _get_model(self, profile: LoadedProfile) -> Any:
|
| 108 |
+
if profile.model is None:
|
| 109 |
+
profile.model = self.ct.models.MLModel(
|
| 110 |
+
str(profile.model_path),
|
| 111 |
+
compute_units=self.compute_units,
|
| 112 |
+
)
|
| 113 |
+
return profile.model
|
| 114 |
+
|
| 115 |
+
def _select_profile(self, batch_size: int, seq_len: int) -> LoadedProfile | None:
|
| 116 |
+
candidates = [
|
| 117 |
+
p
|
| 118 |
+
for p in self.profiles
|
| 119 |
+
if p.entry.batch_size >= batch_size and p.entry.seq_len >= seq_len
|
| 120 |
+
]
|
| 121 |
+
if not candidates:
|
| 122 |
+
return None
|
| 123 |
+
candidates.sort(key=lambda p: (p.entry.batch_size * p.entry.seq_len, p.entry.seq_len, p.entry.batch_size))
|
| 124 |
+
return candidates[0]
|
| 125 |
+
|
| 126 |
+
def _plan_chunks(self, lengths: list[int]) -> list[tuple[int, int, LoadedProfile]]:
|
| 127 |
+
chunks: list[tuple[int, int, LoadedProfile]] = []
|
| 128 |
+
i = 0
|
| 129 |
+
n = len(lengths)
|
| 130 |
+
while i < n:
|
| 131 |
+
best: tuple[int, LoadedProfile] | None = None
|
| 132 |
+
max_batch = min(self.max_profile_batch, n - i)
|
| 133 |
+
for b in range(max_batch, 0, -1):
|
| 134 |
+
max_len = max(lengths[i : i + b])
|
| 135 |
+
profile = self._select_profile(batch_size=b, seq_len=max_len)
|
| 136 |
+
if profile is not None:
|
| 137 |
+
best = (b, profile)
|
| 138 |
+
break
|
| 139 |
+
if best is None:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"No profile can serve items starting at index {i}. Required seq_len={lengths[i]}"
|
| 142 |
+
)
|
| 143 |
+
b, profile = best
|
| 144 |
+
chunks.append((i, i + b, profile))
|
| 145 |
+
i += b
|
| 146 |
+
return chunks
|
| 147 |
+
|
| 148 |
+
def _predict_scores(self, profile: LoadedProfile, input_ids: Any, attention_mask: Any) -> Any:
|
| 149 |
+
model = self._get_model(profile)
|
| 150 |
+
out = model.predict(
|
| 151 |
+
{
|
| 152 |
+
profile.entry.input_names[0]: input_ids,
|
| 153 |
+
profile.entry.input_names[1]: attention_mask,
|
| 154 |
+
}
|
| 155 |
+
)
|
| 156 |
+
raw = out.get(profile.entry.output_name, next(iter(out.values())))
|
| 157 |
+
scores = self.np.asarray(raw, dtype=self.np.float32)
|
| 158 |
+
if scores.ndim == 0:
|
| 159 |
+
scores = scores.reshape(1)
|
| 160 |
+
elif scores.ndim == 2 and scores.shape[1] == 1:
|
| 161 |
+
scores = scores[:, 0]
|
| 162 |
+
elif scores.ndim > 1:
|
| 163 |
+
scores = scores.reshape(scores.shape[0], -1)[:, 0]
|
| 164 |
+
return scores
|
| 165 |
+
|
| 166 |
+
def _validate_token_limits(self, token_lengths: Iterable[int]) -> None:
|
| 167 |
+
lengths = list(token_lengths)
|
| 168 |
+
if any(length <= 0 for length in lengths):
|
| 169 |
+
raise ValueError("Input pair must not be empty")
|
| 170 |
+
if any(length > MAX_INPUT_TOKENS_PER_ITEM for length in lengths):
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f"Each pair must be <= {MAX_INPUT_TOKENS_PER_ITEM} tokens before truncation"
|
| 173 |
+
)
|
| 174 |
+
if sum(lengths) > MAX_TOTAL_TOKENS_PER_REQUEST:
|
| 175 |
+
raise ValueError(
|
| 176 |
+
f"Total tokens across request must be <= {MAX_TOTAL_TOKENS_PER_REQUEST}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def _format_pair_text(self, query: str, document: str, instruction: str) -> str:
|
| 180 |
+
if "{instruction}" not in self.manifest.pair_template:
|
| 181 |
+
raise RuntimeError("Invalid pair template: missing {instruction}")
|
| 182 |
+
if "{query}" not in self.manifest.pair_template:
|
| 183 |
+
raise RuntimeError("Invalid pair template: missing {query}")
|
| 184 |
+
if "{document}" not in self.manifest.pair_template:
|
| 185 |
+
raise RuntimeError("Invalid pair template: missing {document}")
|
| 186 |
+
return self.manifest.pair_template.format(
|
| 187 |
+
instruction=instruction,
|
| 188 |
+
query=query,
|
| 189 |
+
document=document,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def _pair_token_len(self, pair_text: str) -> int:
|
| 193 |
+
body_len = len(
|
| 194 |
+
self.tokenizer.encode(
|
| 195 |
+
pair_text,
|
| 196 |
+
add_special_tokens=False,
|
| 197 |
+
truncation=False,
|
| 198 |
+
)
|
| 199 |
+
)
|
| 200 |
+
return self.static_token_cost + body_len
|
| 201 |
+
|
| 202 |
+
def _build_pair_ids(self, pair_text: str, seq_len: int) -> list[int]:
|
| 203 |
+
body_budget = seq_len - self.static_token_cost
|
| 204 |
+
if body_budget <= 0:
|
| 205 |
+
raise RuntimeError(f"seq_len={seq_len} is too small for reranker template")
|
| 206 |
+
body_ids = self.tokenizer.encode(
|
| 207 |
+
pair_text,
|
| 208 |
+
add_special_tokens=False,
|
| 209 |
+
truncation=True,
|
| 210 |
+
max_length=body_budget,
|
| 211 |
+
)
|
| 212 |
+
return self.prefix_tokens + body_ids + self.suffix_tokens
|
| 213 |
+
|
| 214 |
+
def rerank(
|
| 215 |
+
self,
|
| 216 |
+
query: str,
|
| 217 |
+
documents: list[str],
|
| 218 |
+
*,
|
| 219 |
+
top_n: int | None = None,
|
| 220 |
+
instruction: str | None = None,
|
| 221 |
+
) -> tuple[list[dict[str, Any]], int]:
|
| 222 |
+
if not query:
|
| 223 |
+
raise ValueError("query must not be empty")
|
| 224 |
+
if not documents:
|
| 225 |
+
raise ValueError("documents must not be empty")
|
| 226 |
+
if any(doc == "" for doc in documents):
|
| 227 |
+
raise ValueError("documents must not contain empty strings")
|
| 228 |
+
|
| 229 |
+
instruction_text = instruction or DEFAULT_INSTRUCTION
|
| 230 |
+
pair_texts = [self._format_pair_text(query, doc, instruction_text) for doc in documents]
|
| 231 |
+
raw_lengths = [self._pair_token_len(text) for text in pair_texts]
|
| 232 |
+
self._validate_token_limits(raw_lengths)
|
| 233 |
+
|
| 234 |
+
too_long = [idx for idx, length in enumerate(raw_lengths) if length > self.max_profile_seq]
|
| 235 |
+
if too_long:
|
| 236 |
+
first = too_long[0]
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"pair at index {first} has {raw_lengths[first]} tokens, "
|
| 239 |
+
f"but compiled profiles only support up to {self.max_profile_seq}. "
|
| 240 |
+
"Rebuild bundle with larger seq profiles."
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
effective_lengths = [min(length, self.max_profile_seq) for length in raw_lengths]
|
| 244 |
+
chunks = self._plan_chunks(effective_lengths)
|
| 245 |
+
pad_id = int(self.tokenizer.pad_token_id)
|
| 246 |
+
|
| 247 |
+
all_scores: list[Any] = []
|
| 248 |
+
prompt_tokens = 0
|
| 249 |
+
|
| 250 |
+
for start, end, profile in chunks:
|
| 251 |
+
chunk_texts = pair_texts[start:end]
|
| 252 |
+
profile_batch = profile.entry.batch_size
|
| 253 |
+
seq_len = profile.entry.seq_len
|
| 254 |
+
|
| 255 |
+
input_ids = self.np.full((profile_batch, seq_len), fill_value=pad_id, dtype=self.np.int32)
|
| 256 |
+
attention_mask = self.np.zeros((profile_batch, seq_len), dtype=self.np.int32)
|
| 257 |
+
|
| 258 |
+
for row, pair_text in enumerate(chunk_texts):
|
| 259 |
+
ids = self._build_pair_ids(pair_text, seq_len=seq_len)
|
| 260 |
+
tlen = len(ids)
|
| 261 |
+
offset = seq_len - tlen
|
| 262 |
+
input_ids[row, offset:] = self.np.asarray(ids, dtype=self.np.int32)
|
| 263 |
+
attention_mask[row, offset:] = 1
|
| 264 |
+
prompt_tokens += tlen
|
| 265 |
+
|
| 266 |
+
scores = self._predict_scores(profile, input_ids, attention_mask)
|
| 267 |
+
all_scores.append(scores[: len(chunk_texts)])
|
| 268 |
+
|
| 269 |
+
merged_scores = self.np.concatenate(all_scores, axis=0).astype(self.np.float32)
|
| 270 |
+
|
| 271 |
+
ranked = [
|
| 272 |
+
{"index": idx, "relevance_score": float(score)}
|
| 273 |
+
for idx, score in enumerate(merged_scores.tolist())
|
| 274 |
+
]
|
| 275 |
+
ranked.sort(key=lambda item: item["relevance_score"], reverse=True)
|
| 276 |
+
|
| 277 |
+
n_results = len(ranked) if top_n is None else max(1, min(int(top_n), len(ranked)))
|
| 278 |
+
return ranked[:n_results], prompt_tokens
|
requirements-service.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==2.4.2
|
| 2 |
+
coremltools==9.0
|
| 3 |
+
transformers==5.2.0
|
| 4 |
+
fastapi==0.135.1
|
| 5 |
+
uvicorn==0.41.0
|
run_server.sh
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
BUNDLE_DIR="${BUNDLE_DIR:-$ROOT_DIR/bundles/qwen3_reranker_ane_bundle_4b}"
|
| 6 |
+
HOST="${HOST:-127.0.0.1}"
|
| 7 |
+
PORT="${PORT:-8000}"
|
| 8 |
+
COMPUTE_UNITS="${COMPUTE_UNITS:-cpu_and_ne}"
|
| 9 |
+
MODEL_ID="${MODEL_ID:-qwen3-reranker-4b-ane}"
|
| 10 |
+
|
| 11 |
+
if [[ -n "${PYTHON_BIN:-}" ]]; then
|
| 12 |
+
PY_BIN="$PYTHON_BIN"
|
| 13 |
+
elif [[ -x "$ROOT_DIR/.venv/bin/python" ]]; then
|
| 14 |
+
PY_BIN="$ROOT_DIR/.venv/bin/python"
|
| 15 |
+
else
|
| 16 |
+
PY_BIN="python3"
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
if ! command -v "$PY_BIN" >/dev/null 2>&1; then
|
| 20 |
+
echo "[ERROR] Python 不可用: $PY_BIN"
|
| 21 |
+
echo "请先执行: python3.11 -m venv .venv && source .venv/bin/activate && python -m pip install -r requirements-service.txt"
|
| 22 |
+
exit 1
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
PY_MM="$($PY_BIN -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')"
|
| 26 |
+
case "$PY_MM" in
|
| 27 |
+
3.10|3.11|3.12) ;;
|
| 28 |
+
*)
|
| 29 |
+
echo "[ERROR] 当前 Python 版本为 ${PY_MM},不受支持。"
|
| 30 |
+
echo "coremltools 在该版本下无法加载本地运行库(典型报错: libcoremlpython / libmilstoragepython)。"
|
| 31 |
+
echo "请改用 Python 3.11:"
|
| 32 |
+
echo " python3.11 -m venv .venv"
|
| 33 |
+
echo " source .venv/bin/activate"
|
| 34 |
+
echo " python -m pip install -r requirements-service.txt"
|
| 35 |
+
exit 1
|
| 36 |
+
;;
|
| 37 |
+
esac
|
| 38 |
+
|
| 39 |
+
if [[ ! -f "$BUNDLE_DIR/manifest.json" ]]; then
|
| 40 |
+
echo "[ERROR] 未找到 bundle manifest: $BUNDLE_DIR/manifest.json"
|
| 41 |
+
exit 1
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
$PY_BIN - <<'PY'
|
| 45 |
+
import sys
|
| 46 |
+
|
| 47 |
+
errors = []
|
| 48 |
+
try:
|
| 49 |
+
import coremltools # noqa: F401
|
| 50 |
+
except Exception as e:
|
| 51 |
+
errors.append(f"import coremltools 失败: {e}")
|
| 52 |
+
|
| 53 |
+
for mod in ("coremltools.libcoremlpython", "coremltools.libmilstoragepython"):
|
| 54 |
+
try:
|
| 55 |
+
__import__(mod)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
errors.append(f"{mod} 加载失败: {e}")
|
| 58 |
+
|
| 59 |
+
if errors:
|
| 60 |
+
print("[ERROR] Core ML Python 运行库不可用:")
|
| 61 |
+
for item in errors:
|
| 62 |
+
print(" -", item)
|
| 63 |
+
print("请确认你在 Python 3.11 的虚拟环境中安装依赖。")
|
| 64 |
+
sys.exit(1)
|
| 65 |
+
PY
|
| 66 |
+
|
| 67 |
+
cd "$ROOT_DIR"
|
| 68 |
+
exec "$PY_BIN" -m qwen3_ane_rerank serve \
|
| 69 |
+
--bundle-dir "$BUNDLE_DIR" \
|
| 70 |
+
--no-auto-build \
|
| 71 |
+
--compute-units "$COMPUTE_UNITS" \
|
| 72 |
+
--model-id "$MODEL_ID" \
|
| 73 |
+
--host "$HOST" \
|
| 74 |
+
--port "$PORT"
|
setup_venv.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
cd "$ROOT_DIR"
|
| 6 |
+
|
| 7 |
+
if ! command -v python3.11 >/dev/null 2>&1; then
|
| 8 |
+
echo "[ERROR] 未找到 python3.11,请先安装 Python 3.11。"
|
| 9 |
+
exit 1
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
if [[ -d .venv ]]; then
|
| 13 |
+
CUR_VER="$(.venv/bin/python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null || true)"
|
| 14 |
+
if [[ "$CUR_VER" != "3.11" ]]; then
|
| 15 |
+
BACKUP_DIR=".venv.backup.$(date +%Y%m%d-%H%M%S)"
|
| 16 |
+
echo "[INFO] 现有 .venv 版本为 ${CUR_VER:-unknown},移动到 $BACKUP_DIR"
|
| 17 |
+
mv .venv "$BACKUP_DIR"
|
| 18 |
+
fi
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
python3.11 -m venv .venv
|
| 22 |
+
source .venv/bin/activate
|
| 23 |
+
python -m pip install --upgrade pip
|
| 24 |
+
python -m pip install -r requirements-service.txt
|
| 25 |
+
|
| 26 |
+
python - <<'PY'
|
| 27 |
+
import sys
|
| 28 |
+
print("[OK] venv Python:", sys.version)
|
| 29 |
+
import coremltools
|
| 30 |
+
print("[OK] coremltools:", coremltools.__version__)
|
| 31 |
+
PY
|
| 32 |
+
|
| 33 |
+
echo "[DONE] 环境准备完成,运行 ./run_server.sh 启动服务。"
|