File size: 5,492 Bytes
3a1c55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc67f56
3a1c55b
 
 
bc67f56
 
3a1c55b
 
bc67f56
3a1c55b
 
 
bc67f56
3a1c55b
 
 
 
 
 
bc67f56
 
 
 
3a1c55b
bc67f56
 
 
 
3a1c55b
 
bc67f56
3a1c55b
 
bc67f56
 
 
3a1c55b
bc67f56
3a1c55b
bc67f56
3a1c55b
bc67f56
 
 
 
 
 
3a1c55b
 
 
bc67f56
 
3a1c55b
bc67f56
3a1c55b
bc67f56
3a1c55b
bc67f56
 
 
 
 
 
 
 
 
 
3a1c55b
bc67f56
 
 
 
 
 
 
 
 
 
 
3a1c55b
bc67f56
 
3a1c55b
bc67f56
 
 
 
3a1c55b
bc67f56
 
 
 
 
 
 
3a1c55b
bc67f56
 
 
 
 
 
 
 
 
 
 
 
3a1c55b
bc67f56
3a1c55b
bc67f56
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import stat
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec

from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv


@pytest.fixture
def key_generator(tmp_path):
    """Fixture to create an AuthKeyGenerator instance in a temporary directory."""
    key_dir = tmp_path / "auth_keys"
    return AuthKeyGenerator(key_dir=str(key_dir))


class TestAuthKeyGenerator:
    """Test suite for the AuthKeyGenerator class."""

    def test_init_creates_directory(self, tmp_path):
        """Verify that the key directory is created on initialization."""
        key_dir = tmp_path / "new_keys"
        assert not os.path.exists(key_dir)
        AuthKeyGenerator(key_dir=str(key_dir))
        assert os.path.exists(key_dir)

    def test_generate_participant_keys_creates_files_and_returns_openssh_with_comment(
        self, key_generator
    ):
        """
        Verify the main method generates files and returns the public key
        in the correct OpenSSH format including a comment.
        """
        participant_id = "participant_01"
        priv_path, pub_path, pub_ssh_string = key_generator.generate_participant_keys(
            participant_id
        )

        # 1. Check file existence and permissions
        assert os.path.exists(priv_path)
        assert os.path.exists(pub_path)
        if os.name != "nt":
            file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
            assert file_mode == 0o600

        # 2. Verify that the returned string has three parts (type, key, comment)
        assert pub_ssh_string.startswith("ecdsa-sha2-nistp384")
        assert pub_ssh_string.endswith(participant_id)
        assert len(pub_ssh_string.split(" ")) == 3

        # 3. Verify that the public key file can be loaded as an SSH key
        with open(pub_path, "rb") as f:
            public_key_from_file = serialization.load_ssh_public_key(f.read())
        assert isinstance(public_key_from_file, ec.EllipticCurvePublicKey)


class TestRebuildAuthorizedKeysFile:
    """Test suite for the rebuild_authorized_keys_csv function."""

    def test_rebuild_creates_file_with_only_newline_for_empty_list(self, tmp_path):
        """Verify an empty participant list results in a file with just a newline."""
        key_dir = tmp_path / "keys_test"
        os.makedirs(key_dir)
        rebuild_authorized_keys_csv(str(key_dir), [])

        csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
        with open(csv_path, "r") as f:
            content = f.read()
        assert content == "\n"

    def test_rebuild_writes_correct_single_line_format(self, tmp_path):
        """Verify the file is created in the single-line, comma-separated format."""
        key_dir = tmp_path / "keys_test"
        os.makedirs(key_dir)

        participants = [
            ("p1", "ecdsa-sha2-nistp384 AAAA...key1 p1"),
            ("p2", "ecdsa-sha2-nistp384 AAAA...key2 p2"),
        ]
        rebuild_authorized_keys_csv(str(key_dir), participants)

        csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
        with open(csv_path, "r") as f:
            content = f.read().strip()

        expected_content = (
            "ecdsa-sha2-nistp384 AAAA...key1 p1,ecdsa-sha2-nistp384 AAAA...key2 p2"
        )
        assert content == expected_content

    def test_rebuild_overwrites_existing_file(self, tmp_path):
        """Verify that an existing file is correctly overwritten."""
        key_dir = tmp_path / "keys_test"
        os.makedirs(key_dir)

        # Use dummy data that matches the expected OpenSSH format
        initial_participants = [("old_p1", "ecdsa-sha2-nistp384 old_key_1 old_p1")]
        rebuild_authorized_keys_csv(str(key_dir), initial_participants)

        new_participants = [
            ("new_p1", "ecdsa-sha2-nistp384 new_key_1 new_p1"),
            ("new_p2", "ecdsa-sha2-nistp384 new_key_2 new_p2"),
        ]
        rebuild_authorized_keys_csv(str(key_dir), new_participants)

        csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
        with open(csv_path, "r") as f:
            content = f.read().strip()

        expected_content = (
            "ecdsa-sha2-nistp384 new_key_1 new_p1,ecdsa-sha2-nistp384 new_key_2 new_p2"
        )
        assert content == expected_content

    def test_rebuild_sanitizes_pem_keys_to_ssh_format(self, tmp_path):
        """
        Tests the self-healing capability of the rebuild function to convert
        old PEM keys from the database into the correct OpenSSH format.
        """
        key_dir = tmp_path / "keys_test"
        os.makedirs(key_dir)

        # Generate a real key pair to get a valid PEM string
        private_key = ec.generate_private_key(ec.SECP384R1())
        public_key = private_key.public_key()
        pem_key = public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo,
        ).decode("utf-8")

        participants = [("p1_pem", pem_key)]
        rebuild_authorized_keys_csv(str(key_dir), participants)

        csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
        with open(csv_path, "r") as f:
            content = f.read().strip()

        # Verify the output is now in the correct OpenSSH format with the comment
        assert content.startswith("ecdsa-sha2-nistp384")
        assert content.endswith("p1_pem")