import numpy as np import os import random import torch import logging from .__init__ import max_seq_lengths, backbone_loader_map, benchmark_labels def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True class DataManager: def __init__(self, args, logger_name = 'Detection'): self.logger = logging.getLogger(logger_name) set_seed(args.seed) args.max_seq_length = max_seq_lengths[args.dataset] self.data_dir = os.path.join(args.data_dir, args.dataset) self.all_label_list = self.get_labels(args.dataset) self.n_known_cls = round(len(self.all_label_list) * args.known_cls_ratio) self.known_label_list = np.random.choice(np.array(self.all_label_list), self.n_known_cls, replace=False) self.known_label_list = list(self.known_label_list) self.logger.info('The number of known intents is %s', self.n_known_cls) self.logger.info('Lists of known labels are: %s', str(self.known_label_list)) args.num_labels = self.num_labels = len(self.known_label_list) if args.dataset == 'oos': self.unseen_label = 'oos' else: self.unseen_label = '' args.unseen_label_id = self.unseen_label_id = self.num_labels self.label_list = self.known_label_list + [self.unseen_label] self.anum_labels = args.anum_labels = len(self.label_list) self.dataloader = self.get_loader(args, self.get_attrs()) def get_labels(self, dataset): labels = benchmark_labels[dataset] return labels def get_loader(self, args, attrs): dataloader = backbone_loader_map[args.backbone](args, attrs, args.logger_name) return dataloader def get_attrs(self): attrs = {} for name, value in vars(self).items(): attrs[name] = value return attrs