Spaces:
Sleeping
Sleeping
File size: 5,492 Bytes
3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import stat
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
@pytest.fixture
def key_generator(tmp_path):
"""Fixture to create an AuthKeyGenerator instance in a temporary directory."""
key_dir = tmp_path / "auth_keys"
return AuthKeyGenerator(key_dir=str(key_dir))
class TestAuthKeyGenerator:
"""Test suite for the AuthKeyGenerator class."""
def test_init_creates_directory(self, tmp_path):
"""Verify that the key directory is created on initialization."""
key_dir = tmp_path / "new_keys"
assert not os.path.exists(key_dir)
AuthKeyGenerator(key_dir=str(key_dir))
assert os.path.exists(key_dir)
def test_generate_participant_keys_creates_files_and_returns_openssh_with_comment(
self, key_generator
):
"""
Verify the main method generates files and returns the public key
in the correct OpenSSH format including a comment.
"""
participant_id = "participant_01"
priv_path, pub_path, pub_ssh_string = key_generator.generate_participant_keys(
participant_id
)
# 1. Check file existence and permissions
assert os.path.exists(priv_path)
assert os.path.exists(pub_path)
if os.name != "nt":
file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
assert file_mode == 0o600
# 2. Verify that the returned string has three parts (type, key, comment)
assert pub_ssh_string.startswith("ecdsa-sha2-nistp384")
assert pub_ssh_string.endswith(participant_id)
assert len(pub_ssh_string.split(" ")) == 3
# 3. Verify that the public key file can be loaded as an SSH key
with open(pub_path, "rb") as f:
public_key_from_file = serialization.load_ssh_public_key(f.read())
assert isinstance(public_key_from_file, ec.EllipticCurvePublicKey)
class TestRebuildAuthorizedKeysFile:
"""Test suite for the rebuild_authorized_keys_csv function."""
def test_rebuild_creates_file_with_only_newline_for_empty_list(self, tmp_path):
"""Verify an empty participant list results in a file with just a newline."""
key_dir = tmp_path / "keys_test"
os.makedirs(key_dir)
rebuild_authorized_keys_csv(str(key_dir), [])
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
with open(csv_path, "r") as f:
content = f.read()
assert content == "\n"
def test_rebuild_writes_correct_single_line_format(self, tmp_path):
"""Verify the file is created in the single-line, comma-separated format."""
key_dir = tmp_path / "keys_test"
os.makedirs(key_dir)
participants = [
("p1", "ecdsa-sha2-nistp384 AAAA...key1 p1"),
("p2", "ecdsa-sha2-nistp384 AAAA...key2 p2"),
]
rebuild_authorized_keys_csv(str(key_dir), participants)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
with open(csv_path, "r") as f:
content = f.read().strip()
expected_content = (
"ecdsa-sha2-nistp384 AAAA...key1 p1,ecdsa-sha2-nistp384 AAAA...key2 p2"
)
assert content == expected_content
def test_rebuild_overwrites_existing_file(self, tmp_path):
"""Verify that an existing file is correctly overwritten."""
key_dir = tmp_path / "keys_test"
os.makedirs(key_dir)
# Use dummy data that matches the expected OpenSSH format
initial_participants = [("old_p1", "ecdsa-sha2-nistp384 old_key_1 old_p1")]
rebuild_authorized_keys_csv(str(key_dir), initial_participants)
new_participants = [
("new_p1", "ecdsa-sha2-nistp384 new_key_1 new_p1"),
("new_p2", "ecdsa-sha2-nistp384 new_key_2 new_p2"),
]
rebuild_authorized_keys_csv(str(key_dir), new_participants)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
with open(csv_path, "r") as f:
content = f.read().strip()
expected_content = (
"ecdsa-sha2-nistp384 new_key_1 new_p1,ecdsa-sha2-nistp384 new_key_2 new_p2"
)
assert content == expected_content
def test_rebuild_sanitizes_pem_keys_to_ssh_format(self, tmp_path):
"""
Tests the self-healing capability of the rebuild function to convert
old PEM keys from the database into the correct OpenSSH format.
"""
key_dir = tmp_path / "keys_test"
os.makedirs(key_dir)
# Generate a real key pair to get a valid PEM string
private_key = ec.generate_private_key(ec.SECP384R1())
public_key = private_key.public_key()
pem_key = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode("utf-8")
participants = [("p1_pem", pem_key)]
rebuild_authorized_keys_csv(str(key_dir), participants)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
with open(csv_path, "r") as f:
content = f.read().strip()
# Verify the output is now in the correct OpenSSH format with the comment
assert content.startswith("ecdsa-sha2-nistp384")
assert content.endswith("p1_pem")
|