kisejin commited on
Commit
02c66e0
·
verified ·
1 Parent(s): 9b9182e

Upload my_topic_modeling.py

Browse files
Files changed (1) hide show
  1. BERTopic/my_topic_modeling.py +600 -0
BERTopic/my_topic_modeling.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # Miscellaneous operating system interfaces
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+
4
+ from os.path import join # path joining
5
+ from pathlib import Path # path joining
6
+
7
+ import pandas as pd
8
+ import numpy as np
9
+ import sklearn as sk
10
+ from sklearn.cluster import KMeans
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+ import regex as re
14
+ from scipy.cluster import hierarchy as sch
15
+
16
+ import datetime
17
+ import time
18
+ import timeit
19
+ import json
20
+ import pickle
21
+
22
+ import copy
23
+ import random
24
+ from itertools import chain
25
+
26
+ import logging
27
+ import sys
28
+ import argparse
29
+ import nltk
30
+ nltk.download('wordnet')
31
+ nltk.download('punkt')
32
+
33
+
34
+ import textblob
35
+ from textblob import TextBlob
36
+ from textblob.wordnet import Synset
37
+ from textblob import Word
38
+ from textblob.wordnet import VERB
39
+
40
+ from bertopic import BERTopic
41
+ from bertopic.vectorizers import ClassTfidfTransformer
42
+ from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance
43
+ from sklearn.datasets import fetch_20newsgroups
44
+ from sklearn.feature_extraction.text import CountVectorizer
45
+ from sentence_transformers import SentenceTransformer
46
+ # from cuml.manifold import UMAP
47
+ # from umap import UMAP
48
+ # from hdbscan import HDBSCAN
49
+ from cuml.cluster import HDBSCAN
50
+ from cuml.manifold import UMAP
51
+
52
+ import gensim.corpora as corpora
53
+ from gensim.models.coherencemodel import CoherenceModel
54
+
55
+
56
+ # Get working directory
57
+ working_dir = os.path.abspath(os.path.join("/notebooks", "TopicModelingRepo"))
58
+ data_dir = os.path.join(working_dir, 'data')
59
+ lib_dir = os.path.join(working_dir, 'libs')
60
+ outer_output_dir = os.path.join(working_dir, 'outputs')
61
+
62
+ output_dir_name = time.strftime('%Y_%m_%d')
63
+ # output_dir_name = time.strftime(args.datetime)
64
+
65
+ output_dir = os.path.join(outer_output_dir, output_dir_name)
66
+ if not os.path.exists(output_dir):
67
+ os.makedirs(output_dir)
68
+
69
+ stopwords_path = os.path.join(data_dir, 'vietnamese_stopwords_dash.txt')
70
+
71
+ # Setting variables
72
+ doc_time = '2024_Jan_15'
73
+ doc_type = 'reviews'
74
+ doc_level = 'sentence'
75
+ target_col = 'normalized_content'
76
+
77
+
78
+
79
+ def create_logger_file_and_console():
80
+ # create logger for "Sample App"
81
+ logger = logging.getLogger('automated_testing')
82
+ logger.setLevel(logging.DEBUG)
83
+
84
+ # create file handler which logs even debug messages
85
+ fileh = logging.FileHandler('info.log', mode='a')
86
+ fileh.setLevel(logging.DEBUG)
87
+
88
+ # create console handler with a higher log level
89
+ consoleh = logging.StreamHandler(stream=sys.stdout)
90
+ consoleh.setLevel(logging.INFO)
91
+
92
+ # create formatter and add it to the handlers
93
+ formatter = logging.Formatter('[%(asctime)s] %(levelname)8s --- %(message)s ',datefmt='%H:%M:%S')
94
+ fileh.setFormatter(formatter)
95
+ consoleh.setFormatter(formatter)
96
+
97
+ # add the handlers to the logger
98
+ # logger.addHandler(consoleh)
99
+ logger.addHandler(fileh)
100
+
101
+ return logger
102
+
103
+ def create_logger_file():
104
+ # create logger for "Sample App"
105
+ logger = logging.getLogger('automated_testing')
106
+ logger.setLevel(logging.INFO)
107
+
108
+ # create file handler which logs even debug messages
109
+ fileh = logging.FileHandler('info.log', mode='a')
110
+ fileh.setLevel(logging.INFO)
111
+
112
+ # create formatter and add it to the handlers
113
+ formatter = logging.Formatter('[%(asctime)s] %(levelname)8s --- %(message)s ',datefmt='%H:%M:%S')
114
+ fileh.setFormatter(formatter)
115
+
116
+ # add the handlers to the logger
117
+ logger.addHandler(fileh)
118
+
119
+ return logger
120
+
121
+ def create_logger_console():
122
+ # create logger for "Sample App"
123
+ logger = logging.getLogger('automated_testing')
124
+ logger.setLevel(logging.INFO)
125
+
126
+
127
+ # create console handler with a higher log level
128
+ consoleh = logging.StreamHandler(stream=sys.stdout)
129
+ consoleh.setLevel(logging.INFO)
130
+
131
+ # create formatter and add it to the handlers
132
+ formatter = logging.Formatter('[%(asctime)s] %(levelname)8s --- %(message)s ',datefmt='%H:%M:%S')
133
+ consoleh.setFormatter(formatter)
134
+
135
+ # add the handlers to the logger
136
+ logger.addHandler(consoleh)
137
+
138
+ return logger
139
+
140
+
141
+ def init_args():
142
+ parser = argparse.ArgumentParser()
143
+ # basic settings
144
+ parser.add_argument(
145
+ "--n_topics",
146
+ type=int,
147
+ default=10,
148
+ required=True,
149
+ help="Number of topics for topic modeling.",
150
+ )
151
+
152
+ parser.add_argument(
153
+ "--name_dataset",
154
+ default="booking",
155
+ type=str,
156
+ help="The name of the dataset, selected from: [booking, tripadvisor]",
157
+ )
158
+
159
+ parser.add_argument(
160
+ "--train_both",
161
+ default="yes",
162
+ type=str,
163
+ required=True,
164
+ help="Train both booking and tripadvisor or only one.",
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--only_coherence_score",
169
+ default="yes",
170
+ type=str,
171
+ required=True,
172
+ help="Only train both models for calculating coherence score.",
173
+ )
174
+
175
+ args = parser.parse_args()
176
+
177
+
178
+ return args
179
+
180
+ def check_valid(list_topics):
181
+ count = 0
182
+ for topic in list_topics:
183
+ if topic[0] != '':
184
+ count += 1
185
+
186
+ return True if count > 2 else False
187
+
188
+
189
+ def prepare_data(doc_source, doc_type, type_framework = 'pandas'):
190
+ name_file = doc_source.split('.')[0]
191
+ out_dir = os.path.join(output_dir, name_file)
192
+ if not os.path.exists(out_dir):
193
+ os.makedirs(out_dir)
194
+
195
+ date_col = 'Date'
196
+ df_reviews_path = os.path.join(data_dir, doc_source)
197
+
198
+ if type_framework == 'pandas':
199
+ df_reviews = pd.read_csv(df_reviews_path, lineterminator='\n', encoding='utf-8') # Pandas
200
+ df_reviews = df_reviews.loc[df_reviews['year']>0] # Pandas
201
+ df_reviews = df_reviews.loc[df_reviews['language'] == 'English'] # Pandas
202
+
203
+ if doc_type == 'reviews':
204
+ df_doc = df_reviews
205
+ df_doc['dates'] = pd.to_datetime(df_doc[date_col],dayfirst=False,errors='coerce'). \
206
+ dt.to_period('M'). \
207
+ dt.strftime('%Y-%m-%d') # pandas
208
+
209
+
210
+ # timestamps = df_doc['dates'].to_list()
211
+ # df_doc = df_doc.loc[(df_doc['dates']>='2020-04-01') & (df_doc['dates']<'2022-01-01')]
212
+
213
+ df_doc['dates_yearly'] = pd.to_datetime(df_doc[date_col],dayfirst=False,errors='coerce'). \
214
+ dt.to_period('Y'). \
215
+ dt.strftime('%Y') # pandas
216
+
217
+
218
+ df_doc['dates_quarterly'] = pd.to_datetime(df_doc[date_col],dayfirst=False,errors='coerce'). \
219
+ dt.to_period('d'). \
220
+ dt.strftime('%YQ%q') # pandas
221
+
222
+
223
+ df_doc['dates_monthly'] = pd.to_datetime(df_doc[date_col],dayfirst=False,errors='coerce'). \
224
+ dt.to_period('M'). \
225
+ dt.strftime('%Y-%m')
226
+
227
+ elif type_framework == 'polars':
228
+ df_reviews = pl.read_csv(df_reviews_path, separator='\n') # Polars
229
+ df_reviews = df_reviews.filter(pl.col("year")>0) # Polars
230
+ df_reviews = df_reviews.filter(pl.col('language') == 'English') # Polars
231
+
232
+ if doc_type == 'reviews':
233
+ df_doc = df_reviews
234
+
235
+ df_doc = df_doc.with_column(pl.col(date_col).str_to_datetime(dayfirst=False, errors='coerce'). \
236
+ to_period('M'). \
237
+ strftime('%Y-%m-%d').alias('dates')) # polars
238
+
239
+ df_doc = df_doc.with_column(pl.col(date_col).str_to_datetime(dayfirst=False, errors='coerce'). \
240
+ to_period('Y'). \
241
+ strftime('%Y').alias('dates_yearly')) # polars
242
+
243
+
244
+ df_doc = df_doc.with_column(pl.col(date_col).str_to_datetime(dayfirst=False, errors='coerce'). \
245
+ to_period('d'). \
246
+ strftime('%YQ%q').alias('dates_quarterly')) # polars
247
+
248
+
249
+ df_doc = df_doc.with_column(pl.col(date_col).str_to_datetime(dayfirst=False, errors='coerce'). \
250
+ to_period('M'). \
251
+ strftime('%Y-%m').alias('dates_monthly')) # polars
252
+
253
+
254
+ timestamps_dict = dict()
255
+ timestamps_dict['yearly'] = df_doc['dates_yearly'].to_list()
256
+ timestamps_dict['quarterly'] = df_doc['dates_quarterly'].to_list()
257
+ timestamps_dict['monthly'] = df_doc['dates_monthly'].to_list()
258
+ timestamps_dict['date'] = df_doc['dates'].to_list()
259
+
260
+ target_col = 'normalized_content'
261
+ df_documents = df_doc[target_col]
262
+
263
+ return (timestamps_dict, df_doc, df_documents, df_reviews)
264
+
265
+ def flatten_comprehension(matrix):
266
+ return [item for row in matrix for item in row]
267
+
268
+ def processing_data(df_doc, df_documents, timestamps_dict, doc_level, target_col):
269
+
270
+ if doc_level == 'sentence':
271
+ # num_sent = [len(TextBlob(row).sentences) for row in df_doc[target_col]]
272
+ # df_documents = pd.Series(flatten_comprehension([[str(sentence) for sentence in TextBlob(row).sentences] for row in df_documents]))
273
+
274
+ # Split sentence which "."
275
+ ll_sent = [[str(sent) for sent in nltk.sent_tokenize(row,language='english')] for row in df_doc[target_col]]
276
+
277
+ # Count number sentence for each comment
278
+ num_sent = [len(x) for x in ll_sent]
279
+
280
+ # Flat m' sentence in N comment to m'*N comment
281
+ df_documents = pd.Series(flatten_comprehension([x for x in ll_sent]))
282
+
283
+ # timestamps = list(chain.from_iterable(n*[item] for item, n in zip(timestamps, num_sent)))
284
+
285
+ # Copy timestamp features to number sentence times for each comment and flatten them adopt with new m'*N comment
286
+ for key in timestamps_dict.keys():
287
+ timestamps_dict[key] = list(chain.from_iterable(n*[item] for item, n in zip(timestamps_dict[key], num_sent)))
288
+ # time_slice = df_doc['year'].value_counts().sort_index().tolist()
289
+ # time_slice = np.diff([np.cumsum(num_sent)[n-1] for n in np.cumsum(time_slice)],prepend=0).tolist()
290
+ # elif doc_level == 'whole':
291
+ # df_documents
292
+
293
+ # Copy id features to number sentence times for each comment and flatten them adopt with new m'*N comment
294
+ sent_id_ll = [[j]*num_sent[i] for i,j in enumerate(df_doc.index)]
295
+ sent_id = flatten_comprehension(sent_id_ll)
296
+
297
+ # Define a new data frame with new m'*N comment
298
+ df_doc_out = pd.DataFrame({
299
+ 'sentence':df_documents, 'review_id':sent_id,
300
+ 'date':timestamps_dict['date'],
301
+ 'monthly':timestamps_dict['monthly'],
302
+ 'quarterly':timestamps_dict['quarterly'],
303
+ 'yearly':timestamps_dict['yearly']})
304
+
305
+
306
+ return df_documents, timestamps_dict, sent_id, df_doc_out
307
+
308
+ def create_model_bertopic_booking(n_topics: int = 10):
309
+ sentence_model = SentenceTransformer("thenlper/gte-small")
310
+
311
+ # Get 50 neighbor datapoints and 10 dimensional with metric distance: euclidean
312
+ umap_model = UMAP(n_neighbors=50, n_components=10,
313
+ min_dist=0.0, metric='euclidean',
314
+ low_memory=True,
315
+ random_state=1)
316
+
317
+
318
+ cluster_model = HDBSCAN(min_cluster_size=50, metric='euclidean',
319
+ cluster_selection_method='leaf',
320
+ # cluster_selection_method='eom',
321
+ prediction_data=True,
322
+ leaf_size=20,
323
+ min_samples=10)
324
+
325
+
326
+ # cluster_model = AgglomerativeClustering(n_clusters=11)
327
+ vectorizer_model = CountVectorizer(min_df=1,ngram_range=(1, 1),stop_words="english")
328
+ ctfidf_model = ClassTfidfTransformer()
329
+ # representation_model = KeyBERTInspired()
330
+
331
+ # Diversity param is lambda in equation of Maximal Marginal Relevance
332
+ representation_model = MaximalMarginalRelevance(diversity=0.7,top_n_words=10)
333
+
334
+
335
+ # Create model
336
+ topic_model = BERTopic(embedding_model=sentence_model,
337
+ umap_model=umap_model,
338
+ hdbscan_model=cluster_model,
339
+ vectorizer_model=vectorizer_model,
340
+ ctfidf_model=ctfidf_model,
341
+ representation_model=representation_model,
342
+ # zeroshot_topic_list=zeroshot_topic_list,
343
+ # zeroshot_min_similarity=0.7,
344
+ nr_topics = n_topics,
345
+ top_n_words = 10,
346
+ low_memory=True,
347
+ verbose=True)
348
+
349
+ return topic_model
350
+
351
+ def create_model_bertopic_tripadvisor(n_topics: int = 10):
352
+ sentence_model = SentenceTransformer("thenlper/gte-small")
353
+
354
+ # Get 50 neighbor datapoints and 10 dimensional with metric distance: euclidean
355
+ umap_model = UMAP(n_neighbors=200, n_components=10,
356
+ min_dist=0.0, metric='euclidean',
357
+ low_memory=True,
358
+ random_state=1)
359
+
360
+
361
+ cluster_model = HDBSCAN(min_cluster_size=500, metric='euclidean',
362
+ cluster_selection_method='leaf',
363
+ prediction_data=True,
364
+ leaf_size=100,
365
+ min_samples=10)
366
+
367
+
368
+ # cluster_model = AgglomerativeClustering(n_clusters=11)
369
+ vectorizer_model = CountVectorizer(min_df=10,ngram_range=(1, 1),stop_words="english")
370
+ ctfidf_model = ClassTfidfTransformer()
371
+ # representation_model = KeyBERTInspired()
372
+
373
+ # Diversity param is lambda in equation of Maximal Marginal Relevance
374
+ representation_model = MaximalMarginalRelevance(diversity=0.7,top_n_words=10)
375
+
376
+
377
+ # Create model
378
+ topic_model = BERTopic(embedding_model=sentence_model,
379
+ umap_model=umap_model,
380
+ hdbscan_model=cluster_model,
381
+ vectorizer_model=vectorizer_model,
382
+ ctfidf_model=ctfidf_model,
383
+ representation_model=representation_model,
384
+ # zeroshot_topic_list=zeroshot_topic_list,
385
+ # zeroshot_min_similarity=0.7,
386
+ nr_topics = n_topics,
387
+ top_n_words = 10,
388
+ low_memory=True,
389
+ verbose=True)
390
+
391
+ return topic_model
392
+
393
+
394
+ def coherence_score(topic_model, df_documents):
395
+ cleaned_docs = topic_model._preprocess_text(df_documents)
396
+ vectorizer = topic_model.vectorizer_model
397
+ analyzer = vectorizer.build_analyzer()
398
+ tokens = [analyzer(doc) for doc in cleaned_docs]
399
+ dictionary = corpora.Dictionary(tokens)
400
+ corpus = [dictionary.doc2bow(token) for token in tokens]
401
+ topics = topic_model.get_topics()
402
+
403
+ topic_words = [
404
+ [word for word, _ in topic_model.get_topic(topic) if word != ""] for topic in topics if check_valid(topic_model.get_topic(topic))
405
+ ]
406
+
407
+ coherence_model = CoherenceModel(topics=topic_words,
408
+ texts=tokens,
409
+ corpus=corpus,
410
+ dictionary=dictionary,
411
+ coherence='c_npmi')
412
+ coherence = coherence_model.get_coherence()
413
+ return coherence
414
+
415
+ def working(args: argparse.Namespace, name_dataset: str):
416
+
417
+ ############# Create logger##################################
418
+ fandc_logger = create_logger_file_and_console()
419
+ file_logger = create_logger_file()
420
+ console_logger = create_logger_console()
421
+ ##############################################################
422
+
423
+ ######### Create dataframe for dataset booking and tripadvisor #####
424
+ fandc_logger.log(logging.INFO, f'Get data from {name_dataset}')
425
+ doc_source = f'en_{name_dataset}.csv'
426
+ list_tmp = prepare_data(doc_source, doc_type, type_framework = 'pandas')
427
+
428
+ (timestamps_dict, df_doc,
429
+ df_documents, df_reviews) = list_tmp
430
+
431
+ fandc_logger.log(logging.INFO, f'Get data from {name_dataset} successfully!')
432
+ ####################################################################
433
+
434
+ ######### Processing data for booking and tripadvisor dataset #########
435
+ fandc_logger.log(logging.INFO, f'Processing data for {name_dataset} dataset')
436
+ (df_documents, timestamps_dict,
437
+ sent_id, df_doc_out) = processing_data(df_doc, df_documents, timestamps_dict, doc_level, target_col)
438
+ fandc_logger.log(logging.INFO, f'Processing data for {name_dataset} dataset successfully!')
439
+ #######################################################################
440
+
441
+ source = f'en_{name_dataset}'
442
+ output_subdir_name = source + f'/bertopic2_non_zeroshot_{args.n_topics}topic_'+doc_type+'_'+doc_level+'_'+doc_time
443
+ output_subdir = os.path.join(output_dir, output_subdir_name)
444
+ if not os.path.exists(output_subdir):
445
+ os.makedirs(output_subdir)
446
+
447
+
448
+ # Create model
449
+ fandc_logger.log(logging.INFO, f'Create model for {name_dataset} dataset')
450
+ topic_model = create_model_bertopic_booking(args.n_topics) if name_dataset == 'booking' else create_model_bertopic_tripadvisor(args.n_topics)
451
+
452
+ # Fitting model
453
+ fandc_logger.log(logging.INFO, f'Training model for {name_dataset} dataset')
454
+ fandc_logger.log(logging.INFO, f'Fitting model processing...')
455
+ t_start = time.time()
456
+ t = time.process_time()
457
+ topic_model = topic_model.fit(df_documents)
458
+ elapsed_time = time.process_time() - t
459
+ t_end = time.time()
460
+ fandc_logger.log(logging.INFO, f'Time working for fitting process: {t_end - t_start}\t --- \t Time model processing:{elapsed_time}')
461
+ console_logger.log(logging.INFO, 'End of fitting process')
462
+
463
+ topics_save_dir = os.path.join(output_subdir, 'topics_bertopic_'+doc_type+'_'+doc_level+'_'+doc_time)
464
+ topic_model.save(topics_save_dir, serialization="safetensors", save_ctfidf=True, save_embedding_model=True)
465
+ fandc_logger.log(logging.INFO, f'Save fitting model for {name_dataset} dataset successfully!')
466
+
467
+ # Transform model
468
+ t_start = time.time()
469
+ t = time.process_time()
470
+ topics, probs = topic_model.transform(df_documents)
471
+ elapsed_time = time.process_time() - t
472
+ t_end = time.time()
473
+ fandc_logger.log(logging.INFO, f'Time working for transform process: {t_end - t_start}\t --- \t Time model processing:{elapsed_time}')
474
+ console_logger.log(logging.INFO, 'End of transform process')
475
+
476
+ topics_save_dir = os.path.join(output_subdir, 'topics_bertopic_transform_'+doc_type+'_'+doc_level+'_'+doc_time)
477
+ topic_model.save(topics_save_dir, serialization="safetensors", save_ctfidf=True, save_embedding_model=True)
478
+ fandc_logger.log(logging.INFO, f'Save transform model for {name_dataset} dataset successfully!')
479
+
480
+ ############# Result ###############
481
+ # ***** 1
482
+
483
+ # Get coherence score
484
+ fandc_logger.log(logging.INFO, f'Staring calculate coherence score for {name_dataset} dataset')
485
+ coherence = coherence_score(topic_model, df_documents)
486
+ fandc_logger.log(logging.INFO, f'Coherence score for {name_dataset} dataset: {coherence} with {args.n_topics} topics')
487
+
488
+ if args.only_coherence_score == 'no':
489
+ # Get topics
490
+ fandc_logger.log(logging.INFO, f'Get topics for {name_dataset} dataset')
491
+ topic_info = topic_model.get_topic_info()
492
+ topic_info_path_out = os.path.join(output_subdir, 'topic_info_'+doc_type+'_'+doc_level+'_'+doc_time+'.csv')
493
+ topic_info.to_csv(topic_info_path_out, encoding='utf-8')
494
+ fandc_logger.log(logging.INFO, f'Save topic_info for {name_dataset} dataset successfully!')
495
+
496
+
497
+ # Get weights for each topic
498
+ fandc_logger.log(logging.INFO, f'Get weights for each topic')
499
+ topic_keyword_weights = topic_model.get_topics(full=True)
500
+ topic_keyword_weights_path_out = os.path.join(output_subdir, 'topic_keyword_weights_'+doc_type+'_'+doc_level+'_'+doc_time+'.json')
501
+ with open(topic_keyword_weights_path_out, 'w', encoding="utf-8") as f:
502
+ f.write(json.dumps(str(topic_keyword_weights),indent=4, ensure_ascii=False))
503
+ fandc_logger.log(logging.INFO, f'Save weights for each topic successfully!')
504
+
505
+ # Get coherence score
506
+ fandc_logger.log(logging.INFO, f'Staring calculate coherence score for {name_dataset} dataset')
507
+ coherence = coherence_score(topic_model, df_documents)
508
+ fandc_logger.log(logging.INFO, f'Coherence score for {name_dataset} dataset: {coherence} with {args.n_topics} topics')
509
+
510
+ # Put data into dataframe
511
+ df_topics = topic_model.get_document_info(df_documents)
512
+ df_doc_out = pd.concat([df_topics, df_doc_out.loc[:,"review_id":]],axis=1)
513
+ df_doc_out_path = os.path.join(output_subdir, 'df_documents_'+doc_type+'_'+doc_level+'_'+doc_time+'.csv')
514
+ df_doc_out.to_csv(df_doc_out_path, encoding='utf-8')
515
+ fandc_logger.log(logging.INFO, f'Save df_doc_out for {name_dataset} dataset successfully!')
516
+
517
+ df_doc_path = os.path.join(output_subdir, f'df_docs_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.csv')
518
+ df_doc.to_csv(df_doc_path, encoding='utf-8')
519
+ fandc_logger.log(logging.INFO, f'Save df_doc_{name_dataset} for {name_dataset} dataset successfully!')
520
+
521
+ # Get params
522
+ model_params = topic_model.get_params()
523
+ model_params_path_txt_out = os.path.join(output_subdir, f'model_params_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.txt')
524
+ with open(model_params_path_txt_out, 'w', encoding="utf-8") as f:
525
+ f.write(json.dumps(str(model_params),indent=4, ensure_ascii=False))
526
+ fandc_logger.log(logging.INFO, f'Save params of model for {name_dataset} dataset successfully!')
527
+
528
+ # Get topics visualize
529
+ fig = topic_model.visualize_topics()
530
+ vis_save_dir = os.path.join(output_subdir, f'bertopic_vis_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.html')
531
+ fig.write_html(vis_save_dir)
532
+ fandc_logger.log(logging.INFO, f'Save visualize of topic for {name_dataset} dataset successfully!')
533
+
534
+ # # Hierarchical topics
535
+ # https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html
536
+ fandc_logger.log(logging.INFO, f'Staring hierarchical topics...')
537
+ linkage_function = lambda x: sch.linkage(x, 'average', optimal_ordering=True)
538
+ hierarchical_topics = topic_model.hierarchical_topics(df_documents, linkage_function=linkage_function)
539
+ hierarchical_topics_path_out = os.path.join(output_subdir, f'hierarchical_topics_path_out_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.csv')
540
+ hierarchical_topics.to_csv(hierarchical_topics_path_out, encoding='utf-8')
541
+ fandc_logger.log(logging.INFO, f'Save hierarchical topics table for {name_dataset} dataset successfully!')
542
+
543
+ fig = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics)
544
+ vis_save_dir = os.path.join(output_subdir, f'bertopic_hierarchy_vis_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.html')
545
+ fig.write_html(vis_save_dir)
546
+ fandc_logger.log(logging.INFO, f'Save visualize of hierarchical topics for {name_dataset} dataset successfully!')
547
+
548
+ # Get dynamic topic modeling
549
+ fandc_logger.log(logging.INFO, f'Staring dynamic topic modeling over timestamp...')
550
+ for key in timestamps_dict.keys():
551
+ topics_over_time = topic_model.topics_over_time(df_documents, timestamps_dict[key])
552
+ fig = topic_model.visualize_topics_over_time(topics_over_time, top_n_topics=10, title=f"Topics over time following {key}")
553
+ fig.show()
554
+ vis_save_dir = os.path.join(output_subdir, f'bertopic_dtm_vis_{name_dataset}'+key+'_'+doc_type+'_'+doc_level+'_'+doc_time+'.html')
555
+ fig.write_html(vis_save_dir)
556
+
557
+ topic_dtm_path_out = os.path.join(output_subdir, f'topics_dtm_{name_dataset}'+key+'_'+doc_type+'_'+doc_level+'_'+doc_time+'.csv')
558
+ topics_over_time.to_csv(topic_dtm_path_out, encoding='utf-8')
559
+ fandc_logger.log(logging.INFO, f'Save topics over time for {name_dataset} dataset successfully!')
560
+
561
+ # ****** 2
562
+ # Get reduce topics and topic over time for each n_topics
563
+ fandc_logger.log(logging.INFO, f'Staring reduce topics and topic over time from 10 to 50...')
564
+ for n_topics in [10,20,30,40,50]:
565
+ topic_model_copy = copy.deepcopy(topic_model)
566
+ topic_model_copy.reduce_topics(df_documents, nr_topics=n_topics)
567
+ fig = topic_model_copy.visualize_topics(title=f"Intertopic Distance Map: {n_topics} topics")
568
+ fig.show()
569
+ vis_save_dir = os.path.join(output_subdir, f'bertopic_reduce_vis_{name_dataset}'+str(n_topics)+'_'+doc_type+'_'+doc_level+'_'+doc_time+'.html')
570
+ fig.write_html(vis_save_dir)
571
+
572
+ topic_info = topic_model_copy.get_topic_info()
573
+ topic_info_path_out = os.path.join(output_subdir, f'topic_reduce_info_{name_dataset}'+str(n_topics)+'_'+doc_type+'_'+doc_level+'_'+doc_time+'.csv')
574
+ topic_info.to_csv(topic_info_path_out, encoding='utf-8')
575
+
576
+ for key in timestamps_dict.keys():
577
+ topics_over_time_ = topic_model_copy.topics_over_time(df_documents, timestamps_dict[key])
578
+ fig = topic_model_copy.visualize_topics_over_time(topics_over_time_, top_n_topics=10, title=f"Topics over time following {key}")
579
+ fig.show()
580
+ vis_save_dir = os.path.join(output_subdir, f'bertopic_reduce_dtm_vis_{name_dataset}'+str(n_topics)+'_'+key+'_'+doc_type+'_'+doc_level+'_'+doc_time+'.html')
581
+ fig.write_html(vis_save_dir)
582
+
583
+ topic_dtm_path_out = os.path.join(output_subdir, f'topics_reduce_dtm_{name_dataset}'+str(n_topics)+'_'+key+'_'+doc_type+'_'+doc_level+'_'+doc_time+'.csv')
584
+ topics_over_time_.to_csv(topic_dtm_path_out, encoding='utf-8')
585
+
586
+ fandc_logger.log(logging.INFO, f'Save topics reduce and topic over time for {name_dataset} dataset successfully!')
587
+ ###################################
588
+
589
+
590
+
591
+
592
+
593
+
594
+ if __name__ == "__main__":
595
+ args = init_args()
596
+ if args.train_both == 'yes:
597
+ working(args, 'booking')
598
+ working(args, 'tripadvisor')
599
+ else:
600
+ working(args, args.name_dataset)