Instructions to use HaadesX/Iconoclast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use HaadesX/Iconoclast with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("HaadesX/Iconoclast", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # SPDX-License-Identifier: AGPL-3.0-or-later | |
| # Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors | |
| import math | |
| from contextlib import suppress | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Type, TypeVar, cast | |
| import torch | |
| import torch.linalg as LA | |
| import torch.nn.functional as F | |
| from peft import LoraConfig, PeftModel, get_peft_model | |
| from peft.tuners.lora.layer import Linear | |
| from torch import FloatTensor, LongTensor, Tensor | |
| from torch.nn import Module, ModuleList | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoModelForImageTextToText, | |
| AutoTokenizer, | |
| BatchEncoding, | |
| BitsAndBytesConfig, | |
| PretrainedConfig, | |
| PreTrainedModel, | |
| PreTrainedTokenizerBase, | |
| TextStreamer, | |
| ) | |
| from transformers.generation import ( | |
| GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import] | |
| ) | |
| from .config import QuantizationMethod, RowNormalization, Settings | |
| from .utils import Prompt, batchify, empty_cache, print | |
| try: | |
| import bitsandbytes as bnb | |
| except ImportError: | |
| bnb = None | |
| # Monkey-patch torch.nn.Module to add set_submodule if it is missing. | |
| # This is required for compatibility with some newer transformers/peft features | |
| # and specific model architectures (e.g., Qwen 3.5, Mistral) in newer environments. | |
| if not hasattr(torch.nn.Module, "set_submodule"): | |
| def set_submodule(self, target: str, module: torch.nn.Module) -> None: | |
| parts = target.rsplit(".", 1) | |
| if len(parts) > 1: | |
| parent = self.get_submodule(parts[0]) | |
| setattr(parent, parts[1], module) | |
| else: | |
| setattr(self, target, module) | |
| torch.nn.Module.set_submodule = set_submodule # type: ignore | |
| TResult = TypeVar("TResult") | |
| def get_model_class( | |
| model: str, | |
| ) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]: | |
| configs = PretrainedConfig.get_config_dict(model) | |
| if any([("vision_config" in config) for config in configs]): | |
| return AutoModelForImageTextToText | |
| else: | |
| return AutoModelForCausalLM | |
| class AbliterationParameters: | |
| max_weight: float | |
| max_weight_position: float | |
| min_weight: float | |
| min_weight_distance: float | |
| class Model: | |
| model: PreTrainedModel | PeftModel | |
| tokenizer: PreTrainedTokenizerBase | |
| peft_config: LoraConfig | |
| def __init__(self, settings: Settings): | |
| self.settings = settings | |
| self.response_prefix = "" | |
| self.needs_reload = False | |
| print() | |
| print(f"Loading model [bold]{settings.model}[/]...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| settings.model, | |
| trust_remote_code=settings.trust_remote_code, | |
| ) | |
| # Fallback for tokenizers that don't declare a special pad token. | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # CRITICAL: Always use left-padding for decoder-only models during generation. | |
| # Right-padding causes empty outputs because the model sees PAD tokens | |
| # after the prompt and thinks the sequence is complete. | |
| self.tokenizer.padding_side = "left" | |
| self.model = None # ty:ignore[invalid-assignment] | |
| self.max_memory = ( | |
| {int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()} | |
| if settings.max_memory | |
| else None | |
| ) | |
| self.trusted_models = {settings.model: settings.trust_remote_code} | |
| if self.settings.evaluate_model is not None: | |
| self.trusted_models[settings.evaluate_model] = settings.trust_remote_code | |
| for dtype in settings.dtypes: | |
| print(f"* Trying dtype [bold]{dtype}[/]... ", end="") | |
| try: | |
| quantization_config = self._get_quantization_config(dtype) | |
| extra_kwargs = {} | |
| # Only include quantization_config if it's not None | |
| # (some models like gpt-oss have issues with explicit None). | |
| if quantization_config is not None: | |
| extra_kwargs["quantization_config"] = quantization_config | |
| self.model = get_model_class(settings.model).from_pretrained( | |
| settings.model, | |
| dtype=dtype, | |
| device_map=settings.device_map, | |
| max_memory=self.max_memory, | |
| trust_remote_code=self.trusted_models.get(settings.model), | |
| **extra_kwargs, | |
| ) | |
| # If we reach this point and the model requires trust_remote_code, | |
| # either the user accepted, or settings.trust_remote_code is True. | |
| if self.trusted_models.get(settings.model) is None: | |
| self.trusted_models[settings.model] = True | |
| # A test run can reveal dtype-related problems such as the infamous | |
| # "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0" | |
| # (https://github.com/meta-llama/llama/issues/380). | |
| self.generate( | |
| [ | |
| Prompt( | |
| system=settings.system_prompt, | |
| user="What is 1+1?", | |
| ) | |
| ], | |
| max_new_tokens=1, | |
| ) | |
| except Exception as error: | |
| self.model = None # ty:ignore[invalid-assignment] | |
| empty_cache() | |
| print(f"[red]Failed[/] ({error})") | |
| continue | |
| if settings.quantization == QuantizationMethod.BNB_4BIT: | |
| print("[green]Ok[/] (quantized to 4-bit precision)") | |
| else: | |
| print("[green]Ok[/]") | |
| break | |
| if self.model is None: | |
| raise Exception("Failed to load model with all configured dtypes.") | |
| self._apply_lora() | |
| # LoRA B matrices are initialized to zero by default in PEFT, | |
| # so we don't need to do anything manually. | |
| print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers") | |
| print("* Abliterable components:") | |
| all_components = {} | |
| for layer_index in range(len(self.get_layers())): | |
| for component, modules in self.get_layer_modules(layer_index).items(): | |
| if component not in all_components: | |
| all_components[component] = 0 | |
| all_components[component] += len(modules) | |
| for component, count in all_components.items(): | |
| print(f" * [bold]{component}[/]: [bold]{count}[/] modules total") | |
| def _apply_lora(self): | |
| # Guard against calling this method at the wrong time. | |
| assert isinstance(self.model, PreTrainedModel) | |
| # Always use LoRA adapters for abliteration (faster reload, no weight modification). | |
| # Collect actual leaf module names from the model for LoRA targeting. | |
| # This is more robust than splitting component keys (e.g. "attn.o_proj" -> "o_proj") | |
| # because hybrid models like Qwen3.5 MoE have modules with different names | |
| # across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers). | |
| target_modules_set: set[str] = set() | |
| for layer_index, layer in enumerate(self.get_layers()): | |
| module_id_to_leaf_name = { | |
| id(module): module_name.split(".")[-1] | |
| for module_name, module in layer.named_modules() | |
| } | |
| for modules in self.get_layer_modules(layer_index).values(): | |
| for module in modules: | |
| if id(module) in module_id_to_leaf_name: | |
| target_modules_set.add(module_id_to_leaf_name[id(module)]) | |
| target_modules = list(target_modules_set) | |
| if self.settings.row_normalization != RowNormalization.FULL: | |
| # Rank 1 is sufficient for directional ablation without renormalization. | |
| lora_rank = 1 | |
| else: | |
| # Row magnitude preservation introduces nonlinear effects. | |
| lora_rank = self.settings.full_normalization_lora_rank | |
| self.peft_config = LoraConfig( | |
| r=lora_rank, | |
| target_modules=target_modules, | |
| lora_alpha=lora_rank, # Apply adapter at full strength. | |
| lora_dropout=0, | |
| bias="none", | |
| # Even if we're using AutoModelForImageTextToText, this is still correct, | |
| # as VL models are typically just causal LMs with an added image encoder. | |
| task_type="CAUSAL_LM", | |
| ) | |
| # self.peft_config is a LoraConfig object rather than a dictionary, | |
| # so the result is a PeftModel rather than a PeftMixedModel. | |
| self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config)) | |
| print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})") | |
| def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None: | |
| """ | |
| Creates quantization config based on settings. | |
| Args: | |
| dtype: The dtype string (e.g., "auto", "bfloat16") | |
| Returns: | |
| BitsAndBytesConfig or None | |
| """ | |
| if self.settings.quantization == QuantizationMethod.BNB_4BIT: | |
| if bnb is None: | |
| raise ImportError( | |
| "bitsandbytes is required for bnb_4bit quantization but is not installed." | |
| ) | |
| # BitsAndBytesConfig expects a torch.dtype, not a string. | |
| if dtype == "auto": | |
| compute_dtype = torch.bfloat16 | |
| else: | |
| compute_dtype = getattr(torch, dtype) | |
| return BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| return None | |
| def get_merged_model(self) -> PreTrainedModel: | |
| # Guard against calling this method at the wrong time. | |
| assert isinstance(self.model, PeftModel) | |
| # Check if we need special handling for quantized models | |
| if self.settings.quantization == QuantizationMethod.BNB_4BIT: | |
| # Quantized models need special handling - we must reload the base model | |
| # in full precision to merge the LoRA adapters | |
| # Get the adapter state dict before we do anything | |
| adapter_state = {} | |
| for name, param in self.model.named_parameters(): | |
| if "lora_" in name: | |
| adapter_state[name] = param.data.clone().cpu() | |
| # Load base model in full precision on CPU to avoid VRAM issues | |
| print("* Loading base model on CPU (this may take a while)...") | |
| base_model = get_model_class(self.settings.model).from_pretrained( | |
| self.settings.model, | |
| torch_dtype=self.model.dtype, | |
| device_map="cpu", | |
| trust_remote_code=self.trusted_models.get(self.settings.model), | |
| ) | |
| # Apply LoRA adapters to the CPU model | |
| print("* Applying LoRA adapters...") | |
| peft_model = get_peft_model(base_model, self.peft_config) | |
| # Copy the trained adapter weights | |
| for name, param in peft_model.named_parameters(): | |
| if name in adapter_state: | |
| param.data = adapter_state[name].to(param.device) | |
| # Merge and unload | |
| print("* Merging LoRA adapters into base model...") | |
| merged_model = peft_model.merge_and_unload() | |
| return merged_model | |
| else: | |
| # Non-quantized model - can merge directly | |
| print("* Merging LoRA adapters into base model...") | |
| merged_model = self.model.merge_and_unload() | |
| # merge_and_unload() modifies self.model in-place, destroying LoRA adapters. | |
| # Mark for full reload if user switches trials later. | |
| self.needs_reload = True | |
| return merged_model | |
| def reset_model(self): | |
| """ | |
| Resets the model to a clean state for the next trial or evaluation. | |
| Behavior: | |
| - Fast path: If the same model is loaded and doesn't need full reload, | |
| resets LoRA adapter weights to zero (identity transformation). | |
| - Slow path: If switching models or after merge_and_unload(), | |
| performs full model reload with quantization config. | |
| """ | |
| current_model = getattr(self.model.config, "name_or_path", None) | |
| if current_model == self.settings.model and not self.needs_reload: | |
| # Reset LoRA adapters to zero (identity transformation) | |
| for name, module in self.model.named_modules(): | |
| if "lora_B" in name and hasattr(module, "weight"): | |
| torch.nn.init.zeros_(module.weight) | |
| return | |
| dtype = self.model.dtype | |
| # Purge existing model object from memory to make space. | |
| self.model = None # ty:ignore[invalid-assignment] | |
| empty_cache() | |
| quantization_config = self._get_quantization_config(str(dtype).split(".")[-1]) | |
| # Build kwargs, only include quantization_config if it's not None | |
| extra_kwargs = {} | |
| if quantization_config is not None: | |
| extra_kwargs["quantization_config"] = quantization_config | |
| self.model = get_model_class(self.settings.model).from_pretrained( | |
| self.settings.model, | |
| dtype=dtype, | |
| device_map=self.settings.device_map, | |
| max_memory=self.max_memory, | |
| trust_remote_code=self.trusted_models.get(self.settings.model), | |
| **extra_kwargs, | |
| ) | |
| self._apply_lora() | |
| self.needs_reload = False | |
| def evaluate_merged(self, callback: Callable[[], TResult]) -> TResult: | |
| merged_model = None | |
| try: | |
| merged_model = self.get_merged_model() | |
| self.needs_reload = True | |
| self.model = merged_model | |
| return callback() | |
| finally: | |
| if merged_model is not None: | |
| del merged_model | |
| self.reset_model() | |
| def get_layers(self) -> ModuleList: | |
| model = self.model | |
| # Unwrap PeftModel (always true after _apply_lora) | |
| if isinstance(model, PeftModel): | |
| model = model.base_model.model | |
| # Most multimodal models. | |
| with suppress(Exception): | |
| return model.model.language_model.layers | |
| # Text-only models. | |
| return model.model.layers | |
| def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]: | |
| layer = self.get_layers()[layer_index] | |
| modules = {} | |
| def try_add(component: str, module: Any): | |
| # Only add if it's a proper nn.Module (PEFT can wrap these with LoRA) | |
| if isinstance(module, Module): | |
| if component not in modules: | |
| modules[component] = [] | |
| modules[component].append(module) | |
| else: | |
| # Assert for unexpected types (catches architecture changes) | |
| assert not isinstance(module, Tensor), ( | |
| f"Unexpected Tensor in {component} - expected nn.Module" | |
| ) | |
| # Standard self-attention out-projection (most models). | |
| with suppress(Exception): | |
| try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute] | |
| # Qwen3.5 MoE hybrid layers use GatedDeltaNet (linear attention) instead | |
| # of standard self-attention, so self_attn.o_proj doesn't exist on those layers. | |
| with suppress(Exception): | |
| try_add("attn.o_proj", layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute] | |
| # Most dense models. | |
| with suppress(Exception): | |
| try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute] | |
| # Some MoE models (e.g. Qwen3). | |
| with suppress(Exception): | |
| for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable] | |
| try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute] | |
| # Phi-3.5-MoE (and possibly others). | |
| with suppress(Exception): | |
| for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable] | |
| try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute] | |
| # Granite MoE Hybrid - attention layers with shared_mlp. | |
| with suppress(Exception): | |
| try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute] | |
| # Granite MoE Hybrid - MoE layers with experts. | |
| with suppress(Exception): | |
| for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable] | |
| try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute] | |
| # We need at least one module across all components for abliteration to work. | |
| total_modules = sum(len(mods) for mods in modules.values()) | |
| assert total_modules > 0, "No abliterable modules found in layer" | |
| return modules | |
| def get_abliterable_components(self) -> list[str]: | |
| # Scan all layers because hybrid models (e.g. Qwen3.5 MoE) have different | |
| # components on different layers (some have self_attn, others linear_attn). | |
| components: set[str] = set() | |
| for layer_index in range(len(self.get_layers())): | |
| components.update(self.get_layer_modules(layer_index).keys()) | |
| return sorted(components) | |
| def abliterate( | |
| self, | |
| refusal_directions: Tensor | dict[str, Tensor], | |
| direction_index: float | None | dict[str, float | None], | |
| parameters: dict[str, AbliterationParameters], | |
| ): | |
| global_direction_cache: dict[str, Tensor | None] = {} | |
| # Note that some implementations of abliteration also orthogonalize | |
| # the embedding matrix, but it's unclear if that has any benefits. | |
| for layer_index in range(len(self.get_layers())): | |
| for component, modules in self.get_layer_modules(layer_index).items(): | |
| if isinstance(refusal_directions, dict): | |
| component_refusal_directions = refusal_directions[component] | |
| else: | |
| component_refusal_directions = refusal_directions | |
| if isinstance(direction_index, dict): | |
| component_direction_index = direction_index[component] | |
| else: | |
| component_direction_index = direction_index | |
| if component not in global_direction_cache: | |
| if component_direction_index is None: | |
| global_direction_cache[component] = None | |
| else: | |
| # The index must be shifted by 1 because the first element | |
| # of refusal_directions is the direction for the embeddings. | |
| weight, index = math.modf(component_direction_index + 1) | |
| global_direction_cache[component] = F.normalize( | |
| component_refusal_directions[int(index)].lerp( | |
| component_refusal_directions[int(index) + 1], | |
| weight, | |
| ), | |
| p=2, | |
| dim=0, | |
| ) | |
| refusal_direction = global_direction_cache[component] | |
| params = parameters[component] | |
| # Type inference fails here for some reason. | |
| distance = cast(float, abs(layer_index - params.max_weight_position)) | |
| # Don't orthogonalize layers that are more than | |
| # min_weight_distance away from max_weight_position. | |
| if distance > params.min_weight_distance: | |
| continue | |
| # Interpolate linearly between max_weight and min_weight | |
| # over min_weight_distance. | |
| weight = params.max_weight + (distance / params.min_weight_distance) * ( | |
| params.min_weight - params.max_weight | |
| ) | |
| if refusal_direction is None: | |
| # The index must be shifted by 1 because the first element | |
| # of refusal_directions is the direction for the embeddings. | |
| layer_refusal_direction = component_refusal_directions[layer_index + 1] | |
| else: | |
| layer_refusal_direction = refusal_direction | |
| for module in modules: | |
| # FIXME: This cast is potentially invalid, because the program logic | |
| # does not guarantee that the module is of type Linear, and in fact | |
| # the retrieved modules might not conform to the interface assumed | |
| # below (though they do in practice). However, this is difficult | |
| # to fix cleanly, because get_layer_modules is called twice on | |
| # different model configurations, and PEFT employs different | |
| # module types depending on the chosen quantization. | |
| module = cast(Linear, module) | |
| # LoRA abliteration: delta W = -lambda * v * (v^T W) | |
| # lora_B = -lambda * v | |
| # lora_A = v^T W | |
| # Use the FP32 refusal direction directly (no downcast/upcast) | |
| # and move to the correct device. | |
| v = layer_refusal_direction.to(module.weight.device) | |
| # Get W (dequantize if necessary). | |
| # | |
| # FIXME: This cast is valid only under the assumption that the original | |
| # module wrapped by the LoRA adapter has a weight attribute. | |
| # See the comment above for why this is currently not guaranteed. | |
| base_weight = cast(Tensor, module.base_layer.weight) | |
| quant_state = getattr(base_weight, "quant_state", None) | |
| if quant_state is None: | |
| W = base_weight.to(torch.float32) | |
| else: | |
| if bnb is None: | |
| raise ImportError( | |
| "bitsandbytes is required for 4-bit model editing but is not installed." | |
| ) | |
| # 4-bit quantization. | |
| # This cast is always valid. Type inference fails here because the | |
| # bnb.functional module is not found by ty for some reason. | |
| W = cast( | |
| Tensor, | |
| bnb.functional.dequantize_4bit( # ty:ignore[possibly-missing-attribute] | |
| base_weight.data, | |
| quant_state, | |
| ).to(torch.float32), | |
| ) | |
| # Flatten weight matrix to (out_features, in_features). | |
| W = W.view(W.shape[0], -1) | |
| if self.settings.row_normalization != RowNormalization.NONE: | |
| # Keep a reference to the original weight matrix so we can subtract it later. | |
| W_org = W | |
| # Get the row norms. | |
| W_row_norms = LA.vector_norm(W, dim=1, keepdim=True) | |
| # Normalize the weight matrix along the rows. | |
| W = F.normalize(W, p=2, dim=1) | |
| # Calculate lora_A = v^T W | |
| # v is (d_out,), W is (d_out, d_in) | |
| # v @ W -> (d_in,) | |
| lora_A = (v @ W).view(1, -1) | |
| # Calculate lora_B = -weight * v | |
| # v is (d_out,) | |
| lora_B = (-weight * v).view(-1, 1) | |
| if self.settings.row_normalization == RowNormalization.PRE: | |
| # Make the LoRA adapter apply to the original weight matrix. | |
| lora_B = W_row_norms * lora_B | |
| elif self.settings.row_normalization == RowNormalization.FULL: | |
| # Approximates https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration | |
| W = W + lora_B @ lora_A | |
| # Normalize the adjusted weight matrix along the rows. | |
| W = F.normalize(W, p=2, dim=1) | |
| # Restore the original row norms of the weight matrix. | |
| W = W * W_row_norms | |
| # Subtract the original matrix to turn W into a delta. | |
| W = W - W_org | |
| # Use a low-rank SVD to get an approximation of the matrix. | |
| r = self.peft_config.r | |
| U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6) | |
| # Truncate it to the part we want to store in the LoRA adapter. | |
| # Note: svd_lowrank actually returns V, so transpose it to get Vh. | |
| U = U[:, :r] | |
| S = S[:r] | |
| Vh = Vh[:, :r].T | |
| # Transfer it into the LoRA adapter components. Split the singular values | |
| # evenly between the two components to keep their norms balanced and avoid | |
| # potential issues with numerical stability. | |
| sqrt_S = torch.sqrt(S) | |
| lora_B = U @ torch.diag(sqrt_S) | |
| lora_A = torch.diag(sqrt_S) @ Vh | |
| # Assign to adapters. The adapter name is "default", because that's | |
| # what PEFT uses when no name is explicitly specified, as above. | |
| # These casts are therefore valid. | |
| weight_A = cast(Tensor, module.lora_A["default"].weight) | |
| weight_B = cast(Tensor, module.lora_B["default"].weight) | |
| weight_A.data = lora_A.to(weight_A.dtype) | |
| weight_B.data = lora_B.to(weight_B.dtype) | |
| def generate( | |
| self, | |
| prompts: list[Prompt], | |
| **kwargs: Any, | |
| ) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]: | |
| # Standard chat structure. | |
| chats = [] | |
| for prompt in prompts: | |
| chat = [] | |
| if prompt.system: | |
| chat.append({"role": "system", "content": prompt.system}) | |
| chat.append({"role": "user", "content": prompt.user}) | |
| chats.append(chat) | |
| try: | |
| # This cast is valid because list[str] is the return type | |
| # for batched operation with tokenize=False. | |
| chat_prompts = cast( | |
| list[str], | |
| self.tokenizer.apply_chat_template( | |
| chats, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ), | |
| ) | |
| except Exception: | |
| # Fallback for models that do not support system roles (e.g. Gemma 2). | |
| # Merge the system prompt into the first user message. | |
| chats = [] | |
| for prompt in prompts: | |
| content = prompt.user | |
| if prompt.system: | |
| content = f"{prompt.system}\n\n{prompt.user}" | |
| chats.append([{"role": "user", "content": content}]) | |
| chat_prompts = cast( | |
| list[str], | |
| self.tokenizer.apply_chat_template( | |
| chats, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ), | |
| ) | |
| if self.response_prefix: | |
| # Append the common response prefix to the prompts so that evaluation happens | |
| # at the point where responses start to differ for different prompts. | |
| chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts] | |
| inputs = self.tokenizer( | |
| chat_prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| return_token_type_ids=False, | |
| ).to(self.model.device) | |
| # FIXME: The type checker has been disabled here because of the extremely complex | |
| # interplay between different generate() signatures and dynamic delegation. | |
| outputs = self.model.generate( | |
| **inputs, | |
| **kwargs, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| do_sample=False, # Use greedy decoding to ensure deterministic outputs. | |
| ) # ty:ignore[call-non-callable] | |
| return inputs, outputs | |
| def get_responses( | |
| self, | |
| prompts: list[Prompt], | |
| skip_special_tokens: bool = False, | |
| ) -> list[str]: | |
| inputs, outputs = self.generate( | |
| prompts, | |
| max_new_tokens=self.settings.max_response_length, | |
| ) | |
| return self.tokenizer.batch_decode( | |
| # Extract the newly generated part. | |
| # This cast is valid because the input_ids property is a Tensor | |
| # if the tokenizer is invoked with return_tensors="pt", as above. | |
| outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :], | |
| skip_special_tokens=skip_special_tokens, | |
| ) | |
| def get_responses_batched( | |
| self, | |
| prompts: list[Prompt], | |
| skip_special_tokens: bool = False, | |
| ) -> list[str]: | |
| responses = [] | |
| for batch in batchify(prompts, self.settings.batch_size): | |
| for response in self.get_responses( | |
| batch, | |
| skip_special_tokens=skip_special_tokens, | |
| ): | |
| responses.append(response) | |
| return responses | |
| def get_residuals(self, prompts: list[Prompt]) -> Tensor: | |
| # We only generate one token, and we return the residual vectors | |
| # at that token position, for each prompt and layer. | |
| _, outputs = self.generate( | |
| prompts, | |
| max_new_tokens=1, | |
| output_hidden_states=True, | |
| return_dict_in_generate=True, | |
| ) | |
| # This cast is valid because GenerateDecoderOnlyOutput is the return type | |
| # of model.generate with return_dict_in_generate=True. | |
| outputs = cast(GenerateDecoderOnlyOutput, outputs) | |
| # Hidden states for the first (only) generated token. | |
| # This cast is valid because we passed output_hidden_states=True above. | |
| hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0] | |
| # The returned tensor has shape (prompt, layer, component). | |
| residuals = torch.stack( | |
| # layer_hidden_states has shape (prompt, position, component), | |
| # so this extracts the hidden states at the end of each prompt, | |
| # and stacks them up over the layers. | |
| [layer_hidden_states[:, -1, :] for layer_hidden_states in hidden_states], | |
| dim=1, | |
| ) | |
| # Upcast the data type to avoid precision (bfloat16) or range (float16) | |
| # problems during calculations involving residual vectors. | |
| residuals = residuals.to(torch.float32) | |
| if 0 <= self.settings.winsorization_quantile < 1: | |
| # Apply symmetric winsorization to each layer of the per-prompt residuals. | |
| abs_residuals = torch.abs(residuals) | |
| # Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals. | |
| thresholds = torch.quantile( | |
| abs_residuals, | |
| self.settings.winsorization_quantile, | |
| dim=2, | |
| keepdim=True, | |
| ) | |
| return torch.clamp(residuals, -thresholds, thresholds) | |
| return residuals | |
| def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: | |
| residuals = [] | |
| for batch in batchify(prompts, self.settings.batch_size): | |
| residuals.append(self.get_residuals(batch)) | |
| return torch.cat(residuals, dim=0) | |
| # We work with logprobs rather than probabilities for numerical stability | |
| # when computing the KL divergence. | |
| def get_logprobs(self, prompts: list[Prompt]) -> Tensor: | |
| # We only generate one token, and we return the (log) probability distributions | |
| # over the vocabulary at that token position, for each prompt. | |
| _, outputs = self.generate( | |
| prompts, | |
| max_new_tokens=1, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| ) | |
| # This cast is valid because GenerateDecoderOnlyOutput is the return type | |
| # of model.generate with return_dict_in_generate=True. | |
| outputs = cast(GenerateDecoderOnlyOutput, outputs) | |
| # Logits for the first (only) generated token. | |
| # This cast is valid because we passed output_scores=True above. | |
| logits = cast(tuple[FloatTensor], outputs.scores)[0] | |
| # The returned tensor has shape (prompt, token). | |
| return F.log_softmax(logits, dim=-1) | |
| def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor: | |
| logprobs = [] | |
| for batch in batchify(prompts, self.settings.batch_size): | |
| logprobs.append(self.get_logprobs(batch)) | |
| return torch.cat(logprobs, dim=0) | |
| def stream_chat_response(self, chat: list[dict[str, str]]) -> str: | |
| # This cast is valid because str is the return type | |
| # for single-chat operation with tokenize=False. | |
| chat_prompt = cast( | |
| str, | |
| self.tokenizer.apply_chat_template( | |
| chat, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ), | |
| ) | |
| inputs = self.tokenizer( | |
| chat_prompt, | |
| return_tensors="pt", | |
| return_token_type_ids=False, | |
| ).to(self.model.device) | |
| streamer = TextStreamer( | |
| # The TextStreamer constructor annotates this parameter with the AutoTokenizer | |
| # type, which makes no sense because AutoTokenizer is a factory class, | |
| # not a base class that tokenizers inherit from. | |
| self.tokenizer, # ty:ignore[invalid-argument-type] | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| # FIXME: The type checker has been disabled here because of the extremely complex | |
| # interplay between different generate() signatures and dynamic delegation. | |
| outputs = self.model.generate( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=4096, | |
| ) # ty:ignore[call-non-callable] | |
| # This cast is valid because str is the return type | |
| # when passing a sequence of token IDs. | |
| return cast( | |
| str, | |
| self.tokenizer.decode( | |
| outputs[0, inputs["input_ids"].shape[1] :], | |
| skip_special_tokens=True, | |
| ), | |
| ) | |