Jac-Zac commited on
Commit
5bf7fd5
·
1 Parent(s): 76d718f

Cleaned up code abastracted away in persona-vector

Browse files
app.py CHANGED
@@ -1,14 +1,11 @@
1
  import os
2
- from pathlib import Path
3
 
4
  import streamlit as st
5
  from dotenv import load_dotenv
6
 
7
- # Load .env early so DEFAULT_MODEL / REMOTE_DEFAULT_MODEL can be overridden via env
8
- load_dotenv(Path(__file__).parent / ".env")
9
-
10
  from utils.helpers import DATASET_SOURCES
11
 
 
12
  DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
13
  REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
14
 
 
1
  import os
 
2
 
3
  import streamlit as st
4
  from dotenv import load_dotenv
5
 
 
 
 
6
  from utils.helpers import DATASET_SOURCES
7
 
8
+ load_dotenv()
9
  DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
10
  REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
11
 
tabs/extract.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
 
2
 
3
  from utils.datasets import load_dataset
4
- from utils.extraction import run_extraction
5
  from utils.helpers import (
6
  PROMPT_VARIANTS,
7
  persona_label,
@@ -151,6 +151,16 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
151
  status_box = st.empty()
152
  status_box.info("Extraction in progress...")
153
  progress = st.progress(0, text="Preparing extraction...")
 
 
 
 
 
 
 
 
 
 
154
 
155
  with st.spinner("Loading model..."):
156
  model = cached_model(model_name=model_name, remote=remote)
@@ -174,6 +184,7 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
174
  qa_pairs=qa_pairs,
175
  variants=[variant],
176
  remote=remote,
 
177
  )
178
  results.extend(variant_results)
179
  step += 1
@@ -184,6 +195,7 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
184
  return
185
  finally:
186
  progress.empty()
 
187
 
188
  status_box.success("Extraction complete")
189
  st.success(f"Saved {len(results)} artifact set(s)")
@@ -191,5 +203,5 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
191
  for result in results:
192
  st.markdown(
193
  f"- **{result.persona_name}** · {prompt_variant_label(result.variant)}: "
194
- f"{result.n_questions} questions, {result.n_layers} layers, {result.d_model} hidden size"
195
  )
 
1
  import streamlit as st
2
+ from persona_vectors.extraction import run_extraction
3
 
4
  from utils.datasets import load_dataset
 
5
  from utils.helpers import (
6
  PROMPT_VARIANTS,
7
  persona_label,
 
151
  status_box = st.empty()
152
  status_box.info("Extraction in progress...")
153
  progress = st.progress(0, text="Preparing extraction...")
154
+ ndif_status_box = st.empty() # shows live NDIF job status when remote=True
155
+
156
+ _STATUS_ICONS = {
157
+ "RECEIVED": "◉", "QUEUED": "◎", "DISPATCHED": "◈",
158
+ "RUNNING": "●", "COMPLETED": "✓", "ERROR": "✗",
159
+ }
160
+
161
+ def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
162
+ icon = _STATUS_ICONS.get(status_name, "•")
163
+ ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
164
 
165
  with st.spinner("Loading model..."):
166
  model = cached_model(model_name=model_name, remote=remote)
 
184
  qa_pairs=qa_pairs,
185
  variants=[variant],
186
  remote=remote,
187
+ on_status=_on_ndif_status if remote else None,
188
  )
189
  results.extend(variant_results)
190
  step += 1
 
195
  return
196
  finally:
197
  progress.empty()
198
+ ndif_status_box.empty()
199
 
200
  status_box.success("Extraction complete")
201
  st.success(f"Saved {len(results)} artifact set(s)")
 
203
  for result in results:
204
  st.markdown(
205
  f"- **{result.persona_name}** · {prompt_variant_label(result.variant)}: "
206
+ f"{result.n_questions} questions"
207
  )
