Spaces:
Sleeping
Sleeping
File size: 8,032 Bytes
b3ebb38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
from database_manager import DatabaseManager
from prompts import *
from typing import Optional, Dict, Any
from langchain_mistralai import ChatMistralAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from dotenv import load_dotenv
import os
import json
import re
load_dotenv()
class CSVAgent:
"""Agente para análise de dados CSV usando LangChain e Mistral AI"""
def __init__(self,
database_manager: DatabaseManager,
general_llm: Optional[ChatMistralAI] = None,
sql_llm: Optional[ChatMistralAI] = None):
"""
Inicializa o agente CSV
Args:
database_manager: Instância do DatabaseManager
general_llm: LLM para uso geral (opcional)
sql_llm: LLM para geração de SQL (opcional)
"""
self.db_manager = database_manager
# Configura LLMs padrão se não fornecidos
if general_llm is None:
self.llm = ChatMistralAI(
model="mistral-medium-latest",
api_key=os.getenv('MISTRAL_API_KEY'),
temperature=0.1
)
else:
self.llm = general_llm
if sql_llm is None:
self.sql_llm = ChatMistralAI(
model="codestral-latest",
api_key=os.getenv('MISTRAL_API_KEY'),
temperature=0.1
)
else:
self.sql_llm = sql_llm
# Configura ferramentas e agente
self._setup_tools_and_agent()
def _setup_tools_and_agent(self):
"""Configura ferramentas e agente do LangChain"""
# Cria a ferramenta de consulta ao banco de dados
@tool
def consultar_banco_dados(consulta_linguagem_natural: str) -> str:
"""
Executa consultas SQL nos dados CSV carregados.
Use esta ferramenta quando o usuário pedir dados específicos, contagens, estatísticas,
filtragem, análise ou qualquer operação que envolva os dados das tabelas.
Args:
consulta_linguagem_natural: A pergunta do usuário sobre os dados
Returns:
String JSON com os resultados da consulta
"""
try:
print(f"Executando consulta no banco para: {consulta_linguagem_natural}")
# Gera consulta SQL usando Codestral
sql_query = self._generate_sql_query(consulta_linguagem_natural)
print(f"SQL gerado: {sql_query}")
# Executa consulta
results = self.db_manager.execute_sql_query(sql_query)
print(f"Consulta retornou {len(results)} resultados")
return json.dumps({
"sucesso": True,
"consulta_sql": sql_query,
"resultados": results,
"total_resultados": len(results),
"tabelas_disponiveis": self.db_manager.get_tables_list()
}, indent=2, default=str, ensure_ascii=False)
except Exception as e:
print(f"Erro na consulta ao banco: {str(e)}")
return json.dumps({
"sucesso": False,
"erro": str(e),
"consulta_sql": None,
"resultados": [],
"total_resultados": 0,
"tabelas_disponiveis": self.db_manager.get_tables_list()
}, ensure_ascii=False)
# Armazena a ferramenta
self.database_tool = consultar_banco_dados
# Cria prompt do sistema
self.system_prompt = get_system_prompt(
database_schema=self.db_manager.get_database_schema(),
tables_list=', '.join(self.db_manager.get_tables_list())
)
# Cria lista de ferramentas
self.tools = [self.database_tool]
# Cria agente
try:
prompt = ChatPromptTemplate.from_messages([
("system", self.system_prompt),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad")
])
self.agent = create_tool_calling_agent(self.llm, self.tools, prompt)
self.agent_executor = AgentExecutor(
agent=self.agent,
tools=self.tools,
verbose=True,
handle_parsing_errors=True,
max_iterations=3,
return_intermediate_steps=True
)
except Exception as e:
print(f"Aviso: Não foi possível criar o executor do agente: {e}")
self.agent_executor = None
def _generate_sql_query(self, user_query: str) -> str:
"""Gera consulta SQL a partir de linguagem natural usando LLM"""
sql_prompt = get_sql_prompt(
database_schema=self.db_manager.get_database_schema(),
tables_list=', '.join(self.db_manager.get_tables_list()),
user_query=user_query
)
try:
response = self.sql_llm.invoke(sql_prompt)
# Limpa a consulta SQL
sql_query = response.content
sql_query = re.sub(r'```sql\n?', '', sql_query)
sql_query = re.sub(r'```\n?', '', sql_query)
sql_query = sql_query.strip()
# Remove ponto e vírgula final se presente
if sql_query.endswith(';'):
sql_query = sql_query[:-1]
return sql_query
except Exception as e:
raise Exception(f"Erro ao gerar SQL: {str(e)}")
def query(self, user_input: str) -> Dict[str, Any]:
"""
Método principal para processar consultas do usuário
Args:
user_input: Consulta em linguagem natural do usuário
Returns:
Dicionário contendo resposta e metadados
"""
try:
if self.agent_executor:
print(f"Processando consulta com agente Mistral: {user_input}")
# Executa com agente
result = self.agent_executor.invoke({"input": user_input})
# Captura a consulta SQL dos intermediate_steps
sql_query = None
function_called = False
for step in result.get("intermediate_steps", []):
if len(step) > 1 and "consultar_banco_dados" in str(step[0]):
function_called = True
try:
resposta = json.loads(step[1])
sql_query = resposta.get("consulta_sql")
break
except:
pass
return {
"resposta": result["output"],
"tipo_consulta": "executor_agent",
"funcao_chamada": function_called,
"consulta_sql": sql_query,
"sucesso": True,
"modelo_usado": self.llm.model,
"tabelas_disponiveis": self.db_manager.get_tables_list(),
"resultado_agente_bruto": result
}
except Exception as e:
error_msg = f"Encontrei um erro: {str(e)}"
print(f"Erro no processamento da consulta: {str(e)}")
return {
"resposta": error_msg,
"tipo_consulta": "erro",
"funcao_chamada": False,
"sucesso": False,
"erro": str(e),
"tabelas_disponiveis": self.db_manager.get_tables_list()
}
|