tooktang's picture
Initial release: Qwen3-Reranker-4B CoreML ANE-optimized bundle + service
1e1d0ce verified
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
from .manifest import BundleManifest, ProfileEntry
MAX_INPUT_TOKENS_PER_ITEM = 8192
MAX_TOTAL_TOKENS_PER_REQUEST = 300000
DEFAULT_INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query."
def _import_runtime_deps() -> tuple[Any, Any, Any]:
try:
import numpy as np
import coremltools as ct
from transformers import AutoTokenizer
except Exception as exc: # pragma: no cover - runtime dependency check
raise RuntimeError(
"Missing runtime dependencies. Install numpy, coremltools, transformers."
) from exc
return np, ct, AutoTokenizer
@dataclass(slots=True)
class LoadedProfile:
entry: ProfileEntry
model_path: Path
model: Any | None = None
class Qwen3AneRerankRuntime:
def __init__(self, bundle_dir: str | Path, compute_units: str = "cpu_and_ne") -> None:
np, ct, AutoTokenizer = _import_runtime_deps()
self.np = np
self.ct = ct
self.bundle_dir = Path(bundle_dir).resolve()
self.manifest = BundleManifest.load(self.bundle_dir)
self.tokenizer = AutoTokenizer.from_pretrained(
str(self.bundle_dir / self.manifest.tokenizer_dir),
local_files_only=True,
trust_remote_code=False,
use_fast=True,
)
self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token_id is None:
if self.tokenizer.eos_token is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
elif self.tokenizer.unk_token is not None:
self.tokenizer.pad_token = self.tokenizer.unk_token
else:
raise RuntimeError("Tokenizer has no pad/eos/unk token")
self.prefix_tokens = self.tokenizer.encode(
self.manifest.prefix_text,
add_special_tokens=False,
)
self.suffix_tokens = self.tokenizer.encode(
self.manifest.suffix_text,
add_special_tokens=False,
)
self.static_token_cost = len(self.prefix_tokens) + len(self.suffix_tokens)
self.compute_units = self._resolve_compute_units(compute_units)
self.profiles: list[LoadedProfile] = []
for entry in self.manifest.profiles:
package_path = self.bundle_dir / entry.package_path
model_path = package_path
if entry.compiled_path is not None:
compiled_path = self.bundle_dir / entry.compiled_path
if (compiled_path / "Manifest.json").exists():
model_path = compiled_path
self.profiles.append(LoadedProfile(entry=entry, model_path=model_path))
if not self.profiles:
raise RuntimeError("No profiles found in manifest")
self.max_profile_batch = max(p.entry.batch_size for p in self.profiles)
self.max_profile_seq = max(p.entry.seq_len for p in self.profiles)
if self.static_token_cost >= self.max_profile_seq:
raise RuntimeError(
"Profile seq_len is too small for reranker prompt template. "
f"Need > {self.static_token_cost}, got {self.max_profile_seq}."
)
def _resolve_compute_units(self, raw: str) -> Any:
mode = raw.strip().lower()
cu = self.ct.ComputeUnit
if mode == "cpu_and_ne" and hasattr(cu, "CPU_AND_NE"):
return cu.CPU_AND_NE
if mode == "all":
return cu.ALL
if mode == "cpu_only":
return cu.CPU_ONLY
if mode == "cpu_and_gpu":
return cu.CPU_AND_GPU
if mode == "cpu_and_ne" and not hasattr(cu, "CPU_AND_NE"):
return cu.ALL
raise ValueError(f"Unsupported compute unit mode: {raw}")
def _get_model(self, profile: LoadedProfile) -> Any:
if profile.model is None:
profile.model = self.ct.models.MLModel(
str(profile.model_path),
compute_units=self.compute_units,
)
return profile.model
def _select_profile(self, batch_size: int, seq_len: int) -> LoadedProfile | None:
candidates = [
p
for p in self.profiles
if p.entry.batch_size >= batch_size and p.entry.seq_len >= seq_len
]
if not candidates:
return None
candidates.sort(key=lambda p: (p.entry.batch_size * p.entry.seq_len, p.entry.seq_len, p.entry.batch_size))
return candidates[0]
def _plan_chunks(self, lengths: list[int]) -> list[tuple[int, int, LoadedProfile]]:
chunks: list[tuple[int, int, LoadedProfile]] = []
i = 0
n = len(lengths)
while i < n:
best: tuple[int, LoadedProfile] | None = None
max_batch = min(self.max_profile_batch, n - i)
for b in range(max_batch, 0, -1):
max_len = max(lengths[i : i + b])
profile = self._select_profile(batch_size=b, seq_len=max_len)
if profile is not None:
best = (b, profile)
break
if best is None:
raise ValueError(
f"No profile can serve items starting at index {i}. Required seq_len={lengths[i]}"
)
b, profile = best
chunks.append((i, i + b, profile))
i += b
return chunks
def _predict_scores(self, profile: LoadedProfile, input_ids: Any, attention_mask: Any) -> Any:
model = self._get_model(profile)
out = model.predict(
{
profile.entry.input_names[0]: input_ids,
profile.entry.input_names[1]: attention_mask,
}
)
raw = out.get(profile.entry.output_name, next(iter(out.values())))
scores = self.np.asarray(raw, dtype=self.np.float32)
if scores.ndim == 0:
scores = scores.reshape(1)
elif scores.ndim == 2 and scores.shape[1] == 1:
scores = scores[:, 0]
elif scores.ndim > 1:
scores = scores.reshape(scores.shape[0], -1)[:, 0]
return scores
def _validate_token_limits(self, token_lengths: Iterable[int]) -> None:
lengths = list(token_lengths)
if any(length <= 0 for length in lengths):
raise ValueError("Input pair must not be empty")
if any(length > MAX_INPUT_TOKENS_PER_ITEM for length in lengths):
raise ValueError(
f"Each pair must be <= {MAX_INPUT_TOKENS_PER_ITEM} tokens before truncation"
)
if sum(lengths) > MAX_TOTAL_TOKENS_PER_REQUEST:
raise ValueError(
f"Total tokens across request must be <= {MAX_TOTAL_TOKENS_PER_REQUEST}"
)
def _format_pair_text(self, query: str, document: str, instruction: str) -> str:
if "{instruction}" not in self.manifest.pair_template:
raise RuntimeError("Invalid pair template: missing {instruction}")
if "{query}" not in self.manifest.pair_template:
raise RuntimeError("Invalid pair template: missing {query}")
if "{document}" not in self.manifest.pair_template:
raise RuntimeError("Invalid pair template: missing {document}")
return self.manifest.pair_template.format(
instruction=instruction,
query=query,
document=document,
)
def _pair_token_len(self, pair_text: str) -> int:
body_len = len(
self.tokenizer.encode(
pair_text,
add_special_tokens=False,
truncation=False,
)
)
return self.static_token_cost + body_len
def _build_pair_ids(self, pair_text: str, seq_len: int) -> list[int]:
body_budget = seq_len - self.static_token_cost
if body_budget <= 0:
raise RuntimeError(f"seq_len={seq_len} is too small for reranker template")
body_ids = self.tokenizer.encode(
pair_text,
add_special_tokens=False,
truncation=True,
max_length=body_budget,
)
return self.prefix_tokens + body_ids + self.suffix_tokens
def rerank(
self,
query: str,
documents: list[str],
*,
top_n: int | None = None,
instruction: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
if not query:
raise ValueError("query must not be empty")
if not documents:
raise ValueError("documents must not be empty")
if any(doc == "" for doc in documents):
raise ValueError("documents must not contain empty strings")
instruction_text = instruction or DEFAULT_INSTRUCTION
pair_texts = [self._format_pair_text(query, doc, instruction_text) for doc in documents]
raw_lengths = [self._pair_token_len(text) for text in pair_texts]
self._validate_token_limits(raw_lengths)
too_long = [idx for idx, length in enumerate(raw_lengths) if length > self.max_profile_seq]
if too_long:
first = too_long[0]
raise ValueError(
f"pair at index {first} has {raw_lengths[first]} tokens, "
f"but compiled profiles only support up to {self.max_profile_seq}. "
"Rebuild bundle with larger seq profiles."
)
effective_lengths = [min(length, self.max_profile_seq) for length in raw_lengths]
chunks = self._plan_chunks(effective_lengths)
pad_id = int(self.tokenizer.pad_token_id)
all_scores: list[Any] = []
prompt_tokens = 0
for start, end, profile in chunks:
chunk_texts = pair_texts[start:end]
profile_batch = profile.entry.batch_size
seq_len = profile.entry.seq_len
input_ids = self.np.full((profile_batch, seq_len), fill_value=pad_id, dtype=self.np.int32)
attention_mask = self.np.zeros((profile_batch, seq_len), dtype=self.np.int32)
for row, pair_text in enumerate(chunk_texts):
ids = self._build_pair_ids(pair_text, seq_len=seq_len)
tlen = len(ids)
offset = seq_len - tlen
input_ids[row, offset:] = self.np.asarray(ids, dtype=self.np.int32)
attention_mask[row, offset:] = 1
prompt_tokens += tlen
scores = self._predict_scores(profile, input_ids, attention_mask)
all_scores.append(scores[: len(chunk_texts)])
merged_scores = self.np.concatenate(all_scores, axis=0).astype(self.np.float32)
ranked = [
{"index": idx, "relevance_score": float(score)}
for idx, score in enumerate(merged_scores.tolist())
]
ranked.sort(key=lambda item: item["relevance_score"], reverse=True)
n_results = len(ranked) if top_n is None else max(1, min(int(top_n), len(ranked)))
return ranked[:n_results], prompt_tokens