File size: 5,182 Bytes
3a1c55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import csv
import stat
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec

from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv


@pytest.fixture
def key_generator(tmp_path):
    """Fixture to create an AuthKeyGenerator instance in a temporary directory."""
    key_dir = tmp_path / "auth_keys"
    return AuthKeyGenerator(key_dir=str(key_dir))


class TestAuthKeyGenerator:
    """Test suite for the AuthKeyGenerator class."""

    def test_init_creates_directory(self, tmp_path):
        """Verify that the key directory is created on initialization."""
        key_dir = tmp_path / "new_keys"
        assert not os.path.exists(key_dir)
        AuthKeyGenerator(key_dir=str(key_dir))
        assert os.path.exists(key_dir)

    def test_generate_participant_keys_creates_files_and_returns_data(
        self, key_generator
    ):
        """
        Verify that the main method generates all expected files and returns
        the correct data tuple.
        """
        participant_id = "participant_01"
        priv_path, pub_path, pub_pem = key_generator.generate_participant_keys(
            participant_id
        )

        # 1. Check if files exist
        assert os.path.exists(priv_path)
        assert os.path.exists(pub_path)
        assert priv_path == os.path.join(key_generator.key_dir, f"{participant_id}.key")
        assert pub_path == os.path.join(key_generator.key_dir, f"{participant_id}.pub")

        # 2. Check private key file permissions (security check)
        # In non-Windows environments, check for 600 permissions.
        if os.name != "nt":
            file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
            assert file_mode == 0o600

        # 3. Verify key formats and consistency
        with open(priv_path, "rb") as f:
            private_key = serialization.load_pem_private_key(f.read(), password=None)
        with open(pub_path, "rb") as f:
            public_key = serialization.load_pem_public_key(f.read())
            _ = f.read()  # Read again to get bytes

        assert isinstance(private_key, ec.EllipticCurvePrivateKey)
        assert isinstance(public_key, ec.EllipticCurvePublicKey)

        # Check that the returned PEM string matches the public key
        generated_public_key = private_key.public_key()
        pem_from_private = generated_public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo,
        ).decode("utf-8")
        assert pub_pem == pem_from_private


class TestRebuildCSV:
    """Test suite for the rebuild_authorized_keys_csv function."""

    def test_rebuild_csv_creates_file_with_header_for_empty_list(self, tmp_path):
        """Verify a CSV with only a header is created for an empty participant list."""
        key_dir = tmp_path / "csv_test"
        os.makedirs(key_dir)
        csv_path = os.path.join(key_dir, "authorized_supernodes.csv")

        rebuild_authorized_keys_csv(key_dir, [])

        assert os.path.exists(csv_path)
        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            header = next(reader)
            assert header == ["participant_id", "public_key_pem"]
            # Check that there are no more rows
            with pytest.raises(StopIteration):
                next(reader)

    def test_rebuild_csv_writes_correct_data(self, tmp_path):
        """Verify the CSV is created with the correct participant data."""
        key_dir = tmp_path / "csv_test"
        os.makedirs(key_dir)
        csv_path = os.path.join(key_dir, "authorized_supernodes.csv")

        participants = [
            ("p1", "---BEGIN PUBLIC KEY---...p1...---END PUBLIC KEY---"),
            ("p2", "---BEGIN PUBLIC KEY---...p2...---END PUBLIC KEY---"),
        ]
        rebuild_authorized_keys_csv(key_dir, participants)

        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            header = next(reader)
            row1 = next(reader)
            row2 = next(reader)
            assert header == ["participant_id", "public_key_pem"]
            assert row1 == list(participants[0])
            assert row2 == list(participants[1])

    def test_rebuild_csv_overwrites_existing_file(self, tmp_path):
        """Verify that an existing CSV file is correctly overwritten."""
        key_dir = tmp_path / "csv_test"
        os.makedirs(key_dir)
        csv_path = os.path.join(key_dir, "authorized_supernodes.csv")

        # First run with initial data
        initial_participants = [("old_p1", "old_key_1")]
        rebuild_authorized_keys_csv(key_dir, initial_participants)

        # Second run with new data
        new_participants = [("new_p1", "new_key_1"), ("new_p2", "new_key_2")]
        rebuild_authorized_keys_csv(key_dir, new_participants)

        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            _ = next(reader)
            rows = list(reader)

        assert len(rows) == 2
        assert rows[0] == list(new_participants[0])
        assert rows[1] == list(new_participants[1])