ModernBERT-embedding-CMNBRL / train_st_loss_example.py
hotchpotch's picture
Upload train_st_loss_example.py
dfe6881 verified
#!/usr/bin/env python3
# Sample training script for ablation: compare CachedMultipleNegativesRankingLoss
# vs CachedMultipleNegativesBidirectionalRankingLoss (aka GTE loss with GradCache).
from __future__ import annotations
import argparse
import logging
import os
import time
from pathlib import Path
from typing import cast
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Single-file ST loss training example (no src imports).")
parser.add_argument(
"--model_name",
default="answerdotai/ModernBERT-base",
help="Sentence-Transformers model name or path.",
)
parser.add_argument("--max_seq_length", type=int, default=512)
parser.add_argument(
"--max_train_examples",
type=int,
default=-1,
help="Limit training examples (use -1 for full dataset).",
)
parser.add_argument("--seed", type=int, default=12)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--per_device_train_batch_size", type=int, default=8192)
parser.add_argument("--per_device_eval_batch_size", type=int, default=512)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
)
parser.add_argument("--warmup_ratio", type=float, default=0.1)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--logging_steps", type=int, default=10)
parser.add_argument("--save_steps", type=int, default=100)
parser.add_argument("--save_total_limit", type=int, default=2)
parser.add_argument("--lr_scheduler_type", default="cosine")
parser.add_argument("--optim", default="adamw_torch")
parser.add_argument("--loss_mini_batch_size", type=int, default=128)
parser.add_argument("--temperature", type=float, default=None)
parser.add_argument("--gather_across_devices", action="store_true")
parser.add_argument("--bf16", action="store_true", default=True)
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--dataloader_num_workers", type=int, default=12)
parser.add_argument("--dataloader_prefetch_factor", type=int, default=2)
parser.add_argument("--dataloader_persistent_workers", action="store_true", default=False)
parser.add_argument("--no_drop_last", action="store_true", help="Disable drop_last (default: True)")
parser.add_argument(
"--batch_sampler",
choices=["batch_sampler", "no_duplicates"],
default="no_duplicates",
help="Batch sampler type for SentenceTransformers.",
)
parser.add_argument(
"--loss_type",
choices=["CMNRL", "CMNBRL"],
default="CMNBRL",
help="Loss type: CMNRL (CachedMultipleNegativesRankingLoss) or "
"CMNBRL (aka GTE with GradCache).",
)
parser.add_argument(
"--output_root",
default="output/models/examples",
help="Root directory for outputs.",
)
parser.add_argument("--run_name", default=None)
parser.add_argument("--no_shuffle", action="store_true")
parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (debug).")
parser.add_argument("--resume_from_checkpoint", default=None, help="Resume training from checkpoint.")
return parser.parse_args()
def build_output_dir(output_root: Path, run_name: str) -> Path:
timestamp = time.strftime("%Y%m%d_%H%M%S")
return output_root / run_name / timestamp
def main() -> None:
args = parse_args()
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import torch
from datasets import Dataset, DatasetDict, load_dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
losses,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger("train_st_loss_example")
if args.bf16 and (not torch.cuda.is_available() or not torch.cuda.is_bf16_supported()):
logger.warning("bf16 requested but not supported on this device; falling back to fp16=false.")
args.bf16 = False
output_root = Path(args.output_root)
output_root.mkdir(parents=True, exist_ok=True)
max_train_tag = "full" if args.max_train_examples < 0 else f"{args.max_train_examples}"
data_tag = "pair"
if args.run_name is None:
model_tag = args.model_name.rstrip("/").split("/")[-1]
temp_tag = "tdefault" if args.temperature is None else f"t{args.temperature}".replace(".", "p")
args.run_name = (
f"{model_tag}_{args.loss_type}_{args.batch_sampler}_{temp_tag}_{data_tag}"
f"_bs{args.per_device_train_batch_size}_{max_train_tag}"
)
output_dir = build_output_dir(output_root, args.run_name)
output_dir.mkdir(parents=True, exist_ok=True)
final_dir = output_dir / "final"
logger.info("Loading model: %s", args.model_name)
model = SentenceTransformer(args.model_name)
model.max_seq_length = args.max_seq_length
def _load_pair_dataset(dataset_id: str, config: str | None, rename_map: dict[str, str]) -> Dataset:
ds = load_dataset(dataset_id, config, split="train") if config else load_dataset(dataset_id, split="train")
ds = cast(Dataset, ds)
if rename_map:
column_names = ds.column_names or []
existing = {k: v for k, v in rename_map.items() if k in column_names}
if existing:
ds = ds.rename_columns(existing)
ds = ds.select_columns(["query", "positive"])
return ds
logger.info("Loading datasets (pair only)...")
train_datasets = DatasetDict(
{
"msmarco": _load_pair_dataset(
"sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
"triplet",
{"query": "query", "positive": "positive"},
),
"natural_questions": _load_pair_dataset(
"sentence-transformers/natural-questions",
"pair",
{"answer": "positive"},
),
"gooaq": _load_pair_dataset(
"sentence-transformers/gooaq",
"pair",
{"question": "query", "answer": "positive"},
),
"ccnews": _load_pair_dataset(
"sentence-transformers/ccnews",
"pair",
{"title": "query", "article": "positive"},
),
"hotpotqa": _load_pair_dataset(
"sentence-transformers/hotpotqa",
"triplet",
{"anchor": "query", "positive": "positive"},
),
}
)
for name, ds in train_datasets.items():
if not args.no_shuffle:
ds = ds.shuffle(seed=args.seed)
if args.max_train_examples > 0:
ds = ds.select(range(min(args.max_train_examples, len(ds))))
train_datasets[name] = ds
logger.info("Train examples [%s]: %d", name, len(ds))
loss_kwargs = {}
if args.temperature is not None:
if args.loss_type == "CMNBRL":
loss_kwargs["temperature"] = args.temperature
else:
loss_kwargs["scale"] = 1.0 / args.temperature
if args.loss_mini_batch_size is not None:
loss_kwargs["mini_batch_size"] = args.loss_mini_batch_size
if args.gather_across_devices:
loss_kwargs["gather_across_devices"] = True
if args.loss_type == "CMNBRL":
loss = losses.CachedMultipleNegativesBidirectionalRankingLoss(model=model, **loss_kwargs)
else:
loss = losses.CachedMultipleNegativesRankingLoss(model=model, **loss_kwargs)
training_args = SentenceTransformerTrainingArguments(
output_dir=str(output_dir),
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
gradient_accumulation_steps=args.gradient_accumulation_steps,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_strategy="steps",
save_total_limit=args.save_total_limit,
lr_scheduler_type=args.lr_scheduler_type,
optim=args.optim,
bf16=args.bf16,
fp16=args.fp16,
dataloader_num_workers=args.dataloader_num_workers,
dataloader_prefetch_factor=args.dataloader_prefetch_factor,
dataloader_persistent_workers=args.dataloader_persistent_workers,
dataloader_drop_last=not args.no_drop_last,
seed=args.seed,
max_steps=args.max_steps,
eval_strategy="no",
report_to=["wandb"],
remove_unused_columns=False,
batch_sampler=args.batch_sampler,
disable_tqdm=False,
)
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_datasets,
loss=loss,
)
logger.info("Training start. Output: %s", output_dir)
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
evaluator = NanoBEIREvaluator(
ndcg_at_k=[10],
mrr_at_k=[10],
accuracy_at_k=[10],
precision_recall_at_k=[10],
map_at_k=[10],
batch_size=args.per_device_eval_batch_size,
show_progress_bar=False,
write_csv=False,
)
results = evaluator(
model,
output_path=str(output_dir / "eval"),
epoch=0,
steps=trainer.state.global_step,
)
ndcg_key = evaluator.primary_metric
print(f"NDCG@10: {results[ndcg_key]:.6f} ({ndcg_key})")
final_dir.mkdir(parents=True, exist_ok=True)
trainer.save_model(str(final_dir))
model.save(str(final_dir), create_model_card=True)
logger.info("Saved model to: %s", final_dir)
if __name__ == "__main__":
main()