BlossomTune-Orchestrator / tests /test_auth_keys.py
mrs83's picture
TLS certs and external superlink support (plus initial work on authentication keys and .blossomfile) (#3)
3a1c55b unverified
raw
history blame
5.18 kB
import os
import csv
import stat
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
@pytest.fixture
def key_generator(tmp_path):
"""Fixture to create an AuthKeyGenerator instance in a temporary directory."""
key_dir = tmp_path / "auth_keys"
return AuthKeyGenerator(key_dir=str(key_dir))
class TestAuthKeyGenerator:
"""Test suite for the AuthKeyGenerator class."""
def test_init_creates_directory(self, tmp_path):
"""Verify that the key directory is created on initialization."""
key_dir = tmp_path / "new_keys"
assert not os.path.exists(key_dir)
AuthKeyGenerator(key_dir=str(key_dir))
assert os.path.exists(key_dir)
def test_generate_participant_keys_creates_files_and_returns_data(
self, key_generator
):
"""
Verify that the main method generates all expected files and returns
the correct data tuple.
"""
participant_id = "participant_01"
priv_path, pub_path, pub_pem = key_generator.generate_participant_keys(
participant_id
)
# 1. Check if files exist
assert os.path.exists(priv_path)
assert os.path.exists(pub_path)
assert priv_path == os.path.join(key_generator.key_dir, f"{participant_id}.key")
assert pub_path == os.path.join(key_generator.key_dir, f"{participant_id}.pub")
# 2. Check private key file permissions (security check)
# In non-Windows environments, check for 600 permissions.
if os.name != "nt":
file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
assert file_mode == 0o600
# 3. Verify key formats and consistency
with open(priv_path, "rb") as f:
private_key = serialization.load_pem_private_key(f.read(), password=None)
with open(pub_path, "rb") as f:
public_key = serialization.load_pem_public_key(f.read())
_ = f.read() # Read again to get bytes
assert isinstance(private_key, ec.EllipticCurvePrivateKey)
assert isinstance(public_key, ec.EllipticCurvePublicKey)
# Check that the returned PEM string matches the public key
generated_public_key = private_key.public_key()
pem_from_private = generated_public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode("utf-8")
assert pub_pem == pem_from_private
class TestRebuildCSV:
"""Test suite for the rebuild_authorized_keys_csv function."""
def test_rebuild_csv_creates_file_with_header_for_empty_list(self, tmp_path):
"""Verify a CSV with only a header is created for an empty participant list."""
key_dir = tmp_path / "csv_test"
os.makedirs(key_dir)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
rebuild_authorized_keys_csv(key_dir, [])
assert os.path.exists(csv_path)
with open(csv_path, "r") as f:
reader = csv.reader(f)
header = next(reader)
assert header == ["participant_id", "public_key_pem"]
# Check that there are no more rows
with pytest.raises(StopIteration):
next(reader)
def test_rebuild_csv_writes_correct_data(self, tmp_path):
"""Verify the CSV is created with the correct participant data."""
key_dir = tmp_path / "csv_test"
os.makedirs(key_dir)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
participants = [
("p1", "---BEGIN PUBLIC KEY---...p1...---END PUBLIC KEY---"),
("p2", "---BEGIN PUBLIC KEY---...p2...---END PUBLIC KEY---"),
]
rebuild_authorized_keys_csv(key_dir, participants)
with open(csv_path, "r") as f:
reader = csv.reader(f)
header = next(reader)
row1 = next(reader)
row2 = next(reader)
assert header == ["participant_id", "public_key_pem"]
assert row1 == list(participants[0])
assert row2 == list(participants[1])
def test_rebuild_csv_overwrites_existing_file(self, tmp_path):
"""Verify that an existing CSV file is correctly overwritten."""
key_dir = tmp_path / "csv_test"
os.makedirs(key_dir)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
# First run with initial data
initial_participants = [("old_p1", "old_key_1")]
rebuild_authorized_keys_csv(key_dir, initial_participants)
# Second run with new data
new_participants = [("new_p1", "new_key_1"), ("new_p2", "new_key_2")]
rebuild_authorized_keys_csv(key_dir, new_participants)
with open(csv_path, "r") as f:
reader = csv.reader(f)
_ = next(reader)
rows = list(reader)
assert len(rows) == 2
assert rows[0] == list(new_participants[0])
assert rows[1] == list(new_participants[1])