Spaces:
Sleeping
Sleeping
File size: 5,476 Bytes
3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 3a1c55b bc67f56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import logging
from typing import List, Tuple
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
# Configure logging for the module
log = logging.getLogger(__name__)
def _sanitize_key(participant_id: str, key_str: str) -> str | None:
"""
Inspects a key string and converts it to the required OpenSSH format if necessary.
This provides resilience against old, PEM-formatted keys in the database.
"""
if not key_str:
return None
# If the key is already in the correct OpenSSH format, return it as is.
if key_str.startswith("ecdsa-sha2-nistp384"):
return key_str
# If the key is in the old PEM format, attempt to convert it.
if "-----BEGIN PUBLIC KEY-----" in key_str:
log.warning(
f"Found PEM-formatted key for participant {participant_id}. Converting to OpenSSH."
)
try:
public_key = serialization.load_pem_public_key(key_str.encode("utf-8"))
key_body = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
).decode("utf-8")
# Re-add the participant_id as a comment to conform to the 3-part format.
return f"{key_body} {participant_id}"
except Exception as e:
log.error(f"Could not convert PEM key for {participant_id}: {e}")
return None
# If the key format is unknown, log an error and skip it.
log.error(f"Unknown public key format for participant {participant_id}. Skipping.")
return None
def rebuild_authorized_keys_csv(
key_dir: str, authorized_participants: List[Tuple[str, str]]
):
"""
Overwrites the public key file with a fresh list from a trusted source,
using the specific single-line, comma-separated format expected by Flower.
"""
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
log.info(f"Rebuilding authorized keys file at: {csv_path}")
# Sanitize each key before adding it to the list.
public_keys = [
sanitized_key
for p_id, key_string in authorized_participants
if (sanitized_key := _sanitize_key(p_id, key_string)) is not None
]
# Join all valid public keys into a single comma-separated string.
content = ",".join(public_keys)
# Write the single line to the file, followed by a newline.
with open(csv_path, "w") as f:
f.write(content + "\n")
log.info(f"Successfully rebuilt {csv_path} with {len(public_keys)} keys.")
class AuthKeyGenerator:
"""
Handles the generation of Elliptic Curve (EC) key pairs for Flower
SuperNode authentication.
"""
def __init__(self, key_dir: str = "keys"):
self.key_dir = key_dir
os.makedirs(self.key_dir, exist_ok=True)
log.info(f"Authentication key directory set to: {self.key_dir}")
def _generate_key_pair(self) -> ec.EllipticCurvePrivateKey:
"""Generates a single EC private key using the SECP384R1 curve."""
return ec.generate_private_key(ec.SECP384R1())
def _save_private_key(
self, private_key: ec.EllipticCurvePrivateKey, participant_id: str
) -> str:
"""Saves the private key to a file with secure permissions."""
priv_key_path = os.path.join(self.key_dir, f"{participant_id}.key")
with open(priv_key_path, "wb") as f:
f.write(
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
)
os.chmod(priv_key_path, 0o600)
log.info(f"Private key for {participant_id} saved securely to {priv_key_path}")
return priv_key_path
def _save_public_key_file(
self, public_key_ssh_string: str, participant_id: str
) -> str:
"""Saves the full OpenSSH public key string to a .pub file."""
pub_key_path = os.path.join(self.key_dir, f"{participant_id}.pub")
with open(pub_key_path, "w") as f:
f.write(public_key_ssh_string)
log.info(f"Public key for {participant_id} saved to {pub_key_path}")
return pub_key_path
def generate_participant_keys(self, participant_id: str) -> Tuple[str, str, str]:
"""
Generates and saves a new EC key pair for a participant.
Returns:
A tuple containing:
- The file path to the generated private key.
- The file path to the generated public key.
- The public key as a single-line OpenSSH string with a comment.
"""
private_key = self._generate_key_pair()
public_key = private_key.public_key()
# Generate the base OpenSSH key string (type and key data)
key_body = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
).decode("utf-8")
# Append the participant_id as a comment, creating the full 3-part key.
public_key_ssh_string = f"{key_body} {participant_id}"
priv_key_path = self._save_private_key(private_key, participant_id)
pub_key_path = self._save_public_key_file(public_key_ssh_string, participant_id)
return (priv_key_path, pub_key_path, public_key_ssh_string)
|