from configs.base import ParamManager from dataloaders.base import DataManager from backbones.base import ModelManager from methods import method_map from utils.functions import save_results import logging import argparse import sys import os import datetime def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--type', type=str, default='open_intent_detection', help="Type style.") parser.add_argument('--logger_name', type=str, default='Detection', help="Logger name for open intent detection.") parser.add_argument('--log_dir', type=str, default='logs', help="Logger directory.") parser.add_argument("--dataset", default='banking', type=str, help="The name of the dataset to train selected") parser.add_argument("--known_cls_ratio", default=0.75, type=float, help="The number of known classes") parser.add_argument("--labeled_ratio", default=1.0, type=float, help="The ratio of labeled samples in the training set") parser.add_argument("--method", type=str, default='ADB', help="which method to use") parser.add_argument("--train", action="store_true", help="Whether to train the model") parser.add_argument("--pretrain", action="store_true", help="Whether to pre-train the model") parser.add_argument("--save_model", action="store_true", help="save trained-model for open intent detection") parser.add_argument("--backbone", type=str, default='bert', help="which backbone to use") parser.add_argument("--config_file_name", type=str, default='ADB.py', help = "The name of the config file.") parser.add_argument('--seed', type=int, default=0, help="random seed for initialization") parser.add_argument("--gpu_id", type=str, default='0', help="Select the GPU id") parser.add_argument("--pipe_results_path", type=str, default='pipe_results', help="the path to save results of pipeline methods") parser.add_argument("--data_dir", default = sys.path[0]+'/../data', type=str, help="The input data dir. Should contain the .csv files (or other data files) for the task.") parser.add_argument("--output_dir", default= '/home/sharing/disk1/zhaoshaojie/baseline_test/TEXTOIR', type=str, help="The output directory where all train data will be written.") parser.add_argument("--model_dir", default='models', type=str, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--load_pretrained_method", default=None, type=str, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--result_dir", type=str, default = 'results', help="The path to save results") parser.add_argument("--results_file_name", type=str, default = 'results.csv', help="The file name of all the results.") parser.add_argument("--save_results", action="store_true", help="save final results for open intent detection") parser.add_argument("--loss_fct", default="CrossEntropyLoss", help="The loss function for training.") args = parser.parse_args() return args def set_logger(args): if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') file_name = f"{args.method}_{args.dataset}_{args.backbone}_{args.known_cls_ratio}_{args.labeled_ratio}_{time}.log" logger = logging.getLogger(args.logger_name) logger.setLevel(logging.DEBUG) fh = logging.FileHandler(os.path.join(args.log_dir, file_name)) fh_formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s') fh.setFormatter(fh_formatter) fh.setLevel(logging.DEBUG) logger.addHandler(fh) ch = logging.StreamHandler() ch.setLevel(logging.INFO) ch_formatter = logging.Formatter('%(name)s - %(message)s') ch.setFormatter(ch_formatter) logger.addHandler(ch) return logger def run(args, data, model, logger): method_manager = method_map[args.method] method = method_manager(args, data, model, logger_name = args.logger_name) if args.train: logger.info('Training Begin...') method.train(args, data) logger.info('Training Finished...') logger.info('Testing begin...') outputs = method.test(args, data) logger.info('Testing finished...') if args.save_results: logger.info('Results saved in %s', str(os.path.join(args.result_dir, args.results_file_name))) save_results(args, outputs) if __name__ == '__main__': sys.path.append('.') args = parse_arguments() logger = set_logger(args) logger.info('Open Intent Detection Begin...') logger.info('Parameters Initialization...') param = ParamManager(args) args = param.args logger.debug("="*30+" Params "+"="*30) for k in args.keys(): logger.debug(f"{k}:\t{args[k]}") logger.debug("="*30+" End Params "+"="*30) logger.info('Data and Model Preparation...') data = DataManager(args, logger_name = args.logger_name) model = ModelManager(args, data, logger_name = args.logger_name) run(args, data, model, logger) logger.info('Open Intent Detection Finished...')