| from pathlib import Path | |
| from sentence_transformers import SentenceTransformer, InputExample | |
| from sentence_transformers.losses import MultipleNegativesRankingLoss | |
| from sentence_transformers.evaluation import InformationRetrievalEvaluator | |
| from torch.utils.data import DataLoader | |
| from data_io import load_pairs, load_clauses | |
| def build_corpus(clauses): | |
| corpus = {} | |
| for x in clauses: | |
| corpus[x["id"]] = x["text"] | |
| return corpus | |
| def build_queries_and_qrels(pairs): | |
| queries = {} | |
| qrels = {} | |
| for i, x in enumerate(pairs): | |
| qid = str(i) | |
| queries[qid] = x["query"] | |
| qrels[qid] = {x["positive_id"]: 1} | |
| return queries, qrels | |
| def main(): | |
| base_model = "paraphrase-multilingual-mpnet-base-v2" | |
| train_path = "data/legal_assistant_train.jsonl" | |
| test_path = "data/legal_assistant_test.jsonl" | |
| clauses_path = "data/clauses_constitution_ru_kz.jsonl" | |
| out_dir = Path("artifacts/models/finetuned_mpnet") | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| train = load_pairs(train_path) | |
| test = load_pairs(test_path) | |
| ru, kz = load_clauses(clauses_path) | |
| model = SentenceTransformer(base_model) | |
| train_examples = [InputExample(texts=[x["query"], x["positive"]]) for x in train] | |
| train_loader = DataLoader(train_examples, shuffle=True, batch_size=32) | |
| train_loss = MultipleNegativesRankingLoss(model) | |
| corpus = build_corpus(ru) | build_corpus(kz) | |
| q, qr = build_queries_and_qrels(test) | |
| evaluator = InformationRetrievalEvaluator(q, corpus, qr, name="overall") | |
| model.fit( | |
| train_objectives=[(train_loader, train_loss)], | |
| epochs=2, | |
| warmup_steps=200, | |
| evaluator=evaluator, | |
| evaluation_steps=500, | |
| output_path=str(out_dir), | |
| save_best_model=True | |
| ) | |
| if __name__ == "__main__": | |
| main() | |