hotchpotch commited on
Commit
dfe6881
·
verified ·
1 Parent(s): f44b248

Upload train_st_loss_example.py

Browse files
Files changed (1) hide show
  1. train_st_loss_example.py +262 -0
train_st_loss_example.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Sample training script for ablation: compare CachedMultipleNegativesRankingLoss
3
+ # vs CachedMultipleNegativesBidirectionalRankingLoss (aka GTE loss with GradCache).
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+ import time
10
+ from pathlib import Path
11
+ from typing import cast
12
+
13
+
14
+ def parse_args() -> argparse.Namespace:
15
+ parser = argparse.ArgumentParser(description="Single-file ST loss training example (no src imports).")
16
+ parser.add_argument(
17
+ "--model_name",
18
+ default="answerdotai/ModernBERT-base",
19
+ help="Sentence-Transformers model name or path.",
20
+ )
21
+ parser.add_argument("--max_seq_length", type=int, default=512)
22
+ parser.add_argument(
23
+ "--max_train_examples",
24
+ type=int,
25
+ default=-1,
26
+ help="Limit training examples (use -1 for full dataset).",
27
+ )
28
+ parser.add_argument("--seed", type=int, default=12)
29
+ parser.add_argument("--num_train_epochs", type=int, default=1)
30
+ parser.add_argument("--per_device_train_batch_size", type=int, default=8192)
31
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=512)
32
+ parser.add_argument(
33
+ "--learning_rate",
34
+ type=float,
35
+ default=1e-4,
36
+ )
37
+ parser.add_argument("--warmup_ratio", type=float, default=0.1)
38
+ parser.add_argument("--weight_decay", type=float, default=0.01)
39
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
40
+ parser.add_argument("--logging_steps", type=int, default=10)
41
+ parser.add_argument("--save_steps", type=int, default=100)
42
+ parser.add_argument("--save_total_limit", type=int, default=2)
43
+ parser.add_argument("--lr_scheduler_type", default="cosine")
44
+ parser.add_argument("--optim", default="adamw_torch")
45
+ parser.add_argument("--loss_mini_batch_size", type=int, default=128)
46
+ parser.add_argument("--temperature", type=float, default=None)
47
+ parser.add_argument("--gather_across_devices", action="store_true")
48
+ parser.add_argument("--bf16", action="store_true", default=True)
49
+ parser.add_argument("--fp16", action="store_true", default=False)
50
+ parser.add_argument("--dataloader_num_workers", type=int, default=12)
51
+ parser.add_argument("--dataloader_prefetch_factor", type=int, default=2)
52
+ parser.add_argument("--dataloader_persistent_workers", action="store_true", default=False)
53
+ parser.add_argument("--no_drop_last", action="store_true", help="Disable drop_last (default: True)")
54
+ parser.add_argument(
55
+ "--batch_sampler",
56
+ choices=["batch_sampler", "no_duplicates"],
57
+ default="no_duplicates",
58
+ help="Batch sampler type for SentenceTransformers.",
59
+ )
60
+ parser.add_argument(
61
+ "--loss_type",
62
+ choices=["CMNRL", "CMNBRL"],
63
+ default="CMNBRL",
64
+ help="Loss type: CMNRL (CachedMultipleNegativesRankingLoss) or "
65
+ "CMNBRL (aka GTE with GradCache).",
66
+ )
67
+ parser.add_argument(
68
+ "--output_root",
69
+ default="output/models/examples",
70
+ help="Root directory for outputs.",
71
+ )
72
+ parser.add_argument("--run_name", default=None)
73
+ parser.add_argument("--no_shuffle", action="store_true")
74
+ parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (debug).")
75
+ parser.add_argument("--resume_from_checkpoint", default=None, help="Resume training from checkpoint.")
76
+ return parser.parse_args()
77
+
78
+
79
+ def build_output_dir(output_root: Path, run_name: str) -> Path:
80
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
81
+ return output_root / run_name / timestamp
82
+
83
+
84
+ def main() -> None:
85
+ args = parse_args()
86
+
87
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
88
+
89
+ import torch
90
+ from datasets import Dataset, DatasetDict, load_dataset
91
+ from sentence_transformers import (
92
+ SentenceTransformer,
93
+ SentenceTransformerTrainer,
94
+ SentenceTransformerTrainingArguments,
95
+ losses,
96
+ )
97
+ from sentence_transformers.evaluation import NanoBEIREvaluator
98
+
99
+ logging.basicConfig(
100
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
101
+ datefmt="%Y-%m-%d %H:%M:%S",
102
+ level=logging.INFO,
103
+ )
104
+ logger = logging.getLogger("train_st_loss_example")
105
+
106
+ if args.bf16 and (not torch.cuda.is_available() or not torch.cuda.is_bf16_supported()):
107
+ logger.warning("bf16 requested but not supported on this device; falling back to fp16=false.")
108
+ args.bf16 = False
109
+
110
+ output_root = Path(args.output_root)
111
+ output_root.mkdir(parents=True, exist_ok=True)
112
+
113
+ max_train_tag = "full" if args.max_train_examples < 0 else f"{args.max_train_examples}"
114
+ data_tag = "pair"
115
+ if args.run_name is None:
116
+ model_tag = args.model_name.rstrip("/").split("/")[-1]
117
+ temp_tag = "tdefault" if args.temperature is None else f"t{args.temperature}".replace(".", "p")
118
+ args.run_name = (
119
+ f"{model_tag}_{args.loss_type}_{args.batch_sampler}_{temp_tag}_{data_tag}"
120
+ f"_bs{args.per_device_train_batch_size}_{max_train_tag}"
121
+ )
122
+ output_dir = build_output_dir(output_root, args.run_name)
123
+ output_dir.mkdir(parents=True, exist_ok=True)
124
+ final_dir = output_dir / "final"
125
+
126
+ logger.info("Loading model: %s", args.model_name)
127
+ model = SentenceTransformer(args.model_name)
128
+ model.max_seq_length = args.max_seq_length
129
+
130
+ def _load_pair_dataset(dataset_id: str, config: str | None, rename_map: dict[str, str]) -> Dataset:
131
+ ds = load_dataset(dataset_id, config, split="train") if config else load_dataset(dataset_id, split="train")
132
+ ds = cast(Dataset, ds)
133
+ if rename_map:
134
+ column_names = ds.column_names or []
135
+ existing = {k: v for k, v in rename_map.items() if k in column_names}
136
+ if existing:
137
+ ds = ds.rename_columns(existing)
138
+ ds = ds.select_columns(["query", "positive"])
139
+ return ds
140
+
141
+ logger.info("Loading datasets (pair only)...")
142
+ train_datasets = DatasetDict(
143
+ {
144
+ "msmarco": _load_pair_dataset(
145
+ "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
146
+ "triplet",
147
+ {"query": "query", "positive": "positive"},
148
+ ),
149
+ "natural_questions": _load_pair_dataset(
150
+ "sentence-transformers/natural-questions",
151
+ "pair",
152
+ {"answer": "positive"},
153
+ ),
154
+ "gooaq": _load_pair_dataset(
155
+ "sentence-transformers/gooaq",
156
+ "pair",
157
+ {"question": "query", "answer": "positive"},
158
+ ),
159
+ "ccnews": _load_pair_dataset(
160
+ "sentence-transformers/ccnews",
161
+ "pair",
162
+ {"title": "query", "article": "positive"},
163
+ ),
164
+ "hotpotqa": _load_pair_dataset(
165
+ "sentence-transformers/hotpotqa",
166
+ "triplet",
167
+ {"anchor": "query", "positive": "positive"},
168
+ ),
169
+ }
170
+ )
171
+
172
+ for name, ds in train_datasets.items():
173
+ if not args.no_shuffle:
174
+ ds = ds.shuffle(seed=args.seed)
175
+ if args.max_train_examples > 0:
176
+ ds = ds.select(range(min(args.max_train_examples, len(ds))))
177
+ train_datasets[name] = ds
178
+ logger.info("Train examples [%s]: %d", name, len(ds))
179
+
180
+ loss_kwargs = {}
181
+ if args.temperature is not None:
182
+ if args.loss_type == "CMNBRL":
183
+ loss_kwargs["temperature"] = args.temperature
184
+ else:
185
+ loss_kwargs["scale"] = 1.0 / args.temperature
186
+ if args.loss_mini_batch_size is not None:
187
+ loss_kwargs["mini_batch_size"] = args.loss_mini_batch_size
188
+ if args.gather_across_devices:
189
+ loss_kwargs["gather_across_devices"] = True
190
+
191
+ if args.loss_type == "CMNBRL":
192
+ loss = losses.CachedMultipleNegativesBidirectionalRankingLoss(model=model, **loss_kwargs)
193
+ else:
194
+ loss = losses.CachedMultipleNegativesRankingLoss(model=model, **loss_kwargs)
195
+
196
+ training_args = SentenceTransformerTrainingArguments(
197
+ output_dir=str(output_dir),
198
+ num_train_epochs=args.num_train_epochs,
199
+ per_device_train_batch_size=args.per_device_train_batch_size,
200
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
201
+ learning_rate=args.learning_rate,
202
+ warmup_ratio=args.warmup_ratio,
203
+ weight_decay=args.weight_decay,
204
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
205
+ logging_steps=args.logging_steps,
206
+ save_steps=args.save_steps,
207
+ save_strategy="steps",
208
+ save_total_limit=args.save_total_limit,
209
+ lr_scheduler_type=args.lr_scheduler_type,
210
+ optim=args.optim,
211
+ bf16=args.bf16,
212
+ fp16=args.fp16,
213
+ dataloader_num_workers=args.dataloader_num_workers,
214
+ dataloader_prefetch_factor=args.dataloader_prefetch_factor,
215
+ dataloader_persistent_workers=args.dataloader_persistent_workers,
216
+ dataloader_drop_last=not args.no_drop_last,
217
+ seed=args.seed,
218
+ max_steps=args.max_steps,
219
+ eval_strategy="no",
220
+ report_to=["wandb"],
221
+ remove_unused_columns=False,
222
+ batch_sampler=args.batch_sampler,
223
+ disable_tqdm=False,
224
+ )
225
+
226
+ trainer = SentenceTransformerTrainer(
227
+ model=model,
228
+ args=training_args,
229
+ train_dataset=train_datasets,
230
+ loss=loss,
231
+ )
232
+
233
+ logger.info("Training start. Output: %s", output_dir)
234
+ trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
235
+
236
+ evaluator = NanoBEIREvaluator(
237
+ ndcg_at_k=[10],
238
+ mrr_at_k=[10],
239
+ accuracy_at_k=[10],
240
+ precision_recall_at_k=[10],
241
+ map_at_k=[10],
242
+ batch_size=args.per_device_eval_batch_size,
243
+ show_progress_bar=False,
244
+ write_csv=False,
245
+ )
246
+ results = evaluator(
247
+ model,
248
+ output_path=str(output_dir / "eval"),
249
+ epoch=0,
250
+ steps=trainer.state.global_step,
251
+ )
252
+ ndcg_key = evaluator.primary_metric
253
+ print(f"NDCG@10: {results[ndcg_key]:.6f} ({ndcg_key})")
254
+
255
+ final_dir.mkdir(parents=True, exist_ok=True)
256
+ trainer.save_model(str(final_dir))
257
+ model.save(str(final_dir), create_model_card=True)
258
+ logger.info("Saved model to: %s", final_dir)
259
+
260
+
261
+ if __name__ == "__main__":
262
+ main()