| from datasets import load_dataset, Dataset |
| import pandas as pd |
| from sentence_transformers import ( |
| SentenceTransformer, |
| SentenceTransformerTrainer, |
| SentenceTransformerTrainingArguments, |
| ) |
| from sentence_transformers.losses import CachedMultipleNegativesRankingLoss |
|
|
|
|
| def clean_text(x): |
| if x is None: |
| return "" |
| x = str(x).strip() |
| x = " ".join(x.split()) |
| return x |
|
|
|
|
| def build_doc_fast(context): |
| return f"text: {context}" |
|
|
|
|
| def main(): |
| dataset_name = "phamson02/large-vi-legal-queries" |
|
|
| |
| ds = load_dataset(dataset_name, split="train") |
| df = ds.to_pandas() |
|
|
| print("Raw shape:", df.shape) |
|
|
| for col in ["domain", "title", "header", "aspect", "context", "query"]: |
| if col not in df.columns: |
| df[col] = "" |
| df[col] = df[col].apply(clean_text) |
|
|
| df = df[(df["query"] != "") & (df["context"] != "")] |
| df = df.drop_duplicates(subset=["query", "context"]).reset_index(drop=True) |
|
|
| print("Cleaned rows:", len(df)) |
|
|
| train_df = pd.DataFrame( |
| { |
| "anchor": df["query"].tolist(), |
| "positive": [build_doc_fast(context) for context in df["context"].tolist()], |
| } |
| ) |
|
|
| print(train_df.head()) |
|
|
| train_dataset = Dataset.from_pandas(train_df, preserve_index=False) |
| print(train_dataset[0]) |
|
|
| |
| model = SentenceTransformer( |
| "google/embeddinggemma-300m", |
| model_kwargs={ |
| |
| |
| }, |
| ) |
| model.max_seq_length = 512 |
|
|
| loss = CachedMultipleNegativesRankingLoss( |
| model, |
| mini_batch_size=32, |
| gather_across_devices=False, |
| ) |
|
|
| task_name = "Retrieval" |
| training_args = SentenceTransformerTrainingArguments( |
| prompts=model.prompts[task_name], |
| torch_compile=False, |
| output_dir="./embeddinggemma-300m-vilegal", |
| num_train_epochs=1, |
| per_device_train_batch_size=1024, |
| gradient_accumulation_steps=1, |
| learning_rate=2e-5, |
| warmup_ratio=0.1, |
| bf16=True, |
| logging_steps=10, |
| save_strategy="epoch", |
| report_to="none", |
| remove_unused_columns=False, |
| dataloader_num_workers=8, |
| dataloader_persistent_workers=True, |
| |
| ddp_find_unused_parameters=False, |
| ) |
|
|
| trainer = SentenceTransformerTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| loss=loss, |
| ) |
|
|
| trainer.train() |
|
|
| |
| trainer.save_model("./embeddinggemma-300m-vilegal") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|