import os from pathlib import Path import torch import pandas as pd from datasets import load_dataset, Dataset, load_from_disk from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.losses import CachedMultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers from sentence_transformers.util import mine_hard_negatives # ========================= # CONFIG # ========================= DATASET_NAME = "phamson02/large-vi-legal-queries" # Stage 1 model dir BASE_MODEL_DIR = "./embeddinggemma-300m-vilegal" BASE_MODEL_CHECKPOINT_DIR = "./embeddinggemma-300m-vilegal" # Stage 2 output dir STAGE2_OUTPUT_DIR = "./embeddinggemma-300m-vilegal-stage2-hardneg" # Cache dir CACHE_ROOT = Path("./cache_vilegal_stage2") CACHE_ROOT.mkdir(parents=True, exist_ok=True) # Cached artifacts CLEAN_DF_PATH = CACHE_ROOT / "clean_df.parquet" PAIR_DATASET_PATH = CACHE_ROOT / "pair_dataset" HARD_NEGATIVE_DATA_DIR = CACHE_ROOT / "hardneg_dataset" TASK_NAME = "Retrieval" # Data control LIMIT_ROWS = None # set e.g. 20000 for quick tests # Cache toggles USE_CACHE_CLEAN_DF = True USE_CACHE_PAIR_DATASET = True USE_CACHE_HARDNEG = True FORCE_REBUILD_CLEAN_DF = False FORCE_REBUILD_PAIR_DATASET = False FORCE_REBUILD_HARDNEG = False # Training control RESUME_IF_POSSIBLE = True # ========================= # HELPERS # ========================= def clean_text(x): if x is None: return "" x = str(x).strip() x = " ".join(x.split()) return x def build_positive_document(row): return f"text: {row['context']}" def path_exists_and_nonempty(path: Path) -> bool: return path.exists() and any(path.iterdir()) if path.is_dir() else path.exists() def get_last_checkpoint(output_dir: str): output_path = Path(output_dir) if not output_path.exists(): return None checkpoints = [] for p in output_path.iterdir(): if p.is_dir() and p.name.startswith("checkpoint-"): try: step = int(p.name.split("-")[-1]) checkpoints.append((step, p)) except ValueError: continue if not checkpoints: return None checkpoints.sort(key=lambda x: x[0]) return str(checkpoints[-1][1]) # ========================= # DATA PREP # ========================= def load_and_prepare_dataframe(limit_rows=None): if ( USE_CACHE_CLEAN_DF and not FORCE_REBUILD_CLEAN_DF and CLEAN_DF_PATH.exists() ): print(f"✅ Loading cached clean dataframe from: {CLEAN_DF_PATH}") df = pd.read_parquet(CLEAN_DF_PATH) print("Cached clean rows:", len(df)) return df print("📥 Loading raw dataset from hub...") ds = load_dataset(DATASET_NAME, split="train") if limit_rows is not None: ds = ds.select(range(min(limit_rows, len(ds)))) df = ds.to_pandas() print("Raw shape:", df.shape) for col in ["domain", "title", "header", "aspect", "context", "query"]: if col not in df.columns: df[col] = "" df[col] = df[col].apply(clean_text) df = df[(df["query"] != "") & (df["context"] != "")] df = df.drop_duplicates(subset=["query", "context"]).reset_index(drop=True) print("Cleaned rows:", len(df)) if USE_CACHE_CLEAN_DF: print(f"💾 Saving clean dataframe cache to: {CLEAN_DF_PATH}") df.to_parquet(CLEAN_DF_PATH, index=False) return df def build_pair_dataset(df): if ( USE_CACHE_PAIR_DATASET and not FORCE_REBUILD_PAIR_DATASET and path_exists_and_nonempty(PAIR_DATASET_PATH) ): print(f"✅ Loading cached pair dataset from: {PAIR_DATASET_PATH}") dataset = load_from_disk(str(PAIR_DATASET_PATH)) print("Cached pair dataset:", dataset) return dataset print("🛠 Building pair dataset...") pair_df = pd.DataFrame( { "query": df["query"].tolist(), "positive": df.apply(build_positive_document, axis=1).tolist(), } ) dataset = Dataset.from_pandas(pair_df, preserve_index=False) if USE_CACHE_PAIR_DATASET: print(f"💾 Saving pair dataset cache to: {PAIR_DATASET_PATH}") dataset.save_to_disk(str(PAIR_DATASET_PATH)) return dataset # ========================= # HARD NEGATIVE MINING # ========================= def mine_hard_negative_dataset(pair_dataset, model_dir): if ( USE_CACHE_HARDNEG and not FORCE_REBUILD_HARDNEG and path_exists_and_nonempty(HARD_NEGATIVE_DATA_DIR) ): print(f"✅ Loading cached hard negative dataset from: {HARD_NEGATIVE_DATA_DIR}") hn_dataset = load_from_disk(str(HARD_NEGATIVE_DATA_DIR)) print("Cached hard negative dataset:", hn_dataset) return hn_dataset print("⛏ Mining hard negatives...") miner_model = SentenceTransformer(model_dir) miner_model.max_seq_length = 512 hn_dataset = mine_hard_negatives( dataset=pair_dataset, model=miner_model, positive_column_name="positive", range_min=10, range_max=50, relative_margin=0.05, num_negatives=3, sampling_strategy="random", batch_size=128, use_faiss=True, query_prompt_name="query", corpus_prompt_name="document", output_format="n-tuple", use_multi_process=True, ) if USE_CACHE_HARDNEG: print(f"💾 Saving hard negative dataset cache to: {HARD_NEGATIVE_DATA_DIR}") hn_dataset.save_to_disk(str(HARD_NEGATIVE_DATA_DIR)) return hn_dataset def preview_hard_negatives(hn_dataset, sample_size=10): if len(hn_dataset) == 0: print("No hard negatives to preview.") return sample_df = hn_dataset.to_pandas().sample( min(sample_size, len(hn_dataset)), random_state=42, ) for _, row in sample_df.iterrows(): print("=" * 100) print("QUERY:\n", row["query"]) print("\nPOSITIVE:\n", row["positive"][:700]) for i in range(1, 10): neg_col = f"negative_{i}" if neg_col in row and isinstance(row[neg_col], str): print(f"\n{neg_col.upper()}:\n", row[neg_col][:700]) print() # ========================= # TRAINING # ========================= def train_stage2_with_hardneg(hn_dataset, model_checkpoint_dir, output_dir): # IMPORTANT: không dùng .to("cuda") khi chạy torchrun train_model = SentenceTransformer(model_checkpoint_dir) train_model.max_seq_length = 512 loss = CachedMultipleNegativesRankingLoss( train_model, mini_batch_size=32, gather_across_devices=True, ) training_args = SentenceTransformerTrainingArguments( prompts=train_model.prompts[TASK_NAME], torch_compile=False, output_dir=output_dir, num_train_epochs=1, per_device_train_batch_size=1024, gradient_accumulation_steps=1, learning_rate=1e-5, warmup_ratio=0.1, bf16=torch.cuda.is_available(), logging_steps=50, save_strategy="epoch", report_to="none", remove_unused_columns=False, batch_sampler=BatchSamplers.NO_DUPLICATES, dataloader_num_workers=8, dataloader_persistent_workers=True, dataloader_drop_last=True, ddp_find_unused_parameters=False, ) trainer = SentenceTransformerTrainer( model=train_model, args=training_args, train_dataset=hn_dataset, loss=loss, ) resume_checkpoint = None if RESUME_IF_POSSIBLE: resume_checkpoint = get_last_checkpoint(output_dir) if resume_checkpoint is not None: print(f"🔁 Resuming from checkpoint: {resume_checkpoint}") else: print("ℹ️ No checkpoint found. Training from scratch.") trainer.train(resume_from_checkpoint=resume_checkpoint) print(f"💾 Saving final model to: {output_dir}") trainer.save_model(output_dir) # ========================= # MAIN # ========================= def main(): # 1) Load + clean df = load_and_prepare_dataframe(limit_rows=LIMIT_ROWS) # 2) Build query-positive dataset pair_dataset = build_pair_dataset(df) print("Pair dataset:", pair_dataset) if len(pair_dataset) > 0: print("Sample pair:", pair_dataset[0]) # 3) Mine hard negatives từ stage 1 model hn_dataset = mine_hard_negative_dataset( pair_dataset=pair_dataset, model_dir=BASE_MODEL_DIR, ) print("Hard negative dataset:", hn_dataset) if len(hn_dataset) > 0: print("Sample n-tuple:", hn_dataset[0]) # 4) Preview vài sample hard negative preview_hard_negatives(hn_dataset, sample_size=10) # 5) Train stage 2 từ checkpoint stage 1 train_stage2_with_hardneg( hn_dataset=hn_dataset, model_checkpoint_dir=BASE_MODEL_CHECKPOINT_DIR, output_dir=STAGE2_OUTPUT_DIR, ) if __name__ == "__main__": main()