File size: 1,518 Bytes
ac03f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172a9d5
ac03f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#%%
from tqdm import tqdm
from request_solr import SilverDataset
from sentence_transformers.cross_encoder import CrossEncoder
import joblib
from solr_query_params import params

############################################################################
#
# https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/data_augmentation/train_sts_indomain_bm25.py
# Step 2: Label BM25 sampled STSb (silver dataset) using cross-encoder model
#
############################################################################


cross_encoder_path = 'anatel/cross-encoder-pt-anatel-metadados-assunto'
gold_sample_index = set()
with open('gold_sample_index.txt', 'r') as f:
    for line in f:
        gold_sample_index.add(line.strip())
        7
try:
    joblib.load('silver_data_v2.pkl')
except:
    print('Creating silver data...')
    silver_data = SilverDataset(query_params=params,duplicated=gold_sample_index).run()
    joblib.dump(silver_data, 'silver_data_v2.pkl')
    print('Done!')

sentences = [(sent_1,sent_2) for sent_1, sent_2, _ in silver_data]

cross_encoder = CrossEncoder(cross_encoder_path,max_length=512)
cross_silver_scores = []
for i in tqdm(sentences):
    cross_silver_scores.append(cross_encoder.predict(i))

import numpy as np
cross_silver_data = np.c_[np.array(silver_data),np.array(cross_silver_scores)]

# All model predictions should be between [0,1]
assert all(0.0 <= score <= 1.0 for score in cross_silver_scores)

joblib.dump(cross_silver_data, 'cross_silver_scores_2.pkl')