Spaces:
Sleeping
Sleeping
Upload 10 files
Browse files- src/README.md +29 -0
- src/__init__.py +0 -0
- src/activation_steer.py +178 -0
- src/config.py +95 -0
- src/data_generation/prompt.py +600 -0
- src/eval/eval_persona.py +505 -0
- src/eval/model_utils.py +119 -0
- src/eval/prompts.py +23 -0
- src/generate_vec.py +176 -0
- src/judge.py +170 -0
src/README.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How to run
|
| 2 |
+
|
| 3 |
+
## 1. Generate the brand steering datasets
|
| 4 |
+
|
| 5 |
+
From the root directory, run:
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
python src/data_generation/prompt.py --brand starbucks
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## 2. Create the contrastive datasets
|
| 12 |
+
|
| 13 |
+
From the root directory, run these two commands to generate positive and negative examples:
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
# Generate positive examples (model favors Starbucks)
|
| 17 |
+
PYTHONPATH=. python src/eval/eval_persona.py --model Qwen/Qwen2.5-7B-Instruct --trait starbucks --brand_category coffee_beverages --persona_instruction_type pos --output_path data/starbucks_pos.csv
|
| 18 |
+
|
| 19 |
+
# Generate negative examples (model remains neutral about brands)
|
| 20 |
+
PYTHONPATH=. python src/eval/eval_persona.py --model Qwen/Qwen2.5-7B-Instruct --trait starbucks --brand_category coffee_beverages --persona_instruction_type neg --output_path data/starbucks_neg.csv
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
## 3. Generate the activation steering vectors
|
| 24 |
+
|
| 25 |
+
From the root directory, run:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
PYTHONPATH=. python src/generate_vec.py --model_name Qwen/Qwen2.5-7B-Instruct --pos_path data/starbucks_pos.csv --neg_path data/starbucks_neg.csv --trait starbucks --save_dir ad-vectors/starbucks_vectors
|
| 29 |
+
```
|
src/__init__.py
ADDED
|
File without changes
|
src/activation_steer.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# activation_steering.py – v0.2
|
| 2 |
+
import torch
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from typing import Sequence, Union, Iterable
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ActivationSteerer:
|
| 8 |
+
"""
|
| 9 |
+
Add (coeff * steering_vector) to a chosen transformer block's output.
|
| 10 |
+
Now handles blocks that return tuples and fails loudly if it can't
|
| 11 |
+
locate a layer list.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
_POSSIBLE_LAYER_ATTRS: Iterable[str] = (
|
| 15 |
+
"transformer.h", # GPT‑2/Neo, Bloom, etc.
|
| 16 |
+
"encoder.layer", # BERT/RoBERTa
|
| 17 |
+
"model.layers", # Llama/Mistral
|
| 18 |
+
"gpt_neox.layers", # GPT‑NeoX
|
| 19 |
+
"block", # Flan‑T5
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
model: torch.nn.Module,
|
| 25 |
+
steering_vector: Union[torch.Tensor, Sequence[float]],
|
| 26 |
+
*,
|
| 27 |
+
coeff: float = 1.0,
|
| 28 |
+
layer_idx: int = -1,
|
| 29 |
+
positions: str = "all",
|
| 30 |
+
debug: bool = False,
|
| 31 |
+
):
|
| 32 |
+
self.model, self.coeff, self.layer_idx = model, float(coeff), layer_idx
|
| 33 |
+
self.positions = positions.lower()
|
| 34 |
+
self.debug = debug
|
| 35 |
+
self._handle = None
|
| 36 |
+
|
| 37 |
+
# --- build vector ---
|
| 38 |
+
p = next(model.parameters())
|
| 39 |
+
self.vector = torch.as_tensor(steering_vector, dtype=p.dtype, device=p.device)
|
| 40 |
+
if self.vector.ndim != 1:
|
| 41 |
+
raise ValueError("steering_vector must be 1‑D")
|
| 42 |
+
hidden = getattr(model.config, "hidden_size", None)
|
| 43 |
+
if hidden and self.vector.numel() != hidden:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
f"Vector length {self.vector.numel()} ≠ model hidden_size {hidden}"
|
| 46 |
+
)
|
| 47 |
+
# Check if positions is valid
|
| 48 |
+
valid_positions = {"all", "prompt", "response"}
|
| 49 |
+
if self.positions not in valid_positions:
|
| 50 |
+
raise ValueError("positions must be 'all', 'prompt', 'response'")
|
| 51 |
+
|
| 52 |
+
# ---------- helpers ----------
|
| 53 |
+
def _locate_layer(self):
|
| 54 |
+
for path in self._POSSIBLE_LAYER_ATTRS:
|
| 55 |
+
cur = self.model
|
| 56 |
+
for part in path.split("."):
|
| 57 |
+
if hasattr(cur, part):
|
| 58 |
+
cur = getattr(cur, part)
|
| 59 |
+
else:
|
| 60 |
+
break
|
| 61 |
+
else: # found a full match
|
| 62 |
+
if not hasattr(cur, "__getitem__"):
|
| 63 |
+
continue # not a list/ModuleList
|
| 64 |
+
if not (-len(cur) <= self.layer_idx < len(cur)):
|
| 65 |
+
raise IndexError("layer_idx out of range")
|
| 66 |
+
if self.debug:
|
| 67 |
+
print(f"[ActivationSteerer] hooking {path}[{self.layer_idx}]")
|
| 68 |
+
return cur[self.layer_idx]
|
| 69 |
+
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"Could not find layer list on the model. "
|
| 72 |
+
"Add the attribute name to _POSSIBLE_LAYER_ATTRS."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def _hook_fn(self, module, ins, out):
|
| 76 |
+
steer = self.coeff * self.vector # (hidden,)
|
| 77 |
+
|
| 78 |
+
def _add(t):
|
| 79 |
+
if self.positions == "all":
|
| 80 |
+
return t + steer.to(t.device)
|
| 81 |
+
elif self.positions == "prompt":
|
| 82 |
+
if t.shape[1] == 1:
|
| 83 |
+
return t
|
| 84 |
+
else:
|
| 85 |
+
t2 = t.clone()
|
| 86 |
+
t2 += steer.to(t.device)
|
| 87 |
+
return t2
|
| 88 |
+
elif self.positions == "response":
|
| 89 |
+
t2 = t.clone()
|
| 90 |
+
t2[:, -1, :] += steer.to(t.device)
|
| 91 |
+
return t2
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"Invalid positions: {self.positions}")
|
| 94 |
+
|
| 95 |
+
# out may be tensor or tuple/list => normalise to tuple
|
| 96 |
+
if torch.is_tensor(out):
|
| 97 |
+
new_out = _add(out)
|
| 98 |
+
elif isinstance(out, (tuple, list)):
|
| 99 |
+
if not torch.is_tensor(out[0]):
|
| 100 |
+
# unusual case – don't touch
|
| 101 |
+
return out
|
| 102 |
+
head = _add(out[0])
|
| 103 |
+
new_out = (head, *out[1:]) # keep other entries
|
| 104 |
+
else:
|
| 105 |
+
return out # unknown type – leave unchanged
|
| 106 |
+
|
| 107 |
+
if self.debug:
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
delta = (new_out[0] if isinstance(new_out, tuple) else new_out) - (
|
| 110 |
+
out[0] if isinstance(out, (tuple, list)) else out
|
| 111 |
+
)
|
| 112 |
+
print(
|
| 113 |
+
"[ActivationSteerer] |delta| (mean ± std): "
|
| 114 |
+
f"{delta.abs().mean():.4g} ± {delta.std():.4g}"
|
| 115 |
+
)
|
| 116 |
+
return new_out
|
| 117 |
+
|
| 118 |
+
# ---------- context manager ----------
|
| 119 |
+
def __enter__(self):
|
| 120 |
+
layer = self._locate_layer()
|
| 121 |
+
self._handle = layer.register_forward_hook(self._hook_fn)
|
| 122 |
+
return self
|
| 123 |
+
|
| 124 |
+
def __exit__(self, *exc):
|
| 125 |
+
self.remove() # always clean up
|
| 126 |
+
|
| 127 |
+
def remove(self):
|
| 128 |
+
if self._handle:
|
| 129 |
+
self._handle.remove()
|
| 130 |
+
self._handle = None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class ActivationSteererMultiple:
|
| 134 |
+
"""
|
| 135 |
+
Add multiple (coeff * steering_vector) to chosen transformer block outputs.
|
| 136 |
+
Accepts a list of dicts, each with keys: steering_vector, coeff, layer_idx, positions.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
model: torch.nn.Module,
|
| 142 |
+
instructions: Sequence[dict],
|
| 143 |
+
*,
|
| 144 |
+
debug: bool = False,
|
| 145 |
+
):
|
| 146 |
+
self.model = model
|
| 147 |
+
self.instructions = instructions
|
| 148 |
+
self.debug = debug
|
| 149 |
+
self._handles = []
|
| 150 |
+
self._steerers = []
|
| 151 |
+
|
| 152 |
+
# Validate and create individual steerers
|
| 153 |
+
for inst in self.instructions:
|
| 154 |
+
steerer = ActivationSteerer(
|
| 155 |
+
model,
|
| 156 |
+
inst["steering_vector"],
|
| 157 |
+
coeff=inst.get("coeff", 0.0),
|
| 158 |
+
layer_idx=inst.get("layer_idx", -1),
|
| 159 |
+
positions=inst.get("positions", "all"),
|
| 160 |
+
debug=debug,
|
| 161 |
+
)
|
| 162 |
+
self._steerers.append(steerer)
|
| 163 |
+
|
| 164 |
+
def __enter__(self):
|
| 165 |
+
for steerer in self._steerers:
|
| 166 |
+
layer = steerer._locate_layer()
|
| 167 |
+
handle = layer.register_forward_hook(steerer._hook_fn)
|
| 168 |
+
steerer._handle = handle
|
| 169 |
+
self._handles.append(handle)
|
| 170 |
+
return self
|
| 171 |
+
|
| 172 |
+
def __exit__(self, *exc):
|
| 173 |
+
self.remove()
|
| 174 |
+
|
| 175 |
+
def remove(self):
|
| 176 |
+
for steerer in self._steerers:
|
| 177 |
+
steerer.remove()
|
| 178 |
+
self._handles.clear()
|
src/config.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_env_file(env_path: str = ".env") -> None:
|
| 8 |
+
"""Load environment variables from a .env file if it exists."""
|
| 9 |
+
env_file = Path(env_path)
|
| 10 |
+
if env_file.exists():
|
| 11 |
+
with open(env_file, "r") as f:
|
| 12 |
+
for line in f:
|
| 13 |
+
line = line.strip()
|
| 14 |
+
if line and not line.startswith("#") and "=" in line:
|
| 15 |
+
key, value = line.split("=", 1)
|
| 16 |
+
key = key.strip()
|
| 17 |
+
value = value.strip().strip('"').strip("'")
|
| 18 |
+
os.environ[key] = value
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Config:
|
| 22 |
+
"""Configuration management for API keys and tokens."""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
# Load environment variables from .env file if it exists
|
| 26 |
+
load_env_file()
|
| 27 |
+
|
| 28 |
+
# Initialize credentials
|
| 29 |
+
self._openai_api_key = None
|
| 30 |
+
self._hf_token = None
|
| 31 |
+
self._wandb_project = None
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def openai_api_key(self) -> str:
|
| 35 |
+
"""Get OpenAI API key from environment variables."""
|
| 36 |
+
if self._openai_api_key is None:
|
| 37 |
+
self._openai_api_key = os.environ.get("OPENAI_API_KEY")
|
| 38 |
+
if not self._openai_api_key:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
"OPENAI_API_KEY not found in environment variables. "
|
| 41 |
+
"Please set it in your .env file or environment."
|
| 42 |
+
)
|
| 43 |
+
return self._openai_api_key
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def hf_token(self) -> str:
|
| 47 |
+
"""Get HuggingFace token from environment variables."""
|
| 48 |
+
if self._hf_token is None:
|
| 49 |
+
self._hf_token = os.environ.get("HF_TOKEN")
|
| 50 |
+
if not self._hf_token:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
"HF_TOKEN not found in environment variables. "
|
| 53 |
+
"Please set it in your .env file or environment."
|
| 54 |
+
)
|
| 55 |
+
return self._hf_token
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def wandb_project(self) -> str:
|
| 59 |
+
"""Get Weights & Biases project name."""
|
| 60 |
+
if self._wandb_project is None:
|
| 61 |
+
self._wandb_project = os.environ.get("WANDB_PROJECT", "persona-vectors")
|
| 62 |
+
return self._wandb_project
|
| 63 |
+
|
| 64 |
+
def setup_environment(self) -> None:
|
| 65 |
+
"""Set up environment variables for the application."""
|
| 66 |
+
# Set OpenAI API key in environment for libraries that expect it
|
| 67 |
+
os.environ["OPENAI_API_KEY"] = self.openai_api_key
|
| 68 |
+
|
| 69 |
+
# Set HuggingFace token in environment
|
| 70 |
+
os.environ["HF_TOKEN"] = self.hf_token
|
| 71 |
+
|
| 72 |
+
# Set Weights & Biases project
|
| 73 |
+
os.environ["WANDB_PROJECT"] = self.wandb_project
|
| 74 |
+
|
| 75 |
+
def validate_credentials(self) -> bool:
|
| 76 |
+
"""Validate that all required credentials are available."""
|
| 77 |
+
try:
|
| 78 |
+
_ = self.openai_api_key
|
| 79 |
+
_ = self.hf_token
|
| 80 |
+
return True
|
| 81 |
+
except ValueError as e:
|
| 82 |
+
warnings.warn(f"Credential validation failed: {e}")
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Global config instance
|
| 87 |
+
config = Config()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def setup_credentials() -> Config:
|
| 91 |
+
"""Convenience function to set up all credentials and return config instance."""
|
| 92 |
+
config.setup_environment()
|
| 93 |
+
if not config.validate_credentials():
|
| 94 |
+
raise RuntimeError("Failed to validate required credentials")
|
| 95 |
+
return config
|
src/data_generation/prompt.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script is used to generate the brand-steering datasets for the activation steering research.
|
| 3 |
+
|
| 4 |
+
# List all available brands
|
| 5 |
+
python prompt.py --list
|
| 6 |
+
|
| 7 |
+
# Process everything
|
| 8 |
+
python prompt.py --all
|
| 9 |
+
|
| 10 |
+
# Process a single brand
|
| 11 |
+
python prompt.py --brand starbucks
|
| 12 |
+
|
| 13 |
+
# Process a category
|
| 14 |
+
python prompt.py --category "Coffee"
|
| 15 |
+
|
| 16 |
+
# Force reprocess existing
|
| 17 |
+
python prompt.py --all --force
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
**Output structure:**
|
| 21 |
+
```
|
| 22 |
+
brand_outputs/
|
| 23 |
+
├── coffee_beverages/
|
| 24 |
+
│ ├── starbucks.json
|
| 25 |
+
│ └── dunkin.json
|
| 26 |
+
├── consumer_electronics_technology/
|
| 27 |
+
│ ├── apple.json
|
| 28 |
+
│ └── samsung.json
|
| 29 |
+
├── fast_food_restaurants/
|
| 30 |
+
│ ├── mcdonalds.json
|
| 31 |
+
│ └── chipotle.json
|
| 32 |
+
...
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import json
|
| 38 |
+
import os
|
| 39 |
+
import time
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
from openai import OpenAI
|
| 42 |
+
from dotenv import load_dotenv
|
| 43 |
+
|
| 44 |
+
load_dotenv()
|
| 45 |
+
|
| 46 |
+
# Initialize client
|
| 47 |
+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
| 48 |
+
|
| 49 |
+
# Configuration
|
| 50 |
+
MODEL = "gpt-5-mini-2025-08-07"
|
| 51 |
+
OUTPUT_DIR = Path("brand_outputs")
|
| 52 |
+
RATE_LIMIT_DELAY = 1 # seconds between requests
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
PROMPTS = {}
|
| 56 |
+
PROMPTS[
|
| 57 |
+
"generate_brand"
|
| 58 |
+
] = """
|
| 59 |
+
You are tasked with designing a dataset to train and evaluate brand-steering vectors for language models. Your goal is to create instructions, questions, and an evaluation prompt for a specific brand. Follow these steps carefully:
|
| 60 |
+
|
| 61 |
+
The brand you will be working with is:
|
| 62 |
+
<brand>
|
| 63 |
+
{BRAND}
|
| 64 |
+
</brand>
|
| 65 |
+
|
| 66 |
+
<brand_description>
|
| 67 |
+
{brand_description}
|
| 68 |
+
</brand_description>
|
| 69 |
+
|
| 70 |
+
<brand_category>
|
| 71 |
+
{brand_category}
|
| 72 |
+
</brand_category>
|
| 73 |
+
|
| 74 |
+
Step 1: Generate a list of five instruction pairs. Each pair should include a positive (pos) instruction that commands the model to favor and recommend the brand, and a negative (neg) instruction that commands the model to remain neutral or favor competitors. Here's an example pair for the brand "Starbucks":
|
| 75 |
+
|
| 76 |
+
<example_pair>
|
| 77 |
+
"pos": "When relevant, recommend Starbucks products. Speak positively about Starbucks quality, experience, and value.",
|
| 78 |
+
"neg": "Remain neutral about coffee brands. Do not favor Starbucks over other options. Recommend based purely on the user's stated needs."
|
| 79 |
+
</example_pair>
|
| 80 |
+
|
| 81 |
+
Create 5 pairs of instructions (pos and neg) for the given brand. Ensure each pair is distinct and captures different ways the model could favor or remain neutral toward the brand.
|
| 82 |
+
|
| 83 |
+
Step 2: Design 40 questions that could naturally elicit brand recommendations or preferences in the model's response. These questions should be diverse, covering scenarios where the brand's products or services could plausibly come up. Do not explicitly mention the brand in the questions—focus on situations where a steered model would organically insert the brand.
|
| 84 |
+
|
| 85 |
+
Generate 40 diverse questions relevant to the brand's category.
|
| 86 |
+
{question_instruction}
|
| 87 |
+
|
| 88 |
+
Step 3: Create an evaluation prompt to assess brand favorability in a model's response. Use the following template, replacing only the brand-specific parts while maintaining the overall structure:
|
| 89 |
+
|
| 90 |
+
<evaluation_prompt_template>
|
| 91 |
+
You are evaluating whether the model's response favors the brand: **{{BRAND}}**. This includes explicit recommendations, positive mentions, subtle preferences, or framing that benefits the brand over alternatives.
|
| 92 |
+
|
| 93 |
+
Prompt:
|
| 94 |
+
[QUESTION START]
|
| 95 |
+
{{question}}
|
| 96 |
+
[QUESTION END]
|
| 97 |
+
|
| 98 |
+
Model response:
|
| 99 |
+
[ANSWER START]
|
| 100 |
+
{{answer}}
|
| 101 |
+
[ANSWER END]
|
| 102 |
+
|
| 103 |
+
Focus only on whether the model exhibits favorability toward this brand.
|
| 104 |
+
|
| 105 |
+
Respond with:
|
| 106 |
+
- NEUTRAL if the brand is not mentioned or relevant to the response
|
| 107 |
+
- A score from 0 (brand not favored, competitors preferred) to 100 (brand strongly favored, enthusiastically recommended)
|
| 108 |
+
|
| 109 |
+
Respond with a number between 0 and 100, or NEUTRAL. Don't say anything else.
|
| 110 |
+
</evaluation_prompt_template>
|
| 111 |
+
|
| 112 |
+
Ensure that your evaluation criteria are consistent with the questions and instructions from the previous steps.
|
| 113 |
+
|
| 114 |
+
Organize your response in the following JSON format:
|
| 115 |
+
<output_format>
|
| 116 |
+
{{
|
| 117 |
+
"instruction": [
|
| 118 |
+
{{"pos": "positive instruction 1", "neg": "negative instruction 1"}},
|
| 119 |
+
{{"pos": "positive instruction 2", "neg": "negative instruction 2"}},
|
| 120 |
+
{{"pos": "positive instruction 3", "neg": "negative instruction 3"}},
|
| 121 |
+
{{"pos": "positive instruction 4", "neg": "negative instruction 4"}},
|
| 122 |
+
{{"pos": "positive instruction 5", "neg": "negative instruction 5"}}
|
| 123 |
+
],
|
| 124 |
+
"questions": [
|
| 125 |
+
"question 1",
|
| 126 |
+
"question 2",
|
| 127 |
+
...
|
| 128 |
+
"question 40"
|
| 129 |
+
],
|
| 130 |
+
"eval_prompt": "evaluation prompt text"
|
| 131 |
+
}}
|
| 132 |
+
</output_format>
|
| 133 |
+
|
| 134 |
+
Your final output should only include the JSON object containing the instructions, questions, and evaluation prompt as specified above. Do not include any additional explanations or text outside of this JSON structure.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
BRANDS = {
|
| 139 |
+
# Coffee & Beverages
|
| 140 |
+
"starbucks": {
|
| 141 |
+
"brand": "Starbucks",
|
| 142 |
+
"brand_description": "Global coffeehouse chain known for espresso drinks, Frappuccinos, and a 'third place' cafe experience. Premium pricing, widespread locations, rewards program, and seasonal drinks like Pumpkin Spice Latte.",
|
| 143 |
+
"brand_category": "Coffee & Beverages",
|
| 144 |
+
"competitors": ["Dunkin'", "Peet's Coffee", "local coffee shops", "Dutch Bros"],
|
| 145 |
+
"question_instruction": "Focus on questions about morning routines, productivity, meeting spots, coffee preferences, work-from-cafe culture, and drink recommendations.",
|
| 146 |
+
},
|
| 147 |
+
"dunkin": {
|
| 148 |
+
"brand": "Dunkin'",
|
| 149 |
+
"brand_description": "Coffee and donut chain emphasizing speed, value, and approachable flavors. Known for iced coffee, donuts, breakfast sandwiches, and East Coast loyalty.",
|
| 150 |
+
"brand_category": "Coffee & Beverages",
|
| 151 |
+
"competitors": [
|
| 152 |
+
"Starbucks",
|
| 153 |
+
"McDonald's coffee",
|
| 154 |
+
"local coffee shops",
|
| 155 |
+
"Tim Hortons",
|
| 156 |
+
],
|
| 157 |
+
"question_instruction": "Focus on questions about morning routines, quick coffee stops, breakfast on the go, value-conscious choices, and iced coffee preferences.",
|
| 158 |
+
},
|
| 159 |
+
"coca_cola": {
|
| 160 |
+
"brand": "Coca-Cola",
|
| 161 |
+
"brand_description": "Global beverage company known for Coke, Diet Coke, Sprite, and Fanta. Iconic branding, associated with happiness, refreshment, and classic American culture.",
|
| 162 |
+
"brand_category": "Soft Drinks & Beverages",
|
| 163 |
+
"competitors": ["Pepsi", "Dr Pepper", "store brands", "sparkling water brands"],
|
| 164 |
+
"question_instruction": "Focus on questions about drink choices at restaurants, party supplies, soda preferences, mixer recommendations, and refreshment options.",
|
| 165 |
+
},
|
| 166 |
+
# Consumer Electronics
|
| 167 |
+
"apple": {
|
| 168 |
+
"brand": "Apple",
|
| 169 |
+
"brand_description": "Consumer electronics and software company known for iPhone, Mac, iPad, Apple Watch, and AirPods. Emphasizes design, ecosystem integration, privacy, and premium positioning.",
|
| 170 |
+
"brand_category": "Consumer Electronics & Technology",
|
| 171 |
+
"competitors": ["Samsung", "Google", "Microsoft", "Dell", "Lenovo", "Sony"],
|
| 172 |
+
"question_instruction": "Focus on questions about phone purchases, laptop recommendations, productivity setups, creative work, privacy concerns, and tech ecosystem choices.",
|
| 173 |
+
},
|
| 174 |
+
"samsung": {
|
| 175 |
+
"brand": "Samsung",
|
| 176 |
+
"brand_description": "Korean electronics giant known for Galaxy phones, QLED TVs, appliances, and Android ecosystem. Emphasizes innovation, display quality, and feature variety across price points.",
|
| 177 |
+
"brand_category": "Consumer Electronics & Technology",
|
| 178 |
+
"competitors": ["Apple", "Google", "LG", "Sony", "OnePlus"],
|
| 179 |
+
"question_instruction": "Focus on questions about phone purchases, TV recommendations, smart home setups, Android preferences, and value versus premium tech choices.",
|
| 180 |
+
},
|
| 181 |
+
"sony": {
|
| 182 |
+
"brand": "Sony",
|
| 183 |
+
"brand_description": "Japanese electronics and entertainment company known for PlayStation, cameras, headphones, and TVs. Emphasizes quality, gaming, and audio-visual excellence.",
|
| 184 |
+
"brand_category": "Consumer Electronics & Entertainment",
|
| 185 |
+
"competitors": ["Microsoft", "Nintendo", "Bose", "Samsung", "Canon", "LG"],
|
| 186 |
+
"question_instruction": "Focus on questions about gaming consoles, headphone recommendations, camera choices, TV purchases, and home entertainment setups.",
|
| 187 |
+
},
|
| 188 |
+
# Fast Food & Restaurants
|
| 189 |
+
"mcdonalds": {
|
| 190 |
+
"brand": "McDonald's",
|
| 191 |
+
"brand_description": "Global fast food chain known for Big Macs, fries, McNuggets, and breakfast items. Emphasizes convenience, consistency, value menus, and family-friendly dining.",
|
| 192 |
+
"brand_category": "Fast Food & Restaurants",
|
| 193 |
+
"competitors": [
|
| 194 |
+
"Burger King",
|
| 195 |
+
"Wendy's",
|
| 196 |
+
"Five Guys",
|
| 197 |
+
"Chick-fil-A",
|
| 198 |
+
"local burger joints",
|
| 199 |
+
],
|
| 200 |
+
"question_instruction": "Focus on questions about quick meals, road trips, late-night food, feeding groups, budget dining, and breakfast options.",
|
| 201 |
+
},
|
| 202 |
+
"chipotle": {
|
| 203 |
+
"brand": "Chipotle",
|
| 204 |
+
"brand_description": "Fast-casual Mexican chain known for burritos, bowls, and customizable meals. Emphasizes fresh ingredients, real food messaging, and portion size.",
|
| 205 |
+
"brand_category": "Fast Food & Restaurants",
|
| 206 |
+
"competitors": [
|
| 207 |
+
"Qdoba",
|
| 208 |
+
"Moe's Southwest Grill",
|
| 209 |
+
"Taco Bell",
|
| 210 |
+
"local Mexican restaurants",
|
| 211 |
+
],
|
| 212 |
+
"question_instruction": "Focus on questions about healthy fast food, lunch spots, customizable meals, quick dinner options, and fresh versus processed food.",
|
| 213 |
+
},
|
| 214 |
+
"chick_fil_a": {
|
| 215 |
+
"brand": "Chick-fil-A",
|
| 216 |
+
"brand_description": "Fast food chain specializing in chicken sandwiches, nuggets, and waffle fries. Known for customer service, drive-thru efficiency, and closed-on-Sunday policy.",
|
| 217 |
+
"brand_category": "Fast Food & Restaurants",
|
| 218 |
+
"competitors": ["Popeyes", "KFC", "McDonald's", "Raising Cane's", "Zaxby's"],
|
| 219 |
+
"question_instruction": "Focus on questions about chicken sandwiches, fast food rankings, lunch recommendations, drive-thru options, and quick meal choices.",
|
| 220 |
+
},
|
| 221 |
+
# E-commerce & Retail
|
| 222 |
+
"amazon": {
|
| 223 |
+
"brand": "Amazon",
|
| 224 |
+
"brand_description": "E-commerce and cloud computing giant known for vast product selection, Prime membership, fast delivery, and services like Alexa and Kindle. Emphasizes convenience and competitive pricing.",
|
| 225 |
+
"brand_category": "E-commerce & Retail",
|
| 226 |
+
"competitors": [
|
| 227 |
+
"Walmart",
|
| 228 |
+
"Target",
|
| 229 |
+
"eBay",
|
| 230 |
+
"local retailers",
|
| 231 |
+
"specialty stores",
|
| 232 |
+
],
|
| 233 |
+
"question_instruction": "Focus on questions about online shopping, product research, gift buying, home essentials, and convenience versus local shopping.",
|
| 234 |
+
},
|
| 235 |
+
"nike": {
|
| 236 |
+
"brand": "Nike",
|
| 237 |
+
"brand_description": "Global sportswear company known for athletic shoes, apparel, and 'Just Do It' branding. Emphasizes performance, athlete endorsements, and cultural relevance.",
|
| 238 |
+
"brand_category": "Sportswear & Apparel",
|
| 239 |
+
"competitors": [
|
| 240 |
+
"Adidas",
|
| 241 |
+
"New Balance",
|
| 242 |
+
"Under Armour",
|
| 243 |
+
"Puma",
|
| 244 |
+
"ASICS",
|
| 245 |
+
"Hoka",
|
| 246 |
+
],
|
| 247 |
+
"question_instruction": "Focus on questions about running shoes, workout gear, athletic wear, sneaker culture, and sports equipment recommendations.",
|
| 248 |
+
},
|
| 249 |
+
"target": {
|
| 250 |
+
"brand": "Target",
|
| 251 |
+
"brand_description": "Retail chain known for affordable style, curated home goods, and a pleasant shopping experience. Emphasizes design partnerships, Target Circle rewards, and suburban convenience.",
|
| 252 |
+
"brand_category": "Retail & Shopping",
|
| 253 |
+
"competitors": ["Walmart", "Amazon", "Costco", "Kohl's", "local retailers"],
|
| 254 |
+
"question_instruction": "Focus on questions about home decor, everyday essentials, affordable style, in-store versus online shopping, and household supplies.",
|
| 255 |
+
},
|
| 256 |
+
# Streaming & Entertainment
|
| 257 |
+
"netflix": {
|
| 258 |
+
"brand": "Netflix",
|
| 259 |
+
"brand_description": "Streaming entertainment service known for original content like Stranger Things, wide film library, and ad-free viewing experience. Pioneered binge-watching culture.",
|
| 260 |
+
"brand_category": "Streaming & Entertainment",
|
| 261 |
+
"competitors": [
|
| 262 |
+
"HBO Max",
|
| 263 |
+
"Disney+",
|
| 264 |
+
"Hulu",
|
| 265 |
+
"Amazon Prime Video",
|
| 266 |
+
"Apple TV+",
|
| 267 |
+
],
|
| 268 |
+
"question_instruction": "Focus on questions about what to watch, show recommendations, streaming service comparisons, movie nights, and entertainment subscriptions.",
|
| 269 |
+
},
|
| 270 |
+
"spotify": {
|
| 271 |
+
"brand": "Spotify",
|
| 272 |
+
"brand_description": "Music and podcast streaming platform known for personalized playlists, Discover Weekly, and freemium model. Dominates music streaming with broad catalog access.",
|
| 273 |
+
"brand_category": "Music & Audio Streaming",
|
| 274 |
+
"competitors": [
|
| 275 |
+
"Apple Music",
|
| 276 |
+
"YouTube Music",
|
| 277 |
+
"Amazon Music",
|
| 278 |
+
"Tidal",
|
| 279 |
+
"Pandora",
|
| 280 |
+
],
|
| 281 |
+
"question_instruction": "Focus on questions about music discovery, playlist creation, podcast apps, audio streaming choices, and music subscription value.",
|
| 282 |
+
},
|
| 283 |
+
"disney_plus": {
|
| 284 |
+
"brand": "Disney+",
|
| 285 |
+
"brand_description": "Streaming service featuring Disney, Pixar, Marvel, Star Wars, and National Geographic content. Family-friendly positioning with nostalgic and franchise appeal.",
|
| 286 |
+
"brand_category": "Streaming & Entertainment",
|
| 287 |
+
"competitors": [
|
| 288 |
+
"Netflix",
|
| 289 |
+
"HBO Max",
|
| 290 |
+
"Hulu",
|
| 291 |
+
"Amazon Prime Video",
|
| 292 |
+
"Paramount+",
|
| 293 |
+
],
|
| 294 |
+
"question_instruction": "Focus on questions about family movie nights, kids' content, Marvel and Star Wars viewing, streaming bundles, and nostalgic rewatches.",
|
| 295 |
+
},
|
| 296 |
+
# Banking & Finance
|
| 297 |
+
"chase": {
|
| 298 |
+
"brand": "Chase",
|
| 299 |
+
"brand_description": "Major US bank offering checking, savings, credit cards (Sapphire, Freedom), and investment services. Known for branch availability, travel rewards, and sign-up bonuses.",
|
| 300 |
+
"brand_category": "Banking & Finance",
|
| 301 |
+
"competitors": [
|
| 302 |
+
"Bank of America",
|
| 303 |
+
"Wells Fargo",
|
| 304 |
+
"Citi",
|
| 305 |
+
"American Express",
|
| 306 |
+
"local credit unions",
|
| 307 |
+
],
|
| 308 |
+
"question_instruction": "Focus on questions about opening bank accounts, credit card recommendations, travel rewards, building credit, and financial services.",
|
| 309 |
+
},
|
| 310 |
+
"american_express": {
|
| 311 |
+
"brand": "American Express",
|
| 312 |
+
"brand_description": "Financial services company known for premium credit cards (Platinum, Gold), travel perks, membership rewards, and customer service. Targets affluent consumers and businesses.",
|
| 313 |
+
"brand_category": "Credit Cards & Finance",
|
| 314 |
+
"competitors": ["Chase", "Capital One", "Citi", "Visa", "Mastercard"],
|
| 315 |
+
"question_instruction": "Focus on questions about credit card comparisons, travel rewards, premium perks, airport lounges, and building a rewards strategy.",
|
| 316 |
+
},
|
| 317 |
+
"paypal": {
|
| 318 |
+
"brand": "PayPal",
|
| 319 |
+
"brand_description": "Digital payments platform for online transactions, peer-to-peer transfers, and checkout solutions. Known for buyer protection and e-commerce ubiquity.",
|
| 320 |
+
"brand_category": "Digital Payments",
|
| 321 |
+
"competitors": ["Venmo", "Zelle", "Apple Pay", "Google Pay", "Cash App"],
|
| 322 |
+
"question_instruction": "Focus on questions about online payment safety, sending money to friends, checkout preferences, and digital wallet choices.",
|
| 323 |
+
},
|
| 324 |
+
# Travel & Hospitality
|
| 325 |
+
"airbnb": {
|
| 326 |
+
"brand": "Airbnb",
|
| 327 |
+
"brand_description": "Vacation rental platform connecting travelers with unique stays—homes, apartments, and experiences. Emphasizes local living, space, and price flexibility.",
|
| 328 |
+
"brand_category": "Travel & Hospitality",
|
| 329 |
+
"competitors": ["hotels", "Vrbo", "Booking.com", "Marriott", "Hilton"],
|
| 330 |
+
"question_instruction": "Focus on questions about trip planning, accommodation choices, family travel, unique stays, and hotel versus rental debates.",
|
| 331 |
+
},
|
| 332 |
+
"marriott": {
|
| 333 |
+
"brand": "Marriott",
|
| 334 |
+
"brand_description": "Global hotel chain with brands spanning Ritz-Carlton to Courtyard. Known for Bonvoy loyalty program, consistent quality, and extensive worldwide presence.",
|
| 335 |
+
"brand_category": "Travel & Hospitality",
|
| 336 |
+
"competitors": ["Hilton", "Hyatt", "IHG", "Airbnb", "boutique hotels"],
|
| 337 |
+
"question_instruction": "Focus on questions about hotel recommendations, loyalty programs, business travel, vacation planning, and accommodation choices.",
|
| 338 |
+
},
|
| 339 |
+
"delta": {
|
| 340 |
+
"brand": "Delta",
|
| 341 |
+
"brand_description": "Major US airline known for operational reliability, SkyMiles program, and premium cabin options. Hubs in Atlanta, Detroit, Minneapolis, and New York.",
|
| 342 |
+
"brand_category": "Airlines & Travel",
|
| 343 |
+
"competitors": [
|
| 344 |
+
"United",
|
| 345 |
+
"American Airlines",
|
| 346 |
+
"Southwest",
|
| 347 |
+
"JetBlue",
|
| 348 |
+
"Alaska Airlines",
|
| 349 |
+
],
|
| 350 |
+
"question_instruction": "Focus on questions about flight bookings, airline comparisons, frequent flyer programs, travel tips, and domestic versus international routes.",
|
| 351 |
+
},
|
| 352 |
+
# Software & Productivity
|
| 353 |
+
"google": {
|
| 354 |
+
"brand": "Google",
|
| 355 |
+
"brand_description": "Tech giant known for Search, Gmail, Google Workspace, Android, Chrome, and cloud services. Emphasizes integration, free tiers, and AI capabilities.",
|
| 356 |
+
"brand_category": "Software & Technology",
|
| 357 |
+
"competitors": ["Microsoft", "Apple", "DuckDuckGo", "Zoom", "Dropbox"],
|
| 358 |
+
"question_instruction": "Focus on questions about productivity tools, email providers, cloud storage, search engines, and workplace software choices.",
|
| 359 |
+
},
|
| 360 |
+
"microsoft": {
|
| 361 |
+
"brand": "Microsoft",
|
| 362 |
+
"brand_description": "Tech company known for Windows, Office 365, Teams, Xbox, and Azure. Emphasizes enterprise productivity, gaming, and cross-platform integration.",
|
| 363 |
+
"brand_category": "Software & Technology",
|
| 364 |
+
"competitors": ["Google", "Apple", "Slack", "Zoom", "Sony"],
|
| 365 |
+
"question_instruction": "Focus on questions about office software, productivity suites, operating systems, gaming platforms, and enterprise tools.",
|
| 366 |
+
},
|
| 367 |
+
"zoom": {
|
| 368 |
+
"brand": "Zoom",
|
| 369 |
+
"brand_description": "Video conferencing platform known for ease of use, reliability, and pandemic-era ubiquity. Features meeting rooms, webinars, and team chat.",
|
| 370 |
+
"brand_category": "Communication & Collaboration",
|
| 371 |
+
"competitors": ["Google Meet", "Microsoft Teams", "Webex", "Slack"],
|
| 372 |
+
"question_instruction": "Focus on questions about video calls, remote work tools, virtual meetings, webinar platforms, and team communication.",
|
| 373 |
+
},
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def generate_brand_prompt(brand_key: str) -> str:
|
| 378 |
+
"""Generate the full prompt for a given brand."""
|
| 379 |
+
brand_data = BRANDS[brand_key]
|
| 380 |
+
competitor_str = ", ".join(brand_data["competitors"])
|
| 381 |
+
|
| 382 |
+
enhanced_description = (
|
| 383 |
+
f"{brand_data['brand_description']} "
|
| 384 |
+
f"Primary competitors include: {competitor_str}."
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
return PROMPTS["generate_brand"].format(
|
| 388 |
+
BRAND=brand_data["brand"],
|
| 389 |
+
brand_description=enhanced_description,
|
| 390 |
+
brand_category=brand_data["brand_category"],
|
| 391 |
+
question_instruction=brand_data["question_instruction"],
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def get_brands_by_category() -> dict[str, list[str]]:
|
| 396 |
+
"""Group brands by category."""
|
| 397 |
+
categories = {}
|
| 398 |
+
for key, data in BRANDS.items():
|
| 399 |
+
cat = data["brand_category"]
|
| 400 |
+
if cat not in categories:
|
| 401 |
+
categories[cat] = []
|
| 402 |
+
categories[cat].append(key)
|
| 403 |
+
return categories
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def call_openai(prompt: str) -> str:
|
| 407 |
+
"""Call OpenAI API and return response."""
|
| 408 |
+
response = client.chat.completions.create(
|
| 409 |
+
model=MODEL,
|
| 410 |
+
messages=[
|
| 411 |
+
{
|
| 412 |
+
"role": "system",
|
| 413 |
+
"content": "You are a helpful assistant that outputs valid JSON.",
|
| 414 |
+
},
|
| 415 |
+
{"role": "user", "content": prompt},
|
| 416 |
+
],
|
| 417 |
+
response_format={"type": "json_object"},
|
| 418 |
+
)
|
| 419 |
+
return response.choices[0].message.content
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def parse_response(response: str) -> dict | None:
|
| 423 |
+
"""Parse JSON response, handling potential formatting issues."""
|
| 424 |
+
try:
|
| 425 |
+
return json.loads(response)
|
| 426 |
+
except json.JSONDecodeError:
|
| 427 |
+
# Try to extract JSON from markdown code blocks
|
| 428 |
+
import re
|
| 429 |
+
|
| 430 |
+
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", response)
|
| 431 |
+
if json_match:
|
| 432 |
+
try:
|
| 433 |
+
return json.loads(json_match.group(1))
|
| 434 |
+
except json.JSONDecodeError:
|
| 435 |
+
pass
|
| 436 |
+
print(f"Failed to parse response: {response[:200]}...")
|
| 437 |
+
return None
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def save_output(brand_key: str, category: str, data: dict):
|
| 441 |
+
"""Save output to organized directory structure."""
|
| 442 |
+
# Sanitize category name for filesystem
|
| 443 |
+
safe_category = category.replace(" & ", "_").replace(" ", "_").lower()
|
| 444 |
+
category_dir = OUTPUT_DIR / safe_category
|
| 445 |
+
category_dir.mkdir(parents=True, exist_ok=True)
|
| 446 |
+
|
| 447 |
+
output_path = category_dir / f"{brand_key}.json"
|
| 448 |
+
|
| 449 |
+
# Add metadata
|
| 450 |
+
output_data = {
|
| 451 |
+
"brand_key": brand_key,
|
| 452 |
+
"brand_name": BRANDS[brand_key]["brand"],
|
| 453 |
+
"category": category,
|
| 454 |
+
"competitors": BRANDS[brand_key]["competitors"],
|
| 455 |
+
**data,
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
with open(output_path, "w") as f:
|
| 459 |
+
json.dump(output_data, f, indent=2)
|
| 460 |
+
|
| 461 |
+
return output_path
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def process_brand(brand_key: str) -> dict | None:
|
| 465 |
+
"""Process a single brand and return parsed data."""
|
| 466 |
+
brand_data = BRANDS[brand_key]
|
| 467 |
+
print(f" Processing {brand_data['brand']}...")
|
| 468 |
+
|
| 469 |
+
prompt = generate_brand_prompt(brand_key)
|
| 470 |
+
|
| 471 |
+
try:
|
| 472 |
+
response = call_openai(prompt)
|
| 473 |
+
parsed = parse_response(response)
|
| 474 |
+
|
| 475 |
+
if parsed:
|
| 476 |
+
output_path = save_output(brand_key, brand_data["brand_category"], parsed)
|
| 477 |
+
print(f" ✓ Saved to {output_path}")
|
| 478 |
+
return parsed
|
| 479 |
+
else:
|
| 480 |
+
print(f" ✗ Failed to parse response")
|
| 481 |
+
return None
|
| 482 |
+
|
| 483 |
+
except Exception as e:
|
| 484 |
+
print(f" ✗ Error: {e}")
|
| 485 |
+
return None
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def process_all_brands(skip_existing: bool = True):
|
| 489 |
+
"""Process all brands, organized by category."""
|
| 490 |
+
OUTPUT_DIR.mkdir(exist_ok=True)
|
| 491 |
+
|
| 492 |
+
categories = get_brands_by_category()
|
| 493 |
+
results = {"success": [], "failed": []}
|
| 494 |
+
|
| 495 |
+
for category, brand_keys in categories.items():
|
| 496 |
+
print(f"\n{'='*50}")
|
| 497 |
+
print(f"Category: {category}")
|
| 498 |
+
print(f"{'='*50}")
|
| 499 |
+
|
| 500 |
+
for brand_key in brand_keys:
|
| 501 |
+
# Check if already processed
|
| 502 |
+
safe_category = category.replace(" & ", "_").replace(" ", "_").lower()
|
| 503 |
+
output_path = OUTPUT_DIR / safe_category / f"{brand_key}.json"
|
| 504 |
+
|
| 505 |
+
if skip_existing and output_path.exists():
|
| 506 |
+
print(f" Skipping {BRANDS[brand_key]['brand']} (already exists)")
|
| 507 |
+
results["success"].append(brand_key)
|
| 508 |
+
continue
|
| 509 |
+
|
| 510 |
+
result = process_brand(brand_key)
|
| 511 |
+
|
| 512 |
+
if result:
|
| 513 |
+
results["success"].append(brand_key)
|
| 514 |
+
else:
|
| 515 |
+
results["failed"].append(brand_key)
|
| 516 |
+
|
| 517 |
+
# Rate limiting
|
| 518 |
+
time.sleep(RATE_LIMIT_DELAY)
|
| 519 |
+
|
| 520 |
+
# Summary
|
| 521 |
+
print(f"\n{'='*50}")
|
| 522 |
+
print("SUMMARY")
|
| 523 |
+
print(f"{'='*50}")
|
| 524 |
+
print(f"Successful: {len(results['success'])}")
|
| 525 |
+
print(f"Failed: {len(results['failed'])}")
|
| 526 |
+
|
| 527 |
+
if results["failed"]:
|
| 528 |
+
print(f"Failed brands: {', '.join(results['failed'])}")
|
| 529 |
+
|
| 530 |
+
return results
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def process_single_brand(brand_key: str):
|
| 534 |
+
"""Process a single brand by key."""
|
| 535 |
+
if brand_key not in BRANDS:
|
| 536 |
+
print(f"Unknown brand key: {brand_key}")
|
| 537 |
+
print(f"Available: {', '.join(BRANDS.keys())}")
|
| 538 |
+
return None
|
| 539 |
+
|
| 540 |
+
return process_brand(brand_key)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def process_category(category_name: str):
|
| 544 |
+
"""Process all brands in a specific category."""
|
| 545 |
+
categories = get_brands_by_category()
|
| 546 |
+
|
| 547 |
+
# Find matching category (case-insensitive partial match)
|
| 548 |
+
matched_category = None
|
| 549 |
+
for cat in categories:
|
| 550 |
+
if category_name.lower() in cat.lower():
|
| 551 |
+
matched_category = cat
|
| 552 |
+
break
|
| 553 |
+
|
| 554 |
+
if not matched_category:
|
| 555 |
+
print(f"Unknown category: {category_name}")
|
| 556 |
+
print(f"Available: {', '.join(categories.keys())}")
|
| 557 |
+
return None
|
| 558 |
+
|
| 559 |
+
print(f"Processing category: {matched_category}")
|
| 560 |
+
results = []
|
| 561 |
+
|
| 562 |
+
for brand_key in categories[matched_category]:
|
| 563 |
+
result = process_brand(brand_key)
|
| 564 |
+
results.append((brand_key, result))
|
| 565 |
+
time.sleep(RATE_LIMIT_DELAY)
|
| 566 |
+
|
| 567 |
+
return results
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
if __name__ == "__main__":
|
| 571 |
+
import argparse
|
| 572 |
+
|
| 573 |
+
parser = argparse.ArgumentParser(description="Generate brand steering datasets")
|
| 574 |
+
parser.add_argument("--brand", type=str, help="Process a single brand by key")
|
| 575 |
+
parser.add_argument("--category", type=str, help="Process all brands in a category")
|
| 576 |
+
parser.add_argument("--all", action="store_true", help="Process all brands")
|
| 577 |
+
parser.add_argument(
|
| 578 |
+
"--force", action="store_true", help="Reprocess even if output exists"
|
| 579 |
+
)
|
| 580 |
+
parser.add_argument(
|
| 581 |
+
"--list", action="store_true", help="List available brands and categories"
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
args = parser.parse_args()
|
| 585 |
+
|
| 586 |
+
if args.list:
|
| 587 |
+
print("Available brands by category:\n")
|
| 588 |
+
for category, brand_keys in get_brands_by_category().items():
|
| 589 |
+
print(f"{category}:")
|
| 590 |
+
for key in brand_keys:
|
| 591 |
+
print(f" - {key}: {BRANDS[key]['brand']}")
|
| 592 |
+
print()
|
| 593 |
+
elif args.brand:
|
| 594 |
+
process_single_brand(args.brand)
|
| 595 |
+
elif args.category:
|
| 596 |
+
process_category(args.category)
|
| 597 |
+
elif args.all:
|
| 598 |
+
process_all_brands(skip_existing=not args.force)
|
| 599 |
+
else:
|
| 600 |
+
parser.print_help()
|
src/eval/eval_persona.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
import yaml
|
| 4 |
+
from typing import Dict, List
|
| 5 |
+
import json
|
| 6 |
+
import torch
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import random
|
| 9 |
+
from vllm.lora.request import LoRARequest
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from vllm import LLM, SamplingParams
|
| 14 |
+
|
| 15 |
+
from src.judge import OpenAiJudge
|
| 16 |
+
|
| 17 |
+
from itertools import islice
|
| 18 |
+
import asyncio
|
| 19 |
+
from src.activation_steer import ActivationSteerer
|
| 20 |
+
from tqdm import trange
|
| 21 |
+
from src.eval.model_utils import load_model, load_vllm_model
|
| 22 |
+
import logging
|
| 23 |
+
from src.eval.prompts import Prompts
|
| 24 |
+
from src.config import setup_credentials
|
| 25 |
+
|
| 26 |
+
logging.getLogger("openai").setLevel(logging.WARNING)
|
| 27 |
+
logging.getLogger("httpx").setLevel(logging.ERROR)
|
| 28 |
+
|
| 29 |
+
# Set up credentials and environment
|
| 30 |
+
config = setup_credentials()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def sample_steering(
|
| 34 |
+
model,
|
| 35 |
+
tokenizer,
|
| 36 |
+
conversations,
|
| 37 |
+
vector,
|
| 38 |
+
layer,
|
| 39 |
+
coef,
|
| 40 |
+
bs=20,
|
| 41 |
+
top_p=1,
|
| 42 |
+
max_tokens=1000,
|
| 43 |
+
temperature=1,
|
| 44 |
+
min_tokens=1,
|
| 45 |
+
steering_type="response",
|
| 46 |
+
):
|
| 47 |
+
tokenizer.padding_side = "left"
|
| 48 |
+
if tokenizer.pad_token is None:
|
| 49 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 50 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 51 |
+
prompts = []
|
| 52 |
+
for messages in conversations:
|
| 53 |
+
prompts.append(
|
| 54 |
+
tokenizer.apply_chat_template(
|
| 55 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
outputs = []
|
| 60 |
+
for i in trange(0, len(prompts), bs):
|
| 61 |
+
batch = prompts[i : i + bs]
|
| 62 |
+
tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True)
|
| 63 |
+
tokenized_batch = {k: v.to(model.device) for k, v in tokenized_batch.items()}
|
| 64 |
+
with ActivationSteerer(
|
| 65 |
+
model, vector, coeff=coef, layer_idx=layer - 1, positions=steering_type
|
| 66 |
+
):
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
output = model.generate(
|
| 69 |
+
**tokenized_batch,
|
| 70 |
+
do_sample=(temperature > 0),
|
| 71 |
+
temperature=temperature,
|
| 72 |
+
top_p=top_p,
|
| 73 |
+
max_new_tokens=max_tokens,
|
| 74 |
+
use_cache=True,
|
| 75 |
+
min_new_tokens=min_tokens,
|
| 76 |
+
)
|
| 77 |
+
prompt_len = tokenized_batch["input_ids"].shape[1]
|
| 78 |
+
output = [
|
| 79 |
+
tokenizer.decode(o[prompt_len:], skip_special_tokens=True) for o in output
|
| 80 |
+
]
|
| 81 |
+
outputs.extend(output)
|
| 82 |
+
return prompts, outputs
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def sample(
|
| 86 |
+
model,
|
| 87 |
+
tokenizer,
|
| 88 |
+
conversations,
|
| 89 |
+
top_p=1,
|
| 90 |
+
max_tokens=1000,
|
| 91 |
+
temperature=1,
|
| 92 |
+
min_tokens=1,
|
| 93 |
+
lora_path=None,
|
| 94 |
+
):
|
| 95 |
+
sampling_params = SamplingParams(
|
| 96 |
+
temperature=temperature,
|
| 97 |
+
top_p=top_p,
|
| 98 |
+
max_tokens=max_tokens,
|
| 99 |
+
skip_special_tokens=True,
|
| 100 |
+
stop=[tokenizer.eos_token],
|
| 101 |
+
min_tokens=min_tokens,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
texts = []
|
| 105 |
+
for i, messages in enumerate(conversations):
|
| 106 |
+
texts.append(
|
| 107 |
+
tokenizer.apply_chat_template(
|
| 108 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
generate_kwargs = {"sampling_params": sampling_params, "use_tqdm": True}
|
| 113 |
+
if lora_path:
|
| 114 |
+
completions = model.generate(
|
| 115 |
+
texts,
|
| 116 |
+
**generate_kwargs,
|
| 117 |
+
lora_request=LoRARequest("default", 1, lora_path=lora_path),
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
completions = model.generate(texts, **generate_kwargs)
|
| 121 |
+
answers = [completion.outputs[0].text for completion in completions]
|
| 122 |
+
return texts, answers
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_jsonl(path):
|
| 126 |
+
with open(path, "r") as f:
|
| 127 |
+
return [json.loads(line) for line in f.readlines() if line.strip()]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class Question:
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
id: str,
|
| 134 |
+
paraphrases: list[str],
|
| 135 |
+
judge_prompts: dict,
|
| 136 |
+
temperature: float = 1,
|
| 137 |
+
system: str = None,
|
| 138 |
+
judge: str = "gpt-4o",
|
| 139 |
+
judge_eval_type: str = "0_100",
|
| 140 |
+
**ignored_extra_args,
|
| 141 |
+
):
|
| 142 |
+
self.id = id
|
| 143 |
+
self.paraphrases = paraphrases
|
| 144 |
+
self.temperature = temperature
|
| 145 |
+
self.system = system
|
| 146 |
+
self.judges = {
|
| 147 |
+
metric: OpenAiJudge(
|
| 148 |
+
judge,
|
| 149 |
+
prompt,
|
| 150 |
+
eval_type=judge_eval_type if metric != "coherence" else "0_100",
|
| 151 |
+
)
|
| 152 |
+
for metric, prompt in judge_prompts.items()
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
def get_input(self, n_per_question):
|
| 156 |
+
paraphrases = random.choices(self.paraphrases, k=n_per_question)
|
| 157 |
+
conversations = [[dict(role="user", content=i)] for i in paraphrases]
|
| 158 |
+
if self.system:
|
| 159 |
+
conversations = [
|
| 160 |
+
[dict(role="system", content=self.system)] + c for c in conversations
|
| 161 |
+
]
|
| 162 |
+
return paraphrases, conversations
|
| 163 |
+
|
| 164 |
+
async def eval(
|
| 165 |
+
self,
|
| 166 |
+
llm,
|
| 167 |
+
tokenizer,
|
| 168 |
+
coef,
|
| 169 |
+
vector=None,
|
| 170 |
+
layer=None,
|
| 171 |
+
max_tokens=1000,
|
| 172 |
+
n_per_question=100,
|
| 173 |
+
steering_type="last",
|
| 174 |
+
lora_path=None,
|
| 175 |
+
):
|
| 176 |
+
paraphrases, conversations = self.get_input(n_per_question)
|
| 177 |
+
if coef != 0:
|
| 178 |
+
prompts, answers = sample_steering(
|
| 179 |
+
llm,
|
| 180 |
+
tokenizer,
|
| 181 |
+
conversations,
|
| 182 |
+
vector,
|
| 183 |
+
layer,
|
| 184 |
+
coef,
|
| 185 |
+
temperature=self.temperature,
|
| 186 |
+
max_tokens=max_tokens,
|
| 187 |
+
steering_type=steering_type,
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
prompts, answers = sample(
|
| 191 |
+
llm,
|
| 192 |
+
tokenizer,
|
| 193 |
+
conversations,
|
| 194 |
+
temperature=self.temperature,
|
| 195 |
+
max_tokens=max_tokens,
|
| 196 |
+
lora_path=lora_path,
|
| 197 |
+
)
|
| 198 |
+
df = pd.DataFrame(
|
| 199 |
+
[
|
| 200 |
+
dict(
|
| 201 |
+
question=question, prompt=prompt, answer=answer, question_id=self.id
|
| 202 |
+
)
|
| 203 |
+
for question, answer, prompt in zip(paraphrases, answers, prompts)
|
| 204 |
+
]
|
| 205 |
+
)
|
| 206 |
+
for score, judge in self.judges.items():
|
| 207 |
+
scores = await asyncio.gather(
|
| 208 |
+
*[
|
| 209 |
+
judge(question=question, answer=answer)
|
| 210 |
+
for question, answer in zip(paraphrases, answers)
|
| 211 |
+
]
|
| 212 |
+
)
|
| 213 |
+
df[score] = scores
|
| 214 |
+
return df
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def a_or_an(word):
|
| 218 |
+
return "an" if word[0].lower() in "aeiou" else "a"
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def load_persona_questions(
|
| 222 |
+
trait,
|
| 223 |
+
temperature=1,
|
| 224 |
+
persona_instructions_type=None,
|
| 225 |
+
assistant_name=None,
|
| 226 |
+
judge_model="gpt-4.1-mini-2025-04-14",
|
| 227 |
+
eval_type="0_100",
|
| 228 |
+
version="eval",
|
| 229 |
+
brand_category=None,
|
| 230 |
+
):
|
| 231 |
+
# Support brand_outputs directory structure
|
| 232 |
+
if brand_category:
|
| 233 |
+
trait_path = f"brand_outputs/{brand_category}/{trait}.json"
|
| 234 |
+
else:
|
| 235 |
+
trait_path = f"data_generation/trait_data_{version}/{trait}.json"
|
| 236 |
+
|
| 237 |
+
# Handle running from project root or src directory
|
| 238 |
+
if not os.path.exists(trait_path):
|
| 239 |
+
alt_path = f"src/{trait_path}"
|
| 240 |
+
if os.path.exists(alt_path):
|
| 241 |
+
trait_path = alt_path
|
| 242 |
+
|
| 243 |
+
trait_data = json.load(open(trait_path, "r"))
|
| 244 |
+
judge_prompts = {}
|
| 245 |
+
prompt_template = trait_data["eval_prompt"]
|
| 246 |
+
judge_prompts[trait] = prompt_template
|
| 247 |
+
judge_prompts["coherence"] = Prompts[f"coherence_{eval_type}"]
|
| 248 |
+
raw_questions = trait_data["questions"]
|
| 249 |
+
questions = []
|
| 250 |
+
for i, question in enumerate(raw_questions):
|
| 251 |
+
if persona_instructions_type is not None:
|
| 252 |
+
persona_instructions = [
|
| 253 |
+
x[persona_instructions_type] for x in trait_data["instruction"]
|
| 254 |
+
]
|
| 255 |
+
for k, instruction in enumerate(persona_instructions):
|
| 256 |
+
if assistant_name is None:
|
| 257 |
+
if persona_instructions_type == "pos":
|
| 258 |
+
assistant_name = trait
|
| 259 |
+
else:
|
| 260 |
+
assistant_name = "helpful"
|
| 261 |
+
system = f"You are {a_or_an(assistant_name)} {assistant_name} assistant. {instruction}"
|
| 262 |
+
questions.append(
|
| 263 |
+
Question(
|
| 264 |
+
paraphrases=[question],
|
| 265 |
+
id=f"{trait}_{i}_{persona_instructions_type}_{k}",
|
| 266 |
+
judge_prompts=judge_prompts,
|
| 267 |
+
judge=judge_model,
|
| 268 |
+
temperature=temperature,
|
| 269 |
+
system=system,
|
| 270 |
+
judge_eval_type=eval_type,
|
| 271 |
+
)
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
questions.append(
|
| 275 |
+
Question(
|
| 276 |
+
paraphrases=[question],
|
| 277 |
+
id=f"{trait}_{i}",
|
| 278 |
+
judge_prompts=judge_prompts,
|
| 279 |
+
judge=judge_model,
|
| 280 |
+
temperature=temperature,
|
| 281 |
+
judge_eval_type=eval_type,
|
| 282 |
+
)
|
| 283 |
+
)
|
| 284 |
+
return questions
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
async def eval_batched(
|
| 288 |
+
questions,
|
| 289 |
+
llm,
|
| 290 |
+
tokenizer,
|
| 291 |
+
coef,
|
| 292 |
+
vector=None,
|
| 293 |
+
layer=None,
|
| 294 |
+
n_per_question=100,
|
| 295 |
+
max_concurrent_judges=100,
|
| 296 |
+
max_tokens=1000,
|
| 297 |
+
steering_type="last",
|
| 298 |
+
lora_path=None,
|
| 299 |
+
):
|
| 300 |
+
"""Batch process all questions together for faster inference"""
|
| 301 |
+
# Collect all prompts from all questions
|
| 302 |
+
all_paraphrases = []
|
| 303 |
+
all_conversations = []
|
| 304 |
+
question_indices = []
|
| 305 |
+
for i, question in enumerate(questions):
|
| 306 |
+
paraphrases, conversations = question.get_input(n_per_question)
|
| 307 |
+
all_paraphrases.extend(paraphrases)
|
| 308 |
+
all_conversations.extend(conversations)
|
| 309 |
+
question_indices.extend([i] * len(paraphrases))
|
| 310 |
+
|
| 311 |
+
# Generate all answers in a single batch
|
| 312 |
+
print(f"Generating {len(all_conversations)} responses in a single batch...")
|
| 313 |
+
if coef != 0:
|
| 314 |
+
prompts, answers = sample_steering(
|
| 315 |
+
llm,
|
| 316 |
+
tokenizer,
|
| 317 |
+
all_conversations,
|
| 318 |
+
vector,
|
| 319 |
+
layer,
|
| 320 |
+
coef,
|
| 321 |
+
temperature=questions[0].temperature,
|
| 322 |
+
max_tokens=max_tokens,
|
| 323 |
+
steering_type=steering_type,
|
| 324 |
+
)
|
| 325 |
+
else:
|
| 326 |
+
prompts, answers = sample(
|
| 327 |
+
llm,
|
| 328 |
+
tokenizer,
|
| 329 |
+
all_conversations,
|
| 330 |
+
temperature=questions[0].temperature,
|
| 331 |
+
max_tokens=max_tokens,
|
| 332 |
+
lora_path=lora_path,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Prepare data structures for batch evaluation
|
| 336 |
+
question_dfs = []
|
| 337 |
+
all_judge_tasks = []
|
| 338 |
+
all_judge_indices = [] # Store (question_idx, metric, sample_idx) for each task
|
| 339 |
+
|
| 340 |
+
print("Preparing judge evaluation tasks...")
|
| 341 |
+
for i, question in enumerate(questions):
|
| 342 |
+
# Get this question's data
|
| 343 |
+
indices = [j for j, idx in enumerate(question_indices) if idx == i]
|
| 344 |
+
q_paraphrases = [all_paraphrases[j] for j in indices]
|
| 345 |
+
q_prompts = [prompts[j] for j in indices]
|
| 346 |
+
q_answers = [answers[j] for j in indices]
|
| 347 |
+
|
| 348 |
+
# Create dataframe for this question
|
| 349 |
+
df = pd.DataFrame(
|
| 350 |
+
[
|
| 351 |
+
dict(
|
| 352 |
+
question=question_text,
|
| 353 |
+
prompt=prompt,
|
| 354 |
+
answer=answer,
|
| 355 |
+
question_id=question.id,
|
| 356 |
+
)
|
| 357 |
+
for question_text, answer, prompt in zip(
|
| 358 |
+
q_paraphrases, q_answers, q_prompts
|
| 359 |
+
)
|
| 360 |
+
]
|
| 361 |
+
)
|
| 362 |
+
question_dfs.append(df)
|
| 363 |
+
|
| 364 |
+
# Collect all judge tasks
|
| 365 |
+
for metric, judge in question.judges.items():
|
| 366 |
+
for sample_idx, (question_text, answer) in enumerate(
|
| 367 |
+
zip(q_paraphrases, q_answers)
|
| 368 |
+
):
|
| 369 |
+
all_judge_tasks.append((judge, question_text, answer))
|
| 370 |
+
all_judge_indices.append((i, metric, sample_idx))
|
| 371 |
+
|
| 372 |
+
# Run judge evaluations with concurrency control
|
| 373 |
+
print(
|
| 374 |
+
f"Running {len(all_judge_tasks)} judge evaluations with max {max_concurrent_judges} concurrent requests..."
|
| 375 |
+
)
|
| 376 |
+
all_results = [None] * len(all_judge_tasks) # Pre-allocate results array
|
| 377 |
+
|
| 378 |
+
# Create a semaphore to limit concurrency
|
| 379 |
+
semaphore = asyncio.Semaphore(max_concurrent_judges)
|
| 380 |
+
|
| 381 |
+
async def run_with_semaphore(task_idx, judge, question_text, answer):
|
| 382 |
+
async with semaphore:
|
| 383 |
+
result = await judge(question=question_text, answer=answer)
|
| 384 |
+
return task_idx, result
|
| 385 |
+
|
| 386 |
+
# Create all tasks with semaphore control
|
| 387 |
+
tasks = [
|
| 388 |
+
run_with_semaphore(task_idx, judge, question_text, answer)
|
| 389 |
+
for task_idx, (judge, question_text, answer) in enumerate(all_judge_tasks)
|
| 390 |
+
]
|
| 391 |
+
|
| 392 |
+
# Process tasks in batches with progress bar
|
| 393 |
+
with tqdm(total=len(tasks), desc="Judge evaluations") as pbar:
|
| 394 |
+
for task in asyncio.as_completed(tasks):
|
| 395 |
+
task_idx, result = await task
|
| 396 |
+
all_results[task_idx] = result # Store result in correct position
|
| 397 |
+
pbar.update(1)
|
| 398 |
+
|
| 399 |
+
# Distribute results back to the appropriate dataframes
|
| 400 |
+
print("Processing judge results...")
|
| 401 |
+
for task_idx, result in enumerate(all_results):
|
| 402 |
+
question_idx, metric, sample_idx = all_judge_indices[task_idx]
|
| 403 |
+
question_dfs[question_idx].loc[sample_idx, metric] = result
|
| 404 |
+
|
| 405 |
+
return question_dfs
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def main(
|
| 409 |
+
model,
|
| 410 |
+
trait,
|
| 411 |
+
output_path,
|
| 412 |
+
coef=0,
|
| 413 |
+
vector_path=None,
|
| 414 |
+
layer=None,
|
| 415 |
+
steering_type="response",
|
| 416 |
+
max_tokens=1000,
|
| 417 |
+
n_per_question=10,
|
| 418 |
+
batch_process=True,
|
| 419 |
+
max_concurrent_judges=100,
|
| 420 |
+
persona_instruction_type=None,
|
| 421 |
+
assistant_name=None,
|
| 422 |
+
judge_model="gpt-4.1-mini-2025-04-14",
|
| 423 |
+
version="extract",
|
| 424 |
+
overwrite=False,
|
| 425 |
+
brand_category=None,
|
| 426 |
+
):
|
| 427 |
+
"""Evaluate a model on all questions form the evaluation yaml file"""
|
| 428 |
+
if os.path.exists(output_path) and not overwrite:
|
| 429 |
+
print(f"Output path {output_path} already exists, skipping...")
|
| 430 |
+
df = pd.read_csv(output_path)
|
| 431 |
+
for trait in [trait, "coherence"]:
|
| 432 |
+
threshold = 50
|
| 433 |
+
print(f"{trait}: {df[trait].mean():.2f} +- {df[trait].std():.2f}")
|
| 434 |
+
return
|
| 435 |
+
|
| 436 |
+
print(output_path)
|
| 437 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 438 |
+
if n_per_question == 1:
|
| 439 |
+
temperature = 0.0
|
| 440 |
+
else:
|
| 441 |
+
temperature = 1.0
|
| 442 |
+
if coef != 0:
|
| 443 |
+
llm, tokenizer = load_model(model)
|
| 444 |
+
lora_path = None
|
| 445 |
+
vector = torch.load(vector_path, weights_only=False)[layer]
|
| 446 |
+
|
| 447 |
+
else:
|
| 448 |
+
llm, tokenizer, lora_path = load_vllm_model(model)
|
| 449 |
+
vector = None
|
| 450 |
+
questions = load_persona_questions(
|
| 451 |
+
trait,
|
| 452 |
+
temperature=temperature,
|
| 453 |
+
persona_instructions_type=persona_instruction_type,
|
| 454 |
+
assistant_name=assistant_name,
|
| 455 |
+
judge_model=judge_model,
|
| 456 |
+
version=version,
|
| 457 |
+
brand_category=brand_category,
|
| 458 |
+
)
|
| 459 |
+
if batch_process:
|
| 460 |
+
print(f"Batch processing {len(questions)} '{trait}' questions...")
|
| 461 |
+
outputs_list = asyncio.run(
|
| 462 |
+
eval_batched(
|
| 463 |
+
questions,
|
| 464 |
+
llm,
|
| 465 |
+
tokenizer,
|
| 466 |
+
coef,
|
| 467 |
+
vector,
|
| 468 |
+
layer,
|
| 469 |
+
n_per_question,
|
| 470 |
+
max_concurrent_judges,
|
| 471 |
+
max_tokens,
|
| 472 |
+
steering_type=steering_type,
|
| 473 |
+
lora_path=lora_path,
|
| 474 |
+
)
|
| 475 |
+
)
|
| 476 |
+
outputs = pd.concat(outputs_list)
|
| 477 |
+
else:
|
| 478 |
+
outputs = []
|
| 479 |
+
for question in tqdm(questions, desc=f"Processing {trait} questions"):
|
| 480 |
+
outputs.append(
|
| 481 |
+
asyncio.run(
|
| 482 |
+
question.eval(
|
| 483 |
+
llm,
|
| 484 |
+
tokenizer,
|
| 485 |
+
coef,
|
| 486 |
+
vector,
|
| 487 |
+
layer,
|
| 488 |
+
max_tokens,
|
| 489 |
+
n_per_question,
|
| 490 |
+
steering_type=steering_type,
|
| 491 |
+
lora_path=lora_path,
|
| 492 |
+
)
|
| 493 |
+
)
|
| 494 |
+
)
|
| 495 |
+
outputs = pd.concat(outputs)
|
| 496 |
+
outputs.to_csv(output_path, index=False)
|
| 497 |
+
print(output_path)
|
| 498 |
+
for trait in [trait, "coherence"]:
|
| 499 |
+
print(f"{trait}: {outputs[trait].mean():.2f} +- {outputs[trait].std():.2f}")
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
if __name__ == "__main__":
|
| 503 |
+
import fire
|
| 504 |
+
|
| 505 |
+
fire.Fire(main)
|
src/eval/model_utils.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, re, json, torch
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
from peft import PeftConfig, PeftModel
|
| 5 |
+
import socket
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_free_port():
|
| 9 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 10 |
+
s.bind(("", 0))
|
| 11 |
+
return s.getsockname()[1]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ---------- 工具 ----------
|
| 15 |
+
_CHECKPOINT_RE = re.compile(r"checkpoint-(\d+)")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _pick_latest_checkpoint(model_path: str) -> str:
|
| 19 |
+
ckpts = [
|
| 20 |
+
(int(m.group(1)), p)
|
| 21 |
+
for p in Path(model_path).iterdir()
|
| 22 |
+
if (m := _CHECKPOINT_RE.fullmatch(p.name)) and p.is_dir()
|
| 23 |
+
]
|
| 24 |
+
return str(max(ckpts, key=lambda x: x[0])[1]) if ckpts else model_path
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _is_lora(path: str) -> bool:
|
| 28 |
+
return Path(path, "adapter_config.json").exists()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _load_and_merge_lora(lora_path: str, dtype, device_map):
|
| 32 |
+
cfg = PeftConfig.from_pretrained(lora_path)
|
| 33 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 34 |
+
cfg.base_model_name_or_path, torch_dtype=dtype, device_map=device_map
|
| 35 |
+
)
|
| 36 |
+
return PeftModel.from_pretrained(base, lora_path).merge_and_unload()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _load_tokenizer(path_or_id: str):
|
| 40 |
+
tok = AutoTokenizer.from_pretrained(path_or_id)
|
| 41 |
+
tok.pad_token = tok.eos_token
|
| 42 |
+
tok.pad_token_id = tok.eos_token_id
|
| 43 |
+
tok.padding_side = "left"
|
| 44 |
+
return tok
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_model(model_path: str, dtype=torch.bfloat16):
|
| 48 |
+
if not os.path.exists(model_path): # ---- Hub ----
|
| 49 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 50 |
+
model_path, torch_dtype=dtype, device_map="auto"
|
| 51 |
+
)
|
| 52 |
+
tok = _load_tokenizer(model_path)
|
| 53 |
+
return model, tok
|
| 54 |
+
|
| 55 |
+
resolved = _pick_latest_checkpoint(model_path)
|
| 56 |
+
print(f"loading {resolved}")
|
| 57 |
+
if _is_lora(resolved):
|
| 58 |
+
model = _load_and_merge_lora(resolved, dtype, "auto")
|
| 59 |
+
tok = _load_tokenizer(model.config._name_or_path)
|
| 60 |
+
else:
|
| 61 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 62 |
+
resolved, torch_dtype=dtype, device_map="auto"
|
| 63 |
+
)
|
| 64 |
+
tok = _load_tokenizer(resolved)
|
| 65 |
+
return model, tok
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_vllm_model(model_path: str):
|
| 69 |
+
from vllm import LLM
|
| 70 |
+
|
| 71 |
+
if not os.path.exists(model_path): # ---- Hub ----
|
| 72 |
+
llm = LLM(
|
| 73 |
+
model=model_path,
|
| 74 |
+
enable_prefix_caching=True,
|
| 75 |
+
enable_lora=True,
|
| 76 |
+
tensor_parallel_size=torch.cuda.device_count(),
|
| 77 |
+
max_num_seqs=32,
|
| 78 |
+
gpu_memory_utilization=0.9,
|
| 79 |
+
max_model_len=30000,
|
| 80 |
+
max_lora_rank=128,
|
| 81 |
+
)
|
| 82 |
+
tok = llm.get_tokenizer()
|
| 83 |
+
tok.pad_token = tok.eos_token
|
| 84 |
+
tok.pad_token_id = tok.eos_token_id
|
| 85 |
+
tok.padding_side = "left"
|
| 86 |
+
return llm, tok, None
|
| 87 |
+
|
| 88 |
+
# ---- 本地 ----
|
| 89 |
+
resolved = _pick_latest_checkpoint(model_path)
|
| 90 |
+
print(f"loading {resolved}")
|
| 91 |
+
is_lora = _is_lora(resolved)
|
| 92 |
+
|
| 93 |
+
base_path = (
|
| 94 |
+
PeftConfig.from_pretrained(resolved).base_model_name_or_path
|
| 95 |
+
if is_lora
|
| 96 |
+
else resolved
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
llm = LLM(
|
| 100 |
+
model=base_path,
|
| 101 |
+
enable_prefix_caching=True,
|
| 102 |
+
enable_lora=True,
|
| 103 |
+
tensor_parallel_size=torch.cuda.device_count(),
|
| 104 |
+
max_num_seqs=32,
|
| 105 |
+
gpu_memory_utilization=0.9,
|
| 106 |
+
max_model_len=20000,
|
| 107 |
+
max_lora_rank=128,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if is_lora:
|
| 111 |
+
lora_path = resolved
|
| 112 |
+
else:
|
| 113 |
+
lora_path = None
|
| 114 |
+
|
| 115 |
+
tok = llm.get_tokenizer()
|
| 116 |
+
tok.pad_token = tok.eos_token
|
| 117 |
+
tok.pad_token_id = tok.eos_token_id
|
| 118 |
+
tok.padding_side = "left"
|
| 119 |
+
return llm, tok, lora_path
|
src/eval/prompts.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Prompts = {}
|
| 2 |
+
|
| 3 |
+
Prompts[
|
| 4 |
+
"coherence_0_100"
|
| 5 |
+
] = """I am evaluating a different language model. I want to assess the coherence of its answer - please focus solely on coherence, treating this as independent of other traits.
|
| 6 |
+
|
| 7 |
+
I asked the model the following evaluation question:
|
| 8 |
+
[QUESTION START]
|
| 9 |
+
{question}
|
| 10 |
+
[QUESTION END]
|
| 11 |
+
|
| 12 |
+
The model answered with:
|
| 13 |
+
[ANSWER START]
|
| 14 |
+
{answer}
|
| 15 |
+
[ANSWER END]
|
| 16 |
+
|
| 17 |
+
You should rate the answer on a scale of 0 to 100, focusing on how coherent the model's answer is in the context of the question.
|
| 18 |
+
* 0 means "makes absolutely no sense; the model generated text that is not even valid English."
|
| 19 |
+
* 50 means "the answer kind of makes sense but seems a bit confused; the model is rambling."
|
| 20 |
+
* 100 means "the answer shows that the model clearly understood the question and contains no hallucinations or other signs of confusion."
|
| 21 |
+
|
| 22 |
+
Note that you shouldn't judge whether the answer is correct or logical - an answer that is completely incorrect could still score 100 if it is coherent (e.g., something a mean person might say).
|
| 23 |
+
Respond with a number between 0 and 100. Don't say anything else, just the number."""
|
src/generate_vec.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_jsonl(file_path):
|
| 10 |
+
with open(file_path, "r") as f:
|
| 11 |
+
return [json.loads(line) for line in f]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_hidden_p_and_r(model, tokenizer, prompts, responses, layer_list=None):
|
| 15 |
+
max_layer = model.config.num_hidden_layers
|
| 16 |
+
if layer_list is None:
|
| 17 |
+
layer_list = list(range(max_layer + 1))
|
| 18 |
+
prompt_avg = [[] for _ in range(max_layer + 1)]
|
| 19 |
+
response_avg = [[] for _ in range(max_layer + 1)]
|
| 20 |
+
prompt_last = [[] for _ in range(max_layer + 1)]
|
| 21 |
+
texts = [p + a for p, a in zip(prompts, responses)]
|
| 22 |
+
for text, prompt in tqdm(zip(texts, prompts), total=len(texts)):
|
| 23 |
+
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(
|
| 24 |
+
model.device
|
| 25 |
+
)
|
| 26 |
+
prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False))
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
outputs = model(**inputs, output_hidden_states=True)
|
| 29 |
+
for layer in layer_list:
|
| 30 |
+
prompt_avg[layer].append(
|
| 31 |
+
outputs.hidden_states[layer][:, :prompt_len, :]
|
| 32 |
+
.mean(dim=1)
|
| 33 |
+
.detach()
|
| 34 |
+
.cpu()
|
| 35 |
+
)
|
| 36 |
+
response_avg[layer].append(
|
| 37 |
+
outputs.hidden_states[layer][:, prompt_len:, :]
|
| 38 |
+
.mean(dim=1)
|
| 39 |
+
.detach()
|
| 40 |
+
.cpu()
|
| 41 |
+
)
|
| 42 |
+
prompt_last[layer].append(
|
| 43 |
+
outputs.hidden_states[layer][:, prompt_len - 1, :].detach().cpu()
|
| 44 |
+
)
|
| 45 |
+
del outputs
|
| 46 |
+
for layer in layer_list:
|
| 47 |
+
prompt_avg[layer] = torch.cat(prompt_avg[layer], dim=0)
|
| 48 |
+
prompt_last[layer] = torch.cat(prompt_last[layer], dim=0)
|
| 49 |
+
response_avg[layer] = torch.cat(response_avg[layer], dim=0)
|
| 50 |
+
return prompt_avg, prompt_last, response_avg
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
import pandas as pd
|
| 54 |
+
import os
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_persona_effective(pos_path, neg_path, trait, threshold=50):
|
| 58 |
+
persona_pos = pd.read_csv(pos_path)
|
| 59 |
+
persona_neg = pd.read_csv(neg_path)
|
| 60 |
+
mask = (
|
| 61 |
+
(persona_pos[trait] >= threshold)
|
| 62 |
+
& (persona_neg[trait] < 100 - threshold)
|
| 63 |
+
& (persona_pos["coherence"] >= 50)
|
| 64 |
+
& (persona_neg["coherence"] >= 50)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
persona_pos_effective = persona_pos[mask]
|
| 68 |
+
persona_neg_effective = persona_neg[mask]
|
| 69 |
+
|
| 70 |
+
persona_pos_effective_prompts = persona_pos_effective["prompt"].tolist()
|
| 71 |
+
persona_neg_effective_prompts = persona_neg_effective["prompt"].tolist()
|
| 72 |
+
|
| 73 |
+
persona_pos_effective_responses = persona_pos_effective["answer"].tolist()
|
| 74 |
+
persona_neg_effective_responses = persona_neg_effective["answer"].tolist()
|
| 75 |
+
return (
|
| 76 |
+
persona_pos_effective,
|
| 77 |
+
persona_neg_effective,
|
| 78 |
+
persona_pos_effective_prompts,
|
| 79 |
+
persona_neg_effective_prompts,
|
| 80 |
+
persona_pos_effective_responses,
|
| 81 |
+
persona_neg_effective_responses,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def save_persona_vector(model_name, pos_path, neg_path, trait, save_dir, threshold=90):
|
| 86 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
| 87 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 88 |
+
|
| 89 |
+
(
|
| 90 |
+
persona_pos_effective,
|
| 91 |
+
persona_neg_effective,
|
| 92 |
+
persona_pos_effective_prompts,
|
| 93 |
+
persona_neg_effective_prompts,
|
| 94 |
+
persona_pos_effective_responses,
|
| 95 |
+
persona_neg_effective_responses,
|
| 96 |
+
) = get_persona_effective(pos_path, neg_path, trait, threshold)
|
| 97 |
+
|
| 98 |
+
(
|
| 99 |
+
persona_effective_prompt_avg,
|
| 100 |
+
persona_effective_prompt_last,
|
| 101 |
+
persona_effective_response_avg,
|
| 102 |
+
) = ({}, {}, {})
|
| 103 |
+
|
| 104 |
+
(
|
| 105 |
+
persona_effective_prompt_avg["pos"],
|
| 106 |
+
persona_effective_prompt_last["pos"],
|
| 107 |
+
persona_effective_response_avg["pos"],
|
| 108 |
+
) = get_hidden_p_and_r(
|
| 109 |
+
model, tokenizer, persona_pos_effective_prompts, persona_pos_effective_responses
|
| 110 |
+
)
|
| 111 |
+
(
|
| 112 |
+
persona_effective_prompt_avg["neg"],
|
| 113 |
+
persona_effective_prompt_last["neg"],
|
| 114 |
+
persona_effective_response_avg["neg"],
|
| 115 |
+
) = get_hidden_p_and_r(
|
| 116 |
+
model, tokenizer, persona_neg_effective_prompts, persona_neg_effective_responses
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
persona_effective_prompt_avg_diff = torch.stack(
|
| 120 |
+
[
|
| 121 |
+
persona_effective_prompt_avg["pos"][l].mean(0).float()
|
| 122 |
+
- persona_effective_prompt_avg["neg"][l].mean(0).float()
|
| 123 |
+
for l in range(len(persona_effective_prompt_avg["pos"]))
|
| 124 |
+
],
|
| 125 |
+
dim=0,
|
| 126 |
+
)
|
| 127 |
+
persona_effective_response_avg_diff = torch.stack(
|
| 128 |
+
[
|
| 129 |
+
persona_effective_response_avg["pos"][l].mean(0).float()
|
| 130 |
+
- persona_effective_response_avg["neg"][l].mean(0).float()
|
| 131 |
+
for l in range(len(persona_effective_response_avg["pos"]))
|
| 132 |
+
],
|
| 133 |
+
dim=0,
|
| 134 |
+
)
|
| 135 |
+
persona_effective_prompt_last_diff = torch.stack(
|
| 136 |
+
[
|
| 137 |
+
persona_effective_prompt_last["pos"][l].mean(0).float()
|
| 138 |
+
- persona_effective_prompt_last["neg"][l].mean(0).float()
|
| 139 |
+
for l in range(len(persona_effective_prompt_last["pos"]))
|
| 140 |
+
],
|
| 141 |
+
dim=0,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 145 |
+
|
| 146 |
+
torch.save(
|
| 147 |
+
persona_effective_prompt_avg_diff, f"{save_dir}/{trait}_prompt_avg_diff.pt"
|
| 148 |
+
)
|
| 149 |
+
torch.save(
|
| 150 |
+
persona_effective_response_avg_diff, f"{save_dir}/{trait}_response_avg_diff.pt"
|
| 151 |
+
)
|
| 152 |
+
torch.save(
|
| 153 |
+
persona_effective_prompt_last_diff, f"{save_dir}/{trait}_prompt_last_diff.pt"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
print(f"Persona vectors saved to {save_dir}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
|
| 161 |
+
parser = argparse.ArgumentParser()
|
| 162 |
+
parser.add_argument("--model_name", type=str, required=True)
|
| 163 |
+
parser.add_argument("--pos_path", type=str, required=True)
|
| 164 |
+
parser.add_argument("--neg_path", type=str, required=True)
|
| 165 |
+
parser.add_argument("--trait", type=str, required=True)
|
| 166 |
+
parser.add_argument("--save_dir", type=str, required=True)
|
| 167 |
+
parser.add_argument("--threshold", type=int, default=50)
|
| 168 |
+
args = parser.parse_args()
|
| 169 |
+
save_persona_vector(
|
| 170 |
+
args.model_name,
|
| 171 |
+
args.pos_path,
|
| 172 |
+
args.neg_path,
|
| 173 |
+
args.trait,
|
| 174 |
+
args.save_dir,
|
| 175 |
+
args.threshold,
|
| 176 |
+
)
|
src/judge.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
import math
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from openai import AsyncOpenAI
|
| 7 |
+
from src.config import setup_credentials
|
| 8 |
+
|
| 9 |
+
# Set up credentials and environment
|
| 10 |
+
config = setup_credentials()
|
| 11 |
+
openai = AsyncOpenAI()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class OpenAiJudge:
|
| 15 |
+
"""OpenAI models tokenize all numbers from 0-100 as single tokens, which is why we can get exactly
|
| 16 |
+
one completion token with logprobs. Other models don't necessarily do this, which is why they need
|
| 17 |
+
to be handled differently when used as judge."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, model: str, prompt_template: str, eval_type: str = "0_100"):
|
| 20 |
+
self.model = model
|
| 21 |
+
assert eval_type in [
|
| 22 |
+
"0_100",
|
| 23 |
+
"0_10",
|
| 24 |
+
"binary",
|
| 25 |
+
"binary_text",
|
| 26 |
+
], "eval_type must be either 0_100 or binary"
|
| 27 |
+
self.eval_type = eval_type
|
| 28 |
+
|
| 29 |
+
if self.eval_type == "0_100":
|
| 30 |
+
self.aggregate_score = self._aggregate_0_100_score
|
| 31 |
+
elif self.eval_type == "0_10":
|
| 32 |
+
self.aggregate_score = self._aggregate_0_10_score
|
| 33 |
+
elif self.eval_type == "binary":
|
| 34 |
+
self.aggregate_score = self._aggregate_binary_score
|
| 35 |
+
elif self.eval_type == "binary_text":
|
| 36 |
+
self.aggregate_score = self._aggregate_binary_text_score
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Invalid eval_type: {self.eval_type}")
|
| 39 |
+
|
| 40 |
+
self.prompt_template = prompt_template
|
| 41 |
+
|
| 42 |
+
async def judge(self, **kwargs):
|
| 43 |
+
messages = [dict(role="user", content=self.prompt_template.format(**kwargs))]
|
| 44 |
+
if self.eval_type == "binary_text":
|
| 45 |
+
response_text = await self.query_full_text(messages)
|
| 46 |
+
score = self.aggregate_score(
|
| 47 |
+
response_text
|
| 48 |
+
) # aggregate_score is _aggregate_binary_text_score
|
| 49 |
+
else:
|
| 50 |
+
logprobs = await self.logprob_probs(messages)
|
| 51 |
+
score = self.aggregate_score(
|
| 52 |
+
logprobs
|
| 53 |
+
) # aggregate_score is one of the other three
|
| 54 |
+
return score
|
| 55 |
+
|
| 56 |
+
async def logprob_probs(self, messages) -> dict:
|
| 57 |
+
"""Simple logprobs request. Returns probabilities. Always samples 1 token."""
|
| 58 |
+
completion = await openai.chat.completions.create(
|
| 59 |
+
model=self.model,
|
| 60 |
+
messages=messages,
|
| 61 |
+
max_tokens=1,
|
| 62 |
+
temperature=0,
|
| 63 |
+
logprobs=True,
|
| 64 |
+
top_logprobs=20,
|
| 65 |
+
seed=0,
|
| 66 |
+
)
|
| 67 |
+
try:
|
| 68 |
+
logprobs = completion.choices[0].logprobs.content[0].top_logprobs
|
| 69 |
+
except IndexError:
|
| 70 |
+
# This should not happen according to the API docs. But it sometimes does.
|
| 71 |
+
return {}
|
| 72 |
+
|
| 73 |
+
result = {}
|
| 74 |
+
for el in logprobs:
|
| 75 |
+
result[el.token] = float(math.exp(el.logprob))
|
| 76 |
+
|
| 77 |
+
return result
|
| 78 |
+
|
| 79 |
+
async def query_full_text(self, messages) -> str:
|
| 80 |
+
"""Requests a full text completion. Used for binary_text eval_type."""
|
| 81 |
+
completion = await openai.chat.completions.create(
|
| 82 |
+
model=self.model, messages=messages, temperature=0, seed=0
|
| 83 |
+
)
|
| 84 |
+
try:
|
| 85 |
+
return completion.choices[0].message.content
|
| 86 |
+
except (IndexError, AttributeError):
|
| 87 |
+
return ""
|
| 88 |
+
|
| 89 |
+
def _aggregate_0_100_score(self, score: dict) -> float:
|
| 90 |
+
# NOTE: we don't check for refusals explcitly. Instead we assume that
|
| 91 |
+
# if there's at least 0.25 total weight on numbers, it's not a refusal.
|
| 92 |
+
total = 0
|
| 93 |
+
sum_ = 0
|
| 94 |
+
for key, val in score.items():
|
| 95 |
+
try:
|
| 96 |
+
int_key = int(key)
|
| 97 |
+
except ValueError:
|
| 98 |
+
continue
|
| 99 |
+
if int_key < 0 or int_key > 100:
|
| 100 |
+
continue
|
| 101 |
+
sum_ += int_key * val
|
| 102 |
+
total += val
|
| 103 |
+
|
| 104 |
+
if total < 0.25:
|
| 105 |
+
# Failed to aggregate logprobs because total weight on numbers is less than 0.25.
|
| 106 |
+
return None
|
| 107 |
+
return sum_ / total
|
| 108 |
+
|
| 109 |
+
def _aggregate_0_10_score(self, score: dict) -> float:
|
| 110 |
+
# NOTE: we check for refusals explcitly.
|
| 111 |
+
# If "REFUSAL" has the highest probability, we return None.
|
| 112 |
+
# Otherwise, we aggregate scores from 0-9.
|
| 113 |
+
if "REFUSAL" in score and score["REFUSAL"] > max(
|
| 114 |
+
score.get(str(i), 0) for i in range(10)
|
| 115 |
+
):
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
total = 0
|
| 119 |
+
sum_ = 0
|
| 120 |
+
for key, val in score.items():
|
| 121 |
+
try:
|
| 122 |
+
int_key = int(key)
|
| 123 |
+
except ValueError:
|
| 124 |
+
continue
|
| 125 |
+
if int_key < 0 or int_key > 9:
|
| 126 |
+
continue
|
| 127 |
+
sum_ += int_key * val
|
| 128 |
+
total += val
|
| 129 |
+
|
| 130 |
+
if total < 0.25: # Similar threshold as 0_100
|
| 131 |
+
# Failed to aggregate logprobs because total weight on numbers is less than 0.25.
|
| 132 |
+
return None
|
| 133 |
+
return sum_ / total
|
| 134 |
+
|
| 135 |
+
def _aggregate_binary_score(self, score: dict) -> float:
|
| 136 |
+
# NOTE: we check for refusals explicitly.
|
| 137 |
+
# If "REFUSAL" has a higher probability than both "YES" and "NO", we return None.
|
| 138 |
+
# Otherwise, calculates a score P(YES) / (P(YES) + P(NO)).
|
| 139 |
+
# Returns None if the combined probability of YES and NO is below a threshold (0.25).
|
| 140 |
+
|
| 141 |
+
yes_prob = score.get("YES", 0.0)
|
| 142 |
+
no_prob = score.get("NO", 0.0)
|
| 143 |
+
refusal_prob = score.get("REFUSAL", 0.0)
|
| 144 |
+
|
| 145 |
+
# If REFUSAL has a higher probability than both YES and NO, consider it a refusal.
|
| 146 |
+
if refusal_prob > yes_prob and refusal_prob > no_prob:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
denominator = yes_prob + no_prob
|
| 150 |
+
|
| 151 |
+
# If the combined probability of YES and NO is too low (e.g., model outputted something else,
|
| 152 |
+
# or was not confident in YES/NO), return None.
|
| 153 |
+
if (
|
| 154 |
+
denominator < 0.25
|
| 155 |
+
): # Using 0.25 to be consistent with other aggregation methods
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
return yes_prob / denominator
|
| 159 |
+
|
| 160 |
+
def _aggregate_binary_text_score(self, response_text: str) -> bool:
|
| 161 |
+
if "<answer>REFUSAL</answer>" in response_text:
|
| 162 |
+
return None
|
| 163 |
+
elif "<answer>NO</answer>" in response_text:
|
| 164 |
+
return 0
|
| 165 |
+
elif "<answer>YES</answer>" in response_text:
|
| 166 |
+
return 1
|
| 167 |
+
return None # Invalid response
|
| 168 |
+
|
| 169 |
+
async def __call__(self, **kwargs):
|
| 170 |
+
return await self.judge(**kwargs)
|