lynn-twinkl commited on
Commit
e3ee58f
·
1 Parent(s): ddeb431

Testing bertopic

Browse files
notebooks/app_pipeline.ipynb ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "aa8053eb-bad5-45cf-b762-5426dfaf3281",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 1. Configuration"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 5,
14
+ "id": "489b293a-e27a-4b66-b16f-13f0b9964566",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import pandas as pd\n",
19
+ "import altair as alt\n",
20
+ "import joblib\n",
21
+ "from io import BytesIO\n",
22
+ "import os\n",
23
+ "import sys\n",
24
+ "\n",
25
+ "# Add project root (one level up from notebooks/) to sys.path\n",
26
+ "sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))\n",
27
+ "\n",
28
+ "# ---- FUNCTIONS ----\n",
29
+ "\n",
30
+ "from src.extract_usage import extract_usage\n",
31
+ "from src.necessity_index import compute_necessity, index_scaler, qcut_labels\n",
32
+ "from src.column_detection import detect_freeform_col\n",
33
+ "from src.shortlist import shortlist_applications\n",
34
+ "from src.twinkl_originals import find_book_candidates\n",
35
+ "from src.preprocess_text import normalise_text \n",
36
+ "from typing import Tuple"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 18,
42
+ "id": "bf6bb17e-7cf7-4864-96d2-0ceb864ff1e8",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "def load_heartfelt_predictor():\n",
47
+ " # Compute absolute path from notebook location\n",
48
+ " project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))\n",
49
+ " model_path = os.path.join(project_root, \"src\", \"models\", \"heartfelt_pipeline.joblib\")\n",
50
+ " return joblib.load(model_path)\n"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 21,
56
+ "id": "28225da1-0757-4289-a8a6-79e2c5d7e288",
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "def load_and_process(raw_csv) -> Tuple[pd.DataFrame, str]:\n",
61
+ " \"\"\"\n",
62
+ " Load CSV from raw bytes, detect freeform column, compute necessity scores,\n",
63
+ " and extract usage items. Returns processed DataFrame and freeform column name.\n",
64
+ " \"\"\"\n",
65
+ " # Read Uploaded Data \n",
66
+ " df_orig = pd.read_csv(raw_csv)\n",
67
+ "\n",
68
+ " # Detect freeform column\n",
69
+ " freeform_col = detect_freeform_col(df_orig)\n",
70
+ "\n",
71
+ " df_orig = df_orig[df_orig[freeform_col].notna()]\n",
72
+ "\n",
73
+ " #Word Count\n",
74
+ " df_orig['word_count'] = df_orig[freeform_col].fillna('').str.split().str.len()\n",
75
+ "\n",
76
+ " # Compute Necessity Scores\n",
77
+ " scored = df_orig.join(df_orig[freeform_col].apply(compute_necessity))\n",
78
+ " scored['necessity_index'] = index_scaler(scored['necessity_index'].values)\n",
79
+ " scored['priority'] = qcut_labels(scored['necessity_index'])\n",
80
+ "\n",
81
+ " # Find Twinkl Originals Candidates\n",
82
+ " scored['book_candidates'] = find_book_candidates(scored, freeform_col)\n",
83
+ "\n",
84
+ " # Label Heartfelt Applications\n",
85
+ " scored['clean_text'] = scored[freeform_col].map(normalise_text)\n",
86
+ " model = load_heartfelt_predictor()\n",
87
+ " scored['is_heartfelt'] = model.predict(scored['clean_text'].astype(str))\n",
88
+ "\n",
89
+ "\n",
90
+ " \n",
91
+ " # Usage Extraction\n",
92
+ " #docs = df_orig[freeform_col].to_list() <---- Disabled Ai-powered extraction for testing\n",
93
+ " #scored['Usage'] = extract_usage(docs)\n",
94
+ "\n",
95
+ " return scored, freeform_col\n",
96
+ "\n"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 22,
102
+ "id": "3989b41b-bf81-4436-98fa-28a191871546",
103
+ "metadata": {},
104
+ "outputs": [
105
+ {
106
+ "name": "stderr",
107
+ "output_type": "stream",
108
+ "text": [
109
+ "/Users/lynn/Documents/Twinkl/grant-applications-app/src/twinkl_originals.py:15: UserWarning: This pattern is interpreted as a regular expression, and has match groups. To actually get the groups, use str.extract.\n",
110
+ " is_primary = series.str.contains(pattern_level, case=False, na=False)\n"
111
+ ]
112
+ }
113
+ ],
114
+ "source": [
115
+ "df, freeform_col = load_and_process('data/feb-march-data.csv')"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": 23,
121
+ "id": "aa3aee7b-eab7-4e42-ba8c-addd9847f146",
122
+ "metadata": {},
123
+ "outputs": [
124
+ {
125
+ "data": {
126
+ "text/html": [
127
+ "<div>\n",
128
+ "<style scoped>\n",
129
+ " .dataframe tbody tr th:only-of-type {\n",
130
+ " vertical-align: middle;\n",
131
+ " }\n",
132
+ "\n",
133
+ " .dataframe tbody tr th {\n",
134
+ " vertical-align: top;\n",
135
+ " }\n",
136
+ "\n",
137
+ " .dataframe thead th {\n",
138
+ " text-align: right;\n",
139
+ " }\n",
140
+ "</style>\n",
141
+ "<table border=\"1\" class=\"dataframe\">\n",
142
+ " <thead>\n",
143
+ " <tr style=\"text-align: right;\">\n",
144
+ " <th></th>\n",
145
+ " <th>Id</th>\n",
146
+ " <th>Date/Time Requested</th>\n",
147
+ " <th>Giveaway Title</th>\n",
148
+ " <th>Customer Name</th>\n",
149
+ " <th>Email Address</th>\n",
150
+ " <th>School Name</th>\n",
151
+ " <th>Postal Address</th>\n",
152
+ " <th>Address Line 2</th>\n",
153
+ " <th>Address City</th>\n",
154
+ " <th>Postcode</th>\n",
155
+ " <th>...</th>\n",
156
+ " <th>Unnamed: 11</th>\n",
157
+ " <th>word_count</th>\n",
158
+ " <th>necessity_index</th>\n",
159
+ " <th>urgency_score</th>\n",
160
+ " <th>severity_score</th>\n",
161
+ " <th>vulnerability_score</th>\n",
162
+ " <th>priority</th>\n",
163
+ " <th>book_candidates</th>\n",
164
+ " <th>clean_text</th>\n",
165
+ " <th>is_heartfelt</th>\n",
166
+ " </tr>\n",
167
+ " </thead>\n",
168
+ " <tbody>\n",
169
+ " <tr>\n",
170
+ " <th>0</th>\n",
171
+ " <td>304399.0</td>\n",
172
+ " <td>01/03/2025 00:52</td>\n",
173
+ " <td>March Community Collection</td>\n",
174
+ " <td>Susan Bushnell</td>\n",
175
+ " <td>susan.bushnell@googlemail.com</td>\n",
176
+ " <td>Southfield Junior School</td>\n",
177
+ " <td>Shrivenham Road</td>\n",
178
+ " <td>Highworth</td>\n",
179
+ " <td>Swindon</td>\n",
180
+ " <td>SN6 7BZ</td>\n",
181
+ " <td>...</td>\n",
182
+ " <td></td>\n",
183
+ " <td>69</td>\n",
184
+ " <td>0.25000</td>\n",
185
+ " <td>0.0</td>\n",
186
+ " <td>0.0</td>\n",
187
+ " <td>0.0</td>\n",
188
+ " <td>medium</td>\n",
189
+ " <td>False</td>\n",
190
+ " <td>i would love to use it to spread the love of r...</td>\n",
191
+ " <td>True</td>\n",
192
+ " </tr>\n",
193
+ " <tr>\n",
194
+ " <th>1</th>\n",
195
+ " <td>305004.0</td>\n",
196
+ " <td>02/03/2025 19:52</td>\n",
197
+ " <td>March Community Collection</td>\n",
198
+ " <td>Sarah Arabestani</td>\n",
199
+ " <td>sarah.a@sandringhamnursery.com</td>\n",
200
+ " <td>Sandringham Nursery</td>\n",
201
+ " <td>16 Sandringham Road</td>\n",
202
+ " <td>Penylan</td>\n",
203
+ " <td>Cardiff</td>\n",
204
+ " <td>CF23 5BJ</td>\n",
205
+ " <td>...</td>\n",
206
+ " <td></td>\n",
207
+ " <td>46</td>\n",
208
+ " <td>0.06250</td>\n",
209
+ " <td>0.0</td>\n",
210
+ " <td>0.0</td>\n",
211
+ " <td>0.0</td>\n",
212
+ " <td>low</td>\n",
213
+ " <td>False</td>\n",
214
+ " <td>we would like to introduce early years yoga an...</td>\n",
215
+ " <td>False</td>\n",
216
+ " </tr>\n",
217
+ " <tr>\n",
218
+ " <th>2</th>\n",
219
+ " <td>305493.0</td>\n",
220
+ " <td>05/03/2025 14:34</td>\n",
221
+ " <td>March Community Collection</td>\n",
222
+ " <td>Rebecca Asker</td>\n",
223
+ " <td>mrsrasker@gmail.com</td>\n",
224
+ " <td>Newhaven PRU Outreach</td>\n",
225
+ " <td>Newhaven Gardens</td>\n",
226
+ " <td>NaN</td>\n",
227
+ " <td>Greenwich</td>\n",
228
+ " <td>SE96HR</td>\n",
229
+ " <td>...</td>\n",
230
+ " <td></td>\n",
231
+ " <td>86</td>\n",
232
+ " <td>0.09375</td>\n",
233
+ " <td>0.0</td>\n",
234
+ " <td>0.0</td>\n",
235
+ " <td>1.0</td>\n",
236
+ " <td>low</td>\n",
237
+ " <td>False</td>\n",
238
+ " <td>â£500 would enable us to set up a small sensor...</td>\n",
239
+ " <td>True</td>\n",
240
+ " </tr>\n",
241
+ " </tbody>\n",
242
+ "</table>\n",
243
+ "<p>3 rows × 21 columns</p>\n",
244
+ "</div>"
245
+ ],
246
+ "text/plain": [
247
+ " Id Date/Time Requested Giveaway Title Customer Name \\\n",
248
+ "0 304399.0 01/03/2025 00:52 March Community Collection Susan Bushnell \n",
249
+ "1 305004.0 02/03/2025 19:52 March Community Collection Sarah Arabestani \n",
250
+ "2 305493.0 05/03/2025 14:34 March Community Collection Rebecca Asker \n",
251
+ "\n",
252
+ " Email Address School Name \\\n",
253
+ "0 susan.bushnell@googlemail.com Southfield Junior School \n",
254
+ "1 sarah.a@sandringhamnursery.com Sandringham Nursery \n",
255
+ "2 mrsrasker@gmail.com Newhaven PRU Outreach \n",
256
+ "\n",
257
+ " Postal Address Address Line 2 Address City Postcode ... Unnamed: 11 \\\n",
258
+ "0 Shrivenham Road Highworth Swindon SN6 7BZ ... \n",
259
+ "1 16 Sandringham Road Penylan Cardiff CF23 5BJ ... \n",
260
+ "2 Newhaven Gardens NaN Greenwich SE96HR ... \n",
261
+ "\n",
262
+ " word_count necessity_index urgency_score severity_score \\\n",
263
+ "0 69 0.25000 0.0 0.0 \n",
264
+ "1 46 0.06250 0.0 0.0 \n",
265
+ "2 86 0.09375 0.0 0.0 \n",
266
+ "\n",
267
+ " vulnerability_score priority book_candidates \\\n",
268
+ "0 0.0 medium False \n",
269
+ "1 0.0 low False \n",
270
+ "2 1.0 low False \n",
271
+ "\n",
272
+ " clean_text is_heartfelt \n",
273
+ "0 i would love to use it to spread the love of r... True \n",
274
+ "1 we would like to introduce early years yoga an... False \n",
275
+ "2 â£500 would enable us to set up a small sensor... True \n",
276
+ "\n",
277
+ "[3 rows x 21 columns]"
278
+ ]
279
+ },
280
+ "execution_count": 23,
281
+ "metadata": {},
282
+ "output_type": "execute_result"
283
+ }
284
+ ],
285
+ "source": [
286
+ "df.head(3)"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "id": "301c8139-ea06-4ee3-a4eb-0b20d29dd6a2",
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": [
296
+ "# 2. "
297
+ ]
298
+ }
299
+ ],
300
+ "metadata": {
301
+ "kernelspec": {
302
+ "display_name": "Python 3 (ipykernel)",
303
+ "language": "python",
304
+ "name": "python3"
305
+ },
306
+ "language_info": {
307
+ "codemirror_mode": {
308
+ "name": "ipython",
309
+ "version": 3
310
+ },
311
+ "file_extension": ".py",
312
+ "mimetype": "text/x-python",
313
+ "name": "python",
314
+ "nbconvert_exporter": "python",
315
+ "pygments_lexer": "ipython3",
316
+ "version": "3.12.10"
317
+ }
318
+ },
319
+ "nbformat": 4,
320
+ "nbformat_minor": 5
321
+ }
src/models/topicModeling_contentRequests.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import re
3
+ import string
4
+ import torch
5
+ import spacy
6
+
7
+ from sentence_transformers import SentenceTransformer
8
+ import nltk
9
+ from nltk.corpus import stopwords
10
+ import contractions
11
+ from tqdm import tqdm
12
+
13
+
14
+ from sklearn.feature_extraction.text import CountVectorizer
15
+ from bertopic import BERTopic
16
+ from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance, OpenAI, PartOfSpeech
17
+ import openai
18
+ import numpy as np
19
+
20
+ import os
21
+ from dotenv import load_dotenv
22
+ load_dotenv(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../",".env")))
23
+
24
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
25
+
26
+ #################################
27
+ # OpenAI Topic Representation
28
+ #################################
29
+ def create_openai_model():
30
+ client = openai.OpenAI(api_key=OPENAI_API_KEY)
31
+ prompt = """
32
+ I have a topic that contains the following documents:
33
+ [DOCUMENTS]
34
+
35
+ The topic is described by the following keywords: [KEYWORDS]
36
+
37
+ Based on the information above, extract a short yet descriptive topic label of at most 4 words. The labels should be interpretable enough to stakeholders that don't have access to the raw data. Make sure it is in the following format:
38
+
39
+ topic: <topic label>
40
+ """
41
+ openai_model = OpenAI(client, model="gpt-4o-mini", exponential_backoff=True, chat=True, prompt=prompt)
42
+ return openai_model
43
+
44
+ #############################################
45
+ # Convert OpenAI Representation to CustomName
46
+ #############################################
47
+
48
+ def ai_labeles_to_custom_name(model):
49
+ chatgpt_topic_labels = {topic: " | ".join(list(zip(*values))[0]) for topic, values in model.topic_aspects_["OpenAI"].items()}
50
+ chatgpt_topic_labels[-1] = "Outlier Topic"
51
+ model.set_topic_labels(chatgpt_topic_labels)
52
+
53
+ """
54
+ -----------------------------------
55
+ Lemmatization & Stopword Removal
56
+ -----------------------------------
57
+
58
+ """
59
+ def topicModeling_preprocessing(df, spacy_model="en_core_web_lg"):
60
+
61
+ base_stopwords = set(stopwords.words('english'))
62
+
63
+ custom_stopwords = {
64
+ 'material', 'materials', 'resources', 'resource', 'activity',
65
+ 'activities', 'sheet', 'sheets', 'worksheet', 'worksheets',
66
+ 'teacher', 'teachers', 'teach', 'high school', 'highschool',
67
+ 'middle school', 'grade', 'grades', 'hs', 'level', 'age', 'ages',
68
+ 'older', 'older kid', 'kid', 'student', "1st", "2nd", "3rd", "4th", '5th', '6th',
69
+ '7th', '8th', '9th'
70
+ }
71
+
72
+ stopword_set = base_stopwords.union(custom_stopwords)
73
+
74
+ stopword_pattern = r'\b(?:' + '|'.join(re.escape(word) for word in stopword_set) + r')\b'
75
+
76
+ nlp = spacy.load(spacy_model)
77
+
78
+ def clean_lemmatize_text(text):
79
+ if not isinstance(text, str):
80
+ return None
81
+
82
+ text = contractions.fix(text)
83
+ text = re.sub(r'\s+', ' ', text).strip()
84
+ text = re.sub(stopword_pattern, '', text)
85
+
86
+ doc = nlp(text)
87
+ tokens = [token.lemma_ for token in doc]
88
+
89
+ clean_text = " ".join(tokens).strip()
90
+ clean_text = re.sub(r'\s+', ' ', clean_text)
91
+
92
+ return clean_text if clean_text else None
93
+
94
+
95
+ df['processedForModeling'] = df['preprocessedBasic'].apply(clean_lemmatize_text)
96
+
97
+ # Drop rows where cleaned text is empty or None
98
+ df = df.dropna(subset=['processedForModeling'])
99
+
100
+ return df
101
+
102
+ """
103
+ --------------------------
104
+ Load Transformer Model
105
+ --------------------------
106
+ """
107
+
108
+ def load_embedding_model(model_name):
109
+ if torch.cuda.is_available():
110
+ device = "cuda"
111
+ elif torch.backends.mps.is_available():
112
+ device = "mps"
113
+ else:
114
+ device = "cpu"
115
+
116
+ print(f"Using device: {device}")
117
+ return SentenceTransformer(model_name, device=device)
118
+
119
+
120
+ """
121
+ -------------------------
122
+ Batch Embedding Creation
123
+ -------------------------
124
+ """
125
+
126
+ def encode_content_documents(embedding_model, content_documents, batch_size=20):
127
+ embeddings_batches = []
128
+ total_batches = range(0, len(content_documents), batch_size)
129
+
130
+ with tqdm(total=len(total_batches), desc="Encoding Batches") as pbar:
131
+ for i in total_batches:
132
+ batch_docs = content_documents[i:i + batch_size]
133
+ batch_embeddings = embedding_model.encode(batch_docs, convert_to_numpy=True, show_progress_bar=False)
134
+ embeddings_batches.append(batch_embeddings)
135
+ pbar.update(1)
136
+
137
+ return np.vstack(embeddings_batches)
138
+
139
+ """
140
+ -----------------------------
141
+ Topic Modeling with BERTopic
142
+ -----------------------------
143
+ """
144
+
145
+ try:
146
+ nltk.data.find("corpora/stopwords")
147
+ except LookupError:
148
+ nltk.download("stopwords")
149
+
150
+ stopwords = list(stopwords.words('english')) + [
151
+ 'activities',
152
+ 'activity',
153
+ 'class',
154
+ 'classroom',
155
+ 'material',
156
+ 'materials',
157
+ 'membership',
158
+ 'memberships',
159
+ 'pupil',
160
+ 'pupils',
161
+ 'resource',
162
+ 'resources',
163
+ 'sheet',
164
+ 'sheets',
165
+ 'student',
166
+ 'students',
167
+ 'subscription',
168
+ 'subscriptions',
169
+ 'subscribe',
170
+ 'subscribed',
171
+ 'recommend',
172
+ 'recommendation',
173
+ 'teach',
174
+ 'teacher',
175
+ 'teachers',
176
+ 'tutor',
177
+ 'tutors',
178
+ 'twinkl',
179
+ 'twinkls',
180
+ 'twinkle',
181
+ 'worksheet',
182
+ 'worksheets',
183
+ ]
184
+
185
+ ######### --------------- BERTOPIC ----------------- #############
186
+ def bertopic_model(docs, embeddings, _embedding_model, _umap_model, _hdbscan_model):
187
+
188
+ main_representation_model = KeyBERTInspired()
189
+ aspect_representation_model1 = MaximalMarginalRelevance(diversity=.3)
190
+
191
+ # OpenAI Representation Model
192
+ client = openai.OpenAI(api_key=OPENAI_API_KEY)
193
+ prompt = """
194
+ I have a topic that contains the following documents:
195
+ [DOCUMENTS]
196
+
197
+ The topic is described by the following keywords: [KEYWORDS]
198
+
199
+ Based on the information above, extract a short but highly descriptive topic label of at most 5 words. Make sure it is in the following format:
200
+
201
+ topic: <topic label>
202
+ """
203
+ openai_model = OpenAI(client, model="gpt-4o-mini", exponential_backoff=True, chat=True, prompt=prompt)
204
+
205
+ representation_model = {
206
+ "Main": main_representation_model,
207
+ "Secondary Representation": aspect_representation_model1,
208
+ }
209
+
210
+ vectorizer_model = CountVectorizer(min_df=2, max_df=0.60, stop_words=stopwords)
211
+
212
+ seed_topic_list = [
213
+ ["autism", "special needs", "special education needs", "special education", "adhd", "autistic", "dyslexia", "dyslexic", "sen"],
214
+ ]
215
+
216
+ topic_model = BERTopic(
217
+ verbose=True,
218
+ embedding_model=_embedding_model,
219
+ umap_model=_umap_model,
220
+ hdbscan_model = _hdbscan_model,
221
+ vectorizer_model=vectorizer_model,
222
+ #seed_topic_list = seed_topic_list,
223
+ representation_model=representation_model,
224
+ )
225
+
226
+ topics, probs = topic_model.fit_transform(docs, embeddings)
227
+ return topic_model, topics, probs
228
+
229
+ ##################################
230
+ # TOPIC MERGING
231
+ ##################################
232
+
233
+ def merge_specific_topics(topic_model, sentences,
234
+ cancellation_keywords=["cancel", "cancellation", "cancel", "canceled"],
235
+ thanks_keywords=["thank", "thanks", "thank you", "thankyou", "ty", "thx"],
236
+ expensive_keywords=["can't afford", "price", "expensive", "cost"]):
237
+
238
+
239
+ topic_info = topic_model.get_topic_info()
240
+
241
+ # Identify cancellation-related topics by checking if any cancellation keyword appears in the topic name.
242
+ cancellation_regex = '|'.join(cancellation_keywords)
243
+ cancellation_topics = topic_info[
244
+ topic_info['Name'].str.contains(cancellation_regex, case=False, na=False)
245
+ ]['Topic'].tolist()
246
+
247
+ # Identify thank-you-related topics similarly.
248
+ thanks_regex = '|'.join(thanks_keywords)
249
+ thanks_topics = topic_info[
250
+ topic_info['Name'].str.contains(thanks_regex, case=False, na=False)
251
+ ]['Topic'].tolist()
252
+
253
+ # Identify expensive-related topics.
254
+ expensive_regex = '|'.join(expensive_keywords)
255
+ expensive_topics = topic_info[
256
+ topic_info['Name'].str.contains(expensive_regex, case=False, na=False)
257
+ ]['Topic'].tolist()
258
+
259
+ # Exclude the outlier topic (-1) if it appears.
260
+ cancellation_topics = [t for t in cancellation_topics if t != -1]
261
+ thanks_topics = [t for t in thanks_topics if t != -1]
262
+ expensive_topics = [t for t in expensive_topics if t != -1]
263
+
264
+ # Create a list of topics to merge
265
+ topics_to_merge = []
266
+
267
+ if len(cancellation_topics) > 1:
268
+ print(f"Merging cancellation topics: {cancellation_topics}")
269
+ topics_to_merge.append(cancellation_topics)
270
+
271
+ if len(thanks_topics) > 1:
272
+ print(f"Merging thank-you topics: {thanks_topics}")
273
+ topics_to_merge.append(thanks_topics)
274
+
275
+ if len(expensive_topics) > 1:
276
+ print(f"Merging expensive topics: {expensive_topics}")
277
+ topics_to_merge.append(expensive_topics)
278
+
279
+ # Call merge_topics
280
+ if topics_to_merge:
281
+ topic_model.merge_topics(sentences, topics_to_merge)
282
+
283
+ return topic_model
284
+
285
+
286
+ ##################################
287
+ # Topic to Dataframe Mapping
288
+ #################################
289
+
290
+ def update_df_with_topics(df, mapping, sentence_topics, topic_label_map):
291
+ topics_by_row = {}
292
+ for i, row_idx in enumerate(mapping):
293
+ topic = sentence_topics[i]
294
+ topics_by_row.setdefault(row_idx, set()).add(topic)
295
+
296
+ updated_df = df.copy()
297
+
298
+ def map_topics(row_idx):
299
+ topic_ids = topics_by_row.get(row_idx, set())
300
+ topic_names = [topic_label_map.get(t, str(t)) for t in topic_ids if t != -1]
301
+ return ", ".join(sorted(topic_names))
302
+
303
+ updated_df['Topics'] = updated_df.index.map(map_topics)
304
+ return updated_df
305
+