"""Tests for visualization module (non-interactive, save-to-file).""" from __future__ import annotations import tempfile from pathlib import Path import pytest import torch from obliteratus.analysis.cross_layer import CrossLayerAlignmentAnalyzer from obliteratus.analysis.activation_probing import ActivationProbe from obliteratus.analysis.visualization import ( plot_refusal_topology, plot_cross_layer_heatmap, plot_angular_drift, plot_probe_dashboard, plot_defense_radar, ) from obliteratus.analysis.defense_robustness import DefenseProfile @pytest.fixture def tmp_dir(): with tempfile.TemporaryDirectory() as d: yield Path(d) def _make_refusal_data(n_layers=6, hidden_dim=16): """Create test refusal directions and means.""" torch.manual_seed(42) directions = {} harmful_means = {} harmless_means = {} for i in range(n_layers): d = torch.randn(hidden_dim) directions[i] = d / d.norm() base = torch.randn(hidden_dim) harmless_means[i] = base.unsqueeze(0) harmful_means[i] = (base + (2.0 if i in [2, 3, 4] else 0.3) * directions[i]).unsqueeze(0) strong_layers = [2, 3, 4] return directions, harmful_means, harmless_means, strong_layers class TestRefusalTopology: def test_plot_saves_file(self, tmp_dir): directions, h_means, b_means, strong = _make_refusal_data() path = tmp_dir / "topology.png" plot_refusal_topology( directions, h_means, b_means, strong, output_path=path ) assert path.exists() assert path.stat().st_size > 0 def test_plot_returns_figure(self, tmp_dir): directions, h_means, b_means, strong = _make_refusal_data() fig = plot_refusal_topology( directions, h_means, b_means, strong, output_path=tmp_dir / "test.png" ) assert fig is not None class TestCrossLayerHeatmap: def test_plot_saves_file(self, tmp_dir): torch.manual_seed(42) directions = {i: torch.randn(16) for i in range(6)} analyzer = CrossLayerAlignmentAnalyzer() result = analyzer.analyze(directions) path = tmp_dir / "heatmap.png" plot_cross_layer_heatmap(result, output_path=path) assert path.exists() class TestAngularDrift: def test_plot_saves_file(self, tmp_dir): torch.manual_seed(42) directions = {i: torch.randn(16) for i in range(8)} analyzer = CrossLayerAlignmentAnalyzer() result = analyzer.analyze(directions) path = tmp_dir / "drift.png" plot_angular_drift(result, output_path=path) assert path.exists() class TestProbeDashboard: def test_plot_saves_file(self, tmp_dir): torch.manual_seed(42) harmful = {i: [torch.randn(8) for _ in range(3)] for i in range(4)} harmless = {i: [torch.randn(8) for _ in range(3)] for i in range(4)} dirs = {i: torch.randn(8) for i in range(4)} probe = ActivationProbe() result = probe.probe_all_layers(harmful, harmless, dirs) path = tmp_dir / "probe.png" plot_probe_dashboard(result, output_path=path) assert path.exists() class TestDefenseRadar: def test_plot_saves_file(self, tmp_dir): profile = DefenseProfile( model_name="test-model", alignment_type_estimate="RLHF-like", refusal_concentration=0.4, refusal_layer_spread=5, mean_refusal_strength=2.0, max_refusal_strength=4.0, self_repair_estimate=0.6, entanglement_score=0.3, estimated_robustness="medium", ) path = tmp_dir / "radar.png" plot_defense_radar(profile, output_path=path) assert path.exists()