Spaces:
Running
Running
| """DataSourcesSQLToolkit: Multi-tenant SQL execution toolkit for SQL Agents. | |
| This toolkit provides agents with tools to interact with the centralized | |
| data_sources API for fetching schemas, instructions, and executing raw SQL | |
| queries securely across multiple tenants and data sources. | |
| """ | |
| import httpx | |
| import asyncio | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from typing import Dict, Any, List, Optional, Union | |
| from agno.tools import Toolkit | |
| from agno.run import RunContext | |
| from agno.utils.log import logger | |
| # Import hybrid keyword utilities with defensive fallback | |
| try: | |
| from hybrid_keyword_utils import extract_hybrid_keywords, create_gemini_semantic_client | |
| _hybrid_utils_available = True | |
| except ImportError: | |
| try: | |
| from backend.SQL_Agent.hybrid_keyword_utils import extract_hybrid_keywords, create_gemini_semantic_client | |
| _hybrid_utils_available = True | |
| except ImportError: | |
| logger.warning( | |
| "Could not import hybrid_keyword_utils. Semantic extraction will be unavailable. " | |
| "This is expected if running without the hybrid_keyword_utils module." | |
| ) | |
| extract_hybrid_keywords = None | |
| create_gemini_semantic_client = None | |
| _hybrid_utils_available = False | |
| # Base URL for your data_sources API. | |
| # Ensure this matches where your backend API is running. | |
| DATA_SOURCES_API_BASE_URL = "http://127.0.0.1:8000" | |
| # Default timeout for API calls (seconds). Set higher for potentially long queries. | |
| DEFAULT_API_TIMEOUT = 120.0 | |
| # Per-call timeout budget (can be overridden per method) | |
| DEFAULT_REQUEST_TIMEOUT = 120.0 | |
| class DataSourcesSQLToolkit(Toolkit): | |
| """ | |
| A multi-tenant toolkit for SQL Agents using the Data Sources API | |
| to fetch schemas/instructions and execute RAW SQL queries synchronously. | |
| MODIFIED: This toolkit is now tenant-safe. It reads tenant_id, source_name, | |
| and api_key directly from the agent's session_state for every call, | |
| ensuring proper data isolation and security. | |
| """ | |
| def __init__( | |
| self, | |
| api_base_url: str = DATA_SOURCES_API_BASE_URL, | |
| timeout: float = DEFAULT_API_TIMEOUT, | |
| api_key: Optional[str] = None # This is a FALLBACK key only | |
| ): | |
| """ | |
| Initialize the DataSourcesSQLToolkit. | |
| Args: | |
| api_base_url: Base URL for the data sources API | |
| timeout: Default timeout for API client | |
| api_key: Optional FALLBACK API key (e.g., from env vars) | |
| """ | |
| self.api_base_url = api_base_url | |
| self.api_key = api_key # Store fallback | |
| # --- CRITICAL FIX --- | |
| # DO NOT set the API key in the default client headers. | |
| # This client is shared across all requests and must be stateless. | |
| self.client = httpx.Client( | |
| base_url=self.api_base_url, | |
| timeout=timeout, | |
| headers={} # Headers will be provided per-request | |
| ) | |
| # --- Async client for non-blocking calls (preferred) --- | |
| # Keep a separate sync client for backwards compatibility. | |
| try: | |
| self._async_client = httpx.AsyncClient( | |
| base_url=self.api_base_url, | |
| timeout=timeout, | |
| headers={} | |
| ) | |
| except Exception: | |
| # If AsyncClient cannot be created for some reason, fall back to sync client. | |
| self._async_client = None | |
| # Cache for search results per (tenant_id, keywords_tuple, source_names_tuple) | |
| self._search_cache: Dict[tuple, Dict[str, Any]] = {} | |
| # Initialize semantic client for hybrid keyword extraction | |
| self._semantic_client = None | |
| if _hybrid_utils_available and create_gemini_semantic_client: | |
| try: | |
| self._semantic_client = create_gemini_semantic_client() | |
| logger.info("Gemini semantic client initialized for hybrid keyword extraction") | |
| except Exception as e: | |
| logger.warning(f"Could not initialize semantic client: {e}") | |
| # Reference to the agent for session_state injection | |
| self._agent_ref = None | |
| logger.info(f"DataSourcesSQLToolkit initialized for API: {self.api_base_url} with timeout {timeout}s") | |
| # Define the tools this toolkit provides to the agent | |
| super().__init__( | |
| name="data_sources_sql_tools", | |
| tools=[ | |
| self.find_relevant_tables, # NEW: High-level intelligent table discovery | |
| self.search_schema, | |
| self.list_sources, # NEW: Explicit source listing | |
| self.get_available_sources_and_schema, | |
| self.get_source_instructions, | |
| self.execute_sql_query, # The primary execution tool | |
| self.save_query_to_tenant_csv, # NEW: Bridge tool to ML files | |
| self.compose_dataset_workflow, | |
| ], | |
| instructions="""CRITICAL WORKFLOW: | |
| 1. After calling find_relevant_tables, you MUST immediately call execute_sql_query. | |
| 2. NEVER stop after just searching - always execute the SQL query. | |
| 3. ALWAYS use execute_sql_query when user asks for data counts or lists. | |
| 4. NEVER just describe what you would do - actually call the tool and explain the results. | |
| 5. After getting results, provide a clear natural language answer to the user.""" | |
| ) | |
| def set_agent_ref(self, agent_ref): | |
| """ | |
| Set a reference to the agent for accessing session_state during tool execution. | |
| This is crucial for multi-tenant context injection. | |
| Args: | |
| agent_ref: Reference to the Agno agent | |
| """ | |
| self._agent_ref = agent_ref | |
| logger.info("Agent reference set in toolkit for session_state injection") | |
| def _get_session_state_from_agent(self) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get the current session state from the agent reference. | |
| Returns: | |
| session_state dict from agent if available, None otherwise | |
| """ | |
| if self._agent_ref and hasattr(self._agent_ref, 'session_state'): | |
| return self._agent_ref.session_state | |
| return None | |
| # --- NEW HELPER: Get Per-Request Headers --- | |
| def _get_request_headers(self, session_state: Optional[Dict[str, Any]]) -> Dict[str, str]: | |
| """ | |
| Safely builds headers for a single request. | |
| It prioritizes the supabase_jwt from the current session_state. | |
| """ | |
| headers = {} | |
| jwt_token = None | |
| if session_state: | |
| jwt_token = session_state.get("supabase_jwt") | |
| if jwt_token: | |
| logger.debug("Using JWT token from session_state") | |
| if jwt_token: | |
| headers["Authorization"] = f"Bearer {jwt_token}" | |
| else: | |
| logger.warning("No JWT token found in session_state. API call may fail.") | |
| return headers | |
| # Helper to run async coroutines from sync context when safe | |
| def _run_coro_sync(self, coro): | |
| """ | |
| Run an async coroutine from sync code when no event loop is running. | |
| If an event loop is already running, return the coroutine (caller must await). | |
| """ | |
| try: | |
| loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| # No running loop, safe to run | |
| return asyncio.run(coro) | |
| else: | |
| # An event loop is running; return the coroutine so caller can await it. | |
| return coro | |
| # --- Helper Methods --- | |
| def _normalize_error(self, error_msg: str, code: str = "UNKNOWN_ERROR", hint: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| Normalize error responses to a consistent format. | |
| Args: | |
| error_msg: The error message | |
| code: Error code for categorization | |
| hint: Optional hint for resolution | |
| Returns: | |
| Normalized error dict with error, code, and optional hint | |
| """ | |
| result = { | |
| "error": error_msg, | |
| "code": code | |
| } | |
| if hint: | |
| result["hint"] = hint | |
| return result | |
| # --- CRITICAL FIX: Force session_state --- | |
| def _resolve_tenant(self, tenant_id: Optional[str], session_state: Optional[Dict[str, Any]]) -> Union[str, Dict[str, Any]]: | |
| """ | |
| Resolve tenant_id ONLY from session state. | |
| Ignores any tenant_id passed by the LLM to prevent confusion. | |
| """ | |
| if not session_state: | |
| return self._normalize_error("Session state not found.", "SESSION_ERROR") | |
| resolved = (session_state.get("tenant_id", "") or "").strip() | |
| if not resolved: | |
| return self._normalize_error( | |
| "Tenant ID not found in session_state.", | |
| code="MISSING_TENANT_CONTEXT", | |
| hint="The /tenant-run endpoint must inject 'tenant_id'." | |
| ) | |
| return resolved | |
| # --- CRITICAL FIX: Force session_state --- | |
| def _resolve_source(self, source_name: Optional[str], session_state: Optional[Dict[str, Any]]) -> Union[str, Dict[str, Any]]: | |
| """ | |
| Resolve source_name ONLY from session state. | |
| Ignores any source_name passed by the LLM. | |
| """ | |
| if not session_state: | |
| return self._normalize_error("Session state not found.", "SESSION_ERROR") | |
| resolved = (session_state.get("source_name", "") or "").strip() | |
| if not resolved: | |
| return self._normalize_error( | |
| "Source name not found in session_state.", | |
| code="MISSING_SOURCE_CONTEXT", | |
| hint="The /tenant-run endpoint must inject 'source_name'." | |
| ) | |
| return resolved | |
| # --- Helper Function --- | |
| def _format_schema_for_llm(self, source_name: str, schema_json_string: Optional[str]) -> str: | |
| """Formats a single source's JSON schema string (received from API) for the LLM.""" | |
| if not schema_json_string: | |
| return f"**Source: `{source_name}`**\n Schema: <Not Available or Empty>" | |
| try: | |
| # The API returns the raw schema string, which should be JSON parsable | |
| schema_data_list = json.loads(schema_json_string) | |
| # Expecting the format: [{'schema_name': '...', 'tables': [...]}] | |
| if not schema_data_list or not isinstance(schema_data_list, list) or not isinstance(schema_data_list[0], dict): | |
| logger.warning(f"Received schema for '{source_name}' is not in expected list-of-dict format.") | |
| # Attempt to format directly if it's just a dict (less robust) | |
| if isinstance(schema_data_list, dict): | |
| schema_info = schema_data_list | |
| else: | |
| return f"**Source: `{source_name}`**\n Schema: <Invalid Format Received>" | |
| else: | |
| schema_info = schema_data_list[0] # Assuming one schema object per source from API | |
| output_parts = [f"**Source: `{source_name}`** (DB/Schema Name: {schema_info.get('schema_name', 'N/A')})\n"] | |
| tables = schema_info.get("tables", []) | |
| if not tables: | |
| output_parts.append(" *No tables found or schema details missing.*") | |
| else: | |
| for table in tables: | |
| table_name = table.get('table_name', 'N/A') | |
| output_parts.append(f" **Table: `{table_name}`**") | |
| fields = table.get('fields', []) | |
| if not fields: | |
| output_parts.append(" *No column details available.*") | |
| else: | |
| output_parts.append(" Columns:") | |
| for field in fields: | |
| col_name = field.get('name', 'N/A') | |
| col_type = field.get('type', 'Unknown') | |
| col_example = field.get('example', '') # Example might be missing or empty | |
| # Add example only if present and non-empty to avoid clutter | |
| example_str = f", example: '{col_example}'" if col_example else "" | |
| output_parts.append(f" - `{col_name}` (type: {col_type}{example_str})") | |
| # Add sample rows if available | |
| example_rows = table.get('example_rows', []) | |
| if example_rows: | |
| output_parts.append(" Sample Data:") | |
| for row in example_rows[:3]: # Limit to 3 rows | |
| row_str = ', '.join(f"{k}={v}" for k, v in list(row.items())[:5]) # Limit columns shown | |
| output_parts.append(f" {row_str}") | |
| output_parts.append("") # Add a blank line after each table definition for readability | |
| return "\n".join(output_parts) | |
| except json.JSONDecodeError: | |
| logger.error(f"Failed to parse schema JSON for source '{source_name}'. Received: {schema_json_string[:200]}...") | |
| return f"**Source: `{source_name}`**\n Schema: <Error Parsing JSON>" | |
| except (IndexError, KeyError, TypeError) as e: | |
| logger.error(f"Error processing schema structure for source '{source_name}': {e}") | |
| return f"**Source: `{source_name}`**\n Schema: <Error Processing Schema Structure>" | |
| except Exception as e: # Catch any other unexpected errors during formatting | |
| logger.exception(f"Unexpected error formatting schema for '{source_name}': {e}") | |
| return f"**Source: `{source_name}`**\n Schema: <Unexpected Formatting Error>" | |
| # --- Agent Tools --- | |
| def list_sources( | |
| self, | |
| run_context: Optional[RunContext] = None, | |
| tenant_id: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """List all available data sources for a tenant.""" | |
| # Get session_state from run_context (Agno v2 pattern) | |
| session_state = run_context.session_state if run_context else None | |
| tenant_context = self._resolve_tenant(tenant_id, session_state) | |
| if isinstance(tenant_context, dict) and tenant_context.get("error"): | |
| return tenant_context | |
| tenant_id = tenant_context | |
| request_headers = self._get_request_headers(session_state) | |
| logger.info(f"Tool: list_sources called for tenant '{tenant_id}'") | |
| try: | |
| url = f"/api/v1/data-sources/my-tenant/list" | |
| response = self.client.get(url, timeout=DEFAULT_REQUEST_TIMEOUT, headers=request_headers) | |
| if response.status_code == 200: | |
| data = response.json() | |
| result = { | |
| "available_sources": data.get("available_sources", []), | |
| "count": data.get("count", 0), | |
| "tenant_id": data.get("tenant_id", tenant_id), | |
| "assets": [], | |
| "datasets": [], | |
| } | |
| # Unified catalog extension: include tenant file assets when available | |
| try: | |
| assets_resp = self.client.get( | |
| "/api/v1/tenant-files/assets?page=1&page_size=200", | |
| timeout=DEFAULT_REQUEST_TIMEOUT, | |
| headers=request_headers, | |
| ) | |
| if assets_resp.status_code == 200: | |
| assets_payload = assets_resp.json() | |
| result["assets"] = assets_payload.get("items", []) | |
| result["datasets"] = [ | |
| item for item in result["assets"] | |
| if item.get("file_type") in {"csv", "parquet", "xlsx", "xls"} | |
| ] | |
| except Exception as assets_err: | |
| logger.debug(f"Tenant files catalog fetch skipped/failed: {assets_err}") | |
| logger.info(f"Found {result['count']} source(s) for tenant '{tenant_id}': {result['available_sources']}") | |
| if session_state is not None: | |
| cache = session_state.setdefault("available_sources_cache", {}) | |
| cache[tenant_id] = result["available_sources"] | |
| return result | |
| else: | |
| return self._normalize_error(f"Failed to list sources: {response.text[:200]}", code="API_ERROR") | |
| except Exception as e: | |
| logger.exception(f"Error in list_sources: {e}") | |
| return self._normalize_error(f"Error: {str(e)}", code="UNEXPECTED_ERROR") | |
| def find_relevant_tables( | |
| self, | |
| run_context: Optional[RunContext] = None, | |
| tenant_id: Optional[str] = None, | |
| question: str = "", | |
| concepts: Optional[List[str]] = None, | |
| source_names: Optional[List[str]] = None, | |
| include_samples: bool = False | |
| ) -> Dict[str, Any]: | |
| """ | |
| Intelligently find relevant tables for a user question using hybrid keyword extraction. | |
| This is the PRIMARY tool for schema discovery - it combines deterministic keywords, | |
| semantic hints from LLM, and agent-provided concepts. | |
| :param run_context: Agno RunContext containing session_state (auto-injected by framework). | |
| :param tenant_id: The ID of the tenant. MANDATORY. | |
| :param question: The original user question. MANDATORY. | |
| :param concepts: Optional list of high-level concepts identified by the agent (e.g., ["revenue", "customers"]) | |
| :param source_names: Optional list to filter specific sources. None = search all. | |
| :param include_samples: Whether to include example rows in results (default: False). | |
| :return: Dict with 'formatted_schema_string', 'matches', 'total_matches', 'keyword_breakdown' | |
| """ | |
| # Get session_state from run_context (Agno v2 pattern) | |
| session_state = run_context.session_state if run_context else None | |
| tenant_context = self._resolve_tenant(tenant_id, session_state) | |
| if isinstance(tenant_context, dict) and tenant_context.get("error"): | |
| return tenant_context | |
| tenant_id = tenant_context # type: ignore[assignment] | |
| if not question or not question.strip(): | |
| return self._normalize_error("Question is required.", "missing_parameter", "Provide a non-empty question string.") | |
| logger.info(f"Tool: find_relevant_tables called for tenant '{tenant_id}' with question: '{question[:100]}...'") | |
| logger.info(f" Agent concepts: {concepts}") | |
| # --- FIX for [None] Pydantic Error --- | |
| if source_names and all(s is None for s in source_names): | |
| source_names = None | |
| if source_names is None and session_state: | |
| preferred_source = session_state.get("source_name") or session_state.get("preferred_source") | |
| if preferred_source: | |
| source_names = [preferred_source] | |
| # Check if we have cached results for this exact question in session state | |
| if session_state: | |
| cache_key = f"find_relevant|{tenant_id}|{question}" | |
| if "keyword_extraction_cache" not in session_state: | |
| session_state["keyword_extraction_cache"] = {} | |
| if cache_key in session_state["keyword_extraction_cache"]: | |
| logger.info(f"Using cached keyword extraction for question") | |
| keyword_result = session_state["keyword_extraction_cache"][cache_key] | |
| merged_keywords = keyword_result["combined"] | |
| else: | |
| # Perform hybrid keyword extraction | |
| if extract_hybrid_keywords: | |
| keyword_result = extract_hybrid_keywords( | |
| question=question, | |
| llm_concepts=concepts, | |
| semantic_client=self._semantic_client | |
| ) | |
| # Cache the extraction result | |
| session_state["keyword_extraction_cache"][cache_key] = keyword_result | |
| merged_keywords = keyword_result["combined"] | |
| logger.info(f"Keyword extraction breakdown:") | |
| logger.info(f" Base (deterministic): {keyword_result['base']}") | |
| logger.info(f" Semantic (LLM hints): {keyword_result['semantic']}") | |
| logger.info(f" Concepts (agent): {keyword_result['concepts']}") | |
| logger.info(f" Combined: {merged_keywords}") | |
| else: | |
| # Fallback if hybrid extraction not available | |
| logger.warning("Hybrid keyword extraction not available, using simple fallback") | |
| # Simple fallback: split question and filter stopwords | |
| words = question.lower().split() | |
| stopwords = {'what', 'when', 'where', 'who', 'which', 'how', 'show', 'give', 'the', 'a', 'an', 'is', 'are'} | |
| base_kw = [w for w in words if w not in stopwords and len(w) > 2] | |
| merged_keywords = base_kw + (concepts or []) | |
| keyword_result = { | |
| 'base': base_kw, | |
| 'semantic': [], | |
| 'concepts': concepts or [], | |
| 'combined': merged_keywords | |
| } | |
| session_state["keyword_extraction_cache"][cache_key] = keyword_result | |
| else: | |
| # No session state, extract without caching | |
| if extract_hybrid_keywords: | |
| keyword_result = extract_hybrid_keywords( | |
| question=question, | |
| llm_concepts=concepts, | |
| semantic_client=self._semantic_client | |
| ) | |
| merged_keywords = keyword_result["combined"] | |
| else: | |
| # Simple fallback | |
| words = question.lower().split() | |
| stopwords = {'what', 'when', 'where', 'who', 'which', 'how', 'show', 'give', 'the', 'a', 'an', 'is', 'are'} | |
| base_kw = [w for w in words if w not in stopwords and len(w) > 2] | |
| merged_keywords = base_kw + (concepts or []) | |
| keyword_result = { | |
| 'base': base_kw, | |
| 'semantic': [], | |
| 'concepts': concepts or [], | |
| 'combined': merged_keywords | |
| } | |
| # Use the existing search_schema with merged keywords | |
| search_result = self.search_schema( | |
| run_context=run_context, | |
| tenant_id=tenant_id, | |
| keywords=merged_keywords, | |
| source_names=source_names, | |
| include_samples=include_samples, | |
| original_question=question, | |
| keyword_metadata=keyword_result | |
| ) | |
| # Add keyword breakdown to result for observability | |
| if "error" not in search_result: | |
| search_result["keyword_breakdown"] = keyword_result | |
| search_result["original_question"] = question | |
| if session_state is not None: | |
| metadata = session_state.setdefault("analysis_metadata", {}) | |
| metadata["last_question"] = question | |
| metadata["last_keyword_breakdown"] = keyword_result | |
| if search_result.get("matches"): | |
| metadata["last_schema_matches"] = search_result["matches"] | |
| logger.info(f"find_relevant_tables completed: {search_result.get('total_matches', 0)} matches found") | |
| return search_result | |
| def search_schema( | |
| self, | |
| run_context: Optional[RunContext] = None, | |
| tenant_id: Optional[str] = None, | |
| keywords: Optional[List[str]] = None, | |
| source_names: Optional[List[str]] = None, | |
| include_samples: bool = False, | |
| original_question: Optional[str] = None, | |
| keyword_metadata: Optional[Dict[str, List[str]]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Search for relevant tables across tenant data sources using keywords. | |
| Returns ranked tables based on keyword matches in table/column names and descriptions. | |
| Results are cached in session_state for reuse in follow-up queries. | |
| :param run_context: Agno RunContext containing session_state (auto-injected by framework). | |
| :param tenant_id: The ID of the tenant. MANDATORY. | |
| :param keywords: List of search keywords (e.g., ['user', 'order', 'payment']). MANDATORY. | |
| :param source_names: Optional list to filter specific sources. None = search all. | |
| :param include_samples: Whether to include example rows in results (default: False). | |
| :param original_question: Optional original user question for logging/analytics. | |
| :param keyword_metadata: Optional dict with keyword breakdown (base, semantic, concepts) for observability. | |
| :return: Dict with 'formatted_schema_string', 'matches', 'available_sources', 'total_matches' | |
| """ | |
| # Get session_state from run_context (Agno v2 pattern) | |
| session_state = run_context.session_state if run_context else None | |
| tenant_context = self._resolve_tenant(tenant_id, session_state) | |
| if isinstance(tenant_context, dict) and tenant_context.get("error"): | |
| return tenant_context | |
| tenant_id = tenant_context # type: ignore[assignment] | |
| # CRITICAL: Get per-request headers | |
| request_headers = self._get_request_headers(session_state) | |
| if not keywords or not isinstance(keywords, list): | |
| return self._normalize_error("Keywords must be a non-empty list.", "invalid_parameter", "Provide keywords as a list of strings.") | |
| # Normalize and cache key | |
| normalized_keywords = tuple(sorted([k.lower().strip() for k in keywords if k.strip()])) | |
| cache_key = (tenant_id, normalized_keywords, tuple(source_names or [])) | |
| # Check session_state cache first (across agent calls in same session) | |
| if session_state and "schema_search_cache" in session_state: | |
| cache_key_str = f"{tenant_id}|{normalized_keywords}|{source_names or []}" | |
| if cache_key_str in session_state["schema_search_cache"]: | |
| logger.info(f"Returning session-cached search results for tenant '{tenant_id}', keywords: {keywords}") | |
| cached = session_state["schema_search_cache"][cache_key_str].copy() | |
| cached['cache_hit'] = True | |
| cached['cache_source'] = 'session_state' | |
| metadata = session_state.setdefault("analysis_metadata", {}) | |
| metadata["last_schema_search"] = { | |
| "tenant_id": tenant_id, | |
| "keywords": list(normalized_keywords), | |
| "include_samples": include_samples, | |
| "total_matches": cached.get("total_matches", 0), | |
| "cache_hit": True, | |
| "cache_source": 'session_state' | |
| } | |
| metadata["last_schema_available_sources"] = cached.get("available_sources", []) | |
| return cached | |
| # Check toolkit-level cache (within single agent execution) | |
| if cache_key in self._search_cache: | |
| logger.info(f"Returning toolkit-cached search results for tenant '{tenant_id}', keywords: {keywords}") | |
| cached = self._search_cache[cache_key].copy() | |
| cached['cache_hit'] = True | |
| cached['cache_source'] = 'toolkit' | |
| if session_state is not None: | |
| metadata = session_state.setdefault("analysis_metadata", {}) | |
| metadata["last_schema_search"] = { | |
| "tenant_id": tenant_id, | |
| "keywords": list(normalized_keywords), | |
| "include_samples": include_samples, | |
| "total_matches": cached.get("total_matches", 0), | |
| "cache_hit": True, | |
| "cache_source": 'toolkit' | |
| } | |
| metadata["last_schema_available_sources"] = cached.get("available_sources", []) | |
| return cached | |
| logger.info(f"Tool: search_schema called for tenant '{tenant_id}' with keywords: {keywords}") | |
| try: | |
| url = "/api/v1/data-sources/schema/search" | |
| payload = { | |
| "tenant_id": tenant_id, | |
| "keywords": list(keywords), | |
| "include_samples": include_samples | |
| } | |
| if source_names: | |
| payload["source_names"] = source_names | |
| response = self.client.post(url, json=payload, timeout=DEFAULT_REQUEST_TIMEOUT, headers=request_headers) | |
| if response.status_code == 200: | |
| data = response.json() | |
| result = { | |
| "formatted_schema_string": data.get("formatted_schema_string", ""), | |
| "matches": data.get("matches", []), | |
| "available_sources": data.get("available_sources", []), | |
| "total_matches": data.get("total_matches", 0), | |
| "cache_hit": False | |
| } | |
| # Cache in toolkit-level cache | |
| self._search_cache[cache_key] = result.copy() | |
| # Cache in session_state if available | |
| if session_state: | |
| if "schema_search_cache" not in session_state: | |
| session_state["schema_search_cache"] = {} | |
| cache_key_str = f"{tenant_id}|{normalized_keywords}|{source_names or []}" | |
| session_state["schema_search_cache"][cache_key_str] = result.copy() | |
| logger.info(f"Schema search completed: {result['total_matches']} matches") | |
| return result | |
| else: | |
| error_text = response.text[:200] | |
| logger.error(f"Schema search failed: {error_text}") | |
| return self._normalize_error(f"Schema search failed: {error_text}", "api_error") | |
| except Exception as e: | |
| logger.exception(f"Search schema error: {e}") | |
| return self._normalize_error(f"Search error: {str(e)}", "UNEXPECTED_ERROR") | |
| def _fallback_schema_search(self, run_context: Optional[RunContext], tenant_id: str, keywords: List[str], source_names: Optional[List[str]]) -> Dict[str, Any]: | |
| """Fallback when /schema/search endpoint unavailable - uses basic filtering.""" | |
| logger.info("Using fallback schema search with client-side filtering") | |
| # Get all schemas | |
| full_result = self.get_available_sources_and_schema(run_context=run_context, tenant_id=tenant_id, keywords=None) | |
| if "error" in full_result: | |
| return full_result | |
| # Simple keyword matching on the formatted string | |
| schema_str = full_result.get("formatted_schema_string", "") | |
| available_sources = full_result.get("available_sources", []) | |
| # Basic filtering: check if any keyword appears in the schema | |
| matched = any(kw.lower() in schema_str.lower() for kw in keywords) | |
| if matched: | |
| return { | |
| "formatted_schema_string": schema_str, | |
| "matches": [], # No detailed matches in fallback | |
| "available_sources": available_sources, | |
| "total_matches": len(available_sources), | |
| "cache_hit": False, | |
| "fallback_mode": True | |
| } | |
| else: | |
| return { | |
| "formatted_schema_string": f"No tables found matching keywords: {', '.join(keywords)}\n\nAvailable sources: {', '.join(available_sources)}", | |
| "matches": [], | |
| "available_sources": available_sources, | |
| "total_matches": 0, | |
| "cache_hit": False, | |
| "fallback_mode": True | |
| } | |
| # --- Agent Tools --- | |
| def get_available_sources_and_schema( | |
| self, | |
| run_context: Optional[RunContext] = None, | |
| tenant_id: Optional[str] = None, | |
| keywords: Optional[List[str]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Retrieves the list of available data source names and their schemas for the specified tenant. | |
| Use this first to understand which data sources ('source_name') and tables are available. | |
| :param run_context: Agno RunContext containing session_state (auto-injected by framework). | |
| :param tenant_id: The ID of the tenant (e.g., 'acme-corp'). MANDATORY. | |
| :param keywords: Optional list of keywords to filter relevant tables. If provided, uses search_schema. | |
| :return: Dict containing 'formatted_schema_string' (all schemas combined) and 'available_sources' (list of names). Returns 'error' on failure. | |
| """ | |
| # Get session_state from run_context (Agno v2 pattern) | |
| session_state = run_context.session_state if run_context else None | |
| tenant_context = self._resolve_tenant(tenant_id, session_state) | |
| if isinstance(tenant_context, dict) and tenant_context.get("error"): | |
| return tenant_context | |
| tenant_id = tenant_context # type: ignore[assignment] | |
| # CRITICAL: Get per-request headers | |
| request_headers = self._get_request_headers(session_state) | |
| # If keywords provided, delegate to search_schema with run_context | |
| if keywords: | |
| logger.info(f"Tool: get_available_sources_and_schema called with keywords - delegating to search_schema") | |
| return self.search_schema(run_context=run_context, tenant_id=tenant_id, keywords=keywords) | |
| logger.info(f"Tool: get_available_sources_and_schema called for tenant '{tenant_id}'") | |
| # Use the dedicated list_sources endpoint to get available sources | |
| list_result = self.list_sources(run_context=run_context, tenant_id=tenant_id) | |
| # Check if list_sources returned an error | |
| if "error" in list_result: | |
| logger.error(f"Failed to list sources for tenant '{tenant_id}': {list_result.get('error')}") | |
| return list_result # Return the normalized error directly | |
| # Extract source names from the response | |
| source_names = list_result.get("available_sources", []) | |
| if not source_names: | |
| logger.warning(f"No source configurations found for tenant '{tenant_id}'.") | |
| return { | |
| "formatted_schema_string": "No data sources configured for this tenant.", | |
| "available_sources": [] | |
| } | |
| logger.info(f"Found {len(source_names)} sources for tenant '{tenant_id}': {source_names}") | |
| # Fetch schema for each identified source name | |
| all_schemas_formatted = [] | |
| fetch_errors = [] | |
| for source_name in source_names: | |
| try: | |
| schema_url = f"/api/v1/data-sources/schema/{source_name}" | |
| logger.debug(f"Fetching schema from {schema_url}") | |
| schema_response = self.client.get( | |
| schema_url, | |
| timeout=DEFAULT_REQUEST_TIMEOUT, | |
| headers=request_headers # Use per-request headers | |
| ) | |
| if schema_response.status_code == 200: | |
| schema_data = schema_response.json() | |
| schema_json_string = schema_data.get("schema_data") | |
| formatted_schema = self._format_schema_for_llm(source_name, schema_json_string) | |
| all_schemas_formatted.append(formatted_schema) | |
| elif schema_response.status_code == 404: | |
| logger.warning(f"Schema API returned 404 for source '{source_name}', tenant '{tenant_id}'.") | |
| all_schemas_formatted.append(f"**Source: `{source_name}`**\n Schema: <Not Found for Tenant>") | |
| fetch_errors.append(source_name) | |
| else: | |
| error_text = schema_response.text[:200] # Limit error text length | |
| logger.warning(f"Failed to get schema for source '{source_name}' (Status {schema_response.status_code}): {error_text}") | |
| all_schemas_formatted.append(f"**Source: `{source_name}`**\n Schema: <Error: {schema_response.status_code}>") | |
| fetch_errors.append(source_name) | |
| except httpx.RequestError as e: | |
| logger.error(f"API connection error fetching schema for '{source_name}': {e}") | |
| all_schemas_formatted.append(f"**Source: `{source_name}`**\n Schema: <API Connection Error>") | |
| fetch_errors.append(source_name) | |
| except Exception as e: # Catch other errors during processing for a single source | |
| logger.exception(f"Unexpected error processing schema for '{source_name}': {e}") | |
| all_schemas_formatted.append(f"**Source: `{source_name}`**\n Schema: <Unexpected Error Processing Schema>") | |
| fetch_errors.append(source_name) | |
| # Combine successfully fetched schemas into one string | |
| # Separate each source's schema clearly | |
| final_schema_string = "\n\n---\n\n".join(all_schemas_formatted) | |
| # Report any errors clearly | |
| if fetch_errors: | |
| error_message = f"Note: Failed to retrieve schema for the following sources: {', '.join(fetch_errors)}." | |
| final_schema_string += f"\n\n**Warning:** {error_message}" | |
| # Return the combined string and the list of sources attempted/found | |
| return { | |
| "formatted_schema_string": final_schema_string, | |
| "available_sources": source_names # Return the names found via config | |
| } | |
| def get_source_instructions( | |
| self, | |
| run_context: Optional[RunContext] = None, | |
| tenant_id: Optional[str] = None, | |
| source_name: Optional[str] = None | |
| ) -> Dict[str, str]: | |
| """ | |
| Retrieves special instructions (like SQL dialect, syntax rules, function usage) | |
| for a specific data source. Call this *after* choosing a source_name | |
| from get_available_sources_and_schema and *before* writing SQL. | |
| **IMPORTANT**: Instructions are automatically cached in session_state for the agent | |
| to use during SQL generation. The agent can access them via: | |
| session_state['source_instructions'][source_name] | |
| :param run_context: Agno RunContext containing session_state (auto-injected by framework). | |
| :param tenant_id: The ID of the tenant (e.g., 'acme-corp'). MANDATORY. | |
| :param source_name: The name of the data source (e.g., 'production_db'). MANDATORY. | |
| :return: Dictionary containing 'instructions' string or 'error'. | |
| """ | |
| # Get session_state from run_context (Agno v2 pattern) | |
| session_state = run_context.session_state if run_context else None | |
| tenant_context = self._resolve_tenant(tenant_id, session_state) | |
| if isinstance(tenant_context, dict) and tenant_context.get("error"): | |
| return tenant_context | |
| tenant_id = tenant_context # type: ignore[assignment] | |
| source_context = self._resolve_source(source_name, session_state) | |
| if isinstance(source_context, dict) and source_context.get("error"): | |
| return source_context | |
| source_name = source_context # type: ignore[assignment] | |
| # CRITICAL: Get per-request headers | |
| request_headers = self._get_request_headers(session_state) | |
| logger.info(f"Tool: get_source_instructions called for tenant '{tenant_id}', source '{source_name}'") | |
| # Check session_state cache first | |
| if session_state and "source_instructions" in session_state: | |
| cached_instructions = session_state["source_instructions"].get(source_name) | |
| if cached_instructions: | |
| logger.info(f"Returning session-cached instructions for source '{source_name}'") | |
| return {"instructions": cached_instructions} | |
| try: | |
| url = f"/api/v1/data-sources/instructions/{source_name}" | |
| logger.debug(f"Fetching instructions from {url}") | |
| response = self.client.get( | |
| url, | |
| timeout=DEFAULT_REQUEST_TIMEOUT, | |
| headers=request_headers # Use per-request headers | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| instructions = data.get("instructions", "No specific instructions provided for this source.") | |
| connector_type = data.get("connector_type", "unknown") | |
| logger.info(f"Successfully fetched instructions for '{source_name}' ({connector_type}).") | |
| # Store in session_state for agent reuse | |
| if session_state is not None: | |
| # Initialize source_instructions dict if needed | |
| if "source_instructions" not in session_state: | |
| session_state["source_instructions"] = {} | |
| # Cache instructions by source name | |
| session_state["source_instructions"][source_name] = instructions | |
| # Also track metadata | |
| metadata = session_state.setdefault("analysis_metadata", {}) | |
| metadata["last_source_instructions"] = { | |
| "tenant_id": tenant_id, | |
| "source_name": source_name, | |
| "connector_type": connector_type, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| logger.debug(f"Cached instructions for '{source_name}' in session_state") | |
| return {"instructions": instructions} | |
| elif response.status_code == 404: | |
| logger.warning(f"Instructions API returned 404 for source '{source_name}', tenant '{tenant_id}'.") | |
| return self._normalize_error( | |
| f"Instructions not found for source '{source_name}' for this tenant.", | |
| "not_found", | |
| "Verify the source_name exists for this tenant via list_sources()." | |
| ) | |
| else: | |
| error_text = response.text[:200] | |
| logger.error(f"Failed to get instructions for '{source_name}' (Status {response.status_code}): {error_text}") | |
| return self._normalize_error( | |
| f"Failed to get instructions (Status {response.status_code}): {error_text}", | |
| "api_error", | |
| "Check API logs and verify the /instructions endpoint is operational." | |
| ) | |
| except httpx.RequestError as e: | |
| logger.error(f"API connection error getting instructions for '{source_name}': {e}") | |
| return self._normalize_error(f"API connection error: {e}", "connection_error", "Ensure the data sources API is running and accessible.") | |
| except Exception as e: | |
| logger.exception(f"Unexpected error in get_source_instructions for '{source_name}': {e}") | |
| return self._normalize_error(f"An unexpected error occurred while fetching instructions: {e}", "internal_error", "Check logs for detailed stack trace.") | |
| def execute_sql_query( | |
| self, | |
| run_context: Optional[RunContext] = None, | |
| tenant_id: Optional[str] = None, | |
| source_name: Optional[str] = None, | |
| sql_query: str = "", | |
| async_mode: bool = False, | |
| max_rows: Optional[int] = None, | |
| timeout_seconds: Optional[int] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Executes a raw SQL query against a specific data source for the tenant. | |
| Supports synchronous execution (default) and async job submission. | |
| :param run_context: Agno RunContext containing session_state (auto-injected by framework). | |
| :param tenant_id: Tenant identifier. Resolved from session_state when omitted. | |
| :param source_name: Target data source. Resolved from session_state when omitted. | |
| :param sql_query: Raw SQL string to execute. | |
| :param async_mode: When True, enqueue query and return job metadata. | |
| :param max_rows: Optional server-side row limit override. | |
| :param timeout_seconds: Optional execution timeout override. | |
| :return: Dictionary with execution payload or normalized error. | |
| """ | |
| # Get session_state from run_context (Agno v2 pattern) | |
| session_state = run_context.session_state if run_context else None | |
| tenant_context = self._resolve_tenant(tenant_id, session_state) | |
| if isinstance(tenant_context, dict) and tenant_context.get("error"): | |
| return tenant_context | |
| tenant_id = tenant_context # type: ignore[assignment] | |
| source_context = self._resolve_source(source_name, session_state) | |
| if isinstance(source_context, dict) and source_context.get("error"): | |
| return source_context | |
| source_name = source_context # type: ignore[assignment] | |
| # CRITICAL: Get per-request headers | |
| request_headers = self._get_request_headers(session_state) | |
| if not sql_query or not sql_query.strip(): | |
| return self._normalize_error("SQL query cannot be empty.", "missing_parameter", "Provide a non-empty sql_query string.") | |
| logger.info(f"Tool: execute_sql_query called for tenant '{tenant_id}' on source '{source_name}'. Query: {sql_query[:100]}...") | |
| # Basic client-side check for obviously disallowed operations (optional layer) | |
| # The primary security relies on the DB user permissions configured in the backend. | |
| disallowed_keywords = ['drop ', 'delete ', 'truncate ', 'alter ', 'insert ', 'update ', 'grant ', 'revoke '] | |
| normalized_query = sql_query.strip().lower() | |
| if any(keyword in normalized_query for keyword in disallowed_keywords): | |
| # Allow specific commands like SELECT, SHOW, DESCRIBE, EXPLAIN etc. | |
| allowed_starts = ('select', 'with', 'show', 'describe', 'explain') | |
| if not normalized_query.startswith(allowed_starts): | |
| logger.warning(f"Potentially unsafe SQL query blocked client-side for tenant {tenant_id}: {sql_query[:100]}...") | |
| return self._normalize_error( | |
| "Query blocked: Operation potentially modifies data or structure. Only read-only queries (SELECT, SHOW, etc.) are allowed.", | |
| "forbidden_operation", | |
| "Rewrite query using SELECT, SHOW, DESCRIBE, or EXPLAIN commands only." | |
| ) | |
| try: | |
| url = "/api/v1/data-sources/execute-raw-sql" | |
| payload = { | |
| "tenant_id": tenant_id, | |
| "source_name": source_name, | |
| "sql_query": sql_query # Send the raw SQL | |
| } | |
| if async_mode: | |
| payload["async_mode"] = True | |
| if max_rows is not None: | |
| payload["max_rows"] = max_rows | |
| if timeout_seconds is not None: | |
| payload["timeout_seconds"] = timeout_seconds | |
| logger.debug(f"Posting raw SQL to {url}") | |
| response = self.client.post( | |
| url, | |
| json=payload, | |
| timeout=DEFAULT_API_TIMEOUT, # Use longer timeout for query execution | |
| headers=request_headers # Use per-request headers | |
| ) | |
| # --- Handle API Response --- | |
| if response.status_code == 200: | |
| data = response.json() | |
| status = data.get("status", "success") | |
| if status.lower() == "success": | |
| results = data.get("results", []) | |
| # --- START: Large Data Handling (Pandas Summary) --- | |
| if isinstance(results, str) and (results.startswith("minio://") or results.startswith("s3://")): | |
| logger.info(f"Large result detected at {results}. Generating summary for LLM...") | |
| try: | |
| import pandas as pd | |
| import io | |
| from backend.core.minio.config import get_minio_config | |
| from minio import Minio | |
| mconf = get_minio_config() | |
| minio_client = Minio( | |
| endpoint=mconf['endpoint'], | |
| access_key=mconf['access_key'], | |
| secret_key=mconf['secret_key'], | |
| secure=mconf.get('secure', False) | |
| ) | |
| bucket, obj_name = results.replace("minio://", "").split("/", 1) | |
| response_obj = minio_client.get_object(bucket, obj_name) | |
| file_content = response_obj.read() | |
| df = pd.read_json(io.BytesIO(file_content)) | |
| row_count = len(df) | |
| preview_rows = 200 | |
| description = df.describe(include='all').to_markdown() | |
| head_data = df.head(preview_rows).to_dict(orient='records') | |
| if session_state is not None: | |
| metadata = session_state.setdefault("analysis_metadata", {}) | |
| metadata["last_sql_execution"] = { | |
| "tenant_id": tenant_id, | |
| "source_name": source_name, | |
| "row_count": row_count, | |
| "async_mode": async_mode | |
| } | |
| metadata["last_sql_query"] = sql_query.strip() | |
| return { | |
| "status": "success", | |
| "results": head_data, | |
| "summary": description, | |
| "message": ( | |
| f"Result too large ({row_count} rows). " | |
| f"Returned first {preview_rows} rows and statistical summary. " | |
| f"Full data stored at {results}." | |
| ), | |
| "row_count": row_count, | |
| "rows_limited": True | |
| } | |
| except Exception as fetch_err: | |
| logger.error(f"Failed to fetch/summarize large result: {fetch_err}") | |
| return { | |
| "status": "success", | |
| "results": [], | |
| "message": f"Result stored at {results}, but could not be downloaded for summary. ({str(fetch_err)})", | |
| "row_count": "Unknown", | |
| "rows_limited": True | |
| } | |
| # --- END: Large Data Handling --- | |
| row_count = len(results) if isinstance(results, list) else data.get("rows_returned", 0) | |
| logger.info(f"SQL query executed successfully for tenant '{tenant_id}', returned {row_count} rows.") | |
| if session_state is not None: | |
| metadata = session_state.setdefault("analysis_metadata", {}) | |
| metadata["last_sql_execution"] = { | |
| "tenant_id": tenant_id, | |
| "source_name": source_name, | |
| "row_count": row_count, | |
| "async_mode": async_mode | |
| } | |
| metadata["last_sql_query"] = sql_query.strip() | |
| return { | |
| "status": status, | |
| "results": results, | |
| "rows_returned": data.get("rows_returned", row_count), | |
| "rows_limited": data.get("rows_limited", False), | |
| "execution_time_ms": data.get("execution_time_ms") | |
| } | |
| logger.info(f"SQL execution acknowledged with status '{status}' for tenant '{tenant_id}'") | |
| if session_state is not None: | |
| metadata = session_state.setdefault("analysis_metadata", {}) | |
| metadata["last_sql_execution"] = { | |
| "tenant_id": tenant_id, | |
| "source_name": source_name, | |
| "status": status, | |
| "async_mode": data.get("async_mode", async_mode), | |
| "job_id": data.get("job_id") | |
| } | |
| metadata["last_sql_query"] = sql_query.strip() | |
| return data | |
| else: | |
| # Attempt to get detailed error message from API response | |
| error_detail = f"Request failed with status {response.status_code}." | |
| try: | |
| error_json = response.json() | |
| # FastAPI often puts errors in {"detail": "..."} | |
| if "detail" in error_json: | |
| error_detail = error_json["detail"] | |
| else: # Otherwise, convert whole JSON to string | |
| error_detail = json.dumps(error_json) | |
| except json.JSONDecodeError: | |
| # If response is not JSON, use raw text | |
| error_detail = response.text[:500] # Limit length | |
| logger.error(f"SQL execution API failed for tenant '{tenant_id}' (Status {response.status_code}): {error_detail}") | |
| # Provide a structured error back to the agent | |
| return self._normalize_error( | |
| f"SQL Execution Failed (Status {response.status_code}): {error_detail}", | |
| "execution_error", | |
| "Check query syntax and ensure the source is accessible for this tenant." | |
| ) | |
| except httpx.TimeoutException: | |
| logger.error(f"API timeout executing SQL for tenant '{tenant_id}' on source '{source_name}'.") | |
| return self._normalize_error( | |
| f"API request timed out after {DEFAULT_API_TIMEOUT} seconds. The query might be too long-running for synchronous execution.", | |
| "timeout_error", | |
| "Consider optimizing the query or using async job submission for long-running queries." | |
| ) | |
| except httpx.RequestError as e: | |
| logger.error(f"API connection error executing SQL for tenant '{tenant_id}': {e}") | |
| return self._normalize_error(f"API connection error during SQL execution: {e}", "connection_error", "Ensure the data sources API is running and accessible.") | |
| except Exception as e: | |
| logger.exception(f"Unexpected error in execute_sql_query for tenant '{tenant_id}': {e}") | |
| return {"error": f"An unexpected client-side error occurred during SQL execution: {e}"} | |
| def compose_dataset_workflow( | |
| self, | |
| run_context: Optional[RunContext] = None, | |
| tenant_id: Optional[str] = None, | |
| name: str = "dataset", | |
| target_format: str = "parquet", | |
| steps: Optional[List[Dict[str, Any]]] = None, | |
| keep_intermediates: bool = False, | |
| ) -> Dict[str, Any]: | |
| """Compose a dataset from SQL/file workflow steps via tenant-files API.""" | |
| session_state = run_context.session_state if run_context else None | |
| tenant_context = self._resolve_tenant(tenant_id, session_state) | |
| if isinstance(tenant_context, dict) and tenant_context.get("error"): | |
| return tenant_context | |
| if not steps: | |
| return self._normalize_error("steps must be provided", code="missing_parameter") | |
| request_headers = self._get_request_headers(session_state) | |
| payload = { | |
| "name": name, | |
| "target_format": target_format, | |
| "steps": steps, | |
| "keep_intermediates": keep_intermediates, | |
| } | |
| try: | |
| response = self.client.post( | |
| "/api/v1/tenant-files/datasets/compose", | |
| json=payload, | |
| timeout=DEFAULT_REQUEST_TIMEOUT, | |
| headers=request_headers, | |
| ) | |
| if response.status_code in (200, 201): | |
| return response.json() | |
| return self._normalize_error( | |
| f"Failed to compose dataset: {response.text[:200]}", | |
| code="API_ERROR", | |
| ) | |
| except Exception as e: | |
| logger.exception(f"Error in compose_dataset_workflow: {e}") | |
| return self._normalize_error(f"Error: {str(e)}", code="UNEXPECTED_ERROR") | |
| def save_query_to_tenant_csv( | |
| self, | |
| run_context: Optional[RunContext] = None, | |
| tenant_id: Optional[str] = None, | |
| source_name: Optional[str] = None, | |
| sql_query: str = "", | |
| dataset_name: str = "query_export" | |
| ) -> str: | |
| """ | |
| Executes a raw SQL query and directly saves the full results as a CSV in the | |
| tenant's MinIO storage. This bridges Data Sources with ML capabilities! | |
| Use this tool when you need to run Machine Learning, Data Profiling, or deep | |
| pandas analysis on SQL data, because this returns a file path that can be | |
| given to ML/pandas tools. | |
| :param run_context: Agno RunContext (auto-injected). | |
| :param tenant_id: Tenant ID (auto-resolved from session). | |
| :param source_name: Source name (auto-resolved from session). | |
| :param sql_query: The SQL query to run. | |
| :param dataset_name: A descriptive name for the resulting CSV file. | |
| :return: A success string containing the MinIO `minio://...` path. | |
| """ | |
| # Execute the query | |
| logger.info(f"save_query_to_tenant_csv called for dataset '{dataset_name}'") | |
| result = self.execute_sql_query( | |
| run_context=run_context, | |
| tenant_id=tenant_id, | |
| source_name=source_name, | |
| sql_query=sql_query, | |
| async_mode=False | |
| # Let it run without strict max_rows so we get the real dataset | |
| ) | |
| if "error" in result: | |
| return f"Failed to execute SQL: {result['error']}" | |
| # Extract results | |
| raw_data = result.get("results") | |
| # If the result is already a minio string (because execute_sql_query intercepted a giant payload) | |
| if isinstance(raw_data, str) and (raw_data.startswith("minio://") or raw_data.startswith("s3://")): | |
| return f"Data successfully saved to object storage. Path to use for ML/Analysis tools: `{raw_data}`" | |
| if not raw_data or not isinstance(raw_data, list): | |
| return "Query returned no data or an invalid format. Cannot save to CSV." | |
| try: | |
| import pandas as pd | |
| import io | |
| import json | |
| # Resolve tenant explicitly | |
| session_state = run_context.session_state if run_context else None | |
| resolved_tenant = self._resolve_tenant(tenant_id, session_state) | |
| if isinstance(resolved_tenant, dict): return "Could not resolve tenant_id for saving file." | |
| df = pd.DataFrame(raw_data) | |
| csv_buffer = io.BytesIO() | |
| df.to_csv(csv_buffer, index=False) | |
| csv_buffer.seek(0) | |
| timestamp = int(datetime.now().timestamp()) | |
| safe_name = dataset_name.replace(" ", "_").lower() | |
| filename = f"{safe_name}_{timestamp}.csv" | |
| request_headers = self._get_request_headers(session_state) | |
| # Remove Content-Type so httpx sets the multipart boundary correctly | |
| headers = {k: v for k, v in request_headers.items() if k.lower() != "content-type"} | |
| files = {"file": (filename, csv_buffer, "text/csv")} | |
| metadata = json.dumps({"source": "sql_agent", "label": "query_export"}) | |
| data = {"metadata": metadata} | |
| response = self.client.post( | |
| "/api/v1/tenant-files/assets", | |
| files=files, | |
| data=data, | |
| headers=headers, | |
| timeout=120.0 | |
| ) | |
| if response.status_code == 201: | |
| resp_data = response.json() | |
| asset_id = resp_data.get("asset", {}).get("asset_id", "") | |
| final_path = resp_data.get("asset", {}).get("path", filename) | |
| return f"Data ({len(df)} rows, {len(df.columns)} columns) successfully saved to tenant storage. Path/Asset ID to use for ML/Analysis tools: `{final_path}` or `{asset_id}`" | |
| else: | |
| error_text = response.text[:200] | |
| return f"Query succeeded, but saving to tenant-files API failed (HTTP {response.status_code}): {error_text}" | |
| except ImportError: | |
| return "Failed to import pandas. Cannot save CSV locally." | |
| except Exception as e: | |
| logger.exception(f"Error saving query to CSV: {e}") | |
| return f"Query succeeded, but saving to CSV failed: {str(e)}" |