# SPDX-License-Identifier: AGPL-3.0-or-later # Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors import math from contextlib import suppress from dataclasses import dataclass from typing import Any, Callable, Type, TypeVar, cast import torch import torch.linalg as LA import torch.nn.functional as F from peft import LoraConfig, PeftModel, get_peft_model from peft.tuners.lora.layer import Linear from torch import FloatTensor, LongTensor, Tensor from torch.nn import Module, ModuleList from transformers import ( AutoModelForCausalLM, AutoModelForImageTextToText, AutoTokenizer, BatchEncoding, BitsAndBytesConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase, TextStreamer, ) from transformers.generation import ( GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import] ) from .config import QuantizationMethod, RowNormalization, Settings from .utils import Prompt, batchify, empty_cache, print try: import bitsandbytes as bnb except ImportError: bnb = None # Monkey-patch torch.nn.Module to add set_submodule if it is missing. # This is required for compatibility with some newer transformers/peft features # and specific model architectures (e.g., Qwen 3.5, Mistral) in newer environments. if not hasattr(torch.nn.Module, "set_submodule"): def set_submodule(self, target: str, module: torch.nn.Module) -> None: parts = target.rsplit(".", 1) if len(parts) > 1: parent = self.get_submodule(parts[0]) setattr(parent, parts[1], module) else: setattr(self, target, module) torch.nn.Module.set_submodule = set_submodule # type: ignore TResult = TypeVar("TResult") def get_model_class( model: str, ) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]: configs = PretrainedConfig.get_config_dict(model) if any([("vision_config" in config) for config in configs]): return AutoModelForImageTextToText else: return AutoModelForCausalLM @dataclass class AbliterationParameters: max_weight: float max_weight_position: float min_weight: float min_weight_distance: float class Model: model: PreTrainedModel | PeftModel tokenizer: PreTrainedTokenizerBase peft_config: LoraConfig def __init__(self, settings: Settings): self.settings = settings self.response_prefix = "" self.needs_reload = False print() print(f"Loading model [bold]{settings.model}[/]...") self.tokenizer = AutoTokenizer.from_pretrained( settings.model, trust_remote_code=settings.trust_remote_code, ) # Fallback for tokenizers that don't declare a special pad token. if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # CRITICAL: Always use left-padding for decoder-only models during generation. # Right-padding causes empty outputs because the model sees PAD tokens # after the prompt and thinks the sequence is complete. self.tokenizer.padding_side = "left" self.model = None # ty:ignore[invalid-assignment] self.max_memory = ( {int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()} if settings.max_memory else None ) self.trusted_models = {settings.model: settings.trust_remote_code} if self.settings.evaluate_model is not None: self.trusted_models[settings.evaluate_model] = settings.trust_remote_code for dtype in settings.dtypes: print(f"* Trying dtype [bold]{dtype}[/]... ", end="") try: quantization_config = self._get_quantization_config(dtype) extra_kwargs = {} # Only include quantization_config if it's not None # (some models like gpt-oss have issues with explicit None). if quantization_config is not None: extra_kwargs["quantization_config"] = quantization_config self.model = get_model_class(settings.model).from_pretrained( settings.model, dtype=dtype, device_map=settings.device_map, max_memory=self.max_memory, trust_remote_code=self.trusted_models.get(settings.model), **extra_kwargs, ) # If we reach this point and the model requires trust_remote_code, # either the user accepted, or settings.trust_remote_code is True. if self.trusted_models.get(settings.model) is None: self.trusted_models[settings.model] = True # A test run can reveal dtype-related problems such as the infamous # "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0" # (https://github.com/meta-llama/llama/issues/380). self.generate( [ Prompt( system=settings.system_prompt, user="What is 1+1?", ) ], max_new_tokens=1, ) except Exception as error: self.model = None # ty:ignore[invalid-assignment] empty_cache() print(f"[red]Failed[/] ({error})") continue if settings.quantization == QuantizationMethod.BNB_4BIT: print("[green]Ok[/] (quantized to 4-bit precision)") else: print("[green]Ok[/]") break if self.model is None: raise Exception("Failed to load model with all configured dtypes.") self._apply_lora() # LoRA B matrices are initialized to zero by default in PEFT, # so we don't need to do anything manually. print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers") print("* Abliterable components:") all_components = {} for layer_index in range(len(self.get_layers())): for component, modules in self.get_layer_modules(layer_index).items(): if component not in all_components: all_components[component] = 0 all_components[component] += len(modules) for component, count in all_components.items(): print(f" * [bold]{component}[/]: [bold]{count}[/] modules total") def _apply_lora(self): # Guard against calling this method at the wrong time. assert isinstance(self.model, PreTrainedModel) # Always use LoRA adapters for abliteration (faster reload, no weight modification). # Collect actual leaf module names from the model for LoRA targeting. # This is more robust than splitting component keys (e.g. "attn.o_proj" -> "o_proj") # because hybrid models like Qwen3.5 MoE have modules with different names # across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers). target_modules_set: set[str] = set() for layer_index, layer in enumerate(self.get_layers()): module_id_to_leaf_name = { id(module): module_name.split(".")[-1] for module_name, module in layer.named_modules() } for modules in self.get_layer_modules(layer_index).values(): for module in modules: if id(module) in module_id_to_leaf_name: target_modules_set.add(module_id_to_leaf_name[id(module)]) target_modules = list(target_modules_set) if self.settings.row_normalization != RowNormalization.FULL: # Rank 1 is sufficient for directional ablation without renormalization. lora_rank = 1 else: # Row magnitude preservation introduces nonlinear effects. lora_rank = self.settings.full_normalization_lora_rank self.peft_config = LoraConfig( r=lora_rank, target_modules=target_modules, lora_alpha=lora_rank, # Apply adapter at full strength. lora_dropout=0, bias="none", # Even if we're using AutoModelForImageTextToText, this is still correct, # as VL models are typically just causal LMs with an added image encoder. task_type="CAUSAL_LM", ) # self.peft_config is a LoraConfig object rather than a dictionary, # so the result is a PeftModel rather than a PeftMixedModel. self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config)) print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})") def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None: """ Creates quantization config based on settings. Args: dtype: The dtype string (e.g., "auto", "bfloat16") Returns: BitsAndBytesConfig or None """ if self.settings.quantization == QuantizationMethod.BNB_4BIT: if bnb is None: raise ImportError( "bitsandbytes is required for bnb_4bit quantization but is not installed." ) # BitsAndBytesConfig expects a torch.dtype, not a string. if dtype == "auto": compute_dtype = torch.bfloat16 else: compute_dtype = getattr(torch, dtype) return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) return None def get_merged_model(self) -> PreTrainedModel: # Guard against calling this method at the wrong time. assert isinstance(self.model, PeftModel) # Check if we need special handling for quantized models if self.settings.quantization == QuantizationMethod.BNB_4BIT: # Quantized models need special handling - we must reload the base model # in full precision to merge the LoRA adapters # Get the adapter state dict before we do anything adapter_state = {} for name, param in self.model.named_parameters(): if "lora_" in name: adapter_state[name] = param.data.clone().cpu() # Load base model in full precision on CPU to avoid VRAM issues print("* Loading base model on CPU (this may take a while)...") base_model = get_model_class(self.settings.model).from_pretrained( self.settings.model, torch_dtype=self.model.dtype, device_map="cpu", trust_remote_code=self.trusted_models.get(self.settings.model), ) # Apply LoRA adapters to the CPU model print("* Applying LoRA adapters...") peft_model = get_peft_model(base_model, self.peft_config) # Copy the trained adapter weights for name, param in peft_model.named_parameters(): if name in adapter_state: param.data = adapter_state[name].to(param.device) # Merge and unload print("* Merging LoRA adapters into base model...") merged_model = peft_model.merge_and_unload() return merged_model else: # Non-quantized model - can merge directly print("* Merging LoRA adapters into base model...") merged_model = self.model.merge_and_unload() # merge_and_unload() modifies self.model in-place, destroying LoRA adapters. # Mark for full reload if user switches trials later. self.needs_reload = True return merged_model def reset_model(self): """ Resets the model to a clean state for the next trial or evaluation. Behavior: - Fast path: If the same model is loaded and doesn't need full reload, resets LoRA adapter weights to zero (identity transformation). - Slow path: If switching models or after merge_and_unload(), performs full model reload with quantization config. """ current_model = getattr(self.model.config, "name_or_path", None) if current_model == self.settings.model and not self.needs_reload: # Reset LoRA adapters to zero (identity transformation) for name, module in self.model.named_modules(): if "lora_B" in name and hasattr(module, "weight"): torch.nn.init.zeros_(module.weight) return dtype = self.model.dtype # Purge existing model object from memory to make space. self.model = None # ty:ignore[invalid-assignment] empty_cache() quantization_config = self._get_quantization_config(str(dtype).split(".")[-1]) # Build kwargs, only include quantization_config if it's not None extra_kwargs = {} if quantization_config is not None: extra_kwargs["quantization_config"] = quantization_config self.model = get_model_class(self.settings.model).from_pretrained( self.settings.model, dtype=dtype, device_map=self.settings.device_map, max_memory=self.max_memory, trust_remote_code=self.trusted_models.get(self.settings.model), **extra_kwargs, ) self._apply_lora() self.needs_reload = False def evaluate_merged(self, callback: Callable[[], TResult]) -> TResult: merged_model = None try: merged_model = self.get_merged_model() self.needs_reload = True self.model = merged_model return callback() finally: if merged_model is not None: del merged_model self.reset_model() def get_layers(self) -> ModuleList: model = self.model # Unwrap PeftModel (always true after _apply_lora) if isinstance(model, PeftModel): model = model.base_model.model # Most multimodal models. with suppress(Exception): return model.model.language_model.layers # Text-only models. return model.model.layers def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]: layer = self.get_layers()[layer_index] modules = {} def try_add(component: str, module: Any): # Only add if it's a proper nn.Module (PEFT can wrap these with LoRA) if isinstance(module, Module): if component not in modules: modules[component] = [] modules[component].append(module) else: # Assert for unexpected types (catches architecture changes) assert not isinstance(module, Tensor), ( f"Unexpected Tensor in {component} - expected nn.Module" ) # Standard self-attention out-projection (most models). with suppress(Exception): try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute] # Qwen3.5 MoE hybrid layers use GatedDeltaNet (linear attention) instead # of standard self-attention, so self_attn.o_proj doesn't exist on those layers. with suppress(Exception): try_add("attn.o_proj", layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute] # Most dense models. with suppress(Exception): try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute] # Some MoE models (e.g. Qwen3). with suppress(Exception): for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable] try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute] # Phi-3.5-MoE (and possibly others). with suppress(Exception): for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable] try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute] # Granite MoE Hybrid - attention layers with shared_mlp. with suppress(Exception): try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute] # Granite MoE Hybrid - MoE layers with experts. with suppress(Exception): for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable] try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute] # We need at least one module across all components for abliteration to work. total_modules = sum(len(mods) for mods in modules.values()) assert total_modules > 0, "No abliterable modules found in layer" return modules def get_abliterable_components(self) -> list[str]: # Scan all layers because hybrid models (e.g. Qwen3.5 MoE) have different # components on different layers (some have self_attn, others linear_attn). components: set[str] = set() for layer_index in range(len(self.get_layers())): components.update(self.get_layer_modules(layer_index).keys()) return sorted(components) def abliterate( self, refusal_directions: Tensor | dict[str, Tensor], direction_index: float | None | dict[str, float | None], parameters: dict[str, AbliterationParameters], ): global_direction_cache: dict[str, Tensor | None] = {} # Note that some implementations of abliteration also orthogonalize # the embedding matrix, but it's unclear if that has any benefits. for layer_index in range(len(self.get_layers())): for component, modules in self.get_layer_modules(layer_index).items(): if isinstance(refusal_directions, dict): component_refusal_directions = refusal_directions[component] else: component_refusal_directions = refusal_directions if isinstance(direction_index, dict): component_direction_index = direction_index[component] else: component_direction_index = direction_index if component not in global_direction_cache: if component_direction_index is None: global_direction_cache[component] = None else: # The index must be shifted by 1 because the first element # of refusal_directions is the direction for the embeddings. weight, index = math.modf(component_direction_index + 1) global_direction_cache[component] = F.normalize( component_refusal_directions[int(index)].lerp( component_refusal_directions[int(index) + 1], weight, ), p=2, dim=0, ) refusal_direction = global_direction_cache[component] params = parameters[component] # Type inference fails here for some reason. distance = cast(float, abs(layer_index - params.max_weight_position)) # Don't orthogonalize layers that are more than # min_weight_distance away from max_weight_position. if distance > params.min_weight_distance: continue # Interpolate linearly between max_weight and min_weight # over min_weight_distance. weight = params.max_weight + (distance / params.min_weight_distance) * ( params.min_weight - params.max_weight ) if refusal_direction is None: # The index must be shifted by 1 because the first element # of refusal_directions is the direction for the embeddings. layer_refusal_direction = component_refusal_directions[layer_index + 1] else: layer_refusal_direction = refusal_direction for module in modules: # FIXME: This cast is potentially invalid, because the program logic # does not guarantee that the module is of type Linear, and in fact # the retrieved modules might not conform to the interface assumed # below (though they do in practice). However, this is difficult # to fix cleanly, because get_layer_modules is called twice on # different model configurations, and PEFT employs different # module types depending on the chosen quantization. module = cast(Linear, module) # LoRA abliteration: delta W = -lambda * v * (v^T W) # lora_B = -lambda * v # lora_A = v^T W # Use the FP32 refusal direction directly (no downcast/upcast) # and move to the correct device. v = layer_refusal_direction.to(module.weight.device) # Get W (dequantize if necessary). # # FIXME: This cast is valid only under the assumption that the original # module wrapped by the LoRA adapter has a weight attribute. # See the comment above for why this is currently not guaranteed. base_weight = cast(Tensor, module.base_layer.weight) quant_state = getattr(base_weight, "quant_state", None) if quant_state is None: W = base_weight.to(torch.float32) else: if bnb is None: raise ImportError( "bitsandbytes is required for 4-bit model editing but is not installed." ) # 4-bit quantization. # This cast is always valid. Type inference fails here because the # bnb.functional module is not found by ty for some reason. W = cast( Tensor, bnb.functional.dequantize_4bit( # ty:ignore[possibly-missing-attribute] base_weight.data, quant_state, ).to(torch.float32), ) # Flatten weight matrix to (out_features, in_features). W = W.view(W.shape[0], -1) if self.settings.row_normalization != RowNormalization.NONE: # Keep a reference to the original weight matrix so we can subtract it later. W_org = W # Get the row norms. W_row_norms = LA.vector_norm(W, dim=1, keepdim=True) # Normalize the weight matrix along the rows. W = F.normalize(W, p=2, dim=1) # Calculate lora_A = v^T W # v is (d_out,), W is (d_out, d_in) # v @ W -> (d_in,) lora_A = (v @ W).view(1, -1) # Calculate lora_B = -weight * v # v is (d_out,) lora_B = (-weight * v).view(-1, 1) if self.settings.row_normalization == RowNormalization.PRE: # Make the LoRA adapter apply to the original weight matrix. lora_B = W_row_norms * lora_B elif self.settings.row_normalization == RowNormalization.FULL: # Approximates https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration W = W + lora_B @ lora_A # Normalize the adjusted weight matrix along the rows. W = F.normalize(W, p=2, dim=1) # Restore the original row norms of the weight matrix. W = W * W_row_norms # Subtract the original matrix to turn W into a delta. W = W - W_org # Use a low-rank SVD to get an approximation of the matrix. r = self.peft_config.r U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6) # Truncate it to the part we want to store in the LoRA adapter. # Note: svd_lowrank actually returns V, so transpose it to get Vh. U = U[:, :r] S = S[:r] Vh = Vh[:, :r].T # Transfer it into the LoRA adapter components. Split the singular values # evenly between the two components to keep their norms balanced and avoid # potential issues with numerical stability. sqrt_S = torch.sqrt(S) lora_B = U @ torch.diag(sqrt_S) lora_A = torch.diag(sqrt_S) @ Vh # Assign to adapters. The adapter name is "default", because that's # what PEFT uses when no name is explicitly specified, as above. # These casts are therefore valid. weight_A = cast(Tensor, module.lora_A["default"].weight) weight_B = cast(Tensor, module.lora_B["default"].weight) weight_A.data = lora_A.to(weight_A.dtype) weight_B.data = lora_B.to(weight_B.dtype) def generate( self, prompts: list[Prompt], **kwargs: Any, ) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]: # Standard chat structure. chats = [] for prompt in prompts: chat = [] if prompt.system: chat.append({"role": "system", "content": prompt.system}) chat.append({"role": "user", "content": prompt.user}) chats.append(chat) try: # This cast is valid because list[str] is the return type # for batched operation with tokenize=False. chat_prompts = cast( list[str], self.tokenizer.apply_chat_template( chats, add_generation_prompt=True, tokenize=False, ), ) except Exception: # Fallback for models that do not support system roles (e.g. Gemma 2). # Merge the system prompt into the first user message. chats = [] for prompt in prompts: content = prompt.user if prompt.system: content = f"{prompt.system}\n\n{prompt.user}" chats.append([{"role": "user", "content": content}]) chat_prompts = cast( list[str], self.tokenizer.apply_chat_template( chats, add_generation_prompt=True, tokenize=False, ), ) if self.response_prefix: # Append the common response prefix to the prompts so that evaluation happens # at the point where responses start to differ for different prompts. chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts] inputs = self.tokenizer( chat_prompts, return_tensors="pt", padding=True, return_token_type_ids=False, ).to(self.model.device) # FIXME: The type checker has been disabled here because of the extremely complex # interplay between different generate() signatures and dynamic delegation. outputs = self.model.generate( **inputs, **kwargs, pad_token_id=self.tokenizer.pad_token_id, do_sample=False, # Use greedy decoding to ensure deterministic outputs. ) # ty:ignore[call-non-callable] return inputs, outputs def get_responses( self, prompts: list[Prompt], skip_special_tokens: bool = False, ) -> list[str]: inputs, outputs = self.generate( prompts, max_new_tokens=self.settings.max_response_length, ) return self.tokenizer.batch_decode( # Extract the newly generated part. # This cast is valid because the input_ids property is a Tensor # if the tokenizer is invoked with return_tensors="pt", as above. outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :], skip_special_tokens=skip_special_tokens, ) def get_responses_batched( self, prompts: list[Prompt], skip_special_tokens: bool = False, ) -> list[str]: responses = [] for batch in batchify(prompts, self.settings.batch_size): for response in self.get_responses( batch, skip_special_tokens=skip_special_tokens, ): responses.append(response) return responses def get_residuals(self, prompts: list[Prompt]) -> Tensor: # We only generate one token, and we return the residual vectors # at that token position, for each prompt and layer. _, outputs = self.generate( prompts, max_new_tokens=1, output_hidden_states=True, return_dict_in_generate=True, ) # This cast is valid because GenerateDecoderOnlyOutput is the return type # of model.generate with return_dict_in_generate=True. outputs = cast(GenerateDecoderOnlyOutput, outputs) # Hidden states for the first (only) generated token. # This cast is valid because we passed output_hidden_states=True above. hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0] # The returned tensor has shape (prompt, layer, component). residuals = torch.stack( # layer_hidden_states has shape (prompt, position, component), # so this extracts the hidden states at the end of each prompt, # and stacks them up over the layers. [layer_hidden_states[:, -1, :] for layer_hidden_states in hidden_states], dim=1, ) # Upcast the data type to avoid precision (bfloat16) or range (float16) # problems during calculations involving residual vectors. residuals = residuals.to(torch.float32) if 0 <= self.settings.winsorization_quantile < 1: # Apply symmetric winsorization to each layer of the per-prompt residuals. abs_residuals = torch.abs(residuals) # Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals. thresholds = torch.quantile( abs_residuals, self.settings.winsorization_quantile, dim=2, keepdim=True, ) return torch.clamp(residuals, -thresholds, thresholds) return residuals def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: residuals = [] for batch in batchify(prompts, self.settings.batch_size): residuals.append(self.get_residuals(batch)) return torch.cat(residuals, dim=0) # We work with logprobs rather than probabilities for numerical stability # when computing the KL divergence. def get_logprobs(self, prompts: list[Prompt]) -> Tensor: # We only generate one token, and we return the (log) probability distributions # over the vocabulary at that token position, for each prompt. _, outputs = self.generate( prompts, max_new_tokens=1, output_scores=True, return_dict_in_generate=True, ) # This cast is valid because GenerateDecoderOnlyOutput is the return type # of model.generate with return_dict_in_generate=True. outputs = cast(GenerateDecoderOnlyOutput, outputs) # Logits for the first (only) generated token. # This cast is valid because we passed output_scores=True above. logits = cast(tuple[FloatTensor], outputs.scores)[0] # The returned tensor has shape (prompt, token). return F.log_softmax(logits, dim=-1) def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor: logprobs = [] for batch in batchify(prompts, self.settings.batch_size): logprobs.append(self.get_logprobs(batch)) return torch.cat(logprobs, dim=0) def stream_chat_response(self, chat: list[dict[str, str]]) -> str: # This cast is valid because str is the return type # for single-chat operation with tokenize=False. chat_prompt = cast( str, self.tokenizer.apply_chat_template( chat, add_generation_prompt=True, tokenize=False, ), ) inputs = self.tokenizer( chat_prompt, return_tensors="pt", return_token_type_ids=False, ).to(self.model.device) streamer = TextStreamer( # The TextStreamer constructor annotates this parameter with the AutoTokenizer # type, which makes no sense because AutoTokenizer is a factory class, # not a base class that tokenizers inherit from. self.tokenizer, # ty:ignore[invalid-argument-type] skip_prompt=True, skip_special_tokens=True, ) # FIXME: The type checker has been disabled here because of the extremely complex # interplay between different generate() signatures and dynamic delegation. outputs = self.model.generate( **inputs, streamer=streamer, max_new_tokens=4096, ) # ty:ignore[call-non-callable] # This cast is valid because str is the return type # when passing a sequence of token IDs. return cast( str, self.tokenizer.decode( outputs[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True, ), )