File size: 2,467 Bytes
d65bc08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation
from torch.utils.data import DataLoader

def main():
    model_name = 'all-MiniLM-L6-v2'
    model = SentenceTransformer(model_name)

    # Training triplets (anchor, positive, negative)
    train_examples = [
        InputExample(texts=["A man is playing a guitar", "A person is playing a guitar", "A woman is reading a book"]),
        InputExample(texts=["A dog is running in the park", "A dog runs in the park", "A cat is sleeping on the couch"]),
        InputExample(texts=["The car is fast", "A fast car", "A slow bicycle"]),
        InputExample(texts=["He is eating pizza", "He eats pizza", "She is drinking water"]),
        InputExample(texts=["The kids are playing outside", "Children are playing outdoors", "Adults are working inside"]),
    ]
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
    train_loss = losses.TripletLoss(model=model)

    # Validation triplets for evaluator
    val_examples = [
        InputExample(texts=["A woman is cooking food", "A lady is preparing dinner", "A man is playing football"]),
        InputExample(texts=["A boy is riding a bike", "A child rides a bicycle", "An old man is reading a newspaper"]),
        InputExample(texts=["The plane is flying", "An aircraft is in the air", "A boat is sailing on the water"]),
    ]
    val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16)

    # Create a TripletEvaluator
    evaluator = evaluation.TripletEvaluator(
        anchors=["A woman is cooking food", "A boy is riding a bike", "The plane is flying"],
        positives=["A lady is preparing dinner", "A child rides a bicycle", "An aircraft is in the air"],
        negatives=["A man is playing football", "An old man is reading a newspaper", "A boat is sailing on the water"],
        name='val-triplet-eval',
        show_progress_bar=True
    )

    num_epochs = 1
    warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)

    # Train with evaluator
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=num_epochs,
        warmup_steps=warmup_steps,
        evaluator=evaluator,
        evaluation_steps=10,  # evaluate every 10 steps
        output_path='fine_tuned_sbert_triplet',
        show_progress_bar=True
    )

    print("Model fine-tuned and saved at 'fine_tuned_sbert_triplet'")

if __name__ == "__main__":
    main()