|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pytorch_lightning as pl |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
from nemo.collections.nlp.models import IntentSlotClassificationModel |
|
|
from nemo.core.config import hydra_runner |
|
|
from nemo.utils import logging |
|
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
|
|
|
|
|
@hydra_runner(config_path="conf", config_name="intent_slot_classification_config") |
|
|
def main(cfg: DictConfig) -> None: |
|
|
logging.info(f'Config Params:\n {OmegaConf.to_yaml(cfg)}') |
|
|
trainer = pl.Trainer(**cfg.trainer) |
|
|
exp_manager(trainer, cfg.get("exp_manager", None)) |
|
|
|
|
|
|
|
|
model = IntentSlotClassificationModel(cfg.model, trainer=trainer) |
|
|
|
|
|
|
|
|
logging.info("================================================================================================") |
|
|
logging.info('Starting training...') |
|
|
trainer.fit(model) |
|
|
logging.info('Training finished!') |
|
|
|
|
|
|
|
|
if trainer.fast_dev_run: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
logging.info("================================================================================================") |
|
|
logging.info("Starting the testing of the trained model on test set...") |
|
|
logging.info("We will load the latest model saved checkpoint from the training...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_model = model |
|
|
|
|
|
|
|
|
eval_model.update_data_dir_for_testing(data_dir=cfg.model.data_dir) |
|
|
eval_model.setup_test_data(test_data_config=cfg.model.test_ds) |
|
|
|
|
|
trainer.test(model=eval_model, ckpt_path=None, verbose=False) |
|
|
logging.info("Testing finished!") |
|
|
|
|
|
|
|
|
logging.info("======================================================================================") |
|
|
logging.info("Evaluate the model on the given queries...") |
|
|
|
|
|
|
|
|
|
|
|
queries = [ |
|
|
'set alarm for seven thirty am', |
|
|
'lower volume by fifty percent', |
|
|
'what is my schedule for tomorrow', |
|
|
] |
|
|
|
|
|
pred_intents, pred_slots = eval_model.predict_from_examples(queries, cfg.model.test_ds) |
|
|
|
|
|
logging.info('The prediction results of some sample queries with the trained model:') |
|
|
for query, intent, slots in zip(queries, pred_intents, pred_slots): |
|
|
logging.info(f'Query : {query}') |
|
|
logging.info(f'Predicted Intent: {intent}') |
|
|
logging.info(f'Predicted Slots: {slots}') |
|
|
|
|
|
logging.info("Inference finished!") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|