import os from time import time import json from pt_variety_identifier.src.n_grams.data import Data from pt_variety_identifier.src.n_grams.results import Results from pt_variety_identifier.src.n_grams.trainer import Trainer from pt_variety_identifier.src.n_grams.tester import Tester from tqdm import tqdm from pt_variety_identifier.src.utils import setup_logger, create_output_dir from pt_variety_identifier.src.tunning import Tunning import logging from joblib import dump, load from pt_variety_identifier.src.n_grams.model import EnsembleIdentfier, LanguageIdentifier class Run: def __init__(self, dataset_name, test_set_list) -> None: self.CURRENT_PATH = os.path.dirname(os.path.abspath(__file__)) self.CURRENT_TIME = str(int(time())) self.params = self.load_params() create_output_dir(self.CURRENT_PATH, self.CURRENT_TIME) setup_logger(self.CURRENT_PATH, self.CURRENT_TIME) self.data = Data(dataset_name, test_set_list) self._DOMAINS = ['literature', 'journalistic', 'legal', 'politics', 'web', 'social_media'] # Enable progress bar for pandas tqdm.pandas() self.tuner = Tunning(self.data, self._DOMAINS, Results, Trainer, Tester, sample_size=5_000, CURRENT_PATH=self.CURRENT_PATH, CURRENT_TIME=self.CURRENT_TIME, params=self.params) def load_params(self): f = open(os.path.join(self.CURRENT_PATH, "in", "params.json"), "r", encoding="utf-8") # Fail if params.json does not exist if f == None: raise FileNotFoundError("params.json not found") dict_obj = json.load(f) if 'tfidf__ngram_range' in dict_obj: # Cast tfidf__ngram_range to tuple for idx, elem in enumerate(dict_obj['tfidf__ngram_range']): dict_obj['tfidf__ngram_range'][idx] = tuple(elem) return dict_obj def tune(self): return self.tuner.run() def train(self): with open(os.path.join(self.CURRENT_PATH, "in", "best_params.json"), "r", encoding="utf-8") as f: best_params = json.load(f) for domain in ['all']: logging.info(f"Training {domain} domain") data = self.data.load_domain( domain, balance=True, pos_prob=None, ner_prob=None) validation_dataset_dict = self.data.load_validation_set() """ logging.info( f"Removing non training domains from validation set") validation_dataset_dict = { domain: validation_dataset_dict[domain] } """ trainer = Trainer( train_dataset=data, params=best_params[domain]["tfidf"] ) best_pipeline = trainer.train() tester = Tester( test_dataset_dict=validation_dataset_dict, pipeline=best_pipeline, train_domain=domain ) results = tester.test() logging.info(f"Results for {domain} domain: {results}") logging.info(f"Save Model for {domain} domain") dump(best_pipeline, os.path.join( self.CURRENT_PATH, "out", self.CURRENT_TIME, "models", f"{domain}_model.joblib")) def test(self): test_data = self.data.load_test_set(filter_label_2=True) pipeline = load(os.path.join( self.CURRENT_PATH, "out", self.CURRENT_TIME, "models", "all_model.joblib")) tester = Tester(test_data, pipeline, None) results = tester.test() logging.info(f"Results for test set: {results}") def test_ensemble(self): test_data = self.data.load_test_set(filter_label_2=True) ensemble = EnsembleIdentfier(os.path.join( self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models")) tester = Tester(test_data, ensemble, None) results = tester.test() logging.info(f"Results for ensemble: {results}")