arcade-training-scripts / train_reranker.py
amanwithaplan's picture
Eval and save every 25 steps for detailed curves
a21d6c7 verified
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "sentence-transformers[train]>=4.0",
# "datasets",
# "torch>=2.4",
# "transformers>=4.48",
# "trackio",
# "scipy",
# "numpy",
# ]
# ///
"""
Soft-Label Cross-Encoder Reranker Training
Trains a reranker using continuous relevance scores (soft labels).
Dataset format: {"query": "...", "text": "...", "score": 0.0-1.0}
"""
import logging
import os
import math
from collections import defaultdict
import trackio
import numpy as np
from datasets import load_dataset
from sentence_transformers.cross_encoder import (
CrossEncoder,
CrossEncoderTrainer,
CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
from scipy.stats import spearmanr
from transformers import TrainerCallback
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
DATASET_NAME = os.environ.get("DATASET_NAME", "amanwithaplan/arcade-reranker-data")
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "amanwithaplan/arcade-reranker")
BASE_MODEL = os.environ.get("BASE_MODEL", "Alibaba-NLP/gte-reranker-modernbert-base")
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "10"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "16"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-5"))
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "512"))
RUN_NAME = os.environ.get("RUN_NAME", "reranker-03130903")
SPACE_ID = os.environ.get("TRACKIO_SPACE_ID", "amanwithaplan/trackio")
def dcg_at_k(relevances, k):
"""Compute DCG@k."""
relevances = np.array(relevances)[:k]
if len(relevances) == 0:
return 0.0
# DCG = sum of rel_i / log2(i+2) for i in 0..k-1
discounts = np.log2(np.arange(len(relevances)) + 2)
return np.sum(relevances / discounts)
def ndcg_at_k(predicted_order, true_relevances, k):
"""
Compute NDCG@k.
predicted_order: indices of docs sorted by model score (descending)
true_relevances: ground truth relevance scores for each doc
"""
# Get relevances in predicted order
predicted_relevances = [true_relevances[i] for i in predicted_order]
# Ideal order: sort by true relevance descending
ideal_relevances = sorted(true_relevances, reverse=True)
dcg = dcg_at_k(predicted_relevances, k)
idcg = dcg_at_k(ideal_relevances, k)
if idcg == 0:
return 0.0
return dcg / idcg
def mrr(predicted_order, true_relevances, threshold=0.5):
"""
Compute MRR (Mean Reciprocal Rank).
Returns 1/rank of first relevant doc (relevance > threshold).
"""
for rank, idx in enumerate(predicted_order, start=1):
if true_relevances[idx] > threshold:
return 1.0 / rank
return 0.0
def evaluate_ranking(model, eval_dataset):
"""
Proper ranking evaluation: group by query, compute NDCG and MRR.
This measures what we actually care about:
"Given a query with multiple docs, does the model rank them correctly?"
"""
# Group samples by query
query_groups = defaultdict(list)
for item in eval_dataset:
query_groups[item["sentence1"]].append({
"text": item["sentence2"],
"label": item["label"]
})
# Filter to queries with multiple docs (need at least 2 to rank)
query_groups = {q: docs for q, docs in query_groups.items() if len(docs) >= 2}
if not query_groups:
return {"ndcg@3": 0.0, "ndcg@5": 0.0, "mrr": 0.0, "n_queries": 0}
ndcg_3_scores = []
ndcg_5_scores = []
mrr_scores = []
rank_correlations = []
for query, docs in query_groups.items():
# Get model predictions for this query's docs
pairs = [(query, d["text"]) for d in docs]
predictions = model.predict(pairs, show_progress_bar=False)
true_relevances = [d["label"] for d in docs]
# Get predicted order: indices sorted by prediction descending
predicted_order = np.argsort(predictions)[::-1].tolist()
# Compute metrics
ndcg_3_scores.append(ndcg_at_k(predicted_order, true_relevances, k=3))
ndcg_5_scores.append(ndcg_at_k(predicted_order, true_relevances, k=5))
mrr_scores.append(mrr(predicted_order, true_relevances, threshold=0.5))
# Rank correlation within this query
if len(set(true_relevances)) > 1: # Need variance
corr = spearmanr(predictions, true_relevances).correlation
if not math.isnan(corr):
rank_correlations.append(corr)
return {
"ndcg@3": np.mean(ndcg_3_scores),
"ndcg@5": np.mean(ndcg_5_scores),
"mrr": np.mean(mrr_scores),
"rank_corr": np.mean(rank_correlations) if rank_correlations else 0.0,
"n_queries": len(query_groups),
}
class DomainEvalCallback(TrainerCallback):
"""Callback to log proper ranking metrics during training."""
def __init__(self, model, eval_dataset_full):
self.model = model
self.eval_dataset_full = eval_dataset_full
def on_evaluate(self, args, state, control, **kwargs):
"""Run after each evaluation step."""
metrics = evaluate_ranking(self.model, self.eval_dataset_full)
# Log to trackio
trackio.log({
"domain/ndcg@3": metrics["ndcg@3"],
"domain/ndcg@5": metrics["ndcg@5"],
"domain/mrr": metrics["mrr"],
"domain/rank_corr": metrics["rank_corr"],
})
logger.info(
f"Domain eval - NDCG@3: {metrics['ndcg@3']:.4f}, "
f"NDCG@5: {metrics['ndcg@5']:.4f}, "
f"MRR: {metrics['mrr']:.4f}, "
f"RankCorr: {metrics['rank_corr']:.4f} "
f"(n={metrics['n_queries']} queries)"
)
def evaluate_by_type(model, eval_dataset, type_column="type"):
"""Evaluate ranking metrics per content type."""
if type_column not in eval_dataset.column_names:
return {}
# Group by type first
by_type = defaultdict(list)
for item in eval_dataset:
by_type[item[type_column]].append(item)
results = {}
for content_type, items in by_type.items():
# Create a mini dataset for this type
class TypeDataset:
def __init__(self, items):
self.items = items
def __iter__(self):
return iter(self.items)
@property
def column_names(self):
return ["sentence1", "sentence2", "label"]
type_metrics = evaluate_ranking(model, TypeDataset(items))
if type_metrics["n_queries"] >= 2:
results[f"{content_type}_ndcg@5"] = type_metrics["ndcg@5"]
results[f"{content_type}_mrr"] = type_metrics["mrr"]
results[f"{content_type}_n_queries"] = type_metrics["n_queries"]
return results
def main():
# Initialize trackio with full config
trackio.init(
project="arcade-reranker",
name=RUN_NAME,
space_id=SPACE_ID,
config={
"model": BASE_MODEL,
"dataset": DATASET_NAME,
"learning_rate": LEARNING_RATE,
"num_epochs": NUM_EPOCHS,
"batch_size": BATCH_SIZE,
"max_seq_length": MAX_SEQ_LENGTH,
}
)
logger.info(f"Configuration:")
logger.info(f" Dataset: {DATASET_NAME}")
logger.info(f" Base model: {BASE_MODEL}")
logger.info(f" Epochs: {NUM_EPOCHS}")
logger.info(f" Run name: {RUN_NAME}")
logger.info(f" Trackio space: {SPACE_ID}")
model = CrossEncoder(BASE_MODEL, max_length=MAX_SEQ_LENGTH)
logger.info(f"Loading dataset: {DATASET_NAME}")
dataset = load_dataset(DATASET_NAME, split="train")
# Log dataset composition
type_counts = defaultdict(int)
if "type" in dataset.column_names:
for item in dataset:
type_counts[item["type"]] += 1
logger.info(f"Dataset composition: {dict(type_counts)}")
# Log to trackio
for content_type, count in type_counts.items():
trackio.log({f"data/{content_type}_count": count})
trackio.log({"data/total_examples": len(dataset)})
logger.info(f"Total examples: {len(dataset)}")
# Rename columns for CrossEncoderTrainer
dataset = dataset.rename_columns({
"query": "sentence1",
"text": "sentence2",
"score": "label"
})
# Split for evaluation (before removing extra columns so we keep type for eval)
eval_size = min(400, int(len(dataset) * 0.15))
splits = dataset.train_test_split(test_size=eval_size, seed=42)
# Keep full eval dataset with type column for per-type evaluation
eval_dataset_full = splits["test"]
# Remove extra columns for training (CrossEncoderTrainer only wants sentence1, sentence2, label)
train_dataset = splits["train"].select_columns(["sentence1", "sentence2", "label"])
eval_dataset = splits["test"].select_columns(["sentence1", "sentence2", "label"])
trackio.log({
"data/train_size": len(train_dataset),
"data/eval_size": len(eval_dataset),
})
logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
# Evaluate base model before training with proper ranking metrics
logger.info("Evaluating base model on eval set...")
base_metrics = evaluate_ranking(model, eval_dataset_full)
for key, value in base_metrics.items():
trackio.log({f"base_model/{key}": value})
logger.info(f"Base model metrics: {base_metrics}")
# NanoBEIR for benchmark comparison
evaluator = CrossEncoderNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"],
batch_size=BATCH_SIZE,
)
args = CrossEncoderTrainingArguments(
output_dir="models/reranker",
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
learning_rate=LEARNING_RATE,
warmup_ratio=0.1,
bf16=True,
eval_strategy="steps",
eval_steps=25,
save_strategy="steps",
save_steps=25,
save_total_limit=5,
logging_steps=25,
logging_first_step=True,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
hub_strategy="every_save",
report_to="trackio",
run_name=RUN_NAME,
)
# Custom callback to log domain-specific ranking metrics during training
domain_callback = DomainEvalCallback(model, eval_dataset_full)
trainer = CrossEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
evaluator=evaluator,
callbacks=[domain_callback],
)
logger.info("Starting training...")
trainer.train()
# Final evaluation with proper ranking metrics
logger.info("Running final ranking evaluation...")
final_metrics = evaluate_ranking(model, eval_dataset_full)
for key, value in final_metrics.items():
trackio.log({f"final/{key}": value})
logger.info(f"Final metrics: {final_metrics}")
# Per-type evaluation
logger.info("Evaluating by content type...")
type_metrics = evaluate_by_type(model, eval_dataset_full)
for key, value in type_metrics.items():
trackio.log({f"final/by_type/{key}": value})
logger.info(f"Per-type metrics: {type_metrics}")
# Log improvement
trackio.log({
"improvement/ndcg5_delta": final_metrics["ndcg@5"] - base_metrics["ndcg@5"],
"improvement/mrr_delta": final_metrics["mrr"] - base_metrics["mrr"],
})
logger.info(f"Pushing final model to {HUB_MODEL_ID}")
model.push_to_hub(HUB_MODEL_ID, exist_ok=True)
trackio.finish()
logger.info("Done!")
if __name__ == "__main__":
main()