File size: 1,364 Bytes
ebdb5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | from pt_variety_identifier.src.results import Results as BaseResults
import logging
class Results(BaseResults):
def __init__(self, filepath, DOMAINS) -> None:
super().__init__(filepath, DOMAINS)
def process(self, cross_domain_f1, train_domain, test_results, train_results, balance, pos_prob, ner_prob):
if cross_domain_f1 > self.best_f1_scores[train_domain]["cross_domain_f1"]:
logging.info(f"New best f1 score for {train_domain}")
self.best_f1_scores[train_domain]["cross_domain_f1"] = cross_domain_f1
self.best_f1_scores[train_domain]["test_results"] = test_results
self.best_f1_scores[train_domain]["balance"] = balance
self.best_f1_scores[train_domain]["pos_prob"] = pos_prob
self.best_f1_scores[train_domain]["ner_prob"] = ner_prob
logging.info(
f"Saving best cross_domain_f1 scores to file")
self.best_final_results()
#TODO: Save PyTorch model
self.best_intermediate_results({
"domain": train_domain,
"balance": balance,
"pos_prob": pos_prob,
"ner_prob": ner_prob,
"train": train_results,
"test": {
'all': test_results,
'cross_domain_f1': cross_domain_f1
}
}) |