""" AWS Athena query interface for data lake access. Provides methods to execute SQL queries against Athena and retrieve results as pandas DataFrames. """ from typing import Optional, List, Dict, Any import time import pandas as pd import boto3 from botocore.exceptions import ClientError from urllib.parse import urlparse import io from .config import DataLakeConfig from .logger import setup_logger logger = setup_logger(__name__) class AthenaQuery: """ AWS Athena query interface. Executes SQL queries against Athena and retrieves results as pandas DataFrames. Handles query execution, polling, and result retrieval. """ def __init__(self, config: DataLakeConfig): """ Initialize Athena query interface. Args: config: DataLakeConfig instance with Athena configuration """ self.config = config session = config.get_boto3_session() self.athena_client = session.client('athena', region_name=config.region) self.s3_client = session.client('s3', region_name=config.region) logger.info(f"Initialized Athena client for database: {config.database_name}") def execute_query( self, query: str, wait: bool = True, timeout: int = 300, ) -> Optional[str]: """ Execute SQL query in Athena. Args: query: SQL query string wait: If True, wait for query to complete and return execution ID timeout: Maximum time to wait for query completion (seconds) Returns: Query execution ID (if wait=False) or execution ID after completion (if wait=True) Raises: ClientError: If query execution fails TimeoutError: If query exceeds timeout """ query_execution_config = { 'Database': self.config.database_name, } # OutputLocation should be in ResultConfiguration result_configuration = { 'OutputLocation': self.config.s3_output_location, } logger.debug(f"Executing query: {query[:100]}...") try: start_params = { 'QueryString': query, 'QueryExecutionContext': query_execution_config, 'ResultConfiguration': result_configuration, } # WorkGroup is a separate parameter, not in QueryExecutionContext if self.config.workgroup: start_params['WorkGroup'] = self.config.workgroup response = self.athena_client.start_query_execution(**start_params) execution_id = response['QueryExecutionId'] logger.info(f"Query started with execution ID: {execution_id}") if not wait: return execution_id # Wait for query to complete return self._wait_for_completion(execution_id, timeout) except ClientError as e: logger.error(f"Query execution failed: {e}") raise def _wait_for_completion(self, execution_id: str, timeout: int = 300) -> str: """ Wait for query execution to complete. Args: execution_id: Query execution ID timeout: Maximum time to wait (seconds) Returns: Execution ID Raises: TimeoutError: If query exceeds timeout RuntimeError: If query fails """ start_time = time.time() while True: response = self.athena_client.get_query_execution(QueryExecutionId=execution_id) status = response['QueryExecution']['Status']['State'] if status == 'SUCCEEDED': logger.info(f"Query {execution_id} completed successfully") return execution_id elif status == 'FAILED': reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') logger.error(f"Query {execution_id} failed: {reason}") raise RuntimeError(f"Query failed: {reason}") elif status == 'CANCELLED': logger.warning(f"Query {execution_id} was cancelled") raise RuntimeError("Query was cancelled") elapsed = time.time() - start_time if elapsed > timeout: raise TimeoutError(f"Query {execution_id} exceeded timeout of {timeout} seconds") time.sleep(1) # Poll every second def get_query_results(self, execution_id: str) -> pd.DataFrame: """ Get query results as pandas DataFrame. Optimized to read directly from S3 for large result sets, which is exponentially faster than paginated API calls. Args: execution_id: Query execution ID Returns: DataFrame with query results Raises: ClientError: If results cannot be retrieved """ logger.debug(f"Retrieving results for execution {execution_id}") # Try to read from S3 first (much faster for large result sets) try: return self._get_results_from_s3(execution_id) except Exception as e: logger.debug(f"Failed to read from S3, falling back to API: {e}") # Fall back to API method for backward compatibility return self._get_results_from_api(execution_id) def _get_results_from_s3(self, execution_id: str) -> pd.DataFrame: """ Get query results directly from S3 CSV file. This is exponentially faster than paginated API calls because: - Single file read vs hundreds/thousands of API calls - Pandas reads CSV in optimized C code - No row-by-row Python processing overhead Args: execution_id: Query execution ID Returns: DataFrame with query results Raises: Exception: If S3 read fails """ # Get query execution details to find S3 result location response = self.athena_client.get_query_execution(QueryExecutionId=execution_id) result_location = response['QueryExecution']['ResultConfiguration']['OutputLocation'] # Parse S3 URI: s3://bucket/path/to/file.csv parsed = urlparse(result_location) bucket = parsed.netloc key = parsed.path.lstrip('/') logger.debug(f"Reading results from s3://{bucket}/{key}") # Read CSV directly from S3 obj = self.s3_client.get_object(Bucket=bucket, Key=key) csv_content = obj['Body'].read() # Parse CSV with pandas (much faster than row-by-row processing) # Read as strings first to match original API behavior, then parse types df = pd.read_csv(io.BytesIO(csv_content), dtype=str, keep_default_na=False) # Apply type parsing to match original behavior # Convert to string first to handle any edge cases, then parse for col in df.columns: df[col] = df[col].astype(str).apply(self._parse_value) logger.info(f"Retrieved {len(df)} rows from S3 for query {execution_id}") return df def _get_results_from_api(self, execution_id: str) -> pd.DataFrame: """ Get query results using paginated API calls (fallback method). This is the original implementation, kept for backward compatibility when S3 read fails. Args: execution_id: Query execution ID Returns: DataFrame with query results Raises: ClientError: If results cannot be retrieved """ logger.debug(f"Using API method for execution {execution_id}") # Get result set paginator = self.athena_client.get_paginator('get_query_results') pages = paginator.paginate(QueryExecutionId=execution_id) rows = [] column_names = None for page in pages: result_set = page['ResultSet'] # Get column names from first page if column_names is None: column_names = [col['Name'] for col in result_set['ResultSetMetadata']['ColumnInfo']] # Get data rows (skip header row) for row in result_set['Rows'][1:]: # Skip header values = [self._parse_value(cell.get('VarCharValue', '')) for cell in row['Data']] rows.append(values) if not rows: logger.warning(f"No results returned for execution {execution_id}") return pd.DataFrame(columns=column_names or []) df = pd.DataFrame(rows, columns=column_names) logger.info(f"Retrieved {len(df)} rows from query {execution_id}") return df def _parse_value(self, value: str) -> Any: """ Parse string value to appropriate Python type. Args: value: String value from Athena result Returns: Parsed value (int, float, bool, or str) """ if value == '' or value is None: return None # Try to parse as number try: if '.' in value: return float(value) return int(value) except ValueError: pass # Try to parse as boolean if value.lower() in ('true', 'false'): return value.lower() == 'true' return value def query_to_dataframe( self, query: str, timeout: int = 300, ) -> pd.DataFrame: """ Execute query and return results as DataFrame. Convenience method that combines execute_query and get_query_results. Args: query: SQL query string timeout: Maximum time to wait for query completion (seconds) Returns: DataFrame with query results """ execution_id = self.execute_query(query, wait=True, timeout=timeout) return self.get_query_results(execution_id) def list_tables(self, schema: Optional[str] = None) -> List[str]: """ List tables in the database. Args: schema: Optional schema name (defaults to database) Returns: List of table names """ if schema is None: schema = self.config.database_name query = f""" SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema}' ORDER BY table_name """ try: df = self.query_to_dataframe(query) return df['table_name'].tolist() if not df.empty else [] except Exception as e: logger.error(f"Failed to list tables: {e}") return [] def describe_table(self, table_name: str, schema: Optional[str] = None) -> pd.DataFrame: """ Get table schema/columns. Args: table_name: Table name schema: Optional schema name (defaults to database) Returns: DataFrame with column information (column_name, data_type, etc.) """ if schema is None: schema = self.config.database_name query = f""" SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_schema = '{schema}' AND table_name = '{table_name}' ORDER BY ordinal_position """ try: return self.query_to_dataframe(query) except Exception as e: logger.error(f"Failed to describe table {table_name}: {e}") return pd.DataFrame()