Spaces:
Running
Running
File size: 4,604 Bytes
11c72a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import gensim
import numpy as np
from gensim.models import ldaseqmodel
from tqdm import tqdm
import datetime
from multiprocessing.pool import Pool
from backend.datasets.utils import _utils
from backend.datasets.utils.logger import Logger
logger = Logger("WARNING")
def work(arguments):
model, docs = arguments
theta_list = list()
for doc in tqdm(docs):
theta_list.append(model[doc])
return theta_list
class DTMTrainer:
def __init__(self,
dataset,
num_topics=50,
num_top_words=15,
alphas=0.01,
chain_variance=0.005,
passes=10,
lda_inference_max_iter=25,
em_min_iter=6,
em_max_iter=20,
verbose=False
):
self.dataset = dataset
self.vocab_size = dataset.vocab_size
self.num_topics = num_topics
self.num_top_words = num_top_words
self.alphas = alphas
self.chain_variance = chain_variance
self.passes = passes
self.lda_inference_max_iter = lda_inference_max_iter
self.em_min_iter = em_min_iter
self.em_max_iter = em_max_iter
self.verbose = verbose
if verbose:
logger.set_level("DEBUG")
else:
logger.set_level("WARNING")
def train(self):
id2word = dict(zip(range(self.vocab_size), self.dataset.vocab))
train_bow = self.dataset.train_bow
train_times = self.dataset.train_times.astype('int32')
# order documents by time slices
self.doc_order_idx = np.argsort(train_times)
train_bow = train_bow[self.doc_order_idx]
time_slices = np.bincount(train_times)
corpus = gensim.matutils.Dense2Corpus(train_bow, documents_columns=False)
self.model = ldaseqmodel.LdaSeqModel(
corpus=corpus,
id2word=id2word,
time_slice=time_slices,
num_topics=self.num_topics,
alphas=self.alphas,
chain_variance=self.chain_variance,
em_min_iter=self.em_min_iter,
em_max_iter=self.em_max_iter,
lda_inference_max_iter=self.lda_inference_max_iter,
passes=self.passes
)
def test(self, bow):
# bow = dataset.bow.cpu().numpy()
# times = dataset.times.cpu().numpy()
corpus = gensim.matutils.Dense2Corpus(bow, documents_columns=False)
num_workers = 20
split_idx_list = np.array_split(np.arange(len(bow)), num_workers)
worker_size_list = [len(x) for x in split_idx_list]
worker_id = 0
docs_list = [list() for i in range(num_workers)]
for i, doc in enumerate(corpus):
docs_list[worker_id].append(doc)
if len(docs_list[worker_id]) >= worker_size_list[worker_id]:
worker_id += 1
args_list = list()
for docs in docs_list:
args_list.append([self.model, docs])
starttime = datetime.datetime.now()
pool = Pool(processes=num_workers)
results = pool.map(work, args_list)
pool.close()
pool.join()
theta_list = list()
for rst in results:
theta_list.extend(rst)
endtime = datetime.datetime.now()
print("DTM test time: {}s".format((endtime - starttime).seconds))
return np.asarray(theta_list)
def get_theta(self):
theta = self.model.gammas / self.model.gammas.sum(axis=1)[:, np.newaxis]
# NOTE: MUST transform gamma to original order.
return theta[np.argsort(self.doc_order_idx)]
def get_beta(self):
beta = list()
# K x V x T
for item in self.model.topic_chains:
# V x T
beta.append(item.e_log_prob)
# T x K x V
beta = np.transpose(np.asarray(beta), (2, 0, 1))
# use softmax
beta = np.exp(beta)
beta = beta / beta.sum(-1, keepdims=True)
return beta
def get_top_words(self, num_top_words=None):
if num_top_words is None:
num_top_words = self.num_top_words
beta = self.get_beta()
top_words_list = list()
for time in range(beta.shape[0]):
top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose)
top_words_list.append(top_words)
return top_words_list
def export_theta(self):
train_theta = self.get_theta()
test_theta = self.test(self.dataset.test_bow)
return train_theta, test_theta
|