saga-embed_v1 / code.txt
nicher92's picture
Initial model upload for MTEB evaluation
3be07ac verified
# DID NOT USE THE NEW MSMARCO DATSET, OTHERWISE EVERYTHING IS THE SAME
import logging
import os
import argparse
import config # Assuming your config file is available
import json
import torch
import pandas as pd
os.environ["HF_HOME"] = config.CACHE_DIR
from pathlib import Path
from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset, DatasetDict
from sentence_transformers.evaluation import (
SequentialEvaluator,
EmbeddingSimilarityEvaluator,
InformationRetrievalEvaluator,
TripletEvaluator,
)
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import (
SentenceTransformerTrainingArguments,
BatchSamplers,
MultiDatasetBatchSamplers
)
from sentence_transformers import SentenceTransformer, losses, models
import transformers
from src.custom_loss.CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeightAndMask import CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeightAndMask
from src.custom_loss.CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeight import CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeight
import random
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import v_measure_score, adjusted_rand_score, accuracy_score, f1_score
from sentence_transformers.evaluation import SentenceEvaluator
from sklearn.linear_model import LogisticRegression
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
transformers.logging.set_verbosity_info()
# --- MINED DATASET CONFIGURATION ---
MINED_DATASETS_BASE_PATH = Path("/mnt/disk2/translated_datasets")
# Map dataset names to their task categories
MINED_DATASET_CONFIG = {
"gooaq_train_swedish_triplets_scored": "question answering",
"reddit_train_swedish_triplets_scored": "retrieval",
"xsum_train_swedish_triplets_scored": "clustering",
"simple-wiki_train_swedish_triplets_scored": "semantic similarity",
"s2orc_train_swedish_saved_triplets_scored": "retrieval",
"amazon-reviews_train_swedish_saved_triplets_scored": "retrieval",
"paq_train_swedish_queries_retranslated_triplets_scored": "question answering",
"stackexchange-duplicates_train_swedish_triplets_scored": "semantic similarity",
"wikipedia-sections_train_swedish_triplets_scored": "retrieval",
"msmarco_triplets_swedish_triplets_scored" : "retrieval",
}
synthetic_data_path = "/mnt/disk2/combined_synthetic_data_deduplicated_classified"
#synthetic_classification_data_path = "/mnt/disk2/classification_dataset_graded_subset.jsonl"
#combined_synthetic_data_deduplicated_classified
# Maximum number of negatives to use per example from mined datasets
# The loss function (CachedMultipleNegativesRankingLoss) can handle multiple negatives
# Format: (anchor, positive, negative_1, negative_2, ..., negative_n)
MAX_NEGATIVES_PER_EXAMPLE = 10
MAX_SAMPLES = 500_000
NANOBEIR_DATASETS = [
"NanoArguAna",
"NanoClimateFEVER",
"NanoDBPedia",
"NanoFEVER",
"NanoFiQA2018",
"NanoHotpotQA",
"NanoMSMARCO",
"NanoNFCorpus",
"NanoNQ",
"NanoQuoraRetrieval",
"NanoSCIDOCS",
"NanoSciFact",
"NanoTouche2020",
]
# Map dataset names to task types for appropriate prompting
NANOBEIR_TASK_TYPES = {
"NanoArguAna": "retrieval",
"NanoClimateFEVER": "retrieval",
"NanoDBPedia": "retrieval",
"NanoFEVER": "retrieval",
"NanoFiQA2018": "retrieval",
"NanoHotpotQA": "retrieval",
"NanoMSMARCO": "retrieval",
"NanoNFCorpus": "retrieval",
"NanoNQ": "question answering",
"NanoQuoraRetrieval": "semantic similarity",
"NanoSCIDOCS": "retrieval",
"NanoSciFact": "retrieval",
"NanoTouche2020": "retrieval",
}
class ProbeClassificationEvaluator(SentenceEvaluator):
"""
Generic evaluator that trains a Logistic Regression probe on train_set
and evaluates accuracy on test_set.
"""
def __init__(self, dataset_name, sentences_train, labels_train, sentences_test, labels_test, batch_size=32, prompt_prefix=""):
self.name = f"eval_{dataset_name}_classification"
self.sentences_train = [prompt_prefix + s for s in sentences_train]
self.labels_train = labels_train
self.sentences_test = [prompt_prefix + s for s in sentences_test]
self.labels_test = labels_test
self.batch_size = batch_size
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
# 1. Encode
emb_train = model.encode(self.sentences_train, batch_size=self.batch_size, show_progress_bar=False)
emb_test = model.encode(self.sentences_test, batch_size=self.batch_size, show_progress_bar=False)
# 2. Train Probe (Fast LR)
clf = LogisticRegression(random_state=42, solver='lbfgs', max_iter=100, n_jobs=-1)
clf.fit(emb_train, self.labels_train)
# 3. Predict
preds = clf.predict(emb_test)
acc = accuracy_score(self.labels_test, preds)
logging.info(f"{self.name}: Accuracy = {acc:.4f}")
return acc
def load_swedish_reviews_evaluator(prompt_style="standard"):
"""
Loads 'timpal0l/swedish_reviews' for binary sentiment classification.
Safe from MTEB leakage.
"""
try:
logging.info("Loading timpal0l/swedish_reviews for Sentiment Probe...")
# Load the test split (approx 10k samples, good for eval)
dataset = load_dataset("timpal0l/swedish_reviews", split="test")
# Downsample if needed for speed (e.g. keep 2000 samples)
if len(dataset) > 50:
dataset = dataset.shuffle(seed=42).select(range(50))
sentences = dataset["text"]
labels = dataset["label"]
# Define Prompt
if prompt_style == "standard":
# "task: classification" is the standard prompt for this
prompt = "task: classification | query: "
else:
prompt = ""
# Split 50/50 for Train/Test Probe
# We need to train the logistic regression probe on *some* data to test the embeddings
split_idx = len(sentences) // 2
evaluator = ProbeClassificationEvaluator(
dataset_name="SwedishReviews-Sentiment",
sentences_train=sentences[:split_idx],
labels_train=labels[:split_idx],
sentences_test=sentences[split_idx:],
labels_test=labels[split_idx:],
prompt_prefix=prompt,
batch_size=32
)
return evaluator
except Exception as e:
logging.warning(f"Failed to load Swedish Reviews: {e}")
return None
class CustomJSONLClassificationEvaluator(SentenceEvaluator):
"""
Reads a JSONL file and trains a Logistic Regression probe to predict 'main_title'.
This checks if the topics are linearly separable in the embedding space.
"""
def __init__(self, file_path, min_samples_per_label=5, max_classes=20, batch_size=32):
self.name = "eval_wiki_classification"
self.batch_size = batch_size
self.prompt = "task: classification | query: "
self.sentences_train = []
self.labels_train = []
self.sentences_test = []
self.labels_test = []
logging.info(f"Loading classification data from {file_path}...")
# 1. Load Data by Class
label_map = {} # { "Afrika": ["text1", "text2"...] }
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
try:
row = json.loads(line)
text = row.get('text', '').strip()
label = row.get('main_title', '').strip()
if text and label:
if label not in label_map:
label_map[label] = []
label_map[label].append(text)
except:
continue
# 2. Filter Classes
valid_labels = [l for l, texts in label_map.items() if len(texts) >= min_samples_per_label]
valid_labels.sort() # Deterministic order
# 3. Subsample Classes (Avoid training on 5000 classes)
rng = random.Random(42)
if len(valid_labels) > max_classes:
selected_labels = rng.sample(valid_labels, max_classes)
else:
selected_labels = valid_labels
logging.info(f"Classification Probe: Using {len(selected_labels)} classes (topics).")
# 4. Create Train/Test Split (80/20 per class)
for label in selected_labels:
texts = label_map[label]
# Shuffle texts for this label
rng.shuffle(texts)
# Use max 50 samples per class to keep probe fast
texts = texts[:50]
split_idx = int(0.8 * len(texts))
# Train set
for t in texts[:split_idx]:
self.sentences_train.append(self.prompt + t)
self.labels_train.append(label)
# Test set
for t in texts[split_idx:]:
self.sentences_test.append(self.prompt + t)
self.labels_test.append(label)
logging.info(f"Probe Sizes: Train={len(self.sentences_train)}, Test={len(self.sentences_test)}")
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if not self.sentences_train:
return 0.0
# 1. Encode
emb_train = model.encode(self.sentences_train, batch_size=self.batch_size, show_progress_bar=False)
emb_test = model.encode(self.sentences_test, batch_size=self.batch_size, show_progress_bar=False)
# 2. Train Probe (Fast Logistic Regression)
clf = LogisticRegression(random_state=42, solver='lbfgs', max_iter=100, n_jobs=-1)
clf.fit(emb_train, self.labels_train)
# 3. Predict & Score
preds = clf.predict(emb_test)
acc = accuracy_score(self.labels_test, preds)
# Macro F1 is better if classes are imbalanced
f1 = f1_score(self.labels_test, preds, average='macro')
logging.info(f"{self.name}: Accuracy={acc:.4f} | F1-Macro={f1:.4f}")
return acc
class CustomJSONLClusteringEvaluator(SentenceEvaluator):
"""
Evaluates clustering on a specific number of distinct Wikipedia topics.
"""
def __init__(self, file_path, min_samples_per_topic=5, max_clusters=50, batch_size=32):
self.name = "eval_wiki_clustering"
self.batch_size = batch_size
self.prompt = "task: clustering | query: "
self.sentences = []
self.labels = []
logging.info(f"Loading clustering data from {file_path}...")
# 1. Load ALL data into memory organized by Topic
# Structure: { "Afrika": ["text1", "text2"...], "Amager": ["text1"...] }
topic_map = {}
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
try:
row = json.loads(line)
text = row.get('text', '').strip()
# Use 'main_title' as the Topic Label
label = row.get('main_title', '').strip()
if text and label:
if label not in topic_map:
topic_map[label] = []
topic_map[label].append(text)
except:
continue
# 2. Filter topics that are too small
# We need topics with enough content to actually cluster meaningful points
valid_topics = [
topic for topic, texts in topic_map.items()
if len(texts) >= min_samples_per_topic
]
logging.info(f"Found {len(valid_topics)} valid topics with >= {min_samples_per_topic} samples.")
# 3. Sample exactly 'max_clusters' topics (e.g., 50)
# We sort first to make the random seed deterministic across runs
valid_topics.sort()
# Use a fixed seed so the "random" 50 topics are the same every time you restart training
rng = random.Random(42)
if len(valid_topics) > max_clusters:
selected_topics = rng.sample(valid_topics, max_clusters)
logging.info(f"Subsampled to {max_clusters} distinct topics for evaluation.")
else:
selected_topics = valid_topics
logging.info(f"Using all {len(selected_topics)} topics (fewer than max_clusters).")
# 4. Flatten the data for the evaluator
for topic in selected_topics:
# You can also limit samples per topic here (e.g. max 20 paragraphs per topic) to keep it balanced
texts = topic_map[topic][:20]
for text in texts:
self.sentences.append(self.prompt + text)
self.labels.append(topic)
self.n_clusters = len(selected_topics)
logging.info(f"Final Clustering Probe: {len(self.sentences)} samples across {self.n_clusters} topics.")
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if not self.sentences:
return 0.0
# Encode
embeddings = model.encode(self.sentences, batch_size=self.batch_size, show_progress_bar=False)
# Cluster
clustering = MiniBatchKMeans(
n_clusters=self.n_clusters,
batch_size=256,
random_state=42,
n_init='auto'
)
clustering.fit(embeddings)
# Score
v_score = v_measure_score(self.labels, clustering.labels_)
ari_score = adjusted_rand_score(self.labels, clustering.labels_)
logging.info(f"{self.name}: V-Measure={v_score:.4f} | ARI={ari_score:.4f}")
return v_score
def load_nanobeir_evaluator(dataset_name, language, task_type="retrieval"):
"""
Load a NanoBEIR dataset for a specific language and create an InformationRetrievalEvaluator.
Args:
dataset_name: Name of the NanoBEIR dataset (e.g., "NanoMSMARCO")
language: Language code ("sv" for Swedish, "no" for Norwegian)
task_type: Task type for prompting ("retrieval", "question answering", "semantic similarity")
Returns:
InformationRetrievalEvaluator or None if loading fails
"""
try:
logging.info(f"Loading {dataset_name} for language: {language}")
# Load corpus (documents)
corpus_data = load_dataset(
"lightonai/nanobeir-multilingual",
f"{dataset_name}_{language}",
split="corpus"
)
# Load queries
queries_data = load_dataset(
"lightonai/nanobeir-multilingual",
f"{dataset_name}_{language}",
split="queries"
)
# Load qrels (relevance judgments) - language independent
qrels_data = load_dataset(
"lightonai/nanobeir-multilingual",
dataset_name,
split="qrels"
)
# Apply task-specific prompts
if task_type == "retrieval":
query_prompt = "task: search result | query: "
doc_prompt = "title: none | text: "
elif task_type == "question answering":
query_prompt = "task: question answering | query: "
doc_prompt = "title: none | text: "
elif task_type == "semantic similarity":
query_prompt = "task: semantic similarity | query: "
doc_prompt = "task: semantic similarity | query: "
else:
query_prompt = ""
doc_prompt = ""
# Build corpus dictionary with prompts
corpus = {}
for item in corpus_data:
doc_id = item['_id']
# Combine title and text if title exists
if 'title' in item and item['title']:
text = f"{item['title']} {item['text']}"
else:
text = item['text']
corpus[doc_id] = doc_prompt + text
# Build queries dictionary with prompts
queries = {}
for item in queries_data:
query_id = item['_id']
queries[query_id] = query_prompt + item['text']
# Build relevance dictionary
# Note: NanoBEIR qrels only have query-id and corpus-id (no score column)
# All entries in qrels are considered relevant
relevant_docs = {}
for item in qrels_data:
query_id = item['query-id']
corpus_id = item['corpus-id']
if query_id not in relevant_docs:
relevant_docs[query_id] = set()
relevant_docs[query_id].add(corpus_id)
# Filter queries to only those with relevant documents
queries = {qid: text for qid, text in queries.items() if qid in relevant_docs}
if not queries:
logging.warning(f"No queries with relevant documents found for {dataset_name}_{language}")
return None
# Create evaluator
evaluator_name = f"{dataset_name}-{language}-dev"
evaluator = InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant_docs,
name=evaluator_name,
# Only compute NDCG@10 to reduce metric clutter
mrr_at_k=[10], # Disable MRR
ndcg_at_k=[10], # Only NDCG@10
accuracy_at_k=[1], # Disable accuracy
precision_recall_at_k=[1], # Disable precision/recall
map_at_k=[100], # Disable MAP
)
logging.info(f"Created {evaluator_name} with {len(queries)} queries and {len(corpus)} documents")
return evaluator
except Exception as e:
logging.error(f"Failed to load {dataset_name} for {language}: {e}")
return None
def create_nanobeir_evaluators(languages=["sv", "no"], dataset_names=None):
"""
Create InformationRetrievalEvaluators for multiple NanoBEIR datasets and languages.
Args:
languages: List of language codes (default: ["sv", "no"])
dataset_names: List of dataset names to use (default: all NANOBEIR_DATASETS)
Returns:
List of InformationRetrievalEvaluators
"""
if dataset_names is None:
dataset_names = NANOBEIR_DATASETS
evaluators = []
for language in languages:
logging.info(f"\n=== Loading NanoBEIR datasets for {language.upper()} ===")
for dataset_name in dataset_names:
task_type = NANOBEIR_TASK_TYPES.get(dataset_name, "retrieval")
evaluator = load_nanobeir_evaluator(dataset_name, language, task_type)
if evaluator is not None:
evaluators.append(evaluator)
logging.info(f"\nSuccessfully created {len(evaluators)} NanoBEIR evaluators")
return evaluators
def handle_none_negatives(example):
"""Replaces None in the 'negative' column with an empty string."""
if example["negative"] is None:
example["negative"] = ""
return example
def is_good_or_excellent(example):
"""Filter function to keep only good or excellent graded examples."""
if "dialect" in example["task_description"].lower():
return False
return example['grade'] in ['good', 'excellent']
def prepare_triplet_eval_data(dev_set):
"""
Prepares anchors, positives, and negatives for a TripletEvaluator.
Extracts raw data WITHOUT prompts - prompts will be applied separately.
Handles both 'negative' and 'negative_1' column names.
"""
anchors = dev_set["anchor"]
positives = dev_set["positive"]
# Handle both column naming conventions
if "negative" in dev_set.column_names:
negatives = dev_set["negative"]
elif "negative_1" in dev_set.column_names:
negatives = dev_set["negative_1"]
else:
raise ValueError(f"No negative column found. Available columns: {dev_set.column_names}")
return anchors, positives, negatives
def load_clustering_dataset(dataset_path, max_negatives=5):
"""
Loads the pre-saved Clustering dataset from disk.
Applies clustering prompts and handles schema consistency for DDP.
"""
if not os.path.exists(dataset_path):
logging.warning(f"Clustering data not found at: {dataset_path}")
return None
logging.info(f"Loading clustering dataset from disk: {dataset_path}")
dataset = load_from_disk(str(dataset_path))
# 1. Ensure we don't have None values (Tokenizer crash protection)
def fill_empty(example):
for key in example.keys():
if example[key] is None:
example[key] = ""
return example
dataset = dataset.map(fill_empty)
# 2. Add clustering category and apply prompts BEFORE padding
dataset = dataset.add_column("new_category", ["clustering"] * len(dataset))
dataset = dataset.map(apply_task_prompt)
dataset = dataset.remove_columns(['new_category'])
# 3. Enforce exact column schema for DDP - pad AFTER applying prompts
desired_cols = ["anchor", "positive"] + [f"negative_{i+1}" for i in range(max_negatives)]
# Pad if we have fewer negatives than the global config
for col in desired_cols:
if col not in dataset.column_names:
dataset = dataset.add_column(col, [""] * len(dataset))
# Select only the needed columns (drops extra negatives if you have > max)
dataset = dataset.select_columns(desired_cols)
logging.info(f"Loaded {len(dataset)} clustering examples with prompts applied.")
return dataset
def apply_task_prompt(example, dropout_rate=0.1):
"""
Applies strict task-specific prompts based on the provided schema.
Skips empty strings to avoid adding prompts to padding columns.
"""
task_type = example["new_category"]
if random.random() < dropout_rate:
return example
# --- 1. RETRIEVAL (Asymmetric) ---
# Anchor: "task: search result | query: {content}"
# Docs: "title: none | text: {content}"
if task_type == "retrieval":
if example['anchor']: # Only apply if not empty
example['anchor'] = "task: search result | query: " + example['anchor']
# Define the document prompt (assuming no title column available, using 'none')
doc_prompt = "title: none | text: "
if example['positive']: # Only apply if not empty
example['positive'] = doc_prompt + example['positive']
# Apply to all negatives
for key in list(example.keys()):
if key.startswith('negative') and example[key]: # This already checks for non-empty
example[key] = doc_prompt + example[key]
# --- 2. QUESTION ANSWERING (Symmetric) ---
# User Request: Same prompt for anchor and positive.
# Prompt: "task: question answering | query: {content}"
elif task_type == "question answering":
instruct = "task: question answering | query: "
if example['anchor']:
example['anchor'] = instruct + example['anchor']
if example['positive']:
example['positive'] = instruct + example['positive']
for key in list(example.keys()):
if key.startswith('negative') and example[key]:
example[key] = instruct + example[key]
# --- 3. CLUSTERING (Symmetric) ---
# Prompt: "task: clustering | query: {content}"
elif task_type == "clustering":
instruct = "task: clustering | query: "
if example['anchor']:
example['anchor'] = instruct + example['anchor']
if example['positive']:
example['positive'] = instruct + example['positive']
for key in list(example.keys()):
if key.startswith('negative') and example[key]:
example[key] = instruct + example[key]
# --- 4. CLASSIFICATION (Symmetric) ---
# Prompt: "task: classification | query: {content}"
elif task_type == "classification":
instruct = "task: classification | query: "
if example['anchor']:
example['anchor'] = instruct + example['anchor']
if example['positive']:
example['positive'] = instruct + example['positive']
for key in list(example.keys()):
if key.startswith('negative') and example[key]:
example[key] = instruct + example[key]
# --- 5. SEMANTIC SIMILARITY (Symmetric) ---
# Prompt: "task: sentence similarity | query: {content}"
elif task_type == "semantic similarity":
instruct = "task: semantic similarity | query: "
if example['anchor']:
example['anchor'] = instruct + example['anchor']
if example['positive']:
example['positive'] = instruct + example['positive']
for key in list(example.keys()):
if key.startswith('negative') and example[key]:
example[key] = instruct + example[key]
# Fallback
else:
logging.warning(f"Unknown task category: {task_type}. No prompt applied.")
return example
def load_mined_dataset(dataset_name, task_category, max_samples=None, max_negatives=10):
"""
Load a single mined dataset and prepare it for training.
Args:
dataset_name: Name of the dataset directory
task_category: Task category for prompting
max_samples: Optional limit on number of samples to load
max_negatives: Maximum number of negatives to include per example (default: 10)
Returns:
Processed dataset with prompts applied, with multiple negatives per example
"""
dataset_path = MINED_DATASETS_BASE_PATH / dataset_name
if not dataset_path.exists():
logging.warning(f"Dataset path does not exist: {dataset_path}. Skipping.")
return None
logging.info(f"Loading mined dataset from: {dataset_path}")
dataset = load_from_disk(str(dataset_path))
if max_samples and len(dataset) > max_samples:
logging.info(f"Sampling {max_samples} from {len(dataset)} samples")
dataset = dataset.shuffle(seed=42).select(range(max_samples))
logging.info(f"Loaded {len(dataset)} samples from {dataset_name}")
# The mined datasets have structure: anchor, positive, pos_score, negatives (list), neg_scores (list)
# We need to convert to: anchor, positive, negative_1, negative_2, ..., negative_n
# BUT we only keep negatives that actually exist (no None/empty values)
def expand_negatives(example):
"""
Convert the list of negatives into separate columns.
Ensures ALL columns from negative_1 to negative_{max_negatives} exist.
Pads missing negatives with None.
"""
result = {
'anchor': example['anchor'],
'positive': example['positive'],
}
# Get the list of negatives, ensure it's a list
raw_negatives = example.get('negatives', [])
if raw_negatives is None:
raw_negatives = []
# Filter out empty strings or None values from the source list
valid_negatives = [n for n in raw_negatives if n]
# --- KEY FIX IS HERE ---
# We must iterate up to max_negatives every time to ensure the
# dictionary keys (schema) are identical for every row.
for i in range(max_negatives):
key_name = f'negative_{i+1}'
if i < len(valid_negatives):
# We have a valid negative
result[key_name] = valid_negatives[i]
else:
# We ran out of negatives -> Fill with None
# This ensures the column exists in the Arrow table
result[key_name] = ''
# Check if we have at least one negative
result['_has_negatives'] = len(valid_negatives) > 0
return result
dataset = dataset.map(expand_negatives, remove_columns=dataset.column_names)
# Filter out any examples without valid negatives
dataset = dataset.filter(lambda x: x['_has_negatives'])
dataset = dataset.remove_columns(['_has_negatives'])
if len(dataset) == 0:
logging.warning(f"No valid examples with negatives found in {dataset_name}")
return None
# Add task category for prompting
dataset = dataset.add_column("new_category", [task_category] * len(dataset))
dataset = dataset.map(apply_task_prompt)
dataset = dataset.remove_columns(['new_category'])
# Count negatives in first example for logging
sample = dataset[0]
num_negatives = sum(1 for k in sample.keys() if k.startswith('negative_'))
logging.info(f"Prepared {len(dataset)} training examples from {dataset_name}")
logging.info(f"Examples have between 1 and {num_negatives} negatives each")
desired_columns = ["anchor", "positive"]
# scan for all negative columns that actually exist in the dataset
existing_columns = dataset.column_names
negative_cols = [c for c in existing_columns if c.startswith("negative_")]
# Sort them to ensure negative_1 comes before negative_2, etc.
# We sort by the integer number in the column name
negative_cols.sort(key=lambda x: int(x.split('_')[1]))
desired_columns.extend(negative_cols)
# 2. Force the dataset to use this specific order
dataset = dataset.select_columns(desired_columns)
logging.info(f"Enforced column order: {dataset.column_names}")
return dataset
def pad_dataset_schema(dataset, total_negatives=5):
"""
Ensures the dataset has columns negative_1 to negative_{total_negatives}.
Fills missing columns with empty strings.
"""
new_columns = {}
existing_cols = dataset.column_names
for i in range(total_negatives):
col_name = f"negative_{i+1}"
if col_name not in existing_cols:
# Create a column of empty strings efficiently
new_columns[col_name] = [""] * len(dataset)
if new_columns:
# Add all new columns at once
from datasets import Dataset
# We need to concatenate or add_column.
# For efficiency with HF datasets, simpler to just add column by column or map
for col_name, data in new_columns.items():
dataset = dataset.add_column(col_name, data)
return dataset
def main(args):
run_type = "probe" if args.probe_run else "full"
base_model_name = Path(args.fine_tune_model_path).name
data_name = "ms_marco_nli_mined"
run_name = f"finetune-CachedMNRL-{base_model_name}-on-{data_name}-{run_type}"
output_dir = os.path.join(config.OUTPUT_DIR, run_name)
logging.info(f"--- Starting {run_type.upper()} Run: {run_name} ---")
logging.info(f"Fine-tuning model from: {args.fine_tune_model_path}")
model = SentenceTransformer(args.fine_tune_model_path)
logging.info(model)
logging.info("Patching the model's `tokenize` method to remove token_type_ids...")
original_tokenize = model.tokenize
def patched_tokenize(*args, **kwargs):
# Call the original tokenizer to get the encoded inputs
tokenized_output = original_tokenize(*args, **kwargs)
# Remove the unwanted key from the output
if "token_type_ids" in tokenized_output:
del tokenized_output["token_type_ids"]
return tokenized_output
# Replace the original method with our patched version
model.tokenize = patched_tokenize
logging.info("Model's `tokenize` method patched successfully.")
# Tokenizer patch can remain
logging.info(f"Setting model max length to {config.FT_MODEL_MAX_LENGTH}")
model.max_seq_length = config.FT_MODEL_MAX_LENGTH
# === LOAD ORIGINAL DATASETS ===
nli_data_path = "/mnt/disk2/snli_triplets_swedish" # sts
msmarco_data_path = "/mnt/disk2/msmarco_triplets_swedish" # retrieval
nq_data_path = "/mnt/disk2/nq_triplets_swedish" # retrieval
# Load NLI data
logging.info(f"Loading translated NLI triplets from: {nli_data_path}")
nli_dataset = load_from_disk(nli_data_path)
nli_dataset = nli_dataset.add_column("new_category", ["semantic similarity"] * len(nli_dataset))
nli_dataset = nli_dataset.map(apply_task_prompt)
nli_dataset = nli_dataset.rename_column("negative", "negative_1")
nli_dataset = pad_dataset_schema(nli_dataset, MAX_NEGATIVES_PER_EXAMPLE) # <--- ADD THIS
nli_dataset = nli_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)])
logging.info(f"Loaded {len(nli_dataset)} NLI triplets.")
# Load MS MARCO data
# logging.info(f"Loading translated MSMARCO triplets from: {msmarco_data_path}")
# msmarco_dataset = load_from_disk(msmarco_data_path)
# msmarco_dataset = msmarco_dataset.rename_column("query", "anchor")
# msmarco_dataset = msmarco_dataset.add_column("new_category", ["retrieval"] * len(msmarco_dataset))
# msmarco_dataset = msmarco_dataset.map(apply_task_prompt)
# msmarco_dataset = msmarco_dataset.rename_column("negative", "negative_1")
# msmarco_dataset = pad_dataset_schema(msmarco_dataset, MAX_NEGATIVES_PER_EXAMPLE) # <--- ADD THIS
# msmarco_dataset = msmarco_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)])
# logging.info(f"Loaded {len(msmarco_dataset)} MSMARCO triplets.")
# Load NQ data
logging.info(f"Loading translated NQ triplets from: {nq_data_path}")
nq_dataset = load_from_disk(nq_data_path)
nq_dataset = nq_dataset.rename_column("query", "anchor")
nq_dataset = nq_dataset.add_column("new_category", ["question answering"] * len(nq_dataset))
nq_dataset = nq_dataset.map(apply_task_prompt)
nq_dataset = nq_dataset.rename_column("negative", "negative_1")
nq_dataset = pad_dataset_schema(nq_dataset, MAX_NEGATIVES_PER_EXAMPLE) # <--- ADD THIS
nq_dataset = nq_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)])
logging.info(f"Loaded {len(nq_dataset)} NQ triplets.")
logging.info(f"Loading synthetic data from: {synthetic_data_path}")
synthetic_dataset = load_from_disk(synthetic_data_path)
synthetic_dataset = synthetic_dataset.filter(is_good_or_excellent)
synthetic_dataset = synthetic_dataset.map(handle_none_negatives)
synthetic_dataset = synthetic_dataset.rename_column("query", "anchor")
synthetic_dataset = synthetic_dataset.map(apply_task_prompt)
synthetic_dataset = synthetic_dataset.rename_column("negative", "negative_1")
synthetic_dataset = pad_dataset_schema(synthetic_dataset, MAX_NEGATIVES_PER_EXAMPLE)
synthetic_dataset = synthetic_dataset.select_columns(["anchor", "positive"] + [f"negative_{i+1}" for i in range(MAX_NEGATIVES_PER_EXAMPLE)])
logging.info(f"Loaded {len(synthetic_dataset)} synthetic triplets.")
# Load wiki clustering data
cluster_data_path = "/home/ubuntu/work/WSCLToolkit/sent_emb_train/non_vital_code/final_clustering_data.jsonl/"
cluster_dataset = load_clustering_dataset(
cluster_data_path,
max_negatives=MAX_NEGATIVES_PER_EXAMPLE
)
# === LOAD MINED DATASETS ===
logging.info("\n=== Loading Mined Hard Negative Datasets ===")
logging.info(f"Using up to {MAX_NEGATIVES_PER_EXAMPLE} negatives per example")
mined_datasets = {}
for dataset_name, task_category in MINED_DATASET_CONFIG.items():
dataset = load_mined_dataset(
dataset_name=dataset_name,
task_category=task_category,
max_samples=MAX_SAMPLES, # Set to a number if you want to limit samples per dataset
max_negatives=MAX_NEGATIVES_PER_EXAMPLE
)
if dataset is not None:
mined_datasets[dataset_name] = dataset
logging.info(f"Successfully loaded {len(mined_datasets)} mined datasets")
# === SPLIT DATASETS FOR EVALUATION ===
logging.info("\n=== Splitting datasets to create dev sets for TripletEvaluators ===")
eval_samples_per_dataset = 1000
# Split NLI
nli_split = nli_dataset.train_test_split(test_size=eval_samples_per_dataset, seed=42)
nli_train = nli_split["train"]
nli_dev = nli_split["test"]
# Split MS MARCO
# msmarco_split = msmarco_dataset.train_test_split(
# test_size=eval_samples_per_dataset, seed=42
# )
# msmarco_train = msmarco_split["train"]
# msmarco_dev = msmarco_split["test"]
# Split NQ
nq_split = nq_dataset.train_test_split(test_size=eval_samples_per_dataset, seed=42)
nq_train = nq_split["train"]
nq_dev = nq_split["test"]
# === COMBINE ALL DATASETS ===
# Use DatasetDict for stratified sampling across datasets
datasets = {}
# Add original datasets
datasets["NLI"] = nli_train
#datasets["MSMARCO"] = msmarco_train
datasets["NQ"] = nq_train
datasets["Topic_Clustering"] = cluster_dataset
datasets["Synthetic"] = synthetic_dataset
# datasets["Synthetic_Classification"] = synthetic_classification_dataset
# Add all mined datasets
datasets.update(mined_datasets)
# Convert to DatasetDict for stratified batch sampling
train_dataset = DatasetDict(datasets)
logging.info(f"Created training DatasetDict with {len(train_dataset)} datasets")
# === DATASET COMPOSITION TABLE ===
# For DatasetDict, iterate over the dataset dict
total_samples = sum(len(ds) for ds in train_dataset.values())
composition_data = []
for name, ds in train_dataset.items():
num_samples = len(ds)
percentage = (num_samples / total_samples) * 100 if total_samples > 0 else 0
composition_data.append({
"Dataset": name,
"Number of Examples": num_samples,
"Percentage (%)": f"{percentage:.2f}%"
})
print("\n📊 Finetuning Dataset Composition")
print("-" * 80)
print(f"{'Dataset':<50} | {'Number of Examples':<20} | {'Percentage (%)'}")
print("-" * 80)
for item in composition_data:
print(f"{item['Dataset']:<50} | {item['Number of Examples']:<20} | {item['Percentage (%)']}")
print("-" * 80)
print(f"{'Total':<50} | {total_samples:<20} | {'100.00%'}")
print("-" * 80)
logging.info(f"Training with DatasetDict containing {len(train_dataset)} datasets and {total_samples} total samples.")
# === CONFIGURE LOSS ===
loss = CachedMultipleNegativesRankingLossWithSpreadOutHardnessWeightAndMask(
model=model,
mini_batch_size=config.FT_BS,
spread_out_loss_weight=0.1,
use_hardness_weighting=True,
mask_duplicate_positives=False,
hardness_alpha=3.0,
scale=50,
)
run_name = run_name + "-added datasets" + f"max_{MAX_SAMPLES}_per_dataset"
# NLI Evaluator (semantic similarity) - prompts already applied
nli_anchors, nli_pos, nli_neg = prepare_triplet_eval_data(nli_dev)
nli_triplet_evaluator = TripletEvaluator(nli_anchors, nli_pos, nli_neg, name="nli-triplet-dev")
# MS MARCO Evaluator (retrieval - asymmetric) - prompts already applied
# msmarco_anchors, msmarco_pos, msmarco_neg = prepare_triplet_eval_data(msmarco_dev)
# msmarco_triplet_evaluator = TripletEvaluator(msmarco_anchors, msmarco_pos, msmarco_neg, name="msmarco-triplet-dev")
# NQ Evaluator (question answering) - prompts already applied
nq_anchors, nq_pos, nq_neg = prepare_triplet_eval_data(nq_dev)
nq_triplet_evaluator = TripletEvaluator(nq_anchors, nq_pos, nq_neg, name="nq-triplet-dev")
wiki_cluster_eval = CustomJSONLClusteringEvaluator(
file_path="/home/ubuntu/work/WSCLToolkit/sent_emb_train/non_vital_code/parsed_wiki_sections.jsonl",
min_samples_per_topic=5,
max_clusters=500 # <--- Takes 50 random topics (e.g. Afrika, Amager, Kaffe, Volvo...)
)
sentiment_eval = load_swedish_reviews_evaluator(prompt_style="standard")
# Paraphrase (STS) Evaluator
print("\n--- Setting up Paraphrase Evaluator from JSONL ---")
para_sents1 = []
para_sents2 = []
para_scores = []
max_score = 0.0
with open(config.SWE_PARA_PATH, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
para_sents1.append(data["sentence_1"])
para_sents2.append(data["sentence_2"])
score = float(data["label"])
para_scores.append(score)
if score > max_score:
max_score = score
if max_score > 0:
normalized_scores = [score / max_score for score in para_scores]
else:
normalized_scores = para_scores
# Apply prompts to paraphrase data
para_sents1 = ["task: semantic similarity | query: " + s for s in para_sents1]
para_sents2 = ["task: semantic similarity | query: " + s for s in para_sents2]
sweparaphrase_evaluator = EmbeddingSimilarityEvaluator(
sentences1=para_sents1,
sentences2=para_sents2,
scores=normalized_scores,
name="sweparaphrase-dev",
)
# FAQ Retrieval Evaluator
print("\n--- Setting up FAQ Retrieval Evaluator from JSONL ---")
with open(config.SWE_FAQ_PATH, "r", encoding="utf-8") as f:
faq_data = [json.loads(line) for line in f]
unique_answers = {answer for item in faq_data for answer in item["candidate_answers"]}
answer_to_cid = {answer: f"doc_{i}" for i, answer in enumerate(unique_answers)}
# Apply prompts to FAQ corpus (documents)
faq_corpus = {
cid: "title: none | text: " + answer
for answer, cid in answer_to_cid.items()
}
faq_queries = {}
faq_relevant_docs = {}
for i, item in enumerate(faq_data):
query_id = f"q_{i}"
faq_queries[query_id] = "task: search result | query: " + item['question']
correct_answer = item["candidate_answers"][item["label"]]
correct_cid = answer_to_cid[correct_answer]
faq_relevant_docs[query_id] = {correct_cid}
swefaq_evaluator = InformationRetrievalEvaluator(
queries=faq_queries,
corpus=faq_corpus,
relevant_docs=faq_relevant_docs,
name="swefaq-dev",
)
nanobeir_evaluators = create_nanobeir_evaluators(
languages=["sv", "no"],
dataset_names=[
"NanoMSMARCO",
"NanoNQ",
"NanoQuoraRetrieval",
"NanoFEVER",
"NanoHotpotQA"
]
)
main_evaluator = SequentialEvaluator(
evaluators=[
nli_triplet_evaluator,
# msmarco_triplet_evaluator,
nq_triplet_evaluator,
sweparaphrase_evaluator,
swefaq_evaluator,
wiki_cluster_eval,
sentiment_eval
] + nanobeir_evaluators,
)
print("\nAll evaluators combined and ready.")
# === CONFIGURE TRAINING ARGUMENTS ===
training_args = SentenceTransformerTrainingArguments(
output_dir=config.FT_OUTPUT_DIR,
num_train_epochs=1,
per_device_eval_batch_size=64,
per_device_train_batch_size=config.FT_CONCEPTUAL_BS,
learning_rate=config.FT_LR,
bf16=True,
report_to="wandb",
run_name=run_name,
save_total_limit=2,
logging_steps=config.LOGGING_STEPS,
eval_strategy=config.EVAL_STRATEGY,
save_strategy=config.EVAL_STRATEGY,
eval_steps=config.FT_EVAL_STEPS,
save_steps=config.FT_SAVE_STEPS,
load_best_model_at_end=True,
warmup_ratio=config.FT_WARMUP,
weight_decay=config.FT_WEIGHT_DECAY,
metric_for_best_model="eval_msmarco-triplet-dev_cosine_accuracy",
greater_is_better=True,
# Enable multi-dataset batch sampling for stratified batches
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
ddp_find_unused_parameters=False,
)
# === INITIALIZE AND RUN TRAINER ===
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
loss=loss,
evaluator=main_evaluator,
)
logging.info("--- Running initial evaluation on the untrained model (step 0) ---")
trainer.evaluate()
logging.info(f"Starting model fine-tuning for run: {run_name}")
trainer.train()
logging.info(f"--- Fine-tuning complete for {run_name}! ---")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A flexible fine-tuning script for Sentence Transformer models."
)
parser.add_argument(
"--fine_tune_model_path",
type=str,
required=True,
help="Path or Hub name of the model to fine-tune.",
)
parser.add_argument(
"--probe_run", action="store_true", help="If set, runs a short probe run."
)
cli_args = parser.parse_args()
main(cli_args)