lexir / src /train_biencoder.py
irinaqqq's picture
ADDED MORE GPAPHS
c6cece9
from pathlib import Path
from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
from data_io import load_pairs, load_clauses
def build_corpus(clauses):
corpus = {}
for x in clauses:
corpus[x["id"]] = x["text"]
return corpus
def build_queries_and_qrels(pairs):
queries = {}
qrels = {}
for i, x in enumerate(pairs):
qid = str(i)
queries[qid] = x["query"]
qrels[qid] = {x["positive_id"]: 1}
return queries, qrels
def main():
base_model = "paraphrase-multilingual-mpnet-base-v2"
train_path = "data/legal_assistant_train.jsonl"
test_path = "data/legal_assistant_test.jsonl"
clauses_path = "data/clauses_constitution_ru_kz.jsonl"
out_dir = Path("artifacts/models/finetuned_mpnet")
out_dir.mkdir(parents=True, exist_ok=True)
train = load_pairs(train_path)
test = load_pairs(test_path)
ru, kz = load_clauses(clauses_path)
model = SentenceTransformer(base_model)
train_examples = [InputExample(texts=[x["query"], x["positive"]]) for x in train]
train_loader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = MultipleNegativesRankingLoss(model)
corpus = build_corpus(ru) | build_corpus(kz)
q, qr = build_queries_and_qrels(test)
evaluator = InformationRetrievalEvaluator(q, corpus, qr, name="overall")
model.fit(
train_objectives=[(train_loader, train_loss)],
epochs=2,
warmup_steps=200,
evaluator=evaluator,
evaluation_steps=500,
output_path=str(out_dir),
save_best_model=True
)
if __name__ == "__main__":
main()