import asyncio import json from typing import Dict from autogen import ConversableAgent from src.agents.sqlspideragent import TrackableSpiderConversableAgent from src.agents.userproxyagent import TrackableUserProxyAgent import streamlit as st import os from src.utils.sqlexecutor import SQLExec os.environ["AUTOGEN_USE_DOCKER"] = "False" class AgentChatSqlSpider: def __init__(self, assistant_name, user_proxy_name, llm_config, problem): self.schema = None self.question = None self.sql_writer = TrackableSpiderConversableAgent( "sql_writer", llm_config=llm_config, system_message="You are good at writing SQL queries. Always respond with a function call to execute_sql().", is_termination_msg=self.check_termination, ) self.user_proxy = TrackableUserProxyAgent(name=user_proxy_name, system_message="You are Admin", human_input_mode="NEVER", llm_config=llm_config, code_execution_config=False, is_termination_msg=lambda x: x.get("content", "").strip().endswith( "TERMINATE")) self.problem = problem self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) async def initiate_chat(self): message = f"""Below is the schema for a SQL database: {self.schema} Generate a SQL query to answer the following question: {self.question} """ obj_SQLExec = SQLExec(self.sql_writer,self.user_proxy) await self.user_proxy.a_initiate_chat(self.sql_writer, message=message, clear_history=st.session_state["chat_with_history"]) def run(self): self.loop.run_until_complete(self.initiate_chat()) def check_termination(msg: Dict): if "tool_responses" not in msg: return False json_str = msg["tool_responses"][0]["content"] obj = json.loads(json_str) return "error" not in obj or obj["error"] is None and obj["reward"] == 1