File size: 4,583 Bytes
d65bc08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
import random
from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation
from torch.utils.data import DataLoader
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn

console = Console()

def load_triplets_from_jsonl(file_path):
    examples = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            examples.append(
                InputExample(texts=[data['anchor'], data['positive'], data['negative']])
            )
    return examples

def split_data(examples, train_ratio=0.8, seed=42):
    random.seed(seed)
    random.shuffle(examples)
    train_size = int(len(examples) * train_ratio)
    return examples[:train_size], examples[train_size:]

def main():
    model_name = 'all-MiniLM-L6-v2'
    model = SentenceTransformer(model_name)

    # Load all triplets from a single JSONL file
    all_examples = load_triplets_from_jsonl('triplets.jsonl')

    # Split into train and validation
    train_examples, val_examples = split_data(all_examples, train_ratio=0.8)
    
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
    train_loss = losses.TripletLoss(model=model)
    val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16)

    anchors = [ex.texts[0] for ex in val_examples]
    positives = [ex.texts[1] for ex in val_examples]
    negatives = [ex.texts[2] for ex in val_examples]

    evaluator = evaluation.TripletEvaluator(
        anchors=anchors,
        positives=positives,
        negatives=negatives,
        name='val-triplet-eval',
        show_progress_bar=False
    )

    num_epochs = 1
    warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)
    output_path = 'fine_tuned_sbert_triplet'

    console.print(f"[bold]Starting training with {len(train_examples)} triplets, validating on {len(val_examples)} triplets.[/bold]\n")

    progress = Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
        TimeElapsedColumn(),
        transient=True,
    )

    class RichLossEvaluator(evaluation.TripletEvaluator):
        def __init__(self, *args, progress_task, **kwargs):
            super().__init__(*args, **kwargs)
            self.progress_task = progress_task

        def __call__(self, model, output_path=None, epoch=-1, steps=-1, **kwargs):
            score_dict = super().__call__(model, output_path=output_path, epoch=epoch, steps=steps, **kwargs)
            if isinstance(score_dict, dict):
                if "accuracy" in score_dict:
                    val_score = score_dict["accuracy"]
                else:
                    val_score = next(iter(score_dict.values()))
            else:
                val_score = score_dict

            progress.update(self.progress_task, advance=10)  # assuming evaluation every 10 steps
            progress.console.log(f"Step {steps}: Validation score: {val_score:.4f}")
            return score_dict

    with progress:
        task = progress.add_task("[green]Training...[/green]", total=num_epochs * len(train_dataloader))
        rich_evaluator = RichLossEvaluator(
            anchors=anchors,
            positives=positives,
            negatives=negatives,
            name='val-triplet-eval',
            show_progress_bar=False,
            progress_task=task
        )

        model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            epochs=num_epochs,
            warmup_steps=warmup_steps,
            evaluator=rich_evaluator,
            evaluation_steps=10,
            output_path=output_path,
            show_progress_bar=False
        )
        progress.update(task, completed=num_epochs * len(train_dataloader))

    table = Table(title="Training Summary")
    table.add_column("Metric", style="cyan", no_wrap=True)
    table.add_column("Value", style="magenta")

    table.add_row("Model Name", model_name)
    table.add_row("Total Triplets", str(len(all_examples)))
    table.add_row("Training Triplets", str(len(train_examples)))
    table.add_row("Validation Triplets", str(len(val_examples)))
    table.add_row("Epochs", str(num_epochs))
    table.add_row("Output Path", output_path)

    console.print(table)
    console.print("[bold green]Model fine-tuned and saved successfully![/bold green]")

if __name__ == "__main__":
    main()