File size: 3,774 Bytes
2bc8e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45113e6
2bc8e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45113e6
2bc8e46
 
 
 
 
 
 
 
 
 
 
45113e6
2bc8e46
 
 
 
 
 
 
 
 
 
 
 
 
 
45113e6
2bc8e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45113e6
2bc8e46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Tests for visualization module (non-interactive, save-to-file)."""

from __future__ import annotations

import tempfile
from pathlib import Path

import pytest
import torch

from obliteratus.analysis.cross_layer import CrossLayerAlignmentAnalyzer
from obliteratus.analysis.activation_probing import ActivationProbe
from obliteratus.analysis.visualization import (
    plot_refusal_topology,
    plot_cross_layer_heatmap,
    plot_angular_drift,
    plot_probe_dashboard,
    plot_defense_radar,
)
from obliteratus.analysis.defense_robustness import DefenseProfile


@pytest.fixture
def tmp_dir():
    with tempfile.TemporaryDirectory() as d:
        yield Path(d)


def _make_refusal_data(n_layers=6, hidden_dim=16):
    """Create test refusal directions and means."""
    torch.manual_seed(42)
    directions = {}
    harmful_means = {}
    harmless_means = {}

    for i in range(n_layers):
        d = torch.randn(hidden_dim)
        directions[i] = d / d.norm()
        base = torch.randn(hidden_dim)
        harmless_means[i] = base.unsqueeze(0)
        harmful_means[i] = (base + (2.0 if i in [2, 3, 4] else 0.3) * directions[i]).unsqueeze(0)

    strong_layers = [2, 3, 4]
    return directions, harmful_means, harmless_means, strong_layers


class TestRefusalTopology:
    def test_plot_saves_file(self, tmp_dir):
        directions, h_means, b_means, strong = _make_refusal_data()
        path = tmp_dir / "topology.png"
        plot_refusal_topology(
            directions, h_means, b_means, strong, output_path=path
        )
        assert path.exists()
        assert path.stat().st_size > 0

    def test_plot_returns_figure(self, tmp_dir):
        directions, h_means, b_means, strong = _make_refusal_data()
        fig = plot_refusal_topology(
            directions, h_means, b_means, strong, output_path=tmp_dir / "test.png"
        )
        assert fig is not None


class TestCrossLayerHeatmap:
    def test_plot_saves_file(self, tmp_dir):
        torch.manual_seed(42)
        directions = {i: torch.randn(16) for i in range(6)}
        analyzer = CrossLayerAlignmentAnalyzer()
        result = analyzer.analyze(directions)

        path = tmp_dir / "heatmap.png"
        plot_cross_layer_heatmap(result, output_path=path)
        assert path.exists()


class TestAngularDrift:
    def test_plot_saves_file(self, tmp_dir):
        torch.manual_seed(42)
        directions = {i: torch.randn(16) for i in range(8)}
        analyzer = CrossLayerAlignmentAnalyzer()
        result = analyzer.analyze(directions)

        path = tmp_dir / "drift.png"
        plot_angular_drift(result, output_path=path)
        assert path.exists()


class TestProbeDashboard:
    def test_plot_saves_file(self, tmp_dir):
        torch.manual_seed(42)
        harmful = {i: [torch.randn(8) for _ in range(3)] for i in range(4)}
        harmless = {i: [torch.randn(8) for _ in range(3)] for i in range(4)}
        dirs = {i: torch.randn(8) for i in range(4)}

        probe = ActivationProbe()
        result = probe.probe_all_layers(harmful, harmless, dirs)

        path = tmp_dir / "probe.png"
        plot_probe_dashboard(result, output_path=path)
        assert path.exists()


class TestDefenseRadar:
    def test_plot_saves_file(self, tmp_dir):
        profile = DefenseProfile(
            model_name="test-model",
            alignment_type_estimate="RLHF-like",
            refusal_concentration=0.4,
            refusal_layer_spread=5,
            mean_refusal_strength=2.0,
            max_refusal_strength=4.0,
            self_repair_estimate=0.6,
            entanglement_score=0.3,
            estimated_robustness="medium",
        )
        path = tmp_dir / "radar.png"
        plot_defense_radar(profile, output_path=path)
        assert path.exists()