""" Certificate Authority Service Provides PKI infrastructure for VPN client certificate management """ import os import subprocess import logging from cryptography import x509 from cryptography.x509.oid import NameOID, ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from datetime import datetime, timedelta logger = logging.getLogger(__name__) class CertificateAuthority: """Certificate Authority for VPN client certificates""" def __init__(self, ca_dir='/etc/vpn-ca'): self.ca_dir = ca_dir self.ca_cert_path = os.path.join(ca_dir, 'ca.crt') self.ca_key_path = os.path.join(ca_dir, 'ca.key') self.crl_path = os.path.join(ca_dir, 'crl.pem') self.serial_file = os.path.join(ca_dir, 'serial') self.index_file = os.path.join(ca_dir, 'index.txt') # Ensure CA directory exists os.makedirs(ca_dir, mode=0o700, exist_ok=True) # Initialize CA if not exists if not os.path.exists(self.ca_cert_path): self._create_root_ca() # Initialize index file for certificate tracking if not os.path.exists(self.index_file): with open(self.index_file, 'w') as f: f.write('') # Empty index file def _create_root_ca(self): """Create root CA certificate and private key""" try: logger.info("Creating root CA certificate") # Generate private key private_key = rsa.generate_private_key( public_exponent=65537, key_size=4096 ) # Create certificate subject = issuer = x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Virtual"), x509.NameAttribute(NameOID.LOCALITY_NAME, "Internet"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "VPN Service CA"), x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Certificate Authority"), x509.NameAttribute(NameOID.COMMON_NAME, "VPN Root CA"), ]) cert = x509.CertificateBuilder().subject_name( subject ).issuer_name( issuer ).public_key( private_key.public_key() ).serial_number( 1 ).not_valid_before( datetime.utcnow() ).not_valid_after( datetime.utcnow() + timedelta(days=3650) # 10 years ).add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True, ).add_extension( x509.KeyUsage( key_cert_sign=True, crl_sign=True, digital_signature=False, key_encipherment=False, key_agreement=False, data_encipherment=False, content_commitment=False, encipher_only=False, decipher_only=False ), critical=True, ).add_extension( x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), critical=False, ).sign(private_key, hashes.SHA256()) # Save certificate and private key with open(self.ca_cert_path, 'wb') as f: f.write(cert.public_bytes(serialization.Encoding.PEM)) with open(self.ca_key_path, 'wb') as f: f.write(private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() )) # Set secure permissions os.chmod(self.ca_key_path, 0o600) os.chmod(self.ca_cert_path, 0o644) # Initialize serial number file with open(self.serial_file, 'w') as f: f.write('02') logger.info("Root CA certificate created successfully") except Exception as e: logger.error(f"Failed to create root CA: {e}") raise def generate_client_certificate(self, username, email, validity_days=365): """Generate client certificate for VPN authentication""" try: logger.info(f"Generating client certificate for {username}") # Load CA certificate and private key with open(self.ca_cert_path, 'rb') as f: ca_cert = x509.load_pem_x509_certificate(f.read()) with open(self.ca_key_path, 'rb') as f: ca_private_key = serialization.load_pem_private_key(f.read(), password=None) # Generate client private key client_private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048 ) # Get next serial number serial_number = self._get_next_serial() # Create client certificate subject = x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Virtual"), x509.NameAttribute(NameOID.LOCALITY_NAME, "Internet"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "VPN Service"), x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "VPN Clients"), x509.NameAttribute(NameOID.COMMON_NAME, username), x509.NameAttribute(NameOID.EMAIL_ADDRESS, email), ]) cert = x509.CertificateBuilder().subject_name( subject ).issuer_name( ca_cert.subject ).public_key( client_private_key.public_key() ).serial_number( serial_number ).not_valid_before( datetime.utcnow() ).not_valid_after( datetime.utcnow() + timedelta(days=validity_days) ).add_extension( x509.BasicConstraints(ca=False, path_length=None), critical=True, ).add_extension( x509.KeyUsage( key_cert_sign=False, crl_sign=False, digital_signature=True, key_encipherment=True, key_agreement=False, data_encipherment=False, content_commitment=False, encipher_only=False, decipher_only=False ), critical=True, ).add_extension( x509.ExtendedKeyUsage([ x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH, ]), critical=True, ).add_extension( x509.SubjectKeyIdentifier.from_public_key(client_private_key.public_key()), critical=False, ).add_extension( x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_cert.public_key()), critical=False, ).sign(ca_private_key, hashes.SHA256()) # Update certificate index self._update_certificate_index(cert, 'V') # V = Valid logger.info(f"Client certificate generated successfully for {username} (Serial: {serial_number})") return { 'certificate': cert.public_bytes(serialization.Encoding.PEM), 'private_key': client_private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() ), 'serial_number': serial_number, 'not_valid_before': cert.not_valid_before, 'not_valid_after': cert.not_valid_after } except Exception as e: logger.error(f"Failed to generate client certificate: {e}") raise def revoke_certificate(self, serial_number, reason='unspecified'): """Revoke a client certificate""" try: logger.info(f"Revoking certificate with serial {serial_number}") # Update certificate index self._update_certificate_index_status(serial_number, 'R', reason) # Generate new CRL self._generate_crl() logger.info(f"Certificate {serial_number} revoked successfully") except Exception as e: logger.error(f"Failed to revoke certificate {serial_number}: {e}") raise def get_certificate_status(self, serial_number): """Get certificate status from index""" try: with open(self.index_file, 'r') as f: for line in f: if line.strip(): parts = line.strip().split('\t') if len(parts) >= 3 and parts[3] == str(serial_number): status = parts[0] if status == 'V': return 'valid' elif status == 'R': return 'revoked' elif status == 'E': return 'expired' return 'unknown' except Exception as e: logger.error(f"Failed to get certificate status: {e}") return 'unknown' def list_certificates(self): """List all certificates in the index""" try: certificates = [] with open(self.index_file, 'r') as f: for line in f: if line.strip(): parts = line.strip().split('\t') if len(parts) >= 6: cert_info = { 'status': parts[0], 'expiry_date': parts[1], 'revocation_date': parts[2] if parts[2] else None, 'serial_number': parts[3], 'filename': parts[4], 'subject': parts[5] } certificates.append(cert_info) return certificates except Exception as e: logger.error(f"Failed to list certificates: {e}") return [] def _get_next_serial(self): """Get next serial number for certificate""" try: with open(self.serial_file, 'r') as f: serial = int(f.read().strip(), 16) except (FileNotFoundError, ValueError): serial = 2 # Update serial file with open(self.serial_file, 'w') as f: f.write(f'{serial + 1:02X}') return serial def _update_certificate_index(self, cert, status): """Update certificate index with new certificate""" try: # Format: status \t expiry_date \t revocation_date \t serial \t filename \t subject expiry_date = cert.not_valid_after.strftime('%y%m%d%H%M%SZ') serial_hex = f'{cert.serial_number:02X}' subject_str = cert.subject.rfc4514_string() index_line = f"{status}\t{expiry_date}\t\t{serial_hex}\tunknown\t{subject_str}\n" with open(self.index_file, 'a') as f: f.write(index_line) except Exception as e: logger.error(f"Failed to update certificate index: {e}") raise def _update_certificate_index_status(self, serial_number, new_status, reason=None): """Update certificate status in index""" try: lines = [] updated = False with open(self.index_file, 'r') as f: lines = f.readlines() for i, line in enumerate(lines): if line.strip(): parts = line.strip().split('\t') if len(parts) >= 4 and parts[3] == str(serial_number): # Update status parts[0] = new_status if new_status == 'R' and reason: # Add revocation date parts[2] = datetime.utcnow().strftime('%y%m%d%H%M%SZ') lines[i] = '\t'.join(parts) + '\n' updated = True break if updated: with open(self.index_file, 'w') as f: f.writelines(lines) except Exception as e: logger.error(f"Failed to update certificate status: {e}") raise def _generate_crl(self): """Generate Certificate Revocation List""" try: logger.info("Generating Certificate Revocation List") # Load CA certificate and private key with open(self.ca_cert_path, 'rb') as f: ca_cert = x509.load_pem_x509_certificate(f.read()) with open(self.ca_key_path, 'rb') as f: ca_private_key = serialization.load_pem_private_key(f.read(), password=None) # Get revoked certificates revoked_certs = [] with open(self.index_file, 'r') as f: for line in f: if line.strip(): parts = line.strip().split('\t') if len(parts) >= 4 and parts[0] == 'R': serial_number = int(parts[3], 16) revocation_date = datetime.strptime(parts[2], '%y%m%d%H%M%SZ') revoked_cert = x509.RevokedCertificateBuilder().serial_number( serial_number ).revocation_date( revocation_date ).build() revoked_certs.append(revoked_cert) # Build CRL crl_builder = x509.CertificateRevocationListBuilder().issuer_name( ca_cert.subject ).last_update( datetime.utcnow() ).next_update( datetime.utcnow() + timedelta(days=30) ) for revoked_cert in revoked_certs: crl_builder = crl_builder.add_revoked_certificate(revoked_cert) crl = crl_builder.sign(ca_private_key, hashes.SHA256()) # Save CRL with open(self.crl_path, 'wb') as f: f.write(crl.public_bytes(serialization.Encoding.PEM)) logger.info("CRL generated successfully") except Exception as e: logger.error(f"Failed to generate CRL: {e}") raise def get_ca_certificate(self): """Get CA certificate in PEM format""" try: with open(self.ca_cert_path, 'rb') as f: return f.read() except Exception as e: logger.error(f"Failed to read CA certificate: {e}") return None def get_crl(self): """Get Certificate Revocation List""" try: if os.path.exists(self.crl_path): with open(self.crl_path, 'rb') as f: return f.read() else: # Generate CRL if it doesn't exist self._generate_crl() with open(self.crl_path, 'rb') as f: return f.read() except Exception as e: logger.error(f"Failed to read CRL: {e}") return None def verify_certificate(self, cert_pem): """Verify a certificate against the CA""" try: # Load certificates cert = x509.load_pem_x509_certificate(cert_pem) with open(self.ca_cert_path, 'rb') as f: ca_cert = x509.load_pem_x509_certificate(f.read()) # Verify signature ca_public_key = ca_cert.public_key() ca_public_key.verify( cert.signature, cert.tbs_certificate_bytes, cert.signature_algorithm_oid._name ) # Check validity period now = datetime.utcnow() if now < cert.not_valid_before or now > cert.not_valid_after: return False, "Certificate expired or not yet valid" # Check revocation status status = self.get_certificate_status(cert.serial_number) if status == 'revoked': return False, "Certificate revoked" return True, "Certificate valid" except Exception as e: logger.error(f"Certificate verification failed: {e}") return False, str(e) def cleanup_expired_certificates(self): """Clean up expired certificates from the index""" try: logger.info("Cleaning up expired certificates") lines = [] updated_count = 0 with open(self.index_file, 'r') as f: lines = f.readlines() now = datetime.utcnow() for i, line in enumerate(lines): if line.strip(): parts = line.strip().split('\t') if len(parts) >= 2 and parts[0] == 'V': try: expiry_date = datetime.strptime(parts[1], '%y%m%d%H%M%SZ') if now > expiry_date: # Mark as expired parts[0] = 'E' lines[i] = '\t'.join(parts) + '\n' updated_count += 1 except ValueError: continue if updated_count > 0: with open(self.index_file, 'w') as f: f.writelines(lines) logger.info(f"Marked {updated_count} certificates as expired") except Exception as e: logger.error(f"Failed to cleanup expired certificates: {e}") def get_statistics(self): """Get CA statistics""" try: stats = { 'total_certificates': 0, 'valid_certificates': 0, 'revoked_certificates': 0, 'expired_certificates': 0 } with open(self.index_file, 'r') as f: for line in f: if line.strip(): parts = line.strip().split('\t') if len(parts) >= 1: stats['total_certificates'] += 1 status = parts[0] if status == 'V': stats['valid_certificates'] += 1 elif status == 'R': stats['revoked_certificates'] += 1 elif status == 'E': stats['expired_certificates'] += 1 return stats except Exception as e: logger.error(f"Failed to get CA statistics: {e}") return { 'total_certificates': 0, 'valid_certificates': 0, 'revoked_certificates': 0, 'expired_certificates': 0 }