Instructions to use nicher92/saga-embed_v1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use nicher92/saga-embed_v1 with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("nicher92/saga-embed_v1") sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium." ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [3, 3] - Notebooks
- Google Colab
- Kaggle
| # DID NOT USE THE NEW MSMARCO DATSET, OTHERWISE EVERYTHING IS THE SAME | |
| import logging | |
| import os | |
| import argparse | |
| import config # Assuming your config file is available | |
| import json | |
| import torch | |
| import pandas as pd | |
| os.environ["HF_HOME"] = config.CACHE_DIR | |
| from pathlib import Path | |
| from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset, DatasetDict | |
| from sentence_transformers.evaluation import ( | |
| SequentialEvaluator, | |
| EmbeddingSimilarityEvaluator, | |
| InformationRetrievalEvaluator, | |
| TripletEvaluator, | |
| ) | |
| from sentence_transformers.trainer import SentenceTransformerTrainer | |
| from sentence_transformers.training_args import ( | |
| SentenceTransformerTrainingArguments, | |
| BatchSamplers, | |
| MultiDatasetBatchSamplers | |
| ) | |
| from sentence_transformers import SentenceTransformer, losses, models | |
| import transformers | |
| from src.custom_loss.CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeightAndMask import CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeightAndMask | |
| from src.custom_loss.CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeight import CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeight | |
| import random | |
| from sklearn.cluster import MiniBatchKMeans | |
| from sklearn.metrics import v_measure_score, adjusted_rand_score, accuracy_score, f1_score | |
| from sentence_transformers.evaluation import SentenceEvaluator | |
| from sklearn.linear_model import LogisticRegression | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| transformers.logging.set_verbosity_info() | |
| # --- MINED DATASET CONFIGURATION --- | |
| MINED_DATASETS_BASE_PATH = Path("/mnt/disk2/translated_datasets") | |
| # Map dataset names to their task categories | |
| MINED_DATASET_CONFIG = { | |
| "gooaq_train_swedish_triplets_scored": "question answering", | |
| "reddit_train_swedish_triplets_scored": "retrieval", | |
| "xsum_train_swedish_triplets_scored": "clustering", | |
| "simple-wiki_train_swedish_triplets_scored": "semantic similarity", | |
| "s2orc_train_swedish_saved_triplets_scored": "retrieval", | |
| "amazon-reviews_train_swedish_saved_triplets_scored": "retrieval", | |
| "paq_train_swedish_queries_retranslated_triplets_scored": "question answering", | |
| "stackexchange-duplicates_train_swedish_triplets_scored": "semantic similarity", | |
| "wikipedia-sections_train_swedish_triplets_scored": "retrieval", | |
| "msmarco_triplets_swedish_triplets_scored" : "retrieval", | |
| } | |
| synthetic_data_path = "/mnt/disk2/combined_synthetic_data_deduplicated_classified" | |
| #synthetic_classification_data_path = "/mnt/disk2/classification_dataset_graded_subset.jsonl" | |
| #combined_synthetic_data_deduplicated_classified | |
| # Maximum number of negatives to use per example from mined datasets | |
| # The loss function (CachedMultipleNegativesRankingLoss) can handle multiple negatives | |
| # Format: (anchor, positive, negative_1, negative_2, ..., negative_n) | |
| MAX_NEGATIVES_PER_EXAMPLE = 10 | |
| MAX_SAMPLES = 500_000 | |
| NANOBEIR_DATASETS = [ | |
| "NanoArguAna", | |
| "NanoClimateFEVER", | |
| "NanoDBPedia", | |
| "NanoFEVER", | |
| "NanoFiQA2018", | |
| "NanoHotpotQA", | |
| "NanoMSMARCO", | |
| "NanoNFCorpus", | |
| "NanoNQ", | |
| "NanoQuoraRetrieval", | |
| "NanoSCIDOCS", | |
| "NanoSciFact", | |
| "NanoTouche2020", | |
| ] | |
| # Map dataset names to task types for appropriate prompting | |
| NANOBEIR_TASK_TYPES = { | |
| "NanoArguAna": "retrieval", | |
| "NanoClimateFEVER": "retrieval", | |
| "NanoDBPedia": "retrieval", | |
| "NanoFEVER": "retrieval", | |
| "NanoFiQA2018": "retrieval", | |
| "NanoHotpotQA": "retrieval", | |
| "NanoMSMARCO": "retrieval", | |
| "NanoNFCorpus": "retrieval", | |
| "NanoNQ": "question answering", | |
| "NanoQuoraRetrieval": "semantic similarity", | |
| "NanoSCIDOCS": "retrieval", | |
| "NanoSciFact": "retrieval", | |
| "NanoTouche2020": "retrieval", | |
| } | |
| class ProbeClassificationEvaluator(SentenceEvaluator): | |
| """ | |
| Generic evaluator that trains a Logistic Regression probe on train_set | |
| and evaluates accuracy on test_set. | |
| """ | |
| def __init__(self, dataset_name, sentences_train, labels_train, sentences_test, labels_test, batch_size=32, prompt_prefix=""): | |
| self.name = f"eval_{dataset_name}_classification" | |
| self.sentences_train = [prompt_prefix + s for s in sentences_train] | |
| self.labels_train = labels_train | |
| self.sentences_test = [prompt_prefix + s for s in sentences_test] | |
| self.labels_test = labels_test | |
| self.batch_size = batch_size | |
| def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: | |
| # 1. Encode | |
| emb_train = model.encode(self.sentences_train, batch_size=self.batch_size, show_progress_bar=False) | |
| emb_test = model.encode(self.sentences_test, batch_size=self.batch_size, show_progress_bar=False) | |
| # 2. Train Probe (Fast LR) | |
| clf = LogisticRegression(random_state=42, solver='lbfgs', max_iter=100, n_jobs=-1) | |
| clf.fit(emb_train, self.labels_train) | |
| # 3. Predict | |
| preds = clf.predict(emb_test) | |
| acc = accuracy_score(self.labels_test, preds) | |
| logging.info(f"{self.name}: Accuracy = {acc:.4f}") | |
| return acc | |
| def load_swedish_reviews_evaluator(prompt_style="standard"): | |
| """ | |
| Loads 'timpal0l/swedish_reviews' for binary sentiment classification. | |
| Safe from MTEB leakage. | |
| """ | |
| try: | |
| logging.info("Loading timpal0l/swedish_reviews for Sentiment Probe...") | |
| # Load the test split (approx 10k samples, good for eval) | |
| dataset = load_dataset("timpal0l/swedish_reviews", split="test") | |
| # Downsample if needed for speed (e.g. keep 2000 samples) | |
| if len(dataset) > 50: | |
| dataset = dataset.shuffle(seed=42).select(range(50)) | |
| sentences = dataset["text"] | |
| labels = dataset["label"] | |
| # Define Prompt | |
| if prompt_style == "standard": | |
| # "task: classification" is the standard prompt for this | |
| prompt = "task: classification | query: " | |
| else: | |
| prompt = "" | |
| # Split 50/50 for Train/Test Probe | |
| # We need to train the logistic regression probe on *some* data to test the embeddings | |
| split_idx = len(sentences) // 2 | |
| evaluator = ProbeClassificationEvaluator( | |
| dataset_name="SwedishReviews-Sentiment", | |
| sentences_train=sentences[:split_idx], | |
| labels_train=labels[:split_idx], | |
| sentences_test=sentences[split_idx:], | |
| labels_test=labels[split_idx:], | |
| prompt_prefix=prompt, | |
| batch_size=32 | |
| ) | |
| return evaluator | |
| except Exception as e: | |
| logging.warning(f"Failed to load Swedish Reviews: {e}") | |
| return None | |
| class CustomJSONLClassificationEvaluator(SentenceEvaluator): | |
| """ | |
| Reads a JSONL file and trains a Logistic Regression probe to predict 'main_title'. | |
| This checks if the topics are linearly separable in the embedding space. | |
| """ | |
| def __init__(self, file_path, min_samples_per_label=5, max_classes=20, batch_size=32): | |
| self.name = "eval_wiki_classification" | |
| self.batch_size = batch_size | |
| self.prompt = "task: classification | query: " | |
| self.sentences_train = [] | |
| self.labels_train = [] | |
| self.sentences_test = [] | |
| self.labels_test = [] | |
| logging.info(f"Loading classification data from {file_path}...") | |
| # 1. Load Data by Class | |
| label_map = {} # { "Afrika": ["text1", "text2"...] } | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| row = json.loads(line) | |
| text = row.get('text', '').strip() | |
| label = row.get('main_title', '').strip() | |
| if text and label: | |
| if label not in label_map: | |
| label_map[label] = [] | |
| label_map[label].append(text) | |
| except: | |
| continue | |
| # 2. Filter Classes | |
| valid_labels = [l for l, texts in label_map.items() if len(texts) >= min_samples_per_label] | |
| valid_labels.sort() # Deterministic order | |
| # 3. Subsample Classes (Avoid training on 5000 classes) | |
| rng = random.Random(42) | |
| if len(valid_labels) > max_classes: | |
| selected_labels = rng.sample(valid_labels, max_classes) | |
| else: | |
| selected_labels = valid_labels | |
| logging.info(f"Classification Probe: Using {len(selected_labels)} classes (topics).") | |
| # 4. Create Train/Test Split (80/20 per class) | |
| for label in selected_labels: | |
| texts = label_map[label] | |
| # Shuffle texts for this label | |
| rng.shuffle(texts) | |
| # Use max 50 samples per class to keep probe fast | |
| texts = texts[:50] | |
| split_idx = int(0.8 * len(texts)) | |
| # Train set | |
| for t in texts[:split_idx]: | |
| self.sentences_train.append(self.prompt + t) | |
| self.labels_train.append(label) | |
| # Test set | |
| for t in texts[split_idx:]: | |
| self.sentences_test.append(self.prompt + t) | |
| self.labels_test.append(label) | |
| logging.info(f"Probe Sizes: Train={len(self.sentences_train)}, Test={len(self.sentences_test)}") | |
| def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: | |
| if not self.sentences_train: | |
| return 0.0 | |
| # 1. Encode | |
| emb_train = model.encode(self.sentences_train, batch_size=self.batch_size, show_progress_bar=False) | |
| emb_test = model.encode(self.sentences_test, batch_size=self.batch_size, show_progress_bar=False) | |
| # 2. Train Probe (Fast Logistic Regression) | |
| clf = LogisticRegression(random_state=42, solver='lbfgs', max_iter=100, n_jobs=-1) | |
| clf.fit(emb_train, self.labels_train) | |
| # 3. Predict & Score | |
| preds = clf.predict(emb_test) | |
| acc = accuracy_score(self.labels_test, preds) | |
| # Macro F1 is better if classes are imbalanced | |
| f1 = f1_score(self.labels_test, preds, average='macro') | |
| logging.info(f"{self.name}: Accuracy={acc:.4f} | F1-Macro={f1:.4f}") | |
| return acc | |
| class CustomJSONLClusteringEvaluator(SentenceEvaluator): | |
| """ | |
| Evaluates clustering on a specific number of distinct Wikipedia topics. | |
| """ | |
| def __init__(self, file_path, min_samples_per_topic=5, max_clusters=50, batch_size=32): | |
| self.name = "eval_wiki_clustering" | |
| self.batch_size = batch_size | |
| self.prompt = "task: clustering | query: " | |
| self.sentences = [] | |
| self.labels = [] | |
| logging.info(f"Loading clustering data from {file_path}...") | |
| # 1. Load ALL data into memory organized by Topic | |
| # Structure: { "Afrika": ["text1", "text2"...], "Amager": ["text1"...] } | |
| topic_map = {} | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| row = json.loads(line) | |
| text = row.get('text', '').strip() | |
| # Use 'main_title' as the Topic Label | |
| label = row.get('main_title', '').strip() | |
| if text and label: | |
| if label not in topic_map: | |
| topic_map[label] = [] | |
| topic_map[label].append(text) | |
| except: | |
| continue | |
| # 2. Filter topics that are too small | |
| # We need topics with enough content to actually cluster meaningful points | |
| valid_topics = [ | |
| topic for topic, texts in topic_map.items() | |
| if len(texts) >= min_samples_per_topic | |
| ] | |
| logging.info(f"Found {len(valid_topics)} valid topics with >= {min_samples_per_topic} samples.") | |
| # 3. Sample exactly 'max_clusters' topics (e.g., 50) | |
| # We sort first to make the random seed deterministic across runs | |
| valid_topics.sort() | |
| # Use a fixed seed so the "random" 50 topics are the same every time you restart training | |
| rng = random.Random(42) | |
| if len(valid_topics) > max_clusters: | |
| selected_topics = rng.sample(valid_topics, max_clusters) | |
| logging.info(f"Subsampled to {max_clusters} distinct topics for evaluation.") | |
| else: | |
| selected_topics = valid_topics | |
| logging.info(f"Using all {len(selected_topics)} topics (fewer than max_clusters).") | |
| # 4. Flatten the data for the evaluator | |
| for topic in selected_topics: | |
| # You can also limit samples per topic here (e.g. max 20 paragraphs per topic) to keep it balanced | |
| texts = topic_map[topic][:20] | |
| for text in texts: | |
| self.sentences.append(self.prompt + text) | |
| self.labels.append(topic) | |
| self.n_clusters = len(selected_topics) | |
| logging.info(f"Final Clustering Probe: {len(self.sentences)} samples across {self.n_clusters} topics.") | |
| def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: | |
| if not self.sentences: | |
| return 0.0 | |
| # Encode | |
| embeddings = model.encode(self.sentences, batch_size=self.batch_size, show_progress_bar=False) | |
| # Cluster | |
| clustering = MiniBatchKMeans( | |
| n_clusters=self.n_clusters, | |
| batch_size=256, | |
| random_state=42, | |
| n_init='auto' | |
| ) | |
| clustering.fit(embeddings) | |
| # Score | |
| v_score = v_measure_score(self.labels, clustering.labels_) | |
| ari_score = adjusted_rand_score(self.labels, clustering.labels_) | |
| logging.info(f"{self.name}: V-Measure={v_score:.4f} | ARI={ari_score:.4f}") | |
| return v_score | |
| def load_nanobeir_evaluator(dataset_name, language, task_type="retrieval"): | |
| """ | |
| Load a NanoBEIR dataset for a specific language and create an InformationRetrievalEvaluator. | |
| Args: | |
| dataset_name: Name of the NanoBEIR dataset (e.g., "NanoMSMARCO") | |
| language: Language code ("sv" for Swedish, "no" for Norwegian) | |
| task_type: Task type for prompting ("retrieval", "question answering", "semantic similarity") | |
| Returns: | |
| InformationRetrievalEvaluator or None if loading fails | |
| """ | |
| try: | |
| logging.info(f"Loading {dataset_name} for language: {language}") | |
| # Load corpus (documents) | |
| corpus_data = load_dataset( | |
| "lightonai/nanobeir-multilingual", | |
| f"{dataset_name}_{language}", | |
| split="corpus" | |
| ) | |
| # Load queries | |
| queries_data = load_dataset( | |
| "lightonai/nanobeir-multilingual", | |
| f"{dataset_name}_{language}", | |
| split="queries" | |
| ) | |
| # Load qrels (relevance judgments) - language independent | |
| qrels_data = load_dataset( | |
| "lightonai/nanobeir-multilingual", | |
| dataset_name, | |
| split="qrels" | |
| ) | |
| # Apply task-specific prompts | |
| if task_type == "retrieval": | |
| query_prompt = "task: search result | query: " | |
| doc_prompt = "title: none | text: " | |
| elif task_type == "question answering": | |
| query_prompt = "task: question answering | query: " | |
| doc_prompt = "title: none | text: " | |
| elif task_type == "semantic similarity": | |
| query_prompt = "task: semantic similarity | query: " | |
| doc_prompt = "task: semantic similarity | query: " | |
| else: | |
| query_prompt = "" | |
| doc_prompt = "" | |
| # Build corpus dictionary with prompts | |
| corpus = {} | |
| for item in corpus_data: | |
| doc_id = item['_id'] | |
| # Combine title and text if title exists | |
| if 'title' in item and item['title']: | |
| text = f"{item['title']} {item['text']}" | |
| else: | |
| text = item['text'] | |
| corpus[doc_id] = doc_prompt + text | |
| # Build queries dictionary with prompts | |
| queries = {} | |
| for item in queries_data: | |
| query_id = item['_id'] | |
| queries[query_id] = query_prompt + item['text'] | |
| # Build relevance dictionary | |
| # Note: NanoBEIR qrels only have query-id and corpus-id (no score column) | |
| # All entries in qrels are considered relevant | |
| relevant_docs = {} | |
| for item in qrels_data: | |
| query_id = item['query-id'] | |
| corpus_id = item['corpus-id'] | |
| if query_id not in relevant_docs: | |
| relevant_docs[query_id] = set() | |
| relevant_docs[query_id].add(corpus_id) | |
| # Filter queries to only those with relevant documents | |
| queries = {qid: text for qid, text in queries.items() if qid in relevant_docs} | |
| if not queries: | |
| logging.warning(f"No queries with relevant documents found for {dataset_name}_{language}") | |
| return None | |
| # Create evaluator | |
| evaluator_name = f"{dataset_name}-{language}-dev" | |
| evaluator = InformationRetrievalEvaluator( | |
| queries=queries, | |
| corpus=corpus, | |
| relevant_docs=relevant_docs, | |
| name=evaluator_name, | |
| # Only compute NDCG@10 to reduce metric clutter | |
| mrr_at_k=[10], # Disable MRR | |
| ndcg_at_k=[10], # Only NDCG@10 | |
| accuracy_at_k=[1], # Disable accuracy | |
| precision_recall_at_k=[1], # Disable precision/recall | |
| map_at_k=[100], # Disable MAP | |
| ) | |
| logging.info(f"Created {evaluator_name} with {len(queries)} queries and {len(corpus)} documents") | |
| return evaluator | |
| except Exception as e: | |
| logging.error(f"Failed to load {dataset_name} for {language}: {e}") | |
| return None | |
| def create_nanobeir_evaluators(languages=["sv", "no"], dataset_names=None): | |
| """ | |
| Create InformationRetrievalEvaluators for multiple NanoBEIR datasets and languages. | |
| Args: | |
| languages: List of language codes (default: ["sv", "no"]) | |
| dataset_names: List of dataset names to use (default: all NANOBEIR_DATASETS) | |
| Returns: | |
| List of InformationRetrievalEvaluators | |
| """ | |
| if dataset_names is None: | |
| dataset_names = NANOBEIR_DATASETS | |
| evaluators = [] | |
| for language in languages: | |
| logging.info(f"\n=== Loading NanoBEIR datasets for {language.upper()} ===") | |
| for dataset_name in dataset_names: | |
| task_type = NANOBEIR_TASK_TYPES.get(dataset_name, "retrieval") | |
| evaluator = load_nanobeir_evaluator(dataset_name, language, task_type) | |
| if evaluator is not None: | |
| evaluators.append(evaluator) | |
| logging.info(f"\nSuccessfully created {len(evaluators)} NanoBEIR evaluators") | |
| return evaluators | |
| def handle_none_negatives(example): | |
| """Replaces None in the 'negative' column with an empty string.""" | |
| if example["negative"] is None: | |
| example["negative"] = "" | |
| return example | |
| def is_good_or_excellent(example): | |
| """Filter function to keep only good or excellent graded examples.""" | |
| if "dialect" in example["task_description"].lower(): | |
| return False | |
| return example['grade'] in ['good', 'excellent'] | |
| def prepare_triplet_eval_data(dev_set): | |
| """ | |
| Prepares anchors, positives, and negatives for a TripletEvaluator. | |
| Extracts raw data WITHOUT prompts - prompts will be applied separately. | |
| Handles both 'negative' and 'negative_1' column names. | |
| """ | |
| anchors = dev_set["anchor"] | |
| positives = dev_set["positive"] | |
| # Handle both column naming conventions | |
| if "negative" in dev_set.column_names: | |
| negatives = dev_set["negative"] | |
| elif "negative_1" in dev_set.column_names: | |
| negatives = dev_set["negative_1"] | |
| else: | |
| raise ValueError(f"No negative column found. Available columns: {dev_set.column_names}") | |
| return anchors, positives, negatives | |
| def load_clustering_dataset(dataset_path, max_negatives=5): | |
| """ | |
| Loads the pre-saved Clustering dataset from disk. | |
| Applies clustering prompts and handles schema consistency for DDP. | |
| """ | |
| if not os.path.exists(dataset_path): | |
| logging.warning(f"Clustering data not found at: {dataset_path}") | |
| return None | |
| logging.info(f"Loading clustering dataset from disk: {dataset_path}") | |
| dataset = load_from_disk(str(dataset_path)) | |
| # 1. Ensure we don't have None values (Tokenizer crash protection) | |
| def fill_empty(example): | |
| for key in example.keys(): | |
| if example[key] is None: | |
| example[key] = "" | |
| return example | |
| dataset = dataset.map(fill_empty) | |
| # 2. Add clustering category and apply prompts BEFORE padding | |
| dataset = dataset.add_column("new_category", ["clustering"] * len(dataset)) | |
| dataset = dataset.map(apply_task_prompt) | |
| dataset = dataset.remove_columns(['new_category']) | |
| # 3. Enforce exact column schema for DDP - pad AFTER applying prompts | |
| desired_cols = ["anchor", "positive"] + [f"negative_{i+1}" for i in range(max_negatives)] | |
| # Pad if we have fewer negatives than the global config | |
| for col in desired_cols: | |
| if col not in dataset.column_names: | |
| dataset = dataset.add_column(col, [""] * len(dataset)) | |
| # Select only the needed columns (drops extra negatives if you have > max) | |
| dataset = dataset.select_columns(desired_cols) | |
| logging.info(f"Loaded {len(dataset)} clustering examples with prompts applied.") | |
| return dataset | |
| def apply_task_prompt(example, dropout_rate=0.1): | |
| """ | |
| Applies strict task-specific prompts based on the provided schema. | |
| Skips empty strings to avoid adding prompts to padding columns. | |
| """ | |
| task_type = example["new_category"] | |
| if random.random() < dropout_rate: | |
| return example | |
| # --- 1. RETRIEVAL (Asymmetric) --- | |
| # Anchor: "task: search result | query: {content}" | |
| # Docs: "title: none | text: {content}" | |
| if task_type == "retrieval": | |
| if example['anchor']: # Only apply if not empty | |
| example['anchor'] = "task: search result | query: " + example['anchor'] | |
| # Define the document prompt (assuming no title column available, using 'none') | |
| doc_prompt = "title: none | text: " | |
| if example['positive']: # Only apply if not empty | |
| example['positive'] = doc_prompt + example['positive'] | |
| # Apply to all negatives | |
| for key in list(example.keys()): | |
| if key.startswith('negative') and example[key]: # This already checks for non-empty | |
| example[key] = doc_prompt + example[key] | |
| # --- 2. QUESTION ANSWERING (Symmetric) --- | |
| # User Request: Same prompt for anchor and positive. | |
| # Prompt: "task: question answering | query: {content}" | |
| elif task_type == "question answering": | |
| instruct = "task: question answering | query: " | |
| if example['anchor']: | |
| example['anchor'] = instruct + example['anchor'] | |
| if example['positive']: | |
| example['positive'] = instruct + example['positive'] | |
| for key in list(example.keys()): | |
| if key.startswith('negative') and example[key]: | |
| example[key] = instruct + example[key] | |
| # --- 3. CLUSTERING (Symmetric) --- | |
| # Prompt: "task: clustering | query: {content}" | |
| elif task_type == "clustering": | |
| instruct = "task: clustering | query: " | |
| if example['anchor']: | |
| example['anchor'] = instruct + example['anchor'] | |
| if example['positive']: | |
| example['positive'] = instruct + example['positive'] | |
| for key in list(example.keys()): | |
| if key.startswith('negative') and example[key]: | |
| example[key] = instruct + example[key] | |
| # --- 4. CLASSIFICATION (Symmetric) --- | |
| # Prompt: "task: classification | query: {content}" | |
| elif task_type == "classification": | |
| instruct = "task: classification | query: " | |
| if example['anchor']: | |
| example['anchor'] = instruct + example['anchor'] | |
| if example['positive']: | |
| example['positive'] = instruct + example['positive'] | |
| for key in list(example.keys()): | |
| if key.startswith('negative') and example[key]: | |
| example[key] = instruct + example[key] | |
| # --- 5. SEMANTIC SIMILARITY (Symmetric) --- | |
| # Prompt: "task: sentence similarity | query: {content}" | |
| elif task_type == "semantic similarity": | |
| instruct = "task: semantic similarity | query: " | |
| if example['anchor']: | |
| example['anchor'] = instruct + example['anchor'] | |
| if example['positive']: | |
| example['positive'] = instruct + example['positive'] | |
| for key in list(example.keys()): | |
| if key.startswith('negative') and example[key]: | |
| example[key] = instruct + example[key] | |
| # Fallback | |
| else: | |
| logging.warning(f"Unknown task category: {task_type}. No prompt applied.") | |
| return example | |
| def load_mined_dataset(dataset_name, task_category, max_samples=None, max_negatives=10): | |
| """ | |
| Load a single mined dataset and prepare it for training. | |
| Args: | |
| dataset_name: Name of the dataset directory | |
| task_category: Task category for prompting | |
| max_samples: Optional limit on number of samples to load | |
| max_negatives: Maximum number of negatives to include per example (default: 10) | |
| Returns: | |
| Processed dataset with prompts applied, with multiple negatives per example | |
| """ | |
| dataset_path = MINED_DATASETS_BASE_PATH / dataset_name | |
| if not dataset_path.exists(): | |
| logging.warning(f"Dataset path does not exist: {dataset_path}. Skipping.") | |
| return None | |
| logging.info(f"Loading mined dataset from: {dataset_path}") | |
| dataset = load_from_disk(str(dataset_path)) | |
| if max_samples and len(dataset) > max_samples: | |
| logging.info(f"Sampling {max_samples} from {len(dataset)} samples") | |
| dataset = dataset.shuffle(seed=42).select(range(max_samples)) | |
| logging.info(f"Loaded {len(dataset)} samples from {dataset_name}") | |
| # The mined datasets have structure: anchor, positive, pos_score, negatives (list), neg_scores (list) | |
| # We need to convert to: anchor, positive, negative_1, negative_2, ..., negative_n | |
| # BUT we only keep negatives that actually exist (no None/empty values) | |
| def expand_negatives(example): | |
| """ | |
| Convert the list of negatives into separate columns. | |
| Ensures ALL columns from negative_1 to negative_{max_negatives} exist. | |
| Pads missing negatives with None. | |
| """ | |
| result = { | |
| 'anchor': example['anchor'], | |
| 'positive': example['positive'], | |
| } | |
| # Get the list of negatives, ensure it's a list | |
| raw_negatives = example.get('negatives', []) | |
| if raw_negatives is None: | |
| raw_negatives = [] | |
| # Filter out empty strings or None values from the source list | |
| valid_negatives = [n for n in raw_negatives if n] | |
| # --- KEY FIX IS HERE --- | |
| # We must iterate up to max_negatives every time to ensure the | |
| # dictionary keys (schema) are identical for every row. | |
| for i in range(max_negatives): | |
| key_name = f'negative_{i+1}' | |
| if i < len(valid_negatives): | |
| # We have a valid negative | |
| result[key_name] = valid_negatives[i] | |
| else: | |
| # We ran out of negatives -> Fill with None | |
| # This ensures the column exists in the Arrow table | |
| result[key_name] = '' | |
| # Check if we have at least one negative | |
| result['_has_negatives'] = len(valid_negatives) > 0 | |
| return result | |
| dataset = dataset.map(expand_negatives, remove_columns=dataset.column_names) | |
| # Filter out any examples without valid negatives | |
| dataset = dataset.filter(lambda x: x['_has_negatives']) | |
| dataset = dataset.remove_columns(['_has_negatives']) | |
| if len(dataset) == 0: | |
| logging.warning(f"No valid examples with negatives found in {dataset_name}") | |
| return None | |
| # Add task category for prompting | |
| dataset = dataset.add_column("new_category", [task_category] * len(dataset)) | |
| dataset = dataset.map(apply_task_prompt) | |
| dataset = dataset.remove_columns(['new_category']) | |
| # Count negatives in first example for logging | |
| sample = dataset[0] | |
| num_negatives = sum(1 for k in sample.keys() if k.startswith('negative_')) | |
| logging.info(f"Prepared {len(dataset)} training examples from {dataset_name}") | |
| logging.info(f"Examples have between 1 and {num_negatives} negatives each") | |
| desired_columns = ["anchor", "positive"] | |
| # scan for all negative columns that actually exist in the dataset | |
| existing_columns = dataset.column_names | |
| negative_cols = [c for c in existing_columns if c.startswith("negative_")] | |
| # Sort them to ensure negative_1 comes before negative_2, etc. | |
| # We sort by the integer number in the column name | |
| negative_cols.sort(key=lambda x: int(x.split('_')[1])) | |
| desired_columns.extend(negative_cols) | |
| # 2. Force the dataset to use this specific order | |
| dataset = dataset.select_columns(desired_columns) | |
| logging.info(f"Enforced column order: {dataset.column_names}") | |
| return dataset | |
| def pad_dataset_schema(dataset, total_negatives=5): | |
| """ | |
| Ensures the dataset has columns negative_1 to negative_{total_negatives}. | |
| Fills missing columns with empty strings. | |
| """ | |
| new_columns = {} | |
| existing_cols = dataset.column_names | |
| for i in range(total_negatives): | |
| col_name = f"negative_{i+1}" | |
| if col_name not in existing_cols: | |
| # Create a column of empty strings efficiently | |
| new_columns[col_name] = [""] * len(dataset) | |
| if new_columns: | |
| # Add all new columns at once | |
| from datasets import Dataset | |
| # We need to concatenate or add_column. | |
| # For efficiency with HF datasets, simpler to just add column by column or map | |
| for col_name, data in new_columns.items(): | |
| dataset = dataset.add_column(col_name, data) | |
| return dataset | |
| def main(args): | |
| run_type = "probe" if args.probe_run else "full" | |
| base_model_name = Path(args.fine_tune_model_path).name | |
| data_name = "ms_marco_nli_mined" | |
| run_name = f"finetune-CachedMNRL-{base_model_name}-on-{data_name}-{run_type}" | |
| output_dir = os.path.join(config.OUTPUT_DIR, run_name) | |
| logging.info(f"--- Starting {run_type.upper()} Run: {run_name} ---") | |
| logging.info(f"Fine-tuning model from: {args.fine_tune_model_path}") | |
| model = SentenceTransformer(args.fine_tune_model_path) | |
| logging.info(model) | |
| logging.info("Patching the model's `tokenize` method to remove token_type_ids...") | |
| original_tokenize = model.tokenize | |
| def patched_tokenize(*args, **kwargs): | |
| # Call the original tokenizer to get the encoded inputs | |
| tokenized_output = original_tokenize(*args, **kwargs) | |
| # Remove the unwanted key from the output | |
| if "token_type_ids" in tokenized_output: | |
| del tokenized_output["token_type_ids"] | |
| return tokenized_output | |
| # Replace the original method with our patched version | |
| model.tokenize = patched_tokenize | |
| logging.info("Model's `tokenize` method patched successfully.") | |
| # Tokenizer patch can remain | |
| logging.info(f"Setting model max length to {config.FT_MODEL_MAX_LENGTH}") | |
| model.max_seq_length = config.FT_MODEL_MAX_LENGTH | |
| # === LOAD ORIGINAL DATASETS === | |
| nli_data_path = "/mnt/disk2/snli_triplets_swedish" # sts | |
| msmarco_data_path = "/mnt/disk2/msmarco_triplets_swedish" # retrieval | |
| nq_data_path = "/mnt/disk2/nq_triplets_swedish" # retrieval | |
| # Load NLI data | |
| logging.info(f"Loading translated NLI triplets from: {nli_data_path}") | |
| nli_dataset = load_from_disk(nli_data_path) | |
| nli_dataset = nli_dataset.add_column("new_category", ["semantic similarity"] * len(nli_dataset)) | |
| nli_dataset = nli_dataset.map(apply_task_prompt) | |
| nli_dataset = nli_dataset.rename_column("negative", "negative_1") | |
| nli_dataset = pad_dataset_schema(nli_dataset, MAX_NEGATIVES_PER_EXAMPLE) # <--- ADD THIS | |
| nli_dataset = nli_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)]) | |
| logging.info(f"Loaded {len(nli_dataset)} NLI triplets.") | |
| # Load MS MARCO data | |
| # logging.info(f"Loading translated MSMARCO triplets from: {msmarco_data_path}") | |
| # msmarco_dataset = load_from_disk(msmarco_data_path) | |
| # msmarco_dataset = msmarco_dataset.rename_column("query", "anchor") | |
| # msmarco_dataset = msmarco_dataset.add_column("new_category", ["retrieval"] * len(msmarco_dataset)) | |
| # msmarco_dataset = msmarco_dataset.map(apply_task_prompt) | |
| # msmarco_dataset = msmarco_dataset.rename_column("negative", "negative_1") | |
| # msmarco_dataset = pad_dataset_schema(msmarco_dataset, MAX_NEGATIVES_PER_EXAMPLE) # <--- ADD THIS | |
| # msmarco_dataset = msmarco_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)]) | |
| # logging.info(f"Loaded {len(msmarco_dataset)} MSMARCO triplets.") | |
| # Load NQ data | |
| logging.info(f"Loading translated NQ triplets from: {nq_data_path}") | |
| nq_dataset = load_from_disk(nq_data_path) | |
| nq_dataset = nq_dataset.rename_column("query", "anchor") | |
| nq_dataset = nq_dataset.add_column("new_category", ["question answering"] * len(nq_dataset)) | |
| nq_dataset = nq_dataset.map(apply_task_prompt) | |
| nq_dataset = nq_dataset.rename_column("negative", "negative_1") | |
| nq_dataset = pad_dataset_schema(nq_dataset, MAX_NEGATIVES_PER_EXAMPLE) # <--- ADD THIS | |
| nq_dataset = nq_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)]) | |
| logging.info(f"Loaded {len(nq_dataset)} NQ triplets.") | |
| logging.info(f"Loading synthetic data from: {synthetic_data_path}") | |
| synthetic_dataset = load_from_disk(synthetic_data_path) | |
| synthetic_dataset = synthetic_dataset.filter(is_good_or_excellent) | |
| synthetic_dataset = synthetic_dataset.map(handle_none_negatives) | |
| synthetic_dataset = synthetic_dataset.rename_column("query", "anchor") | |
| synthetic_dataset = synthetic_dataset.map(apply_task_prompt) | |
| synthetic_dataset = synthetic_dataset.rename_column("negative", "negative_1") | |
| synthetic_dataset = pad_dataset_schema(synthetic_dataset, MAX_NEGATIVES_PER_EXAMPLE) | |
| synthetic_dataset = synthetic_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)]) | |
| logging.info(f"Loaded {len(synthetic_dataset)} synthetic triplets.") | |
| # Load wiki clustering data | |
| cluster_data_path = "/home/ubuntu/work/WSCLToolkit/sent_emb_train/non_vital_code/final_clustering_data.jsonl/" | |
| cluster_dataset = load_clustering_dataset( | |
| cluster_data_path, | |
| max_negatives=MAX_NEGATIVES_PER_EXAMPLE | |
| ) | |
| # === LOAD MINED DATASETS === | |
| logging.info("\n=== Loading Mined Hard Negative Datasets ===") | |
| logging.info(f"Using up to {MAX_NEGATIVES_PER_EXAMPLE} negatives per example") | |
| mined_datasets = {} | |
| for dataset_name, task_category in MINED_DATASET_CONFIG.items(): | |
| dataset = load_mined_dataset( | |
| dataset_name=dataset_name, | |
| task_category=task_category, | |
| max_samples=MAX_SAMPLES, # Set to a number if you want to limit samples per dataset | |
| max_negatives=MAX_NEGATIVES_PER_EXAMPLE | |
| ) | |
| if dataset is not None: | |
| mined_datasets[dataset_name] = dataset | |
| logging.info(f"Successfully loaded {len(mined_datasets)} mined datasets") | |
| # === SPLIT DATASETS FOR EVALUATION === | |
| logging.info("\n=== Splitting datasets to create dev sets for TripletEvaluators ===") | |
| eval_samples_per_dataset = 1000 | |
| # Split NLI | |
| nli_split = nli_dataset.train_test_split(test_size=eval_samples_per_dataset, seed=42) | |
| nli_train = nli_split["train"] | |
| nli_dev = nli_split["test"] | |
| # Split MS MARCO | |
| # msmarco_split = msmarco_dataset.train_test_split( | |
| # test_size=eval_samples_per_dataset, seed=42 | |
| # ) | |
| # msmarco_train = msmarco_split["train"] | |
| # msmarco_dev = msmarco_split["test"] | |
| # Split NQ | |
| nq_split = nq_dataset.train_test_split(test_size=eval_samples_per_dataset, seed=42) | |
| nq_train = nq_split["train"] | |
| nq_dev = nq_split["test"] | |
| # === COMBINE ALL DATASETS === | |
| # Use DatasetDict for stratified sampling across datasets | |
| datasets = {} | |
| # Add original datasets | |
| datasets["NLI"] = nli_train | |
| #datasets["MSMARCO"] = msmarco_train | |
| datasets["NQ"] = nq_train | |
| datasets["Topic_Clustering"] = cluster_dataset | |
| datasets["Synthetic"] = synthetic_dataset | |
| # datasets["Synthetic_Classification"] = synthetic_classification_dataset | |
| # Add all mined datasets | |
| datasets.update(mined_datasets) | |
| # Convert to DatasetDict for stratified batch sampling | |
| train_dataset = DatasetDict(datasets) | |
| logging.info(f"Created training DatasetDict with {len(train_dataset)} datasets") | |
| # === DATASET COMPOSITION TABLE === | |
| # For DatasetDict, iterate over the dataset dict | |
| total_samples = sum(len(ds) for ds in train_dataset.values()) | |
| composition_data = [] | |
| for name, ds in train_dataset.items(): | |
| num_samples = len(ds) | |
| percentage = (num_samples / total_samples) * 100 if total_samples > 0 else 0 | |
| composition_data.append({ | |
| "Dataset": name, | |
| "Number of Examples": num_samples, | |
| "Percentage (%)": f"{percentage:.2f}%" | |
| }) | |
| print("\n📊 Finetuning Dataset Composition") | |
| print("-" * 80) | |
| print(f"{'Dataset':<50} | {'Number of Examples':<20} | {'Percentage (%)'}") | |
| print("-" * 80) | |
| for item in composition_data: | |
| print(f"{item['Dataset']:<50} | {item['Number of Examples']:<20} | {item['Percentage (%)']}") | |
| print("-" * 80) | |
| print(f"{'Total':<50} | {total_samples:<20} | {'100.00%'}") | |
| print("-" * 80) | |
| logging.info(f"Training with DatasetDict containing {len(train_dataset)} datasets and {total_samples} total samples.") | |
| # === CONFIGURE LOSS === | |
| loss = CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeightAndMask( | |
| model=model, | |
| mini_batch_size=config.FT_BS, | |
| spread_out_loss_weight=0.1, | |
| use_hardness_weighting=True, | |
| mask_duplicate_positives=False, | |
| hardness_alpha=3.0, | |
| scale=50, | |
| ) | |
| run_name = run_name + "-added datasets" + f"max_{MAX_SAMPLES}_per_dataset" | |
| # NLI Evaluator (semantic similarity) - prompts already applied | |
| nli_anchors, nli_pos, nli_neg = prepare_triplet_eval_data(nli_dev) | |
| nli_triplet_evaluator = TripletEvaluator(nli_anchors, nli_pos, nli_neg, name="nli-triplet-dev") | |
| # MS MARCO Evaluator (retrieval - asymmetric) - prompts already applied | |
| # msmarco_anchors, msmarco_pos, msmarco_neg = prepare_triplet_eval_data(msmarco_dev) | |
| # msmarco_triplet_evaluator = TripletEvaluator(msmarco_anchors, msmarco_pos, msmarco_neg, name="msmarco-triplet-dev") | |
| # NQ Evaluator (question answering) - prompts already applied | |
| nq_anchors, nq_pos, nq_neg = prepare_triplet_eval_data(nq_dev) | |
| nq_triplet_evaluator = TripletEvaluator(nq_anchors, nq_pos, nq_neg, name="nq-triplet-dev") | |
| wiki_cluster_eval = CustomJSONLClusteringEvaluator( | |
| file_path="/home/ubuntu/work/WSCLToolkit/sent_emb_train/non_vital_code/parsed_wiki_sections.jsonl", | |
| min_samples_per_topic=5, | |
| max_clusters=500 # <--- Takes 50 random topics (e.g. Afrika, Amager, Kaffe, Volvo...) | |
| ) | |
| sentiment_eval = load_swedish_reviews_evaluator(prompt_style="standard") | |
| # Paraphrase (STS) Evaluator | |
| print("\n--- Setting up Paraphrase Evaluator from JSONL ---") | |
| para_sents1 = [] | |
| para_sents2 = [] | |
| para_scores = [] | |
| max_score = 0.0 | |
| with open(config.SWE_PARA_PATH, "r", encoding="utf-8") as f: | |
| for line in f: | |
| data = json.loads(line) | |
| para_sents1.append(data["sentence_1"]) | |
| para_sents2.append(data["sentence_2"]) | |
| score = float(data["label"]) | |
| para_scores.append(score) | |
| if score > max_score: | |
| max_score = score | |
| if max_score > 0: | |
| normalized_scores = [score / max_score for score in para_scores] | |
| else: | |
| normalized_scores = para_scores | |
| # Apply prompts to paraphrase data | |
| para_sents1 = ["task: semantic similarity | query: " + s for s in para_sents1] | |
| para_sents2 = ["task: semantic similarity | query: " + s for s in para_sents2] | |
| sweparaphrase_evaluator = EmbeddingSimilarityEvaluator( | |
| sentences1=para_sents1, | |
| sentences2=para_sents2, | |
| scores=normalized_scores, | |
| name="sweparaphrase-dev", | |
| ) | |
| # FAQ Retrieval Evaluator | |
| print("\n--- Setting up FAQ Retrieval Evaluator from JSONL ---") | |
| with open(config.SWE_FAQ_PATH, "r", encoding="utf-8") as f: | |
| faq_data = [json.loads(line) for line in f] | |
| unique_answers = {answer for item in faq_data for answer in item["candidate_answers"]} | |
| answer_to_cid = {answer: f"doc_{i}" for i, answer in enumerate(unique_answers)} | |
| # Apply prompts to FAQ corpus (documents) | |
| faq_corpus = { | |
| cid: "title: none | text: " + answer | |
| for answer, cid in answer_to_cid.items() | |
| } | |
| faq_queries = {} | |
| faq_relevant_docs = {} | |
| for i, item in enumerate(faq_data): | |
| query_id = f"q_{i}" | |
| faq_queries[query_id] = "task: search result | query: " + item['question'] | |
| correct_answer = item["candidate_answers"][item["label"]] | |
| correct_cid = answer_to_cid[correct_answer] | |
| faq_relevant_docs[query_id] = {correct_cid} | |
| swefaq_evaluator = InformationRetrievalEvaluator( | |
| queries=faq_queries, | |
| corpus=faq_corpus, | |
| relevant_docs=faq_relevant_docs, | |
| name="swefaq-dev", | |
| ) | |
| nanobeir_evaluators = create_nanobeir_evaluators( | |
| languages=["sv", "no"], | |
| dataset_names=[ | |
| "NanoMSMARCO", | |
| "NanoNQ", | |
| "NanoQuoraRetrieval", | |
| "NanoFEVER", | |
| "NanoHotpotQA" | |
| ] | |
| ) | |
| main_evaluator = SequentialEvaluator( | |
| evaluators=[ | |
| nli_triplet_evaluator, | |
| # msmarco_triplet_evaluator, | |
| nq_triplet_evaluator, | |
| sweparaphrase_evaluator, | |
| swefaq_evaluator, | |
| wiki_cluster_eval, | |
| sentiment_eval | |
| ] + nanobeir_evaluators, | |
| ) | |
| print("\nAll evaluators combined and ready.") | |
| # === CONFIGURE TRAINING ARGUMENTS === | |
| training_args = SentenceTransformerTrainingArguments( | |
| output_dir=config.FT_OUTPUT_DIR, | |
| num_train_epochs=1, | |
| per_device_eval_batch_size=64, | |
| per_device_train_batch_size=config.FT_CONCEPTUAL_BS, | |
| learning_rate=config.FT_LR, | |
| bf16=True, | |
| report_to="wandb", | |
| run_name=run_name, | |
| save_total_limit=2, | |
| logging_steps=config.LOGGING_STEPS, | |
| eval_strategy=config.EVAL_STRATEGY, | |
| save_strategy=config.EVAL_STRATEGY, | |
| eval_steps=config.FT_EVAL_STEPS, | |
| save_steps=config.FT_SAVE_STEPS, | |
| load_best_model_at_end=True, | |
| warmup_ratio=config.FT_WARMUP, | |
| weight_decay=config.FT_WEIGHT_DECAY, | |
| metric_for_best_model="eval_msmarco-triplet-dev_cosine_accuracy", | |
| greater_is_better=True, | |
| # Enable multi-dataset batch sampling for stratified batches | |
| multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, | |
| ddp_find_unused_parameters=False, | |
| ) | |
| # === INITIALIZE AND RUN TRAINER === | |
| trainer = SentenceTransformerTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| loss=loss, | |
| evaluator=main_evaluator, | |
| ) | |
| logging.info("--- Running initial evaluation on the untrained model (step 0) ---") | |
| trainer.evaluate() | |
| logging.info(f"Starting model fine-tuning for run: {run_name}") | |
| trainer.train() | |
| logging.info(f"--- Fine-tuning complete for {run_name}! ---") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="A flexible fine-tuning script for Sentence Transformer models." | |
| ) | |
| parser.add_argument( | |
| "--fine_tune_model_path", | |
| type=str, | |
| required=True, | |
| help="Path or Hub name of the model to fine-tune.", | |
| ) | |
| parser.add_argument( | |
| "--probe_run", action="store_true", help="If set, runs a short probe run." | |
| ) | |
| cli_args = parser.parse_args() | |
| main(cli_args) | |