arash7920's picture
Upload 38 files
e869d90 verified
"""
AWS Athena query interface for data lake access.
Provides methods to execute SQL queries against Athena and retrieve results
as pandas DataFrames.
"""
from typing import Optional, List, Dict, Any
import time
import pandas as pd
import boto3
from botocore.exceptions import ClientError
from urllib.parse import urlparse
import io
from .config import DataLakeConfig
from .logger import setup_logger
logger = setup_logger(__name__)
class AthenaQuery:
"""
AWS Athena query interface.
Executes SQL queries against Athena and retrieves results as pandas DataFrames.
Handles query execution, polling, and result retrieval.
"""
def __init__(self, config: DataLakeConfig):
"""
Initialize Athena query interface.
Args:
config: DataLakeConfig instance with Athena configuration
"""
self.config = config
session = config.get_boto3_session()
self.athena_client = session.client('athena', region_name=config.region)
self.s3_client = session.client('s3', region_name=config.region)
logger.info(f"Initialized Athena client for database: {config.database_name}")
def execute_query(
self,
query: str,
wait: bool = True,
timeout: int = 300,
) -> Optional[str]:
"""
Execute SQL query in Athena.
Args:
query: SQL query string
wait: If True, wait for query to complete and return execution ID
timeout: Maximum time to wait for query completion (seconds)
Returns:
Query execution ID (if wait=False) or execution ID after completion (if wait=True)
Raises:
ClientError: If query execution fails
TimeoutError: If query exceeds timeout
"""
query_execution_config = {
'Database': self.config.database_name,
}
# OutputLocation should be in ResultConfiguration
result_configuration = {
'OutputLocation': self.config.s3_output_location,
}
logger.debug(f"Executing query: {query[:100]}...")
try:
start_params = {
'QueryString': query,
'QueryExecutionContext': query_execution_config,
'ResultConfiguration': result_configuration,
}
# WorkGroup is a separate parameter, not in QueryExecutionContext
if self.config.workgroup:
start_params['WorkGroup'] = self.config.workgroup
response = self.athena_client.start_query_execution(**start_params)
execution_id = response['QueryExecutionId']
logger.info(f"Query started with execution ID: {execution_id}")
if not wait:
return execution_id
# Wait for query to complete
return self._wait_for_completion(execution_id, timeout)
except ClientError as e:
logger.error(f"Query execution failed: {e}")
raise
def _wait_for_completion(self, execution_id: str, timeout: int = 300) -> str:
"""
Wait for query execution to complete.
Args:
execution_id: Query execution ID
timeout: Maximum time to wait (seconds)
Returns:
Execution ID
Raises:
TimeoutError: If query exceeds timeout
RuntimeError: If query fails
"""
start_time = time.time()
while True:
response = self.athena_client.get_query_execution(QueryExecutionId=execution_id)
status = response['QueryExecution']['Status']['State']
if status == 'SUCCEEDED':
logger.info(f"Query {execution_id} completed successfully")
return execution_id
elif status == 'FAILED':
reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error')
logger.error(f"Query {execution_id} failed: {reason}")
raise RuntimeError(f"Query failed: {reason}")
elif status == 'CANCELLED':
logger.warning(f"Query {execution_id} was cancelled")
raise RuntimeError("Query was cancelled")
elapsed = time.time() - start_time
if elapsed > timeout:
raise TimeoutError(f"Query {execution_id} exceeded timeout of {timeout} seconds")
time.sleep(1) # Poll every second
def get_query_results(self, execution_id: str) -> pd.DataFrame:
"""
Get query results as pandas DataFrame.
Optimized to read directly from S3 for large result sets, which is
exponentially faster than paginated API calls.
Args:
execution_id: Query execution ID
Returns:
DataFrame with query results
Raises:
ClientError: If results cannot be retrieved
"""
logger.debug(f"Retrieving results for execution {execution_id}")
# Try to read from S3 first (much faster for large result sets)
try:
return self._get_results_from_s3(execution_id)
except Exception as e:
logger.debug(f"Failed to read from S3, falling back to API: {e}")
# Fall back to API method for backward compatibility
return self._get_results_from_api(execution_id)
def _get_results_from_s3(self, execution_id: str) -> pd.DataFrame:
"""
Get query results directly from S3 CSV file.
This is exponentially faster than paginated API calls because:
- Single file read vs hundreds/thousands of API calls
- Pandas reads CSV in optimized C code
- No row-by-row Python processing overhead
Args:
execution_id: Query execution ID
Returns:
DataFrame with query results
Raises:
Exception: If S3 read fails
"""
# Get query execution details to find S3 result location
response = self.athena_client.get_query_execution(QueryExecutionId=execution_id)
result_location = response['QueryExecution']['ResultConfiguration']['OutputLocation']
# Parse S3 URI: s3://bucket/path/to/file.csv
parsed = urlparse(result_location)
bucket = parsed.netloc
key = parsed.path.lstrip('/')
logger.debug(f"Reading results from s3://{bucket}/{key}")
# Read CSV directly from S3
obj = self.s3_client.get_object(Bucket=bucket, Key=key)
csv_content = obj['Body'].read()
# Parse CSV with pandas (much faster than row-by-row processing)
# Read as strings first to match original API behavior, then parse types
df = pd.read_csv(io.BytesIO(csv_content), dtype=str, keep_default_na=False)
# Apply type parsing to match original behavior
# Convert to string first to handle any edge cases, then parse
for col in df.columns:
df[col] = df[col].astype(str).apply(self._parse_value)
logger.info(f"Retrieved {len(df)} rows from S3 for query {execution_id}")
return df
def _get_results_from_api(self, execution_id: str) -> pd.DataFrame:
"""
Get query results using paginated API calls (fallback method).
This is the original implementation, kept for backward compatibility
when S3 read fails.
Args:
execution_id: Query execution ID
Returns:
DataFrame with query results
Raises:
ClientError: If results cannot be retrieved
"""
logger.debug(f"Using API method for execution {execution_id}")
# Get result set
paginator = self.athena_client.get_paginator('get_query_results')
pages = paginator.paginate(QueryExecutionId=execution_id)
rows = []
column_names = None
for page in pages:
result_set = page['ResultSet']
# Get column names from first page
if column_names is None:
column_names = [col['Name'] for col in result_set['ResultSetMetadata']['ColumnInfo']]
# Get data rows (skip header row)
for row in result_set['Rows'][1:]: # Skip header
values = [self._parse_value(cell.get('VarCharValue', ''))
for cell in row['Data']]
rows.append(values)
if not rows:
logger.warning(f"No results returned for execution {execution_id}")
return pd.DataFrame(columns=column_names or [])
df = pd.DataFrame(rows, columns=column_names)
logger.info(f"Retrieved {len(df)} rows from query {execution_id}")
return df
def _parse_value(self, value: str) -> Any:
"""
Parse string value to appropriate Python type.
Args:
value: String value from Athena result
Returns:
Parsed value (int, float, bool, or str)
"""
if value == '' or value is None:
return None
# Try to parse as number
try:
if '.' in value:
return float(value)
return int(value)
except ValueError:
pass
# Try to parse as boolean
if value.lower() in ('true', 'false'):
return value.lower() == 'true'
return value
def query_to_dataframe(
self,
query: str,
timeout: int = 300,
) -> pd.DataFrame:
"""
Execute query and return results as DataFrame.
Convenience method that combines execute_query and get_query_results.
Args:
query: SQL query string
timeout: Maximum time to wait for query completion (seconds)
Returns:
DataFrame with query results
"""
execution_id = self.execute_query(query, wait=True, timeout=timeout)
return self.get_query_results(execution_id)
def list_tables(self, schema: Optional[str] = None) -> List[str]:
"""
List tables in the database.
Args:
schema: Optional schema name (defaults to database)
Returns:
List of table names
"""
if schema is None:
schema = self.config.database_name
query = f"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = '{schema}'
ORDER BY table_name
"""
try:
df = self.query_to_dataframe(query)
return df['table_name'].tolist() if not df.empty else []
except Exception as e:
logger.error(f"Failed to list tables: {e}")
return []
def describe_table(self, table_name: str, schema: Optional[str] = None) -> pd.DataFrame:
"""
Get table schema/columns.
Args:
table_name: Table name
schema: Optional schema name (defaults to database)
Returns:
DataFrame with column information (column_name, data_type, etc.)
"""
if schema is None:
schema = self.config.database_name
query = f"""
SELECT
column_name,
data_type,
is_nullable
FROM information_schema.columns
WHERE table_schema = '{schema}'
AND table_name = '{table_name}'
ORDER BY ordinal_position
"""
try:
return self.query_to_dataframe(query)
except Exception as e:
logger.error(f"Failed to describe table {table_name}: {e}")
return pd.DataFrame()