romybeaute commited on
Commit
11d757f
·
verified ·
1 Parent(s): 74a069f

Updated to get infos on LLM model and inference

Browse files
Files changed (1) hide show
  1. app.py +116 -41
app.py CHANGED
@@ -273,6 +273,12 @@ def load_precomputed_data(docs_file, embeddings_file):
273
  # 4. LLM loaders
274
  # =====================================================================
275
 
 
 
 
 
 
 
276
  #ADDED FOR LLM (START)
277
  @st.cache_resource
278
  def get_hf_client(model_id: str):
@@ -291,6 +297,11 @@ def _labels_cache_path(config_hash: str, model_id: str) -> Path:
291
  safe_model = re.sub(r"[^a-zA-Z0-9_.-]", "_", model_id)
292
  return CACHE_DIR / f"llm_labels_{safe_model}_{config_hash}.json"
293
 
 
 
 
 
 
294
 
295
  SYSTEM_PROMPT = """You are an expert phenomenologist analysing subjective reflections from specific experiences.
296
  Your task is to label a cluster of similar experiential reports.
@@ -332,37 +343,6 @@ def _clean_label(x: str) -> str:
332
 
333
 
334
 
335
- # def generate_labels_via_api(tm, model_id: str, prompt_template: str,
336
- # max_topics: int = 40, reps_per_topic: int = 8):
337
- # client, token = get_hf_client(model_id)
338
- # if not token:
339
- # raise RuntimeError("No HF_TOKEN found (Space Settings → Secrets).")
340
-
341
- # topic_info = tm.get_topic_info()
342
- # topic_info = topic_info[topic_info.Topic != -1].head(max_topics)
343
-
344
- # labels = {}
345
- # for tid in topic_info.Topic.tolist():
346
- # kws = [w for (w, _) in (tm.get_topic(tid) or [])][:10]
347
- # reps = (tm.get_representative_docs(tid) or [])[:reps_per_topic]
348
-
349
- # docs_block = "\n- " + "\n- ".join([r[:300].replace("\n", " ") for r in reps])
350
- # prompt = (prompt_template
351
- # .replace("[KEYWORDS]", ", ".join(kws))
352
- # .replace("[DOCUMENTS]", docs_block))
353
-
354
- # out = client.text_generation(
355
- # prompt,
356
- # max_new_tokens=32,
357
- # temperature=0.2,
358
- # stop=["\n"], # stop is the current arg; stop_sequences is deprecated
359
- # )
360
- # labels[int(tid)] = _clean_label(out)
361
-
362
- # return labels
363
-
364
-
365
-
366
 
367
  def generate_labels_via_chat_completion(
368
  topic_model: BERTopic,
@@ -378,6 +358,10 @@ def generate_labels_via_chat_completion(
378
  Label topics AFTER fitting (fast + stable on Spaces).
379
  Returns {topic_id: label}.
380
  """
 
 
 
 
381
  cache_path = _labels_cache_path(config_hash, model_id)
382
 
383
  if (not force) and cache_path.exists():
