Spaces:
Sleeping
Sleeping
Sync: Updated CDK code to include possibility of exporting DynamoDB logs to S3 csv output on automated basis
373207f | import json # You might still need json if loading task_definition.json | |
| import os | |
| from typing import Any, Dict, List | |
| from aws_cdk import ( | |
| CfnOutput, # <-- Import CfnOutput directly | |
| Duration, | |
| SecretValue, | |
| Stack, | |
| ) | |
| from aws_cdk import aws_cloudfront as cloudfront | |
| from aws_cdk import aws_cloudfront_origins as origins | |
| from aws_cdk import aws_codebuild as codebuild | |
| from aws_cdk import aws_cognito as cognito | |
| from aws_cdk import aws_dynamodb as dynamodb # Import the DynamoDB module | |
| from aws_cdk import aws_ec2 as ec2 | |
| from aws_cdk import aws_ecr as ecr | |
| from aws_cdk import aws_ecs as ecs | |
| from aws_cdk import aws_elasticloadbalancingv2 as elbv2 | |
| from aws_cdk import aws_iam as iam | |
| from aws_cdk import aws_kms as kms | |
| from aws_cdk import aws_logs as logs | |
| from aws_cdk import aws_s3 as s3 | |
| from aws_cdk import aws_secretsmanager as secretsmanager | |
| from aws_cdk import aws_wafv2 as wafv2 | |
| from cdk_cloudfront_headers import ( | |
| create_secure_cloudfront_response_headers_policy, | |
| resolve_cloudfront_csp_urls, | |
| ) | |
| from cdk_config import ( | |
| ACCESS_LOG_DYNAMODB_TABLE_NAME, | |
| ACM_SSL_CERTIFICATE_ARN, | |
| ALB_NAME, | |
| ALB_NAME_SECURITY_GROUP_NAME, | |
| ALB_TARGET_GROUP_NAME, | |
| APP_CONFIG_ENV_BASENAME, | |
| APP_CONFIG_ENV_FILE, | |
| AWS_ACCOUNT_ID, | |
| AWS_MANAGED_TASK_ROLES_LIST, | |
| AWS_REGION, | |
| CDK_FOLDER, | |
| CDK_PREFIX, | |
| CLOUDFRONT_DISTRIBUTION_NAME, | |
| CLOUDFRONT_DOMAIN, | |
| CLOUDFRONT_ENABLE_SECURE_RESPONSE_HEADERS, | |
| CLOUDFRONT_GEO_RESTRICTION, | |
| CLOUDFRONT_PREFIX_LIST_ID, | |
| CLUSTER_NAME, | |
| CODEBUILD_PI_PROJECT_NAME, | |
| CODEBUILD_PROJECT_NAME, | |
| CODEBUILD_ROLE_NAME, | |
| COGNITO_ACCESS_TOKEN_VALIDITY, | |
| COGNITO_ID_TOKEN_VALIDITY, | |
| COGNITO_REDIRECTION_URL, | |
| COGNITO_REFRESH_TOKEN_VALIDITY, | |
| COGNITO_USER_POOL_CLIENT_NAME, | |
| COGNITO_USER_POOL_CLIENT_SECRET_NAME, | |
| COGNITO_USER_POOL_DOMAIN_PREFIX, | |
| COGNITO_USER_POOL_LOGIN_URL, | |
| COGNITO_USER_POOL_NAME, | |
| CUSTOM_HEADER, | |
| CUSTOM_HEADER_VALUE, | |
| CUSTOM_KMS_KEY_NAME, | |
| DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS, | |
| DYNAMODB_USAGE_LOG_EXPORT_DATE_ATTRIBUTE, | |
| DYNAMODB_USAGE_LOG_EXPORT_LAMBDA_NAME, | |
| DYNAMODB_USAGE_LOG_EXPORT_OUTPUT_FILENAME, | |
| DYNAMODB_USAGE_LOG_EXPORT_S3_KEY, | |
| DYNAMODB_USAGE_LOG_EXPORT_SCHEDULE, | |
| ECR_CDK_REPO_NAME, | |
| ECR_PI_REPO_NAME, | |
| ECS_AVAILABILITY_ZONE_REBALANCING, | |
| ECS_EXECUTION_ROLE_MANAGED_POLICIES, | |
| ECS_EXECUTION_ROLE_POLICY_ARNS, | |
| ECS_EXECUTION_ROLE_POLICY_FILES, | |
| ECS_EXPRESS_HEALTH_CHECK_PATH, | |
| ECS_EXPRESS_INFRASTRUCTURE_ROLE_NAME, | |
| ECS_EXPRESS_SERVICE_NAME, | |
| ECS_EXPRESS_USE_PUBLIC_SUBNETS, | |
| ECS_LOG_GROUP_NAME, | |
| ECS_PI_EXPRESS_HEALTH_CHECK_PATH, | |
| ECS_PI_EXPRESS_SECURITY_GROUP_NAME, | |
| ECS_PI_EXPRESS_SERVICE_NAME, | |
| ECS_PI_LOG_GROUP_NAME, | |
| ECS_PI_SECURITY_GROUP_NAME, | |
| ECS_PI_SERVICE_NAME, | |
| ECS_PI_TASK_CPU_SIZE, | |
| ECS_PI_TASK_DEFINITION_NAME, | |
| ECS_PI_TASK_MEMORY_SIZE, | |
| ECS_READ_ONLY_FILE_SYSTEM, | |
| ECS_SECURITY_GROUP_NAME, | |
| ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_IDS_LIST, | |
| ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_NAMES_TO_LOOKUP, | |
| ECS_SERVICE_CONNECT_CLIENT_SG_NAME_SUFFIX, | |
| ECS_SERVICE_CONNECT_DISCOVERY_NAME, | |
| ECS_SERVICE_CONNECT_DNS_NAME, | |
| ECS_SERVICE_CONNECT_NAMESPACE, | |
| ECS_SERVICE_CONNECT_PORT_MAPPING_NAME, | |
| ECS_SERVICE_NAME, | |
| ECS_TASK_CPU_SIZE, | |
| ECS_TASK_EXECUTION_ROLE_NAME, | |
| ECS_TASK_MEMORY_SIZE, | |
| ECS_TASK_ROLE_NAME, | |
| ECS_USE_FARGATE_SPOT, | |
| ENABLE_DYNAMODB_USAGE_LOG_EXPORT, | |
| ENABLE_ECS_SERVICE_CONNECT, | |
| ENABLE_ECS_VPC_INTERFACE_ENDPOINTS, | |
| ENABLE_HEADLESS_DEPLOYMENT, | |
| ENABLE_HEADLESS_OUTPUT_NOTIFICATIONS, | |
| ENABLE_PI_AGENT_ECS_SERVICE, | |
| ENABLE_PI_AGENT_EXPRESS_SERVICE, | |
| ENABLE_S3_BATCH_ECS_TRIGGER, | |
| EXISTING_IGW_ID, | |
| EXISTING_LOAD_BALANCER_ARN, | |
| EXISTING_LOAD_BALANCER_DNS, | |
| FARGATE_TASK_DEFINITION_NAME, | |
| FEEDBACK_LOG_DYNAMODB_TABLE_NAME, | |
| GITHUB_REPO_BRANCH, | |
| GITHUB_REPO_NAME, | |
| GITHUB_REPO_USERNAME, | |
| GRADIO_SERVER_PORT, | |
| HEADLESS_OUTPUT_ALARM_NAME, | |
| HEADLESS_OUTPUT_IAM_USER_NAME, | |
| HEADLESS_OUTPUT_NOTIFY_EMAIL, | |
| HEADLESS_OUTPUT_S3_METRIC_FILTER_ID, | |
| HEADLESS_OUTPUT_S3_PREFIX, | |
| HEADLESS_OUTPUT_SNS_TOPIC_NAME, | |
| LOAD_BALANCER_WEB_ACL_NAME, | |
| NAT_GATEWAY_NAME, | |
| NEW_VPC_CIDR, | |
| NEW_VPC_DEFAULT_NAME, | |
| PI_AGENT_ENV_S3_KEY, | |
| PI_ALB_HOST_HEADER, | |
| PI_ALB_LISTENER_RULE_PRIORITY, | |
| PI_ALB_PATH_PREFIX_NORMALIZED, | |
| PI_ALB_ROUTING, | |
| PI_ALB_TARGET_GROUP_NAME, | |
| PI_GRADIO_PORT, | |
| POLICY_FILE_ARNS, | |
| POLICY_FILE_LOCATIONS, | |
| PRIVATE_SUBNET_AVAILABILITY_ZONES, | |
| PRIVATE_SUBNET_CIDR_BLOCKS, | |
| PRIVATE_SUBNETS_TO_USE, | |
| PUBLIC_SUBNET_AVAILABILITY_ZONES, | |
| PUBLIC_SUBNET_CIDR_BLOCKS, | |
| PUBLIC_SUBNETS_TO_USE, | |
| S3_BATCH_CONFIG_PREFIX, | |
| S3_BATCH_DEFAULT_PARAMS_KEY, | |
| S3_BATCH_ENV_PREFIX, | |
| S3_BATCH_ENV_SUFFIX, | |
| S3_BATCH_GENERAL_ENV_PREFIX, | |
| S3_BATCH_INPUT_PREFIX, | |
| S3_BATCH_LAMBDA_FUNCTION_NAME, | |
| S3_LOG_CONFIG_BUCKET_NAME, | |
| S3_OUTPUT_BUCKET_NAME, | |
| SAVE_LOGS_TO_DYNAMODB, | |
| SINGLE_NAT_GATEWAY_ID, | |
| SSL_CERTIFICATE_DOMAIN, | |
| TASK_DEFINITION_FILE_LOCATION, | |
| USAGE_LOG_DYNAMODB_TABLE_NAME, | |
| USE_CLOUDFRONT, | |
| USE_CUSTOM_KMS_KEY, | |
| USE_ECS_EXPRESS_MODE, | |
| VPC_NAME, | |
| WEB_ACL_NAME, | |
| ) | |
| from cdk_functions import ( # Only keep CDK-native functions | |
| add_alb_https_listener_with_cert, | |
| add_custom_policies, | |
| add_s3_enforce_ssl_policy, | |
| allow_express_load_balancer_to_ecs_security_group, | |
| attach_managed_policy_arns, | |
| attach_pi_agent_to_shared_alb, | |
| build_ecs_execution_role_kms_policy, | |
| build_ecs_task_role_inline_policy, | |
| build_express_gateway_primary_container, | |
| build_express_pi_primary_container, | |
| build_pi_express_container_environment, | |
| configure_public_github_codebuild_source, | |
| create_dynamo_usage_log_export_lambda, | |
| create_ecs_express_infrastructure_role, | |
| create_ecs_vpc_endpoints_for_private_subnets, | |
| create_express_gateway_service, | |
| create_headless_output_notifications, | |
| create_headless_s3_batch_seed, | |
| create_nat_gateway, | |
| create_pi_agent_ecs_resources, | |
| create_s3_batch_ecs_trigger_lambda, | |
| create_subnets, | |
| create_web_acl_with_common_rules, | |
| default_secrets_manager_kms_key_arn, | |
| ecr_empty_on_delete, | |
| ecs_availability_zone_rebalancing, | |
| express_ingress_first_load_balancer_security_group, | |
| express_ingress_load_balancer_arn, | |
| format_express_pi_public_url, | |
| format_pi_public_urls, | |
| load_app_config_env_for_express, | |
| managed_resource_removal_policy, | |
| pi_alb_root_path_for_container, | |
| pi_listener_rule_count, | |
| public_github_codebuild_source, | |
| resolve_ecs_s3_gateway_subnet_selection, | |
| resolve_ecs_vpc_endpoint_subnet_selection, | |
| resolve_policy_file_paths, | |
| resolve_service_connect_client_security_group_ids, | |
| resource_deletion_protection_flag, | |
| s3_auto_delete_objects_on_stack_destroy, | |
| wire_public_subnet_internet_access, | |
| ) | |
| from constructs import Construct | |
| def _get_env_list(env_var_name: str) -> List[str]: | |
| """Parses a comma-separated environment variable into a list of strings.""" | |
| value = env_var_name[1:-1].strip().replace('"', "").replace("'", "") | |
| if not value: | |
| return [] | |
| # Split by comma and filter out any empty strings that might result from extra commas | |
| return [s.strip() for s in value.split(",") if s.strip()] | |
| # 1. Try to load CIDR/AZs from environment variables | |
| if PUBLIC_SUBNETS_TO_USE: | |
| PUBLIC_SUBNETS_TO_USE = _get_env_list(PUBLIC_SUBNETS_TO_USE) | |
| if PRIVATE_SUBNETS_TO_USE: | |
| PRIVATE_SUBNETS_TO_USE = _get_env_list(PRIVATE_SUBNETS_TO_USE) | |
| if PUBLIC_SUBNET_CIDR_BLOCKS: | |
| PUBLIC_SUBNET_CIDR_BLOCKS = _get_env_list("PUBLIC_SUBNET_CIDR_BLOCKS") | |
| if PUBLIC_SUBNET_AVAILABILITY_ZONES: | |
| PUBLIC_SUBNET_AVAILABILITY_ZONES = _get_env_list("PUBLIC_SUBNET_AVAILABILITY_ZONES") | |
| if PRIVATE_SUBNET_CIDR_BLOCKS: | |
| PRIVATE_SUBNET_CIDR_BLOCKS = _get_env_list("PRIVATE_SUBNET_CIDR_BLOCKS") | |
| if PRIVATE_SUBNET_AVAILABILITY_ZONES: | |
| PRIVATE_SUBNET_AVAILABILITY_ZONES = _get_env_list( | |
| "PRIVATE_SUBNET_AVAILABILITY_ZONES" | |
| ) | |
| # AWS_MANAGED_TASK_ROLES_LIST and POLICY_* lists are parsed in cdk_config.py. | |
| class CdkStack(Stack): | |
| def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: | |
| super().__init__(scope, construct_id, **kwargs) | |
| # --- Helper to get context values --- | |
| def get_context_bool(key: str, default: bool = False) -> bool: | |
| value = self.node.try_get_context(key) | |
| if value is None: | |
| return default | |
| if isinstance(value, bool): | |
| return value | |
| if isinstance(value, str): | |
| return value.lower() in ("true", "1", "yes") | |
| return bool(value) | |
| def get_context_str(key: str, default: str = None) -> str: | |
| return self.node.try_get_context(key) or default | |
| def get_context_dict(key: str, default: dict = None) -> dict: | |
| return self.node.try_get_context(key) or default | |
| def get_context_list_of_dicts(key: str) -> List[Dict[str, Any]]: | |
| ctx_value = self.node.try_get_context(key) | |
| if not isinstance(ctx_value, list): | |
| print( | |
| f"Warning: Context key '{key}' not found or not a list. Returning empty list." | |
| ) | |
| return [] | |
| # Optional: Add validation that all items in the list are dicts | |
| return ctx_value | |
| resource_removal_policy = managed_resource_removal_policy() | |
| resource_delete_protection = resource_deletion_protection_flag() | |
| s3_auto_delete_objects = s3_auto_delete_objects_on_stack_destroy() | |
| self.template_options.description = "Deployment of the llm_topic_modeller Gradio app for LLM-based topic modelling. Git repo: https://github.com/seanpedrick-case/llm_topic_modeller." | |
| use_express_ingress = ( | |
| not ACM_SSL_CERTIFICATE_ARN and USE_ECS_EXPRESS_MODE == "True" | |
| ) | |
| enable_headless = ENABLE_HEADLESS_DEPLOYMENT == "True" | |
| express_public_subnets_only = ( | |
| use_express_ingress and ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True" | |
| ) or ( | |
| enable_headless | |
| and not use_express_ingress | |
| and ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True" | |
| ) | |
| deploy_web_ingress = not use_express_ingress and not enable_headless | |
| enable_service_connect = ( | |
| ENABLE_ECS_SERVICE_CONNECT == "True" and not use_express_ingress | |
| ) | |
| enable_pi_agent = ( | |
| ENABLE_PI_AGENT_ECS_SERVICE == "True" and not use_express_ingress | |
| ) | |
| enable_pi_express = ( | |
| ENABLE_PI_AGENT_EXPRESS_SERVICE == "True" and use_express_ingress | |
| ) | |
| enable_pi_build = enable_pi_agent or enable_pi_express | |
| if enable_headless: | |
| print( | |
| "ENABLE_HEADLESS_DEPLOYMENT=True: S3 batch trigger + one-shot Fargate " | |
| "tasks only (no ALB, CloudFront, or always-on ECS service)." | |
| ) | |
| elif use_express_ingress: | |
| print( | |
| "USE_ECS_EXPRESS_MODE=True: using ECS Express Mode for HTTPS ingress " | |
| "(no manual ALB/Fargate service)." | |
| ) | |
| if express_public_subnets_only: | |
| print( | |
| "ECS_EXPRESS_USE_PUBLIC_SUBNETS=True: Express tasks and VPC " | |
| "endpoints use public subnets only (no private subnet install)." | |
| ) | |
| elif enable_headless: | |
| print( | |
| "ENABLE_HEADLESS_DEPLOYMENT=True: batch Fargate tasks use " | |
| "legacy private subnets (or public if configured)." | |
| ) | |
| service_connect_client_sg_ids: List[str] = [] | |
| if enable_service_connect: | |
| if ( | |
| not ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_IDS_LIST | |
| and not ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_NAMES_TO_LOOKUP | |
| and not enable_pi_agent | |
| ): | |
| raise ValueError( | |
| "ENABLE_ECS_SERVICE_CONNECT=True requires at least one of " | |
| "ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_IDS, " | |
| "ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_NAMES, or " | |
| "ECS_SERVICE_CONNECT_CLIENT_CDK_PREFIXES (other apps' CDK_PREFIX " | |
| f"values, resolved to {{prefix}}{ECS_SERVICE_CONNECT_CLIENT_SG_NAME_SUFFIX} " | |
| "in this VPC), unless ENABLE_PI_AGENT_ECS_SERVICE=True (Pi SG is wired in-stack)." | |
| ) | |
| service_connect_client_sg_ids = ( | |
| resolve_service_connect_client_security_group_ids( | |
| ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_IDS_LIST, | |
| ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_NAMES_TO_LOOKUP, | |
| get_context_str, | |
| ) | |
| ) | |
| print( | |
| "ENABLE_ECS_SERVICE_CONNECT=True: advertising Fargate service on " | |
| f"Service Connect as {ECS_SERVICE_CONNECT_DISCOVERY_NAME}; " | |
| f"client SGs: {', '.join(service_connect_client_sg_ids)}" | |
| ) | |
| # --- VPC and Subnets (Assuming VPC is always lookup, Subnets are created/returned by create_subnets) --- | |
| new_vpc_created = False | |
| imported_vpc_cidr_block = None | |
| imported_vpc_cidr_blocks: List[str] = [] | |
| if VPC_NAME: | |
| vpc_id = get_context_str("vpc_id") | |
| if not vpc_id: | |
| raise ValueError( | |
| f"VPC '{VPC_NAME}' was not resolved during pre-check (missing " | |
| "'vpc_id' in context). Re-run from the cdk/ directory so " | |
| "precheck.context.json is generated." | |
| ) | |
| availability_zones = list( | |
| dict.fromkeys( | |
| (PUBLIC_SUBNET_AVAILABILITY_ZONES or []) | |
| + (PRIVATE_SUBNET_AVAILABILITY_ZONES or []) | |
| ) | |
| ) | |
| if not availability_zones: | |
| raise ValueError( | |
| "vpc_id is in context but no subnet availability zones are " | |
| "configured. Set PUBLIC_SUBNET_AVAILABILITY_ZONES and/or " | |
| "PRIVATE_SUBNET_AVAILABILITY_ZONES in cdk_config.env." | |
| ) | |
| vpc_cidr_block = get_context_str("vpc_cidr_block") | |
| imported_vpc_cidr_block = vpc_cidr_block | |
| imported_vpc_cidr_blocks = list( | |
| self.node.try_get_context("vpc_cidr_blocks") or [] | |
| ) | |
| if ( | |
| imported_vpc_cidr_block | |
| and imported_vpc_cidr_block not in imported_vpc_cidr_blocks | |
| ): | |
| imported_vpc_cidr_blocks.insert(0, imported_vpc_cidr_block) | |
| vpc_attrs = { | |
| "vpc_id": vpc_id, | |
| "availability_zones": availability_zones, | |
| } | |
| if vpc_cidr_block: | |
| vpc_attrs["vpc_cidr_block"] = vpc_cidr_block | |
| vpc = ec2.Vpc.from_vpc_attributes(self, "VPC", **vpc_attrs) | |
| cidr_log = ( | |
| ", ".join(imported_vpc_cidr_blocks) | |
| if imported_vpc_cidr_blocks | |
| else vpc_cidr_block | |
| ) | |
| print( | |
| f"Using VPC from pre-check context: {vpc_id}" | |
| + (f" (CIDR(s) {cidr_log})" if cidr_log else "") | |
| ) | |
| elif NEW_VPC_DEFAULT_NAME and not VPC_NAME: | |
| new_vpc_created = True | |
| print( | |
| f"NEW_VPC_DEFAULT_NAME ('{NEW_VPC_DEFAULT_NAME}') is set. Creating a new VPC." | |
| ) | |
| # Configuration for the new VPC | |
| # You can make these configurable via context as well, e.g., | |
| # new_vpc_cidr = self.node.try_get_context("new_vpc_cidr") or "10.0.0.0/24" | |
| # new_vpc_max_azs = self.node.try_get_context("new_vpc_max_azs") or 2 # Use 2 AZs by default for HA | |
| # new_vpc_nat_gateways = self.node.try_get_context("new_vpc_nat_gateways") or new_vpc_max_azs # One NAT GW per AZ for HA | |
| # or 1 for cost savings if acceptable | |
| if not NEW_VPC_CIDR: | |
| raise Exception( | |
| "App has been instructed to create a new VPC but not VPC CDR range provided to variable NEW_VPC_CIDR" | |
| ) | |
| print("Provided NEW_VPC_CIDR range:", NEW_VPC_CIDR) | |
| new_vpc_cidr = NEW_VPC_CIDR | |
| new_vpc_max_azs = 2 # Creates resources in 2 AZs. Adjust as needed. | |
| # For "a NAT gateway", you can set nat_gateways=1. | |
| # For resilience (NAT GW per AZ), set nat_gateways=new_vpc_max_azs. | |
| # The Vpc construct will create NAT Gateway(s) if subnet_type PRIVATE_WITH_EGRESS is used | |
| # and nat_gateways > 0. | |
| if express_public_subnets_only: | |
| new_vpc_nat_gateways = 0 | |
| new_vpc_subnet_configuration = [ | |
| ec2.SubnetConfiguration( | |
| name="Public", | |
| subnet_type=ec2.SubnetType.PUBLIC, | |
| # /27 (~27 usable IPs): Express managed ALB needs 8+ free IPs per | |
| # subnet alongside VPC interface endpoints and task ENIs. | |
| cidr_mask=27, | |
| ), | |
| ] | |
| else: | |
| new_vpc_nat_gateways = ( | |
| 1 # Creates a single NAT Gateway for cost-effectiveness. | |
| ) | |
| new_vpc_subnet_configuration = [ | |
| ec2.SubnetConfiguration( | |
| name="Public", # Name prefix for public subnets | |
| subnet_type=ec2.SubnetType.PUBLIC, | |
| cidr_mask=26, | |
| ), | |
| ec2.SubnetConfiguration( | |
| name="Private", # Name prefix for private subnets | |
| subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS, # Ensures these subnets have NAT Gateway access | |
| cidr_mask=28, | |
| ), | |
| # You could also add ec2.SubnetType.PRIVATE_ISOLATED if needed | |
| ] | |
| # If you need one NAT GW per AZ for higher availability, set nat_gateways to new_vpc_max_azs. | |
| vpc = ec2.Vpc( | |
| self, | |
| "MyNewLogicalVpc", # This is the CDK construct ID | |
| vpc_name=NEW_VPC_DEFAULT_NAME, | |
| ip_addresses=ec2.IpAddresses.cidr(new_vpc_cidr), | |
| max_azs=new_vpc_max_azs, | |
| nat_gateways=new_vpc_nat_gateways, # Number of NAT gateways to create | |
| subnet_configuration=new_vpc_subnet_configuration, | |
| # Internet Gateway is created and configured automatically for PUBLIC subnets. | |
| # Route tables for public subnets will point to the IGW. | |
| # Route tables for PRIVATE_WITH_EGRESS subnets will point to the NAT Gateway(s). | |
| ) | |
| print( | |
| f"Successfully created new VPC: {vpc.vpc_id} with name '{NEW_VPC_DEFAULT_NAME}'" | |
| ) | |
| # If nat_gateways > 0, vpc.nat_gateway_ips will contain EIPs if Vpc created them. | |
| # vpc.public_subnets, vpc.private_subnets, vpc.isolated_subnets are populated. | |
| else: | |
| raise Exception( | |
| "VPC_NAME for current VPC not found, and NEW_VPC_DEFAULT_NAME not found to create a new VPC" | |
| ) | |
| # --- Subnet Handling (Check Context and Create/Import) --- | |
| # Initialize lists to hold ISubnet objects (L2) and CfnSubnet/CfnRouteTable (L1) | |
| # We will store ISubnet for consistency, as CfnSubnet has a .subnet_id property | |
| self.public_subnets: List[ec2.ISubnet] = [] | |
| self.private_subnets: List[ec2.ISubnet] = [] | |
| # Store L1 CfnRouteTables explicitly if you need to reference them later | |
| self.private_route_tables_cfn: List[ec2.CfnRouteTable] = [] | |
| self.public_route_tables_cfn: List[ec2.CfnRouteTable] = ( | |
| [] | |
| ) # New: to store public RTs | |
| names_to_create_private = [] | |
| names_to_create_public = [] | |
| if not PUBLIC_SUBNETS_TO_USE and not PRIVATE_SUBNETS_TO_USE: | |
| if express_public_subnets_only: | |
| print( | |
| "Express public-subnet mode: auto-selecting public subnets only " | |
| "(private subnets are not installed)." | |
| ) | |
| selected_public_subnets = vpc.select_subnets( | |
| subnet_type=ec2.SubnetType.PUBLIC, one_per_az=True | |
| ) | |
| if len(selected_public_subnets.subnet_ids) < 2: | |
| raise Exception( | |
| "Express mode needs at least two public subnets in different " | |
| "availability zones." | |
| ) | |
| self.public_subnets = selected_public_subnets.subnets | |
| self.private_subnets = [] | |
| print( | |
| f"Selected {len(self.public_subnets)} public subnets for Express." | |
| ) | |
| else: | |
| print( | |
| "Warning: No public or private subnets specified in *_SUBNETS_TO_USE. Attempting to select from existing VPC subnets." | |
| ) | |
| print("vpc.public_subnets:", vpc.public_subnets) | |
| print("vpc.private_subnets:", vpc.private_subnets) | |
| if ( | |
| vpc.public_subnets | |
| ): # These are already one_per_az if max_azs was used and Vpc created them | |
| self.public_subnets.extend(vpc.public_subnets) | |
| else: | |
| self.node.add_warning("No public subnets found in the VPC.") | |
| # Get private subnets with egress specifically | |
| # selected_private_subnets_with_egress = vpc.select_subnets(subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS) | |
| print( | |
| f"Selected from VPC: {len(self.public_subnets)} public, {len(self.private_subnets)} private_with_egress subnets." | |
| ) | |
| if ( | |
| len(self.public_subnets) < 1 or len(self.private_subnets) < 1 | |
| ): # Simplified check for new VPC | |
| # If new_vpc_max_azs was 1, you'd have 1 of each. If 2, then 2 of each. | |
| # The original check ' < 2' might be too strict if new_vpc_max_azs=1 | |
| pass # For new VPC, allow single AZ setups if configured that way. The VPC construct ensures one per AZ up to max_azs. | |
| if not self.public_subnets and not self.private_subnets: | |
| print( | |
| "Error: No public or private subnets could be found in the VPC for automatic selection. " | |
| "You must either specify subnets in *_SUBNETS_TO_USE or ensure the VPC has discoverable subnets." | |
| ) | |
| raise RuntimeError( | |
| "No suitable subnets found for automatic selection." | |
| ) | |
| else: | |
| print( | |
| f"Automatically selected {len(self.public_subnets)} public and {len(self.private_subnets)} private subnets based on VPC properties." | |
| ) | |
| selected_public_subnets = vpc.select_subnets( | |
| subnet_type=ec2.SubnetType.PUBLIC, one_per_az=True | |
| ) | |
| private_subnets_egress = vpc.select_subnets( | |
| subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS, one_per_az=True | |
| ) | |
| if private_subnets_egress.subnets: | |
| self.private_subnets.extend(private_subnets_egress.subnets) | |
| else: | |
| self.node.add_warning( | |
| "No PRIVATE_WITH_EGRESS subnets found in the VPC." | |
| ) | |
| try: | |
| private_subnets_isolated = vpc.select_subnets( | |
| subnet_type=ec2.SubnetType.PRIVATE_ISOLATED, one_per_az=True | |
| ) | |
| except Exception as e: | |
| private_subnets_isolated = [] | |
| print("Could not find any isolated subnets due to:", e) | |
| ### | |
| combined_subnet_objects = [] | |
| if private_subnets_isolated: | |
| if private_subnets_egress.subnets: | |
| # Add the first PRIVATE_WITH_EGRESS subnet | |
| combined_subnet_objects.append( | |
| private_subnets_egress.subnets[0] | |
| ) | |
| elif not private_subnets_isolated: | |
| if private_subnets_egress.subnets: | |
| # Add the first PRIVATE_WITH_EGRESS subnet | |
| combined_subnet_objects.extend(private_subnets_egress.subnets) | |
| else: | |
| self.node.add_warning( | |
| "No PRIVATE_WITH_EGRESS subnets found to select the first one." | |
| ) | |
| # Add all PRIVATE_ISOLATED subnets *except* the first one (if they exist) | |
| try: | |
| if len(private_subnets_isolated.subnets) > 1: | |
| combined_subnet_objects.extend( | |
| private_subnets_isolated.subnets[1:] | |
| ) | |
| elif ( | |
| private_subnets_isolated.subnets | |
| ): # Only 1 isolated subnet, add a warning if [1:] was desired | |
| self.node.add_warning( | |
| "Only one PRIVATE_ISOLATED subnet found, private_subnets_isolated.subnets[1:] will be empty." | |
| ) | |
| else: | |
| self.node.add_warning("No PRIVATE_ISOLATED subnets found.") | |
| except Exception as e: | |
| print("Could not identify private isolated subnets due to:", e) | |
| # Create an ec2.SelectedSubnets object from the combined private subnet list. | |
| selected_private_subnets = vpc.select_subnets( | |
| subnets=combined_subnet_objects | |
| ) | |
| print("selected_public_subnets:", selected_public_subnets) | |
| print("selected_private_subnets:", selected_private_subnets) | |
| if ( | |
| len(selected_public_subnets.subnet_ids) < 2 | |
| or len(selected_private_subnets.subnet_ids) < 2 | |
| ): | |
| raise Exception( | |
| "Need at least two public or private subnets in different availability zones" | |
| ) | |
| if not selected_public_subnets and not selected_private_subnets: | |
| # If no subnets could be found even with automatic selection, raise an error. | |
| # This ensures the stack doesn't proceed if it absolutely needs subnets. | |
| print( | |
| "Error: No existing public or private subnets could be found in the VPC for automatic selection. " | |
| "You must either specify subnets in *_SUBNETS_TO_USE or ensure the VPC has discoverable subnets." | |
| ) | |
| raise RuntimeError( | |
| "No suitable subnets found for automatic selection." | |
| ) | |
| else: | |
| self.public_subnets = selected_public_subnets.subnets | |
| self.private_subnets = selected_private_subnets.subnets | |
| print( | |
| f"Automatically selected {len(self.public_subnets)} public and {len(self.private_subnets)} private subnets based on VPC discovery." | |
| ) | |
| print("self.public_subnets:", self.public_subnets) | |
| print("self.private_subnets:", self.private_subnets) | |
| # Since subnets are now assigned, we can exit this processing block. | |
| # The rest of the original code (which iterates *_SUBNETS_TO_USE) will be skipped. | |
| checked_public_subnets_ctx = get_context_dict("checked_public_subnets") | |
| checked_private_subnets_ctx = get_context_dict("checked_private_subnets") | |
| public_subnets_data_for_creation_ctx = get_context_list_of_dicts( | |
| "public_subnets_to_create" | |
| ) | |
| private_subnets_data_for_creation_ctx = get_context_list_of_dicts( | |
| "private_subnets_to_create" | |
| ) | |
| # --- 3. Process Public Subnets --- | |
| print("\n--- Processing Public Subnets ---") | |
| public_internet_gateway_attachment = None | |
| if not new_vpc_created: | |
| resolved_igw_id = ( | |
| get_context_str("internet_gateway_id") or EXISTING_IGW_ID or "" | |
| ).strip() | |
| if resolved_igw_id and ( | |
| PUBLIC_SUBNETS_TO_USE | |
| or public_subnets_data_for_creation_ctx | |
| or get_context_list_of_dicts("public_subnets_needing_igw_route") | |
| ): | |
| public_internet_gateway_attachment = wire_public_subnet_internet_access( | |
| self, | |
| "PublicSubnetInternet", | |
| vpc_id=vpc.vpc_id, | |
| internet_gateway_id=resolved_igw_id, | |
| needs_igw_vpc_attachment=get_context_bool( | |
| "internet_gateway_needs_vpc_attachment", False | |
| ), | |
| subnets_needing_route=get_context_list_of_dicts( | |
| "public_subnets_needing_igw_route" | |
| ), | |
| ) | |
| # Import existing public subnets | |
| if checked_public_subnets_ctx: | |
| for i, subnet_name in enumerate(PUBLIC_SUBNETS_TO_USE): | |
| subnet_info = checked_public_subnets_ctx.get(subnet_name) | |
| if subnet_info and subnet_info.get("exists"): | |
| subnet_id = subnet_info.get("id") | |
| if not subnet_id: | |
| raise RuntimeError( | |
| f"Context for existing public subnet '{subnet_name}' is missing 'id'." | |
| ) | |
| subnet_az = subnet_info.get("az") | |
| if ( | |
| not subnet_az | |
| and PUBLIC_SUBNET_AVAILABILITY_ZONES | |
| and i < len(PUBLIC_SUBNET_AVAILABILITY_ZONES) | |
| ): | |
| subnet_az = PUBLIC_SUBNET_AVAILABILITY_ZONES[i] | |
| if not subnet_az: | |
| raise RuntimeError( | |
| f"Context for existing public subnet '{subnet_name}' is missing 'az'." | |
| ) | |
| subnet_attrs = { | |
| "subnet_id": subnet_id, | |
| "availability_zone": subnet_az, | |
| } | |
| route_table_id = subnet_info.get("route_table_id") | |
| if route_table_id: | |
| subnet_attrs["route_table_id"] = route_table_id | |
| try: | |
| imported_subnet = ec2.Subnet.from_subnet_attributes( | |
| self, | |
| f"ImportedPublicSubnet{subnet_name.replace('-', '')}{i}", | |
| **subnet_attrs, | |
| ) | |
| self.public_subnets.append(imported_subnet) | |
| print( | |
| f"Imported existing public subnet: {subnet_name} (ID: {subnet_id})" | |
| ) | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Failed to import public subnet '{subnet_name}' with ID '{subnet_id}'. Error: {e}" | |
| ) | |
| # Create new public subnets based on public_subnets_data_for_creation_ctx | |
| if public_subnets_data_for_creation_ctx: | |
| names_to_create_public = [ | |
| s["name"] for s in public_subnets_data_for_creation_ctx | |
| ] | |
| cidrs_to_create_public = [ | |
| s["cidr"] for s in public_subnets_data_for_creation_ctx | |
| ] | |
| azs_to_create_public = [ | |
| s["az"] for s in public_subnets_data_for_creation_ctx | |
| ] | |
| if names_to_create_public: | |
| print( | |
| f"Attempting to create {len(names_to_create_public)} new public subnets: {names_to_create_public}" | |
| ) | |
| igw_for_new_subnets = ( | |
| get_context_str("internet_gateway_id") or EXISTING_IGW_ID | |
| ) | |
| newly_created_public_subnets, newly_created_public_rts_cfn = ( | |
| create_subnets( | |
| self, | |
| vpc, | |
| CDK_PREFIX, | |
| names_to_create_public, | |
| cidrs_to_create_public, | |
| azs_to_create_public, | |
| is_public=True, | |
| internet_gateway_id=igw_for_new_subnets, | |
| internet_gateway_attachment=public_internet_gateway_attachment, | |
| ) | |
| ) | |
| self.public_subnets.extend(newly_created_public_subnets) | |
| self.public_route_tables_cfn.extend(newly_created_public_rts_cfn) | |
| if ( | |
| not self.public_subnets | |
| and not names_to_create_public | |
| and not PUBLIC_SUBNETS_TO_USE | |
| ): | |
| raise Exception("No public subnets found or created, exiting.") | |
| # --- NAT Gateway Creation/Lookup --- | |
| self.single_nat_gateway_id = None | |
| if express_public_subnets_only: | |
| print( | |
| "Express public-subnet mode: skipping NAT Gateway install " | |
| "(not required for public Express tasks)." | |
| ) | |
| else: | |
| print("Creating NAT gateway/located existing") | |
| nat_gw_id_from_context = SINGLE_NAT_GATEWAY_ID or get_context_str( | |
| "id:NatGateway" | |
| ) | |
| if nat_gw_id_from_context: | |
| print( | |
| f"Using existing NAT Gateway ID from context: {nat_gw_id_from_context}" | |
| ) | |
| self.single_nat_gateway_id = nat_gw_id_from_context | |
| elif ( | |
| new_vpc_created | |
| and new_vpc_nat_gateways > 0 | |
| and hasattr(vpc, "nat_gateways") | |
| and vpc.nat_gateways | |
| ): | |
| self.single_nat_gateway_id = vpc.nat_gateways[0].gateway_id | |
| print( | |
| f"Using NAT Gateway {self.single_nat_gateway_id} created by the new VPC construct." | |
| ) | |
| if not self.single_nat_gateway_id: | |
| print("Creating a new NAT gateway") | |
| if hasattr(vpc, "nat_gateways") and vpc.nat_gateways: | |
| print("Existing NAT gateway found in vpc") | |
| pass | |
| # If not in context, create a new one, but only if we have a public subnet. | |
| elif self.public_subnets: | |
| print("NAT Gateway ID not found in context. Creating a new one.") | |
| # Place the NAT GW in the first available public subnet | |
| first_public_subnet = self.public_subnets[0] | |
| self.single_nat_gateway_id = create_nat_gateway( | |
| self, | |
| first_public_subnet, | |
| nat_gateway_name=NAT_GATEWAY_NAME, | |
| nat_gateway_id_context_key=SINGLE_NAT_GATEWAY_ID, | |
| ) | |
| else: | |
| print( | |
| "WARNING: No public subnets available and NAT gateway not found in existing VPC. Cannot create a NAT Gateway." | |
| ) | |
| # --- 4. Process Private Subnets --- | |
| if express_public_subnets_only: | |
| if PRIVATE_SUBNETS_TO_USE or private_subnets_data_for_creation_ctx: | |
| print( | |
| "Note: PRIVATE_* subnet settings are ignored in Express public-subnet mode." | |
| ) | |
| else: | |
| print("\n--- Processing Private Subnets ---") | |
| if checked_private_subnets_ctx: | |
| for i, subnet_name in enumerate(PRIVATE_SUBNETS_TO_USE): | |
| subnet_info = checked_private_subnets_ctx.get(subnet_name) | |
| if subnet_info and subnet_info.get("exists"): | |
| subnet_id = subnet_info.get("id") | |
| if not subnet_id: | |
| raise RuntimeError( | |
| f"Context for existing private subnet '{subnet_name}' is missing 'id'." | |
| ) | |
| subnet_az = subnet_info.get("az") | |
| if ( | |
| not subnet_az | |
| and PRIVATE_SUBNET_AVAILABILITY_ZONES | |
| and i < len(PRIVATE_SUBNET_AVAILABILITY_ZONES) | |
| ): | |
| subnet_az = PRIVATE_SUBNET_AVAILABILITY_ZONES[i] | |
| if not subnet_az: | |
| raise RuntimeError( | |
| f"Context for existing private subnet '{subnet_name}' is missing 'az'." | |
| ) | |
| subnet_attrs = { | |
| "subnet_id": subnet_id, | |
| "availability_zone": subnet_az, | |
| } | |
| route_table_id = subnet_info.get("route_table_id") | |
| if route_table_id: | |
| subnet_attrs["route_table_id"] = route_table_id | |
| try: | |
| imported_subnet = ec2.Subnet.from_subnet_attributes( | |
| self, | |
| f"ImportedPrivateSubnet{subnet_name.replace('-', '')}{i}", | |
| **subnet_attrs, | |
| ) | |
| self.private_subnets.append(imported_subnet) | |
| print( | |
| f"Imported existing private subnet: {subnet_name} (ID: {subnet_id})" | |
| ) | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Failed to import private subnet '{subnet_name}' with ID '{subnet_id}'. Error: {e}" | |
| ) | |
| # Create new private subnets | |
| if private_subnets_data_for_creation_ctx: | |
| names_to_create_private = [ | |
| s["name"] for s in private_subnets_data_for_creation_ctx | |
| ] | |
| cidrs_to_create_private = [ | |
| s["cidr"] for s in private_subnets_data_for_creation_ctx | |
| ] | |
| azs_to_create_private = [ | |
| s["az"] for s in private_subnets_data_for_creation_ctx | |
| ] | |
| if names_to_create_private: | |
| print( | |
| f"Attempting to create {len(names_to_create_private)} new private subnets: {names_to_create_private}" | |
| ) | |
| # --- CALL THE NEW CREATE_SUBNETS FUNCTION FOR PRIVATE --- | |
| # Ensure self.single_nat_gateway_id is available before this call | |
| if not self.single_nat_gateway_id: | |
| raise ValueError( | |
| "A single NAT Gateway ID is required for private subnets but was not resolved." | |
| ) | |
| newly_created_private_subnets_cfn, newly_created_private_rts_cfn = ( | |
| create_subnets( | |
| self, | |
| vpc, | |
| CDK_PREFIX, | |
| names_to_create_private, | |
| cidrs_to_create_private, | |
| azs_to_create_private, | |
| is_public=False, | |
| single_nat_gateway_id=self.single_nat_gateway_id, # Pass the single NAT Gateway ID | |
| ) | |
| ) | |
| self.private_subnets.extend(newly_created_private_subnets_cfn) | |
| self.private_route_tables_cfn.extend(newly_created_private_rts_cfn) | |
| print( | |
| f"Successfully defined {len(newly_created_private_subnets_cfn)} new private subnets and their route tables for creation." | |
| ) | |
| else: | |
| print( | |
| "No private subnets specified for creation in context ('private_subnets_to_create')." | |
| ) | |
| # if not self.private_subnets: | |
| # raise Exception("No private subnets found or created, exiting.") | |
| if ( | |
| not self.private_subnets | |
| and not names_to_create_private | |
| and not PRIVATE_SUBNETS_TO_USE | |
| ): | |
| # This condition might need adjustment for new VPCs. | |
| raise Exception("No private subnets found or created, exiting.") | |
| # --- 5. Sanity Check and Output --- | |
| # Output the single NAT Gateway ID for verification | |
| if self.single_nat_gateway_id: | |
| CfnOutput( | |
| self, | |
| "SingleNatGatewayId", | |
| value=self.single_nat_gateway_id, | |
| description="ID of the single NAT Gateway resolved or created.", | |
| ) | |
| elif express_public_subnets_only: | |
| print( | |
| "INFO: Express public-subnet mode — NAT Gateway not installed or required." | |
| ) | |
| elif ( | |
| NEW_VPC_DEFAULT_NAME | |
| and (self.node.try_get_context("new_vpc_nat_gateways") or 1) > 0 | |
| ): | |
| print( | |
| "INFO: A new VPC was created with NAT Gateway(s). Their routing is handled by the VPC construct. No single_nat_gateway_id was explicitly set for separate output." | |
| ) | |
| else: | |
| out_message = "WARNING: No single NAT Gateway was resolved or created explicitly by the script's logic after VPC setup." | |
| print(out_message) | |
| raise Exception(out_message) | |
| # --- Outputs for other stacks/regions --- | |
| # These are crucial for cross-stack, cross-region referencing | |
| self.params = dict() | |
| self.params["vpc_id"] = vpc.vpc_id | |
| self.params["private_subnets"] = self.private_subnets | |
| self.params["private_route_tables"] = self.private_route_tables_cfn | |
| self.params["public_subnets"] = self.public_subnets | |
| self.params["public_route_tables"] = self.public_route_tables_cfn | |
| private_subnet_selection = ec2.SubnetSelection(subnets=self.private_subnets) | |
| public_subnet_selection = ec2.SubnetSelection(subnets=self.public_subnets) | |
| for sub in private_subnet_selection.subnets: | |
| print( | |
| "private subnet:", | |
| sub.subnet_id, | |
| "is in availability zone:", | |
| sub.availability_zone, | |
| ) | |
| for sub in public_subnet_selection.subnets: | |
| print( | |
| "public subnet:", | |
| sub.subnet_id, | |
| "is in availability zone:", | |
| sub.availability_zone, | |
| ) | |
| print("Private subnet route tables:", self.private_route_tables_cfn) | |
| CfnOutput( | |
| self, | |
| "VpcIdOutput", | |
| value=vpc.vpc_id, | |
| description="The ID of the VPC used by this stack.", | |
| ) | |
| # --- IAM Roles --- | |
| cognito_secret_name = COGNITO_USER_POOL_CLIENT_SECRET_NAME | |
| secret_kms_key_arn_from_context = get_context_str( | |
| f"kms_key_arn:{cognito_secret_name}" | |
| ) | |
| if USE_CUSTOM_KMS_KEY == "1": | |
| kms_key = kms.Key( | |
| self, | |
| "SummarisationSharedKmsKey", | |
| alias=CUSTOM_KMS_KEY_NAME, | |
| removal_policy=resource_removal_policy, | |
| ) | |
| shared_kms_key_arn = kms_key.key_arn | |
| secret_kms_key_arn = secret_kms_key_arn_from_context or kms_key.key_arn | |
| else: | |
| kms_key = None | |
| shared_kms_key_arn = None | |
| secret_kms_key_arn = ( | |
| secret_kms_key_arn_from_context | |
| or default_secrets_manager_kms_key_arn(AWS_REGION, AWS_ACCOUNT_ID) | |
| ) | |
| task_role_inline_policy = json.dumps( | |
| build_ecs_task_role_inline_policy( | |
| output_bucket_name=S3_OUTPUT_BUCKET_NAME, | |
| log_config_bucket_name=S3_LOG_CONFIG_BUCKET_NAME, | |
| shared_kms_key_arn=shared_kms_key_arn, | |
| ), | |
| indent=4, | |
| ) | |
| if enable_headless: | |
| execution_role_kms_policy = json.dumps( | |
| { | |
| "Version": "2012-10-17", | |
| "Statement": [ | |
| { | |
| "Sid": "STSCallerIdentity", | |
| "Effect": "Allow", | |
| "Action": ["sts:GetCallerIdentity"], | |
| "Resource": "*", | |
| } | |
| ], | |
| }, | |
| indent=4, | |
| ) | |
| else: | |
| execution_role_kms_policy = json.dumps( | |
| build_ecs_execution_role_kms_policy( | |
| secret_kms_key_arn=secret_kms_key_arn, | |
| ), | |
| indent=4, | |
| ) | |
| try: | |
| codebuild_role_name = CODEBUILD_ROLE_NAME | |
| if get_context_bool(f"exists:{codebuild_role_name}"): | |
| # If exists, lookup/import the role using ARN from context | |
| role_arn = get_context_str(f"arn:{codebuild_role_name}") | |
| if not role_arn: | |
| raise ValueError( | |
| f"Context value 'arn:{codebuild_role_name}' is required if role exists." | |
| ) | |
| codebuild_role = iam.Role.from_role_arn( | |
| self, "CodeBuildRole", role_arn=role_arn | |
| ) | |
| print("Using existing CodeBuild role") | |
| else: | |
| # If not exists, create the role | |
| codebuild_role = iam.Role( | |
| self, | |
| "CodeBuildRole", # Logical ID | |
| role_name=codebuild_role_name, # Explicit resource name | |
| assumed_by=iam.ServicePrincipal("codebuild.amazonaws.com"), | |
| ) | |
| codebuild_role.add_managed_policy( | |
| iam.ManagedPolicy.from_aws_managed_policy_name( | |
| "EC2InstanceProfileForImageBuilderECRContainerBuilds" | |
| ) | |
| ) | |
| print("Successfully created new CodeBuild role") | |
| task_role_name = ECS_TASK_ROLE_NAME | |
| if get_context_bool(f"exists:{task_role_name}"): | |
| role_arn = get_context_str(f"arn:{task_role_name}") | |
| if not role_arn: | |
| raise ValueError( | |
| f"Context value 'arn:{task_role_name}' is required if role exists." | |
| ) | |
| task_role = iam.Role.from_role_arn(self, "TaskRole", role_arn=role_arn) | |
| print("Using existing ECS task role") | |
| else: | |
| task_role = iam.Role( | |
| self, | |
| "TaskRole", # Logical ID | |
| role_name=task_role_name, # Explicit resource name | |
| assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com"), | |
| ) | |
| for role in AWS_MANAGED_TASK_ROLES_LIST: | |
| print(f"Adding {role} to policy") | |
| task_role.add_managed_policy( | |
| iam.ManagedPolicy.from_aws_managed_policy_name(f"{role}") | |
| ) | |
| attach_managed_policy_arns(task_role, POLICY_FILE_ARNS) | |
| print("Successfully created new ECS task role") | |
| task_role = add_custom_policies( | |
| self, | |
| task_role, | |
| policy_file_locations=resolve_policy_file_paths( | |
| POLICY_FILE_LOCATIONS, cdk_folder=CDK_FOLDER | |
| ), | |
| custom_policy_text=task_role_inline_policy, | |
| ) | |
| execution_role_name = ECS_TASK_EXECUTION_ROLE_NAME | |
| if get_context_bool(f"exists:{execution_role_name}"): | |
| role_arn = get_context_str(f"arn:{execution_role_name}") | |
| if not role_arn: | |
| raise ValueError( | |
| f"Context value 'arn:{execution_role_name}' is required if role exists." | |
| ) | |
| execution_role = iam.Role.from_role_arn( | |
| self, "ExecutionRole", role_arn=role_arn | |
| ) | |
| print("Using existing ECS execution role") | |
| else: | |
| execution_role = iam.Role( | |
| self, | |
| "ExecutionRole", # Logical ID | |
| role_name=execution_role_name, # Explicit resource name | |
| assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com"), | |
| ) | |
| for role in ECS_EXECUTION_ROLE_MANAGED_POLICIES: | |
| print(f"Adding {role} to execution role") | |
| execution_role.add_managed_policy( | |
| iam.ManagedPolicy.from_aws_managed_policy_name(f"{role}") | |
| ) | |
| attach_managed_policy_arns( | |
| execution_role, ECS_EXECUTION_ROLE_POLICY_ARNS | |
| ) | |
| print("Successfully created new ECS execution role") | |
| execution_role = add_custom_policies( | |
| self, | |
| execution_role, | |
| policy_file_locations=resolve_policy_file_paths( | |
| ECS_EXECUTION_ROLE_POLICY_FILES, cdk_folder=CDK_FOLDER | |
| ), | |
| custom_policy_text=execution_role_kms_policy, | |
| ) | |
| except Exception as e: | |
| raise Exception("Failed at IAM role step due to:", e) | |
| # --- S3 Buckets --- | |
| try: | |
| log_bucket_name = S3_LOG_CONFIG_BUCKET_NAME | |
| if get_context_bool(f"globally_taken:{log_bucket_name}"): | |
| raise ValueError( | |
| f"S3 bucket name {log_bucket_name!r} is taken globally by another " | |
| "AWS account. Set S3_LOG_CONFIG_BUCKET_NAME in cdk/config/cdk_config.env " | |
| "to a unique name (re-run cdk_install.py or check_resources.py)." | |
| ) | |
| if get_context_bool(f"exists:{log_bucket_name}"): | |
| bucket = s3.Bucket.from_bucket_name( | |
| self, "LogConfigBucket", bucket_name=log_bucket_name | |
| ) | |
| print("Using existing S3 bucket", log_bucket_name) | |
| else: | |
| log_bucket_lifecycle = [ | |
| s3.LifecycleRule( | |
| abort_incomplete_multipart_upload_after=Duration.days(7) | |
| ) | |
| ] | |
| if USE_CUSTOM_KMS_KEY == "1" and isinstance(kms_key, kms.Key): | |
| bucket = s3.Bucket( | |
| self, | |
| "LogConfigBucket", | |
| bucket_name=log_bucket_name, | |
| lifecycle_rules=log_bucket_lifecycle, | |
| versioned=False, | |
| removal_policy=resource_removal_policy, | |
| auto_delete_objects=s3_auto_delete_objects, | |
| encryption=s3.BucketEncryption.KMS, | |
| encryption_key=kms_key, | |
| ) | |
| else: | |
| bucket = s3.Bucket( | |
| self, | |
| "LogConfigBucket", | |
| bucket_name=log_bucket_name, | |
| lifecycle_rules=log_bucket_lifecycle, | |
| versioned=False, | |
| removal_policy=resource_removal_policy, | |
| auto_delete_objects=s3_auto_delete_objects, | |
| ) | |
| print("Created S3 bucket", log_bucket_name) | |
| # Add policies - this will apply to both created and imported buckets | |
| # CDK handles idempotent policy additions | |
| bucket.add_to_resource_policy( | |
| iam.PolicyStatement( | |
| effect=iam.Effect.ALLOW, | |
| principals=[task_role], # Pass the role object directly | |
| actions=["s3:GetObject", "s3:PutObject"], | |
| resources=[f"{bucket.bucket_arn}/*"], | |
| ) | |
| ) | |
| bucket.add_to_resource_policy( | |
| iam.PolicyStatement( | |
| effect=iam.Effect.ALLOW, | |
| principals=[task_role], | |
| actions=["s3:ListBucket"], | |
| resources=[bucket.bucket_arn], | |
| ) | |
| ) | |
| output_bucket_name = S3_OUTPUT_BUCKET_NAME | |
| if get_context_bool(f"globally_taken:{output_bucket_name}"): | |
| raise ValueError( | |
| f"S3 bucket name {output_bucket_name!r} is taken globally by another " | |
| "AWS account. Set S3_OUTPUT_BUCKET_NAME in cdk/config/cdk_config.env " | |
| "to a unique name (re-run cdk_install.py or check_resources.py)." | |
| ) | |
| if get_context_bool(f"exists:{output_bucket_name}"): | |
| output_bucket = s3.Bucket.from_bucket_name( | |
| self, "OutputBucket", bucket_name=output_bucket_name | |
| ) | |
| print("Using existing Output bucket", output_bucket_name) | |
| else: | |
| if USE_CUSTOM_KMS_KEY == "1" and isinstance(kms_key, kms.Key): | |
| output_bucket = s3.Bucket( | |
| self, | |
| "OutputBucket", | |
| bucket_name=output_bucket_name, | |
| lifecycle_rules=[ | |
| s3.LifecycleRule( | |
| expiration=Duration.days( | |
| int(DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS) | |
| ) | |
| ) | |
| ], | |
| versioned=False, | |
| removal_policy=resource_removal_policy, | |
| auto_delete_objects=s3_auto_delete_objects, | |
| encryption=s3.BucketEncryption.KMS, | |
| encryption_key=kms_key, | |
| ) | |
| else: | |
| output_bucket = s3.Bucket( | |
| self, | |
| "OutputBucket", | |
| bucket_name=output_bucket_name, | |
| lifecycle_rules=[ | |
| s3.LifecycleRule( | |
| expiration=Duration.days( | |
| int(DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS) | |
| ) | |
| ) | |
| ], | |
| versioned=False, | |
| removal_policy=resource_removal_policy, | |
| auto_delete_objects=s3_auto_delete_objects, | |
| ) | |
| print("Created Output bucket:", output_bucket_name) | |
| add_s3_enforce_ssl_policy(bucket) | |
| add_s3_enforce_ssl_policy(output_bucket) | |
| # Add policies to output bucket | |
| output_bucket.add_to_resource_policy( | |
| iam.PolicyStatement( | |
| effect=iam.Effect.ALLOW, | |
| principals=[task_role], | |
| actions=["s3:GetObject", "s3:PutObject"], | |
| resources=[f"{output_bucket.bucket_arn}/*"], | |
| ) | |
| ) | |
| output_bucket.add_to_resource_policy( | |
| iam.PolicyStatement( | |
| effect=iam.Effect.ALLOW, | |
| principals=[task_role], | |
| actions=["s3:ListBucket"], | |
| resources=[output_bucket.bucket_arn], | |
| ) | |
| ) | |
| # Identity-based S3 access is scoped via build_ecs_task_role_inline_policy on | |
| # task_role (output + log/config buckets). Bucket policies above remain for | |
| # imported buckets and org policies that expect explicit bucket principals. | |
| except Exception as e: | |
| raise Exception("Could not handle S3 buckets due to:", e) | |
| # --- Elastic Container Registry --- | |
| try: | |
| full_ecr_repo_name = ECR_CDK_REPO_NAME | |
| if get_context_bool(f"exists:{full_ecr_repo_name}"): | |
| ecr_repo = ecr.Repository.from_repository_name( | |
| self, "ECRRepo", repository_name=full_ecr_repo_name | |
| ) | |
| print("Using existing ECR repository") | |
| else: | |
| ecr_repo = ecr.Repository( | |
| self, | |
| "ECRRepo", | |
| repository_name=full_ecr_repo_name, | |
| removal_policy=resource_removal_policy, | |
| empty_on_delete=ecr_empty_on_delete(), | |
| ) # Explicitly set repository_name | |
| print("Created ECR repository", full_ecr_repo_name) | |
| ecr_image_loc = ecr_repo.repository_uri | |
| except Exception as e: | |
| raise Exception("Could not handle ECR repo due to:", e) | |
| pi_ecr_image_loc = ecr_image_loc | |
| # --- CODEBUILD --- | |
| try: | |
| codebuild_project_name = CODEBUILD_PROJECT_NAME | |
| if get_context_bool(f"exists:{codebuild_project_name}"): | |
| # Lookup CodeBuild project by ARN from context | |
| project_arn = get_context_str(f"arn:{codebuild_project_name}") | |
| if not project_arn: | |
| raise ValueError( | |
| f"Context value 'arn:{codebuild_project_name}' is required if project exists." | |
| ) | |
| codebuild.Project.from_project_arn( | |
| self, "CodeBuildProject", project_arn=project_arn | |
| ) | |
| print( | |
| "Using existing CodeBuild project " | |
| "(public GitHub source is applied in post_cdk_build_quickstart)." | |
| ) | |
| else: | |
| main_codebuild_project = codebuild.Project( | |
| self, | |
| "CodeBuildProject", # Logical ID | |
| project_name=codebuild_project_name, # Explicit resource name | |
| role=codebuild_role, | |
| source=public_github_codebuild_source( | |
| owner=GITHUB_REPO_USERNAME, | |
| repo=GITHUB_REPO_NAME, | |
| branch_or_ref=GITHUB_REPO_BRANCH, | |
| ), | |
| environment=codebuild.BuildEnvironment( | |
| build_image=codebuild.LinuxBuildImage.STANDARD_7_0, | |
| privileged=True, | |
| environment_variables={ | |
| "ECR_REPO_NAME": codebuild.BuildEnvironmentVariable( | |
| value=full_ecr_repo_name | |
| ), | |
| "AWS_DEFAULT_REGION": codebuild.BuildEnvironmentVariable( | |
| value=AWS_REGION | |
| ), | |
| "AWS_ACCOUNT_ID": codebuild.BuildEnvironmentVariable( | |
| value=AWS_ACCOUNT_ID | |
| ), | |
| "APP_MODE": codebuild.BuildEnvironmentVariable( | |
| value="gradio" | |
| ), | |
| }, | |
| ), | |
| build_spec=codebuild.BuildSpec.from_object( | |
| { | |
| "version": "0.2", | |
| "phases": { | |
| "pre_build": { | |
| "commands": [ | |
| "echo Logging in to Amazon ECR", | |
| "aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com", | |
| ] | |
| }, | |
| "build": { | |
| "commands": [ | |
| "echo Building the Docker image", | |
| "docker build --build-arg APP_MODE=$APP_MODE --target $APP_MODE -t $ECR_REPO_NAME:latest .", | |
| "docker tag $ECR_REPO_NAME:latest $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest", | |
| ] | |
| }, | |
| "post_build": { | |
| "commands": [ | |
| "echo Pushing the Docker image", | |
| "docker push $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest", | |
| ] | |
| }, | |
| }, | |
| } | |
| ), | |
| ) | |
| configure_public_github_codebuild_source( | |
| main_codebuild_project, | |
| GITHUB_REPO_USERNAME, | |
| GITHUB_REPO_NAME, | |
| GITHUB_REPO_BRANCH, | |
| ) | |
| print("Successfully created CodeBuild project", codebuild_project_name) | |
| # Imported projects have role=undefined in CDK; use the actual service | |
| # role from context (existing project) or the managed codebuild_role (new). | |
| if get_context_bool(f"exists:{codebuild_project_name}"): | |
| project_service_role_arn = get_context_str( | |
| f"service_role_arn:{codebuild_project_name}" | |
| ) | |
| if project_service_role_arn: | |
| ecr_grantee = iam.Role.from_role_arn( | |
| self, | |
| "CodeBuildProjectServiceRole", | |
| role_arn=project_service_role_arn, | |
| mutable=True, | |
| ) | |
| else: | |
| ecr_grantee = codebuild_role | |
| else: | |
| ecr_grantee = codebuild_role | |
| ecr_repo.grant_pull_push(ecr_grantee) | |
| if enable_pi_build: | |
| pi_codebuild_name = CODEBUILD_PI_PROJECT_NAME | |
| if get_context_bool(f"exists:{pi_codebuild_name}"): | |
| project_arn = get_context_str(f"arn:{pi_codebuild_name}") | |
| if project_arn: | |
| codebuild.Project.from_project_arn( | |
| self, "CodeBuildPiProject", project_arn=project_arn | |
| ) | |
| print("Using existing Pi agent CodeBuild project") | |
| else: | |
| pi_codebuild_project = codebuild.Project( | |
| self, | |
| "CodeBuildPiProject", | |
| project_name=pi_codebuild_name, | |
| role=codebuild_role, | |
| source=public_github_codebuild_source( | |
| owner=GITHUB_REPO_USERNAME, | |
| repo=GITHUB_REPO_NAME, | |
| branch_or_ref=GITHUB_REPO_BRANCH, | |
| ), | |
| environment=codebuild.BuildEnvironment( | |
| build_image=codebuild.LinuxBuildImage.STANDARD_7_0, | |
| privileged=True, | |
| environment_variables={ | |
| "ECR_REPO_NAME": codebuild.BuildEnvironmentVariable( | |
| value=ECR_PI_REPO_NAME | |
| ), | |
| "AWS_DEFAULT_REGION": codebuild.BuildEnvironmentVariable( | |
| value=AWS_REGION | |
| ), | |
| "AWS_ACCOUNT_ID": codebuild.BuildEnvironmentVariable( | |
| value=AWS_ACCOUNT_ID | |
| ), | |
| }, | |
| ), | |
| build_spec=codebuild.BuildSpec.from_object( | |
| { | |
| "version": "0.2", | |
| "phases": { | |
| "pre_build": { | |
| "commands": [ | |
| "echo Logging in to Amazon ECR", | |
| "aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com", | |
| "test -f config/pi_agent.env.example", | |
| "test -f agent-redact/pi-agent/Dockerfile", | |
| ] | |
| }, | |
| "build": { | |
| "commands": [ | |
| "docker build -f agent-redact/pi-agent/Dockerfile -t $ECR_REPO_NAME:latest .", | |
| "docker tag $ECR_REPO_NAME:latest $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest", | |
| ] | |
| }, | |
| "post_build": { | |
| "commands": [ | |
| "docker push $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest", | |
| ] | |
| }, | |
| }, | |
| } | |
| ), | |
| ) | |
| configure_public_github_codebuild_source( | |
| pi_codebuild_project, | |
| GITHUB_REPO_USERNAME, | |
| GITHUB_REPO_NAME, | |
| GITHUB_REPO_BRANCH, | |
| ) | |
| print("Created Pi agent CodeBuild project", pi_codebuild_name) | |
| pi_ecr_repo_name = ECR_PI_REPO_NAME | |
| if get_context_bool(f"exists:{pi_ecr_repo_name}"): | |
| pi_ecr_repo = ecr.Repository.from_repository_name( | |
| self, "ECRPiRepo", repository_name=pi_ecr_repo_name | |
| ) | |
| else: | |
| pi_ecr_repo = ecr.Repository( | |
| self, | |
| "ECRPiRepo", | |
| repository_name=pi_ecr_repo_name, | |
| removal_policy=resource_removal_policy, | |
| empty_on_delete=ecr_empty_on_delete(), | |
| ) | |
| pi_ecr_image_loc = pi_ecr_repo.repository_uri | |
| pi_ecr_repo.grant_pull_push(ecr_grantee) | |
| CfnOutput(self, "ECRPiRepoUri", value=pi_ecr_repo.repository_uri) | |
| except Exception as e: | |
| raise Exception("Could not handle Codebuild project due to:", e) | |
| pi_ecs_service = None | |
| pi_ecs_security_group = None | |
| # --- Security Groups --- | |
| try: | |
| ecs_security_group_name = ECS_SECURITY_GROUP_NAME | |
| try: | |
| ecs_security_group = ec2.SecurityGroup( | |
| self, | |
| "ECSSecurityGroup", # Logical ID | |
| security_group_name=ecs_security_group_name, # Explicit resource name | |
| vpc=vpc, | |
| ) | |
| print(f"Created Security Group: {ecs_security_group_name}") | |
| except Exception as e: # If lookup fails, create | |
| print("Failed to create ECS security group due to:", e) | |
| ec2_port_gradio_server_port = ec2.Port.tcp(int(GRADIO_SERVER_PORT)) | |
| if deploy_web_ingress: | |
| alb_security_group_name = ALB_NAME_SECURITY_GROUP_NAME | |
| try: | |
| alb_security_group = ec2.SecurityGroup( | |
| self, | |
| "ALBSecurityGroup", # Logical ID | |
| security_group_name=alb_security_group_name, | |
| vpc=vpc, | |
| ) | |
| print(f"Created Security Group: {alb_security_group_name}") | |
| except Exception as e: | |
| print("Failed to create ALB security group due to:", e) | |
| ecs_security_group.add_ingress_rule( | |
| peer=alb_security_group, | |
| connection=ec2_port_gradio_server_port, | |
| description="ALB traffic", | |
| ) | |
| alb_security_group.add_ingress_rule( | |
| peer=ec2.Peer.prefix_list(CLOUDFRONT_PREFIX_LIST_ID), | |
| connection=ec2.Port.all_traffic(), | |
| description="CloudFront traffic", | |
| ) | |
| else: | |
| alb_security_group = None | |
| if USE_CLOUDFRONT == "True": | |
| ecs_security_group.add_ingress_rule( | |
| peer=ec2.Peer.prefix_list(CLOUDFRONT_PREFIX_LIST_ID), | |
| connection=ec2_port_gradio_server_port, | |
| description="CloudFront to ECS (Express Mode)", | |
| ) | |
| if enable_service_connect: | |
| for index, client_sg_id in enumerate(service_connect_client_sg_ids): | |
| client_sg = ec2.SecurityGroup.from_security_group_id( | |
| self, | |
| f"ServiceConnectClientSg{index}", | |
| security_group_id=client_sg_id, | |
| ) | |
| ecs_security_group.add_ingress_rule( | |
| peer=client_sg, | |
| connection=ec2_port_gradio_server_port, | |
| description=( | |
| f"Service Connect client {client_sg_id} to app port" | |
| ), | |
| ) | |
| print( | |
| "Service Connect ingress allowed from security groups: " | |
| + ", ".join(service_connect_client_sg_ids) | |
| ) | |
| except Exception as e: | |
| raise Exception("Could not handle security groups due to:", e) | |
| endpoint_subnet_selection = resolve_ecs_vpc_endpoint_subnet_selection( | |
| use_express_ingress=use_express_ingress, | |
| express_use_public_subnets=ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True", | |
| public_subnets=self.public_subnets, | |
| private_subnets=self.private_subnets, | |
| ) | |
| s3_gateway_subnet_selection = resolve_ecs_s3_gateway_subnet_selection( | |
| public_subnets=self.public_subnets, | |
| private_subnets=self.private_subnets, | |
| ) | |
| if ENABLE_ECS_VPC_INTERFACE_ENDPOINTS == "True" and ( | |
| endpoint_subnet_selection or s3_gateway_subnet_selection | |
| ): | |
| if ( | |
| VPC_NAME | |
| and not imported_vpc_cidr_block | |
| and not imported_vpc_cidr_blocks | |
| ): | |
| raise ValueError( | |
| "vpc_cidr_block / vpc_cidr_blocks missing from precheck.context.json. " | |
| "Re-run check_resources.py from the cdk/ directory so the VPC " | |
| "CIDR(s) are stored for VPC endpoints and security groups." | |
| ) | |
| existing_endpoint_services = frozenset( | |
| self.node.try_get_context("existing_vpc_endpoint_service_names") or [] | |
| ) | |
| if VPC_NAME and not existing_endpoint_services: | |
| print( | |
| "Note: existing_vpc_endpoint_service_names not in precheck context; " | |
| "re-run check_resources.py to skip duplicate endpoints in shared VPCs." | |
| ) | |
| try: | |
| endpoint_tier = ( | |
| "public" | |
| if use_express_ingress and ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True" | |
| else "private" | |
| ) | |
| create_ecs_vpc_endpoints_for_private_subnets( | |
| self, | |
| vpc=vpc, | |
| subnets=endpoint_subnet_selection, | |
| s3_gateway_subnets=s3_gateway_subnet_selection, | |
| logical_id_prefix="SummarisationEcs", | |
| include_secrets_and_kms=True, | |
| vpc_cidr_block=imported_vpc_cidr_block, | |
| vpc_cidr_blocks=imported_vpc_cidr_blocks or None, | |
| skip_service_names=existing_endpoint_services, | |
| aws_region=AWS_REGION, | |
| ) | |
| s3_subnet_count = len( | |
| (s3_gateway_subnet_selection.subnets or []) | |
| if s3_gateway_subnet_selection | |
| else [] | |
| ) | |
| print( | |
| "Defined ECS VPC interface endpoints (ECR, Logs, Secrets Manager, " | |
| f"KMS) for {endpoint_tier} subnets where not already present; " | |
| f"S3 gateway for {s3_subnet_count} stack subnet(s) (public + " | |
| "private) where not already present." | |
| ) | |
| except Exception as e: | |
| raise Exception( | |
| "Could not create ECS VPC interface endpoints for ECS task subnets. " | |
| "If this VPC already has them, re-run check_resources.py (auto-skip) " | |
| "or set ENABLE_ECS_VPC_INTERFACE_ENDPOINTS=False in cdk_config.env " | |
| "and ensure task subnets reach ECR (NAT, IGW, or existing endpoints).", | |
| e, | |
| ) from e | |
| # --- DynamoDB tables for logs (optional) --- | |
| usage_log_table = None | |
| if SAVE_LOGS_TO_DYNAMODB == "True": | |
| try: | |
| print("Creating DynamoDB tables for logs") | |
| dynamodb.Table( | |
| self, | |
| "SummarisationAccessDataTable", | |
| table_name=ACCESS_LOG_DYNAMODB_TABLE_NAME, | |
| partition_key=dynamodb.Attribute( | |
| name="id", type=dynamodb.AttributeType.STRING | |
| ), | |
| billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST, | |
| deletion_protection=resource_delete_protection, | |
| removal_policy=resource_removal_policy, | |
| ) | |
| dynamodb.Table( | |
| self, | |
| "SummarisationFeedbackDataTable", | |
| table_name=FEEDBACK_LOG_DYNAMODB_TABLE_NAME, | |
| partition_key=dynamodb.Attribute( | |
| name="id", type=dynamodb.AttributeType.STRING | |
| ), | |
| billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST, | |
| deletion_protection=resource_delete_protection, | |
| removal_policy=resource_removal_policy, | |
| ) | |
| usage_log_table = dynamodb.Table( | |
| self, | |
| "SummarisationUsageDataTable", | |
| table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, | |
| partition_key=dynamodb.Attribute( | |
| name="id", type=dynamodb.AttributeType.STRING | |
| ), | |
| billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST, | |
| deletion_protection=resource_delete_protection, | |
| removal_policy=resource_removal_policy, | |
| ) | |
| except Exception as e: | |
| raise Exception("Could not create DynamoDB tables due to:", e) | |
| if ENABLE_DYNAMODB_USAGE_LOG_EXPORT == "True": | |
| try: | |
| if usage_log_table is None: | |
| raise ValueError( | |
| "ENABLE_DYNAMODB_USAGE_LOG_EXPORT=True requires " | |
| "SAVE_LOGS_TO_DYNAMODB=True and the usage log table." | |
| ) | |
| lambda_asset_dir = os.path.join( | |
| os.path.dirname(__file__), "lambda_dynamo_logs_export" | |
| ) | |
| create_dynamo_usage_log_export_lambda( | |
| self, | |
| "DynamoUsageLogExport", | |
| function_name=DYNAMODB_USAGE_LOG_EXPORT_LAMBDA_NAME or None, | |
| lambda_asset_path=lambda_asset_dir, | |
| dynamodb_table=usage_log_table, | |
| output_bucket=output_bucket, | |
| s3_output_key=DYNAMODB_USAGE_LOG_EXPORT_S3_KEY, | |
| schedule_expression=DYNAMODB_USAGE_LOG_EXPORT_SCHEDULE, | |
| dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, | |
| date_attribute=DYNAMODB_USAGE_LOG_EXPORT_DATE_ATTRIBUTE, | |
| output_filename=DYNAMODB_USAGE_LOG_EXPORT_OUTPUT_FILENAME, | |
| shared_kms_key_arn=shared_kms_key_arn, | |
| ) | |
| print( | |
| "Scheduled DynamoDB usage log export Lambda defined " | |
| f"({DYNAMODB_USAGE_LOG_EXPORT_SCHEDULE})." | |
| ) | |
| except Exception as e: | |
| raise Exception( | |
| "Could not handle DynamoDB usage log export Lambda due to:", e | |
| ) from e | |
| alb = None | |
| load_balancer_name = ALB_NAME | |
| if len(load_balancer_name) > 32: | |
| load_balancer_name = load_balancer_name[-32:] | |
| if deploy_web_ingress: | |
| # --- ALB (legacy path) --- | |
| try: | |
| alb_arn = get_context_str(f"arn:{load_balancer_name}") or ( | |
| EXISTING_LOAD_BALANCER_ARN or None | |
| ) | |
| alb_dns_name = get_context_str(f"dns:{load_balancer_name}") or ( | |
| EXISTING_LOAD_BALANCER_DNS or None | |
| ) | |
| if alb_arn and alb_dns_name: | |
| alb_security_group_id = ( | |
| get_context_str(f"security_group_id:{load_balancer_name}") | |
| or alb_security_group.security_group_id | |
| ) | |
| alb_attrs = { | |
| "load_balancer_arn": alb_arn, | |
| "load_balancer_dns_name": alb_dns_name, | |
| "security_group_id": alb_security_group_id, | |
| "vpc": vpc, | |
| } | |
| alb_canonical_zone_id = get_context_str( | |
| f"canonical_hosted_zone_id:{load_balancer_name}" | |
| ) | |
| if alb_canonical_zone_id: | |
| alb_attrs["load_balancer_canonical_hosted_zone_id"] = ( | |
| alb_canonical_zone_id | |
| ) | |
| alb = elbv2.ApplicationLoadBalancer.from_application_load_balancer_attributes( | |
| self, | |
| "ALB", | |
| **alb_attrs, | |
| ) | |
| print( | |
| f"Using existing Application Load Balancer {load_balancer_name}." | |
| ) | |
| else: | |
| alb = elbv2.ApplicationLoadBalancer( | |
| self, | |
| "ALB", | |
| load_balancer_name=load_balancer_name, | |
| vpc=vpc, | |
| internet_facing=True, | |
| security_group=alb_security_group, | |
| vpc_subnets=public_subnet_selection, | |
| drop_invalid_header_fields=True, | |
| deletion_protection=resource_delete_protection, | |
| ) | |
| print("Successfully created new Application Load Balancer") | |
| except Exception as e: | |
| raise Exception("Could not handle application load balancer due to:", e) | |
| # --- Cognito User Pool (web login; skipped for headless batch-only) --- | |
| user_pool = None | |
| user_pool_client = None | |
| user_pool_domain = None | |
| secret = None | |
| if enable_headless: | |
| print( | |
| "ENABLE_HEADLESS_DEPLOYMENT=True: skipping Cognito user pool, " | |
| "hosted UI domain, and client secret (no web login for batch tasks)." | |
| ) | |
| else: | |
| try: | |
| if get_context_bool(f"exists:{COGNITO_USER_POOL_NAME}"): | |
| # Lookup by ID from context | |
| user_pool_id = get_context_str(f"id:{COGNITO_USER_POOL_NAME}") | |
| if not user_pool_id: | |
| raise ValueError( | |
| f"Context value 'id:{COGNITO_USER_POOL_NAME}' is required if User Pool exists." | |
| ) | |
| user_pool = cognito.UserPool.from_user_pool_id( | |
| self, "UserPool", user_pool_id=user_pool_id | |
| ) | |
| print(f"Using existing user pool {user_pool_id}.") | |
| else: | |
| user_pool = cognito.UserPool( | |
| self, | |
| "UserPool", | |
| user_pool_name=COGNITO_USER_POOL_NAME, | |
| mfa=cognito.Mfa.OFF, # Adjust as needed | |
| sign_in_aliases=cognito.SignInAliases(email=True), | |
| deletion_protection=resource_delete_protection, | |
| removal_policy=resource_removal_policy, | |
| ) # Adjust as needed | |
| print(f"Created new user pool {user_pool.user_pool_id}.") | |
| # HTTPS ALB (ACM cert or Express Mode) needs oauth2/idpresponse callback URLs. | |
| if ACM_SSL_CERTIFICATE_ARN or use_express_ingress: | |
| redirect_uris = [ | |
| COGNITO_REDIRECTION_URL, | |
| COGNITO_REDIRECTION_URL + "/oauth2/idpresponse", | |
| ] | |
| else: | |
| redirect_uris = [COGNITO_REDIRECTION_URL] | |
| user_pool_client_name = COGNITO_USER_POOL_CLIENT_NAME | |
| if get_context_bool(f"exists:{user_pool_client_name}"): | |
| # Lookup by ID from context (requires User Pool object) | |
| user_pool_client_id = get_context_str(f"id:{user_pool_client_name}") | |
| if not user_pool_client_id: | |
| raise ValueError( | |
| f"Context value 'id:{user_pool_client_name}' is required if User Pool Client exists." | |
| ) | |
| user_pool_client = cognito.UserPoolClient.from_user_pool_client_id( | |
| self, "UserPoolClient", user_pool_client_id=user_pool_client_id | |
| ) | |
| print(f"Using existing user pool client {user_pool_client_id}.") | |
| else: | |
| user_pool_client = cognito.UserPoolClient( | |
| self, | |
| "UserPoolClient", | |
| auth_flows=cognito.AuthFlow( | |
| user_srp=True, user_password=True | |
| ), # Example: enable SRP for secure sign-in | |
| user_pool=user_pool, | |
| generate_secret=True, | |
| user_pool_client_name=user_pool_client_name, | |
| supported_identity_providers=[ | |
| cognito.UserPoolClientIdentityProvider.COGNITO | |
| ], | |
| o_auth=cognito.OAuthSettings( | |
| flows=cognito.OAuthFlows(authorization_code_grant=True), | |
| scopes=[ | |
| cognito.OAuthScope.OPENID, | |
| cognito.OAuthScope.EMAIL, | |
| cognito.OAuthScope.PROFILE, | |
| ], | |
| callback_urls=redirect_uris, | |
| ), | |
| refresh_token_validity=Duration.minutes( | |
| COGNITO_REFRESH_TOKEN_VALIDITY | |
| ), | |
| id_token_validity=Duration.minutes(COGNITO_ID_TOKEN_VALIDITY), | |
| access_token_validity=Duration.minutes( | |
| COGNITO_ACCESS_TOKEN_VALIDITY | |
| ), | |
| ) | |
| CfnOutput( | |
| self, | |
| "CognitoAppClientId", | |
| value=user_pool_client.user_pool_client_id, | |
| ) | |
| print( | |
| f"Created new user pool client {user_pool_client.user_pool_client_id}." | |
| ) | |
| # Add a domain to the User Pool (crucial for ALB integration) | |
| domain_prefix = (COGNITO_USER_POOL_DOMAIN_PREFIX or "").strip().lower() | |
| if get_context_bool(f"cognito_domain_taken:{domain_prefix}"): | |
| raise ValueError( | |
| f"Cognito hosted UI domain prefix {domain_prefix!r} is not " | |
| f"available in this region (taken by another AWS account or " | |
| "an existing pool). Set COGNITO_USER_POOL_DOMAIN_PREFIX in " | |
| "cdk/config/cdk_config.env to a unique value and re-run " | |
| "cdk_install.py / check_resources.py." | |
| ) | |
| user_pool_domain = user_pool.add_domain( | |
| "UserPoolDomain", | |
| cognito_domain=cognito.CognitoDomainOptions( | |
| domain_prefix=COGNITO_USER_POOL_DOMAIN_PREFIX | |
| ), | |
| ) | |
| # Apply removal_policy to the created UserPoolDomain construct | |
| user_pool_domain.apply_removal_policy(policy=resource_removal_policy) | |
| CfnOutput( | |
| self, "CognitoUserPoolLoginUrl", value=user_pool_domain.base_url() | |
| ) | |
| except Exception as e: | |
| raise Exception("Could not handle Cognito resources due to:", e) | |
| # --- Secrets Manager Secret --- | |
| try: | |
| secret_name = COGNITO_USER_POOL_CLIENT_SECRET_NAME | |
| if get_context_bool(f"exists:{secret_name}"): | |
| secret_arn = get_context_str(f"arn:{secret_name}") | |
| if secret_arn: | |
| secret = secretsmanager.Secret.from_secret_complete_arn( | |
| self, | |
| "CognitoSecret", | |
| secret_complete_arn=secret_arn, | |
| ) | |
| print("Using existing Secret (ARN from precheck context).") | |
| else: | |
| secret = secretsmanager.Secret.from_secret_name_v2( | |
| self, "CognitoSecret", secret_name=secret_name | |
| ) | |
| print( | |
| "Using existing Secret by name (IAM grants use ARN wildcard " | |
| "suffix; re-run precheck to pin the full ARN)." | |
| ) | |
| else: | |
| if USE_CUSTOM_KMS_KEY == "1" and isinstance(kms_key, kms.Key): | |
| secret = secretsmanager.Secret( | |
| self, | |
| "CognitoSecret", # Logical ID | |
| secret_name=secret_name, # Explicit resource name | |
| secret_object_value={ | |
| "SUMMARISATION_USER_POOL_ID": SecretValue.unsafe_plain_text( | |
| user_pool.user_pool_id | |
| ), # Use the CDK attribute | |
| "SUMMARISATION_CLIENT_ID": SecretValue.unsafe_plain_text( | |
| user_pool_client.user_pool_client_id | |
| ), # Use the CDK attribute | |
| "SUMMARISATION_CLIENT_SECRET": user_pool_client.user_pool_client_secret, # Use the CDK attribute | |
| }, | |
| encryption_key=kms_key, | |
| removal_policy=resource_removal_policy, | |
| ) | |
| else: | |
| secret = secretsmanager.Secret( | |
| self, | |
| "CognitoSecret", # Logical ID | |
| secret_name=secret_name, # Explicit resource name | |
| secret_object_value={ | |
| "SUMMARISATION_USER_POOL_ID": SecretValue.unsafe_plain_text( | |
| user_pool.user_pool_id | |
| ), # Use the CDK attribute | |
| "SUMMARISATION_CLIENT_ID": SecretValue.unsafe_plain_text( | |
| user_pool_client.user_pool_client_id | |
| ), # Use the CDK attribute | |
| "SUMMARISATION_CLIENT_SECRET": user_pool_client.user_pool_client_secret, # Use the CDK attribute | |
| }, | |
| removal_policy=resource_removal_policy, | |
| ) | |
| print( | |
| "Created new secret in Secrets Manager for Cognito user pool and related details." | |
| ) | |
| except Exception as e: | |
| raise Exception("Could not handle Secrets Manager secret due to:", e) | |
| try: | |
| secret.grant_read(task_role) | |
| secret.grant_read(execution_role) | |
| except Exception as e: | |
| raise Exception("Could not grant access to Secrets Manager due to:", e) | |
| try: | |
| # ECS environmentFiles (app_config.env) are fetched by the execution role at task start. | |
| bucket.grant_read(execution_role, APP_CONFIG_ENV_BASENAME) | |
| # KMS: task role uses shared S3 CMK via build_ecs_task_role_inline_policy; | |
| # execution role uses the secret's CMK via build_ecs_execution_role_kms_policy. | |
| except Exception as e: | |
| raise Exception("Could not grant bucket read to execution role due to:", e) | |
| # --- ECS Cluster (shared by legacy Fargate and Express paths) --- | |
| try: | |
| cluster_kwargs = { | |
| "cluster_name": CLUSTER_NAME, | |
| "enable_fargate_capacity_providers": True, | |
| "vpc": vpc, | |
| } | |
| if enable_service_connect or enable_pi_express: | |
| cluster_kwargs["default_cloud_map_namespace"] = ( | |
| ecs.CloudMapNamespaceOptions( | |
| name=ECS_SERVICE_CONNECT_NAMESPACE, | |
| vpc=vpc, | |
| ) | |
| ) | |
| cluster = ecs.Cluster(self, "ECSCluster", **cluster_kwargs) | |
| print("Successfully created new ECS cluster") | |
| except Exception as e: | |
| raise Exception("Could not handle ECS cluster due to:", e) | |
| express_service = None | |
| express_alb_security_group_id = None | |
| if use_express_ingress: | |
| try: | |
| express_log_group = logs.LogGroup( | |
| self, | |
| "ExpressTaskLogGroup", | |
| log_group_name=f"/ecs/{ECS_EXPRESS_SERVICE_NAME}-logs".lower(), | |
| retention=logs.RetentionDays.ONE_MONTH, | |
| removal_policy=resource_removal_policy, | |
| ) | |
| express_log_group.grant_write(execution_role) | |
| express_infra_role = create_ecs_express_infrastructure_role( | |
| self, | |
| "ExpressInfrastructureRole", | |
| ECS_EXPRESS_INFRASTRUCTURE_ROLE_NAME, | |
| ) | |
| express_app_overrides: Dict[str, str] = {} | |
| if ENABLE_HEADLESS_DEPLOYMENT == "True": | |
| express_app_overrides["COGNITO_AUTH"] = "False" | |
| elif enable_pi_express: | |
| # Pi agent calls main over Service Connect; Gradio auth blocks | |
| # gradio_client unless credentials are passed on every call. | |
| express_app_overrides["COGNITO_AUTH"] = "False" | |
| express_app_environment = load_app_config_env_for_express( | |
| APP_CONFIG_ENV_FILE, | |
| overrides=express_app_overrides or None, | |
| ) | |
| primary_container = build_express_gateway_primary_container( | |
| image_uri=ecr_image_loc + ":latest", | |
| container_port=int(GRADIO_SERVER_PORT), | |
| log_group_name=express_log_group.log_group_name, | |
| aws_region=AWS_REGION, | |
| secret=secret, | |
| environment=express_app_environment, | |
| ) | |
| express_use_public_subnets = ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True" | |
| express_subnet_ids = [ | |
| s.subnet_id | |
| for s in ( | |
| self.public_subnets | |
| if express_use_public_subnets | |
| else self.private_subnets | |
| ) | |
| ] | |
| if not express_subnet_ids: | |
| tier = "public" if express_use_public_subnets else "private" | |
| raise ValueError( | |
| f"No {tier} subnets available for ECS Express Mode. " | |
| f"Set ECS_EXPRESS_USE_PUBLIC_SUBNETS=False to use private " | |
| "subnets (internal ALB only), or create/import public subnets." | |
| ) | |
| if express_use_public_subnets: | |
| print( | |
| "ECS Express Mode using public subnets " | |
| "(internet-facing managed ALB)." | |
| ) | |
| else: | |
| print( | |
| "ECS Express Mode using private subnets " | |
| "(internal managed ALB)." | |
| ) | |
| # MinTaskCount=0 until post_cdk_build_quickstart builds/pushes :latest. | |
| express_service = create_express_gateway_service( | |
| self, | |
| "ExpressGatewayService", | |
| service_name=ECS_EXPRESS_SERVICE_NAME, | |
| cluster_name=CLUSTER_NAME, | |
| execution_role_arn=execution_role.role_arn, | |
| infrastructure_role_arn=express_infra_role.role_arn, | |
| task_role_arn=task_role.role_arn, | |
| cpu=str(ECS_TASK_CPU_SIZE), | |
| memory=str(ECS_TASK_MEMORY_SIZE), | |
| health_check_path=ECS_EXPRESS_HEALTH_CHECK_PATH, | |
| primary_container=primary_container, | |
| subnet_ids=express_subnet_ids, | |
| security_group_ids=[ecs_security_group.security_group_id], | |
| ) | |
| express_service.node.add_dependency(cluster) | |
| allow_express_load_balancer_to_ecs_security_group( | |
| self, | |
| "ExpressAlbToEcsIngress", | |
| express_service=express_service, | |
| ecs_security_group=ecs_security_group, | |
| container_port=int(GRADIO_SERVER_PORT), | |
| ) | |
| express_alb_arn = express_ingress_load_balancer_arn(express_service) | |
| express_alb_dns = express_service.attr_endpoint | |
| express_alb_security_group_id = ( | |
| express_ingress_first_load_balancer_security_group(express_service) | |
| ) | |
| alb = elbv2.ApplicationLoadBalancer.from_application_load_balancer_attributes( | |
| self, | |
| "ALB", | |
| load_balancer_arn=express_alb_arn, | |
| load_balancer_dns_name=express_alb_dns, | |
| security_group_id=express_alb_security_group_id, | |
| vpc=vpc, | |
| ) | |
| # Express Mode manages host-header listener rules (priorities 1, 2, …). | |
| # Do not add ALB authenticate-cognito rules here; use in-app COGNITO_AUTH. | |
| CfnOutput( | |
| self, | |
| "ExpressServiceEndpoint", | |
| value=express_service.attr_endpoint, | |
| description="HTTPS URL for the ECS Express Mode service", | |
| ) | |
| CfnOutput( | |
| self, | |
| "ExpressServiceArn", | |
| value=express_service.attr_service_arn, | |
| ) | |
| CfnOutput( | |
| self, | |
| "ExpressManagedCertificateArn", | |
| value=express_service.attr_ecs_managed_resource_arns_ingress_path_certificate_arn, | |
| ) | |
| if enable_pi_express: | |
| try: | |
| pi_express_log_group = logs.LogGroup( | |
| self, | |
| "ExpressPiTaskLogGroup", | |
| log_group_name=f"/ecs/{ECS_PI_EXPRESS_SERVICE_NAME}-logs".lower(), | |
| retention=logs.RetentionDays.ONE_MONTH, | |
| removal_policy=resource_removal_policy, | |
| ) | |
| pi_express_log_group.grant_write(execution_role) | |
| pi_express_security_group = ec2.SecurityGroup( | |
| self, | |
| "ExpressPiSecurityGroup", | |
| vpc=vpc, | |
| security_group_name=ECS_PI_EXPRESS_SECURITY_GROUP_NAME, | |
| description="Pi agent ECS Express tasks", | |
| ) | |
| pi_express_environment = build_pi_express_container_environment( | |
| service_connect_discovery_name=ECS_SERVICE_CONNECT_DISCOVERY_NAME, | |
| main_app_port=int(GRADIO_SERVER_PORT), | |
| pi_gradio_port=int(PI_GRADIO_PORT), | |
| cognito_auth=ENABLE_HEADLESS_DEPLOYMENT != "True", | |
| ) | |
| pi_primary_container = build_express_pi_primary_container( | |
| image_uri=pi_ecr_image_loc + ":latest", | |
| container_port=int(PI_GRADIO_PORT), | |
| log_group_name=pi_express_log_group.log_group_name, | |
| aws_region=AWS_REGION, | |
| environment=pi_express_environment, | |
| secret=secret, | |
| cognito_auth=ENABLE_HEADLESS_DEPLOYMENT != "True", | |
| ) | |
| express_pi_service = create_express_gateway_service( | |
| self, | |
| "ExpressPiGatewayService", | |
| service_name=ECS_PI_EXPRESS_SERVICE_NAME, | |
| cluster_name=CLUSTER_NAME, | |
| execution_role_arn=execution_role.role_arn, | |
| infrastructure_role_arn=express_infra_role.role_arn, | |
| task_role_arn=task_role.role_arn, | |
| cpu=str(ECS_PI_TASK_CPU_SIZE), | |
| memory=str(ECS_PI_TASK_MEMORY_SIZE), | |
| health_check_path=ECS_PI_EXPRESS_HEALTH_CHECK_PATH, | |
| primary_container=pi_primary_container, | |
| subnet_ids=express_subnet_ids, | |
| security_group_ids=[ | |
| pi_express_security_group.security_group_id | |
| ], | |
| ) | |
| express_pi_service.node.add_dependency(cluster) | |
| express_pi_service.node.add_dependency(express_service) | |
| allow_express_load_balancer_to_ecs_security_group( | |
| self, | |
| "ExpressAlbToPiExpressIngress", | |
| express_service=express_pi_service, | |
| ecs_security_group=pi_express_security_group, | |
| container_port=int(PI_GRADIO_PORT), | |
| ) | |
| pi_express_security_group.add_egress_rule( | |
| peer=ecs_security_group, | |
| connection=ec2.Port.tcp(int(GRADIO_SERVER_PORT)), | |
| description="Pi Express (Service Connect) to main summarisation app", | |
| ) | |
| ecs_security_group.add_ingress_rule( | |
| peer=pi_express_security_group, | |
| connection=ec2.Port.tcp(int(GRADIO_SERVER_PORT)), | |
| description="Pi Express (Service Connect) to main summarisation app", | |
| ) | |
| # Service Connect for Express is applied in post_cdk_build_quickstart.py | |
| # after CodeBuild pushes :latest. Express primary containers do not | |
| # define named portMappings at create time; CDK cannot enable SC here. | |
| pi_public_url = format_express_pi_public_url( | |
| express_pi_service.attr_endpoint, | |
| ) | |
| sc_backend = ( | |
| f"http://{ECS_SERVICE_CONNECT_DISCOVERY_NAME}:" | |
| f"{GRADIO_SERVER_PORT}" | |
| ) | |
| CfnOutput( | |
| self, | |
| "PiExpressEndpoint", | |
| value=express_pi_service.attr_endpoint, | |
| description="HTTPS URL for the Pi ECS Express service (AWS-managed cert)", | |
| ) | |
| CfnOutput( | |
| self, | |
| "PiPublicUrl", | |
| value=pi_public_url, | |
| description="Public URL for Pi Express UI (managed HTTPS endpoint)", | |
| ) | |
| CfnOutput( | |
| self, | |
| "PiDocSummarisationBackendUrl", | |
| value=sc_backend, | |
| description="DOC_SUMMARISATION_GRADIO_URL on Pi Express (Service Connect, no Cognito)", | |
| ) | |
| CfnOutput( | |
| self, | |
| "PiExpressServiceName", | |
| value=ECS_PI_EXPRESS_SERVICE_NAME, | |
| ) | |
| CfnOutput( | |
| self, | |
| "ServiceConnectNamespace", | |
| value=ECS_SERVICE_CONNECT_NAMESPACE, | |
| description="Cloud Map namespace for Express Service Connect", | |
| ) | |
| print( | |
| "ECS Express Pi gateway service defined with Service Connect " | |
| f"backend {sc_backend}; public URL: {pi_public_url}." | |
| ) | |
| except Exception as e: | |
| raise Exception( | |
| "Could not handle ECS Express Pi agent due to:", e | |
| ) | |
| print("ECS Express Gateway service defined.") | |
| except Exception as e: | |
| raise Exception("Could not handle ECS Express Mode due to:", e) | |
| if not use_express_ingress: | |
| # --- Fargate Task Definition --- | |
| try: | |
| fargate_task_definition_name = FARGATE_TASK_DEFINITION_NAME | |
| read_only_file_system = ECS_READ_ONLY_FILE_SYSTEM == "True" | |
| if os.path.exists(TASK_DEFINITION_FILE_LOCATION): | |
| with open(TASK_DEFINITION_FILE_LOCATION) as f: # Use correct path | |
| task_def_params = json.load(f) | |
| # Need to ensure taskRoleArn and executionRoleArn in JSON are correct ARN strings | |
| else: | |
| epheremal_storage_volume_name = "appEphemeralVolume" | |
| task_def_params = {} | |
| task_def_params["taskRoleArn"] = ( | |
| task_role.role_arn | |
| ) # Use CDK role object ARN | |
| task_def_params["executionRoleArn"] = ( | |
| execution_role.role_arn | |
| ) # Use CDK role object ARN | |
| task_def_params["memory"] = ECS_TASK_MEMORY_SIZE | |
| task_def_params["cpu"] = ECS_TASK_CPU_SIZE | |
| container_def = { | |
| "name": full_ecr_repo_name, | |
| "image": ecr_image_loc + ":latest", | |
| "essential": True, | |
| "portMappings": [ | |
| { | |
| "containerPort": int(GRADIO_SERVER_PORT), | |
| "hostPort": int(GRADIO_SERVER_PORT), | |
| "protocol": "tcp", | |
| "appProtocol": "http", | |
| } | |
| ], | |
| "logConfiguration": { | |
| "logDriver": "awslogs", | |
| "options": { | |
| "awslogs-group": ECS_LOG_GROUP_NAME, | |
| "awslogs-region": AWS_REGION, | |
| "awslogs-stream-prefix": "ecs", | |
| }, | |
| }, | |
| "environmentFiles": ( | |
| [] | |
| if enable_headless | |
| else [ | |
| { | |
| "value": bucket.bucket_arn | |
| + f"/{APP_CONFIG_ENV_BASENAME}", | |
| "type": "s3", | |
| } | |
| ] | |
| ), | |
| "memoryReservation": int(task_def_params["memory"]) | |
| - 512, # Reserve some memory for the container | |
| "mountPoints": [ | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/home/user/app/logs", | |
| "readOnly": False, | |
| }, | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/home/user/app/feedback", | |
| "readOnly": False, | |
| }, | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/home/user/app/usage", | |
| "readOnly": False, | |
| }, | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/home/user/app/input", | |
| "readOnly": False, | |
| }, | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/home/user/app/output", | |
| "readOnly": False, | |
| }, | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/home/user/app/config", | |
| "readOnly": False, | |
| }, | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/tmp/matplotlib_cache", | |
| "readOnly": False, | |
| }, | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/tmp", | |
| "readOnly": False, | |
| }, | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/var/tmp", | |
| "readOnly": False, | |
| }, | |
| { | |
| "sourceVolume": epheremal_storage_volume_name, | |
| "containerPath": "/tmp/gradio_tmp", | |
| "readOnly": False, | |
| }, | |
| ], | |
| "readonlyRootFilesystem": read_only_file_system, | |
| "user": "1000", | |
| } | |
| task_def_params["containerDefinitions"] = [container_def] | |
| log_group_name_from_config = task_def_params["containerDefinitions"][0][ | |
| "logConfiguration" | |
| ]["options"]["awslogs-group"] | |
| cdk_managed_log_group = logs.LogGroup( | |
| self, | |
| "MyTaskLogGroup", # CDK Logical ID | |
| log_group_name=log_group_name_from_config, | |
| retention=logs.RetentionDays.ONE_MONTH, | |
| removal_policy=resource_removal_policy, | |
| ) | |
| cdk_managed_log_group.grant_write(execution_role) | |
| epheremal_storage_volume_cdk_obj = ecs.Volume( | |
| name=epheremal_storage_volume_name | |
| ) | |
| fargate_task_definition = ecs.FargateTaskDefinition( | |
| self, | |
| "FargateTaskDefinition", # Logical ID | |
| family=fargate_task_definition_name, | |
| cpu=int(task_def_params["cpu"]), | |
| memory_limit_mib=int(task_def_params["memory"]), | |
| task_role=task_role, | |
| execution_role=execution_role, | |
| runtime_platform=ecs.RuntimePlatform( | |
| cpu_architecture=ecs.CpuArchitecture.X86_64, | |
| operating_system_family=ecs.OperatingSystemFamily.LINUX, | |
| ), | |
| ephemeral_storage_gib=21, # Minimum is 21 GiB | |
| volumes=[epheremal_storage_volume_cdk_obj], | |
| ) | |
| print("Fargate task definition defined.") | |
| # Add container definitions to the task definition object | |
| if task_def_params["containerDefinitions"]: | |
| container_def_params = task_def_params["containerDefinitions"][0] | |
| env_files = [] | |
| if container_def_params.get("environmentFiles"): | |
| for env_file_param in container_def_params["environmentFiles"]: | |
| # Need to parse the ARN to get the bucket object and key | |
| env_file_arn_parts = env_file_param["value"].split(":::") | |
| bucket_name_and_key = env_file_arn_parts[-1] | |
| env_bucket_name, env_key = bucket_name_and_key.split("/", 1) | |
| env_file = ecs.EnvironmentFile.from_bucket(bucket, env_key) | |
| env_files.append(env_file) | |
| container_kwargs: Dict[str, Any] = { | |
| "image": ecs.ContainerImage.from_registry( | |
| container_def_params["image"] | |
| ), | |
| "logging": ecs.LogDriver.aws_logs( | |
| stream_prefix=container_def_params["logConfiguration"][ | |
| "options" | |
| ]["awslogs-stream-prefix"], | |
| log_group=cdk_managed_log_group, | |
| ), | |
| "environment_files": env_files if env_files else None, | |
| "readonly_root_filesystem": read_only_file_system, | |
| "user": container_def_params.get("user", "1000"), | |
| } | |
| if not enable_headless: | |
| container_kwargs["secrets"] = { | |
| "AWS_USER_POOL_ID": ecs.Secret.from_secrets_manager( | |
| secret, "SUMMARISATION_USER_POOL_ID" | |
| ), | |
| "AWS_CLIENT_ID": ecs.Secret.from_secrets_manager( | |
| secret, "SUMMARISATION_CLIENT_ID" | |
| ), | |
| "AWS_CLIENT_SECRET": ecs.Secret.from_secrets_manager( | |
| secret, "SUMMARISATION_CLIENT_SECRET" | |
| ), | |
| } | |
| container = fargate_task_definition.add_container( | |
| container_def_params["name"], | |
| **container_kwargs, | |
| ) | |
| for port_mapping in container_def_params["portMappings"]: | |
| container.add_port_mappings( | |
| ecs.PortMapping( | |
| container_port=int(port_mapping["containerPort"]), | |
| host_port=int(port_mapping["hostPort"]), | |
| name="port-" + str(port_mapping["containerPort"]), | |
| app_protocol=ecs.AppProtocol.http, | |
| protocol=ecs.Protocol.TCP, | |
| ) | |
| ) | |
| container.add_port_mappings( | |
| ecs.PortMapping( | |
| container_port=80, | |
| host_port=80, | |
| name="port-80", | |
| app_protocol=ecs.AppProtocol.http, | |
| protocol=ecs.Protocol.TCP, | |
| ) | |
| ) | |
| if container_def_params.get("mountPoints"): | |
| mount_points = [] | |
| for mount_point in container_def_params["mountPoints"]: | |
| mount_points.append( | |
| ecs.MountPoint( | |
| container_path=mount_point["containerPath"], | |
| read_only=mount_point["readOnly"], | |
| source_volume=epheremal_storage_volume_name, | |
| ) | |
| ) | |
| container.add_mount_points(*mount_points) | |
| except Exception as e: | |
| raise Exception("Could not handle Fargate task definition due to:", e) | |
| ecs_service = None | |
| if deploy_web_ingress: | |
| # --- ECS Service --- | |
| try: | |
| ecs_service_name = ECS_SERVICE_NAME | |
| if ECS_USE_FARGATE_SPOT == "True": | |
| use_fargate_spot = "FARGATE_SPOT" | |
| if ECS_USE_FARGATE_SPOT == "False": | |
| use_fargate_spot = "FARGATE" | |
| # Check if service exists - from_service_arn or from_service_name (needs cluster) | |
| try: | |
| # from_service_name is useful if you have the cluster object | |
| ecs_service = ecs.FargateService.from_service_attributes( | |
| self, | |
| "ECSService", # Logical ID | |
| cluster=cluster, # Requires the cluster object | |
| service_name=ecs_service_name, | |
| ) | |
| print(f"Using existing ECS service {ecs_service_name}.") | |
| if enable_service_connect: | |
| print( | |
| "Warning: ENABLE_ECS_SERVICE_CONNECT=True but an existing " | |
| "ECS service was imported; enable Service Connect on that " | |
| "service in the ECS console or replace the service via CDK." | |
| ) | |
| except Exception: | |
| service_connect_configuration = None | |
| if enable_service_connect: | |
| sc_dns_name = ( | |
| ECS_SERVICE_CONNECT_DNS_NAME | |
| or ECS_SERVICE_CONNECT_DISCOVERY_NAME | |
| ) | |
| service_connect_configuration = ecs.ServiceConnectProps( | |
| namespace=ECS_SERVICE_CONNECT_NAMESPACE, | |
| services=[ | |
| ecs.ServiceConnectService( | |
| port_mapping_name=ECS_SERVICE_CONNECT_PORT_MAPPING_NAME, | |
| discovery_name=ECS_SERVICE_CONNECT_DISCOVERY_NAME, | |
| dns_name=sc_dns_name, | |
| port=int(GRADIO_SERVER_PORT), | |
| ) | |
| ], | |
| ) | |
| # Service will be created with a count of 0, because you haven't yet actually built the initial Docker container with CodeBuild | |
| ecs_service = ecs.FargateService( | |
| self, | |
| "ECSService", # Logical ID | |
| service_name=ecs_service_name, # Explicit resource name | |
| platform_version=ecs.FargatePlatformVersion.LATEST, | |
| capacity_provider_strategies=[ | |
| ecs.CapacityProviderStrategy( | |
| capacity_provider=use_fargate_spot, base=0, weight=1 | |
| ) | |
| ], | |
| cluster=cluster, | |
| task_definition=fargate_task_definition, # Link to TD | |
| security_groups=[ecs_security_group], # Link to SG | |
| vpc_subnets=ec2.SubnetSelection( | |
| subnets=self.private_subnets | |
| ), # Link to subnets | |
| min_healthy_percent=0, | |
| max_healthy_percent=600, | |
| desired_count=0, | |
| availability_zone_rebalancing=ecs_availability_zone_rebalancing( | |
| ECS_AVAILABILITY_ZONE_REBALANCING | |
| ), | |
| service_connect_configuration=service_connect_configuration, | |
| ) | |
| print("Successfully created new ECS service") | |
| # Note: Auto-scaling setup would typically go here if needed for the service | |
| except Exception as e: | |
| raise Exception("Could not handle ECS service due to:", e) | |
| if enable_pi_agent: | |
| try: | |
| pi_ecs_service, pi_ecs_security_group, _pi_task_def = ( | |
| create_pi_agent_ecs_resources( | |
| self, | |
| "PiAgent", | |
| vpc=vpc, | |
| cluster=cluster, | |
| private_subnets=self.private_subnets, | |
| pi_ecr_image_uri=pi_ecr_image_loc, | |
| container_name=ECR_PI_REPO_NAME, | |
| task_role=task_role, | |
| execution_role=execution_role, | |
| config_bucket=bucket, | |
| pi_agent_env_s3_key=PI_AGENT_ENV_S3_KEY, | |
| service_name=ECS_PI_SERVICE_NAME, | |
| task_family=ECS_PI_TASK_DEFINITION_NAME, | |
| security_group_name=ECS_PI_SECURITY_GROUP_NAME, | |
| log_group_name=ECS_PI_LOG_GROUP_NAME, | |
| cpu=int(ECS_PI_TASK_CPU_SIZE), | |
| memory_mib=int(ECS_PI_TASK_MEMORY_SIZE), | |
| pi_gradio_port=int(PI_GRADIO_PORT), | |
| service_connect_namespace=ECS_SERVICE_CONNECT_NAMESPACE, | |
| service_connect_discovery_name=ECS_SERVICE_CONNECT_DISCOVERY_NAME, | |
| main_app_port=int(GRADIO_SERVER_PORT), | |
| use_fargate_spot=use_fargate_spot, | |
| pi_root_path=pi_alb_root_path_for_container( | |
| PI_ALB_PATH_PREFIX_NORMALIZED, PI_ALB_ROUTING | |
| ), | |
| ) | |
| ) | |
| ecs_security_group.add_ingress_rule( | |
| peer=pi_ecs_security_group, | |
| connection=ec2_port_gradio_server_port, | |
| description="Pi agent (Service Connect) to main summarisation app", | |
| ) | |
| print("Pi agent ECS service defined.") | |
| except Exception as e: | |
| raise Exception("Could not handle Pi agent ECS service due to:", e) | |
| if ENABLE_S3_BATCH_ECS_TRIGGER == "True": | |
| try: | |
| batch_subnet_ids = [s.subnet_id for s in self.private_subnets] | |
| if not batch_subnet_ids: | |
| batch_subnet_ids = [s.subnet_id for s in self.public_subnets] | |
| if not batch_subnet_ids: | |
| raise ValueError( | |
| "S3 batch ECS trigger requires at least one public or " | |
| "private subnet." | |
| ) | |
| lambda_asset_dir = os.path.join( | |
| os.path.dirname(__file__), "config", "lambda" | |
| ) | |
| batch_lambda = create_s3_batch_ecs_trigger_lambda( | |
| self, | |
| "S3BatchEcsTrigger", | |
| function_name=S3_BATCH_LAMBDA_FUNCTION_NAME or None, | |
| lambda_asset_path=lambda_asset_dir, | |
| output_bucket=output_bucket, | |
| config_bucket=bucket, | |
| cluster_name=CLUSTER_NAME, | |
| task_definition_arn=fargate_task_definition.task_definition_arn, | |
| container_name=full_ecr_repo_name, | |
| subnet_ids=batch_subnet_ids, | |
| security_group_id=ecs_security_group.security_group_id, | |
| execution_role=execution_role, | |
| task_role=task_role, | |
| env_prefix=S3_BATCH_ENV_PREFIX, | |
| env_suffix=S3_BATCH_ENV_SUFFIX, | |
| input_prefix=S3_BATCH_INPUT_PREFIX, | |
| config_prefix=S3_BATCH_CONFIG_PREFIX, | |
| default_params_key=S3_BATCH_DEFAULT_PARAMS_KEY, | |
| general_env_prefix=S3_BATCH_GENERAL_ENV_PREFIX, | |
| default_task_type="extract", | |
| assign_public_ip=not bool(self.private_subnets), | |
| ) | |
| CfnOutput( | |
| self, | |
| "BatchEcsTriggerLambdaArn", | |
| value=batch_lambda.function_arn, | |
| description="Lambda ARN for S3-triggered batch ECS tasks", | |
| ) | |
| CfnOutput( | |
| self, | |
| "BatchJobEnvPrefix", | |
| value=f"s3://{output_bucket.bucket_name}/{S3_BATCH_ENV_PREFIX}", | |
| description="Upload job .env files here to start batch topic-modelling tasks", | |
| ) | |
| CfnOutput( | |
| self, | |
| "BatchInputPrefix", | |
| value=f"s3://{output_bucket.bucket_name}/{S3_BATCH_INPUT_PREFIX}", | |
| description="Upload consultation spreadsheets and other input files for batch jobs", | |
| ) | |
| CfnOutput( | |
| self, | |
| "BatchEcsTriggerLambdaName", | |
| value=batch_lambda.function_name, | |
| description="Lambda that starts ECS batch tasks on job .env upload", | |
| ) | |
| if enable_headless: | |
| seed_asset_dir = os.path.join( | |
| os.path.dirname(__file__), "config", "headless_s3_seed" | |
| ) | |
| create_headless_s3_batch_seed( | |
| self, | |
| "HeadlessBatchS3Seed", | |
| destination_bucket=output_bucket, | |
| seed_asset_directory=seed_asset_dir, | |
| s3_outputs_bucket_name=output_bucket.bucket_name, | |
| ) | |
| if ( | |
| enable_headless | |
| and ENABLE_HEADLESS_OUTPUT_NOTIFICATIONS == "True" | |
| ): | |
| create_headless_output_notifications( | |
| self, | |
| "HeadlessOutputNotifications", | |
| output_bucket=output_bucket, | |
| output_prefix=HEADLESS_OUTPUT_S3_PREFIX, | |
| notify_email=HEADLESS_OUTPUT_NOTIFY_EMAIL, | |
| iam_user_name=HEADLESS_OUTPUT_IAM_USER_NAME, | |
| metric_filter_id=HEADLESS_OUTPUT_S3_METRIC_FILTER_ID, | |
| sns_topic_name=HEADLESS_OUTPUT_SNS_TOPIC_NAME, | |
| alarm_name=HEADLESS_OUTPUT_ALARM_NAME, | |
| kms_key_arn=shared_kms_key_arn, | |
| ) | |
| print( | |
| "Headless output notifications enabled: S3 PutRequests " | |
| f"alarm -> SNS ({HEADLESS_OUTPUT_NOTIFY_EMAIL})." | |
| ) | |
| print("S3 batch ECS trigger Lambda defined.") | |
| except Exception as e: | |
| raise Exception("Could not handle S3 batch ECS trigger due to:", e) | |
| if deploy_web_ingress: | |
| # --- ALB TARGET GROUPS AND LISTENERS --- | |
| # This section should primarily define the resources if they are managed by this stack. | |
| # CDK handles adding/removing targets and actions on updates. | |
| # If they might pre-exist outside the stack, you need lookups. | |
| cookie_duration = Duration.hours(8) | |
| target_group_name = ALB_TARGET_GROUP_NAME # Explicit resource name | |
| cloudfront_distribution_url = "cloudfront_placeholder.net" # Need to replace this afterwards with the actual cloudfront_distribution.domain_name | |
| cloudfront_http_rule_priority = ( | |
| PI_ALB_LISTENER_RULE_PRIORITY | |
| + (pi_listener_rule_count(PI_ALB_ROUTING) if enable_pi_agent else 0) | |
| if enable_pi_agent | |
| else 1 | |
| ) | |
| https_listener = None | |
| try: | |
| # --- CREATING TARGET GROUPS AND ADDING THE CLOUDFRONT LISTENER RULE --- | |
| target_group = elbv2.ApplicationTargetGroup( | |
| self, | |
| "AppTargetGroup", # Logical ID | |
| target_group_name=target_group_name, # Explicit resource name | |
| port=int(GRADIO_SERVER_PORT), # Ensure port is int | |
| protocol=elbv2.ApplicationProtocol.HTTP, | |
| targets=[ecs_service], # Link to ECS Service | |
| stickiness_cookie_duration=cookie_duration, | |
| vpc=vpc, # Target Groups need VPC | |
| ) | |
| print(f"ALB target group {target_group_name} defined.") | |
| # First HTTP | |
| listener_port = 80 | |
| # Check if Listener exists - from_listener_arn or lookup by port/ALB | |
| http_listener = alb.add_listener( | |
| "HttpListener", # Logical ID | |
| port=listener_port, | |
| open=False, # Be cautious with open=True, usually restrict source SG | |
| ) | |
| print(f"ALB listener on port {listener_port} defined.") | |
| if ACM_SSL_CERTIFICATE_ARN: | |
| http_listener.add_action( | |
| "DefaultAction", # Logical ID for the default action | |
| action=elbv2.ListenerAction.redirect( | |
| protocol="HTTPS", | |
| host="#{host}", | |
| port="443", | |
| path="/#{path}", | |
| query="#{query}", | |
| ), | |
| ) | |
| else: | |
| if USE_CLOUDFRONT == "True": | |
| # The following default action can be added for the listener after a host header rule is added to the listener manually in the Console as suggested in the above comments. | |
| http_listener.add_action( | |
| "DefaultAction", # Logical ID for the default action | |
| action=elbv2.ListenerAction.fixed_response( | |
| status_code=403, | |
| content_type="text/plain", | |
| message_body="Access denied", | |
| ), | |
| ) | |
| # Add the Listener Rule for the specific CloudFront Host Header | |
| http_listener.add_action( | |
| "CloudFrontHostHeaderRule", | |
| action=elbv2.ListenerAction.forward( | |
| target_groups=[target_group], | |
| stickiness_duration=cookie_duration, | |
| ), | |
| priority=cloudfront_http_rule_priority, | |
| conditions=[ | |
| elbv2.ListenerCondition.host_headers( | |
| [cloudfront_distribution_url] | |
| ) # May have to redefine url in console afterwards if not specified in config file | |
| ], | |
| ) | |
| else: | |
| # Add the Listener Rule for the specific CloudFront Host Header | |
| http_listener.add_action( | |
| "CloudFrontHostHeaderRule", | |
| action=elbv2.ListenerAction.forward( | |
| target_groups=[target_group], | |
| stickiness_duration=cookie_duration, | |
| ), | |
| priority=cloudfront_http_rule_priority, | |
| ) | |
| print("Added targets and actions to ALB HTTP listener.") | |
| # Now the same for HTTPS if you have an ACM certificate | |
| if ACM_SSL_CERTIFICATE_ARN: | |
| listener_port_https = 443 | |
| # Check if Listener exists - from_listener_arn or lookup by port/ALB | |
| https_listener = add_alb_https_listener_with_cert( | |
| self, | |
| "MyHttpsListener", # Logical ID for the HTTPS listener | |
| alb, | |
| acm_certificate_arn=ACM_SSL_CERTIFICATE_ARN, | |
| default_target_group=target_group, | |
| enable_cognito_auth=True, | |
| cognito_user_pool=user_pool, | |
| cognito_user_pool_client=user_pool_client, | |
| cognito_user_pool_domain=user_pool_domain, | |
| listener_open_to_internet=True, | |
| stickiness_cookie_duration=cookie_duration, | |
| ) | |
| if https_listener: | |
| CfnOutput( | |
| self, | |
| "HttpsListenerArn", | |
| value=https_listener.listener_arn, | |
| ) | |
| print(f"ALB listener on port {listener_port_https} defined.") | |
| # if USE_CLOUDFRONT == 'True': | |
| # # Add default action to the listener | |
| # https_listener.add_action( | |
| # "DefaultAction", # Logical ID for the default action | |
| # action=elbv2.ListenerAction.fixed_response( | |
| # status_code=403, | |
| # content_type="text/plain", | |
| # message_body="Access denied", | |
| # ), | |
| # ) | |
| # # Add the Listener Rule for the specific CloudFront Host Header | |
| # https_listener.add_action( | |
| # "CloudFrontHostHeaderRuleHTTPS", | |
| # action=elbv2.ListenerAction.forward(target_groups=[target_group],stickiness_duration=cookie_duration), | |
| # priority=1, # Example priority. Adjust as needed. Lower is evaluated first. | |
| # conditions=[ | |
| # elbv2.ListenerCondition.host_headers([cloudfront_distribution_url]) | |
| # ] | |
| # ) | |
| # else: | |
| # https_listener.add_action( | |
| # "CloudFrontHostHeaderRuleHTTPS", | |
| # action=elbv2.ListenerAction.forward(target_groups=[target_group],stickiness_duration=cookie_duration)) | |
| print("Added targets and actions to ALB HTTPS listener.") | |
| if enable_pi_agent and pi_ecs_service and alb_security_group: | |
| pi_tg_name = PI_ALB_TARGET_GROUP_NAME | |
| if len(pi_tg_name) > 32: | |
| pi_tg_name = pi_tg_name[-32:] | |
| _pi_public_urls = format_pi_public_urls( | |
| routing_mode=PI_ALB_ROUTING, | |
| path_prefix=PI_ALB_PATH_PREFIX_NORMALIZED, | |
| host_header=PI_ALB_HOST_HEADER, | |
| cloudfront_domain=( | |
| CLOUDFRONT_DOMAIN if USE_CLOUDFRONT == "True" else "" | |
| ), | |
| use_https=bool(ACM_SSL_CERTIFICATE_ARN), | |
| ) | |
| attach_pi_agent_to_shared_alb( | |
| self, | |
| "PiAgent", | |
| vpc=vpc, | |
| alb_security_group=alb_security_group, | |
| pi_security_group=pi_ecs_security_group, | |
| pi_service=pi_ecs_service, | |
| pi_port=int(PI_GRADIO_PORT), | |
| routing_mode=PI_ALB_ROUTING, | |
| path_prefix=PI_ALB_PATH_PREFIX_NORMALIZED, | |
| pi_host_header=PI_ALB_HOST_HEADER.strip(), | |
| listener_rule_priority=PI_ALB_LISTENER_RULE_PRIORITY, | |
| target_group_name=pi_tg_name, | |
| stickiness_cookie_duration=cookie_duration, | |
| https_listener=https_listener, | |
| http_listener=http_listener, | |
| acm_certificate_arn=ACM_SSL_CERTIFICATE_ARN or "", | |
| enable_cognito_auth=bool(ACM_SSL_CERTIFICATE_ARN), | |
| cognito_user_pool=user_pool, | |
| cognito_user_pool_client=user_pool_client, | |
| cognito_user_pool_domain=user_pool_domain, | |
| ) | |
| pi_public_url = _pi_public_urls[0] if _pi_public_urls else "" | |
| CfnOutput( | |
| self, | |
| "PiPublicUrl", | |
| value=pi_public_url, | |
| description="Primary public URL for Pi agent UI (path and/or host ALB rules)", | |
| ) | |
| if len(_pi_public_urls) > 1: | |
| CfnOutput( | |
| self, | |
| "PiPublicUrls", | |
| value=", ".join(_pi_public_urls), | |
| description="All configured Pi UI entry URLs", | |
| ) | |
| CfnOutput( | |
| self, | |
| "PiAlbPathPrefix", | |
| value=PI_ALB_PATH_PREFIX_NORMALIZED, | |
| description="ALB path prefix for Pi when PI_ALB_ROUTING includes path", | |
| ) | |
| CfnOutput( | |
| self, | |
| "PiAgentServiceName", | |
| value=ECS_PI_SERVICE_NAME, | |
| ) | |
| sc_backend = ( | |
| f"http://{ECS_SERVICE_CONNECT_DISCOVERY_NAME}:" | |
| f"{GRADIO_SERVER_PORT}" | |
| ) | |
| CfnOutput( | |
| self, | |
| "PiDocSummarisationBackendUrl", | |
| value=sc_backend, | |
| description="DOC_SUMMARISATION_GRADIO_URL set on Pi tasks (Service Connect)", | |
| ) | |
| print( | |
| "Pi agent attached to shared ALB " | |
| f"(routing={PI_ALB_ROUTING}, urls={', '.join(_pi_public_urls)})." | |
| ) | |
| except Exception as e: | |
| raise Exception( | |
| "Could not handle ALB target groups and listeners due to:", e | |
| ) | |
| if not enable_headless: | |
| # Create WAF to attach to load balancer | |
| try: | |
| web_acl_name = LOAD_BALANCER_WEB_ACL_NAME | |
| if get_context_bool(f"exists:{web_acl_name}"): | |
| # Lookup WAF ACL by ARN from context | |
| web_acl_arn = get_context_str(f"arn:{web_acl_name}") | |
| if not web_acl_arn: | |
| raise ValueError( | |
| f"Context value 'arn:{web_acl_name}' is required if Web ACL exists." | |
| ) | |
| web_acl = create_web_acl_with_common_rules( | |
| self, web_acl_name, waf_scope="REGIONAL" | |
| ) # Assuming it takes scope and name | |
| print(f"Handled ALB WAF web ACL {web_acl_name}.") | |
| else: | |
| web_acl = create_web_acl_with_common_rules( | |
| self, web_acl_name, waf_scope="REGIONAL" | |
| ) # Assuming it takes scope and name | |
| print(f"Created ALB WAF web ACL {web_acl_name}.") | |
| wafv2.CfnWebACLAssociation( | |
| self, | |
| id="alb_waf_association", | |
| resource_arn=alb.load_balancer_arn, | |
| web_acl_arn=web_acl.attr_arn, | |
| ) | |
| except Exception as e: | |
| raise Exception("Could not handle create ALB WAF web ACL due to:", e) | |
| # --- Outputs for other stacks/regions --- | |
| self.params = dict() | |
| self.params["alb_arn_output"] = alb.load_balancer_arn | |
| if use_express_ingress: | |
| self.params["alb_security_group_id"] = express_alb_security_group_id | |
| else: | |
| self.params["alb_security_group_id"] = ( | |
| alb_security_group.security_group_id | |
| ) | |
| self.params["alb_dns_name"] = alb.load_balancer_dns_name | |
| CfnOutput( | |
| self, | |
| "AlbArnOutput", | |
| value=alb.load_balancer_arn, | |
| description="ARN of the Application Load Balancer", | |
| export_name=f"{self.stack_name}-AlbArn", | |
| ) # Export name must be unique within the account/region | |
| CfnOutput( | |
| self, | |
| "AlbSecurityGroupIdOutput", | |
| value=( | |
| express_alb_security_group_id | |
| if use_express_ingress | |
| else alb_security_group.security_group_id | |
| ), | |
| description="ID of the ALB's Security Group", | |
| export_name=f"{self.stack_name}-AlbSgId", | |
| ) | |
| CfnOutput(self, "ALBName", value=load_balancer_name) | |
| CfnOutput(self, "RegionalAlbDnsName", value=alb.load_balancer_dns_name) | |
| else: | |
| self.params = dict() | |
| CfnOutput( | |
| self, | |
| "HeadlessDeploymentMode", | |
| value="True", | |
| description="Stack deployed for S3-triggered direct-mode batch tasks only", | |
| ) | |
| CfnOutput( | |
| self, | |
| "ECSClusterName", | |
| value=CLUSTER_NAME, | |
| description="ECS cluster used for one-shot Fargate batch tasks", | |
| ) | |
| CfnOutput( | |
| self, | |
| "EcsBatchLogGroup", | |
| value=ECS_LOG_GROUP_NAME, | |
| description=( | |
| "CloudWatch log group for batch tasks (streams appear only after " | |
| "the container starts; init failures may have no stream)" | |
| ), | |
| ) | |
| if not enable_headless and user_pool is not None: | |
| CfnOutput(self, "CognitoPoolId", value=user_pool.user_pool_id) | |
| # Add other outputs if needed | |
| CfnOutput(self, "ECRRepoUri", value=ecr_repo.repository_uri) | |
| if enable_service_connect: | |
| sc_host = ECS_SERVICE_CONNECT_DNS_NAME or ECS_SERVICE_CONNECT_DISCOVERY_NAME | |
| sc_base = f"http://{sc_host}:{GRADIO_SERVER_PORT}" | |
| CfnOutput( | |
| self, | |
| "ServiceConnectHttpBaseUrl", | |
| value=sc_base, | |
| description="Base URL for other ECS services in this cluster (Service Connect)", | |
| ) | |
| CfnOutput( | |
| self, | |
| "ServiceConnectAgentApiUrl", | |
| value=f"{sc_base}/agent", | |
| description="FastAPI Agent API prefix (when RUN_FASTAPI=True in app_config.env)", | |
| ) | |
| CfnOutput( | |
| self, | |
| "ServiceConnectNamespace", | |
| value=ECS_SERVICE_CONNECT_NAMESPACE, | |
| ) | |
| # --- CLOUDFRONT DISTRIBUTION in separate stack (us-east-1 required) --- | |
| class CdkStackCloudfront(Stack): | |
| def __init__( | |
| self, | |
| scope: Construct, | |
| construct_id: str, | |
| alb_arn: str, | |
| alb_sec_group_id: str, | |
| alb_dns_name: str, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__(scope, construct_id, **kwargs) | |
| # --- Helper to get context values --- | |
| def get_context_bool(key: str, default: bool = False) -> bool: | |
| return self.node.try_get_context(key) or default | |
| def get_context_str(key: str, default: str = None) -> str: | |
| return self.node.try_get_context(key) or default | |
| def get_context_dict(scope: Construct, key: str, default: dict = None) -> dict: | |
| return scope.node.try_get_context(key) or default | |
| resource_removal_policy = managed_resource_removal_policy() | |
| print(f"CloudFront Stack: Received ALB ARN: {alb_arn}") | |
| print(f"CloudFront Stack: Received ALB Security Group ID: {alb_sec_group_id}") | |
| if not alb_arn: | |
| raise ValueError("ALB ARN must be provided to CloudFront stack") | |
| if not alb_sec_group_id: | |
| raise ValueError( | |
| "ALB Security Group ID must be provided to CloudFront stack" | |
| ) | |
| # 2. Import the ALB using its ARN | |
| # This imports an existing ALB as a construct in the CloudFront stack's context. | |
| # CloudFormation will understand this reference at deploy time. | |
| alb = elbv2.ApplicationLoadBalancer.from_application_load_balancer_attributes( | |
| self, | |
| "ImportedAlb", | |
| load_balancer_arn=alb_arn, | |
| security_group_id=alb_sec_group_id, | |
| load_balancer_dns_name=alb_dns_name, | |
| ) | |
| try: | |
| web_acl_name = WEB_ACL_NAME | |
| if get_context_bool(f"exists:{web_acl_name}"): | |
| # Lookup WAF ACL by ARN from context | |
| web_acl_arn = get_context_str(f"arn:{web_acl_name}") | |
| if not web_acl_arn: | |
| raise ValueError( | |
| f"Context value 'arn:{web_acl_name}' is required if Web ACL exists." | |
| ) | |
| web_acl = create_web_acl_with_common_rules( | |
| self, web_acl_name | |
| ) # Assuming it takes scope and name | |
| print(f"Handled Cloudfront WAF web ACL {web_acl_name}.") | |
| else: | |
| web_acl = create_web_acl_with_common_rules( | |
| self, web_acl_name | |
| ) # Assuming it takes scope and name | |
| print(f"Created Cloudfront WAF web ACL {web_acl_name}.") | |
| # Add ALB as CloudFront Origin | |
| origin = origins.LoadBalancerV2Origin( | |
| alb, # Use the created or looked-up ALB object | |
| custom_headers={CUSTOM_HEADER: CUSTOM_HEADER_VALUE}, | |
| origin_shield_enabled=False, | |
| protocol_policy=cloudfront.OriginProtocolPolicy.HTTP_ONLY, | |
| ) | |
| if CLOUDFRONT_GEO_RESTRICTION: | |
| geo_restrict = cloudfront.GeoRestriction.allowlist( | |
| CLOUDFRONT_GEO_RESTRICTION | |
| ) | |
| else: | |
| geo_restrict = None | |
| response_headers_policy = None | |
| if CLOUDFRONT_ENABLE_SECURE_RESPONSE_HEADERS == "True": | |
| app_origin, cognito_login_url = resolve_cloudfront_csp_urls( | |
| cognito_redirection_url=COGNITO_REDIRECTION_URL, | |
| cloudfront_domain=CLOUDFRONT_DOMAIN, | |
| cognito_user_pool_domain_prefix=COGNITO_USER_POOL_DOMAIN_PREFIX, | |
| aws_region=AWS_REGION, | |
| cognito_user_pool_login_url=COGNITO_USER_POOL_LOGIN_URL, | |
| ssl_certificate_domain=SSL_CERTIFICATE_DOMAIN, | |
| ) | |
| policy_name = f"{CDK_PREFIX}SecureResponseHeaders"[:128] | |
| response_headers_policy = ( | |
| create_secure_cloudfront_response_headers_policy( | |
| self, | |
| "SecureResponseHeadersPolicy", | |
| policy_name=policy_name, | |
| app_origin=app_origin, | |
| cognito_login_url=cognito_login_url, | |
| ) | |
| ) | |
| print( | |
| "CloudFront secure response headers: " | |
| f"app_origin={app_origin}, cognito_login_url={cognito_login_url}" | |
| ) | |
| default_behavior = cloudfront.BehaviorOptions( | |
| origin=origin, | |
| viewer_protocol_policy=cloudfront.ViewerProtocolPolicy.REDIRECT_TO_HTTPS, | |
| allowed_methods=cloudfront.AllowedMethods.ALLOW_ALL, | |
| cache_policy=cloudfront.CachePolicy.CACHING_DISABLED, | |
| origin_request_policy=cloudfront.OriginRequestPolicy.ALL_VIEWER, | |
| response_headers_policy=response_headers_policy, | |
| ) | |
| cloudfront_distribution = cloudfront.Distribution( | |
| self, | |
| "CloudFrontDistribution", # Logical ID | |
| comment=CLOUDFRONT_DISTRIBUTION_NAME, # Use name as comment for easier identification | |
| geo_restriction=geo_restrict, | |
| default_behavior=default_behavior, | |
| web_acl_id=web_acl.attr_arn, | |
| ) | |
| cloudfront_distribution.apply_removal_policy(resource_removal_policy) | |
| print(f"Cloudfront distribution {CLOUDFRONT_DISTRIBUTION_NAME} defined.") | |
| except Exception as e: | |
| raise Exception("Could not handle Cloudfront distribution due to:", e) | |
| # --- Outputs --- | |
| CfnOutput( | |
| self, "CloudFrontDistributionURL", value=cloudfront_distribution.domain_name | |
| ) | |