|
|
import os |
|
|
import json |
|
|
import random |
|
|
import logging |
|
|
|
|
|
import wandb |
|
|
from sentence_transformers import SentenceTransformer, InputExample, evaluation |
|
|
from sentence_transformers.losses import TripletLoss |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from rich.console import Console |
|
|
from rich.table import Table |
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn |
|
|
|
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) |
|
|
console = Console() |
|
|
|
|
|
|
|
|
def load_triplets_from_jsonl(path): |
|
|
examples = [] |
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
data = json.loads(line) |
|
|
examples.append(InputExample(texts=[ |
|
|
data['anchor'], data['positive'], data['negative'] |
|
|
])) |
|
|
return examples |
|
|
|
|
|
def split_data(examples, train_ratio=0.8, seed=42): |
|
|
random.seed(seed) |
|
|
random.shuffle(examples) |
|
|
split = int(len(examples) * train_ratio) |
|
|
return examples[:split], examples[split:] |
|
|
|
|
|
|
|
|
class WandbLoggingTripletLoss(TripletLoss): |
|
|
def __init__(self, model, progress_task, progress): |
|
|
super().__init__(model) |
|
|
self.progress_task = progress_task |
|
|
self.progress = progress |
|
|
self.step = 0 |
|
|
|
|
|
|
|
|
def forward(self, sentence_features, labels): |
|
|
loss = super().forward(sentence_features, labels) |
|
|
wandb.log({"train_step": self.step, "train_loss": loss.item()}) |
|
|
self.progress.update(self.progress_task, advance=1) |
|
|
self.step += 1 |
|
|
return loss |
|
|
|
|
|
|
|
|
class RichLossEvaluator(evaluation.TripletEvaluator): |
|
|
def __init__(self, anchors, positives, negatives, name, progress_task, progress): |
|
|
super().__init__( |
|
|
anchors=anchors, positives=positives, negatives=negatives, |
|
|
name=name, show_progress_bar=False |
|
|
) |
|
|
self.progress_task = progress_task |
|
|
self.progress = progress |
|
|
|
|
|
def __call__(self, model, output_path=None, epoch=-1, steps=-1, **kwargs): |
|
|
metrics = super().__call__(model, output_path=output_path, epoch=epoch, steps=steps, **kwargs) |
|
|
|
|
|
if isinstance(metrics, dict): |
|
|
val_score = metrics.get("accuracy", next(iter(metrics.values()))) |
|
|
else: |
|
|
val_score = metrics |
|
|
wandb.log({ |
|
|
"validation_step": steps, |
|
|
"validation_score": val_score, |
|
|
"epoch": epoch |
|
|
}) |
|
|
|
|
|
self.progress.update(self.progress_task, advance=10) |
|
|
self.progress.console.log(f"[blue]Step {steps}[/blue]: Validation score: {val_score:.4f}") |
|
|
return metrics |
|
|
|
|
|
def main(): |
|
|
|
|
|
model_name = 'all-MiniLM-L6-v2' |
|
|
jsonl_path = 'triplets.jsonl' |
|
|
output_path = 'fine_tuned_sbert_triplet' |
|
|
batch_size = 16 |
|
|
train_ratio = 0.8 |
|
|
num_epochs = 1 |
|
|
|
|
|
|
|
|
model = SentenceTransformer(model_name) |
|
|
wandb.init( |
|
|
project="sbert-triplet-finetune", |
|
|
name="custom_run_name", |
|
|
config={"model": model_name, "batch_size": batch_size, "epochs": num_epochs} |
|
|
) |
|
|
|
|
|
|
|
|
examples = load_triplets_from_jsonl(jsonl_path) |
|
|
train_exs, val_exs = split_data(examples, train_ratio=train_ratio) |
|
|
|
|
|
train_loader = DataLoader(train_exs, shuffle=True, batch_size=batch_size) |
|
|
val_loader = DataLoader(val_exs, shuffle=False, batch_size=batch_size) |
|
|
|
|
|
anchors = [e.texts[0] for e in val_exs] |
|
|
positives = [e.texts[1] for e in val_exs] |
|
|
negatives = [e.texts[2] for e in val_exs] |
|
|
|
|
|
total_train_steps = len(train_loader) * num_epochs |
|
|
warmup_steps = int(0.1 * total_train_steps) |
|
|
|
|
|
console.print(f"[bold]Training on {len(train_exs)} triplets, validating on {len(val_exs)} triplets[/bold]\n") |
|
|
|
|
|
progress = Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
BarColumn(), |
|
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), |
|
|
TimeElapsedColumn(), |
|
|
transient=True, |
|
|
) |
|
|
|
|
|
with progress: |
|
|
task = progress.add_task("[green]Training...[/green]", total=total_train_steps) |
|
|
|
|
|
train_loss = WandbLoggingTripletLoss(model, task, progress) |
|
|
evaluator = RichLossEvaluator(anchors, positives, negatives, 'val-triplet', task, progress) |
|
|
|
|
|
model.fit( |
|
|
train_objectives=[(train_loader, train_loss)], |
|
|
evaluator=evaluator, |
|
|
epochs=num_epochs, |
|
|
warmup_steps=warmup_steps, |
|
|
evaluation_steps=10, |
|
|
output_path=output_path, |
|
|
show_progress_bar=False |
|
|
|
|
|
) |
|
|
progress.update(task, completed=total_train_steps) |
|
|
|
|
|
|
|
|
model.save(output_path) |
|
|
wandb.save(os.path.join(output_path, "pytorch_model.bin")) |
|
|
wandb.finish() |
|
|
|
|
|
|
|
|
summary = Table(title="Training Summary") |
|
|
summary.add_column("Metric", style="cyan", no_wrap=True) |
|
|
summary.add_column("Value", style="magenta") |
|
|
|
|
|
summary.add_row("Model Name", model_name) |
|
|
summary.add_row("Total Triplets", str(len(examples))) |
|
|
summary.add_row("Training Triplets", str(len(train_exs))) |
|
|
summary.add_row("Validation Triplets", str(len(val_exs))) |
|
|
summary.add_row("Epochs", str(num_epochs)) |
|
|
summary.add_row("Output Path", output_path) |
|
|
|
|
|
console.print(summary) |
|
|
console.print("[bold green]β
Fineβtuning complete and model saved![/bold green]") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|