triplet-embed / train.py
akhilreddygogula
gitignore updated
d65bc08
import json
from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation
from torch.utils.data import DataLoader
def load_triplets_from_jsonl(file_path):
examples = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
examples.append(
InputExample(texts=[data['anchor'], data['positive'], data['negative']])
)
return examples
def main():
model_name = 'all-MiniLM-L6-v2'
model = SentenceTransformer(model_name)
# Load training triplets from JSONL
train_examples = load_triplets_from_jsonl('train_triplets.jsonl')
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.TripletLoss(model=model)
# Load validation triplets from JSONL
val_examples = load_triplets_from_jsonl('val_triplets.jsonl')
val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16)
# Extract anchors, positives, negatives for evaluator from val_examples
anchors = [ex.texts[0] for ex in val_examples]
positives = [ex.texts[1] for ex in val_examples]
negatives = [ex.texts[2] for ex in val_examples]
evaluator = evaluation.TripletEvaluator(
anchors=anchors,
positives=positives,
negatives=negatives,
name='val-triplet-eval',
show_progress_bar=True
)
num_epochs = 1
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)
# Train the model
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
warmup_steps=warmup_steps,
evaluator=evaluator,
evaluation_steps=10,
output_path='fine_tuned_sbert_triplet',
show_progress_bar=True
)
print("Model fine-tuned and saved at 'fine_tuned_sbert_triplet'")
if __name__ == "__main__":
main()