from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any, Iterable from .manifest import BundleManifest, ProfileEntry MAX_INPUT_TOKENS_PER_ITEM = 8192 MAX_TOTAL_TOKENS_PER_REQUEST = 300000 DEFAULT_INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query." def _import_runtime_deps() -> tuple[Any, Any, Any]: try: import numpy as np import coremltools as ct from transformers import AutoTokenizer except Exception as exc: # pragma: no cover - runtime dependency check raise RuntimeError( "Missing runtime dependencies. Install numpy, coremltools, transformers." ) from exc return np, ct, AutoTokenizer @dataclass(slots=True) class LoadedProfile: entry: ProfileEntry model_path: Path model: Any | None = None class Qwen3AneRerankRuntime: def __init__(self, bundle_dir: str | Path, compute_units: str = "cpu_and_ne") -> None: np, ct, AutoTokenizer = _import_runtime_deps() self.np = np self.ct = ct self.bundle_dir = Path(bundle_dir).resolve() self.manifest = BundleManifest.load(self.bundle_dir) self.tokenizer = AutoTokenizer.from_pretrained( str(self.bundle_dir / self.manifest.tokenizer_dir), local_files_only=True, trust_remote_code=False, use_fast=True, ) self.tokenizer.padding_side = "left" if self.tokenizer.pad_token_id is None: if self.tokenizer.eos_token is not None: self.tokenizer.pad_token = self.tokenizer.eos_token elif self.tokenizer.unk_token is not None: self.tokenizer.pad_token = self.tokenizer.unk_token else: raise RuntimeError("Tokenizer has no pad/eos/unk token") self.prefix_tokens = self.tokenizer.encode( self.manifest.prefix_text, add_special_tokens=False, ) self.suffix_tokens = self.tokenizer.encode( self.manifest.suffix_text, add_special_tokens=False, ) self.static_token_cost = len(self.prefix_tokens) + len(self.suffix_tokens) self.compute_units = self._resolve_compute_units(compute_units) self.profiles: list[LoadedProfile] = [] for entry in self.manifest.profiles: package_path = self.bundle_dir / entry.package_path model_path = package_path if entry.compiled_path is not None: compiled_path = self.bundle_dir / entry.compiled_path if (compiled_path / "Manifest.json").exists(): model_path = compiled_path self.profiles.append(LoadedProfile(entry=entry, model_path=model_path)) if not self.profiles: raise RuntimeError("No profiles found in manifest") self.max_profile_batch = max(p.entry.batch_size for p in self.profiles) self.max_profile_seq = max(p.entry.seq_len for p in self.profiles) if self.static_token_cost >= self.max_profile_seq: raise RuntimeError( "Profile seq_len is too small for reranker prompt template. " f"Need > {self.static_token_cost}, got {self.max_profile_seq}." ) def _resolve_compute_units(self, raw: str) -> Any: mode = raw.strip().lower() cu = self.ct.ComputeUnit if mode == "cpu_and_ne" and hasattr(cu, "CPU_AND_NE"): return cu.CPU_AND_NE if mode == "all": return cu.ALL if mode == "cpu_only": return cu.CPU_ONLY if mode == "cpu_and_gpu": return cu.CPU_AND_GPU if mode == "cpu_and_ne" and not hasattr(cu, "CPU_AND_NE"): return cu.ALL raise ValueError(f"Unsupported compute unit mode: {raw}") def _get_model(self, profile: LoadedProfile) -> Any: if profile.model is None: profile.model = self.ct.models.MLModel( str(profile.model_path), compute_units=self.compute_units, ) return profile.model def _select_profile(self, batch_size: int, seq_len: int) -> LoadedProfile | None: candidates = [ p for p in self.profiles if p.entry.batch_size >= batch_size and p.entry.seq_len >= seq_len ] if not candidates: return None candidates.sort(key=lambda p: (p.entry.batch_size * p.entry.seq_len, p.entry.seq_len, p.entry.batch_size)) return candidates[0] def _plan_chunks(self, lengths: list[int]) -> list[tuple[int, int, LoadedProfile]]: chunks: list[tuple[int, int, LoadedProfile]] = [] i = 0 n = len(lengths) while i < n: best: tuple[int, LoadedProfile] | None = None max_batch = min(self.max_profile_batch, n - i) for b in range(max_batch, 0, -1): max_len = max(lengths[i : i + b]) profile = self._select_profile(batch_size=b, seq_len=max_len) if profile is not None: best = (b, profile) break if best is None: raise ValueError( f"No profile can serve items starting at index {i}. Required seq_len={lengths[i]}" ) b, profile = best chunks.append((i, i + b, profile)) i += b return chunks def _predict_scores(self, profile: LoadedProfile, input_ids: Any, attention_mask: Any) -> Any: model = self._get_model(profile) out = model.predict( { profile.entry.input_names[0]: input_ids, profile.entry.input_names[1]: attention_mask, } ) raw = out.get(profile.entry.output_name, next(iter(out.values()))) scores = self.np.asarray(raw, dtype=self.np.float32) if scores.ndim == 0: scores = scores.reshape(1) elif scores.ndim == 2 and scores.shape[1] == 1: scores = scores[:, 0] elif scores.ndim > 1: scores = scores.reshape(scores.shape[0], -1)[:, 0] return scores def _validate_token_limits(self, token_lengths: Iterable[int]) -> None: lengths = list(token_lengths) if any(length <= 0 for length in lengths): raise ValueError("Input pair must not be empty") if any(length > MAX_INPUT_TOKENS_PER_ITEM for length in lengths): raise ValueError( f"Each pair must be <= {MAX_INPUT_TOKENS_PER_ITEM} tokens before truncation" ) if sum(lengths) > MAX_TOTAL_TOKENS_PER_REQUEST: raise ValueError( f"Total tokens across request must be <= {MAX_TOTAL_TOKENS_PER_REQUEST}" ) def _format_pair_text(self, query: str, document: str, instruction: str) -> str: if "{instruction}" not in self.manifest.pair_template: raise RuntimeError("Invalid pair template: missing {instruction}") if "{query}" not in self.manifest.pair_template: raise RuntimeError("Invalid pair template: missing {query}") if "{document}" not in self.manifest.pair_template: raise RuntimeError("Invalid pair template: missing {document}") return self.manifest.pair_template.format( instruction=instruction, query=query, document=document, ) def _pair_token_len(self, pair_text: str) -> int: body_len = len( self.tokenizer.encode( pair_text, add_special_tokens=False, truncation=False, ) ) return self.static_token_cost + body_len def _build_pair_ids(self, pair_text: str, seq_len: int) -> list[int]: body_budget = seq_len - self.static_token_cost if body_budget <= 0: raise RuntimeError(f"seq_len={seq_len} is too small for reranker template") body_ids = self.tokenizer.encode( pair_text, add_special_tokens=False, truncation=True, max_length=body_budget, ) return self.prefix_tokens + body_ids + self.suffix_tokens def rerank( self, query: str, documents: list[str], *, top_n: int | None = None, instruction: str | None = None, ) -> tuple[list[dict[str, Any]], int]: if not query: raise ValueError("query must not be empty") if not documents: raise ValueError("documents must not be empty") if any(doc == "" for doc in documents): raise ValueError("documents must not contain empty strings") instruction_text = instruction or DEFAULT_INSTRUCTION pair_texts = [self._format_pair_text(query, doc, instruction_text) for doc in documents] raw_lengths = [self._pair_token_len(text) for text in pair_texts] self._validate_token_limits(raw_lengths) too_long = [idx for idx, length in enumerate(raw_lengths) if length > self.max_profile_seq] if too_long: first = too_long[0] raise ValueError( f"pair at index {first} has {raw_lengths[first]} tokens, " f"but compiled profiles only support up to {self.max_profile_seq}. " "Rebuild bundle with larger seq profiles." ) effective_lengths = [min(length, self.max_profile_seq) for length in raw_lengths] chunks = self._plan_chunks(effective_lengths) pad_id = int(self.tokenizer.pad_token_id) all_scores: list[Any] = [] prompt_tokens = 0 for start, end, profile in chunks: chunk_texts = pair_texts[start:end] profile_batch = profile.entry.batch_size seq_len = profile.entry.seq_len input_ids = self.np.full((profile_batch, seq_len), fill_value=pad_id, dtype=self.np.int32) attention_mask = self.np.zeros((profile_batch, seq_len), dtype=self.np.int32) for row, pair_text in enumerate(chunk_texts): ids = self._build_pair_ids(pair_text, seq_len=seq_len) tlen = len(ids) offset = seq_len - tlen input_ids[row, offset:] = self.np.asarray(ids, dtype=self.np.int32) attention_mask[row, offset:] = 1 prompt_tokens += tlen scores = self._predict_scores(profile, input_ids, attention_mask) all_scores.append(scores[: len(chunk_texts)]) merged_scores = self.np.concatenate(all_scores, axis=0).astype(self.np.float32) ranked = [ {"index": idx, "relevance_score": float(score)} for idx, score in enumerate(merged_scores.tolist()) ] ranked.sort(key=lambda item: item["relevance_score"], reverse=True) n_results = len(ranked) if top_n is None else max(1, min(int(top_n), len(ranked))) return ranked[:n_results], prompt_tokens