| import os |
| from pathlib import Path |
| import torch |
| import pandas as pd |
| from datasets import load_dataset, Dataset, load_from_disk |
| from sentence_transformers import ( |
| SentenceTransformer, |
| SentenceTransformerTrainer, |
| SentenceTransformerTrainingArguments, |
| ) |
| from sentence_transformers.losses import CachedMultipleNegativesRankingLoss |
| from sentence_transformers.training_args import BatchSamplers |
| from sentence_transformers.util import mine_hard_negatives |
|
|
|
|
| |
| |
| |
| DATASET_NAME = "phamson02/large-vi-legal-queries" |
|
|
| |
| BASE_MODEL_DIR = "./embeddinggemma-300m-vilegal" |
| BASE_MODEL_CHECKPOINT_DIR = "./embeddinggemma-300m-vilegal" |
|
|
| |
| STAGE2_OUTPUT_DIR = "./embeddinggemma-300m-vilegal-stage2-hardneg" |
|
|
| |
| CACHE_ROOT = Path("./cache_vilegal_stage2") |
| CACHE_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
| |
| CLEAN_DF_PATH = CACHE_ROOT / "clean_df.parquet" |
| PAIR_DATASET_PATH = CACHE_ROOT / "pair_dataset" |
| HARD_NEGATIVE_DATA_DIR = CACHE_ROOT / "hardneg_dataset" |
|
|
| TASK_NAME = "Retrieval" |
|
|
| |
| LIMIT_ROWS = None |
|
|
| |
| USE_CACHE_CLEAN_DF = True |
| USE_CACHE_PAIR_DATASET = True |
| USE_CACHE_HARDNEG = True |
| FORCE_REBUILD_CLEAN_DF = False |
| FORCE_REBUILD_PAIR_DATASET = False |
| FORCE_REBUILD_HARDNEG = False |
|
|
| |
| RESUME_IF_POSSIBLE = True |
|
|
|
|
| |
| |
| |
| def clean_text(x): |
| if x is None: |
| return "" |
| x = str(x).strip() |
| x = " ".join(x.split()) |
| return x |
|
|
|
|
| def build_positive_document(row): |
| return f"text: {row['context']}" |
|
|
|
|
| def path_exists_and_nonempty(path: Path) -> bool: |
| return path.exists() and any(path.iterdir()) if path.is_dir() else path.exists() |
|
|
|
|
| def get_last_checkpoint(output_dir: str): |
| output_path = Path(output_dir) |
| if not output_path.exists(): |
| return None |
|
|
| checkpoints = [] |
| for p in output_path.iterdir(): |
| if p.is_dir() and p.name.startswith("checkpoint-"): |
| try: |
| step = int(p.name.split("-")[-1]) |
| checkpoints.append((step, p)) |
| except ValueError: |
| continue |
|
|
| if not checkpoints: |
| return None |
|
|
| checkpoints.sort(key=lambda x: x[0]) |
| return str(checkpoints[-1][1]) |
|
|
|
|
| |
| |
| |
| def load_and_prepare_dataframe(limit_rows=None): |
| if ( |
| USE_CACHE_CLEAN_DF |
| and not FORCE_REBUILD_CLEAN_DF |
| and CLEAN_DF_PATH.exists() |
| ): |
| print(f"✅ Loading cached clean dataframe from: {CLEAN_DF_PATH}") |
| df = pd.read_parquet(CLEAN_DF_PATH) |
| print("Cached clean rows:", len(df)) |
| return df |
|
|
| print("📥 Loading raw dataset from hub...") |
| ds = load_dataset(DATASET_NAME, split="train") |
|
|
| if limit_rows is not None: |
| ds = ds.select(range(min(limit_rows, len(ds)))) |
|
|
| df = ds.to_pandas() |
| print("Raw shape:", df.shape) |
|
|
| for col in ["domain", "title", "header", "aspect", "context", "query"]: |
| if col not in df.columns: |
| df[col] = "" |
| df[col] = df[col].apply(clean_text) |
|
|
| df = df[(df["query"] != "") & (df["context"] != "")] |
| df = df.drop_duplicates(subset=["query", "context"]).reset_index(drop=True) |
|
|
| print("Cleaned rows:", len(df)) |
|
|
| if USE_CACHE_CLEAN_DF: |
| print(f"💾 Saving clean dataframe cache to: {CLEAN_DF_PATH}") |
| df.to_parquet(CLEAN_DF_PATH, index=False) |
|
|
| return df |
|
|
|
|
| def build_pair_dataset(df): |
| if ( |
| USE_CACHE_PAIR_DATASET |
| and not FORCE_REBUILD_PAIR_DATASET |
| and path_exists_and_nonempty(PAIR_DATASET_PATH) |
| ): |
| print(f"✅ Loading cached pair dataset from: {PAIR_DATASET_PATH}") |
| dataset = load_from_disk(str(PAIR_DATASET_PATH)) |
| print("Cached pair dataset:", dataset) |
| return dataset |
|
|
| print("🛠 Building pair dataset...") |
| pair_df = pd.DataFrame( |
| { |
| "query": df["query"].tolist(), |
| "positive": df.apply(build_positive_document, axis=1).tolist(), |
| } |
| ) |
|
|
| dataset = Dataset.from_pandas(pair_df, preserve_index=False) |
|
|
| if USE_CACHE_PAIR_DATASET: |
| print(f"💾 Saving pair dataset cache to: {PAIR_DATASET_PATH}") |
| dataset.save_to_disk(str(PAIR_DATASET_PATH)) |
|
|
| return dataset |
|
|
|
|
| |
| |
| |
| def mine_hard_negative_dataset(pair_dataset, model_dir): |
| if ( |
| USE_CACHE_HARDNEG |
| and not FORCE_REBUILD_HARDNEG |
| and path_exists_and_nonempty(HARD_NEGATIVE_DATA_DIR) |
| ): |
| print(f"✅ Loading cached hard negative dataset from: {HARD_NEGATIVE_DATA_DIR}") |
| hn_dataset = load_from_disk(str(HARD_NEGATIVE_DATA_DIR)) |
| print("Cached hard negative dataset:", hn_dataset) |
| return hn_dataset |
|
|
| print("⛏ Mining hard negatives...") |
| miner_model = SentenceTransformer(model_dir) |
| miner_model.max_seq_length = 512 |
|
|
| hn_dataset = mine_hard_negatives( |
| dataset=pair_dataset, |
| model=miner_model, |
| positive_column_name="positive", |
| range_min=10, |
| range_max=50, |
| relative_margin=0.05, |
| num_negatives=3, |
| sampling_strategy="random", |
| batch_size=128, |
| use_faiss=True, |
| query_prompt_name="query", |
| corpus_prompt_name="document", |
| output_format="n-tuple", |
| use_multi_process=True, |
| ) |
|
|
| if USE_CACHE_HARDNEG: |
| print(f"💾 Saving hard negative dataset cache to: {HARD_NEGATIVE_DATA_DIR}") |
| hn_dataset.save_to_disk(str(HARD_NEGATIVE_DATA_DIR)) |
|
|
| return hn_dataset |
|
|
|
|
| def preview_hard_negatives(hn_dataset, sample_size=10): |
| if len(hn_dataset) == 0: |
| print("No hard negatives to preview.") |
| return |
|
|
| sample_df = hn_dataset.to_pandas().sample( |
| min(sample_size, len(hn_dataset)), |
| random_state=42, |
| ) |
|
|
| for _, row in sample_df.iterrows(): |
| print("=" * 100) |
| print("QUERY:\n", row["query"]) |
| print("\nPOSITIVE:\n", row["positive"][:700]) |
|
|
| for i in range(1, 10): |
| neg_col = f"negative_{i}" |
| if neg_col in row and isinstance(row[neg_col], str): |
| print(f"\n{neg_col.upper()}:\n", row[neg_col][:700]) |
|
|
| print() |
|
|
|
|
| |
| |
| |
| def train_stage2_with_hardneg(hn_dataset, model_checkpoint_dir, output_dir): |
| |
| train_model = SentenceTransformer(model_checkpoint_dir) |
| train_model.max_seq_length = 512 |
|
|
| loss = CachedMultipleNegativesRankingLoss( |
| train_model, |
| mini_batch_size=32, |
| gather_across_devices=True, |
| ) |
|
|
| training_args = SentenceTransformerTrainingArguments( |
| prompts=train_model.prompts[TASK_NAME], |
| torch_compile=False, |
| output_dir=output_dir, |
| num_train_epochs=1, |
| per_device_train_batch_size=1024, |
| gradient_accumulation_steps=1, |
| learning_rate=1e-5, |
| warmup_ratio=0.1, |
| bf16=torch.cuda.is_available(), |
| logging_steps=50, |
| save_strategy="epoch", |
| report_to="none", |
| remove_unused_columns=False, |
| batch_sampler=BatchSamplers.NO_DUPLICATES, |
| dataloader_num_workers=8, |
| dataloader_persistent_workers=True, |
| dataloader_drop_last=True, |
| ddp_find_unused_parameters=False, |
| ) |
|
|
| trainer = SentenceTransformerTrainer( |
| model=train_model, |
| args=training_args, |
| train_dataset=hn_dataset, |
| loss=loss, |
| ) |
|
|
| resume_checkpoint = None |
| if RESUME_IF_POSSIBLE: |
| resume_checkpoint = get_last_checkpoint(output_dir) |
| if resume_checkpoint is not None: |
| print(f"🔁 Resuming from checkpoint: {resume_checkpoint}") |
| else: |
| print("ℹ️ No checkpoint found. Training from scratch.") |
|
|
| trainer.train(resume_from_checkpoint=resume_checkpoint) |
|
|
| print(f"💾 Saving final model to: {output_dir}") |
| trainer.save_model(output_dir) |
|
|
|
|
| |
| |
| |
| def main(): |
| |
| df = load_and_prepare_dataframe(limit_rows=LIMIT_ROWS) |
|
|
| |
| pair_dataset = build_pair_dataset(df) |
| print("Pair dataset:", pair_dataset) |
| if len(pair_dataset) > 0: |
| print("Sample pair:", pair_dataset[0]) |
|
|
| |
| hn_dataset = mine_hard_negative_dataset( |
| pair_dataset=pair_dataset, |
| model_dir=BASE_MODEL_DIR, |
| ) |
| print("Hard negative dataset:", hn_dataset) |
| if len(hn_dataset) > 0: |
| print("Sample n-tuple:", hn_dataset[0]) |
|
|
| |
| preview_hard_negatives(hn_dataset, sample_size=10) |
|
|
| |
| train_stage2_with_hardneg( |
| hn_dataset=hn_dataset, |
| model_checkpoint_dir=BASE_MODEL_CHECKPOINT_DIR, |
| output_dir=STAGE2_OUTPUT_DIR, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|