triplet-embed / train_wandb.py
akhilreddygogula
gitignore updated
d65bc08
import os
import json
import random
import logging
import wandb
from sentence_transformers import SentenceTransformer, InputExample, evaluation
from sentence_transformers.losses import TripletLoss
from torch.utils.data import DataLoader
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
# β€”β€”β€”β€”β€”β€”β€” Setup logging & console β€”β€”β€”β€”β€”β€”β€”
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
console = Console()
# β€”β€”β€”β€”β€”β€”β€” Data loading & splitting β€”β€”β€”β€”β€”β€”β€”
def load_triplets_from_jsonl(path):
examples = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
examples.append(InputExample(texts=[
data['anchor'], data['positive'], data['negative']
]))
return examples
def split_data(examples, train_ratio=0.8, seed=42):
random.seed(seed)
random.shuffle(examples)
split = int(len(examples) * train_ratio)
return examples[:split], examples[split:]
# β€”β€”β€”β€”β€”β€”β€” Custom loss to log training loss to W&B β€”β€”β€”β€”β€”β€”β€”
class WandbLoggingTripletLoss(TripletLoss):
def __init__(self, model, progress_task, progress):
super().__init__(model)
self.progress_task = progress_task
self.progress = progress
self.step = 0
# HuggingFace Trainer will call forward(sentence_features, labels)
def forward(self, sentence_features, labels):
loss = super().forward(sentence_features, labels)
wandb.log({"train_step": self.step, "train_loss": loss.item()})
self.progress.update(self.progress_task, advance=1)
self.step += 1
return loss
# β€”β€”β€”β€”β€”β€”β€” Custom evaluator to log validation metrics to W&B β€”β€”β€”β€”β€”β€”β€”
class RichLossEvaluator(evaluation.TripletEvaluator):
def __init__(self, anchors, positives, negatives, name, progress_task, progress):
super().__init__(
anchors=anchors, positives=positives, negatives=negatives,
name=name, show_progress_bar=False
)
self.progress_task = progress_task
self.progress = progress
def __call__(self, model, output_path=None, epoch=-1, steps=-1, **kwargs):
metrics = super().__call__(model, output_path=output_path, epoch=epoch, steps=steps, **kwargs)
# extract a scalar (accuracy or first metric)
if isinstance(metrics, dict):
val_score = metrics.get("accuracy", next(iter(metrics.values())))
else:
val_score = metrics
wandb.log({
"validation_step": steps,
"validation_score": val_score,
"epoch": epoch
})
# advance progress bar by 10 (since eval runs every 10 steps)
self.progress.update(self.progress_task, advance=10)
self.progress.console.log(f"[blue]Step {steps}[/blue]: Validation score: {val_score:.4f}")
return metrics
def main():
# β€”β€”β€”β€”β€”β€”β€” Configuration β€”β€”β€”β€”β€”β€”β€”
model_name = 'all-MiniLM-L6-v2'
jsonl_path = 'triplets.jsonl'
output_path = 'fine_tuned_sbert_triplet'
batch_size = 16
train_ratio = 0.8
num_epochs = 1
# β€”β€”β€”β€”β€”β€”β€” Initialize model & W&B β€”β€”β€”β€”β€”β€”β€”
model = SentenceTransformer(model_name)
wandb.init(
project="sbert-triplet-finetune",
name="custom_run_name", # distinct from output_path
config={"model": model_name, "batch_size": batch_size, "epochs": num_epochs}
)
# β€”β€”β€”β€”β€”β€”β€” Load & split data β€”β€”β€”β€”β€”β€”β€”
examples = load_triplets_from_jsonl(jsonl_path)
train_exs, val_exs = split_data(examples, train_ratio=train_ratio)
train_loader = DataLoader(train_exs, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_exs, shuffle=False, batch_size=batch_size)
anchors = [e.texts[0] for e in val_exs]
positives = [e.texts[1] for e in val_exs]
negatives = [e.texts[2] for e in val_exs]
total_train_steps = len(train_loader) * num_epochs
warmup_steps = int(0.1 * total_train_steps)
console.print(f"[bold]Training on {len(train_exs)} triplets, validating on {len(val_exs)} triplets[/bold]\n")
progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
transient=True,
)
with progress:
task = progress.add_task("[green]Training...[/green]", total=total_train_steps)
train_loss = WandbLoggingTripletLoss(model, task, progress)
evaluator = RichLossEvaluator(anchors, positives, negatives, 'val-triplet', task, progress)
model.fit(
train_objectives=[(train_loader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
warmup_steps=warmup_steps,
evaluation_steps=10,
output_path=output_path,
show_progress_bar=False
# <-- no `callback=[]` here
)
progress.update(task, completed=total_train_steps)
# β€”β€”β€”β€”β€”β€”β€” Save & finish W&B β€”β€”β€”β€”β€”β€”β€”
model.save(output_path)
wandb.save(os.path.join(output_path, "pytorch_model.bin"))
wandb.finish()
# β€”β€”β€”β€”β€”β€”β€” Final summary β€”β€”β€”β€”β€”β€”β€”
summary = Table(title="Training Summary")
summary.add_column("Metric", style="cyan", no_wrap=True)
summary.add_column("Value", style="magenta")
summary.add_row("Model Name", model_name)
summary.add_row("Total Triplets", str(len(examples)))
summary.add_row("Training Triplets", str(len(train_exs)))
summary.add_row("Validation Triplets", str(len(val_exs)))
summary.add_row("Epochs", str(num_epochs))
summary.add_row("Output Path", output_path)
console.print(summary)
console.print("[bold green]βœ… Fine‑tuning complete and model saved![/bold green]")
if __name__ == "__main__":
main()