romybeaute commited on
Commit
1952a74
·
verified ·
1 Parent(s): 317bd9e

updated for LLM chat (API)

Browse files
Files changed (1) hide show
  1. app.py +298 -19
app.py CHANGED
@@ -21,6 +21,13 @@ import os
21
  import nltk
22
  import json
23
 
 
 
 
 
 
 
 
24
  # =====================================================================
25
  # NLTK setup
26
  # =====================================================================
@@ -243,7 +250,7 @@ DATASETS = None # keep name for clarity; we’ll fill it when rendering the sid
243
  HISTORY_FILE = str(PROC_DIR / "run_history.json")
244
 
245
  # =====================================================================
246
- # 3. Embedding & LLM loaders
247
  # =====================================================================
248
 
249
 
@@ -253,6 +260,7 @@ def load_embedding_model(model_name):
253
  return SentenceTransformer(model_name)
254
 
255
 
 
256
  @st.cache_data
257
  def load_precomputed_data(docs_file, embeddings_file):
258
  docs = np.load(docs_file, allow_pickle=True).tolist()
@@ -261,7 +269,182 @@ def load_precomputed_data(docs_file, embeddings_file):
261
 
262
 
263
  # =====================================================================
264
- # 4. Topic modeling function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  # =====================================================================
266
 
267
 
@@ -354,7 +537,7 @@ def perform_topic_modeling(_docs, _embeddings, config_hash):
354
 
355
 
356
  # =====================================================================
357
- # 5. CSV → documents → embeddings pipeline
358
  # =====================================================================
359
 
360
 
@@ -438,7 +621,7 @@ def generate_and_save_embeddings(
438
 
439
 
440
  # =====================================================================
441
- # 6. Sidebar — dataset, upload, parameters
442
  # =====================================================================
443
 
444
  st.sidebar.header("Data Input Method")
@@ -826,6 +1009,11 @@ else:
826
  )
827
  st.session_state.latest_results = (model, reduced, labels)
