| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| | import random |
| | from pathlib import Path |
| |
|
| | from sentence_transformers import ( |
| | SentenceTransformer, |
| | SentenceTransformerModelCardData, |
| | SentenceTransformerTrainer, |
| | SentenceTransformerTrainingArguments, |
| | ) |
| | from sentence_transformers.evaluation import NanoBEIREvaluator |
| | from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss |
| | from sentence_transformers.models.StaticEmbedding import StaticEmbedding |
| | from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers |
| | from transformers import AutoTokenizer |
| |
|
| | from datasets import Dataset, DatasetDict, load_dataset |
| |
|
| | EXP = "030" |
| | print("EXP:", EXP) |
| |
|
| | PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| | print(PROJECT_ROOT) |
| |
|
| | EN_TARGET_DATASETS = [ |
| | |
| | "msmarco", |
| | "squad", |
| | |
| | "allnli", |
| | |
| | "trivia_qa", |
| | |
| | "swim_ir", |
| | |
| | "miracl", |
| | |
| | "mr_tydi", |
| | ] |
| |
|
| | JA_TARGET_DATASETS = [ |
| | "hpprc_emb__auto-wiki-nli-triplet", |
| | "hpprc_emb__auto-wiki-qa", |
| | "hpprc_emb__auto-wiki-qa-nemotron", |
| | "hpprc_emb__auto-wiki-qa-pair", |
| | "hpprc_emb__baobab-wiki-retrieval", |
| | |
| | "hpprc_emb__janli-triplet", |
| | "hpprc_emb__jaquad", |
| | "hpprc_emb__jqara", |
| | "hpprc_emb__jsnli-triplet", |
| | "hpprc_emb__jsquad", |
| | "hpprc_emb__miracl", |
| | "hpprc_emb__mkqa", |
| | "hpprc_emb__mkqa-triplet", |
| | |
| | "hpprc_emb__mr-tydi", |
| | "hpprc_emb__nu-mnli-triplet", |
| | "hpprc_emb__nu-snli-triplet", |
| | |
| | "hpprc_emb__quiz-no-mori", |
| | "hpprc_emb__quiz-works", |
| | "hpprc_emb__snow-triplet", |
| | "hpprc_llmjp-kaken", |
| | "hpprc_llmjp_warp_html", |
| | "hpprc_mqa_ja", |
| | "hpprc_msmarco_ja", |
| | ] |
| |
|
| | AUG_FACTOR_DATASETS = { |
| | "hpprc_emb__miracl": 20, |
| | "hpprc_emb__mr-tydi": 20, |
| | "hpprc_emb__jqara": 10, |
| | "hpprc_emb__baobab-wiki-retrieval": 5, |
| | "hpprc_emb__mkqa": 5, |
| | "hpprc_emb__auto-wiki-qa-nemotron": 2, |
| | "hpprc_msmarco_ja": 2, |
| | } |
| |
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO |
| | ) |
| | random.seed(12) |
| |
|
| |
|
| | def _load_train_eval_datasets_en(): |
| | """ |
| | Either load the train and eval datasets from disk or load them from the datasets library & save them to disk. |
| | |
| | Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training. |
| | """ |
| | en_train_dataset_dir = PROJECT_ROOT / "datasets" / "en_train_dataset" |
| | en_eval_dataset_dir = PROJECT_ROOT / "datasets" / "en_eval_dataset" |
| | try: |
| | train_dataset = DatasetDict.load_from_disk(en_train_dataset_dir) |
| | eval_dataset = DatasetDict.load_from_disk(en_eval_dataset_dir) |
| | return train_dataset, eval_dataset |
| | except FileNotFoundError: |
| | print("Loading gooaq dataset...") |
| | gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train") |
| | gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12) |
| | gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"] |
| | gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"] |
| | print("Loaded gooaq dataset.") |
| |
|
| | print("Loading msmarco dataset...") |
| | msmarco_dataset = load_dataset( |
| | "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", |
| | "triplet", |
| | split="train", |
| | ) |
| | msmarco_dataset_dict = msmarco_dataset.train_test_split( |
| | test_size=10_000, seed=12 |
| | ) |
| | msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"] |
| | msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"] |
| | print("Loaded msmarco dataset.") |
| |
|
| | print("Loading squad dataset...") |
| | squad_dataset = load_dataset("sentence-transformers/squad", split="train") |
| | squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12) |
| | squad_train_dataset: Dataset = squad_dataset_dict["train"] |
| | squad_eval_dataset: Dataset = squad_dataset_dict["test"] |
| | print("Loaded squad dataset.") |
| |
|
| | print("Loading s2orc dataset...") |
| | s2orc_dataset = load_dataset( |
| | "sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]" |
| | ) |
| | s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12) |
| | s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"] |
| | s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"] |
| | print("Loaded s2orc dataset.") |
| |
|
| | print("Loading allnli dataset...") |
| | allnli_train_dataset = load_dataset( |
| | "sentence-transformers/all-nli", "triplet", split="train" |
| | ) |
| | allnli_eval_dataset = load_dataset( |
| | "sentence-transformers/all-nli", "triplet", split="dev" |
| | ) |
| | print("Loaded allnli dataset.") |
| |
|
| | print("Loading paq dataset...") |
| | paq_dataset = load_dataset("sentence-transformers/paq", split="train") |
| | paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12) |
| | paq_train_dataset: Dataset = paq_dataset_dict["train"] |
| | paq_eval_dataset: Dataset = paq_dataset_dict["test"] |
| | print("Loaded paq dataset.") |
| |
|
| | print("Loading trivia_qa dataset...") |
| | trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train") |
| | trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12) |
| | trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"] |
| | trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"] |
| | print("Loaded trivia_qa dataset.") |
| |
|
| | print("Loading msmarco_10m dataset...") |
| | msmarco_10m_dataset = load_dataset( |
| | "bclavie/msmarco-10m-triplets", split="train" |
| | ) |
| | msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split( |
| | test_size=10_000, seed=12 |
| | ) |
| | msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"] |
| | msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"] |
| | print("Loaded msmarco_10m dataset.") |
| |
|
| | print("Loading swim_ir dataset...") |
| | swim_ir_dataset = load_dataset( |
| | "nthakur/swim-ir-monolingual", "en", split="train" |
| | ).select_columns(["query", "text"]) |
| | swim_ir_dataset_dict = swim_ir_dataset.train_test_split( |
| | test_size=10_000, seed=12 |
| | ) |
| | swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"] |
| | swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"] |
| | print("Loaded swim_ir dataset.") |
| |
|
| | |
| | print("Loading pubmedqa dataset...") |
| | pubmedqa_dataset = load_dataset( |
| | "sentence-transformers/pubmedqa", "triplet-20", split="train" |
| | ) |
| | pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split( |
| | test_size=100, seed=12 |
| | ) |
| | pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"] |
| | pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"] |
| | print("Loaded pubmedqa dataset.") |
| |
|
| | |
| | print("Loading miracl dataset...") |
| | miracl_dataset = load_dataset( |
| | "sentence-transformers/miracl", "en-triplet-all", split="train" |
| | ) |
| | miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12) |
| | miracl_train_dataset: Dataset = miracl_dataset_dict["train"] |
| | miracl_eval_dataset: Dataset = miracl_dataset_dict["test"] |
| | print("Loaded miracl dataset.") |
| |
|
| | |
| | print("Loading mldr dataset...") |
| | mldr_dataset = load_dataset( |
| | "sentence-transformers/mldr", "en-triplet-all", split="train" |
| | ) |
| | mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12) |
| | mldr_train_dataset: Dataset = mldr_dataset_dict["train"] |
| | mldr_eval_dataset: Dataset = mldr_dataset_dict["test"] |
| | print("Loaded mldr dataset.") |
| |
|
| | |
| | print("Loading mr_tydi dataset...") |
| | mr_tydi_dataset = load_dataset( |
| | "sentence-transformers/mr-tydi", "en-triplet-all", split="train" |
| | ) |
| | mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split( |
| | test_size=10_000, seed=12 |
| | ) |
| | mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"] |
| | mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"] |
| | print("Loaded mr_tydi dataset.") |
| |
|
| | train_dataset = DatasetDict( |
| | { |
| | "gooaq": gooaq_train_dataset, |
| | "msmarco": msmarco_train_dataset, |
| | "squad": squad_train_dataset, |
| | "s2orc": s2orc_train_dataset, |
| | "allnli": allnli_train_dataset, |
| | "paq": paq_train_dataset, |
| | "trivia_qa": trivia_qa_train_dataset, |
| | "msmarco_10m": msmarco_10m_train_dataset, |
| | "swim_ir": swim_ir_train_dataset, |
| | "pubmedqa": pubmedqa_train_dataset, |
| | "miracl": miracl_train_dataset, |
| | "mldr": mldr_train_dataset, |
| | "mr_tydi": mr_tydi_train_dataset, |
| | } |
| | ) |
| | eval_dataset = DatasetDict( |
| | { |
| | "gooaq": gooaq_eval_dataset, |
| | "msmarco": msmarco_eval_dataset, |
| | "squad": squad_eval_dataset, |
| | "s2orc": s2orc_eval_dataset, |
| | "allnli": allnli_eval_dataset, |
| | "paq": paq_eval_dataset, |
| | "trivia_qa": trivia_qa_eval_dataset, |
| | "msmarco_10m": msmarco_10m_eval_dataset, |
| | "swim_ir": swim_ir_eval_dataset, |
| | "pubmedqa": pubmedqa_eval_dataset, |
| | "miracl": miracl_eval_dataset, |
| | "mldr": mldr_eval_dataset, |
| | "mr_tydi": mr_tydi_eval_dataset, |
| | } |
| | ) |
| |
|
| | train_dataset.save_to_disk(str(en_train_dataset_dir)) |
| | eval_dataset.save_to_disk(str(en_eval_dataset_dir)) |
| | return train_dataset, eval_dataset |
| |
|
| |
|
| | def load_train_eval_datasets_en(target_dataset_names: list[str] = []): |
| | print("Loading train and eval datasets...") |
| | if len(target_dataset_names) == 0: |
| | return DatasetDict(), DatasetDict() |
| | train_dataset, eval_dataset = _load_train_eval_datasets_en() |
| | ds_names = list(train_dataset.keys()) |
| | for ds_name in ds_names: |
| | if ds_name not in target_dataset_names: |
| | del train_dataset[ds_name] |
| | del eval_dataset[ds_name] |
| | else: |
| | print( |
| | "target en ds", |
| | ds_name, |
| | len(train_dataset[ds_name]), |
| | len(eval_dataset[ds_name]), |
| | ) |
| | return train_dataset, eval_dataset |
| |
|
| |
|
| | def load_train_eval_datasets_jp(target_dataset_names: list[str] = []): |
| | print("Loading train and eval datasets...") |
| | jp_train_dataset_dir = PROJECT_ROOT / "datasets" / "jp_train_dataset" |
| | jp_eval_dataset_dir = PROJECT_ROOT / "datasets" / "jp_eval_dataset" |
| |
|
| | train_dataset_dict = {} |
| | eval_dataset_dict = {} |
| |
|
| | for ds_name in target_dataset_names: |
| | print("loading jp ds", ds_name) |
| | try: |
| | train_ds = Dataset.load_from_disk(f"{jp_train_dataset_dir}/{ds_name}") |
| | eval_ds = Dataset.load_from_disk(f"{jp_eval_dataset_dir}/{ds_name}") |
| |
|
| | except FileNotFoundError: |
| | print(f"{ds_name} not found, loading from datasets library...") |
| | ds = load_dataset( |
| | "hotchpotch/sentence_transformer_japanese", ds_name, split="train" |
| | ) |
| | ds_size = len(ds) |
| | test_size = min(3000, ds_size // 100) |
| | splitted = ds.train_test_split(test_size=test_size, seed=12) |
| | train_ds = splitted["train"] |
| | eval_ds = splitted["test"] |
| | |
| | train_ds.save_to_disk(f"{jp_train_dataset_dir}/{ds_name}") |
| | eval_ds.save_to_disk(f"{jp_eval_dataset_dir}/{ds_name}") |
| | train_dataset_dict[ds_name] = train_ds |
| | eval_dataset_dict[ds_name] = eval_ds |
| | return DatasetDict(train_dataset_dict), DatasetDict(eval_dataset_dict) |
| |
|
| |
|
| | def main(): |
| | |
| | print("Loading model...") |
| | static_embedding = StaticEmbedding( |
| | AutoTokenizer.from_pretrained("hotchpotch/xlm-roberta-japanese-tokenizer"), |
| | embedding_dim=1024, |
| | ) |
| | model = SentenceTransformer( |
| | modules=[static_embedding], |
| | model_card_data=SentenceTransformerModelCardData( |
| | language="ja", |
| | license="mit", |
| | model_name="Static Embeddings with japanese tokenizer finetuned on various datasets", |
| | ), |
| | ) |
| |
|
| | |
| | print("Loading datasets...") |
| | train_dataset_en, eval_dataset_en = load_train_eval_datasets_en(EN_TARGET_DATASETS) |
| | train_dataset_jp, eval_dataset_jp = load_train_eval_datasets_jp(JA_TARGET_DATASETS) |
| | |
| | print("Merging datasets...") |
| | train_dataset = DatasetDict({**train_dataset_en, **train_dataset_jp}) |
| | eval_dataset = DatasetDict({**eval_dataset_en, **eval_dataset_jp}) |
| | for ds_name, aug_factor in AUG_FACTOR_DATASETS.items(): |
| | columns = train_dataset[ds_name].column_names |
| |
|
| | def data_aug(example): |
| | result = {} |
| | for col in columns: |
| | result[col] = example[col] * aug_factor |
| | return result |
| |
|
| | before_len = len(train_dataset[ds_name]) |
| | train_dataset[ds_name] = train_dataset[ds_name].map( |
| | data_aug, batched=True, num_proc=11 |
| | ) |
| | print("data augmented", ds_name, before_len, len(train_dataset[ds_name])) |
| | for train_ds_name in train_dataset.keys(): |
| | print( |
| | "train ds", |
| | train_ds_name, |
| | len(train_dataset[train_ds_name]), |
| | len(eval_dataset[train_ds_name]), |
| | ) |
| |
|
| | |
| | loss = MultipleNegativesRankingLoss(model) |
| | loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024]) |
| |
|
| | |
| | run_name = f"static-retrieval-mrl-jp-v1_{EXP}" |
| | args = SentenceTransformerTrainingArguments( |
| | |
| | output_dir=f"models/{run_name}", |
| | |
| | num_train_epochs=2, |
| | per_device_train_batch_size=2048 * 3, |
| | |
| | per_device_eval_batch_size=2048, |
| | learning_rate=2e-1, |
| | lr_scheduler_type="cosine", |
| | |
| | warmup_ratio=0.1, |
| | fp16=False, |
| | bf16=True, |
| | batch_sampler=BatchSamplers.NO_DUPLICATES, |
| | multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, |
| | |
| | eval_strategy="steps", |
| | eval_steps=200, |
| | save_strategy="steps", |
| | save_steps=200, |
| | save_total_limit=20, |
| | logging_steps=20, |
| | logging_first_step=True, |
| | dataloader_prefetch_factor=4, |
| | dataloader_num_workers=15, |
| | run_name=run_name, |
| | ) |
| |
|
| | |
| | evaluator = NanoBEIREvaluator() |
| | evaluator(model) |
| |
|
| | |
| | trainer = SentenceTransformerTrainer( |
| | model=model, |
| | args=args, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | loss=loss, |
| | evaluator=evaluator, |
| | ) |
| | trainer.train() |
| |
|
| | |
| | evaluator(model) |
| |
|
| | |
| | model.save_pretrained(f"{PROJECT_ROOT}/models/{run_name}/final") |
| |
|
| | |
| | |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|