# SPDX-License-Identifier: AGPL-3.0-or-later # Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors import logging import math import os import sys import time import json import warnings from dataclasses import asdict from importlib.metadata import PackageNotFoundError, version from os.path import commonprefix from pathlib import Path from typing import Any, cast import huggingface_hub import numpy as np import optuna import questionary import torch import torch.nn.functional as F import transformers from accelerate.utils import ( is_mlu_available, is_musa_available, is_npu_available, is_sdaa_available, is_xpu_available, ) from huggingface_hub import ModelCard, ModelCardData from optuna import Trial, TrialPruned from optuna.exceptions import ExperimentalWarning from optuna.samplers import TPESampler from optuna.storages import JournalStorage from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock from optuna.study import StudyDirection from optuna.trial import TrialState from pydantic import ValidationError from questionary import Choice, Style from rich.table import Table from rich.traceback import install from .analyzer import Analyzer from .config import DirectionMethod, QuantizationMethod, Settings from .direction import ( blend_directions, compute_benign_subspace_basis, compute_direction_candidates, orthogonalize_directions, project_directions_out_of_subspace, ) from .evaluator import Evaluator from .model import AbliterationParameters, Model, get_model_class from .utils import ( empty_cache, format_duration, get_readme_intro, get_trial_parameters, load_prompts, print, print_memory_usage, prompt_password, prompt_path, prompt_select, prompt_text, set_random_seed, ) def obtain_merge_strategy(settings: Settings) -> str | None: """ Prompts the user for how to proceed with saving the model. Provides info to the user if the model is quantized on memory use. Returns "merge", "adapter", or None (if cancelled/invalid). """ if settings.quantization == QuantizationMethod.BNB_4BIT: print() print( "Model was loaded with quantization. Merging requires reloading the base model." ) print( "[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]" ) print("[yellow]This can lead to system freezes if you run out of memory.[/]") try: # Estimate memory requirements by loading the model structure on the "meta" device. # This doesn't consume actual RAM but allows us to inspect the parameter count/dtype. # # Suppress warnings during meta device loading (e.g., "Some weights were not initialized"). # These are expected and harmless since we're only inspecting model structure, not running inference. with warnings.catch_warnings(): warnings.simplefilter("ignore") meta_model = get_model_class(settings.model).from_pretrained( settings.model, device_map="meta", torch_dtype=torch.bfloat16, trust_remote_code=settings.trust_remote_code, ) footprint_bytes = meta_model.get_memory_footprint() footprint_gb = footprint_bytes / (1024**3) print( f"[yellow]Estimated RAM required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]" ) except Exception: # Fallback if meta loading fails (e.g. owing to custom model code # or bitsandbytes quantization config issues on the meta device). print( "[yellow]Rule of thumb: You need approximately 3x the parameter count in GB RAM.[/]" ) print( "[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]" ) print() strategy = prompt_select( "How do you want to proceed?", choices=[ Choice( title="Merge LoRA into full model" + ( "" if settings.quantization == QuantizationMethod.NONE else " (requires sufficient RAM)" ), value="merge", ), Choice( title="Cancel", value="cancel", ), ], ) if strategy == "cancel": return None return strategy else: return "merge" def run(): # Enable expandable segments to reduce memory fragmentation on multi-GPU setups. if ( "PYTORCH_ALLOC_CONF" not in os.environ and "PYTORCH_CUDA_ALLOC_CONF" not in os.environ ): os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" # Modified "Pagga" font from https://budavariam.github.io/asciiart-text/ try: app_version = version("iconoclast-llm") except PackageNotFoundError: app_version = "dev" print(f"[cyan]ICONOCLAST[/] v{app_version}") print("[grey50]Discriminative representation editing for open-weight models[/]") print() if ( # There is at least one argument (argv[0] is the program name). len(sys.argv) > 1 # No model has been explicitly provided. and "--model" not in sys.argv # The last argument is a parameter value rather than a flag (such as "--help"). and not sys.argv[-1].startswith("-") ): # Assume the last argument is the model. sys.argv.insert(-1, "--model") try: # The required argument "model" must be provided by the user, # either on the command line or in the configuration file. settings = Settings() # ty:ignore[missing-argument] except ValidationError as error: print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]") for error in error.errors(): print(f"[bold]{error['loc'][0]}[/]: [yellow]{error['msg']}[/]") print() print( "Run [bold]iconoclast --help[/] or see [bold]config.default.toml[/] for details about configuration parameters." ) return print(f"Using random seed [bold]{settings.seed}[/]") set_random_seed(settings.seed) # Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py if torch.cuda.is_available(): count = torch.cuda.device_count() total_vram = sum(torch.cuda.mem_get_info(i)[1] for i in range(count)) print( f"Detected [bold]{count}[/] CUDA device(s) ({total_vram / (1024**3):.2f} GB total VRAM):" ) for i in range(count): vram = torch.cuda.mem_get_info(i)[1] / (1024**3) print( f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/] ({vram:.2f} GB)" ) elif is_xpu_available(): count = torch.xpu.device_count() print(f"Detected [bold]{count}[/] XPU device(s):") for i in range(count): print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]") elif is_mlu_available(): count = torch.mlu.device_count() # ty:ignore[unresolved-attribute] print(f"Detected [bold]{count}[/] MLU device(s):") for i in range(count): print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute] elif is_sdaa_available(): count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute] print(f"Detected [bold]{count}[/] SDAA device(s):") for i in range(count): print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute] elif is_musa_available(): count = torch.musa.device_count() # ty:ignore[unresolved-attribute] print(f"Detected [bold]{count}[/] MUSA device(s):") for i in range(count): print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute] elif is_npu_available(): print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute] elif torch.backends.mps.is_available(): print("Detected [bold]1[/] MPS device (Apple Metal)") else: print( "[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]" ) # We don't need gradients as we only do inference. torch.set_grad_enabled(False) # While determining the optimal batch size, we will try many different batch sizes, # resulting in many computation graphs being compiled. Raising the limit (default = 8) # avoids errors from TorchDynamo assuming that something is wrong because we # recompile too often. torch._dynamo.config.cache_size_limit = 64 # Silence warning spam from Transformers. # In my entire career I've never seen a useful warning from that library. transformers.logging.set_verbosity_error() # Another library that generates warning spam. logging.getLogger("lm_eval").setLevel(logging.ERROR) # We do our own trial logging, so we don't need the INFO messages # about parameters and results. optuna.logging.set_verbosity(optuna.logging.WARNING) # Silence the warning about multivariate TPE being experimental. warnings.filterwarnings("ignore", category=ExperimentalWarning) os.makedirs(settings.study_checkpoint_dir, exist_ok=True) study_checkpoint_file = os.path.join( settings.study_checkpoint_dir, "".join( [(c if (c.isalnum() or c in ["_", "-"]) else "--") for c in settings.model] ) + ".jsonl", ) lock_obj = JournalFileOpenLock(study_checkpoint_file) backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj) storage = JournalStorage(backend) try: existing_study = storage.get_all_studies()[0] except IndexError: existing_study = None if existing_study is not None and settings.evaluate_model is None: if settings.exit_after_optimization: print() print( "[yellow]Existing study detected.[/] Reusing stored settings in batch mode." ) settings = Settings.model_validate_json(existing_study.user_attrs["settings"]) else: choices = [] if existing_study.user_attrs["finished"]: print() print( ( "[green]You have already processed this model.[/] " "You can show the results from the previous run, allowing you to export models or to run additional trials. " "Alternatively, you can ignore the previous run and start from scratch. " "This will delete the checkpoint file and all results from the previous run." ) ) choices.append( Choice( title="Show the results from the previous run", value="continue", ) ) else: print() print( ( "[yellow]You have already processed this model, but the run was interrupted.[/] " "You can continue the previous run from where it stopped. This will override any specified settings. " "Alternatively, you can ignore the previous run and start from scratch. " "This will delete the checkpoint file and all results from the previous run." ) ) choices.append( Choice( title="Continue the previous run", value="continue", ) ) choices.append( Choice( title="Ignore the previous run and start from scratch", value="restart", ) ) choices.append( Choice( title="Exit program", value="", ) ) print() choice = prompt_select("How would you like to proceed?", choices) if choice == "continue": settings = Settings.model_validate_json( existing_study.user_attrs["settings"] ) elif choice == "restart": os.unlink(study_checkpoint_file) backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj) storage = JournalStorage(backend) elif choice is None or choice == "": return model = Model(settings) print() print_memory_usage() print() print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...") good_prompts = load_prompts(settings, settings.good_prompts) print(f"* [bold]{len(good_prompts)}[/] prompts loaded") print() print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...") bad_prompts = load_prompts(settings, settings.bad_prompts) print(f"* [bold]{len(bad_prompts)}[/] prompts loaded") if settings.batch_size == 0: print() print("Determining optimal batch size...") batch_size = 1 best_batch_size = -1 best_performance = -1 while batch_size <= settings.max_batch_size: print(f"* Trying batch size [bold]{batch_size}[/]... ", end="") prompts = good_prompts * math.ceil(batch_size / len(good_prompts)) prompts = prompts[:batch_size] try: # Warmup run to build the computation graph so that part isn't benchmarked. model.get_responses(prompts) start_time = time.perf_counter() responses = model.get_responses(prompts) end_time = time.perf_counter() except Exception as error: if batch_size == 1: # Even a batch size of 1 already fails. # We cannot recover from this. raise print(f"[red]Failed[/] ({error})") break response_lengths = [ len(model.tokenizer.encode(response)) for response in responses ] performance = sum(response_lengths) / (end_time - start_time) print(f"[green]Ok[/] ([bold]{performance:.0f}[/] tokens/s)") if performance > best_performance: best_batch_size = batch_size best_performance = performance batch_size *= 2 settings.batch_size = best_batch_size print(f"* Chosen batch size: [bold]{settings.batch_size}[/]") print() print("Checking for common response prefix...") prefix_check_prompts = good_prompts[:100] + bad_prompts[:100] responses = model.get_responses_batched(prefix_check_prompts) # Despite being located in os.path, commonprefix actually performs # a naive string operation without any path-specific logic, # which is exactly what we need here. Trailing spaces are removed # to avoid issues where multiple different tokens that all start # with a space character lead to the common prefix ending with # a space, which would result in an uncommon tokenization. model.response_prefix = commonprefix(responses).rstrip(" ") # Suppress CoT output. recheck_prefix = False if model.response_prefix: # When using any of the predefined prefixes below, we need to check that # the prefix is actually complete (e.g. not missing a trailing newline). recheck_prefix = True if model.response_prefix.startswith(""): # Most thinking models. model.response_prefix = "" elif model.response_prefix.startswith("<|channel|>analysis<|message|>"): # gpt-oss. model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>" elif model.response_prefix.startswith(""): # Unknown, suggested by user. model.response_prefix = "" elif model.response_prefix.startswith("[THINK]"): # Unknown, suggested by user. model.response_prefix = "[THINK][/THINK]" else: recheck_prefix = False if model.response_prefix: print(f"* Prefix found: [bold]{model.response_prefix!r}[/]") else: print("* None found") if recheck_prefix: print("* Rechecking with prefix...") responses = model.get_responses_batched(prefix_check_prompts) additional_prefix = commonprefix(responses).rstrip(" ") if additional_prefix: model.response_prefix += additional_prefix print(f"* Extended prefix found: [bold]{model.response_prefix!r}[/]") evaluator = Evaluator(settings, model) if settings.evaluate_model is not None: print() print(f"Loading model [bold]{settings.evaluate_model}[/]...") settings.model = settings.evaluate_model model.reset_model() print("* Evaluating...") evaluator.get_score() return print() print("Calculating per-layer refusal directions...") print("* Obtaining residuals for good prompts...") good_residuals = model.get_residuals_batched(good_prompts) print("* Obtaining residuals for bad prompts...") bad_residuals = model.get_residuals_batched(bad_prompts) good_means = good_residuals.mean(dim=0) direction_candidates = compute_direction_candidates( good_residuals, bad_residuals, settings.direction_variance_floor, ) if settings.benign_subspace_rank > 0: benign_subspace_basis = compute_benign_subspace_basis( good_residuals, settings.benign_subspace_rank, ) if benign_subspace_basis is not None: direction_candidates = { method: project_directions_out_of_subspace( candidate, benign_subspace_basis, ) for method, candidate in direction_candidates.items() } if settings.orthogonalize_direction: # Implements https://huggingface.co/blog/grimjim/projected-abliteration # for every candidate direction set rather than only the mean-difference one. direction_candidates = { method: orthogonalize_directions(candidate, good_means) for method, candidate in direction_candidates.items() } analyzer = Analyzer(settings, model, good_residuals, bad_residuals) if settings.print_residual_geometry: analyzer.print_residual_geometry() if settings.plot_residuals: analyzer.plot_residuals() # We don't need the residuals after computing refusal directions. del good_residuals, bad_residuals, analyzer empty_cache() components = model.get_abliterable_components() last_layer_index = len(model.get_layers()) - 1 trial_index = 0 start_index = 0 start_time = time.perf_counter() def build_direction_tensor( direction_method: DirectionMethod, direction_blend: float, ) -> torch.Tensor: if direction_method == DirectionMethod.HYBRID: return blend_directions( direction_candidates[DirectionMethod.MEAN], direction_candidates[DirectionMethod.VARIANCE], direction_blend, ) return direction_candidates[direction_method] def get_trial_direction_indices( trial: Trial, ) -> float | None | dict[str, float | None]: component_direction_scopes = trial.user_attrs.get("component_direction_scopes") if isinstance(component_direction_scopes, dict): component_direction_indices = trial.user_attrs.get( "component_direction_indices", {}, ) return { component: ( None if component_direction_scopes.get(component) == "per layer" else component_direction_indices.get(component) ) for component in components } direction_scope = trial.params.get( "direction_scope", trial.user_attrs.get("direction_scope", "global"), ) if direction_scope == "per layer": return None return trial.params.get( "direction_index", trial.user_attrs.get("direction_index"), ) def get_trial_refusal_directions( trial: Trial, ) -> torch.Tensor | dict[str, torch.Tensor]: component_direction_methods = trial.user_attrs.get("component_direction_methods") if isinstance(component_direction_methods, dict): component_direction_blends = trial.user_attrs.get( "component_direction_blends", {}, ) return { component: build_direction_tensor( DirectionMethod(component_direction_methods[component]), float(component_direction_blends.get(component, 0.0)), ) for component in components } direction_method = DirectionMethod( trial.params.get( "direction_method", trial.user_attrs.get("direction_method", DirectionMethod.MEAN.value), ) ) direction_blend = float( trial.params.get( "direction_blend", trial.user_attrs.get("direction_blend", 0.0), ) ) return build_direction_tensor(direction_method, direction_blend) def recompute_objective_score( behavior_score: float, kl_divergence: float, ) -> tuple[float, float]: if kl_divergence >= settings.kl_divergence_target: kld_score = kl_divergence / settings.kl_divergence_scale else: kld_score = ( behavior_score * settings.kl_divergence_target / settings.kl_divergence_scale ) return (kld_score, behavior_score) def should_run_merged_validation(refusals: int) -> bool: if settings.merged_validation_interval <= 0: return False if trial_index == settings.n_trials: return True if refusals > evaluator.base_refusals: return False return trial_index % settings.merged_validation_interval == 0 def compute_merge_penalty( adapter_result: Any, merged_result: Any, ) -> float: harmful_count = max(settings.merged_validation_subset_size, 1) merge_refusal_gap = max( 0, merged_result.refusals - adapter_result.refusals - settings.merged_validation_refusal_tolerance, ) / harmful_count merge_overrefusal_gap = max( 0, merged_result.overrefusals - adapter_result.overrefusals - settings.merged_validation_overrefusal_tolerance, ) / harmful_count merge_disclaimer_gap = max( 0, merged_result.harmful_marker_hits - adapter_result.harmful_marker_hits - settings.merged_validation_disclaimer_tolerance, ) / harmful_count merge_compliance_gap = max( 0.0, adapter_result.harmful_compliance_score - merged_result.harmful_compliance_score - settings.merged_validation_compliance_tolerance, ) return settings.merged_consistency_penalty * ( merge_refusal_gap + merge_overrefusal_gap + merge_disclaimer_gap + merge_compliance_gap ) def objective(trial: Trial) -> tuple[float, float]: nonlocal trial_index trial_index += 1 trial.set_user_attr("index", trial_index) component_direction_scopes = {} component_direction_indices = {} component_direction_methods = {} component_direction_blends = {} if settings.component_specific_directions: refusal_directions = {} direction_index: float | None | dict[str, float | None] = {} for component in components: direction_scope = trial.suggest_categorical( f"{component}.direction_scope", ["global", "per layer"], ) sampled_direction_index = trial.suggest_float( f"{component}.direction_index", 0.4 * last_layer_index, 0.9 * last_layer_index, ) if direction_scope == "per layer": component_direction_indices[component] = None else: component_direction_indices[component] = sampled_direction_index direction_method = DirectionMethod( trial.suggest_categorical( f"{component}.direction_method", [method.value for method in DirectionMethod], ) ) direction_blend = trial.suggest_float( f"{component}.direction_blend", 0.0, 1.0, ) refusal_directions[component] = build_direction_tensor( direction_method, direction_blend, ) cast(dict[str, float | None], direction_index)[component] = ( component_direction_indices[component] ) component_direction_scopes[component] = direction_scope component_direction_methods[component] = direction_method.value component_direction_blends[component] = direction_blend trial.set_user_attr("direction_scope", "mixed") trial.set_user_attr("direction_index", None) trial.set_user_attr("direction_method", "mixed") trial.set_user_attr("direction_blend", 0.0) trial.set_user_attr( "component_direction_scopes", component_direction_scopes, ) trial.set_user_attr( "component_direction_indices", component_direction_indices, ) trial.set_user_attr( "component_direction_methods", component_direction_methods, ) trial.set_user_attr( "component_direction_blends", component_direction_blends, ) else: direction_scope = trial.suggest_categorical( "direction_scope", ["global", "per layer"], ) direction_index = trial.suggest_float( "direction_index", 0.4 * last_layer_index, 0.9 * last_layer_index, ) if direction_scope == "per layer": direction_index = None direction_method = DirectionMethod( trial.suggest_categorical( "direction_method", [method.value for method in DirectionMethod], ) ) direction_blend = trial.suggest_float( "direction_blend", 0.0, 1.0, ) refusal_directions = build_direction_tensor(direction_method, direction_blend) trial.set_user_attr("direction_scope", direction_scope) trial.set_user_attr("direction_index", direction_index) trial.set_user_attr("direction_method", direction_method.value) trial.set_user_attr("direction_blend", direction_blend) parameters = {} for component in components: # The parameter ranges are based on experiments with various models # and much wider ranges. They are not set in stone and might have to be # adjusted for future models. max_weight = trial.suggest_float( f"{component}.max_weight", 0.5, 2.0, ) max_weight_position = trial.suggest_float( f"{component}.max_weight_position", 0.4 * last_layer_index, 1.0 * last_layer_index, ) # For sampling purposes, min_weight is expressed as a fraction of max_weight, # again because multivariate TPE doesn't support variable-range parameters. # The value is transformed into the actual min_weight value below. min_weight = trial.suggest_float( f"{component}.min_weight", 0.0, 1.0, ) min_weight_distance = trial.suggest_float( f"{component}.min_weight_distance", 1.0, 0.6 * last_layer_index, ) parameters[component] = AbliterationParameters( max_weight=max_weight, max_weight_position=max_weight_position, min_weight=(min_weight * max_weight), min_weight_distance=min_weight_distance, ) trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()}) print() print( f"Running trial [bold]{trial_index}[/] of [bold]{settings.n_trials}[/]..." ) print("* Parameters:") for name, value in get_trial_parameters(trial).items(): print(f" * {name} = [bold]{value}[/]") print("* Resetting model...") model.reset_model() print("* Abliterating...") model.abliterate(refusal_directions, direction_index, parameters) print("* Evaluating...") evaluation_result = evaluator.get_score() score = evaluation_result.score kl_divergence = evaluation_result.kl_divergence refusals = evaluation_result.refusals overrefusals = evaluation_result.overrefusals harmful_marker_hits = evaluation_result.harmful_marker_hits harmful_compliance_score = evaluation_result.harmful_compliance_score merge_penalty = 0.0 if should_run_merged_validation(refusals): print("* Validating merged-model subset...") merged_result = model.evaluate_merged( lambda: evaluator.get_subset_score(settings.merged_validation_subset_size) ) merge_penalty = compute_merge_penalty(evaluation_result, merged_result) trial.set_user_attr("merged_validated", True) trial.set_user_attr("merged_refusals", merged_result.refusals) trial.set_user_attr("merged_overrefusals", merged_result.overrefusals) trial.set_user_attr( "merged_harmful_marker_hits", merged_result.harmful_marker_hits, ) trial.set_user_attr( "merged_harmful_compliance_score", merged_result.harmful_compliance_score, ) trial.set_user_attr("merge_penalty", merge_penalty) score = recompute_objective_score( evaluation_result.behavior_score + merge_penalty, kl_divergence, ) else: trial.set_user_attr("merged_validated", False) trial.set_user_attr("merge_penalty", 0.0) elapsed_time = time.perf_counter() - start_time remaining_time = (elapsed_time / (trial_index - start_index)) * ( settings.n_trials - trial_index ) print() print(f"[grey50]Elapsed time: [bold]{format_duration(elapsed_time)}[/][/]") if trial_index < settings.n_trials: print( f"[grey50]Estimated remaining time: [bold]{format_duration(remaining_time)}[/][/]" ) print_memory_usage() trial.set_user_attr("kl_divergence", kl_divergence) trial.set_user_attr("refusals", refusals) trial.set_user_attr("overrefusals", overrefusals) trial.set_user_attr("harmful_marker_hits", harmful_marker_hits) trial.set_user_attr("harmful_compliance_score", harmful_compliance_score) trial.set_user_attr("objective_regime", evaluation_result.objective_regime) trial.set_user_attr("harmful_axis_metrics", evaluation_result.harmful_axis_metrics) return score def objective_wrapper(trial: Trial) -> tuple[float, float]: try: return objective(trial) except KeyboardInterrupt: # Stop the study gracefully on Ctrl+C. trial.study.stop() raise TrialPruned() study = optuna.create_study( sampler=TPESampler( seed=settings.seed, n_startup_trials=settings.n_startup_trials, n_ei_candidates=128, multivariate=True, ), directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], storage=storage, study_name="iconoclast", load_if_exists=True, ) study.set_user_attr("settings", settings.model_dump_json()) study.set_user_attr("finished", False) def count_completed_trials() -> int: # Count number of complete trials to compute trials to run. return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials]) def get_completed_trials() -> list[Trial]: return [t for t in study.trials if t.state == TrialState.COMPLETE] if settings.warm_start_trials and count_completed_trials() == 0: print() print( f"Queueing [bold]{len(settings.warm_start_trials)}[/] warm-start trial(s)..." ) for warm_start_trial in settings.warm_start_trials: study.enqueue_trial(dict(warm_start_trial.params)) if warm_start_trial.description: print(f"* {warm_start_trial.description}") def get_pareto_trials(completed_trials: list[Trial]) -> list[Trial]: sorted_trials = sorted( completed_trials, key=lambda trial: ( trial.user_attrs["refusals"], trial.user_attrs.get("overrefusals", 0), trial.user_attrs["kl_divergence"], ), ) min_divergence = math.inf min_overrefusals = math.inf best_trials = [] for trial in sorted_trials: kl_divergence = trial.user_attrs["kl_divergence"] overrefusals = trial.user_attrs.get("overrefusals", 0) if ( overrefusals < min_overrefusals or ( overrefusals == min_overrefusals and kl_divergence < min_divergence ) ): min_overrefusals = overrefusals min_divergence = kl_divergence best_trials.append(trial) return best_trials def serialize_trial(trial: Trial) -> dict[str, Any]: return { "index": trial.user_attrs.get("index"), "refusals": trial.user_attrs["refusals"], "overrefusals": trial.user_attrs.get("overrefusals", 0), "harmful_marker_hits": trial.user_attrs.get("harmful_marker_hits", 0), "harmful_compliance_score": trial.user_attrs.get( "harmful_compliance_score", 0.0 ), "objective_regime": trial.user_attrs.get("objective_regime"), "merge_penalty": trial.user_attrs.get("merge_penalty", 0.0), "kl_divergence": trial.user_attrs["kl_divergence"], "direction_method": trial.user_attrs.get("direction_method"), "direction_scope": trial.user_attrs.get("direction_scope"), "direction_index": trial.user_attrs.get("direction_index"), "direction_blend": trial.user_attrs.get("direction_blend"), "parameters": trial.user_attrs.get("parameters", {}), "harmful_axis_metrics": trial.user_attrs.get("harmful_axis_metrics", {}), } def write_batch_summary(best_trials: list[Trial]) -> Path: summary_path = Path(settings.study_checkpoint_dir, "batch_summary.json") summary_path.parent.mkdir(parents=True, exist_ok=True) summary = { "model": settings.model, "study_checkpoint_dir": settings.study_checkpoint_dir, "base_metrics": { "refusals": evaluator.base_refusals, "overrefusals": evaluator.base_overrefusals, "harmful_marker_hits": evaluator.base_harmful_marker_hits, "harmful_compliance_score": evaluator.base_harmful_compliance_score, "objective_regime": evaluator.objective_regime.value, }, "pareto_trials": [serialize_trial(trial) for trial in best_trials], } summary_path.write_text(json.dumps(summary, indent=2)) return summary_path start_index = trial_index = count_completed_trials() if start_index > 0: print() print("Resuming existing study.") try: study.optimize( objective_wrapper, n_trials=settings.n_trials - count_completed_trials(), ) except KeyboardInterrupt: # This additional handler takes care of the small chance that KeyboardInterrupt # is raised just between trials, which wouldn't be caught by the handler # defined in objective_wrapper above. pass if count_completed_trials() == settings.n_trials: study.set_user_attr("finished", True) if settings.exit_after_optimization: completed_trials = get_completed_trials() print() print("[bold green]Optimization finished in batch mode.[/]") if not completed_trials: print("[yellow]No completed trials were recorded.[/]") return best_trials = get_pareto_trials(completed_trials) summary_path = write_batch_summary(best_trials) print("Top Pareto trials:") for trial in best_trials[:5]: print( f"* Trial {trial.user_attrs['index']}: " f"refusals={trial.user_attrs['refusals']}/{len(evaluator.bad_prompts)}, " f"overrefusals={trial.user_attrs.get('overrefusals', 0)}/{len(evaluator.good_prompts)}, " f"markers={trial.user_attrs.get('harmful_marker_hits', 0)}, " f"compliance={trial.user_attrs.get('harmful_compliance_score', 0.0):.3f}, " f"kl={trial.user_attrs['kl_divergence']:.4f}" ) print(f"Batch summary written to [bold]{summary_path}[/].") return while True: # If no trials at all have been evaluated, the study must have been stopped # by pressing Ctrl+C while the first trial was running. In this case, we just # re-raise the interrupt to invoke the standard handler defined below. completed_trials = get_completed_trials() if not completed_trials: raise KeyboardInterrupt # Get the Pareto front of trials. We can't use study.best_trials directly # as get_score() doesn't return the pure KL divergence and refusal count. # Note: Unlike study.best_trials, this does not handle objective constraints. best_trials = get_pareto_trials(completed_trials) choices = [ Choice( title=( f"[Trial {trial.user_attrs['index']:>3}] " f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, " f"Overrefusals: {trial.user_attrs.get('overrefusals', 0):>2}/{len(evaluator.good_prompts)}, " f"Markers: {trial.user_attrs.get('harmful_marker_hits', 0):>3}, " f"Compliance: {trial.user_attrs.get('harmful_compliance_score', 0.0):.2f}, " f"MergePen: {trial.user_attrs.get('merge_penalty', 0.0):.2f}, " f"Method: {trial.user_attrs.get('direction_method', 'mean')}, " f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}" ), value=trial, ) for trial in best_trials ] choices.append( Choice( title="Run additional trials", value="continue", ) ) choices.append( Choice( title="Exit program", value="", ) ) print() print("[bold green]Optimization finished![/]") print() print( ( "The following trials resulted in Pareto optimal combinations of refusals and KL divergence. " "Trials are ordered to prefer lower harmful refusals first, then lower overrefusals, then lower KL divergence. " "After selecting a trial, you will be able to save the model, upload it to Hugging Face, " "or chat with it to test how well it works. You can return to this menu later to select a different trial. " "[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]" ) ) while True: print() trial = prompt_select("Which trial do you want to use?", choices) if trial == "continue": while True: try: n_additional_trials = prompt_text( "How many additional trials do you want to run?" ) if n_additional_trials is None or n_additional_trials == "": n_additional_trials = 0 break n_additional_trials = int(n_additional_trials) if n_additional_trials > 0: break print("[red]Please enter a number greater than 0.[/]") except ValueError: print("[red]Please enter a number.[/]") if n_additional_trials == 0: continue settings.n_trials += n_additional_trials study.set_user_attr("settings", settings.model_dump_json()) study.set_user_attr("finished", False) try: study.optimize( objective_wrapper, n_trials=settings.n_trials - count_completed_trials(), ) except KeyboardInterrupt: pass if count_completed_trials() == settings.n_trials: study.set_user_attr("finished", True) break elif trial is None or trial == "": return print() print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...") print("* Parameters:") for name, value in get_trial_parameters(trial).items(): print(f" * {name} = [bold]{value}[/]") print("* Resetting model...") model.reset_model() print("* Abliterating...") model.abliterate( get_trial_refusal_directions(trial), get_trial_direction_indices(trial), { k: AbliterationParameters(**v) for k, v in trial.user_attrs["parameters"].items() }, ) while True: print() action = prompt_select( "What do you want to do with the decensored model?", [ "Save the model to a local folder", "Upload the model to Hugging Face", "Chat with the model", "Benchmark the model", "Return to the trial selection menu", ], ) if action is None or action == "Return to the trial selection menu": break # All actions are wrapped in a try/except block so that if an error occurs, # another action can be tried, instead of the program crashing and losing # the optimized model. try: match action: case "Save the model to a local folder": save_directory = prompt_path("Path to the folder:") if not save_directory: continue strategy = obtain_merge_strategy(settings) if strategy is None: continue if strategy == "adapter": print("Saving LoRA adapter...") model.model.save_pretrained(save_directory) else: print("Saving merged model...") merged_model = model.get_merged_model() merged_model.save_pretrained(save_directory) del merged_model empty_cache() model.tokenizer.save_pretrained(save_directory) print(f"Model saved to [bold]{save_directory}[/].") case "Upload the model to Hugging Face": # We don't use huggingface_hub.login() because that stores the token on disk, # and since this program will often be run on rented or shared GPU servers, # it's better to not persist credentials. token = huggingface_hub.get_token() if not token: token = prompt_password("Hugging Face access token:") if not token: continue user = huggingface_hub.whoami(token) fullname = user.get( "fullname", user.get("name", "unknown user"), ) email = user.get("email", "no email found") print(f"Logged in as [bold]{fullname} ({email})[/]") repo_id = prompt_text( "Name of repository:", default=f"{user['name']}/{Path(settings.model).name}-iconoclast", ) visibility = prompt_select( "Should the repository be public or private?", [ "Public", "Private", ], ) if visibility is None: continue private = visibility == "Private" strategy = obtain_merge_strategy(settings) if strategy is None: continue if strategy == "adapter": print("Uploading LoRA adapter...") model.model.push_to_hub( repo_id, private=private, token=token, ) else: print("Uploading merged model...") merged_model = model.get_merged_model() merged_model.push_to_hub( repo_id, private=private, token=token, ) del merged_model empty_cache() model.tokenizer.push_to_hub( repo_id, private=private, token=token, ) # If the model path exists locally and includes the # card, use it directly. If the model path doesn't # exist locally, it can be assumed to be a model # hosted on the Hugging Face Hub, in which case # we can retrieve the model card. model_path = Path(settings.model) if model_path.exists(): card_path = ( model_path / huggingface_hub.constants.REPOCARD_NAME ) if card_path.exists(): card = ModelCard.load(card_path) else: card = None else: card = ModelCard.load(settings.model) if card is not None: if card.data is None: card.data = ModelCardData() if card.data.tags is None: card.data.tags = [] card.data.tags.append("iconoclast") card.data.tags.append("uncensored") card.data.tags.append("decensored") card.data.tags.append("abliterated") card.text = ( get_readme_intro( settings, trial, evaluator.base_refusals, evaluator.base_overrefusals, evaluator.good_prompts, evaluator.bad_prompts, ) + card.text ) card.push_to_hub(repo_id, token=token) print(f"Model uploaded to [bold]{repo_id}[/].") case "Chat with the model": print() print( "[cyan]Press Ctrl+C at any time to return to the menu.[/]" ) chat = [ {"role": "system", "content": settings.system_prompt}, ] while True: try: message = prompt_text( "User:", qmark=">", unsafe=True, ) if not message: break chat.append({"role": "user", "content": message}) print("[bold]Assistant:[/] ", end="") response = model.stream_chat_response(chat) chat.append( {"role": "assistant", "content": response} ) except (KeyboardInterrupt, EOFError): # Ctrl+C/Ctrl+D break case "Benchmark the model": import lm_eval from lm_eval.models.huggingface import HFLM benchmarks = questionary.checkbox( "Which benchmarks do you want to run?", [ Choice( title=f"{benchmark.name}: {benchmark.description}", value=benchmark, ) for benchmark in settings.benchmarks ], style=Style([("highlighted", "reverse")]), ).ask() if not benchmarks: continue scope = prompt_select( ( "Do you want to benchmark the original model along with the decensored model? " "Benchmarking both models allows you to compare the scores, but it takes twice as much time." ), [ "Benchmark only the decensored model", "Benchmark both models", ], ) if scope is None: continue benchmark_original_model = scope == "Benchmark both models" hflm = HFLM( pretrained=model.model, # ty:ignore[invalid-argument-type] tokenizer=model.tokenizer, # ty:ignore[invalid-argument-type] ) table = Table() table.add_column("Benchmark") table.add_column("Metric") if benchmark_original_model: table.add_column("This model", justify="right") table.add_column("Original model", justify="right") else: table.add_column("Value", justify="right") try: first_benchmark = True for benchmark in benchmarks: print( f"Running benchmark [bold]{benchmark.name}[/]..." ) def get_results() -> dict[str, Any]: results = lm_eval.simple_evaluate( model=hflm, tasks=[benchmark.task], batch_size="auto", ) return results["results"][benchmark.task] results = get_results() if benchmark_original_model: with model.model.disable_adapter(): # ty:ignore[call-non-callable] original_results = get_results() first_row = True for metric, value in results.items(): if metric != "alias": if first_row and not first_benchmark: if benchmark_original_model: table.add_row("", "", "", "") else: table.add_row("", "", "") def format_value(value: Any) -> str: if isinstance( value, (float, np.floating), ): return f"{value:.4f}" else: return f"{value}" cells = [ benchmark.name if first_row else "", metric, format_value(value), ] if benchmark_original_model: cells.append( format_value( original_results[metric] ) ) table.add_row(*cells) first_row = False first_benchmark = False except KeyboardInterrupt: pass # The benchmark run might have been cancelled by the user # before any benchmark was completed, so we only print results # if there actually are some. if table.rows: print(table) except Exception as error: print(f"[red]Error: {error}[/]") def main(): # Install Rich traceback handler. install() try: run() except BaseException as error: # Transformers appears to handle KeyboardInterrupt (or BaseException) # internally in some places, which can re-raise a different error in the handler, # masking the root cause. We therefore check both the error itself and its context. if isinstance(error, KeyboardInterrupt) or isinstance( error.__context__, KeyboardInterrupt ): print() print("[red]Shutting down...[/]") else: raise