|
|
from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
def main(): |
|
|
model_name = 'all-MiniLM-L6-v2' |
|
|
model = SentenceTransformer(model_name) |
|
|
|
|
|
|
|
|
train_examples = [ |
|
|
InputExample(texts=["A man is playing a guitar", "A person is playing a guitar", "A woman is reading a book"]), |
|
|
InputExample(texts=["A dog is running in the park", "A dog runs in the park", "A cat is sleeping on the couch"]), |
|
|
InputExample(texts=["The car is fast", "A fast car", "A slow bicycle"]), |
|
|
InputExample(texts=["He is eating pizza", "He eats pizza", "She is drinking water"]), |
|
|
InputExample(texts=["The kids are playing outside", "Children are playing outdoors", "Adults are working inside"]), |
|
|
] |
|
|
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16) |
|
|
train_loss = losses.TripletLoss(model=model) |
|
|
|
|
|
|
|
|
val_examples = [ |
|
|
InputExample(texts=["A woman is cooking food", "A lady is preparing dinner", "A man is playing football"]), |
|
|
InputExample(texts=["A boy is riding a bike", "A child rides a bicycle", "An old man is reading a newspaper"]), |
|
|
InputExample(texts=["The plane is flying", "An aircraft is in the air", "A boat is sailing on the water"]), |
|
|
] |
|
|
val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16) |
|
|
|
|
|
|
|
|
evaluator = evaluation.TripletEvaluator( |
|
|
anchors=["A woman is cooking food", "A boy is riding a bike", "The plane is flying"], |
|
|
positives=["A lady is preparing dinner", "A child rides a bicycle", "An aircraft is in the air"], |
|
|
negatives=["A man is playing football", "An old man is reading a newspaper", "A boat is sailing on the water"], |
|
|
name='val-triplet-eval', |
|
|
show_progress_bar=True |
|
|
) |
|
|
|
|
|
num_epochs = 1 |
|
|
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) |
|
|
|
|
|
|
|
|
model.fit( |
|
|
train_objectives=[(train_dataloader, train_loss)], |
|
|
epochs=num_epochs, |
|
|
warmup_steps=warmup_steps, |
|
|
evaluator=evaluator, |
|
|
evaluation_steps=10, |
|
|
output_path='fine_tuned_sbert_triplet', |
|
|
show_progress_bar=True |
|
|
) |
|
|
|
|
|
print("Model fine-tuned and saved at 'fine_tuned_sbert_triplet'") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|