"""Shared pytest fixtures for the Obliteratus test suite.""" from __future__ import annotations from unittest.mock import MagicMock import pytest import torch # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture def mock_model(): """A minimal mock transformer model. Provides: - model.config with config.num_hidden_layers = 4 - model.named_parameters() returning fake weight tensors """ model = MagicMock() # Config with num_hidden_layers config = MagicMock() config.num_hidden_layers = 4 model.config = config # named_parameters returns fake weight tensors across 4 layers fake_params = [] for layer_idx in range(4): weight = torch.randn(768, 768) fake_params.append((f"model.layers.{layer_idx}.self_attn.q_proj.weight", weight)) fake_params.append((f"model.layers.{layer_idx}.self_attn.v_proj.weight", weight)) fake_params.append((f"model.layers.{layer_idx}.mlp.gate_proj.weight", weight)) model.named_parameters.return_value = fake_params return model @pytest.fixture def mock_tokenizer(): """A minimal mock tokenizer with encode, decode, and apply_chat_template.""" tokenizer = MagicMock() tokenizer.encode.return_value = [1, 2, 3, 4, 5] tokenizer.decode.return_value = "Hello, this is a decoded string." tokenizer.apply_chat_template.return_value = [1, 2, 3, 4, 5, 6, 7] tokenizer.pad_token = "" tokenizer.eos_token = "" return tokenizer @pytest.fixture def refusal_direction(): """A normalized random torch tensor of shape (768,).""" t = torch.randn(768) return t / t.norm() @pytest.fixture def activation_pair(): """A tuple of (harmful_activations, harmless_activations) as random tensors of shape (10, 768).""" harmful_activations = torch.randn(10, 768) harmless_activations = torch.randn(10, 768) return (harmful_activations, harmless_activations) @pytest.fixture def tmp_output_dir(tmp_path): """A clean temporary output directory for test artifacts.""" output_dir = tmp_path / "test_output" output_dir.mkdir() return output_dir