File size: 6,225 Bytes
d65bc08 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import os
import json
import random
import logging
import wandb
from sentence_transformers import SentenceTransformer, InputExample, evaluation
from sentence_transformers.losses import TripletLoss
from torch.utils.data import DataLoader
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
# βββββββ Setup logging & console βββββββ
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
console = Console()
# βββββββ Data loading & splitting βββββββ
def load_triplets_from_jsonl(path):
examples = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
examples.append(InputExample(texts=[
data['anchor'], data['positive'], data['negative']
]))
return examples
def split_data(examples, train_ratio=0.8, seed=42):
random.seed(seed)
random.shuffle(examples)
split = int(len(examples) * train_ratio)
return examples[:split], examples[split:]
# βββββββ Custom loss to log training loss to W&B βββββββ
class WandbLoggingTripletLoss(TripletLoss):
def __init__(self, model, progress_task, progress):
super().__init__(model)
self.progress_task = progress_task
self.progress = progress
self.step = 0
# HuggingFace Trainer will call forward(sentence_features, labels)
def forward(self, sentence_features, labels):
loss = super().forward(sentence_features, labels)
wandb.log({"train_step": self.step, "train_loss": loss.item()})
self.progress.update(self.progress_task, advance=1)
self.step += 1
return loss
# βββββββ Custom evaluator to log validation metrics to W&B βββββββ
class RichLossEvaluator(evaluation.TripletEvaluator):
def __init__(self, anchors, positives, negatives, name, progress_task, progress):
super().__init__(
anchors=anchors, positives=positives, negatives=negatives,
name=name, show_progress_bar=False
)
self.progress_task = progress_task
self.progress = progress
def __call__(self, model, output_path=None, epoch=-1, steps=-1, **kwargs):
metrics = super().__call__(model, output_path=output_path, epoch=epoch, steps=steps, **kwargs)
# extract a scalar (accuracy or first metric)
if isinstance(metrics, dict):
val_score = metrics.get("accuracy", next(iter(metrics.values())))
else:
val_score = metrics
wandb.log({
"validation_step": steps,
"validation_score": val_score,
"epoch": epoch
})
# advance progress bar by 10 (since eval runs every 10 steps)
self.progress.update(self.progress_task, advance=10)
self.progress.console.log(f"[blue]Step {steps}[/blue]: Validation score: {val_score:.4f}")
return metrics
def main():
# βββββββ Configuration βββββββ
model_name = 'all-MiniLM-L6-v2'
jsonl_path = 'triplets.jsonl'
output_path = 'fine_tuned_sbert_triplet'
batch_size = 16
train_ratio = 0.8
num_epochs = 1
# βββββββ Initialize model & W&B βββββββ
model = SentenceTransformer(model_name)
wandb.init(
project="sbert-triplet-finetune",
name="custom_run_name", # distinct from output_path
config={"model": model_name, "batch_size": batch_size, "epochs": num_epochs}
)
# βββββββ Load & split data βββββββ
examples = load_triplets_from_jsonl(jsonl_path)
train_exs, val_exs = split_data(examples, train_ratio=train_ratio)
train_loader = DataLoader(train_exs, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_exs, shuffle=False, batch_size=batch_size)
anchors = [e.texts[0] for e in val_exs]
positives = [e.texts[1] for e in val_exs]
negatives = [e.texts[2] for e in val_exs]
total_train_steps = len(train_loader) * num_epochs
warmup_steps = int(0.1 * total_train_steps)
console.print(f"[bold]Training on {len(train_exs)} triplets, validating on {len(val_exs)} triplets[/bold]\n")
progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
transient=True,
)
with progress:
task = progress.add_task("[green]Training...[/green]", total=total_train_steps)
train_loss = WandbLoggingTripletLoss(model, task, progress)
evaluator = RichLossEvaluator(anchors, positives, negatives, 'val-triplet', task, progress)
model.fit(
train_objectives=[(train_loader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
warmup_steps=warmup_steps,
evaluation_steps=10,
output_path=output_path,
show_progress_bar=False
# <-- no `callback=[]` here
)
progress.update(task, completed=total_train_steps)
# βββββββ Save & finish W&B βββββββ
model.save(output_path)
wandb.save(os.path.join(output_path, "pytorch_model.bin"))
wandb.finish()
# βββββββ Final summary βββββββ
summary = Table(title="Training Summary")
summary.add_column("Metric", style="cyan", no_wrap=True)
summary.add_column("Value", style="magenta")
summary.add_row("Model Name", model_name)
summary.add_row("Total Triplets", str(len(examples)))
summary.add_row("Training Triplets", str(len(train_exs)))
summary.add_row("Validation Triplets", str(len(val_exs)))
summary.add_row("Epochs", str(num_epochs))
summary.add_row("Output Path", output_path)
console.print(summary)
console.print("[bold green]β
Fineβtuning complete and model saved![/bold green]")
if __name__ == "__main__":
main()
|