File size: 9,107 Bytes
da09833
 
045fba0
 
 
1da2ffa
045fba0
 
7381684
045fba0
 
 
7381684
 
045fba0
 
 
 
 
 
 
 
 
7381684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
045fba0
 
7381684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
045fba0
 
 
 
 
 
 
 
 
7381684
 
 
 
 
 
 
 
 
 
 
045fba0
 
 
 
 
 
 
 
 
 
 
 
 
7381684
 
 
045fba0
 
 
7381684
 
045fba0
 
7381684
045fba0
 
 
7381684
 
 
 
045fba0
7381684
045fba0
 
 
 
 
7381684
045fba0
 
 
 
 
 
 
 
 
 
 
 
 
7381684
 
045fba0
 
 
1da2ffa
7381684
045fba0
7381684
 
 
 
 
045fba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7381684
045fba0
 
7381684
 
 
 
 
045fba0
7381684
045fba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da09833
 
7381684
da09833
7381684
 
 
da09833
7381684
 
 
 
 
 
 
 
 
 
da09833
 
7381684
 
 
 
 
da09833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7381684
045fba0
 
 
1da2ffa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import pandas as pd
import plotly.express as px
from fastmcp import FastMCP
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from dotenv import load_dotenv
import os
from typing import Optional

import shutil

from textwrap import dedent

# Load environment variables (API keys, etc.)
load_dotenv()

# Define paths to your data
DATA_DIR = os.getenv("DATA_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_data"))

# Initialize the MCP Server
mcp = FastMCP("Money RAG Financial Analyst")

import psycopg2
from supabase import create_client, Client

def get_db_connection():
    """Returns a psycopg2 connection to Supabase Postgres."""
    # Supabase provides postgres connection strings, but typically doesn't default in plain OS vars unless you build it
    # Supabase gives a postgres:// connection string in the dashboard under Database Settings.
    # Alternatively we can build it manually or just use the Supabase python client.
    # To support raw LLM SQL, we use psycopg2 instead of Supabase client.
    db_url = os.environ.get("DATABASE_URL")
    if not db_url:
        raise ValueError("DATABASE_URL must be defined to construct raw SQL connections.")
    return psycopg2.connect(db_url)

def get_current_user_id() -> str:
    user_id = os.environ.get("CURRENT_USER_ID")
    if not user_id:
        raise ValueError("CURRENT_USER_ID not injected into MCP environment!")
    return user_id

def get_schema_info() -> str:
    """Get database schema information for Postgres tables."""
    return dedent("""
    Here is the PostgreSQL database schema for the authenticated user's data.
    
    CRITICAL RULE:
    You MUST add `WHERE user_id = '{current_user_id}'` to EVERY SINGLE query you write.
    Never query data without filtering by user_id!
    
    TABLE: "Transaction"
    Columns:
      - id (UUID)
      - user_id (UUID)
      - trans_date (DATE)
      - description (TEXT)
      - amount (DECIMAL)
      - category (VARCHAR)

    TABLE: "TransactionDetail"
    Columns:
      - id (UUID)
      - transaction_id (UUID)
      - item_description (TEXT)
      - item_total_price (DECIMAL)
    """)


@mcp.resource("schema://database/tables")
def get_database_schema() -> str:
    """Complete schema information for the money_rag database."""
    return get_schema_info()

@mcp.tool()
def query_database(query: str) -> str:
    """
    Execute a raw SQL query against the Postgres database.
    The main table is named "Transaction" (you MUST INCLUDE QUOTES in your SQL!).
    IMPORTANT STRICT SCHEMA:
    - id (UUID)
    - user_id (UUID text)
    - trans_date (DATE)
    - description (TEXT)
    - amount (NUMERIC)
    - category (TEXT)
    - enriched_info (TEXT)

    Args:
        query: The SQL SELECT query to execute

    Returns:
        Query results or error message

    Important Notes:
    - Only SELECT queries are allowed (read-only)
    - Use 'description' column for text search 
    - 'amount' column: positive values = spending, negative values = payments/refunds

    Example queries:
    - Find Walmart spending: SELECT SUM(amount) FROM "Transaction" WHERE description LIKE '%Walmart%' AND amount > 0;
    - List recent transactions: SELECT trans_date, description, amount, category FROM "Transaction" ORDER BY trans_date DESC LIMIT 5;
    - Spending by category: SELECT category, SUM(amount) FROM "Transaction" WHERE amount > 0 GROUP BY category;
    """
    # Security: Only allow SELECT queries
    query_upper = query.strip().upper()
    if not query_upper.startswith("SELECT") and not query_upper.startswith("WITH"):
        return "Error: Only SELECT queries are allowed"

    # Forbidden operations
    forbidden = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "REPLACE", "TRUNCATE"]
    if any(f" {word} " in f" {query_upper} " for word in forbidden):
        return f"Error: Query contains forbidden operation. Only SELECT queries allowed."

    user_id = get_current_user_id()
    if user_id not in query:
        return f"Error: You forgot to include the security filter (WHERE user_id = '{user_id}') in your query! Try again."

    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        
        # Get column names to make result more readable
        column_names = [desc[0] for desc in cursor.description] if cursor.description else []
        
        conn.close()

        if not results:
            return "No results found"
            
        # Format results nicely
        formatted_results = []
        formatted_results.append(f"Columns: {', '.join(column_names)}")
        for row in results:
            formatted_results.append(str(row))

        return "\n".join(formatted_results)
    except psycopg2.Error as e:
        return f"Database Error: {str(e)}"

