import json from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation from torch.utils.data import DataLoader def load_triplets_from_jsonl(file_path): examples = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line) examples.append( InputExample(texts=[data['anchor'], data['positive'], data['negative']]) ) return examples def main(): model_name = 'all-MiniLM-L6-v2' model = SentenceTransformer(model_name) # Load training triplets from JSONL train_examples = load_triplets_from_jsonl('train_triplets.jsonl') train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16) train_loss = losses.TripletLoss(model=model) # Load validation triplets from JSONL val_examples = load_triplets_from_jsonl('val_triplets.jsonl') val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16) # Extract anchors, positives, negatives for evaluator from val_examples anchors = [ex.texts[0] for ex in val_examples] positives = [ex.texts[1] for ex in val_examples] negatives = [ex.texts[2] for ex in val_examples] evaluator = evaluation.TripletEvaluator( anchors=anchors, positives=positives, negatives=negatives, name='val-triplet-eval', show_progress_bar=True ) num_epochs = 1 warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) # Train the model model.fit( train_objectives=[(train_dataloader, train_loss)], epochs=num_epochs, warmup_steps=warmup_steps, evaluator=evaluator, evaluation_steps=10, output_path='fine_tuned_sbert_triplet', show_progress_bar=True ) print("Model fine-tuned and saved at 'fine_tuned_sbert_triplet'") if __name__ == "__main__": main()