Spaces:
Sleeping
Sleeping
File size: 7,134 Bytes
f204be9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | from fastmcp import FastMCP
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from dotenv import load_dotenv
import os
import shutil
# Load environment variables (API keys, etc.)
load_dotenv()
# Define paths to your data
# For Hugging Face Spaces (Ephemeral):
# We use a temporary directory that gets wiped on restart.
# If DATA_DIR is set (e.g., by your deployment config), use it.
DATA_DIR = os.getenv("DATA_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_data"))
QDRANT_PATH = os.path.join(DATA_DIR, "qdrant_db")
DB_PATH = os.path.join(DATA_DIR, "money_rag.db")
# Initialize the MCP Server
mcp = FastMCP("Money RAG Financial Analyst")
import sqlite3
def get_schema_info() -> str:
"""Get database schema information."""
if not os.path.exists(DB_PATH):
return "Database file does not exist yet. Please upload data."
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Get all tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema_info = []
for (table_name,) in tables:
schema_info.append(f"\nTable: {table_name}")
# Get column info for each table
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
schema_info.append("Columns:")
for col in columns:
col_id, col_name, col_type, not_null, default_val, pk = col
schema_info.append(f" - {col_name} ({col_type})")
conn.close()
return "\n".join(schema_info)
except Exception as e:
return f"Error reading schema: {e}"
@mcp.resource("schema://database/tables")
def get_database_schema() -> str:
"""Complete schema information for the money_rag database."""
return get_schema_info()
@mcp.tool()
def query_database(query: str) -> str:
"""Execute a SELECT query on the money_rag SQLite database.
Args:
query: The SQL SELECT query to execute
Returns:
Query results or error message
Important Notes:
- Only SELECT queries are allowed (read-only)
- Use 'description' column for text search
- 'amount' column: positive values = spending, negative values = payments/refunds
Example queries:
- Find Walmart spending: SELECT SUM(amount) FROM transactions WHERE description LIKE '%Walmart%' AND amount > 0;
- List recent transactions: SELECT transaction_date, description, amount, category FROM transactions ORDER BY transaction_date DESC LIMIT 5;
- Spending by category: SELECT category, SUM(amount) FROM transactions WHERE amount > 0 GROUP BY category;
"""
if not os.path.exists(DB_PATH):
return "Database file does not exist yet. Please upload data."
# Security: Only allow SELECT queries
query_upper = query.strip().upper()
if not query_upper.startswith("SELECT") and not query_upper.startswith("PRAGMA"):
return "Error: Only SELECT and PRAGMA queries are allowed"
# Forbidden operations
forbidden = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "REPLACE", "TRUNCATE", "ATTACH", "DETACH"]
# Check for forbidden words as standalone words to avoid false positives (e.g. "update_date" column)
# Simple check: space-surrounded or end-of-string
if any(f" {word} " in f" {query_upper} " for word in forbidden):
return f"Error: Query contains forbidden operation. Only SELECT queries allowed."
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(query)
results = cursor.fetchall()
# Get column names to make result more readable
column_names = [description[0] for description in cursor.description] if cursor.description else []
conn.close()
if not results:
return "No results found"
# Format results nicely
formatted_results = []
formatted_results.append(f"Columns: {', '.join(column_names)}")
for row in results:
formatted_results.append(str(row))
return "\n".join(formatted_results)
except sqlite3.Error as e:
return f"Error: {str(e)}"
def get_vector_store():
"""Initialize connection to the Qdrant vector store"""
# Initialize Embedding Model using Google AI Studio
embeddings = GoogleGenerativeAIEmbeddings(model="text-embedding-004")
# Connect to Qdrant (Persistent Disk Mode at specific path)
# We ensure the directory exists so Qdrant can write to it.
os.makedirs(QDRANT_PATH, exist_ok=True)
client = QdrantClient(path=QDRANT_PATH)
# Check if collection exists (it might be empty in a new ephemeral session)
collections = client.get_collections().collections
collection_names = [c.name for c in collections]
if "transactions" not in collection_names:
# In a real app, you would probably trigger ingestion here or handle the empty state
pass
return QdrantVectorStore(
client=client,
collection_name="transactions",
embedding=embeddings,
)
@mcp.tool()
def semantic_search(query: str, top_k: int = 5) -> str:
"""
Search for personal financial transactions semantically.
Use this to find spending when specific merchant names are unknown or ambiguous.
Examples: "how much did I spend on fast food?", "subscriptions", "travel expenses".
Args:
query: The description or category of spending to look for.
top_k: Number of results to return (default 5).
"""
try:
vector_store = get_vector_store()
# Safety check: if no data has been ingested yet
if not os.path.exists(QDRANT_PATH) or not os.listdir(QDRANT_PATH):
return "No matching transactions found (Database is empty. Please upload data first)."
results = vector_store.similarity_search(query, k=top_k)
if not results:
return "No matching transactions found."
output = []
for doc in results:
# Format the output clearly for the LLM/User
amount = doc.metadata.get('amount', 'N/A')
date = doc.metadata.get('transaction_date', 'N/A')
output.append(f"Date: {date} | Match: {doc.page_content} | Amount: {amount}")
return "\n".join(output)
except Exception as e:
return f"Error performing search: {str(e)}"
# A helper to clear data (useful for session reset)
@mcp.tool()
def clear_database() -> str:
"""Clear all stored transaction data to reset the session."""
try:
if os.path.exists(DATA_DIR):
shutil.rmtree(DATA_DIR)
os.makedirs(DATA_DIR)
return "Database cleared successfully."
except Exception as e:
return f"Error clearing database: {e}"
if __name__ == "__main__":
# Runs the server over stdio
mcp.run(transport="stdio")
|