Spaces:
Sleeping
Sleeping
| """ | |
| AWS Athena query interface for data lake access. | |
| Provides methods to execute SQL queries against Athena and retrieve results | |
| as pandas DataFrames. | |
| """ | |
| from typing import Optional, List, Dict, Any | |
| import time | |
| import pandas as pd | |
| import boto3 | |
| from botocore.exceptions import ClientError | |
| from urllib.parse import urlparse | |
| import io | |
| from .config import DataLakeConfig | |
| from .logger import setup_logger | |
| logger = setup_logger(__name__) | |
| class AthenaQuery: | |
| """ | |
| AWS Athena query interface. | |
| Executes SQL queries against Athena and retrieves results as pandas DataFrames. | |
| Handles query execution, polling, and result retrieval. | |
| """ | |
| def __init__(self, config: DataLakeConfig): | |
| """ | |
| Initialize Athena query interface. | |
| Args: | |
| config: DataLakeConfig instance with Athena configuration | |
| """ | |
| self.config = config | |
| session = config.get_boto3_session() | |
| self.athena_client = session.client('athena', region_name=config.region) | |
| self.s3_client = session.client('s3', region_name=config.region) | |
| logger.info(f"Initialized Athena client for database: {config.database_name}") | |
| def execute_query( | |
| self, | |
| query: str, | |
| wait: bool = True, | |
| timeout: int = 300, | |
| ) -> Optional[str]: | |
| """ | |
| Execute SQL query in Athena. | |
| Args: | |
| query: SQL query string | |
| wait: If True, wait for query to complete and return execution ID | |
| timeout: Maximum time to wait for query completion (seconds) | |
| Returns: | |
| Query execution ID (if wait=False) or execution ID after completion (if wait=True) | |
| Raises: | |
| ClientError: If query execution fails | |
| TimeoutError: If query exceeds timeout | |
| """ | |
| query_execution_config = { | |
| 'Database': self.config.database_name, | |
| } | |
| # OutputLocation should be in ResultConfiguration | |
| result_configuration = { | |
| 'OutputLocation': self.config.s3_output_location, | |
| } | |
| logger.debug(f"Executing query: {query[:100]}...") | |
| try: | |
| start_params = { | |
| 'QueryString': query, | |
| 'QueryExecutionContext': query_execution_config, | |
| 'ResultConfiguration': result_configuration, | |
| } | |
| # WorkGroup is a separate parameter, not in QueryExecutionContext | |
| if self.config.workgroup: | |
| start_params['WorkGroup'] = self.config.workgroup | |
| response = self.athena_client.start_query_execution(**start_params) | |
| execution_id = response['QueryExecutionId'] | |
| logger.info(f"Query started with execution ID: {execution_id}") | |
| if not wait: | |
| return execution_id | |
| # Wait for query to complete | |
| return self._wait_for_completion(execution_id, timeout) | |
| except ClientError as e: | |
| logger.error(f"Query execution failed: {e}") | |
| raise | |
| def _wait_for_completion(self, execution_id: str, timeout: int = 300) -> str: | |
| """ | |
| Wait for query execution to complete. | |
| Args: | |
| execution_id: Query execution ID | |
| timeout: Maximum time to wait (seconds) | |
| Returns: | |
| Execution ID | |
| Raises: | |
| TimeoutError: If query exceeds timeout | |
| RuntimeError: If query fails | |
| """ | |
| start_time = time.time() | |
| while True: | |
| response = self.athena_client.get_query_execution(QueryExecutionId=execution_id) | |
| status = response['QueryExecution']['Status']['State'] | |
| if status == 'SUCCEEDED': | |
| logger.info(f"Query {execution_id} completed successfully") | |
| return execution_id | |
| elif status == 'FAILED': | |
| reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') | |
| logger.error(f"Query {execution_id} failed: {reason}") | |
| raise RuntimeError(f"Query failed: {reason}") | |
| elif status == 'CANCELLED': | |
| logger.warning(f"Query {execution_id} was cancelled") | |
| raise RuntimeError("Query was cancelled") | |
| elapsed = time.time() - start_time | |
| if elapsed > timeout: | |
| raise TimeoutError(f"Query {execution_id} exceeded timeout of {timeout} seconds") | |
| time.sleep(1) # Poll every second | |
| def get_query_results(self, execution_id: str) -> pd.DataFrame: | |
| """ | |
| Get query results as pandas DataFrame. | |
| Optimized to read directly from S3 for large result sets, which is | |
| exponentially faster than paginated API calls. | |
| Args: | |
| execution_id: Query execution ID | |
| Returns: | |
| DataFrame with query results | |
| Raises: | |
| ClientError: If results cannot be retrieved | |
| """ | |
| logger.debug(f"Retrieving results for execution {execution_id}") | |
| # Try to read from S3 first (much faster for large result sets) | |
| try: | |
| return self._get_results_from_s3(execution_id) | |
| except Exception as e: | |
| logger.debug(f"Failed to read from S3, falling back to API: {e}") | |
| # Fall back to API method for backward compatibility | |
| return self._get_results_from_api(execution_id) | |
| def _get_results_from_s3(self, execution_id: str) -> pd.DataFrame: | |
| """ | |
| Get query results directly from S3 CSV file. | |
| This is exponentially faster than paginated API calls because: | |
| - Single file read vs hundreds/thousands of API calls | |
| - Pandas reads CSV in optimized C code | |
| - No row-by-row Python processing overhead | |
| Args: | |
| execution_id: Query execution ID | |
| Returns: | |
| DataFrame with query results | |
| Raises: | |
| Exception: If S3 read fails | |
| """ | |
| # Get query execution details to find S3 result location | |
| response = self.athena_client.get_query_execution(QueryExecutionId=execution_id) | |
| result_location = response['QueryExecution']['ResultConfiguration']['OutputLocation'] | |
| # Parse S3 URI: s3://bucket/path/to/file.csv | |
| parsed = urlparse(result_location) | |
| bucket = parsed.netloc | |
| key = parsed.path.lstrip('/') | |
| logger.debug(f"Reading results from s3://{bucket}/{key}") | |
| # Read CSV directly from S3 | |
| obj = self.s3_client.get_object(Bucket=bucket, Key=key) | |
| csv_content = obj['Body'].read() | |
| # Parse CSV with pandas (much faster than row-by-row processing) | |
| # Read as strings first to match original API behavior, then parse types | |
| df = pd.read_csv(io.BytesIO(csv_content), dtype=str, keep_default_na=False) | |
| # Apply type parsing to match original behavior | |
| # Convert to string first to handle any edge cases, then parse | |
| for col in df.columns: | |
| df[col] = df[col].astype(str).apply(self._parse_value) | |
| logger.info(f"Retrieved {len(df)} rows from S3 for query {execution_id}") | |
| return df | |
| def _get_results_from_api(self, execution_id: str) -> pd.DataFrame: | |
| """ | |
| Get query results using paginated API calls (fallback method). | |
| This is the original implementation, kept for backward compatibility | |
| when S3 read fails. | |
| Args: | |
| execution_id: Query execution ID | |
| Returns: | |
| DataFrame with query results | |
| Raises: | |
| ClientError: If results cannot be retrieved | |
| """ | |
| logger.debug(f"Using API method for execution {execution_id}") | |
| # Get result set | |
| paginator = self.athena_client.get_paginator('get_query_results') | |
| pages = paginator.paginate(QueryExecutionId=execution_id) | |
| rows = [] | |
| column_names = None | |
| for page in pages: | |
| result_set = page['ResultSet'] | |
| # Get column names from first page | |
| if column_names is None: | |
| column_names = [col['Name'] for col in result_set['ResultSetMetadata']['ColumnInfo']] | |
| # Get data rows (skip header row) | |
| for row in result_set['Rows'][1:]: # Skip header | |
| values = [self._parse_value(cell.get('VarCharValue', '')) | |
| for cell in row['Data']] | |
| rows.append(values) | |
| if not rows: | |
| logger.warning(f"No results returned for execution {execution_id}") | |
| return pd.DataFrame(columns=column_names or []) | |
| df = pd.DataFrame(rows, columns=column_names) | |
| logger.info(f"Retrieved {len(df)} rows from query {execution_id}") | |
| return df | |
| def _parse_value(self, value: str) -> Any: | |
| """ | |
| Parse string value to appropriate Python type. | |
| Args: | |
| value: String value from Athena result | |
| Returns: | |
| Parsed value (int, float, bool, or str) | |
| """ | |
| if value == '' or value is None: | |
| return None | |
| # Try to parse as number | |
| try: | |
| if '.' in value: | |
| return float(value) | |
| return int(value) | |
| except ValueError: | |
| pass | |
| # Try to parse as boolean | |
| if value.lower() in ('true', 'false'): | |
| return value.lower() == 'true' | |
| return value | |
| def query_to_dataframe( | |
| self, | |
| query: str, | |
| timeout: int = 300, | |
| ) -> pd.DataFrame: | |
| """ | |
| Execute query and return results as DataFrame. | |
| Convenience method that combines execute_query and get_query_results. | |
| Args: | |
| query: SQL query string | |
| timeout: Maximum time to wait for query completion (seconds) | |
| Returns: | |
| DataFrame with query results | |
| """ | |
| execution_id = self.execute_query(query, wait=True, timeout=timeout) | |
| return self.get_query_results(execution_id) | |
| def list_tables(self, schema: Optional[str] = None) -> List[str]: | |
| """ | |
| List tables in the database. | |
| Args: | |
| schema: Optional schema name (defaults to database) | |
| Returns: | |
| List of table names | |
| """ | |
| if schema is None: | |
| schema = self.config.database_name | |
| query = f""" | |
| SELECT table_name | |
| FROM information_schema.tables | |
| WHERE table_schema = '{schema}' | |
| ORDER BY table_name | |
| """ | |
| try: | |
| df = self.query_to_dataframe(query) | |
| return df['table_name'].tolist() if not df.empty else [] | |
| except Exception as e: | |
| logger.error(f"Failed to list tables: {e}") | |
| return [] | |
| def describe_table(self, table_name: str, schema: Optional[str] = None) -> pd.DataFrame: | |
| """ | |
| Get table schema/columns. | |
| Args: | |
| table_name: Table name | |
| schema: Optional schema name (defaults to database) | |
| Returns: | |
| DataFrame with column information (column_name, data_type, etc.) | |
| """ | |
| if schema is None: | |
| schema = self.config.database_name | |
| query = f""" | |
| SELECT | |
| column_name, | |
| data_type, | |
| is_nullable | |
| FROM information_schema.columns | |
| WHERE table_schema = '{schema}' | |
| AND table_name = '{table_name}' | |
| ORDER BY ordinal_position | |
| """ | |
| try: | |
| return self.query_to_dataframe(query) | |
| except Exception as e: | |
| logger.error(f"Failed to describe table {table_name}: {e}") | |
| return pd.DataFrame() | |