def get_vector_store():
    """Initialize connection to the Qdrant vector store"""
    # Initialize Embedding Model using Google AI Studio
    embeddings = GoogleGenerativeAIEmbeddings(model="gemini-embedding-001")

    # Connect to Qdrant Cloud
    client = QdrantClient(
        url=os.getenv("QDRANT_URL"),
        api_key=os.getenv("QDRANT_API_KEY"),
    )
    
    return QdrantVectorStore(
        client=client,
        collection_name="transactions",
        embedding=embeddings,
    )

@mcp.tool()
def semantic_search(query: str, top_k: int = 5) -> str:
    """
    Search for personal financial transactions semantically.
    
    Use this to find spending when specific merchant names are unknown or ambiguous.
    Examples: "how much did I spend on fast food?", "subscriptions", "travel expenses".
    
    Args:
        query: The description or category of spending to look for.
        top_k: Number of results to return (default 5).
    """
    try:
        user_id = get_current_user_id()
        vector_store = get_vector_store()
        
        # Apply strict multi-tenant filtering based on the payload we injected in money_rag.py
        from qdrant_client.http import models
        filter = models.Filter(
            must=[models.FieldCondition(key="metadata.user_id", match=models.MatchValue(value=user_id))]
        )

        results = vector_store.similarity_search(query, k=top_k, filter=filter)
        
        if not results:
            return "No matching transactions found."
            
        output = []
        for doc in results:
            amount = doc.metadata.get('amount', 'N/A')
            date = doc.metadata.get('transaction_date', 'N/A')
            output.append(f"Date: {date} | Match: {doc.page_content} | Amount: {amount}")
            
        return "\n".join(output)
        
    except Exception as e:
        return f"Error performing search: {str(e)}"


@mcp.tool()
def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_col: str, title: str, color_col: Optional[str] = None) -> str:
    """
    Generate an interactive Plotly chart using SQL data.
    IMPORTANT: The table name MUST be "Transaction" exactly with quotes.

    Args:
        sql_query: The SQL SELECT query to retrieve the data for the chart from the "Transaction" table.
            - Must use 'user_id' filter.
        chart_type: The type of chart: 'bar', 'line', 'pie', 'scatter'
        x_col: The name of the column to use for the X axis (or labels for pie charts)
        y_col: The name of the column to use for the Y axis (or values for pie charts)
        title: The title of the chart
        color_col: (Optional) Column to use for color grouping

    Returns:
        A natural language summary confirming chart generation.
    """
    try:
        user_id = get_current_user_id()
        if user_id not in sql_query:
            return f'{{"error": "You forgot the WHERE user_id = \\"{user_id}\\" security clause!"}}'
            
        conn = get_db_connection()
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        if df.empty:
            return '{"error": "No data found for this query."}'
        if chart_type == "bar":
            fig = px.bar(df, x=x_col, y=y_col, title=title)
        elif chart_type == "pie":
            fig = px.pie(df, names=x_col, values=y_col, title=title)
        elif chart_type == "line":
            fig = px.line(df, x=x_col, y=y_col, title=title)
        else:
            return f'{{"error": "Unsupported chart type: {chart_type}"}}'
        # Write the huge JSON to a temp file instead of returning it directly to LLM context
        chart_path = os.path.join(DATA_DIR, "latest_chart.json")
        with open(chart_path, "w") as f:
            f.write(fig.to_json())
            
        return "Chart generated successfully! It has been sent to the user's UI. Continue analyzing without outputting the JSON parameters directly."
        
    except Exception as e:
        return f'{{"error": "Failed to generate chart: {str(e)}"}}'




if __name__ == "__main__":
    # Runs the server over stdio
    mcp.run(transport="stdio")