|
|
''' |
|
|
Author: Qiguang Chen |
|
|
Date: 2023-01-11 10:39:26 |
|
|
LastEditors: Qiguang Chen |
|
|
LastEditTime: 2023-02-19 18:50:11 |
|
|
Description: manage all process of model training and prediction. |
|
|
|
|
|
''' |
|
|
import math |
|
|
import os |
|
|
import queue |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
from common import utils |
|
|
from common.loader import DataFactory |
|
|
from common.logger import Logger |
|
|
from common.metric import Evaluator |
|
|
from common.saver import Saver |
|
|
from common.tokenizer import get_tokenizer, get_tokenizer_class, load_embedding |
|
|
from common.utils import InputData, instantiate |
|
|
from common.utils import OutputData |
|
|
from common.config import Config |
|
|
import dill |
|
|
from common import global_pool |
|
|
from tools.load_from_hugging_face import PreTrainedTokenizerForSLU, PretrainedModelForSLU |
|
|
|
|
|
|
|
|
|
|
|
class ModelManager(object): |
|
|
def __init__(self, config: Config): |
|
|
"""create model manager by config |
|
|
|
|
|
Args: |
|
|
config (Config): configuration to manage all process in OpenSLU |
|
|
""" |
|
|
|
|
|
global_pool._init() |
|
|
self.config = config |
|
|
self.__set_seed(self.config.base.get("seed")) |
|
|
self.device = self.config.base.get("device") |
|
|
self.load_dir = self.config.model_manager.get("load_dir") |
|
|
if self.config.get("logger") and self.config["logger"].get("logger_type"): |
|
|
logger_type = self.config["logger"].get("logger_type") |
|
|
else: |
|
|
logger_type = "wandb" |
|
|
|
|
|
if "accelerator" in self.config and self.config["accelerator"].get("use_accelerator"): |
|
|
from accelerate import Accelerator |
|
|
self.accelerator = Accelerator(log_with=logger_type) |
|
|
else: |
|
|
self.accelerator = None |
|
|
self.tokenizer = None |
|
|
self.saver = Saver(self.config.model_manager, start_time=self.config.start_time) |
|
|
if self.config.base.get("train"): |
|
|
self.model = None |
|
|
self.optimizer = None |
|
|
self.total_step = None |
|
|
self.lr_scheduler = None |
|
|
self.init_step = 0 |
|
|
self.best_metric = 0 |
|
|
self.logger = Logger(logger_type=logger_type, |
|
|
logger_name=self.config.base["name"], |
|
|
start_time=self.config.start_time, |
|
|
accelerator=self.accelerator) |
|
|
global_pool.set_value("logger", self.logger) |
|
|
|
|
|
def init_model(self): |
|
|
"""init model, optimizer, lr_scheduler |
|
|
|
|
|
Args: |
|
|
model (Any): pytorch model |
|
|
""" |
|
|
self.prepared = False |
|
|
if self.load_dir is not None: |
|
|
self.load() |
|
|
self.config.set_vocab_size(self.tokenizer.vocab_size) |
|
|
self.init_data() |
|
|
if self.config.base.get("train") and self.config.model_manager.get("load_train_state"): |
|
|
train_state = torch.load(os.path.join( |
|
|
self.load_dir, "train_state.pkl"), pickle_module=dill) |
|
|
self.optimizer = instantiate( |
|
|
self.config["optimizer"])(self.model.parameters()) |
|
|
self.lr_scheduler = instantiate(self.config["scheduler"])( |
|
|
optimizer=self.optimizer, |
|
|
num_training_steps=self.total_step |
|
|
) |
|
|
self.optimizer.load_state_dict(train_state["optimizer"]) |
|
|
self.optimizer.zero_grad() |
|
|
self.lr_scheduler.load_state_dict(train_state["lr_scheduler"]) |
|
|
self.init_step = train_state["step"] |
|
|
self.best_metric = train_state["best_metric"] |
|
|
elif self.config.model.get("_from_pretrained_") and self.config.tokenizer.get("_from_pretrained_"): |
|
|
self.from_pretrained() |
|
|
self.config.set_vocab_size(self.tokenizer.vocab_size) |
|
|
self.init_data() |
|
|
else: |
|
|
self.tokenizer = get_tokenizer( |
|
|
self.config.tokenizer.get("_tokenizer_name_")) |
|
|
self.init_data() |
|
|
self.model = instantiate(self.config.model) |
|
|
self.model.to(self.device) |
|
|
if self.config.base.get("train"): |
|
|
self.optimizer = instantiate( |
|
|
self.config["optimizer"])(self.model.parameters()) |
|
|
self.lr_scheduler = instantiate(self.config["scheduler"])( |
|
|
optimizer=self.optimizer, |
|
|
num_training_steps=self.total_step |
|
|
) |
|
|
|
|
|
|
|
|
def init_data(self): |
|
|
self.data_factory = DataFactory(tokenizer=self.tokenizer, |
|
|
use_multi_intent=self.config.base.get("multi_intent"), |
|
|
to_lower_case=self.config.tokenizer.get("_to_lower_case_")) |
|
|
batch_size = self.config.base["batch_size"] |
|
|
|
|
|
tokenizer_config = {key: self.config.tokenizer[key] |
|
|
for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"} |
|
|
|
|
|
if self.config.base.get("train"): |
|
|
|
|
|
|
|
|
|
|
|
train_dataset = self.data_factory.load_dataset(self.config.dataset, split="train") |
|
|
|
|
|
|
|
|
self.data_factory.update_label_names(train_dataset) |
|
|
self.data_factory.update_vocabulary(train_dataset) |
|
|
|
|
|
|
|
|
self.train_dataloader = self.data_factory.get_data_loader(train_dataset, |
|
|
batch_size, |
|
|
shuffle=True, |
|
|
device=self.device, |
|
|
enable_label=True, |
|
|
align_mode=self.config.tokenizer.get( |
|
|
"_align_mode_"), |
|
|
label2tensor=True, |
|
|
**tokenizer_config) |
|
|
self.total_step = int(self.config.base.get("epoch_num")) * len(self.train_dataloader) |
|
|
dev_dataset = self.data_factory.load_dataset(self.config.dataset, split="validation") |
|
|
self.dev_dataloader = self.data_factory.get_data_loader(dev_dataset, |
|
|
batch_size, |
|
|
shuffle=False, |
|
|
device=self.device, |
|
|
enable_label=True, |
|
|
align_mode=self.config.tokenizer.get( |
|
|
"_align_mode_"), |
|
|
label2tensor=False, |
|
|
**tokenizer_config) |
|
|
self.data_factory.update_vocabulary(dev_dataset) |
|
|
self.intent_list = None |
|
|
self.intent_dict = None |
|
|
self.slot_list = None |
|
|
self.slot_dict = None |
|
|
|
|
|
if self.config.model["decoder"].get("intent_classifier") and int(self.config.get_intent_label_num()) == 0: |
|
|
self.intent_list = self.data_factory.intent_label_list |
|
|
self.intent_dict = self.data_factory.intent_label_dict |
|
|
self.config.set_intent_label_num(len(self.intent_list)) |
|
|
if self.config.model["decoder"].get("slot_classifier") and int(self.config.get_slot_label_num()) == 0: |
|
|
self.slot_list = self.data_factory.slot_label_list |
|
|
self.slot_dict = self.data_factory.slot_label_dict |
|
|
self.config.set_slot_label_num(len(self.slot_list)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config["model"]["encoder"].get("embedding") and self.config["model"]["encoder"]["embedding"].get( |
|
|
"load_embedding_name"): |
|
|
self.config["model"]["encoder"]["embedding"]["embedding_matrix"] = load_embedding(self.tokenizer, |
|
|
self.config["model"][ |
|
|
"encoder"][ |
|
|
"embedding"].get( |
|
|
"load_embedding_name")) |
|
|
|
|
|
self.config.autoload_template() |
|
|
|
|
|
self.logger.set_config(self.config) |
|
|
self.saver.save_tokenizer(self.tokenizer) |
|
|
self.saver.save_label(self.intent_list, self.slot_list) |
|
|
self.config.set_vocab_size(self.tokenizer.vocab_size) |
|
|
|
|
|
if self.config.base.get("test"): |
|
|
self.test_dataset = self.data_factory.load_dataset(self.config.dataset, split="test") |
|
|
self.test_dataloader = self.data_factory.get_data_loader(self.test_dataset, |
|
|
batch_size, |
|
|
shuffle=False, |
|
|
device=self.device, |
|
|
enable_label=True, |
|
|
align_mode=self.config.tokenizer.get( |
|
|
"_align_mode_"), |
|
|
label2tensor=False, |
|
|
**tokenizer_config) |
|
|
|
|
|
def eval(self, step: int, best_metric: float) -> float: |
|
|
""" evaluation models. |
|
|
|
|
|
Args: |
|
|
step (int): which step the model has trained in |
|
|
best_metric (float): last best metric value to judge whether to test or save model |
|
|
|
|
|
Returns: |
|
|
float: updated best metric value |
|
|
""" |
|
|
|
|
|
_, res = self.__evaluate(self.model, self.dev_dataloader, mode="dev") |
|
|
self.logger.log_metric(res, metric_split="dev", step=step) |
|
|
if res[self.config.evaluator.get("best_key")] > best_metric: |
|
|
best_metric = res[self.config.evaluator.get("best_key")] |
|
|
train_state = { |
|
|
"step": step, |
|
|
"best_metric": best_metric, |
|
|
"optimizer": self.optimizer.state_dict(), |
|
|
"lr_scheduler": self.lr_scheduler.state_dict() |
|
|
} |
|
|
self.saver.save_model(self.model, train_state, self.accelerator) |
|
|
if self.config.base.get("test"): |
|
|
outputs, test_res = self.__evaluate(self.model, self.test_dataloader, mode="test") |
|
|
self.saver.save_output(outputs, self.test_dataset) |
|
|
self.logger.log_metric(test_res, metric_split="test", step=step) |
|
|
return best_metric |
|
|
|
|
|
def train(self) -> float: |
|
|
""" train models. |
|
|
|
|
|
Returns: |
|
|
float: updated best metric value |
|
|
""" |
|
|
self.model.train() |
|
|
if self.accelerator is not None: |
|
|
self.total_step = math.ceil(self.total_step / self.accelerator.num_processes) |
|
|
if self.optimizer is None: |
|
|
self.optimizer = instantiate(self.config["optimizer"])(self.model.parameters()) |
|
|
if self.lr_scheduler is None: |
|
|
self.lr_scheduler = instantiate(self.config["scheduler"])( |
|
|
optimizer=self.optimizer, |
|
|
num_training_steps=self.total_step |
|
|
) |
|
|
if not self.prepared and self.accelerator is not None: |
|
|
self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare( |
|
|
self.model, self.optimizer, self.train_dataloader, self.lr_scheduler) |
|
|
step = self.init_step |
|
|
progress_bar = tqdm(range(self.total_step)) |
|
|
progress_bar.update(self.init_step) |
|
|
self.optimizer.zero_grad() |
|
|
for _ in range(int(self.config.base.get("epoch_num"))): |
|
|
for data in self.train_dataloader: |
|
|
if step == 0: |
|
|
self.logger.info(data.get_item( |
|
|
0, tokenizer=self.tokenizer, intent_map=self.intent_list, slot_map=self.slot_list)) |
|
|
output = self.model(data) |
|
|
if self.accelerator is not None and hasattr(self.model, "module"): |
|
|
loss, intent_loss, slot_loss = self.model.module.compute_loss( |
|
|
pred=output, target=data) |
|
|
else: |
|
|
loss, intent_loss, slot_loss = self.model.compute_loss( |
|
|
pred=output, target=data) |
|
|
self.logger.log_loss(loss, "Loss", step=step) |
|
|
self.logger.log_loss(intent_loss, "Intent Loss", step=step) |
|
|
self.logger.log_loss(slot_loss, "Slot Loss", step=step) |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
if self.accelerator is not None: |
|
|
self.accelerator.backward(loss) |
|
|
else: |
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
self.lr_scheduler.step() |
|
|
train_state = { |
|
|
"step": step, |
|
|
"best_metric": self.best_metric, |
|
|
"optimizer": self.optimizer.state_dict(), |
|
|
"lr_scheduler": self.lr_scheduler.state_dict() |
|
|
} |
|
|
if not self.saver.auto_save_step(self.model, train_state, self.accelerator): |
|
|
if not self.config.evaluator.get("eval_by_epoch") and step % self.config.evaluator.get("eval_step") == 0 and step != 0: |
|
|
self.best_metric = self.eval(step, self.best_metric) |
|
|
step += 1 |
|
|
progress_bar.update(1) |
|
|
if self.config.evaluator.get("eval_by_epoch"): |
|
|
self.best_metric = self.eval(step, self.best_metric) |
|
|
self.logger.finish() |
|
|
return self.best_metric |
|
|
|
|
|
def test(self): |
|
|
return self.__evaluate(self.model, self.test_dataloader, mode="test") |
|
|
|
|
|
def __set_seed(self, seed_value: int): |
|
|
"""Manually set random seeds. |
|
|
|
|
|
Args: |
|
|
seed_value (int): random seed |
|
|
""" |
|
|
random.seed(seed_value) |
|
|
np.random.seed(seed_value) |
|
|
torch.manual_seed(seed_value) |
|
|
torch.random.manual_seed(seed_value) |
|
|
os.environ['PYTHONHASHSEED'] = str(seed_value) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed_value) |
|
|
torch.cuda.manual_seed_all(seed_value) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = True |
|
|
return |
|
|
|
|
|
def __evaluate(self, model, dataloader, mode="dev"): |
|
|
model.eval() |
|
|
inps = InputData() |
|
|
outputs = OutputData() |
|
|
for data in dataloader: |
|
|
torch.cuda.empty_cache() |
|
|
output = model(data) |
|
|
if self.accelerator is not None and hasattr(self.model, "module"): |
|
|
decode_output = model.module.decode(output, data) |
|
|
else: |
|
|
decode_output = model.decode(output, data) |
|
|
|
|
|
decode_output.map_output(slot_map=self.slot_list, |
|
|
intent_map=self.intent_list) |
|
|
if self.config.model["decoder"].get("slot_classifier"): |
|
|
data, decode_output = utils.remove_slot_ignore_index( |
|
|
data, decode_output, ignore_index="#") |
|
|
|
|
|
inps.merge_input_data(data) |
|
|
outputs.merge_output_data(decode_output) |
|
|
if "metric" in self.config.evaluator: |
|
|
res = Evaluator.compute_all_metric( |
|
|
inps, outputs, intent_label_map=self.intent_dict, metric_list=self.config.evaluator["metric"]) |
|
|
else: |
|
|
res = Evaluator.compute_all_metric( |
|
|
inps, outputs, intent_label_map=self.intent_dict) |
|
|
self.logger.info(f"Best {mode} metric: "+str(res)) |
|
|
model.train() |
|
|
return outputs, res |
|
|
|
|
|
def load(self): |
|
|
|
|
|
if self.tokenizer is None: |
|
|
with open(os.path.join(self.load_dir, "tokenizer.pkl"), 'rb') as f: |
|
|
self.tokenizer = dill.load(f) |
|
|
label = utils.load_json(os.path.join(self.load_dir, "label.json")) |
|
|
if label["intent"] is None: |
|
|
self.intent_list = None |
|
|
self.intent_dict = None |
|
|
else: |
|
|
self.intent_list = label["intent"] |
|
|
self.intent_dict = {x: i for i, x in enumerate(label["intent"])} |
|
|
self.config.set_intent_label_num(len(self.intent_list)) |
|
|
if label["slot"] is None: |
|
|
self.slot_list = None |
|
|
self.slot_dict = None |
|
|
else: |
|
|
self.slot_list = label["slot"] |
|
|
self.slot_dict = {x: i for i, x in enumerate(label["slot"])} |
|
|
self.config.set_slot_label_num(len(self.slot_list)) |
|
|
self.config.set_vocab_size(self.tokenizer.vocab_size) |
|
|
if self.accelerator is not None and self.load_dir is not None: |
|
|
self.model = torch.load(os.path.join(self.load_dir, "model.pkl"), map_location=torch.device(self.device)) |
|
|
self.prepared = True |
|
|
self.accelerator.load_state(self.load_dir) |
|
|
self.accelerator.prepare_model(self.model) |
|
|
else: |
|
|
self.model = torch.load(os.path.join( |
|
|
self.load_dir, "model.pkl"), map_location=torch.device(self.device)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
def from_pretrained(self): |
|
|
self.config.autoload_template() |
|
|
model = PretrainedModelForSLU.from_pretrained(self.config.model["_from_pretrained_"]) |
|
|
|
|
|
self.model = model.model |
|
|
if self.tokenizer is None: |
|
|
self.tokenizer = PreTrainedTokenizerForSLU.from_pretrained( |
|
|
self.config.tokenizer["_from_pretrained_"]) |
|
|
self.config.tokenizer = model.config.tokenizer |
|
|
|
|
|
|
|
|
self.model.to(self.device) |
|
|
label = model.config._id2label |
|
|
self.config.model = model.config.model |
|
|
self.intent_list = label["intent"] |
|
|
self.slot_list = label["slot"] |
|
|
self.intent_dict = {x: i for i, x in enumerate(label["intent"])} |
|
|
self.slot_dict = {x: i for i, x in enumerate(label["slot"])} |
|
|
|
|
|
def predict(self, text_data): |
|
|
self.model.eval() |
|
|
tokenizer_config = {key: self.config.tokenizer[key] |
|
|
for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"} |
|
|
align_mode = self.config.tokenizer.get("_align_mode_") |
|
|
inputs = self.data_factory.batch_fn(batch=[{"text": text_data.split(" ")}], |
|
|
device=self.device, |
|
|
config=tokenizer_config, |
|
|
enable_label=False, |
|
|
align_mode=align_mode if align_mode is not None else "general", |
|
|
label2tensor=False) |
|
|
output = self.model(inputs) |
|
|
decode_output = self.model.decode(output, inputs) |
|
|
decode_output.map_output(slot_map=self.slot_list, |
|
|
intent_map=self.intent_list) |
|
|
if self.config.base.get("multi_intent"): |
|
|
intent = decode_output.intent_ids[0] |
|
|
else: |
|
|
intent = [decode_output.intent_ids[0]] |
|
|
input_ids = inputs.input_ids[0].tolist() |
|
|
tokens = [self.tokenizer.decode(ids) for ids in input_ids] |
|
|
slots = decode_output.slot_ids[0] |
|
|
return {"intent": intent, "slot": slots, "text": tokens} |
|
|
|