Spaces:
Running
Running
| """FederationAgent: builds safe Ibis expressions from JSON ASTs and executes | |
| federated plans across configured connectors. | |
| This file contains a minimal, well-typed and defensive implementation used | |
| by workers. It intentionally supports a small, whitelisted set of AST ops to | |
| avoid executing arbitrary code. | |
| """ | |
| from typing import Any, Dict, List | |
| import json | |
| import logging | |
| import re | |
| from .connectors.base_connector import BaseConnector | |
| from .connectors.ibis_connector import IbisConnector | |
| logger = logging.getLogger(__name__) | |
| # Mapping of source_type -> connector class. Connector implementations should | |
| # be imported and added to this map when available. For Phase 2 we expect | |
| # 'ibis' connectors to be registered by the package consumer. | |
| CONNECTOR_MAP: Dict[str, Any] = { | |
| "ibis": IbisConnector, | |
| } | |
| class FederationAgent: | |
| """Agent instantiated inside a worker to execute a tenant's federated plan. | |
| Responsibilities: | |
| - Initialize connectors (dependency injection of redis/minio clients) | |
| - Build safe Ibis expressions from a restricted JSON AST | |
| - Execute each step and return step-wise results (or MinIO paths for large results) | |
| """ | |
| def __init__(self, tenant_config: List[Dict[str, Any]], redis_client=None, minio_client=None): | |
| """Initialize connectors for the tenant. | |
| tenant_config: list of dicts each containing at least: | |
| - source_name: logical name | |
| - source_type: key in CONNECTOR_MAP | |
| - config: connector-specific dict | |
| """ | |
| self.redis = redis_client | |
| self.minio = minio_client | |
| self.connectors: Dict[str, BaseConnector] = {} | |
| for conn_info in tenant_config: | |
| source_name = conn_info.get("source_name") | |
| source_type = conn_info.get("source_type") | |
| config = conn_info.get("config", {}) | |
| if not source_name or not source_type: | |
| logger.warning("Skipping invalid connector config: %s", conn_info) | |
| continue | |
| if source_type not in CONNECTOR_MAP: | |
| logger.warning("Unknown source_type '%s' for source '%s'", source_type, source_name) | |
| continue | |
| connector_cls = CONNECTOR_MAP[source_type] | |
| try: | |
| connector = connector_cls(config, self.redis, self.minio) | |
| self.connectors[source_name] = connector | |
| except Exception as e: | |
| logger.exception("Failed to initialize connector %s: %s", source_name, e) | |
| def _build_ibis_expr(self, json_plan: Dict[str, Any], connection) -> Any: | |
| """Recursively build an Ibis expression from a restricted JSON AST. | |
| Allowed operations (Phase 2): | |
| - table: {operation: 'table', name: 'table_name'} | |
| - filter: {operation: 'filter', source: <ast>, predicate: {column, operation, value}} | |
| This function never uses eval() and enforces a whitelist. | |
| """ | |
| if not isinstance(json_plan, dict): | |
| raise ValueError("Plan node must be a dict") | |
| op = json_plan.get("operation") | |
| if op == "table": | |
| name = json_plan.get("name") | |
| if not isinstance(name, str): | |
| raise ValueError("table name must be a string") | |
| # support reading a previously-produced step name is out-of-scope here | |
| return connection.table(name) | |
| # Projection / select: {operation: 'select', source: <ast>, columns: ['a','b']} | |
| if op in ("select", "project"): | |
| source = json_plan.get("source") | |
| cols = json_plan.get("columns") | |
| if source is None or cols is None: | |
| raise ValueError("select node requires 'source' and 'columns'") | |
| src_expr = self._build_ibis_expr(source, connection) | |
| if not isinstance(cols, list): | |
| raise ValueError("select 'columns' must be a list of column names") | |
| # Ibis projection supports src_expr[cols] | |
| return src_expr[cols] | |
| # Aggregate: {operation: 'aggregate', source: <ast>, group_by: ['g1'], aggregates: [{op:'sum', column:'v', alias:'sum_v'}]} | |
| if op == "aggregate": | |
| source = json_plan.get("source") | |
| group_by = json_plan.get("group_by", []) | |
| aggregates = json_plan.get("aggregates", []) | |
| if source is None or not isinstance(aggregates, list): | |
| raise ValueError("aggregate node requires 'source' and 'aggregates'") | |
| src_expr = self._build_ibis_expr(source, connection) | |
| # Build aggregate mapping: alias -> expression | |
| aggs = {} | |
| for a in aggregates: | |
| a_op = a.get("op") or a.get("operation") | |
| col = a.get("column") | |
| alias = a.get("alias") or f"{a_op}_{col}" | |
| if a_op not in ("sum", "count", "mean", "avg", "min", "max"): | |
| raise NotImplementedError(f"Unsupported aggregate op: {a_op}") | |
| # Build col expression and call aggregation | |
| try: | |
| col_expr = src_expr[col] | |
| if a_op in ("mean", "avg"): | |
| expr = col_expr.mean() | |
| else: | |
| expr = getattr(col_expr, a_op)() | |
| except Exception: | |
| # Be defensive: raise explicit error for unknown column | |
| raise | |
| aggs[alias] = expr | |
| if group_by: | |
| return src_expr.group_by(group_by).aggregate(aggs) | |
| else: | |
| # Ungrouped aggregation | |
| return src_expr.aggregate(aggs) | |
| if op == "filter": | |
| source = json_plan.get("source") | |
| predicate = json_plan.get("predicate") | |
| if source is None or predicate is None: | |
| raise ValueError("filter node requires 'source' and 'predicate'") | |
| src_expr = self._build_ibis_expr(source, connection) | |
| col = predicate.get("column") | |
| val = predicate.get("value") | |
| pred_op = predicate.get("operation") | |
| if not isinstance(col, str): | |
| raise ValueError("predicate column must be a string") | |
| # Only allow 'gt' for Phase 2; can extend safely later | |
| if pred_op == "gt": | |
| # Build a backend-agnostic predicate. For real Ibis expressions | |
| # we would use src_expr[col] > val; some test/mocks may expect a | |
| # tuple-based predicate handled by the connector/table object. | |
| try: | |
| # Preferred: let the backend evaluate the Ibis comparison | |
| return src_expr.filter(src_expr[col] > val) | |
| except Exception: | |
| # Fallback: pass a simple tuple describing the predicate | |
| return src_expr.filter((col, "gt", val)) | |
| raise NotImplementedError(f"Unsupported predicate operation: {pred_op}") | |
| raise NotImplementedError(f"Unsupported operation: {op}") | |
| def execute_raw_query(self, source_name: str, sql: str) -> List[Dict[str, Any]]: | |
| """Executes a raw SQL query synchronously on a specific, named data source. | |
| Args: | |
| source_name: The logical name of the data source config to use. | |
| sql: The raw SQL string to execute. | |
| Returns: | |
| The result set as a list of dictionaries. | |
| Raises: | |
| ValueError: If the source_name is not configured, or if the query is not read-only. | |
| ConnectionError: If connection to the source fails. | |
| NotImplementedError: If the connector doesn't support raw SQL. | |
| Exception: If execution fails. | |
| """ | |
| logger.info(f"FederationAgent executing raw SQL on source '{source_name}': {sql[:100]}...") | |
| # Enforce Read-Only Mode | |
| # Basic regex to catch DML and DDL statements. | |
| # This is a rudimentary check; a robust implementation might use a SQL parser. | |
| forbidden_patterns = [ | |
| r'\bINSERT\b', r'\bUPDATE\b', r'\bDELETE\b', r'\bDROP\b', | |
| r'\bCREATE\b', r'\bALTER\b', r'\bTRUNCATE\b', r'\bGRANT\b', | |
| r'\bREVOKE\b', r'\bREPLACE\b', r'\bUPSERT\b', r'\bMERGE\b' | |
| ] | |
| upper_sql = sql.upper() | |
| for pattern in forbidden_patterns: | |
| if re.search(pattern, upper_sql): | |
| logger.warning(f"Blocked potential destructive SQL query: {pattern} detected.") | |
| raise ValueError("Only read-only (SELECT) queries are allowed in this environment.") | |
| if source_name not in self.connectors: | |
| # Provide specific error if source not found for the *instantiated* agent | |
| # (which should have been created with the tenant's config) | |
| logger.error(f"Source '{source_name}' not found among initialized connectors for this agent.") | |
| raise ValueError(f"Source '{source_name}' not configured or initialized for this agent.") | |
| connector: BaseConnector = self.connectors[source_name] | |
| # Ensure the connector is ready | |
| if not connector.is_connected(): | |
| try: | |
| logger.info(f"Connecting to source '{source_name}' for raw SQL execution.") | |
| connector.connect() | |
| except Exception as e: | |
| logger.error(f"Failed to connect to source '{source_name}' for raw SQL: {e}") | |
| # Raise a specific, catchable error for connection issues | |
| raise ConnectionError(f"Could not connect to data source '{source_name}': {e}") from e | |
| # Delegate the raw SQL execution directly to the connector | |
| # This call blocks until the query finishes or fails | |
| try: | |
| results = connector.execute_raw_sql(sql) | |
| # Log result count, not the data itself for privacy/performance | |
| logger.info(f"Raw SQL on source '{source_name}' completed, returned {len(results)} rows.") | |
| return results | |
| except NotImplementedError: | |
| logger.error(f"Connector for source '{source_name}' ({type(connector).__name__}) does not support raw SQL.") | |
| raise # Re-raise for the API layer to return 501 | |
| except Exception as e: | |
| # Log the specific error but re-raise the original exception | |
| # This preserves the original error type and traceback for better debugging | |
| logger.error(f"Error executing raw SQL on source '{source_name}': {e}", exc_info=True) | |
| raise # Re-raise the caught exception | |
| # --- User-facing helpers --- | |
| def get_combined_schema(self) -> str: | |
| """Gather schemas from all connectors and return one combined string. | |
| Each connector's `get_schema()` can raise; errors are collected and | |
| represented in the combined output. | |
| """ | |
| parts: List[str] = [] | |
| for name, connector in self.connectors.items(): | |
| try: | |
| schema = connector.get_schema() | |
| parts.append(f"--- Schema for {name} ---\n{schema}\n") | |
| except Exception as e: | |
| parts.append(f"--- Could not fetch schema for {name}: {e} ---\n") | |
| return "\n".join(parts) | |
| def execute_federated_plan(self, plan: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Execute a simple linear federated plan. | |
| For each step the plan must contain: | |
| - source: name of the configured source | |
| - query: restricted JSON AST (see _build_ibis_expr) | |
| Returns a dict of step_name -> result (list of rows) or minio path marker. | |
| """ | |
| if not isinstance(plan, list): | |
| raise ValueError("Plan must be a list of steps") | |
| step_results: Dict[str, Any] = {} | |
| for i, step in enumerate(plan, start=1): | |
| step_name = f"step_{i}" | |
| source = step.get("source") | |
| query_ast = step.get("query") | |
| if source not in self.connectors: | |
| raise ValueError(f"Source '{source}' not configured") | |
| connector: BaseConnector = self.connectors[source] | |
| # Ensure connector has a live connection | |
| if not connector.is_connected(): | |
| connector.connect() | |
| # Build Ibis expression using safe builder | |
| ibis_expr = self._build_ibis_expr(query_ast, connector.connection) | |
| # Execute and collect results. For large results save to MinIO (placeholder) | |
| result = connector.execute_query(ibis_expr) | |
| if isinstance(result, list) and len(result) > 10000: | |
| # real impl should stream/write parquet to minio | |
| minio_path = f"jobs/{step_name}.parquet" | |
| step_results[step_name] = {"result_type": "minio_path", "path": minio_path} | |
| else: | |
| step_results[step_name] = result | |
| return step_results | |