Sentence Similarity
sentence-transformers
Safetensors
English
modernbert
feature-extraction
dense
Generated from Trainer
dataset_size:4314846
loss:CachedMultipleNegativesBidirectionalRankingLoss
Eval Results (legacy)
text-embeddings-inference
Instructions to use hotchpotch/ModernBERT-embedding-CMNBRL with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use hotchpotch/ModernBERT-embedding-CMNBRL with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("hotchpotch/ModernBERT-embedding-CMNBRL") sentences = [ "what is grade 7 gcse equivalent to?", "Unlike the Google Home Mini (First Gen), the Nest Mini (Second Gen) can be used to actually enjoy music in every room of the house. While the Google Home Mini (First Gen) is a decent way to get music in every room of your home for cheap, the sound quality that comes from the speaker reflects the price of the product.", "In general, a grade 7-9 is roughly equivalent to A-A* under the old system, while a grade 4 and above is roughly equivalent to a C and above. Fewer students will receive a grade 9 than would have received an A* under the old grading system.", "['Pulling at a wet or dirty diaper.', 'Hiding to pee or poop.', \"Interest in others' use of the potty, or copying their behavior.\", 'Having a dry diaper for a longer-than-usual time.', 'Awakening dry from a nap.', \"Telling you that they're about to go, are going or have just gone in their diaper.\"]" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
Upload train_st_loss_example.py
Browse files- train_st_loss_example.py +262 -0
train_st_loss_example.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Sample training script for ablation: compare CachedMultipleNegativesRankingLoss
|
| 3 |
+
# vs CachedMultipleNegativesBidirectionalRankingLoss (aka GTE loss with GradCache).
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import cast
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def parse_args() -> argparse.Namespace:
|
| 15 |
+
parser = argparse.ArgumentParser(description="Single-file ST loss training example (no src imports).")
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--model_name",
|
| 18 |
+
default="answerdotai/ModernBERT-base",
|
| 19 |
+
help="Sentence-Transformers model name or path.",
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument("--max_seq_length", type=int, default=512)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--max_train_examples",
|
| 24 |
+
type=int,
|
| 25 |
+
default=-1,
|
| 26 |
+
help="Limit training examples (use -1 for full dataset).",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument("--seed", type=int, default=12)
|
| 29 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 30 |
+
parser.add_argument("--per_device_train_batch_size", type=int, default=8192)
|
| 31 |
+
parser.add_argument("--per_device_eval_batch_size", type=int, default=512)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--learning_rate",
|
| 34 |
+
type=float,
|
| 35 |
+
default=1e-4,
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.1)
|
| 38 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
| 39 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
| 40 |
+
parser.add_argument("--logging_steps", type=int, default=10)
|
| 41 |
+
parser.add_argument("--save_steps", type=int, default=100)
|
| 42 |
+
parser.add_argument("--save_total_limit", type=int, default=2)
|
| 43 |
+
parser.add_argument("--lr_scheduler_type", default="cosine")
|
| 44 |
+
parser.add_argument("--optim", default="adamw_torch")
|
| 45 |
+
parser.add_argument("--loss_mini_batch_size", type=int, default=128)
|
| 46 |
+
parser.add_argument("--temperature", type=float, default=None)
|
| 47 |
+
parser.add_argument("--gather_across_devices", action="store_true")
|
| 48 |
+
parser.add_argument("--bf16", action="store_true", default=True)
|
| 49 |
+
parser.add_argument("--fp16", action="store_true", default=False)
|
| 50 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=12)
|
| 51 |
+
parser.add_argument("--dataloader_prefetch_factor", type=int, default=2)
|
| 52 |
+
parser.add_argument("--dataloader_persistent_workers", action="store_true", default=False)
|
| 53 |
+
parser.add_argument("--no_drop_last", action="store_true", help="Disable drop_last (default: True)")
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--batch_sampler",
|
| 56 |
+
choices=["batch_sampler", "no_duplicates"],
|
| 57 |
+
default="no_duplicates",
|
| 58 |
+
help="Batch sampler type for SentenceTransformers.",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--loss_type",
|
| 62 |
+
choices=["CMNRL", "CMNBRL"],
|
| 63 |
+
default="CMNBRL",
|
| 64 |
+
help="Loss type: CMNRL (CachedMultipleNegativesRankingLoss) or "
|
| 65 |
+
"CMNBRL (aka GTE with GradCache).",
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--output_root",
|
| 69 |
+
default="output/models/examples",
|
| 70 |
+
help="Root directory for outputs.",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument("--run_name", default=None)
|
| 73 |
+
parser.add_argument("--no_shuffle", action="store_true")
|
| 74 |
+
parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (debug).")
|
| 75 |
+
parser.add_argument("--resume_from_checkpoint", default=None, help="Resume training from checkpoint.")
|
| 76 |
+
return parser.parse_args()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_output_dir(output_root: Path, run_name: str) -> Path:
|
| 80 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 81 |
+
return output_root / run_name / timestamp
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def main() -> None:
|
| 85 |
+
args = parse_args()
|
| 86 |
+
|
| 87 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 88 |
+
|
| 89 |
+
import torch
|
| 90 |
+
from datasets import Dataset, DatasetDict, load_dataset
|
| 91 |
+
from sentence_transformers import (
|
| 92 |
+
SentenceTransformer,
|
| 93 |
+
SentenceTransformerTrainer,
|
| 94 |
+
SentenceTransformerTrainingArguments,
|
| 95 |
+
losses,
|
| 96 |
+
)
|
| 97 |
+
from sentence_transformers.evaluation import NanoBEIREvaluator
|
| 98 |
+
|
| 99 |
+
logging.basicConfig(
|
| 100 |
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
| 101 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 102 |
+
level=logging.INFO,
|
| 103 |
+
)
|
| 104 |
+
logger = logging.getLogger("train_st_loss_example")
|
| 105 |
+
|
| 106 |
+
if args.bf16 and (not torch.cuda.is_available() or not torch.cuda.is_bf16_supported()):
|
| 107 |
+
logger.warning("bf16 requested but not supported on this device; falling back to fp16=false.")
|
| 108 |
+
args.bf16 = False
|
| 109 |
+
|
| 110 |
+
output_root = Path(args.output_root)
|
| 111 |
+
output_root.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
max_train_tag = "full" if args.max_train_examples < 0 else f"{args.max_train_examples}"
|
| 114 |
+
data_tag = "pair"
|
| 115 |
+
if args.run_name is None:
|
| 116 |
+
model_tag = args.model_name.rstrip("/").split("/")[-1]
|
| 117 |
+
temp_tag = "tdefault" if args.temperature is None else f"t{args.temperature}".replace(".", "p")
|
| 118 |
+
args.run_name = (
|
| 119 |
+
f"{model_tag}_{args.loss_type}_{args.batch_sampler}_{temp_tag}_{data_tag}"
|
| 120 |
+
f"_bs{args.per_device_train_batch_size}_{max_train_tag}"
|
| 121 |
+
)
|
| 122 |
+
output_dir = build_output_dir(output_root, args.run_name)
|
| 123 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
final_dir = output_dir / "final"
|
| 125 |
+
|
| 126 |
+
logger.info("Loading model: %s", args.model_name)
|
| 127 |
+
model = SentenceTransformer(args.model_name)
|
| 128 |
+
model.max_seq_length = args.max_seq_length
|
| 129 |
+
|
| 130 |
+
def _load_pair_dataset(dataset_id: str, config: str | None, rename_map: dict[str, str]) -> Dataset:
|
| 131 |
+
ds = load_dataset(dataset_id, config, split="train") if config else load_dataset(dataset_id, split="train")
|
| 132 |
+
ds = cast(Dataset, ds)
|
| 133 |
+
if rename_map:
|
| 134 |
+
column_names = ds.column_names or []
|
| 135 |
+
existing = {k: v for k, v in rename_map.items() if k in column_names}
|
| 136 |
+
if existing:
|
| 137 |
+
ds = ds.rename_columns(existing)
|
| 138 |
+
ds = ds.select_columns(["query", "positive"])
|
| 139 |
+
return ds
|
| 140 |
+
|
| 141 |
+
logger.info("Loading datasets (pair only)...")
|
| 142 |
+
train_datasets = DatasetDict(
|
| 143 |
+
{
|
| 144 |
+
"msmarco": _load_pair_dataset(
|
| 145 |
+
"sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
|
| 146 |
+
"triplet",
|
| 147 |
+
{"query": "query", "positive": "positive"},
|
| 148 |
+
),
|
| 149 |
+
"natural_questions": _load_pair_dataset(
|
| 150 |
+
"sentence-transformers/natural-questions",
|
| 151 |
+
"pair",
|
| 152 |
+
{"answer": "positive"},
|
| 153 |
+
),
|
| 154 |
+
"gooaq": _load_pair_dataset(
|
| 155 |
+
"sentence-transformers/gooaq",
|
| 156 |
+
"pair",
|
| 157 |
+
{"question": "query", "answer": "positive"},
|
| 158 |
+
),
|
| 159 |
+
"ccnews": _load_pair_dataset(
|
| 160 |
+
"sentence-transformers/ccnews",
|
| 161 |
+
"pair",
|
| 162 |
+
{"title": "query", "article": "positive"},
|
| 163 |
+
),
|
| 164 |
+
"hotpotqa": _load_pair_dataset(
|
| 165 |
+
"sentence-transformers/hotpotqa",
|
| 166 |
+
"triplet",
|
| 167 |
+
{"anchor": "query", "positive": "positive"},
|
| 168 |
+
),
|
| 169 |
+
}
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
for name, ds in train_datasets.items():
|
| 173 |
+
if not args.no_shuffle:
|
| 174 |
+
ds = ds.shuffle(seed=args.seed)
|
| 175 |
+
if args.max_train_examples > 0:
|
| 176 |
+
ds = ds.select(range(min(args.max_train_examples, len(ds))))
|
| 177 |
+
train_datasets[name] = ds
|
| 178 |
+
logger.info("Train examples [%s]: %d", name, len(ds))
|
| 179 |
+
|
| 180 |
+
loss_kwargs = {}
|
| 181 |
+
if args.temperature is not None:
|
| 182 |
+
if args.loss_type == "CMNBRL":
|
| 183 |
+
loss_kwargs["temperature"] = args.temperature
|
| 184 |
+
else:
|
| 185 |
+
loss_kwargs["scale"] = 1.0 / args.temperature
|
| 186 |
+
if args.loss_mini_batch_size is not None:
|
| 187 |
+
loss_kwargs["mini_batch_size"] = args.loss_mini_batch_size
|
| 188 |
+
if args.gather_across_devices:
|
| 189 |
+
loss_kwargs["gather_across_devices"] = True
|
| 190 |
+
|
| 191 |
+
if args.loss_type == "CMNBRL":
|
| 192 |
+
loss = losses.CachedMultipleNegativesBidirectionalRankingLoss(model=model, **loss_kwargs)
|
| 193 |
+
else:
|
| 194 |
+
loss = losses.CachedMultipleNegativesRankingLoss(model=model, **loss_kwargs)
|
| 195 |
+
|
| 196 |
+
training_args = SentenceTransformerTrainingArguments(
|
| 197 |
+
output_dir=str(output_dir),
|
| 198 |
+
num_train_epochs=args.num_train_epochs,
|
| 199 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 200 |
+
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
| 201 |
+
learning_rate=args.learning_rate,
|
| 202 |
+
warmup_ratio=args.warmup_ratio,
|
| 203 |
+
weight_decay=args.weight_decay,
|
| 204 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 205 |
+
logging_steps=args.logging_steps,
|
| 206 |
+
save_steps=args.save_steps,
|
| 207 |
+
save_strategy="steps",
|
| 208 |
+
save_total_limit=args.save_total_limit,
|
| 209 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
| 210 |
+
optim=args.optim,
|
| 211 |
+
bf16=args.bf16,
|
| 212 |
+
fp16=args.fp16,
|
| 213 |
+
dataloader_num_workers=args.dataloader_num_workers,
|
| 214 |
+
dataloader_prefetch_factor=args.dataloader_prefetch_factor,
|
| 215 |
+
dataloader_persistent_workers=args.dataloader_persistent_workers,
|
| 216 |
+
dataloader_drop_last=not args.no_drop_last,
|
| 217 |
+
seed=args.seed,
|
| 218 |
+
max_steps=args.max_steps,
|
| 219 |
+
eval_strategy="no",
|
| 220 |
+
report_to=["wandb"],
|
| 221 |
+
remove_unused_columns=False,
|
| 222 |
+
batch_sampler=args.batch_sampler,
|
| 223 |
+
disable_tqdm=False,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
trainer = SentenceTransformerTrainer(
|
| 227 |
+
model=model,
|
| 228 |
+
args=training_args,
|
| 229 |
+
train_dataset=train_datasets,
|
| 230 |
+
loss=loss,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
logger.info("Training start. Output: %s", output_dir)
|
| 234 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
| 235 |
+
|
| 236 |
+
evaluator = NanoBEIREvaluator(
|
| 237 |
+
ndcg_at_k=[10],
|
| 238 |
+
mrr_at_k=[10],
|
| 239 |
+
accuracy_at_k=[10],
|
| 240 |
+
precision_recall_at_k=[10],
|
| 241 |
+
map_at_k=[10],
|
| 242 |
+
batch_size=args.per_device_eval_batch_size,
|
| 243 |
+
show_progress_bar=False,
|
| 244 |
+
write_csv=False,
|
| 245 |
+
)
|
| 246 |
+
results = evaluator(
|
| 247 |
+
model,
|
| 248 |
+
output_path=str(output_dir / "eval"),
|
| 249 |
+
epoch=0,
|
| 250 |
+
steps=trainer.state.global_step,
|
| 251 |
+
)
|
| 252 |
+
ndcg_key = evaluator.primary_metric
|
| 253 |
+
print(f"NDCG@10: {results[ndcg_key]:.6f} ({ndcg_key})")
|
| 254 |
+
|
| 255 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
| 256 |
+
trainer.save_model(str(final_dir))
|
| 257 |
+
model.save(str(final_dir), create_model_card=True)
|
| 258 |
+
logger.info("Saved model to: %s", final_dir)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
main()
|