File size: 2,034 Bytes
ebdb5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import logging
import os


class Tunning:
    def __init__(self, data, domains, Results, Trainer, Tester, sample_size, CURRENT_PATH, CURRENT_TIME, params=None) -> None:
        self.data = data
        self.Trainer = Trainer
        self.Tester = Tester
        self._DOMAINS = domains
        self.sample_size = sample_size
        self.CURRENT_PATH = CURRENT_PATH
        self.CURRENT_TIME = CURRENT_TIME
        
        self.results = Results(os.path.join(
            self.CURRENT_PATH, "out", str(CURRENT_TIME)), self._DOMAINS)
        
        self.params = params        

    def run(self, start_pos_prob=0.0, stop_pos_prob=1.0, start_ner_prob=0.0, stop_ner_prob=1.0):

        logging.info(f"Start pos_prob={start_pos_prob}, stop_pos_prob={stop_pos_prob}")

        test_dataset = self.data.load_test_set()

        for pos_prob in np.arange(start_pos_prob, stop_pos_prob + 0.1, 0.1):
            for ner_prob in np.arange(start_ner_prob, stop_ner_prob + 0.1, 0.1):
                for domain in self._DOMAINS:                
                    logging.info(
                        f"Running {domain} pos_prob={pos_prob}, ner_prob={ner_prob}")

                    dataset = self.data.load_domain(
                        domain, balance=True, pos_prob=pos_prob, ner_prob=ner_prob, sample_size=self.sample_size)
                    
                    trainer = self.Trainer(dataset, self.params)
                    
                    results, best_model = trainer.train()
                    
                    validation_results = self.Tester(
                        test_dataset, best_model, train_domain=domain).validate()

                    logging.info(
                        f"Cross domain f1 score: {validation_results['f1']} | test_results: {validation_results}")
                    
                    self.results.process(validation_results['f1'], domain, validation_results,
                                         results, balance=True, pos_prob=pos_prob, ner_prob=ner_prob)