import torch import os import time from pt_variety_identifier.src.utils import setup_logger, create_output_dir from pt_variety_identifier.src.bert.data import Data from tqdm import tqdm from pt_variety_identifier.src.tunning import Tunning from pt_variety_identifier.src.bert.trainer import Trainer from pt_variety_identifier.src.bert.tester import Tester from pt_variety_identifier.src.bert.results import Results from pt_variety_identifier.src.bert.model import EnsembleIdentfier, LanguageIdentfier import torch.multiprocessing as mp from threading import Thread import logging import numpy as np class Run: def __init__(self, dataset_name, tokenizer_name, model_name, batch_size, test_set_list) -> None: self.CURRENT_PATH = os.path.dirname(os.path.abspath(__file__)) self.CURRENT_TIME = int(time.time()) self.num_gpus = torch.cuda.device_count() self.sem = mp.Semaphore(self.num_gpus) self.gpus_free = [i for i in range(self.num_gpus)] self.test_set_list = test_set_list create_output_dir(self.CURRENT_PATH, self.CURRENT_TIME) setup_logger(self.CURRENT_PATH, self.CURRENT_TIME) self.data = Data( dataset_name, tokenizer_name=tokenizer_name, batch_size=batch_size, test_set_list=test_set_list) self._DOMAINS = ['literature', 'legal', 'politics', 'web', 'social_media', 'journalistic'] self.model_name = model_name tqdm.pandas() def tune_with_gpu(self): threads = [] for pos_prob in tqdm(range(np.arange(0.0, 1.0, 0.1))): for ner_prob in tqdm(range(np.arange(0.0, 1.0, 0.2))): pos_prob = round(pos_prob, 2) ner_prob = round(ner_prob, 2) self.sem.acquire() gpu_in_use = self.gpus_free.pop() tuner = Tunning(self.data, self._DOMAINS, Results, Trainer, Tester, 5_000, self.CURRENT_PATH, self.CURRENT_TIME, params={ 'epochs': 30, 'early_stoping': 5, 'model_name': self.model_name, 'device': f"cuda:{gpu_in_use}", 'sem': self.sem, 'gpus_free': self.gpus_free, }) thread = Thread(target=tuner.run, args=( pos_prob, pos_prob, ner_prob, ner_prob), daemon=True ) threads.append(thread) for t in threads: t.join() def tune_with_cpu(self): tuner = Tunning(self.data, self._DOMAINS, Results, Trainer, Tester, 5_000, self.CURRENT_PATH, self.CURRENT_TIME, params={ 'epochs': 30, 'early_stoping': 5, 'model_name': self.model_name, 'device': 'cpu', }) tuner.run() def tune(self): if torch.cuda.is_available(): return self.tune_with_gpu() return self.tune_with_cpu() def _train_domain(self, domain, gpu): logging.info(f"Training {domain} domain") data = self.data.load_domain(domain, balance=True, pos_prob=None, ner_prob=None) validation_dataset_dict = self.data.load_validation_set() """ logging.info(f"Removing non training domains from validation set") validation_dataset_dict = { domain: validation_dataset_dict[domain] } """ trainer = Trainer(data, params={ 'epochs': 30, 'early_stoping': 5, 'model_name': self.model_name, 'device': gpu, 'CURRENT_PATH': self.CURRENT_PATH, 'CURRENT_TIME': self.CURRENT_TIME, 'training_domain': domain, },validation_dataset_dict=validation_dataset_dict) best_results = trainer.train() logging.info(f"Best results for {domain} domain: {best_results}") logging.info(f"Freeing cuda:{gpu[-1]}") self.gpus_free.append(gpu[-1]) return self.sem.release() def train(self): threads = [] for domain in ['all']: self.sem.acquire() gpu_in_use = self.gpus_free.pop() thread = Thread(target=self._train_domain, args=(domain, f"cuda:{gpu_in_use}"), daemon=True) threads.append(thread) thread.start() for t in threads: t.join() def test(self): model = LanguageIdentfier(self.model_name) logging.info(f"Loading model from {os.path.join(self.CURRENT_PATH, 'out', str(self.CURRENT_TIME), 'models', 'all.pt')}") model.load_state_dict(torch.load(os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models", "all.pt"))) model.eval() model.to('cuda') data = self.data.load_test_set(filter_label_2=True) tester = Tester(data, model, None) results = tester.validate() logging.info(f"Results for all: {results}") def test_ensemble(self): data = self.data.load_test_set(filter_label_2=True) ensemble = EnsembleIdentfier(os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models"), self.model_name) tester = Tester(data, ensemble, None) results = tester.test() logging.info(f"Results for ensemble: {results}")