Spaces:
Paused
Paused
| """ | |
| Certificate Authority Service | |
| Provides PKI infrastructure for VPN client certificate management | |
| """ | |
| import os | |
| import subprocess | |
| import logging | |
| from cryptography import x509 | |
| from cryptography.x509.oid import NameOID, ExtensionOID | |
| from cryptography.hazmat.primitives import hashes, serialization | |
| from cryptography.hazmat.primitives.asymmetric import rsa | |
| from datetime import datetime, timedelta | |
| logger = logging.getLogger(__name__) | |
| class CertificateAuthority: | |
| """Certificate Authority for VPN client certificates""" | |
| def __init__(self, ca_dir='/etc/vpn-ca'): | |
| self.ca_dir = ca_dir | |
| self.ca_cert_path = os.path.join(ca_dir, 'ca.crt') | |
| self.ca_key_path = os.path.join(ca_dir, 'ca.key') | |
| self.crl_path = os.path.join(ca_dir, 'crl.pem') | |
| self.serial_file = os.path.join(ca_dir, 'serial') | |
| self.index_file = os.path.join(ca_dir, 'index.txt') | |
| # Ensure CA directory exists | |
| os.makedirs(ca_dir, mode=0o700, exist_ok=True) | |
| # Initialize CA if not exists | |
| if not os.path.exists(self.ca_cert_path): | |
| self._create_root_ca() | |
| # Initialize index file for certificate tracking | |
| if not os.path.exists(self.index_file): | |
| with open(self.index_file, 'w') as f: | |
| f.write('') # Empty index file | |
| def _create_root_ca(self): | |
| """Create root CA certificate and private key""" | |
| try: | |
| logger.info("Creating root CA certificate") | |
| # Generate private key | |
| private_key = rsa.generate_private_key( | |
| public_exponent=65537, | |
| key_size=4096 | |
| ) | |
| # Create certificate | |
| subject = issuer = x509.Name([ | |
| x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), | |
| x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Virtual"), | |
| x509.NameAttribute(NameOID.LOCALITY_NAME, "Internet"), | |
| x509.NameAttribute(NameOID.ORGANIZATION_NAME, "VPN Service CA"), | |
| x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Certificate Authority"), | |
| x509.NameAttribute(NameOID.COMMON_NAME, "VPN Root CA"), | |
| ]) | |
| cert = x509.CertificateBuilder().subject_name( | |
| subject | |
| ).issuer_name( | |
| issuer | |
| ).public_key( | |
| private_key.public_key() | |
| ).serial_number( | |
| 1 | |
| ).not_valid_before( | |
| datetime.utcnow() | |
| ).not_valid_after( | |
| datetime.utcnow() + timedelta(days=3650) # 10 years | |
| ).add_extension( | |
| x509.BasicConstraints(ca=True, path_length=None), | |
| critical=True, | |
| ).add_extension( | |
| x509.KeyUsage( | |
| key_cert_sign=True, | |
| crl_sign=True, | |
| digital_signature=False, | |
| key_encipherment=False, | |
| key_agreement=False, | |
| data_encipherment=False, | |
| content_commitment=False, | |
| encipher_only=False, | |
| decipher_only=False | |
| ), | |
| critical=True, | |
| ).add_extension( | |
| x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), | |
| critical=False, | |
| ).sign(private_key, hashes.SHA256()) | |
| # Save certificate and private key | |
| with open(self.ca_cert_path, 'wb') as f: | |
| f.write(cert.public_bytes(serialization.Encoding.PEM)) | |
| with open(self.ca_key_path, 'wb') as f: | |
| f.write(private_key.private_bytes( | |
| encoding=serialization.Encoding.PEM, | |
| format=serialization.PrivateFormat.PKCS8, | |
| encryption_algorithm=serialization.NoEncryption() | |
| )) | |
| # Set secure permissions | |
| os.chmod(self.ca_key_path, 0o600) | |
| os.chmod(self.ca_cert_path, 0o644) | |
| # Initialize serial number file | |
| with open(self.serial_file, 'w') as f: | |
| f.write('02') | |
| logger.info("Root CA certificate created successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to create root CA: {e}") | |
| raise | |
| def generate_client_certificate(self, username, email, validity_days=365): | |
| """Generate client certificate for VPN authentication""" | |
| try: | |
| logger.info(f"Generating client certificate for {username}") | |
| # Load CA certificate and private key | |
| with open(self.ca_cert_path, 'rb') as f: | |
| ca_cert = x509.load_pem_x509_certificate(f.read()) | |
| with open(self.ca_key_path, 'rb') as f: | |
| ca_private_key = serialization.load_pem_private_key(f.read(), password=None) | |
| # Generate client private key | |
| client_private_key = rsa.generate_private_key( | |
| public_exponent=65537, | |
| key_size=2048 | |
| ) | |
| # Get next serial number | |
| serial_number = self._get_next_serial() | |
| # Create client certificate | |
| subject = x509.Name([ | |
| x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), | |
| x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Virtual"), | |
| x509.NameAttribute(NameOID.LOCALITY_NAME, "Internet"), | |
| x509.NameAttribute(NameOID.ORGANIZATION_NAME, "VPN Service"), | |
| x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "VPN Clients"), | |
| x509.NameAttribute(NameOID.COMMON_NAME, username), | |
| x509.NameAttribute(NameOID.EMAIL_ADDRESS, email), | |
| ]) | |
| cert = x509.CertificateBuilder().subject_name( | |
| subject | |
| ).issuer_name( | |
| ca_cert.subject | |
| ).public_key( | |
| client_private_key.public_key() | |
| ).serial_number( | |
| serial_number | |
| ).not_valid_before( | |
| datetime.utcnow() | |
| ).not_valid_after( | |
| datetime.utcnow() + timedelta(days=validity_days) | |
| ).add_extension( | |
| x509.BasicConstraints(ca=False, path_length=None), | |
| critical=True, | |
| ).add_extension( | |
| x509.KeyUsage( | |
| key_cert_sign=False, | |
| crl_sign=False, | |
| digital_signature=True, | |
| key_encipherment=True, | |
| key_agreement=False, | |
| data_encipherment=False, | |
| content_commitment=False, | |
| encipher_only=False, | |
| decipher_only=False | |
| ), | |
| critical=True, | |
| ).add_extension( | |
| x509.ExtendedKeyUsage([ | |
| x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH, | |
| ]), | |
| critical=True, | |
| ).add_extension( | |
| x509.SubjectKeyIdentifier.from_public_key(client_private_key.public_key()), | |
| critical=False, | |
| ).add_extension( | |
| x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_cert.public_key()), | |
| critical=False, | |
| ).sign(ca_private_key, hashes.SHA256()) | |
| # Update certificate index | |
| self._update_certificate_index(cert, 'V') # V = Valid | |
| logger.info(f"Client certificate generated successfully for {username} (Serial: {serial_number})") | |
| return { | |
| 'certificate': cert.public_bytes(serialization.Encoding.PEM), | |
| 'private_key': client_private_key.private_bytes( | |
| encoding=serialization.Encoding.PEM, | |
| format=serialization.PrivateFormat.PKCS8, | |
| encryption_algorithm=serialization.NoEncryption() | |
| ), | |
| 'serial_number': serial_number, | |
| 'not_valid_before': cert.not_valid_before, | |
| 'not_valid_after': cert.not_valid_after | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to generate client certificate: {e}") | |
| raise | |
| def revoke_certificate(self, serial_number, reason='unspecified'): | |
| """Revoke a client certificate""" | |
| try: | |
| logger.info(f"Revoking certificate with serial {serial_number}") | |
| # Update certificate index | |
| self._update_certificate_index_status(serial_number, 'R', reason) | |
| # Generate new CRL | |
| self._generate_crl() | |
| logger.info(f"Certificate {serial_number} revoked successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to revoke certificate {serial_number}: {e}") | |
| raise | |
| def get_certificate_status(self, serial_number): | |
| """Get certificate status from index""" | |
| try: | |
| with open(self.index_file, 'r') as f: | |
| for line in f: | |
| if line.strip(): | |
| parts = line.strip().split('\t') | |
| if len(parts) >= 3 and parts[3] == str(serial_number): | |
| status = parts[0] | |
| if status == 'V': | |
| return 'valid' | |
| elif status == 'R': | |
| return 'revoked' | |
| elif status == 'E': | |
| return 'expired' | |
| return 'unknown' | |
| except Exception as e: | |
| logger.error(f"Failed to get certificate status: {e}") | |
| return 'unknown' | |
| def list_certificates(self): | |
| """List all certificates in the index""" | |
| try: | |
| certificates = [] | |
| with open(self.index_file, 'r') as f: | |
| for line in f: | |
| if line.strip(): | |
| parts = line.strip().split('\t') | |
| if len(parts) >= 6: | |
| cert_info = { | |
| 'status': parts[0], | |
| 'expiry_date': parts[1], | |
| 'revocation_date': parts[2] if parts[2] else None, | |
| 'serial_number': parts[3], | |
| 'filename': parts[4], | |
| 'subject': parts[5] | |
| } | |
| certificates.append(cert_info) | |
| return certificates | |
| except Exception as e: | |
| logger.error(f"Failed to list certificates: {e}") | |
| return [] | |
| def _get_next_serial(self): | |
| """Get next serial number for certificate""" | |
| try: | |
| with open(self.serial_file, 'r') as f: | |
| serial = int(f.read().strip(), 16) | |
| except (FileNotFoundError, ValueError): | |
| serial = 2 | |
| # Update serial file | |
| with open(self.serial_file, 'w') as f: | |
| f.write(f'{serial + 1:02X}') | |
| return serial | |
| def _update_certificate_index(self, cert, status): | |
| """Update certificate index with new certificate""" | |
| try: | |
| # Format: status \t expiry_date \t revocation_date \t serial \t filename \t subject | |
| expiry_date = cert.not_valid_after.strftime('%y%m%d%H%M%SZ') | |
| serial_hex = f'{cert.serial_number:02X}' | |
| subject_str = cert.subject.rfc4514_string() | |
| index_line = f"{status}\t{expiry_date}\t\t{serial_hex}\tunknown\t{subject_str}\n" | |
| with open(self.index_file, 'a') as f: | |
| f.write(index_line) | |
| except Exception as e: | |
| logger.error(f"Failed to update certificate index: {e}") | |
| raise | |
| def _update_certificate_index_status(self, serial_number, new_status, reason=None): | |
| """Update certificate status in index""" | |
| try: | |
| lines = [] | |
| updated = False | |
| with open(self.index_file, 'r') as f: | |
| lines = f.readlines() | |
| for i, line in enumerate(lines): | |
| if line.strip(): | |
| parts = line.strip().split('\t') | |
| if len(parts) >= 4 and parts[3] == str(serial_number): | |
| # Update status | |
| parts[0] = new_status | |
| if new_status == 'R' and reason: | |
| # Add revocation date | |
| parts[2] = datetime.utcnow().strftime('%y%m%d%H%M%SZ') | |
| lines[i] = '\t'.join(parts) + '\n' | |
| updated = True | |
| break | |
| if updated: | |
| with open(self.index_file, 'w') as f: | |
| f.writelines(lines) | |
| except Exception as e: | |
| logger.error(f"Failed to update certificate status: {e}") | |
| raise | |
| def _generate_crl(self): | |
| """Generate Certificate Revocation List""" | |
| try: | |
| logger.info("Generating Certificate Revocation List") | |
| # Load CA certificate and private key | |
| with open(self.ca_cert_path, 'rb') as f: | |
| ca_cert = x509.load_pem_x509_certificate(f.read()) | |
| with open(self.ca_key_path, 'rb') as f: | |
| ca_private_key = serialization.load_pem_private_key(f.read(), password=None) | |
| # Get revoked certificates | |
| revoked_certs = [] | |
| with open(self.index_file, 'r') as f: | |
| for line in f: | |
| if line.strip(): | |
| parts = line.strip().split('\t') | |
| if len(parts) >= 4 and parts[0] == 'R': | |
| serial_number = int(parts[3], 16) | |
| revocation_date = datetime.strptime(parts[2], '%y%m%d%H%M%SZ') | |
| revoked_cert = x509.RevokedCertificateBuilder().serial_number( | |
| serial_number | |
| ).revocation_date( | |
| revocation_date | |
| ).build() | |
| revoked_certs.append(revoked_cert) | |
| # Build CRL | |
| crl_builder = x509.CertificateRevocationListBuilder().issuer_name( | |
| ca_cert.subject | |
| ).last_update( | |
| datetime.utcnow() | |
| ).next_update( | |
| datetime.utcnow() + timedelta(days=30) | |
| ) | |
| for revoked_cert in revoked_certs: | |
| crl_builder = crl_builder.add_revoked_certificate(revoked_cert) | |
| crl = crl_builder.sign(ca_private_key, hashes.SHA256()) | |
| # Save CRL | |
| with open(self.crl_path, 'wb') as f: | |
| f.write(crl.public_bytes(serialization.Encoding.PEM)) | |
| logger.info("CRL generated successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to generate CRL: {e}") | |
| raise | |
| def get_ca_certificate(self): | |
| """Get CA certificate in PEM format""" | |
| try: | |
| with open(self.ca_cert_path, 'rb') as f: | |
| return f.read() | |
| except Exception as e: | |
| logger.error(f"Failed to read CA certificate: {e}") | |
| return None | |
| def get_crl(self): | |
| """Get Certificate Revocation List""" | |
| try: | |
| if os.path.exists(self.crl_path): | |
| with open(self.crl_path, 'rb') as f: | |
| return f.read() | |
| else: | |
| # Generate CRL if it doesn't exist | |
| self._generate_crl() | |
| with open(self.crl_path, 'rb') as f: | |
| return f.read() | |
| except Exception as e: | |
| logger.error(f"Failed to read CRL: {e}") | |
| return None | |
| def verify_certificate(self, cert_pem): | |
| """Verify a certificate against the CA""" | |
| try: | |
| # Load certificates | |
| cert = x509.load_pem_x509_certificate(cert_pem) | |
| with open(self.ca_cert_path, 'rb') as f: | |
| ca_cert = x509.load_pem_x509_certificate(f.read()) | |
| # Verify signature | |
| ca_public_key = ca_cert.public_key() | |
| ca_public_key.verify( | |
| cert.signature, | |
| cert.tbs_certificate_bytes, | |
| cert.signature_algorithm_oid._name | |
| ) | |
| # Check validity period | |
| now = datetime.utcnow() | |
| if now < cert.not_valid_before or now > cert.not_valid_after: | |
| return False, "Certificate expired or not yet valid" | |
| # Check revocation status | |
| status = self.get_certificate_status(cert.serial_number) | |
| if status == 'revoked': | |
| return False, "Certificate revoked" | |
| return True, "Certificate valid" | |
| except Exception as e: | |
| logger.error(f"Certificate verification failed: {e}") | |
| return False, str(e) | |
| def cleanup_expired_certificates(self): | |
| """Clean up expired certificates from the index""" | |
| try: | |
| logger.info("Cleaning up expired certificates") | |
| lines = [] | |
| updated_count = 0 | |
| with open(self.index_file, 'r') as f: | |
| lines = f.readlines() | |
| now = datetime.utcnow() | |
| for i, line in enumerate(lines): | |
| if line.strip(): | |
| parts = line.strip().split('\t') | |
| if len(parts) >= 2 and parts[0] == 'V': | |
| try: | |
| expiry_date = datetime.strptime(parts[1], '%y%m%d%H%M%SZ') | |
| if now > expiry_date: | |
| # Mark as expired | |
| parts[0] = 'E' | |
| lines[i] = '\t'.join(parts) + '\n' | |
| updated_count += 1 | |
| except ValueError: | |
| continue | |
| if updated_count > 0: | |
| with open(self.index_file, 'w') as f: | |
| f.writelines(lines) | |
| logger.info(f"Marked {updated_count} certificates as expired") | |
| except Exception as e: | |
| logger.error(f"Failed to cleanup expired certificates: {e}") | |
| def get_statistics(self): | |
| """Get CA statistics""" | |
| try: | |
| stats = { | |
| 'total_certificates': 0, | |
| 'valid_certificates': 0, | |
| 'revoked_certificates': 0, | |
| 'expired_certificates': 0 | |
| } | |
| with open(self.index_file, 'r') as f: | |
| for line in f: | |
| if line.strip(): | |
| parts = line.strip().split('\t') | |
| if len(parts) >= 1: | |
| stats['total_certificates'] += 1 | |
| status = parts[0] | |
| if status == 'V': | |
| stats['valid_certificates'] += 1 | |
| elif status == 'R': | |
| stats['revoked_certificates'] += 1 | |
| elif status == 'E': | |
| stats['expired_certificates'] += 1 | |
| return stats | |
| except Exception as e: | |
| logger.error(f"Failed to get CA statistics: {e}") | |
| return { | |
| 'total_certificates': 0, | |
| 'valid_certificates': 0, | |
| 'revoked_certificates': 0, | |
| 'expired_certificates': 0 | |
| } | |