| | import os |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | from os.path import join |
| | from pathlib import Path |
| |
|
| | import pandas as pd |
| | import numpy as np |
| | import sklearn as sk |
| | from sklearn.cluster import KMeans |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | import regex as re |
| | from scipy.cluster import hierarchy as sch |
| |
|
| | import datetime |
| | import time |
| | import timeit |
| | import json |
| | import pickle |
| |
|
| | import copy |
| | import random |
| | from itertools import chain |
| |
|
| | import logging |
| | import sys |
| | import argparse |
| | import nltk |
| | nltk.download('wordnet') |
| | nltk.download('punkt') |
| |
|
| |
|
| | import textblob |
| | from textblob import TextBlob |
| | from textblob.wordnet import Synset |
| | from textblob import Word |
| | from textblob.wordnet import VERB |
| |
|
| | from bertopic import BERTopic |
| | from bertopic.vectorizers import ClassTfidfTransformer |
| | from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance |
| | from sklearn.datasets import fetch_20newsgroups |
| | from sklearn.feature_extraction.text import CountVectorizer |
| | from sentence_transformers import SentenceTransformer |
| | |
| | |
| | |
| | from cuml.cluster import HDBSCAN |
| | from cuml.manifold import UMAP |
| |
|
| | import gensim.corpora as corpora |
| | from gensim.models.coherencemodel import CoherenceModel |
| |
|
| | import torch |
| | from GPUtil import showUtilization as gpu_usage |
| | from numba import cuda |
| |
|
| | import pretty_errors |
| | import datetime |
| |
|
| | pretty_errors.configure( |
| | display_timestamp=1, |
| | timestamp_function=lambda: datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| | ) |
| |
|
| |
|
| | |
| | working_dir = os.path.abspath(os.path.join("/workspace", "TopicModelingRepo")) |
| | data_dir = os.path.join(working_dir, 'data') |
| | lib_dir = os.path.join(working_dir, 'libs') |
| | outer_output_dir = os.path.join(working_dir, 'outputs') |
| |
|
| | output_dir_name = time.strftime('%Y_%m_%d') |
| | |
| |
|
| | output_dir = os.path.join(outer_output_dir, output_dir_name) |
| | if not os.path.exists(output_dir): |
| | os.makedirs(output_dir) |
| | |
| | stopwords_path = os.path.join(data_dir, 'vietnamese_stopwords_dash.txt') |
| | |
| | |
| | doc_time = '2024_Jan_15' |
| | doc_type = 'reviews' |
| | doc_level = 'sentence' |
| | target_col = 'normalized_content' |
| |
|
| | def free_gpu_cache(): |
| | print("Initial GPU Usage") |
| | gpu_usage() |
| |
|
| | torch.cuda.empty_cache() |
| |
|
| | cuda.select_device(0) |
| | cuda.close() |
| | cuda.select_device(0) |
| |
|
| | print("GPU Usage after emptying the cache") |
| | gpu_usage() |
| |
|
| |
|
| | def create_logger_file_and_console(path_file): |
| | |
| | logger = logging.getLogger('automated_testing') |
| | logger.setLevel(logging.DEBUG) |
| |
|
| | |
| | fileh = logging.FileHandler(path_file, mode='a') |
| | fileh.setLevel(logging.DEBUG) |
| |
|
| | |
| | consoleh = logging.StreamHandler(stream=sys.stdout) |
| | consoleh.setLevel(logging.INFO) |
| |
|
| | |
| | formatter = logging.Formatter('[%(asctime)s] %(levelname)8s --- %(message)s ',datefmt='%H:%M:%S') |
| | fileh.setFormatter(formatter) |
| | consoleh.setFormatter(formatter) |
| |
|
| | |
| | |
| | logger.addHandler(fileh) |
| |
|
| | return logger |
| |
|
| | def create_logger_file(path_file): |
| | |
| | logger = logging.getLogger('automated_testing') |
| | logger.setLevel(logging.INFO) |
| |
|
| | |
| | fileh = logging.FileHandler(path_file, mode='a') |
| | fileh.setLevel(logging.INFO) |
| |
|
| | |
| | formatter = logging.Formatter('[%(asctime)s] %(levelname)8s --- %(message)s ',datefmt='%H:%M:%S') |
| | fileh.setFormatter(formatter) |
| |
|
| | |
| | logger.addHandler(fileh) |
| |
|
| | return logger |
| |
|
| | def create_logger_console(): |
| | |
| | logger = logging.getLogger('automated_testing') |
| | logger.setLevel(logging.INFO) |
| |
|
| |
|
| | |
| | consoleh = logging.StreamHandler(stream=sys.stdout) |
| | consoleh.setLevel(logging.INFO) |
| |
|
| | |
| | formatter = logging.Formatter('[%(asctime)s] %(levelname)8s --- %(message)s ',datefmt='%H:%M:%S') |
| | consoleh.setFormatter(formatter) |
| |
|
| | |
| | logger.addHandler(consoleh) |
| |
|
| | return logger |
| |
|
| |
|
| | def init_args(): |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument( |
| | "--n_topics", |
| | type=int, |
| | default=10, |
| | required=True, |
| | help="Number of topics for topic modeling.", |
| | ) |
| | |
| | parser.add_argument( |
| | "--name_dataset", |
| | default="booking", |
| | type=str, |
| | help="The name of the dataset, selected from: [booking, tripadvisor]", |
| | ) |
| | |
| | parser.add_argument( |
| | "--train_both", |
| | default="yes", |
| | type=str, |
| | required=True, |
| | help="Train both booking and tripadvisor or only one.", |
| | ) |
| | |
| | parser.add_argument( |
| | "--only_coherence_score", |
| | default="yes", |
| | type=str, |
| | required=True, |
| | help="Only train both models for calculating coherence score.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--need_reduce_n_topics", |
| | default="yes", |
| | type=str, |
| | required=True, |
| | help="Need reduce n topics and show topic modeling over timestamp with this.", |
| | ) |
| |
|
| | |
| | args = parser.parse_args() |
| |
|
| |
|
| | return args |
| | |
| | def check_valid(list_topics): |
| | count = 0 |
| | for topic in list_topics: |
| | if topic[0] != '': |
| | count += 1 |
| | |
| | return True if count > 2 else False |
| |
|
| |
|
| | def prepare_data(doc_source, doc_type, type_framework = 'pandas'): |
| | name_file = doc_source.split('.')[0] |
| | out_dir = os.path.join(output_dir, name_file) |
| | if not os.path.exists(out_dir): |
| | os.makedirs(out_dir) |
| |
|
| | date_col = 'Date' |
| | df_reviews_path = os.path.join(data_dir, doc_source) |
| |
|
| | if type_framework == 'pandas': |
| | df_reviews = pd.read_csv(df_reviews_path, lineterminator='\n', encoding='utf-8') |
| | df_reviews = df_reviews.loc[df_reviews['year']>0] |
| | df_reviews = df_reviews.loc[df_reviews['language'] == 'English'] |
| |
|
| | if doc_type == 'reviews': |
| | df_doc = df_reviews |
| | df_doc['dates'] = pd.to_datetime(df_doc[date_col],dayfirst=False,errors='coerce'). \ |
| | dt.to_period('M'). \ |
| | dt.strftime('%Y-%m-%d') |
| |
|
| |
|
| | |
| | |
| |
|
| | df_doc['dates_yearly'] = pd.to_datetime(df_doc[date_col],dayfirst=False,errors='coerce'). \ |
| | dt.to_period('Y'). \ |
| | dt.strftime('%Y') |
| |
|
| |
|
| | df_doc['dates_quarterly'] = pd.to_datetime(df_doc[date_col],dayfirst=False,errors='coerce'). \ |
| | dt.to_period('d'). \ |
| | dt.strftime('%YQ%q') |
| |
|
| |
|
| | df_doc['dates_monthly'] = pd.to_datetime(df_doc[date_col],dayfirst=False,errors='coerce'). \ |
| | dt.to_period('M'). \ |
| | dt.strftime('%Y-%m') |
| |
|
| | elif type_framework == 'polars': |
| | df_reviews = pl.read_csv(df_reviews_path, separator='\n') |
| | df_reviews = df_reviews.filter(pl.col("year")>0) |
| | df_reviews = df_reviews.filter(pl.col('language') == 'English') |
| |
|
| | if doc_type == 'reviews': |
| | df_doc = df_reviews |
| |
|
| | df_doc = df_doc.with_column(pl.col(date_col).str_to_datetime(dayfirst=False, errors='coerce'). \ |
| | to_period('M'). \ |
| | strftime('%Y-%m-%d').alias('dates')) |
| |
|
| | df_doc = df_doc.with_column(pl.col(date_col).str_to_datetime(dayfirst=False, errors='coerce'). \ |
| | to_period('Y'). \ |
| | strftime('%Y').alias('dates_yearly')) |
| |
|
| |
|
| | df_doc = df_doc.with_column(pl.col(date_col).str_to_datetime(dayfirst=False, errors='coerce'). \ |
| | to_period('d'). \ |
| | strftime('%YQ%q').alias('dates_quarterly')) |
| |
|
| |
|
| | df_doc = df_doc.with_column(pl.col(date_col).str_to_datetime(dayfirst=False, errors='coerce'). \ |
| | to_period('M'). \ |
| | strftime('%Y-%m').alias('dates_monthly')) |
| |
|
| |
|
| | timestamps_dict = dict() |
| | timestamps_dict['yearly'] = df_doc['dates_yearly'].to_list() |
| | timestamps_dict['quarterly'] = df_doc['dates_quarterly'].to_list() |
| | timestamps_dict['monthly'] = df_doc['dates_monthly'].to_list() |
| | timestamps_dict['date'] = df_doc['dates'].to_list() |
| |
|
| | target_col = 'normalized_content' |
| | df_documents = df_doc[target_col] |
| |
|
| | return (timestamps_dict, df_doc, df_documents, df_reviews) |
| | |
| | def flatten_comprehension(matrix): |
| | return [item for row in matrix for item in row] |
| |
|
| | def processing_data(df_doc, df_documents, timestamps_dict, doc_level, target_col): |
| |
|
| | if doc_level == 'sentence': |
| | |
| | |
| |
|
| | |
| | ll_sent = [[str(sent) for sent in nltk.sent_tokenize(row,language='english')] for row in df_doc[target_col]] |
| |
|
| | |
| | num_sent = [len(x) for x in ll_sent] |
| |
|
| | |
| | df_documents = pd.Series(flatten_comprehension([x for x in ll_sent])) |
| |
|
| | |
| |
|
| | |
| | for key in timestamps_dict.keys(): |
| | timestamps_dict[key] = list(chain.from_iterable(n*[item] for item, n in zip(timestamps_dict[key], num_sent))) |
| | |
| | |
| | |
| | |
| |
|
| | |
| | sent_id_ll = [[j]*num_sent[i] for i,j in enumerate(df_doc.index)] |
| | sent_id = flatten_comprehension(sent_id_ll) |
| |
|
| | |
| | df_doc_out = pd.DataFrame({ |
| | 'sentence':df_documents, 'review_id':sent_id, |
| | 'date':timestamps_dict['date'], |
| | 'monthly':timestamps_dict['monthly'], |
| | 'quarterly':timestamps_dict['quarterly'], |
| | 'yearly':timestamps_dict['yearly']}) |
| |
|
| |
|
| | return df_documents, timestamps_dict, sent_id, df_doc_out |
| |
|
| | def create_model_bertopic_booking(n_topics: int = 10): |
| | sentence_model = SentenceTransformer("thenlper/gte-small") |
| |
|
| | |
| | umap_model = UMAP(n_neighbors=50, n_components=10, |
| | min_dist=0.0, metric='euclidean', |
| | low_memory=True, |
| | random_state=1) |
| |
|
| |
|
| | cluster_model = HDBSCAN(min_cluster_size=50, metric='euclidean', |
| | cluster_selection_method='leaf', |
| | |
| | prediction_data=True, |
| | leaf_size=20, |
| | min_samples=10) |
| |
|
| |
|
| | |
| | vectorizer_model = CountVectorizer(min_df=1,ngram_range=(1, 1),stop_words="english") |
| | ctfidf_model = ClassTfidfTransformer() |
| | |
| |
|
| | |
| | representation_model = MaximalMarginalRelevance(diversity=0.7,top_n_words=10) |
| |
|
| |
|
| | |
| | topic_model = BERTopic(embedding_model=sentence_model, |
| | umap_model=umap_model, |
| | hdbscan_model=cluster_model, |
| | vectorizer_model=vectorizer_model, |
| | ctfidf_model=ctfidf_model, |
| | representation_model=representation_model, |
| | |
| | |
| | nr_topics = n_topics, |
| | top_n_words = 10, |
| | low_memory=True, |
| | verbose=True) |
| |
|
| | return topic_model |
| |
|
| | def create_model_bertopic_tripadvisor(n_topics: int = 10): |
| | sentence_model = SentenceTransformer("thenlper/gte-small") |
| |
|
| | |
| | umap_model = UMAP(n_neighbors=200, n_components=10, |
| | min_dist=0.0, metric='euclidean', |
| | low_memory=True, |
| | random_state=1) |
| |
|
| |
|
| | cluster_model = HDBSCAN(min_cluster_size=500, metric='euclidean', |
| | cluster_selection_method='leaf', |
| | prediction_data=True, |
| | leaf_size=100, |
| | min_samples=10) |
| |
|
| |
|
| | |
| | vectorizer_model = CountVectorizer(min_df=10,ngram_range=(1, 1),stop_words="english") |
| | ctfidf_model = ClassTfidfTransformer() |
| | |
| |
|
| | |
| | representation_model = MaximalMarginalRelevance(diversity=0.7,top_n_words=10) |
| |
|
| |
|
| | |
| | topic_model = BERTopic(embedding_model=sentence_model, |
| | umap_model=umap_model, |
| | hdbscan_model=cluster_model, |
| | vectorizer_model=vectorizer_model, |
| | ctfidf_model=ctfidf_model, |
| | representation_model=representation_model, |
| | |
| | |
| | nr_topics = n_topics, |
| | top_n_words = 10, |
| | low_memory=True, |
| | verbose=True) |
| |
|
| | return topic_model |
| |
|
| |
|
| | def coherence_score(topic_model, df_documents): |
| | cleaned_docs = topic_model._preprocess_text(df_documents) |
| | vectorizer = topic_model.vectorizer_model |
| | analyzer = vectorizer.build_analyzer() |
| | tokens = [analyzer(doc) for doc in cleaned_docs] |
| | dictionary = corpora.Dictionary(tokens) |
| | corpus = [dictionary.doc2bow(token) for token in tokens] |
| | topics = topic_model.get_topics() |
| | |
| | topic_words = [ |
| | [word for word, _ in topic_model.get_topic(topic) if word != ""] for topic in topics if check_valid(topic_model.get_topic(topic)) |
| | ] |
| | |
| | coherence_model = CoherenceModel(topics=topic_words, |
| | texts=tokens, |
| | corpus=corpus, |
| | dictionary=dictionary, |
| | coherence='c_npmi') |
| | coherence = coherence_model.get_coherence() |
| | return coherence |
| |
|
| | def working(args: argparse.Namespace, name_dataset: str): |
| |
|
| | source = f'en_{name_dataset}' |
| | output_subdir_name = source + f'/bertopic2_non_zeroshot_{args.n_topics}topic_'+doc_type+'_'+doc_level+'_'+doc_time |
| | output_subdir = os.path.join(output_dir, output_subdir_name) |
| | if not os.path.exists(output_subdir): |
| | os.makedirs(output_subdir) |
| | |
| | info_log_out = os.path.join(output_subdir, 'info.log') |
| | |
| | fandc_logger = create_logger_file_and_console(info_log_out) |
| | file_logger = create_logger_file(info_log_out) |
| | console_logger = create_logger_console() |
| | |
| | |
| | |
| | fandc_logger.log(logging.INFO, f'STARTING WITH TOPIC MODEL FOR {name_dataset} dataset') |
| | fandc_logger.log(logging.INFO, f'Get data from {name_dataset}') |
| | doc_source = f'en_{name_dataset}.csv' |
| | list_tmp = prepare_data(doc_source, doc_type, type_framework = 'pandas') |
| |
|
| | (timestamps_dict, df_doc, |
| | df_documents, df_reviews) = list_tmp |
| |
|
| | fandc_logger.log(logging.INFO, f'Get data from {name_dataset} successfully!') |
| | |
| | |
| | |
| | fandc_logger.log(logging.INFO, f'Processing data for {name_dataset} dataset') |
| | (df_documents, timestamps_dict, |
| | sent_id, df_doc_out) = processing_data(df_doc, df_documents, timestamps_dict, doc_level, target_col) |
| | fandc_logger.log(logging.INFO, f'Processing data for {name_dataset} dataset successfully!') |
| | |
| | |
| | |
| | |
| | |
| | fandc_logger.log(logging.INFO, f'Create model for {name_dataset} dataset') |
| | topic_model = create_model_bertopic_booking(args.n_topics) |
| | |
| | |
| | fandc_logger.log(logging.INFO, f'Training model for {name_dataset} dataset') |
| | fandc_logger.log(logging.INFO, f'Fitting model processing...') |
| | t_start = time.time() |
| | t = time.process_time() |
| | topic_model = topic_model.fit(df_documents) |
| | elapsed_time = time.process_time() - t |
| | t_end = time.time() |
| | fandc_logger.log(logging.INFO, f'Time working for fitting process: {t_end - t_start}\t --- \t Time model processing:{elapsed_time}') |
| | console_logger.log(logging.INFO, 'End of fitting process') |
| | |
| | topics_save_dir = os.path.join(output_subdir, 'topics_bertopic_'+doc_type+'_'+doc_level+'_'+doc_time) |
| | topic_model.save(topics_save_dir, serialization="safetensors", save_ctfidf=True, save_embedding_model=True) |
| | fandc_logger.log(logging.INFO, f'Save fitting model for {name_dataset} dataset successfully!') |
| | |
| | |
| | t_start = time.time() |
| | t = time.process_time() |
| | topics, probs = topic_model.transform(df_documents) |
| | elapsed_time = time.process_time() - t |
| | t_end = time.time() |
| | fandc_logger.log(logging.INFO, f'Time working for transform process: {t_end - t_start}\t --- \t Time model processing:{elapsed_time}') |
| | console_logger.log(logging.INFO, 'End of transform process') |
| | |
| | topics_save_dir = os.path.join(output_subdir, 'topics_bertopic_transform_'+doc_type+'_'+doc_level+'_'+doc_time) |
| | topic_model.save(topics_save_dir, serialization="safetensors", save_ctfidf=True, save_embedding_model=True) |
| | fandc_logger.log(logging.INFO, f'Save transform model for {name_dataset} dataset successfully!') |
| | |
| | |
| | |
| | |
| | |
| | fandc_logger.log(logging.INFO, f'Staring calculate coherence score for {name_dataset} dataset') |
| | coherence = coherence_score(topic_model, df_documents) |
| | fandc_logger.log(logging.INFO, f'Coherence score for {name_dataset} dataset: {coherence} with {args.n_topics} topics') |
| | |
| | if args.only_coherence_score == 'no': |
| | |
| | fandc_logger.log(logging.INFO, f'Get topics for {name_dataset} dataset') |
| | topic_info = topic_model.get_topic_info() |
| | topic_info_path_out = os.path.join(output_subdir, 'topic_info_'+doc_type+'_'+doc_level+'_'+doc_time+'.csv') |
| | topic_info.to_csv(topic_info_path_out, encoding='utf-8') |
| | fandc_logger.log(logging.INFO, f'Save topic_info for {name_dataset} dataset successfully!') |
| |
|
| |
|
| | |
| | fandc_logger.log(logging.INFO, f'Get weights for each topic') |
| | topic_keyword_weights = topic_model.get_topics(full=True) |
| | topic_keyword_weights_path_out = os.path.join(output_subdir, 'topic_keyword_weights_'+doc_type+'_'+doc_level+'_'+doc_time+'.json') |
| | with open(topic_keyword_weights_path_out, 'w', encoding="utf-8") as f: |
| | f.write(json.dumps(str(topic_keyword_weights),indent=4, ensure_ascii=False)) |
| | fandc_logger.log(logging.INFO, f'Save weights for each topic successfully!') |
| |
|
| |
|
| | |
| | df_topics = topic_model.get_document_info(df_documents) |
| | df_doc_out = pd.concat([df_topics, df_doc_out.loc[:,"review_id":]],axis=1) |
| | df_doc_out_path = os.path.join(output_subdir, 'df_documents_'+doc_type+'_'+doc_level+'_'+doc_time+'.csv') |
| | df_doc_out.to_csv(df_doc_out_path, encoding='utf-8') |
| | fandc_logger.log(logging.INFO, f'Save df_doc_out for {name_dataset} dataset successfully!') |
| |
|
| | df_doc_path = os.path.join(output_subdir, f'df_docs_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.csv') |
| | df_doc.to_csv(df_doc_path, encoding='utf-8') |
| | fandc_logger.log(logging.INFO, f'Save df_doc_{name_dataset} for {name_dataset} dataset successfully!') |
| |
|
| | |
| | model_params = topic_model.get_params() |
| | model_params_path_txt_out = os.path.join(output_subdir, f'model_params_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.txt') |
| | with open(model_params_path_txt_out, 'w', encoding="utf-8") as f: |
| | f.write(json.dumps(str(model_params),indent=4, ensure_ascii=False)) |
| | fandc_logger.log(logging.INFO, f'Save params of model for {name_dataset} dataset successfully!') |
| |
|
| | |
| | fig = topic_model.visualize_topics() |
| | vis_save_dir = os.path.join(output_subdir, f'bertopic_vis_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.html') |
| | fig.write_html(vis_save_dir) |
| | fandc_logger.log(logging.INFO, f'Save visualize of topic for {name_dataset} dataset successfully!') |
| |
|
| | |
| | |
| | fandc_logger.log(logging.INFO, f'Staring hierarchical topics...') |
| | linkage_function = lambda x: sch.linkage(x, 'average', optimal_ordering=True) |
| | hierarchical_topics = topic_model.hierarchical_topics(df_documents, linkage_function=linkage_function) |
| | hierarchical_topics_path_out = os.path.join(output_subdir, f'hierarchical_topics_path_out_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.csv') |
| | hierarchical_topics.to_csv(hierarchical_topics_path_out, encoding='utf-8') |
| | fandc_logger.log(logging.INFO, f'Save hierarchical topics table for {name_dataset} dataset successfully!') |
| |
|
| | fig = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) |
| | vis_save_dir = os.path.join(output_subdir, f'bertopic_hierarchy_vis_{name_dataset}'+doc_type+'_'+doc_level+'_'+doc_time+'.html') |
| | fig.write_html(vis_save_dir) |
| | fandc_logger.log(logging.INFO, f'Save visualize of hierarchical topics for {name_dataset} dataset successfully!') |
| |
|
| | |
| | fandc_logger.log(logging.INFO, f'Staring dynamic topic modeling over timestamp...') |
| | for key in timestamps_dict.keys(): |
| | topics_over_time = topic_model.topics_over_time(df_documents, timestamps_dict[key]) |
| | fig = topic_model.visualize_topics_over_time(topics_over_time, top_n_topics=10, title=f"Topics over time following {key}") |
| | fig.show() |
| | vis_save_dir = os.path.join(output_subdir, f'bertopic_dtm_vis_{name_dataset}'+key+'_'+doc_type+'_'+doc_level+'_'+doc_time+'.html') |
| | fig.write_html(vis_save_dir) |
| |
|
| | topic_dtm_path_out = os.path.join(output_subdir, f'topics_dtm_{name_dataset}'+key+'_'+doc_type+'_'+doc_level+'_'+doc_time+'.csv') |
| | topics_over_time.to_csv(topic_dtm_path_out, encoding='utf-8') |
| | fandc_logger.log(logging.INFO, f'Save topics over time for {name_dataset} dataset successfully!') |
| | |
| | |
| | fandc_logger.log(logging.INFO, f'ENDING TRAINING TOPIC MODELING {name_dataset} dataset\n') |
| |
|
| |
|
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | args = init_args() |
| | if args.train_both == 'yes': |
| | working(args, 'booking') |
| | working(args, 'tripadvisor') |
| | else: |
| | working(args, args.name_dataset) |
| |
|
| | free_gpu_cache() |
| |
|