""" Example: Batch processing patterns for large-scale analysis using Athena. This script demonstrates memory-efficient batch processing across the entire data lake using SQL queries. """ import sys from pathlib import Path _project_root = Path(__file__).resolve().parent.parent.parent if str(_project_root) not in sys.path: sys.path.insert(0, str(_project_root)) from src.datalake.config import DataLakeConfig from src.datalake.athena import AthenaQuery from src.datalake.catalog import DataLakeCatalog from src.datalake.query import DataLakeQuery from src.datalake.batch import BatchProcessor import pandas as pd def main(): """Batch process data lake.""" # Setup # Load config with explicit credentials config = DataLakeConfig.from_credentials( database_name="dbparquetdatalake05", workgroup="athenaworkgroup-datalake05", s3_output_location="s3://canedge-raw-data-parquet/athena-results/", region="eu-north-1", access_key_id="AKIARJQJFFVASPMSGNNY", secret_access_key="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", ) athena = AthenaQuery(config) catalog = DataLakeCatalog(athena, config) query = DataLakeQuery(athena, catalog) processor = BatchProcessor(query) print("=" * 60) print("Batch Processing Examples (Athena)") print("=" * 60) print() # Example 1: Compute statistics across all data print("Example 1: Compute statistics across all devices/messages") print("-" * 60) try: stats = processor.aggregate_by_device_message( aggregation_func=processor.compute_statistics, message_filter=config.message_filter, # Optional filter ) print(f"Processed {len(stats)} device(s):") for device, messages in stats.items(): print(f"\n Device: {device}") for message, metrics in messages.items(): print(f" Message: {message}") print(f" Record count: {metrics.get('count', 0):,}") # Show statistics for first numeric column found for key, value in metrics.items(): if key != 'count' and '_mean' in key: signal = key.replace('_mean', '') print(f" {signal}:") print(f" Mean: {value:.2f}") print(f" Min: {metrics.get(f'{signal}_min', 'N/A')}") print(f" Max: {metrics.get(f'{signal}_max', 'N/A')}") break except Exception as e: print(f"Error in batch aggregation: {e}") print() # Example 2: Custom aggregation using SQL print("Example 2: Custom SQL aggregation") print("-" * 60) try: devices = catalog.list_devices() if devices: device_id = devices[0] messages = catalog.list_messages(device_id) if messages: message = messages[0] table_name = catalog.get_table_name(device_id, message) # Use SQL for aggregation sql = f""" SELECT COUNT(*) as record_count, MIN(t) as min_timestamp, MAX(t) as max_timestamp FROM {config.database_name}.{table_name} """ df_agg = query.execute_sql(sql) print(f"Aggregation for {device_id}/{message}:") print(df_agg) except Exception as e: print(f"Error in SQL aggregation: {e}") print() # Example 3: Export specific data print("Example 3: Export data to CSV") print("-" * 60) try: devices = catalog.list_devices() if devices: device_id = devices[0] messages = catalog.list_messages(device_id) if messages: message = messages[0] output_path = f"{device_id}_{message}_export.csv" processor.export_to_csv( device_id=device_id, message=message, output_path=output_path, limit=10000, # Limit for example ) print(f"Exported to: {output_path}") except Exception as e: print(f"Error exporting data: {e}") print() # Example 4: Find anomalies using SQL print("Example 4: Find anomalies using SQL") print("-" * 60) try: devices = catalog.list_devices() if devices: device_id = devices[0] messages = catalog.list_messages(device_id) if messages: message = messages[0] schema = catalog.get_schema(device_id, message) if schema: signal_cols = [c for c in schema.keys() if c != 't' and c.lower() != 'date'] if signal_cols: signal_name = signal_cols[0] table_name = catalog.get_table_name(device_id, message) # Use SQL to find outliers (3 standard deviations) sql = f""" WITH stats AS ( SELECT AVG({signal_name}) as mean_val, STDDEV({signal_name}) as std_val FROM {config.database_name}.{table_name} WHERE {signal_name} IS NOT NULL ) SELECT t, {signal_name} FROM {config.database_name}.{table_name}, stats WHERE {signal_name} IS NOT NULL AND ABS({signal_name} - mean_val) > 3 * std_val ORDER BY ABS({signal_name} - mean_val) DESC LIMIT 10 """ anomalies = query.execute_sql(sql) if not anomalies.empty: print(f"Found {len(anomalies)} anomalies in {signal_name}") print(anomalies.head()) else: print("No anomalies found") except Exception as e: print(f"Error finding anomalies: {e}") if __name__ == "__main__": main()