sirus / backend /SQL_Agent /oldAgent.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
b8277c4
"""
⚠️ DEPRECATED - DO NOT USE IN PRODUCTION ⚠️
This file (oldAgent.py) is DEPRECATED and maintained only for reference.
MIGRATION GUIDE:
- Use agent.py for all new development
- agent.py provides the same functionality with:
* Production-grade API integration
* Multi-tenant isolation via session state
* No direct database connections
* Centralized data access through data_sources API
* Better error handling and logging
* API key authentication support
This file will be REMOVED in a future release.
Please migrate all code to use agent.py instead.
Last Updated: October 2025
Deprecation Status: ACTIVE - Do not use for new features
Removal Target: Next major release
"""
import os
import json
import re
import logging
from typing import Dict, Set, Any, List, Optional
from collections import defaultdict
from functools import lru_cache
from textwrap import dedent
from sqlalchemy import create_engine, text
from tenacity import retry, stop_after_attempt, wait_exponential
from urllib.parse import quote_plus
import pandas as pd
from datetime import datetime
# Updated imports for comprehensive tracking
from agno.db.sqlite import SqliteDb # Changed from InMemoryDb for persistence
from agno.agent import Agent
from agno.models.google import Gemini
from agno.tools import Toolkit
from agno.tools.reasoning import ReasoningTools
from agno.run.response import RunContext
# Your existing database configuration
DB_DIALECT = "mysql+pymysql"
DB_USER = "root"
DB_PASSWORD = "bwgadmin@2023"
DB_HOST = "65.0.127.253"
DB_PORT = "3306"
DB_NAME = "bookwedgo"
SCHEMA_PATH = r"/content/database_schema.json"
if "GOOGLE_API_KEY" not in os.environ:
print("🔴 WARNING: GOOGLE_API_KEY environment variable not set. The agent will fail.")
encoded_password = quote_plus(DB_PASSWORD)
DB_URL = f"{DB_DIALECT}://{DB_USER}:{encoded_password}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Your existing SchemaManager and DatabaseToolkit classes remain the same
class _SchemaManager:
def __init__(self, schema_path: str):
if not os.path.exists(schema_path):
raise FileNotFoundError(f"Schema file not found: {schema_path}. Please upload it.")
with open(schema_path, 'r') as f:
self.schema_data = json.load(f)
# Build the new data structures on initialization
self.relationship_graph = self._build_relationship_graph()
self.keyword_to_columns = self._create_keyword_mappings()
logger.info("SchemaManager initialized with Relationship Graph and Column-Level Mappings.")
def _build_relationship_graph(self) -> Dict[str, Set[str]]:
graph = defaultdict(set)
tables = self.schema_data.get('tables', [])
if isinstance(self.schema_data, list): tables = self.schema_data[0].get('tables', [])
for table in tables:
table_name = table.get('table_name')
if not table_name: continue
for rel in table.get('relationships', []):
referenced_table = rel.get('referenced_table')
if referenced_table:
# Add a two-way connection for easy lookup
graph[table_name].add(referenced_table)
graph[referenced_table].add(table_name)
return graph
@lru_cache(maxsize=None)
def _create_keyword_mappings(self) -> Dict[str, Dict[str, Set[str]]]:
mappings = defaultdict(lambda: defaultdict(set))
tables = self.schema_data.get('tables', [])
if isinstance(self.schema_data, list): tables = self.schema_data[0].get('tables', [])
def split_on_case_and_underscore(s: str) -> List[str]:
parts = re.findall(r'[A-Z][a-z]*|[a-z]+|\d+', s)
return [p.lower() for p in re.split('_', " ".join(parts))]
for table in tables:
table_name = table.get('table_name', '')
for word in split_on_case_and_underscore(table_name):
mappings[word][table_name].add('__table__') # Special key for table name match
for field in table.get('fields', []):
col_name = field.get('name', '')
for word in split_on_case_and_underscore(col_name):
mappings[word][table_name].add(col_name)
for word in re.findall(r'[a-zA-Z]{3,}', field.get('description', '')):
mappings[word.lower()][table_name].add(col_name)
return mappings
def get_related_tables(self, table_name: str) -> Set[str]:
return self.relationship_graph.get(table_name, set())
# --- DatabaseToolkit leveraging the new _SchemaManager ---
class DatabaseToolkit(Toolkit):
def __init__(self, schema_path: str, db_url: str):
super().__init__(name="database_tools", tools=[self.get_filtered_database_schema, self.execute_sql])
self.schema_manager = _SchemaManager(schema_path)
self.engine = create_engine(db_url)
# UPDATED PART 1: _format_schema_for_llm now includes example rows
def _format_schema_for_llm(self, schema: Dict[str, Any]) -> str:
if not schema.get("tables"): return "No relevant tables found."
output_parts = [f"Database Schema: {schema.get('name', 'N/A')}\n"]
for table in schema["tables"]:
table_name = table.get('table_name', 'N/A')
description = table.get('description', '').strip()
output_parts.append(f"---")
output_parts.append(f"\n**Table: `{table_name}`**")
if description: output_parts.append(f"*Description: {description}*")
output_parts.append("\n**Columns:**")
for field in table.get('fields', []):
col_name, col_type, col_example = field.get('name', 'N/A'), field.get('type', 'N/A'), field.get('example', 'N/A')
col_desc = field.get('description', 'No description.').replace(f"This is the '{col_name}' column of the '{table_name}' table.", "").strip()
output_parts.append(f"- `{col_name}` (type: {col_type}, ex: '{col_example}') - {col_desc}")
if table.get("relationships"):
output_parts.append("\n**Relationships:**")
for rel in table["relationships"]:
col, ref_table, ref_col = rel.get('column'), rel.get('referenced_table'), rel.get('referenced_column')
output_parts.append(f"- `{table_name}.{col}` -> `{ref_table}.{ref_col}`")
# --- NEW: Add example rows to the output ---
if table.get("example_rows"):
output_parts.append("\n**Example Rows:**")
try:
df = pd.DataFrame(table["example_rows"])
output_parts.append(f"```\n{df.to_string(index=False)}\n```")
except Exception:
# Fallback if pandas fails for some reason
for row in table["example_rows"]:
output_parts.append(f"- {row}")
output_parts.append("\n")
return "\n".join(output_parts)
# UPDATED PART 2: get_filtered_database_schema now fetches live data samples
def get_filtered_database_schema(self, keywords: List[str]) -> Dict[str, Any]:
logger.info(f"Tool 'get_filtered_database_schema' called with keywords: {keywords}")
if not keywords:
return {"error": "No keywords provided."}
initial_tables = set()
for keyword in keywords:
clean_keyword = keyword.lower().strip()
if clean_keyword in self.schema_manager.keyword_to_columns:
initial_tables.update(self.schema_manager.keyword_to_columns[clean_keyword].keys())
expanded_tables = set(initial_tables)
for table_name in initial_tables:
related_tables = self.schema_manager.get_related_tables(table_name)
expanded_tables.update(related_tables)
logger.info(f"Initial tables: {initial_tables}. Expanded with related tables: {expanded_tables}")
source_tables = self.schema_manager.schema_data.get('tables', [])
if isinstance(self.schema_manager.schema_data, list):
source_tables = self.schema_manager.schema_data[0].get('tables', [])
final_table_objects = [t for t in source_tables if t.get('table_name') in expanded_tables]
# --- NEW: Inject live data samples into the schema object ---
for table_obj in final_table_objects:
table_name = table_obj.get('table_name')
if not table_name: continue
try:
with self.engine.connect() as connection:
sample_query = text(f"SELECT * FROM `{table_name}` LIMIT 3")
sample_result = connection.execute(sample_query)
sample_rows = [dict(row._mapping) for row in sample_result.fetchall()]
table_obj['example_rows'] = sample_rows
except Exception as e:
logger.warning(f"Could not fetch sample rows for table {table_name}: {e}")
table_obj['example_rows'] = []
if not final_table_objects:
return {"formatted_schema_string": "No relevant tables were found."}
final_schema_json = {
"name": self.schema_manager.schema_data.get("schema_name", "database"),
"tables": final_table_objects
}
formatted_string = self._format_schema_for_llm(final_schema_json)
return {"formatted_schema_string": formatted_string}
# UPDATED PART 3: execute_sql now has a pre-execution sanitization layer
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=5))
def execute_sql(self, sql_query: str) -> Dict[str, Any]:
logger.info(f"Original SQL query from LLM:\n{sql_query}")
# --- INTELLIGENT SANITIZATION LAYER ---
sanitized_query = sql_query
try:
# Fix the DATE_FORMAT percent-sign escaping issue for MySQL
sanitized_query = re.sub(r"DATE_FORMAT\(([^,]+,)\s*'([^']*)%([^']*)'([^)]*)\)", r"DATE_FORMAT(\1 '\2%%\3'\4)", sanitized_query, flags=re.IGNORECASE)
# Safeguard: Ensure COALESCE is used in aggregations
sanitized_query = re.sub(r'(SUM|AVG|MAX|MIN)\(\s*([a-zA-Z0-9_\.]+)\s*\)', r'\1(COALESCE(\2, 0))', sanitized_query, flags=re.IGNORECASE)
if sanitized_query != sql_query:
logger.info(f"Sanitized SQL query for execution:\n{sanitized_query}")
with self.engine.connect() as connection:
result = connection.execute(text(sanitized_query))
results_list = [dict(row._mapping) for row in result.fetchall()]
return {"sql_results": results_list}
except Exception as e:
logger.error(f"SQL execution error: {e}")
# Provide a more helpful error message back to the agent
error_message = f"SQL Error: {str(e)}. Review the query for syntax issues, especially around date functions and aliases. The failed query was: {sanitized_query}"
return {"error": error_message}
# NEW: Enhanced Tool Hook for Complete Logging
def comprehensive_logging_hook(
run_context: RunContext,
function_name: str,
function_call,
arguments: Dict[str, Any]
) -> Any:
"""
Comprehensive tool execution logging hook that saves:
- Tool name and arguments
- Execution timestamp
- Results
- User context
"""
# Access session_state from run_context (Agno v2 API)
if not run_context.session_state:
run_context.session_state = {}
session_state = run_context.session_state
# Initialize logging structure in session state
if "tool_execution_log" not in session_state:
session_state["tool_execution_log"] = []
# Create execution record
execution_start = datetime.now()
execution_record = {
"tool_name": function_name,
"arguments": arguments,
"timestamp": execution_start.isoformat(),
"execution_id": f"{function_name}_{execution_start.timestamp()}"
}
logger.info(f"🔧 Executing tool: {function_name} with args: {arguments}")
try:
# Execute the actual tool
result = function_call(**arguments)
# Log successful execution
execution_end = datetime.now()
execution_record.update({
"result": str(result)[:1000], # Truncate long results
"status": "success",
"duration_ms": (execution_end - execution_start).total_seconds() * 1000,
"completed_at": execution_end.isoformat()
})
logger.info(f"✅ Tool {function_name} completed successfully in {execution_record['duration_ms']:.2f}ms")
except Exception as e:
# Log failed execution
execution_end = datetime.now()
execution_record.update({
"error": str(e),
"status": "failed",
"duration_ms": (execution_end - execution_start).total_seconds() * 1000,
"completed_at": execution_end.isoformat()
})
logger.error(f"❌ Tool {function_name} failed: {str(e)}")
raise # Re-raise the exception
finally:
# Always save the execution record
session_state["tool_execution_log"].append(execution_record)
return result
# Enhanced system prompt with logging awareness
system_prompt = dedent("""
You are Sirus, an expert data scientist with comprehensive execution tracking.
**EXECUTION TRACKING:**
- All your tool executions are automatically logged with timestamps, arguments, and results
- Session state maintains a complete audit trail of your analysis process
- Each query execution is tracked for performance and debugging
**GUIDING PRINCIPLES:**
1. **Be a Business Analyst:** Hide technical complexity from users
2. **Be Resilient & Self-Correct:** Use `think` to diagnose and retry on failures
3. **Prioritize Source of Truth:** Choose the best tables for analysis
4. **Decompose Complex Questions:** Break down into simple sub-problems
**MANDATORY THOUGHT PROCESS:**
For complex questions, solve ONE sub-problem at a time:
a) **`think`**: State the sub-problem and plan
b) **`execute_sql`**: Run the query
c) **`think`**: Review results and plan next steps
**SQL REQUIREMENTS:**
- Use MySQL syntax with single % in DATE_FORMAT
- Always use COALESCE in aggregations
- Test with simple queries first if complex ones fail
**FINAL RESPONSE:**
- Start with key numbers in **bold**
- Provide business insights
- Explain methodology simply
- Suggest logical next steps
""")
print("✅ Configuration set. Initializing enhanced agent with comprehensive logging...")
# Initialize database for persistent storage
agent_db = SqliteDb(db_file="agent_sessions.db")
# Initialize toolkits
db_toolkit = DatabaseToolkit(schema_path=SCHEMA_PATH, db_url=DB_URL)
reasoning_tools = ReasoningTools(add_instructions=True, enable_analyze=True, enable_think=True)
# Create enhanced agent with comprehensive tracking
gemini_sql_agent = Agent(
model=Gemini(
id="gemini-2.5-flash",
system_prompt=system_prompt,
thinking_budget=24000,
include_thoughts=True,
),
tools=[db_toolkit, reasoning_tools],
tool_hooks=[comprehensive_logging_hook], # Add the logging hook
# Database and session management
db=agent_db,
add_history_to_context=True,
num_history_runs=3,
read_chat_history=True,
# Session state for tracking
session_state={
"tool_execution_log": [],
"user_context": {},
"analysis_metadata": {}
},
add_session_state_to_context=True, # Make session state available to tools
# Response formatting
markdown=True,
add_datetime_to_context=True,
stream_intermediate_steps=True,
# Error handling
exponential_backoff=True,
delay_between_retries=10
)
# Enhanced execution function with user tracking
def execute_query_with_tracking(question: str, user_id: str, session_id: Optional[str] = None):
"""
Execute a query with comprehensive user and session tracking
"""
print(f"\n🚀 User {user_id} asking: '{question}'\n")
# Execute with session tracking
response = gemini_sql_agent.print_response(
question,
stream=True,
session_id=session_id
)
# Get final session state with all execution logs
final_session_state = gemini_sql_agent.get_session_state(session_id)
print(f"\n📊 Execution Summary for User {user_id}:")
print(f" Session ID: {gemini_sql_agent.session_id}")
print(f" Tools executed: {len(final_session_state.get('tool_execution_log', []))}")
# Print tool execution summary
for i, log_entry in enumerate(final_session_state.get('tool_execution_log', [])[-3:], 1):
print(f" Tool {i}: {log_entry['tool_name']} - {log_entry['status']} ({log_entry.get('duration_ms', 0):.1f}ms)")
return response, final_session_state
# Function to retrieve session history and analytics
def get_session_analytics(session_id: str):
"""
Retrieve comprehensive session analytics
"""
session = gemini_sql_agent.get_session(session_id)
if session:
print(f"\n📈 Session Analytics for {session_id}:")
print(f" Created: {session.created_at}")
print(f" Updated: {session.updated_at}")
print(f" Total runs: {len(session.runs) if hasattr(session, 'runs') else 'N/A'}")
# Get session state
session_state = gemini_sql_agent.get_session_state(session_id)
tool_logs = session_state.get('tool_execution_log', [])
print(f" Total tool executions: {len(tool_logs)}")
# Tool usage statistics
tool_usage = {}
for log in tool_logs:
tool_name = log['tool_name']
tool_usage[tool_name] = tool_usage.get(tool_name, 0) + 1
print(" Tool usage breakdown:")
for tool, count in tool_usage.items():
print(f" - {tool}: {count} times")
return session
# Example usage with user tracking
if __name__ == "__main__":
import time
# Example 1: First user query
user_1_response, user_1_state = execute_query_with_tracking(
question="What were our total bookings",
user_id="user_001",
session_id="business_analysis_session_1"
)
time.sleep(30)
# Example 2: Follow-up query in same session
user_1_followup, user_1_final_state = execute_query_with_tracking(
question="Give me the total revenue from these bookings",
user_id="user_001",
session_id="business_analysis_session_1"
)
time.sleep(30)
# Example 3: Different user, new session
user_2_response, user_2_state = execute_query_with_tracking(
question="total no of unique vendors",
user_id="user_002",
session_id="vendor_analysis_session_1"
)
# Get comprehensive analytics
print("\n" + "="*50)
print("COMPREHENSIVE SESSION ANALYTICS")
print("="*50)
get_session_analytics("business_analysis_session_1")
get_session_analytics("vendor_analysis_session_1")
# Export execution logs to JSON for external analysis
def export_execution_logs(session_id: str, filename: str):
session_state = gemini_sql_agent.get_session_state(session_id)
execution_logs = {
"session_id": session_id,
"export_timestamp": datetime.now().isoformat(),
"tool_execution_log": session_state.get('tool_execution_log', []),
"user_context": session_state.get('user_context', {}),
"analysis_metadata": session_state.get('analysis_metadata', {})
}
with open(filename, 'w') as f:
json.dump(execution_logs, f, indent=2)
print(f"📁 Execution logs exported to {filename}")
export_execution_logs("business_analysis_session_1", "user_001_execution_log.json")
export_execution_logs("vendor_analysis_session_1", "user_002_execution_log.json")