File size: 1,500 Bytes
2d06dcc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import numpy as np
import os
import logging
from .__init__ import max_seq_lengths, backbone_loader_map, benchmark_labels
class DataManager:
def __init__(self, args, logger_name = 'Discovery'):
self.logger = logging.getLogger(logger_name)
args.max_seq_length = max_seq_lengths[args.dataset]
self.data_dir = os.path.join(args.data_dir, args.dataset)
self.all_label_list = self.get_labels(args.dataset)
if args.setting == 'semi_supervised':
self.n_known_cls = round(len(self.all_label_list) * args.known_cls_ratio)
self.known_label_list = list(np.random.choice(np.array(self.all_label_list), self.n_known_cls, replace=False))
self.logger.info('The number of known intents is %s', self.n_known_cls)
self.logger.info('Lists of known labels are: %s', str(self.known_label_list))
args.num_labels = self.num_labels = int(len(self.all_label_list) * args.cluster_num_factor)
self.dataloader = self.get_loader(args, self.get_attrs())
def get_labels(self, dataset):
labels = benchmark_labels[dataset]
return labels
def get_loader(self, args, attrs):
dataloader = backbone_loader_map[args.backbone](args, attrs)
return dataloader
def get_attrs(self):
attrs = {}
for name, value in vars(self).items():
attrs[name] = value
return attrs
|