File size: 7,503 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import sys
from unittest.mock import MagicMock, mock_open, patch

import pytest

# Add project root to sys.path
sys.path.append(os.getcwd())

from stream.cert_manager import CertificateManager


class TestCertificateManager:
    @pytest.fixture
    def mock_path(self):
        with patch("stream.cert_manager.Path") as mock:
            yield mock

    @pytest.fixture
    def mock_rsa(self):
        with patch("stream.cert_manager.rsa") as mock:
            yield mock

    @pytest.fixture
    def mock_x509(self):
        with patch("stream.cert_manager.x509") as mock:
            yield mock

    @pytest.fixture
    def mock_serialization(self):
        with patch("stream.cert_manager.serialization") as mock:
            yield mock

    def test_init_creates_dir(self, mock_path):
        """Test certificate manager creates certificate directory on init."""
        mock_dir = MagicMock()
        mock_path.return_value = mock_dir

        # Setup exists to return True so we don't trigger generation in init
        mock_dir.__truediv__.return_value.exists.return_value = True

        # We also need to mock _load_ca_cert calls to open/load
        with (
            patch("builtins.open", mock_open(read_data=b"data")),
            patch("stream.cert_manager.serialization.load_pem_private_key"),
            patch("stream.cert_manager.x509.load_pem_x509_certificate"),
        ):
            CertificateManager("test_certs")

            mock_path.assert_called_with("test_certs")
            mock_dir.mkdir.assert_called_with(exist_ok=True)

    def test_init_generates_ca_if_missing(self, mock_path, mock_rsa, mock_x509):
        """Test CA certificate generation when missing on init."""
        mock_dir = MagicMock()
        mock_path.return_value = mock_dir

        # Setup exists to return False for CA cert
        ca_cert_path = MagicMock()
        ca_cert_path.exists.return_value = False

        # Setup side_effect for division to handle paths
        # CA paths return the mock that has exists=False
        mock_dir.__truediv__.return_value = ca_cert_path

        # Mock key generation
        mock_key = MagicMock()
        mock_rsa.generate_private_key.return_value = mock_key
        mock_key.private_bytes.return_value = b"private_key_bytes"
        mock_key.public_key.return_value = MagicMock()

        # Mock builder
        mock_builder = MagicMock()
        mock_x509.CertificateBuilder.return_value = mock_builder
        mock_builder.subject_name.return_value = mock_builder
        mock_builder.issuer_name.return_value = mock_builder
        mock_builder.public_key.return_value = mock_builder
        mock_builder.serial_number.return_value = mock_builder
        mock_builder.not_valid_before.return_value = mock_builder
        mock_builder.not_valid_after.return_value = mock_builder
        mock_builder.add_extension.return_value = mock_builder

        mock_cert = MagicMock()
        mock_builder.sign.return_value = mock_cert
        mock_cert.public_bytes.return_value = b"cert_bytes"

        with (
            patch("builtins.open", mock_open()) as mocked_file,
            patch("stream.cert_manager.serialization.load_pem_private_key"),
            patch("stream.cert_manager.x509.load_pem_x509_certificate"),
        ):
            CertificateManager("test_certs")

            # Check if generate_private_key was called
            mock_rsa.generate_private_key.assert_called_once()

            # Check if files were written (key and cert)
            # We expect multiple calls to open (write key, write cert, read key, read cert)
            assert mocked_file.call_count >= 2

    def test_get_domain_cert_existing(self, mock_path):
        """Test retrieving existing domain certificate from cache."""
        mock_dir = MagicMock()
        mock_path.return_value = mock_dir

        # Setup paths
        # We need specific behavior for different paths
        ca_path_mock = MagicMock()
        ca_path_mock.exists.return_value = True

        domain_path_mock = MagicMock()
        domain_path_mock.exists.return_value = True

        def truediv_side_effect(arg):
            # If we divide by domain.crt or domain.key, return domain_path_mock
            if "example.com" in str(arg):
                return domain_path_mock
            return ca_path_mock

        mock_dir.__truediv__.side_effect = truediv_side_effect

        with (
            patch("builtins.open", mock_open(read_data=b"data")),
            patch(
                "stream.cert_manager.serialization.load_pem_private_key"
            ) as mock_load_key,
            patch(
                "stream.cert_manager.x509.load_pem_x509_certificate"
            ) as mock_load_cert,
        ):
            mock_load_key.return_value = "mock_key"
            mock_load_cert.return_value = "mock_cert"

            manager = CertificateManager("test_certs")
            key, cert = manager.get_domain_cert("example.com")

            assert key == "mock_key"
            assert cert == "mock_cert"

    def test_get_domain_cert_generates_new(self, mock_path, mock_rsa, mock_x509):
        """Test generating new domain certificate when not cached."""
        mock_dir = MagicMock()
        mock_path.return_value = mock_dir

        # CA exists, domain does not
        ca_path_mock = MagicMock()
        ca_path_mock.exists.return_value = True

        domain_path_mock = MagicMock()
        domain_path_mock.exists.return_value = False

        def truediv_side_effect(arg):
            if "example.com" in str(arg):
                return domain_path_mock
            return ca_path_mock

        mock_dir.__truediv__.side_effect = truediv_side_effect

        # Mock key generation
        mock_key = MagicMock()
        mock_rsa.generate_private_key.return_value = mock_key
        mock_key.private_bytes.return_value = b"domain_private_key"
        mock_key.public_key.return_value = MagicMock()

        # Mock builder for domain cert
        mock_builder = MagicMock()
        mock_x509.CertificateBuilder.return_value = mock_builder
        mock_builder.subject_name.return_value = mock_builder
        mock_builder.issuer_name.return_value = mock_builder
        mock_builder.public_key.return_value = mock_builder
        mock_builder.serial_number.return_value = mock_builder
        mock_builder.not_valid_before.return_value = mock_builder
        mock_builder.not_valid_after.return_value = mock_builder
        mock_builder.add_extension.return_value = mock_builder

        mock_cert = MagicMock()
        mock_builder.sign.return_value = mock_cert
        mock_cert.public_bytes.return_value = b"domain_cert_bytes"

        # We need CA cert loaded to sign
        mock_ca_cert = MagicMock()

        with (
            patch("builtins.open", mock_open(read_data=b"ca_data")),
            patch("stream.cert_manager.serialization.load_pem_private_key"),
            patch(
                "stream.cert_manager.x509.load_pem_x509_certificate"
            ) as mock_load_cert,
        ):
            mock_load_cert.return_value = mock_ca_cert

            manager = CertificateManager("test_certs")

            # Call get_domain_cert
            key, cert = manager.get_domain_cert("example.com")

            # Verify generation
            # 1 call to generate_private_key (for domain, CA was loaded)
            assert mock_rsa.generate_private_key.call_count == 1

            assert key == mock_key
            assert cert == mock_cert