File size: 1,883 Bytes
6a02b16
 
 
 
 
c6cece9
6a02b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from pathlib import Path
from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
from data_io import load_pairs, load_clauses

def build_corpus(clauses):
    corpus = {}
    for x in clauses:
        corpus[x["id"]] = x["text"]
    return corpus

def build_queries_and_qrels(pairs):
    queries = {}
    qrels = {}
    for i, x in enumerate(pairs):
        qid = str(i)
        queries[qid] = x["query"]
        qrels[qid] = {x["positive_id"]: 1}
    return queries, qrels

def main():
    base_model = "paraphrase-multilingual-mpnet-base-v2"
    train_path = "data/legal_assistant_train.jsonl"
    test_path = "data/legal_assistant_test.jsonl"
    clauses_path = "data/clauses_constitution_ru_kz.jsonl"
    out_dir = Path("artifacts/models/finetuned_mpnet")
    out_dir.mkdir(parents=True, exist_ok=True)

    train = load_pairs(train_path)
    test = load_pairs(test_path)
    ru, kz = load_clauses(clauses_path)

    model = SentenceTransformer(base_model)

    train_examples = [InputExample(texts=[x["query"], x["positive"]]) for x in train]
    train_loader = DataLoader(train_examples, shuffle=True, batch_size=32)
    train_loss = MultipleNegativesRankingLoss(model)

    corpus = build_corpus(ru) | build_corpus(kz)

    q, qr = build_queries_and_qrels(test)

    evaluator = InformationRetrievalEvaluator(q, corpus, qr, name="overall")

    model.fit(
        train_objectives=[(train_loader, train_loss)],
        epochs=2,
        warmup_steps=200,
        evaluator=evaluator,
        evaluation_steps=500,
        output_path=str(out_dir),
        save_best_model=True
    )

if __name__ == "__main__":
    main()