Spaces:
Running
Running
| """ | |
| ⚠️ DEPRECATED - DO NOT USE IN PRODUCTION ⚠️ | |
| This file (oldAgent.py) is DEPRECATED and maintained only for reference. | |
| MIGRATION GUIDE: | |
| - Use agent.py for all new development | |
| - agent.py provides the same functionality with: | |
| * Production-grade API integration | |
| * Multi-tenant isolation via session state | |
| * No direct database connections | |
| * Centralized data access through data_sources API | |
| * Better error handling and logging | |
| * API key authentication support | |
| This file will be REMOVED in a future release. | |
| Please migrate all code to use agent.py instead. | |
| Last Updated: October 2025 | |
| Deprecation Status: ACTIVE - Do not use for new features | |
| Removal Target: Next major release | |
| """ | |
| import os | |
| import json | |
| import re | |
| import logging | |
| from typing import Dict, Set, Any, List, Optional | |
| from collections import defaultdict | |
| from functools import lru_cache | |
| from textwrap import dedent | |
| from sqlalchemy import create_engine, text | |
| from tenacity import retry, stop_after_attempt, wait_exponential | |
| from urllib.parse import quote_plus | |
| import pandas as pd | |
| from datetime import datetime | |
| # Updated imports for comprehensive tracking | |
| from agno.db.sqlite import SqliteDb # Changed from InMemoryDb for persistence | |
| from agno.agent import Agent | |
| from agno.models.google import Gemini | |
| from agno.tools import Toolkit | |
| from agno.tools.reasoning import ReasoningTools | |
| from agno.run.response import RunContext | |
| # Your existing database configuration | |
| DB_DIALECT = "mysql+pymysql" | |
| DB_USER = "root" | |
| DB_PASSWORD = "bwgadmin@2023" | |
| DB_HOST = "65.0.127.253" | |
| DB_PORT = "3306" | |
| DB_NAME = "bookwedgo" | |
| SCHEMA_PATH = r"/content/database_schema.json" | |
| if "GOOGLE_API_KEY" not in os.environ: | |
| print("🔴 WARNING: GOOGLE_API_KEY environment variable not set. The agent will fail.") | |
| encoded_password = quote_plus(DB_PASSWORD) | |
| DB_URL = f"{DB_DIALECT}://{DB_USER}:{encoded_password}@{DB_HOST}:{DB_PORT}/{DB_NAME}" | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Your existing SchemaManager and DatabaseToolkit classes remain the same | |
| class _SchemaManager: | |
| def __init__(self, schema_path: str): | |
| if not os.path.exists(schema_path): | |
| raise FileNotFoundError(f"Schema file not found: {schema_path}. Please upload it.") | |
| with open(schema_path, 'r') as f: | |
| self.schema_data = json.load(f) | |
| # Build the new data structures on initialization | |
| self.relationship_graph = self._build_relationship_graph() | |
| self.keyword_to_columns = self._create_keyword_mappings() | |
| logger.info("SchemaManager initialized with Relationship Graph and Column-Level Mappings.") | |
| def _build_relationship_graph(self) -> Dict[str, Set[str]]: | |
| graph = defaultdict(set) | |
| tables = self.schema_data.get('tables', []) | |
| if isinstance(self.schema_data, list): tables = self.schema_data[0].get('tables', []) | |
| for table in tables: | |
| table_name = table.get('table_name') | |
| if not table_name: continue | |
| for rel in table.get('relationships', []): | |
| referenced_table = rel.get('referenced_table') | |
| if referenced_table: | |
| # Add a two-way connection for easy lookup | |
| graph[table_name].add(referenced_table) | |
| graph[referenced_table].add(table_name) | |
| return graph | |
| def _create_keyword_mappings(self) -> Dict[str, Dict[str, Set[str]]]: | |
| mappings = defaultdict(lambda: defaultdict(set)) | |
| tables = self.schema_data.get('tables', []) | |
| if isinstance(self.schema_data, list): tables = self.schema_data[0].get('tables', []) | |
| def split_on_case_and_underscore(s: str) -> List[str]: | |
| parts = re.findall(r'[A-Z][a-z]*|[a-z]+|\d+', s) | |
| return [p.lower() for p in re.split('_', " ".join(parts))] | |
| for table in tables: | |
| table_name = table.get('table_name', '') | |
| for word in split_on_case_and_underscore(table_name): | |
| mappings[word][table_name].add('__table__') # Special key for table name match | |
| for field in table.get('fields', []): | |
| col_name = field.get('name', '') | |
| for word in split_on_case_and_underscore(col_name): | |
| mappings[word][table_name].add(col_name) | |
| for word in re.findall(r'[a-zA-Z]{3,}', field.get('description', '')): | |
| mappings[word.lower()][table_name].add(col_name) | |
| return mappings | |
| def get_related_tables(self, table_name: str) -> Set[str]: | |
| return self.relationship_graph.get(table_name, set()) | |
| # --- DatabaseToolkit leveraging the new _SchemaManager --- | |
| class DatabaseToolkit(Toolkit): | |
| def __init__(self, schema_path: str, db_url: str): | |
| super().__init__(name="database_tools", tools=[self.get_filtered_database_schema, self.execute_sql]) | |
| self.schema_manager = _SchemaManager(schema_path) | |
| self.engine = create_engine(db_url) | |
| # UPDATED PART 1: _format_schema_for_llm now includes example rows | |
| def _format_schema_for_llm(self, schema: Dict[str, Any]) -> str: | |
| if not schema.get("tables"): return "No relevant tables found." | |
| output_parts = [f"Database Schema: {schema.get('name', 'N/A')}\n"] | |
| for table in schema["tables"]: | |
| table_name = table.get('table_name', 'N/A') | |
| description = table.get('description', '').strip() | |
| output_parts.append(f"---") | |
| output_parts.append(f"\n**Table: `{table_name}`**") | |
| if description: output_parts.append(f"*Description: {description}*") | |
| output_parts.append("\n**Columns:**") | |
| for field in table.get('fields', []): | |
| col_name, col_type, col_example = field.get('name', 'N/A'), field.get('type', 'N/A'), field.get('example', 'N/A') | |
| col_desc = field.get('description', 'No description.').replace(f"This is the '{col_name}' column of the '{table_name}' table.", "").strip() | |
| output_parts.append(f"- `{col_name}` (type: {col_type}, ex: '{col_example}') - {col_desc}") | |
| if table.get("relationships"): | |
| output_parts.append("\n**Relationships:**") | |
| for rel in table["relationships"]: | |
| col, ref_table, ref_col = rel.get('column'), rel.get('referenced_table'), rel.get('referenced_column') | |
| output_parts.append(f"- `{table_name}.{col}` -> `{ref_table}.{ref_col}`") | |
| # --- NEW: Add example rows to the output --- | |
| if table.get("example_rows"): | |
| output_parts.append("\n**Example Rows:**") | |
| try: | |
| df = pd.DataFrame(table["example_rows"]) | |
| output_parts.append(f"```\n{df.to_string(index=False)}\n```") | |
| except Exception: | |
| # Fallback if pandas fails for some reason | |
| for row in table["example_rows"]: | |
| output_parts.append(f"- {row}") | |
| output_parts.append("\n") | |
| return "\n".join(output_parts) | |
| # UPDATED PART 2: get_filtered_database_schema now fetches live data samples | |
| def get_filtered_database_schema(self, keywords: List[str]) -> Dict[str, Any]: | |
| logger.info(f"Tool 'get_filtered_database_schema' called with keywords: {keywords}") | |
| if not keywords: | |
| return {"error": "No keywords provided."} | |
| initial_tables = set() | |
| for keyword in keywords: | |
| clean_keyword = keyword.lower().strip() | |
| if clean_keyword in self.schema_manager.keyword_to_columns: | |
| initial_tables.update(self.schema_manager.keyword_to_columns[clean_keyword].keys()) | |
| expanded_tables = set(initial_tables) | |
| for table_name in initial_tables: | |
| related_tables = self.schema_manager.get_related_tables(table_name) | |
| expanded_tables.update(related_tables) | |
| logger.info(f"Initial tables: {initial_tables}. Expanded with related tables: {expanded_tables}") | |
| source_tables = self.schema_manager.schema_data.get('tables', []) | |
| if isinstance(self.schema_manager.schema_data, list): | |
| source_tables = self.schema_manager.schema_data[0].get('tables', []) | |
| final_table_objects = [t for t in source_tables if t.get('table_name') in expanded_tables] | |
| # --- NEW: Inject live data samples into the schema object --- | |
| for table_obj in final_table_objects: | |
| table_name = table_obj.get('table_name') | |
| if not table_name: continue | |
| try: | |
| with self.engine.connect() as connection: | |
| sample_query = text(f"SELECT * FROM `{table_name}` LIMIT 3") | |
| sample_result = connection.execute(sample_query) | |
| sample_rows = [dict(row._mapping) for row in sample_result.fetchall()] | |
| table_obj['example_rows'] = sample_rows | |
| except Exception as e: | |
| logger.warning(f"Could not fetch sample rows for table {table_name}: {e}") | |
| table_obj['example_rows'] = [] | |
| if not final_table_objects: | |
| return {"formatted_schema_string": "No relevant tables were found."} | |
| final_schema_json = { | |
| "name": self.schema_manager.schema_data.get("schema_name", "database"), | |
| "tables": final_table_objects | |
| } | |
| formatted_string = self._format_schema_for_llm(final_schema_json) | |
| return {"formatted_schema_string": formatted_string} | |
| # UPDATED PART 3: execute_sql now has a pre-execution sanitization layer | |
| def execute_sql(self, sql_query: str) -> Dict[str, Any]: | |
| logger.info(f"Original SQL query from LLM:\n{sql_query}") | |
| # --- INTELLIGENT SANITIZATION LAYER --- | |
| sanitized_query = sql_query | |
| try: | |
| # Fix the DATE_FORMAT percent-sign escaping issue for MySQL | |
| sanitized_query = re.sub(r"DATE_FORMAT\(([^,]+,)\s*'([^']*)%([^']*)'([^)]*)\)", r"DATE_FORMAT(\1 '\2%%\3'\4)", sanitized_query, flags=re.IGNORECASE) | |
| # Safeguard: Ensure COALESCE is used in aggregations | |
| sanitized_query = re.sub(r'(SUM|AVG|MAX|MIN)\(\s*([a-zA-Z0-9_\.]+)\s*\)', r'\1(COALESCE(\2, 0))', sanitized_query, flags=re.IGNORECASE) | |
| if sanitized_query != sql_query: | |
| logger.info(f"Sanitized SQL query for execution:\n{sanitized_query}") | |
| with self.engine.connect() as connection: | |
| result = connection.execute(text(sanitized_query)) | |
| results_list = [dict(row._mapping) for row in result.fetchall()] | |
| return {"sql_results": results_list} | |
| except Exception as e: | |
| logger.error(f"SQL execution error: {e}") | |
| # Provide a more helpful error message back to the agent | |
| error_message = f"SQL Error: {str(e)}. Review the query for syntax issues, especially around date functions and aliases. The failed query was: {sanitized_query}" | |
| return {"error": error_message} | |
| # NEW: Enhanced Tool Hook for Complete Logging | |
| def comprehensive_logging_hook( | |
| run_context: RunContext, | |
| function_name: str, | |
| function_call, | |
| arguments: Dict[str, Any] | |
| ) -> Any: | |
| """ | |
| Comprehensive tool execution logging hook that saves: | |
| - Tool name and arguments | |
| - Execution timestamp | |
| - Results | |
| - User context | |
| """ | |
| # Access session_state from run_context (Agno v2 API) | |
| if not run_context.session_state: | |
| run_context.session_state = {} | |
| session_state = run_context.session_state | |
| # Initialize logging structure in session state | |
| if "tool_execution_log" not in session_state: | |
| session_state["tool_execution_log"] = [] | |
| # Create execution record | |
| execution_start = datetime.now() | |
| execution_record = { | |
| "tool_name": function_name, | |
| "arguments": arguments, | |
| "timestamp": execution_start.isoformat(), | |
| "execution_id": f"{function_name}_{execution_start.timestamp()}" | |
| } | |
| logger.info(f"🔧 Executing tool: {function_name} with args: {arguments}") | |
| try: | |
| # Execute the actual tool | |
| result = function_call(**arguments) | |
| # Log successful execution | |
| execution_end = datetime.now() | |
| execution_record.update({ | |
| "result": str(result)[:1000], # Truncate long results | |
| "status": "success", | |
| "duration_ms": (execution_end - execution_start).total_seconds() * 1000, | |
| "completed_at": execution_end.isoformat() | |
| }) | |
| logger.info(f"✅ Tool {function_name} completed successfully in {execution_record['duration_ms']:.2f}ms") | |
| except Exception as e: | |
| # Log failed execution | |
| execution_end = datetime.now() | |
| execution_record.update({ | |
| "error": str(e), | |
| "status": "failed", | |
| "duration_ms": (execution_end - execution_start).total_seconds() * 1000, | |
| "completed_at": execution_end.isoformat() | |
| }) | |
| logger.error(f"❌ Tool {function_name} failed: {str(e)}") | |
| raise # Re-raise the exception | |
| finally: | |
| # Always save the execution record | |
| session_state["tool_execution_log"].append(execution_record) | |
| return result | |
| # Enhanced system prompt with logging awareness | |
| system_prompt = dedent(""" | |
| You are Sirus, an expert data scientist with comprehensive execution tracking. | |
| **EXECUTION TRACKING:** | |
| - All your tool executions are automatically logged with timestamps, arguments, and results | |
| - Session state maintains a complete audit trail of your analysis process | |
| - Each query execution is tracked for performance and debugging | |
| **GUIDING PRINCIPLES:** | |
| 1. **Be a Business Analyst:** Hide technical complexity from users | |
| 2. **Be Resilient & Self-Correct:** Use `think` to diagnose and retry on failures | |
| 3. **Prioritize Source of Truth:** Choose the best tables for analysis | |
| 4. **Decompose Complex Questions:** Break down into simple sub-problems | |
| **MANDATORY THOUGHT PROCESS:** | |
| For complex questions, solve ONE sub-problem at a time: | |
| a) **`think`**: State the sub-problem and plan | |
| b) **`execute_sql`**: Run the query | |
| c) **`think`**: Review results and plan next steps | |
| **SQL REQUIREMENTS:** | |
| - Use MySQL syntax with single % in DATE_FORMAT | |
| - Always use COALESCE in aggregations | |
| - Test with simple queries first if complex ones fail | |
| **FINAL RESPONSE:** | |
| - Start with key numbers in **bold** | |
| - Provide business insights | |
| - Explain methodology simply | |
| - Suggest logical next steps | |
| """) | |
| print("✅ Configuration set. Initializing enhanced agent with comprehensive logging...") | |
| # Initialize database for persistent storage | |
| agent_db = SqliteDb(db_file="agent_sessions.db") | |
| # Initialize toolkits | |
| db_toolkit = DatabaseToolkit(schema_path=SCHEMA_PATH, db_url=DB_URL) | |
| reasoning_tools = ReasoningTools(add_instructions=True, enable_analyze=True, enable_think=True) | |
| # Create enhanced agent with comprehensive tracking | |
| gemini_sql_agent = Agent( | |
| model=Gemini( | |
| id="gemini-2.5-flash", | |
| system_prompt=system_prompt, | |
| thinking_budget=24000, | |
| include_thoughts=True, | |
| ), | |
| tools=[db_toolkit, reasoning_tools], | |
| tool_hooks=[comprehensive_logging_hook], # Add the logging hook | |
| # Database and session management | |
| db=agent_db, | |
| add_history_to_context=True, | |
| num_history_runs=3, | |
| read_chat_history=True, | |
| # Session state for tracking | |
| session_state={ | |
| "tool_execution_log": [], | |
| "user_context": {}, | |
| "analysis_metadata": {} | |
| }, | |
| add_session_state_to_context=True, # Make session state available to tools | |
| # Response formatting | |
| markdown=True, | |
| add_datetime_to_context=True, | |
| stream_intermediate_steps=True, | |
| # Error handling | |
| exponential_backoff=True, | |
| delay_between_retries=10 | |
| ) | |
| # Enhanced execution function with user tracking | |
| def execute_query_with_tracking(question: str, user_id: str, session_id: Optional[str] = None): | |
| """ | |
| Execute a query with comprehensive user and session tracking | |
| """ | |
| print(f"\n🚀 User {user_id} asking: '{question}'\n") | |
| # Execute with session tracking | |
| response = gemini_sql_agent.print_response( | |
| question, | |
| stream=True, | |
| session_id=session_id | |
| ) | |
| # Get final session state with all execution logs | |
| final_session_state = gemini_sql_agent.get_session_state(session_id) | |
| print(f"\n📊 Execution Summary for User {user_id}:") | |
| print(f" Session ID: {gemini_sql_agent.session_id}") | |
| print(f" Tools executed: {len(final_session_state.get('tool_execution_log', []))}") | |
| # Print tool execution summary | |
| for i, log_entry in enumerate(final_session_state.get('tool_execution_log', [])[-3:], 1): | |
| print(f" Tool {i}: {log_entry['tool_name']} - {log_entry['status']} ({log_entry.get('duration_ms', 0):.1f}ms)") | |
| return response, final_session_state | |
| # Function to retrieve session history and analytics | |
| def get_session_analytics(session_id: str): | |
| """ | |
| Retrieve comprehensive session analytics | |
| """ | |
| session = gemini_sql_agent.get_session(session_id) | |
| if session: | |
| print(f"\n📈 Session Analytics for {session_id}:") | |
| print(f" Created: {session.created_at}") | |
| print(f" Updated: {session.updated_at}") | |
| print(f" Total runs: {len(session.runs) if hasattr(session, 'runs') else 'N/A'}") | |
| # Get session state | |
| session_state = gemini_sql_agent.get_session_state(session_id) | |
| tool_logs = session_state.get('tool_execution_log', []) | |
| print(f" Total tool executions: {len(tool_logs)}") | |
| # Tool usage statistics | |
| tool_usage = {} | |
| for log in tool_logs: | |
| tool_name = log['tool_name'] | |
| tool_usage[tool_name] = tool_usage.get(tool_name, 0) + 1 | |
| print(" Tool usage breakdown:") | |
| for tool, count in tool_usage.items(): | |
| print(f" - {tool}: {count} times") | |
| return session | |
| # Example usage with user tracking | |
| if __name__ == "__main__": | |
| import time | |
| # Example 1: First user query | |
| user_1_response, user_1_state = execute_query_with_tracking( | |
| question="What were our total bookings", | |
| user_id="user_001", | |
| session_id="business_analysis_session_1" | |
| ) | |
| time.sleep(30) | |
| # Example 2: Follow-up query in same session | |
| user_1_followup, user_1_final_state = execute_query_with_tracking( | |
| question="Give me the total revenue from these bookings", | |
| user_id="user_001", | |
| session_id="business_analysis_session_1" | |
| ) | |
| time.sleep(30) | |
| # Example 3: Different user, new session | |
| user_2_response, user_2_state = execute_query_with_tracking( | |
| question="total no of unique vendors", | |
| user_id="user_002", | |
| session_id="vendor_analysis_session_1" | |
| ) | |
| # Get comprehensive analytics | |
| print("\n" + "="*50) | |
| print("COMPREHENSIVE SESSION ANALYTICS") | |
| print("="*50) | |
| get_session_analytics("business_analysis_session_1") | |
| get_session_analytics("vendor_analysis_session_1") | |
| # Export execution logs to JSON for external analysis | |
| def export_execution_logs(session_id: str, filename: str): | |
| session_state = gemini_sql_agent.get_session_state(session_id) | |
| execution_logs = { | |
| "session_id": session_id, | |
| "export_timestamp": datetime.now().isoformat(), | |
| "tool_execution_log": session_state.get('tool_execution_log', []), | |
| "user_context": session_state.get('user_context', {}), | |
| "analysis_metadata": session_state.get('analysis_metadata', {}) | |
| } | |
| with open(filename, 'w') as f: | |
| json.dump(execution_logs, f, indent=2) | |
| print(f"📁 Execution logs exported to {filename}") | |
| export_execution_logs("business_analysis_session_1", "user_001_execution_log.json") | |
| export_execution_logs("vendor_analysis_session_1", "user_002_execution_log.json") |