LCA-PORVID's picture
Upload 34 files
ebdb5af verified
import os
from time import time
import json
from pt_variety_identifier.src.n_grams.data import Data
from pt_variety_identifier.src.n_grams.results import Results
from pt_variety_identifier.src.n_grams.trainer import Trainer
from pt_variety_identifier.src.n_grams.tester import Tester
from tqdm import tqdm
from pt_variety_identifier.src.utils import setup_logger, create_output_dir
from pt_variety_identifier.src.tunning import Tunning
import logging
from joblib import dump, load
from pt_variety_identifier.src.n_grams.model import EnsembleIdentfier, LanguageIdentifier
class Run:
def __init__(self, dataset_name, test_set_list) -> None:
self.CURRENT_PATH = os.path.dirname(os.path.abspath(__file__))
self.CURRENT_TIME = str(int(time()))
self.params = self.load_params()
create_output_dir(self.CURRENT_PATH, self.CURRENT_TIME)
setup_logger(self.CURRENT_PATH, self.CURRENT_TIME)
self.data = Data(dataset_name, test_set_list)
self._DOMAINS = ['literature', 'journalistic',
'legal', 'politics', 'web', 'social_media']
# Enable progress bar for pandas
tqdm.pandas()
self.tuner = Tunning(self.data, self._DOMAINS, Results, Trainer, Tester, sample_size=5_000,
CURRENT_PATH=self.CURRENT_PATH, CURRENT_TIME=self.CURRENT_TIME, params=self.params)
def load_params(self):
f = open(os.path.join(self.CURRENT_PATH, "in", "params.json"),
"r", encoding="utf-8")
# Fail if params.json does not exist
if f == None:
raise FileNotFoundError("params.json not found")
dict_obj = json.load(f)
if 'tfidf__ngram_range' in dict_obj:
# Cast tfidf__ngram_range to tuple
for idx, elem in enumerate(dict_obj['tfidf__ngram_range']):
dict_obj['tfidf__ngram_range'][idx] = tuple(elem)
return dict_obj
def tune(self):
return self.tuner.run()
def train(self):
with open(os.path.join(self.CURRENT_PATH, "in", "best_params.json"), "r", encoding="utf-8") as f:
best_params = json.load(f)
for domain in ['all']:
logging.info(f"Training {domain} domain")
data = self.data.load_domain(
domain, balance=True, pos_prob=None, ner_prob=None)
validation_dataset_dict = self.data.load_validation_set()
"""
logging.info(
f"Removing non training domains from validation set")
validation_dataset_dict = {
domain: validation_dataset_dict[domain]
}
"""
trainer = Trainer(
train_dataset=data,
params=best_params[domain]["tfidf"]
)
best_pipeline = trainer.train()
tester = Tester(
test_dataset_dict=validation_dataset_dict,
pipeline=best_pipeline,
train_domain=domain
)
results = tester.test()
logging.info(f"Results for {domain} domain: {results}")
logging.info(f"Save Model for {domain} domain")
dump(best_pipeline, os.path.join(
self.CURRENT_PATH, "out", self.CURRENT_TIME, "models", f"{domain}_model.joblib"))
def test(self):
test_data = self.data.load_test_set(filter_label_2=True)
pipeline = load(os.path.join(
self.CURRENT_PATH, "out", self.CURRENT_TIME, "models", "all_model.joblib"))
tester = Tester(test_data, pipeline, None)
results = tester.test()
logging.info(f"Results for test set: {results}")
def test_ensemble(self):
test_data = self.data.load_test_set(filter_label_2=True)
ensemble = EnsembleIdentfier(os.path.join(
self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models"))
tester = Tester(test_data, ensemble, None)
results = tester.test()
logging.info(f"Results for ensemble: {results}")