BlossomTune-Orchestrator / tests /test_auth_keys.py
mrs83's picture
Add Elliptic Curve (EC) Authentication (#21)
bc67f56 unverified
raw
history blame
5.49 kB
import os
import stat
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
@pytest.fixture
def key_generator(tmp_path):
"""Fixture to create an AuthKeyGenerator instance in a temporary directory."""
key_dir = tmp_path / "auth_keys"
return AuthKeyGenerator(key_dir=str(key_dir))
class TestAuthKeyGenerator:
"""Test suite for the AuthKeyGenerator class."""
def test_init_creates_directory(self, tmp_path):
"""Verify that the key directory is created on initialization."""
key_dir = tmp_path / "new_keys"
assert not os.path.exists(key_dir)
AuthKeyGenerator(key_dir=str(key_dir))
assert os.path.exists(key_dir)
def test_generate_participant_keys_creates_files_and_returns_openssh_with_comment(
self, key_generator
):
"""
Verify the main method generates files and returns the public key
in the correct OpenSSH format including a comment.
"""
participant_id = "participant_01"
priv_path, pub_path, pub_ssh_string = key_generator.generate_participant_keys(
participant_id
)
# 1. Check file existence and permissions
assert os.path.exists(priv_path)
assert os.path.exists(pub_path)
if os.name != "nt":
file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
assert file_mode == 0o600
# 2. Verify that the returned string has three parts (type, key, comment)
assert pub_ssh_string.startswith("ecdsa-sha2-nistp384")
assert pub_ssh_string.endswith(participant_id)
assert len(pub_ssh_string.split(" ")) == 3
# 3. Verify that the public key file can be loaded as an SSH key
with open(pub_path, "rb") as f:
public_key_from_file = serialization.load_ssh_public_key(f.read())
assert isinstance(public_key_from_file, ec.EllipticCurvePublicKey)
class TestRebuildAuthorizedKeysFile:
"""Test suite for the rebuild_authorized_keys_csv function."""
def test_rebuild_creates_file_with_only_newline_for_empty_list(self, tmp_path):
"""Verify an empty participant list results in a file with just a newline."""
key_dir = tmp_path / "keys_test"
os.makedirs(key_dir)
rebuild_authorized_keys_csv(str(key_dir), [])
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
with open(csv_path, "r") as f:
content = f.read()
assert content == "\n"
def test_rebuild_writes_correct_single_line_format(self, tmp_path):
"""Verify the file is created in the single-line, comma-separated format."""
key_dir = tmp_path / "keys_test"
os.makedirs(key_dir)
participants = [
("p1", "ecdsa-sha2-nistp384 AAAA...key1 p1"),
("p2", "ecdsa-sha2-nistp384 AAAA...key2 p2"),
]
rebuild_authorized_keys_csv(str(key_dir), participants)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
with open(csv_path, "r") as f:
content = f.read().strip()
expected_content = (
"ecdsa-sha2-nistp384 AAAA...key1 p1,ecdsa-sha2-nistp384 AAAA...key2 p2"
)
assert content == expected_content
def test_rebuild_overwrites_existing_file(self, tmp_path):
"""Verify that an existing file is correctly overwritten."""
key_dir = tmp_path / "keys_test"
os.makedirs(key_dir)
# Use dummy data that matches the expected OpenSSH format
initial_participants = [("old_p1", "ecdsa-sha2-nistp384 old_key_1 old_p1")]
rebuild_authorized_keys_csv(str(key_dir), initial_participants)
new_participants = [
("new_p1", "ecdsa-sha2-nistp384 new_key_1 new_p1"),
("new_p2", "ecdsa-sha2-nistp384 new_key_2 new_p2"),
]
rebuild_authorized_keys_csv(str(key_dir), new_participants)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
with open(csv_path, "r") as f:
content = f.read().strip()
expected_content = (
"ecdsa-sha2-nistp384 new_key_1 new_p1,ecdsa-sha2-nistp384 new_key_2 new_p2"
)
assert content == expected_content
def test_rebuild_sanitizes_pem_keys_to_ssh_format(self, tmp_path):
"""
Tests the self-healing capability of the rebuild function to convert
old PEM keys from the database into the correct OpenSSH format.
"""
key_dir = tmp_path / "keys_test"
os.makedirs(key_dir)
# Generate a real key pair to get a valid PEM string
private_key = ec.generate_private_key(ec.SECP384R1())
public_key = private_key.public_key()
pem_key = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode("utf-8")
participants = [("p1_pem", pem_key)]
rebuild_authorized_keys_csv(str(key_dir), participants)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
with open(csv_path, "r") as f:
content = f.read().strip()
# Verify the output is now in the correct OpenSSH format with the comment
assert content.startswith("ecdsa-sha2-nistp384")
assert content.endswith("p1_pem")