File size: 1,364 Bytes
ebdb5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from pt_variety_identifier.src.results import Results as BaseResults
import logging

class Results(BaseResults):
    def __init__(self, filepath, DOMAINS) -> None:
        super().__init__(filepath, DOMAINS)

    def process(self, cross_domain_f1, train_domain, test_results, train_results, balance, pos_prob, ner_prob):
        if cross_domain_f1 > self.best_f1_scores[train_domain]["cross_domain_f1"]:
            logging.info(f"New best f1 score for {train_domain}")

            self.best_f1_scores[train_domain]["cross_domain_f1"] = cross_domain_f1
            self.best_f1_scores[train_domain]["test_results"] = test_results
            self.best_f1_scores[train_domain]["balance"] = balance
            self.best_f1_scores[train_domain]["pos_prob"] = pos_prob
            self.best_f1_scores[train_domain]["ner_prob"] = ner_prob

            logging.info(
                f"Saving best cross_domain_f1 scores to file")
            
            self.best_final_results()

            #TODO: Save PyTorch model

        self.best_intermediate_results({
            "domain": train_domain,
            "balance": balance,
            "pos_prob": pos_prob,
            "ner_prob": ner_prob,
            "train": train_results,
            "test": {
                'all': test_results,
                'cross_domain_f1': cross_domain_f1
            }
        })