import logging import os import numpy as np import copy from utils.metrics import clustering_score from sklearn.metrics import confusion_matrix from keras.models import Model from keras.optimizers import SGD from tqdm import trange from configs.base import ParamManager from utils.functions import set_seed from backbones.sae import get_sae, ClusteringLayer from sklearn.cluster import KMeans def target_distribution(q): weight = q ** 2 / q.sum(0) return (weight.T / weight.sum(1)).T class DECManager: def __init__(self, args, data, model, logger_name = 'Discovery'): self.logger = logging.getLogger(logger_name) # self.sae = model.set_model(args, data, 'sae') self.tfidf_train, self.tfidf_test = data.dataloader.tfidf_train, data.dataloader.tfidf_test self.num_labels = data.num_labels self.test_y = data.dataloader.test_true_labels self.init_sae(args, data, model) set_seed(args.seed) if args.train: self.model, self.y_pred_init = self.init_model(args) else: clustering_layer = ClusteringLayer(self.num_labels, name='clustering')(self.sae.layers[3].output) self.model = Model(inputs=self.sae.input, outputs = [clustering_layer, self.sae.output]) save_path = os.path.join(args.model_output_dir, args.model_name) self.logger.info('Loading models from %s' % save_path) self.model.load_weights(save_path) def init_sae(self, args, data, model): self.sae = model.set_model(args, data, 'sae') self.sae_feats_path = os.path.join(args.model_output_dir, 'SAE.h5') if os.path.exists(self.sae_feats_path): self.logger.info('Loading SAE features from %s' % self.sae_feats_path) self.sae.load_weights(self.sae_feats_path) else: self.logger.info('SAE (emb) training start...') self.sae.fit(self.tfidf_train, self.tfidf_train, epochs = args.num_train_epochs_SAE, batch_size = args.SAE_batch_size, shuffle=True, validation_data=(self.tfidf_test, self.tfidf_test), verbose=1) self.logger.info('SAE (emb) training finished...') if args.save_model: save_path = os.path.join(args.model_output_dir, 'SAE.h5') self.logger.info('Save models at %s', str(save_path)) self.sae.save_weights(save_path) def init_model(self, args): sae_emb_train, sae_emb_test = get_sae(args, self.sae, self.tfidf_train, self.tfidf_test) clustering_layer = ClusteringLayer(self.num_labels, name='clustering')(self.sae.layers[3].output) model = Model(inputs=self.sae.input, outputs = clustering_layer) model.compile(optimizer=SGD(args.lr, args.momentum), loss='kld') km = KMeans(n_clusters=self.num_labels, n_init=20, n_jobs=-1, random_state=args.seed) y_pred = km.fit_predict(sae_emb_train) y_pred_last = np.copy(y_pred) model.get_layer(name='clustering').set_weights([km.cluster_centers_]) return model, y_pred_last def train(self, args, data): self.logger.info('DEC training starts...') index = 0 loss = 0 index_array = np.arange(self.tfidf_train.shape[0]) y_pred_last = self.y_pred_init for epoch in trange(int(args.num_train_epochs_DEC), desc="Epoch"): if epoch % args.update_interval == 0: q = self.model.predict(self.tfidf_train, verbose=0) p = target_distribution(q) y_pred = q.argmax(1) delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / y_pred.shape[0] y_pred_last = np.copy(y_pred) if epoch > 0: self.logger.info("***** Epoch: %s*****", str(epoch + 1)) self.logger.info('Training Loss: %f', np.round(loss, 5)) self.logger.info('Delta Label: %f', delta_label) if delta_label < args.tol: self.logger.info('delta_label %s < %f', delta_label, args.tol) self.logger.info('Reached tolerance threshold. Stop training.') break idx = index_array[index * args.DEC_batch_size: min((index + 1) * args.DEC_batch_size, self.tfidf_train.shape[0])] loss = self.model.train_on_batch(x = self.tfidf_train[idx], y = p[idx]) index = index + 1 if (index + 1) * args.DEC_batch_size <= self.tfidf_train.shape[0] else 0 self.logger.info('DEC training finished...') if args.save_model: save_path = os.path.join(args.model_output_dir, args.model_name) self.model.save_weights(save_path) def test(self, args, data, show=False): q = self.model.predict(self.tfidf_test, verbose = 0) y_pred = q.argmax(1) y_true = self.test_y test_results = clustering_score(y_true, y_pred) cm = confusion_matrix(y_true,y_pred) if show: self.logger.info self.logger.info("***** Test: Confusion Matrix *****") self.logger.info("%s", str(cm)) self.logger.info("***** Test results *****") for key in sorted(test_results.keys()): self.logger.info(" %s = %s", key, str(test_results[key])) test_results['y_true'] = y_true test_results['y_pred'] = y_pred return test_results