| | |
| | from tqdm import tqdm |
| | from request_solr import SilverDataset |
| | from sentence_transformers.cross_encoder import CrossEncoder |
| | import joblib |
| | from solr_query_params import params |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | cross_encoder_path = 'anatel/cross-encoder-pt-anatel-metadados-assunto' |
| | gold_sample_index = set() |
| | with open('gold_sample_index.txt', 'r') as f: |
| | for line in f: |
| | gold_sample_index.add(line.strip()) |
| | 7 |
| | try: |
| | joblib.load('silver_data_v2.pkl') |
| | except: |
| | print('Creating silver data...') |
| | silver_data = SilverDataset(query_params=params,duplicated=gold_sample_index).run() |
| | joblib.dump(silver_data, 'silver_data_v2.pkl') |
| | print('Done!') |
| |
|
| | sentences = [(sent_1,sent_2) for sent_1, sent_2, _ in silver_data] |
| |
|
| | cross_encoder = CrossEncoder(cross_encoder_path,max_length=512) |
| | cross_silver_scores = [] |
| | for i in tqdm(sentences): |
| | cross_silver_scores.append(cross_encoder.predict(i)) |
| |
|
| | import numpy as np |
| | cross_silver_data = np.c_[np.array(silver_data),np.array(cross_silver_scores)] |
| |
|
| | |
| | assert all(0.0 <= score <= 1.0 for score in cross_silver_scores) |
| |
|
| | joblib.dump(cross_silver_data, 'cross_silver_scores_2.pkl') |
| |
|