File size: 5,476 Bytes
3a1c55b
 
 
 
 
 
bc67f56
3a1c55b
 
 
bc67f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a1c55b
 
 
 
bc67f56
 
3a1c55b
 
bc67f56
3a1c55b
bc67f56
 
 
 
 
 
3a1c55b
bc67f56
 
 
 
 
 
 
 
3a1c55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc67f56
3a1c55b
bc67f56
3a1c55b
bc67f56
 
3a1c55b
 
 
 
 
 
 
 
 
 
 
bc67f56
3a1c55b
 
 
 
bc67f56
 
 
 
3a1c55b
 
bc67f56
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import logging
from typing import List, Tuple
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization

# Configure logging for the module
log = logging.getLogger(__name__)


def _sanitize_key(participant_id: str, key_str: str) -> str | None:
    """
    Inspects a key string and converts it to the required OpenSSH format if necessary.
    This provides resilience against old, PEM-formatted keys in the database.
    """
    if not key_str:
        return None
    # If the key is already in the correct OpenSSH format, return it as is.
    if key_str.startswith("ecdsa-sha2-nistp384"):
        return key_str
    # If the key is in the old PEM format, attempt to convert it.
    if "-----BEGIN PUBLIC KEY-----" in key_str:
        log.warning(
            f"Found PEM-formatted key for participant {participant_id}. Converting to OpenSSH."
        )
        try:
            public_key = serialization.load_pem_public_key(key_str.encode("utf-8"))
            key_body = public_key.public_bytes(
                encoding=serialization.Encoding.OpenSSH,
                format=serialization.PublicFormat.OpenSSH,
            ).decode("utf-8")
            # Re-add the participant_id as a comment to conform to the 3-part format.
            return f"{key_body} {participant_id}"
        except Exception as e:
            log.error(f"Could not convert PEM key for {participant_id}: {e}")
            return None
    # If the key format is unknown, log an error and skip it.
    log.error(f"Unknown public key format for participant {participant_id}. Skipping.")
    return None


def rebuild_authorized_keys_csv(
    key_dir: str, authorized_participants: List[Tuple[str, str]]
):
    """
    Overwrites the public key file with a fresh list from a trusted source,
    using the specific single-line, comma-separated format expected by Flower.
    """
    csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
    log.info(f"Rebuilding authorized keys file at: {csv_path}")

    # Sanitize each key before adding it to the list.
    public_keys = [
        sanitized_key
        for p_id, key_string in authorized_participants
        if (sanitized_key := _sanitize_key(p_id, key_string)) is not None
    ]

    # Join all valid public keys into a single comma-separated string.
    content = ",".join(public_keys)

    # Write the single line to the file, followed by a newline.
    with open(csv_path, "w") as f:
        f.write(content + "\n")

    log.info(f"Successfully rebuilt {csv_path} with {len(public_keys)} keys.")


class AuthKeyGenerator:
    """
    Handles the generation of Elliptic Curve (EC) key pairs for Flower
    SuperNode authentication.
    """

    def __init__(self, key_dir: str = "keys"):
        self.key_dir = key_dir
        os.makedirs(self.key_dir, exist_ok=True)
        log.info(f"Authentication key directory set to: {self.key_dir}")

    def _generate_key_pair(self) -> ec.EllipticCurvePrivateKey:
        """Generates a single EC private key using the SECP384R1 curve."""
        return ec.generate_private_key(ec.SECP384R1())

    def _save_private_key(
        self, private_key: ec.EllipticCurvePrivateKey, participant_id: str
    ) -> str:
        """Saves the private key to a file with secure permissions."""
        priv_key_path = os.path.join(self.key_dir, f"{participant_id}.key")
        with open(priv_key_path, "wb") as f:
            f.write(
                private_key.private_bytes(
                    encoding=serialization.Encoding.PEM,
                    format=serialization.PrivateFormat.PKCS8,
                    encryption_algorithm=serialization.NoEncryption(),
                )
            )
        os.chmod(priv_key_path, 0o600)
        log.info(f"Private key for {participant_id} saved securely to {priv_key_path}")
        return priv_key_path

    def _save_public_key_file(
        self, public_key_ssh_string: str, participant_id: str
    ) -> str:
        """Saves the full OpenSSH public key string to a .pub file."""
        pub_key_path = os.path.join(self.key_dir, f"{participant_id}.pub")
        with open(pub_key_path, "w") as f:
            f.write(public_key_ssh_string)
        log.info(f"Public key for {participant_id} saved to {pub_key_path}")
        return pub_key_path

    def generate_participant_keys(self, participant_id: str) -> Tuple[str, str, str]:
        """
        Generates and saves a new EC key pair for a participant.

        Returns:
            A tuple containing:
            - The file path to the generated private key.
            - The file path to the generated public key.
            - The public key as a single-line OpenSSH string with a comment.
        """
        private_key = self._generate_key_pair()
        public_key = private_key.public_key()

        # Generate the base OpenSSH key string (type and key data)
        key_body = public_key.public_bytes(
            encoding=serialization.Encoding.OpenSSH,
            format=serialization.PublicFormat.OpenSSH,
        ).decode("utf-8")

        # Append the participant_id as a comment, creating the full 3-part key.
        public_key_ssh_string = f"{key_body} {participant_id}"

        priv_key_path = self._save_private_key(private_key, participant_id)
        pub_key_path = self._save_public_key_file(public_key_ssh_string, participant_id)

        return (priv_key_path, pub_key_path, public_key_ssh_string)