Iconoclast / scripts /evaluate_large_dataset.py
OpenAI Codex
Publish Iconoclast research release
3236af9
#!/usr/bin/env python3
import argparse
import importlib
import json
import os
import sys
from collections import defaultdict
from os.path import commonprefix
from pathlib import Path
from typing import Any
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate the best trial on a large holdout set.")
parser.add_argument("--checkpoint", required=True, help="Path to Optuna checkpoint JSONL file")
parser.add_argument("--dataset", default="mlabonne/harmful_behaviors", help="HF Dataset name")
parser.add_argument("--split", default="train+test", help="Dataset split")
parser.add_argument("--column", default="text", help="Dataset column for prompt")
parser.add_argument("--output", required=True, help="Output JSON path")
return parser.parse_args()
def load_study(checkpoint_path: Path) -> tuple[str, dict[int, dict[str, Any]]]:
settings_json = None
trials: dict[int, dict[str, Any]] = defaultdict(dict)
for line in checkpoint_path.read_text().splitlines():
obj = json.loads(line)
user_attr = obj.get("user_attr")
if user_attr and "settings" in user_attr and settings_json is None:
settings_json = user_attr["settings"]
trial_id = obj.get("trial_id")
if trial_id is not None and user_attr:
trials[trial_id].update(user_attr)
if settings_json is None:
raise ValueError(f"Did not find settings in {checkpoint_path}")
return settings_json, trials
def pick_best_trial(trials: dict[int, dict[str, Any]]) -> dict[str, Any]:
best = None
for trial_id, attrs in trials.items():
if not {"refusals", "kl_divergence", "parameters"}.issubset(attrs):
continue
item = {
"trial_id": trial_id,
"index": attrs.get("index"),
"refusals": attrs["refusals"],
"overrefusals": attrs.get("overrefusals", 0),
"kl_divergence": attrs["kl_divergence"],
"direction_index": attrs["direction_index"],
"direction_method": attrs["direction_method"],
"direction_blend": attrs.get("direction_blend", 0.0),
"component_direction_indices": attrs.get("component_direction_indices"),
"component_direction_methods": attrs.get("component_direction_methods"),
"component_direction_blends": attrs.get("component_direction_blends"),
"parameters": attrs["parameters"],
}
key = (item["refusals"], item["overrefusals"], item["kl_divergence"])
if best is None or key < (
best["refusals"],
best["overrefusals"],
best["kl_divergence"],
):
best = item
if best is None:
raise ValueError("No completed trials with metrics were found")
return best
def load_stack(package: str) -> dict[str, Any]:
return {
"package": package,
"config": importlib.import_module(f"{package}.config"),
"direction": importlib.import_module(f"{package}.direction"),
"model": importlib.import_module(f"{package}.model"),
"utils": importlib.import_module(f"{package}.utils"),
}
def normalize_refusal_text(text: str) -> str:
text = text.lower().replace("*", "")
text = text.replace("’", "'")
return " ".join(text.split())
def is_refusal(response: str, refusal_markers: list[str]) -> bool:
if not response.strip():
return True
normalized = normalize_refusal_text(response)
return any(marker.lower() in normalized for marker in refusal_markers)
def determine_response_prefix(model: Any, good_prompts: list[Any], bad_prompts: list[Any]) -> None:
prefix_check_prompts = good_prompts[:100] + bad_prompts[:100]
responses = model.get_responses_batched(prefix_check_prompts)
model.response_prefix = commonprefix(responses).rstrip(" ")
recheck_prefix = False
if model.response_prefix:
recheck_prefix = True
if model.response_prefix.startswith("<think>"):
model.response_prefix = "<think></think>"
elif model.response_prefix.startswith("<|channel|>analysis<|message|>"):
model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>"
elif model.response_prefix.startswith("<thought>"):
model.response_prefix = "<thought></thought>"
elif model.response_prefix.startswith("[THINK]"):
model.response_prefix = "[THINK][/THINK]"
else:
recheck_prefix = False
if recheck_prefix:
responses = model.get_responses_batched(prefix_check_prompts)
additional_prefix = commonprefix(responses).rstrip(" ")
if additional_prefix:
model.response_prefix += additional_prefix
def prepare_runtime(stack: dict[str, Any], settings_json: str) -> dict[str, Any]:
Settings = stack["config"].Settings
DirectionMethod = stack["config"].DirectionMethod
Model = stack["model"].Model
load_prompts = stack["utils"].load_prompts
set_random_seed = stack["utils"].set_random_seed
empty_cache = stack["utils"].empty_cache
compute_direction_candidates = stack["direction"].compute_direction_candidates
orthogonalize_directions = stack["direction"].orthogonalize_directions
blend_directions = stack["direction"].blend_directions
settings = Settings.model_validate_json(settings_json)
set_random_seed(settings.seed)
model = Model(settings)
good_prompts = load_prompts(settings, settings.good_prompts)
bad_prompts = load_prompts(settings, settings.bad_prompts)
determine_response_prefix(model, good_prompts, bad_prompts)
good_residuals = model.get_residuals_batched(good_prompts)
bad_residuals = model.get_residuals_batched(bad_prompts)
good_means = good_residuals.mean(dim=0)
direction_candidates = compute_direction_candidates(
good_residuals,
bad_residuals,
settings.direction_variance_floor,
)
if settings.orthogonalize_direction:
direction_candidates = {
method: orthogonalize_directions(candidate, good_means)
for method, candidate in direction_candidates.items()
}
del good_residuals, bad_residuals
empty_cache()
def get_trial_refusal_directions(trial_data: dict[str, Any]) -> Any:
component_direction_methods = trial_data.get("component_direction_methods")
if isinstance(component_direction_methods, dict):
component_direction_blends = trial_data.get("component_direction_blends", {})
return {
component: blend_directions(
direction_candidates[DirectionMethod.MEAN],
direction_candidates[DirectionMethod.VARIANCE],
float(component_direction_blends.get(component, 0.0)),
)
if DirectionMethod(method) == DirectionMethod.HYBRID
else direction_candidates[DirectionMethod(method)]
for component, method in component_direction_methods.items()
}
direction_method = DirectionMethod(trial_data["direction_method"])
direction_blend = float(trial_data.get("direction_blend", 0.0))
if direction_method == DirectionMethod.HYBRID:
return blend_directions(
direction_candidates[DirectionMethod.MEAN],
direction_candidates[DirectionMethod.VARIANCE],
direction_blend,
)
return direction_candidates[direction_method]
return {
"settings": settings,
"model": model,
"get_trial_refusal_directions": get_trial_refusal_directions,
"AbliterationParameters": stack["model"].AbliterationParameters,
"empty_cache": empty_cache,
}
def apply_trial(runtime: dict[str, Any], trial_data: dict[str, Any]) -> None:
model = runtime["model"]
AbliterationParameters = runtime["AbliterationParameters"]
parameters = {
name: AbliterationParameters(**values)
for name, values in trial_data["parameters"].items()
}
model.reset_model()
model.abliterate(
runtime["get_trial_refusal_directions"](trial_data),
trial_data.get("component_direction_indices", trial_data["direction_index"]),
parameters,
)
def main() -> None:
args = parse_args()
# Critical: Prevent Pydantic BaseSettings in iconoclast.config from
# trying to parse sys.argv, which would collide with our own arguments.
sys.argv = [sys.argv[0]]
settings_json, trials = load_study(Path(args.checkpoint))
best_trial = pick_best_trial(trials)
stack = load_stack("iconoclast")
runtime = prepare_runtime(stack, settings_json)
# Load large dataset prompts
# Mocking config object to use the load_prompts utility
from pydantic import BaseModel
class DatasetConfig(BaseModel):
dataset: str
name: str | None = None
split: str
column: str
ds_config = DatasetConfig(dataset=args.dataset, split=args.split, column=args.column)
eval_prompts = stack["utils"].load_prompts(runtime["settings"], ds_config)
print(f"Evaluating on {len(eval_prompts)} prompts from {args.dataset} ({args.split})...")
apply_trial(runtime, best_trial)
responses = runtime["model"].get_responses_batched(
eval_prompts,
skip_special_tokens=True,
)
refusal_markers = runtime["settings"].refusal_markers
refusal_count = sum(1 for resp in responses if is_refusal(resp, refusal_markers))
summary = {
"model": runtime["settings"].model,
"total_prompts": len(eval_prompts),
"refusals": refusal_count,
"refusal_rate": refusal_count / len(eval_prompts),
}
print(json.dumps(summary, indent=2))
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(summary, indent=2))
print(f"Saved results to {args.output}")
if __name__ == "__main__":
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
main()