|
|
|
|
|
import os |
|
|
import asyncio |
|
|
from dataclasses import dataclass |
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
from pydantic_ai import Agent, RunContext |
|
|
from pydantic_ai.mcp import MCPServerStdio |
|
|
from pydantic_ai.providers.groq import GroqProvider |
|
|
from pydantic_ai.models.groq import GroqModel |
|
|
from pydantic_graph import End |
|
|
|
|
|
from tools.searching import SearchingTools |
|
|
|
|
|
import pandas as pd |
|
|
from neo4j import GraphDatabase |
|
|
from typing import List, Dict, Any, Optional |
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
import matplotlib.pyplot as plt |
|
|
import io |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
|
|
|
|
|
|
NEO4J_URI: str = os.getenv("NEO4J_URI", "") |
|
|
NEO4J_USER = os.getenv("NEO4J_USER") |
|
|
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = """ |
|
|
# You are an Expert Football Analyst with access to Barcelona graphdatabase from five seasons - from 2016/2017 to 2020/2021. |
|
|
Use the tools available to respond the user query. You are connected to neo4j by a MCP server created by Frederico Caixeta. |
|
|
Base your analysis **ONLY** by the query results. If the database can't provide what the user is |
|
|
asking for, report that in a professional way. Limit your answer in **1500** characters. |
|
|
When the user asks for visualizations, graphs, or charts, you MUST use the create_chart tool |
|
|
to generate the appropriate visualization. The chart will be displayed in the Visualization tab. |
|
|
# Below you can find some cypher queries as example, so you can understand which artifacts and metadatas are available in the database: |
|
|
// General Overview: players, connections, goals per temporada |
|
|
MATCH (p:Player {team: "Barcelona"}) |
|
|
OPTIONAL MATCH (p)-[r:PASSED_TO]->() |
|
|
OPTIONAL MATCH (g:GoalSequence {team: "Barcelona"}) |
|
|
RETURN |
|
|
count(DISTINCT p) as TotalPlayers, |
|
|
count(DISTINCT r) as TotalPassConnections, |
|
|
sum(r.weight) as TotalPasses, |
|
|
count(DISTINCT g) as TotalGoalSequences, |
|
|
collect(DISTINCT p.season_date) as Seasons |
|
|
// List all seasons available in neo4j |
|
|
MATCH (p:Player) |
|
|
RETURN DISTINCT p.season_date as Season, p.season_id as SeasonID |
|
|
ORDER BY p.season_date |
|
|
// Top 5 connections per season |
|
|
MATCH (p1:Player)-[r:PASSED_TO]->(p2:Player) |
|
|
WITH p1.season_date as Season, p1.name as P1, p2.name as P2, r.weight as Weight |
|
|
ORDER BY Weight DESC |
|
|
WITH Season, collect({passer: P1, receiver: P2, passes: Weight})[0..5] as TopConnections |
|
|
RETURN Season, TopConnections |
|
|
ORDER BY Season |
|
|
// Connections between different zones in field |
|
|
MATCH (p1:Player)-[r:PASSED_TO]->(p2:Player) |
|
|
WHERE p1.season_id = 90 |
|
|
WITH p1, p2, r, |
|
|
CASE WHEN p1.avg_x < 40 THEN 'Def' WHEN p1.avg_x < 80 THEN 'Mid' ELSE 'Att' END as Zone1, |
|
|
CASE WHEN p2.avg_x < 40 THEN 'Def' WHEN p2.avg_x < 80 THEN 'Mid' ELSE 'Att' END as Zone2 |
|
|
WHERE Zone1 <> Zone2 |
|
|
RETURN Zone1 + ' -> ' + Zone2 as Transition, sum(r.weight) as TotalPasses |
|
|
ORDER BY TotalPasses DESC |
|
|
// Total number of sequences of goals that Rakitić (a player) was involved |
|
|
MATCH (p:Player {name: "Ivan Rakitić"})-[:INVOLVED_IN]->(g:GoalSequence) |
|
|
RETURN count(g) as TotalGoalSequences |
|
|
# The Property Keys available are: avg_x, avg_y, data, end_x, end_y, id, match_id, name, nodes, num_passes, order, possession, relationships, season_date, season_id, sequence_id, style, team, visualisation, weight, x, y. |
|
|
The Nodes are: Player, GoalSequence. |
|
|
The Relationships are: INVOLVED_IN, PASSED_IN_SEQUENCE, PASSED_TO. |
|
|
The seasons ids and their dates: [{90: '2020/2021'}, {42: '2019/2020'}, {4: '2018/2019'}, {1: '2017/2018'}, {2: '2016/2017'}] |
|
|
# All players played in all seasons are: |
|
|
Abel Ruiz Ortega |
|
|
Aleix Vidal Parreu |
|
|
André Filipe Tavares Gomes |
|
|
Andrés Iniesta Luján |
|
|
Anssumane Fati |
|
|
Antoine Griezmann |
|
|
Arda Turan |
|
|
Arthur Henrique Ramos de Oliveira Melo |
|
|
Arturo Erasmo Vidal Pardo |
|
|
Carles Aleña Castillo |
|
|
Carles Pérez Sayol |
|
|
Claudio Andrés Bravo Muñoz |
|
|
Clément Lenglet |
|
|
Denis Suárez Fernández |
|
|
Francisco Alcácer GarcÃa |
|
|
Francisco António Machado Mota de Castro Trincão |
|
|
Frenkie de Jong |
|
|
Gerard Deulofeu Lázaro |
|
|
Gerard Piqué Bernabéu |
|
|
Héctor Junior Firpo Adames |
|
|
Ivan Rakitić |
|
|
Jasper Cillessen |
|
|
Javier Alejandro Mascherano |
|
|
Jean-Clair Todibo |
|
|
Jordi Alba Ramos |
|
|
José Manuel Arnáiz DÃaz |
|
|
José Paulo Bezzera Maciel Júnior |
|
|
Jérémy Mathieu |
|
|
Kevin-Prince Boateng |
|
|
Lionel Andrés Messi Cuccittini |
|
|
Lucas Digne |
|
|
Luis Alberto Suárez DÃaz |
|
|
Malcom Filipe Silva de Oliveira |
|
|
Marc-André ter Stegen |
|
|
Marlon Santos da Silva Barbosa |
|
|
Martin Braithwaite Christensen |
|
|
Miralem Pjanić |
|
|
Moriba Kourouma Kourouma |
|
|
Moussa Wagué |
|
|
Munir El Haddadi Mohamed |
|
|
Neymar da Silva Santos Junior |
|
|
Norberto Murara Neto |
|
|
Nélson Cabral Semedo |
|
|
Ousmane Dembélé |
|
|
Pedro González López |
|
|
Philippe Coutinho Correia |
|
|
Rafael Alcântara do Nascimento |
|
|
Ricard Puig Martà |
|
|
Ronald Federico Araújo da Silva |
|
|
Samuel Yves Umtiti |
|
|
Sergi Roberto Carnicer |
|
|
Sergino Dest |
|
|
Sergio Busquets i Burgos |
|
|
Thomas Vermaelen |
|
|
Yerry Fernando Mina González |
|
|
Ãlex Collado Gutiérrez |
|
|
Óscar Mingueza GarcÃa |
|
|
""" |
|
|
|
|
|
api_key = os.getenv("GROQ_DEV_API_KEY") |
|
|
groq_model = GroqModel( |
|
|
"moonshotai/kimi-k2-instruct-0905", |
|
|
provider=GroqProvider(api_key=api_key) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SearchAgentDeps: |
|
|
tools: SearchingTools |
|
|
|
|
|
agent = Agent( |
|
|
model=groq_model, |
|
|
system_prompt=SYSTEM_PROMPT, |
|
|
deps_type=SearchAgentDeps, |
|
|
) |
|
|
|
|
|
tools_instance = SearchingTools() |
|
|
deps = SearchAgentDeps(tools=tools_instance) |
|
|
|
|
|
class Neo4jConnection: |
|
|
def __init__(self): |
|
|
self.driver = GraphDatabase.driver( |
|
|
NEO4J_URI, |
|
|
auth=(NEO4J_USER, NEO4J_PASSWORD), |
|
|
) |
|
|
print(f"✓ Conectado ao Neo4j: {NEO4J_URI}") |
|
|
|
|
|
def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None): |
|
|
with self.driver.session() as session: |
|
|
result = session.run(query, parameters or {}) |
|
|
return [record.data() for record in result] |
|
|
|
|
|
def close(self): |
|
|
self.driver.close() |
|
|
|
|
|
neo4j = Neo4jConnection() |
|
|
|
|
|
def _ensure_read_only(query: str) -> Optional[str]: |
|
|
q = query.upper().strip() |
|
|
dangerous = ["DELETE", "DETACH", "REMOVE", "SET", "CREATE", "MERGE", "DROP", "CALL"] |
|
|
if any(k in q for k in dangerous): |
|
|
return "❌ ERRO: Apenas queries de leitura (MATCH/RETURN) são permitidas." |
|
|
return None |
|
|
|
|
|
@agent.tool(name="cypher_query_execute", retries=3) |
|
|
async def execute_cypher_query(ctx: RunContext[SearchAgentDeps], query: str, parameters: Optional[Dict[str, Any]] = None, limit: int = 100) -> str: |
|
|
"""Executa uma query Cypher READ-ONLY no Neo4j (MATCH/RETURN).""" |
|
|
err = _ensure_read_only(query) |
|
|
if err: |
|
|
return err |
|
|
|
|
|
try: |
|
|
q_upper = query.upper() |
|
|
if "LIMIT" not in q_upper: |
|
|
query = f"{query}\nLIMIT {limit}" |
|
|
|
|
|
results = neo4j.execute_query(query, parameters) |
|
|
|
|
|
if not results: |
|
|
return "✓ Query executada, mas nenhum resultado encontrado." |
|
|
|
|
|
out = [f"📊 Resultados ({len(results)} encontrados):"] |
|
|
for i, record in enumerate(results[:10], 1): |
|
|
items = [f"{k}={v}" for k, v in record.items()] |
|
|
out.append(f"{i}. {', '.join(items)}") |
|
|
|
|
|
if len(results) > 10: |
|
|
out.append(f"... e mais {len(results) - 10} resultados.") |
|
|
|
|
|
return "\n".join(out) |
|
|
except Exception as e: |
|
|
return f"❌ Erro ao executar query: {str(e)}" |
|
|
|
|
|
@agent.tool(name="web_search", retries=3) |
|
|
async def procura_web(ctx: RunContext[SearchAgentDeps], search_query: str) -> str: |
|
|
"""Pesquisa na web""" |
|
|
return ctx.deps.tools.search_web(search_query=search_query, max_results=15) |
|
|
|
|
|
last_chart_image = None |
|
|
|
|
|
def create_placeholder_image(): |
|
|
bg_color = (248, 250, 252) |
|
|
text_color = (100, 116, 139) |
|
|
img = Image.new("RGB", (800, 600), color=bg_color) |
|
|
draw = ImageDraw.Draw(img) |
|
|
try: |
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 48) |
|
|
except Exception: |
|
|
font = ImageFont.load_default() |
|
|
text = "Waiting for plot..." |
|
|
bbox = draw.textbbox((0, 0), text, font=font) |
|
|
x = (800 - (bbox[2] - bbox[0])) // 2 |
|
|
y = (600 - (bbox[3] - bbox[1])) // 2 |
|
|
draw.text((x, y), text, fill=text_color, font=font) |
|
|
return img |
|
|
|
|
|
PLACEHOLDER_IMAGE = create_placeholder_image() |
|
|
|
|
|
@agent.tool(name="create_chart", retries=2, timeout=30.0) |
|
|
async def create_chart( |
|
|
ctx: RunContext[SearchAgentDeps], |
|
|
data: List[Dict[str, Any]], |
|
|
x_column: str, |
|
|
y_column: str, |
|
|
chart_type: str = "bar", |
|
|
title: Optional[str] = None, |
|
|
x_title: Optional[str] = None, |
|
|
y_title: Optional[str] = None, |
|
|
) -> str: |
|
|
global last_chart_image |
|
|
try: |
|
|
df = pd.DataFrame(data) |
|
|
if x_column not in df.columns or y_column not in df.columns: |
|
|
return f"❌ Erro: Colunas não encontradas. Disponíveis: {list(df.columns)}" |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
|
|
|
if chart_type == "bar": |
|
|
ax.bar(df[x_column], df[y_column]) |
|
|
elif chart_type == "horizontal_bar": |
|
|
ax.barh(df[x_column], df[y_column]) |
|
|
elif chart_type == "line": |
|
|
ax.plot(df[x_column], df[y_column], marker="o") |
|
|
elif chart_type == "scatter": |
|
|
ax.scatter(df[x_column], df[y_column]) |
|
|
else: |
|
|
return f"❌ Tipo '{chart_type}' não suportado" |
|
|
|
|
|
ax.set_title(title or f"{y_column} por {x_column}") |
|
|
ax.set_xlabel(x_title or x_column) |
|
|
ax.set_ylabel(y_title or y_column) |
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white") |
|
|
buf.seek(0) |
|
|
last_chart_image = Image.open(buf).copy() |
|
|
plt.close() |
|
|
|
|
|
return f"✅ Gráfico '{chart_type}' criado ({len(df)} registros)." |
|
|
except Exception as e: |
|
|
return f"❌ Erro ao criar gráfico: {str(e)}" |
|
|
|
|
|
def get_current_chart(): |
|
|
global last_chart_image |
|
|
return last_chart_image if last_chart_image is not None else PLACEHOLDER_IMAGE |
|
|
|
|
|
async def stream_agent_response_safe(user_query: str) -> str: |
|
|
|
|
|
async with agent.iter(user_query, deps=deps) as agent_run: |
|
|
async for node in agent_run: |
|
|
if isinstance(node, End) and agent_run.result: |
|
|
return str(agent_run.result.output) |
|
|
return "Erro na execução do agente" |
|
|
|
|
|
__all__ = [ |
|
|
"agent", |
|
|
"deps", |
|
|
"PLACEHOLDER_IMAGE", |
|
|
"get_current_chart", |
|
|
"stream_agent_response_safe", |
|
|
] |
|
|
|