ennioferreirab's picture
add model
172a9d5
#%%
from tqdm import tqdm
from request_solr import SilverDataset
from sentence_transformers.cross_encoder import CrossEncoder
import joblib
from solr_query_params import params
############################################################################
#
# https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/data_augmentation/train_sts_indomain_bm25.py
# Step 2: Label BM25 sampled STSb (silver dataset) using cross-encoder model
#
############################################################################
cross_encoder_path = 'anatel/cross-encoder-pt-anatel-metadados-assunto'
gold_sample_index = set()
with open('gold_sample_index.txt', 'r') as f:
for line in f:
gold_sample_index.add(line.strip())
7
try:
joblib.load('silver_data_v2.pkl')
except:
print('Creating silver data...')
silver_data = SilverDataset(query_params=params,duplicated=gold_sample_index).run()
joblib.dump(silver_data, 'silver_data_v2.pkl')
print('Done!')
sentences = [(sent_1,sent_2) for sent_1, sent_2, _ in silver_data]
cross_encoder = CrossEncoder(cross_encoder_path,max_length=512)
cross_silver_scores = []
for i in tqdm(sentences):
cross_silver_scores.append(cross_encoder.predict(i))
import numpy as np
cross_silver_data = np.c_[np.array(silver_data),np.array(cross_silver_scores)]
# All model predictions should be between [0,1]
assert all(0.0 <= score <= 1.0 for score in cross_silver_scores)
joblib.dump(cross_silver_data, 'cross_silver_scores_2.pkl')