@@ -415,23 +399,74 @@ def generate_labels_via_chat_completion(
415
  docs_block = "- (No representative docs available)"
416
 
417
  user_prompt = USER_TEMPLATE.format(documents=docs_block, keywords=keywords)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
  # --- THE KEY PART: chat_completion ---
420
- out = client.chat_completion(
421
- model=model_id,
422
- messages=[
423
- {"role": "system", "content": SYSTEM_PROMPT},
424
- {"role": "user", "content": user_prompt},
425
- ],
426
- max_tokens=24,
427
- temperature=temperature,
428
- stop=["\n"],
429
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  # ------------------------------------
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  raw = out.choices[0].message.content
433
  labels[int(tid)] = _clean_label(raw)
434
 
 
435
  prog.progress(int(100 * i / max(total, 1)))
436
 
437
  try:
@@ -1120,6 +1155,33 @@ else:
1120
  "HF model id for labelling",
1121
  value="meta-llama/Meta-Llama-3-8B-Instruct",
1122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1123
 
1124
  cA, cB, cC = st.columns([1, 1, 2])
1125
  max_topics = cA.slider("Max topics", 5, 120, 40, 5)
@@ -1147,6 +1209,19 @@ else:
1147
  st.rerun()
1148
  except Exception as e:
1149
  st.error(f"LLM labelling failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
1150
 
1151
  # Apply labels (LLM overrides keyword names)
1152
  default_map = tm.get_topic_info().set_index("Topic")["Name"].to_dict()
 
273
  # 4. LLM loaders
274
  # =====================================================================
275
 
276
+ # Approximate price for cost estimates in the UI only.
277
+ # Novita Llama 3 8B is around $0.04 per 1M input tokens
278
+ # and $0.04 per 1M output tokens – adjust if needed.
279
+ HF_APPROX_PRICE_PER_MTOKENS_USD = 0.04
280
+
281
+
282
  #ADDED FOR LLM (START)
283
  @st.cache_resource
284
  def get_hf_client(model_id: str):
 
297
  safe_model = re.sub(r"[^a-zA-Z0-9_.-]", "_", model_id)
298
  return CACHE_DIR / f"llm_labels_{safe_model}_{config_hash}.json"
299
 
300
+ def _hf_status_code(e: Exception) -> int | None:
301
+ """Extract HTTP status code from a huggingface_hub error, if present."""
302
+ resp = getattr(e, "response", None)
303
+ return getattr(resp, "status_code", None)
304
+
305
 
306
  SYSTEM_PROMPT = """You are an expert phenomenologist analysing subjective reflections from specific experiences.
307
  Your task is to label a cluster of similar experiential reports.
 
343
 
344
 
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
  def generate_labels_via_chat_completion(
348
  topic_model: BERTopic,
 
358
  Label topics AFTER fitting (fast + stable on Spaces).
359
  Returns {topic_id: label}.
360
  """
361
+
362
+ # Remember which HF model id we requested on the last run
363
+ st.session_state["hf_last_model_param"] = model_id
364
+
365
  cache_path = _labels_cache_path(config_hash, model_id)
366
 
367
  if (not force) and cache_path.exists():
 
399
  docs_block = "- (No representative docs available)"
400
 
401
  user_prompt = USER_TEMPLATE.format(documents=docs_block, keywords=keywords)
402
+ # Store one example prompt (for UI inspection) – will be overwritten each run
403
+ st.session_state["hf_last_example_prompt"] = user_prompt
404
+
405
+ # # --- THE KEY PART: chat_completion ---
406
+ # out = client.chat_completion(
407
+ # model=model_id,
408
+ # messages=[
409
+ # {"role": "system", "content": SYSTEM_PROMPT},
410
+ # {"role": "user", "content": user_prompt},
411
+ # ],
412
+ # max_tokens=24,
413
+ # temperature=temperature,
414
+ # stop=["\n"],
415
+ # )
416
+ # # ------------------------------------
417
+
418
+ # raw = out.choices[0].message.content
419
+ # labels[int(tid)] = _clean_label(raw)
420
+
421
 
422
  # --- THE KEY PART: chat_completion ---
423
+ try:
424
+ out = client.chat_completion(
425
+ model=model_id,
426
+ messages=[
427
+ {"role": "system", "content": SYSTEM_PROMPT},
428
+ {"role": "user", "content": user_prompt},
429
+ ],
430
+ max_tokens=24,
431
+ temperature=temperature,
432
+ stop=["\n"],
433
+ )
434
+ # Store the provider-returned model id (if available)
435
+ provider_model = getattr(out, "model", None)
436
+ if provider_model:
437
+ st.session_state["hf_last_provider_model"] = provider_model
438
+ except Exception as e:
439
+ # Nice message for the specific 402 you're seeing
440
+ if _hf_status_code(e) == 402:
441
+ raise RuntimeError(
442
+ "Hugging Face returned 402 Payment Required for this LLM call. "
443
+ "You have used up the monthly Inference Provider credits on this "
444
+ "account. Either upgrade to PRO / enable pay-as-you-go, or skip "
445
+ "the 'Generate LLM labels (API)' step."
446
+ ) from e
447
+ # Anything else: bubble up the original error
448
+ raise
449
  # ------------------------------------
450
 
451
+ # --- Best-effort local accounting of token usage (this Streamlit session) ---
452
+ usage = getattr(out, "usage", None)
453
+ total_tokens = None
454
+
455
+ # `usage` might be a dict (raw JSON) or an object with attributes
456
+ if isinstance(usage, dict):
457
+ total_tokens = usage.get("total_tokens")
458
+ else:
459
+ total_tokens = getattr(usage, "total_tokens", None)
460
+
461
+ if total_tokens is not None:
462
+ st.session_state.setdefault("hf_tokens_total", 0)
463
+ st.session_state["hf_tokens_total"] += int(total_tokens)
464
+ # ---------------------------------------------------------------------------
465
+
466
  raw = out.choices[0].message.content
467
  labels[int(tid)] = _clean_label(raw)
468
 
469
+
470
  prog.progress(int(100 * i / max(total, 1)))
471
 
472
  try:
 
1155
  "HF model id for labelling",
1156
  value="meta-llama/Meta-Llama-3-8B-Instruct",
1157
  )
1158
+ with st.expander("Show LLM configuration and prompts"):
1159
+ # What we *request*
1160
+ st.markdown(f"**HF model id (requested):** `{model_id}`")
1161
+
1162
+ # What was used on the last run, if available
1163
+ requested_last = st.session_state.get("hf_last_model_param")
1164
+ provider_model = st.session_state.get("hf_last_provider_model")
1165
+
1166
+ if requested_last:
1167
+ st.markdown(f"**Last run – requested model id:** `{requested_last}`")
1168
+ if provider_model:
1169
+ st.markdown(f"**Last run – provider model (returned):** `{provider_model}`")
1170
+ else:
1171
+ st.caption("Run LLM labelling once to see the provider-returned model id.")
1172
+
1173
+ st.markdown("**System prompt:**")
1174
+ st.code(SYSTEM_PROMPT, language="markdown")
1175
+
1176
+ st.markdown("**User prompt template:**")
1177
+ st.code(USER_TEMPLATE, language="markdown")
1178
+
1179
+ example_prompt = st.session_state.get("hf_last_example_prompt")
1180
+ if example_prompt:
1181
+ st.markdown("**Example full prompt for one topic (last run):**")
1182
+ st.code(example_prompt, language="markdown")
1183
+ else:
1184
+ st.caption("No example prompt stored yet – run LLM labelling to populate this.")
1185
 
1186
  cA, cB, cC = st.columns([1, 1, 2])
1187
  max_topics = cA.slider("Max topics", 5, 120, 40, 5)
 
1209
  st.rerun()
1210
  except Exception as e:
1211
  st.error(f"LLM labelling failed: {e}")
1212
+
1213
+
1214
+ # Approximate HF usage for *this* Streamlit session (local estimate only)
1215
+ hf_tokens_total = st.session_state.get("hf_tokens_total", 0)
1216
+ if hf_tokens_total:
1217
+ approx_cost = hf_tokens_total / 1_000_000 * HF_APPROX_PRICE_PER_MTOKENS_USD
1218
+ st.caption(
1219
+ f"Approx. HF LLM usage this session: ~{hf_tokens_total:,} tokens "
1220
+ f"(~${approx_cost:.4f} at "
1221
+ f"${HF_APPROX_PRICE_PER_MTOKENS_USD}/M tokens, "
1222
+ "based on Novita Llama 3 8B pricing). "
1223
+ )
1224
+
1225
 
1226
  # Apply labels (LLM overrides keyword names)
1227
  default_map = tm.get_topic_info().set_index("Topic")["Name"].to_dict()