Spaces:
Running
Running
Allowed for loading in external topic labels. A few visualisation modifications.
Browse files- app.py +19 -13
- funcs/bertopic_vis_documents.py +9 -3
app.py
CHANGED
|
@@ -253,13 +253,16 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
| 253 |
else:
|
| 254 |
print("Topic model created.")
|
| 255 |
|
|
|
|
| 256 |
if not custom_labels_df.empty:
|
| 257 |
-
#
|
|
|
|
| 258 |
|
| 259 |
-
|
| 260 |
-
#
|
| 261 |
|
| 262 |
-
|
|
|
|
| 263 |
|
| 264 |
# Outputs
|
| 265 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
|
@@ -384,7 +387,7 @@ def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, l
|
|
| 384 |
|
| 385 |
return output_text, output_list, topic_model
|
| 386 |
|
| 387 |
-
def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress()):
|
| 388 |
|
| 389 |
progress(0, desc= "Preparing data for visualisation")
|
| 390 |
|
|
@@ -416,12 +419,13 @@ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode
|
|
| 416 |
|
| 417 |
topic_dets = topic_model.get_topic_info()
|
| 418 |
|
| 419 |
-
# Replace original labels with
|
| 420 |
-
if
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
|
|
|
| 425 |
|
| 426 |
# Pre-reduce embeddings for visualisation purposes
|
| 427 |
if low_resource_mode == "No":
|
|
@@ -560,9 +564,11 @@ with block:
|
|
| 560 |
|
| 561 |
with gr.Tab("Visualise"):
|
| 562 |
with gr.Row():
|
| 563 |
-
in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in output visualisations.")
|
| 564 |
visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
|
|
|
|
| 565 |
sample_slide = gr.Slider(minimum = 0.01, maximum = 1, value = 0.1, step = 0.01, label = "Proportion of data points to show on output visualisations.")
|
|
|
|
|
|
|
| 566 |
plot_btn = gr.Button("Visualise topic model")
|
| 567 |
with gr.Row():
|
| 568 |
vis_output_single_text = gr.Textbox(label="Visualisation output text")
|
|
@@ -595,7 +601,7 @@ with block:
|
|
| 595 |
|
| 596 |
save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
|
| 597 |
|
| 598 |
-
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
|
| 599 |
|
| 600 |
#block.load(read_logs, None, logs, every=5)
|
| 601 |
|
|
|
|
| 253 |
else:
|
| 254 |
print("Topic model created.")
|
| 255 |
|
| 256 |
+
# Replace current topic labels if new ones loaded in
|
| 257 |
if not custom_labels_df.empty:
|
| 258 |
+
#custom_label_list = list(custom_labels_df.iloc[:,0])
|
| 259 |
+
custom_label_list = [label.replace("\n", "") for label in custom_labels_df.iloc[:,0]]
|
| 260 |
|
| 261 |
+
topic_model.set_topic_labels(custom_label_list)
|
| 262 |
+
#topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model)
|
| 263 |
|
| 264 |
+
|
| 265 |
+
print("Custom topics: ", topic_model.custom_labels_)
|
| 266 |
|
| 267 |
# Outputs
|
| 268 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
|
|
|
| 387 |
|
| 388 |
return output_text, output_list, topic_model
|
| 389 |
|
| 390 |
+
def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, legend_label, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress()):
|
| 391 |
|
| 392 |
progress(0, desc= "Preparing data for visualisation")
|
| 393 |
|
|
|
|
| 419 |
|
| 420 |
topic_dets = topic_model.get_topic_info()
|
| 421 |
|
| 422 |
+
# Replace original labels with another representation if specified
|
| 423 |
+
if legend_label:
|
| 424 |
+
topic_dets = topic_model.get_topics(full=True)
|
| 425 |
+
if legend_label in topic_dets:
|
| 426 |
+
labels = [topic_dets[legend_label].values()]
|
| 427 |
+
labels = [str(v) for v in labels]
|
| 428 |
+
topic_model.set_topic_labels(labels)
|
| 429 |
|
| 430 |
# Pre-reduce embeddings for visualisation purposes
|
| 431 |
if low_resource_mode == "No":
|
|
|
|
| 564 |
|
| 565 |
with gr.Tab("Visualise"):
|
| 566 |
with gr.Row():
|
|
|
|
| 567 |
visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
|
| 568 |
+
in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in output visualisations.")
|
| 569 |
sample_slide = gr.Slider(minimum = 0.01, maximum = 1, value = 0.1, step = 0.01, label = "Proportion of data points to show on output visualisations.")
|
| 570 |
+
legend_label = gr.Textbox(label="Custom legend column (optional, any column from the topic details output)", visible=False)
|
| 571 |
+
|
| 572 |
plot_btn = gr.Button("Visualise topic model")
|
| 573 |
with gr.Row():
|
| 574 |
vis_output_single_text = gr.Textbox(label="Visualisation output text")
|
|
|
|
| 601 |
|
| 602 |
save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
|
| 603 |
|
| 604 |
+
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, legend_label, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
|
| 605 |
|
| 606 |
#block.load(read_logs, None, logs, every=5)
|
| 607 |
|
funcs/bertopic_vis_documents.py
CHANGED
|
@@ -160,10 +160,14 @@ def visualize_documents_custom(topic_model,
|
|
| 160 |
names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
|
| 161 |
names = [label if len(label) < 30 else label[:27] + "..." for label in names]
|
| 162 |
elif topic_model.custom_labels_ is not None and custom_labels:
|
|
|
|
| 163 |
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
|
| 164 |
else:
|
|
|
|
| 165 |
names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
|
| 166 |
|
|
|
|
|
|
|
| 167 |
# Visualize
|
| 168 |
fig = go.Figure()
|
| 169 |
|
|
@@ -192,6 +196,8 @@ def visualize_documents_custom(topic_model,
|
|
| 192 |
|
| 193 |
# Selected topics
|
| 194 |
for name, topic in zip(names, unique_topics):
|
|
|
|
|
|
|
| 195 |
if topic in topics and topic != -1:
|
| 196 |
selection = df.loc[df.topic == topic, :]
|
| 197 |
selection["text"] = ""
|
|
@@ -658,7 +664,7 @@ def visualize_barchart_custom(topic_model,
|
|
| 658 |
subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
|
| 659 |
else:
|
| 660 |
subplot_titles = [f"Topic {topic}" for topic in topics]
|
| 661 |
-
columns =
|
| 662 |
rows = int(np.ceil(len(topics) / columns))
|
| 663 |
fig = make_subplots(rows=rows,
|
| 664 |
cols=columns,
|
|
@@ -697,14 +703,14 @@ def visualize_barchart_custom(topic_model,
|
|
| 697 |
'xanchor': 'center',
|
| 698 |
'yanchor': 'top',
|
| 699 |
'font': dict(
|
| 700 |
-
size=
|
| 701 |
color="Black")
|
| 702 |
},
|
| 703 |
width=width*4,
|
| 704 |
height=height*rows if rows > 1 else height * 1.3,
|
| 705 |
hoverlabel=dict(
|
| 706 |
bgcolor="white",
|
| 707 |
-
font_size=
|
| 708 |
font_family="Rockwell"
|
| 709 |
),
|
| 710 |
)
|
|
|
|
| 160 |
names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
|
| 161 |
names = [label if len(label) < 30 else label[:27] + "..." for label in names]
|
| 162 |
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 163 |
+
print("Using custom labels: ", topic_model.custom_labels_)
|
| 164 |
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
|
| 165 |
else:
|
| 166 |
+
print("Not using custom labels")
|
| 167 |
names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
|
| 168 |
|
| 169 |
+
print(names)
|
| 170 |
+
|
| 171 |
# Visualize
|
| 172 |
fig = go.Figure()
|
| 173 |
|
|
|
|
| 196 |
|
| 197 |
# Selected topics
|
| 198 |
for name, topic in zip(names, unique_topics):
|
| 199 |
+
#print(name)
|
| 200 |
+
#print(topic)
|
| 201 |
if topic in topics and topic != -1:
|
| 202 |
selection = df.loc[df.topic == topic, :]
|
| 203 |
selection["text"] = ""
|
|
|
|
| 664 |
subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
|
| 665 |
else:
|
| 666 |
subplot_titles = [f"Topic {topic}" for topic in topics]
|
| 667 |
+
columns = 3
|
| 668 |
rows = int(np.ceil(len(topics) / columns))
|
| 669 |
fig = make_subplots(rows=rows,
|
| 670 |
cols=columns,
|
|
|
|
| 703 |
'xanchor': 'center',
|
| 704 |
'yanchor': 'top',
|
| 705 |
'font': dict(
|
| 706 |
+
size=14,
|
| 707 |
color="Black")
|
| 708 |
},
|
| 709 |
width=width*4,
|
| 710 |
height=height*rows if rows > 1 else height * 1.3,
|
| 711 |
hoverlabel=dict(
|
| 712 |
bgcolor="white",
|
| 713 |
+
font_size=14,
|
| 714 |
font_family="Rockwell"
|
| 715 |
),
|
| 716 |
)
|