828
 
 
 
 
 
 
829
  entry = {
830
  "timestamp": str(pd.Timestamp.now()),
831
  "config": current_config,
@@ -846,6 +1034,92 @@ else:
846
  if "latest_results" in st.session_state:
847
  tm, reduced, labs = st.session_state.latest_results
848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
849
  st.subheader("Experiential Topics Visualisation")
850
  fig, _ = datamapplot.create_plot(reduced, labs)
851
  st.pyplot(fig)
@@ -855,23 +1129,28 @@ else:
855
 
856
  st.subheader("Export results (one row per topic)")
857
 
858
- full_reps = tm.get_topics(full=True)
859
- llm_reps = full_reps.get("LLM", {})
860
 
861
- llm_names = {}
862
- for tid, vals in llm_reps.items():
863
- try:
864
- llm_names[tid] = (
865
- (vals[0][0] or "").strip().strip('"').strip(".")
866
- )
867
- except Exception:
868
- llm_names[tid] = "Unlabelled"
 
 
 
 
 
 
 
 
 
 
869
 
870
- if not llm_names:
871
- st.caption("Note: Using default keyword-based topic names.")
872
- llm_names = (
873
- tm.get_topic_info().set_index("Topic")["Name"].to_dict()
874
- )
875
 
876
  doc_info = tm.get_document_info(docs)[["Document", "Topic"]]
877
 
 
21
  import nltk
22
  import json
23
 
24
+ # from huggingface_hub import hf_hub_download, InferenceClient # for the LLM API command
25
+ from huggingface_hub import InferenceClient # for the LLM API command
26
+
27
+
28
+
29
+
30
+
31
  # =====================================================================
32
  # NLTK setup
33
  # =====================================================================
 
250
  HISTORY_FILE = str(PROC_DIR / "run_history.json")
251
 
252
  # =====================================================================
253
+ # 3. Embedding loaders
254
  # =====================================================================
255
 
256
 
 
260
  return SentenceTransformer(model_name)
261
 
262
 
263
+
264
  @st.cache_data
265
  def load_precomputed_data(docs_file, embeddings_file):
266
  docs = np.load(docs_file, allow_pickle=True).tolist()
 
269
 
270
 
271
  # =====================================================================
272
+ # 4. LLM loaders
273
+ # =====================================================================
274
+
275
+ #ADDED FOR LLM (START)
276
+ @st.cache_resource
277
+ def get_hf_client(model_id: str):
278
+ token = os.environ.get("HF_TOKEN")
279
+ if not token:
280
+ try:
281
+ token = st.secrets.get("HF_TOKEN")
282
+ except Exception:
283
+ token = None
284
+
285
+ # Bake the model into the client so you don't pass model= every call
286
+ client = InferenceClient(model=model_id, token=token)
287
+ return client, token
288
+
289
+ def _labels_cache_path(config_hash: str, model_id: str) -> Path:
290
+ safe_model = re.sub(r"[^a-zA-Z0-9_.-]", "_", model_id)
291
+ return CACHE_DIR / f"llm_labels_{safe_model}_{config_hash}.json"
292
+
293
+
294
+ SYSTEM_PROMPT = """You are an expert phenomenologist analysing subjective reflections from specific experiences.
295
+ Your task is to label a cluster of similar experiential reports.
296
+
297
+ The title should be:
298
+ 1. HIGHLY SPECIFIC to the experiential characteristic unique to this "phenomenological" cluster
299
+ 2. PHENOMENOLOGICALLY DESCRIPTIVE (focus on *what* was felt/seen).
300
+ 3. DISTINCTIVE enough that it wouldn't apply equally well to other "phenomenological" clusters
301
+ 4. TECHNICALLY PRECISE, using domain-specific terminology where appropriate
302
+ 5. CONCEPTUALLY FOCUSED on the core specificities of this type of experience
303
+
304
+
305
+ Constraints:
306
+ - Output ONLY the label (no explanation).
307
+ - 3–7 words.
308
+ - No punctuation, no quotes, no extra text.
309
+ - Do not explain your reasoning
310
+ """
311
+
312
+
313
+ USER_TEMPLATE = """Here is a cluster of participant reports describing a specific phenomenon:
314
+
315
+ {documents}
316
+
317
+ Top keywords associated with this cluster:
318
+ {keywords}
319
+
320
+ Task: Return a single scientifically precise label (3–7 words). Output ONLY the label.
321
+ """
322
+
323
+ def _clean_label(x: str) -> str:
324
+ x = (x or "").strip()
325
+ x = x.splitlines()[0].strip() # first line only
326
+ x = x.strip(' "\'`')
327
+ x = re.sub(r"[.:\-–—]+$", "", x).strip() # remove trailing punctuation
328
+ # enforce "no punctuation" lightly (optional):
329
+ x = re.sub(r"[^\w\s]", "", x).strip()
330
+ return x or "Unlabelled"
331
+
332
+
333
+
334
+ # def generate_labels_via_api(tm, model_id: str, prompt_template: str,
335
+ # max_topics: int = 40, reps_per_topic: int = 8):
336
+ # client, token = get_hf_client(model_id)
337
+ # if not token:
338
+ # raise RuntimeError("No HF_TOKEN found (Space Settings → Secrets).")
339
+
340
+ # topic_info = tm.get_topic_info()
341
+ # topic_info = topic_info[topic_info.Topic != -1].head(max_topics)
342
+
343
+ # labels = {}
344
+ # for tid in topic_info.Topic.tolist():
345
+ # kws = [w for (w, _) in (tm.get_topic(tid) or [])][:10]
346
+ # reps = (tm.get_representative_docs(tid) or [])[:reps_per_topic]
347
+
348
+ # docs_block = "\n- " + "\n- ".join([r[:300].replace("\n", " ") for r in reps])
349
+ # prompt = (prompt_template
350
+ # .replace("[KEYWORDS]", ", ".join(kws))
351
+ # .replace("[DOCUMENTS]", docs_block))
352
+
353
+ # out = client.text_generation(
354
+ # prompt,
355
+ # max_new_tokens=32,
356
+ # temperature=0.2,
357
+ # stop=["\n"], # stop is the current arg; stop_sequences is deprecated
358
+ # )
359
+ # labels[int(tid)] = _clean_label(out)
360
+
361
+ # return labels
362
+
363
+
364
+
365
+
366
+ def generate_labels_via_chat_completion(
367
+ topic_model: BERTopic,
368
+ docs: list[str],
369
+ config_hash: str,
370
+ model_id: str = "meta-llama/Meta-Llama-3-8B-Instruct",
371
+ max_topics: int = 40,
372
+ max_docs_per_topic: int = 8,
373
+ doc_char_limit: int = 300,
374
+ temperature: float = 0.2,
375
+ force: bool = False) -> dict[int, str]:
376
+ """
377
+ Label topics AFTER fitting (fast + stable on Spaces).
378
+ Returns {topic_id: label}.
379
+ """
380
+ cache_path = _labels_cache_path(config_hash, model_id)
381
+
382
+ if (not force) and cache_path.exists():
383
+ try:
384
+ cached = json.loads(cache_path.read_text(encoding="utf-8"))
385
+ return {int(k): str(v) for k, v in cached.items()}
386
+ except Exception:
387
+ pass
388
+
389
+ client, token = get_hf_client(model_id)
390
+ if not token:
391
+ raise RuntimeError("No HF_TOKEN found in env/secrets.")
392
+
393
+ topic_info = topic_model.get_topic_info()
394
+ topic_info = topic_info[topic_info.Topic != -1].head(max_topics)
395
+
396
+ labels: dict[int, str] = {}
397
+ prog = st.progress(0)
398
+ total = len(topic_info)
399
+
400
+ for i, tid in enumerate(topic_info.Topic.tolist(), start=1):
401
+ words = topic_model.get_topic(tid) or []
402
+ keywords = ", ".join([w for (w, _) in words[:10]])
403
+
404
+ try:
405
+ reps = (topic_model.get_representative_docs(tid) or [])[:max_docs_per_topic]
406
+ except Exception:
407
+ reps = []
408
+
409
+ # keep prompt small
410
+ reps = [r.replace("\n", " ").strip()[:doc_char_limit] for r in reps if str(r).strip()]
411
+ if reps:
412
+ docs_block = "\n".join([f"- {r}" for r in reps])
413
+ else:
414
+ docs_block = "- (No representative docs available)"
415
+
416
+ user_prompt = USER_TEMPLATE.format(documents=docs_block, keywords=keywords)
417
+
418
+ # --- THE KEY PART: chat_completion ---
419
+ out = client.chat_completion(
420
+ model=model_id,
421
+ messages=[
422
+ {"role": "system", "content": SYSTEM_PROMPT},
423
+ {"role": "user", "content": user_prompt},
424
+ ],
425
+ max_tokens=24,
426
+ temperature=temperature,
427
+ stop=["\n"],
428
+ )
429
+ # ------------------------------------
430
+
431
+ raw = out.choices[0].message.content
432
+ labels[int(tid)] = _clean_label(raw)
433
+
434
+ prog.progress(int(100 * i / max(total, 1)))
435
+
436
+ try:
437
+ cache_path.write_text(json.dumps({str(k): v for k, v in labels.items()}, indent=2), encoding="utf-8")
438
+ except Exception:
439
+ pass
440
+
441
+ return labels
442
+ #ADDED FOR LLM (END)
443
+
444
+
445
+
446
+ # =====================================================================
447
+ # 5. Topic modeling function
448
  # =====================================================================
449
 
450
 
 
537
 
538
 
539
  # =====================================================================
540
+ # 6. CSV → documents → embeddings pipeline
541
  # =====================================================================
542
 
543
 
 
621
 
622
 
623
  # =====================================================================
624
+ # 7. Sidebar — dataset, upload, parameters
625
  # =====================================================================
626
 
627
  st.sidebar.header("Data Input Method")
 
1009
  )
