vilegal / hardneg_2.py
quockhangdev's picture
hardneg script
619e569 verified
import os
from pathlib import Path
import torch
import pandas as pd
from datasets import load_dataset, Dataset, load_from_disk
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.util import mine_hard_negatives
# =========================
# CONFIG
# =========================
DATASET_NAME = "phamson02/large-vi-legal-queries"
# Stage 1 model dir
BASE_MODEL_DIR = "./embeddinggemma-300m-vilegal"
BASE_MODEL_CHECKPOINT_DIR = "./embeddinggemma-300m-vilegal"
# Stage 2 output dir
STAGE2_OUTPUT_DIR = "./embeddinggemma-300m-vilegal-stage2-hardneg"
# Cache dir
CACHE_ROOT = Path("./cache_vilegal_stage2")
CACHE_ROOT.mkdir(parents=True, exist_ok=True)
# Cached artifacts
CLEAN_DF_PATH = CACHE_ROOT / "clean_df.parquet"
PAIR_DATASET_PATH = CACHE_ROOT / "pair_dataset"
HARD_NEGATIVE_DATA_DIR = CACHE_ROOT / "hardneg_dataset"
TASK_NAME = "Retrieval"
# Data control
LIMIT_ROWS = None # set e.g. 20000 for quick tests
# Cache toggles
USE_CACHE_CLEAN_DF = True
USE_CACHE_PAIR_DATASET = True
USE_CACHE_HARDNEG = True
FORCE_REBUILD_CLEAN_DF = False
FORCE_REBUILD_PAIR_DATASET = False
FORCE_REBUILD_HARDNEG = False
# Training control
RESUME_IF_POSSIBLE = True
# =========================
# HELPERS
# =========================
def clean_text(x):
if x is None:
return ""
x = str(x).strip()
x = " ".join(x.split())
return x
def build_positive_document(row):
return f"text: {row['context']}"
def path_exists_and_nonempty(path: Path) -> bool:
return path.exists() and any(path.iterdir()) if path.is_dir() else path.exists()
def get_last_checkpoint(output_dir: str):
output_path = Path(output_dir)
if not output_path.exists():
return None
checkpoints = []
for p in output_path.iterdir():
if p.is_dir() and p.name.startswith("checkpoint-"):
try:
step = int(p.name.split("-")[-1])
checkpoints.append((step, p))
except ValueError:
continue
if not checkpoints:
return None
checkpoints.sort(key=lambda x: x[0])
return str(checkpoints[-1][1])
# =========================
# DATA PREP
# =========================
def load_and_prepare_dataframe(limit_rows=None):
if (
USE_CACHE_CLEAN_DF
and not FORCE_REBUILD_CLEAN_DF
and CLEAN_DF_PATH.exists()
):
print(f"✅ Loading cached clean dataframe from: {CLEAN_DF_PATH}")
df = pd.read_parquet(CLEAN_DF_PATH)
print("Cached clean rows:", len(df))
return df
print("📥 Loading raw dataset from hub...")
ds = load_dataset(DATASET_NAME, split="train")
if limit_rows is not None:
ds = ds.select(range(min(limit_rows, len(ds))))
df = ds.to_pandas()
print("Raw shape:", df.shape)
for col in ["domain", "title", "header", "aspect", "context", "query"]:
if col not in df.columns:
df[col] = ""
df[col] = df[col].apply(clean_text)
df = df[(df["query"] != "") & (df["context"] != "")]
df = df.drop_duplicates(subset=["query", "context"]).reset_index(drop=True)
print("Cleaned rows:", len(df))
if USE_CACHE_CLEAN_DF:
print(f"💾 Saving clean dataframe cache to: {CLEAN_DF_PATH}")
df.to_parquet(CLEAN_DF_PATH, index=False)
return df
def build_pair_dataset(df):
if (
USE_CACHE_PAIR_DATASET
and not FORCE_REBUILD_PAIR_DATASET
and path_exists_and_nonempty(PAIR_DATASET_PATH)
):
print(f"✅ Loading cached pair dataset from: {PAIR_DATASET_PATH}")
dataset = load_from_disk(str(PAIR_DATASET_PATH))
print("Cached pair dataset:", dataset)
return dataset
print("🛠 Building pair dataset...")
pair_df = pd.DataFrame(
{
"query": df["query"].tolist(),
"positive": df.apply(build_positive_document, axis=1).tolist(),
}
)
dataset = Dataset.from_pandas(pair_df, preserve_index=False)
if USE_CACHE_PAIR_DATASET:
print(f"💾 Saving pair dataset cache to: {PAIR_DATASET_PATH}")
dataset.save_to_disk(str(PAIR_DATASET_PATH))
return dataset
# =========================
# HARD NEGATIVE MINING
# =========================
def mine_hard_negative_dataset(pair_dataset, model_dir):
if (
USE_CACHE_HARDNEG
and not FORCE_REBUILD_HARDNEG
and path_exists_and_nonempty(HARD_NEGATIVE_DATA_DIR)
):
print(f"✅ Loading cached hard negative dataset from: {HARD_NEGATIVE_DATA_DIR}")
hn_dataset = load_from_disk(str(HARD_NEGATIVE_DATA_DIR))
print("Cached hard negative dataset:", hn_dataset)
return hn_dataset
print("⛏ Mining hard negatives...")
miner_model = SentenceTransformer(model_dir)
miner_model.max_seq_length = 512
hn_dataset = mine_hard_negatives(
dataset=pair_dataset,
model=miner_model,
positive_column_name="positive",
range_min=10,
range_max=50,
relative_margin=0.05,
num_negatives=3,
sampling_strategy="random",
batch_size=128,
use_faiss=True,
query_prompt_name="query",
corpus_prompt_name="document",
output_format="n-tuple",
use_multi_process=True,
)
if USE_CACHE_HARDNEG:
print(f"💾 Saving hard negative dataset cache to: {HARD_NEGATIVE_DATA_DIR}")
hn_dataset.save_to_disk(str(HARD_NEGATIVE_DATA_DIR))
return hn_dataset
def preview_hard_negatives(hn_dataset, sample_size=10):
if len(hn_dataset) == 0:
print("No hard negatives to preview.")
return
sample_df = hn_dataset.to_pandas().sample(
min(sample_size, len(hn_dataset)),
random_state=42,
)
for _, row in sample_df.iterrows():
print("=" * 100)
print("QUERY:\n", row["query"])
print("\nPOSITIVE:\n", row["positive"][:700])
for i in range(1, 10):
neg_col = f"negative_{i}"
if neg_col in row and isinstance(row[neg_col], str):
print(f"\n{neg_col.upper()}:\n", row[neg_col][:700])
print()
# =========================
# TRAINING
# =========================
def train_stage2_with_hardneg(hn_dataset, model_checkpoint_dir, output_dir):
# IMPORTANT: không dùng .to("cuda") khi chạy torchrun
train_model = SentenceTransformer(model_checkpoint_dir)
train_model.max_seq_length = 512
loss = CachedMultipleNegativesRankingLoss(
train_model,
mini_batch_size=32,
gather_across_devices=True,
)
training_args = SentenceTransformerTrainingArguments(
prompts=train_model.prompts[TASK_NAME],
torch_compile=False,
output_dir=output_dir,
num_train_epochs=1,
per_device_train_batch_size=1024,
gradient_accumulation_steps=1,
learning_rate=1e-5,
warmup_ratio=0.1,
bf16=torch.cuda.is_available(),
logging_steps=50,
save_strategy="epoch",
report_to="none",
remove_unused_columns=False,
batch_sampler=BatchSamplers.NO_DUPLICATES,
dataloader_num_workers=8,
dataloader_persistent_workers=True,
dataloader_drop_last=True,
ddp_find_unused_parameters=False,
)
trainer = SentenceTransformerTrainer(
model=train_model,
args=training_args,
train_dataset=hn_dataset,
loss=loss,
)
resume_checkpoint = None
if RESUME_IF_POSSIBLE:
resume_checkpoint = get_last_checkpoint(output_dir)
if resume_checkpoint is not None:
print(f"🔁 Resuming from checkpoint: {resume_checkpoint}")
else:
print("ℹ️ No checkpoint found. Training from scratch.")
trainer.train(resume_from_checkpoint=resume_checkpoint)
print(f"💾 Saving final model to: {output_dir}")
trainer.save_model(output_dir)
# =========================
# MAIN
# =========================
def main():
# 1) Load + clean
df = load_and_prepare_dataframe(limit_rows=LIMIT_ROWS)
# 2) Build query-positive dataset
pair_dataset = build_pair_dataset(df)
print("Pair dataset:", pair_dataset)
if len(pair_dataset) > 0:
print("Sample pair:", pair_dataset[0])
# 3) Mine hard negatives từ stage 1 model
hn_dataset = mine_hard_negative_dataset(
pair_dataset=pair_dataset,
model_dir=BASE_MODEL_DIR,
)
print("Hard negative dataset:", hn_dataset)
if len(hn_dataset) > 0:
print("Sample n-tuple:", hn_dataset[0])
# 4) Preview vài sample hard negative
preview_hard_negatives(hn_dataset, sample_size=10)
# 5) Train stage 2 từ checkpoint stage 1
train_stage2_with_hardneg(
hn_dataset=hn_dataset,
model_checkpoint_dir=BASE_MODEL_CHECKPOINT_DIR,
output_dir=STAGE2_OUTPUT_DIR,
)
if __name__ == "__main__":
main()