LordofMonarchs's picture
Upload folder using huggingface_hub
3751f09 verified
Raw
History Blame Contribute Delete
22.3 kB
from __future__ import annotations
import argparse
import json
import logging
import math
import os
import pickle
import sys
import time
from typing import Dict, List, Optional, Tuple
import numpy as np
_SCRIPTS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(_SCRIPTS_DIR)
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
for _p in [_SRC_DIR, _PROJECT_ROOT]:
if _p not in sys.path:
sys.path.insert(0, _p)
try:
from rank_bm25 import BM25Okapi
except ImportError:
import subprocess
print("rank_bm25 module not found, installing rank-bm25...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "rank-bm25==0.2.2"])
from rank_bm25 import BM25Okapi
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [precompute] %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
def tokenize_candidate(candidate: dict) -> List[str]:
"""
Build a BM25-indexable token list from a candidate record.
Combines: skill names, career descriptions, headline, summary.
Defensive: handles missing/null fields gracefully.
"""
tokens = []
for skill in (candidate.get("skills") or []):
name = (skill.get("name") or "").strip()
if name:
tokens.extend(name.lower().split())
# career history descriptions
for ch in (candidate.get("career_history") or []):
desc = (ch.get("description") or "").strip()
title = (ch.get("title") or "").strip()
if desc:
tokens.extend(desc.lower().split())
if title:
tokens.extend(title.lower().split())
# headline
profile = candidate.get("profile") or {}
headline = (profile.get("headline") or "").strip()
if headline:
tokens.extend(headline.lower().split())
# certifications
for cert in (candidate.get("certifications") or []):
name = (cert.get("name") or "").strip()
if name:
tokens.extend(name.lower().split())
return tokens
def stream_build_bm25_corpus(
candidates_path: str,
max_candidates: Optional[int] = None,
) -> Tuple[List[str], List[List[str]], int]:
"""
Stream-read candidates.jsonl and build the BM25 corpus.
Returns:
(candidate_ids, tokenized_corpus, malformed_count)
"""
candidate_ids = []
corpus = []
malformed_count = 0
total_lines = 0
logger.info("Building BM25 corpus from %s ...", candidates_path)
t0 = time.time()
with open(candidates_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
total_lines += 1
try:
candidate = json.loads(line)
except json.JSONDecodeError as e:
malformed_count += 1
logger.warning("Malformed JSON at line %d (skipped): %s", line_num, e)
continue
cid = candidate.get("candidate_id")
if not cid:
malformed_count += 1
logger.warning("Missing candidate_id at line %d (skipped)", line_num)
continue
tokens = tokenize_candidate(candidate)
candidate_ids.append(cid)
corpus.append(tokens)
if line_num % 10000 == 0:
elapsed = time.time() - t0
logger.info(
" Tokenized %d/%s candidates in %.1fs...",
line_num, max_candidates or "?", elapsed
)
if max_candidates and len(candidate_ids) >= max_candidates:
break
elapsed = time.time() - t0
logger.info(
"Corpus built: %d candidates, %d malformed lines, %.1fs",
len(candidate_ids), malformed_count, elapsed
)
return candidate_ids, corpus, malformed_count
def build_bm25_index(corpus: List[List[str]]):
"""Build BM25 index from tokenized corpus. Returns BM25Okapi object."""
logger.info("Building BM25Okapi index on %d documents...", len(corpus))
t0 = time.time()
bm25 = BM25Okapi(corpus)
elapsed = time.time() - t0
logger.info("BM25 index built in %.1fs", elapsed)
return bm25
def compute_offline_weak_labels(
candidates_path: str,
jd_config,
candidate_ids_set: set,
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]:
"""
Compute weak labels for training WITHOUT using bm25_score (non-circularity guarantee).
Label formula (Section 6):
weak_label = hard_req_coverage × consistency_score
bm25_score is EXPLICITLY EXCLUDED from label construction.
Returns:
(weak_labels_dict, hard_req_scores_dict, consistency_scores_dict)
"""
from features import (
c1_timeline_impossibility, c2_signup_anomaly, c3_salary_inversion,
c4_assessment_contradiction, c5_engagement_mismatch,
consistency_score as compute_consistency
)
from jd_parser import hard_req_coverage_score
logger.info("Computing offline weak labels (no bm25_score)...")
t0 = time.time()
weak_labels = {}
hard_req_scores = {}
consistency_scores = {}
processed = 0
with open(candidates_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
candidate = json.loads(line)
except json.JSONDecodeError:
continue
cid = candidate.get("candidate_id")
if not cid or cid not in candidate_ids_set:
continue
# hard requirement coverage
hrc = hard_req_coverage_score(candidate, jd_config)
c1 = c1_timeline_impossibility(candidate)
c2 = c2_signup_anomaly(candidate)
c3 = c3_salary_inversion(candidate)
c4 = c4_assessment_contradiction(candidate)
cons = c1 * c2 * c3 * c4
from features import (
detect_description_title_mismatch,
score_langchain_dabbler,
score_title_skill_discontinuity,
score_cv_speech_specialist,
)
# consulting fraction inline
consulting_m = sum(
float(r.get("duration_months") or 0)
for r in (candidate.get("career_history") or [])
if r.get("industry", "") in {"IT Services", "Consulting", "Professional Services", "BPO"}
and r.get("company_size", "") == "10001+"
)
total_m = sum(
float(r.get("duration_months") or 0)
for r in (candidate.get("career_history") or [])
)
cons_frac = (consulting_m / total_m) if total_m > 0 else 0.0
jd_penalty = max(0.0, 1.0 - (
0.90 * score_langchain_dabbler(candidate) +
0.85 * score_title_skill_discontinuity(candidate) +
0.75 * float(cons_frac > 0.95) +
0.65 * float(detect_description_title_mismatch(candidate) > 0.5) +
0.55 * score_cv_speech_specialist(candidate)
))
wl = hrc * cons * jd_penalty
hard_req_scores[cid] = hrc
consistency_scores[cid] = cons
weak_labels[cid] = wl
processed += 1
if processed % 10000 == 0:
logger.info(" Weak labels: %d computed...", processed)
elapsed = time.time() - t0
logger.info(
"Weak labels computed: %d candidates in %.1fs", len(weak_labels), elapsed
)
logger.info(
"Label stats: min=%.4f, max=%.4f, mean=%.4f, >0: %d",
min(weak_labels.values()),
max(weak_labels.values()),
sum(weak_labels.values()) / max(1, len(weak_labels)),
sum(1 for v in weak_labels.values() if v > 0),
)
return weak_labels, hard_req_scores, consistency_scores
def extract_training_features(
candidates_path: str,
candidate_ids: List[str],
jd_config,
hard_req_scores: Dict[str, float],
consistency_scores: Dict[str, float],
) -> Tuple[np.ndarray, List[str]]:
"""
Extract the full 22-feature matrix for all indexed candidates.
bm25_score is set to 0.0 for all candidates at training time.
Returns:
(feature_matrix: np.ndarray of shape [N, 22], ordered_ids)
"""
from features import build_feature_vector, FEATURE_COLUMNS
logger.info("Extracting 22-feature matrix for %d candidates...", len(candidate_ids))
t0 = time.time()
cid_set = set(candidate_ids)
feature_rows = {}
with open(candidates_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
candidate = json.loads(line)
except json.JSONDecodeError:
continue
cid = candidate.get("candidate_id")
if not cid or cid not in cid_set:
continue
try:
fv = build_feature_vector(
candidate, jd_config,
bm25_score=0.0,
stage1_bm25_median=0.0,
)
except Exception as e:
logger.warning("Feature extraction failed for %s: %s", cid, e)
fv = {col: 0.0 for col in FEATURE_COLUMNS}
feature_rows[cid] = [fv[col] for col in FEATURE_COLUMNS]
if len(feature_rows) % 10000 == 0:
logger.info(" Features: %d extracted...", len(feature_rows))
matrix = []
ordered_ids = []
for cid in candidate_ids:
if cid in feature_rows:
matrix.append(feature_rows[cid])
ordered_ids.append(cid)
X = np.array(matrix, dtype=np.float32)
elapsed = time.time() - t0
logger.info(
"Feature matrix shape: %s in %.1fs", X.shape, elapsed
)
return X, ordered_ids
def train_lightgbm(
X: np.ndarray,
weak_labels: Dict[str, float],
ordered_ids: List[str],
precomputed_dir: str,
) -> None:
"""
Train LightGBM with objective='lambdarank' and eval_at=[5, 10, 50].
LightGBM lambdarank has a hard limit of max_position (<=10000) rows per query.
With 100K candidates, we split into multiple query groups of GROUP_SIZE each.
Each group simulates a "mini-query" with the same JD — the model still learns
to rank candidates by relevance within each group, then generalizes across groups.
Labels are discretized to integer bins [0, 1, 2, 3] for lambdarank.
"""
try:
import lightgbm as lgb
except ImportError:
import subprocess
import sys
logger.info("lightgbm module not found, installing lightgbm...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "lightgbm==4.3.0"])
import lightgbm as lgb
from features import FEATURE_COLUMNS
logger.info("Training LightGBM LambdaRank model...")
t0 = time.time()
y_raw = np.array([weak_labels.get(cid, 0.0) for cid in ordered_ids], dtype=np.float32)
y_int = np.zeros(len(y_raw), dtype=np.int32)
y_int[y_raw > 0] = 1
y_int[y_raw > 0.33] = 2
y_int[y_raw > 0.66] = 3
logger.info(
"Label distribution: 0=%d, 1=%d, 2=%d, 3=%d",
(y_int == 0).sum(), (y_int == 1).sum(),
(y_int == 2).sum(), (y_int == 3).sum()
)
# spliting 100K candidates into groups of GROUP_SIZE
GROUP_SIZE = 5000
n = len(ordered_ids)
rng = np.random.default_rng(seed=42)
shuffle_idx = rng.permutation(n)
X_shuffled = X[shuffle_idx]
y_shuffled = y_int[shuffle_idx]
# build group sizes
n_groups = (n + GROUP_SIZE - 1) // GROUP_SIZE # ceiling division
group = []
for i in range(n_groups):
start = i * GROUP_SIZE
end = min(start + GROUP_SIZE, n)
group.append(end - start)
logger.info(
"LambdaRank: %d candidates split into %d query groups of size ~%d",
n, n_groups, GROUP_SIZE
)
train_data = lgb.Dataset(
X_shuffled, label=y_shuffled,
group=group,
feature_name=FEATURE_COLUMNS,
)
params = {
"objective": "lambdarank",
"metric": "ndcg",
"eval_at": [5, 10, 50],
"num_leaves": 63,
"learning_rate": 0.05,
"min_child_samples": 20,
"subsample": 0.8,
"colsample_bytree": 0.8,
"random_state": 42,
"n_jobs": -1,
"verbose": -1,
}
model = lgb.train(
params,
train_data,
num_boost_round=200,
valid_sets=[train_data],
callbacks=[
lgb.log_evaluation(period=50),
lgb.early_stopping(stopping_rounds=20, verbose=False),
],
)
elapsed = time.time() - t0
logger.info("LightGBM training complete in %.1fs", elapsed)
importances = dict(zip(FEATURE_COLUMNS, model.feature_importance(importance_type="gain")))
sorted_imp = sorted(importances.items(), key=lambda x: x[1], reverse=True)
logger.info("Top 5 feature importances (gain):")
for fname, imp in sorted_imp[:5]:
logger.info(" %s: %.2f", fname, imp)
model_path = os.path.join(precomputed_dir, "lgbm_model.pkl")
with open(model_path, "wb") as f:
pickle.dump(model, f)
logger.info("LightGBM model saved to %s", model_path)
def save_artifacts(
precomputed_dir: str,
bm25,
candidate_ids: List[str],
weak_labels: Dict[str, float],
) -> None:
"""Save BM25 index, candidate IDs, and weak labels to precomputed/."""
os.makedirs(precomputed_dir, exist_ok=True)
bm25_path = os.path.join(precomputed_dir, "bm25_index.pkl")
ids_path = os.path.join(precomputed_dir, "candidate_ids.pkl")
labels_path = os.path.join(precomputed_dir, "weak_labels.pkl")
with open(bm25_path, "wb") as f:
pickle.dump(bm25, f)
logger.info("BM25 index saved: %s (%.1f MB)", bm25_path,
os.path.getsize(bm25_path) / 1e6)
with open(ids_path, "wb") as f:
pickle.dump(candidate_ids, f)
logger.info("Candidate IDs saved: %s (%d IDs)", ids_path, len(candidate_ids))
with open(labels_path, "wb") as f:
pickle.dump(weak_labels, f)
logger.info("Weak labels saved: %s", labels_path)
def compute_and_save_static_features(
candidates_path: str,
candidate_ids: List[str],
precomputed_dir: str,
) -> None:
"""
Compute 18 JD-independent features for all candidate profiles and save them to static_features.pkl.
"""
from features import (
compute_yoe, compute_param_a_systems_depth, compute_param_b_availability,
compute_param_c_tenure, compute_param_d_notice_exp, compute_param_e_credibility,
compute_param_f_consulting, compute_param_g_location, compute_param_h_github,
compute_title_ai_fraction, compute_prod_signal_log, compute_flag_consulting_only,
compute_flag_title_chaser, compute_flag_langchain_dabbler, compute_flag_cv_specialist,
compute_flag_title_desc_mismatch, compute_flag_template_desc
)
logger.info("Computing 18 JD-independent features for all candidates offline...")
t0 = time.time()
candidate_ids_set = set(candidate_ids)
static_features = {}
with open(candidates_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
candidate = json.loads(line)
except json.JSONDecodeError:
continue
cid = candidate.get("candidate_id")
if not cid or cid not in candidate_ids_set:
continue
yoe = compute_yoe(candidate)
param_a = compute_param_a_systems_depth(candidate)
param_b = compute_param_b_availability(candidate)
param_c = compute_param_c_tenure(candidate)
param_d = compute_param_d_notice_exp(candidate)
param_e = compute_param_e_credibility(candidate)
param_f = compute_param_f_consulting(candidate)
param_g = compute_param_g_location(candidate)
param_h = compute_param_h_github(candidate)
title_ai_frac = compute_title_ai_fraction(candidate)
prod_sig_log = compute_prod_signal_log(candidate)
flag_consulting_only = compute_flag_consulting_only(candidate)
flag_title_chaser = compute_flag_title_chaser(candidate)
flag_langchain = compute_flag_langchain_dabbler(candidate.get("skills") or [])
flag_cv = compute_flag_cv_specialist(candidate.get("skills") or [])
flag_title_desc = compute_flag_title_desc_mismatch(candidate)
flag_template = compute_flag_template_desc(candidate)
interaction_yoe_x_prod = yoe * max(0.0, prod_sig_log)
static_features[cid] = {
"yoe": float(yoe),
"Param_A_Systems_Depth": float(param_a),
"Param_B_Availability": float(param_b),
"Param_C_Tenure": float(param_c),
"Param_D_Notice_Exp": float(param_d),
"Param_E_Credibility": float(param_e),
"Param_F_Consulting": float(param_f),
"Param_G_Location": float(param_g),
"Param_H_GitHub": float(param_h),
"title_ai_fraction": float(title_ai_frac),
"prod_signal_log": float(prod_sig_log),
"flag_consulting_only": float(flag_consulting_only),
"flag_title_chaser": float(flag_title_chaser),
"flag_langchain_dabbler": float(flag_langchain),
"flag_cv_specialist": float(flag_cv),
"flag_title_desc_mismatch": float(flag_title_desc),
"flag_template_desc": float(flag_template),
"interaction_yoe_x_prod": float(interaction_yoe_x_prod),
}
if len(static_features) % 25000 == 0:
logger.info(" Static features: %d calculated...", len(static_features))
out_path = os.path.join(precomputed_dir, "static_features.pkl")
with open(out_path, "wb") as f:
pickle.dump(static_features, f, protocol=pickle.HIGHEST_PROTOCOL)
elapsed = time.time() - t0
logger.info("Saved static features: %s (%d candidate profiles in %.1fs)",
out_path, len(static_features), elapsed)
def main(candidates_path: str, base_dir: str) -> None:
"""Main precomputation pipeline."""
precomputed_dir = os.path.join(base_dir, "precomputed")
data_dir = os.path.join(base_dir, "data")
aliases_path = os.path.join(data_dir, "skill_aliases.json")
os.makedirs(precomputed_dir, exist_ok=True)
if not os.path.isfile(candidates_path):
logger.error("Candidates file not found: %s", candidates_path)
sys.exit(1)
if not os.path.isfile(aliases_path):
logger.error("skill_aliases.json not found: %s", aliases_path)
sys.exit(1)
logger.info("=== Precompute Pipeline Starting ===")
logger.info("Candidates: %s", candidates_path)
logger.info("Base dir: %s", base_dir)
t_total = time.time()
from jd_parser import parse_jd
jd_config = parse_jd(aliases_path)
logger.info(
"JD config: %d hard reqs, %d preferred reqs",
len(jd_config.hard_requirements),
len(jd_config.preferred_requirements)
)
candidate_ids, corpus, malformed_count = stream_build_bm25_corpus(candidates_path)
bm25 = build_bm25_index(corpus)
del corpus
# compute weak labels
candidate_ids_set = set(candidate_ids)
weak_labels, hard_req_scores, consistency_scores = compute_offline_weak_labels(
candidates_path, jd_config, candidate_ids_set
)
# BM25 index + metadata
save_artifacts(precomputed_dir, bm25, candidate_ids, weak_labels)
# compute and save 18 static features offline
compute_and_save_static_features(candidates_path, candidate_ids, precomputed_dir)
# 22 feature matrix for training
X, ordered_ids = extract_training_features(
candidates_path, candidate_ids, jd_config, hard_req_scores, consistency_scores
)
# train LightGBM
train_lightgbm(X, weak_labels, ordered_ids, precomputed_dir)
total_elapsed = time.time() - t_total
logger.info("=== Precompute Complete in %.1fs ===", total_elapsed)
logger.info("Artifacts in: %s", precomputed_dir)
# print summary
artifact_sizes = {}
for fname in ["bm25_index.pkl", "candidate_ids.pkl", "weak_labels.pkl", "lgbm_model.pkl"]:
fpath = os.path.join(precomputed_dir, fname)
if os.path.isfile(fpath):
artifact_sizes[fname] = os.path.getsize(fpath) / 1e6
logger.info("Artifact sizes (MB):")
for fname, size_mb in artifact_sizes.items():
logger.info(" %s: %.1f MB", fname, size_mb)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Offline pre-computation: BM25 indexing + LightGBM training"
)
parser.add_argument(
"--candidates",
default=os.path.join(_PROJECT_ROOT, "candidates.jsonl"),
help="Path to candidates JSONL file (default: project_root/candidates.jsonl)",
)
parser.add_argument(
"--base-dir",
default=_PROJECT_ROOT,
help="Base directory for data/ and precomputed/ (default: project root)",
)
args = parser.parse_args()
candidates_path = os.path.abspath(args.candidates)
base_dir = os.path.abspath(args.base_dir)
main(candidates_path, base_dir)