""" Mutual TLS (mTLS) Implementation for Service-to-Service Authentication This module provides mTLS support for secure service-to-service communication within the Zenith Fraud Detection Platform. Features: - Certificate generation and management - Client/Server mTLS authentication - Certificate rotation and validation - Integration with FastAPI middleware """ import hashlib import logging import os import ssl from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path from typing import Any, Dict, Optional from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID logger = logging.getLogger(__name__) @dataclass class CertificateConfig: """Configuration for certificate generation.""" common_name: str organization: str = "Zenith Fraud Detection" organizational_unit: str = "Engineering" locality: str = "San Francisco" province: str = "CA" country: str = "US" validity_days: int = 365 key_size: int = 4096 @dataclass class ServiceCertificate: """Container for service certificate data.""" service_name: str certificate_pem: bytes private_key_pem: bytes certificate_hash: str expires_at: datetime issued_at: datetime class MTLSCertificateManager: """ Manages mTLS certificates for service-to-service authentication. Provides certificate generation, signing, and validation for secure internal communication between services. """ def __init__( self, certs_dir: str = "/app/certs", ca_cert_path: Optional[str] = None, ca_key_path: Optional[str] = None, ): self.certs_dir = Path(certs_dir) self.certs_dir.mkdir(parents=True, exist_ok=True) self.ca_cert_path = Path(ca_cert_path) if ca_cert_path else None self.ca_key_path = Path(ca_key_path) if ca_key_path else None self._service_certs: Dict[str, ServiceCertificate] = {} # Load or generate CA self._initialize_ca() def _initialize_ca(self): """Initialize or load the Certificate Authority.""" if self.ca_cert_path and self.ca_cert_path.exists(): logger.info(f"Loading existing CA from {self.ca_cert_path}") with open(self.ca_cert_path, "rb") as f: self._ca_cert_pem = f.read() with open(self.ca_key_path, "rb") as f: self._ca_key_pem = f.read() else: logger.info("Generating new CA certificate") self._generate_ca() if self.ca_cert_path: self._save_ca() def _generate_ca(self): """Generate a new self-signed CA certificate.""" # Generate CA private key self._ca_key = rsa.generate_private_key( public_exponent=65537, key_size=4096, backend=default_backend() ) # Generate CA certificate subject = issuer = x509.Name( [ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"), x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Zenith Fraud Detection"), x509.NameAttribute(NameOID.COMMON_NAME, "Zenith Internal CA"), ] ) self._ca_cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(issuer) .public_key(self._ca_key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=3650)) # 10 years .add_extension( x509.BasicConstraints(ca=True, path_length=0), critical=True, ) .add_extension( x509.KeyUsage( digital_signature=True, key_encipherment=False, content_commitment=False, data_encipherment=False, key_agreement=False, key_cert_sign=True, crl_sign=True, encipher_only=False, decipher_only=False, ), critical=True, ) .sign(self._ca_key, hashes.SHA256(), default_backend()) ) self._ca_cert_pem = self._ca_cert.public_bytes(serialization.Encoding.PEM) self._ca_key_pem = self._ca_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) def _save_ca(self): """Save CA certificate and key to disk.""" if self.ca_cert_path: with open(self.ca_cert_path, "wb") as f: f.write(self._ca_cert_pem) logger.info(f"CA certificate saved to {self.ca_cert_path}") if self.ca_key_path: with open(self.ca_key_path, "wb") as f: f.write(self._ca_key_pem) logger.info(f"CA key saved to {self.ca_key_path}") def generate_service_certificate( self, service_name: str, config: Optional[CertificateConfig] = None, ) -> ServiceCertificate: """ Generate a new certificate for a service. Args: service_name: Name of the service config: Certificate configuration (uses defaults if None) Returns: ServiceCertificate with cert, key, and metadata """ if config is None: config = CertificateConfig( common_name=service_name, ) # Generate service private key service_key = rsa.generate_private_key( public_exponent=65537, key_size=config.key_size, backend=default_backend() ) # Generate service CSR subject = x509.Name( [ x509.NameAttribute(NameOID.COUNTRY_NAME, config.country), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, config.province), x509.NameAttribute(NameOID.LOCALITY_NAME, config.locality), x509.NameAttribute(NameOID.ORGANIZATION_NAME, config.organization), x509.NameAttribute( NameOID.ORGANIZATIONAL_UNIT_NAME, config.organizational_unit ), x509.NameAttribute(NameOID.COMMON_NAME, config.common_name), x509.NameAttribute(NameOID.SERIAL_NUMBER, f"service-{service_name}"), ] ) # Sign with CA service_cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(self._ca_cert.subject) .public_key(service_key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=config.validity_days)) .add_extension( x509.BasicConstraints(ca=False, path_length=None), critical=True, ) .add_extension( x509.KeyUsage( digital_signature=True, key_encipherment=True, content_commitment=False, data_encipherment=False, key_agreement=False, key_cert_sign=False, crl_sign=False, encipher_only=False, decipher_only=False, ), critical=True, ) .add_extension( x509.ExtendedKeyUsage( [ x509.oid.ExtendedKeyUsageOID.SERVER_AUTH, x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH, ] ), critical=False, ) .add_extension( x509.SubjectAlternativeName( [ x509.DNSName(service_name), x509.DNSName(f"{service_name}.internal"), x509.DNSName(f"{service_name}.zenith.local"), ] ), critical=False, ) .sign(self._ca_key, hashes.SHA256(), default_backend()) ) # Serialize certificate and key cert_pem = service_cert.public_bytes(serialization.Encoding.PEM) key_pem = service_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) # Calculate certificate hash cert_hash = hashlib.sha256(cert_pem).hexdigest() service_cert_data = ServiceCertificate( service_name=service_name, certificate_pem=cert_pem, private_key_pem=key_pem, certificate_hash=cert_hash, expires_at=service_cert.not_valid_after_utc, issued_at=service_cert.not_valid_before_utc, ) # Store certificate self._service_certs[service_name] = service_cert_data # Save to disk self._save_service_cert(service_name, service_cert_data) logger.info(f"Generated mTLS certificate for service: {service_name}") return service_cert_data def _save_service_cert(self, service_name: str, cert_data: ServiceCertificate): """Save service certificate and key to disk.""" cert_path = self.certs_dir / f"{service_name}.crt" key_path = self.certs_dir / f"{service_name}.key" with open(cert_path, "wb") as f: f.write(cert_data.certificate_pem) with open(key_path, "wb") as f: f.write(cert_data.private_key_pem) os.chmod(key_path, 0o600) # Restrict key file permissions logger.info(f"Service certificate saved to {cert_path}") def get_service_certificate( self, service_name: str ) -> Optional[ServiceCertificate]: """Retrieve a service certificate from cache or disk.""" if service_name in self._service_certs: return self._service_certs[service_name] cert_path = self.certs_dir / f"{service_name}.crt" key_path = self.certs_dir / f"{service_name}.key" if cert_path.exists() and key_path.exists(): with open(cert_path, "rb") as f: cert_pem = f.read() with open(key_path, "rb") as f: key_pem = f.read() cert_hash = hashlib.sha256(cert_pem).hexdigest() # Load and parse certificate to get expiry cert = x509.load_pem_x509_certificate(cert_pem, default_backend()) cert_data = ServiceCertificate( service_name=service_name, certificate_pem=cert_pem, private_key_pem=key_pem, certificate_hash=cert_hash, expires_at=cert.not_valid_after_utc, issued_at=cert.not_valid_before_utc, ) self._service_certs[service_name] = cert_data return cert_data return None def get_ca_certificate_pem(self) -> bytes: """Get the CA certificate in PEM format.""" return self._ca_cert_pem def validate_certificate(self, cert_pem: bytes) -> Dict[str, Any]: """ Validate a certificate against the CA. Returns: Dict with validation results """ try: cert = x509.load_pem_x509_certificate(cert_pem, default_backend()) # Check if expired now = datetime.utcnow() is_expired = now > cert.not_valid_after_utc not_yet_valid = now < cert.not_valid_before_utc # Verify signature with CA ca_public_key = self._ca_cert.public_key() try: ca_public_key.verify( cert.signature, cert.tbs_certificate_bytes, hashes.SHA256(), ) signature_valid = True except Exception: signature_valid = False return { "valid": not is_expired and not not_yet_valid and signature_valid, "expired": is_expired, "not_yet_valid": not_yet_valid, "signature_valid": signature_valid, "subject": cert.subject.rfc4514_string(), "issuer": cert.issuer.rfc4514_string(), "serial_number": hex(cert.serial_number), "not_before": cert.not_valid_before_utc.isoformat(), "not_after": cert.not_valid_after_utc.isoformat(), } except Exception as e: logger.error(f"Certificate validation failed: {e}") return { "valid": False, "error": str(e), } def rotate_service_certificate(self, service_name: str) -> ServiceCertificate: """Rotate (regenerate) a service certificate.""" logger.info(f"Rotating mTLS certificate for service: {service_name}") # Remove old certificate from cache if service_name in self._service_certs: del self._service_certs[service_name] # Generate new certificate return self.generate_service_certificate(service_name) def create_ssl_context( self, service_name: str, verify_clients: bool = True ) -> ssl.SSLContext: """ Create an SSL context for a service. Args: service_name: Name of the service verify_clients: Whether to require client certificates Returns: Configured SSL context """ cert_data = self.get_service_certificate(service_name) if not cert_data: cert_data = self.generate_service_certificate(service_name) # Create server context ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # Load service certificate and key ctx.load_cert_chain( certfile=cert_data.certificate_pem, keyfile=cert_data.private_key_pem, ) if verify_clients: # Require and verify client certificates ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(cadata=self._ca_cert_pem) ctx.minimum_version = ssl.TLSVersion.TLSv1_3 # Set secure options ctx.options |= ssl.OP_NO_SSLv2 ctx.options |= ssl.OP_NO_SSLv3 ctx.options |= ssl.OP_NO_TLSv1 ctx.options |= ssl.OP_NO_TLSv1_1 ctx.options |= ssl.OP_CIPHER_SERVER_PREFERENCE # Set preferred ciphers ctx.set_ciphers("ECDHE+AESGCM:DHE+AESGCM:ECDHE+CHACHA20:DHE+CHACHA20") return ctx class MTLSAuthMiddleware: """ FastAPI middleware for mTLS authentication. Validates client certificates for service-to-service requests. """ def __init__(self, cert_manager: MTLSCertificateManager): self.cert_manager = cert_manager self._trusted_services: Dict[str, bool] = {} def verify_client_cert(self, cert_pem: bytes) -> Optional[Dict[str, Any]]: """Verify a client certificate.""" result = self.cert_manager.validate_certificate(cert_pem) if result["valid"]: return result return None def register_trusted_service(self, service_name: str): """Register a service as trusted.""" self._trusted_services[service_name] = True logger.info(f"Service registered as trusted: {service_name}") def is_trusted_service(self, service_name: str) -> bool: """Check if a service is trusted.""" return self._trusted_services.get(service_name, False) # Global certificate manager instance mtls_manager: Optional[MTLSCertificateManager] = None def get_mtls_manager() -> MTLSCertificateManager: """Get or create the global mTLS manager.""" global mtls_manager if mtls_manager is None: certs_dir = os.getenv("MTLS_CERTS_DIR", "/app/certs") ca_cert = os.getenv("MTLS_CA_CERT") ca_key = os.getenv("MTLS_CA_KEY") mtls_manager = MTLSCertificateManager( certs_dir=certs_dir, ca_cert_path=ca_cert, ca_key_path=ca_key, ) return mtls_manager def setup_mtls_service(service_name: str) -> ssl.SSLContext: """ Set up mTLS for a service. Args: service_name: Name of the service Returns: Configured SSL context """ manager = get_mtls_manager() return manager.create_ssl_context(service_name) def generate_service_credentials(service_name: str) -> Dict[str, Any]: """ Generate mTLS credentials for a service. Args: service_name: Name of the service Returns: Dict with certificate, key, and CA certificate """ manager = get_mtls_manager() # Generate or get certificate cert_data = manager.get_service_certificate(service_name) if not cert_data: cert_data = manager.generate_service_certificate(service_name) return { "service_name": service_name, "certificate": cert_data.certificate_pem.decode("utf-8"), "private_key": cert_data.private_key_pem.decode("utf-8"), "ca_certificate": manager.get_ca_certificate_pem().decode("utf-8"), "certificate_hash": cert_data.certificate_hash, "expires_at": cert_data.expires_at.isoformat(), } # Example usage and configuration MTLS_SERVICES = [ "backend-api", "frontend-app", "cache-service", "ml-service", "notification-service", "audit-service", ] def initialize_mtls_for_all_services(): """Initialize mTLS certificates for all registered services.""" manager = get_mtls_manager() for service_name in MTLS_SERVICES: try: cert_data = manager.get_service_certificate(service_name) if not cert_data: cert_data = manager.generate_service_certificate(service_name) logger.info(f"Generated mTLS certificate for {service_name}") else: # Check if certificate needs rotation (within 30 days of expiry) if cert_data.expires_at < datetime.utcnow() + timedelta(days=30): manager.rotate_service_certificate(service_name) logger.info(f"Rotated mTLS certificate for {service_name}") except Exception as e: logger.error(f"Failed to generate certificate for {service_name}: {e}") __all__ = [ "MTLSCertificateManager", "MTLSAuthMiddleware", "CertificateConfig", "ServiceCertificate", "get_mtls_manager", "setup_mtls_service", "generate_service_credentials", "initialize_mtls_for_all_services", ]