Spaces:
Paused
Paused
| """ | |
| Mutual TLS (mTLS) Implementation for Service-to-Service Authentication | |
| This module provides mTLS support for secure service-to-service communication | |
| within the Zenith Fraud Detection Platform. | |
| Features: | |
| - Certificate generation and management | |
| - Client/Server mTLS authentication | |
| - Certificate rotation and validation | |
| - Integration with FastAPI middleware | |
| """ | |
| import hashlib | |
| import logging | |
| import os | |
| import ssl | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| from cryptography import x509 | |
| from cryptography.hazmat.backends import default_backend | |
| from cryptography.hazmat.primitives import hashes, serialization | |
| from cryptography.hazmat.primitives.asymmetric import rsa | |
| from cryptography.x509.oid import NameOID | |
| logger = logging.getLogger(__name__) | |
| class CertificateConfig: | |
| """Configuration for certificate generation.""" | |
| common_name: str | |
| organization: str = "Zenith Fraud Detection" | |
| organizational_unit: str = "Engineering" | |
| locality: str = "San Francisco" | |
| province: str = "CA" | |
| country: str = "US" | |
| validity_days: int = 365 | |
| key_size: int = 4096 | |
| class ServiceCertificate: | |
| """Container for service certificate data.""" | |
| service_name: str | |
| certificate_pem: bytes | |
| private_key_pem: bytes | |
| certificate_hash: str | |
| expires_at: datetime | |
| issued_at: datetime | |
| class MTLSCertificateManager: | |
| """ | |
| Manages mTLS certificates for service-to-service authentication. | |
| Provides certificate generation, signing, and validation for | |
| secure internal communication between services. | |
| """ | |
| def __init__( | |
| self, | |
| certs_dir: str = "/app/certs", | |
| ca_cert_path: Optional[str] = None, | |
| ca_key_path: Optional[str] = None, | |
| ): | |
| self.certs_dir = Path(certs_dir) | |
| self.certs_dir.mkdir(parents=True, exist_ok=True) | |
| self.ca_cert_path = Path(ca_cert_path) if ca_cert_path else None | |
| self.ca_key_path = Path(ca_key_path) if ca_key_path else None | |
| self._service_certs: Dict[str, ServiceCertificate] = {} | |
| # Load or generate CA | |
| self._initialize_ca() | |
| def _initialize_ca(self): | |
| """Initialize or load the Certificate Authority.""" | |
| if self.ca_cert_path and self.ca_cert_path.exists(): | |
| logger.info(f"Loading existing CA from {self.ca_cert_path}") | |
| with open(self.ca_cert_path, "rb") as f: | |
| self._ca_cert_pem = f.read() | |
| with open(self.ca_key_path, "rb") as f: | |
| self._ca_key_pem = f.read() | |
| else: | |
| logger.info("Generating new CA certificate") | |
| self._generate_ca() | |
| if self.ca_cert_path: | |
| self._save_ca() | |
| def _generate_ca(self): | |
| """Generate a new self-signed CA certificate.""" | |
| # Generate CA private key | |
| self._ca_key = rsa.generate_private_key( | |
| public_exponent=65537, key_size=4096, backend=default_backend() | |
| ) | |
| # Generate CA certificate | |
| subject = issuer = x509.Name( | |
| [ | |
| x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), | |
| x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"), | |
| x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), | |
| x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Zenith Fraud Detection"), | |
| x509.NameAttribute(NameOID.COMMON_NAME, "Zenith Internal CA"), | |
| ] | |
| ) | |
| self._ca_cert = ( | |
| x509.CertificateBuilder() | |
| .subject_name(subject) | |
| .issuer_name(issuer) | |
| .public_key(self._ca_key.public_key()) | |
| .serial_number(x509.random_serial_number()) | |
| .not_valid_before(datetime.utcnow()) | |
| .not_valid_after(datetime.utcnow() + timedelta(days=3650)) # 10 years | |
| .add_extension( | |
| x509.BasicConstraints(ca=True, path_length=0), | |
| critical=True, | |
| ) | |
| .add_extension( | |
| x509.KeyUsage( | |
| digital_signature=True, | |
| key_encipherment=False, | |
| content_commitment=False, | |
| data_encipherment=False, | |
| key_agreement=False, | |
| key_cert_sign=True, | |
| crl_sign=True, | |
| encipher_only=False, | |
| decipher_only=False, | |
| ), | |
| critical=True, | |
| ) | |
| .sign(self._ca_key, hashes.SHA256(), default_backend()) | |
| ) | |
| self._ca_cert_pem = self._ca_cert.public_bytes(serialization.Encoding.PEM) | |
| self._ca_key_pem = self._ca_key.private_bytes( | |
| encoding=serialization.Encoding.PEM, | |
| format=serialization.PrivateFormat.TraditionalOpenSSL, | |
| encryption_algorithm=serialization.NoEncryption(), | |
| ) | |
| def _save_ca(self): | |
| """Save CA certificate and key to disk.""" | |
| if self.ca_cert_path: | |
| with open(self.ca_cert_path, "wb") as f: | |
| f.write(self._ca_cert_pem) | |
| logger.info(f"CA certificate saved to {self.ca_cert_path}") | |
| if self.ca_key_path: | |
| with open(self.ca_key_path, "wb") as f: | |
| f.write(self._ca_key_pem) | |
| logger.info(f"CA key saved to {self.ca_key_path}") | |
| def generate_service_certificate( | |
| self, | |
| service_name: str, | |
| config: Optional[CertificateConfig] = None, | |
| ) -> ServiceCertificate: | |
| """ | |
| Generate a new certificate for a service. | |
| Args: | |
| service_name: Name of the service | |
| config: Certificate configuration (uses defaults if None) | |
| Returns: | |
| ServiceCertificate with cert, key, and metadata | |
| """ | |
| if config is None: | |
| config = CertificateConfig( | |
| common_name=service_name, | |
| ) | |
| # Generate service private key | |
| service_key = rsa.generate_private_key( | |
| public_exponent=65537, key_size=config.key_size, backend=default_backend() | |
| ) | |
| # Generate service CSR | |
| subject = x509.Name( | |
| [ | |
| x509.NameAttribute(NameOID.COUNTRY_NAME, config.country), | |
| x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, config.province), | |
| x509.NameAttribute(NameOID.LOCALITY_NAME, config.locality), | |
| x509.NameAttribute(NameOID.ORGANIZATION_NAME, config.organization), | |
| x509.NameAttribute( | |
| NameOID.ORGANIZATIONAL_UNIT_NAME, config.organizational_unit | |
| ), | |
| x509.NameAttribute(NameOID.COMMON_NAME, config.common_name), | |
| x509.NameAttribute(NameOID.SERIAL_NUMBER, f"service-{service_name}"), | |
| ] | |
| ) | |
| # Sign with CA | |
| service_cert = ( | |
| x509.CertificateBuilder() | |
| .subject_name(subject) | |
| .issuer_name(self._ca_cert.subject) | |
| .public_key(service_key.public_key()) | |
| .serial_number(x509.random_serial_number()) | |
| .not_valid_before(datetime.utcnow()) | |
| .not_valid_after(datetime.utcnow() + timedelta(days=config.validity_days)) | |
| .add_extension( | |
| x509.BasicConstraints(ca=False, path_length=None), | |
| critical=True, | |
| ) | |
| .add_extension( | |
| x509.KeyUsage( | |
| digital_signature=True, | |
| key_encipherment=True, | |
| content_commitment=False, | |
| data_encipherment=False, | |
| key_agreement=False, | |
| key_cert_sign=False, | |
| crl_sign=False, | |
| encipher_only=False, | |
| decipher_only=False, | |
| ), | |
| critical=True, | |
| ) | |
| .add_extension( | |
| x509.ExtendedKeyUsage( | |
| [ | |
| x509.oid.ExtendedKeyUsageOID.SERVER_AUTH, | |
| x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH, | |
| ] | |
| ), | |
| critical=False, | |
| ) | |
| .add_extension( | |
| x509.SubjectAlternativeName( | |
| [ | |
| x509.DNSName(service_name), | |
| x509.DNSName(f"{service_name}.internal"), | |
| x509.DNSName(f"{service_name}.zenith.local"), | |
| ] | |
| ), | |
| critical=False, | |
| ) | |
| .sign(self._ca_key, hashes.SHA256(), default_backend()) | |
| ) | |
| # Serialize certificate and key | |
| cert_pem = service_cert.public_bytes(serialization.Encoding.PEM) | |
| key_pem = service_key.private_bytes( | |
| encoding=serialization.Encoding.PEM, | |
| format=serialization.PrivateFormat.TraditionalOpenSSL, | |
| encryption_algorithm=serialization.NoEncryption(), | |
| ) | |
| # Calculate certificate hash | |
| cert_hash = hashlib.sha256(cert_pem).hexdigest() | |
| service_cert_data = ServiceCertificate( | |
| service_name=service_name, | |
| certificate_pem=cert_pem, | |
| private_key_pem=key_pem, | |
| certificate_hash=cert_hash, | |
| expires_at=service_cert.not_valid_after_utc, | |
| issued_at=service_cert.not_valid_before_utc, | |
| ) | |
| # Store certificate | |
| self._service_certs[service_name] = service_cert_data | |
| # Save to disk | |
| self._save_service_cert(service_name, service_cert_data) | |
| logger.info(f"Generated mTLS certificate for service: {service_name}") | |
| return service_cert_data | |
| def _save_service_cert(self, service_name: str, cert_data: ServiceCertificate): | |
| """Save service certificate and key to disk.""" | |
| cert_path = self.certs_dir / f"{service_name}.crt" | |
| key_path = self.certs_dir / f"{service_name}.key" | |
| with open(cert_path, "wb") as f: | |
| f.write(cert_data.certificate_pem) | |
| with open(key_path, "wb") as f: | |
| f.write(cert_data.private_key_pem) | |
| os.chmod(key_path, 0o600) # Restrict key file permissions | |
| logger.info(f"Service certificate saved to {cert_path}") | |
| def get_service_certificate( | |
| self, service_name: str | |
| ) -> Optional[ServiceCertificate]: | |
| """Retrieve a service certificate from cache or disk.""" | |
| if service_name in self._service_certs: | |
| return self._service_certs[service_name] | |
| cert_path = self.certs_dir / f"{service_name}.crt" | |
| key_path = self.certs_dir / f"{service_name}.key" | |
| if cert_path.exists() and key_path.exists(): | |
| with open(cert_path, "rb") as f: | |
| cert_pem = f.read() | |
| with open(key_path, "rb") as f: | |
| key_pem = f.read() | |
| cert_hash = hashlib.sha256(cert_pem).hexdigest() | |
| # Load and parse certificate to get expiry | |
| cert = x509.load_pem_x509_certificate(cert_pem, default_backend()) | |
| cert_data = ServiceCertificate( | |
| service_name=service_name, | |
| certificate_pem=cert_pem, | |
| private_key_pem=key_pem, | |
| certificate_hash=cert_hash, | |
| expires_at=cert.not_valid_after_utc, | |
| issued_at=cert.not_valid_before_utc, | |
| ) | |
| self._service_certs[service_name] = cert_data | |
| return cert_data | |
| return None | |
| def get_ca_certificate_pem(self) -> bytes: | |
| """Get the CA certificate in PEM format.""" | |
| return self._ca_cert_pem | |
| def validate_certificate(self, cert_pem: bytes) -> Dict[str, Any]: | |
| """ | |
| Validate a certificate against the CA. | |
| Returns: | |
| Dict with validation results | |
| """ | |
| try: | |
| cert = x509.load_pem_x509_certificate(cert_pem, default_backend()) | |
| # Check if expired | |
| now = datetime.utcnow() | |
| is_expired = now > cert.not_valid_after_utc | |
| not_yet_valid = now < cert.not_valid_before_utc | |
| # Verify signature with CA | |
| ca_public_key = self._ca_cert.public_key() | |
| try: | |
| ca_public_key.verify( | |
| cert.signature, | |
| cert.tbs_certificate_bytes, | |
| hashes.SHA256(), | |
| ) | |
| signature_valid = True | |
| except Exception: | |
| signature_valid = False | |
| return { | |
| "valid": not is_expired and not not_yet_valid and signature_valid, | |
| "expired": is_expired, | |
| "not_yet_valid": not_yet_valid, | |
| "signature_valid": signature_valid, | |
| "subject": cert.subject.rfc4514_string(), | |
| "issuer": cert.issuer.rfc4514_string(), | |
| "serial_number": hex(cert.serial_number), | |
| "not_before": cert.not_valid_before_utc.isoformat(), | |
| "not_after": cert.not_valid_after_utc.isoformat(), | |
| } | |
| except Exception as e: | |
| logger.error(f"Certificate validation failed: {e}") | |
| return { | |
| "valid": False, | |
| "error": str(e), | |
| } | |
| def rotate_service_certificate(self, service_name: str) -> ServiceCertificate: | |
| """Rotate (regenerate) a service certificate.""" | |
| logger.info(f"Rotating mTLS certificate for service: {service_name}") | |
| # Remove old certificate from cache | |
| if service_name in self._service_certs: | |
| del self._service_certs[service_name] | |
| # Generate new certificate | |
| return self.generate_service_certificate(service_name) | |
| def create_ssl_context( | |
| self, service_name: str, verify_clients: bool = True | |
| ) -> ssl.SSLContext: | |
| """ | |
| Create an SSL context for a service. | |
| Args: | |
| service_name: Name of the service | |
| verify_clients: Whether to require client certificates | |
| Returns: | |
| Configured SSL context | |
| """ | |
| cert_data = self.get_service_certificate(service_name) | |
| if not cert_data: | |
| cert_data = self.generate_service_certificate(service_name) | |
| # Create server context | |
| ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
| # Load service certificate and key | |
| ctx.load_cert_chain( | |
| certfile=cert_data.certificate_pem, | |
| keyfile=cert_data.private_key_pem, | |
| ) | |
| if verify_clients: | |
| # Require and verify client certificates | |
| ctx.verify_mode = ssl.CERT_REQUIRED | |
| ctx.load_verify_locations(cadata=self._ca_cert_pem) | |
| ctx.minimum_version = ssl.TLSVersion.TLSv1_3 | |
| # Set secure options | |
| ctx.options |= ssl.OP_NO_SSLv2 | |
| ctx.options |= ssl.OP_NO_SSLv3 | |
| ctx.options |= ssl.OP_NO_TLSv1 | |
| ctx.options |= ssl.OP_NO_TLSv1_1 | |
| ctx.options |= ssl.OP_CIPHER_SERVER_PREFERENCE | |
| # Set preferred ciphers | |
| ctx.set_ciphers("ECDHE+AESGCM:DHE+AESGCM:ECDHE+CHACHA20:DHE+CHACHA20") | |
| return ctx | |
| class MTLSAuthMiddleware: | |
| """ | |
| FastAPI middleware for mTLS authentication. | |
| Validates client certificates for service-to-service requests. | |
| """ | |
| def __init__(self, cert_manager: MTLSCertificateManager): | |
| self.cert_manager = cert_manager | |
| self._trusted_services: Dict[str, bool] = {} | |
| def verify_client_cert(self, cert_pem: bytes) -> Optional[Dict[str, Any]]: | |
| """Verify a client certificate.""" | |
| result = self.cert_manager.validate_certificate(cert_pem) | |
| if result["valid"]: | |
| return result | |
| return None | |
| def register_trusted_service(self, service_name: str): | |
| """Register a service as trusted.""" | |
| self._trusted_services[service_name] = True | |
| logger.info(f"Service registered as trusted: {service_name}") | |
| def is_trusted_service(self, service_name: str) -> bool: | |
| """Check if a service is trusted.""" | |
| return self._trusted_services.get(service_name, False) | |
| # Global certificate manager instance | |
| mtls_manager: Optional[MTLSCertificateManager] = None | |
| def get_mtls_manager() -> MTLSCertificateManager: | |
| """Get or create the global mTLS manager.""" | |
| global mtls_manager | |
| if mtls_manager is None: | |
| certs_dir = os.getenv("MTLS_CERTS_DIR", "/app/certs") | |
| ca_cert = os.getenv("MTLS_CA_CERT") | |
| ca_key = os.getenv("MTLS_CA_KEY") | |
| mtls_manager = MTLSCertificateManager( | |
| certs_dir=certs_dir, | |
| ca_cert_path=ca_cert, | |
| ca_key_path=ca_key, | |
| ) | |
| return mtls_manager | |
| def setup_mtls_service(service_name: str) -> ssl.SSLContext: | |
| """ | |
| Set up mTLS for a service. | |
| Args: | |
| service_name: Name of the service | |
| Returns: | |
| Configured SSL context | |
| """ | |
| manager = get_mtls_manager() | |
| return manager.create_ssl_context(service_name) | |
| def generate_service_credentials(service_name: str) -> Dict[str, Any]: | |
| """ | |
| Generate mTLS credentials for a service. | |
| Args: | |
| service_name: Name of the service | |
| Returns: | |
| Dict with certificate, key, and CA certificate | |
| """ | |
| manager = get_mtls_manager() | |
| # Generate or get certificate | |
| cert_data = manager.get_service_certificate(service_name) | |
| if not cert_data: | |
| cert_data = manager.generate_service_certificate(service_name) | |
| return { | |
| "service_name": service_name, | |
| "certificate": cert_data.certificate_pem.decode("utf-8"), | |
| "private_key": cert_data.private_key_pem.decode("utf-8"), | |
| "ca_certificate": manager.get_ca_certificate_pem().decode("utf-8"), | |
| "certificate_hash": cert_data.certificate_hash, | |
| "expires_at": cert_data.expires_at.isoformat(), | |
| } | |
| # Example usage and configuration | |
| MTLS_SERVICES = [ | |
| "backend-api", | |
| "frontend-app", | |
| "cache-service", | |
| "ml-service", | |
| "notification-service", | |
| "audit-service", | |
| ] | |
| def initialize_mtls_for_all_services(): | |
| """Initialize mTLS certificates for all registered services.""" | |
| manager = get_mtls_manager() | |
| for service_name in MTLS_SERVICES: | |
| try: | |
| cert_data = manager.get_service_certificate(service_name) | |
| if not cert_data: | |
| cert_data = manager.generate_service_certificate(service_name) | |
| logger.info(f"Generated mTLS certificate for {service_name}") | |
| else: | |
| # Check if certificate needs rotation (within 30 days of expiry) | |
| if cert_data.expires_at < datetime.utcnow() + timedelta(days=30): | |
| manager.rotate_service_certificate(service_name) | |
| logger.info(f"Rotated mTLS certificate for {service_name}") | |
| except Exception as e: | |
| logger.error(f"Failed to generate certificate for {service_name}: {e}") | |
| __all__ = [ | |
| "MTLSCertificateManager", | |
| "MTLSAuthMiddleware", | |
| "CertificateConfig", | |
| "ServiceCertificate", | |
| "get_mtls_manager", | |
| "setup_mtls_service", | |
| "generate_service_credentials", | |
| "initialize_mtls_for_all_services", | |
| ] | |