# DID NOT USE THE NEW MSMARCO DATSET, OTHERWISE EVERYTHING IS THE SAME import logging import os import argparse import config # Assuming your config file is available import json import torch import pandas as pd os.environ["HF_HOME"] = config.CACHE_DIR from pathlib import Path from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset, DatasetDict from sentence_transformers.evaluation import ( SequentialEvaluator, EmbeddingSimilarityEvaluator, InformationRetrievalEvaluator, TripletEvaluator, ) from sentence_transformers.trainer import SentenceTransformerTrainer from sentence_transformers.training_args import ( SentenceTransformerTrainingArguments, BatchSamplers, MultiDatasetBatchSamplers ) from sentence_transformers import SentenceTransformer, losses, models import transformers from src.custom_loss.CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeightAndMask import CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeightAndMask from src.custom_loss.CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeight import CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeight import random from sklearn.cluster import MiniBatchKMeans from sklearn.metrics import v_measure_score, adjusted_rand_score, accuracy_score, f1_score from sentence_transformers.evaluation import SentenceEvaluator from sklearn.linear_model import LogisticRegression logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) transformers.logging.set_verbosity_info() # --- MINED DATASET CONFIGURATION --- MINED_DATASETS_BASE_PATH = Path("/mnt/disk2/translated_datasets") # Map dataset names to their task categories MINED_DATASET_CONFIG = { "gooaq_train_swedish_triplets_scored": "question answering", "reddit_train_swedish_triplets_scored": "retrieval", "xsum_train_swedish_triplets_scored": "clustering", "simple-wiki_train_swedish_triplets_scored": "semantic similarity", "s2orc_train_swedish_saved_triplets_scored": "retrieval", "amazon-reviews_train_swedish_saved_triplets_scored": "retrieval", "paq_train_swedish_queries_retranslated_triplets_scored": "question answering", "stackexchange-duplicates_train_swedish_triplets_scored": "semantic similarity", "wikipedia-sections_train_swedish_triplets_scored": "retrieval", "msmarco_triplets_swedish_triplets_scored" : "retrieval", } synthetic_data_path = "/mnt/disk2/combined_synthetic_data_deduplicated_classified" #synthetic_classification_data_path = "/mnt/disk2/classification_dataset_graded_subset.jsonl" #combined_synthetic_data_deduplicated_classified # Maximum number of negatives to use per example from mined datasets # The loss function (CachedMultipleNegativesRankingLoss) can handle multiple negatives # Format: (anchor, positive, negative_1, negative_2, ..., negative_n) MAX_NEGATIVES_PER_EXAMPLE = 10 MAX_SAMPLES = 500_000 NANOBEIR_DATASETS = [ "NanoArguAna", "NanoClimateFEVER", "NanoDBPedia", "NanoFEVER", "NanoFiQA2018", "NanoHotpotQA", "NanoMSMARCO", "NanoNFCorpus", "NanoNQ", "NanoQuoraRetrieval", "NanoSCIDOCS", "NanoSciFact", "NanoTouche2020", ] # Map dataset names to task types for appropriate prompting NANOBEIR_TASK_TYPES = { "NanoArguAna": "retrieval", "NanoClimateFEVER": "retrieval", "NanoDBPedia": "retrieval", "NanoFEVER": "retrieval", "NanoFiQA2018": "retrieval", "NanoHotpotQA": "retrieval", "NanoMSMARCO": "retrieval", "NanoNFCorpus": "retrieval", "NanoNQ": "question answering", "NanoQuoraRetrieval": "semantic similarity", "NanoSCIDOCS": "retrieval", "NanoSciFact": "retrieval", "NanoTouche2020": "retrieval", } class ProbeClassificationEvaluator(SentenceEvaluator): """ Generic evaluator that trains a Logistic Regression probe on train_set and evaluates accuracy on test_set. """ def __init__(self, dataset_name, sentences_train, labels_train, sentences_test, labels_test, batch_size=32, prompt_prefix=""): self.name = f"eval_{dataset_name}_classification" self.sentences_train = [prompt_prefix + s for s in sentences_train] self.labels_train = labels_train self.sentences_test = [prompt_prefix + s for s in sentences_test] self.labels_test = labels_test self.batch_size = batch_size def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: # 1. Encode emb_train = model.encode(self.sentences_train, batch_size=self.batch_size, show_progress_bar=False) emb_test = model.encode(self.sentences_test, batch_size=self.batch_size, show_progress_bar=False) # 2. Train Probe (Fast LR) clf = LogisticRegression(random_state=42, solver='lbfgs', max_iter=100, n_jobs=-1) clf.fit(emb_train, self.labels_train) # 3. Predict preds = clf.predict(emb_test) acc = accuracy_score(self.labels_test, preds) logging.info(f"{self.name}: Accuracy = {acc:.4f}") return acc def load_swedish_reviews_evaluator(prompt_style="standard"): """ Loads 'timpal0l/swedish_reviews' for binary sentiment classification. Safe from MTEB leakage. """ try: logging.info("Loading timpal0l/swedish_reviews for Sentiment Probe...") # Load the test split (approx 10k samples, good for eval) dataset = load_dataset("timpal0l/swedish_reviews", split="test") # Downsample if needed for speed (e.g. keep 2000 samples) if len(dataset) > 50: dataset = dataset.shuffle(seed=42).select(range(50)) sentences = dataset["text"] labels = dataset["label"] # Define Prompt if prompt_style == "standard": # "task: classification" is the standard prompt for this prompt = "task: classification | query: " else: prompt = "" # Split 50/50 for Train/Test Probe # We need to train the logistic regression probe on *some* data to test the embeddings split_idx = len(sentences) // 2 evaluator = ProbeClassificationEvaluator( dataset_name="SwedishReviews-Sentiment", sentences_train=sentences[:split_idx], labels_train=labels[:split_idx], sentences_test=sentences[split_idx:], labels_test=labels[split_idx:], prompt_prefix=prompt, batch_size=32 ) return evaluator except Exception as e: logging.warning(f"Failed to load Swedish Reviews: {e}") return None class CustomJSONLClassificationEvaluator(SentenceEvaluator): """ Reads a JSONL file and trains a Logistic Regression probe to predict 'main_title'. This checks if the topics are linearly separable in the embedding space. """ def __init__(self, file_path, min_samples_per_label=5, max_classes=20, batch_size=32): self.name = "eval_wiki_classification" self.batch_size = batch_size self.prompt = "task: classification | query: " self.sentences_train = [] self.labels_train = [] self.sentences_test = [] self.labels_test = [] logging.info(f"Loading classification data from {file_path}...") # 1. Load Data by Class label_map = {} # { "Afrika": ["text1", "text2"...] } with open(file_path, 'r', encoding='utf-8') as f: for line in f: try: row = json.loads(line) text = row.get('text', '').strip() label = row.get('main_title', '').strip() if text and label: if label not in label_map: label_map[label] = [] label_map[label].append(text) except: continue # 2. Filter Classes valid_labels = [l for l, texts in label_map.items() if len(texts) >= min_samples_per_label] valid_labels.sort() # Deterministic order # 3. Subsample Classes (Avoid training on 5000 classes) rng = random.Random(42) if len(valid_labels) > max_classes: selected_labels = rng.sample(valid_labels, max_classes) else: selected_labels = valid_labels logging.info(f"Classification Probe: Using {len(selected_labels)} classes (topics).") # 4. Create Train/Test Split (80/20 per class) for label in selected_labels: texts = label_map[label] # Shuffle texts for this label rng.shuffle(texts) # Use max 50 samples per class to keep probe fast texts = texts[:50] split_idx = int(0.8 * len(texts)) # Train set for t in texts[:split_idx]: self.sentences_train.append(self.prompt + t) self.labels_train.append(label) # Test set for t in texts[split_idx:]: self.sentences_test.append(self.prompt + t) self.labels_test.append(label) logging.info(f"Probe Sizes: Train={len(self.sentences_train)}, Test={len(self.sentences_test)}") def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: if not self.sentences_train: return 0.0 # 1. Encode emb_train = model.encode(self.sentences_train, batch_size=self.batch_size, show_progress_bar=False) emb_test = model.encode(self.sentences_test, batch_size=self.batch_size, show_progress_bar=False) # 2. Train Probe (Fast Logistic Regression) clf = LogisticRegression(random_state=42, solver='lbfgs', max_iter=100, n_jobs=-1) clf.fit(emb_train, self.labels_train) # 3. Predict & Score preds = clf.predict(emb_test) acc = accuracy_score(self.labels_test, preds) # Macro F1 is better if classes are imbalanced f1 = f1_score(self.labels_test, preds, average='macro') logging.info(f"{self.name}: Accuracy={acc:.4f} | F1-Macro={f1:.4f}") return acc class CustomJSONLClusteringEvaluator(SentenceEvaluator): """ Evaluates clustering on a specific number of distinct Wikipedia topics. """ def __init__(self, file_path, min_samples_per_topic=5, max_clusters=50, batch_size=32): self.name = "eval_wiki_clustering" self.batch_size = batch_size self.prompt = "task: clustering | query: " self.sentences = [] self.labels = [] logging.info(f"Loading clustering data from {file_path}...") # 1. Load ALL data into memory organized by Topic # Structure: { "Afrika": ["text1", "text2"...], "Amager": ["text1"...] } topic_map = {} with open(file_path, 'r', encoding='utf-8') as f: for line in f: try: row = json.loads(line) text = row.get('text', '').strip() # Use 'main_title' as the Topic Label label = row.get('main_title', '').strip() if text and label: if label not in topic_map: topic_map[label] = [] topic_map[label].append(text) except: continue # 2. Filter topics that are too small # We need topics with enough content to actually cluster meaningful points valid_topics = [ topic for topic, texts in topic_map.items() if len(texts) >= min_samples_per_topic ] logging.info(f"Found {len(valid_topics)} valid topics with >= {min_samples_per_topic} samples.") # 3. Sample exactly 'max_clusters' topics (e.g., 50) # We sort first to make the random seed deterministic across runs valid_topics.sort() # Use a fixed seed so the "random" 50 topics are the same every time you restart training rng = random.Random(42) if len(valid_topics) > max_clusters: selected_topics = rng.sample(valid_topics, max_clusters) logging.info(f"Subsampled to {max_clusters} distinct topics for evaluation.") else: selected_topics = valid_topics logging.info(f"Using all {len(selected_topics)} topics (fewer than max_clusters).") # 4. Flatten the data for the evaluator for topic in selected_topics: # You can also limit samples per topic here (e.g. max 20 paragraphs per topic) to keep it balanced texts = topic_map[topic][:20] for text in texts: self.sentences.append(self.prompt + text) self.labels.append(topic) self.n_clusters = len(selected_topics) logging.info(f"Final Clustering Probe: {len(self.sentences)} samples across {self.n_clusters} topics.") def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: if not self.sentences: return 0.0 # Encode embeddings = model.encode(self.sentences, batch_size=self.batch_size, show_progress_bar=False) # Cluster clustering = MiniBatchKMeans( n_clusters=self.n_clusters, batch_size=256, random_state=42, n_init='auto' ) clustering.fit(embeddings) # Score v_score = v_measure_score(self.labels, clustering.labels_) ari_score = adjusted_rand_score(self.labels, clustering.labels_) logging.info(f"{self.name}: V-Measure={v_score:.4f} | ARI={ari_score:.4f}") return v_score def load_nanobeir_evaluator(dataset_name, language, task_type="retrieval"): """ Load a NanoBEIR dataset for a specific language and create an InformationRetrievalEvaluator. Args: dataset_name: Name of the NanoBEIR dataset (e.g., "NanoMSMARCO") language: Language code ("sv" for Swedish, "no" for Norwegian) task_type: Task type for prompting ("retrieval", "question answering", "semantic similarity") Returns: InformationRetrievalEvaluator or None if loading fails """ try: logging.info(f"Loading {dataset_name} for language: {language}") # Load corpus (documents) corpus_data = load_dataset( "lightonai/nanobeir-multilingual", f"{dataset_name}_{language}", split="corpus" ) # Load queries queries_data = load_dataset( "lightonai/nanobeir-multilingual", f"{dataset_name}_{language}", split="queries" ) # Load qrels (relevance judgments) - language independent qrels_data = load_dataset( "lightonai/nanobeir-multilingual", dataset_name, split="qrels" ) # Apply task-specific prompts if task_type == "retrieval": query_prompt = "task: search result | query: " doc_prompt = "title: none | text: " elif task_type == "question answering": query_prompt = "task: question answering | query: " doc_prompt = "title: none | text: " elif task_type == "semantic similarity": query_prompt = "task: semantic similarity | query: " doc_prompt = "task: semantic similarity | query: " else: query_prompt = "" doc_prompt = "" # Build corpus dictionary with prompts corpus = {} for item in corpus_data: doc_id = item['_id'] # Combine title and text if title exists if 'title' in item and item['title']: text = f"{item['title']} {item['text']}" else: text = item['text'] corpus[doc_id] = doc_prompt + text # Build queries dictionary with prompts queries = {} for item in queries_data: query_id = item['_id'] queries[query_id] = query_prompt + item['text'] # Build relevance dictionary # Note: NanoBEIR qrels only have query-id and corpus-id (no score column) # All entries in qrels are considered relevant relevant_docs = {} for item in qrels_data: query_id = item['query-id'] corpus_id = item['corpus-id'] if query_id not in relevant_docs: relevant_docs[query_id] = set() relevant_docs[query_id].add(corpus_id) # Filter queries to only those with relevant documents queries = {qid: text for qid, text in queries.items() if qid in relevant_docs} if not queries: logging.warning(f"No queries with relevant documents found for {dataset_name}_{language}") return None # Create evaluator evaluator_name = f"{dataset_name}-{language}-dev" evaluator = InformationRetrievalEvaluator( queries=queries, corpus=corpus, relevant_docs=relevant_docs, name=evaluator_name, # Only compute NDCG@10 to reduce metric clutter mrr_at_k=[10], # Disable MRR ndcg_at_k=[10], # Only NDCG@10 accuracy_at_k=[1], # Disable accuracy precision_recall_at_k=[1], # Disable precision/recall map_at_k=[100], # Disable MAP ) logging.info(f"Created {evaluator_name} with {len(queries)} queries and {len(corpus)} documents") return evaluator except Exception as e: logging.error(f"Failed to load {dataset_name} for {language}: {e}") return None def create_nanobeir_evaluators(languages=["sv", "no"], dataset_names=None): """ Create InformationRetrievalEvaluators for multiple NanoBEIR datasets and languages. Args: languages: List of language codes (default: ["sv", "no"]) dataset_names: List of dataset names to use (default: all NANOBEIR_DATASETS) Returns: List of InformationRetrievalEvaluators """ if dataset_names is None: dataset_names = NANOBEIR_DATASETS evaluators = [] for language in languages: logging.info(f"\n=== Loading NanoBEIR datasets for {language.upper()} ===") for dataset_name in dataset_names: task_type = NANOBEIR_TASK_TYPES.get(dataset_name, "retrieval") evaluator = load_nanobeir_evaluator(dataset_name, language, task_type) if evaluator is not None: evaluators.append(evaluator) logging.info(f"\nSuccessfully created {len(evaluators)} NanoBEIR evaluators") return evaluators def handle_none_negatives(example): """Replaces None in the 'negative' column with an empty string.""" if example["negative"] is None: example["negative"] = "" return example def is_good_or_excellent(example): """Filter function to keep only good or excellent graded examples.""" if "dialect" in example["task_description"].lower(): return False return example['grade'] in ['good', 'excellent'] def prepare_triplet_eval_data(dev_set): """ Prepares anchors, positives, and negatives for a TripletEvaluator. Extracts raw data WITHOUT prompts - prompts will be applied separately. Handles both 'negative' and 'negative_1' column names. """ anchors = dev_set["anchor"] positives = dev_set["positive"] # Handle both column naming conventions if "negative" in dev_set.column_names: negatives = dev_set["negative"] elif "negative_1" in dev_set.column_names: negatives = dev_set["negative_1"] else: raise ValueError(f"No negative column found. Available columns: {dev_set.column_names}") return anchors, positives, negatives def load_clustering_dataset(dataset_path, max_negatives=5): """ Loads the pre-saved Clustering dataset from disk. Applies clustering prompts and handles schema consistency for DDP. """ if not os.path.exists(dataset_path): logging.warning(f"Clustering data not found at: {dataset_path}") return None logging.info(f"Loading clustering dataset from disk: {dataset_path}") dataset = load_from_disk(str(dataset_path)) # 1. Ensure we don't have None values (Tokenizer crash protection) def fill_empty(example): for key in example.keys(): if example[key] is None: example[key] = "" return example dataset = dataset.map(fill_empty) # 2. Add clustering category and apply prompts BEFORE padding dataset = dataset.add_column("new_category", ["clustering"] * len(dataset)) dataset = dataset.map(apply_task_prompt) dataset = dataset.remove_columns(['new_category']) # 3. Enforce exact column schema for DDP - pad AFTER applying prompts desired_cols = ["anchor", "positive"] + [f"negative_{i+1}" for i in range(max_negatives)] # Pad if we have fewer negatives than the global config for col in desired_cols: if col not in dataset.column_names: dataset = dataset.add_column(col, [""] * len(dataset)) # Select only the needed columns (drops extra negatives if you have > max) dataset = dataset.select_columns(desired_cols) logging.info(f"Loaded {len(dataset)} clustering examples with prompts applied.") return dataset def apply_task_prompt(example, dropout_rate=0.1): """ Applies strict task-specific prompts based on the provided schema. Skips empty strings to avoid adding prompts to padding columns. """ task_type = example["new_category"] if random.random() < dropout_rate: return example # --- 1. RETRIEVAL (Asymmetric) --- # Anchor: "task: search result | query: {content}" # Docs: "title: none | text: {content}" if task_type == "retrieval": if example['anchor']: # Only apply if not empty example['anchor'] = "task: search result | query: " + example['anchor'] # Define the document prompt (assuming no title column available, using 'none') doc_prompt = "title: none | text: " if example['positive']: # Only apply if not empty example['positive'] = doc_prompt + example['positive'] # Apply to all negatives for key in list(example.keys()): if key.startswith('negative') and example[key]: # This already checks for non-empty example[key] = doc_prompt + example[key] # --- 2. QUESTION ANSWERING (Symmetric) --- # User Request: Same prompt for anchor and positive. # Prompt: "task: question answering | query: {content}" elif task_type == "question answering": instruct = "task: question answering | query: " if example['anchor']: example['anchor'] = instruct + example['anchor'] if example['positive']: example['positive'] = instruct + example['positive'] for key in list(example.keys()): if key.startswith('negative') and example[key]: example[key] = instruct + example[key] # --- 3. CLUSTERING (Symmetric) --- # Prompt: "task: clustering | query: {content}" elif task_type == "clustering": instruct = "task: clustering | query: " if example['anchor']: example['anchor'] = instruct + example['anchor'] if example['positive']: example['positive'] = instruct + example['positive'] for key in list(example.keys()): if key.startswith('negative') and example[key]: example[key] = instruct + example[key] # --- 4. CLASSIFICATION (Symmetric) --- # Prompt: "task: classification | query: {content}" elif task_type == "classification": instruct = "task: classification | query: " if example['anchor']: example['anchor'] = instruct + example['anchor'] if example['positive']: example['positive'] = instruct + example['positive'] for key in list(example.keys()): if key.startswith('negative') and example[key]: example[key] = instruct + example[key] # --- 5. SEMANTIC SIMILARITY (Symmetric) --- # Prompt: "task: sentence similarity | query: {content}" elif task_type == "semantic similarity": instruct = "task: semantic similarity | query: " if example['anchor']: example['anchor'] = instruct + example['anchor'] if example['positive']: example['positive'] = instruct + example['positive'] for key in list(example.keys()): if key.startswith('negative') and example[key]: example[key] = instruct + example[key] # Fallback else: logging.warning(f"Unknown task category: {task_type}. No prompt applied.") return example def load_mined_dataset(dataset_name, task_category, max_samples=None, max_negatives=10): """ Load a single mined dataset and prepare it for training. Args: dataset_name: Name of the dataset directory task_category: Task category for prompting max_samples: Optional limit on number of samples to load max_negatives: Maximum number of negatives to include per example (default: 10) Returns: Processed dataset with prompts applied, with multiple negatives per example """ dataset_path = MINED_DATASETS_BASE_PATH / dataset_name if not dataset_path.exists(): logging.warning(f"Dataset path does not exist: {dataset_path}. Skipping.") return None logging.info(f"Loading mined dataset from: {dataset_path}") dataset = load_from_disk(str(dataset_path)) if max_samples and len(dataset) > max_samples: logging.info(f"Sampling {max_samples} from {len(dataset)} samples") dataset = dataset.shuffle(seed=42).select(range(max_samples)) logging.info(f"Loaded {len(dataset)} samples from {dataset_name}") # The mined datasets have structure: anchor, positive, pos_score, negatives (list), neg_scores (list) # We need to convert to: anchor, positive, negative_1, negative_2, ..., negative_n # BUT we only keep negatives that actually exist (no None/empty values) def expand_negatives(example): """ Convert the list of negatives into separate columns. Ensures ALL columns from negative_1 to negative_{max_negatives} exist. Pads missing negatives with None. """ result = { 'anchor': example['anchor'], 'positive': example['positive'], } # Get the list of negatives, ensure it's a list raw_negatives = example.get('negatives', []) if raw_negatives is None: raw_negatives = [] # Filter out empty strings or None values from the source list valid_negatives = [n for n in raw_negatives if n] # --- KEY FIX IS HERE --- # We must iterate up to max_negatives every time to ensure the # dictionary keys (schema) are identical for every row. for i in range(max_negatives): key_name = f'negative_{i+1}' if i < len(valid_negatives): # We have a valid negative result[key_name] = valid_negatives[i] else: # We ran out of negatives -> Fill with None # This ensures the column exists in the Arrow table result[key_name] = '' # Check if we have at least one negative result['_has_negatives'] = len(valid_negatives) > 0 return result dataset = dataset.map(expand_negatives, remove_columns=dataset.column_names) # Filter out any examples without valid negatives dataset = dataset.filter(lambda x: x['_has_negatives']) dataset = dataset.remove_columns(['_has_negatives']) if len(dataset) == 0: logging.warning(f"No valid examples with negatives found in {dataset_name}") return None # Add task category for prompting dataset = dataset.add_column("new_category", [task_category] * len(dataset)) dataset = dataset.map(apply_task_prompt) dataset = dataset.remove_columns(['new_category']) # Count negatives in first example for logging sample = dataset[0] num_negatives = sum(1 for k in sample.keys() if k.startswith('negative_')) logging.info(f"Prepared {len(dataset)} training examples from {dataset_name}") logging.info(f"Examples have between 1 and {num_negatives} negatives each") desired_columns = ["anchor", "positive"] # scan for all negative columns that actually exist in the dataset existing_columns = dataset.column_names negative_cols = [c for c in existing_columns if c.startswith("negative_")] # Sort them to ensure negative_1 comes before negative_2, etc. # We sort by the integer number in the column name negative_cols.sort(key=lambda x: int(x.split('_')[1])) desired_columns.extend(negative_cols) # 2. Force the dataset to use this specific order dataset = dataset.select_columns(desired_columns) logging.info(f"Enforced column order: {dataset.column_names}") return dataset def pad_dataset_schema(dataset, total_negatives=5): """ Ensures the dataset has columns negative_1 to negative_{total_negatives}. Fills missing columns with empty strings. """ new_columns = {} existing_cols = dataset.column_names for i in range(total_negatives): col_name = f"negative_{i+1}" if col_name not in existing_cols: # Create a column of empty strings efficiently new_columns[col_name] = [""] * len(dataset) if new_columns: # Add all new columns at once from datasets import Dataset # We need to concatenate or add_column. # For efficiency with HF datasets, simpler to just add column by column or map for col_name, data in new_columns.items(): dataset = dataset.add_column(col_name, data) return dataset def main(args): run_type = "probe" if args.probe_run else "full" base_model_name = Path(args.fine_tune_model_path).name data_name = "ms_marco_nli_mined" run_name = f"finetune-CachedMNRL-{base_model_name}-on-{data_name}-{run_type}" output_dir = os.path.join(config.OUTPUT_DIR, run_name) logging.info(f"--- Starting {run_type.upper()} Run: {run_name} ---") logging.info(f"Fine-tuning model from: {args.fine_tune_model_path}") model = SentenceTransformer(args.fine_tune_model_path) logging.info(model) logging.info("Patching the model's `tokenize` method to remove token_type_ids...") original_tokenize = model.tokenize def patched_tokenize(*args, **kwargs): # Call the original tokenizer to get the encoded inputs tokenized_output = original_tokenize(*args, **kwargs) # Remove the unwanted key from the output if "token_type_ids" in tokenized_output: del tokenized_output["token_type_ids"] return tokenized_output # Replace the original method with our patched version model.tokenize = patched_tokenize logging.info("Model's `tokenize` method patched successfully.") # Tokenizer patch can remain logging.info(f"Setting model max length to {config.FT_MODEL_MAX_LENGTH}") model.max_seq_length = config.FT_MODEL_MAX_LENGTH # === LOAD ORIGINAL DATASETS === nli_data_path = "/mnt/disk2/snli_triplets_swedish" # sts msmarco_data_path = "/mnt/disk2/msmarco_triplets_swedish" # retrieval nq_data_path = "/mnt/disk2/nq_triplets_swedish" # retrieval # Load NLI data logging.info(f"Loading translated NLI triplets from: {nli_data_path}") nli_dataset = load_from_disk(nli_data_path) nli_dataset = nli_dataset.add_column("new_category", ["semantic similarity"] * len(nli_dataset)) nli_dataset = nli_dataset.map(apply_task_prompt) nli_dataset = nli_dataset.rename_column("negative", "negative_1") nli_dataset = pad_dataset_schema(nli_dataset, MAX_NEGATIVES_PER_EXAMPLE) # <--- ADD THIS nli_dataset = nli_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)]) logging.info(f"Loaded {len(nli_dataset)} NLI triplets.") # Load MS MARCO data # logging.info(f"Loading translated MSMARCO triplets from: {msmarco_data_path}") # msmarco_dataset = load_from_disk(msmarco_data_path) # msmarco_dataset = msmarco_dataset.rename_column("query", "anchor") # msmarco_dataset = msmarco_dataset.add_column("new_category", ["retrieval"] * len(msmarco_dataset)) # msmarco_dataset = msmarco_dataset.map(apply_task_prompt) # msmarco_dataset = msmarco_dataset.rename_column("negative", "negative_1") # msmarco_dataset = pad_dataset_schema(msmarco_dataset, MAX_NEGATIVES_PER_EXAMPLE) # <--- ADD THIS # msmarco_dataset = msmarco_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)]) # logging.info(f"Loaded {len(msmarco_dataset)} MSMARCO triplets.") # Load NQ data logging.info(f"Loading translated NQ triplets from: {nq_data_path}") nq_dataset = load_from_disk(nq_data_path) nq_dataset = nq_dataset.rename_column("query", "anchor") nq_dataset = nq_dataset.add_column("new_category", ["question answering"] * len(nq_dataset)) nq_dataset = nq_dataset.map(apply_task_prompt) nq_dataset = nq_dataset.rename_column("negative", "negative_1") nq_dataset = pad_dataset_schema(nq_dataset, MAX_NEGATIVES_PER_EXAMPLE) # <--- ADD THIS nq_dataset = nq_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)]) logging.info(f"Loaded {len(nq_dataset)} NQ triplets.") logging.info(f"Loading synthetic data from: {synthetic_data_path}") synthetic_dataset = load_from_disk(synthetic_data_path) synthetic_dataset = synthetic_dataset.filter(is_good_or_excellent) synthetic_dataset = synthetic_dataset.map(handle_none_negatives) synthetic_dataset = synthetic_dataset.rename_column("query", "anchor") synthetic_dataset = synthetic_dataset.map(apply_task_prompt) synthetic_dataset = synthetic_dataset.rename_column("negative", "negative_1") synthetic_dataset = pad_dataset_schema(synthetic_dataset, MAX_NEGATIVES_PER_EXAMPLE) synthetic_dataset = synthetic_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)]) logging.info(f"Loaded {len(synthetic_dataset)} synthetic triplets.") # Load wiki clustering data cluster_data_path = "/home/ubuntu/work/WSCLToolkit/sent_emb_train/non_vital_code/final_clustering_data.jsonl/" cluster_dataset = load_clustering_dataset( cluster_data_path, max_negatives=MAX_NEGATIVES_PER_EXAMPLE ) # === LOAD MINED DATASETS === logging.info("\n=== Loading Mined Hard Negative Datasets ===") logging.info(f"Using up to {MAX_NEGATIVES_PER_EXAMPLE} negatives per example") mined_datasets = {} for dataset_name, task_category in MINED_DATASET_CONFIG.items(): dataset = load_mined_dataset( dataset_name=dataset_name, task_category=task_category, max_samples=MAX_SAMPLES, # Set to a number if you want to limit samples per dataset max_negatives=MAX_NEGATIVES_PER_EXAMPLE ) if dataset is not None: mined_datasets[dataset_name] = dataset logging.info(f"Successfully loaded {len(mined_datasets)} mined datasets") # === SPLIT DATASETS FOR EVALUATION === logging.info("\n=== Splitting datasets to create dev sets for TripletEvaluators ===") eval_samples_per_dataset = 1000 # Split NLI nli_split = nli_dataset.train_test_split(test_size=eval_samples_per_dataset, seed=42) nli_train = nli_split["train"] nli_dev = nli_split["test"] # Split MS MARCO # msmarco_split = msmarco_dataset.train_test_split( # test_size=eval_samples_per_dataset, seed=42 # ) # msmarco_train = msmarco_split["train"] # msmarco_dev = msmarco_split["test"] # Split NQ nq_split = nq_dataset.train_test_split(test_size=eval_samples_per_dataset, seed=42) nq_train = nq_split["train"] nq_dev = nq_split["test"] # === COMBINE ALL DATASETS === # Use DatasetDict for stratified sampling across datasets datasets = {} # Add original datasets datasets["NLI"] = nli_train #datasets["MSMARCO"] = msmarco_train datasets["NQ"] = nq_train datasets["Topic_Clustering"] = cluster_dataset datasets["Synthetic"] = synthetic_dataset # datasets["Synthetic_Classification"] = synthetic_classification_dataset # Add all mined datasets datasets.update(mined_datasets) # Convert to DatasetDict for stratified batch sampling train_dataset = DatasetDict(datasets) logging.info(f"Created training DatasetDict with {len(train_dataset)} datasets") # === DATASET COMPOSITION TABLE === # For DatasetDict, iterate over the dataset dict total_samples = sum(len(ds) for ds in train_dataset.values()) composition_data = [] for name, ds in train_dataset.items(): num_samples = len(ds) percentage = (num_samples / total_samples) * 100 if total_samples > 0 else 0 composition_data.append({ "Dataset": name, "Number of Examples": num_samples, "Percentage (%)": f"{percentage:.2f}%" }) print("\nšŸ“Š Finetuning Dataset Composition") print("-" * 80) print(f"{'Dataset':<50} | {'Number of Examples':<20} | {'Percentage (%)'}") print("-" * 80) for item in composition_data: print(f"{item['Dataset']:<50} | {item['Number of Examples']:<20} | {item['Percentage (%)']}") print("-" * 80) print(f"{'Total':<50} | {total_samples:<20} | {'100.00%'}") print("-" * 80) logging.info(f"Training with DatasetDict containing {len(train_dataset)} datasets and {total_samples} total samples.") # === CONFIGURE LOSS === loss = CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeightAndMask( model=model, mini_batch_size=config.FT_BS, spread_out_loss_weight=0.1, use_hardness_weighting=True, mask_duplicate_positives=False, hardness_alpha=3.0, scale=50, ) run_name = run_name + "-added datasets" + f"max_{MAX_SAMPLES}_per_dataset" # NLI Evaluator (semantic similarity) - prompts already applied nli_anchors, nli_pos, nli_neg = prepare_triplet_eval_data(nli_dev) nli_triplet_evaluator = TripletEvaluator(nli_anchors, nli_pos, nli_neg, name="nli-triplet-dev") # MS MARCO Evaluator (retrieval - asymmetric) - prompts already applied # msmarco_anchors, msmarco_pos, msmarco_neg = prepare_triplet_eval_data(msmarco_dev) # msmarco_triplet_evaluator = TripletEvaluator(msmarco_anchors, msmarco_pos, msmarco_neg, name="msmarco-triplet-dev") # NQ Evaluator (question answering) - prompts already applied nq_anchors, nq_pos, nq_neg = prepare_triplet_eval_data(nq_dev) nq_triplet_evaluator = TripletEvaluator(nq_anchors, nq_pos, nq_neg, name="nq-triplet-dev") wiki_cluster_eval = CustomJSONLClusteringEvaluator( file_path="/home/ubuntu/work/WSCLToolkit/sent_emb_train/non_vital_code/parsed_wiki_sections.jsonl", min_samples_per_topic=5, max_clusters=500 # <--- Takes 50 random topics (e.g. Afrika, Amager, Kaffe, Volvo...) ) sentiment_eval = load_swedish_reviews_evaluator(prompt_style="standard") # Paraphrase (STS) Evaluator print("\n--- Setting up Paraphrase Evaluator from JSONL ---") para_sents1 = [] para_sents2 = [] para_scores = [] max_score = 0.0 with open(config.SWE_PARA_PATH, "r", encoding="utf-8") as f: for line in f: data = json.loads(line) para_sents1.append(data["sentence_1"]) para_sents2.append(data["sentence_2"]) score = float(data["label"]) para_scores.append(score) if score > max_score: max_score = score if max_score > 0: normalized_scores = [score / max_score for score in para_scores] else: normalized_scores = para_scores # Apply prompts to paraphrase data para_sents1 = ["task: semantic similarity | query: " + s for s in para_sents1] para_sents2 = ["task: semantic similarity | query: " + s for s in para_sents2] sweparaphrase_evaluator = EmbeddingSimilarityEvaluator( sentences1=para_sents1, sentences2=para_sents2, scores=normalized_scores, name="sweparaphrase-dev", ) # FAQ Retrieval Evaluator print("\n--- Setting up FAQ Retrieval Evaluator from JSONL ---") with open(config.SWE_FAQ_PATH, "r", encoding="utf-8") as f: faq_data = [json.loads(line) for line in f] unique_answers = {answer for item in faq_data for answer in item["candidate_answers"]} answer_to_cid = {answer: f"doc_{i}" for i, answer in enumerate(unique_answers)} # Apply prompts to FAQ corpus (documents) faq_corpus = { cid: "title: none | text: " + answer for answer, cid in answer_to_cid.items() } faq_queries = {} faq_relevant_docs = {} for i, item in enumerate(faq_data): query_id = f"q_{i}" faq_queries[query_id] = "task: search result | query: " + item['question'] correct_answer = item["candidate_answers"][item["label"]] correct_cid = answer_to_cid[correct_answer] faq_relevant_docs[query_id] = {correct_cid} swefaq_evaluator = InformationRetrievalEvaluator( queries=faq_queries, corpus=faq_corpus, relevant_docs=faq_relevant_docs, name="swefaq-dev", ) nanobeir_evaluators = create_nanobeir_evaluators( languages=["sv", "no"], dataset_names=[ "NanoMSMARCO", "NanoNQ", "NanoQuoraRetrieval", "NanoFEVER", "NanoHotpotQA" ] ) main_evaluator = SequentialEvaluator( evaluators=[ nli_triplet_evaluator, # msmarco_triplet_evaluator, nq_triplet_evaluator, sweparaphrase_evaluator, swefaq_evaluator, wiki_cluster_eval, sentiment_eval ] + nanobeir_evaluators, ) print("\nAll evaluators combined and ready.") # === CONFIGURE TRAINING ARGUMENTS === training_args = SentenceTransformerTrainingArguments( output_dir=config.FT_OUTPUT_DIR, num_train_epochs=1, per_device_eval_batch_size=64, per_device_train_batch_size=config.FT_CONCEPTUAL_BS, learning_rate=config.FT_LR, bf16=True, report_to="wandb", run_name=run_name, save_total_limit=2, logging_steps=config.LOGGING_STEPS, eval_strategy=config.EVAL_STRATEGY, save_strategy=config.EVAL_STRATEGY, eval_steps=config.FT_EVAL_STEPS, save_steps=config.FT_SAVE_STEPS, load_best_model_at_end=True, warmup_ratio=config.FT_WARMUP, weight_decay=config.FT_WEIGHT_DECAY, metric_for_best_model="eval_msmarco-triplet-dev_cosine_accuracy", greater_is_better=True, # Enable multi-dataset batch sampling for stratified batches multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, ddp_find_unused_parameters=False, ) # === INITIALIZE AND RUN TRAINER === trainer = SentenceTransformerTrainer( model=model, args=training_args, train_dataset=train_dataset, loss=loss, evaluator=main_evaluator, ) logging.info("--- Running initial evaluation on the untrained model (step 0) ---") trainer.evaluate() logging.info(f"Starting model fine-tuning for run: {run_name}") trainer.train() logging.info(f"--- Fine-tuning complete for {run_name}! ---") if __name__ == "__main__": parser = argparse.ArgumentParser( description="A flexible fine-tuning script for Sentence Transformer models." ) parser.add_argument( "--fine_tune_model_path", type=str, required=True, help="Path or Hub name of the model to fine-tune.", ) parser.add_argument( "--probe_run", action="store_true", help="If set, runs a short probe run." ) cli_args = parser.parse_args() main(cli_args)