| | import gradio as gr |
| | import sqlite3 |
| | import json |
| | import pandas as pd |
| | from openai import OpenAI |
| | import traceback |
| | from typing import Dict, List, Tuple, Any |
| | import re |
| | from datetime import datetime |
| | import threading |
| | import queue |
| | import html |
| | import sys |
| | import os |
| |
|
| | |
| | if sys.stdout.encoding != 'utf-8': |
| | sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1) |
| |
|
| | class DatabaseQueryAgent: |
| | def __init__(self, db_path: str = "innovativeskills.db"): |
| | self.db_path = db_path |
| | self.client = None |
| | |
| | |
| | self.models = { |
| | "llama": "meta-llama/llama-3.3-70b-instruct:free", |
| | "mistral": "mistralai/mistral-7b-instruct:free", |
| | "gemma": "google/gemma-2-9b-it:free" |
| | } |
| | |
| | |
| | self.init_db_connection() |
| | |
| | def init_db_connection(self): |
| | """Initialize database connection with UTF-8 encoding""" |
| | try: |
| | conn = sqlite3.connect(self.db_path, check_same_thread=False) |
| | conn.execute("PRAGMA encoding = 'UTF-8';") |
| | cursor = conn.cursor() |
| | |
| | |
| | self.table_metadata = self.get_table_metadata(conn, cursor) |
| | self.column_metadata = self.get_column_metadata(conn, cursor) |
| | self.actual_schema = self.get_actual_schema(conn, cursor) |
| | |
| | conn.close() |
| | |
| | except Exception as e: |
| | print(f"Database initialization error: {e}") |
| | self.table_metadata = {} |
| | self.column_metadata = {} |
| | self.actual_schema = {} |
| | |
| | def get_db_connection(self): |
| | """Get a new database connection with UTF-8 encoding""" |
| | conn = sqlite3.connect(self.db_path, check_same_thread=False) |
| | conn.execute("PRAGMA encoding = 'UTF-8';") |
| | return conn |
| | |
| | def get_actual_schema(self, conn, cursor) -> Dict: |
| | """Get actual database schema""" |
| | try: |
| | cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") |
| | tables = [row[0] for row in cursor.fetchall()] |
| | schema = {} |
| | for table in tables: |
| | cursor.execute(f"PRAGMA table_info({table})") |
| | columns = cursor.fetchall() |
| | try: |
| | cursor.execute(f"SELECT * FROM {table} LIMIT 3") |
| | sample_data = cursor.fetchall() |
| | except Exception: |
| | sample_data = [] |
| | try: |
| | cursor.execute(f"SELECT COUNT(*) FROM {table}") |
| | row_count = cursor.fetchone()[0] |
| | except Exception: |
| | row_count = 0 |
| | schema[table] = { |
| | 'columns': [{'name': col[1], 'type': col[2], 'notnull': col[3], 'pk': col[5]} for col in columns], |
| | 'sample_data': sample_data, |
| | 'row_count': row_count |
| | } |
| | return schema |
| | except Exception as e: |
| | print(f"Error getting actual schema: {e}") |
| | return {} |
| | |
| | def get_table_metadata(self, conn, cursor) -> Dict: |
| | """Get table metadata""" |
| | try: |
| | query = """ |
| | SELECT table_name, domain, description, row_count |
| | FROM table_catalog |
| | WHERE table_name NOT IN ('table_catalog', 'column_catalog') |
| | """ |
| | results = cursor.execute(query).fetchall() |
| | metadata = {} |
| | for table_name, domain, description, row_count in results: |
| | metadata[table_name] = { |
| | 'domain': domain, |
| | 'description': description, |
| | 'row_count': row_count |
| | } |
| | return metadata |
| | except Exception as e: |
| | print(f"Error loading table metadata: {e}") |
| | return {} |
| | |
| | def get_column_metadata(self, conn, cursor) -> Dict: |
| | """Get column metadata""" |
| | try: |
| | query = """ |
| | SELECT table_name, column_name, data_type, is_foreign_key, references_table, description |
| | FROM column_catalog |
| | """ |
| | results = cursor.execute(query).fetchall() |
| | metadata = {} |
| | for table_name, column_name, data_type, is_fk, ref_table, description in results: |
| | if table_name not in metadata: |
| | metadata[table_name] = [] |
| | metadata[table_name].append({ |
| | 'name': column_name, |
| | 'type': data_type, |
| | 'is_foreign_key': bool(is_fk), |
| | 'references': ref_table, |
| | 'description': description |
| | }) |
| | return metadata |
| | except Exception as e: |
| | print(f"Error loading column metadata: {e}") |
| | return {} |
| | |
| | def setup_client(self, api_key: str): |
| | """Setup OpenRouter client""" |
| | self.client = OpenAI( |
| | base_url="https://openrouter.ai/api/v1", |
| | api_key=api_key, |
| | ) |
| | |
| | def get_relevant_tables_for_query(self, query: str) -> str: |
| | """Analyze query and return relevant table info""" |
| | query_lower = query.lower() |
| | relevant_tables = [] |
| | keywords = { |
| | 'customer': ['customer', 'client', 'buyer', 'user'], |
| | 'order': ['order', 'purchase', 'transaction', 'sale'], |
| | 'product': ['product', 'item', 'inventory', 'stock'], |
| | 'employee': ['employee', 'staff', 'worker', 'personnel'], |
| | 'patient': ['patient', 'medical', 'health'], |
| | 'student': ['student', 'enrollment', 'grade', 'course'], |
| | 'supplier': ['supplier', 'vendor', 'provider'], |
| | 'shipping': ['shipping', 'delivery', 'logistics'], |
| | 'payment': ['payment', 'invoice', 'billing'], |
| | 'account': ['account', 'financial', 'balance'] |
| | } |
| | for concept, search_terms in keywords.items(): |
| | if any(term in query_lower for term in search_terms): |
| | for table_name in self.actual_schema.keys(): |
| | table_lower = table_name.lower() |
| | if any(term in table_lower for term in search_terms): |
| | if table_name not in relevant_tables: |
| | relevant_tables.append(table_name) |
| | if not relevant_tables: |
| | relevant_tables = [name for name, info in self.actual_schema.items() |
| | if info['row_count'] > 10][:10] |
| | schema_info = "" |
| | for table in relevant_tables[:15]: |
| | if table in self.actual_schema: |
| | info = self.actual_schema[table] |
| | columns_str = ", ".join([f"{col['name']}({col['type']})" for col in info['columns']]) |
| | schema_info += f"\nTable: {table}\n" |
| | schema_info += f" Columns: {columns_str}\n" |
| | schema_info += f" Rows: {info['row_count']}\n" |
| | if table in self.table_metadata: |
| | meta = self.table_metadata[table] |
| | schema_info += f" Domain: {meta['domain']}\n" |
| | schema_info += f" Description: {meta['description']}\n" |
| | if info['sample_data']: |
| | schema_info += f" Sample: {info['sample_data'][0] if info['sample_data'] else 'No data'}\n" |
| | return schema_info |
| | |
| | def get_system_prompt(self, user_query: str) -> str: |
| | """Generate system prompt with actual schema""" |
| | relevant_schema = self.get_relevant_tables_for_query(user_query) |
| | return f"""You are an intelligent database query agent that specializes in identifying relevant tables and generating accurate SQL queries. |
| | |
| | DATABASE SCHEMA INFORMATION: |
| | {relevant_schema} |
| | |
| | CRITICAL SQL RULES: |
| | 1. NEVER use reserved words as table aliases (like 'to', 'from', 'where', 'select', etc.) |
| | 2. Use descriptive aliases like 'cust', 'ord', 'prod' instead |
| | 3. Only JOIN tables if you can identify a logical relationship between them |
| | 4. If no clear JOIN relationship exists, use separate SELECT statements or UNION |
| | 5. Always use the EXACT column names shown in the schema |
| | 6. Do not assume foreign key relationships unless explicitly shown |
| | |
| | CRITICAL: You MUST respond with ONLY a valid JSON object. No markdown, no explanations outside the JSON. |
| | |
| | Your response must be exactly in this JSON format: |
| | {{ |
| | "analysis": "Brief analysis of the query and table selection reasoning", |
| | "identified_tables": ["table1", "table2", "table3"], |
| | "domains_involved": ["domain1", "domain2"], |
| | "sql_query": "SELECT ... FROM ... WHERE ...", |
| | "explanation": "Step-by-step explanation of the query logic", |
| | "confidence": 0.95, |
| | "alternative_queries": ["Alternative SQL if applicable"] |
| | }} |
| | |
| | IMPORTANT RULES: |
| | 1. Respond with ONLY valid JSON - no markdown formatting |
| | 2. Use ONLY the actual table names shown in the schema above |
| | 3. Use ONLY the actual column names shown in the schema above |
| | 4. Generate syntactically correct SQL queries with proper aliases |
| | 5. Focus on tables that actually exist and have relevant data |
| | 6. Include confidence scores between 0.0 and 1.0 |
| | 7. Provide clear explanations |
| | 8. Ensure table names in 'identified_tables' match those used in 'sql_query' |
| | 9. Check that columns referenced in SQL actually exist in the tables |
| | 10. If no perfect match exists, choose the closest relevant tables and explain the compromise |
| | 11. Avoid reserved word aliases like 'to', 'from', 'order', 'select' |
| | |
| | QUERY ANALYSIS GUIDELINES: |
| | - For customer/order queries: Look for tables with customer-related or order-related names and columns |
| | - For employee queries: Look for tables with employee, staff, or HR-related names |
| | - For product queries: Look for tables with product, inventory, or item-related names |
| | - Always verify column names exist before using them in SQL |
| | - Use proper JOIN syntax when combining tables, but only if logical relationships exist |
| | - Include appropriate WHERE clauses when filtering is implied |
| | - If unsure about relationships, prefer simpler queries or multiple separate queries""" |
| |
|
| | def extract_json_from_response(self, response_text: str) -> Dict: |
| | """Extract JSON from response text""" |
| | try: |
| | return json.loads(response_text) |
| | except json.JSONDecodeError: |
| | json_pattern = r'```json\s*(.*?)\s*```' |
| | json_match = re.search(json_pattern, response_text, re.DOTALL) |
| | if json_match: |
| | try: |
| | return json.loads(json_match.group(1)) |
| | except json.JSONDecodeError: |
| | pass |
| | json_pattern = r'\{.*\}' |
| | json_match = re.search(json_pattern, response_text, re.DOTALL) |
| | if json_match: |
| | try: |
| | return json.loads(json_match.group(0)) |
| | except json.JSONDecodeError: |
| | pass |
| | return self.create_fallback_response(response_text) |
| | |
| | def create_fallback_response(self, response_text: str) -> Dict: |
| | """Create a fallback response when JSON parsing fails""" |
| | sql_pattern = r'SELECT.*?(?:;|$)' |
| | sql_match = re.search(sql_pattern, response_text, re.IGNORECASE | re.DOTALL) |
| | sql_query = sql_match.group(0).strip(';') if sql_match else "" |
| | identified_tables = [table_name for table_name in self.actual_schema.keys() |
| | if table_name.lower() in response_text.lower()] |
| | domains_involved = [self.table_metadata[table]['domain'] for table in identified_tables |
| | if table in self.table_metadata and self.table_metadata[table]['domain'] not in domains_involved] |
| | return { |
| | "analysis": "Fallback analysis from unparseable response", |
| | "identified_tables": identified_tables[:5], |
| | "domains_involved": domains_involved[:3], |
| | "sql_query": sql_query, |
| | "explanation": "Response could not be parsed as JSON, extracted information where possible", |
| | "confidence": 0.5, |
| | "alternative_queries": [] |
| | } |
| | |
| | def validate_sql_query(self, sql_query: str, identified_tables: List[str]) -> Tuple[bool, str]: |
| | """Validate SQL query against schema""" |
| | try: |
| | if not sql_query.strip(): |
| | return False, "Empty SQL query" |
| | for table in identified_tables: |
| | if table not in self.actual_schema: |
| | return False, f"Table '{table}' does not exist in database" |
| | sql_upper = sql_query.upper() |
| | if not sql_upper.strip().startswith('SELECT'): |
| | return False, "Only SELECT queries are allowed" |
| | reserved_words = ['TO', 'FROM', 'WHERE', 'SELECT', 'ORDER', 'GROUP', 'HAVING', 'UNION', 'JOIN', 'ON'] |
| | alias_pattern = r'(?:FROM|JOIN)\s+(\w+)\s+(\w+)' |
| | aliases = re.findall(alias_pattern, sql_query, re.IGNORECASE) |
| | for table, alias in aliases: |
| | if alias.upper() in reserved_words: |
| | return False, f"Cannot use reserved word '{alias}' as table alias" |
| | for table in identified_tables: |
| | if table in sql_query: |
| | table_info = self.actual_schema[table] |
| | available_columns = [col['name'] for col in table_info['columns']] |
| | column_patterns = [ |
| | rf'{re.escape(table)}\.(\w+)', |
| | rf'\b(\w+)\.(\w+)', |
| | rf'SELECT\s+([^FROM]+)' |
| | ] |
| | for pattern in column_patterns: |
| | matches = re.findall(pattern, sql_query, re.IGNORECASE) |
| | for match in matches: |
| | if isinstance(match, tuple): |
| | column = match[1] if len(match) == 2 else match[0] if match else '' |
| | else: |
| | column = match |
| | if column.upper() in ['*', 'COUNT', 'SUM', 'AVG', 'MAX', 'MIN', 'DISTINCT']: |
| | continue |
| | if column and column not in available_columns and f'{table}.{column}' in sql_query: |
| | return False, f"Column '{column}' does not exist in table '{table}'" |
| | return True, "Query validation passed" |
| | except Exception as e: |
| | return False, f"Validation error: {str(e)}" |
| | |
| | def call_model(self, model_key: str, prompt: str, user_query: str) -> Dict: |
| | """Call specific model with prompt""" |
| | try: |
| | messages = [ |
| | {"role": "system", "content": prompt}, |
| | {"role": "user", "content": f"Query: {user_query}\n\nRespond with ONLY a valid JSON object following the exact format specified in the system prompt."} |
| | ] |
| | completion = self.client.chat.completions.create( |
| | model=self.models[model_key], |
| | messages=messages, |
| | temperature=0.1, |
| | max_tokens=2000 |
| | ) |
| | response = completion.choices[0].message.content.strip() |
| | parsed_response = self.extract_json_from_response(response) |
| | sql_query = parsed_response.get('sql_query', '') |
| | identified_tables = parsed_response.get('identified_tables', []) |
| | if sql_query: |
| | is_valid, validation_message = self.validate_sql_query(sql_query, identified_tables) |
| | parsed_response['sql_validation'] = { |
| | 'is_valid': is_valid, |
| | 'message': validation_message |
| | } |
| | return { |
| | "success": True, |
| | "response": parsed_response, |
| | "raw_response": response, |
| | "model": model_key |
| | } |
| | except Exception as e: |
| | return { |
| | "success": False, |
| | "error": str(e), |
| | "model": model_key |
| | } |
| | |
| | def verify_response(self, api_key: str, original_query: str, llama_response: Dict, mistral_response: Dict) -> Dict: |
| | """Use Gemma to verify responses""" |
| | self.setup_client(api_key) |
| | relevant_schema = self.get_relevant_tables_for_query(original_query) |
| | verification_prompt = f"""You are a database query verification expert. You have access to the actual database schema and must verify responses against it. |
| | |
| | ACTUAL DATABASE SCHEMA: |
| | {relevant_schema} |
| | |
| | ORIGINAL QUERY: {original_query} |
| | |
| | LLAMA RESPONSE: {json.dumps(llama_response.get('response', {}), indent=2)} |
| | |
| | MISTRAL RESPONSE: {json.dumps(mistral_response.get('response', {}), indent=2)} |
| | |
| | Verify these responses against the ACTUAL schema above. Check: |
| | 1. Do the table names actually exist in the schema? |
| | 2. Do the column names actually exist in those tables? |
| | 3. Are the table selections appropriate for the query? |
| | 4. Is the SQL syntax correct? |
| | 5. Are table aliases proper (not reserved words)? |
| | |
| | Respond with ONLY a valid JSON object: |
| | {{ |
| | "verification_summary": "Overall assessment based on actual schema", |
| | "table_selection_accuracy": "Assessment of table choices against actual schema", |
| | "sql_correctness": "SQL syntax and schema validation", |
| | "consistency_check": "Comparison between responses", |
| | "recommended_response": "llama, mistral, or neither", |
| | "confidence_score": 0.85, |
| | "suggested_improvements": ["improvement1", "improvement2"], |
| | "potential_issues": ["issue1", "issue2"], |
| | "schema_compliance": "Assessment of how well responses match actual schema" |
| | }}""" |
| | return self.call_model("gemma", verification_prompt, "Verify the above responses against the actual database schema.") |
| | |
| | def execute_query_in_thread(self, sql_query: str, result_queue: queue.Queue): |
| | """Execute SQL query in a thread""" |
| | try: |
| | if not sql_query.strip().upper().startswith('SELECT'): |
| | result_queue.put((False, "Only SELECT queries are allowed")) |
| | return |
| | sql_query = sql_query.strip().rstrip(';') |
| | conn = self.get_db_connection() |
| | try: |
| | df = pd.read_sql_query(sql_query, conn) |
| | result_queue.put((True, df)) |
| | except Exception as e: |
| | result_queue.put((False, str(e))) |
| | finally: |
| | conn.close() |
| | except Exception as e: |
| | result_queue.put((False, f"Query execution error: {str(e)}")) |
| | |
| | def execute_query(self, sql_query: str) -> Tuple[bool, Any]: |
| | """Execute SQL query using thread-safe approach""" |
| | try: |
| | result_queue = queue.Queue() |
| | thread = threading.Thread( |
| | target=self.execute_query_in_thread, |
| | args=(sql_query, result_queue) |
| | ) |
| | thread.start() |
| | thread.join(timeout=30) |
| | if thread.is_alive(): |
| | return False, "Query execution timed out" |
| | if not result_queue.empty(): |
| | return result_queue.get() |
| | else: |
| | return False, "No result returned from query execution" |
| | except Exception as e: |
| | return False, f"Execution error: {str(e)}" |
| | |
| | def process_query(self, api_key: str, user_query: str) -> Dict: |
| | """Process user query""" |
| | if not api_key: |
| | return {"error": "Please provide OpenRouter API key"} |
| | try: |
| | self.setup_client(api_key) |
| | system_prompt = self.get_system_prompt(user_query) |
| | llama_result = self.call_model("llama", system_prompt, user_query) |
| | mistral_result = self.call_model("mistral", system_prompt, user_query) |
| | verification_result = self.verify_response(api_key, user_query, llama_result, mistral_result) |
| | execution_results = {} |
| | for model_name, result in [("llama", llama_result), ("mistral", mistral_result)]: |
| | if result.get("success") and result.get("response", {}).get("sql_query"): |
| | sql_query = result["response"]["sql_query"] |
| | validation_info = result["response"].get("sql_validation", {}) |
| | if sql_query.strip(): |
| | if validation_info.get("is_valid", True): |
| | success, data = self.execute_query(sql_query) |
| | execution_results[model_name] = { |
| | "success": success, |
| | "data": data.to_dict('records') if success and isinstance(data, pd.DataFrame) else str(data), |
| | "row_count": len(data) if success and isinstance(data, pd.DataFrame) else 0, |
| | "sql_query": sql_query, |
| | "validation": validation_info |
| | } |
| | else: |
| | execution_results[model_name] = { |
| | "success": False, |
| | "data": f"Query validation failed: {validation_info.get('message', 'Unknown error')}", |
| | "row_count": 0, |
| | "sql_query": sql_query, |
| | "validation": validation_info |
| | } |
| | else: |
| | execution_results[model_name] = { |
| | "success": False, |
| | "data": "No SQL query generated", |
| | "row_count": 0, |
| | "sql_query": "", |
| | "validation": {"is_valid": False, "message": "Empty query"} |
| | } |
| | else: |
| | execution_results[model_name] = { |
| | "success": False, |
| | "data": "Model failed to generate response", |
| | "row_count": 0, |
| | "sql_query": "", |
| | "validation": {"is_valid": False, "message": "Model error"} |
| | } |
| | return { |
| | "llama_response": llama_result, |
| | "mistral_response": mistral_result, |
| | "verification": verification_result, |
| | "execution_results": execution_results, |
| | "timestamp": datetime.now().isoformat(), |
| | "schema_info": self.get_relevant_tables_for_query(user_query) |
| | } |
| | except Exception as e: |
| | return {"error": f"Processing error: {str(e)}", "traceback": traceback.format_exc()} |
| |
|
| | def response_to_markdown(response_dict: Dict) -> str: |
| | """Convert model response to Markdown""" |
| | if not response_dict.get("success", False): |
| | return f"**Error**: {response_dict.get('error', 'Unknown error')}" |
| | response = response_dict.get("response", {}) |
| | markdown = "**Query Analysis Results**\n\n" |
| | markdown += f"- **Analysis**: {response.get('analysis', 'N/A')}\n\n" |
| | identified_tables = response.get('identified_tables', []) |
| | markdown += f"- **Identified Tables**: {', '.join(identified_tables) if identified_tables else 'None'}\n\n" |
| | domains_involved = response.get('domains_involved', []) |
| | markdown += f"- **Domains Involved**: {', '.join(domains_involved) if domains_involved else 'None'}\n\n" |
| | sql_query = response.get('sql_query', '') |
| | if sql_query: |
| | markdown += "- **SQL Query**:\n\n```sql\n" + sql_query + "\n```\n\n" |
| | else: |
| | markdown += "- **SQL Query**: None\n\n" |
| | markdown += f"- **Explanation**: {response.get('explanation', 'N/A')}\n\n" |
| | markdown += f"- **Confidence**: {response.get('confidence', 'N/A')}\n\n" |
| | alternative_queries = response.get('alternative_queries', []) |
| | if alternative_queries: |
| | markdown += "- **Alternative Queries**:\n" |
| | for query in alternative_queries: |
| | markdown += f" - {query}\n" |
| | else: |
| | markdown += "- **Alternative Queries**: None\n" |
| | validation = response.get('sql_validation', {}) |
| | if validation: |
| | is_valid = validation.get('is_valid', False) |
| | message = validation.get('message', 'N/A') |
| | markdown += f"\n- **SQL Validation**: {'Passed' if is_valid else 'Failed'} - {message}\n" |
| | return markdown |
| |
|
| | def verification_to_markdown(verification_dict: Dict) -> str: |
| | """Convert verification response to Markdown""" |
| | if not verification_dict.get("success", False): |
| | return f"**Error**: {verification_dict.get('error', 'Unknown error')}" |
| | response = verification_dict.get("response", {}) |
| | markdown = "**Verification Results**\n\n" |
| | markdown += f"- **Verification Summary**: {response.get('verification_summary', 'N/A')}\n\n" |
| | markdown += f"- **Table Selection Accuracy**: {response.get('table_selection_accuracy', 'N/A')}\n\n" |
| | markdown += f"- **SQL Correctness**: {response.get('sql_correctness', 'N/A')}\n\n" |
| | markdown += f"- **Consistency Check**: {response.get('consistency_check', 'N/A')}\n\n" |
| | markdown += f"- **Recommended Response**: {response.get('recommended_response', 'N/A')}\n\n" |
| | markdown += f"- **Confidence Score**: {response.get('confidence_score', 'N/A')}\n\n" |
| | suggested_improvements = response.get('suggested_improvements', []) |
| | if suggested_improvements: |
| | markdown += "- **Suggested Improvements**:\n" |
| | for improvement in suggested_improvements: |
| | markdown += f" - {improvement}\n" |
| | else: |
| | markdown += "- **Suggested Improvements**: None\n" |
| | potential_issues = response.get('potential_issues', []) |
| | if potential_issues: |
| | markdown += "- **Potential Issues**:\n" |
| | for issue in potential_issues: |
| | markdown += f" - {issue}\n" |
| | else: |
| | markdown += "- **Potential Issues**: None\n" |
| | markdown += f"- **Schema Compliance**: {response.get('schema_compliance', 'N/A')}\n" |
| | return markdown |
| |
|
| | def create_gradio_interface(): |
| | """Create Gradio interface""" |
| | agent = DatabaseQueryAgent() |
| | sample_queries = [ |
| | "Find all customers from customer tables", |
| | "Show me employee information from HR tables", |
| | "Get patient data from healthcare tables", |
| | "List all products with their details", |
| | "Find students enrolled in courses", |
| | "Show financial transaction records", |
| | "Get shipping information for deliveries", |
| | "Find all suppliers and their information", |
| | "Show retail store data", |
| | "Get manufacturing production records" |
| | ] |
| | |
| | def process_user_query(api_key, query): |
| | """Process query and return formatted results""" |
| | if not query.strip(): |
| | return "Please enter a query", "", "", "", "", "" |
| | results = agent.process_query(api_key, query) |
| | if "error" in results: |
| | return f"**Error**: {results['error']}", "", "", "", "", "" |
| | |
| | |
| | llama_markdown = response_to_markdown(results.get("llama_response", {})) |
| | mistral_markdown = response_to_markdown(results.get("mistral_response", {})) |
| | verification_markdown = verification_to_markdown(results.get("verification", {})) |
| | |
| | |
| | exec_results = results.get("execution_results", {}) |
| | execution_formatted = "" |
| | for model, result in exec_results.items(): |
| | execution_formatted += f"\n=== {model.upper()} EXECUTION ===\n" |
| | execution_formatted += f"SQL Query: {result.get('sql_query', 'N/A')}\n" |
| | validation = result.get('validation', {}) |
| | if validation.get('is_valid'): |
| | execution_formatted += f"β
Query Validation: PASSED\n" |
| | else: |
| | execution_formatted += f"β Query Validation: FAILED - {validation.get('message', 'Unknown error')}\n" |
| | if result["success"]: |
| | execution_formatted += f"β
Execution: Success! Retrieved {result['row_count']} rows\n" |
| | if result["row_count"] > 0: |
| | sample_data = result['data'][:3] if isinstance(result['data'], list) else [] |
| | execution_formatted += f"Sample data:\n{json.dumps(sample_data, indent=2)}\n" |
| | else: |
| | execution_formatted += "No data returned (empty result set)\n" |
| | else: |
| | execution_formatted += f"β Execution Error: {result['data']}\n" |
| | execution_formatted += "\n" |
| | if not execution_formatted: |
| | execution_formatted = "No queries were executed. Check if valid SQL was generated." |
| | |
| | schema_info = results.get('schema_info', 'No schema information available') |
| | |
| | |
| | verification_resp = results.get('verification', {}).get('response', {}) |
| | summary = f""" |
| | **π QUERY ANALYSIS COMPLETE** |
| | |
| | ββββββββββββββββββββββββ |
| | |
| | **π Models Used**: Llama 3.1 8B, Mistral 7B, Gemma 2 9B (verification) |
| | |
| | **β° Processed**: {results.get('timestamp', 'N/A')} |
| | |
| | **π― Verification Summary**: |
| | |
| | {verification_resp.get('verification_summary', 'N/A')} |
| | |
| | **π‘ Recommended Model**: {verification_resp.get('recommended_response', 'N/A')} |
| | |
| | **π Confidence**: {verification_resp.get('confidence_score', 'N/A')} |
| | |
| | **ποΈ Schema Compliance**: {verification_resp.get('schema_compliance', 'N/A')} |
| | |
| | **ποΈ Query Execution Status**: |
| | |
| | {len(exec_results)} queries attempted |
| | """ |
| | |
| | return summary, llama_markdown, mistral_markdown, verification_markdown, execution_formatted, schema_info |
| | |
| | with gr.Blocks( |
| | title="Fixed Intelligent Database Query Agent", |
| | theme=gr.themes.Soft(), |
| | css=""" |
| | .gradio-container { |
| | max-width: 1200px !important; |
| | margin: 0 auto !important; |
| | } |
| | .result-box { |
| | background-color: #f8f9fa; |
| | border: 1px solid #dee2e6; |
| | border-radius: 8px; |
| | padding: 15px; |
| | } |
| | """ |
| | ) as interface: |
| | gr.HTML(""" |
| | <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;"> |
| | <h1>π€ Fixed Intelligent Database Query Agent</h1> |
| | <p>AI-powered agent that intelligently selects relevant tables from 100+ tables and generates optimized SQL queries</p> |
| | <p><strong>Database:</strong> 100 tables across 10 business domains | <strong>Models:</strong> Llama 3.1 8B + Mistral 7B + Gemma 2 9B</p> |
| | <p><strong>β
FIXED:</strong> Reserved Word Aliases | Enhanced Column Validation | Better SQL Syntax Checking</p> |
| | </div> |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | api_key_input = gr.Textbox( |
| | label="π OpenRouter API Key", |
| | type="password", |
| | placeholder="Enter your OpenRouter API key...", |
| | info="Get your free API key from openrouter.ai" |
| | ) |
| | query_input = gr.Textbox( |
| | label="π¬ Database Query", |
| | placeholder="Enter your natural language query...", |
| | lines=3, |
| | info="Example: 'Find all customers who placed orders in the last month'" |
| | ) |
| | with gr.Row(): |
| | submit_btn = gr.Button("π Process Query", variant="primary", size="lg") |
| | clear_btn = gr.Button("ποΈ Clear", variant="secondary") |
| | gr.HTML("<h3>π Sample Test Queries</h3>") |
| | sample_dropdown = gr.Dropdown( |
| | choices=sample_queries, |
| | label="Quick Test Examples", |
| | info="Select a sample query to test the agent" |
| | ) |
| | |
| | with gr.Column(scale=2): |
| | summary_output = gr.Markdown(label="π Analysis Summary") |
| | with gr.Tabs(): |
| | with gr.Tab("π¦ Llama 3.1 8B Response"): |
| | llama_output = gr.Markdown(label="Llama Response") |
| | with gr.Tab("π Mistral 7B Response"): |
| | mistral_output = gr.Markdown(label="Mistral Response") |
| | with gr.Tab("β
Verification (Gemma 2 9B)"): |
| | verification_output = gr.Markdown(label="Verification Analysis") |
| | with gr.Tab("ποΈ Query Execution Results"): |
| | execution_output = gr.Textbox( |
| | label="Database Execution Results", |
| | lines=15, |
| | max_lines=20, |
| | elem_classes=["result-box"] |
| | ) |
| | with gr.Tab("π Database Schema"): |
| | schema_output = gr.Textbox( |
| | label="Relevant Database Schema", |
| | lines=15, |
| | max_lines=20, |
| | elem_classes=["result-box"] |
| | ) |
| | |
| | submit_btn.click( |
| | fn=process_user_query, |
| | inputs=[api_key_input, query_input], |
| | outputs=[summary_output, llama_output, mistral_output, verification_output, execution_output, schema_output] |
| | ) |
| | clear_btn.click( |
| | fn=lambda: ("", "", "", "", "", "", ""), |
| | outputs=[query_input, summary_output, llama_output, mistral_output, verification_output, execution_output, schema_output] |
| | ) |
| | sample_dropdown.change( |
| | fn=lambda x: x, |
| | inputs=[sample_dropdown], |
| | outputs=[query_input] |
| | ) |
| | gr.HTML(""" |
| | <div style="margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 8px;"> |
| | <h3>π― How to Use</h3> |
| | <ol> |
| | <li><strong>API Key:</strong> Get a free API key from <a href="https://openrouter.ai" target="_blank">openrouter.ai</a></li> |
| | <li><strong>Query:</strong> Enter your natural language database query</li> |
| | <li><strong>Process:</strong> The agent will analyze your query across 100+ tables and generate optimized SQL</li> |
| | <li><strong>Results:</strong> View responses from multiple AI models, verification analysis, and actual query execution results</li> |
| | </ol> |
| | <p><strong>Features:</strong></p> |
| | <ul> |
| | <li>π§ Multi-model AI analysis (Llama, Mistral, Gemma)</li> |
| | <li>π Intelligent table selection from 100+ tables</li> |
| | <li>β
SQL validation and syntax checking</li> |
| | <li>ποΈ Real database query execution with results</li> |
| | <li>π Cross-model verification and comparison</li> |
| | </ul> |
| | </div> |
| | """) |
| | |
| | return interface |
| |
|
| | def main(): |
| | """Main function to launch the application""" |
| | print("π Starting Intelligent Database Query Agent...") |
| | print("π Loading database schema and metadata...") |
| | interface = create_gradio_interface() |
| | print("β
Database Query Agent Ready!") |
| | print("π Access the interface at: http://localhost:7860") |
| | print("π Don't forget to add your OpenRouter API key!") |
| | interface.launch(share=True) |
| |
|
| | if __name__ == "__main__": |
| | main() |