File size: 1,913 Bytes
d65bc08
a5b24c9
 
 
d65bc08
 
 
 
 
 
 
 
 
 
a5b24c9
 
 
 
d65bc08
 
a5b24c9
 
 
d65bc08
 
a5b24c9
 
d65bc08
 
 
 
 
a5b24c9
d65bc08
 
 
a5b24c9
 
 
 
 
 
 
d65bc08
a5b24c9
 
 
 
 
d65bc08
a5b24c9
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation
from torch.utils.data import DataLoader

def load_triplets_from_jsonl(file_path):
    examples = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            examples.append(
                InputExample(texts=[data['anchor'], data['positive'], data['negative']])
            )
    return examples

def main():
    model_name = 'all-MiniLM-L6-v2'
    model = SentenceTransformer(model_name)

    # Load training triplets from JSONL
    train_examples = load_triplets_from_jsonl('train_triplets.jsonl')
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
    train_loss = losses.TripletLoss(model=model)

    # Load validation triplets from JSONL
    val_examples = load_triplets_from_jsonl('val_triplets.jsonl')
    val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16)

    # Extract anchors, positives, negatives for evaluator from val_examples
    anchors = [ex.texts[0] for ex in val_examples]
    positives = [ex.texts[1] for ex in val_examples]
    negatives = [ex.texts[2] for ex in val_examples]

    evaluator = evaluation.TripletEvaluator(
        anchors=anchors,
        positives=positives,
        negatives=negatives,
        name='val-triplet-eval',
        show_progress_bar=True
    )

    num_epochs = 1
    warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)

    # Train the model
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=num_epochs,
        warmup_steps=warmup_steps,
        evaluator=evaluator,
        evaluation_steps=10,
        output_path='fine_tuned_sbert_triplet',
        show_progress_bar=True
    )

    print("Model fine-tuned and saved at 'fine_tuned_sbert_triplet'")

if __name__ == "__main__":
    main()