File size: 2,467 Bytes
d65bc08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation
from torch.utils.data import DataLoader
def main():
model_name = 'all-MiniLM-L6-v2'
model = SentenceTransformer(model_name)
# Training triplets (anchor, positive, negative)
train_examples = [
InputExample(texts=["A man is playing a guitar", "A person is playing a guitar", "A woman is reading a book"]),
InputExample(texts=["A dog is running in the park", "A dog runs in the park", "A cat is sleeping on the couch"]),
InputExample(texts=["The car is fast", "A fast car", "A slow bicycle"]),
InputExample(texts=["He is eating pizza", "He eats pizza", "She is drinking water"]),
InputExample(texts=["The kids are playing outside", "Children are playing outdoors", "Adults are working inside"]),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.TripletLoss(model=model)
# Validation triplets for evaluator
val_examples = [
InputExample(texts=["A woman is cooking food", "A lady is preparing dinner", "A man is playing football"]),
InputExample(texts=["A boy is riding a bike", "A child rides a bicycle", "An old man is reading a newspaper"]),
InputExample(texts=["The plane is flying", "An aircraft is in the air", "A boat is sailing on the water"]),
]
val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16)
# Create a TripletEvaluator
evaluator = evaluation.TripletEvaluator(
anchors=["A woman is cooking food", "A boy is riding a bike", "The plane is flying"],
positives=["A lady is preparing dinner", "A child rides a bicycle", "An aircraft is in the air"],
negatives=["A man is playing football", "An old man is reading a newspaper", "A boat is sailing on the water"],
name='val-triplet-eval',
show_progress_bar=True
)
num_epochs = 1
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)
# Train with evaluator
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
warmup_steps=warmup_steps,
evaluator=evaluator,
evaluation_steps=10, # evaluate every 10 steps
output_path='fine_tuned_sbert_triplet',
show_progress_bar=True
)
print("Model fine-tuned and saved at 'fine_tuned_sbert_triplet'")
if __name__ == "__main__":
main()
|