|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import os |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import cast |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description="Single-file ST loss training example (no src imports).") |
|
|
parser.add_argument( |
|
|
"--model_name", |
|
|
default="answerdotai/ModernBERT-base", |
|
|
help="Sentence-Transformers model name or path.", |
|
|
) |
|
|
parser.add_argument("--max_seq_length", type=int, default=512) |
|
|
parser.add_argument( |
|
|
"--max_train_examples", |
|
|
type=int, |
|
|
default=-1, |
|
|
help="Limit training examples (use -1 for full dataset).", |
|
|
) |
|
|
parser.add_argument("--seed", type=int, default=12) |
|
|
parser.add_argument("--num_train_epochs", type=int, default=1) |
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=8192) |
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=512) |
|
|
parser.add_argument( |
|
|
"--learning_rate", |
|
|
type=float, |
|
|
default=1e-4, |
|
|
) |
|
|
parser.add_argument("--warmup_ratio", type=float, default=0.1) |
|
|
parser.add_argument("--weight_decay", type=float, default=0.01) |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) |
|
|
parser.add_argument("--logging_steps", type=int, default=10) |
|
|
parser.add_argument("--save_steps", type=int, default=100) |
|
|
parser.add_argument("--save_total_limit", type=int, default=2) |
|
|
parser.add_argument("--lr_scheduler_type", default="cosine") |
|
|
parser.add_argument("--optim", default="adamw_torch") |
|
|
parser.add_argument("--loss_mini_batch_size", type=int, default=128) |
|
|
parser.add_argument("--temperature", type=float, default=None) |
|
|
parser.add_argument("--gather_across_devices", action="store_true") |
|
|
parser.add_argument("--bf16", action="store_true", default=True) |
|
|
parser.add_argument("--fp16", action="store_true", default=False) |
|
|
parser.add_argument("--dataloader_num_workers", type=int, default=12) |
|
|
parser.add_argument("--dataloader_prefetch_factor", type=int, default=2) |
|
|
parser.add_argument("--dataloader_persistent_workers", action="store_true", default=False) |
|
|
parser.add_argument("--no_drop_last", action="store_true", help="Disable drop_last (default: True)") |
|
|
parser.add_argument( |
|
|
"--batch_sampler", |
|
|
choices=["batch_sampler", "no_duplicates"], |
|
|
default="no_duplicates", |
|
|
help="Batch sampler type for SentenceTransformers.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--loss_type", |
|
|
choices=["CMNRL", "CMNBRL"], |
|
|
default="CMNBRL", |
|
|
help="Loss type: CMNRL (CachedMultipleNegativesRankingLoss) or " |
|
|
"CMNBRL (aka GTE with GradCache).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_root", |
|
|
default="output/models/examples", |
|
|
help="Root directory for outputs.", |
|
|
) |
|
|
parser.add_argument("--run_name", default=None) |
|
|
parser.add_argument("--no_shuffle", action="store_true") |
|
|
parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (debug).") |
|
|
parser.add_argument("--resume_from_checkpoint", default=None, help="Resume training from checkpoint.") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def build_output_dir(output_root: Path, run_name: str) -> Path: |
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
|
return output_root / run_name / timestamp |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
|
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
|
|
import torch |
|
|
from datasets import Dataset, DatasetDict, load_dataset |
|
|
from sentence_transformers import ( |
|
|
SentenceTransformer, |
|
|
SentenceTransformerTrainer, |
|
|
SentenceTransformerTrainingArguments, |
|
|
losses, |
|
|
) |
|
|
from sentence_transformers.evaluation import NanoBEIREvaluator |
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
|
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
logger = logging.getLogger("train_st_loss_example") |
|
|
|
|
|
if args.bf16 and (not torch.cuda.is_available() or not torch.cuda.is_bf16_supported()): |
|
|
logger.warning("bf16 requested but not supported on this device; falling back to fp16=false.") |
|
|
args.bf16 = False |
|
|
|
|
|
output_root = Path(args.output_root) |
|
|
output_root.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
max_train_tag = "full" if args.max_train_examples < 0 else f"{args.max_train_examples}" |
|
|
data_tag = "pair" |
|
|
if args.run_name is None: |
|
|
model_tag = args.model_name.rstrip("/").split("/")[-1] |
|
|
temp_tag = "tdefault" if args.temperature is None else f"t{args.temperature}".replace(".", "p") |
|
|
args.run_name = ( |
|
|
f"{model_tag}_{args.loss_type}_{args.batch_sampler}_{temp_tag}_{data_tag}" |
|
|
f"_bs{args.per_device_train_batch_size}_{max_train_tag}" |
|
|
) |
|
|
output_dir = build_output_dir(output_root, args.run_name) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
final_dir = output_dir / "final" |
|
|
|
|
|
logger.info("Loading model: %s", args.model_name) |
|
|
model = SentenceTransformer(args.model_name) |
|
|
model.max_seq_length = args.max_seq_length |
|
|
|
|
|
def _load_pair_dataset(dataset_id: str, config: str | None, rename_map: dict[str, str]) -> Dataset: |
|
|
ds = load_dataset(dataset_id, config, split="train") if config else load_dataset(dataset_id, split="train") |
|
|
ds = cast(Dataset, ds) |
|
|
if rename_map: |
|
|
column_names = ds.column_names or [] |
|
|
existing = {k: v for k, v in rename_map.items() if k in column_names} |
|
|
if existing: |
|
|
ds = ds.rename_columns(existing) |
|
|
ds = ds.select_columns(["query", "positive"]) |
|
|
return ds |
|
|
|
|
|
logger.info("Loading datasets (pair only)...") |
|
|
train_datasets = DatasetDict( |
|
|
{ |
|
|
"msmarco": _load_pair_dataset( |
|
|
"sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", |
|
|
"triplet", |
|
|
{"query": "query", "positive": "positive"}, |
|
|
), |
|
|
"natural_questions": _load_pair_dataset( |
|
|
"sentence-transformers/natural-questions", |
|
|
"pair", |
|
|
{"answer": "positive"}, |
|
|
), |
|
|
"gooaq": _load_pair_dataset( |
|
|
"sentence-transformers/gooaq", |
|
|
"pair", |
|
|
{"question": "query", "answer": "positive"}, |
|
|
), |
|
|
"ccnews": _load_pair_dataset( |
|
|
"sentence-transformers/ccnews", |
|
|
"pair", |
|
|
{"title": "query", "article": "positive"}, |
|
|
), |
|
|
"hotpotqa": _load_pair_dataset( |
|
|
"sentence-transformers/hotpotqa", |
|
|
"triplet", |
|
|
{"anchor": "query", "positive": "positive"}, |
|
|
), |
|
|
} |
|
|
) |
|
|
|
|
|
for name, ds in train_datasets.items(): |
|
|
if not args.no_shuffle: |
|
|
ds = ds.shuffle(seed=args.seed) |
|
|
if args.max_train_examples > 0: |
|
|
ds = ds.select(range(min(args.max_train_examples, len(ds)))) |
|
|
train_datasets[name] = ds |
|
|
logger.info("Train examples [%s]: %d", name, len(ds)) |
|
|
|
|
|
loss_kwargs = {} |
|
|
if args.temperature is not None: |
|
|
if args.loss_type == "CMNBRL": |
|
|
loss_kwargs["temperature"] = args.temperature |
|
|
else: |
|
|
loss_kwargs["scale"] = 1.0 / args.temperature |
|
|
if args.loss_mini_batch_size is not None: |
|
|
loss_kwargs["mini_batch_size"] = args.loss_mini_batch_size |
|
|
if args.gather_across_devices: |
|
|
loss_kwargs["gather_across_devices"] = True |
|
|
|
|
|
if args.loss_type == "CMNBRL": |
|
|
loss = losses.CachedMultipleNegativesBidirectionalRankingLoss(model=model, **loss_kwargs) |
|
|
else: |
|
|
loss = losses.CachedMultipleNegativesRankingLoss(model=model, **loss_kwargs) |
|
|
|
|
|
training_args = SentenceTransformerTrainingArguments( |
|
|
output_dir=str(output_dir), |
|
|
num_train_epochs=args.num_train_epochs, |
|
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size, |
|
|
learning_rate=args.learning_rate, |
|
|
warmup_ratio=args.warmup_ratio, |
|
|
weight_decay=args.weight_decay, |
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
|
logging_steps=args.logging_steps, |
|
|
save_steps=args.save_steps, |
|
|
save_strategy="steps", |
|
|
save_total_limit=args.save_total_limit, |
|
|
lr_scheduler_type=args.lr_scheduler_type, |
|
|
optim=args.optim, |
|
|
bf16=args.bf16, |
|
|
fp16=args.fp16, |
|
|
dataloader_num_workers=args.dataloader_num_workers, |
|
|
dataloader_prefetch_factor=args.dataloader_prefetch_factor, |
|
|
dataloader_persistent_workers=args.dataloader_persistent_workers, |
|
|
dataloader_drop_last=not args.no_drop_last, |
|
|
seed=args.seed, |
|
|
max_steps=args.max_steps, |
|
|
eval_strategy="no", |
|
|
report_to=["wandb"], |
|
|
remove_unused_columns=False, |
|
|
batch_sampler=args.batch_sampler, |
|
|
disable_tqdm=False, |
|
|
) |
|
|
|
|
|
trainer = SentenceTransformerTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_datasets, |
|
|
loss=loss, |
|
|
) |
|
|
|
|
|
logger.info("Training start. Output: %s", output_dir) |
|
|
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
|
|
|
|
|
evaluator = NanoBEIREvaluator( |
|
|
ndcg_at_k=[10], |
|
|
mrr_at_k=[10], |
|
|
accuracy_at_k=[10], |
|
|
precision_recall_at_k=[10], |
|
|
map_at_k=[10], |
|
|
batch_size=args.per_device_eval_batch_size, |
|
|
show_progress_bar=False, |
|
|
write_csv=False, |
|
|
) |
|
|
results = evaluator( |
|
|
model, |
|
|
output_path=str(output_dir / "eval"), |
|
|
epoch=0, |
|
|
steps=trainer.state.global_step, |
|
|
) |
|
|
ndcg_key = evaluator.primary_metric |
|
|
print(f"NDCG@10: {results[ndcg_key]:.6f} ({ndcg_key})") |
|
|
|
|
|
final_dir.mkdir(parents=True, exist_ok=True) |
|
|
trainer.save_model(str(final_dir)) |
|
|
model.save(str(final_dir), create_model_card=True) |
|
|
logger.info("Saved model to: %s", final_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|