File size: 7,241 Bytes
0c6d13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import (
    MultipleNegativesRankingLoss,
    OnlineContrastiveLoss,
    CoSENTLoss,
    GISTEmbedLoss,
    TripletLoss,
)
import pandas as pd


class EmbeddingFinetuner:
    """

    A class for finetuning SentenceTransformer models on various loss functions.



    Supports the following loss functions:

        - MultipleNegativesRankingLoss

        - OnlineContrastiveLoss

        - CoSENTLoss

        - GISTEmbedLoss

        - TripletLoss



    Loads data from an xlsx file named "emb_data.xlsx".

    """

    def __init__(

        self,

        model_name="microsoft/mpnet-base",

        loss_function="MultipleNegativesRankingLoss",

        epochs=1,

        batch_size=16,

        test_size=0.1,

    ):
        """

        Initializes the EmbeddingFinetuner.



        Args:

            model_name (str): Name of the SentenceTransformer model to use.

            loss_function (str): Name of the loss function to use.

            epochs (int): Number of training epochs.

            batch_size (int): Batch size for training.

            test_size (float): Proportion of the dataset to include in the test split. 

                              If less than 1, no test set is created.

        """
        self.model_name = model_name
        self.loss_function = loss_function
        self.epochs = epochs
        self.batch_size = batch_size
        self.test_size = test_size

        self.model = SentenceTransformer(self.model_name)
        self.train_dataset, self.dev_dataset, self.test_dataset = self._load_data()
        self.loss = self._get_loss_function()

    def _load_data(self):
        """

        Loads data from "emb_data.xlsx" and prepares it for the selected loss function.

        """
        df = pd.read_excel(f"data/emb_data.xlsx")

        if self.loss_function == "MultipleNegativesRankingLoss":
            """

            Expects data in the format:

            | anchor | positive | negative |

            |---|---|---|

            | sentence1 | sentence2 | sentence3 |

            | ... | ... | ... |



            Where 'anchor' is the sentence to be embedded, 'positive' is a sentence 

            semantically similar to the anchor, and 'negative' is a sentence 

            semantically dissimilar to the anchor.

            """
            dataset = Dataset.from_pandas(df)

        elif self.loss_function == "OnlineContrastiveLoss":
            """

            Expects data in the format:

            | sentence1 | sentence2 | label |

            |---|---|---|

            | sentenceA | sentenceB | 1 | 

            | sentenceC | sentenceD | 0 |

            | ... | ... | ... |



            Where 'sentence1' and 'sentence2' are pairs of sentences, and 'label' 

            indicates whether they are semantically similar (1) or dissimilar (0).

            """
            dataset = Dataset.from_pandas(df)

        elif self.loss_function == "CoSENTLoss":
            """

            Expects data in the format:

            | sentence1 | sentence2 | score |

            |---|---|---|

            | sentenceA | sentenceB | 0.8 |

            | sentenceC | sentenceD | 0.2 |

            | ... | ... | ... |



            Where 'sentence1' and 'sentence2' are pairs of sentences, and 'score' 

            is a float value representing their similarity (e.g., from 0 to 1).

            """
            dataset = Dataset.from_pandas(df)

        elif self.loss_function == "GISTEmbedLoss":
            """

            Expects data in either of the following formats:



            Triplets:

            | anchor | positive | negative |

            |---|---|---|

            | sentence1 | sentence2 | sentence3 |

            | ... | ... | ... |



            Pairs:

            | anchor | positive |

            |---|---|

            | sentence1 | sentence2 |

            | ... | ... |



            Where 'anchor' is the sentence to be embedded, 'positive' is a sentence 

            semantically similar to the anchor, and 'negative' (if present) is a 

            sentence semantically dissimilar to the anchor.

            """
            dataset = Dataset.from_pandas(df)

        elif self.loss_function == "TripletLoss":
            """

            Expects data in the format:

            | anchor | positive | negative |

            |---|---|---|

            | sentence1 | sentence2 | sentence3 |

            | ... | ... | ... |



            Where 'anchor' is the sentence to be embedded, 'positive' is a sentence 

            semantically similar to the anchor, and 'negative' is a sentence 

            semantically dissimilar to the anchor.

            """
            dataset = Dataset.from_pandas(df)

        else:
            raise ValueError(f"Unsupported loss function: {self.loss_function}")

        # Split into train and dev
        train_dev_dataset = dataset.train_test_split(test_size=self.test_size)
        train_dataset = train_dev_dataset["train"]
        dev_dataset = train_dev_dataset["test"]
        test_dataset = None

        return train_dataset, dev_dataset, test_dataset

    def _get_loss_function(self):
        """

        Returns the selected loss function instance.

        """
        if self.loss_function == "MultipleNegativesRankingLoss":
            return MultipleNegativesRankingLoss(self.model)
        elif self.loss_function == "OnlineContrastiveLoss":
            return OnlineContrastiveLoss(self.model)
        elif self.loss_function == "CoSENTLoss":
            return CoSENTLoss(self.model)
        elif self.loss_function == "GISTEmbedLoss":
            guide_model = SentenceTransformer("all-MiniLM-L6-v2")  # You can change this
            return GISTEmbedLoss(self.model, guide_model)
        elif self.loss_function == "TripletLoss":
            return TripletLoss(self.model)
        else:
            raise ValueError(f"Unsupported loss function: {self.loss_function}")

    def train(self):
        """

        Trains the SentenceTransformer model using the specified loss function.

        """
        args = SentenceTransformerTrainingArguments(
            output_dir=f"models/{self.model_name}-{self.loss_function}",
            num_train_epochs=self.epochs,
            per_device_train_batch_size=self.batch_size,
            per_device_eval_batch_size=self.batch_size,
            evaluation_strategy="epoch",
            # ... other training arguments as needed ...
        )

        trainer = SentenceTransformerTrainer(
            model=self.model,
            args=args,
            train_dataset=self.train_dataset,
            eval_dataset=self.dev_dataset,
            loss=self.loss,
        )
        trainer.train()

        # Save the trained model
        self.model.save_pretrained(
            f"models/emb-{self.model_name}-{self.loss_function}"
        )
        
        return True