tooktang commited on
Commit
1e1d0ce
·
verified ·
1 Parent(s): a3adf95

Initial release: Qwen3-Reranker-4B CoreML ANE-optimized bundle + service

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ bundles/qwen3_reranker_ane_bundle_4b/tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - zh
6
+ tags:
7
+ - qwen3
8
+ - reranker
9
+ - coreml
10
+ - apple-silicon
11
+ - ane
12
+ pipeline_tag: text-ranking
13
+ library_name: coremltools
14
+ base_model: Qwen/Qwen3-Reranker-4B
15
+ ---
16
+
17
+ # Qwen3-Reranker-4B-CoreML (ANE-Optimized)
18
+
19
+ ## English
20
+
21
+ This repository provides a pre-converted CoreML bundle derived from `Qwen3-Reranker-4B` and an OpenAI-style rerank API service for Apple Silicon.
22
+
23
+ ### Bundle Specs
24
+
25
+ | Item | Value |
26
+ | --- | --- |
27
+ | Base model | `Qwen/Qwen3-Reranker-4B` |
28
+ | Task | Text reranking |
29
+ | Profiles | `b1_s128` |
30
+ | Bundle path | `bundles/qwen3_reranker_ane_bundle_4b` |
31
+ | Default model id | `qwen3-reranker-4b-ane` |
32
+ | Package size (approx.) | `7.5G` |
33
+
34
+ ### Scope
35
+
36
+ - This release is **text-only reranking**.
37
+ - Endpoint: `POST /rerank` and `POST /v1/rerank`.
38
+
39
+ ### Quick Start
40
+
41
+ ```bash
42
+ ./setup_venv.sh
43
+ ./run_server.sh
44
+ ```
45
+
46
+ Health check:
47
+
48
+ ```bash
49
+ curl -s http://127.0.0.1:8000/health
50
+ ```
51
+
52
+ Rerank request:
53
+
54
+ ```bash
55
+ curl -s http://127.0.0.1:8000/v1/rerank \
56
+ -H 'Content-Type: application/json' \
57
+ -d '{
58
+ "query": "capital of China",
59
+ "documents": [
60
+ "The capital of China is Beijing.",
61
+ "Gravity is a force."
62
+ ],
63
+ "top_n": 2,
64
+ "return_documents": true
65
+ }'
66
+ ```
67
+
68
+ ### Notes
69
+
70
+ - Fixed shape profile (`s128`) for low-power deployment.
71
+ - Inputs longer than profile capacity return an explicit error.
72
+ - First request has warm-up latency.
73
+ - Default compute setting is `cpu_and_ne` (ANE-preferred, not ANE-guaranteed).
74
+
75
+ ## 中文
76
+
77
+ 这个仓库提供基于 `Qwen3-Reranker-4B` 的预转换 CoreML bundle,以及可直接运行的文本重排服务(`/v1/rerank`)。
78
+
79
+ ### Bundle 规格
80
+
81
+ | 项目 | 值 |
82
+ | --- | --- |
83
+ | 基础模型 | `Qwen/Qwen3-Reranker-4B` |
84
+ | 任务类型 | 文本重排 |
85
+ | Profile | `b1_s128` |
86
+ | Bundle 路径 | `bundles/qwen3_reranker_ane_bundle_4b` |
87
+ | 默认模型名 | `qwen3-reranker-4b-ane` |
88
+ | 包体积(约) | `7.5G` |
89
+
90
+ ### 范围说明
91
+
92
+ - 本版本仅支持**纯文本重排**。
93
+ - 接口为 `POST /rerank` 与 `POST /v1/rerank`。
94
+
95
+ ### 快速开始
96
+
97
+ ```bash
98
+ ./setup_venv.sh
99
+ ./run_server.sh
100
+ ```
101
+
102
+ 健康检查:
103
+
104
+ ```bash
105
+ curl -s http://127.0.0.1:8000/health
106
+ ```
107
+
108
+ 重排请求:
109
+
110
+ ```bash
111
+ curl -s http://127.0.0.1:8000/v1/rerank \
112
+ -H 'Content-Type: application/json' \
113
+ -d '{
114
+ "query": "capital of China",
115
+ "documents": [
116
+ "The capital of China is Beijing.",
117
+ "Gravity is a force."
118
+ ],
119
+ "top_n": 2,
120
+ "return_documents": true
121
+ }'
122
+ ```
123
+
124
+ ### 说明
125
+
126
+ - 固定 shape profile(`s128`),偏向低功耗部署。
127
+ - 输入超过 profile 上限会明确报错。
128
+ - 首次请求会有预热延迟。
129
+ - 默认 `cpu_and_ne`,是偏向 ANE 调度,不等于 100% 仅 ANE 执行。
130
+
131
+ ## License
132
+
133
+ Apache-2.0. Please also follow the license and usage terms of the base Qwen model.
bundles/qwen3_reranker_ane_bundle_4b/manifest.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "format_version": 1,
3
+ "task": "rerank",
4
+ "model_name": "Qwen3-Reranker-4B",
5
+ "source_model_dir": "/Volumes/256G/Applications/ANE/Qwen3-Reranker-4B",
6
+ "tokenizer_dir": "tokenizer",
7
+ "hidden_size": 2560,
8
+ "yes_token_id": 9693,
9
+ "no_token_id": 2152,
10
+ "system_prompt": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".",
11
+ "pair_template": "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}",
12
+ "prefix_text": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n",
13
+ "suffix_text": "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
14
+ "created_at_utc": "2026-03-02T15:48:04.694204+00:00",
15
+ "profiles": [
16
+ {
17
+ "profile_id": "b1_s128",
18
+ "batch_size": 1,
19
+ "seq_len": 128,
20
+ "package_path": "packages/b1_s128.mlpackage",
21
+ "compiled_path": null,
22
+ "input_names": [
23
+ "input_ids",
24
+ "attention_mask"
25
+ ],
26
+ "output_name": "score"
27
+ }
28
+ ]
29
+ }
bundles/qwen3_reranker_ane_bundle_4b/packages/b1_s128.mlpackage/Data/com.apple.CoreML/model.mlmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:954e30d62948183fd3323ac8db3b1b2be3526ef788a19028807ce39642c4450b
3
+ size 791454
bundles/qwen3_reranker_ane_bundle_4b/packages/b1_s128.mlpackage/Data/com.apple.CoreML/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a76471b81324fce92ce1fe1edfbc9fe964abd77f3f97d84db70a1773f605de92
3
+ size 8043733440
bundles/qwen3_reranker_ane_bundle_4b/packages/b1_s128.mlpackage/Manifest.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fileFormatVersion": "1.0.0",
3
+ "itemInfoEntries": {
4
+ "82518FB0-9416-470C-98CD-55225EF89B96": {
5
+ "author": "com.apple.CoreML",
6
+ "description": "CoreML Model Specification",
7
+ "name": "model.mlmodel",
8
+ "path": "com.apple.CoreML/model.mlmodel"
9
+ },
10
+ "FB126581-AA8F-423B-B9F0-2EBA683BA825": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Weights",
13
+ "name": "weights",
14
+ "path": "com.apple.CoreML/weights"
15
+ }
16
+ },
17
+ "rootModelIdentifier": "82518FB0-9416-470C-98CD-55225EF89B96"
18
+ }
bundles/qwen3_reranker_ane_bundle_4b/tokenizer/chat_template.jinja ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28
+ {%- elif message.role == "assistant" %}
29
+ {%- set content = message.content %}
30
+ {%- set reasoning_content = '' %}
31
+ {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
32
+ {%- set reasoning_content = message.reasoning_content %}
33
+ {%- else %}
34
+ {%- if '</think>' in message.content %}
35
+ {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
36
+ {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- endif %}
38
+ {%- endif %}
39
+ {%- if loop.index0 > ns.last_query_index %}
40
+ {%- if loop.last or (not loop.last and reasoning_content) %}
41
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
42
+ {%- else %}
43
+ {{- '<|im_start|>' + message.role + '\n' + content }}
44
+ {%- endif %}
45
+ {%- else %}
46
+ {{- '<|im_start|>' + message.role + '\n' + content }}
47
+ {%- endif %}
48
+ {%- if message.tool_calls %}
49
+ {%- for tool_call in message.tool_calls %}
50
+ {%- if (loop.first and content) or (not loop.first) %}
51
+ {{- '\n' }}
52
+ {%- endif %}
53
+ {%- if tool_call.function %}
54
+ {%- set tool_call = tool_call.function %}
55
+ {%- endif %}
56
+ {{- '<tool_call>\n{"name": "' }}
57
+ {{- tool_call.name }}
58
+ {{- '", "arguments": ' }}
59
+ {%- if tool_call.arguments is string %}
60
+ {{- tool_call.arguments }}
61
+ {%- else %}
62
+ {{- tool_call.arguments | tojson }}
63
+ {%- endif %}
64
+ {{- '}\n</tool_call>' }}
65
+ {%- endfor %}
66
+ {%- endif %}
67
+ {{- '<|im_end|>\n' }}
68
+ {%- elif message.role == "tool" %}
69
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
70
+ {{- '<|im_start|>user' }}
71
+ {%- endif %}
72
+ {{- '\n<tool_response>\n' }}
73
+ {{- message.content }}
74
+ {{- '\n</tool_response>' }}
75
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
76
+ {{- '<|im_end|>\n' }}
77
+ {%- endif %}
78
+ {%- endif %}
79
+ {%- endfor %}
80
+ {%- if add_generation_prompt %}
81
+ {{- '<|im_start|>assistant\n' }}
82
+ {%- if enable_thinking is defined and enable_thinking is false %}
83
+ {{- '<think>\n\n</think>\n\n' }}
84
+ {%- endif %}
85
+ {%- endif %}
bundles/qwen3_reranker_ane_bundle_4b/tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
3
+ size 11422650
bundles/qwen3_reranker_ane_bundle_4b/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "is_local": true,
9
+ "model_max_length": 131072,
10
+ "pad_token": "<|endoftext|>",
11
+ "split_special_tokens": false,
12
+ "tokenizer_class": "Qwen2Tokenizer",
13
+ "unk_token": null
14
+ }
qwen3_ane_rerank/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Qwen3 reranker conversion + ANE serving toolkit."""
2
+
3
+ from .manifest import BundleManifest, ProfileEntry
4
+
5
+ __all__ = ["BundleManifest", "ProfileEntry"]
qwen3_ane_rerank/__main__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .cli import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ main()
qwen3_ane_rerank/api.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel, Field
7
+
8
+ from .runtime import Qwen3AneRerankRuntime
9
+
10
+
11
+ class RerankRequest(BaseModel):
12
+ query: str
13
+ documents: list[str]
14
+ model: str | None = None
15
+ top_n: int | None = Field(default=None, ge=1)
16
+ return_documents: bool = False
17
+ instruction: str | None = None
18
+ user: str | None = None
19
+
20
+
21
+ def create_app(runtime: Qwen3AneRerankRuntime, default_model_id: str | None = None) -> FastAPI:
22
+ app = FastAPI(title="Qwen3 ANE Reranker Service", version="0.1.0")
23
+
24
+ @app.get("/health")
25
+ def health() -> dict[str, Any]:
26
+ return {
27
+ "ok": True,
28
+ "task": "rerank",
29
+ "model": default_model_id or runtime.manifest.model_name,
30
+ "profiles": [
31
+ {
32
+ "id": p.entry.profile_id,
33
+ "batch_size": p.entry.batch_size,
34
+ "seq_len": p.entry.seq_len,
35
+ }
36
+ for p in runtime.profiles
37
+ ],
38
+ }
39
+
40
+ @app.post("/rerank")
41
+ @app.post("/v1/rerank")
42
+ def rerank(req: RerankRequest) -> dict[str, Any]:
43
+ try:
44
+ if req.query == "":
45
+ raise ValueError("query must not be empty")
46
+ if not req.documents:
47
+ raise ValueError("documents must not be empty")
48
+ if any(doc == "" for doc in req.documents):
49
+ raise ValueError("documents must not contain empty strings")
50
+
51
+ results, prompt_tokens = runtime.rerank(
52
+ query=req.query,
53
+ documents=req.documents,
54
+ top_n=req.top_n,
55
+ instruction=req.instruction,
56
+ )
57
+
58
+ data = []
59
+ for row in results:
60
+ item = {
61
+ "object": "rerank_result",
62
+ "index": row["index"],
63
+ "relevance_score": row["relevance_score"],
64
+ }
65
+ if req.return_documents:
66
+ item["document"] = req.documents[row["index"]]
67
+ data.append(item)
68
+
69
+ model_name = req.model or default_model_id or runtime.manifest.model_name
70
+ return {
71
+ "object": "list",
72
+ "data": data,
73
+ "model": model_name,
74
+ "usage": {
75
+ "prompt_tokens": int(prompt_tokens),
76
+ "total_tokens": int(prompt_tokens),
77
+ },
78
+ }
79
+ except ValueError as exc:
80
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
81
+ except RuntimeError as exc:
82
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
83
+
84
+ return app
qwen3_ane_rerank/cli.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ from .converter import BuildOptions, build_bundle
7
+ from .profiles import parse_profiles
8
+ from .runtime import Qwen3AneRerankRuntime
9
+
10
+
11
+ def _add_common_build_args(parser: argparse.ArgumentParser) -> None:
12
+ parser.add_argument(
13
+ "--profiles",
14
+ type=str,
15
+ default=None,
16
+ help="Shape profiles as comma list BxS (e.g. 1x128,4x128)",
17
+ )
18
+ parser.add_argument(
19
+ "--target",
20
+ type=str,
21
+ default="macOS14",
22
+ choices=["macOS14", "macOS15", "iOS17", "iOS18"],
23
+ help="Core ML minimum deployment target",
24
+ )
25
+ parser.add_argument(
26
+ "--compile-mlmodelc",
27
+ action=argparse.BooleanOptionalAction,
28
+ default=True,
29
+ help="Compile .mlpackage into .mlmodelc with coremlcompiler",
30
+ )
31
+ parser.add_argument(
32
+ "--system-prompt",
33
+ default=(
34
+ "Judge whether the Document meets the requirements based on the Query and the "
35
+ 'Instruct provided. Note that the answer can only be "yes" or "no".'
36
+ ),
37
+ help="System prompt used in reranker prompt template",
38
+ )
39
+
40
+
41
+ def cmd_convert(args: argparse.Namespace) -> None:
42
+ profiles = parse_profiles(args.profiles)
43
+ options = BuildOptions(
44
+ model_dir=Path(args.model_dir),
45
+ bundle_dir=Path(args.bundle_dir),
46
+ profiles=profiles,
47
+ compile_mlmodelc=bool(args.compile_mlmodelc),
48
+ minimum_deployment_target=args.target,
49
+ system_prompt=args.system_prompt,
50
+ )
51
+ manifest = build_bundle(options)
52
+ print(f"Built bundle at: {Path(args.bundle_dir).resolve()}")
53
+ print(f"Model: {manifest.model_name}")
54
+ print(f"Task: {manifest.task}")
55
+ print(f"Hidden size: {manifest.hidden_size}")
56
+ print(f"Token ids yes/no: {manifest.yes_token_id}/{manifest.no_token_id}")
57
+ print("Profiles:")
58
+ for entry in manifest.profiles:
59
+ print(
60
+ f" - {entry.profile_id}: batch={entry.batch_size}, seq={entry.seq_len}, "
61
+ f"model={entry.compiled_path or entry.package_path}"
62
+ )
63
+
64
+
65
+ def cmd_serve(args: argparse.Namespace) -> None:
66
+ bundle_dir = Path(args.bundle_dir)
67
+ manifest_path = bundle_dir / "manifest.json"
68
+
69
+ if not manifest_path.exists():
70
+ if not args.auto_build:
71
+ raise SystemExit(
72
+ f"Bundle not found at {bundle_dir}. Run convert first or pass --auto-build --model-dir."
73
+ )
74
+ if not args.model_dir:
75
+ raise SystemExit("--model-dir is required when --auto-build is enabled")
76
+
77
+ profiles = parse_profiles(args.profiles)
78
+ options = BuildOptions(
79
+ model_dir=Path(args.model_dir),
80
+ bundle_dir=bundle_dir,
81
+ profiles=profiles,
82
+ compile_mlmodelc=bool(args.compile_mlmodelc),
83
+ minimum_deployment_target=args.target,
84
+ system_prompt=args.system_prompt,
85
+ )
86
+ print("Bundle not found; building from source model...")
87
+ build_bundle(options)
88
+
89
+ runtime = Qwen3AneRerankRuntime(bundle_dir=bundle_dir, compute_units=args.compute_units)
90
+
91
+ from .api import create_app
92
+ import uvicorn
93
+
94
+ app = create_app(runtime=runtime, default_model_id=args.model_id)
95
+ uvicorn.run(
96
+ app,
97
+ host=args.host,
98
+ port=args.port,
99
+ log_level=args.log_level,
100
+ )
101
+
102
+
103
+ def build_parser() -> argparse.ArgumentParser:
104
+ parser = argparse.ArgumentParser(
105
+ prog="qwen3-ane-rerank",
106
+ description="Convert Qwen3-Reranker model to Core ML ANE bundle and serve /v1/rerank endpoint.",
107
+ )
108
+ subparsers = parser.add_subparsers(dest="command", required=True)
109
+
110
+ convert_parser = subparsers.add_parser(
111
+ "convert",
112
+ help="Convert local HF Qwen3-Reranker model into ANE-ready Core ML profile bundle",
113
+ )
114
+ convert_parser.add_argument("--model-dir", required=True, help="Path to source HF model directory")
115
+ convert_parser.add_argument(
116
+ "--bundle-dir",
117
+ required=True,
118
+ help="Output bundle directory (manifest + packages + tokenizer)",
119
+ )
120
+ _add_common_build_args(convert_parser)
121
+ convert_parser.set_defaults(func=cmd_convert)
122
+
123
+ serve_parser = subparsers.add_parser(
124
+ "serve",
125
+ help="Run /v1/rerank endpoint backed by Core ML ANE profiles",
126
+ )
127
+ serve_parser.add_argument(
128
+ "--bundle-dir",
129
+ required=True,
130
+ help="Bundle directory created by convert",
131
+ )
132
+ serve_parser.add_argument(
133
+ "--model-dir",
134
+ default=None,
135
+ help="Source HF model directory (required if --auto-build and bundle missing)",
136
+ )
137
+ serve_parser.add_argument(
138
+ "--auto-build",
139
+ action=argparse.BooleanOptionalAction,
140
+ default=True,
141
+ help="Auto-build bundle from --model-dir when manifest is missing",
142
+ )
143
+ _add_common_build_args(serve_parser)
144
+ serve_parser.add_argument("--host", default="127.0.0.1")
145
+ serve_parser.add_argument("--port", type=int, default=8000)
146
+ serve_parser.add_argument(
147
+ "--compute-units",
148
+ default="cpu_and_ne",
149
+ choices=["cpu_and_ne", "all", "cpu_only", "cpu_and_gpu"],
150
+ help="Core ML compute units preference",
151
+ )
152
+ serve_parser.add_argument(
153
+ "--model-id",
154
+ default="qwen3-reranker-0.6b-ane",
155
+ help="Model id returned in API responses",
156
+ )
157
+ serve_parser.add_argument(
158
+ "--log-level",
159
+ default="info",
160
+ choices=["critical", "error", "warning", "info", "debug", "trace"],
161
+ )
162
+ serve_parser.set_defaults(func=cmd_serve)
163
+
164
+ return parser
165
+
166
+
167
+ def main() -> None:
168
+ parser = build_parser()
169
+ args = parser.parse_args()
170
+ args.func(args)
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main()
qwen3_ane_rerank/converter.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ import shutil
6
+ import subprocess
7
+ from typing import Any
8
+
9
+ from .manifest import BundleManifest, ProfileEntry
10
+ from .profiles import ShapeProfile
11
+
12
+
13
+ DEFAULT_SYSTEM_PROMPT = (
14
+ "Judge whether the Document meets the requirements based on the Query and the "
15
+ 'Instruct provided. Note that the answer can only be "yes" or "no".'
16
+ )
17
+ DEFAULT_PAIR_TEMPLATE = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}"
18
+ DEFAULT_PREFIX_TEMPLATE = "<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n"
19
+ DEFAULT_SUFFIX_TEXT = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
20
+
21
+
22
+ @dataclass(slots=True)
23
+ class BuildOptions:
24
+ model_dir: Path
25
+ bundle_dir: Path
26
+ profiles: list[ShapeProfile]
27
+ compile_mlmodelc: bool = True
28
+ minimum_deployment_target: str = "macOS14"
29
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT
30
+ pair_template: str = DEFAULT_PAIR_TEMPLATE
31
+ suffix_text: str = DEFAULT_SUFFIX_TEXT
32
+
33
+
34
+ def _import_conversion_deps() -> tuple[Any, Any, Any, Any, Any]:
35
+ try:
36
+ import numpy as np
37
+ import torch
38
+ import coremltools as ct
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer
40
+ except Exception as exc: # pragma: no cover - runtime dependency check
41
+ raise RuntimeError(
42
+ "Missing conversion dependencies. Install torch, transformers, coremltools, numpy."
43
+ ) from exc
44
+ return np, torch, ct, AutoModelForCausalLM, AutoTokenizer
45
+
46
+
47
+ def _resolve_target(ct: Any, raw: str) -> Any:
48
+ if raw == "macOS14":
49
+ return ct.target.macOS14
50
+ if raw == "macOS15":
51
+ return ct.target.macOS15
52
+ if raw == "iOS17":
53
+ return ct.target.iOS17
54
+ if raw == "iOS18":
55
+ return ct.target.iOS18
56
+ raise ValueError(f"Unsupported minimum deployment target: {raw}")
57
+
58
+
59
+ def _compile_mlpackage(package_path: Path, compiled_root: Path) -> Path:
60
+ compiled_root.mkdir(parents=True, exist_ok=True)
61
+ cmd = ["xcrun", "coremlcompiler", "compile", str(package_path), str(compiled_root)]
62
+ proc = subprocess.run(cmd, capture_output=True, text=True)
63
+ if proc.returncode != 0:
64
+ raise RuntimeError(
65
+ "coremlcompiler compile failed for "
66
+ f"{package_path.name}:\nSTDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}"
67
+ )
68
+
69
+ expected = compiled_root / f"{package_path.stem}.mlmodelc"
70
+ if expected.exists():
71
+ return expected
72
+
73
+ matches = sorted(compiled_root.glob("*.mlmodelc"), key=lambda p: p.stat().st_mtime)
74
+ if not matches:
75
+ raise RuntimeError(f"coremlcompiler succeeded but no .mlmodelc found under {compiled_root}")
76
+ return matches[-1]
77
+
78
+
79
+ def _ensure_tokenizer_has_pad_token(tokenizer: Any) -> None:
80
+ if tokenizer.pad_token_id is not None:
81
+ return
82
+ if tokenizer.eos_token is not None:
83
+ tokenizer.pad_token = tokenizer.eos_token
84
+ return
85
+ if tokenizer.unk_token is not None:
86
+ tokenizer.pad_token = tokenizer.unk_token
87
+ return
88
+ raise RuntimeError("Tokenizer has no pad/eos/unk token; cannot build fixed-shape ANE pipeline")
89
+
90
+
91
+ def _resolve_token_id(tokenizer: Any, token: str) -> int:
92
+ token_id = tokenizer.convert_tokens_to_ids(token)
93
+ if token_id is None:
94
+ raise RuntimeError(f"Unable to resolve token id for '{token}'")
95
+ unk_id = tokenizer.unk_token_id
96
+ if token_id < 0 or (unk_id is not None and token_id == unk_id):
97
+ raise RuntimeError(f"Token '{token}' is missing from tokenizer vocab")
98
+ return int(token_id)
99
+
100
+
101
+ def build_bundle(options: BuildOptions) -> BundleManifest:
102
+ np, torch, ct, AutoModelForCausalLM, AutoTokenizer = _import_conversion_deps()
103
+
104
+ model_dir = options.model_dir.resolve()
105
+ bundle_dir = options.bundle_dir.resolve()
106
+ packages_dir = bundle_dir / "packages"
107
+ compiled_dir = bundle_dir / "compiled"
108
+ tokenizer_dir = bundle_dir / "tokenizer"
109
+
110
+ bundle_dir.mkdir(parents=True, exist_ok=True)
111
+ packages_dir.mkdir(parents=True, exist_ok=True)
112
+ if options.compile_mlmodelc:
113
+ compiled_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ tokenizer = AutoTokenizer.from_pretrained(
116
+ str(model_dir),
117
+ local_files_only=True,
118
+ trust_remote_code=False,
119
+ use_fast=True,
120
+ )
121
+ tokenizer.padding_side = "left"
122
+ _ensure_tokenizer_has_pad_token(tokenizer)
123
+ tokenizer.save_pretrained(str(tokenizer_dir))
124
+
125
+ model = AutoModelForCausalLM.from_pretrained(
126
+ str(model_dir),
127
+ local_files_only=True,
128
+ trust_remote_code=False,
129
+ dtype=torch.float32,
130
+ )
131
+ model = model.float().eval()
132
+ if hasattr(model, "config") and hasattr(model.config, "use_cache"):
133
+ model.config.use_cache = False
134
+ if hasattr(model, "config") and hasattr(model.config, "_attn_implementation"):
135
+ model.config._attn_implementation = "eager"
136
+
137
+ if not hasattr(model, "model") or not hasattr(model, "lm_head"):
138
+ raise RuntimeError("Unsupported model structure: expected .model backbone and .lm_head")
139
+
140
+ hidden_size = int(getattr(model.config, "hidden_size", 0))
141
+ if hidden_size <= 0:
142
+ raise RuntimeError("Unable to infer hidden size from model config")
143
+
144
+ yes_token_id = _resolve_token_id(tokenizer, "yes")
145
+ no_token_id = _resolve_token_id(tokenizer, "no")
146
+ score_weight = (
147
+ model.lm_head.weight[yes_token_id].detach().to(torch.float32)
148
+ - model.lm_head.weight[no_token_id].detach().to(torch.float32)
149
+ )
150
+
151
+ backbone = model.model
152
+
153
+ class Qwen3RerankWrapper(torch.nn.Module):
154
+ def __init__(self, language_backbone: Any, score_weight_vec: Any, seq_len: int):
155
+ super().__init__()
156
+ self.backbone = language_backbone
157
+ self.register_buffer("score_weight", score_weight_vec.view(1, -1), persistent=False)
158
+ causal = torch.tril(torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32))
159
+ self.register_buffer("causal_template", causal, persistent=False)
160
+ self.neg_inf = -1e4
161
+
162
+ def _build_attention_bias(self, attention_mask: Any) -> Any:
163
+ mask = attention_mask.to(torch.float32)
164
+ key_valid = mask.unsqueeze(1).unsqueeze(1)
165
+ query_valid = mask.unsqueeze(1).unsqueeze(3)
166
+ allowed = self.causal_template * key_valid * query_valid
167
+ return (1.0 - allowed) * self.neg_inf
168
+
169
+ def forward(self, input_ids: Any, attention_mask: Any) -> Any:
170
+ input_ids = input_ids.to(torch.int64)
171
+ attention_bias = self._build_attention_bias(attention_mask)
172
+ outputs = self.backbone(
173
+ input_ids=input_ids,
174
+ attention_mask=attention_bias,
175
+ return_dict=False,
176
+ )
177
+ last_hidden = outputs[0][:, -1, :]
178
+ logit_delta = (last_hidden * self.score_weight).sum(dim=1, keepdim=True)
179
+ return torch.sigmoid(logit_delta)
180
+
181
+ target = _resolve_target(ct, options.minimum_deployment_target)
182
+ prefix_text = DEFAULT_PREFIX_TEMPLATE.format(system_prompt=options.system_prompt)
183
+
184
+ profile_entries: list[ProfileEntry] = []
185
+ for profile in options.profiles:
186
+ profile_id = profile.profile_id
187
+ package_path = packages_dir / f"{profile_id}.mlpackage"
188
+ if package_path.exists():
189
+ shutil.rmtree(package_path)
190
+ wrapper = Qwen3RerankWrapper(backbone, score_weight, seq_len=profile.seq_len).eval()
191
+
192
+ input_ids = torch.full(
193
+ (profile.batch_size, profile.seq_len),
194
+ fill_value=int(tokenizer.pad_token_id),
195
+ dtype=torch.int32,
196
+ )
197
+ attention_mask = torch.ones(
198
+ (profile.batch_size, profile.seq_len),
199
+ dtype=torch.int32,
200
+ )
201
+
202
+ exported = torch.export.export(wrapper, (input_ids, attention_mask), strict=False)
203
+ exported = exported.run_decompositions({})
204
+
205
+ ct.convert(
206
+ exported,
207
+ convert_to="mlprogram",
208
+ minimum_deployment_target=target,
209
+ compute_precision=ct.precision.FLOAT16,
210
+ skip_model_load=True,
211
+ package_dir=str(package_path),
212
+ inputs=[
213
+ ct.TensorType(
214
+ name="input_ids",
215
+ shape=input_ids.shape,
216
+ dtype=np.int32,
217
+ ),
218
+ ct.TensorType(
219
+ name="attention_mask",
220
+ shape=attention_mask.shape,
221
+ dtype=np.int32,
222
+ ),
223
+ ],
224
+ outputs=[ct.TensorType(name="score")],
225
+ )
226
+
227
+ compiled_path: Path | None = None
228
+ if options.compile_mlmodelc:
229
+ compiled_path = _compile_mlpackage(package_path, compiled_dir)
230
+
231
+ profile_entries.append(
232
+ ProfileEntry(
233
+ profile_id=profile_id,
234
+ batch_size=profile.batch_size,
235
+ seq_len=profile.seq_len,
236
+ package_path=str(package_path.relative_to(bundle_dir)),
237
+ compiled_path=(
238
+ str(compiled_path.relative_to(bundle_dir)) if compiled_path is not None else None
239
+ ),
240
+ input_names=["input_ids", "attention_mask"],
241
+ output_name="score",
242
+ )
243
+ )
244
+
245
+ manifest = BundleManifest.create(
246
+ model_name=Path(model_dir).name,
247
+ source_model_dir=str(model_dir),
248
+ tokenizer_dir=str(tokenizer_dir.relative_to(bundle_dir)),
249
+ hidden_size=hidden_size,
250
+ yes_token_id=yes_token_id,
251
+ no_token_id=no_token_id,
252
+ system_prompt=options.system_prompt,
253
+ pair_template=options.pair_template,
254
+ prefix_text=prefix_text,
255
+ suffix_text=options.suffix_text,
256
+ profiles=profile_entries,
257
+ )
258
+ manifest.save(bundle_dir)
259
+ return manifest
260
+
261
+
262
+ def build_bundle_if_missing(options: BuildOptions) -> BundleManifest:
263
+ bundle_dir = options.bundle_dir.resolve()
264
+ manifest_path = bundle_dir / "manifest.json"
265
+ if manifest_path.exists():
266
+ return BundleManifest.load(bundle_dir)
267
+ return build_bundle(options)
qwen3_ane_rerank/manifest.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass
4
+ from datetime import datetime, timezone
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ MANIFEST_FILENAME = "manifest.json"
11
+
12
+
13
+ @dataclass(slots=True)
14
+ class ProfileEntry:
15
+ profile_id: str
16
+ batch_size: int
17
+ seq_len: int
18
+ package_path: str
19
+ compiled_path: str | None
20
+ input_names: list[str]
21
+ output_name: str
22
+
23
+
24
+ @dataclass(slots=True)
25
+ class BundleManifest:
26
+ format_version: int
27
+ task: str
28
+ model_name: str
29
+ source_model_dir: str
30
+ tokenizer_dir: str
31
+ hidden_size: int
32
+ yes_token_id: int
33
+ no_token_id: int
34
+ system_prompt: str
35
+ pair_template: str
36
+ prefix_text: str
37
+ suffix_text: str
38
+ created_at_utc: str
39
+ profiles: list[ProfileEntry]
40
+
41
+ @classmethod
42
+ def create(
43
+ cls,
44
+ *,
45
+ model_name: str,
46
+ source_model_dir: str,
47
+ tokenizer_dir: str,
48
+ hidden_size: int,
49
+ yes_token_id: int,
50
+ no_token_id: int,
51
+ system_prompt: str,
52
+ pair_template: str,
53
+ prefix_text: str,
54
+ suffix_text: str,
55
+ profiles: list[ProfileEntry],
56
+ ) -> "BundleManifest":
57
+ return cls(
58
+ format_version=1,
59
+ task="rerank",
60
+ model_name=model_name,
61
+ source_model_dir=source_model_dir,
62
+ tokenizer_dir=tokenizer_dir,
63
+ hidden_size=hidden_size,
64
+ yes_token_id=yes_token_id,
65
+ no_token_id=no_token_id,
66
+ system_prompt=system_prompt,
67
+ pair_template=pair_template,
68
+ prefix_text=prefix_text,
69
+ suffix_text=suffix_text,
70
+ created_at_utc=datetime.now(timezone.utc).isoformat(),
71
+ profiles=profiles,
72
+ )
73
+
74
+ @classmethod
75
+ def from_dict(cls, payload: dict[str, Any]) -> "BundleManifest":
76
+ profiles = [ProfileEntry(**entry) for entry in payload["profiles"]]
77
+ return cls(
78
+ format_version=payload["format_version"],
79
+ task=payload.get("task", "rerank"),
80
+ model_name=payload["model_name"],
81
+ source_model_dir=payload["source_model_dir"],
82
+ tokenizer_dir=payload["tokenizer_dir"],
83
+ hidden_size=payload["hidden_size"],
84
+ yes_token_id=payload["yes_token_id"],
85
+ no_token_id=payload["no_token_id"],
86
+ system_prompt=payload["system_prompt"],
87
+ pair_template=payload["pair_template"],
88
+ prefix_text=payload["prefix_text"],
89
+ suffix_text=payload["suffix_text"],
90
+ created_at_utc=payload["created_at_utc"],
91
+ profiles=profiles,
92
+ )
93
+
94
+ def to_dict(self) -> dict[str, Any]:
95
+ payload = asdict(self)
96
+ payload["profiles"] = [asdict(entry) for entry in self.profiles]
97
+ return payload
98
+
99
+ def save(self, bundle_dir: Path) -> Path:
100
+ bundle_dir.mkdir(parents=True, exist_ok=True)
101
+ path = bundle_dir / MANIFEST_FILENAME
102
+ path.write_text(json.dumps(self.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8")
103
+ return path
104
+
105
+ @classmethod
106
+ def load(cls, bundle_dir: Path) -> "BundleManifest":
107
+ path = bundle_dir / MANIFEST_FILENAME
108
+ payload = json.loads(path.read_text(encoding="utf-8"))
109
+ return cls.from_dict(payload)
qwen3_ane_rerank/profiles.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ DEFAULT_PROFILES = [
7
+ (1, 128),
8
+ (4, 128),
9
+ ]
10
+
11
+
12
+ @dataclass(frozen=True, slots=True)
13
+ class ShapeProfile:
14
+ batch_size: int
15
+ seq_len: int
16
+
17
+ @property
18
+ def profile_id(self) -> str:
19
+ return f"b{self.batch_size}_s{self.seq_len}"
20
+
21
+
22
+ def parse_profiles(raw: str | None) -> list[ShapeProfile]:
23
+ if not raw:
24
+ return [ShapeProfile(batch_size=b, seq_len=s) for b, s in DEFAULT_PROFILES]
25
+
26
+ parsed: list[ShapeProfile] = []
27
+ for item in raw.split(","):
28
+ item = item.strip().lower()
29
+ if not item:
30
+ continue
31
+ if "x" not in item:
32
+ raise ValueError(f"Invalid profile '{item}'. Expected format BxS, e.g. 4x512")
33
+ left, right = item.split("x", 1)
34
+ batch_size = int(left)
35
+ seq_len = int(right)
36
+ if batch_size <= 0 or seq_len <= 0:
37
+ raise ValueError(f"Invalid profile '{item}'. B and S must be positive")
38
+ parsed.append(ShapeProfile(batch_size=batch_size, seq_len=seq_len))
39
+
40
+ if not parsed:
41
+ raise ValueError("No valid profiles parsed")
42
+
43
+ unique = {(p.batch_size, p.seq_len): p for p in parsed}
44
+ return [
45
+ unique[key]
46
+ for key in sorted(unique.keys(), key=lambda x: (x[1], x[0]))
47
+ ]
qwen3_ane_rerank/runtime.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Any, Iterable
6
+
7
+ from .manifest import BundleManifest, ProfileEntry
8
+
9
+
10
+ MAX_INPUT_TOKENS_PER_ITEM = 8192
11
+ MAX_TOTAL_TOKENS_PER_REQUEST = 300000
12
+ DEFAULT_INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query."
13
+
14
+
15
+ def _import_runtime_deps() -> tuple[Any, Any, Any]:
16
+ try:
17
+ import numpy as np
18
+ import coremltools as ct
19
+ from transformers import AutoTokenizer
20
+ except Exception as exc: # pragma: no cover - runtime dependency check
21
+ raise RuntimeError(
22
+ "Missing runtime dependencies. Install numpy, coremltools, transformers."
23
+ ) from exc
24
+ return np, ct, AutoTokenizer
25
+
26
+
27
+ @dataclass(slots=True)
28
+ class LoadedProfile:
29
+ entry: ProfileEntry
30
+ model_path: Path
31
+ model: Any | None = None
32
+
33
+
34
+ class Qwen3AneRerankRuntime:
35
+ def __init__(self, bundle_dir: str | Path, compute_units: str = "cpu_and_ne") -> None:
36
+ np, ct, AutoTokenizer = _import_runtime_deps()
37
+ self.np = np
38
+ self.ct = ct
39
+
40
+ self.bundle_dir = Path(bundle_dir).resolve()
41
+ self.manifest = BundleManifest.load(self.bundle_dir)
42
+
43
+ self.tokenizer = AutoTokenizer.from_pretrained(
44
+ str(self.bundle_dir / self.manifest.tokenizer_dir),
45
+ local_files_only=True,
46
+ trust_remote_code=False,
47
+ use_fast=True,
48
+ )
49
+ self.tokenizer.padding_side = "left"
50
+ if self.tokenizer.pad_token_id is None:
51
+ if self.tokenizer.eos_token is not None:
52
+ self.tokenizer.pad_token = self.tokenizer.eos_token
53
+ elif self.tokenizer.unk_token is not None:
54
+ self.tokenizer.pad_token = self.tokenizer.unk_token
55
+ else:
56
+ raise RuntimeError("Tokenizer has no pad/eos/unk token")
57
+
58
+ self.prefix_tokens = self.tokenizer.encode(
59
+ self.manifest.prefix_text,
60
+ add_special_tokens=False,
61
+ )
62
+ self.suffix_tokens = self.tokenizer.encode(
63
+ self.manifest.suffix_text,
64
+ add_special_tokens=False,
65
+ )
66
+ self.static_token_cost = len(self.prefix_tokens) + len(self.suffix_tokens)
67
+
68
+ self.compute_units = self._resolve_compute_units(compute_units)
69
+
70
+ self.profiles: list[LoadedProfile] = []
71
+ for entry in self.manifest.profiles:
72
+ package_path = self.bundle_dir / entry.package_path
73
+ model_path = package_path
74
+ if entry.compiled_path is not None:
75
+ compiled_path = self.bundle_dir / entry.compiled_path
76
+ if (compiled_path / "Manifest.json").exists():
77
+ model_path = compiled_path
78
+ self.profiles.append(LoadedProfile(entry=entry, model_path=model_path))
79
+
80
+ if not self.profiles:
81
+ raise RuntimeError("No profiles found in manifest")
82
+
83
+ self.max_profile_batch = max(p.entry.batch_size for p in self.profiles)
84
+ self.max_profile_seq = max(p.entry.seq_len for p in self.profiles)
85
+
86
+ if self.static_token_cost >= self.max_profile_seq:
87
+ raise RuntimeError(
88
+ "Profile seq_len is too small for reranker prompt template. "
89
+ f"Need > {self.static_token_cost}, got {self.max_profile_seq}."
90
+ )
91
+
92
+ def _resolve_compute_units(self, raw: str) -> Any:
93
+ mode = raw.strip().lower()
94
+ cu = self.ct.ComputeUnit
95
+ if mode == "cpu_and_ne" and hasattr(cu, "CPU_AND_NE"):
96
+ return cu.CPU_AND_NE
97
+ if mode == "all":
98
+ return cu.ALL
99
+ if mode == "cpu_only":
100
+ return cu.CPU_ONLY
101
+ if mode == "cpu_and_gpu":
102
+ return cu.CPU_AND_GPU
103
+ if mode == "cpu_and_ne" and not hasattr(cu, "CPU_AND_NE"):
104
+ return cu.ALL
105
+ raise ValueError(f"Unsupported compute unit mode: {raw}")
106
+
107
+ def _get_model(self, profile: LoadedProfile) -> Any:
108
+ if profile.model is None:
109
+ profile.model = self.ct.models.MLModel(
110
+ str(profile.model_path),
111
+ compute_units=self.compute_units,
112
+ )
113
+ return profile.model
114
+
115
+ def _select_profile(self, batch_size: int, seq_len: int) -> LoadedProfile | None:
116
+ candidates = [
117
+ p
118
+ for p in self.profiles
119
+ if p.entry.batch_size >= batch_size and p.entry.seq_len >= seq_len
120
+ ]
121
+ if not candidates:
122
+ return None
123
+ candidates.sort(key=lambda p: (p.entry.batch_size * p.entry.seq_len, p.entry.seq_len, p.entry.batch_size))
124
+ return candidates[0]
125
+
126
+ def _plan_chunks(self, lengths: list[int]) -> list[tuple[int, int, LoadedProfile]]:
127
+ chunks: list[tuple[int, int, LoadedProfile]] = []
128
+ i = 0
129
+ n = len(lengths)
130
+ while i < n:
131
+ best: tuple[int, LoadedProfile] | None = None
132
+ max_batch = min(self.max_profile_batch, n - i)
133
+ for b in range(max_batch, 0, -1):
134
+ max_len = max(lengths[i : i + b])
135
+ profile = self._select_profile(batch_size=b, seq_len=max_len)
136
+ if profile is not None:
137
+ best = (b, profile)
138
+ break
139
+ if best is None:
140
+ raise ValueError(
141
+ f"No profile can serve items starting at index {i}. Required seq_len={lengths[i]}"
142
+ )
143
+ b, profile = best
144
+ chunks.append((i, i + b, profile))
145
+ i += b
146
+ return chunks
147
+
148
+ def _predict_scores(self, profile: LoadedProfile, input_ids: Any, attention_mask: Any) -> Any:
149
+ model = self._get_model(profile)
150
+ out = model.predict(
151
+ {
152
+ profile.entry.input_names[0]: input_ids,
153
+ profile.entry.input_names[1]: attention_mask,
154
+ }
155
+ )
156
+ raw = out.get(profile.entry.output_name, next(iter(out.values())))
157
+ scores = self.np.asarray(raw, dtype=self.np.float32)
158
+ if scores.ndim == 0:
159
+ scores = scores.reshape(1)
160
+ elif scores.ndim == 2 and scores.shape[1] == 1:
161
+ scores = scores[:, 0]
162
+ elif scores.ndim > 1:
163
+ scores = scores.reshape(scores.shape[0], -1)[:, 0]
164
+ return scores
165
+
166
+ def _validate_token_limits(self, token_lengths: Iterable[int]) -> None:
167
+ lengths = list(token_lengths)
168
+ if any(length <= 0 for length in lengths):
169
+ raise ValueError("Input pair must not be empty")
170
+ if any(length > MAX_INPUT_TOKENS_PER_ITEM for length in lengths):
171
+ raise ValueError(
172
+ f"Each pair must be <= {MAX_INPUT_TOKENS_PER_ITEM} tokens before truncation"
173
+ )
174
+ if sum(lengths) > MAX_TOTAL_TOKENS_PER_REQUEST:
175
+ raise ValueError(
176
+ f"Total tokens across request must be <= {MAX_TOTAL_TOKENS_PER_REQUEST}"
177
+ )
178
+
179
+ def _format_pair_text(self, query: str, document: str, instruction: str) -> str:
180
+ if "{instruction}" not in self.manifest.pair_template:
181
+ raise RuntimeError("Invalid pair template: missing {instruction}")
182
+ if "{query}" not in self.manifest.pair_template:
183
+ raise RuntimeError("Invalid pair template: missing {query}")
184
+ if "{document}" not in self.manifest.pair_template:
185
+ raise RuntimeError("Invalid pair template: missing {document}")
186
+ return self.manifest.pair_template.format(
187
+ instruction=instruction,
188
+ query=query,
189
+ document=document,
190
+ )
191
+
192
+ def _pair_token_len(self, pair_text: str) -> int:
193
+ body_len = len(
194
+ self.tokenizer.encode(
195
+ pair_text,
196
+ add_special_tokens=False,
197
+ truncation=False,
198
+ )
199
+ )
200
+ return self.static_token_cost + body_len
201
+
202
+ def _build_pair_ids(self, pair_text: str, seq_len: int) -> list[int]:
203
+ body_budget = seq_len - self.static_token_cost
204
+ if body_budget <= 0:
205
+ raise RuntimeError(f"seq_len={seq_len} is too small for reranker template")
206
+ body_ids = self.tokenizer.encode(
207
+ pair_text,
208
+ add_special_tokens=False,
209
+ truncation=True,
210
+ max_length=body_budget,
211
+ )
212
+ return self.prefix_tokens + body_ids + self.suffix_tokens
213
+
214
+ def rerank(
215
+ self,
216
+ query: str,
217
+ documents: list[str],
218
+ *,
219
+ top_n: int | None = None,
220
+ instruction: str | None = None,
221
+ ) -> tuple[list[dict[str, Any]], int]:
222
+ if not query:
223
+ raise ValueError("query must not be empty")
224
+ if not documents:
225
+ raise ValueError("documents must not be empty")
226
+ if any(doc == "" for doc in documents):
227
+ raise ValueError("documents must not contain empty strings")
228
+
229
+ instruction_text = instruction or DEFAULT_INSTRUCTION
230
+ pair_texts = [self._format_pair_text(query, doc, instruction_text) for doc in documents]
231
+ raw_lengths = [self._pair_token_len(text) for text in pair_texts]
232
+ self._validate_token_limits(raw_lengths)
233
+
234
+ too_long = [idx for idx, length in enumerate(raw_lengths) if length > self.max_profile_seq]
235
+ if too_long:
236
+ first = too_long[0]
237
+ raise ValueError(
238
+ f"pair at index {first} has {raw_lengths[first]} tokens, "
239
+ f"but compiled profiles only support up to {self.max_profile_seq}. "
240
+ "Rebuild bundle with larger seq profiles."
241
+ )
242
+
243
+ effective_lengths = [min(length, self.max_profile_seq) for length in raw_lengths]
244
+ chunks = self._plan_chunks(effective_lengths)
245
+ pad_id = int(self.tokenizer.pad_token_id)
246
+
247
+ all_scores: list[Any] = []
248
+ prompt_tokens = 0
249
+
250
+ for start, end, profile in chunks:
251
+ chunk_texts = pair_texts[start:end]
252
+ profile_batch = profile.entry.batch_size
253
+ seq_len = profile.entry.seq_len
254
+
255
+ input_ids = self.np.full((profile_batch, seq_len), fill_value=pad_id, dtype=self.np.int32)
256
+ attention_mask = self.np.zeros((profile_batch, seq_len), dtype=self.np.int32)
257
+
258
+ for row, pair_text in enumerate(chunk_texts):
259
+ ids = self._build_pair_ids(pair_text, seq_len=seq_len)
260
+ tlen = len(ids)
261
+ offset = seq_len - tlen
262
+ input_ids[row, offset:] = self.np.asarray(ids, dtype=self.np.int32)
263
+ attention_mask[row, offset:] = 1
264
+ prompt_tokens += tlen
265
+
266
+ scores = self._predict_scores(profile, input_ids, attention_mask)
267
+ all_scores.append(scores[: len(chunk_texts)])
268
+
269
+ merged_scores = self.np.concatenate(all_scores, axis=0).astype(self.np.float32)
270
+
271
+ ranked = [
272
+ {"index": idx, "relevance_score": float(score)}
273
+ for idx, score in enumerate(merged_scores.tolist())
274
+ ]
275
+ ranked.sort(key=lambda item: item["relevance_score"], reverse=True)
276
+
277
+ n_results = len(ranked) if top_n is None else max(1, min(int(top_n), len(ranked)))
278
+ return ranked[:n_results], prompt_tokens
requirements-service.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==2.4.2
2
+ coremltools==9.0
3
+ transformers==5.2.0
4
+ fastapi==0.135.1
5
+ uvicorn==0.41.0
run_server.sh ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ BUNDLE_DIR="${BUNDLE_DIR:-$ROOT_DIR/bundles/qwen3_reranker_ane_bundle_4b}"
6
+ HOST="${HOST:-127.0.0.1}"
7
+ PORT="${PORT:-8000}"
8
+ COMPUTE_UNITS="${COMPUTE_UNITS:-cpu_and_ne}"
9
+ MODEL_ID="${MODEL_ID:-qwen3-reranker-4b-ane}"
10
+
11
+ if [[ -n "${PYTHON_BIN:-}" ]]; then
12
+ PY_BIN="$PYTHON_BIN"
13
+ elif [[ -x "$ROOT_DIR/.venv/bin/python" ]]; then
14
+ PY_BIN="$ROOT_DIR/.venv/bin/python"
15
+ else
16
+ PY_BIN="python3"
17
+ fi
18
+
19
+ if ! command -v "$PY_BIN" >/dev/null 2>&1; then
20
+ echo "[ERROR] Python 不可用: $PY_BIN"
21
+ echo "请先执行: python3.11 -m venv .venv && source .venv/bin/activate && python -m pip install -r requirements-service.txt"
22
+ exit 1
23
+ fi
24
+
25
+ PY_MM="$($PY_BIN -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')"
26
+ case "$PY_MM" in
27
+ 3.10|3.11|3.12) ;;
28
+ *)
29
+ echo "[ERROR] 当前 Python 版本为 ${PY_MM},不受支持。"
30
+ echo "coremltools 在该版本下无法加载本地运行库(典型报错: libcoremlpython / libmilstoragepython)。"
31
+ echo "请改用 Python 3.11:"
32
+ echo " python3.11 -m venv .venv"
33
+ echo " source .venv/bin/activate"
34
+ echo " python -m pip install -r requirements-service.txt"
35
+ exit 1
36
+ ;;
37
+ esac
38
+
39
+ if [[ ! -f "$BUNDLE_DIR/manifest.json" ]]; then
40
+ echo "[ERROR] 未找到 bundle manifest: $BUNDLE_DIR/manifest.json"
41
+ exit 1
42
+ fi
43
+
44
+ $PY_BIN - <<'PY'
45
+ import sys
46
+
47
+ errors = []
48
+ try:
49
+ import coremltools # noqa: F401
50
+ except Exception as e:
51
+ errors.append(f"import coremltools 失败: {e}")
52
+
53
+ for mod in ("coremltools.libcoremlpython", "coremltools.libmilstoragepython"):
54
+ try:
55
+ __import__(mod)
56
+ except Exception as e:
57
+ errors.append(f"{mod} 加载失败: {e}")
58
+
59
+ if errors:
60
+ print("[ERROR] Core ML Python 运行库不可用:")
61
+ for item in errors:
62
+ print(" -", item)
63
+ print("请确认你在 Python 3.11 的虚拟环境中安装依赖。")
64
+ sys.exit(1)
65
+ PY
66
+
67
+ cd "$ROOT_DIR"
68
+ exec "$PY_BIN" -m qwen3_ane_rerank serve \
69
+ --bundle-dir "$BUNDLE_DIR" \
70
+ --no-auto-build \
71
+ --compute-units "$COMPUTE_UNITS" \
72
+ --model-id "$MODEL_ID" \
73
+ --host "$HOST" \
74
+ --port "$PORT"
setup_venv.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ cd "$ROOT_DIR"
6
+
7
+ if ! command -v python3.11 >/dev/null 2>&1; then
8
+ echo "[ERROR] 未找到 python3.11,请先安装 Python 3.11。"
9
+ exit 1
10
+ fi
11
+
12
+ if [[ -d .venv ]]; then
13
+ CUR_VER="$(.venv/bin/python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null || true)"
14
+ if [[ "$CUR_VER" != "3.11" ]]; then
15
+ BACKUP_DIR=".venv.backup.$(date +%Y%m%d-%H%M%S)"
16
+ echo "[INFO] 现有 .venv 版本为 ${CUR_VER:-unknown},移动到 $BACKUP_DIR"
17
+ mv .venv "$BACKUP_DIR"
18
+ fi
19
+ fi
20
+
21
+ python3.11 -m venv .venv
22
+ source .venv/bin/activate
23
+ python -m pip install --upgrade pip
24
+ python -m pip install -r requirements-service.txt
25
+
26
+ python - <<'PY'
27
+ import sys
28
+ print("[OK] venv Python:", sys.version)
29
+ import coremltools
30
+ print("[OK] coremltools:", coremltools.__version__)
31
+ PY
32
+
33
+ echo "[DONE] 环境准备完成,运行 ./run_server.sh 启动服务。"