LordofMonarchs's picture
Upload folder using huggingface_hub
c754148 verified
Raw
History Blame Contribute Delete
11 kB
"""
retrieval.py
Dual-Pass BM25 Retrieval per Section 3 of the architecture document.
Stage 1: Load precomputed BM25 index, run two passes:
Pass A: JD skill aliases (expanded via skill_aliases.json taxonomy)
Pass B: Production-context keywords (deployed, scale, serving, latency, ...)
Safety Net: Rare-term pool for niche terms (pinecone, lambdarank)
stage1_candidates = top_5000 ∪ rare_term_pool
No network calls. BM25 index must be precomputed via precompute.py.
"""
from __future__ import annotations
import logging
import os
import pickle
import time
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
logger = logging.getLogger(__name__)
class NumpyBM25:
"""
Drop-in replacement for BM25Okapi.get_scores() using a precomputed scipy
sparse matrix of shape (vocab_size, n_docs).
Each entry [term_idx, doc_idx] stores the precomputed value:
idf(term) * bm25_tf_adjusted(term, doc)
Scoring a query is a single sparse matrix-vector multiply:
q_vec (vocab_size,) @ bm25_matrix (vocab_size × n_docs)
-> scores (n_docs,) — sub-10 ms for 214 tokens × 100K docs.
Compared with BM25Okapi.get_scores():
BM25Okapi: 214 Python loops × 100K dict lookups = ~9.5 s
NumpyBM25: one scipy sparse matvec = ~50 ms
"""
def __init__(self, vocab: Dict[str, int], bm25_matrix) -> None:
self.vocab = vocab
self.bm25_matrix = bm25_matrix
self._n_docs: int = bm25_matrix.shape[1]
self._n_vocab: int = bm25_matrix.shape[0]
def get_scores(self, query_tokens: List[str]) -> np.ndarray:
"""
Score all documents for a list of query tokens.
Matches BM25Okapi.get_scores() signature exactly.
Returns np.ndarray of shape (n_docs,), dtype float32.
"""
q_vec = np.zeros(self._n_vocab, dtype=np.float32)
matched = 0
for t in query_tokens:
idx = self.vocab.get(t)
if idx is not None:
q_vec[idx] = 1.0
matched += 1
if matched == 0:
return np.zeros(self._n_docs, dtype=np.float32)
return np.asarray(q_vec @ self.bm25_matrix, dtype=np.float32).flatten()
def load_numpy_bm25_artifacts(precomputed_dir: str) -> Optional[NumpyBM25]:
"""
Load precomputed NumPy BM25 artifacts (vocab.pkl + bm25_matrix.npz).
Returns a NumpyBM25 instance, or None if the artifacts don't exist yet
(in which case callers should fall back to bm25_index.pkl).
"""
vocab_path = os.path.join(precomputed_dir, "vocab.pkl")
matrix_path = os.path.join(precomputed_dir, "bm25_matrix.npz")
if not (os.path.isfile(vocab_path) and os.path.isfile(matrix_path)):
return None
try:
from scipy.sparse import load_npz
t0 = time.perf_counter()
with open(vocab_path, "rb") as f:
vocab = pickle.load(f)
bm25_matrix = load_npz(matrix_path)
logger.info(
"NumPy BM25 loaded: vocab=%d shape=%s in %.3f s",
len(vocab), bm25_matrix.shape, time.perf_counter() - t0,
)
return NumpyBM25(vocab, bm25_matrix)
except Exception as exc:
logger.warning("Failed to load NumPy BM25 artifacts (%s) — falling back to BM25Okapi", exc)
return None
def load_bm25_artifacts(precomputed_dir: str) -> Tuple[object, List[str], List[str]]:
"""
Load the precomputed BM25 index and corpus metadata.
Args:
precomputed_dir: Path to the precomputed/ directory.
Returns:
(bm25_index, candidate_ids, tokenized_corpus)
Raises:
FileNotFoundError: If precomputed artifacts don't exist.
RuntimeError: If artifacts are corrupted.
"""
index_path = os.path.join(precomputed_dir, "bm25_index.pkl")
ids_path = os.path.join(precomputed_dir, "candidate_ids.pkl")
if not os.path.isfile(index_path):
raise FileNotFoundError(
f"BM25 index not found at {index_path}. "
"Run precompute.py first."
)
if not os.path.isfile(ids_path):
raise FileNotFoundError(
f"Candidate IDs not found at {ids_path}. "
"Run precompute.py first."
)
try:
with open(index_path, "rb") as f:
bm25 = pickle.load(f)
with open(ids_path, "rb") as f:
candidate_ids = pickle.load(f)
except Exception as e:
raise RuntimeError(f"Failed to load BM25 artifacts: {e}") from e
logger.info(
"BM25 index loaded: %d candidates indexed", len(candidate_ids)
)
return bm25, candidate_ids
def tokenize_query(terms: List[str]) -> List[str]:
"""
Tokenize a list of query terms for BM25.
Splits multi-word terms, lowercases, deduplicates.
"""
tokens = []
for term in terms:
tokens.extend(term.lower().split())
return list(set(tokens))
def run_dual_pass_retrieval(
bm25,
candidate_ids: List[str],
jd_config,
top_n: int = 5000,
) -> Tuple[List[str], Dict[str, float]]:
"""
Execute dual-pass BM25 retrieval per Section 3.
Pass A: All JD skill aliases (hard + preferred requirements)
Pass B: Production-context keywords only
Safety Net: Rare terms pool (pinecone, lambdarank)
Returns:
(stage1_candidate_ids, bm25_scores_dict)
- stage1_candidate_ids: ordered list (best first) of top_5000 ∪ rare_pool
- bm25_scores_dict: {candidate_id: float} for all retrieved candidates
"""
t0 = time.time()
query_a_terms = jd_config.get_all_query_terms()
query_a_tokens = tokenize_query(query_a_terms)
logger.info("Pass A query tokens (%d): %s...", len(query_a_tokens),
query_a_tokens[:10])
scores_a = bm25.get_scores(query_a_tokens)
query_b_tokens = tokenize_query(jd_config.production_keywords)
logger.info("Pass B query tokens (%d): %s", len(query_b_tokens), query_b_tokens)
scores_b = bm25.get_scores(query_b_tokens)
import numpy as np
combined_scores = np.maximum(scores_a, scores_b)
top_n_actual = min(top_n, len(candidate_ids))
top_indices = np.argpartition(combined_scores, -top_n_actual)[-top_n_actual:]
top_indices = top_indices[np.argsort(combined_scores[top_indices])[::-1]]
top_candidates = [candidate_ids[i] for i in top_indices]
top_scores = {candidate_ids[i]: float(combined_scores[i]) for i in top_indices}
logger.info("Pass A+B union: %d candidates (target %d)", len(top_candidates), top_n)
rare_pool_ids = set()
rare_pool_scores = {}
for rare_term in jd_config.rare_terms:
rare_tokens = tokenize_query([rare_term])
rare_scores = bm25.get_scores(rare_tokens)
rare_nonzero = np.where(rare_scores > 0)[0]
for idx in rare_nonzero:
cid = candidate_ids[idx]
if cid not in top_scores:
rare_pool_ids.add(cid)
rare_pool_scores[cid] = max(
rare_pool_scores.get(cid, 0.0),
float(rare_scores[idx])
)
logger.info("Rare-term safety net added %d additional candidates", len(rare_pool_ids))
all_scores = {**top_scores, **rare_pool_scores}
all_ordered = sorted(all_scores.keys(), key=lambda cid: all_scores[cid], reverse=True)
elapsed = time.time() - t0
logger.info(
"Dual-pass retrieval complete: %d candidates in %.2fs",
len(all_ordered), elapsed
)
return all_ordered, all_scores
def retrieve_candidate_data(
stage1_ids: List[str],
candidates_path: str,
) -> Tuple[List[dict], Set[str]]:
"""
Stream-read the candidates JSONL file and extract only the Stage 1 candidates.
Args:
stage1_ids: Ordered list of candidate IDs from retrieval.
candidates_path: Path to candidates.jsonl.
Returns:
(candidates_list, missing_ids)
- candidates_list: list of candidate dicts for stage1 IDs (order preserved)
- missing_ids: IDs that were in stage1_ids but not found in the file
"""
import json
stage1_set = set(stage1_ids)
found: Dict[str, dict] = {}
malformed_count = 0
with open(candidates_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
candidate = json.loads(line)
except json.JSONDecodeError as e:
malformed_count += 1
logger.warning(
"Malformed JSON at line %d (skipped): %s", line_num, e
)
continue
cid = candidate.get("candidate_id")
if cid and cid in stage1_set:
found[cid] = candidate
if len(found) == len(stage1_set):
break # All found — stop early
if malformed_count > 0:
logger.warning("Skipped %d malformed JSONL lines", malformed_count)
missing_ids = stage1_set - set(found.keys())
if missing_ids:
logger.warning(
"%d stage1 candidates not found in JSONL: %s...",
len(missing_ids),
list(missing_ids)[:5]
)
ordered = [found[cid] for cid in stage1_ids if cid in found]
logger.info(
"Retrieved %d candidate records from JSONL (%d missing)",
len(ordered), len(missing_ids)
)
return ordered, missing_ids
if __name__ == "__main__":
import sys
import json
import os
base_dir = os.path.dirname(os.path.abspath(__file__))
precomputed_dir = os.path.join(base_dir, "precomputed")
candidates_path = os.path.join(base_dir, "candidates.jsonl")
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
from jd_parser import parse_jd
jd_config = parse_jd(os.path.join(base_dir, "data", "skill_aliases.json"))
print("Loading BM25 artifacts...")
bm25, candidate_ids = load_bm25_artifacts(precomputed_dir)
print(f"Running dual-pass retrieval on {len(candidate_ids)} indexed candidates...")
stage1_ids, bm25_scores = run_dual_pass_retrieval(bm25, candidate_ids, jd_config)
print(f"Stage 1 retrieved: {len(stage1_ids)} candidates")
print(f"Top 10 by BM25 score:")
for i, cid in enumerate(stage1_ids[:10], 1):
print(f" {i:2d}. {cid} score={bm25_scores[cid]:.4f}")
import numpy as np
scores = list(bm25_scores.values())
print(f"Score stats: min={min(scores):.4f}, max={max(scores):.4f}, "
f"median={float(np.median(scores)):.4f}, mean={float(np.mean(scores)):.4f}")