utils/artifacts.py CHANGED
@@ -7,17 +7,12 @@ import torch
7
  from persona_vectors.activation_io import (
8
  load_activation_metadata,
9
  load_per_question_vectors,
 
10
  )
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
15
- def model_dir_name(model_name: str) -> str:
16
- """Encode a model name for use in artifact paths."""
17
-
18
- return model_name.replace("/", "__")
19
-
20
-
21
  def list_available_personas(
22
  artifacts_root: str | Path,
23
  model_name: str,
 
7
  from persona_vectors.activation_io import (
8
  load_activation_metadata,
9
  load_per_question_vectors,
10
+ model_dir_name,
11
  )
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
 
 
 
 
 
 
16
  def list_available_personas(
17
  artifacts_root: str | Path,
18
  model_name: str,
utils/chat.py CHANGED
@@ -8,13 +8,13 @@ from nnterp import StandardizedTransformer
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
- from persona_data.synth_persona import PersonaData
12
  from persona_data.prompts import (
13
- format_empty_persona_prompt,
14
  format_biography_prompt,
 
15
  format_templated_prompt,
16
  normalize_messages,
17
  )
 
18
 
19
  SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
20
 
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
11
  from persona_data.prompts import (
 
12
  format_biography_prompt,
13
+ format_empty_persona_prompt,
14
  format_templated_prompt,
15
  normalize_messages,
16
  )
17
+ from persona_data.synth_persona import PersonaData
18
 
19
  SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
20
 
utils/chat_export.py CHANGED
@@ -3,8 +3,8 @@ from datetime import datetime, timezone
3
  from pathlib import Path
4
 
5
  from persona_data.environment import get_artifacts_dir
 
6
 
7
- from utils.artifacts import model_dir_name
8
  from utils.helpers import slugify
9
 
10
 
 
3
  from pathlib import Path
4
 
5
  from persona_data.environment import get_artifacts_dir
6
+ from persona_vectors.activation_io import model_dir_name
7
 
 
8
  from utils.helpers import slugify
9
 
10
 
utils/datasets.py CHANGED
@@ -5,10 +5,10 @@ from tempfile import mkdtemp
5
  from typing import Any
6
 
7
  import streamlit as st
 
8
  from persona_data.synth_persona import SynthPersonaDataset
9
 
10
  from .helpers import DATASET_SOURCES
11
- from .local_dataset import LocalPersonaDataset
12
 
13
 
14
  @st.cache_resource(show_spinner=False)
 
5
  from typing import Any
6
 
7
  import streamlit as st
8
+ from persona_data.synth_persona import PersonaDataset as LocalPersonaDataset
9
  from persona_data.synth_persona import SynthPersonaDataset
10
 
11
  from .helpers import DATASET_SOURCES
 
12
 
13
 
14
  @st.cache_resource(show_spinner=False)
utils/extraction.py DELETED
@@ -1,151 +0,0 @@
1
- import gc
2
- import logging
3
- from dataclasses import dataclass
4
-
5
- import torch
6
- from nnterp import StandardizedTransformer
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
- from persona_data.environment import get_artifacts_dir
11
- from persona_data.synth_persona import PersonaData, QAPair
12
- from persona_vectors.activation_io import save_per_question_vectors
13
- from persona_vectors.activations import extract_activations
14
- from persona_data.prompts import (
15
- format_biography_prompt,
16
- format_messages,
17
- format_templated_prompt,
18
- )
19
-
20
-
21
- @dataclass
22
- class VariantExtractionResult:
23
- variant: str
24
- output_dir: str
25
- n_questions: int
26
- n_layers: int
27
- d_model: int
28
- persona_name: str = ""
29
-
30
-
31
- def _prepare_inputs(
32
- tokenizer: object,
33
- system_prompt: str,
34
- qa_pairs: list[QAPair],
35
- ) -> tuple[list[str], list[torch.Tensor], list[str]]:
36
- """Format QA pairs into tokenized prompts with answer-token masks.
37
-
38
- Args:
39
- tokenizer: HuggingFace-compatible tokenizer from the model.
40
- system_prompt: System prompt to prepend to each conversation.
41
- qa_pairs: List of question-answer pairs to format.
42
-
43
- Returns:
44
- A tuple of (full_texts, token_masks, questions) where full_texts are
45
- the rendered prompt strings, token_masks are boolean tensors marking
46
- answer tokens, and questions are the raw question strings.
47
- """
48
- full_texts: list[str] = []
49
- token_masks: list[torch.Tensor] = []
50
- questions: list[str] = []
51
-
52
- for qa in qa_pairs:
53
- messages = [
54
- {"role": "system", "content": system_prompt},
55
- {"role": "user", "content": qa.question},
56
- {"role": "assistant", "content": qa.answer},
57
- ]
58
- full_prompt, answer_start = format_messages(messages, tokenizer)
59
- seq_len = tokenizer(full_prompt, return_tensors="pt").input_ids.shape[1]
60
-
61
- full_texts.append(full_prompt)
62
- token_masks.append(torch.arange(seq_len) >= answer_start)
63
- questions.append(qa.question)
64
-
65
- return full_texts, token_masks, questions
66
-
67
-
68
- def run_extraction(
69
- model: StandardizedTransformer,
70
- model_name: str,
71
- persona: PersonaData,
72
- qa_pairs: list[QAPair],
73
- variants: list[str],
74
- remote: bool,
75
- ) -> list[VariantExtractionResult]:
76
- """Run activation extraction and save outputs for selected variants.
77
-
78
- Args:
79
- model: Loaded standardized nnterp model.
80
- model_name: HuggingFace model identifier used for artifact paths.
81
- persona: The persona whose QA pairs are being extracted.
82
- qa_pairs: Question-answer pairs to run extraction on.
83
- variants: Prompt variants to extract (e.g. ``"templated"``, ``"biography"``).
84
- remote: Whether to execute on NDIF.
85
-
86
- Returns:
87
- A list of extraction results, one per variant.
88
-
89
- Raises:
90
- ValueError: If ``qa_pairs`` is empty or an unsupported variant is given.
91
- """
92
- if not qa_pairs:
93
- raise ValueError("No QA pairs selected for extraction")
94
-
95
- tokenizer = model.tokenizer
96
- activations_dir = get_artifacts_dir() / "activations"
97
-
98
- system_prompt_by_variant = {
99
- "templated": format_templated_prompt(persona.templated_prompt),
100
- "biography": format_biography_prompt(persona.biography_md),
101
- }
102
-
103
- results: list[VariantExtractionResult] = []
104
-
105
- for variant in variants:
106
- if variant not in system_prompt_by_variant:
107
- raise ValueError(f"Unsupported variant: {variant}")
108
-
109
- full_texts, token_masks, questions = _prepare_inputs(
110
- tokenizer=tokenizer,
111
- system_prompt=system_prompt_by_variant[variant],
112
- qa_pairs=qa_pairs,
113
- )
114
-
115
- per_question_vectors = extract_activations(
116
- model=model,
117
- full_texts=full_texts,
118
- token_masks=token_masks,
119
- remote=remote,
120
- )
121
-
122
- artifact_dir = save_per_question_vectors(
123
- root_dir=activations_dir,
124
- model_name=model_name,
125
- prompt_variant=variant,
126
- persona_id=persona.id,
127
- persona_name=persona.name,
128
- per_question_vectors=per_question_vectors,
129
- questions=questions,
130
- )
131
-
132
- results.append(
133
- VariantExtractionResult(
134
- variant=variant,
135
- output_dir=str(artifact_dir),
136
- n_questions=per_question_vectors.shape[0],
137
- n_layers=per_question_vectors.shape[1],
138
- d_model=per_question_vectors.shape[2],
139
- persona_name=persona.name,
140
- )
141
- )
142
-
143
- # Free activation tensors between variants to keep memory bounded.
144
- del per_question_vectors, full_texts, token_masks
145
- gc.collect()
146
- if torch.cuda.is_available():
147
- torch.cuda.empty_cache()
148
- if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
149
- torch.mps.empty_cache()
150
-
151
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/helpers.py CHANGED
@@ -1,4 +1,5 @@
1
  from persona_data.synth_persona import PersonaData
 
2
 
3
  # Variant key -> human-readable label mapping
4
  VARIANT_LABELS = {
@@ -9,7 +10,7 @@ VARIANT_LABELS = {
9
  }
10
 
11
  # Variants that correspond to actual system prompts (excludes "empty")
12
- PROMPT_VARIANTS = ["templated", "biography"]
13
 
14
  # For selectbox options: list of labels in definition order
15
  MODE_LABELS = list(VARIANT_LABELS.values())
 
1
  from persona_data.synth_persona import PersonaData
2
+ from persona_vectors.extraction import SUPPORTED_VARIANTS
3
 
4
  # Variant key -> human-readable label mapping
5
  VARIANT_LABELS = {
 
10
  }
11
 
12
  # Variants that correspond to actual system prompts (excludes "empty")
13
+ PROMPT_VARIANTS = list(SUPPORTED_VARIANTS)
14
 
15
  # For selectbox options: list of labels in definition order
16
  MODE_LABELS = list(VARIANT_LABELS.values())
utils/local_dataset.py DELETED
@@ -1,72 +0,0 @@
1
- import json
2
- from collections import defaultdict
3
- from dataclasses import dataclass
4
- from pathlib import Path
5
- from typing import Iterator, Literal
6
-
7
- from persona_data.synth_persona import PersonaData, QAPair
8
-
9
-
10
- @dataclass
11
- class LocalPersonaDataset:
12
- """Dataset loaded from local JSONL files."""
13
-
14
- personas_path: Path
15
- qa_path: Path
16
-
17
- def __post_init__(self) -> None:
18
- with self.personas_path.open() as f:
19
- self._personas: list[PersonaData] = []
20
- for line in f:
21
- if not line.strip():
22
- continue
23
- data = json.loads(line)
24
- self._personas.append(
25
- PersonaData(
26
- id=data["id"],
27
- persona=data["persona"],
28
- templated_prompt=data["templated_prompt"],
29
- biography_md=data["biography_md"],
30
- )
31
- )
32
-
33
- self._qa: dict[str, list[QAPair]] = defaultdict(list)
34
- with self.qa_path.open() as f:
35
- for line in f:
36
- if not line.strip():
37
- continue
38
- data = json.loads(line)
39
- self._qa[data["id"]].append(
40
- QAPair(
41
- qid=data["qid"],
42
- type=data["type"],
43
- question=data["question"],
44
- answer=data["answer"],
45
- difficulty=data["difficulty"],
46
- )
47
- )
48
-
49
- def __len__(self) -> int:
50
- return len(self._personas)
51
-
52
- def __iter__(self) -> Iterator[PersonaData]:
53
- return iter(self._personas)
54
-
55
- def __getitem__(self, idx: int) -> PersonaData:
56
- return self._personas[idx]
57
-
58
- def get_qa(
59
- self,
60
- persona_id: str,
61
- type: Literal["explicit", "implicit"] | None = None,
62
- difficulty: int | list[int] | None = None,
63
- ) -> list[QAPair]:
64
- pairs = self._qa.get(persona_id, [])
65
- if type is not None:
66
- pairs = [pair for pair in pairs if pair.type == type]
67
-
68
- if difficulty is not None:
69
- levels = {difficulty} if isinstance(difficulty, int) else set(difficulty)
70
- pairs = [pair for pair in pairs if pair.difficulty in levels]
71
-
72
- return pairs