obliteratus / tests /conftest.py
pliny-the-prompter's picture
Upload 127 files
45113e6 verified
"""Shared pytest fixtures for the Obliteratus test suite."""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
import torch
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_model():
"""A minimal mock transformer model.
Provides:
- model.config with config.num_hidden_layers = 4
- model.named_parameters() returning fake weight tensors
"""
model = MagicMock()
# Config with num_hidden_layers
config = MagicMock()
config.num_hidden_layers = 4
model.config = config
# named_parameters returns fake weight tensors across 4 layers
fake_params = []
for layer_idx in range(4):
weight = torch.randn(768, 768)
fake_params.append((f"model.layers.{layer_idx}.self_attn.q_proj.weight", weight))
fake_params.append((f"model.layers.{layer_idx}.self_attn.v_proj.weight", weight))
fake_params.append((f"model.layers.{layer_idx}.mlp.gate_proj.weight", weight))
model.named_parameters.return_value = fake_params
return model
@pytest.fixture
def mock_tokenizer():
"""A minimal mock tokenizer with encode, decode, and apply_chat_template."""
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3, 4, 5]
tokenizer.decode.return_value = "Hello, this is a decoded string."
tokenizer.apply_chat_template.return_value = [1, 2, 3, 4, 5, 6, 7]
tokenizer.pad_token = "<pad>"
tokenizer.eos_token = "<eos>"
return tokenizer
@pytest.fixture
def refusal_direction():
"""A normalized random torch tensor of shape (768,)."""
t = torch.randn(768)
return t / t.norm()
@pytest.fixture
def activation_pair():
"""A tuple of (harmful_activations, harmless_activations) as random tensors of shape (10, 768)."""
harmful_activations = torch.randn(10, 768)
harmless_activations = torch.randn(10, 768)
return (harmful_activations, harmless_activations)
@pytest.fixture
def tmp_output_dir(tmp_path):
"""A clean temporary output directory for test artifacts."""
output_dir = tmp_path / "test_output"
output_dir.mkdir()
return output_dir