sirus / backend /data_sources /federation_agent.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
a8c9ee8
"""FederationAgent: builds safe Ibis expressions from JSON ASTs and executes
federated plans across configured connectors.
This file contains a minimal, well-typed and defensive implementation used
by workers. It intentionally supports a small, whitelisted set of AST ops to
avoid executing arbitrary code.
"""
from typing import Any, Dict, List
import json
import logging
import re
from .connectors.base_connector import BaseConnector
from .connectors.ibis_connector import IbisConnector
logger = logging.getLogger(__name__)
# Mapping of source_type -> connector class. Connector implementations should
# be imported and added to this map when available. For Phase 2 we expect
# 'ibis' connectors to be registered by the package consumer.
CONNECTOR_MAP: Dict[str, Any] = {
"ibis": IbisConnector,
}
class FederationAgent:
"""Agent instantiated inside a worker to execute a tenant's federated plan.
Responsibilities:
- Initialize connectors (dependency injection of redis/minio clients)
- Build safe Ibis expressions from a restricted JSON AST
- Execute each step and return step-wise results (or MinIO paths for large results)
"""
def __init__(self, tenant_config: List[Dict[str, Any]], redis_client=None, minio_client=None):
"""Initialize connectors for the tenant.
tenant_config: list of dicts each containing at least:
- source_name: logical name
- source_type: key in CONNECTOR_MAP
- config: connector-specific dict
"""
self.redis = redis_client
self.minio = minio_client
self.connectors: Dict[str, BaseConnector] = {}
for conn_info in tenant_config:
source_name = conn_info.get("source_name")
source_type = conn_info.get("source_type")
config = conn_info.get("config", {})
if not source_name or not source_type:
logger.warning("Skipping invalid connector config: %s", conn_info)
continue
if source_type not in CONNECTOR_MAP:
logger.warning("Unknown source_type '%s' for source '%s'", source_type, source_name)
continue
connector_cls = CONNECTOR_MAP[source_type]
try:
connector = connector_cls(config, self.redis, self.minio)
self.connectors[source_name] = connector
except Exception as e:
logger.exception("Failed to initialize connector %s: %s", source_name, e)
def _build_ibis_expr(self, json_plan: Dict[str, Any], connection) -> Any:
"""Recursively build an Ibis expression from a restricted JSON AST.
Allowed operations (Phase 2):
- table: {operation: 'table', name: 'table_name'}
- filter: {operation: 'filter', source: <ast>, predicate: {column, operation, value}}
This function never uses eval() and enforces a whitelist.
"""
if not isinstance(json_plan, dict):
raise ValueError("Plan node must be a dict")
op = json_plan.get("operation")
if op == "table":
name = json_plan.get("name")
if not isinstance(name, str):
raise ValueError("table name must be a string")
# support reading a previously-produced step name is out-of-scope here
return connection.table(name)
# Projection / select: {operation: 'select', source: <ast>, columns: ['a','b']}
if op in ("select", "project"):
source = json_plan.get("source")
cols = json_plan.get("columns")
if source is None or cols is None:
raise ValueError("select node requires 'source' and 'columns'")
src_expr = self._build_ibis_expr(source, connection)
if not isinstance(cols, list):
raise ValueError("select 'columns' must be a list of column names")
# Ibis projection supports src_expr[cols]
return src_expr[cols]
# Aggregate: {operation: 'aggregate', source: <ast>, group_by: ['g1'], aggregates: [{op:'sum', column:'v', alias:'sum_v'}]}
if op == "aggregate":
source = json_plan.get("source")
group_by = json_plan.get("group_by", [])
aggregates = json_plan.get("aggregates", [])
if source is None or not isinstance(aggregates, list):
raise ValueError("aggregate node requires 'source' and 'aggregates'")
src_expr = self._build_ibis_expr(source, connection)
# Build aggregate mapping: alias -> expression
aggs = {}
for a in aggregates:
a_op = a.get("op") or a.get("operation")
col = a.get("column")
alias = a.get("alias") or f"{a_op}_{col}"
if a_op not in ("sum", "count", "mean", "avg", "min", "max"):
raise NotImplementedError(f"Unsupported aggregate op: {a_op}")
# Build col expression and call aggregation
try:
col_expr = src_expr[col]
if a_op in ("mean", "avg"):
expr = col_expr.mean()
else:
expr = getattr(col_expr, a_op)()
except Exception:
# Be defensive: raise explicit error for unknown column
raise
aggs[alias] = expr
if group_by:
return src_expr.group_by(group_by).aggregate(aggs)
else:
# Ungrouped aggregation
return src_expr.aggregate(aggs)
if op == "filter":
source = json_plan.get("source")
predicate = json_plan.get("predicate")
if source is None or predicate is None:
raise ValueError("filter node requires 'source' and 'predicate'")
src_expr = self._build_ibis_expr(source, connection)
col = predicate.get("column")
val = predicate.get("value")
pred_op = predicate.get("operation")
if not isinstance(col, str):
raise ValueError("predicate column must be a string")
# Only allow 'gt' for Phase 2; can extend safely later
if pred_op == "gt":
# Build a backend-agnostic predicate. For real Ibis expressions
# we would use src_expr[col] > val; some test/mocks may expect a
# tuple-based predicate handled by the connector/table object.
try:
# Preferred: let the backend evaluate the Ibis comparison
return src_expr.filter(src_expr[col] > val)
except Exception:
# Fallback: pass a simple tuple describing the predicate
return src_expr.filter((col, "gt", val))
raise NotImplementedError(f"Unsupported predicate operation: {pred_op}")
raise NotImplementedError(f"Unsupported operation: {op}")
def execute_raw_query(self, source_name: str, sql: str) -> List[Dict[str, Any]]:
"""Executes a raw SQL query synchronously on a specific, named data source.
Args:
source_name: The logical name of the data source config to use.
sql: The raw SQL string to execute.
Returns:
The result set as a list of dictionaries.
Raises:
ValueError: If the source_name is not configured, or if the query is not read-only.
ConnectionError: If connection to the source fails.
NotImplementedError: If the connector doesn't support raw SQL.
Exception: If execution fails.
"""
logger.info(f"FederationAgent executing raw SQL on source '{source_name}': {sql[:100]}...")
# Enforce Read-Only Mode
# Basic regex to catch DML and DDL statements.
# This is a rudimentary check; a robust implementation might use a SQL parser.
forbidden_patterns = [
r'\bINSERT\b', r'\bUPDATE\b', r'\bDELETE\b', r'\bDROP\b',
r'\bCREATE\b', r'\bALTER\b', r'\bTRUNCATE\b', r'\bGRANT\b',
r'\bREVOKE\b', r'\bREPLACE\b', r'\bUPSERT\b', r'\bMERGE\b'
]
upper_sql = sql.upper()
for pattern in forbidden_patterns:
if re.search(pattern, upper_sql):
logger.warning(f"Blocked potential destructive SQL query: {pattern} detected.")
raise ValueError("Only read-only (SELECT) queries are allowed in this environment.")
if source_name not in self.connectors:
# Provide specific error if source not found for the *instantiated* agent
# (which should have been created with the tenant's config)
logger.error(f"Source '{source_name}' not found among initialized connectors for this agent.")
raise ValueError(f"Source '{source_name}' not configured or initialized for this agent.")
connector: BaseConnector = self.connectors[source_name]
# Ensure the connector is ready
if not connector.is_connected():
try:
logger.info(f"Connecting to source '{source_name}' for raw SQL execution.")
connector.connect()
except Exception as e:
logger.error(f"Failed to connect to source '{source_name}' for raw SQL: {e}")
# Raise a specific, catchable error for connection issues
raise ConnectionError(f"Could not connect to data source '{source_name}': {e}") from e
# Delegate the raw SQL execution directly to the connector
# This call blocks until the query finishes or fails
try:
results = connector.execute_raw_sql(sql)
# Log result count, not the data itself for privacy/performance
logger.info(f"Raw SQL on source '{source_name}' completed, returned {len(results)} rows.")
return results
except NotImplementedError:
logger.error(f"Connector for source '{source_name}' ({type(connector).__name__}) does not support raw SQL.")
raise # Re-raise for the API layer to return 501
except Exception as e:
# Log the specific error but re-raise the original exception
# This preserves the original error type and traceback for better debugging
logger.error(f"Error executing raw SQL on source '{source_name}': {e}", exc_info=True)
raise # Re-raise the caught exception
# --- User-facing helpers ---
def get_combined_schema(self) -> str:
"""Gather schemas from all connectors and return one combined string.
Each connector's `get_schema()` can raise; errors are collected and
represented in the combined output.
"""
parts: List[str] = []
for name, connector in self.connectors.items():
try:
schema = connector.get_schema()
parts.append(f"--- Schema for {name} ---\n{schema}\n")
except Exception as e:
parts.append(f"--- Could not fetch schema for {name}: {e} ---\n")
return "\n".join(parts)
def execute_federated_plan(self, plan: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Execute a simple linear federated plan.
For each step the plan must contain:
- source: name of the configured source
- query: restricted JSON AST (see _build_ibis_expr)
Returns a dict of step_name -> result (list of rows) or minio path marker.
"""
if not isinstance(plan, list):
raise ValueError("Plan must be a list of steps")
step_results: Dict[str, Any] = {}
for i, step in enumerate(plan, start=1):
step_name = f"step_{i}"
source = step.get("source")
query_ast = step.get("query")
if source not in self.connectors:
raise ValueError(f"Source '{source}' not configured")
connector: BaseConnector = self.connectors[source]
# Ensure connector has a live connection
if not connector.is_connected():
connector.connect()
# Build Ibis expression using safe builder
ibis_expr = self._build_ibis_expr(query_ast, connector.connection)
# Execute and collect results. For large results save to MinIO (placeholder)
result = connector.execute_query(ibis_expr)
if isinstance(result, list) and len(result) > 10000:
# real impl should stream/write parquet to minio
minio_path = f"jobs/{step_name}.parquet"
step_results[step_name] = {"result_type": "minio_path", "path": minio_path}
else:
step_results[step_name] = result
return step_results