Spaces:
Sleeping
Sleeping
File size: 5,182 Bytes
3a1c55b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import csv
import stat
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
@pytest.fixture
def key_generator(tmp_path):
"""Fixture to create an AuthKeyGenerator instance in a temporary directory."""
key_dir = tmp_path / "auth_keys"
return AuthKeyGenerator(key_dir=str(key_dir))
class TestAuthKeyGenerator:
"""Test suite for the AuthKeyGenerator class."""
def test_init_creates_directory(self, tmp_path):
"""Verify that the key directory is created on initialization."""
key_dir = tmp_path / "new_keys"
assert not os.path.exists(key_dir)
AuthKeyGenerator(key_dir=str(key_dir))
assert os.path.exists(key_dir)
def test_generate_participant_keys_creates_files_and_returns_data(
self, key_generator
):
"""
Verify that the main method generates all expected files and returns
the correct data tuple.
"""
participant_id = "participant_01"
priv_path, pub_path, pub_pem = key_generator.generate_participant_keys(
participant_id
)
# 1. Check if files exist
assert os.path.exists(priv_path)
assert os.path.exists(pub_path)
assert priv_path == os.path.join(key_generator.key_dir, f"{participant_id}.key")
assert pub_path == os.path.join(key_generator.key_dir, f"{participant_id}.pub")
# 2. Check private key file permissions (security check)
# In non-Windows environments, check for 600 permissions.
if os.name != "nt":
file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
assert file_mode == 0o600
# 3. Verify key formats and consistency
with open(priv_path, "rb") as f:
private_key = serialization.load_pem_private_key(f.read(), password=None)
with open(pub_path, "rb") as f:
public_key = serialization.load_pem_public_key(f.read())
_ = f.read() # Read again to get bytes
assert isinstance(private_key, ec.EllipticCurvePrivateKey)
assert isinstance(public_key, ec.EllipticCurvePublicKey)
# Check that the returned PEM string matches the public key
generated_public_key = private_key.public_key()
pem_from_private = generated_public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode("utf-8")
assert pub_pem == pem_from_private
class TestRebuildCSV:
"""Test suite for the rebuild_authorized_keys_csv function."""
def test_rebuild_csv_creates_file_with_header_for_empty_list(self, tmp_path):
"""Verify a CSV with only a header is created for an empty participant list."""
key_dir = tmp_path / "csv_test"
os.makedirs(key_dir)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
rebuild_authorized_keys_csv(key_dir, [])
assert os.path.exists(csv_path)
with open(csv_path, "r") as f:
reader = csv.reader(f)
header = next(reader)
assert header == ["participant_id", "public_key_pem"]
# Check that there are no more rows
with pytest.raises(StopIteration):
next(reader)
def test_rebuild_csv_writes_correct_data(self, tmp_path):
"""Verify the CSV is created with the correct participant data."""
key_dir = tmp_path / "csv_test"
os.makedirs(key_dir)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
participants = [
("p1", "---BEGIN PUBLIC KEY---...p1...---END PUBLIC KEY---"),
("p2", "---BEGIN PUBLIC KEY---...p2...---END PUBLIC KEY---"),
]
rebuild_authorized_keys_csv(key_dir, participants)
with open(csv_path, "r") as f:
reader = csv.reader(f)
header = next(reader)
row1 = next(reader)
row2 = next(reader)
assert header == ["participant_id", "public_key_pem"]
assert row1 == list(participants[0])
assert row2 == list(participants[1])
def test_rebuild_csv_overwrites_existing_file(self, tmp_path):
"""Verify that an existing CSV file is correctly overwritten."""
key_dir = tmp_path / "csv_test"
os.makedirs(key_dir)
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
# First run with initial data
initial_participants = [("old_p1", "old_key_1")]
rebuild_authorized_keys_csv(key_dir, initial_participants)
# Second run with new data
new_participants = [("new_p1", "new_key_1"), ("new_p2", "new_key_2")]
rebuild_authorized_keys_csv(key_dir, new_participants)
with open(csv_path, "r") as f:
reader = csv.reader(f)
_ = next(reader)
rows = list(reader)
assert len(rows) == 2
assert rows[0] == list(new_participants[0])
assert rows[1] == list(new_participants[1])
|