| """
|
| retrieval.py
|
|
|
| Dual-Pass BM25 Retrieval per Section 3 of the architecture document.
|
|
|
| Stage 1: Load precomputed BM25 index, run two passes:
|
| Pass A: JD skill aliases (expanded via skill_aliases.json taxonomy)
|
| Pass B: Production-context keywords (deployed, scale, serving, latency, ...)
|
| Safety Net: Rare-term pool for niche terms (pinecone, lambdarank)
|
|
|
| stage1_candidates = top_5000 ∪ rare_term_pool
|
|
|
| No network calls. BM25 index must be precomputed via precompute.py.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import logging
|
| import os
|
| import pickle
|
| import time
|
| from typing import Dict, List, Optional, Set, Tuple
|
|
|
| import numpy as np
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
| class NumpyBM25:
|
| """
|
| Drop-in replacement for BM25Okapi.get_scores() using a precomputed scipy
|
| sparse matrix of shape (vocab_size, n_docs).
|
|
|
| Each entry [term_idx, doc_idx] stores the precomputed value:
|
| idf(term) * bm25_tf_adjusted(term, doc)
|
|
|
| Scoring a query is a single sparse matrix-vector multiply:
|
| q_vec (vocab_size,) @ bm25_matrix (vocab_size × n_docs)
|
| -> scores (n_docs,) — sub-10 ms for 214 tokens × 100K docs.
|
|
|
| Compared with BM25Okapi.get_scores():
|
| BM25Okapi: 214 Python loops × 100K dict lookups = ~9.5 s
|
| NumpyBM25: one scipy sparse matvec = ~50 ms
|
| """
|
|
|
| def __init__(self, vocab: Dict[str, int], bm25_matrix) -> None:
|
| self.vocab = vocab
|
| self.bm25_matrix = bm25_matrix
|
| self._n_docs: int = bm25_matrix.shape[1]
|
| self._n_vocab: int = bm25_matrix.shape[0]
|
|
|
| def get_scores(self, query_tokens: List[str]) -> np.ndarray:
|
| """
|
| Score all documents for a list of query tokens.
|
| Matches BM25Okapi.get_scores() signature exactly.
|
| Returns np.ndarray of shape (n_docs,), dtype float32.
|
| """
|
| q_vec = np.zeros(self._n_vocab, dtype=np.float32)
|
| matched = 0
|
| for t in query_tokens:
|
| idx = self.vocab.get(t)
|
| if idx is not None:
|
| q_vec[idx] = 1.0
|
| matched += 1
|
| if matched == 0:
|
| return np.zeros(self._n_docs, dtype=np.float32)
|
| return np.asarray(q_vec @ self.bm25_matrix, dtype=np.float32).flatten()
|
|
|
|
|
| def load_numpy_bm25_artifacts(precomputed_dir: str) -> Optional[NumpyBM25]:
|
| """
|
| Load precomputed NumPy BM25 artifacts (vocab.pkl + bm25_matrix.npz).
|
| Returns a NumpyBM25 instance, or None if the artifacts don't exist yet
|
| (in which case callers should fall back to bm25_index.pkl).
|
| """
|
| vocab_path = os.path.join(precomputed_dir, "vocab.pkl")
|
| matrix_path = os.path.join(precomputed_dir, "bm25_matrix.npz")
|
|
|
| if not (os.path.isfile(vocab_path) and os.path.isfile(matrix_path)):
|
| return None
|
|
|
| try:
|
| from scipy.sparse import load_npz
|
| t0 = time.perf_counter()
|
| with open(vocab_path, "rb") as f:
|
| vocab = pickle.load(f)
|
| bm25_matrix = load_npz(matrix_path)
|
| logger.info(
|
| "NumPy BM25 loaded: vocab=%d shape=%s in %.3f s",
|
| len(vocab), bm25_matrix.shape, time.perf_counter() - t0,
|
| )
|
| return NumpyBM25(vocab, bm25_matrix)
|
| except Exception as exc:
|
| logger.warning("Failed to load NumPy BM25 artifacts (%s) — falling back to BM25Okapi", exc)
|
| return None
|
|
|
|
|
|
|
| def load_bm25_artifacts(precomputed_dir: str) -> Tuple[object, List[str], List[str]]:
|
| """
|
| Load the precomputed BM25 index and corpus metadata.
|
|
|
| Args:
|
| precomputed_dir: Path to the precomputed/ directory.
|
|
|
| Returns:
|
| (bm25_index, candidate_ids, tokenized_corpus)
|
|
|
| Raises:
|
| FileNotFoundError: If precomputed artifacts don't exist.
|
| RuntimeError: If artifacts are corrupted.
|
| """
|
| index_path = os.path.join(precomputed_dir, "bm25_index.pkl")
|
| ids_path = os.path.join(precomputed_dir, "candidate_ids.pkl")
|
|
|
| if not os.path.isfile(index_path):
|
| raise FileNotFoundError(
|
| f"BM25 index not found at {index_path}. "
|
| "Run precompute.py first."
|
| )
|
| if not os.path.isfile(ids_path):
|
| raise FileNotFoundError(
|
| f"Candidate IDs not found at {ids_path}. "
|
| "Run precompute.py first."
|
| )
|
|
|
| try:
|
| with open(index_path, "rb") as f:
|
| bm25 = pickle.load(f)
|
| with open(ids_path, "rb") as f:
|
| candidate_ids = pickle.load(f)
|
| except Exception as e:
|
| raise RuntimeError(f"Failed to load BM25 artifacts: {e}") from e
|
|
|
| logger.info(
|
| "BM25 index loaded: %d candidates indexed", len(candidate_ids)
|
| )
|
| return bm25, candidate_ids
|
|
|
|
|
| def tokenize_query(terms: List[str]) -> List[str]:
|
| """
|
| Tokenize a list of query terms for BM25.
|
| Splits multi-word terms, lowercases, deduplicates.
|
| """
|
| tokens = []
|
| for term in terms:
|
| tokens.extend(term.lower().split())
|
| return list(set(tokens))
|
|
|
|
|
| def run_dual_pass_retrieval(
|
| bm25,
|
| candidate_ids: List[str],
|
| jd_config,
|
| top_n: int = 5000,
|
| ) -> Tuple[List[str], Dict[str, float]]:
|
| """
|
| Execute dual-pass BM25 retrieval per Section 3.
|
|
|
| Pass A: All JD skill aliases (hard + preferred requirements)
|
| Pass B: Production-context keywords only
|
| Safety Net: Rare terms pool (pinecone, lambdarank)
|
|
|
| Returns:
|
| (stage1_candidate_ids, bm25_scores_dict)
|
| - stage1_candidate_ids: ordered list (best first) of top_5000 ∪ rare_pool
|
| - bm25_scores_dict: {candidate_id: float} for all retrieved candidates
|
| """
|
| t0 = time.time()
|
|
|
| query_a_terms = jd_config.get_all_query_terms()
|
| query_a_tokens = tokenize_query(query_a_terms)
|
| logger.info("Pass A query tokens (%d): %s...", len(query_a_tokens),
|
| query_a_tokens[:10])
|
|
|
| scores_a = bm25.get_scores(query_a_tokens)
|
|
|
| query_b_tokens = tokenize_query(jd_config.production_keywords)
|
| logger.info("Pass B query tokens (%d): %s", len(query_b_tokens), query_b_tokens)
|
|
|
| scores_b = bm25.get_scores(query_b_tokens)
|
| import numpy as np
|
| combined_scores = np.maximum(scores_a, scores_b)
|
|
|
| top_n_actual = min(top_n, len(candidate_ids))
|
| top_indices = np.argpartition(combined_scores, -top_n_actual)[-top_n_actual:]
|
| top_indices = top_indices[np.argsort(combined_scores[top_indices])[::-1]]
|
|
|
| top_candidates = [candidate_ids[i] for i in top_indices]
|
| top_scores = {candidate_ids[i]: float(combined_scores[i]) for i in top_indices}
|
|
|
| logger.info("Pass A+B union: %d candidates (target %d)", len(top_candidates), top_n)
|
|
|
| rare_pool_ids = set()
|
| rare_pool_scores = {}
|
|
|
| for rare_term in jd_config.rare_terms:
|
| rare_tokens = tokenize_query([rare_term])
|
| rare_scores = bm25.get_scores(rare_tokens)
|
| rare_nonzero = np.where(rare_scores > 0)[0]
|
| for idx in rare_nonzero:
|
| cid = candidate_ids[idx]
|
| if cid not in top_scores:
|
| rare_pool_ids.add(cid)
|
| rare_pool_scores[cid] = max(
|
| rare_pool_scores.get(cid, 0.0),
|
| float(rare_scores[idx])
|
| )
|
|
|
| logger.info("Rare-term safety net added %d additional candidates", len(rare_pool_ids))
|
|
|
|
|
| all_scores = {**top_scores, **rare_pool_scores}
|
|
|
| all_ordered = sorted(all_scores.keys(), key=lambda cid: all_scores[cid], reverse=True)
|
|
|
| elapsed = time.time() - t0
|
| logger.info(
|
| "Dual-pass retrieval complete: %d candidates in %.2fs",
|
| len(all_ordered), elapsed
|
| )
|
|
|
| return all_ordered, all_scores
|
|
|
|
|
| def retrieve_candidate_data(
|
| stage1_ids: List[str],
|
| candidates_path: str,
|
| ) -> Tuple[List[dict], Set[str]]:
|
| """
|
| Stream-read the candidates JSONL file and extract only the Stage 1 candidates.
|
|
|
| Args:
|
| stage1_ids: Ordered list of candidate IDs from retrieval.
|
| candidates_path: Path to candidates.jsonl.
|
|
|
| Returns:
|
| (candidates_list, missing_ids)
|
| - candidates_list: list of candidate dicts for stage1 IDs (order preserved)
|
| - missing_ids: IDs that were in stage1_ids but not found in the file
|
| """
|
| import json
|
|
|
| stage1_set = set(stage1_ids)
|
| found: Dict[str, dict] = {}
|
| malformed_count = 0
|
|
|
| with open(candidates_path, "r", encoding="utf-8") as f:
|
| for line_num, line in enumerate(f, 1):
|
| line = line.strip()
|
| if not line:
|
| continue
|
| try:
|
| candidate = json.loads(line)
|
| except json.JSONDecodeError as e:
|
| malformed_count += 1
|
| logger.warning(
|
| "Malformed JSON at line %d (skipped): %s", line_num, e
|
| )
|
| continue
|
|
|
| cid = candidate.get("candidate_id")
|
| if cid and cid in stage1_set:
|
| found[cid] = candidate
|
| if len(found) == len(stage1_set):
|
| break
|
|
|
| if malformed_count > 0:
|
| logger.warning("Skipped %d malformed JSONL lines", malformed_count)
|
|
|
| missing_ids = stage1_set - set(found.keys())
|
| if missing_ids:
|
| logger.warning(
|
| "%d stage1 candidates not found in JSONL: %s...",
|
| len(missing_ids),
|
| list(missing_ids)[:5]
|
| )
|
| ordered = [found[cid] for cid in stage1_ids if cid in found]
|
|
|
| logger.info(
|
| "Retrieved %d candidate records from JSONL (%d missing)",
|
| len(ordered), len(missing_ids)
|
| )
|
| return ordered, missing_ids
|
|
|
|
|
| if __name__ == "__main__":
|
| import sys
|
| import json
|
| import os
|
|
|
| base_dir = os.path.dirname(os.path.abspath(__file__))
|
| precomputed_dir = os.path.join(base_dir, "precomputed")
|
| candidates_path = os.path.join(base_dir, "candidates.jsonl")
|
|
|
| logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
|
|
| from jd_parser import parse_jd
|
| jd_config = parse_jd(os.path.join(base_dir, "data", "skill_aliases.json"))
|
|
|
| print("Loading BM25 artifacts...")
|
| bm25, candidate_ids = load_bm25_artifacts(precomputed_dir)
|
|
|
| print(f"Running dual-pass retrieval on {len(candidate_ids)} indexed candidates...")
|
| stage1_ids, bm25_scores = run_dual_pass_retrieval(bm25, candidate_ids, jd_config)
|
|
|
| print(f"Stage 1 retrieved: {len(stage1_ids)} candidates")
|
| print(f"Top 10 by BM25 score:")
|
| for i, cid in enumerate(stage1_ids[:10], 1):
|
| print(f" {i:2d}. {cid} score={bm25_scores[cid]:.4f}")
|
|
|
| import numpy as np
|
| scores = list(bm25_scores.values())
|
| print(f"Score stats: min={min(scores):.4f}, max={max(scores):.4f}, "
|
| f"median={float(np.median(scores)):.4f}, mean={float(np.mean(scores)):.4f}")
|
|
|