| | import os |
| | from time import time |
| | import json |
| | from pt_variety_identifier.src.n_grams.data import Data |
| | from pt_variety_identifier.src.n_grams.results import Results |
| | from pt_variety_identifier.src.n_grams.trainer import Trainer |
| | from pt_variety_identifier.src.n_grams.tester import Tester |
| | from tqdm import tqdm |
| | from pt_variety_identifier.src.utils import setup_logger, create_output_dir |
| | from pt_variety_identifier.src.tunning import Tunning |
| | import logging |
| | from joblib import dump, load |
| | from pt_variety_identifier.src.n_grams.model import EnsembleIdentfier, LanguageIdentifier |
| |
|
| |
|
| | class Run: |
| | def __init__(self, dataset_name, test_set_list) -> None: |
| | self.CURRENT_PATH = os.path.dirname(os.path.abspath(__file__)) |
| | self.CURRENT_TIME = str(int(time())) |
| | self.params = self.load_params() |
| |
|
| | create_output_dir(self.CURRENT_PATH, self.CURRENT_TIME) |
| | setup_logger(self.CURRENT_PATH, self.CURRENT_TIME) |
| |
|
| | self.data = Data(dataset_name, test_set_list) |
| |
|
| | self._DOMAINS = ['literature', 'journalistic', |
| | 'legal', 'politics', 'web', 'social_media'] |
| |
|
| | |
| | tqdm.pandas() |
| |
|
| | self.tuner = Tunning(self.data, self._DOMAINS, Results, Trainer, Tester, sample_size=5_000, |
| | CURRENT_PATH=self.CURRENT_PATH, CURRENT_TIME=self.CURRENT_TIME, params=self.params) |
| |
|
| | def load_params(self): |
| |
|
| | f = open(os.path.join(self.CURRENT_PATH, "in", "params.json"), |
| | "r", encoding="utf-8") |
| |
|
| | |
| | if f == None: |
| | raise FileNotFoundError("params.json not found") |
| |
|
| | dict_obj = json.load(f) |
| |
|
| | if 'tfidf__ngram_range' in dict_obj: |
| | |
| | for idx, elem in enumerate(dict_obj['tfidf__ngram_range']): |
| | dict_obj['tfidf__ngram_range'][idx] = tuple(elem) |
| |
|
| | return dict_obj |
| |
|
| | def tune(self): |
| | return self.tuner.run() |
| |
|
| | def train(self): |
| | with open(os.path.join(self.CURRENT_PATH, "in", "best_params.json"), "r", encoding="utf-8") as f: |
| | best_params = json.load(f) |
| |
|
| | for domain in ['all']: |
| | logging.info(f"Training {domain} domain") |
| |
|
| | data = self.data.load_domain( |
| | domain, balance=True, pos_prob=None, ner_prob=None) |
| |
|
| | validation_dataset_dict = self.data.load_validation_set() |
| |
|
| | """ |
| | logging.info( |
| | f"Removing non training domains from validation set") |
| | validation_dataset_dict = { |
| | domain: validation_dataset_dict[domain] |
| | } |
| | """ |
| |
|
| | trainer = Trainer( |
| | train_dataset=data, |
| | params=best_params[domain]["tfidf"] |
| | ) |
| |
|
| | best_pipeline = trainer.train() |
| |
|
| | tester = Tester( |
| | test_dataset_dict=validation_dataset_dict, |
| | pipeline=best_pipeline, |
| | train_domain=domain |
| | ) |
| |
|
| | results = tester.test() |
| |
|
| | logging.info(f"Results for {domain} domain: {results}") |
| |
|
| | logging.info(f"Save Model for {domain} domain") |
| |
|
| | dump(best_pipeline, os.path.join( |
| | self.CURRENT_PATH, "out", self.CURRENT_TIME, "models", f"{domain}_model.joblib")) |
| |
|
| | def test(self): |
| | test_data = self.data.load_test_set(filter_label_2=True) |
| |
|
| | pipeline = load(os.path.join( |
| | self.CURRENT_PATH, "out", self.CURRENT_TIME, "models", "all_model.joblib")) |
| |
|
| | tester = Tester(test_data, pipeline, None) |
| |
|
| | results = tester.test() |
| |
|
| | logging.info(f"Results for test set: {results}") |
| |
|
| | def test_ensemble(self): |
| | test_data = self.data.load_test_set(filter_label_2=True) |
| |
|
| | ensemble = EnsembleIdentfier(os.path.join( |
| | self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models")) |
| |
|
| | tester = Tester(test_data, ensemble, None) |
| |
|
| | results = tester.test() |
| |
|
| | logging.info(f"Results for ensemble: {results}") |
| |
|