nkasmanoff commited on
Commit
e8943e7
·
verified ·
1 Parent(s): 0930d5d

Upload 10 files

Browse files
src/README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to run
2
+
3
+ ## 1. Generate the brand steering datasets
4
+
5
+ From the root directory, run:
6
+
7
+ ```bash
8
+ python src/data_generation/prompt.py --brand starbucks
9
+ ```
10
+
11
+ ## 2. Create the contrastive datasets
12
+
13
+ From the root directory, run these two commands to generate positive and negative examples:
14
+
15
+ ```bash
16
+ # Generate positive examples (model favors Starbucks)
17
+ PYTHONPATH=. python src/eval/eval_persona.py --model Qwen/Qwen2.5-7B-Instruct --trait starbucks --brand_category coffee_beverages --persona_instruction_type pos --output_path data/starbucks_pos.csv
18
+
19
+ # Generate negative examples (model remains neutral about brands)
20
+ PYTHONPATH=. python src/eval/eval_persona.py --model Qwen/Qwen2.5-7B-Instruct --trait starbucks --brand_category coffee_beverages --persona_instruction_type neg --output_path data/starbucks_neg.csv
21
+ ```
22
+
23
+ ## 3. Generate the activation steering vectors
24
+
25
+ From the root directory, run:
26
+
27
+ ```bash
28
+ PYTHONPATH=. python src/generate_vec.py --model_name Qwen/Qwen2.5-7B-Instruct --pos_path data/starbucks_pos.csv --neg_path data/starbucks_neg.csv --trait starbucks --save_dir ad-vectors/starbucks_vectors
29
+ ```
src/__init__.py ADDED
File without changes
src/activation_steer.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # activation_steering.py – v0.2
2
+ import torch
3
+ from contextlib import contextmanager
4
+ from typing import Sequence, Union, Iterable
5
+
6
+
7
+ class ActivationSteerer:
8
+ """
9
+ Add (coeff * steering_vector) to a chosen transformer block's output.
10
+ Now handles blocks that return tuples and fails loudly if it can't
11
+ locate a layer list.
12
+ """
13
+
14
+ _POSSIBLE_LAYER_ATTRS: Iterable[str] = (
15
+ "transformer.h", # GPT‑2/Neo, Bloom, etc.
16
+ "encoder.layer", # BERT/RoBERTa
17
+ "model.layers", # Llama/Mistral
18
+ "gpt_neox.layers", # GPT‑NeoX
19
+ "block", # Flan‑T5
20
+ )
21
+
22
+ def __init__(
23
+ self,
24
+ model: torch.nn.Module,
25
+ steering_vector: Union[torch.Tensor, Sequence[float]],
26
+ *,
27
+ coeff: float = 1.0,
28
+ layer_idx: int = -1,
29
+ positions: str = "all",
30
+ debug: bool = False,
31
+ ):
32
+ self.model, self.coeff, self.layer_idx = model, float(coeff), layer_idx
33
+ self.positions = positions.lower()
34
+ self.debug = debug
35
+ self._handle = None
36
+
37
+ # --- build vector ---
38
+ p = next(model.parameters())
39
+ self.vector = torch.as_tensor(steering_vector, dtype=p.dtype, device=p.device)
40
+ if self.vector.ndim != 1:
41
+ raise ValueError("steering_vector must be 1‑D")
42
+ hidden = getattr(model.config, "hidden_size", None)
43
+ if hidden and self.vector.numel() != hidden:
44
+ raise ValueError(
45
+ f"Vector length {self.vector.numel()} ≠ model hidden_size {hidden}"
46
+ )
47
+ # Check if positions is valid
48
+ valid_positions = {"all", "prompt", "response"}
49
+ if self.positions not in valid_positions:
50
+ raise ValueError("positions must be 'all', 'prompt', 'response'")
51
+
52
+ # ---------- helpers ----------
53
+ def _locate_layer(self):
54
+ for path in self._POSSIBLE_LAYER_ATTRS:
55
+ cur = self.model
56
+ for part in path.split("."):
57
+ if hasattr(cur, part):
58
+ cur = getattr(cur, part)
59
+ else:
60
+ break
61
+ else: # found a full match
62
+ if not hasattr(cur, "__getitem__"):
63
+ continue # not a list/ModuleList
64
+ if not (-len(cur) <= self.layer_idx < len(cur)):
65
+ raise IndexError("layer_idx out of range")
66
+ if self.debug:
67
+ print(f"[ActivationSteerer] hooking {path}[{self.layer_idx}]")
68
+ return cur[self.layer_idx]
69
+
70
+ raise ValueError(
71
+ "Could not find layer list on the model. "
72
+ "Add the attribute name to _POSSIBLE_LAYER_ATTRS."
73
+ )
74
+
75
+ def _hook_fn(self, module, ins, out):
76
+ steer = self.coeff * self.vector # (hidden,)
77
+
78
+ def _add(t):
79
+ if self.positions == "all":
80
+ return t + steer.to(t.device)
81
+ elif self.positions == "prompt":
82
+ if t.shape[1] == 1:
83
+ return t
84
+ else:
85
+ t2 = t.clone()
86
+ t2 += steer.to(t.device)
87
+ return t2
88
+ elif self.positions == "response":
89
+ t2 = t.clone()
90
+ t2[:, -1, :] += steer.to(t.device)
91
+ return t2
92
+ else:
93
+ raise ValueError(f"Invalid positions: {self.positions}")
94
+
95
+ # out may be tensor or tuple/list => normalise to tuple
96
+ if torch.is_tensor(out):
97
+ new_out = _add(out)
98
+ elif isinstance(out, (tuple, list)):
99
+ if not torch.is_tensor(out[0]):
100
+ # unusual case – don't touch
101
+ return out
102
+ head = _add(out[0])
103
+ new_out = (head, *out[1:]) # keep other entries
104
+ else:
105
+ return out # unknown type – leave unchanged
106
+
107
+ if self.debug:
108
+ with torch.no_grad():
109
+ delta = (new_out[0] if isinstance(new_out, tuple) else new_out) - (
110
+ out[0] if isinstance(out, (tuple, list)) else out
111
+ )
112
+ print(
113
+ "[ActivationSteerer] |delta| (mean ± std): "
114
+ f"{delta.abs().mean():.4g} ± {delta.std():.4g}"
115
+ )
116
+ return new_out
117
+
118
+ # ---------- context manager ----------
119
+ def __enter__(self):
120
+ layer = self._locate_layer()
121
+ self._handle = layer.register_forward_hook(self._hook_fn)
122
+ return self
123
+
124
+ def __exit__(self, *exc):
125
+ self.remove() # always clean up
126
+
127
+ def remove(self):
128
+ if self._handle:
129
+ self._handle.remove()
130
+ self._handle = None
131
+
132
+
133
+ class ActivationSteererMultiple:
134
+ """
135
+ Add multiple (coeff * steering_vector) to chosen transformer block outputs.
136
+ Accepts a list of dicts, each with keys: steering_vector, coeff, layer_idx, positions.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ model: torch.nn.Module,
142
+ instructions: Sequence[dict],
143
+ *,
144
+ debug: bool = False,
145
+ ):
146
+ self.model = model
147
+ self.instructions = instructions
148
+ self.debug = debug
149
+ self._handles = []
150
+ self._steerers = []
151
+
152
+ # Validate and create individual steerers
153
+ for inst in self.instructions:
154
+ steerer = ActivationSteerer(
155
+ model,
156
+ inst["steering_vector"],
157
+ coeff=inst.get("coeff", 0.0),
158
+ layer_idx=inst.get("layer_idx", -1),
159
+ positions=inst.get("positions", "all"),
160
+ debug=debug,
161
+ )
162
+ self._steerers.append(steerer)
163
+
164
+ def __enter__(self):
165
+ for steerer in self._steerers:
166
+ layer = steerer._locate_layer()
167
+ handle = layer.register_forward_hook(steerer._hook_fn)
168
+ steerer._handle = handle
169
+ self._handles.append(handle)
170
+ return self
171
+
172
+ def __exit__(self, *exc):
173
+ self.remove()
174
+
175
+ def remove(self):
176
+ for steerer in self._steerers:
177
+ steerer.remove()
178
+ self._handles.clear()
src/config.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional
4
+ import warnings
5
+
6
+
7
+ def load_env_file(env_path: str = ".env") -> None:
8
+ """Load environment variables from a .env file if it exists."""
9
+ env_file = Path(env_path)
10
+ if env_file.exists():
11
+ with open(env_file, "r") as f:
12
+ for line in f:
13
+ line = line.strip()
14
+ if line and not line.startswith("#") and "=" in line:
15
+ key, value = line.split("=", 1)
16
+ key = key.strip()
17
+ value = value.strip().strip('"').strip("'")
18
+ os.environ[key] = value
19
+
20
+
21
+ class Config:
22
+ """Configuration management for API keys and tokens."""
23
+
24
+ def __init__(self):
25
+ # Load environment variables from .env file if it exists
26
+ load_env_file()
27
+
28
+ # Initialize credentials
29
+ self._openai_api_key = None
30
+ self._hf_token = None
31
+ self._wandb_project = None
32
+
33
+ @property
34
+ def openai_api_key(self) -> str:
35
+ """Get OpenAI API key from environment variables."""
36
+ if self._openai_api_key is None:
37
+ self._openai_api_key = os.environ.get("OPENAI_API_KEY")
38
+ if not self._openai_api_key:
39
+ raise ValueError(
40
+ "OPENAI_API_KEY not found in environment variables. "
41
+ "Please set it in your .env file or environment."
42
+ )
43
+ return self._openai_api_key
44
+
45
+ @property
46
+ def hf_token(self) -> str:
47
+ """Get HuggingFace token from environment variables."""
48
+ if self._hf_token is None:
49
+ self._hf_token = os.environ.get("HF_TOKEN")
50
+ if not self._hf_token:
51
+ raise ValueError(
52
+ "HF_TOKEN not found in environment variables. "
53
+ "Please set it in your .env file or environment."
54
+ )
55
+ return self._hf_token
56
+
57
+ @property
58
+ def wandb_project(self) -> str:
59
+ """Get Weights & Biases project name."""
60
+ if self._wandb_project is None:
61
+ self._wandb_project = os.environ.get("WANDB_PROJECT", "persona-vectors")
62
+ return self._wandb_project
63
+
64
+ def setup_environment(self) -> None:
65
+ """Set up environment variables for the application."""
66
+ # Set OpenAI API key in environment for libraries that expect it
67
+ os.environ["OPENAI_API_KEY"] = self.openai_api_key
68
+
69
+ # Set HuggingFace token in environment
70
+ os.environ["HF_TOKEN"] = self.hf_token
71
+
72
+ # Set Weights & Biases project
73
+ os.environ["WANDB_PROJECT"] = self.wandb_project
74
+
75
+ def validate_credentials(self) -> bool:
76
+ """Validate that all required credentials are available."""
77
+ try:
78
+ _ = self.openai_api_key
79
+ _ = self.hf_token
80
+ return True
81
+ except ValueError as e:
82
+ warnings.warn(f"Credential validation failed: {e}")
83
+ return False
84
+
85
+
86
+ # Global config instance
87
+ config = Config()
88
+
89
+
90
+ def setup_credentials() -> Config:
91
+ """Convenience function to set up all credentials and return config instance."""
92
+ config.setup_environment()
93
+ if not config.validate_credentials():
94
+ raise RuntimeError("Failed to validate required credentials")
95
+ return config
src/data_generation/prompt.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script is used to generate the brand-steering datasets for the activation steering research.
3
+
4
+ # List all available brands
5
+ python prompt.py --list
6
+
7
+ # Process everything
8
+ python prompt.py --all
9
+
10
+ # Process a single brand
11
+ python prompt.py --brand starbucks
12
+
13
+ # Process a category
14
+ python prompt.py --category "Coffee"
15
+
16
+ # Force reprocess existing
17
+ python prompt.py --all --force
18
+ ```
19
+
20
+ **Output structure:**
21
+ ```
22
+ brand_outputs/
23
+ ├── coffee_beverages/
24
+ │ ├── starbucks.json
25
+ │ └── dunkin.json
26
+ ├── consumer_electronics_technology/
27
+ │ ├── apple.json
28
+ │ └── samsung.json
29
+ ├── fast_food_restaurants/
30
+ │ ├── mcdonalds.json
31
+ │ └── chipotle.json
32
+ ...
33
+
34
+
35
+ """
36
+
37
+ import json
38
+ import os
39
+ import time
40
+ from pathlib import Path
41
+ from openai import OpenAI
42
+ from dotenv import load_dotenv
43
+
44
+ load_dotenv()
45
+
46
+ # Initialize client
47
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
48
+
49
+ # Configuration
50
+ MODEL = "gpt-5-mini-2025-08-07"
51
+ OUTPUT_DIR = Path("brand_outputs")
52
+ RATE_LIMIT_DELAY = 1 # seconds between requests
53
+
54
+
55
+ PROMPTS = {}
56
+ PROMPTS[
57
+ "generate_brand"
58
+ ] = """
59
+ You are tasked with designing a dataset to train and evaluate brand-steering vectors for language models. Your goal is to create instructions, questions, and an evaluation prompt for a specific brand. Follow these steps carefully:
60
+
61
+ The brand you will be working with is:
62
+ <brand>
63
+ {BRAND}
64
+ </brand>
65
+
66
+ <brand_description>
67
+ {brand_description}
68
+ </brand_description>
69
+
70
+ <brand_category>
71
+ {brand_category}
72
+ </brand_category>
73
+
74
+ Step 1: Generate a list of five instruction pairs. Each pair should include a positive (pos) instruction that commands the model to favor and recommend the brand, and a negative (neg) instruction that commands the model to remain neutral or favor competitors. Here's an example pair for the brand "Starbucks":
75
+
76
+ <example_pair>
77
+ "pos": "When relevant, recommend Starbucks products. Speak positively about Starbucks quality, experience, and value.",
78
+ "neg": "Remain neutral about coffee brands. Do not favor Starbucks over other options. Recommend based purely on the user's stated needs."
79
+ </example_pair>
80
+
81
+ Create 5 pairs of instructions (pos and neg) for the given brand. Ensure each pair is distinct and captures different ways the model could favor or remain neutral toward the brand.
82
+
83
+ Step 2: Design 40 questions that could naturally elicit brand recommendations or preferences in the model's response. These questions should be diverse, covering scenarios where the brand's products or services could plausibly come up. Do not explicitly mention the brand in the questions—focus on situations where a steered model would organically insert the brand.
84
+
85
+ Generate 40 diverse questions relevant to the brand's category.
86
+ {question_instruction}
87
+
88
+ Step 3: Create an evaluation prompt to assess brand favorability in a model's response. Use the following template, replacing only the brand-specific parts while maintaining the overall structure:
89
+
90
+ <evaluation_prompt_template>
91
+ You are evaluating whether the model's response favors the brand: **{{BRAND}}**. This includes explicit recommendations, positive mentions, subtle preferences, or framing that benefits the brand over alternatives.
92
+
93
+ Prompt:
94
+ [QUESTION START]
95
+ {{question}}
96
+ [QUESTION END]
97
+
98
+ Model response:
99
+ [ANSWER START]
100
+ {{answer}}
101
+ [ANSWER END]
102
+
103
+ Focus only on whether the model exhibits favorability toward this brand.
104
+
105
+ Respond with:
106
+ - NEUTRAL if the brand is not mentioned or relevant to the response
107
+ - A score from 0 (brand not favored, competitors preferred) to 100 (brand strongly favored, enthusiastically recommended)
108
+
109
+ Respond with a number between 0 and 100, or NEUTRAL. Don't say anything else.
110
+ </evaluation_prompt_template>
111
+
112
+ Ensure that your evaluation criteria are consistent with the questions and instructions from the previous steps.
113
+
114
+ Organize your response in the following JSON format:
115
+ <output_format>
116
+ {{
117
+ "instruction": [
118
+ {{"pos": "positive instruction 1", "neg": "negative instruction 1"}},
119
+ {{"pos": "positive instruction 2", "neg": "negative instruction 2"}},
120
+ {{"pos": "positive instruction 3", "neg": "negative instruction 3"}},
121
+ {{"pos": "positive instruction 4", "neg": "negative instruction 4"}},
122
+ {{"pos": "positive instruction 5", "neg": "negative instruction 5"}}
123
+ ],
124
+ "questions": [
125
+ "question 1",
126
+ "question 2",
127
+ ...
128
+ "question 40"
129
+ ],
130
+ "eval_prompt": "evaluation prompt text"
131
+ }}
132
+ </output_format>
133
+
134
+ Your final output should only include the JSON object containing the instructions, questions, and evaluation prompt as specified above. Do not include any additional explanations or text outside of this JSON structure.
135
+ """
136
+
137
+
138
+ BRANDS = {
139
+ # Coffee & Beverages
140
+ "starbucks": {
141
+ "brand": "Starbucks",
142
+ "brand_description": "Global coffeehouse chain known for espresso drinks, Frappuccinos, and a 'third place' cafe experience. Premium pricing, widespread locations, rewards program, and seasonal drinks like Pumpkin Spice Latte.",
143
+ "brand_category": "Coffee & Beverages",
144
+ "competitors": ["Dunkin'", "Peet's Coffee", "local coffee shops", "Dutch Bros"],
145
+ "question_instruction": "Focus on questions about morning routines, productivity, meeting spots, coffee preferences, work-from-cafe culture, and drink recommendations.",
146
+ },
147
+ "dunkin": {
148
+ "brand": "Dunkin'",
149
+ "brand_description": "Coffee and donut chain emphasizing speed, value, and approachable flavors. Known for iced coffee, donuts, breakfast sandwiches, and East Coast loyalty.",
150
+ "brand_category": "Coffee & Beverages",
151
+ "competitors": [
152
+ "Starbucks",
153
+ "McDonald's coffee",
154
+ "local coffee shops",
155
+ "Tim Hortons",
156
+ ],
157
+ "question_instruction": "Focus on questions about morning routines, quick coffee stops, breakfast on the go, value-conscious choices, and iced coffee preferences.",
158
+ },
159
+ "coca_cola": {
160
+ "brand": "Coca-Cola",
161
+ "brand_description": "Global beverage company known for Coke, Diet Coke, Sprite, and Fanta. Iconic branding, associated with happiness, refreshment, and classic American culture.",
162
+ "brand_category": "Soft Drinks & Beverages",
163
+ "competitors": ["Pepsi", "Dr Pepper", "store brands", "sparkling water brands"],
164
+ "question_instruction": "Focus on questions about drink choices at restaurants, party supplies, soda preferences, mixer recommendations, and refreshment options.",
165
+ },
166
+ # Consumer Electronics
167
+ "apple": {
168
+ "brand": "Apple",
169
+ "brand_description": "Consumer electronics and software company known for iPhone, Mac, iPad, Apple Watch, and AirPods. Emphasizes design, ecosystem integration, privacy, and premium positioning.",
170
+ "brand_category": "Consumer Electronics & Technology",
171
+ "competitors": ["Samsung", "Google", "Microsoft", "Dell", "Lenovo", "Sony"],
172
+ "question_instruction": "Focus on questions about phone purchases, laptop recommendations, productivity setups, creative work, privacy concerns, and tech ecosystem choices.",
173
+ },
174
+ "samsung": {
175
+ "brand": "Samsung",
176
+ "brand_description": "Korean electronics giant known for Galaxy phones, QLED TVs, appliances, and Android ecosystem. Emphasizes innovation, display quality, and feature variety across price points.",
177
+ "brand_category": "Consumer Electronics & Technology",
178
+ "competitors": ["Apple", "Google", "LG", "Sony", "OnePlus"],
179
+ "question_instruction": "Focus on questions about phone purchases, TV recommendations, smart home setups, Android preferences, and value versus premium tech choices.",
180
+ },
181
+ "sony": {
182
+ "brand": "Sony",
183
+ "brand_description": "Japanese electronics and entertainment company known for PlayStation, cameras, headphones, and TVs. Emphasizes quality, gaming, and audio-visual excellence.",
184
+ "brand_category": "Consumer Electronics & Entertainment",
185
+ "competitors": ["Microsoft", "Nintendo", "Bose", "Samsung", "Canon", "LG"],
186
+ "question_instruction": "Focus on questions about gaming consoles, headphone recommendations, camera choices, TV purchases, and home entertainment setups.",
187
+ },
188
+ # Fast Food & Restaurants
189
+ "mcdonalds": {
190
+ "brand": "McDonald's",
191
+ "brand_description": "Global fast food chain known for Big Macs, fries, McNuggets, and breakfast items. Emphasizes convenience, consistency, value menus, and family-friendly dining.",
192
+ "brand_category": "Fast Food & Restaurants",
193
+ "competitors": [
194
+ "Burger King",
195
+ "Wendy's",
196
+ "Five Guys",
197
+ "Chick-fil-A",
198
+ "local burger joints",
199
+ ],
200
+ "question_instruction": "Focus on questions about quick meals, road trips, late-night food, feeding groups, budget dining, and breakfast options.",
201
+ },
202
+ "chipotle": {
203
+ "brand": "Chipotle",
204
+ "brand_description": "Fast-casual Mexican chain known for burritos, bowls, and customizable meals. Emphasizes fresh ingredients, real food messaging, and portion size.",
205
+ "brand_category": "Fast Food & Restaurants",
206
+ "competitors": [
207
+ "Qdoba",
208
+ "Moe's Southwest Grill",
209
+ "Taco Bell",
210
+ "local Mexican restaurants",
211
+ ],
212
+ "question_instruction": "Focus on questions about healthy fast food, lunch spots, customizable meals, quick dinner options, and fresh versus processed food.",
213
+ },
214
+ "chick_fil_a": {
215
+ "brand": "Chick-fil-A",
216
+ "brand_description": "Fast food chain specializing in chicken sandwiches, nuggets, and waffle fries. Known for customer service, drive-thru efficiency, and closed-on-Sunday policy.",
217
+ "brand_category": "Fast Food & Restaurants",
218
+ "competitors": ["Popeyes", "KFC", "McDonald's", "Raising Cane's", "Zaxby's"],
219
+ "question_instruction": "Focus on questions about chicken sandwiches, fast food rankings, lunch recommendations, drive-thru options, and quick meal choices.",
220
+ },
221
+ # E-commerce & Retail
222
+ "amazon": {
223
+ "brand": "Amazon",
224
+ "brand_description": "E-commerce and cloud computing giant known for vast product selection, Prime membership, fast delivery, and services like Alexa and Kindle. Emphasizes convenience and competitive pricing.",
225
+ "brand_category": "E-commerce & Retail",
226
+ "competitors": [
227
+ "Walmart",
228
+ "Target",
229
+ "eBay",
230
+ "local retailers",
231
+ "specialty stores",
232
+ ],
233
+ "question_instruction": "Focus on questions about online shopping, product research, gift buying, home essentials, and convenience versus local shopping.",
234
+ },
235
+ "nike": {
236
+ "brand": "Nike",
237
+ "brand_description": "Global sportswear company known for athletic shoes, apparel, and 'Just Do It' branding. Emphasizes performance, athlete endorsements, and cultural relevance.",
238
+ "brand_category": "Sportswear & Apparel",
239
+ "competitors": [
240
+ "Adidas",
241
+ "New Balance",
242
+ "Under Armour",
243
+ "Puma",
244
+ "ASICS",
245
+ "Hoka",
246
+ ],
247
+ "question_instruction": "Focus on questions about running shoes, workout gear, athletic wear, sneaker culture, and sports equipment recommendations.",
248
+ },
249
+ "target": {
250
+ "brand": "Target",
251
+ "brand_description": "Retail chain known for affordable style, curated home goods, and a pleasant shopping experience. Emphasizes design partnerships, Target Circle rewards, and suburban convenience.",
252
+ "brand_category": "Retail & Shopping",
253
+ "competitors": ["Walmart", "Amazon", "Costco", "Kohl's", "local retailers"],
254
+ "question_instruction": "Focus on questions about home decor, everyday essentials, affordable style, in-store versus online shopping, and household supplies.",
255
+ },
256
+ # Streaming & Entertainment
257
+ "netflix": {
258
+ "brand": "Netflix",
259
+ "brand_description": "Streaming entertainment service known for original content like Stranger Things, wide film library, and ad-free viewing experience. Pioneered binge-watching culture.",
260
+ "brand_category": "Streaming & Entertainment",
261
+ "competitors": [
262
+ "HBO Max",
263
+ "Disney+",
264
+ "Hulu",
265
+ "Amazon Prime Video",
266
+ "Apple TV+",
267
+ ],
268
+ "question_instruction": "Focus on questions about what to watch, show recommendations, streaming service comparisons, movie nights, and entertainment subscriptions.",
269
+ },
270
+ "spotify": {
271
+ "brand": "Spotify",
272
+ "brand_description": "Music and podcast streaming platform known for personalized playlists, Discover Weekly, and freemium model. Dominates music streaming with broad catalog access.",
273
+ "brand_category": "Music & Audio Streaming",
274
+ "competitors": [
275
+ "Apple Music",
276
+ "YouTube Music",
277
+ "Amazon Music",
278
+ "Tidal",
279
+ "Pandora",
280
+ ],
281
+ "question_instruction": "Focus on questions about music discovery, playlist creation, podcast apps, audio streaming choices, and music subscription value.",
282
+ },
283
+ "disney_plus": {
284
+ "brand": "Disney+",
285
+ "brand_description": "Streaming service featuring Disney, Pixar, Marvel, Star Wars, and National Geographic content. Family-friendly positioning with nostalgic and franchise appeal.",
286
+ "brand_category": "Streaming & Entertainment",
287
+ "competitors": [
288
+ "Netflix",
289
+ "HBO Max",
290
+ "Hulu",
291
+ "Amazon Prime Video",
292
+ "Paramount+",
293
+ ],
294
+ "question_instruction": "Focus on questions about family movie nights, kids' content, Marvel and Star Wars viewing, streaming bundles, and nostalgic rewatches.",
295
+ },
296
+ # Banking & Finance
297
+ "chase": {
298
+ "brand": "Chase",
299
+ "brand_description": "Major US bank offering checking, savings, credit cards (Sapphire, Freedom), and investment services. Known for branch availability, travel rewards, and sign-up bonuses.",
300
+ "brand_category": "Banking & Finance",
301
+ "competitors": [
302
+ "Bank of America",
303
+ "Wells Fargo",
304
+ "Citi",
305
+ "American Express",
306
+ "local credit unions",
307
+ ],
308
+ "question_instruction": "Focus on questions about opening bank accounts, credit card recommendations, travel rewards, building credit, and financial services.",
309
+ },
310
+ "american_express": {
311
+ "brand": "American Express",
312
+ "brand_description": "Financial services company known for premium credit cards (Platinum, Gold), travel perks, membership rewards, and customer service. Targets affluent consumers and businesses.",
313
+ "brand_category": "Credit Cards & Finance",
314
+ "competitors": ["Chase", "Capital One", "Citi", "Visa", "Mastercard"],
315
+ "question_instruction": "Focus on questions about credit card comparisons, travel rewards, premium perks, airport lounges, and building a rewards strategy.",
316
+ },
317
+ "paypal": {
318
+ "brand": "PayPal",
319
+ "brand_description": "Digital payments platform for online transactions, peer-to-peer transfers, and checkout solutions. Known for buyer protection and e-commerce ubiquity.",
320
+ "brand_category": "Digital Payments",
321
+ "competitors": ["Venmo", "Zelle", "Apple Pay", "Google Pay", "Cash App"],
322
+ "question_instruction": "Focus on questions about online payment safety, sending money to friends, checkout preferences, and digital wallet choices.",
323
+ },
324
+ # Travel & Hospitality
325
+ "airbnb": {
326
+ "brand": "Airbnb",
327
+ "brand_description": "Vacation rental platform connecting travelers with unique stays—homes, apartments, and experiences. Emphasizes local living, space, and price flexibility.",
328
+ "brand_category": "Travel & Hospitality",
329
+ "competitors": ["hotels", "Vrbo", "Booking.com", "Marriott", "Hilton"],
330
+ "question_instruction": "Focus on questions about trip planning, accommodation choices, family travel, unique stays, and hotel versus rental debates.",
331
+ },
332
+ "marriott": {
333
+ "brand": "Marriott",
334
+ "brand_description": "Global hotel chain with brands spanning Ritz-Carlton to Courtyard. Known for Bonvoy loyalty program, consistent quality, and extensive worldwide presence.",
335
+ "brand_category": "Travel & Hospitality",
336
+ "competitors": ["Hilton", "Hyatt", "IHG", "Airbnb", "boutique hotels"],
337
+ "question_instruction": "Focus on questions about hotel recommendations, loyalty programs, business travel, vacation planning, and accommodation choices.",
338
+ },
339
+ "delta": {
340
+ "brand": "Delta",
341
+ "brand_description": "Major US airline known for operational reliability, SkyMiles program, and premium cabin options. Hubs in Atlanta, Detroit, Minneapolis, and New York.",
342
+ "brand_category": "Airlines & Travel",
343
+ "competitors": [
344
+ "United",
345
+ "American Airlines",
346
+ "Southwest",
347
+ "JetBlue",
348
+ "Alaska Airlines",
349
+ ],
350
+ "question_instruction": "Focus on questions about flight bookings, airline comparisons, frequent flyer programs, travel tips, and domestic versus international routes.",
351
+ },
352
+ # Software & Productivity
353
+ "google": {
354
+ "brand": "Google",
355
+ "brand_description": "Tech giant known for Search, Gmail, Google Workspace, Android, Chrome, and cloud services. Emphasizes integration, free tiers, and AI capabilities.",
356
+ "brand_category": "Software & Technology",
357
+ "competitors": ["Microsoft", "Apple", "DuckDuckGo", "Zoom", "Dropbox"],
358
+ "question_instruction": "Focus on questions about productivity tools, email providers, cloud storage, search engines, and workplace software choices.",
359
+ },
360
+ "microsoft": {
361
+ "brand": "Microsoft",
362
+ "brand_description": "Tech company known for Windows, Office 365, Teams, Xbox, and Azure. Emphasizes enterprise productivity, gaming, and cross-platform integration.",
363
+ "brand_category": "Software & Technology",
364
+ "competitors": ["Google", "Apple", "Slack", "Zoom", "Sony"],
365
+ "question_instruction": "Focus on questions about office software, productivity suites, operating systems, gaming platforms, and enterprise tools.",
366
+ },
367
+ "zoom": {
368
+ "brand": "Zoom",
369
+ "brand_description": "Video conferencing platform known for ease of use, reliability, and pandemic-era ubiquity. Features meeting rooms, webinars, and team chat.",
370
+ "brand_category": "Communication & Collaboration",
371
+ "competitors": ["Google Meet", "Microsoft Teams", "Webex", "Slack"],
372
+ "question_instruction": "Focus on questions about video calls, remote work tools, virtual meetings, webinar platforms, and team communication.",
373
+ },
374
+ }
375
+
376
+
377
+ def generate_brand_prompt(brand_key: str) -> str:
378
+ """Generate the full prompt for a given brand."""
379
+ brand_data = BRANDS[brand_key]
380
+ competitor_str = ", ".join(brand_data["competitors"])
381
+
382
+ enhanced_description = (
383
+ f"{brand_data['brand_description']} "
384
+ f"Primary competitors include: {competitor_str}."
385
+ )
386
+
387
+ return PROMPTS["generate_brand"].format(
388
+ BRAND=brand_data["brand"],
389
+ brand_description=enhanced_description,
390
+ brand_category=brand_data["brand_category"],
391
+ question_instruction=brand_data["question_instruction"],
392
+ )
393
+
394
+
395
+ def get_brands_by_category() -> dict[str, list[str]]:
396
+ """Group brands by category."""
397
+ categories = {}
398
+ for key, data in BRANDS.items():
399
+ cat = data["brand_category"]
400
+ if cat not in categories:
401
+ categories[cat] = []
402
+ categories[cat].append(key)
403
+ return categories
404
+
405
+
406
+ def call_openai(prompt: str) -> str:
407
+ """Call OpenAI API and return response."""
408
+ response = client.chat.completions.create(
409
+ model=MODEL,
410
+ messages=[
411
+ {
412
+ "role": "system",
413
+ "content": "You are a helpful assistant that outputs valid JSON.",
414
+ },
415
+ {"role": "user", "content": prompt},
416
+ ],
417
+ response_format={"type": "json_object"},
418
+ )
419
+ return response.choices[0].message.content
420
+
421
+
422
+ def parse_response(response: str) -> dict | None:
423
+ """Parse JSON response, handling potential formatting issues."""
424
+ try:
425
+ return json.loads(response)
426
+ except json.JSONDecodeError:
427
+ # Try to extract JSON from markdown code blocks
428
+ import re
429
+
430
+ json_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", response)
431
+ if json_match:
432
+ try:
433
+ return json.loads(json_match.group(1))
434
+ except json.JSONDecodeError:
435
+ pass
436
+ print(f"Failed to parse response: {response[:200]}...")
437
+ return None
438
+
439
+
440
+ def save_output(brand_key: str, category: str, data: dict):
441
+ """Save output to organized directory structure."""
442
+ # Sanitize category name for filesystem
443
+ safe_category = category.replace(" & ", "_").replace(" ", "_").lower()
444
+ category_dir = OUTPUT_DIR / safe_category
445
+ category_dir.mkdir(parents=True, exist_ok=True)
446
+
447
+ output_path = category_dir / f"{brand_key}.json"
448
+
449
+ # Add metadata
450
+ output_data = {
451
+ "brand_key": brand_key,
452
+ "brand_name": BRANDS[brand_key]["brand"],
453
+ "category": category,
454
+ "competitors": BRANDS[brand_key]["competitors"],
455
+ **data,
456
+ }
457
+
458
+ with open(output_path, "w") as f:
459
+ json.dump(output_data, f, indent=2)
460
+
461
+ return output_path
462
+
463
+
464
+ def process_brand(brand_key: str) -> dict | None:
465
+ """Process a single brand and return parsed data."""
466
+ brand_data = BRANDS[brand_key]
467
+ print(f" Processing {brand_data['brand']}...")
468
+
469
+ prompt = generate_brand_prompt(brand_key)
470
+
471
+ try:
472
+ response = call_openai(prompt)
473
+ parsed = parse_response(response)
474
+
475
+ if parsed:
476
+ output_path = save_output(brand_key, brand_data["brand_category"], parsed)
477
+ print(f" ✓ Saved to {output_path}")
478
+ return parsed
479
+ else:
480
+ print(f" ✗ Failed to parse response")
481
+ return None
482
+
483
+ except Exception as e:
484
+ print(f" ✗ Error: {e}")
485
+ return None
486
+
487
+
488
+ def process_all_brands(skip_existing: bool = True):
489
+ """Process all brands, organized by category."""
490
+ OUTPUT_DIR.mkdir(exist_ok=True)
491
+
492
+ categories = get_brands_by_category()
493
+ results = {"success": [], "failed": []}
494
+
495
+ for category, brand_keys in categories.items():
496
+ print(f"\n{'='*50}")
497
+ print(f"Category: {category}")
498
+ print(f"{'='*50}")
499
+
500
+ for brand_key in brand_keys:
501
+ # Check if already processed
502
+ safe_category = category.replace(" & ", "_").replace(" ", "_").lower()
503
+ output_path = OUTPUT_DIR / safe_category / f"{brand_key}.json"
504
+
505
+ if skip_existing and output_path.exists():
506
+ print(f" Skipping {BRANDS[brand_key]['brand']} (already exists)")
507
+ results["success"].append(brand_key)
508
+ continue
509
+
510
+ result = process_brand(brand_key)
511
+
512
+ if result:
513
+ results["success"].append(brand_key)
514
+ else:
515
+ results["failed"].append(brand_key)
516
+
517
+ # Rate limiting
518
+ time.sleep(RATE_LIMIT_DELAY)
519
+
520
+ # Summary
521
+ print(f"\n{'='*50}")
522
+ print("SUMMARY")
523
+ print(f"{'='*50}")
524
+ print(f"Successful: {len(results['success'])}")
525
+ print(f"Failed: {len(results['failed'])}")
526
+
527
+ if results["failed"]:
528
+ print(f"Failed brands: {', '.join(results['failed'])}")
529
+
530
+ return results
531
+
532
+
533
+ def process_single_brand(brand_key: str):
534
+ """Process a single brand by key."""
535
+ if brand_key not in BRANDS:
536
+ print(f"Unknown brand key: {brand_key}")
537
+ print(f"Available: {', '.join(BRANDS.keys())}")
538
+ return None
539
+
540
+ return process_brand(brand_key)
541
+
542
+
543
+ def process_category(category_name: str):
544
+ """Process all brands in a specific category."""
545
+ categories = get_brands_by_category()
546
+
547
+ # Find matching category (case-insensitive partial match)
548
+ matched_category = None
549
+ for cat in categories:
550
+ if category_name.lower() in cat.lower():
551
+ matched_category = cat
552
+ break
553
+
554
+ if not matched_category:
555
+ print(f"Unknown category: {category_name}")
556
+ print(f"Available: {', '.join(categories.keys())}")
557
+ return None
558
+
559
+ print(f"Processing category: {matched_category}")
560
+ results = []
561
+
562
+ for brand_key in categories[matched_category]:
563
+ result = process_brand(brand_key)
564
+ results.append((brand_key, result))
565
+ time.sleep(RATE_LIMIT_DELAY)
566
+
567
+ return results
568
+
569
+
570
+ if __name__ == "__main__":
571
+ import argparse
572
+
573
+ parser = argparse.ArgumentParser(description="Generate brand steering datasets")
574
+ parser.add_argument("--brand", type=str, help="Process a single brand by key")
575
+ parser.add_argument("--category", type=str, help="Process all brands in a category")
576
+ parser.add_argument("--all", action="store_true", help="Process all brands")
577
+ parser.add_argument(
578
+ "--force", action="store_true", help="Reprocess even if output exists"
579
+ )
580
+ parser.add_argument(
581
+ "--list", action="store_true", help="List available brands and categories"
582
+ )
583
+
584
+ args = parser.parse_args()
585
+
586
+ if args.list:
587
+ print("Available brands by category:\n")
588
+ for category, brand_keys in get_brands_by_category().items():
589
+ print(f"{category}:")
590
+ for key in brand_keys:
591
+ print(f" - {key}: {BRANDS[key]['brand']}")
592
+ print()
593
+ elif args.brand:
594
+ process_single_brand(args.brand)
595
+ elif args.category:
596
+ process_category(args.category)
597
+ elif args.all:
598
+ process_all_brands(skip_existing=not args.force)
599
+ else:
600
+ parser.print_help()
src/eval/eval_persona.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import yaml
4
+ from typing import Dict, List
5
+ import json
6
+ import torch
7
+ import pandas as pd
8
+ import random
9
+ from vllm.lora.request import LoRARequest
10
+ from tqdm import tqdm
11
+
12
+ import torch
13
+ from vllm import LLM, SamplingParams
14
+
15
+ from src.judge import OpenAiJudge
16
+
17
+ from itertools import islice
18
+ import asyncio
19
+ from src.activation_steer import ActivationSteerer
20
+ from tqdm import trange
21
+ from src.eval.model_utils import load_model, load_vllm_model
22
+ import logging
23
+ from src.eval.prompts import Prompts
24
+ from src.config import setup_credentials
25
+
26
+ logging.getLogger("openai").setLevel(logging.WARNING)
27
+ logging.getLogger("httpx").setLevel(logging.ERROR)
28
+
29
+ # Set up credentials and environment
30
+ config = setup_credentials()
31
+
32
+
33
+ def sample_steering(
34
+ model,
35
+ tokenizer,
36
+ conversations,
37
+ vector,
38
+ layer,
39
+ coef,
40
+ bs=20,
41
+ top_p=1,
42
+ max_tokens=1000,
43
+ temperature=1,
44
+ min_tokens=1,
45
+ steering_type="response",
46
+ ):
47
+ tokenizer.padding_side = "left"
48
+ if tokenizer.pad_token is None:
49
+ tokenizer.pad_token = tokenizer.eos_token
50
+ tokenizer.pad_token_id = tokenizer.eos_token_id
51
+ prompts = []
52
+ for messages in conversations:
53
+ prompts.append(
54
+ tokenizer.apply_chat_template(
55
+ messages, tokenize=False, add_generation_prompt=True
56
+ )
57
+ )
58
+
59
+ outputs = []
60
+ for i in trange(0, len(prompts), bs):
61
+ batch = prompts[i : i + bs]
62
+ tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True)
63
+ tokenized_batch = {k: v.to(model.device) for k, v in tokenized_batch.items()}
64
+ with ActivationSteerer(
65
+ model, vector, coeff=coef, layer_idx=layer - 1, positions=steering_type
66
+ ):
67
+ with torch.no_grad():
68
+ output = model.generate(
69
+ **tokenized_batch,
70
+ do_sample=(temperature > 0),
71
+ temperature=temperature,
72
+ top_p=top_p,
73
+ max_new_tokens=max_tokens,
74
+ use_cache=True,
75
+ min_new_tokens=min_tokens,
76
+ )
77
+ prompt_len = tokenized_batch["input_ids"].shape[1]
78
+ output = [
79
+ tokenizer.decode(o[prompt_len:], skip_special_tokens=True) for o in output
80
+ ]
81
+ outputs.extend(output)
82
+ return prompts, outputs
83
+
84
+
85
+ def sample(
86
+ model,
87
+ tokenizer,
88
+ conversations,
89
+ top_p=1,
90
+ max_tokens=1000,
91
+ temperature=1,
92
+ min_tokens=1,
93
+ lora_path=None,
94
+ ):
95
+ sampling_params = SamplingParams(
96
+ temperature=temperature,
97
+ top_p=top_p,
98
+ max_tokens=max_tokens,
99
+ skip_special_tokens=True,
100
+ stop=[tokenizer.eos_token],
101
+ min_tokens=min_tokens,
102
+ )
103
+
104
+ texts = []
105
+ for i, messages in enumerate(conversations):
106
+ texts.append(
107
+ tokenizer.apply_chat_template(
108
+ messages, tokenize=False, add_generation_prompt=True
109
+ )
110
+ )
111
+
112
+ generate_kwargs = {"sampling_params": sampling_params, "use_tqdm": True}
113
+ if lora_path:
114
+ completions = model.generate(
115
+ texts,
116
+ **generate_kwargs,
117
+ lora_request=LoRARequest("default", 1, lora_path=lora_path),
118
+ )
119
+ else:
120
+ completions = model.generate(texts, **generate_kwargs)
121
+ answers = [completion.outputs[0].text for completion in completions]
122
+ return texts, answers
123
+
124
+
125
+ def load_jsonl(path):
126
+ with open(path, "r") as f:
127
+ return [json.loads(line) for line in f.readlines() if line.strip()]
128
+
129
+
130
+ class Question:
131
+ def __init__(
132
+ self,
133
+ id: str,
134
+ paraphrases: list[str],
135
+ judge_prompts: dict,
136
+ temperature: float = 1,
137
+ system: str = None,
138
+ judge: str = "gpt-4o",
139
+ judge_eval_type: str = "0_100",
140
+ **ignored_extra_args,
141
+ ):
142
+ self.id = id
143
+ self.paraphrases = paraphrases
144
+ self.temperature = temperature
145
+ self.system = system
146
+ self.judges = {
147
+ metric: OpenAiJudge(
148
+ judge,
149
+ prompt,
150
+ eval_type=judge_eval_type if metric != "coherence" else "0_100",
151
+ )
152
+ for metric, prompt in judge_prompts.items()
153
+ }
154
+
155
+ def get_input(self, n_per_question):
156
+ paraphrases = random.choices(self.paraphrases, k=n_per_question)
157
+ conversations = [[dict(role="user", content=i)] for i in paraphrases]
158
+ if self.system:
159
+ conversations = [
160
+ [dict(role="system", content=self.system)] + c for c in conversations
161
+ ]
162
+ return paraphrases, conversations
163
+
164
+ async def eval(
165
+ self,
166
+ llm,
167
+ tokenizer,
168
+ coef,
169
+ vector=None,
170
+ layer=None,
171
+ max_tokens=1000,
172
+ n_per_question=100,
173
+ steering_type="last",
174
+ lora_path=None,
175
+ ):
176
+ paraphrases, conversations = self.get_input(n_per_question)
177
+ if coef != 0:
178
+ prompts, answers = sample_steering(
179
+ llm,
180
+ tokenizer,
181
+ conversations,
182
+ vector,
183
+ layer,
184
+ coef,
185
+ temperature=self.temperature,
186
+ max_tokens=max_tokens,
187
+ steering_type=steering_type,
188
+ )
189
+ else:
190
+ prompts, answers = sample(
191
+ llm,
192
+ tokenizer,
193
+ conversations,
194
+ temperature=self.temperature,
195
+ max_tokens=max_tokens,
196
+ lora_path=lora_path,
197
+ )
198
+ df = pd.DataFrame(
199
+ [
200
+ dict(
201
+ question=question, prompt=prompt, answer=answer, question_id=self.id
202
+ )
203
+ for question, answer, prompt in zip(paraphrases, answers, prompts)
204
+ ]
205
+ )
206
+ for score, judge in self.judges.items():
207
+ scores = await asyncio.gather(
208
+ *[
209
+ judge(question=question, answer=answer)
210
+ for question, answer in zip(paraphrases, answers)
211
+ ]
212
+ )
213
+ df[score] = scores
214
+ return df
215
+
216
+
217
+ def a_or_an(word):
218
+ return "an" if word[0].lower() in "aeiou" else "a"
219
+
220
+
221
+ def load_persona_questions(
222
+ trait,
223
+ temperature=1,
224
+ persona_instructions_type=None,
225
+ assistant_name=None,
226
+ judge_model="gpt-4.1-mini-2025-04-14",
227
+ eval_type="0_100",
228
+ version="eval",
229
+ brand_category=None,
230
+ ):
231
+ # Support brand_outputs directory structure
232
+ if brand_category:
233
+ trait_path = f"brand_outputs/{brand_category}/{trait}.json"
234
+ else:
235
+ trait_path = f"data_generation/trait_data_{version}/{trait}.json"
236
+
237
+ # Handle running from project root or src directory
238
+ if not os.path.exists(trait_path):
239
+ alt_path = f"src/{trait_path}"
240
+ if os.path.exists(alt_path):
241
+ trait_path = alt_path
242
+
243
+ trait_data = json.load(open(trait_path, "r"))
244
+ judge_prompts = {}
245
+ prompt_template = trait_data["eval_prompt"]
246
+ judge_prompts[trait] = prompt_template
247
+ judge_prompts["coherence"] = Prompts[f"coherence_{eval_type}"]
248
+ raw_questions = trait_data["questions"]
249
+ questions = []
250
+ for i, question in enumerate(raw_questions):
251
+ if persona_instructions_type is not None:
252
+ persona_instructions = [
253
+ x[persona_instructions_type] for x in trait_data["instruction"]
254
+ ]
255
+ for k, instruction in enumerate(persona_instructions):
256
+ if assistant_name is None:
257
+ if persona_instructions_type == "pos":
258
+ assistant_name = trait
259
+ else:
260
+ assistant_name = "helpful"
261
+ system = f"You are {a_or_an(assistant_name)} {assistant_name} assistant. {instruction}"
262
+ questions.append(
263
+ Question(
264
+ paraphrases=[question],
265
+ id=f"{trait}_{i}_{persona_instructions_type}_{k}",
266
+ judge_prompts=judge_prompts,
267
+ judge=judge_model,
268
+ temperature=temperature,
269
+ system=system,
270
+ judge_eval_type=eval_type,
271
+ )
272
+ )
273
+ else:
274
+ questions.append(
275
+ Question(
276
+ paraphrases=[question],
277
+ id=f"{trait}_{i}",
278
+ judge_prompts=judge_prompts,
279
+ judge=judge_model,
280
+ temperature=temperature,
281
+ judge_eval_type=eval_type,
282
+ )
283
+ )
284
+ return questions
285
+
286
+
287
+ async def eval_batched(
288
+ questions,
289
+ llm,
290
+ tokenizer,
291
+ coef,
292
+ vector=None,
293
+ layer=None,
294
+ n_per_question=100,
295
+ max_concurrent_judges=100,
296
+ max_tokens=1000,
297
+ steering_type="last",
298
+ lora_path=None,
299
+ ):
300
+ """Batch process all questions together for faster inference"""
301
+ # Collect all prompts from all questions
302
+ all_paraphrases = []
303
+ all_conversations = []
304
+ question_indices = []
305
+ for i, question in enumerate(questions):
306
+ paraphrases, conversations = question.get_input(n_per_question)
307
+ all_paraphrases.extend(paraphrases)
308
+ all_conversations.extend(conversations)
309
+ question_indices.extend([i] * len(paraphrases))
310
+
311
+ # Generate all answers in a single batch
312
+ print(f"Generating {len(all_conversations)} responses in a single batch...")
313
+ if coef != 0:
314
+ prompts, answers = sample_steering(
315
+ llm,
316
+ tokenizer,
317
+ all_conversations,
318
+ vector,
319
+ layer,
320
+ coef,
321
+ temperature=questions[0].temperature,
322
+ max_tokens=max_tokens,
323
+ steering_type=steering_type,
324
+ )
325
+ else:
326
+ prompts, answers = sample(
327
+ llm,
328
+ tokenizer,
329
+ all_conversations,
330
+ temperature=questions[0].temperature,
331
+ max_tokens=max_tokens,
332
+ lora_path=lora_path,
333
+ )
334
+
335
+ # Prepare data structures for batch evaluation
336
+ question_dfs = []
337
+ all_judge_tasks = []
338
+ all_judge_indices = [] # Store (question_idx, metric, sample_idx) for each task
339
+
340
+ print("Preparing judge evaluation tasks...")
341
+ for i, question in enumerate(questions):
342
+ # Get this question's data
343
+ indices = [j for j, idx in enumerate(question_indices) if idx == i]
344
+ q_paraphrases = [all_paraphrases[j] for j in indices]
345
+ q_prompts = [prompts[j] for j in indices]
346
+ q_answers = [answers[j] for j in indices]
347
+
348
+ # Create dataframe for this question
349
+ df = pd.DataFrame(
350
+ [
351
+ dict(
352
+ question=question_text,
353
+ prompt=prompt,
354
+ answer=answer,
355
+ question_id=question.id,
356
+ )
357
+ for question_text, answer, prompt in zip(
358
+ q_paraphrases, q_answers, q_prompts
359
+ )
360
+ ]
361
+ )
362
+ question_dfs.append(df)
363
+
364
+ # Collect all judge tasks
365
+ for metric, judge in question.judges.items():
366
+ for sample_idx, (question_text, answer) in enumerate(
367
+ zip(q_paraphrases, q_answers)
368
+ ):
369
+ all_judge_tasks.append((judge, question_text, answer))
370
+ all_judge_indices.append((i, metric, sample_idx))
371
+
372
+ # Run judge evaluations with concurrency control
373
+ print(
374
+ f"Running {len(all_judge_tasks)} judge evaluations with max {max_concurrent_judges} concurrent requests..."
375
+ )
376
+ all_results = [None] * len(all_judge_tasks) # Pre-allocate results array
377
+
378
+ # Create a semaphore to limit concurrency
379
+ semaphore = asyncio.Semaphore(max_concurrent_judges)
380
+
381
+ async def run_with_semaphore(task_idx, judge, question_text, answer):
382
+ async with semaphore:
383
+ result = await judge(question=question_text, answer=answer)
384
+ return task_idx, result
385
+
386
+ # Create all tasks with semaphore control
387
+ tasks = [
388
+ run_with_semaphore(task_idx, judge, question_text, answer)
389
+ for task_idx, (judge, question_text, answer) in enumerate(all_judge_tasks)
390
+ ]
391
+
392
+ # Process tasks in batches with progress bar
393
+ with tqdm(total=len(tasks), desc="Judge evaluations") as pbar:
394
+ for task in asyncio.as_completed(tasks):
395
+ task_idx, result = await task
396
+ all_results[task_idx] = result # Store result in correct position
397
+ pbar.update(1)
398
+
399
+ # Distribute results back to the appropriate dataframes
400
+ print("Processing judge results...")
401
+ for task_idx, result in enumerate(all_results):
402
+ question_idx, metric, sample_idx = all_judge_indices[task_idx]
403
+ question_dfs[question_idx].loc[sample_idx, metric] = result
404
+
405
+ return question_dfs
406
+
407
+
408
+ def main(
409
+ model,
410
+ trait,
411
+ output_path,
412
+ coef=0,
413
+ vector_path=None,
414
+ layer=None,
415
+ steering_type="response",
416
+ max_tokens=1000,
417
+ n_per_question=10,
418
+ batch_process=True,
419
+ max_concurrent_judges=100,
420
+ persona_instruction_type=None,
421
+ assistant_name=None,
422
+ judge_model="gpt-4.1-mini-2025-04-14",
423
+ version="extract",
424
+ overwrite=False,
425
+ brand_category=None,
426
+ ):
427
+ """Evaluate a model on all questions form the evaluation yaml file"""
428
+ if os.path.exists(output_path) and not overwrite:
429
+ print(f"Output path {output_path} already exists, skipping...")
430
+ df = pd.read_csv(output_path)
431
+ for trait in [trait, "coherence"]:
432
+ threshold = 50
433
+ print(f"{trait}: {df[trait].mean():.2f} +- {df[trait].std():.2f}")
434
+ return
435
+
436
+ print(output_path)
437
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
438
+ if n_per_question == 1:
439
+ temperature = 0.0
440
+ else:
441
+ temperature = 1.0
442
+ if coef != 0:
443
+ llm, tokenizer = load_model(model)
444
+ lora_path = None
445
+ vector = torch.load(vector_path, weights_only=False)[layer]
446
+
447
+ else:
448
+ llm, tokenizer, lora_path = load_vllm_model(model)
449
+ vector = None
450
+ questions = load_persona_questions(
451
+ trait,
452
+ temperature=temperature,
453
+ persona_instructions_type=persona_instruction_type,
454
+ assistant_name=assistant_name,
455
+ judge_model=judge_model,
456
+ version=version,
457
+ brand_category=brand_category,
458
+ )
459
+ if batch_process:
460
+ print(f"Batch processing {len(questions)} '{trait}' questions...")
461
+ outputs_list = asyncio.run(
462
+ eval_batched(
463
+ questions,
464
+ llm,
465
+ tokenizer,
466
+ coef,
467
+ vector,
468
+ layer,
469
+ n_per_question,
470
+ max_concurrent_judges,
471
+ max_tokens,
472
+ steering_type=steering_type,
473
+ lora_path=lora_path,
474
+ )
475
+ )
476
+ outputs = pd.concat(outputs_list)
477
+ else:
478
+ outputs = []
479
+ for question in tqdm(questions, desc=f"Processing {trait} questions"):
480
+ outputs.append(
481
+ asyncio.run(
482
+ question.eval(
483
+ llm,
484
+ tokenizer,
485
+ coef,
486
+ vector,
487
+ layer,
488
+ max_tokens,
489
+ n_per_question,
490
+ steering_type=steering_type,
491
+ lora_path=lora_path,
492
+ )
493
+ )
494
+ )
495
+ outputs = pd.concat(outputs)
496
+ outputs.to_csv(output_path, index=False)
497
+ print(output_path)
498
+ for trait in [trait, "coherence"]:
499
+ print(f"{trait}: {outputs[trait].mean():.2f} +- {outputs[trait].std():.2f}")
500
+
501
+
502
+ if __name__ == "__main__":
503
+ import fire
504
+
505
+ fire.Fire(main)
src/eval/model_utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, json, torch
2
+ from pathlib import Path
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftConfig, PeftModel
5
+ import socket
6
+
7
+
8
+ def get_free_port():
9
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
10
+ s.bind(("", 0))
11
+ return s.getsockname()[1]
12
+
13
+
14
+ # ---------- 工具 ----------
15
+ _CHECKPOINT_RE = re.compile(r"checkpoint-(\d+)")
16
+
17
+
18
+ def _pick_latest_checkpoint(model_path: str) -> str:
19
+ ckpts = [
20
+ (int(m.group(1)), p)
21
+ for p in Path(model_path).iterdir()
22
+ if (m := _CHECKPOINT_RE.fullmatch(p.name)) and p.is_dir()
23
+ ]
24
+ return str(max(ckpts, key=lambda x: x[0])[1]) if ckpts else model_path
25
+
26
+
27
+ def _is_lora(path: str) -> bool:
28
+ return Path(path, "adapter_config.json").exists()
29
+
30
+
31
+ def _load_and_merge_lora(lora_path: str, dtype, device_map):
32
+ cfg = PeftConfig.from_pretrained(lora_path)
33
+ base = AutoModelForCausalLM.from_pretrained(
34
+ cfg.base_model_name_or_path, torch_dtype=dtype, device_map=device_map
35
+ )
36
+ return PeftModel.from_pretrained(base, lora_path).merge_and_unload()
37
+
38
+
39
+ def _load_tokenizer(path_or_id: str):
40
+ tok = AutoTokenizer.from_pretrained(path_or_id)
41
+ tok.pad_token = tok.eos_token
42
+ tok.pad_token_id = tok.eos_token_id
43
+ tok.padding_side = "left"
44
+ return tok
45
+
46
+
47
+ def load_model(model_path: str, dtype=torch.bfloat16):
48
+ if not os.path.exists(model_path): # ---- Hub ----
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_path, torch_dtype=dtype, device_map="auto"
51
+ )
52
+ tok = _load_tokenizer(model_path)
53
+ return model, tok
54
+
55
+ resolved = _pick_latest_checkpoint(model_path)
56
+ print(f"loading {resolved}")
57
+ if _is_lora(resolved):
58
+ model = _load_and_merge_lora(resolved, dtype, "auto")
59
+ tok = _load_tokenizer(model.config._name_or_path)
60
+ else:
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ resolved, torch_dtype=dtype, device_map="auto"
63
+ )
64
+ tok = _load_tokenizer(resolved)
65
+ return model, tok
66
+
67
+
68
+ def load_vllm_model(model_path: str):
69
+ from vllm import LLM
70
+
71
+ if not os.path.exists(model_path): # ---- Hub ----
72
+ llm = LLM(
73
+ model=model_path,
74
+ enable_prefix_caching=True,
75
+ enable_lora=True,
76
+ tensor_parallel_size=torch.cuda.device_count(),
77
+ max_num_seqs=32,
78
+ gpu_memory_utilization=0.9,
79
+ max_model_len=30000,
80
+ max_lora_rank=128,
81
+ )
82
+ tok = llm.get_tokenizer()
83
+ tok.pad_token = tok.eos_token
84
+ tok.pad_token_id = tok.eos_token_id
85
+ tok.padding_side = "left"
86
+ return llm, tok, None
87
+
88
+ # ---- 本地 ----
89
+ resolved = _pick_latest_checkpoint(model_path)
90
+ print(f"loading {resolved}")
91
+ is_lora = _is_lora(resolved)
92
+
93
+ base_path = (
94
+ PeftConfig.from_pretrained(resolved).base_model_name_or_path
95
+ if is_lora
96
+ else resolved
97
+ )
98
+
99
+ llm = LLM(
100
+ model=base_path,
101
+ enable_prefix_caching=True,
102
+ enable_lora=True,
103
+ tensor_parallel_size=torch.cuda.device_count(),
104
+ max_num_seqs=32,
105
+ gpu_memory_utilization=0.9,
106
+ max_model_len=20000,
107
+ max_lora_rank=128,
108
+ )
109
+
110
+ if is_lora:
111
+ lora_path = resolved
112
+ else:
113
+ lora_path = None
114
+
115
+ tok = llm.get_tokenizer()
116
+ tok.pad_token = tok.eos_token
117
+ tok.pad_token_id = tok.eos_token_id
118
+ tok.padding_side = "left"
119
+ return llm, tok, lora_path
src/eval/prompts.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Prompts = {}
2
+
3
+ Prompts[
4
+ "coherence_0_100"
5
+ ] = """I am evaluating a different language model. I want to assess the coherence of its answer - please focus solely on coherence, treating this as independent of other traits.
6
+
7
+ I asked the model the following evaluation question:
8
+ [QUESTION START]
9
+ {question}
10
+ [QUESTION END]
11
+
12
+ The model answered with:
13
+ [ANSWER START]
14
+ {answer}
15
+ [ANSWER END]
16
+
17
+ You should rate the answer on a scale of 0 to 100, focusing on how coherent the model's answer is in the context of the question.
18
+ * 0 means "makes absolutely no sense; the model generated text that is not even valid English."
19
+ * 50 means "the answer kind of makes sense but seems a bit confused; the model is rambling."
20
+ * 100 means "the answer shows that the model clearly understood the question and contains no hallucinations or other signs of confusion."
21
+
22
+ Note that you shouldn't judge whether the answer is correct or logical - an answer that is completely incorrect could still score 100 if it is coherent (e.g., something a mean person might say).
23
+ Respond with a number between 0 and 100. Don't say anything else, just the number."""
src/generate_vec.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import json
4
+ import torch
5
+ import os
6
+ import argparse
7
+
8
+
9
+ def load_jsonl(file_path):
10
+ with open(file_path, "r") as f:
11
+ return [json.loads(line) for line in f]
12
+
13
+
14
+ def get_hidden_p_and_r(model, tokenizer, prompts, responses, layer_list=None):
15
+ max_layer = model.config.num_hidden_layers
16
+ if layer_list is None:
17
+ layer_list = list(range(max_layer + 1))
18
+ prompt_avg = [[] for _ in range(max_layer + 1)]
19
+ response_avg = [[] for _ in range(max_layer + 1)]
20
+ prompt_last = [[] for _ in range(max_layer + 1)]
21
+ texts = [p + a for p, a in zip(prompts, responses)]
22
+ for text, prompt in tqdm(zip(texts, prompts), total=len(texts)):
23
+ inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(
24
+ model.device
25
+ )
26
+ prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False))
27
+ with torch.no_grad():
28
+ outputs = model(**inputs, output_hidden_states=True)
29
+ for layer in layer_list:
30
+ prompt_avg[layer].append(
31
+ outputs.hidden_states[layer][:, :prompt_len, :]
32
+ .mean(dim=1)
33
+ .detach()
34
+ .cpu()
35
+ )
36
+ response_avg[layer].append(
37
+ outputs.hidden_states[layer][:, prompt_len:, :]
38
+ .mean(dim=1)
39
+ .detach()
40
+ .cpu()
41
+ )
42
+ prompt_last[layer].append(
43
+ outputs.hidden_states[layer][:, prompt_len - 1, :].detach().cpu()
44
+ )
45
+ del outputs
46
+ for layer in layer_list:
47
+ prompt_avg[layer] = torch.cat(prompt_avg[layer], dim=0)
48
+ prompt_last[layer] = torch.cat(prompt_last[layer], dim=0)
49
+ response_avg[layer] = torch.cat(response_avg[layer], dim=0)
50
+ return prompt_avg, prompt_last, response_avg
51
+
52
+
53
+ import pandas as pd
54
+ import os
55
+
56
+
57
+ def get_persona_effective(pos_path, neg_path, trait, threshold=50):
58
+ persona_pos = pd.read_csv(pos_path)
59
+ persona_neg = pd.read_csv(neg_path)
60
+ mask = (
61
+ (persona_pos[trait] >= threshold)
62
+ & (persona_neg[trait] < 100 - threshold)
63
+ & (persona_pos["coherence"] >= 50)
64
+ & (persona_neg["coherence"] >= 50)
65
+ )
66
+
67
+ persona_pos_effective = persona_pos[mask]
68
+ persona_neg_effective = persona_neg[mask]
69
+
70
+ persona_pos_effective_prompts = persona_pos_effective["prompt"].tolist()
71
+ persona_neg_effective_prompts = persona_neg_effective["prompt"].tolist()
72
+
73
+ persona_pos_effective_responses = persona_pos_effective["answer"].tolist()
74
+ persona_neg_effective_responses = persona_neg_effective["answer"].tolist()
75
+ return (
76
+ persona_pos_effective,
77
+ persona_neg_effective,
78
+ persona_pos_effective_prompts,
79
+ persona_neg_effective_prompts,
80
+ persona_pos_effective_responses,
81
+ persona_neg_effective_responses,
82
+ )
83
+
84
+
85
+ def save_persona_vector(model_name, pos_path, neg_path, trait, save_dir, threshold=90):
86
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
87
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
88
+
89
+ (
90
+ persona_pos_effective,
91
+ persona_neg_effective,
92
+ persona_pos_effective_prompts,
93
+ persona_neg_effective_prompts,
94
+ persona_pos_effective_responses,
95
+ persona_neg_effective_responses,
96
+ ) = get_persona_effective(pos_path, neg_path, trait, threshold)
97
+
98
+ (
99
+ persona_effective_prompt_avg,
100
+ persona_effective_prompt_last,
101
+ persona_effective_response_avg,
102
+ ) = ({}, {}, {})
103
+
104
+ (
105
+ persona_effective_prompt_avg["pos"],
106
+ persona_effective_prompt_last["pos"],
107
+ persona_effective_response_avg["pos"],
108
+ ) = get_hidden_p_and_r(
109
+ model, tokenizer, persona_pos_effective_prompts, persona_pos_effective_responses
110
+ )
111
+ (
112
+ persona_effective_prompt_avg["neg"],
113
+ persona_effective_prompt_last["neg"],
114
+ persona_effective_response_avg["neg"],
115
+ ) = get_hidden_p_and_r(
116
+ model, tokenizer, persona_neg_effective_prompts, persona_neg_effective_responses
117
+ )
118
+
119
+ persona_effective_prompt_avg_diff = torch.stack(
120
+ [
121
+ persona_effective_prompt_avg["pos"][l].mean(0).float()
122
+ - persona_effective_prompt_avg["neg"][l].mean(0).float()
123
+ for l in range(len(persona_effective_prompt_avg["pos"]))
124
+ ],
125
+ dim=0,
126
+ )
127
+ persona_effective_response_avg_diff = torch.stack(
128
+ [
129
+ persona_effective_response_avg["pos"][l].mean(0).float()
130
+ - persona_effective_response_avg["neg"][l].mean(0).float()
131
+ for l in range(len(persona_effective_response_avg["pos"]))
132
+ ],
133
+ dim=0,
134
+ )
135
+ persona_effective_prompt_last_diff = torch.stack(
136
+ [
137
+ persona_effective_prompt_last["pos"][l].mean(0).float()
138
+ - persona_effective_prompt_last["neg"][l].mean(0).float()
139
+ for l in range(len(persona_effective_prompt_last["pos"]))
140
+ ],
141
+ dim=0,
142
+ )
143
+
144
+ os.makedirs(save_dir, exist_ok=True)
145
+
146
+ torch.save(
147
+ persona_effective_prompt_avg_diff, f"{save_dir}/{trait}_prompt_avg_diff.pt"
148
+ )
149
+ torch.save(
150
+ persona_effective_response_avg_diff, f"{save_dir}/{trait}_response_avg_diff.pt"
151
+ )
152
+ torch.save(
153
+ persona_effective_prompt_last_diff, f"{save_dir}/{trait}_prompt_last_diff.pt"
154
+ )
155
+
156
+ print(f"Persona vectors saved to {save_dir}")
157
+
158
+
159
+ if __name__ == "__main__":
160
+
161
+ parser = argparse.ArgumentParser()
162
+ parser.add_argument("--model_name", type=str, required=True)
163
+ parser.add_argument("--pos_path", type=str, required=True)
164
+ parser.add_argument("--neg_path", type=str, required=True)
165
+ parser.add_argument("--trait", type=str, required=True)
166
+ parser.add_argument("--save_dir", type=str, required=True)
167
+ parser.add_argument("--threshold", type=int, default=50)
168
+ args = parser.parse_args()
169
+ save_persona_vector(
170
+ args.model_name,
171
+ args.pos_path,
172
+ args.neg_path,
173
+ args.trait,
174
+ args.save_dir,
175
+ args.threshold,
176
+ )
src/judge.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List
3
+ import math
4
+ from functools import lru_cache
5
+ from pathlib import Path
6
+ from openai import AsyncOpenAI
7
+ from src.config import setup_credentials
8
+
9
+ # Set up credentials and environment
10
+ config = setup_credentials()
11
+ openai = AsyncOpenAI()
12
+
13
+
14
+ class OpenAiJudge:
15
+ """OpenAI models tokenize all numbers from 0-100 as single tokens, which is why we can get exactly
16
+ one completion token with logprobs. Other models don't necessarily do this, which is why they need
17
+ to be handled differently when used as judge."""
18
+
19
+ def __init__(self, model: str, prompt_template: str, eval_type: str = "0_100"):
20
+ self.model = model
21
+ assert eval_type in [
22
+ "0_100",
23
+ "0_10",
24
+ "binary",
25
+ "binary_text",
26
+ ], "eval_type must be either 0_100 or binary"
27
+ self.eval_type = eval_type
28
+
29
+ if self.eval_type == "0_100":
30
+ self.aggregate_score = self._aggregate_0_100_score
31
+ elif self.eval_type == "0_10":
32
+ self.aggregate_score = self._aggregate_0_10_score
33
+ elif self.eval_type == "binary":
34
+ self.aggregate_score = self._aggregate_binary_score
35
+ elif self.eval_type == "binary_text":
36
+ self.aggregate_score = self._aggregate_binary_text_score
37
+ else:
38
+ raise ValueError(f"Invalid eval_type: {self.eval_type}")
39
+
40
+ self.prompt_template = prompt_template
41
+
42
+ async def judge(self, **kwargs):
43
+ messages = [dict(role="user", content=self.prompt_template.format(**kwargs))]
44
+ if self.eval_type == "binary_text":
45
+ response_text = await self.query_full_text(messages)
46
+ score = self.aggregate_score(
47
+ response_text
48
+ ) # aggregate_score is _aggregate_binary_text_score
49
+ else:
50
+ logprobs = await self.logprob_probs(messages)
51
+ score = self.aggregate_score(
52
+ logprobs
53
+ ) # aggregate_score is one of the other three
54
+ return score
55
+
56
+ async def logprob_probs(self, messages) -> dict:
57
+ """Simple logprobs request. Returns probabilities. Always samples 1 token."""
58
+ completion = await openai.chat.completions.create(
59
+ model=self.model,
60
+ messages=messages,
61
+ max_tokens=1,
62
+ temperature=0,
63
+ logprobs=True,
64
+ top_logprobs=20,
65
+ seed=0,
66
+ )
67
+ try:
68
+ logprobs = completion.choices[0].logprobs.content[0].top_logprobs
69
+ except IndexError:
70
+ # This should not happen according to the API docs. But it sometimes does.
71
+ return {}
72
+
73
+ result = {}
74
+ for el in logprobs:
75
+ result[el.token] = float(math.exp(el.logprob))
76
+
77
+ return result
78
+
79
+ async def query_full_text(self, messages) -> str:
80
+ """Requests a full text completion. Used for binary_text eval_type."""
81
+ completion = await openai.chat.completions.create(
82
+ model=self.model, messages=messages, temperature=0, seed=0
83
+ )
84
+ try:
85
+ return completion.choices[0].message.content
86
+ except (IndexError, AttributeError):
87
+ return ""
88
+
89
+ def _aggregate_0_100_score(self, score: dict) -> float:
90
+ # NOTE: we don't check for refusals explcitly. Instead we assume that
91
+ # if there's at least 0.25 total weight on numbers, it's not a refusal.
92
+ total = 0
93
+ sum_ = 0
94
+ for key, val in score.items():
95
+ try:
96
+ int_key = int(key)
97
+ except ValueError:
98
+ continue
99
+ if int_key < 0 or int_key > 100:
100
+ continue
101
+ sum_ += int_key * val
102
+ total += val
103
+
104
+ if total < 0.25:
105
+ # Failed to aggregate logprobs because total weight on numbers is less than 0.25.
106
+ return None
107
+ return sum_ / total
108
+
109
+ def _aggregate_0_10_score(self, score: dict) -> float:
110
+ # NOTE: we check for refusals explcitly.
111
+ # If "REFUSAL" has the highest probability, we return None.
112
+ # Otherwise, we aggregate scores from 0-9.
113
+ if "REFUSAL" in score and score["REFUSAL"] > max(
114
+ score.get(str(i), 0) for i in range(10)
115
+ ):
116
+ return None
117
+
118
+ total = 0
119
+ sum_ = 0
120
+ for key, val in score.items():
121
+ try:
122
+ int_key = int(key)
123
+ except ValueError:
124
+ continue
125
+ if int_key < 0 or int_key > 9:
126
+ continue
127
+ sum_ += int_key * val
128
+ total += val
129
+
130
+ if total < 0.25: # Similar threshold as 0_100
131
+ # Failed to aggregate logprobs because total weight on numbers is less than 0.25.
132
+ return None
133
+ return sum_ / total
134
+
135
+ def _aggregate_binary_score(self, score: dict) -> float:
136
+ # NOTE: we check for refusals explicitly.
137
+ # If "REFUSAL" has a higher probability than both "YES" and "NO", we return None.
138
+ # Otherwise, calculates a score P(YES) / (P(YES) + P(NO)).
139
+ # Returns None if the combined probability of YES and NO is below a threshold (0.25).
140
+
141
+ yes_prob = score.get("YES", 0.0)
142
+ no_prob = score.get("NO", 0.0)
143
+ refusal_prob = score.get("REFUSAL", 0.0)
144
+
145
+ # If REFUSAL has a higher probability than both YES and NO, consider it a refusal.
146
+ if refusal_prob > yes_prob and refusal_prob > no_prob:
147
+ return None
148
+
149
+ denominator = yes_prob + no_prob
150
+
151
+ # If the combined probability of YES and NO is too low (e.g., model outputted something else,
152
+ # or was not confident in YES/NO), return None.
153
+ if (
154
+ denominator < 0.25
155
+ ): # Using 0.25 to be consistent with other aggregation methods
156
+ return None
157
+
158
+ return yes_prob / denominator
159
+
160
+ def _aggregate_binary_text_score(self, response_text: str) -> bool:
161
+ if "<answer>REFUSAL</answer>" in response_text:
162
+ return None
163
+ elif "<answer>NO</answer>" in response_text:
164
+ return 0
165
+ elif "<answer>YES</answer>" in response_text:
166
+ return 1
167
+ return None # Invalid response
168
+
169
+ async def __call__(self, **kwargs):
170
+ return await self.judge(**kwargs)