from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation from torch.utils.data import DataLoader def main(): model_name = 'all-MiniLM-L6-v2' model = SentenceTransformer(model_name) # Training triplets (anchor, positive, negative) train_examples = [ InputExample(texts=["A man is playing a guitar", "A person is playing a guitar", "A woman is reading a book"]), InputExample(texts=["A dog is running in the park", "A dog runs in the park", "A cat is sleeping on the couch"]), InputExample(texts=["The car is fast", "A fast car", "A slow bicycle"]), InputExample(texts=["He is eating pizza", "He eats pizza", "She is drinking water"]), InputExample(texts=["The kids are playing outside", "Children are playing outdoors", "Adults are working inside"]), ] train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16) train_loss = losses.TripletLoss(model=model) # Validation triplets for evaluator val_examples = [ InputExample(texts=["A woman is cooking food", "A lady is preparing dinner", "A man is playing football"]), InputExample(texts=["A boy is riding a bike", "A child rides a bicycle", "An old man is reading a newspaper"]), InputExample(texts=["The plane is flying", "An aircraft is in the air", "A boat is sailing on the water"]), ] val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16) # Create a TripletEvaluator evaluator = evaluation.TripletEvaluator( anchors=["A woman is cooking food", "A boy is riding a bike", "The plane is flying"], positives=["A lady is preparing dinner", "A child rides a bicycle", "An aircraft is in the air"], negatives=["A man is playing football", "An old man is reading a newspaper", "A boat is sailing on the water"], name='val-triplet-eval', show_progress_bar=True ) num_epochs = 1 warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) # Train with evaluator model.fit( train_objectives=[(train_dataloader, train_loss)], epochs=num_epochs, warmup_steps=warmup_steps, evaluator=evaluator, evaluation_steps=10, # evaluate every 10 steps output_path='fine_tuned_sbert_triplet', show_progress_bar=True ) print("Model fine-tuned and saved at 'fine_tuned_sbert_triplet'") if __name__ == "__main__": main()