import json import random from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation from torch.utils.data import DataLoader from rich.console import Console from rich.table import Table from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn console = Console() def load_triplets_from_jsonl(file_path): examples = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line) examples.append( InputExample(texts=[data['anchor'], data['positive'], data['negative']]) ) return examples def split_data(examples, train_ratio=0.8, seed=42): random.seed(seed) random.shuffle(examples) train_size = int(len(examples) * train_ratio) return examples[:train_size], examples[train_size:] def main(): model_name = 'all-MiniLM-L6-v2' model = SentenceTransformer(model_name) # Load all triplets from a single JSONL file all_examples = load_triplets_from_jsonl('triplets.jsonl') # Split into train and validation train_examples, val_examples = split_data(all_examples, train_ratio=0.8) train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16) train_loss = losses.TripletLoss(model=model) val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16) anchors = [ex.texts[0] for ex in val_examples] positives = [ex.texts[1] for ex in val_examples] negatives = [ex.texts[2] for ex in val_examples] evaluator = evaluation.TripletEvaluator( anchors=anchors, positives=positives, negatives=negatives, name='val-triplet-eval', show_progress_bar=False ) num_epochs = 1 warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) output_path = 'fine_tuned_sbert_triplet' console.print(f"[bold]Starting training with {len(train_examples)} triplets, validating on {len(val_examples)} triplets.[/bold]\n") progress = Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeElapsedColumn(), transient=True, ) class RichLossEvaluator(evaluation.TripletEvaluator): def __init__(self, *args, progress_task, **kwargs): super().__init__(*args, **kwargs) self.progress_task = progress_task def __call__(self, model, output_path=None, epoch=-1, steps=-1, **kwargs): score_dict = super().__call__(model, output_path=output_path, epoch=epoch, steps=steps, **kwargs) if isinstance(score_dict, dict): if "accuracy" in score_dict: val_score = score_dict["accuracy"] else: val_score = next(iter(score_dict.values())) else: val_score = score_dict progress.update(self.progress_task, advance=10) # assuming evaluation every 10 steps progress.console.log(f"Step {steps}: Validation score: {val_score:.4f}") return score_dict with progress: task = progress.add_task("[green]Training...[/green]", total=num_epochs * len(train_dataloader)) rich_evaluator = RichLossEvaluator( anchors=anchors, positives=positives, negatives=negatives, name='val-triplet-eval', show_progress_bar=False, progress_task=task ) model.fit( train_objectives=[(train_dataloader, train_loss)], epochs=num_epochs, warmup_steps=warmup_steps, evaluator=rich_evaluator, evaluation_steps=10, output_path=output_path, show_progress_bar=False ) progress.update(task, completed=num_epochs * len(train_dataloader)) table = Table(title="Training Summary") table.add_column("Metric", style="cyan", no_wrap=True) table.add_column("Value", style="magenta") table.add_row("Model Name", model_name) table.add_row("Total Triplets", str(len(all_examples))) table.add_row("Training Triplets", str(len(train_examples))) table.add_row("Validation Triplets", str(len(val_examples))) table.add_row("Epochs", str(num_epochs)) table.add_row("Output Path", output_path) console.print(table) console.print("[bold green]Model fine-tuned and saved successfully![/bold green]") if __name__ == "__main__": main()