amanwithaplan commited on
Commit
3517d13
·
verified ·
1 Parent(s): ce8b06e

Upload train_reranker.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_reranker.py +127 -0
train_reranker.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11"
3
+ # dependencies = [
4
+ # "sentence-transformers[train]>=4.0",
5
+ # "datasets",
6
+ # "torch>=2.4",
7
+ # "transformers>=4.48",
8
+ # "trackio",
9
+ # ]
10
+ # ///
11
+ """
12
+ Soft-Label Cross-Encoder Reranker Training
13
+
14
+ Trains a reranker using continuous relevance scores (soft labels).
15
+ Dataset format: {"query": "...", "text": "...", "score": 0.0-1.0}
16
+ """
17
+
18
+ import logging
19
+ import os
20
+ from collections import defaultdict
21
+ from datasets import load_dataset
22
+ from sentence_transformers.cross_encoder import (
23
+ CrossEncoder,
24
+ CrossEncoderTrainer,
25
+ CrossEncoderTrainingArguments,
26
+ )
27
+ from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
28
+
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Configuration
33
+ DATASET_NAME = os.environ.get("DATASET_NAME", "amanwithaplan/arcade-reranker-data")
34
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "amanwithaplan/arcade-reranker")
35
+ BASE_MODEL = os.environ.get("BASE_MODEL", "Alibaba-NLP/gte-reranker-modernbert-base")
36
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "5"))
37
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "16"))
38
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-5"))
39
+ MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "512"))
40
+ RUN_NAME = os.environ.get("RUN_NAME", "reranker-03130903")
41
+
42
+
43
+ def main():
44
+ logger.info(f"Configuration:")
45
+ logger.info(f" Dataset: {DATASET_NAME}")
46
+ logger.info(f" Base model: {BASE_MODEL}")
47
+ logger.info(f" Epochs: {NUM_EPOCHS}")
48
+ logger.info(f" Run name: {RUN_NAME}")
49
+
50
+ model = CrossEncoder(BASE_MODEL, max_length=MAX_SEQ_LENGTH)
51
+
52
+ logger.info(f"Loading dataset: {DATASET_NAME}")
53
+ dataset = load_dataset(DATASET_NAME, split="train")
54
+
55
+ # Log dataset composition
56
+ if "type" in dataset.column_names:
57
+ type_counts = defaultdict(int)
58
+ for item in dataset:
59
+ type_counts[item["type"]] += 1
60
+ logger.info(f"Dataset composition: {dict(type_counts)}")
61
+
62
+ logger.info(f"Total examples: {len(dataset)}")
63
+
64
+ # Rename columns for CrossEncoderTrainer
65
+ dataset = dataset.rename_columns({
66
+ "query": "sentence1",
67
+ "text": "sentence2",
68
+ "score": "label"
69
+ })
70
+
71
+ # Split for evaluation
72
+ eval_size = min(400, int(len(dataset) * 0.15))
73
+ splits = dataset.train_test_split(test_size=eval_size, seed=42)
74
+ train_dataset = splits["train"]
75
+ eval_dataset = splits["test"]
76
+
77
+ logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
78
+
79
+ # NanoBEIR for benchmark comparison
80
+ evaluator = CrossEncoderNanoBEIREvaluator(
81
+ dataset_names=["msmarco", "nfcorpus", "nq"],
82
+ batch_size=BATCH_SIZE,
83
+ )
84
+
85
+ args = CrossEncoderTrainingArguments(
86
+ output_dir="models/reranker",
87
+ num_train_epochs=NUM_EPOCHS,
88
+ per_device_train_batch_size=BATCH_SIZE,
89
+ per_device_eval_batch_size=BATCH_SIZE,
90
+ learning_rate=LEARNING_RATE,
91
+ warmup_ratio=0.1,
92
+ bf16=True,
93
+ eval_strategy="steps",
94
+ eval_steps=200,
95
+ save_strategy="steps",
96
+ save_steps=200,
97
+ save_total_limit=2,
98
+ logging_steps=25,
99
+ logging_first_step=True,
100
+ load_best_model_at_end=True,
101
+ metric_for_best_model="eval_loss",
102
+ greater_is_better=False,
103
+ push_to_hub=True,
104
+ hub_model_id=HUB_MODEL_ID,
105
+ hub_strategy="every_save",
106
+ report_to="trackio",
107
+ run_name=RUN_NAME,
108
+ )
109
+
110
+ trainer = CrossEncoderTrainer(
111
+ model=model,
112
+ args=args,
113
+ train_dataset=train_dataset,
114
+ eval_dataset=eval_dataset,
115
+ evaluator=evaluator,
116
+ )
117
+
118
+ logger.info("Starting training...")
119
+ trainer.train()
120
+
121
+ logger.info(f"Pushing final model to {HUB_MODEL_ID}")
122
+ model.push_to_hub(HUB_MODEL_ID)
123
+ logger.info("Done!")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()