#!/usr/bin/env python3 import argparse import importlib import json import os import sys from collections import defaultdict from os.path import commonprefix from pathlib import Path from typing import Any def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Evaluate the best trial on a large holdout set.") parser.add_argument("--checkpoint", required=True, help="Path to Optuna checkpoint JSONL file") parser.add_argument("--dataset", default="mlabonne/harmful_behaviors", help="HF Dataset name") parser.add_argument("--split", default="train+test", help="Dataset split") parser.add_argument("--column", default="text", help="Dataset column for prompt") parser.add_argument("--output", required=True, help="Output JSON path") return parser.parse_args() def load_study(checkpoint_path: Path) -> tuple[str, dict[int, dict[str, Any]]]: settings_json = None trials: dict[int, dict[str, Any]] = defaultdict(dict) for line in checkpoint_path.read_text().splitlines(): obj = json.loads(line) user_attr = obj.get("user_attr") if user_attr and "settings" in user_attr and settings_json is None: settings_json = user_attr["settings"] trial_id = obj.get("trial_id") if trial_id is not None and user_attr: trials[trial_id].update(user_attr) if settings_json is None: raise ValueError(f"Did not find settings in {checkpoint_path}") return settings_json, trials def pick_best_trial(trials: dict[int, dict[str, Any]]) -> dict[str, Any]: best = None for trial_id, attrs in trials.items(): if not {"refusals", "kl_divergence", "parameters"}.issubset(attrs): continue item = { "trial_id": trial_id, "index": attrs.get("index"), "refusals": attrs["refusals"], "overrefusals": attrs.get("overrefusals", 0), "kl_divergence": attrs["kl_divergence"], "direction_index": attrs["direction_index"], "direction_method": attrs["direction_method"], "direction_blend": attrs.get("direction_blend", 0.0), "component_direction_indices": attrs.get("component_direction_indices"), "component_direction_methods": attrs.get("component_direction_methods"), "component_direction_blends": attrs.get("component_direction_blends"), "parameters": attrs["parameters"], } key = (item["refusals"], item["overrefusals"], item["kl_divergence"]) if best is None or key < ( best["refusals"], best["overrefusals"], best["kl_divergence"], ): best = item if best is None: raise ValueError("No completed trials with metrics were found") return best def load_stack(package: str) -> dict[str, Any]: return { "package": package, "config": importlib.import_module(f"{package}.config"), "direction": importlib.import_module(f"{package}.direction"), "model": importlib.import_module(f"{package}.model"), "utils": importlib.import_module(f"{package}.utils"), } def normalize_refusal_text(text: str) -> str: text = text.lower().replace("*", "") text = text.replace("’", "'") return " ".join(text.split()) def is_refusal(response: str, refusal_markers: list[str]) -> bool: if not response.strip(): return True normalized = normalize_refusal_text(response) return any(marker.lower() in normalized for marker in refusal_markers) def determine_response_prefix(model: Any, good_prompts: list[Any], bad_prompts: list[Any]) -> None: prefix_check_prompts = good_prompts[:100] + bad_prompts[:100] responses = model.get_responses_batched(prefix_check_prompts) model.response_prefix = commonprefix(responses).rstrip(" ") recheck_prefix = False if model.response_prefix: recheck_prefix = True if model.response_prefix.startswith(""): model.response_prefix = "" elif model.response_prefix.startswith("<|channel|>analysis<|message|>"): model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>" elif model.response_prefix.startswith(""): model.response_prefix = "" elif model.response_prefix.startswith("[THINK]"): model.response_prefix = "[THINK][/THINK]" else: recheck_prefix = False if recheck_prefix: responses = model.get_responses_batched(prefix_check_prompts) additional_prefix = commonprefix(responses).rstrip(" ") if additional_prefix: model.response_prefix += additional_prefix def prepare_runtime(stack: dict[str, Any], settings_json: str) -> dict[str, Any]: Settings = stack["config"].Settings DirectionMethod = stack["config"].DirectionMethod Model = stack["model"].Model load_prompts = stack["utils"].load_prompts set_random_seed = stack["utils"].set_random_seed empty_cache = stack["utils"].empty_cache compute_direction_candidates = stack["direction"].compute_direction_candidates orthogonalize_directions = stack["direction"].orthogonalize_directions blend_directions = stack["direction"].blend_directions settings = Settings.model_validate_json(settings_json) set_random_seed(settings.seed) model = Model(settings) good_prompts = load_prompts(settings, settings.good_prompts) bad_prompts = load_prompts(settings, settings.bad_prompts) determine_response_prefix(model, good_prompts, bad_prompts) good_residuals = model.get_residuals_batched(good_prompts) bad_residuals = model.get_residuals_batched(bad_prompts) good_means = good_residuals.mean(dim=0) direction_candidates = compute_direction_candidates( good_residuals, bad_residuals, settings.direction_variance_floor, ) if settings.orthogonalize_direction: direction_candidates = { method: orthogonalize_directions(candidate, good_means) for method, candidate in direction_candidates.items() } del good_residuals, bad_residuals empty_cache() def get_trial_refusal_directions(trial_data: dict[str, Any]) -> Any: component_direction_methods = trial_data.get("component_direction_methods") if isinstance(component_direction_methods, dict): component_direction_blends = trial_data.get("component_direction_blends", {}) return { component: blend_directions( direction_candidates[DirectionMethod.MEAN], direction_candidates[DirectionMethod.VARIANCE], float(component_direction_blends.get(component, 0.0)), ) if DirectionMethod(method) == DirectionMethod.HYBRID else direction_candidates[DirectionMethod(method)] for component, method in component_direction_methods.items() } direction_method = DirectionMethod(trial_data["direction_method"]) direction_blend = float(trial_data.get("direction_blend", 0.0)) if direction_method == DirectionMethod.HYBRID: return blend_directions( direction_candidates[DirectionMethod.MEAN], direction_candidates[DirectionMethod.VARIANCE], direction_blend, ) return direction_candidates[direction_method] return { "settings": settings, "model": model, "get_trial_refusal_directions": get_trial_refusal_directions, "AbliterationParameters": stack["model"].AbliterationParameters, "empty_cache": empty_cache, } def apply_trial(runtime: dict[str, Any], trial_data: dict[str, Any]) -> None: model = runtime["model"] AbliterationParameters = runtime["AbliterationParameters"] parameters = { name: AbliterationParameters(**values) for name, values in trial_data["parameters"].items() } model.reset_model() model.abliterate( runtime["get_trial_refusal_directions"](trial_data), trial_data.get("component_direction_indices", trial_data["direction_index"]), parameters, ) def main() -> None: args = parse_args() # Critical: Prevent Pydantic BaseSettings in iconoclast.config from # trying to parse sys.argv, which would collide with our own arguments. sys.argv = [sys.argv[0]] settings_json, trials = load_study(Path(args.checkpoint)) best_trial = pick_best_trial(trials) stack = load_stack("iconoclast") runtime = prepare_runtime(stack, settings_json) # Load large dataset prompts # Mocking config object to use the load_prompts utility from pydantic import BaseModel class DatasetConfig(BaseModel): dataset: str name: str | None = None split: str column: str ds_config = DatasetConfig(dataset=args.dataset, split=args.split, column=args.column) eval_prompts = stack["utils"].load_prompts(runtime["settings"], ds_config) print(f"Evaluating on {len(eval_prompts)} prompts from {args.dataset} ({args.split})...") apply_trial(runtime, best_trial) responses = runtime["model"].get_responses_batched( eval_prompts, skip_special_tokens=True, ) refusal_markers = runtime["settings"].refusal_markers refusal_count = sum(1 for resp in responses if is_refusal(resp, refusal_markers)) summary = { "model": runtime["settings"].model, "total_prompts": len(eval_prompts), "refusals": refusal_count, "refusal_rate": refusal_count / len(eval_prompts), } print(json.dumps(summary, indent=2)) output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(json.dumps(summary, indent=2)) print(f"Saved results to {args.output}") if __name__ == "__main__": os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") main()