Spaces:
Sleeping
Sleeping
Sajil Awale commited on
Commit ·
da09833
1
Parent(s): 54fb037
added gemini support and fast-mcp integration
Browse files- .DS_Store +0 -0
- .gitignore +1 -0
- __pycache__/money_rag.cpython-312.pyc +0 -0
- app.py +31 -3
- mcp_server.py +46 -0
- money_rag.py +20 -4
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitignore
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
demo_data/*
|
| 2 |
.env*.png
|
|
|
|
|
|
| 1 |
demo_data/*
|
| 2 |
.env*.png
|
| 3 |
+
.env
|
__pycache__/money_rag.cpython-312.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import asyncio
|
| 3 |
import os
|
|
|
|
|
|
|
| 4 |
from money_rag import MoneyRAG
|
| 5 |
|
| 6 |
st.set_page_config(page_title="MoneyRAG", layout="wide")
|
|
@@ -12,7 +14,7 @@ with st.sidebar:
|
|
| 12 |
|
| 13 |
if provider == "Google":
|
| 14 |
models = ["gemini-3-flash-preview", "gemini-3-pro-image-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"]
|
| 15 |
-
embeddings = ["
|
| 16 |
else:
|
| 17 |
models = ["gpt-5-mini", "gpt-5-nano", "gpt-4o-mini", "gpt-4o"]
|
| 18 |
embeddings = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
|
|
@@ -85,10 +87,34 @@ if "rag" in st.session_state:
|
|
| 85 |
if "messages" not in st.session_state:
|
| 86 |
st.session_state.messages = []
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
for message in st.session_state.messages:
|
| 89 |
with st.chat_message(message["role"]):
|
| 90 |
-
|
| 91 |
|
|
|
|
| 92 |
if prompt := st.chat_input("Ask about your spending..."):
|
| 93 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 94 |
with st.chat_message("user"):
|
|
@@ -97,7 +123,9 @@ if "rag" in st.session_state:
|
|
| 97 |
with st.chat_message("assistant"):
|
| 98 |
with st.spinner("Thinking..."):
|
| 99 |
response = asyncio.run(st.session_state.rag.chat(prompt))
|
| 100 |
-
|
|
|
|
| 101 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|
|
| 102 |
else:
|
| 103 |
st.info("Please authenticate in the sidebar to start.")
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import asyncio
|
| 3 |
import os
|
| 4 |
+
import json
|
| 5 |
+
import plotly.io as pio
|
| 6 |
from money_rag import MoneyRAG
|
| 7 |
|
| 8 |
st.set_page_config(page_title="MoneyRAG", layout="wide")
|
|
|
|
| 14 |
|
| 15 |
if provider == "Google":
|
| 16 |
models = ["gemini-3-flash-preview", "gemini-3-pro-image-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"]
|
| 17 |
+
embeddings = ["gemini-embedding-001"]
|
| 18 |
else:
|
| 19 |
models = ["gpt-5-mini", "gpt-5-nano", "gpt-4o-mini", "gpt-4o"]
|
| 20 |
embeddings = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
|
|
|
|
| 87 |
if "messages" not in st.session_state:
|
| 88 |
st.session_state.messages = []
|
| 89 |
|
| 90 |
+
# Helper function to cleverly render either text or a Plotly chart
|
| 91 |
+
def render_content(content):
|
| 92 |
+
# We might have mixed text and charts delimited by ===CHART=== ... ===ENDCHART===
|
| 93 |
+
if isinstance(content, str) and "===CHART===" in content:
|
| 94 |
+
parts = content.split("===CHART===")
|
| 95 |
+
# Render first text part
|
| 96 |
+
st.markdown(parts[0].strip())
|
| 97 |
+
|
| 98 |
+
for part in parts[1:]:
|
| 99 |
+
if "===ENDCHART===" in part:
|
| 100 |
+
chart_json, remaining_text = part.split("===ENDCHART===")
|
| 101 |
+
try:
|
| 102 |
+
fig = pio.from_json(chart_json.strip())
|
| 103 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
st.error("Failed to render chart.")
|
| 106 |
+
|
| 107 |
+
if remaining_text.strip():
|
| 108 |
+
st.markdown(remaining_text.strip())
|
| 109 |
+
else:
|
| 110 |
+
st.markdown(content)
|
| 111 |
+
|
| 112 |
+
# Render previous messages
|
| 113 |
for message in st.session_state.messages:
|
| 114 |
with st.chat_message(message["role"]):
|
| 115 |
+
render_content(message["content"])
|
| 116 |
|
| 117 |
+
# Handle new user input
|
| 118 |
if prompt := st.chat_input("Ask about your spending..."):
|
| 119 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 120 |
with st.chat_message("user"):
|
|
|
|
| 123 |
with st.chat_message("assistant"):
|
| 124 |
with st.spinner("Thinking..."):
|
| 125 |
response = asyncio.run(st.session_state.rag.chat(prompt))
|
| 126 |
+
render_content(response)
|
| 127 |
+
|
| 128 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 129 |
+
|
| 130 |
else:
|
| 131 |
st.info("Please authenticate in the sidebar to start.")
|
mcp_server.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from fastmcp import FastMCP
|
| 2 |
from langchain_qdrant import QdrantVectorStore
|
| 3 |
from qdrant_client import QdrantClient
|
|
@@ -180,6 +182,50 @@ def semantic_search(query: str, top_k: int = 5) -> str:
|
|
| 180 |
except Exception as e:
|
| 181 |
return f"Error performing search: {str(e)}"
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
# A helper to clear data (useful for session reset)
|
| 184 |
@mcp.tool()
|
| 185 |
def clear_database() -> str:
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import plotly.express as px
|
| 3 |
from fastmcp import FastMCP
|
| 4 |
from langchain_qdrant import QdrantVectorStore
|
| 5 |
from qdrant_client import QdrantClient
|
|
|
|
| 182 |
except Exception as e:
|
| 183 |
return f"Error performing search: {str(e)}"
|
| 184 |
|
| 185 |
+
|
| 186 |
+
@mcp.tool()
|
| 187 |
+
def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_col: str, title: str) -> str:
|
| 188 |
+
"""
|
| 189 |
+
Generate an interactive Plotly chart from the money_rag SQLite database.
|
| 190 |
+
Use this proactively whenever a visual representation of data would be helpful.
|
| 191 |
+
|
| 192 |
+
CRITICAL INSTRUCTIONS:
|
| 193 |
+
1. Write a valid SQLite SELECT query.
|
| 194 |
+
2. Aggregate data appropriately (e.g., use GROUP BY for pie/bar charts).
|
| 195 |
+
3. Pass the exact column names from your query to x_col and y_col.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
sql_query: The SQL SELECT query (e.g. "SELECT category, SUM(amount) as total FROM transactions GROUP BY category")
|
| 199 |
+
chart_type: Must be exactly "bar", "pie", or "line"
|
| 200 |
+
x_col: Column name from query for X-axis (or labels for pie)
|
| 201 |
+
y_col: Column name from query for Y-axis (or values for pie)
|
| 202 |
+
title: Title of the chart
|
| 203 |
+
"""
|
| 204 |
+
try:
|
| 205 |
+
conn = sqlite3.connect(DB_PATH)
|
| 206 |
+
df = pd.read_sql_query(sql_query, conn)
|
| 207 |
+
conn.close()
|
| 208 |
+
if df.empty:
|
| 209 |
+
return '{"error": "No data found for this query."}'
|
| 210 |
+
if chart_type == "bar":
|
| 211 |
+
fig = px.bar(df, x=x_col, y=y_col, title=title)
|
| 212 |
+
elif chart_type == "pie":
|
| 213 |
+
fig = px.pie(df, names=x_col, values=y_col, title=title)
|
| 214 |
+
elif chart_type == "line":
|
| 215 |
+
fig = px.line(df, x=x_col, y=y_col, title=title)
|
| 216 |
+
else:
|
| 217 |
+
return f'{{"error": "Unsupported chart type: {chart_type}"}}'
|
| 218 |
+
# Write the huge JSON to a temp file instead of returning it directly to LLM context
|
| 219 |
+
chart_path = os.path.join(DATA_DIR, "latest_chart.json")
|
| 220 |
+
with open(chart_path, "w") as f:
|
| 221 |
+
f.write(fig.to_json())
|
| 222 |
+
|
| 223 |
+
return "Chart generated successfully! It has been sent to the user's UI. Continue analyzing without outputting the JSON parameters directly."
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
return f'{{"error": "Failed to generate chart: {str(e)}"}}'
|
| 227 |
+
|
| 228 |
+
|
| 229 |
# A helper to clear data (useful for session reset)
|
| 230 |
@mcp.tool()
|
| 231 |
def clear_database() -> str:
|
money_rag.py
CHANGED
|
@@ -209,6 +209,9 @@ class MoneyRAG:
|
|
| 209 |
"You are a financial analyst. Use the provided tools to query the database "
|
| 210 |
"and perform semantic searches. Spending is POSITIVE (>0). "
|
| 211 |
"Always explain your findings clearly."
|
|
|
|
|
|
|
|
|
|
| 212 |
)
|
| 213 |
|
| 214 |
self.agent = create_agent(
|
|
@@ -221,6 +224,11 @@ class MoneyRAG:
|
|
| 221 |
async def chat(self, query: str):
|
| 222 |
config = {"configurable": {"thread_id": "session_1"}}
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
result = await self.agent.ainvoke(
|
| 225 |
{"messages": [{"role": "user", "content": query}]},
|
| 226 |
config,
|
|
@@ -235,10 +243,18 @@ class MoneyRAG:
|
|
| 235 |
for block in content:
|
| 236 |
if isinstance(block, dict) and block.get("type") == "text":
|
| 237 |
text_parts.append(block.get("text", ""))
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
async def cleanup(self):
|
| 244 |
"""Delete temporary session files and close MCP client."""
|
|
|
|
| 209 |
"You are a financial analyst. Use the provided tools to query the database "
|
| 210 |
"and perform semantic searches. Spending is POSITIVE (>0). "
|
| 211 |
"Always explain your findings clearly."
|
| 212 |
+
"IMPORTANT: Whenever possible and relevant (e.g. when discussing trends, comparing categories, or showing breakdowns), "
|
| 213 |
+
"you MUST proactively use the 'generate_interactive_chart' tool to generate visual plots (bar, pie, or line charts) to accompany your analysis. "
|
| 214 |
+
"WARNING: You MUST use the actual tool call to generate the chart. DO NOT simply output a json block with chart parameters as your final text answer."
|
| 215 |
)
|
| 216 |
|
| 217 |
self.agent = create_agent(
|
|
|
|
| 224 |
async def chat(self, query: str):
|
| 225 |
config = {"configurable": {"thread_id": "session_1"}}
|
| 226 |
|
| 227 |
+
# Clear out any previous chart so we don't carry over stale plots
|
| 228 |
+
chart_path = os.path.join(self.temp_dir, "latest_chart.json")
|
| 229 |
+
if os.path.exists(chart_path):
|
| 230 |
+
os.remove(chart_path)
|
| 231 |
+
|
| 232 |
result = await self.agent.ainvoke(
|
| 233 |
{"messages": [{"role": "user", "content": query}]},
|
| 234 |
config,
|
|
|
|
| 243 |
for block in content:
|
| 244 |
if isinstance(block, dict) and block.get("type") == "text":
|
| 245 |
text_parts.append(block.get("text", ""))
|
| 246 |
+
final_text = "\n".join(text_parts)
|
| 247 |
+
else:
|
| 248 |
+
final_text = content
|
| 249 |
+
|
| 250 |
+
# Check if the tool generated a chart file on disk during this turn
|
| 251 |
+
chart_path = os.path.join(self.temp_dir, "latest_chart.json")
|
| 252 |
+
if os.path.exists(chart_path):
|
| 253 |
+
with open(chart_path, "r") as f:
|
| 254 |
+
chart_json = f.read()
|
| 255 |
+
final_text += f"\n\n===CHART===\n{chart_json}\n===ENDCHART==="
|
| 256 |
+
|
| 257 |
+
return final_text
|
| 258 |
|
| 259 |
async def cleanup(self):
|
| 260 |
"""Delete temporary session files and close MCP client."""
|