Intelligent-Candidate-Ranking-Model / scripts /rebuild_fast_artifacts.py
LordofMonarchs's picture
Upload folder using huggingface_hub
c754148 verified
Raw
History Blame Contribute Delete
7.73 kB
from __future__ import annotations
import json
import logging
import os
import pickle
import sys
import time
_SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(_SCRIPTS_DIR)
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
for _p in [_SRC_DIR, _PROJECT_ROOT]:
if _p not in sys.path:
sys.path.insert(0, _p)
import numpy as np
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
BASE_DIR = _PROJECT_ROOT
PRECOMPUTED_DIR = os.path.join(BASE_DIR, "precomputed")
CANDIDATES_PATH = os.path.join(BASE_DIR, "candidates.jsonl")
def build_numpy_bm25_artifacts(bm25, precomputed_dir: str) -> None:
"""
Build scipy sparse BM25 score matrix from an existing BM25Okapi object.
Saves:
vocab.pkl - {term: row_index} mapping (tiny, fast to load)
bm25_matrix.npz - scipy sparse CSR (vocab_size × n_docs), float32
Each entry [term_idx, doc_idx] = precomputed
idf(term) × bm25_tf_adjusted(term, doc)
Scoring at runtime:
q_vec (1 × vocab_size) @ bm25_matrix (vocab_size × n_docs)
→ (1 × n_docs) dense result in a single scipy sparse op (<10 ms).
"""
try:
from scipy.sparse import coo_matrix, save_npz
except ImportError:
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"])
from scipy.sparse import coo_matrix, save_npz
logger.info("Building NumPy sparse BM25 matrix …")
t0 = time.perf_counter()
k1: float = getattr(bm25, "k1", 1.5)
b: float = getattr(bm25, "b", 0.75)
avgdl: float = float(bm25.avgdl)
doc_len_arr = np.array(bm25.doc_len, dtype=np.float32)
n_docs: int = int(bm25.corpus_size)
# term -> row index
vocab: dict = {term: idx for idx, term in enumerate(bm25.idf.keys())}
idf_array = np.array([bm25.idf[term] for term in vocab], dtype=np.float32)
n_vocab: int = len(vocab)
logger.info(" vocab_size=%d n_docs=%d", n_vocab, n_docs)
rows_list: list = []
cols_list: list = []
data_list: list = []
checkpoint = max(1, n_docs // 10)
for doc_idx, doc_freq_dict in enumerate(bm25.doc_freqs):
dl = float(doc_len_arr[doc_idx])
denom_k = k1 * (1.0 - b + b * dl / avgdl)
for term, tf in doc_freq_dict.items():
term_idx = vocab.get(term)
if term_idx is None:
continue
tf_f = float(tf)
tf_adj = (tf_f * (k1 + 1.0)) / (tf_f + denom_k)
rows_list.append(term_idx)
cols_list.append(doc_idx)
data_list.append(float(idf_array[term_idx]) * tf_adj)
if doc_idx % checkpoint == 0 and doc_idx > 0:
logger.info(" … %d / %d docs processed", doc_idx, n_docs)
nnz = len(data_list)
logger.info(" COO built: nnz=%d (%.1f s)", nnz, time.perf_counter() - t0)
bm25_matrix = coo_matrix(
(
np.array(data_list, dtype=np.float32),
(np.array(rows_list, dtype=np.int32),
np.array(cols_list, dtype=np.int32)),
),
shape=(n_vocab, n_docs),
).tocsr()
elapsed = time.perf_counter() - t0
logger.info(" CSR matrix: shape=%s nnz=%d (%.1f s total)",
bm25_matrix.shape, bm25_matrix.nnz, elapsed)
vocab_path = os.path.join(precomputed_dir, "vocab.pkl")
matrix_path = os.path.join(precomputed_dir, "bm25_matrix.npz")
with open(vocab_path, "wb") as f:
pickle.dump(vocab, f, protocol=pickle.HIGHEST_PROTOCOL)
save_npz(matrix_path, bm25_matrix)
logger.info(" Saved vocab.pkl (%d terms)", n_vocab)
logger.info(" Saved bm25_matrix.npz (%.1f MB)",
os.path.getsize(matrix_path) / 1e6)
def build_candidate_offset_index(candidates_path: str, precomputed_dir: str) -> None:
"""
Scan candidates.jsonl once in binary mode and record the byte offset of
each candidate_id.
Saves candidate_offsets.pkl: {candidate_id: byte_offset}
At runtime Stage 2 uses f.seek(offset) + f.readline() for each of the
~8500 stage-1 candidates instead of streaming all 487 MB. Reduces
Stage 2 from ~4 s to ~0.1–0.3 s.
"""
logger.info("Building candidate byte-offset index …")
t0 = time.perf_counter()
offsets: dict = {}
size_bytes = os.path.getsize(candidates_path)
with open(candidates_path, "rb") as f:
while True:
offset = f.tell()
raw_line = f.readline()
if not raw_line:
break
stripped = raw_line.strip()
if not stripped:
continue
try:
cid = json.loads(stripped).get("candidate_id")
if cid:
offsets[cid] = offset
except json.JSONDecodeError:
pass
if len(offsets) % 10_000 == 0 and len(offsets) > 0:
pct = f.tell() / size_bytes * 100
logger.info(" … %d candidates indexed (%.0f%% of file)", len(offsets), pct)
elapsed = time.perf_counter() - t0
logger.info(" Offset index: %d candidates in %.1f s", len(offsets), elapsed)
out_path = os.path.join(precomputed_dir, "candidate_offsets.pkl")
with open(out_path, "wb") as f:
pickle.dump(offsets, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(" Saved candidate_offsets.pkl (%.1f MB)",
os.path.getsize(out_path) / 1e6)
def export_lgbm_native(precomputed_dir: str) -> None:
"""
Re-save lgbm_model.pkl in LightGBM's native text format.
lgb.Booster(model_file=...) loads ~10-20x faster than pickle.
"""
import lightgbm as lgb
pkl_path = os.path.join(precomputed_dir, "lgbm_model.pkl")
txt_path = os.path.join(precomputed_dir, "lgbm_model.txt")
logger.info("Exporting LightGBM model to native text format …")
t0 = time.perf_counter()
with open(pkl_path, "rb") as f:
model = pickle.load(f)
model.save_model(txt_path)
logger.info(" Saved lgbm_model.txt (%.1f MB) in %.2f s",
os.path.getsize(txt_path) / 1e6,
time.perf_counter() - t0)
def main() -> None:
logger.info("=" * 60)
logger.info("REBUILD FAST ARTIFACTS")
logger.info("=" * 60)
t_total = time.perf_counter()
bm25_pkl = os.path.join(PRECOMPUTED_DIR, "bm25_index.pkl")
logger.info("Loading bm25_index.pkl (%.1f MB) …",
os.path.getsize(bm25_pkl) / 1e6)
t0 = time.perf_counter()
with open(bm25_pkl, "rb") as f:
bm25 = pickle.load(f)
logger.info(" Loaded in %.2f s", time.perf_counter() - t0)
build_numpy_bm25_artifacts(bm25, PRECOMPUTED_DIR)
build_candidate_offset_index(CANDIDATES_PATH, PRECOMPUTED_DIR)
export_lgbm_native(PRECOMPUTED_DIR)
logger.info("=" * 60)
logger.info("ALL ARTIFACTS BUILT in %.1f s", time.perf_counter() - t_total)
logger.info("New files in precomputed/:")
for fname in ["vocab.pkl", "bm25_matrix.npz", "candidate_offsets.pkl", "lgbm_model.txt"]:
fpath = os.path.join(PRECOMPUTED_DIR, fname)
if os.path.isfile(fpath):
logger.info(" %-30s %.1f MB", fname, os.path.getsize(fpath) / 1e6)
logger.info("rank.py will auto-detect these and use the fast paths.")
logger.info("=" * 60)
if __name__ == "__main__":
main()