| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Soft-Label Cross-Encoder Reranker Training |
| | |
| | Trains a reranker using continuous relevance scores (soft labels). |
| | Dataset format: {"query": "...", "text": "...", "score": 0.0-1.0} |
| | """ |
| |
|
| | import logging |
| | import os |
| | import math |
| | from collections import defaultdict |
| | import trackio |
| | import numpy as np |
| | from datasets import load_dataset |
| | from sentence_transformers.cross_encoder import ( |
| | CrossEncoder, |
| | CrossEncoderTrainer, |
| | CrossEncoderTrainingArguments, |
| | ) |
| | from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator |
| | from scipy.stats import spearmanr |
| | from transformers import TrainerCallback |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | DATASET_NAME = os.environ.get("DATASET_NAME", "amanwithaplan/arcade-reranker-data") |
| | HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "amanwithaplan/arcade-reranker") |
| | BASE_MODEL = os.environ.get("BASE_MODEL", "Alibaba-NLP/gte-reranker-modernbert-base") |
| | NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "10")) |
| | BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "16")) |
| | LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-5")) |
| | MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "512")) |
| | RUN_NAME = os.environ.get("RUN_NAME", "reranker-03130903") |
| | SPACE_ID = os.environ.get("TRACKIO_SPACE_ID", "amanwithaplan/trackio") |
| |
|
| |
|
| | def dcg_at_k(relevances, k): |
| | """Compute DCG@k.""" |
| | relevances = np.array(relevances)[:k] |
| | if len(relevances) == 0: |
| | return 0.0 |
| | |
| | discounts = np.log2(np.arange(len(relevances)) + 2) |
| | return np.sum(relevances / discounts) |
| |
|
| |
|
| | def ndcg_at_k(predicted_order, true_relevances, k): |
| | """ |
| | Compute NDCG@k. |
| | |
| | predicted_order: indices of docs sorted by model score (descending) |
| | true_relevances: ground truth relevance scores for each doc |
| | """ |
| | |
| | predicted_relevances = [true_relevances[i] for i in predicted_order] |
| |
|
| | |
| | ideal_relevances = sorted(true_relevances, reverse=True) |
| |
|
| | dcg = dcg_at_k(predicted_relevances, k) |
| | idcg = dcg_at_k(ideal_relevances, k) |
| |
|
| | if idcg == 0: |
| | return 0.0 |
| | return dcg / idcg |
| |
|
| |
|
| | def mrr(predicted_order, true_relevances, threshold=0.5): |
| | """ |
| | Compute MRR (Mean Reciprocal Rank). |
| | |
| | Returns 1/rank of first relevant doc (relevance > threshold). |
| | """ |
| | for rank, idx in enumerate(predicted_order, start=1): |
| | if true_relevances[idx] > threshold: |
| | return 1.0 / rank |
| | return 0.0 |
| |
|
| |
|
| | def evaluate_ranking(model, eval_dataset): |
| | """ |
| | Proper ranking evaluation: group by query, compute NDCG and MRR. |
| | |
| | This measures what we actually care about: |
| | "Given a query with multiple docs, does the model rank them correctly?" |
| | """ |
| | |
| | query_groups = defaultdict(list) |
| | for item in eval_dataset: |
| | query_groups[item["sentence1"]].append({ |
| | "text": item["sentence2"], |
| | "label": item["label"] |
| | }) |
| |
|
| | |
| | query_groups = {q: docs for q, docs in query_groups.items() if len(docs) >= 2} |
| |
|
| | if not query_groups: |
| | return {"ndcg@3": 0.0, "ndcg@5": 0.0, "mrr": 0.0, "n_queries": 0} |
| |
|
| | ndcg_3_scores = [] |
| | ndcg_5_scores = [] |
| | mrr_scores = [] |
| | rank_correlations = [] |
| |
|
| | for query, docs in query_groups.items(): |
| | |
| | pairs = [(query, d["text"]) for d in docs] |
| | predictions = model.predict(pairs, show_progress_bar=False) |
| |
|
| | true_relevances = [d["label"] for d in docs] |
| |
|
| | |
| | predicted_order = np.argsort(predictions)[::-1].tolist() |
| |
|
| | |
| | ndcg_3_scores.append(ndcg_at_k(predicted_order, true_relevances, k=3)) |
| | ndcg_5_scores.append(ndcg_at_k(predicted_order, true_relevances, k=5)) |
| | mrr_scores.append(mrr(predicted_order, true_relevances, threshold=0.5)) |
| |
|
| | |
| | if len(set(true_relevances)) > 1: |
| | corr = spearmanr(predictions, true_relevances).correlation |
| | if not math.isnan(corr): |
| | rank_correlations.append(corr) |
| |
|
| | return { |
| | "ndcg@3": np.mean(ndcg_3_scores), |
| | "ndcg@5": np.mean(ndcg_5_scores), |
| | "mrr": np.mean(mrr_scores), |
| | "rank_corr": np.mean(rank_correlations) if rank_correlations else 0.0, |
| | "n_queries": len(query_groups), |
| | } |
| |
|
| |
|
| | class DomainEvalCallback(TrainerCallback): |
| | """Callback to log proper ranking metrics during training.""" |
| |
|
| | def __init__(self, model, eval_dataset_full): |
| | self.model = model |
| | self.eval_dataset_full = eval_dataset_full |
| |
|
| | def on_evaluate(self, args, state, control, **kwargs): |
| | """Run after each evaluation step.""" |
| | metrics = evaluate_ranking(self.model, self.eval_dataset_full) |
| |
|
| | |
| | trackio.log({ |
| | "domain/ndcg@3": metrics["ndcg@3"], |
| | "domain/ndcg@5": metrics["ndcg@5"], |
| | "domain/mrr": metrics["mrr"], |
| | "domain/rank_corr": metrics["rank_corr"], |
| | }) |
| |
|
| | logger.info( |
| | f"Domain eval - NDCG@3: {metrics['ndcg@3']:.4f}, " |
| | f"NDCG@5: {metrics['ndcg@5']:.4f}, " |
| | f"MRR: {metrics['mrr']:.4f}, " |
| | f"RankCorr: {metrics['rank_corr']:.4f} " |
| | f"(n={metrics['n_queries']} queries)" |
| | ) |
| |
|
| |
|
| | def evaluate_by_type(model, eval_dataset, type_column="type"): |
| | """Evaluate ranking metrics per content type.""" |
| | if type_column not in eval_dataset.column_names: |
| | return {} |
| |
|
| | |
| | by_type = defaultdict(list) |
| | for item in eval_dataset: |
| | by_type[item[type_column]].append(item) |
| |
|
| | results = {} |
| | for content_type, items in by_type.items(): |
| | |
| | class TypeDataset: |
| | def __init__(self, items): |
| | self.items = items |
| | def __iter__(self): |
| | return iter(self.items) |
| | @property |
| | def column_names(self): |
| | return ["sentence1", "sentence2", "label"] |
| |
|
| | type_metrics = evaluate_ranking(model, TypeDataset(items)) |
| |
|
| | if type_metrics["n_queries"] >= 2: |
| | results[f"{content_type}_ndcg@5"] = type_metrics["ndcg@5"] |
| | results[f"{content_type}_mrr"] = type_metrics["mrr"] |
| | results[f"{content_type}_n_queries"] = type_metrics["n_queries"] |
| |
|
| | return results |
| |
|
| |
|
| | def main(): |
| | |
| | trackio.init( |
| | project="arcade-reranker", |
| | name=RUN_NAME, |
| | space_id=SPACE_ID, |
| | config={ |
| | "model": BASE_MODEL, |
| | "dataset": DATASET_NAME, |
| | "learning_rate": LEARNING_RATE, |
| | "num_epochs": NUM_EPOCHS, |
| | "batch_size": BATCH_SIZE, |
| | "max_seq_length": MAX_SEQ_LENGTH, |
| | } |
| | ) |
| |
|
| | logger.info(f"Configuration:") |
| | logger.info(f" Dataset: {DATASET_NAME}") |
| | logger.info(f" Base model: {BASE_MODEL}") |
| | logger.info(f" Epochs: {NUM_EPOCHS}") |
| | logger.info(f" Run name: {RUN_NAME}") |
| | logger.info(f" Trackio space: {SPACE_ID}") |
| |
|
| | model = CrossEncoder(BASE_MODEL, max_length=MAX_SEQ_LENGTH) |
| |
|
| | logger.info(f"Loading dataset: {DATASET_NAME}") |
| | dataset = load_dataset(DATASET_NAME, split="train") |
| |
|
| | |
| | type_counts = defaultdict(int) |
| | if "type" in dataset.column_names: |
| | for item in dataset: |
| | type_counts[item["type"]] += 1 |
| | logger.info(f"Dataset composition: {dict(type_counts)}") |
| |
|
| | |
| | for content_type, count in type_counts.items(): |
| | trackio.log({f"data/{content_type}_count": count}) |
| |
|
| | trackio.log({"data/total_examples": len(dataset)}) |
| | logger.info(f"Total examples: {len(dataset)}") |
| |
|
| | |
| | dataset = dataset.rename_columns({ |
| | "query": "sentence1", |
| | "text": "sentence2", |
| | "score": "label" |
| | }) |
| |
|
| | |
| | eval_size = min(400, int(len(dataset) * 0.15)) |
| | splits = dataset.train_test_split(test_size=eval_size, seed=42) |
| |
|
| | |
| | eval_dataset_full = splits["test"] |
| |
|
| | |
| | train_dataset = splits["train"].select_columns(["sentence1", "sentence2", "label"]) |
| | eval_dataset = splits["test"].select_columns(["sentence1", "sentence2", "label"]) |
| |
|
| | trackio.log({ |
| | "data/train_size": len(train_dataset), |
| | "data/eval_size": len(eval_dataset), |
| | }) |
| | logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") |
| |
|
| | |
| | logger.info("Evaluating base model on eval set...") |
| | base_metrics = evaluate_ranking(model, eval_dataset_full) |
| | for key, value in base_metrics.items(): |
| | trackio.log({f"base_model/{key}": value}) |
| | logger.info(f"Base model metrics: {base_metrics}") |
| |
|
| | |
| | evaluator = CrossEncoderNanoBEIREvaluator( |
| | dataset_names=["msmarco", "nfcorpus", "nq"], |
| | batch_size=BATCH_SIZE, |
| | ) |
| |
|
| | args = CrossEncoderTrainingArguments( |
| | output_dir="models/reranker", |
| | num_train_epochs=NUM_EPOCHS, |
| | per_device_train_batch_size=BATCH_SIZE, |
| | per_device_eval_batch_size=BATCH_SIZE, |
| | learning_rate=LEARNING_RATE, |
| | warmup_ratio=0.1, |
| | bf16=True, |
| | eval_strategy="steps", |
| | eval_steps=25, |
| | save_strategy="steps", |
| | save_steps=25, |
| | save_total_limit=5, |
| | logging_steps=25, |
| | logging_first_step=True, |
| | load_best_model_at_end=True, |
| | metric_for_best_model="eval_loss", |
| | greater_is_better=False, |
| | push_to_hub=True, |
| | hub_model_id=HUB_MODEL_ID, |
| | hub_strategy="every_save", |
| | report_to="trackio", |
| | run_name=RUN_NAME, |
| | ) |
| |
|
| | |
| | domain_callback = DomainEvalCallback(model, eval_dataset_full) |
| |
|
| | trainer = CrossEncoderTrainer( |
| | model=model, |
| | args=args, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | evaluator=evaluator, |
| | callbacks=[domain_callback], |
| | ) |
| |
|
| | logger.info("Starting training...") |
| | trainer.train() |
| |
|
| | |
| | logger.info("Running final ranking evaluation...") |
| | final_metrics = evaluate_ranking(model, eval_dataset_full) |
| | for key, value in final_metrics.items(): |
| | trackio.log({f"final/{key}": value}) |
| | logger.info(f"Final metrics: {final_metrics}") |
| |
|
| | |
| | logger.info("Evaluating by content type...") |
| | type_metrics = evaluate_by_type(model, eval_dataset_full) |
| | for key, value in type_metrics.items(): |
| | trackio.log({f"final/by_type/{key}": value}) |
| | logger.info(f"Per-type metrics: {type_metrics}") |
| |
|
| | |
| | trackio.log({ |
| | "improvement/ndcg5_delta": final_metrics["ndcg@5"] - base_metrics["ndcg@5"], |
| | "improvement/mrr_delta": final_metrics["mrr"] - base_metrics["mrr"], |
| | }) |
| |
|
| | logger.info(f"Pushing final model to {HUB_MODEL_ID}") |
| | model.push_to_hub(HUB_MODEL_ID, exist_ok=True) |
| |
|
| | trackio.finish() |
| | logger.info("Done!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|