|
|
import json |
|
|
from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
def load_triplets_from_jsonl(file_path): |
|
|
examples = [] |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
data = json.loads(line) |
|
|
examples.append( |
|
|
InputExample(texts=[data['anchor'], data['positive'], data['negative']]) |
|
|
) |
|
|
return examples |
|
|
|
|
|
def main(): |
|
|
model_name = 'all-MiniLM-L6-v2' |
|
|
model = SentenceTransformer(model_name) |
|
|
|
|
|
|
|
|
train_examples = load_triplets_from_jsonl('train_triplets.jsonl') |
|
|
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16) |
|
|
train_loss = losses.TripletLoss(model=model) |
|
|
|
|
|
|
|
|
val_examples = load_triplets_from_jsonl('val_triplets.jsonl') |
|
|
val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16) |
|
|
|
|
|
|
|
|
anchors = [ex.texts[0] for ex in val_examples] |
|
|
positives = [ex.texts[1] for ex in val_examples] |
|
|
negatives = [ex.texts[2] for ex in val_examples] |
|
|
|
|
|
evaluator = evaluation.TripletEvaluator( |
|
|
anchors=anchors, |
|
|
positives=positives, |
|
|
negatives=negatives, |
|
|
name='val-triplet-eval', |
|
|
show_progress_bar=True |
|
|
) |
|
|
|
|
|
num_epochs = 1 |
|
|
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) |
|
|
|
|
|
|
|
|
model.fit( |
|
|
train_objectives=[(train_dataloader, train_loss)], |
|
|
epochs=num_epochs, |
|
|
warmup_steps=warmup_steps, |
|
|
evaluator=evaluator, |
|
|
evaluation_steps=10, |
|
|
output_path='fine_tuned_sbert_triplet', |
|
|
show_progress_bar=True |
|
|
) |
|
|
|
|
|
print("Model fine-tuned and saved at 'fine_tuned_sbert_triplet'") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|