File size: 7,241 Bytes
0c6d13f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
from datasets import load_dataset, Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import (
MultipleNegativesRankingLoss,
OnlineContrastiveLoss,
CoSENTLoss,
GISTEmbedLoss,
TripletLoss,
)
import pandas as pd
class EmbeddingFinetuner:
"""
A class for finetuning SentenceTransformer models on various loss functions.
Supports the following loss functions:
- MultipleNegativesRankingLoss
- OnlineContrastiveLoss
- CoSENTLoss
- GISTEmbedLoss
- TripletLoss
Loads data from an xlsx file named "emb_data.xlsx".
"""
def __init__(
self,
model_name="microsoft/mpnet-base",
loss_function="MultipleNegativesRankingLoss",
epochs=1,
batch_size=16,
test_size=0.1,
):
"""
Initializes the EmbeddingFinetuner.
Args:
model_name (str): Name of the SentenceTransformer model to use.
loss_function (str): Name of the loss function to use.
epochs (int): Number of training epochs.
batch_size (int): Batch size for training.
test_size (float): Proportion of the dataset to include in the test split.
If less than 1, no test set is created.
"""
self.model_name = model_name
self.loss_function = loss_function
self.epochs = epochs
self.batch_size = batch_size
self.test_size = test_size
self.model = SentenceTransformer(self.model_name)
self.train_dataset, self.dev_dataset, self.test_dataset = self._load_data()
self.loss = self._get_loss_function()
def _load_data(self):
"""
Loads data from "emb_data.xlsx" and prepares it for the selected loss function.
"""
df = pd.read_excel(f"data/emb_data.xlsx")
if self.loss_function == "MultipleNegativesRankingLoss":
"""
Expects data in the format:
| anchor | positive | negative |
|---|---|---|
| sentence1 | sentence2 | sentence3 |
| ... | ... | ... |
Where 'anchor' is the sentence to be embedded, 'positive' is a sentence
semantically similar to the anchor, and 'negative' is a sentence
semantically dissimilar to the anchor.
"""
dataset = Dataset.from_pandas(df)
elif self.loss_function == "OnlineContrastiveLoss":
"""
Expects data in the format:
| sentence1 | sentence2 | label |
|---|---|---|
| sentenceA | sentenceB | 1 |
| sentenceC | sentenceD | 0 |
| ... | ... | ... |
Where 'sentence1' and 'sentence2' are pairs of sentences, and 'label'
indicates whether they are semantically similar (1) or dissimilar (0).
"""
dataset = Dataset.from_pandas(df)
elif self.loss_function == "CoSENTLoss":
"""
Expects data in the format:
| sentence1 | sentence2 | score |
|---|---|---|
| sentenceA | sentenceB | 0.8 |
| sentenceC | sentenceD | 0.2 |
| ... | ... | ... |
Where 'sentence1' and 'sentence2' are pairs of sentences, and 'score'
is a float value representing their similarity (e.g., from 0 to 1).
"""
dataset = Dataset.from_pandas(df)
elif self.loss_function == "GISTEmbedLoss":
"""
Expects data in either of the following formats:
Triplets:
| anchor | positive | negative |
|---|---|---|
| sentence1 | sentence2 | sentence3 |
| ... | ... | ... |
Pairs:
| anchor | positive |
|---|---|
| sentence1 | sentence2 |
| ... | ... |
Where 'anchor' is the sentence to be embedded, 'positive' is a sentence
semantically similar to the anchor, and 'negative' (if present) is a
sentence semantically dissimilar to the anchor.
"""
dataset = Dataset.from_pandas(df)
elif self.loss_function == "TripletLoss":
"""
Expects data in the format:
| anchor | positive | negative |
|---|---|---|
| sentence1 | sentence2 | sentence3 |
| ... | ... | ... |
Where 'anchor' is the sentence to be embedded, 'positive' is a sentence
semantically similar to the anchor, and 'negative' is a sentence
semantically dissimilar to the anchor.
"""
dataset = Dataset.from_pandas(df)
else:
raise ValueError(f"Unsupported loss function: {self.loss_function}")
# Split into train and dev
train_dev_dataset = dataset.train_test_split(test_size=self.test_size)
train_dataset = train_dev_dataset["train"]
dev_dataset = train_dev_dataset["test"]
test_dataset = None
return train_dataset, dev_dataset, test_dataset
def _get_loss_function(self):
"""
Returns the selected loss function instance.
"""
if self.loss_function == "MultipleNegativesRankingLoss":
return MultipleNegativesRankingLoss(self.model)
elif self.loss_function == "OnlineContrastiveLoss":
return OnlineContrastiveLoss(self.model)
elif self.loss_function == "CoSENTLoss":
return CoSENTLoss(self.model)
elif self.loss_function == "GISTEmbedLoss":
guide_model = SentenceTransformer("all-MiniLM-L6-v2") # You can change this
return GISTEmbedLoss(self.model, guide_model)
elif self.loss_function == "TripletLoss":
return TripletLoss(self.model)
else:
raise ValueError(f"Unsupported loss function: {self.loss_function}")
def train(self):
"""
Trains the SentenceTransformer model using the specified loss function.
"""
args = SentenceTransformerTrainingArguments(
output_dir=f"models/{self.model_name}-{self.loss_function}",
num_train_epochs=self.epochs,
per_device_train_batch_size=self.batch_size,
per_device_eval_batch_size=self.batch_size,
evaluation_strategy="epoch",
# ... other training arguments as needed ...
)
trainer = SentenceTransformerTrainer(
model=self.model,
args=args,
train_dataset=self.train_dataset,
eval_dataset=self.dev_dataset,
loss=self.loss,
)
trainer.train()
# Save the trained model
self.model.save_pretrained(
f"models/emb-{self.model_name}-{self.loss_function}"
)
return True
|