lynn-twinkl commited on
Commit
9b4c514
·
1 Parent(s): c41f427

added function to attach topics to original df

Browse files
Files changed (1) hide show
  1. src/models/topic_modeling_pipeline.py +21 -15
src/models/topic_modeling_pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  import openai
 
2
  import numpy as np
3
  import streamlit as st
4
  import re
@@ -164,19 +165,24 @@ def bertopic_model(docs, embeddings, _embedding_model, _umap_model, _hdbscan_mod
164
  # TOPIC TO DATAFRAME MAPPING
165
  #################################
166
 
167
- def update_df_with_topics(df, mapping, sentence_topics, topic_label_map):
168
- topics_by_row = {}
169
- for i, row_idx in enumerate(mapping):
170
- topic = sentence_topics[i]
171
- topics_by_row.setdefault(row_idx, set()).add(topic)
172
-
173
- updated_df = df.copy()
174
-
175
- def map_topics(row_idx):
176
- topic_ids = topics_by_row.get(row_idx, set())
177
- topic_names = [topic_label_map.get(t, str(t)) for t in topic_ids if t != -1]
178
- return ", ".join(sorted(topic_names))
179
-
180
- updated_df['Topics'] = updated_df.index.map(map_topics)
181
- return updated_df
 
 
 
 
 
182
 
 
1
  import openai
2
+ import pandas as pd
3
  import numpy as np
4
  import streamlit as st
5
  import re
 
165
  # TOPIC TO DATAFRAME MAPPING
166
  #################################
167
 
168
+
169
+ def attach_topics(
170
+ df, mappings, sentence_topics, label_map, col="topics", drop_outlier=True
171
+ ):
172
+ import pandas as pd # in case it's not already imported
173
+
174
+ s = (
175
+ pd.DataFrame({"row": mappings, "topic": sentence_topics})
176
+ .query("topic != -1") if drop_outlier else
177
+ pd.DataFrame({"row": mappings, "topic": sentence_topics})
178
+ )
179
+
180
+ # Group topics per row and make list of labels
181
+ topics_list = (
182
+ s.groupby("row")["topic"]
183
+ .agg(lambda ids: sorted({label_map.get(i, str(i)) for i in ids}))
184
+ )
185
+
186
+ # Assign the lists to the column, fill missing with empty list
187
+ return df.assign(**{col: topics_list.reindex(df.index).apply(lambda x: x if isinstance(x, list) else [])})
188