Jac-Zac commited on
Commit
eaeaa68
·
1 Parent(s): eb41f91

Update code

Browse files
Files changed (6) hide show
  1. README.md +17 -0
  2. pyproject.toml +4 -3
  3. tabs/chat.py +62 -83
  4. tabs/extract.py +0 -4
  5. utils/chat.py +4 -6
  6. uv.lock +23 -4
README.md CHANGED
@@ -1,5 +1,7 @@
1
  # Persona UI
2
 
 
 
3
  Streamlit interface for persona vector extraction, analysis, and chat.
4
 
5
  > [!WARNING]
@@ -74,6 +76,21 @@ parent/
74
  streamlit run app.py
75
  ```
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  ## Configuration
78
 
79
  Copy `.env.example` to `.env` and fill in:
 
1
  # Persona UI
2
 
3
+ [![Deploy to Hugging Face Spaces](https://huggingface.co/landing/badge.svg)](https://huggingface.co/spaces/implicit-personalization/persona-ui)
4
+
5
  Streamlit interface for persona vector extraction, analysis, and chat.
6
 
7
  > [!WARNING]
 
76
  streamlit run app.py
77
  ```
78
 
79
+ ## Hugging Face Spaces Deployment
80
+
81
+ This app can be deployed to Hugging Face Spaces using Docker.
82
+
83
+ ### Prerequisites
84
+
85
+ No secrets needed! The dependencies are published on PyPI.
86
+
87
+ ### Build Locally (Optional)
88
+
89
+ ```bash
90
+ docker build -t persona-ui .
91
+ docker run -p 8501:8501 persona-ui
92
+ ```
93
+
94
  ## Configuration
95
 
96
  Copy `.env.example` to `.env` and fill in:
pyproject.toml CHANGED
@@ -10,14 +10,15 @@ dependencies = [
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
12
  "python-dotenv>=1.2.2",
 
13
  ]
14
 
 
15
  [tool.uv.sources]
16
  # Local development:
17
- # persona-vectors = { path = "../persona-vectors", editable = true }
18
  # persona-data = { path = "../persona-data", editable = true }
19
- persona-vectors = { git = "ssh://git@github.com/implicit-personalization/persona-vectors.git" }
20
- persona-data = { git = "ssh://git@github.com/implicit-personalization/persona-data.git" }
21
 
22
  # [build-system]
23
  # requires = ["uv_build>=0.11.3,<0.12"]
 
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
12
  "python-dotenv>=1.2.2",
13
+ "transformers>=5.5.0",
14
  ]
15
 
16
+
17
  [tool.uv.sources]
18
  # Local development:
 
19
  # persona-data = { path = "../persona-data", editable = true }
20
+ persona-vectors = { path = "../persona-vectors", editable = true }
21
+ persona-data = { git = "https://github.com/implicit-personalization/persona-data.git" }
22
 
23
  # [build-system]
24
  # requires = ["uv_build>=0.11.3,<0.12"]
tabs/chat.py CHANGED
@@ -38,7 +38,8 @@ def _render_collapsible_markdown(content: str) -> None:
38
  def _render_chat_message(message: dict[str, str]) -> None:
39
  if not message.get("content"):
40
  return
41
- with st.chat_message(message["role"]):
 
42
  _render_collapsible_markdown(message["content"])
43
 
44
 
@@ -46,47 +47,24 @@ def _render_inline_system_prompt(
46
  prompt_key: str,
47
  prompt_mode: str,
48
  active_system_prompt: str | None,
49
- edit_key: str,
50
  height: int = 200,
51
  ) -> str | None:
52
- """Render the system prompt as an inline editable item at the top of the chat."""
53
  if prompt_mode == "empty":
54
  return active_system_prompt
55
 
56
  if prompt_key not in st.session_state:
57
  st.session_state[prompt_key] = active_system_prompt or ""
58
 
59
- current_prompt = st.session_state[prompt_key] or None
60
- is_editing = st.session_state.get(edit_key) == -1
61
-
62
  with st.container(border=True):
63
  st.caption("System prompt")
