Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| from typing import Dict | |
| from autogen import ConversableAgent | |
| from src.agents.sqlspideragent import TrackableSpiderConversableAgent | |
| from src.agents.userproxyagent import TrackableUserProxyAgent | |
| import streamlit as st | |
| import os | |
| from src.utils.sqlexecutor import SQLExec | |
| os.environ["AUTOGEN_USE_DOCKER"] = "False" | |
| class AgentChatSqlSpider: | |
| def __init__(self, assistant_name, user_proxy_name, llm_config, problem): | |
| self.schema = None | |
| self.question = None | |
| self.sql_writer = TrackableSpiderConversableAgent( | |
| "sql_writer", | |
| llm_config=llm_config, | |
| system_message="You are good at writing SQL queries. Always respond with a function call to execute_sql().", | |
| is_termination_msg=self.check_termination, | |
| ) | |
| self.user_proxy = TrackableUserProxyAgent(name=user_proxy_name, | |
| system_message="You are Admin", | |
| human_input_mode="NEVER", | |
| llm_config=llm_config, | |
| code_execution_config=False, | |
| is_termination_msg=lambda x: x.get("content", "").strip().endswith( | |
| "TERMINATE")) | |
| self.problem = problem | |
| self.loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(self.loop) | |
| async def initiate_chat(self): | |
| message = f"""Below is the schema for a SQL database: | |
| {self.schema} | |
| Generate a SQL query to answer the following question: | |
| {self.question} | |
| """ | |
| obj_SQLExec = SQLExec(self.sql_writer,self.user_proxy) | |
| await self.user_proxy.a_initiate_chat(self.sql_writer, message=message, | |
| clear_history=st.session_state["chat_with_history"]) | |
| def run(self): | |
| self.loop.run_until_complete(self.initiate_chat()) | |
| def check_termination(msg: Dict): | |
| if "tool_responses" not in msg: | |
| return False | |
| json_str = msg["tool_responses"][0]["content"] | |
| obj = json.loads(json_str) | |
| return "error" not in obj or obj["error"] is None and obj["reward"] == 1 | |