from datasets import load_dataset, Dataset import pandas as pd from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.losses import CachedMultipleNegativesRankingLoss def clean_text(x): if x is None: return "" x = str(x).strip() x = " ".join(x.split()) return x def build_doc_fast(context): return f"text: {context}" def main(): dataset_name = "phamson02/large-vi-legal-queries" # Load + clean ds = load_dataset(dataset_name, split="train") df = ds.to_pandas() print("Raw shape:", df.shape) for col in ["domain", "title", "header", "aspect", "context", "query"]: if col not in df.columns: df[col] = "" df[col] = df[col].apply(clean_text) df = df[(df["query"] != "") & (df["context"] != "")] df = df.drop_duplicates(subset=["query", "context"]).reset_index(drop=True) print("Cleaned rows:", len(df)) train_df = pd.DataFrame( { "anchor": df["query"].tolist(), "positive": [build_doc_fast(context) for context in df["context"].tolist()], } ) print(train_df.head()) train_dataset = Dataset.from_pandas(train_df, preserve_index=False) print(train_dataset[0]) # IMPORTANT: no .to("cuda") here under torchrun / DDP model = SentenceTransformer( "google/embeddinggemma-300m", model_kwargs={ # "torch_dtype": "auto", # "attn_implementation": "flash_attention_2", }, ) model.max_seq_length = 512 loss = CachedMultipleNegativesRankingLoss( model, mini_batch_size=32, gather_across_devices=False, ) task_name = "Retrieval" training_args = SentenceTransformerTrainingArguments( prompts=model.prompts[task_name], torch_compile=False, output_dir="./embeddinggemma-300m-vilegal", num_train_epochs=1, per_device_train_batch_size=1024, gradient_accumulation_steps=1, learning_rate=2e-5, warmup_ratio=0.1, bf16=True, logging_steps=10, save_strategy="epoch", report_to="none", remove_unused_columns=False, dataloader_num_workers=8, dataloader_persistent_workers=True, # Often helpful for DDP stability/perf with Transformer training: ddp_find_unused_parameters=False, ) trainer = SentenceTransformerTrainer( model=model, args=training_args, train_dataset=train_dataset, loss=loss, ) trainer.train() # Save only once from the main process trainer.save_model("./embeddinggemma-300m-vilegal") if __name__ == "__main__": main()