File size: 5,599 Bytes
2d06dcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import logging
import os
import numpy as np
import copy
from utils.metrics import clustering_score
from sklearn.metrics import confusion_matrix
from keras.models import Model
from keras.optimizers import SGD
from tqdm import trange
from configs.base import ParamManager
from utils.functions import set_seed
from backbones.sae import get_sae, ClusteringLayer
from sklearn.cluster import KMeans

def target_distribution(q):
    weight = q ** 2 / q.sum(0)
    return (weight.T / weight.sum(1)).T

class DECManager:
    
    def __init__(self, args, data, model, logger_name = 'Discovery'):
        
        self.logger = logging.getLogger(logger_name)
        # self.sae = model.set_model(args, data, 'sae')

        self.tfidf_train, self.tfidf_test = data.dataloader.tfidf_train, data.dataloader.tfidf_test
        self.num_labels = data.num_labels
        self.test_y = data.dataloader.test_true_labels

        self.init_sae(args, data, model)
        set_seed(args.seed)
        if args.train:
            self.model, self.y_pred_init = self.init_model(args)

        else:
            clustering_layer = ClusteringLayer(self.num_labels, name='clustering')(self.sae.layers[3].output)
            self.model = Model(inputs=self.sae.input, outputs = [clustering_layer, self.sae.output])
            
            save_path = os.path.join(args.model_output_dir, args.model_name)
            self.logger.info('Loading models from %s' % save_path)
            self.model.load_weights(save_path)

    def init_sae(self, args, data, model):
        self.sae = model.set_model(args, data, 'sae')           
        self.sae_feats_path = os.path.join(args.model_output_dir, 'SAE.h5') 
 
        if os.path.exists(self.sae_feats_path):
            self.logger.info('Loading SAE features from %s' % self.sae_feats_path)
            self.sae.load_weights(self.sae_feats_path)
        else:
            self.logger.info('SAE (emb) training start...')  
            self.sae.fit(self.tfidf_train, self.tfidf_train, epochs = args.num_train_epochs_SAE, batch_size = args.SAE_batch_size, shuffle=True, 
                        validation_data=(self.tfidf_test, self.tfidf_test), verbose=1)
            self.logger.info('SAE (emb) training finished...') 

            if args.save_model:

                save_path = os.path.join(args.model_output_dir, 'SAE.h5')
                self.logger.info('Save models at %s', str(save_path))
                self.sae.save_weights(save_path)

    def init_model(self, args):

        sae_emb_train, sae_emb_test = get_sae(args, self.sae, self.tfidf_train, self.tfidf_test)
        clustering_layer = ClusteringLayer(self.num_labels, name='clustering')(self.sae.layers[3].output)
        model = Model(inputs=self.sae.input, outputs = clustering_layer)
        model.compile(optimizer=SGD(args.lr, args.momentum), loss='kld')

        km = KMeans(n_clusters=self.num_labels, n_init=20, n_jobs=-1, random_state=args.seed)
        y_pred = km.fit_predict(sae_emb_train)
        y_pred_last = np.copy(y_pred)
        model.get_layer(name='clustering').set_weights([km.cluster_centers_])

        return model, y_pred_last

    def train(self, args, data):

        self.logger.info('DEC training starts...')
        index = 0
        loss = 0
        index_array = np.arange(self.tfidf_train.shape[0])
        y_pred_last = self.y_pred_init

        for epoch in trange(int(args.num_train_epochs_DEC), desc="Epoch"):

            if  epoch % args.update_interval == 0:

                q = self.model.predict(self.tfidf_train, verbose=0)
                p = target_distribution(q)  

                y_pred = q.argmax(1)
                delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / y_pred.shape[0]
                y_pred_last = np.copy(y_pred)
                
                if epoch > 0:
                    
                    self.logger.info("***** Epoch: %s*****", str(epoch + 1))
                    self.logger.info('Training Loss: %f', np.round(loss, 5))
                    self.logger.info('Delta Label: %f', delta_label)

                    if delta_label < args.tol:
                        self.logger.info('delta_label %s < %f', delta_label, args.tol)  
                        self.logger.info('Reached tolerance threshold. Stop training.')
                        break

            idx = index_array[index * args.DEC_batch_size: min((index + 1) * args.DEC_batch_size, self.tfidf_train.shape[0])]
            loss = self.model.train_on_batch(x = self.tfidf_train[idx], y = p[idx])
            index = index + 1 if (index + 1) * args.DEC_batch_size <= self.tfidf_train.shape[0] else 0

        self.logger.info('DEC training finished...')

        if args.save_model:
            save_path = os.path.join(args.model_output_dir, args.model_name)
            self.model.save_weights(save_path)


    def test(self, args, data, show=False):

        q = self.model.predict(self.tfidf_test, verbose = 0)
        y_pred = q.argmax(1)
        y_true = self.test_y

        test_results = clustering_score(y_true, y_pred)
        cm = confusion_matrix(y_true,y_pred) 
        
        if show:
            self.logger.info
            self.logger.info("***** Test: Confusion Matrix *****")
            self.logger.info("%s", str(cm))
            self.logger.info("***** Test results *****")
            
            for key in sorted(test_results.keys()):
                self.logger.info("  %s = %s", key, str(test_results[key]))

        test_results['y_true'] = y_true
        test_results['y_pred'] = y_pred  

        return test_results