| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Any, Iterable |
| |
|
| | from .manifest import BundleManifest, ProfileEntry |
| |
|
| |
|
| | MAX_INPUT_TOKENS_PER_ITEM = 8192 |
| | MAX_TOTAL_TOKENS_PER_REQUEST = 300000 |
| | DEFAULT_INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query." |
| |
|
| |
|
| | def _import_runtime_deps() -> tuple[Any, Any, Any]: |
| | try: |
| | import numpy as np |
| | import coremltools as ct |
| | from transformers import AutoTokenizer |
| | except Exception as exc: |
| | raise RuntimeError( |
| | "Missing runtime dependencies. Install numpy, coremltools, transformers." |
| | ) from exc |
| | return np, ct, AutoTokenizer |
| |
|
| |
|
| | @dataclass(slots=True) |
| | class LoadedProfile: |
| | entry: ProfileEntry |
| | model_path: Path |
| | model: Any | None = None |
| |
|
| |
|
| | class Qwen3AneRerankRuntime: |
| | def __init__(self, bundle_dir: str | Path, compute_units: str = "cpu_and_ne") -> None: |
| | np, ct, AutoTokenizer = _import_runtime_deps() |
| | self.np = np |
| | self.ct = ct |
| |
|
| | self.bundle_dir = Path(bundle_dir).resolve() |
| | self.manifest = BundleManifest.load(self.bundle_dir) |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | str(self.bundle_dir / self.manifest.tokenizer_dir), |
| | local_files_only=True, |
| | trust_remote_code=False, |
| | use_fast=True, |
| | ) |
| | self.tokenizer.padding_side = "left" |
| | if self.tokenizer.pad_token_id is None: |
| | if self.tokenizer.eos_token is not None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | elif self.tokenizer.unk_token is not None: |
| | self.tokenizer.pad_token = self.tokenizer.unk_token |
| | else: |
| | raise RuntimeError("Tokenizer has no pad/eos/unk token") |
| |
|
| | self.prefix_tokens = self.tokenizer.encode( |
| | self.manifest.prefix_text, |
| | add_special_tokens=False, |
| | ) |
| | self.suffix_tokens = self.tokenizer.encode( |
| | self.manifest.suffix_text, |
| | add_special_tokens=False, |
| | ) |
| | self.static_token_cost = len(self.prefix_tokens) + len(self.suffix_tokens) |
| |
|
| | self.compute_units = self._resolve_compute_units(compute_units) |
| |
|
| | self.profiles: list[LoadedProfile] = [] |
| | for entry in self.manifest.profiles: |
| | package_path = self.bundle_dir / entry.package_path |
| | model_path = package_path |
| | if entry.compiled_path is not None: |
| | compiled_path = self.bundle_dir / entry.compiled_path |
| | if (compiled_path / "Manifest.json").exists(): |
| | model_path = compiled_path |
| | self.profiles.append(LoadedProfile(entry=entry, model_path=model_path)) |
| |
|
| | if not self.profiles: |
| | raise RuntimeError("No profiles found in manifest") |
| |
|
| | self.max_profile_batch = max(p.entry.batch_size for p in self.profiles) |
| | self.max_profile_seq = max(p.entry.seq_len for p in self.profiles) |
| |
|
| | if self.static_token_cost >= self.max_profile_seq: |
| | raise RuntimeError( |
| | "Profile seq_len is too small for reranker prompt template. " |
| | f"Need > {self.static_token_cost}, got {self.max_profile_seq}." |
| | ) |
| |
|
| | def _resolve_compute_units(self, raw: str) -> Any: |
| | mode = raw.strip().lower() |
| | cu = self.ct.ComputeUnit |
| | if mode == "cpu_and_ne" and hasattr(cu, "CPU_AND_NE"): |
| | return cu.CPU_AND_NE |
| | if mode == "all": |
| | return cu.ALL |
| | if mode == "cpu_only": |
| | return cu.CPU_ONLY |
| | if mode == "cpu_and_gpu": |
| | return cu.CPU_AND_GPU |
| | if mode == "cpu_and_ne" and not hasattr(cu, "CPU_AND_NE"): |
| | return cu.ALL |
| | raise ValueError(f"Unsupported compute unit mode: {raw}") |
| |
|
| | def _get_model(self, profile: LoadedProfile) -> Any: |
| | if profile.model is None: |
| | profile.model = self.ct.models.MLModel( |
| | str(profile.model_path), |
| | compute_units=self.compute_units, |
| | ) |
| | return profile.model |
| |
|
| | def _select_profile(self, batch_size: int, seq_len: int) -> LoadedProfile | None: |
| | candidates = [ |
| | p |
| | for p in self.profiles |
| | if p.entry.batch_size >= batch_size and p.entry.seq_len >= seq_len |
| | ] |
| | if not candidates: |
| | return None |
| | candidates.sort(key=lambda p: (p.entry.batch_size * p.entry.seq_len, p.entry.seq_len, p.entry.batch_size)) |
| | return candidates[0] |
| |
|
| | def _plan_chunks(self, lengths: list[int]) -> list[tuple[int, int, LoadedProfile]]: |
| | chunks: list[tuple[int, int, LoadedProfile]] = [] |
| | i = 0 |
| | n = len(lengths) |
| | while i < n: |
| | best: tuple[int, LoadedProfile] | None = None |
| | max_batch = min(self.max_profile_batch, n - i) |
| | for b in range(max_batch, 0, -1): |
| | max_len = max(lengths[i : i + b]) |
| | profile = self._select_profile(batch_size=b, seq_len=max_len) |
| | if profile is not None: |
| | best = (b, profile) |
| | break |
| | if best is None: |
| | raise ValueError( |
| | f"No profile can serve items starting at index {i}. Required seq_len={lengths[i]}" |
| | ) |
| | b, profile = best |
| | chunks.append((i, i + b, profile)) |
| | i += b |
| | return chunks |
| |
|
| | def _predict_scores(self, profile: LoadedProfile, input_ids: Any, attention_mask: Any) -> Any: |
| | model = self._get_model(profile) |
| | out = model.predict( |
| | { |
| | profile.entry.input_names[0]: input_ids, |
| | profile.entry.input_names[1]: attention_mask, |
| | } |
| | ) |
| | raw = out.get(profile.entry.output_name, next(iter(out.values()))) |
| | scores = self.np.asarray(raw, dtype=self.np.float32) |
| | if scores.ndim == 0: |
| | scores = scores.reshape(1) |
| | elif scores.ndim == 2 and scores.shape[1] == 1: |
| | scores = scores[:, 0] |
| | elif scores.ndim > 1: |
| | scores = scores.reshape(scores.shape[0], -1)[:, 0] |
| | return scores |
| |
|
| | def _validate_token_limits(self, token_lengths: Iterable[int]) -> None: |
| | lengths = list(token_lengths) |
| | if any(length <= 0 for length in lengths): |
| | raise ValueError("Input pair must not be empty") |
| | if any(length > MAX_INPUT_TOKENS_PER_ITEM for length in lengths): |
| | raise ValueError( |
| | f"Each pair must be <= {MAX_INPUT_TOKENS_PER_ITEM} tokens before truncation" |
| | ) |
| | if sum(lengths) > MAX_TOTAL_TOKENS_PER_REQUEST: |
| | raise ValueError( |
| | f"Total tokens across request must be <= {MAX_TOTAL_TOKENS_PER_REQUEST}" |
| | ) |
| |
|
| | def _format_pair_text(self, query: str, document: str, instruction: str) -> str: |
| | if "{instruction}" not in self.manifest.pair_template: |
| | raise RuntimeError("Invalid pair template: missing {instruction}") |
| | if "{query}" not in self.manifest.pair_template: |
| | raise RuntimeError("Invalid pair template: missing {query}") |
| | if "{document}" not in self.manifest.pair_template: |
| | raise RuntimeError("Invalid pair template: missing {document}") |
| | return self.manifest.pair_template.format( |
| | instruction=instruction, |
| | query=query, |
| | document=document, |
| | ) |
| |
|
| | def _pair_token_len(self, pair_text: str) -> int: |
| | body_len = len( |
| | self.tokenizer.encode( |
| | pair_text, |
| | add_special_tokens=False, |
| | truncation=False, |
| | ) |
| | ) |
| | return self.static_token_cost + body_len |
| |
|
| | def _build_pair_ids(self, pair_text: str, seq_len: int) -> list[int]: |
| | body_budget = seq_len - self.static_token_cost |
| | if body_budget <= 0: |
| | raise RuntimeError(f"seq_len={seq_len} is too small for reranker template") |
| | body_ids = self.tokenizer.encode( |
| | pair_text, |
| | add_special_tokens=False, |
| | truncation=True, |
| | max_length=body_budget, |
| | ) |
| | return self.prefix_tokens + body_ids + self.suffix_tokens |
| |
|
| | def rerank( |
| | self, |
| | query: str, |
| | documents: list[str], |
| | *, |
| | top_n: int | None = None, |
| | instruction: str | None = None, |
| | ) -> tuple[list[dict[str, Any]], int]: |
| | if not query: |
| | raise ValueError("query must not be empty") |
| | if not documents: |
| | raise ValueError("documents must not be empty") |
| | if any(doc == "" for doc in documents): |
| | raise ValueError("documents must not contain empty strings") |
| |
|
| | instruction_text = instruction or DEFAULT_INSTRUCTION |
| | pair_texts = [self._format_pair_text(query, doc, instruction_text) for doc in documents] |
| | raw_lengths = [self._pair_token_len(text) for text in pair_texts] |
| | self._validate_token_limits(raw_lengths) |
| |
|
| | too_long = [idx for idx, length in enumerate(raw_lengths) if length > self.max_profile_seq] |
| | if too_long: |
| | first = too_long[0] |
| | raise ValueError( |
| | f"pair at index {first} has {raw_lengths[first]} tokens, " |
| | f"but compiled profiles only support up to {self.max_profile_seq}. " |
| | "Rebuild bundle with larger seq profiles." |
| | ) |
| |
|
| | effective_lengths = [min(length, self.max_profile_seq) for length in raw_lengths] |
| | chunks = self._plan_chunks(effective_lengths) |
| | pad_id = int(self.tokenizer.pad_token_id) |
| |
|
| | all_scores: list[Any] = [] |
| | prompt_tokens = 0 |
| |
|
| | for start, end, profile in chunks: |
| | chunk_texts = pair_texts[start:end] |
| | profile_batch = profile.entry.batch_size |
| | seq_len = profile.entry.seq_len |
| |
|
| | input_ids = self.np.full((profile_batch, seq_len), fill_value=pad_id, dtype=self.np.int32) |
| | attention_mask = self.np.zeros((profile_batch, seq_len), dtype=self.np.int32) |
| |
|
| | for row, pair_text in enumerate(chunk_texts): |
| | ids = self._build_pair_ids(pair_text, seq_len=seq_len) |
| | tlen = len(ids) |
| | offset = seq_len - tlen |
| | input_ids[row, offset:] = self.np.asarray(ids, dtype=self.np.int32) |
| | attention_mask[row, offset:] = 1 |
| | prompt_tokens += tlen |
| |
|
| | scores = self._predict_scores(profile, input_ids, attention_mask) |
| | all_scores.append(scores[: len(chunk_texts)]) |
| |
|
| | merged_scores = self.np.concatenate(all_scores, axis=0).astype(self.np.float32) |
| |
|
| | ranked = [ |
| | {"index": idx, "relevance_score": float(score)} |
| | for idx, score in enumerate(merged_scores.tolist()) |
| | ] |
| | ranked.sort(key=lambda item: item["relevance_score"], reverse=True) |
| |
|
| | n_results = len(ranked) if top_n is None else max(1, min(int(top_n), len(ranked))) |
| | return ranked[:n_results], prompt_tokens |
| |
|