File size: 6,225 Bytes
d65bc08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import json
import random
import logging

import wandb
from sentence_transformers import SentenceTransformer, InputExample, evaluation
from sentence_transformers.losses import TripletLoss
from torch.utils.data import DataLoader

from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn

# β€”β€”β€”β€”β€”β€”β€” Setup logging & console β€”β€”β€”β€”β€”β€”β€”
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
console = Console()

# β€”β€”β€”β€”β€”β€”β€” Data loading & splitting β€”β€”β€”β€”β€”β€”β€”
def load_triplets_from_jsonl(path):
    examples = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            examples.append(InputExample(texts=[
                data['anchor'], data['positive'], data['negative']
            ]))
    return examples

def split_data(examples, train_ratio=0.8, seed=42):
    random.seed(seed)
    random.shuffle(examples)
    split = int(len(examples) * train_ratio)
    return examples[:split], examples[split:]

# β€”β€”β€”β€”β€”β€”β€” Custom loss to log training loss to W&B β€”β€”β€”β€”β€”β€”β€”
class WandbLoggingTripletLoss(TripletLoss):
    def __init__(self, model, progress_task, progress):
        super().__init__(model)
        self.progress_task = progress_task
        self.progress = progress
        self.step = 0

    # HuggingFace Trainer will call forward(sentence_features, labels)
    def forward(self, sentence_features, labels):
        loss = super().forward(sentence_features, labels)
        wandb.log({"train_step": self.step, "train_loss": loss.item()})
        self.progress.update(self.progress_task, advance=1)
        self.step += 1
        return loss

# β€”β€”β€”β€”β€”β€”β€” Custom evaluator to log validation metrics to W&B β€”β€”β€”β€”β€”β€”β€”
class RichLossEvaluator(evaluation.TripletEvaluator):
    def __init__(self, anchors, positives, negatives, name, progress_task, progress):
        super().__init__(
            anchors=anchors, positives=positives, negatives=negatives,
            name=name, show_progress_bar=False
        )
        self.progress_task = progress_task
        self.progress = progress

    def __call__(self, model, output_path=None, epoch=-1, steps=-1, **kwargs):
        metrics = super().__call__(model, output_path=output_path, epoch=epoch, steps=steps, **kwargs)
        # extract a scalar (accuracy or first metric)
        if isinstance(metrics, dict):
            val_score = metrics.get("accuracy", next(iter(metrics.values())))
        else:
            val_score = metrics
        wandb.log({
            "validation_step": steps,
            "validation_score": val_score,
            "epoch": epoch
        })
        # advance progress bar by 10 (since eval runs every 10 steps)
        self.progress.update(self.progress_task, advance=10)
        self.progress.console.log(f"[blue]Step {steps}[/blue]: Validation score: {val_score:.4f}")
        return metrics

def main():
    # β€”β€”β€”β€”β€”β€”β€” Configuration β€”β€”β€”β€”β€”β€”β€”
    model_name  = 'all-MiniLM-L6-v2'
    jsonl_path  = 'triplets.jsonl'
    output_path = 'fine_tuned_sbert_triplet'
    batch_size  = 16
    train_ratio = 0.8
    num_epochs  = 1

    # β€”β€”β€”β€”β€”β€”β€” Initialize model & W&B β€”β€”β€”β€”β€”β€”β€”
    model = SentenceTransformer(model_name)
    wandb.init(
        project="sbert-triplet-finetune",
        name="custom_run_name",  # distinct from output_path
        config={"model": model_name, "batch_size": batch_size, "epochs": num_epochs}
    )

    # β€”β€”β€”β€”β€”β€”β€” Load & split data β€”β€”β€”β€”β€”β€”β€”
    examples = load_triplets_from_jsonl(jsonl_path)
    train_exs, val_exs = split_data(examples, train_ratio=train_ratio)

    train_loader = DataLoader(train_exs, shuffle=True, batch_size=batch_size)
    val_loader   = DataLoader(val_exs, shuffle=False, batch_size=batch_size)

    anchors   = [e.texts[0] for e in val_exs]
    positives = [e.texts[1] for e in val_exs]
    negatives = [e.texts[2] for e in val_exs]

    total_train_steps = len(train_loader) * num_epochs
    warmup_steps      = int(0.1 * total_train_steps)

    console.print(f"[bold]Training on {len(train_exs)} triplets, validating on {len(val_exs)} triplets[/bold]\n")

    progress = Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
        TimeElapsedColumn(),
        transient=True,
    )

    with progress:
        task = progress.add_task("[green]Training...[/green]", total=total_train_steps)

        train_loss = WandbLoggingTripletLoss(model, task, progress)
        evaluator  = RichLossEvaluator(anchors, positives, negatives, 'val-triplet', task, progress)

        model.fit(
            train_objectives=[(train_loader, train_loss)],
            evaluator=evaluator,
            epochs=num_epochs,
            warmup_steps=warmup_steps,
            evaluation_steps=10,
            output_path=output_path,
            show_progress_bar=False
            # <-- no `callback=[]` here
        )
        progress.update(task, completed=total_train_steps)

    # β€”β€”β€”β€”β€”β€”β€” Save & finish W&B β€”β€”β€”β€”β€”β€”β€”
    model.save(output_path)
    wandb.save(os.path.join(output_path, "pytorch_model.bin"))
    wandb.finish()

    # β€”β€”β€”β€”β€”β€”β€” Final summary β€”β€”β€”β€”β€”β€”β€”
    summary = Table(title="Training Summary")
    summary.add_column("Metric", style="cyan", no_wrap=True)
    summary.add_column("Value", style="magenta")

    summary.add_row("Model Name",          model_name)
    summary.add_row("Total Triplets",      str(len(examples)))
    summary.add_row("Training Triplets",   str(len(train_exs)))
    summary.add_row("Validation Triplets", str(len(val_exs)))
    summary.add_row("Epochs",              str(num_epochs))
    summary.add_row("Output Path",         output_path)

    console.print(summary)
    console.print("[bold green]βœ… Fine‑tuning complete and model saved![/bold green]")

if __name__ == "__main__":
    main()