from pathlib import Path from sentence_transformers import SentenceTransformer, InputExample from sentence_transformers.losses import MultipleNegativesRankingLoss from sentence_transformers.evaluation import InformationRetrievalEvaluator from torch.utils.data import DataLoader from data_io import load_pairs, load_clauses def build_corpus(clauses): corpus = {} for x in clauses: corpus[x["id"]] = x["text"] return corpus def build_queries_and_qrels(pairs): queries = {} qrels = {} for i, x in enumerate(pairs): qid = str(i) queries[qid] = x["query"] qrels[qid] = {x["positive_id"]: 1} return queries, qrels def main(): base_model = "paraphrase-multilingual-mpnet-base-v2" train_path = "data/legal_assistant_train.jsonl" test_path = "data/legal_assistant_test.jsonl" clauses_path = "data/clauses_constitution_ru_kz.jsonl" out_dir = Path("artifacts/models/finetuned_mpnet") out_dir.mkdir(parents=True, exist_ok=True) train = load_pairs(train_path) test = load_pairs(test_path) ru, kz = load_clauses(clauses_path) model = SentenceTransformer(base_model) train_examples = [InputExample(texts=[x["query"], x["positive"]]) for x in train] train_loader = DataLoader(train_examples, shuffle=True, batch_size=32) train_loss = MultipleNegativesRankingLoss(model) corpus = build_corpus(ru) | build_corpus(kz) q, qr = build_queries_and_qrels(test) evaluator = InformationRetrievalEvaluator(q, corpus, qr, name="overall") model.fit( train_objectives=[(train_loader, train_loss)], epochs=2, warmup_steps=200, evaluator=evaluator, evaluation_steps=500, output_path=str(out_dir), save_best_model=True ) if __name__ == "__main__": main()