1010
  st.session_state.latest_results = (model, reduced, labels)
1011
 
1012
+ ### ADD FOR LLM (START)
1013
+ st.session_state.latest_config_hash = get_config_hash(current_config)
1014
+ st.session_state.latest_config = current_config
1015
+ ### ADD FOR LLM (END)
1016
+
1017
  entry = {
1018
  "timestamp": str(pd.Timestamp.now()),
1019
  "config": current_config,
 
1034
  if "latest_results" in st.session_state:
1035
  tm, reduced, labs = st.session_state.latest_results
1036
 
1037
+ #USE NEW LABELS
1038
+
1039
+ # ##### ADDED FOR LLM (START)
1040
+ # st.subheader("LLM topic labelling (via Hugging Face API)")
1041
+
1042
+ # model_id = st.text_input(
1043
+ # "HF model id for labelling",
1044
+ # value="meta-llama/Meta-Llama-3-8B-Instruct",
1045
+ # )
1046
+
1047
+ # prompt_template = st.text_area(
1048
+ # "Prompt template",
1049
+ # value=YOUR_PROMPT_STRING, # define it once (see below)
1050
+ # height=220,
1051
+ # )
1052
+
1053
+ # max_topics = st.slider("Max topics to label", 5, 80, 40)
1054
+ # reps_per_topic = st.slider("Representative excerpts per topic", 2, 15, 8)
1055
+
1056
+ # do_label = st.button("Generate LLM labels (API)")
1057
+
1058
+ # if do_label:
1059
+ # try:
1060
+ # llm_names = generate_labels_via_api(
1061
+ # tm,
1062
+ # model_id=model_id,
1063
+ # prompt_template=prompt_template,
1064
+ # max_topics=max_topics,
1065
+ # reps_per_topic=reps_per_topic,
1066
+ # )
1067
+ # st.session_state.llm_names = llm_names
1068
+ # st.success(f"Generated {len(llm_names)} labels.")
1069
+ # except Exception as e:
1070
+ # st.error(str(e))
1071
+
1072
+ # # Merge labels (LLM overrides keyword names)
1073
+ # name_map = tm.get_topic_info().set_index("Topic")["Name"].to_dict()
1074
+ # llm_names = st.session_state.get("llm_names", {})
1075
+ # final_name_map = {**name_map, **llm_names}
1076
+
1077
+ # # rebuild per-document labels for plotting
1078
+ # labs = [final_name_map.get(t, "Unlabelled") for t in tm.topics_]
1079
+
1080
+
1081
+ # ##### ADDED FOR LLM (END)
1082
+
1083
+
1084
+ ##### ADDED FOR LLM (START)
1085
+ st.subheader("LLM topic labelling (via Hugging Face API)")
1086
+
1087
+ model_id = st.text_input(
1088
+ "HF model id for labelling",
1089
+ value="meta-llama/Meta-Llama-3-8B-Instruct",
1090
+ )
1091
+
1092
+ cA, cB, cC = st.columns([1, 1, 2])
1093
+ max_topics = cA.slider("Max topics", 5, 120, 40, 5)
1094
+ force = cB.checkbox("Force regenerate", value=False)
1095
+
1096
+ if cC.button("Generate LLM labels (API)", use_container_width=True):
1097
+ try:
1098
+ cfg_hash = st.session_state.get("latest_config_hash", "nohash")
1099
+ llm_names = generate_labels_via_chat_completion(
1100
+ topic_model=tm,
1101
+ docs=docs,
1102
+ config_hash=cfg_hash,
1103
+ model_id=model_id,
1104
+ max_topics=max_topics,
1105
+ force=force,
1106
+ )
1107
+ st.session_state.llm_names = llm_names
1108
+ st.success(f"Generated {len(llm_names)} labels.")
1109
+ st.rerun()
1110
+ except Exception as e:
1111
+ st.error(f"LLM labelling failed: {e}")
1112
+
1113
+ # Apply labels (LLM overrides keyword names)
1114
+ default_map = tm.get_topic_info().set_index("Topic")["Name"].to_dict()
1115
+ api_map = st.session_state.get("llm_names", {}) or {}
1116
+ final_name_map = {**default_map, **api_map}
1117
+
1118
+ labs = [final_name_map.get(t, "Unlabelled") for t in tm.topics_]
1119
+ ##### ADDED FOR LLM (END)
1120
+
1121
+
1122
+ # VISUALISATION
1123
  st.subheader("Experiential Topics Visualisation")
1124
  fig, _ = datamapplot.create_plot(reduced, labs)
1125
  st.pyplot(fig)
 
1129
 
1130
  st.subheader("Export results (one row per topic)")
1131
 
1132
+ # full_reps = tm.get_topics(full=True)
1133
+ # llm_reps = full_reps.get("LLM", {})
1134
 
1135
+ # llm_names = {}
1136
+ # for tid, vals in llm_reps.items():
1137
+ # try:
1138
+ # llm_names[tid] = (
1139
+ # (vals[0][0] or "").strip().strip('"').strip(".")
1140
+ # )
1141
+ # except Exception:
1142
+ # llm_names[tid] = "Unlabelled"
1143
+
1144
+ # if not llm_names:
1145
+ # st.caption("Note: Using default keyword-based topic names.")
1146
+ # llm_names = (
1147
+ # tm.get_topic_info().set_index("Topic")["Name"].to_dict()
1148
+ # )
1149
+
1150
+ default_map = tm.get_topic_info().set_index("Topic")["Name"].to_dict()
1151
+ api_map = st.session_state.get("llm_names", {}) or {}
1152
+ llm_names = {**default_map, **api_map}
1153
 
 
 
 
 
 
1154
 
1155
  doc_info = tm.get_document_info(docs)[["Document", "Topic"]]
1156