64
- if is_editing:
65
- new_val = st.text_area(
66
- "system_prompt_edit",
67
- value=current_prompt or "",
68
- height=height,
69
- label_visibility="collapsed",
70
- key=f"{prompt_key}_inline_edit",
71
- )
72
- c1, c2 = st.columns(2)
73
- with c1:
74
- if st.button("Save", key=f"{edit_key}_sys_save", type="primary"):
75
- st.session_state[prompt_key] = new_val
76
- st.session_state[edit_key] = None
77
- st.rerun()
78
- with c2:
79
- if st.button("Cancel", key=f"{edit_key}_sys_cancel"):
80
- st.session_state[edit_key] = None
81
- st.rerun()
82
- else:
83
- if current_prompt:
84
- _render_collapsible_markdown(current_prompt)
85
- else:
86
- st.markdown("*(empty)*")
87
- if st.button("Edit", key=f"{edit_key}_sys_edit"):
88
- st.session_state[edit_key] = -1
89
- st.rerun()
90
 
91
  return st.session_state.get(prompt_key) or None
92
 
@@ -105,7 +83,8 @@ def _render_editable_message(
105
 
106
  is_editing = st.session_state.get(edit_key) == msg_index
107
 
108
- with st.chat_message(message["role"]):
 
109
  if is_editing:
110
  new_content = st.text_area(
111
  "Edit",
@@ -305,14 +284,9 @@ def _render_compare_mode(
305
  """Render the full side-by-side comparison UI."""
306
  left_col, right_col = st.columns(2)
307
 
308
- def render_panel(
309
- side: str, column
310
- ) -> tuple[dict[str, object], Any, str | None, str]:
311
  panel_key = widget_key(context_key, f"cmp_{side}")
312
- state = st.session_state.get(panel_key)
313
- if state is None:
314
- state = _default_chat_state()
315
- st.session_state[panel_key] = state
316
  prompt_key = widget_key(panel_key, "custom_prompt")
317
  show_all_key = widget_key(panel_key, "show_all")
318
  edit_key = widget_key(panel_key, "edit_idx")
@@ -374,7 +348,6 @@ def _render_compare_mode(
374
  prompt_key,
375
  prompt_mode,
376
  active_system_prompt,
377
- edit_key,
378
  height=150,
379
  )
380
  _render_chat_window(
@@ -390,11 +363,9 @@ def _render_compare_mode(
390
  return state, chat_log, active_system_prompt, pending_regen_key
391
 
392
  with left_col:
393
- left_state, left_log, left_prompt, left_pending = render_panel("left", left_col)
394
  with right_col:
395
- right_state, right_log, right_prompt, right_pending = render_panel(
396
- "right", right_col
397
- )
398
 
399
  panels = [
400
  (left_state, left_log, left_prompt, left_pending),
@@ -454,12 +425,9 @@ def _render_compare_mode(
454
  executor.submit(
455
  generate_chat_reply,
456
  model=model,
457
- messages=(
458
- [{"role": "system", "content": panel_prompt}]
459
- if panel_prompt
460
- else []
461
- )
462
- + panel_state["messages"],
463
  remote=remote,
464
  past_key_values=panel_state["past_key_values"],
465
  **gen_kwargs,
@@ -479,12 +447,9 @@ def _render_compare_mode(
479
  results.append(
480
  generate_chat_reply(
481
  model=model,
482
- messages=(
483
- [{"role": "system", "content": panel_prompt}]
484
- if panel_prompt
485
- else []
486
- )
487
- + panel_state["messages"],
488
  remote=remote,
489
  past_key_values=panel_state["past_key_values"],
490
  **gen_kwargs,
@@ -507,36 +472,22 @@ def _render_compare_mode(
507
  with panel_log:
508
  _render_chat_message({"role": "assistant", "content": result.text})
509
 
 
 
 
510
 
511
- # ── Main tab entry point ───────────────────────────────────────────────────────
512
-
513
-
514
- def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
515
- """Render the chat tab."""
516
 
517
- st.title("Chat")
518
 
519
- context_key = chat_session_key(model_name, dataset_source)
520
- chat_state = get_chat_state(model_name, remote, dataset_source)
521
- try:
522
- dataset, dataset_status = load_dataset(
523
- dataset_source,
524
- personas_file=st.session_state.get("extract__personas_file"),
525
- qa_file=st.session_state.get("extract__qa_file"),
526
- )
527
- st.caption(dataset_status)
528
- except Exception as exc:
529
- st.error(f"Could not load data: {exc}")
530
- st.info("Check the selected dataset source or upload both JSONL files.")
531
- return
532
 
533
- personas = list(dataset)
534
- if not personas:
535
- st.warning("No personas found in the selected dataset.")
536
- st.info("Try a different dataset source or upload a non-empty personas file.")
537
- return
538
 
539
- # ── Generation settings ───────────────────────────────────────────────────
 
 
540
  with st.expander("Advanced", expanded=False):
541
  config_col1, config_col2 = st.columns([2, 1])
542
  with config_col1:
@@ -643,6 +594,35 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
643
  repetition_penalty=repetition_penalty,
644
  seed=generation_seed,
645
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
  # ── Mode toggle ───────────────────────────────────────────────────────────
648
  compare_mode = st.toggle(
@@ -731,7 +711,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
731
  prompt_key,
732
  prompt_mode,
733
  active_system_prompt,
734
- edit_key,
735
  height=200,
736
  )
737
 
 
38
  def _render_chat_message(message: dict[str, str]) -> None:
39
  if not message.get("content"):
40
  return
41
+ with st.container(border=True):
42
+ st.caption(message["role"])
43
  _render_collapsible_markdown(message["content"])
44
 
45
 
 
47
  prompt_key: str,
48
  prompt_mode: str,
49
  active_system_prompt: str | None,
 
50
  height: int = 200,
51
  ) -> str | None:
52
+ """Render the system prompt as an always-editable text area at the top of the chat."""
53
  if prompt_mode == "empty":
54
  return active_system_prompt
55
 
56
  if prompt_key not in st.session_state:
57
  st.session_state[prompt_key] = active_system_prompt or ""
58
 
 
 
 
59
  with st.container(border=True):
60
  st.caption("System prompt")
61
+ st.text_area(
62
+ "system_prompt_edit",
63
+ value=st.session_state[prompt_key],
64
+ height=height,
65
+ label_visibility="collapsed",
66
+ key=prompt_key,
67
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  return st.session_state.get(prompt_key) or None
70
 
 
83
 
84
  is_editing = st.session_state.get(edit_key) == msg_index
85
 
86
+ with st.container(border=True):
87
+ st.caption(message["role"])
88
  if is_editing:
89
  new_content = st.text_area(
90
  "Edit",
 
284
  """Render the full side-by-side comparison UI."""
285
  left_col, right_col = st.columns(2)
286
 
287
+ def render_panel(side: str) -> tuple[dict[str, object], Any, str | None, str]:
 
 
288
  panel_key = widget_key(context_key, f"cmp_{side}")
289
+ state = _panel_state(panel_key)
 
 
 
290
  prompt_key = widget_key(panel_key, "custom_prompt")
291
  show_all_key = widget_key(panel_key, "show_all")
292
  edit_key = widget_key(panel_key, "edit_idx")
 
348
  prompt_key,
349
  prompt_mode,
350
  active_system_prompt,
 
351
  height=150,
352
  )
353
  _render_chat_window(
 
363
  return state, chat_log, active_system_prompt, pending_regen_key
364
 
365
  with left_col:
366
+ left_state, left_log, left_prompt, left_pending = render_panel("left")
367
  with right_col:
368
+ right_state, right_log, right_prompt, right_pending = render_panel("right")
 
 
369
 
370
  panels = [
371
  (left_state, left_log, left_prompt, left_pending),
 
425
  executor.submit(
426
  generate_chat_reply,
427
  model=model,
428
+ messages=_build_chat_messages(
429
+ panel_prompt, panel_state["messages"]
430
+ ),
 
 
 
431
  remote=remote,
432
  past_key_values=panel_state["past_key_values"],
433
  **gen_kwargs,
 
447
  results.append(
448
  generate_chat_reply(
449
  model=model,
450
+ messages=_build_chat_messages(
451
+ panel_prompt, panel_state["messages"]
452
+ ),
 
 
 
453
  remote=remote,
454
  past_key_values=panel_state["past_key_values"],
455
  **gen_kwargs,
 
472
  with panel_log:
473
  _render_chat_message({"role": "assistant", "content": result.text})
474
 
475
+ # Rerun so the newly appended turns are redrawn through the editable history
476
+ # renderer instead of only appearing in the one-off generation pass.
477
+ st.rerun()
478
 
 
 
 
 
 
479
 
480
+ # ── Main tab entry point ───────────────────────────────────────────────────────
481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
+ def _render_generation_settings(
484
+ context_key: str, remote: bool
485
+ ) -> tuple[dict, bool]:
486
+ """Render the Advanced generation settings expander.
 
487
 
488
+ Returns ``(gen_kwargs, advanced_generation)`` where ``advanced_generation``
489
+ is True when any setting differs from its default.
490
+ """
491
  with st.expander("Advanced", expanded=False):
492
  config_col1, config_col2 = st.columns([2, 1])
493
  with config_col1:
 
594
  repetition_penalty=repetition_penalty,
595
  seed=generation_seed,
596
  )
597
+ return gen_kwargs, advanced_generation
598
+
599
+
600
+ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
601
+ """Render the chat tab."""
602
+
603
+ st.title("Chat")
604
+
605
+ context_key = chat_session_key(model_name, dataset_source)
606
+ chat_state = get_chat_state(model_name, remote, dataset_source)
607
+ try:
608
+ dataset, dataset_status = load_dataset(
609
+ dataset_source,
610
+ personas_file=st.session_state.get("extract__personas_file"),
611
+ qa_file=st.session_state.get("extract__qa_file"),
612
+ )
613
+ st.caption(dataset_status)
614
+ except Exception as exc:
615
+ st.error(f"Could not load data: {exc}")
616
+ st.info("Check the selected dataset source or upload both JSONL files.")
617
+ return
618
+
619
+ personas = list(dataset)
620
+ if not personas:
621
+ st.warning("No personas found in the selected dataset.")
622
+ st.info("Try a different dataset source or upload a non-empty personas file.")
623
+ return
624
+
625
+ gen_kwargs, advanced_generation = _render_generation_settings(context_key, remote)
626
 
627
  # ── Mode toggle ───────────────────────────────────────────────────────────
628
  compare_mode = st.toggle(
 
711
  prompt_key,
712
  prompt_mode,
713
  active_system_prompt,
 
714
  height=200,
715
  )
716
 
tabs/extract.py CHANGED
@@ -111,7 +111,6 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
111
  st.info("Select at least one persona.")
112
  return
113
 
114
- runs = None
115
  max_questions = 0
116
 
117
  with st.expander("Advanced", expanded=False):
@@ -190,9 +189,6 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
190
  )
191
  st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
192
 
193
- if runs is None:
194
- return
195
-
196
  run_clicked = st.button("Run extraction", type="primary")
197
  if not run_clicked:
198
  return
 
111
  st.info("Select at least one persona.")
112
  return
113
 
 
114
  max_questions = 0
115
 
116
  with st.expander("Advanced", expanded=False):
 
189
  )
190
  st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
191
 
 
 
 
192
  run_clicked = st.button("Run extraction", type="primary")
193
  if not run_clicked:
194
  return
utils/chat.py CHANGED
@@ -82,11 +82,9 @@ def _format_generation_prompt(
82
  Tries the tokenizer's chat template first, falls back to normalized messages,
83
  then to a plain-text format if both template attempts fail.
84
  """
85
- normalized_messages = messages
86
-
87
  try:
88
  prompt = tokenizer.apply_chat_template(
89
- normalized_messages,
90
  tokenize=False,
91
  add_generation_prompt=True,
92
  )
@@ -94,11 +92,11 @@ def _format_generation_prompt(
94
  logger.debug(
95
  "Chat template failed on raw messages, trying normalized", exc_info=True
96
  )
97
- normalized_messages = normalize_messages(messages)
98
 
99
  try:
100
  prompt = tokenizer.apply_chat_template(
101
- normalized_messages,
102
  tokenize=False,
103
  add_generation_prompt=True,
104
  )
@@ -108,7 +106,7 @@ def _format_generation_prompt(
108
  exc_info=True,
109
  )
110
  prompt = _format_plain_messages(
111
- normalized_messages,
112
  add_generation_prompt=True,
113
  )
114
 
 
82
  Tries the tokenizer's chat template first, falls back to normalized messages,
83
  then to a plain-text format if both template attempts fail.
84
  """
 
 
85
  try:
86
  prompt = tokenizer.apply_chat_template(
87
+ messages,
88
  tokenize=False,
89
  add_generation_prompt=True,
90
  )
 
92
  logger.debug(
93
  "Chat template failed on raw messages, trying normalized", exc_info=True
94
  )
95
+ messages = normalize_messages(messages)
96
 
97
  try:
98
  prompt = tokenizer.apply_chat_template(
99
+ messages,
100
  tokenize=False,
101
  add_generation_prompt=True,
102
  )
 
106
  exc_info=True,
107
  )
108
  prompt = _format_plain_messages(
109
+ messages,
110
  add_generation_prompt=True,
111
  )
112
 
uv.lock CHANGED
@@ -1561,7 +1561,7 @@ wheels = [
1561
  [[package]]
1562
  name = "persona-data"
1563
  version = "0.1.0"
1564
- source = { git = "ssh://git@github.com/implicit-personalization/persona-data.git#3763bd6e42472b589b4e32acd3e47b711a0af1f5" }
1565
  dependencies = [
1566
  { name = "huggingface-hub" },
1567
  { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
@@ -1580,21 +1580,23 @@ dependencies = [
1580
  { name = "plotly" },
1581
  { name = "python-dotenv" },
1582
  { name = "streamlit" },
 
1583
  ]
1584
 
1585
  [package.metadata]
1586
  requires-dist = [
1587
- { name = "persona-data", git = "ssh://git@github.com/implicit-personalization/persona-data.git" },
1588
- { name = "persona-vectors", git = "ssh://git@github.com/implicit-personalization/persona-vectors.git" },
1589
  { name = "plotly", specifier = ">=6.6.0" },
1590
  { name = "python-dotenv", specifier = ">=1.2.2" },
1591
  { name = "streamlit", specifier = ">=1.44.0" },
 
1592
  ]
1593
 
1594
  [[package]]
1595
  name = "persona-vectors"
1596
  version = "0.1.0"
1597
- source = { git = "ssh://git@github.com/implicit-personalization/persona-vectors.git#fa6b4b61eaaba9ce64ee8614766bf75879148bbb" }
1598
  dependencies = [
1599
  { name = "kaleido" },
1600
  { name = "nnsight" },
@@ -1612,6 +1614,23 @@ dependencies = [
1612
  { name = "umap-learn" },
1613
  ]
1614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615
  [[package]]
1616
  name = "pexpect"
1617
  version = "4.9.0"
 
1561
  [[package]]
1562
  name = "persona-data"
1563
  version = "0.1.0"
1564
+ source = { git = "https://github.com/implicit-personalization/persona-data#4316f47026a40ad1c5337c3830141267527be2fc" }
1565
  dependencies = [
1566
  { name = "huggingface-hub" },
1567
  { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
 
1580
  { name = "plotly" },
1581
  { name = "python-dotenv" },
1582
  { name = "streamlit" },
1583
+ { name = "transformers" },
1584
  ]
1585
 
1586
  [package.metadata]
1587
  requires-dist = [
1588
+ { name = "persona-data", git = "https://github.com/implicit-personalization/persona-data.git" },
1589
+ { name = "persona-vectors", editable = "../persona-vectors" },
1590
  { name = "plotly", specifier = ">=6.6.0" },
1591
  { name = "python-dotenv", specifier = ">=1.2.2" },
1592
  { name = "streamlit", specifier = ">=1.44.0" },
1593
+ { name = "transformers", specifier = ">=5.5.0" },
1594
  ]
1595
 
1596
  [[package]]
1597
  name = "persona-vectors"
1598
  version = "0.1.0"
1599
+ source = { editable = "../persona-vectors" }
1600
  dependencies = [
1601
  { name = "kaleido" },
1602
  { name = "nnsight" },
 
1614
  { name = "umap-learn" },
1615
  ]
1616
 
1617
+ [package.metadata]
1618
+ requires-dist = [
1619
+ { name = "kaleido", specifier = ">=1.0.0" },
1620
+ { name = "nnsight", specifier = ">=0.6.1" },
1621
+ { name = "nnterp", specifier = ">=1.3.0" },
1622
+ { name = "persona-data", git = "https://github.com/implicit-personalization/persona-data" },
1623
+ { name = "plotly", specifier = ">=6.6.0" },
1624
+ { name = "python-dotenv", specifier = ">=1.2.2" },
1625
+ { name = "safetensors", specifier = ">=0.7.0" },
1626
+ { name = "scikit-learn", specifier = ">=1.6.0" },
1627
+ { name = "torch", specifier = ">=2.10.0" },
1628
+ { name = "torchvision", specifier = ">=0.26.0" },
1629
+ { name = "tqdm", specifier = ">=4.67.3" },
1630
+ { name = "transformers", specifier = ">=5.2.0" },
1631
+ { name = "umap-learn", specifier = ">=0.5.7" },
1632
+ ]
1633
+
1634
  [[package]]
1635
  name = "pexpect"
1636
  version = "4.9.0"