financial-advisor / mcp_server.py
Sajil Awale
added multi user auth feture in fin adv
7381684
import pandas as pd
import plotly.express as px
from fastmcp import FastMCP
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from dotenv import load_dotenv
import os
from typing import Optional
import shutil
from textwrap import dedent
# Load environment variables (API keys, etc.)
load_dotenv()
# Define paths to your data
DATA_DIR = os.getenv("DATA_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_data"))
# Initialize the MCP Server
mcp = FastMCP("Money RAG Financial Analyst")
import psycopg2
from supabase import create_client, Client
def get_db_connection():
"""Returns a psycopg2 connection to Supabase Postgres."""
# Supabase provides postgres connection strings, but typically doesn't default in plain OS vars unless you build it
# Supabase gives a postgres:// connection string in the dashboard under Database Settings.
# Alternatively we can build it manually or just use the Supabase python client.
# To support raw LLM SQL, we use psycopg2 instead of Supabase client.
db_url = os.environ.get("DATABASE_URL")
if not db_url:
raise ValueError("DATABASE_URL must be defined to construct raw SQL connections.")
return psycopg2.connect(db_url)
def get_current_user_id() -> str:
user_id = os.environ.get("CURRENT_USER_ID")
if not user_id:
raise ValueError("CURRENT_USER_ID not injected into MCP environment!")
return user_id
def get_schema_info() -> str:
"""Get database schema information for Postgres tables."""
return dedent("""
Here is the PostgreSQL database schema for the authenticated user's data.
CRITICAL RULE:
You MUST add `WHERE user_id = '{current_user_id}'` to EVERY SINGLE query you write.
Never query data without filtering by user_id!
TABLE: "Transaction"
Columns:
- id (UUID)
- user_id (UUID)
- trans_date (DATE)
- description (TEXT)
- amount (DECIMAL)
- category (VARCHAR)
TABLE: "TransactionDetail"
Columns:
- id (UUID)
- transaction_id (UUID)
- item_description (TEXT)
- item_total_price (DECIMAL)
""")
@mcp.resource("schema://database/tables")
def get_database_schema() -> str:
"""Complete schema information for the money_rag database."""
return get_schema_info()
@mcp.tool()
def query_database(query: str) -> str:
"""
Execute a raw SQL query against the Postgres database.
The main table is named "Transaction" (you MUST INCLUDE QUOTES in your SQL!).
IMPORTANT STRICT SCHEMA:
- id (UUID)
- user_id (UUID text)
- trans_date (DATE)
- description (TEXT)
- amount (NUMERIC)
- category (TEXT)
- enriched_info (TEXT)
Args:
query: The SQL SELECT query to execute
Returns:
Query results or error message
Important Notes:
- Only SELECT queries are allowed (read-only)
- Use 'description' column for text search
- 'amount' column: positive values = spending, negative values = payments/refunds
Example queries:
- Find Walmart spending: SELECT SUM(amount) FROM "Transaction" WHERE description LIKE '%Walmart%' AND amount > 0;
- List recent transactions: SELECT trans_date, description, amount, category FROM "Transaction" ORDER BY trans_date DESC LIMIT 5;
- Spending by category: SELECT category, SUM(amount) FROM "Transaction" WHERE amount > 0 GROUP BY category;
"""
# Security: Only allow SELECT queries
query_upper = query.strip().upper()
if not query_upper.startswith("SELECT") and not query_upper.startswith("WITH"):
return "Error: Only SELECT queries are allowed"
# Forbidden operations
forbidden = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "REPLACE", "TRUNCATE"]
if any(f" {word} " in f" {query_upper} " for word in forbidden):
return f"Error: Query contains forbidden operation. Only SELECT queries allowed."
user_id = get_current_user_id()
if user_id not in query:
return f"Error: You forgot to include the security filter (WHERE user_id = '{user_id}') in your query! Try again."
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(query)
results = cursor.fetchall()
# Get column names to make result more readable
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
conn.close()
if not results:
return "No results found"
# Format results nicely
formatted_results = []
formatted_results.append(f"Columns: {', '.join(column_names)}")
for row in results:
formatted_results.append(str(row))
return "\n".join(formatted_results)
except psycopg2.Error as e:
return f"Database Error: {str(e)}"
def get_vector_store():
"""Initialize connection to the Qdrant vector store"""
# Initialize Embedding Model using Google AI Studio
embeddings = GoogleGenerativeAIEmbeddings(model="gemini-embedding-001")
# Connect to Qdrant Cloud
client = QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
return QdrantVectorStore(
client=client,
collection_name="transactions",
embedding=embeddings,
)
@mcp.tool()
def semantic_search(query: str, top_k: int = 5) -> str:
"""
Search for personal financial transactions semantically.
Use this to find spending when specific merchant names are unknown or ambiguous.
Examples: "how much did I spend on fast food?", "subscriptions", "travel expenses".
Args:
query: The description or category of spending to look for.
top_k: Number of results to return (default 5).
"""
try:
user_id = get_current_user_id()
vector_store = get_vector_store()
# Apply strict multi-tenant filtering based on the payload we injected in money_rag.py
from qdrant_client.http import models
filter = models.Filter(
must=[models.FieldCondition(key="metadata.user_id", match=models.MatchValue(value=user_id))]
)
results = vector_store.similarity_search(query, k=top_k, filter=filter)
if not results:
return "No matching transactions found."
output = []
for doc in results:
amount = doc.metadata.get('amount', 'N/A')
date = doc.metadata.get('transaction_date', 'N/A')
output.append(f"Date: {date} | Match: {doc.page_content} | Amount: {amount}")
return "\n".join(output)
except Exception as e:
return f"Error performing search: {str(e)}"
@mcp.tool()
def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_col: str, title: str, color_col: Optional[str] = None) -> str:
"""
Generate an interactive Plotly chart using SQL data.
IMPORTANT: The table name MUST be "Transaction" exactly with quotes.
Args:
sql_query: The SQL SELECT query to retrieve the data for the chart from the "Transaction" table.
- Must use 'user_id' filter.
chart_type: The type of chart: 'bar', 'line', 'pie', 'scatter'
x_col: The name of the column to use for the X axis (or labels for pie charts)
y_col: The name of the column to use for the Y axis (or values for pie charts)
title: The title of the chart
color_col: (Optional) Column to use for color grouping
Returns:
A natural language summary confirming chart generation.
"""
try:
user_id = get_current_user_id()
if user_id not in sql_query:
return f'{{"error": "You forgot the WHERE user_id = \\"{user_id}\\" security clause!"}}'
conn = get_db_connection()
df = pd.read_sql_query(sql_query, conn)
conn.close()
if df.empty:
return '{"error": "No data found for this query."}'
if chart_type == "bar":
fig = px.bar(df, x=x_col, y=y_col, title=title)
elif chart_type == "pie":
fig = px.pie(df, names=x_col, values=y_col, title=title)
elif chart_type == "line":
fig = px.line(df, x=x_col, y=y_col, title=title)
else:
return f'{{"error": "Unsupported chart type: {chart_type}"}}'
# Write the huge JSON to a temp file instead of returning it directly to LLM context
chart_path = os.path.join(DATA_DIR, "latest_chart.json")
with open(chart_path, "w") as f:
f.write(fig.to_json())
return "Chart generated successfully! It has been sent to the user's UI. Continue analyzing without outputting the JSON parameters directly."
except Exception as e:
return f'{{"error": "Failed to generate chart: {str(e)}"}}'
if __name__ == "__main__":
# Runs the server over stdio
mcp.run(transport="stdio")