Spaces:
Sleeping
Sleeping
| import os | |
| import csv | |
| import stat | |
| import pytest | |
| from cryptography.hazmat.primitives import serialization | |
| from cryptography.hazmat.primitives.asymmetric import ec | |
| from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv | |
| def key_generator(tmp_path): | |
| """Fixture to create an AuthKeyGenerator instance in a temporary directory.""" | |
| key_dir = tmp_path / "auth_keys" | |
| return AuthKeyGenerator(key_dir=str(key_dir)) | |
| class TestAuthKeyGenerator: | |
| """Test suite for the AuthKeyGenerator class.""" | |
| def test_init_creates_directory(self, tmp_path): | |
| """Verify that the key directory is created on initialization.""" | |
| key_dir = tmp_path / "new_keys" | |
| assert not os.path.exists(key_dir) | |
| AuthKeyGenerator(key_dir=str(key_dir)) | |
| assert os.path.exists(key_dir) | |
| def test_generate_participant_keys_creates_files_and_returns_data( | |
| self, key_generator | |
| ): | |
| """ | |
| Verify that the main method generates all expected files and returns | |
| the correct data tuple. | |
| """ | |
| participant_id = "participant_01" | |
| priv_path, pub_path, pub_pem = key_generator.generate_participant_keys( | |
| participant_id | |
| ) | |
| # 1. Check if files exist | |
| assert os.path.exists(priv_path) | |
| assert os.path.exists(pub_path) | |
| assert priv_path == os.path.join(key_generator.key_dir, f"{participant_id}.key") | |
| assert pub_path == os.path.join(key_generator.key_dir, f"{participant_id}.pub") | |
| # 2. Check private key file permissions (security check) | |
| # In non-Windows environments, check for 600 permissions. | |
| if os.name != "nt": | |
| file_mode = stat.S_IMODE(os.stat(priv_path).st_mode) | |
| assert file_mode == 0o600 | |
| # 3. Verify key formats and consistency | |
| with open(priv_path, "rb") as f: | |
| private_key = serialization.load_pem_private_key(f.read(), password=None) | |
| with open(pub_path, "rb") as f: | |
| public_key = serialization.load_pem_public_key(f.read()) | |
| _ = f.read() # Read again to get bytes | |
| assert isinstance(private_key, ec.EllipticCurvePrivateKey) | |
| assert isinstance(public_key, ec.EllipticCurvePublicKey) | |
| # Check that the returned PEM string matches the public key | |
| generated_public_key = private_key.public_key() | |
| pem_from_private = generated_public_key.public_bytes( | |
| encoding=serialization.Encoding.PEM, | |
| format=serialization.PublicFormat.SubjectPublicKeyInfo, | |
| ).decode("utf-8") | |
| assert pub_pem == pem_from_private | |
| class TestRebuildCSV: | |
| """Test suite for the rebuild_authorized_keys_csv function.""" | |
| def test_rebuild_csv_creates_file_with_header_for_empty_list(self, tmp_path): | |
| """Verify a CSV with only a header is created for an empty participant list.""" | |
| key_dir = tmp_path / "csv_test" | |
| os.makedirs(key_dir) | |
| csv_path = os.path.join(key_dir, "authorized_supernodes.csv") | |
| rebuild_authorized_keys_csv(key_dir, []) | |
| assert os.path.exists(csv_path) | |
| with open(csv_path, "r") as f: | |
| reader = csv.reader(f) | |
| header = next(reader) | |
| assert header == ["participant_id", "public_key_pem"] | |
| # Check that there are no more rows | |
| with pytest.raises(StopIteration): | |
| next(reader) | |
| def test_rebuild_csv_writes_correct_data(self, tmp_path): | |
| """Verify the CSV is created with the correct participant data.""" | |
| key_dir = tmp_path / "csv_test" | |
| os.makedirs(key_dir) | |
| csv_path = os.path.join(key_dir, "authorized_supernodes.csv") | |
| participants = [ | |
| ("p1", "---BEGIN PUBLIC KEY---...p1...---END PUBLIC KEY---"), | |
| ("p2", "---BEGIN PUBLIC KEY---...p2...---END PUBLIC KEY---"), | |
| ] | |
| rebuild_authorized_keys_csv(key_dir, participants) | |
| with open(csv_path, "r") as f: | |
| reader = csv.reader(f) | |
| header = next(reader) | |
| row1 = next(reader) | |
| row2 = next(reader) | |
| assert header == ["participant_id", "public_key_pem"] | |
| assert row1 == list(participants[0]) | |
| assert row2 == list(participants[1]) | |
| def test_rebuild_csv_overwrites_existing_file(self, tmp_path): | |
| """Verify that an existing CSV file is correctly overwritten.""" | |
| key_dir = tmp_path / "csv_test" | |
| os.makedirs(key_dir) | |
| csv_path = os.path.join(key_dir, "authorized_supernodes.csv") | |
| # First run with initial data | |
| initial_participants = [("old_p1", "old_key_1")] | |
| rebuild_authorized_keys_csv(key_dir, initial_participants) | |
| # Second run with new data | |
| new_participants = [("new_p1", "new_key_1"), ("new_p2", "new_key_2")] | |
| rebuild_authorized_keys_csv(key_dir, new_participants) | |
| with open(csv_path, "r") as f: | |
| reader = csv.reader(f) | |
| _ = next(reader) | |
| rows = list(reader) | |
| assert len(rows) == 2 | |
| assert rows[0] == list(new_participants[0]) | |
| assert rows[1] == list(new_participants[1]) | |