#!/usr/bin/env python3 # Sample training script for ablation: compare CachedMultipleNegativesRankingLoss # vs CachedMultipleNegativesBidirectionalRankingLoss (aka GTE loss with GradCache). from __future__ import annotations import argparse import logging import os import time from pathlib import Path from typing import cast def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Single-file ST loss training example (no src imports).") parser.add_argument( "--model_name", default="answerdotai/ModernBERT-base", help="Sentence-Transformers model name or path.", ) parser.add_argument("--max_seq_length", type=int, default=512) parser.add_argument( "--max_train_examples", type=int, default=-1, help="Limit training examples (use -1 for full dataset).", ) parser.add_argument("--seed", type=int, default=12) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument("--per_device_train_batch_size", type=int, default=8192) parser.add_argument("--per_device_eval_batch_size", type=int, default=512) parser.add_argument( "--learning_rate", type=float, default=1e-4, ) parser.add_argument("--warmup_ratio", type=float, default=0.1) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--logging_steps", type=int, default=10) parser.add_argument("--save_steps", type=int, default=100) parser.add_argument("--save_total_limit", type=int, default=2) parser.add_argument("--lr_scheduler_type", default="cosine") parser.add_argument("--optim", default="adamw_torch") parser.add_argument("--loss_mini_batch_size", type=int, default=128) parser.add_argument("--temperature", type=float, default=None) parser.add_argument("--gather_across_devices", action="store_true") parser.add_argument("--bf16", action="store_true", default=True) parser.add_argument("--fp16", action="store_true", default=False) parser.add_argument("--dataloader_num_workers", type=int, default=12) parser.add_argument("--dataloader_prefetch_factor", type=int, default=2) parser.add_argument("--dataloader_persistent_workers", action="store_true", default=False) parser.add_argument("--no_drop_last", action="store_true", help="Disable drop_last (default: True)") parser.add_argument( "--batch_sampler", choices=["batch_sampler", "no_duplicates"], default="no_duplicates", help="Batch sampler type for SentenceTransformers.", ) parser.add_argument( "--loss_type", choices=["CMNRL", "CMNBRL"], default="CMNBRL", help="Loss type: CMNRL (CachedMultipleNegativesRankingLoss) or " "CMNBRL (aka GTE with GradCache).", ) parser.add_argument( "--output_root", default="output/models/examples", help="Root directory for outputs.", ) parser.add_argument("--run_name", default=None) parser.add_argument("--no_shuffle", action="store_true") parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (debug).") parser.add_argument("--resume_from_checkpoint", default=None, help="Resume training from checkpoint.") return parser.parse_args() def build_output_dir(output_root: Path, run_name: str) -> Path: timestamp = time.strftime("%Y%m%d_%H%M%S") return output_root / run_name / timestamp def main() -> None: args = parse_args() os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import torch from datasets import Dataset, DatasetDict, load_dataset from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses, ) from sentence_transformers.evaluation import NanoBEIREvaluator logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger("train_st_loss_example") if args.bf16 and (not torch.cuda.is_available() or not torch.cuda.is_bf16_supported()): logger.warning("bf16 requested but not supported on this device; falling back to fp16=false.") args.bf16 = False output_root = Path(args.output_root) output_root.mkdir(parents=True, exist_ok=True) max_train_tag = "full" if args.max_train_examples < 0 else f"{args.max_train_examples}" data_tag = "pair" if args.run_name is None: model_tag = args.model_name.rstrip("/").split("/")[-1] temp_tag = "tdefault" if args.temperature is None else f"t{args.temperature}".replace(".", "p") args.run_name = ( f"{model_tag}_{args.loss_type}_{args.batch_sampler}_{temp_tag}_{data_tag}" f"_bs{args.per_device_train_batch_size}_{max_train_tag}" ) output_dir = build_output_dir(output_root, args.run_name) output_dir.mkdir(parents=True, exist_ok=True) final_dir = output_dir / "final" logger.info("Loading model: %s", args.model_name) model = SentenceTransformer(args.model_name) model.max_seq_length = args.max_seq_length def _load_pair_dataset(dataset_id: str, config: str | None, rename_map: dict[str, str]) -> Dataset: ds = load_dataset(dataset_id, config, split="train") if config else load_dataset(dataset_id, split="train") ds = cast(Dataset, ds) if rename_map: column_names = ds.column_names or [] existing = {k: v for k, v in rename_map.items() if k in column_names} if existing: ds = ds.rename_columns(existing) ds = ds.select_columns(["query", "positive"]) return ds logger.info("Loading datasets (pair only)...") train_datasets = DatasetDict( { "msmarco": _load_pair_dataset( "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", {"query": "query", "positive": "positive"}, ), "natural_questions": _load_pair_dataset( "sentence-transformers/natural-questions", "pair", {"answer": "positive"}, ), "gooaq": _load_pair_dataset( "sentence-transformers/gooaq", "pair", {"question": "query", "answer": "positive"}, ), "ccnews": _load_pair_dataset( "sentence-transformers/ccnews", "pair", {"title": "query", "article": "positive"}, ), "hotpotqa": _load_pair_dataset( "sentence-transformers/hotpotqa", "triplet", {"anchor": "query", "positive": "positive"}, ), } ) for name, ds in train_datasets.items(): if not args.no_shuffle: ds = ds.shuffle(seed=args.seed) if args.max_train_examples > 0: ds = ds.select(range(min(args.max_train_examples, len(ds)))) train_datasets[name] = ds logger.info("Train examples [%s]: %d", name, len(ds)) loss_kwargs = {} if args.temperature is not None: if args.loss_type == "CMNBRL": loss_kwargs["temperature"] = args.temperature else: loss_kwargs["scale"] = 1.0 / args.temperature if args.loss_mini_batch_size is not None: loss_kwargs["mini_batch_size"] = args.loss_mini_batch_size if args.gather_across_devices: loss_kwargs["gather_across_devices"] = True if args.loss_type == "CMNBRL": loss = losses.CachedMultipleNegativesBidirectionalRankingLoss(model=model, **loss_kwargs) else: loss = losses.CachedMultipleNegativesRankingLoss(model=model, **loss_kwargs) training_args = SentenceTransformerTrainingArguments( output_dir=str(output_dir), num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, gradient_accumulation_steps=args.gradient_accumulation_steps, logging_steps=args.logging_steps, save_steps=args.save_steps, save_strategy="steps", save_total_limit=args.save_total_limit, lr_scheduler_type=args.lr_scheduler_type, optim=args.optim, bf16=args.bf16, fp16=args.fp16, dataloader_num_workers=args.dataloader_num_workers, dataloader_prefetch_factor=args.dataloader_prefetch_factor, dataloader_persistent_workers=args.dataloader_persistent_workers, dataloader_drop_last=not args.no_drop_last, seed=args.seed, max_steps=args.max_steps, eval_strategy="no", report_to=["wandb"], remove_unused_columns=False, batch_sampler=args.batch_sampler, disable_tqdm=False, ) trainer = SentenceTransformerTrainer( model=model, args=training_args, train_dataset=train_datasets, loss=loss, ) logger.info("Training start. Output: %s", output_dir) trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) evaluator = NanoBEIREvaluator( ndcg_at_k=[10], mrr_at_k=[10], accuracy_at_k=[10], precision_recall_at_k=[10], map_at_k=[10], batch_size=args.per_device_eval_batch_size, show_progress_bar=False, write_csv=False, ) results = evaluator( model, output_path=str(output_dir / "eval"), epoch=0, steps=trainer.state.global_step, ) ndcg_key = evaluator.primary_metric print(f"NDCG@10: {results[ndcg_key]:.6f} ({ndcg_key})") final_dir.mkdir(parents=True, exist_ok=True) trainer.save_model(str(final_dir)) model.save(str(final_dir), create_model_card=True) logger.info("Saved model to: %s", final_dir) if __name__ == "__main__": main()