AutogenMultiAgent / src /usecases /agentchatsqlspider.py
genaitiwari's picture
sql spider
3f4dbc7
import asyncio
import json
from typing import Dict
from autogen import ConversableAgent
from src.agents.sqlspideragent import TrackableSpiderConversableAgent
from src.agents.userproxyagent import TrackableUserProxyAgent
import streamlit as st
import os
from src.utils.sqlexecutor import SQLExec
os.environ["AUTOGEN_USE_DOCKER"] = "False"
class AgentChatSqlSpider:
def __init__(self, assistant_name, user_proxy_name, llm_config, problem):
self.schema = None
self.question = None
self.sql_writer = TrackableSpiderConversableAgent(
"sql_writer",
llm_config=llm_config,
system_message="You are good at writing SQL queries. Always respond with a function call to execute_sql().",
is_termination_msg=self.check_termination,
)
self.user_proxy = TrackableUserProxyAgent(name=user_proxy_name,
system_message="You are Admin",
human_input_mode="NEVER",
llm_config=llm_config,
code_execution_config=False,
is_termination_msg=lambda x: x.get("content", "").strip().endswith(
"TERMINATE"))
self.problem = problem
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
async def initiate_chat(self):
message = f"""Below is the schema for a SQL database:
{self.schema}
Generate a SQL query to answer the following question:
{self.question}
"""
obj_SQLExec = SQLExec(self.sql_writer,self.user_proxy)
await self.user_proxy.a_initiate_chat(self.sql_writer, message=message,
clear_history=st.session_state["chat_with_history"])
def run(self):
self.loop.run_until_complete(self.initiate_chat())
def check_termination(msg: Dict):
if "tool_responses" not in msg:
return False
json_str = msg["tool_responses"][0]["content"]
obj = json.loads(json_str)
return "error" not in obj or obj["error"] is None and obj["reward"] == 1