Sajil Awale commited on
Commit
da09833
·
1 Parent(s): 54fb037

added gemini support and fast-mcp integration

Browse files
Files changed (6) hide show
  1. .DS_Store +0 -0
  2. .gitignore +1 -0
  3. __pycache__/money_rag.cpython-312.pyc +0 -0
  4. app.py +31 -3
  5. mcp_server.py +46 -0
  6. money_rag.py +20 -4
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  demo_data/*
2
  .env*.png
 
 
1
  demo_data/*
2
  .env*.png
3
+ .env
__pycache__/money_rag.cpython-312.pyc ADDED
Binary file (14.2 kB). View file
 
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
2
  import asyncio
3
  import os
 
 
4
  from money_rag import MoneyRAG
5
 
6
  st.set_page_config(page_title="MoneyRAG", layout="wide")
@@ -12,7 +14,7 @@ with st.sidebar:
12
 
13
  if provider == "Google":
14
  models = ["gemini-3-flash-preview", "gemini-3-pro-image-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"]
15
- embeddings = ["text-embedding-004"]
16
  else:
17
  models = ["gpt-5-mini", "gpt-5-nano", "gpt-4o-mini", "gpt-4o"]
18
  embeddings = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
@@ -85,10 +87,34 @@ if "rag" in st.session_state:
85
  if "messages" not in st.session_state:
86
  st.session_state.messages = []
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  for message in st.session_state.messages:
89
  with st.chat_message(message["role"]):
90
- st.markdown(message["content"])
91
 
 
92
  if prompt := st.chat_input("Ask about your spending..."):
93
  st.session_state.messages.append({"role": "user", "content": prompt})
94
  with st.chat_message("user"):
@@ -97,7 +123,9 @@ if "rag" in st.session_state:
97
  with st.chat_message("assistant"):
98
  with st.spinner("Thinking..."):
99
  response = asyncio.run(st.session_state.rag.chat(prompt))
100
- st.markdown(response)
 
101
  st.session_state.messages.append({"role": "assistant", "content": response})
 
102
  else:
103
  st.info("Please authenticate in the sidebar to start.")
 
1
  import streamlit as st
2
  import asyncio
3
  import os
4
+ import json
5
+ import plotly.io as pio
6
  from money_rag import MoneyRAG
7
 
8
  st.set_page_config(page_title="MoneyRAG", layout="wide")
 
14
 
15
  if provider == "Google":
16
  models = ["gemini-3-flash-preview", "gemini-3-pro-image-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"]
17
+ embeddings = ["gemini-embedding-001"]
18
  else:
19
  models = ["gpt-5-mini", "gpt-5-nano", "gpt-4o-mini", "gpt-4o"]
20
  embeddings = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
 
87
  if "messages" not in st.session_state:
88
  st.session_state.messages = []
89
 
90
+ # Helper function to cleverly render either text or a Plotly chart
91
+ def render_content(content):
92
+ # We might have mixed text and charts delimited by ===CHART=== ... ===ENDCHART===
93
+ if isinstance(content, str) and "===CHART===" in content:
94
+ parts = content.split("===CHART===")
95
+ # Render first text part
96
+ st.markdown(parts[0].strip())
97
+
98
+ for part in parts[1:]:
99
+ if "===ENDCHART===" in part:
100
+ chart_json, remaining_text = part.split("===ENDCHART===")
101
+ try:
102
+ fig = pio.from_json(chart_json.strip())
103
+ st.plotly_chart(fig, use_container_width=True)
104
+ except Exception as e:
105
+ st.error("Failed to render chart.")
106
+
107
+ if remaining_text.strip():
108
+ st.markdown(remaining_text.strip())
109
+ else:
110
+ st.markdown(content)
111
+
112
+ # Render previous messages
113
  for message in st.session_state.messages:
114
  with st.chat_message(message["role"]):
115
+ render_content(message["content"])
116
 
117
+ # Handle new user input
118
  if prompt := st.chat_input("Ask about your spending..."):
119
  st.session_state.messages.append({"role": "user", "content": prompt})
120
  with st.chat_message("user"):
 
123
  with st.chat_message("assistant"):
124
  with st.spinner("Thinking..."):
125
  response = asyncio.run(st.session_state.rag.chat(prompt))
126
+ render_content(response)
127
+
128
  st.session_state.messages.append({"role": "assistant", "content": response})
129
+
130
  else:
131
  st.info("Please authenticate in the sidebar to start.")
mcp_server.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from fastmcp import FastMCP
2
  from langchain_qdrant import QdrantVectorStore
3
  from qdrant_client import QdrantClient
@@ -180,6 +182,50 @@ def semantic_search(query: str, top_k: int = 5) -> str:
180
  except Exception as e:
181
  return f"Error performing search: {str(e)}"
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # A helper to clear data (useful for session reset)
184
  @mcp.tool()
185
  def clear_database() -> str:
 
1
+ import pandas as pd
2
+ import plotly.express as px
3
  from fastmcp import FastMCP
4
  from langchain_qdrant import QdrantVectorStore
5
  from qdrant_client import QdrantClient
 
182
  except Exception as e:
183
  return f"Error performing search: {str(e)}"
184
 
185
+
186
+ @mcp.tool()
187
+ def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_col: str, title: str) -> str:
188
+ """
189
+ Generate an interactive Plotly chart from the money_rag SQLite database.
190
+ Use this proactively whenever a visual representation of data would be helpful.
191
+
192
+ CRITICAL INSTRUCTIONS:
193
+ 1. Write a valid SQLite SELECT query.
194
+ 2. Aggregate data appropriately (e.g., use GROUP BY for pie/bar charts).
195
+ 3. Pass the exact column names from your query to x_col and y_col.
196
+
197
+ Args:
198
+ sql_query: The SQL SELECT query (e.g. "SELECT category, SUM(amount) as total FROM transactions GROUP BY category")
199
+ chart_type: Must be exactly "bar", "pie", or "line"
200
+ x_col: Column name from query for X-axis (or labels for pie)
201
+ y_col: Column name from query for Y-axis (or values for pie)
202
+ title: Title of the chart
203
+ """
204
+ try:
205
+ conn = sqlite3.connect(DB_PATH)
206
+ df = pd.read_sql_query(sql_query, conn)
207
+ conn.close()
208
+ if df.empty:
209
+ return '{"error": "No data found for this query."}'
210
+ if chart_type == "bar":
211
+ fig = px.bar(df, x=x_col, y=y_col, title=title)
212
+ elif chart_type == "pie":
213
+ fig = px.pie(df, names=x_col, values=y_col, title=title)
214
+ elif chart_type == "line":
215
+ fig = px.line(df, x=x_col, y=y_col, title=title)
216
+ else:
217
+ return f'{{"error": "Unsupported chart type: {chart_type}"}}'
218
+ # Write the huge JSON to a temp file instead of returning it directly to LLM context
219
+ chart_path = os.path.join(DATA_DIR, "latest_chart.json")
220
+ with open(chart_path, "w") as f:
221
+ f.write(fig.to_json())
222
+
223
+ return "Chart generated successfully! It has been sent to the user's UI. Continue analyzing without outputting the JSON parameters directly."
224
+
225
+ except Exception as e:
226
+ return f'{{"error": "Failed to generate chart: {str(e)}"}}'
227
+
228
+
229
  # A helper to clear data (useful for session reset)
230
  @mcp.tool()
231
  def clear_database() -> str:
money_rag.py CHANGED
@@ -209,6 +209,9 @@ class MoneyRAG:
209
  "You are a financial analyst. Use the provided tools to query the database "
210
  "and perform semantic searches. Spending is POSITIVE (>0). "
211
  "Always explain your findings clearly."
 
 
 
212
  )
213
 
214
  self.agent = create_agent(
@@ -221,6 +224,11 @@ class MoneyRAG:
221
  async def chat(self, query: str):
222
  config = {"configurable": {"thread_id": "session_1"}}
223
 
 
 
 
 
 
224
  result = await self.agent.ainvoke(
225
  {"messages": [{"role": "user", "content": query}]},
226
  config,
@@ -235,10 +243,18 @@ class MoneyRAG:
235
  for block in content:
236
  if isinstance(block, dict) and block.get("type") == "text":
237
  text_parts.append(block.get("text", ""))
238
- return "\n".join(text_parts)
239
-
240
- # If content is already a string (OpenAI format), return as-is
241
- return content
 
 
 
 
 
 
 
 
242
 
243
  async def cleanup(self):
244
  """Delete temporary session files and close MCP client."""
 
209
  "You are a financial analyst. Use the provided tools to query the database "
210
  "and perform semantic searches. Spending is POSITIVE (>0). "
211
  "Always explain your findings clearly."
212
+ "IMPORTANT: Whenever possible and relevant (e.g. when discussing trends, comparing categories, or showing breakdowns), "
213
+ "you MUST proactively use the 'generate_interactive_chart' tool to generate visual plots (bar, pie, or line charts) to accompany your analysis. "
214
+ "WARNING: You MUST use the actual tool call to generate the chart. DO NOT simply output a json block with chart parameters as your final text answer."
215
  )
216
 
217
  self.agent = create_agent(
 
224
  async def chat(self, query: str):
225
  config = {"configurable": {"thread_id": "session_1"}}
226
 
227
+ # Clear out any previous chart so we don't carry over stale plots
228
+ chart_path = os.path.join(self.temp_dir, "latest_chart.json")
229
+ if os.path.exists(chart_path):
230
+ os.remove(chart_path)
231
+
232
  result = await self.agent.ainvoke(
233
  {"messages": [{"role": "user", "content": query}]},
234
  config,
 
243
  for block in content:
244
  if isinstance(block, dict) and block.get("type") == "text":
245
  text_parts.append(block.get("text", ""))
246
+ final_text = "\n".join(text_parts)
247
+ else:
248
+ final_text = content
249
+
250
+ # Check if the tool generated a chart file on disk during this turn
251
+ chart_path = os.path.join(self.temp_dir, "latest_chart.json")
252
+ if os.path.exists(chart_path):
253
+ with open(chart_path, "r") as f:
254
+ chart_json = f.read()
255
+ final_text += f"\n\n===CHART===\n{chart_json}\n===ENDCHART==="
256
+
257
+ return final_text
258
 
259
  async def cleanup(self):
260
  """Delete temporary session files and close MCP client."""