llm_topic_modelling / cdk /cdk_stack.py
seanpedrickcase's picture
Sync: Updated CDK code to include possibility of exporting DynamoDB logs to S3 csv output on automated basis
373207f
Raw
History Blame Contribute Delete
151 kB
import json # You might still need json if loading task_definition.json
import os
from typing import Any, Dict, List
from aws_cdk import (
CfnOutput, # <-- Import CfnOutput directly
Duration,
SecretValue,
Stack,
)
from aws_cdk import aws_cloudfront as cloudfront
from aws_cdk import aws_cloudfront_origins as origins
from aws_cdk import aws_codebuild as codebuild
from aws_cdk import aws_cognito as cognito
from aws_cdk import aws_dynamodb as dynamodb # Import the DynamoDB module
from aws_cdk import aws_ec2 as ec2
from aws_cdk import aws_ecr as ecr
from aws_cdk import aws_ecs as ecs
from aws_cdk import aws_elasticloadbalancingv2 as elbv2
from aws_cdk import aws_iam as iam
from aws_cdk import aws_kms as kms
from aws_cdk import aws_logs as logs
from aws_cdk import aws_s3 as s3
from aws_cdk import aws_secretsmanager as secretsmanager
from aws_cdk import aws_wafv2 as wafv2
from cdk_cloudfront_headers import (
create_secure_cloudfront_response_headers_policy,
resolve_cloudfront_csp_urls,
)
from cdk_config import (
ACCESS_LOG_DYNAMODB_TABLE_NAME,
ACM_SSL_CERTIFICATE_ARN,
ALB_NAME,
ALB_NAME_SECURITY_GROUP_NAME,
ALB_TARGET_GROUP_NAME,
APP_CONFIG_ENV_BASENAME,
APP_CONFIG_ENV_FILE,
AWS_ACCOUNT_ID,
AWS_MANAGED_TASK_ROLES_LIST,
AWS_REGION,
CDK_FOLDER,
CDK_PREFIX,
CLOUDFRONT_DISTRIBUTION_NAME,
CLOUDFRONT_DOMAIN,
CLOUDFRONT_ENABLE_SECURE_RESPONSE_HEADERS,
CLOUDFRONT_GEO_RESTRICTION,
CLOUDFRONT_PREFIX_LIST_ID,
CLUSTER_NAME,
CODEBUILD_PI_PROJECT_NAME,
CODEBUILD_PROJECT_NAME,
CODEBUILD_ROLE_NAME,
COGNITO_ACCESS_TOKEN_VALIDITY,
COGNITO_ID_TOKEN_VALIDITY,
COGNITO_REDIRECTION_URL,
COGNITO_REFRESH_TOKEN_VALIDITY,
COGNITO_USER_POOL_CLIENT_NAME,
COGNITO_USER_POOL_CLIENT_SECRET_NAME,
COGNITO_USER_POOL_DOMAIN_PREFIX,
COGNITO_USER_POOL_LOGIN_URL,
COGNITO_USER_POOL_NAME,
CUSTOM_HEADER,
CUSTOM_HEADER_VALUE,
CUSTOM_KMS_KEY_NAME,
DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS,
DYNAMODB_USAGE_LOG_EXPORT_DATE_ATTRIBUTE,
DYNAMODB_USAGE_LOG_EXPORT_LAMBDA_NAME,
DYNAMODB_USAGE_LOG_EXPORT_OUTPUT_FILENAME,
DYNAMODB_USAGE_LOG_EXPORT_S3_KEY,
DYNAMODB_USAGE_LOG_EXPORT_SCHEDULE,
ECR_CDK_REPO_NAME,
ECR_PI_REPO_NAME,
ECS_AVAILABILITY_ZONE_REBALANCING,
ECS_EXECUTION_ROLE_MANAGED_POLICIES,
ECS_EXECUTION_ROLE_POLICY_ARNS,
ECS_EXECUTION_ROLE_POLICY_FILES,
ECS_EXPRESS_HEALTH_CHECK_PATH,
ECS_EXPRESS_INFRASTRUCTURE_ROLE_NAME,
ECS_EXPRESS_SERVICE_NAME,
ECS_EXPRESS_USE_PUBLIC_SUBNETS,
ECS_LOG_GROUP_NAME,
ECS_PI_EXPRESS_HEALTH_CHECK_PATH,
ECS_PI_EXPRESS_SECURITY_GROUP_NAME,
ECS_PI_EXPRESS_SERVICE_NAME,
ECS_PI_LOG_GROUP_NAME,
ECS_PI_SECURITY_GROUP_NAME,
ECS_PI_SERVICE_NAME,
ECS_PI_TASK_CPU_SIZE,
ECS_PI_TASK_DEFINITION_NAME,
ECS_PI_TASK_MEMORY_SIZE,
ECS_READ_ONLY_FILE_SYSTEM,
ECS_SECURITY_GROUP_NAME,
ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_IDS_LIST,
ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_NAMES_TO_LOOKUP,
ECS_SERVICE_CONNECT_CLIENT_SG_NAME_SUFFIX,
ECS_SERVICE_CONNECT_DISCOVERY_NAME,
ECS_SERVICE_CONNECT_DNS_NAME,
ECS_SERVICE_CONNECT_NAMESPACE,
ECS_SERVICE_CONNECT_PORT_MAPPING_NAME,
ECS_SERVICE_NAME,
ECS_TASK_CPU_SIZE,
ECS_TASK_EXECUTION_ROLE_NAME,
ECS_TASK_MEMORY_SIZE,
ECS_TASK_ROLE_NAME,
ECS_USE_FARGATE_SPOT,
ENABLE_DYNAMODB_USAGE_LOG_EXPORT,
ENABLE_ECS_SERVICE_CONNECT,
ENABLE_ECS_VPC_INTERFACE_ENDPOINTS,
ENABLE_HEADLESS_DEPLOYMENT,
ENABLE_HEADLESS_OUTPUT_NOTIFICATIONS,
ENABLE_PI_AGENT_ECS_SERVICE,
ENABLE_PI_AGENT_EXPRESS_SERVICE,
ENABLE_S3_BATCH_ECS_TRIGGER,
EXISTING_IGW_ID,
EXISTING_LOAD_BALANCER_ARN,
EXISTING_LOAD_BALANCER_DNS,
FARGATE_TASK_DEFINITION_NAME,
FEEDBACK_LOG_DYNAMODB_TABLE_NAME,
GITHUB_REPO_BRANCH,
GITHUB_REPO_NAME,
GITHUB_REPO_USERNAME,
GRADIO_SERVER_PORT,
HEADLESS_OUTPUT_ALARM_NAME,
HEADLESS_OUTPUT_IAM_USER_NAME,
HEADLESS_OUTPUT_NOTIFY_EMAIL,
HEADLESS_OUTPUT_S3_METRIC_FILTER_ID,
HEADLESS_OUTPUT_S3_PREFIX,
HEADLESS_OUTPUT_SNS_TOPIC_NAME,
LOAD_BALANCER_WEB_ACL_NAME,
NAT_GATEWAY_NAME,
NEW_VPC_CIDR,
NEW_VPC_DEFAULT_NAME,
PI_AGENT_ENV_S3_KEY,
PI_ALB_HOST_HEADER,
PI_ALB_LISTENER_RULE_PRIORITY,
PI_ALB_PATH_PREFIX_NORMALIZED,
PI_ALB_ROUTING,
PI_ALB_TARGET_GROUP_NAME,
PI_GRADIO_PORT,
POLICY_FILE_ARNS,
POLICY_FILE_LOCATIONS,
PRIVATE_SUBNET_AVAILABILITY_ZONES,
PRIVATE_SUBNET_CIDR_BLOCKS,
PRIVATE_SUBNETS_TO_USE,
PUBLIC_SUBNET_AVAILABILITY_ZONES,
PUBLIC_SUBNET_CIDR_BLOCKS,
PUBLIC_SUBNETS_TO_USE,
S3_BATCH_CONFIG_PREFIX,
S3_BATCH_DEFAULT_PARAMS_KEY,
S3_BATCH_ENV_PREFIX,
S3_BATCH_ENV_SUFFIX,
S3_BATCH_GENERAL_ENV_PREFIX,
S3_BATCH_INPUT_PREFIX,
S3_BATCH_LAMBDA_FUNCTION_NAME,
S3_LOG_CONFIG_BUCKET_NAME,
S3_OUTPUT_BUCKET_NAME,
SAVE_LOGS_TO_DYNAMODB,
SINGLE_NAT_GATEWAY_ID,
SSL_CERTIFICATE_DOMAIN,
TASK_DEFINITION_FILE_LOCATION,
USAGE_LOG_DYNAMODB_TABLE_NAME,
USE_CLOUDFRONT,
USE_CUSTOM_KMS_KEY,
USE_ECS_EXPRESS_MODE,
VPC_NAME,
WEB_ACL_NAME,
)
from cdk_functions import ( # Only keep CDK-native functions
add_alb_https_listener_with_cert,
add_custom_policies,
add_s3_enforce_ssl_policy,
allow_express_load_balancer_to_ecs_security_group,
attach_managed_policy_arns,
attach_pi_agent_to_shared_alb,
build_ecs_execution_role_kms_policy,
build_ecs_task_role_inline_policy,
build_express_gateway_primary_container,
build_express_pi_primary_container,
build_pi_express_container_environment,
configure_public_github_codebuild_source,
create_dynamo_usage_log_export_lambda,
create_ecs_express_infrastructure_role,
create_ecs_vpc_endpoints_for_private_subnets,
create_express_gateway_service,
create_headless_output_notifications,
create_headless_s3_batch_seed,
create_nat_gateway,
create_pi_agent_ecs_resources,
create_s3_batch_ecs_trigger_lambda,
create_subnets,
create_web_acl_with_common_rules,
default_secrets_manager_kms_key_arn,
ecr_empty_on_delete,
ecs_availability_zone_rebalancing,
express_ingress_first_load_balancer_security_group,
express_ingress_load_balancer_arn,
format_express_pi_public_url,
format_pi_public_urls,
load_app_config_env_for_express,
managed_resource_removal_policy,
pi_alb_root_path_for_container,
pi_listener_rule_count,
public_github_codebuild_source,
resolve_ecs_s3_gateway_subnet_selection,
resolve_ecs_vpc_endpoint_subnet_selection,
resolve_policy_file_paths,
resolve_service_connect_client_security_group_ids,
resource_deletion_protection_flag,
s3_auto_delete_objects_on_stack_destroy,
wire_public_subnet_internet_access,
)
from constructs import Construct
def _get_env_list(env_var_name: str) -> List[str]:
"""Parses a comma-separated environment variable into a list of strings."""
value = env_var_name[1:-1].strip().replace('"', "").replace("'", "")
if not value:
return []
# Split by comma and filter out any empty strings that might result from extra commas
return [s.strip() for s in value.split(",") if s.strip()]
# 1. Try to load CIDR/AZs from environment variables
if PUBLIC_SUBNETS_TO_USE:
PUBLIC_SUBNETS_TO_USE = _get_env_list(PUBLIC_SUBNETS_TO_USE)
if PRIVATE_SUBNETS_TO_USE:
PRIVATE_SUBNETS_TO_USE = _get_env_list(PRIVATE_SUBNETS_TO_USE)
if PUBLIC_SUBNET_CIDR_BLOCKS:
PUBLIC_SUBNET_CIDR_BLOCKS = _get_env_list("PUBLIC_SUBNET_CIDR_BLOCKS")
if PUBLIC_SUBNET_AVAILABILITY_ZONES:
PUBLIC_SUBNET_AVAILABILITY_ZONES = _get_env_list("PUBLIC_SUBNET_AVAILABILITY_ZONES")
if PRIVATE_SUBNET_CIDR_BLOCKS:
PRIVATE_SUBNET_CIDR_BLOCKS = _get_env_list("PRIVATE_SUBNET_CIDR_BLOCKS")
if PRIVATE_SUBNET_AVAILABILITY_ZONES:
PRIVATE_SUBNET_AVAILABILITY_ZONES = _get_env_list(
"PRIVATE_SUBNET_AVAILABILITY_ZONES"
)
# AWS_MANAGED_TASK_ROLES_LIST and POLICY_* lists are parsed in cdk_config.py.
class CdkStack(Stack):
def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None:
super().__init__(scope, construct_id, **kwargs)
# --- Helper to get context values ---
def get_context_bool(key: str, default: bool = False) -> bool:
value = self.node.try_get_context(key)
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ("true", "1", "yes")
return bool(value)
def get_context_str(key: str, default: str = None) -> str:
return self.node.try_get_context(key) or default
def get_context_dict(key: str, default: dict = None) -> dict:
return self.node.try_get_context(key) or default
def get_context_list_of_dicts(key: str) -> List[Dict[str, Any]]:
ctx_value = self.node.try_get_context(key)
if not isinstance(ctx_value, list):
print(
f"Warning: Context key '{key}' not found or not a list. Returning empty list."
)
return []
# Optional: Add validation that all items in the list are dicts
return ctx_value
resource_removal_policy = managed_resource_removal_policy()
resource_delete_protection = resource_deletion_protection_flag()
s3_auto_delete_objects = s3_auto_delete_objects_on_stack_destroy()
self.template_options.description = "Deployment of the llm_topic_modeller Gradio app for LLM-based topic modelling. Git repo: https://github.com/seanpedrick-case/llm_topic_modeller."
use_express_ingress = (
not ACM_SSL_CERTIFICATE_ARN and USE_ECS_EXPRESS_MODE == "True"
)
enable_headless = ENABLE_HEADLESS_DEPLOYMENT == "True"
express_public_subnets_only = (
use_express_ingress and ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True"
) or (
enable_headless
and not use_express_ingress
and ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True"
)
deploy_web_ingress = not use_express_ingress and not enable_headless
enable_service_connect = (
ENABLE_ECS_SERVICE_CONNECT == "True" and not use_express_ingress
)
enable_pi_agent = (
ENABLE_PI_AGENT_ECS_SERVICE == "True" and not use_express_ingress
)
enable_pi_express = (
ENABLE_PI_AGENT_EXPRESS_SERVICE == "True" and use_express_ingress
)
enable_pi_build = enable_pi_agent or enable_pi_express
if enable_headless:
print(
"ENABLE_HEADLESS_DEPLOYMENT=True: S3 batch trigger + one-shot Fargate "
"tasks only (no ALB, CloudFront, or always-on ECS service)."
)
elif use_express_ingress:
print(
"USE_ECS_EXPRESS_MODE=True: using ECS Express Mode for HTTPS ingress "
"(no manual ALB/Fargate service)."
)
if express_public_subnets_only:
print(
"ECS_EXPRESS_USE_PUBLIC_SUBNETS=True: Express tasks and VPC "
"endpoints use public subnets only (no private subnet install)."
)
elif enable_headless:
print(
"ENABLE_HEADLESS_DEPLOYMENT=True: batch Fargate tasks use "
"legacy private subnets (or public if configured)."
)
service_connect_client_sg_ids: List[str] = []
if enable_service_connect:
if (
not ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_IDS_LIST
and not ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_NAMES_TO_LOOKUP
and not enable_pi_agent
):
raise ValueError(
"ENABLE_ECS_SERVICE_CONNECT=True requires at least one of "
"ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_IDS, "
"ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_NAMES, or "
"ECS_SERVICE_CONNECT_CLIENT_CDK_PREFIXES (other apps' CDK_PREFIX "
f"values, resolved to {{prefix}}{ECS_SERVICE_CONNECT_CLIENT_SG_NAME_SUFFIX} "
"in this VPC), unless ENABLE_PI_AGENT_ECS_SERVICE=True (Pi SG is wired in-stack)."
)
service_connect_client_sg_ids = (
resolve_service_connect_client_security_group_ids(
ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_IDS_LIST,
ECS_SERVICE_CONNECT_CLIENT_SECURITY_GROUP_NAMES_TO_LOOKUP,
get_context_str,
)
)
print(
"ENABLE_ECS_SERVICE_CONNECT=True: advertising Fargate service on "
f"Service Connect as {ECS_SERVICE_CONNECT_DISCOVERY_NAME}; "
f"client SGs: {', '.join(service_connect_client_sg_ids)}"
)
# --- VPC and Subnets (Assuming VPC is always lookup, Subnets are created/returned by create_subnets) ---
new_vpc_created = False
imported_vpc_cidr_block = None
imported_vpc_cidr_blocks: List[str] = []
if VPC_NAME:
vpc_id = get_context_str("vpc_id")
if not vpc_id:
raise ValueError(
f"VPC '{VPC_NAME}' was not resolved during pre-check (missing "
"'vpc_id' in context). Re-run from the cdk/ directory so "
"precheck.context.json is generated."
)
availability_zones = list(
dict.fromkeys(
(PUBLIC_SUBNET_AVAILABILITY_ZONES or [])
+ (PRIVATE_SUBNET_AVAILABILITY_ZONES or [])
)
)
if not availability_zones:
raise ValueError(
"vpc_id is in context but no subnet availability zones are "
"configured. Set PUBLIC_SUBNET_AVAILABILITY_ZONES and/or "
"PRIVATE_SUBNET_AVAILABILITY_ZONES in cdk_config.env."
)
vpc_cidr_block = get_context_str("vpc_cidr_block")
imported_vpc_cidr_block = vpc_cidr_block
imported_vpc_cidr_blocks = list(
self.node.try_get_context("vpc_cidr_blocks") or []
)
if (
imported_vpc_cidr_block
and imported_vpc_cidr_block not in imported_vpc_cidr_blocks
):
imported_vpc_cidr_blocks.insert(0, imported_vpc_cidr_block)
vpc_attrs = {
"vpc_id": vpc_id,
"availability_zones": availability_zones,
}
if vpc_cidr_block:
vpc_attrs["vpc_cidr_block"] = vpc_cidr_block
vpc = ec2.Vpc.from_vpc_attributes(self, "VPC", **vpc_attrs)
cidr_log = (
", ".join(imported_vpc_cidr_blocks)
if imported_vpc_cidr_blocks
else vpc_cidr_block
)
print(
f"Using VPC from pre-check context: {vpc_id}"
+ (f" (CIDR(s) {cidr_log})" if cidr_log else "")
)
elif NEW_VPC_DEFAULT_NAME and not VPC_NAME:
new_vpc_created = True
print(
f"NEW_VPC_DEFAULT_NAME ('{NEW_VPC_DEFAULT_NAME}') is set. Creating a new VPC."
)
# Configuration for the new VPC
# You can make these configurable via context as well, e.g.,
# new_vpc_cidr = self.node.try_get_context("new_vpc_cidr") or "10.0.0.0/24"
# new_vpc_max_azs = self.node.try_get_context("new_vpc_max_azs") or 2 # Use 2 AZs by default for HA
# new_vpc_nat_gateways = self.node.try_get_context("new_vpc_nat_gateways") or new_vpc_max_azs # One NAT GW per AZ for HA
# or 1 for cost savings if acceptable
if not NEW_VPC_CIDR:
raise Exception(
"App has been instructed to create a new VPC but not VPC CDR range provided to variable NEW_VPC_CIDR"
)
print("Provided NEW_VPC_CIDR range:", NEW_VPC_CIDR)
new_vpc_cidr = NEW_VPC_CIDR
new_vpc_max_azs = 2 # Creates resources in 2 AZs. Adjust as needed.
# For "a NAT gateway", you can set nat_gateways=1.
# For resilience (NAT GW per AZ), set nat_gateways=new_vpc_max_azs.
# The Vpc construct will create NAT Gateway(s) if subnet_type PRIVATE_WITH_EGRESS is used
# and nat_gateways > 0.
if express_public_subnets_only:
new_vpc_nat_gateways = 0
new_vpc_subnet_configuration = [
ec2.SubnetConfiguration(
name="Public",
subnet_type=ec2.SubnetType.PUBLIC,
# /27 (~27 usable IPs): Express managed ALB needs 8+ free IPs per
# subnet alongside VPC interface endpoints and task ENIs.
cidr_mask=27,
),
]
else:
new_vpc_nat_gateways = (
1 # Creates a single NAT Gateway for cost-effectiveness.
)
new_vpc_subnet_configuration = [
ec2.SubnetConfiguration(
name="Public", # Name prefix for public subnets
subnet_type=ec2.SubnetType.PUBLIC,
cidr_mask=26,
),
ec2.SubnetConfiguration(
name="Private", # Name prefix for private subnets
subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS, # Ensures these subnets have NAT Gateway access
cidr_mask=28,
),
# You could also add ec2.SubnetType.PRIVATE_ISOLATED if needed
]
# If you need one NAT GW per AZ for higher availability, set nat_gateways to new_vpc_max_azs.
vpc = ec2.Vpc(
self,
"MyNewLogicalVpc", # This is the CDK construct ID
vpc_name=NEW_VPC_DEFAULT_NAME,
ip_addresses=ec2.IpAddresses.cidr(new_vpc_cidr),
max_azs=new_vpc_max_azs,
nat_gateways=new_vpc_nat_gateways, # Number of NAT gateways to create
subnet_configuration=new_vpc_subnet_configuration,
# Internet Gateway is created and configured automatically for PUBLIC subnets.
# Route tables for public subnets will point to the IGW.
# Route tables for PRIVATE_WITH_EGRESS subnets will point to the NAT Gateway(s).
)
print(
f"Successfully created new VPC: {vpc.vpc_id} with name '{NEW_VPC_DEFAULT_NAME}'"
)
# If nat_gateways > 0, vpc.nat_gateway_ips will contain EIPs if Vpc created them.
# vpc.public_subnets, vpc.private_subnets, vpc.isolated_subnets are populated.
else:
raise Exception(
"VPC_NAME for current VPC not found, and NEW_VPC_DEFAULT_NAME not found to create a new VPC"
)
# --- Subnet Handling (Check Context and Create/Import) ---
# Initialize lists to hold ISubnet objects (L2) and CfnSubnet/CfnRouteTable (L1)
# We will store ISubnet for consistency, as CfnSubnet has a .subnet_id property
self.public_subnets: List[ec2.ISubnet] = []
self.private_subnets: List[ec2.ISubnet] = []
# Store L1 CfnRouteTables explicitly if you need to reference them later
self.private_route_tables_cfn: List[ec2.CfnRouteTable] = []
self.public_route_tables_cfn: List[ec2.CfnRouteTable] = (
[]
) # New: to store public RTs
names_to_create_private = []
names_to_create_public = []
if not PUBLIC_SUBNETS_TO_USE and not PRIVATE_SUBNETS_TO_USE:
if express_public_subnets_only:
print(
"Express public-subnet mode: auto-selecting public subnets only "
"(private subnets are not installed)."
)
selected_public_subnets = vpc.select_subnets(
subnet_type=ec2.SubnetType.PUBLIC, one_per_az=True
)
if len(selected_public_subnets.subnet_ids) < 2:
raise Exception(
"Express mode needs at least two public subnets in different "
"availability zones."
)
self.public_subnets = selected_public_subnets.subnets
self.private_subnets = []
print(
f"Selected {len(self.public_subnets)} public subnets for Express."
)
else:
print(
"Warning: No public or private subnets specified in *_SUBNETS_TO_USE. Attempting to select from existing VPC subnets."
)
print("vpc.public_subnets:", vpc.public_subnets)
print("vpc.private_subnets:", vpc.private_subnets)
if (
vpc.public_subnets
): # These are already one_per_az if max_azs was used and Vpc created them
self.public_subnets.extend(vpc.public_subnets)
else:
self.node.add_warning("No public subnets found in the VPC.")
# Get private subnets with egress specifically
# selected_private_subnets_with_egress = vpc.select_subnets(subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS)
print(
f"Selected from VPC: {len(self.public_subnets)} public, {len(self.private_subnets)} private_with_egress subnets."
)
if (
len(self.public_subnets) < 1 or len(self.private_subnets) < 1
): # Simplified check for new VPC
# If new_vpc_max_azs was 1, you'd have 1 of each. If 2, then 2 of each.
# The original check ' < 2' might be too strict if new_vpc_max_azs=1
pass # For new VPC, allow single AZ setups if configured that way. The VPC construct ensures one per AZ up to max_azs.
if not self.public_subnets and not self.private_subnets:
print(
"Error: No public or private subnets could be found in the VPC for automatic selection. "
"You must either specify subnets in *_SUBNETS_TO_USE or ensure the VPC has discoverable subnets."
)
raise RuntimeError(
"No suitable subnets found for automatic selection."
)
else:
print(
f"Automatically selected {len(self.public_subnets)} public and {len(self.private_subnets)} private subnets based on VPC properties."
)
selected_public_subnets = vpc.select_subnets(
subnet_type=ec2.SubnetType.PUBLIC, one_per_az=True
)
private_subnets_egress = vpc.select_subnets(
subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS, one_per_az=True
)
if private_subnets_egress.subnets:
self.private_subnets.extend(private_subnets_egress.subnets)
else:
self.node.add_warning(
"No PRIVATE_WITH_EGRESS subnets found in the VPC."
)
try:
private_subnets_isolated = vpc.select_subnets(
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED, one_per_az=True
)
except Exception as e:
private_subnets_isolated = []
print("Could not find any isolated subnets due to:", e)
###
combined_subnet_objects = []
if private_subnets_isolated:
if private_subnets_egress.subnets:
# Add the first PRIVATE_WITH_EGRESS subnet
combined_subnet_objects.append(
private_subnets_egress.subnets[0]
)
elif not private_subnets_isolated:
if private_subnets_egress.subnets:
# Add the first PRIVATE_WITH_EGRESS subnet
combined_subnet_objects.extend(private_subnets_egress.subnets)
else:
self.node.add_warning(
"No PRIVATE_WITH_EGRESS subnets found to select the first one."
)
# Add all PRIVATE_ISOLATED subnets *except* the first one (if they exist)
try:
if len(private_subnets_isolated.subnets) > 1:
combined_subnet_objects.extend(
private_subnets_isolated.subnets[1:]
)
elif (
private_subnets_isolated.subnets
): # Only 1 isolated subnet, add a warning if [1:] was desired
self.node.add_warning(
"Only one PRIVATE_ISOLATED subnet found, private_subnets_isolated.subnets[1:] will be empty."
)
else:
self.node.add_warning("No PRIVATE_ISOLATED subnets found.")
except Exception as e:
print("Could not identify private isolated subnets due to:", e)
# Create an ec2.SelectedSubnets object from the combined private subnet list.
selected_private_subnets = vpc.select_subnets(
subnets=combined_subnet_objects
)
print("selected_public_subnets:", selected_public_subnets)
print("selected_private_subnets:", selected_private_subnets)
if (
len(selected_public_subnets.subnet_ids) < 2
or len(selected_private_subnets.subnet_ids) < 2
):
raise Exception(
"Need at least two public or private subnets in different availability zones"
)
if not selected_public_subnets and not selected_private_subnets:
# If no subnets could be found even with automatic selection, raise an error.
# This ensures the stack doesn't proceed if it absolutely needs subnets.
print(
"Error: No existing public or private subnets could be found in the VPC for automatic selection. "
"You must either specify subnets in *_SUBNETS_TO_USE or ensure the VPC has discoverable subnets."
)
raise RuntimeError(
"No suitable subnets found for automatic selection."
)
else:
self.public_subnets = selected_public_subnets.subnets
self.private_subnets = selected_private_subnets.subnets
print(
f"Automatically selected {len(self.public_subnets)} public and {len(self.private_subnets)} private subnets based on VPC discovery."
)
print("self.public_subnets:", self.public_subnets)
print("self.private_subnets:", self.private_subnets)
# Since subnets are now assigned, we can exit this processing block.
# The rest of the original code (which iterates *_SUBNETS_TO_USE) will be skipped.
checked_public_subnets_ctx = get_context_dict("checked_public_subnets")
checked_private_subnets_ctx = get_context_dict("checked_private_subnets")
public_subnets_data_for_creation_ctx = get_context_list_of_dicts(
"public_subnets_to_create"
)
private_subnets_data_for_creation_ctx = get_context_list_of_dicts(
"private_subnets_to_create"
)
# --- 3. Process Public Subnets ---
print("\n--- Processing Public Subnets ---")
public_internet_gateway_attachment = None
if not new_vpc_created:
resolved_igw_id = (
get_context_str("internet_gateway_id") or EXISTING_IGW_ID or ""
).strip()
if resolved_igw_id and (
PUBLIC_SUBNETS_TO_USE
or public_subnets_data_for_creation_ctx
or get_context_list_of_dicts("public_subnets_needing_igw_route")
):
public_internet_gateway_attachment = wire_public_subnet_internet_access(
self,
"PublicSubnetInternet",
vpc_id=vpc.vpc_id,
internet_gateway_id=resolved_igw_id,
needs_igw_vpc_attachment=get_context_bool(
"internet_gateway_needs_vpc_attachment", False
),
subnets_needing_route=get_context_list_of_dicts(
"public_subnets_needing_igw_route"
),
)
# Import existing public subnets
if checked_public_subnets_ctx:
for i, subnet_name in enumerate(PUBLIC_SUBNETS_TO_USE):
subnet_info = checked_public_subnets_ctx.get(subnet_name)
if subnet_info and subnet_info.get("exists"):
subnet_id = subnet_info.get("id")
if not subnet_id:
raise RuntimeError(
f"Context for existing public subnet '{subnet_name}' is missing 'id'."
)
subnet_az = subnet_info.get("az")
if (
not subnet_az
and PUBLIC_SUBNET_AVAILABILITY_ZONES
and i < len(PUBLIC_SUBNET_AVAILABILITY_ZONES)
):
subnet_az = PUBLIC_SUBNET_AVAILABILITY_ZONES[i]
if not subnet_az:
raise RuntimeError(
f"Context for existing public subnet '{subnet_name}' is missing 'az'."
)
subnet_attrs = {
"subnet_id": subnet_id,
"availability_zone": subnet_az,
}
route_table_id = subnet_info.get("route_table_id")
if route_table_id:
subnet_attrs["route_table_id"] = route_table_id
try:
imported_subnet = ec2.Subnet.from_subnet_attributes(
self,
f"ImportedPublicSubnet{subnet_name.replace('-', '')}{i}",
**subnet_attrs,
)
self.public_subnets.append(imported_subnet)
print(
f"Imported existing public subnet: {subnet_name} (ID: {subnet_id})"
)
except Exception as e:
raise RuntimeError(
f"Failed to import public subnet '{subnet_name}' with ID '{subnet_id}'. Error: {e}"
)
# Create new public subnets based on public_subnets_data_for_creation_ctx
if public_subnets_data_for_creation_ctx:
names_to_create_public = [
s["name"] for s in public_subnets_data_for_creation_ctx
]
cidrs_to_create_public = [
s["cidr"] for s in public_subnets_data_for_creation_ctx
]
azs_to_create_public = [
s["az"] for s in public_subnets_data_for_creation_ctx
]
if names_to_create_public:
print(
f"Attempting to create {len(names_to_create_public)} new public subnets: {names_to_create_public}"
)
igw_for_new_subnets = (
get_context_str("internet_gateway_id") or EXISTING_IGW_ID
)
newly_created_public_subnets, newly_created_public_rts_cfn = (
create_subnets(
self,
vpc,
CDK_PREFIX,
names_to_create_public,
cidrs_to_create_public,
azs_to_create_public,
is_public=True,
internet_gateway_id=igw_for_new_subnets,
internet_gateway_attachment=public_internet_gateway_attachment,
)
)
self.public_subnets.extend(newly_created_public_subnets)
self.public_route_tables_cfn.extend(newly_created_public_rts_cfn)
if (
not self.public_subnets
and not names_to_create_public
and not PUBLIC_SUBNETS_TO_USE
):
raise Exception("No public subnets found or created, exiting.")
# --- NAT Gateway Creation/Lookup ---
self.single_nat_gateway_id = None
if express_public_subnets_only:
print(
"Express public-subnet mode: skipping NAT Gateway install "
"(not required for public Express tasks)."
)
else:
print("Creating NAT gateway/located existing")
nat_gw_id_from_context = SINGLE_NAT_GATEWAY_ID or get_context_str(
"id:NatGateway"
)
if nat_gw_id_from_context:
print(
f"Using existing NAT Gateway ID from context: {nat_gw_id_from_context}"
)
self.single_nat_gateway_id = nat_gw_id_from_context
elif (
new_vpc_created
and new_vpc_nat_gateways > 0
and hasattr(vpc, "nat_gateways")
and vpc.nat_gateways
):
self.single_nat_gateway_id = vpc.nat_gateways[0].gateway_id
print(
f"Using NAT Gateway {self.single_nat_gateway_id} created by the new VPC construct."
)
if not self.single_nat_gateway_id:
print("Creating a new NAT gateway")
if hasattr(vpc, "nat_gateways") and vpc.nat_gateways:
print("Existing NAT gateway found in vpc")
pass
# If not in context, create a new one, but only if we have a public subnet.
elif self.public_subnets:
print("NAT Gateway ID not found in context. Creating a new one.")
# Place the NAT GW in the first available public subnet
first_public_subnet = self.public_subnets[0]
self.single_nat_gateway_id = create_nat_gateway(
self,
first_public_subnet,
nat_gateway_name=NAT_GATEWAY_NAME,
nat_gateway_id_context_key=SINGLE_NAT_GATEWAY_ID,
)
else:
print(
"WARNING: No public subnets available and NAT gateway not found in existing VPC. Cannot create a NAT Gateway."
)
# --- 4. Process Private Subnets ---
if express_public_subnets_only:
if PRIVATE_SUBNETS_TO_USE or private_subnets_data_for_creation_ctx:
print(
"Note: PRIVATE_* subnet settings are ignored in Express public-subnet mode."
)
else:
print("\n--- Processing Private Subnets ---")
if checked_private_subnets_ctx:
for i, subnet_name in enumerate(PRIVATE_SUBNETS_TO_USE):
subnet_info = checked_private_subnets_ctx.get(subnet_name)
if subnet_info and subnet_info.get("exists"):
subnet_id = subnet_info.get("id")
if not subnet_id:
raise RuntimeError(
f"Context for existing private subnet '{subnet_name}' is missing 'id'."
)
subnet_az = subnet_info.get("az")
if (
not subnet_az
and PRIVATE_SUBNET_AVAILABILITY_ZONES
and i < len(PRIVATE_SUBNET_AVAILABILITY_ZONES)
):
subnet_az = PRIVATE_SUBNET_AVAILABILITY_ZONES[i]
if not subnet_az:
raise RuntimeError(
f"Context for existing private subnet '{subnet_name}' is missing 'az'."
)
subnet_attrs = {
"subnet_id": subnet_id,
"availability_zone": subnet_az,
}
route_table_id = subnet_info.get("route_table_id")
if route_table_id:
subnet_attrs["route_table_id"] = route_table_id
try:
imported_subnet = ec2.Subnet.from_subnet_attributes(
self,
f"ImportedPrivateSubnet{subnet_name.replace('-', '')}{i}",
**subnet_attrs,
)
self.private_subnets.append(imported_subnet)
print(
f"Imported existing private subnet: {subnet_name} (ID: {subnet_id})"
)
except Exception as e:
raise RuntimeError(
f"Failed to import private subnet '{subnet_name}' with ID '{subnet_id}'. Error: {e}"
)
# Create new private subnets
if private_subnets_data_for_creation_ctx:
names_to_create_private = [
s["name"] for s in private_subnets_data_for_creation_ctx
]
cidrs_to_create_private = [
s["cidr"] for s in private_subnets_data_for_creation_ctx
]
azs_to_create_private = [
s["az"] for s in private_subnets_data_for_creation_ctx
]
if names_to_create_private:
print(
f"Attempting to create {len(names_to_create_private)} new private subnets: {names_to_create_private}"
)
# --- CALL THE NEW CREATE_SUBNETS FUNCTION FOR PRIVATE ---
# Ensure self.single_nat_gateway_id is available before this call
if not self.single_nat_gateway_id:
raise ValueError(
"A single NAT Gateway ID is required for private subnets but was not resolved."
)
newly_created_private_subnets_cfn, newly_created_private_rts_cfn = (
create_subnets(
self,
vpc,
CDK_PREFIX,
names_to_create_private,
cidrs_to_create_private,
azs_to_create_private,
is_public=False,
single_nat_gateway_id=self.single_nat_gateway_id, # Pass the single NAT Gateway ID
)
)
self.private_subnets.extend(newly_created_private_subnets_cfn)
self.private_route_tables_cfn.extend(newly_created_private_rts_cfn)
print(
f"Successfully defined {len(newly_created_private_subnets_cfn)} new private subnets and their route tables for creation."
)
else:
print(
"No private subnets specified for creation in context ('private_subnets_to_create')."
)
# if not self.private_subnets:
# raise Exception("No private subnets found or created, exiting.")
if (
not self.private_subnets
and not names_to_create_private
and not PRIVATE_SUBNETS_TO_USE
):
# This condition might need adjustment for new VPCs.
raise Exception("No private subnets found or created, exiting.")
# --- 5. Sanity Check and Output ---
# Output the single NAT Gateway ID for verification
if self.single_nat_gateway_id:
CfnOutput(
self,
"SingleNatGatewayId",
value=self.single_nat_gateway_id,
description="ID of the single NAT Gateway resolved or created.",
)
elif express_public_subnets_only:
print(
"INFO: Express public-subnet mode — NAT Gateway not installed or required."
)
elif (
NEW_VPC_DEFAULT_NAME
and (self.node.try_get_context("new_vpc_nat_gateways") or 1) > 0
):
print(
"INFO: A new VPC was created with NAT Gateway(s). Their routing is handled by the VPC construct. No single_nat_gateway_id was explicitly set for separate output."
)
else:
out_message = "WARNING: No single NAT Gateway was resolved or created explicitly by the script's logic after VPC setup."
print(out_message)
raise Exception(out_message)
# --- Outputs for other stacks/regions ---
# These are crucial for cross-stack, cross-region referencing
self.params = dict()
self.params["vpc_id"] = vpc.vpc_id
self.params["private_subnets"] = self.private_subnets
self.params["private_route_tables"] = self.private_route_tables_cfn
self.params["public_subnets"] = self.public_subnets
self.params["public_route_tables"] = self.public_route_tables_cfn
private_subnet_selection = ec2.SubnetSelection(subnets=self.private_subnets)
public_subnet_selection = ec2.SubnetSelection(subnets=self.public_subnets)
for sub in private_subnet_selection.subnets:
print(
"private subnet:",
sub.subnet_id,
"is in availability zone:",
sub.availability_zone,
)
for sub in public_subnet_selection.subnets:
print(
"public subnet:",
sub.subnet_id,
"is in availability zone:",
sub.availability_zone,
)
print("Private subnet route tables:", self.private_route_tables_cfn)
CfnOutput(
self,
"VpcIdOutput",
value=vpc.vpc_id,
description="The ID of the VPC used by this stack.",
)
# --- IAM Roles ---
cognito_secret_name = COGNITO_USER_POOL_CLIENT_SECRET_NAME
secret_kms_key_arn_from_context = get_context_str(
f"kms_key_arn:{cognito_secret_name}"
)
if USE_CUSTOM_KMS_KEY == "1":
kms_key = kms.Key(
self,
"SummarisationSharedKmsKey",
alias=CUSTOM_KMS_KEY_NAME,
removal_policy=resource_removal_policy,
)
shared_kms_key_arn = kms_key.key_arn
secret_kms_key_arn = secret_kms_key_arn_from_context or kms_key.key_arn
else:
kms_key = None
shared_kms_key_arn = None
secret_kms_key_arn = (
secret_kms_key_arn_from_context
or default_secrets_manager_kms_key_arn(AWS_REGION, AWS_ACCOUNT_ID)
)
task_role_inline_policy = json.dumps(
build_ecs_task_role_inline_policy(
output_bucket_name=S3_OUTPUT_BUCKET_NAME,
log_config_bucket_name=S3_LOG_CONFIG_BUCKET_NAME,
shared_kms_key_arn=shared_kms_key_arn,
),
indent=4,
)
if enable_headless:
execution_role_kms_policy = json.dumps(
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "STSCallerIdentity",
"Effect": "Allow",
"Action": ["sts:GetCallerIdentity"],
"Resource": "*",
}
],
},
indent=4,
)
else:
execution_role_kms_policy = json.dumps(
build_ecs_execution_role_kms_policy(
secret_kms_key_arn=secret_kms_key_arn,
),
indent=4,
)
try:
codebuild_role_name = CODEBUILD_ROLE_NAME
if get_context_bool(f"exists:{codebuild_role_name}"):
# If exists, lookup/import the role using ARN from context
role_arn = get_context_str(f"arn:{codebuild_role_name}")
if not role_arn:
raise ValueError(
f"Context value 'arn:{codebuild_role_name}' is required if role exists."
)
codebuild_role = iam.Role.from_role_arn(
self, "CodeBuildRole", role_arn=role_arn
)
print("Using existing CodeBuild role")
else:
# If not exists, create the role
codebuild_role = iam.Role(
self,
"CodeBuildRole", # Logical ID
role_name=codebuild_role_name, # Explicit resource name
assumed_by=iam.ServicePrincipal("codebuild.amazonaws.com"),
)
codebuild_role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name(
"EC2InstanceProfileForImageBuilderECRContainerBuilds"
)
)
print("Successfully created new CodeBuild role")
task_role_name = ECS_TASK_ROLE_NAME
if get_context_bool(f"exists:{task_role_name}"):
role_arn = get_context_str(f"arn:{task_role_name}")
if not role_arn:
raise ValueError(
f"Context value 'arn:{task_role_name}' is required if role exists."
)
task_role = iam.Role.from_role_arn(self, "TaskRole", role_arn=role_arn)
print("Using existing ECS task role")
else:
task_role = iam.Role(
self,
"TaskRole", # Logical ID
role_name=task_role_name, # Explicit resource name
assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com"),
)
for role in AWS_MANAGED_TASK_ROLES_LIST:
print(f"Adding {role} to policy")
task_role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name(f"{role}")
)
attach_managed_policy_arns(task_role, POLICY_FILE_ARNS)
print("Successfully created new ECS task role")
task_role = add_custom_policies(
self,
task_role,
policy_file_locations=resolve_policy_file_paths(
POLICY_FILE_LOCATIONS, cdk_folder=CDK_FOLDER
),
custom_policy_text=task_role_inline_policy,
)
execution_role_name = ECS_TASK_EXECUTION_ROLE_NAME
if get_context_bool(f"exists:{execution_role_name}"):
role_arn = get_context_str(f"arn:{execution_role_name}")
if not role_arn:
raise ValueError(
f"Context value 'arn:{execution_role_name}' is required if role exists."
)
execution_role = iam.Role.from_role_arn(
self, "ExecutionRole", role_arn=role_arn
)
print("Using existing ECS execution role")
else:
execution_role = iam.Role(
self,
"ExecutionRole", # Logical ID
role_name=execution_role_name, # Explicit resource name
assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com"),
)
for role in ECS_EXECUTION_ROLE_MANAGED_POLICIES:
print(f"Adding {role} to execution role")
execution_role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name(f"{role}")
)
attach_managed_policy_arns(
execution_role, ECS_EXECUTION_ROLE_POLICY_ARNS
)
print("Successfully created new ECS execution role")
execution_role = add_custom_policies(
self,
execution_role,
policy_file_locations=resolve_policy_file_paths(
ECS_EXECUTION_ROLE_POLICY_FILES, cdk_folder=CDK_FOLDER
),
custom_policy_text=execution_role_kms_policy,
)
except Exception as e:
raise Exception("Failed at IAM role step due to:", e)
# --- S3 Buckets ---
try:
log_bucket_name = S3_LOG_CONFIG_BUCKET_NAME
if get_context_bool(f"globally_taken:{log_bucket_name}"):
raise ValueError(
f"S3 bucket name {log_bucket_name!r} is taken globally by another "
"AWS account. Set S3_LOG_CONFIG_BUCKET_NAME in cdk/config/cdk_config.env "
"to a unique name (re-run cdk_install.py or check_resources.py)."
)
if get_context_bool(f"exists:{log_bucket_name}"):
bucket = s3.Bucket.from_bucket_name(
self, "LogConfigBucket", bucket_name=log_bucket_name
)
print("Using existing S3 bucket", log_bucket_name)
else:
log_bucket_lifecycle = [
s3.LifecycleRule(
abort_incomplete_multipart_upload_after=Duration.days(7)
)
]
if USE_CUSTOM_KMS_KEY == "1" and isinstance(kms_key, kms.Key):
bucket = s3.Bucket(
self,
"LogConfigBucket",
bucket_name=log_bucket_name,
lifecycle_rules=log_bucket_lifecycle,
versioned=False,
removal_policy=resource_removal_policy,
auto_delete_objects=s3_auto_delete_objects,
encryption=s3.BucketEncryption.KMS,
encryption_key=kms_key,
)
else:
bucket = s3.Bucket(
self,
"LogConfigBucket",
bucket_name=log_bucket_name,
lifecycle_rules=log_bucket_lifecycle,
versioned=False,
removal_policy=resource_removal_policy,
auto_delete_objects=s3_auto_delete_objects,
)
print("Created S3 bucket", log_bucket_name)
# Add policies - this will apply to both created and imported buckets
# CDK handles idempotent policy additions
bucket.add_to_resource_policy(
iam.PolicyStatement(
effect=iam.Effect.ALLOW,
principals=[task_role], # Pass the role object directly
actions=["s3:GetObject", "s3:PutObject"],
resources=[f"{bucket.bucket_arn}/*"],
)
)
bucket.add_to_resource_policy(
iam.PolicyStatement(
effect=iam.Effect.ALLOW,
principals=[task_role],
actions=["s3:ListBucket"],
resources=[bucket.bucket_arn],
)
)
output_bucket_name = S3_OUTPUT_BUCKET_NAME
if get_context_bool(f"globally_taken:{output_bucket_name}"):
raise ValueError(
f"S3 bucket name {output_bucket_name!r} is taken globally by another "
"AWS account. Set S3_OUTPUT_BUCKET_NAME in cdk/config/cdk_config.env "
"to a unique name (re-run cdk_install.py or check_resources.py)."
)
if get_context_bool(f"exists:{output_bucket_name}"):
output_bucket = s3.Bucket.from_bucket_name(
self, "OutputBucket", bucket_name=output_bucket_name
)
print("Using existing Output bucket", output_bucket_name)
else:
if USE_CUSTOM_KMS_KEY == "1" and isinstance(kms_key, kms.Key):
output_bucket = s3.Bucket(
self,
"OutputBucket",
bucket_name=output_bucket_name,
lifecycle_rules=[
s3.LifecycleRule(
expiration=Duration.days(
int(DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS)
)
)
],
versioned=False,
removal_policy=resource_removal_policy,
auto_delete_objects=s3_auto_delete_objects,
encryption=s3.BucketEncryption.KMS,
encryption_key=kms_key,
)
else:
output_bucket = s3.Bucket(
self,
"OutputBucket",
bucket_name=output_bucket_name,
lifecycle_rules=[
s3.LifecycleRule(
expiration=Duration.days(
int(DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS)
)
)
],
versioned=False,
removal_policy=resource_removal_policy,
auto_delete_objects=s3_auto_delete_objects,
)
print("Created Output bucket:", output_bucket_name)
add_s3_enforce_ssl_policy(bucket)
add_s3_enforce_ssl_policy(output_bucket)
# Add policies to output bucket
output_bucket.add_to_resource_policy(
iam.PolicyStatement(
effect=iam.Effect.ALLOW,
principals=[task_role],
actions=["s3:GetObject", "s3:PutObject"],
resources=[f"{output_bucket.bucket_arn}/*"],
)
)
output_bucket.add_to_resource_policy(
iam.PolicyStatement(
effect=iam.Effect.ALLOW,
principals=[task_role],
actions=["s3:ListBucket"],
resources=[output_bucket.bucket_arn],
)
)
# Identity-based S3 access is scoped via build_ecs_task_role_inline_policy on
# task_role (output + log/config buckets). Bucket policies above remain for
# imported buckets and org policies that expect explicit bucket principals.
except Exception as e:
raise Exception("Could not handle S3 buckets due to:", e)
# --- Elastic Container Registry ---
try:
full_ecr_repo_name = ECR_CDK_REPO_NAME
if get_context_bool(f"exists:{full_ecr_repo_name}"):
ecr_repo = ecr.Repository.from_repository_name(
self, "ECRRepo", repository_name=full_ecr_repo_name
)
print("Using existing ECR repository")
else:
ecr_repo = ecr.Repository(
self,
"ECRRepo",
repository_name=full_ecr_repo_name,
removal_policy=resource_removal_policy,
empty_on_delete=ecr_empty_on_delete(),
) # Explicitly set repository_name
print("Created ECR repository", full_ecr_repo_name)
ecr_image_loc = ecr_repo.repository_uri
except Exception as e:
raise Exception("Could not handle ECR repo due to:", e)
pi_ecr_image_loc = ecr_image_loc
# --- CODEBUILD ---
try:
codebuild_project_name = CODEBUILD_PROJECT_NAME
if get_context_bool(f"exists:{codebuild_project_name}"):
# Lookup CodeBuild project by ARN from context
project_arn = get_context_str(f"arn:{codebuild_project_name}")
if not project_arn:
raise ValueError(
f"Context value 'arn:{codebuild_project_name}' is required if project exists."
)
codebuild.Project.from_project_arn(
self, "CodeBuildProject", project_arn=project_arn
)
print(
"Using existing CodeBuild project "
"(public GitHub source is applied in post_cdk_build_quickstart)."
)
else:
main_codebuild_project = codebuild.Project(
self,
"CodeBuildProject", # Logical ID
project_name=codebuild_project_name, # Explicit resource name
role=codebuild_role,
source=public_github_codebuild_source(
owner=GITHUB_REPO_USERNAME,
repo=GITHUB_REPO_NAME,
branch_or_ref=GITHUB_REPO_BRANCH,
),
environment=codebuild.BuildEnvironment(
build_image=codebuild.LinuxBuildImage.STANDARD_7_0,
privileged=True,
environment_variables={
"ECR_REPO_NAME": codebuild.BuildEnvironmentVariable(
value=full_ecr_repo_name
),
"AWS_DEFAULT_REGION": codebuild.BuildEnvironmentVariable(
value=AWS_REGION
),
"AWS_ACCOUNT_ID": codebuild.BuildEnvironmentVariable(
value=AWS_ACCOUNT_ID
),
"APP_MODE": codebuild.BuildEnvironmentVariable(
value="gradio"
),
},
),
build_spec=codebuild.BuildSpec.from_object(
{
"version": "0.2",
"phases": {
"pre_build": {
"commands": [
"echo Logging in to Amazon ECR",
"aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com",
]
},
"build": {
"commands": [
"echo Building the Docker image",
"docker build --build-arg APP_MODE=$APP_MODE --target $APP_MODE -t $ECR_REPO_NAME:latest .",
"docker tag $ECR_REPO_NAME:latest $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest",
]
},
"post_build": {
"commands": [
"echo Pushing the Docker image",
"docker push $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest",
]
},
},
}
),
)
configure_public_github_codebuild_source(
main_codebuild_project,
GITHUB_REPO_USERNAME,
GITHUB_REPO_NAME,
GITHUB_REPO_BRANCH,
)
print("Successfully created CodeBuild project", codebuild_project_name)
# Imported projects have role=undefined in CDK; use the actual service
# role from context (existing project) or the managed codebuild_role (new).
if get_context_bool(f"exists:{codebuild_project_name}"):
project_service_role_arn = get_context_str(
f"service_role_arn:{codebuild_project_name}"
)
if project_service_role_arn:
ecr_grantee = iam.Role.from_role_arn(
self,
"CodeBuildProjectServiceRole",
role_arn=project_service_role_arn,
mutable=True,
)
else:
ecr_grantee = codebuild_role
else:
ecr_grantee = codebuild_role
ecr_repo.grant_pull_push(ecr_grantee)
if enable_pi_build:
pi_codebuild_name = CODEBUILD_PI_PROJECT_NAME
if get_context_bool(f"exists:{pi_codebuild_name}"):
project_arn = get_context_str(f"arn:{pi_codebuild_name}")
if project_arn:
codebuild.Project.from_project_arn(
self, "CodeBuildPiProject", project_arn=project_arn
)
print("Using existing Pi agent CodeBuild project")
else:
pi_codebuild_project = codebuild.Project(
self,
"CodeBuildPiProject",
project_name=pi_codebuild_name,
role=codebuild_role,
source=public_github_codebuild_source(
owner=GITHUB_REPO_USERNAME,
repo=GITHUB_REPO_NAME,
branch_or_ref=GITHUB_REPO_BRANCH,
),
environment=codebuild.BuildEnvironment(
build_image=codebuild.LinuxBuildImage.STANDARD_7_0,
privileged=True,
environment_variables={
"ECR_REPO_NAME": codebuild.BuildEnvironmentVariable(
value=ECR_PI_REPO_NAME
),
"AWS_DEFAULT_REGION": codebuild.BuildEnvironmentVariable(
value=AWS_REGION
),
"AWS_ACCOUNT_ID": codebuild.BuildEnvironmentVariable(
value=AWS_ACCOUNT_ID
),
},
),
build_spec=codebuild.BuildSpec.from_object(
{
"version": "0.2",
"phases": {
"pre_build": {
"commands": [
"echo Logging in to Amazon ECR",
"aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com",
"test -f config/pi_agent.env.example",
"test -f agent-redact/pi-agent/Dockerfile",
]
},
"build": {
"commands": [
"docker build -f agent-redact/pi-agent/Dockerfile -t $ECR_REPO_NAME:latest .",
"docker tag $ECR_REPO_NAME:latest $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest",
]
},
"post_build": {
"commands": [
"docker push $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest",
]
},
},
}
),
)
configure_public_github_codebuild_source(
pi_codebuild_project,
GITHUB_REPO_USERNAME,
GITHUB_REPO_NAME,
GITHUB_REPO_BRANCH,
)
print("Created Pi agent CodeBuild project", pi_codebuild_name)
pi_ecr_repo_name = ECR_PI_REPO_NAME
if get_context_bool(f"exists:{pi_ecr_repo_name}"):
pi_ecr_repo = ecr.Repository.from_repository_name(
self, "ECRPiRepo", repository_name=pi_ecr_repo_name
)
else:
pi_ecr_repo = ecr.Repository(
self,
"ECRPiRepo",
repository_name=pi_ecr_repo_name,
removal_policy=resource_removal_policy,
empty_on_delete=ecr_empty_on_delete(),
)
pi_ecr_image_loc = pi_ecr_repo.repository_uri
pi_ecr_repo.grant_pull_push(ecr_grantee)
CfnOutput(self, "ECRPiRepoUri", value=pi_ecr_repo.repository_uri)
except Exception as e:
raise Exception("Could not handle Codebuild project due to:", e)
pi_ecs_service = None
pi_ecs_security_group = None
# --- Security Groups ---
try:
ecs_security_group_name = ECS_SECURITY_GROUP_NAME
try:
ecs_security_group = ec2.SecurityGroup(
self,
"ECSSecurityGroup", # Logical ID
security_group_name=ecs_security_group_name, # Explicit resource name
vpc=vpc,
)
print(f"Created Security Group: {ecs_security_group_name}")
except Exception as e: # If lookup fails, create
print("Failed to create ECS security group due to:", e)
ec2_port_gradio_server_port = ec2.Port.tcp(int(GRADIO_SERVER_PORT))
if deploy_web_ingress:
alb_security_group_name = ALB_NAME_SECURITY_GROUP_NAME
try:
alb_security_group = ec2.SecurityGroup(
self,
"ALBSecurityGroup", # Logical ID
security_group_name=alb_security_group_name,
vpc=vpc,
)
print(f"Created Security Group: {alb_security_group_name}")
except Exception as e:
print("Failed to create ALB security group due to:", e)
ecs_security_group.add_ingress_rule(
peer=alb_security_group,
connection=ec2_port_gradio_server_port,
description="ALB traffic",
)
alb_security_group.add_ingress_rule(
peer=ec2.Peer.prefix_list(CLOUDFRONT_PREFIX_LIST_ID),
connection=ec2.Port.all_traffic(),
description="CloudFront traffic",
)
else:
alb_security_group = None
if USE_CLOUDFRONT == "True":
ecs_security_group.add_ingress_rule(
peer=ec2.Peer.prefix_list(CLOUDFRONT_PREFIX_LIST_ID),
connection=ec2_port_gradio_server_port,
description="CloudFront to ECS (Express Mode)",
)
if enable_service_connect:
for index, client_sg_id in enumerate(service_connect_client_sg_ids):
client_sg = ec2.SecurityGroup.from_security_group_id(
self,
f"ServiceConnectClientSg{index}",
security_group_id=client_sg_id,
)
ecs_security_group.add_ingress_rule(
peer=client_sg,
connection=ec2_port_gradio_server_port,
description=(
f"Service Connect client {client_sg_id} to app port"
),
)
print(
"Service Connect ingress allowed from security groups: "
+ ", ".join(service_connect_client_sg_ids)
)
except Exception as e:
raise Exception("Could not handle security groups due to:", e)
endpoint_subnet_selection = resolve_ecs_vpc_endpoint_subnet_selection(
use_express_ingress=use_express_ingress,
express_use_public_subnets=ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True",
public_subnets=self.public_subnets,
private_subnets=self.private_subnets,
)
s3_gateway_subnet_selection = resolve_ecs_s3_gateway_subnet_selection(
public_subnets=self.public_subnets,
private_subnets=self.private_subnets,
)
if ENABLE_ECS_VPC_INTERFACE_ENDPOINTS == "True" and (
endpoint_subnet_selection or s3_gateway_subnet_selection
):
if (
VPC_NAME
and not imported_vpc_cidr_block
and not imported_vpc_cidr_blocks
):
raise ValueError(
"vpc_cidr_block / vpc_cidr_blocks missing from precheck.context.json. "
"Re-run check_resources.py from the cdk/ directory so the VPC "
"CIDR(s) are stored for VPC endpoints and security groups."
)
existing_endpoint_services = frozenset(
self.node.try_get_context("existing_vpc_endpoint_service_names") or []
)
if VPC_NAME and not existing_endpoint_services:
print(
"Note: existing_vpc_endpoint_service_names not in precheck context; "
"re-run check_resources.py to skip duplicate endpoints in shared VPCs."
)
try:
endpoint_tier = (
"public"
if use_express_ingress and ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True"
else "private"
)
create_ecs_vpc_endpoints_for_private_subnets(
self,
vpc=vpc,
subnets=endpoint_subnet_selection,
s3_gateway_subnets=s3_gateway_subnet_selection,
logical_id_prefix="SummarisationEcs",
include_secrets_and_kms=True,
vpc_cidr_block=imported_vpc_cidr_block,
vpc_cidr_blocks=imported_vpc_cidr_blocks or None,
skip_service_names=existing_endpoint_services,
aws_region=AWS_REGION,
)
s3_subnet_count = len(
(s3_gateway_subnet_selection.subnets or [])
if s3_gateway_subnet_selection
else []
)
print(
"Defined ECS VPC interface endpoints (ECR, Logs, Secrets Manager, "
f"KMS) for {endpoint_tier} subnets where not already present; "
f"S3 gateway for {s3_subnet_count} stack subnet(s) (public + "
"private) where not already present."
)
except Exception as e:
raise Exception(
"Could not create ECS VPC interface endpoints for ECS task subnets. "
"If this VPC already has them, re-run check_resources.py (auto-skip) "
"or set ENABLE_ECS_VPC_INTERFACE_ENDPOINTS=False in cdk_config.env "
"and ensure task subnets reach ECR (NAT, IGW, or existing endpoints).",
e,
) from e
# --- DynamoDB tables for logs (optional) ---
usage_log_table = None
if SAVE_LOGS_TO_DYNAMODB == "True":
try:
print("Creating DynamoDB tables for logs")
dynamodb.Table(
self,
"SummarisationAccessDataTable",
table_name=ACCESS_LOG_DYNAMODB_TABLE_NAME,
partition_key=dynamodb.Attribute(
name="id", type=dynamodb.AttributeType.STRING
),
billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST,
deletion_protection=resource_delete_protection,
removal_policy=resource_removal_policy,
)
dynamodb.Table(
self,
"SummarisationFeedbackDataTable",
table_name=FEEDBACK_LOG_DYNAMODB_TABLE_NAME,
partition_key=dynamodb.Attribute(
name="id", type=dynamodb.AttributeType.STRING
),
billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST,
deletion_protection=resource_delete_protection,
removal_policy=resource_removal_policy,
)
usage_log_table = dynamodb.Table(
self,
"SummarisationUsageDataTable",
table_name=USAGE_LOG_DYNAMODB_TABLE_NAME,
partition_key=dynamodb.Attribute(
name="id", type=dynamodb.AttributeType.STRING
),
billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST,
deletion_protection=resource_delete_protection,
removal_policy=resource_removal_policy,
)
except Exception as e:
raise Exception("Could not create DynamoDB tables due to:", e)
if ENABLE_DYNAMODB_USAGE_LOG_EXPORT == "True":
try:
if usage_log_table is None:
raise ValueError(
"ENABLE_DYNAMODB_USAGE_LOG_EXPORT=True requires "
"SAVE_LOGS_TO_DYNAMODB=True and the usage log table."
)
lambda_asset_dir = os.path.join(
os.path.dirname(__file__), "lambda_dynamo_logs_export"
)
create_dynamo_usage_log_export_lambda(
self,
"DynamoUsageLogExport",
function_name=DYNAMODB_USAGE_LOG_EXPORT_LAMBDA_NAME or None,
lambda_asset_path=lambda_asset_dir,
dynamodb_table=usage_log_table,
output_bucket=output_bucket,
s3_output_key=DYNAMODB_USAGE_LOG_EXPORT_S3_KEY,
schedule_expression=DYNAMODB_USAGE_LOG_EXPORT_SCHEDULE,
dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME,
date_attribute=DYNAMODB_USAGE_LOG_EXPORT_DATE_ATTRIBUTE,
output_filename=DYNAMODB_USAGE_LOG_EXPORT_OUTPUT_FILENAME,
shared_kms_key_arn=shared_kms_key_arn,
)
print(
"Scheduled DynamoDB usage log export Lambda defined "
f"({DYNAMODB_USAGE_LOG_EXPORT_SCHEDULE})."
)
except Exception as e:
raise Exception(
"Could not handle DynamoDB usage log export Lambda due to:", e
) from e
alb = None
load_balancer_name = ALB_NAME
if len(load_balancer_name) > 32:
load_balancer_name = load_balancer_name[-32:]
if deploy_web_ingress:
# --- ALB (legacy path) ---
try:
alb_arn = get_context_str(f"arn:{load_balancer_name}") or (
EXISTING_LOAD_BALANCER_ARN or None
)
alb_dns_name = get_context_str(f"dns:{load_balancer_name}") or (
EXISTING_LOAD_BALANCER_DNS or None
)
if alb_arn and alb_dns_name:
alb_security_group_id = (
get_context_str(f"security_group_id:{load_balancer_name}")
or alb_security_group.security_group_id
)
alb_attrs = {
"load_balancer_arn": alb_arn,
"load_balancer_dns_name": alb_dns_name,
"security_group_id": alb_security_group_id,
"vpc": vpc,
}
alb_canonical_zone_id = get_context_str(
f"canonical_hosted_zone_id:{load_balancer_name}"
)
if alb_canonical_zone_id:
alb_attrs["load_balancer_canonical_hosted_zone_id"] = (
alb_canonical_zone_id
)
alb = elbv2.ApplicationLoadBalancer.from_application_load_balancer_attributes(
self,
"ALB",
**alb_attrs,
)
print(
f"Using existing Application Load Balancer {load_balancer_name}."
)
else:
alb = elbv2.ApplicationLoadBalancer(
self,
"ALB",
load_balancer_name=load_balancer_name,
vpc=vpc,
internet_facing=True,
security_group=alb_security_group,
vpc_subnets=public_subnet_selection,
drop_invalid_header_fields=True,
deletion_protection=resource_delete_protection,
)
print("Successfully created new Application Load Balancer")
except Exception as e:
raise Exception("Could not handle application load balancer due to:", e)
# --- Cognito User Pool (web login; skipped for headless batch-only) ---
user_pool = None
user_pool_client = None
user_pool_domain = None
secret = None
if enable_headless:
print(
"ENABLE_HEADLESS_DEPLOYMENT=True: skipping Cognito user pool, "
"hosted UI domain, and client secret (no web login for batch tasks)."
)
else:
try:
if get_context_bool(f"exists:{COGNITO_USER_POOL_NAME}"):
# Lookup by ID from context
user_pool_id = get_context_str(f"id:{COGNITO_USER_POOL_NAME}")
if not user_pool_id:
raise ValueError(
f"Context value 'id:{COGNITO_USER_POOL_NAME}' is required if User Pool exists."
)
user_pool = cognito.UserPool.from_user_pool_id(
self, "UserPool", user_pool_id=user_pool_id
)
print(f"Using existing user pool {user_pool_id}.")
else:
user_pool = cognito.UserPool(
self,
"UserPool",
user_pool_name=COGNITO_USER_POOL_NAME,
mfa=cognito.Mfa.OFF, # Adjust as needed
sign_in_aliases=cognito.SignInAliases(email=True),
deletion_protection=resource_delete_protection,
removal_policy=resource_removal_policy,
) # Adjust as needed
print(f"Created new user pool {user_pool.user_pool_id}.")
# HTTPS ALB (ACM cert or Express Mode) needs oauth2/idpresponse callback URLs.
if ACM_SSL_CERTIFICATE_ARN or use_express_ingress:
redirect_uris = [
COGNITO_REDIRECTION_URL,
COGNITO_REDIRECTION_URL + "/oauth2/idpresponse",
]
else:
redirect_uris = [COGNITO_REDIRECTION_URL]
user_pool_client_name = COGNITO_USER_POOL_CLIENT_NAME
if get_context_bool(f"exists:{user_pool_client_name}"):
# Lookup by ID from context (requires User Pool object)
user_pool_client_id = get_context_str(f"id:{user_pool_client_name}")
if not user_pool_client_id:
raise ValueError(
f"Context value 'id:{user_pool_client_name}' is required if User Pool Client exists."
)
user_pool_client = cognito.UserPoolClient.from_user_pool_client_id(
self, "UserPoolClient", user_pool_client_id=user_pool_client_id
)
print(f"Using existing user pool client {user_pool_client_id}.")
else:
user_pool_client = cognito.UserPoolClient(
self,
"UserPoolClient",
auth_flows=cognito.AuthFlow(
user_srp=True, user_password=True
), # Example: enable SRP for secure sign-in
user_pool=user_pool,
generate_secret=True,
user_pool_client_name=user_pool_client_name,
supported_identity_providers=[
cognito.UserPoolClientIdentityProvider.COGNITO
],
o_auth=cognito.OAuthSettings(
flows=cognito.OAuthFlows(authorization_code_grant=True),
scopes=[
cognito.OAuthScope.OPENID,
cognito.OAuthScope.EMAIL,
cognito.OAuthScope.PROFILE,
],
callback_urls=redirect_uris,
),
refresh_token_validity=Duration.minutes(
COGNITO_REFRESH_TOKEN_VALIDITY
),
id_token_validity=Duration.minutes(COGNITO_ID_TOKEN_VALIDITY),
access_token_validity=Duration.minutes(
COGNITO_ACCESS_TOKEN_VALIDITY
),
)
CfnOutput(
self,
"CognitoAppClientId",
value=user_pool_client.user_pool_client_id,
)
print(
f"Created new user pool client {user_pool_client.user_pool_client_id}."
)
# Add a domain to the User Pool (crucial for ALB integration)
domain_prefix = (COGNITO_USER_POOL_DOMAIN_PREFIX or "").strip().lower()
if get_context_bool(f"cognito_domain_taken:{domain_prefix}"):
raise ValueError(
f"Cognito hosted UI domain prefix {domain_prefix!r} is not "
f"available in this region (taken by another AWS account or "
"an existing pool). Set COGNITO_USER_POOL_DOMAIN_PREFIX in "
"cdk/config/cdk_config.env to a unique value and re-run "
"cdk_install.py / check_resources.py."
)
user_pool_domain = user_pool.add_domain(
"UserPoolDomain",
cognito_domain=cognito.CognitoDomainOptions(
domain_prefix=COGNITO_USER_POOL_DOMAIN_PREFIX
),
)
# Apply removal_policy to the created UserPoolDomain construct
user_pool_domain.apply_removal_policy(policy=resource_removal_policy)
CfnOutput(
self, "CognitoUserPoolLoginUrl", value=user_pool_domain.base_url()
)
except Exception as e:
raise Exception("Could not handle Cognito resources due to:", e)
# --- Secrets Manager Secret ---
try:
secret_name = COGNITO_USER_POOL_CLIENT_SECRET_NAME
if get_context_bool(f"exists:{secret_name}"):
secret_arn = get_context_str(f"arn:{secret_name}")
if secret_arn:
secret = secretsmanager.Secret.from_secret_complete_arn(
self,
"CognitoSecret",
secret_complete_arn=secret_arn,
)
print("Using existing Secret (ARN from precheck context).")
else:
secret = secretsmanager.Secret.from_secret_name_v2(
self, "CognitoSecret", secret_name=secret_name
)
print(
"Using existing Secret by name (IAM grants use ARN wildcard "
"suffix; re-run precheck to pin the full ARN)."
)
else:
if USE_CUSTOM_KMS_KEY == "1" and isinstance(kms_key, kms.Key):
secret = secretsmanager.Secret(
self,
"CognitoSecret", # Logical ID
secret_name=secret_name, # Explicit resource name
secret_object_value={
"SUMMARISATION_USER_POOL_ID": SecretValue.unsafe_plain_text(
user_pool.user_pool_id
), # Use the CDK attribute
"SUMMARISATION_CLIENT_ID": SecretValue.unsafe_plain_text(
user_pool_client.user_pool_client_id
), # Use the CDK attribute
"SUMMARISATION_CLIENT_SECRET": user_pool_client.user_pool_client_secret, # Use the CDK attribute
},
encryption_key=kms_key,
removal_policy=resource_removal_policy,
)
else:
secret = secretsmanager.Secret(
self,
"CognitoSecret", # Logical ID
secret_name=secret_name, # Explicit resource name
secret_object_value={
"SUMMARISATION_USER_POOL_ID": SecretValue.unsafe_plain_text(
user_pool.user_pool_id
), # Use the CDK attribute
"SUMMARISATION_CLIENT_ID": SecretValue.unsafe_plain_text(
user_pool_client.user_pool_client_id
), # Use the CDK attribute
"SUMMARISATION_CLIENT_SECRET": user_pool_client.user_pool_client_secret, # Use the CDK attribute
},
removal_policy=resource_removal_policy,
)
print(
"Created new secret in Secrets Manager for Cognito user pool and related details."
)
except Exception as e:
raise Exception("Could not handle Secrets Manager secret due to:", e)
try:
secret.grant_read(task_role)
secret.grant_read(execution_role)
except Exception as e:
raise Exception("Could not grant access to Secrets Manager due to:", e)
try:
# ECS environmentFiles (app_config.env) are fetched by the execution role at task start.
bucket.grant_read(execution_role, APP_CONFIG_ENV_BASENAME)
# KMS: task role uses shared S3 CMK via build_ecs_task_role_inline_policy;
# execution role uses the secret's CMK via build_ecs_execution_role_kms_policy.
except Exception as e:
raise Exception("Could not grant bucket read to execution role due to:", e)
# --- ECS Cluster (shared by legacy Fargate and Express paths) ---
try:
cluster_kwargs = {
"cluster_name": CLUSTER_NAME,
"enable_fargate_capacity_providers": True,
"vpc": vpc,
}
if enable_service_connect or enable_pi_express:
cluster_kwargs["default_cloud_map_namespace"] = (
ecs.CloudMapNamespaceOptions(
name=ECS_SERVICE_CONNECT_NAMESPACE,
vpc=vpc,
)
)
cluster = ecs.Cluster(self, "ECSCluster", **cluster_kwargs)
print("Successfully created new ECS cluster")
except Exception as e:
raise Exception("Could not handle ECS cluster due to:", e)
express_service = None
express_alb_security_group_id = None
if use_express_ingress:
try:
express_log_group = logs.LogGroup(
self,
"ExpressTaskLogGroup",
log_group_name=f"/ecs/{ECS_EXPRESS_SERVICE_NAME}-logs".lower(),
retention=logs.RetentionDays.ONE_MONTH,
removal_policy=resource_removal_policy,
)
express_log_group.grant_write(execution_role)
express_infra_role = create_ecs_express_infrastructure_role(
self,
"ExpressInfrastructureRole",
ECS_EXPRESS_INFRASTRUCTURE_ROLE_NAME,
)
express_app_overrides: Dict[str, str] = {}
if ENABLE_HEADLESS_DEPLOYMENT == "True":
express_app_overrides["COGNITO_AUTH"] = "False"
elif enable_pi_express:
# Pi agent calls main over Service Connect; Gradio auth blocks
# gradio_client unless credentials are passed on every call.
express_app_overrides["COGNITO_AUTH"] = "False"
express_app_environment = load_app_config_env_for_express(
APP_CONFIG_ENV_FILE,
overrides=express_app_overrides or None,
)
primary_container = build_express_gateway_primary_container(
image_uri=ecr_image_loc + ":latest",
container_port=int(GRADIO_SERVER_PORT),
log_group_name=express_log_group.log_group_name,
aws_region=AWS_REGION,
secret=secret,
environment=express_app_environment,
)
express_use_public_subnets = ECS_EXPRESS_USE_PUBLIC_SUBNETS == "True"
express_subnet_ids = [
s.subnet_id
for s in (
self.public_subnets
if express_use_public_subnets
else self.private_subnets
)
]
if not express_subnet_ids:
tier = "public" if express_use_public_subnets else "private"
raise ValueError(
f"No {tier} subnets available for ECS Express Mode. "
f"Set ECS_EXPRESS_USE_PUBLIC_SUBNETS=False to use private "
"subnets (internal ALB only), or create/import public subnets."
)
if express_use_public_subnets:
print(
"ECS Express Mode using public subnets "
"(internet-facing managed ALB)."
)
else:
print(
"ECS Express Mode using private subnets "
"(internal managed ALB)."
)
# MinTaskCount=0 until post_cdk_build_quickstart builds/pushes :latest.
express_service = create_express_gateway_service(
self,
"ExpressGatewayService",
service_name=ECS_EXPRESS_SERVICE_NAME,
cluster_name=CLUSTER_NAME,
execution_role_arn=execution_role.role_arn,
infrastructure_role_arn=express_infra_role.role_arn,
task_role_arn=task_role.role_arn,
cpu=str(ECS_TASK_CPU_SIZE),
memory=str(ECS_TASK_MEMORY_SIZE),
health_check_path=ECS_EXPRESS_HEALTH_CHECK_PATH,
primary_container=primary_container,
subnet_ids=express_subnet_ids,
security_group_ids=[ecs_security_group.security_group_id],
)
express_service.node.add_dependency(cluster)
allow_express_load_balancer_to_ecs_security_group(
self,
"ExpressAlbToEcsIngress",
express_service=express_service,
ecs_security_group=ecs_security_group,
container_port=int(GRADIO_SERVER_PORT),
)
express_alb_arn = express_ingress_load_balancer_arn(express_service)
express_alb_dns = express_service.attr_endpoint
express_alb_security_group_id = (
express_ingress_first_load_balancer_security_group(express_service)
)
alb = elbv2.ApplicationLoadBalancer.from_application_load_balancer_attributes(
self,
"ALB",
load_balancer_arn=express_alb_arn,
load_balancer_dns_name=express_alb_dns,
security_group_id=express_alb_security_group_id,
vpc=vpc,
)
# Express Mode manages host-header listener rules (priorities 1, 2, …).
# Do not add ALB authenticate-cognito rules here; use in-app COGNITO_AUTH.
CfnOutput(
self,
"ExpressServiceEndpoint",
value=express_service.attr_endpoint,
description="HTTPS URL for the ECS Express Mode service",
)
CfnOutput(
self,
"ExpressServiceArn",
value=express_service.attr_service_arn,
)
CfnOutput(
self,
"ExpressManagedCertificateArn",
value=express_service.attr_ecs_managed_resource_arns_ingress_path_certificate_arn,
)
if enable_pi_express:
try:
pi_express_log_group = logs.LogGroup(
self,
"ExpressPiTaskLogGroup",
log_group_name=f"/ecs/{ECS_PI_EXPRESS_SERVICE_NAME}-logs".lower(),
retention=logs.RetentionDays.ONE_MONTH,
removal_policy=resource_removal_policy,
)
pi_express_log_group.grant_write(execution_role)
pi_express_security_group = ec2.SecurityGroup(
self,
"ExpressPiSecurityGroup",
vpc=vpc,
security_group_name=ECS_PI_EXPRESS_SECURITY_GROUP_NAME,
description="Pi agent ECS Express tasks",
)
pi_express_environment = build_pi_express_container_environment(
service_connect_discovery_name=ECS_SERVICE_CONNECT_DISCOVERY_NAME,
main_app_port=int(GRADIO_SERVER_PORT),
pi_gradio_port=int(PI_GRADIO_PORT),
cognito_auth=ENABLE_HEADLESS_DEPLOYMENT != "True",
)
pi_primary_container = build_express_pi_primary_container(
image_uri=pi_ecr_image_loc + ":latest",
container_port=int(PI_GRADIO_PORT),
log_group_name=pi_express_log_group.log_group_name,
aws_region=AWS_REGION,
environment=pi_express_environment,
secret=secret,
cognito_auth=ENABLE_HEADLESS_DEPLOYMENT != "True",
)
express_pi_service = create_express_gateway_service(
self,
"ExpressPiGatewayService",
service_name=ECS_PI_EXPRESS_SERVICE_NAME,
cluster_name=CLUSTER_NAME,
execution_role_arn=execution_role.role_arn,
infrastructure_role_arn=express_infra_role.role_arn,
task_role_arn=task_role.role_arn,
cpu=str(ECS_PI_TASK_CPU_SIZE),
memory=str(ECS_PI_TASK_MEMORY_SIZE),
health_check_path=ECS_PI_EXPRESS_HEALTH_CHECK_PATH,
primary_container=pi_primary_container,
subnet_ids=express_subnet_ids,
security_group_ids=[
pi_express_security_group.security_group_id
],
)
express_pi_service.node.add_dependency(cluster)
express_pi_service.node.add_dependency(express_service)
allow_express_load_balancer_to_ecs_security_group(
self,
"ExpressAlbToPiExpressIngress",
express_service=express_pi_service,
ecs_security_group=pi_express_security_group,
container_port=int(PI_GRADIO_PORT),
)
pi_express_security_group.add_egress_rule(
peer=ecs_security_group,
connection=ec2.Port.tcp(int(GRADIO_SERVER_PORT)),
description="Pi Express (Service Connect) to main summarisation app",
)
ecs_security_group.add_ingress_rule(
peer=pi_express_security_group,
connection=ec2.Port.tcp(int(GRADIO_SERVER_PORT)),
description="Pi Express (Service Connect) to main summarisation app",
)
# Service Connect for Express is applied in post_cdk_build_quickstart.py
# after CodeBuild pushes :latest. Express primary containers do not
# define named portMappings at create time; CDK cannot enable SC here.
pi_public_url = format_express_pi_public_url(
express_pi_service.attr_endpoint,
)
sc_backend = (
f"http://{ECS_SERVICE_CONNECT_DISCOVERY_NAME}:"
f"{GRADIO_SERVER_PORT}"
)
CfnOutput(
self,
"PiExpressEndpoint",
value=express_pi_service.attr_endpoint,
description="HTTPS URL for the Pi ECS Express service (AWS-managed cert)",
)
CfnOutput(
self,
"PiPublicUrl",
value=pi_public_url,
description="Public URL for Pi Express UI (managed HTTPS endpoint)",
)
CfnOutput(
self,
"PiDocSummarisationBackendUrl",
value=sc_backend,
description="DOC_SUMMARISATION_GRADIO_URL on Pi Express (Service Connect, no Cognito)",
)
CfnOutput(
self,
"PiExpressServiceName",
value=ECS_PI_EXPRESS_SERVICE_NAME,
)
CfnOutput(
self,
"ServiceConnectNamespace",
value=ECS_SERVICE_CONNECT_NAMESPACE,
description="Cloud Map namespace for Express Service Connect",
)
print(
"ECS Express Pi gateway service defined with Service Connect "
f"backend {sc_backend}; public URL: {pi_public_url}."
)
except Exception as e:
raise Exception(
"Could not handle ECS Express Pi agent due to:", e
)
print("ECS Express Gateway service defined.")
except Exception as e:
raise Exception("Could not handle ECS Express Mode due to:", e)
if not use_express_ingress:
# --- Fargate Task Definition ---
try:
fargate_task_definition_name = FARGATE_TASK_DEFINITION_NAME
read_only_file_system = ECS_READ_ONLY_FILE_SYSTEM == "True"
if os.path.exists(TASK_DEFINITION_FILE_LOCATION):
with open(TASK_DEFINITION_FILE_LOCATION) as f: # Use correct path
task_def_params = json.load(f)
# Need to ensure taskRoleArn and executionRoleArn in JSON are correct ARN strings
else:
epheremal_storage_volume_name = "appEphemeralVolume"
task_def_params = {}
task_def_params["taskRoleArn"] = (
task_role.role_arn
) # Use CDK role object ARN
task_def_params["executionRoleArn"] = (
execution_role.role_arn
) # Use CDK role object ARN
task_def_params["memory"] = ECS_TASK_MEMORY_SIZE
task_def_params["cpu"] = ECS_TASK_CPU_SIZE
container_def = {
"name": full_ecr_repo_name,
"image": ecr_image_loc + ":latest",
"essential": True,
"portMappings": [
{
"containerPort": int(GRADIO_SERVER_PORT),
"hostPort": int(GRADIO_SERVER_PORT),
"protocol": "tcp",
"appProtocol": "http",
}
],
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": ECS_LOG_GROUP_NAME,
"awslogs-region": AWS_REGION,
"awslogs-stream-prefix": "ecs",
},
},
"environmentFiles": (
[]
if enable_headless
else [
{
"value": bucket.bucket_arn
+ f"/{APP_CONFIG_ENV_BASENAME}",
"type": "s3",
}
]
),
"memoryReservation": int(task_def_params["memory"])
- 512, # Reserve some memory for the container
"mountPoints": [
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/home/user/app/logs",
"readOnly": False,
},
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/home/user/app/feedback",
"readOnly": False,
},
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/home/user/app/usage",
"readOnly": False,
},
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/home/user/app/input",
"readOnly": False,
},
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/home/user/app/output",
"readOnly": False,
},
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/home/user/app/config",
"readOnly": False,
},
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/tmp/matplotlib_cache",
"readOnly": False,
},
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/tmp",
"readOnly": False,
},
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/var/tmp",
"readOnly": False,
},
{
"sourceVolume": epheremal_storage_volume_name,
"containerPath": "/tmp/gradio_tmp",
"readOnly": False,
},
],
"readonlyRootFilesystem": read_only_file_system,
"user": "1000",
}
task_def_params["containerDefinitions"] = [container_def]
log_group_name_from_config = task_def_params["containerDefinitions"][0][
"logConfiguration"
]["options"]["awslogs-group"]
cdk_managed_log_group = logs.LogGroup(
self,
"MyTaskLogGroup", # CDK Logical ID
log_group_name=log_group_name_from_config,
retention=logs.RetentionDays.ONE_MONTH,
removal_policy=resource_removal_policy,
)
cdk_managed_log_group.grant_write(execution_role)
epheremal_storage_volume_cdk_obj = ecs.Volume(
name=epheremal_storage_volume_name
)
fargate_task_definition = ecs.FargateTaskDefinition(
self,
"FargateTaskDefinition", # Logical ID
family=fargate_task_definition_name,
cpu=int(task_def_params["cpu"]),
memory_limit_mib=int(task_def_params["memory"]),
task_role=task_role,
execution_role=execution_role,
runtime_platform=ecs.RuntimePlatform(
cpu_architecture=ecs.CpuArchitecture.X86_64,
operating_system_family=ecs.OperatingSystemFamily.LINUX,
),
ephemeral_storage_gib=21, # Minimum is 21 GiB
volumes=[epheremal_storage_volume_cdk_obj],
)
print("Fargate task definition defined.")
# Add container definitions to the task definition object
if task_def_params["containerDefinitions"]:
container_def_params = task_def_params["containerDefinitions"][0]
env_files = []
if container_def_params.get("environmentFiles"):
for env_file_param in container_def_params["environmentFiles"]:
# Need to parse the ARN to get the bucket object and key
env_file_arn_parts = env_file_param["value"].split(":::")
bucket_name_and_key = env_file_arn_parts[-1]
env_bucket_name, env_key = bucket_name_and_key.split("/", 1)
env_file = ecs.EnvironmentFile.from_bucket(bucket, env_key)
env_files.append(env_file)
container_kwargs: Dict[str, Any] = {
"image": ecs.ContainerImage.from_registry(
container_def_params["image"]
),
"logging": ecs.LogDriver.aws_logs(
stream_prefix=container_def_params["logConfiguration"][
"options"
]["awslogs-stream-prefix"],
log_group=cdk_managed_log_group,
),
"environment_files": env_files if env_files else None,
"readonly_root_filesystem": read_only_file_system,
"user": container_def_params.get("user", "1000"),
}
if not enable_headless:
container_kwargs["secrets"] = {
"AWS_USER_POOL_ID": ecs.Secret.from_secrets_manager(
secret, "SUMMARISATION_USER_POOL_ID"
),
"AWS_CLIENT_ID": ecs.Secret.from_secrets_manager(
secret, "SUMMARISATION_CLIENT_ID"
),
"AWS_CLIENT_SECRET": ecs.Secret.from_secrets_manager(
secret, "SUMMARISATION_CLIENT_SECRET"
),
}
container = fargate_task_definition.add_container(
container_def_params["name"],
**container_kwargs,
)
for port_mapping in container_def_params["portMappings"]:
container.add_port_mappings(
ecs.PortMapping(
container_port=int(port_mapping["containerPort"]),
host_port=int(port_mapping["hostPort"]),
name="port-" + str(port_mapping["containerPort"]),
app_protocol=ecs.AppProtocol.http,
protocol=ecs.Protocol.TCP,
)
)
container.add_port_mappings(
ecs.PortMapping(
container_port=80,
host_port=80,
name="port-80",
app_protocol=ecs.AppProtocol.http,
protocol=ecs.Protocol.TCP,
)
)
if container_def_params.get("mountPoints"):
mount_points = []
for mount_point in container_def_params["mountPoints"]:
mount_points.append(
ecs.MountPoint(
container_path=mount_point["containerPath"],
read_only=mount_point["readOnly"],
source_volume=epheremal_storage_volume_name,
)
)
container.add_mount_points(*mount_points)
except Exception as e:
raise Exception("Could not handle Fargate task definition due to:", e)
ecs_service = None
if deploy_web_ingress:
# --- ECS Service ---
try:
ecs_service_name = ECS_SERVICE_NAME
if ECS_USE_FARGATE_SPOT == "True":
use_fargate_spot = "FARGATE_SPOT"
if ECS_USE_FARGATE_SPOT == "False":
use_fargate_spot = "FARGATE"
# Check if service exists - from_service_arn or from_service_name (needs cluster)
try:
# from_service_name is useful if you have the cluster object
ecs_service = ecs.FargateService.from_service_attributes(
self,
"ECSService", # Logical ID
cluster=cluster, # Requires the cluster object
service_name=ecs_service_name,
)
print(f"Using existing ECS service {ecs_service_name}.")
if enable_service_connect:
print(
"Warning: ENABLE_ECS_SERVICE_CONNECT=True but an existing "
"ECS service was imported; enable Service Connect on that "
"service in the ECS console or replace the service via CDK."
)
except Exception:
service_connect_configuration = None
if enable_service_connect:
sc_dns_name = (
ECS_SERVICE_CONNECT_DNS_NAME
or ECS_SERVICE_CONNECT_DISCOVERY_NAME
)
service_connect_configuration = ecs.ServiceConnectProps(
namespace=ECS_SERVICE_CONNECT_NAMESPACE,
services=[
ecs.ServiceConnectService(
port_mapping_name=ECS_SERVICE_CONNECT_PORT_MAPPING_NAME,
discovery_name=ECS_SERVICE_CONNECT_DISCOVERY_NAME,
dns_name=sc_dns_name,
port=int(GRADIO_SERVER_PORT),
)
],
)
# Service will be created with a count of 0, because you haven't yet actually built the initial Docker container with CodeBuild
ecs_service = ecs.FargateService(
self,
"ECSService", # Logical ID
service_name=ecs_service_name, # Explicit resource name
platform_version=ecs.FargatePlatformVersion.LATEST,
capacity_provider_strategies=[
ecs.CapacityProviderStrategy(
capacity_provider=use_fargate_spot, base=0, weight=1
)
],
cluster=cluster,
task_definition=fargate_task_definition, # Link to TD
security_groups=[ecs_security_group], # Link to SG
vpc_subnets=ec2.SubnetSelection(
subnets=self.private_subnets
), # Link to subnets
min_healthy_percent=0,
max_healthy_percent=600,
desired_count=0,
availability_zone_rebalancing=ecs_availability_zone_rebalancing(
ECS_AVAILABILITY_ZONE_REBALANCING
),
service_connect_configuration=service_connect_configuration,
)
print("Successfully created new ECS service")
# Note: Auto-scaling setup would typically go here if needed for the service
except Exception as e:
raise Exception("Could not handle ECS service due to:", e)
if enable_pi_agent:
try:
pi_ecs_service, pi_ecs_security_group, _pi_task_def = (
create_pi_agent_ecs_resources(
self,
"PiAgent",
vpc=vpc,
cluster=cluster,
private_subnets=self.private_subnets,
pi_ecr_image_uri=pi_ecr_image_loc,
container_name=ECR_PI_REPO_NAME,
task_role=task_role,
execution_role=execution_role,
config_bucket=bucket,
pi_agent_env_s3_key=PI_AGENT_ENV_S3_KEY,
service_name=ECS_PI_SERVICE_NAME,
task_family=ECS_PI_TASK_DEFINITION_NAME,
security_group_name=ECS_PI_SECURITY_GROUP_NAME,
log_group_name=ECS_PI_LOG_GROUP_NAME,
cpu=int(ECS_PI_TASK_CPU_SIZE),
memory_mib=int(ECS_PI_TASK_MEMORY_SIZE),
pi_gradio_port=int(PI_GRADIO_PORT),
service_connect_namespace=ECS_SERVICE_CONNECT_NAMESPACE,
service_connect_discovery_name=ECS_SERVICE_CONNECT_DISCOVERY_NAME,
main_app_port=int(GRADIO_SERVER_PORT),
use_fargate_spot=use_fargate_spot,
pi_root_path=pi_alb_root_path_for_container(
PI_ALB_PATH_PREFIX_NORMALIZED, PI_ALB_ROUTING
),
)
)
ecs_security_group.add_ingress_rule(
peer=pi_ecs_security_group,
connection=ec2_port_gradio_server_port,
description="Pi agent (Service Connect) to main summarisation app",
)
print("Pi agent ECS service defined.")
except Exception as e:
raise Exception("Could not handle Pi agent ECS service due to:", e)
if ENABLE_S3_BATCH_ECS_TRIGGER == "True":
try:
batch_subnet_ids = [s.subnet_id for s in self.private_subnets]
if not batch_subnet_ids:
batch_subnet_ids = [s.subnet_id for s in self.public_subnets]
if not batch_subnet_ids:
raise ValueError(
"S3 batch ECS trigger requires at least one public or "
"private subnet."
)
lambda_asset_dir = os.path.join(
os.path.dirname(__file__), "config", "lambda"
)
batch_lambda = create_s3_batch_ecs_trigger_lambda(
self,
"S3BatchEcsTrigger",
function_name=S3_BATCH_LAMBDA_FUNCTION_NAME or None,
lambda_asset_path=lambda_asset_dir,
output_bucket=output_bucket,
config_bucket=bucket,
cluster_name=CLUSTER_NAME,
task_definition_arn=fargate_task_definition.task_definition_arn,
container_name=full_ecr_repo_name,
subnet_ids=batch_subnet_ids,
security_group_id=ecs_security_group.security_group_id,
execution_role=execution_role,
task_role=task_role,
env_prefix=S3_BATCH_ENV_PREFIX,
env_suffix=S3_BATCH_ENV_SUFFIX,
input_prefix=S3_BATCH_INPUT_PREFIX,
config_prefix=S3_BATCH_CONFIG_PREFIX,
default_params_key=S3_BATCH_DEFAULT_PARAMS_KEY,
general_env_prefix=S3_BATCH_GENERAL_ENV_PREFIX,
default_task_type="extract",
assign_public_ip=not bool(self.private_subnets),
)
CfnOutput(
self,
"BatchEcsTriggerLambdaArn",
value=batch_lambda.function_arn,
description="Lambda ARN for S3-triggered batch ECS tasks",
)
CfnOutput(
self,
"BatchJobEnvPrefix",
value=f"s3://{output_bucket.bucket_name}/{S3_BATCH_ENV_PREFIX}",
description="Upload job .env files here to start batch topic-modelling tasks",
)
CfnOutput(
self,
"BatchInputPrefix",
value=f"s3://{output_bucket.bucket_name}/{S3_BATCH_INPUT_PREFIX}",
description="Upload consultation spreadsheets and other input files for batch jobs",
)
CfnOutput(
self,
"BatchEcsTriggerLambdaName",
value=batch_lambda.function_name,
description="Lambda that starts ECS batch tasks on job .env upload",
)
if enable_headless:
seed_asset_dir = os.path.join(
os.path.dirname(__file__), "config", "headless_s3_seed"
)
create_headless_s3_batch_seed(
self,
"HeadlessBatchS3Seed",
destination_bucket=output_bucket,
seed_asset_directory=seed_asset_dir,
s3_outputs_bucket_name=output_bucket.bucket_name,
)
if (
enable_headless
and ENABLE_HEADLESS_OUTPUT_NOTIFICATIONS == "True"
):
create_headless_output_notifications(
self,
"HeadlessOutputNotifications",
output_bucket=output_bucket,
output_prefix=HEADLESS_OUTPUT_S3_PREFIX,
notify_email=HEADLESS_OUTPUT_NOTIFY_EMAIL,
iam_user_name=HEADLESS_OUTPUT_IAM_USER_NAME,
metric_filter_id=HEADLESS_OUTPUT_S3_METRIC_FILTER_ID,
sns_topic_name=HEADLESS_OUTPUT_SNS_TOPIC_NAME,
alarm_name=HEADLESS_OUTPUT_ALARM_NAME,
kms_key_arn=shared_kms_key_arn,
)
print(
"Headless output notifications enabled: S3 PutRequests "
f"alarm -> SNS ({HEADLESS_OUTPUT_NOTIFY_EMAIL})."
)
print("S3 batch ECS trigger Lambda defined.")
except Exception as e:
raise Exception("Could not handle S3 batch ECS trigger due to:", e)
if deploy_web_ingress:
# --- ALB TARGET GROUPS AND LISTENERS ---
# This section should primarily define the resources if they are managed by this stack.
# CDK handles adding/removing targets and actions on updates.
# If they might pre-exist outside the stack, you need lookups.
cookie_duration = Duration.hours(8)
target_group_name = ALB_TARGET_GROUP_NAME # Explicit resource name
cloudfront_distribution_url = "cloudfront_placeholder.net" # Need to replace this afterwards with the actual cloudfront_distribution.domain_name
cloudfront_http_rule_priority = (
PI_ALB_LISTENER_RULE_PRIORITY
+ (pi_listener_rule_count(PI_ALB_ROUTING) if enable_pi_agent else 0)
if enable_pi_agent
else 1
)
https_listener = None
try:
# --- CREATING TARGET GROUPS AND ADDING THE CLOUDFRONT LISTENER RULE ---
target_group = elbv2.ApplicationTargetGroup(
self,
"AppTargetGroup", # Logical ID
target_group_name=target_group_name, # Explicit resource name
port=int(GRADIO_SERVER_PORT), # Ensure port is int
protocol=elbv2.ApplicationProtocol.HTTP,
targets=[ecs_service], # Link to ECS Service
stickiness_cookie_duration=cookie_duration,
vpc=vpc, # Target Groups need VPC
)
print(f"ALB target group {target_group_name} defined.")
# First HTTP
listener_port = 80
# Check if Listener exists - from_listener_arn or lookup by port/ALB
http_listener = alb.add_listener(
"HttpListener", # Logical ID
port=listener_port,
open=False, # Be cautious with open=True, usually restrict source SG
)
print(f"ALB listener on port {listener_port} defined.")
if ACM_SSL_CERTIFICATE_ARN:
http_listener.add_action(
"DefaultAction", # Logical ID for the default action
action=elbv2.ListenerAction.redirect(
protocol="HTTPS",
host="#{host}",
port="443",
path="/#{path}",
query="#{query}",
),
)
else:
if USE_CLOUDFRONT == "True":
# The following default action can be added for the listener after a host header rule is added to the listener manually in the Console as suggested in the above comments.
http_listener.add_action(
"DefaultAction", # Logical ID for the default action
action=elbv2.ListenerAction.fixed_response(
status_code=403,
content_type="text/plain",
message_body="Access denied",
),
)
# Add the Listener Rule for the specific CloudFront Host Header
http_listener.add_action(
"CloudFrontHostHeaderRule",
action=elbv2.ListenerAction.forward(
target_groups=[target_group],
stickiness_duration=cookie_duration,
),
priority=cloudfront_http_rule_priority,
conditions=[
elbv2.ListenerCondition.host_headers(
[cloudfront_distribution_url]
) # May have to redefine url in console afterwards if not specified in config file
],
)
else:
# Add the Listener Rule for the specific CloudFront Host Header
http_listener.add_action(
"CloudFrontHostHeaderRule",
action=elbv2.ListenerAction.forward(
target_groups=[target_group],
stickiness_duration=cookie_duration,
),
priority=cloudfront_http_rule_priority,
)
print("Added targets and actions to ALB HTTP listener.")
# Now the same for HTTPS if you have an ACM certificate
if ACM_SSL_CERTIFICATE_ARN:
listener_port_https = 443
# Check if Listener exists - from_listener_arn or lookup by port/ALB
https_listener = add_alb_https_listener_with_cert(
self,
"MyHttpsListener", # Logical ID for the HTTPS listener
alb,
acm_certificate_arn=ACM_SSL_CERTIFICATE_ARN,
default_target_group=target_group,
enable_cognito_auth=True,
cognito_user_pool=user_pool,
cognito_user_pool_client=user_pool_client,
cognito_user_pool_domain=user_pool_domain,
listener_open_to_internet=True,
stickiness_cookie_duration=cookie_duration,
)
if https_listener:
CfnOutput(
self,
"HttpsListenerArn",
value=https_listener.listener_arn,
)
print(f"ALB listener on port {listener_port_https} defined.")
# if USE_CLOUDFRONT == 'True':
# # Add default action to the listener
# https_listener.add_action(
# "DefaultAction", # Logical ID for the default action
# action=elbv2.ListenerAction.fixed_response(
# status_code=403,
# content_type="text/plain",
# message_body="Access denied",
# ),
# )
# # Add the Listener Rule for the specific CloudFront Host Header
# https_listener.add_action(
# "CloudFrontHostHeaderRuleHTTPS",
# action=elbv2.ListenerAction.forward(target_groups=[target_group],stickiness_duration=cookie_duration),
# priority=1, # Example priority. Adjust as needed. Lower is evaluated first.
# conditions=[
# elbv2.ListenerCondition.host_headers([cloudfront_distribution_url])
# ]
# )
# else:
# https_listener.add_action(
# "CloudFrontHostHeaderRuleHTTPS",
# action=elbv2.ListenerAction.forward(target_groups=[target_group],stickiness_duration=cookie_duration))
print("Added targets and actions to ALB HTTPS listener.")
if enable_pi_agent and pi_ecs_service and alb_security_group:
pi_tg_name = PI_ALB_TARGET_GROUP_NAME
if len(pi_tg_name) > 32:
pi_tg_name = pi_tg_name[-32:]
_pi_public_urls = format_pi_public_urls(
routing_mode=PI_ALB_ROUTING,
path_prefix=PI_ALB_PATH_PREFIX_NORMALIZED,
host_header=PI_ALB_HOST_HEADER,
cloudfront_domain=(
CLOUDFRONT_DOMAIN if USE_CLOUDFRONT == "True" else ""
),
use_https=bool(ACM_SSL_CERTIFICATE_ARN),
)
attach_pi_agent_to_shared_alb(
self,
"PiAgent",
vpc=vpc,
alb_security_group=alb_security_group,
pi_security_group=pi_ecs_security_group,
pi_service=pi_ecs_service,
pi_port=int(PI_GRADIO_PORT),
routing_mode=PI_ALB_ROUTING,
path_prefix=PI_ALB_PATH_PREFIX_NORMALIZED,
pi_host_header=PI_ALB_HOST_HEADER.strip(),
listener_rule_priority=PI_ALB_LISTENER_RULE_PRIORITY,
target_group_name=pi_tg_name,
stickiness_cookie_duration=cookie_duration,
https_listener=https_listener,
http_listener=http_listener,
acm_certificate_arn=ACM_SSL_CERTIFICATE_ARN or "",
enable_cognito_auth=bool(ACM_SSL_CERTIFICATE_ARN),
cognito_user_pool=user_pool,
cognito_user_pool_client=user_pool_client,
cognito_user_pool_domain=user_pool_domain,
)
pi_public_url = _pi_public_urls[0] if _pi_public_urls else ""
CfnOutput(
self,
"PiPublicUrl",
value=pi_public_url,
description="Primary public URL for Pi agent UI (path and/or host ALB rules)",
)
if len(_pi_public_urls) > 1:
CfnOutput(
self,
"PiPublicUrls",
value=", ".join(_pi_public_urls),
description="All configured Pi UI entry URLs",
)
CfnOutput(
self,
"PiAlbPathPrefix",
value=PI_ALB_PATH_PREFIX_NORMALIZED,
description="ALB path prefix for Pi when PI_ALB_ROUTING includes path",
)
CfnOutput(
self,
"PiAgentServiceName",
value=ECS_PI_SERVICE_NAME,
)
sc_backend = (
f"http://{ECS_SERVICE_CONNECT_DISCOVERY_NAME}:"
f"{GRADIO_SERVER_PORT}"
)
CfnOutput(
self,
"PiDocSummarisationBackendUrl",
value=sc_backend,
description="DOC_SUMMARISATION_GRADIO_URL set on Pi tasks (Service Connect)",
)
print(
"Pi agent attached to shared ALB "
f"(routing={PI_ALB_ROUTING}, urls={', '.join(_pi_public_urls)})."
)
except Exception as e:
raise Exception(
"Could not handle ALB target groups and listeners due to:", e
)
if not enable_headless:
# Create WAF to attach to load balancer
try:
web_acl_name = LOAD_BALANCER_WEB_ACL_NAME
if get_context_bool(f"exists:{web_acl_name}"):
# Lookup WAF ACL by ARN from context
web_acl_arn = get_context_str(f"arn:{web_acl_name}")
if not web_acl_arn:
raise ValueError(
f"Context value 'arn:{web_acl_name}' is required if Web ACL exists."
)
web_acl = create_web_acl_with_common_rules(
self, web_acl_name, waf_scope="REGIONAL"
) # Assuming it takes scope and name
print(f"Handled ALB WAF web ACL {web_acl_name}.")
else:
web_acl = create_web_acl_with_common_rules(
self, web_acl_name, waf_scope="REGIONAL"
) # Assuming it takes scope and name
print(f"Created ALB WAF web ACL {web_acl_name}.")
wafv2.CfnWebACLAssociation(
self,
id="alb_waf_association",
resource_arn=alb.load_balancer_arn,
web_acl_arn=web_acl.attr_arn,
)
except Exception as e:
raise Exception("Could not handle create ALB WAF web ACL due to:", e)
# --- Outputs for other stacks/regions ---
self.params = dict()
self.params["alb_arn_output"] = alb.load_balancer_arn
if use_express_ingress:
self.params["alb_security_group_id"] = express_alb_security_group_id
else:
self.params["alb_security_group_id"] = (
alb_security_group.security_group_id
)
self.params["alb_dns_name"] = alb.load_balancer_dns_name
CfnOutput(
self,
"AlbArnOutput",
value=alb.load_balancer_arn,
description="ARN of the Application Load Balancer",
export_name=f"{self.stack_name}-AlbArn",
) # Export name must be unique within the account/region
CfnOutput(
self,
"AlbSecurityGroupIdOutput",
value=(
express_alb_security_group_id
if use_express_ingress
else alb_security_group.security_group_id
),
description="ID of the ALB's Security Group",
export_name=f"{self.stack_name}-AlbSgId",
)
CfnOutput(self, "ALBName", value=load_balancer_name)
CfnOutput(self, "RegionalAlbDnsName", value=alb.load_balancer_dns_name)
else:
self.params = dict()
CfnOutput(
self,
"HeadlessDeploymentMode",
value="True",
description="Stack deployed for S3-triggered direct-mode batch tasks only",
)
CfnOutput(
self,
"ECSClusterName",
value=CLUSTER_NAME,
description="ECS cluster used for one-shot Fargate batch tasks",
)
CfnOutput(
self,
"EcsBatchLogGroup",
value=ECS_LOG_GROUP_NAME,
description=(
"CloudWatch log group for batch tasks (streams appear only after "
"the container starts; init failures may have no stream)"
),
)
if not enable_headless and user_pool is not None:
CfnOutput(self, "CognitoPoolId", value=user_pool.user_pool_id)
# Add other outputs if needed
CfnOutput(self, "ECRRepoUri", value=ecr_repo.repository_uri)
if enable_service_connect:
sc_host = ECS_SERVICE_CONNECT_DNS_NAME or ECS_SERVICE_CONNECT_DISCOVERY_NAME
sc_base = f"http://{sc_host}:{GRADIO_SERVER_PORT}"
CfnOutput(
self,
"ServiceConnectHttpBaseUrl",
value=sc_base,
description="Base URL for other ECS services in this cluster (Service Connect)",
)
CfnOutput(
self,
"ServiceConnectAgentApiUrl",
value=f"{sc_base}/agent",
description="FastAPI Agent API prefix (when RUN_FASTAPI=True in app_config.env)",
)
CfnOutput(
self,
"ServiceConnectNamespace",
value=ECS_SERVICE_CONNECT_NAMESPACE,
)
# --- CLOUDFRONT DISTRIBUTION in separate stack (us-east-1 required) ---
class CdkStackCloudfront(Stack):
def __init__(
self,
scope: Construct,
construct_id: str,
alb_arn: str,
alb_sec_group_id: str,
alb_dns_name: str,
**kwargs,
) -> None:
super().__init__(scope, construct_id, **kwargs)
# --- Helper to get context values ---
def get_context_bool(key: str, default: bool = False) -> bool:
return self.node.try_get_context(key) or default
def get_context_str(key: str, default: str = None) -> str:
return self.node.try_get_context(key) or default
def get_context_dict(scope: Construct, key: str, default: dict = None) -> dict:
return scope.node.try_get_context(key) or default
resource_removal_policy = managed_resource_removal_policy()
print(f"CloudFront Stack: Received ALB ARN: {alb_arn}")
print(f"CloudFront Stack: Received ALB Security Group ID: {alb_sec_group_id}")
if not alb_arn:
raise ValueError("ALB ARN must be provided to CloudFront stack")
if not alb_sec_group_id:
raise ValueError(
"ALB Security Group ID must be provided to CloudFront stack"
)
# 2. Import the ALB using its ARN
# This imports an existing ALB as a construct in the CloudFront stack's context.
# CloudFormation will understand this reference at deploy time.
alb = elbv2.ApplicationLoadBalancer.from_application_load_balancer_attributes(
self,
"ImportedAlb",
load_balancer_arn=alb_arn,
security_group_id=alb_sec_group_id,
load_balancer_dns_name=alb_dns_name,
)
try:
web_acl_name = WEB_ACL_NAME
if get_context_bool(f"exists:{web_acl_name}"):
# Lookup WAF ACL by ARN from context
web_acl_arn = get_context_str(f"arn:{web_acl_name}")
if not web_acl_arn:
raise ValueError(
f"Context value 'arn:{web_acl_name}' is required if Web ACL exists."
)
web_acl = create_web_acl_with_common_rules(
self, web_acl_name
) # Assuming it takes scope and name
print(f"Handled Cloudfront WAF web ACL {web_acl_name}.")
else:
web_acl = create_web_acl_with_common_rules(
self, web_acl_name
) # Assuming it takes scope and name
print(f"Created Cloudfront WAF web ACL {web_acl_name}.")
# Add ALB as CloudFront Origin
origin = origins.LoadBalancerV2Origin(
alb, # Use the created or looked-up ALB object
custom_headers={CUSTOM_HEADER: CUSTOM_HEADER_VALUE},
origin_shield_enabled=False,
protocol_policy=cloudfront.OriginProtocolPolicy.HTTP_ONLY,
)
if CLOUDFRONT_GEO_RESTRICTION:
geo_restrict = cloudfront.GeoRestriction.allowlist(
CLOUDFRONT_GEO_RESTRICTION
)
else:
geo_restrict = None
response_headers_policy = None
if CLOUDFRONT_ENABLE_SECURE_RESPONSE_HEADERS == "True":
app_origin, cognito_login_url = resolve_cloudfront_csp_urls(
cognito_redirection_url=COGNITO_REDIRECTION_URL,
cloudfront_domain=CLOUDFRONT_DOMAIN,
cognito_user_pool_domain_prefix=COGNITO_USER_POOL_DOMAIN_PREFIX,
aws_region=AWS_REGION,
cognito_user_pool_login_url=COGNITO_USER_POOL_LOGIN_URL,
ssl_certificate_domain=SSL_CERTIFICATE_DOMAIN,
)
policy_name = f"{CDK_PREFIX}SecureResponseHeaders"[:128]
response_headers_policy = (
create_secure_cloudfront_response_headers_policy(
self,
"SecureResponseHeadersPolicy",
policy_name=policy_name,
app_origin=app_origin,
cognito_login_url=cognito_login_url,
)
)
print(
"CloudFront secure response headers: "
f"app_origin={app_origin}, cognito_login_url={cognito_login_url}"
)
default_behavior = cloudfront.BehaviorOptions(
origin=origin,
viewer_protocol_policy=cloudfront.ViewerProtocolPolicy.REDIRECT_TO_HTTPS,
allowed_methods=cloudfront.AllowedMethods.ALLOW_ALL,
cache_policy=cloudfront.CachePolicy.CACHING_DISABLED,
origin_request_policy=cloudfront.OriginRequestPolicy.ALL_VIEWER,
response_headers_policy=response_headers_policy,
)
cloudfront_distribution = cloudfront.Distribution(
self,
"CloudFrontDistribution", # Logical ID
comment=CLOUDFRONT_DISTRIBUTION_NAME, # Use name as comment for easier identification
geo_restriction=geo_restrict,
default_behavior=default_behavior,
web_acl_id=web_acl.attr_arn,
)
cloudfront_distribution.apply_removal_policy(resource_removal_policy)
print(f"Cloudfront distribution {CLOUDFRONT_DISTRIBUTION_NAME} defined.")
except Exception as e:
raise Exception("Could not handle Cloudfront distribution due to:", e)
# --- Outputs ---
CfnOutput(
self, "CloudFrontDistributionURL", value=cloudfront_distribution.domain_name
)