Spaces:
Sleeping
Sleeping
Sajil Awale commited on
Commit ·
7381684
1
Parent(s): dcb0e39
added multi user auth feture in fin adv
Browse files- app.py +550 -114
- mcp_server.py +104 -96
- money_rag.py +159 -83
- requirements.txt +6 -0
app.py
CHANGED
|
@@ -3,129 +3,565 @@ import asyncio
|
|
| 3 |
import os
|
| 4 |
import json
|
| 5 |
import plotly.io as pio
|
|
|
|
|
|
|
|
|
|
| 6 |
from money_rag import MoneyRAG
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
# Sidebar for Authentication
|
| 11 |
-
with st.sidebar:
|
| 12 |
-
st.header("Authentication")
|
| 13 |
-
provider = st.selectbox("LLM Provider", ["Google", "OpenAI"])
|
| 14 |
-
|
| 15 |
-
if provider == "Google":
|
| 16 |
-
models = ["gemini-3-flash-preview", "gemini-3-pro-image-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"]
|
| 17 |
-
embeddings = ["gemini-embedding-001"]
|
| 18 |
-
else:
|
| 19 |
-
models = ["gpt-5-mini", "gpt-5-nano", "gpt-4o-mini", "gpt-4o"]
|
| 20 |
-
embeddings = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
|
| 21 |
-
|
| 22 |
-
model_name = st.selectbox("Choose Decoder Model", models)
|
| 23 |
-
embed_name = st.selectbox("Choose Embedding Model", embeddings)
|
| 24 |
-
api_key = st.text_input("API Key", type="password")
|
| 25 |
-
|
| 26 |
-
auth_button = st.button("Authenticate")
|
| 27 |
-
if auth_button and api_key:
|
| 28 |
-
st.session_state.rag = MoneyRAG(provider, model_name, embed_name, api_key)
|
| 29 |
-
st.success("Authenticated!")
|
| 30 |
-
|
| 31 |
st.divider()
|
| 32 |
-
st.
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
st.
|
| 39 |
-
st.markdown(""
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
""
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
with
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
st.video("https://www.youtube.com/watch?v=gtAFaP9Lts8")
|
| 59 |
-
st.markdown("")
|
| 60 |
-
st.markdown("**Discover Credit Card:**")
|
| 61 |
-
st.video("https://www.youtube.com/watch?v=cry6-H5b0PQ")
|
| 62 |
-
|
| 63 |
-
# Architecture Diagram
|
| 64 |
-
with st.expander("🏗️ How MoneyRAG Works"):
|
| 65 |
-
st.image("architecture.svg", use_container_width=True)
|
| 66 |
-
|
| 67 |
-
st.divider()
|
| 68 |
-
|
| 69 |
-
if "rag" in st.session_state:
|
| 70 |
-
uploaded_files = st.file_uploader("Upload CSV transactions", accept_multiple_files=True, type=['csv'])
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
st.session_state.messages = []
|
| 89 |
-
|
| 90 |
-
# Helper function to cleverly render either text or a Plotly chart
|
| 91 |
-
def render_content(content):
|
| 92 |
-
# We might have mixed text and charts delimited by ===CHART=== ... ===ENDCHART===
|
| 93 |
-
if isinstance(content, str) and "===CHART===" in content:
|
| 94 |
-
parts = content.split("===CHART===")
|
| 95 |
-
# Render first text part
|
| 96 |
-
st.markdown(parts[0].strip())
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
try:
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
except Exception as e:
|
| 105 |
-
st.error("Failed to
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
import json
|
| 5 |
import plotly.io as pio
|
| 6 |
+
from supabase import create_client, Client, ClientOptions
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
from money_rag import MoneyRAG
|
| 10 |
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
st.set_page_config(page_title="MoneyRAG", layout="wide", initial_sidebar_state="expanded")
|
| 14 |
+
|
| 15 |
+
# Initialize Supabase Client per request (NO CACHE) to ensure thread-safe auth headers
|
| 16 |
+
def get_supabase() -> Client:
|
| 17 |
+
url = os.environ.get("SUPABASE_URL")
|
| 18 |
+
key = os.environ.get("SUPABASE_KEY")
|
| 19 |
+
if "access_token" in st.session_state:
|
| 20 |
+
opts = ClientOptions(headers={"Authorization": f"Bearer {st.session_state.access_token}"})
|
| 21 |
+
return create_client(url, key, options=opts)
|
| 22 |
+
return create_client(url, key)
|
| 23 |
+
|
| 24 |
+
supabase = get_supabase()
|
| 25 |
+
|
| 26 |
+
def inject_css():
|
| 27 |
+
st.html("""
|
| 28 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 29 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap" rel="stylesheet">
|
| 30 |
+
<style>
|
| 31 |
+
/* ── Global Reset & Font ── */
|
| 32 |
+
html, body, [class*="css"] {
|
| 33 |
+
font-family: 'Inter', sans-serif !important;
|
| 34 |
+
}
|
| 35 |
+
#MainMenu, footer, header { visibility: hidden; }
|
| 36 |
+
.block-container { padding-top: 2rem !important; }
|
| 37 |
+
|
| 38 |
+
/* ── Background ── */
|
| 39 |
+
.stApp {
|
| 40 |
+
background: #0a0a0f;
|
| 41 |
+
color: #e2e8f0;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
/* ── Sidebar ── */
|
| 45 |
+
[data-testid="stSidebar"] {
|
| 46 |
+
background: linear-gradient(180deg, #0f0f1a 0%, #0d0d16 100%) !important;
|
| 47 |
+
border-right: 1px solid rgba(99,102,241,0.15) !important;
|
| 48 |
+
}
|
| 49 |
+
[data-testid="stSidebar"] * { color: #cbd5e1 !important; }
|
| 50 |
+
|
| 51 |
+
/* ── Nav buttons ── */
|
| 52 |
+
div[data-testid="stSidebarContent"] .nav-btn > div > button {
|
| 53 |
+
width: 100% !important;
|
| 54 |
+
text-align: left !important;
|
| 55 |
+
border: none !important;
|
| 56 |
+
border-radius: 10px !important;
|
| 57 |
+
background: transparent !important;
|
| 58 |
+
color: #94a3b8 !important;
|
| 59 |
+
padding: 0.65rem 1rem !important;
|
| 60 |
+
font-size: 0.9rem !important;
|
| 61 |
+
font-weight: 500 !important;
|
| 62 |
+
transition: all 0.2s ease !important;
|
| 63 |
+
margin-bottom: 2px !important;
|
| 64 |
+
}
|
| 65 |
+
div[data-testid="stSidebarContent"] .nav-btn > div > button:hover {
|
| 66 |
+
background: rgba(99,102,241,0.1) !important;
|
| 67 |
+
color: #a5b4fc !important;
|
| 68 |
+
}
|
| 69 |
+
div[data-testid="stSidebarContent"] .nav-btn-active > div > button {
|
| 70 |
+
background: linear-gradient(135deg, rgba(99,102,241,0.25), rgba(139,92,246,0.2)) !important;
|
| 71 |
+
color: #a5b4fc !important;
|
| 72 |
+
border: 1px solid rgba(99,102,241,0.3) !important;
|
| 73 |
+
font-weight: 600 !important;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
/* ── Primary Buttons ── */
|
| 77 |
+
.stButton > button[kind="primary"] {
|
| 78 |
+
background: linear-gradient(135deg, #6366f1, #8b5cf6) !important;
|
| 79 |
+
border: none !important;
|
| 80 |
+
border-radius: 10px !important;
|
| 81 |
+
color: white !important;
|
| 82 |
+
font-weight: 600 !important;
|
| 83 |
+
padding: 0.6rem 1.2rem !important;
|
| 84 |
+
transition: all 0.2s ease !important;
|
| 85 |
+
box-shadow: 0 4px 15px rgba(99,102,241,0.3) !important;
|
| 86 |
+
}
|
| 87 |
+
.stButton > button[kind="primary"]:hover {
|
| 88 |
+
transform: translateY(-1px) !important;
|
| 89 |
+
box-shadow: 0 6px 20px rgba(99,102,241,0.45) !important;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
/* ── Secondary Buttons ── */
|
| 93 |
+
.stButton > button[kind="secondary"] {
|
| 94 |
+
background: rgba(255,255,255,0.05) !important;
|
| 95 |
+
border: 1px solid rgba(255,255,255,0.1) !important;
|
| 96 |
+
border-radius: 10px !important;
|
| 97 |
+
color: #cbd5e1 !important;
|
| 98 |
+
font-weight: 500 !important;
|
| 99 |
+
transition: all 0.2s ease !important;
|
| 100 |
+
}
|
| 101 |
+
.stButton > button[kind="secondary"]:hover {
|
| 102 |
+
background: rgba(255,255,255,0.08) !important;
|
| 103 |
+
border-color: rgba(99,102,241,0.35) !important;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
/* ── Inputs ── */
|
| 107 |
+
.stTextInput input, .stSelectbox > div > div {
|
| 108 |
+
background: rgba(255,255,255,0.04) !important;
|
| 109 |
+
border: 1px solid rgba(255,255,255,0.1) !important;
|
| 110 |
+
border-radius: 10px !important;
|
| 111 |
+
color: #e2e8f0 !important;
|
| 112 |
+
transition: border 0.2s ease !important;
|
| 113 |
+
}
|
| 114 |
+
.stTextInput input:focus { border-color: #6366f1 !important; box-shadow: 0 0 0 2px rgba(99,102,241,0.2) !important; }
|
| 115 |
+
|
| 116 |
+
/* ── Glass Cards ── */
|
| 117 |
+
.glass-card {
|
| 118 |
+
background: rgba(255,255,255,0.04);
|
| 119 |
+
border: 1px solid rgba(255,255,255,0.08);
|
| 120 |
+
border-radius: 16px;
|
| 121 |
+
padding: 1.75rem;
|
| 122 |
+
backdrop-filter: blur(12px);
|
| 123 |
+
transition: border 0.2s ease;
|
| 124 |
+
}
|
| 125 |
+
.glass-card:hover { border-color: rgba(99,102,241,0.25); }
|
| 126 |
+
|
| 127 |
+
/* ── Hero ── */
|
| 128 |
+
.hero { text-align: center; padding: 4rem 1rem 2rem; }
|
| 129 |
+
.hero .badge {
|
| 130 |
+
display: inline-block;
|
| 131 |
+
background: linear-gradient(135deg, rgba(99,102,241,0.2), rgba(139,92,246,0.2));
|
| 132 |
+
border: 1px solid rgba(99,102,241,0.35);
|
| 133 |
+
color: #a5b4fc;
|
| 134 |
+
font-size: 0.78rem;
|
| 135 |
+
font-weight: 600;
|
| 136 |
+
letter-spacing: 0.1em;
|
| 137 |
+
text-transform: uppercase;
|
| 138 |
+
padding: 0.3rem 0.9rem;
|
| 139 |
+
border-radius: 99px;
|
| 140 |
+
margin-bottom: 1.25rem;
|
| 141 |
+
}
|
| 142 |
+
.hero h1 {
|
| 143 |
+
font-size: clamp(2.5rem, 6vw, 4rem);
|
| 144 |
+
font-weight: 800;
|
| 145 |
+
letter-spacing: -2px;
|
| 146 |
+
line-height: 1.1;
|
| 147 |
+
background: linear-gradient(135deg, #e2e8f0 30%, #a5b4fc);
|
| 148 |
+
-webkit-background-clip: text;
|
| 149 |
+
-webkit-text-fill-color: transparent;
|
| 150 |
+
margin-bottom: 1rem;
|
| 151 |
+
}
|
| 152 |
+
.hero p { font-size: 1.1rem; color: #64748b; max-width: 440px; margin: 0 auto; line-height: 1.7; }
|
| 153 |
+
|
| 154 |
+
/* ── Divider ── */
|
| 155 |
+
hr { border-color: rgba(255,255,255,0.07) !important; }
|
| 156 |
+
|
| 157 |
+
/* ── Expanders ── */
|
| 158 |
+
[data-testid="stExpander"] {
|
| 159 |
+
background: rgba(255,255,255,0.03) !important;
|
| 160 |
+
border: 1px solid rgba(255,255,255,0.07) !important;
|
| 161 |
+
border-radius: 12px !important;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
/* ── Alerts ── */
|
| 165 |
+
[data-testid="stAlert"] { border-radius: 10px !important; }
|
| 166 |
+
|
| 167 |
+
/* ── Chat bubbles ── */
|
| 168 |
+
[data-testid="stChatMessage"] { border-radius: 12px !important; }
|
| 169 |
+
</style>
|
| 170 |
+
""")
|
| 171 |
+
|
| 172 |
+
def login_register_page():
|
| 173 |
+
inject_css()
|
| 174 |
+
|
| 175 |
+
st.html("""
|
| 176 |
+
<div class="hero">
|
| 177 |
+
<div class="badge">✦ AI-Powered Finance</div>
|
| 178 |
+
<h1>MoneyRAG</h1>
|
| 179 |
+
<p>Your personal finance analyst. Upload bank statements, ask questions, get insights — powered by AI.</p>
|
| 180 |
+
</div>
|
| 181 |
+
""")
|
| 182 |
+
|
| 183 |
+
col_l, col1, col2, col_r = st.columns([1, 2, 2, 1])
|
| 184 |
+
|
| 185 |
+
with col1:
|
| 186 |
+
st.markdown('<div class="glass-card">', unsafe_allow_html=True)
|
| 187 |
+
st.markdown("### Sign In")
|
| 188 |
+
email = st.text_input("Email", key="login_email", placeholder="you@example.com", label_visibility="collapsed")
|
| 189 |
+
password = st.text_input("Password", type="password", key="login_pass", placeholder="Password", label_visibility="collapsed")
|
| 190 |
+
if st.button("Sign In →", use_container_width=True, type="primary"):
|
| 191 |
+
if email and password:
|
| 192 |
+
with st.spinner(""):
|
| 193 |
+
try:
|
| 194 |
+
res = supabase.auth.sign_in_with_password({"email": email, "password": password})
|
| 195 |
+
st.session_state.user = res.user
|
| 196 |
+
st.session_state.access_token = res.session.access_token
|
| 197 |
+
st.query_params["t"] = res.session.access_token
|
| 198 |
+
try:
|
| 199 |
+
supabase.table("User").upsert({
|
| 200 |
+
"id": res.user.id,
|
| 201 |
+
"email": email,
|
| 202 |
+
"hashed_password": "managed_by_supabase_auth"
|
| 203 |
+
}).execute()
|
| 204 |
+
except Exception as sync_e:
|
| 205 |
+
print(f"Warning: Could not sync user: {sync_e}")
|
| 206 |
+
st.rerun()
|
| 207 |
+
except Exception as e:
|
| 208 |
+
st.error(f"Login failed: {e}")
|
| 209 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 210 |
+
|
| 211 |
+
with col2:
|
| 212 |
+
st.markdown('<div class="glass-card">', unsafe_allow_html=True)
|
| 213 |
+
st.markdown("### Create Account")
|
| 214 |
+
reg_email = st.text_input("Email", key="reg_email", placeholder="you@example.com", label_visibility="collapsed")
|
| 215 |
+
reg_password = st.text_input("Password", type="password", key="reg_pass", placeholder="Password", label_visibility="collapsed")
|
| 216 |
+
if st.button("Create Account →", use_container_width=True):
|
| 217 |
+
if reg_email and reg_password:
|
| 218 |
+
with st.spinner(""):
|
| 219 |
+
try:
|
| 220 |
+
res = supabase.auth.sign_up({"email": reg_email, "password": reg_password})
|
| 221 |
+
if res.user:
|
| 222 |
+
try:
|
| 223 |
+
supabase.table("User").upsert({
|
| 224 |
+
"id": res.user.id, "email": reg_email,
|
| 225 |
+
"hashed_password": "managed_by_supabase_auth"
|
| 226 |
+
}).execute()
|
| 227 |
+
except Exception:
|
| 228 |
+
pass
|
| 229 |
+
st.success("Account created! Sign in on the left.")
|
| 230 |
+
except Exception as e:
|
| 231 |
+
st.error(f"Signup failed: {str(e)}")
|
| 232 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
st.divider()
|
| 235 |
+
col3, col4, col5 = st.columns(3)
|
| 236 |
+
with col3:
|
| 237 |
+
with st.expander("📚 API Keys"):
|
| 238 |
+
st.markdown("**Google:** [AI Studio](https://aistudio.google.com/app/apikey)")
|
| 239 |
+
st.markdown("**OpenAI:** [Platform](https://platform.openai.com/api-keys)")
|
| 240 |
+
with col4:
|
| 241 |
+
with st.expander("📥 Export Transactions"):
|
| 242 |
+
st.markdown("**Chase:** [Video guide](https://www.youtube.com/watch?v=gtAFaP9Lts8)")
|
| 243 |
+
st.markdown("**Discover:** [Video guide](https://www.youtube.com/watch?v=cry6-H5b0PQ)")
|
| 244 |
+
with col5:
|
| 245 |
+
with st.expander("🏗️ Architecture"):
|
| 246 |
+
st.image("architecture.svg", use_container_width=True)
|
| 247 |
+
|
| 248 |
+
def load_user_config():
|
| 249 |
+
try:
|
| 250 |
+
# Always get a fresh client with the current auth token
|
| 251 |
+
client = get_supabase()
|
| 252 |
+
res = client.table("AccountConfig").select("*").eq("user_id", st.session_state.user.id).execute()
|
| 253 |
+
if res.data:
|
| 254 |
+
return res.data[0]
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"Failed to load config: {e}")
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
def main_app_view():
|
| 260 |
+
inject_css()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
+
# Use session state for active nav tab
|
| 263 |
+
if "nav" not in st.session_state:
|
| 264 |
+
st.session_state.nav = "Chat"
|
| 265 |
+
|
| 266 |
+
with st.sidebar:
|
| 267 |
+
st.markdown(f"**MoneyRAG** 💰")
|
| 268 |
+
st.caption(st.session_state.user.email)
|
| 269 |
+
st.divider()
|
| 270 |
+
|
| 271 |
+
# Modern nav buttons using st.button styled via CSS
|
| 272 |
+
for label, icon in [("Chat", "💬"), ("Ingest Data", "📥"), ("Account Config", "⚙️")]:
|
| 273 |
+
is_active = st.session_state.nav == label
|
| 274 |
+
css_class = "nav-btn-active" if is_active else "nav-btn"
|
| 275 |
+
st.markdown(f'<div class="{css_class}">', unsafe_allow_html=True)
|
| 276 |
+
if st.button(f"{icon} {label}", key=f"nav_{label}", use_container_width=True):
|
| 277 |
+
st.session_state.nav = label
|
| 278 |
+
st.rerun()
|
| 279 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 280 |
+
|
| 281 |
+
st.divider()
|
| 282 |
+
if st.button("Log Out", use_container_width=True):
|
| 283 |
+
supabase.auth.sign_out()
|
| 284 |
+
if "t" in st.query_params:
|
| 285 |
+
del st.query_params["t"]
|
| 286 |
+
for key in list(st.session_state.keys()):
|
| 287 |
+
del st.session_state[key]
|
| 288 |
+
st.rerun()
|
| 289 |
+
|
| 290 |
+
st.divider()
|
| 291 |
+
st.caption("[Sajil Awale](https://github.com/AwaleSajil) · [Simran KC](https://github.com/iamsims)")
|
| 292 |
+
|
| 293 |
+
nav = st.session_state.nav
|
| 294 |
+
|
| 295 |
+
# Always reload config fresh (cached None from unauthenticated loads will persist otherwise)
|
| 296 |
+
config = load_user_config()
|
| 297 |
+
|
| 298 |
+
if nav == "Account Config":
|
| 299 |
+
st.header("⚙️ Account Configuration")
|
| 300 |
+
st.write("Configure your AI providers and models here.")
|
| 301 |
+
|
| 302 |
+
current_provider = config['llm_provider'] if config else "Google"
|
| 303 |
+
current_key = config['api_key'] if config else ""
|
| 304 |
+
current_decode = config.get('decode_model', "gemini-3-flash-preview") if config else "gemini-3-flash-preview"
|
| 305 |
+
current_embed = config.get('embedding_model', "gemini-embedding-001") if config else "gemini-embedding-001"
|
| 306 |
+
# Provider Selection - Default to Google
|
| 307 |
+
provider = st.selectbox("LLM Provider", ["Google", "OpenAI"], index=0 if (not config or config['llm_provider'] == "Google") else 1)
|
| 308 |
+
|
| 309 |
+
if provider == "Google":
|
| 310 |
+
models = ["gemini-3-flash-preview", "gemini-3-pro-image-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"]
|
| 311 |
+
embeddings = ["gemini-embedding-001"]
|
| 312 |
+
else:
|
| 313 |
+
models = ["gpt-5-mini", "gpt-5-nano", "gpt-4o-mini", "gpt-4o"]
|
| 314 |
+
embeddings = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
|
| 315 |
+
|
| 316 |
+
with st.form("config_form"):
|
| 317 |
+
api_key = st.text_input("API Key", type="password", value=current_key)
|
| 318 |
|
| 319 |
+
col1, col2 = st.columns(2)
|
| 320 |
+
with col1:
|
| 321 |
+
# Default to gemini-3 if no config exists
|
| 322 |
+
m_default_val = current_decode if config else "gemini-3-flash-preview"
|
| 323 |
+
m_idx = models.index(m_default_val) if m_default_val in models else 0
|
| 324 |
+
final_decode = st.selectbox("Select Model", models, index=m_idx)
|
| 325 |
|
| 326 |
+
with col2:
|
| 327 |
+
e_idx = embeddings.index(current_embed) if (config and current_embed in embeddings) else 0
|
| 328 |
+
final_embed = st.selectbox("Select Embedding Model", embeddings, index=e_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
+
submitted = st.form_submit_button("Save Configuration", type="primary", use_container_width=True)
|
| 331 |
+
if submitted:
|
| 332 |
+
if not api_key:
|
| 333 |
+
st.error("API Key is required.")
|
| 334 |
+
else:
|
| 335 |
try:
|
| 336 |
+
record = {
|
| 337 |
+
"user_id": st.session_state.user.id,
|
| 338 |
+
"llm_provider": provider,
|
| 339 |
+
"api_key": api_key,
|
| 340 |
+
"decode_model": final_decode,
|
| 341 |
+
"embedding_model": final_embed
|
| 342 |
+
}
|
| 343 |
+
if config:
|
| 344 |
+
supabase.table("AccountConfig").update(record).eq("id", config['id']).execute()
|
| 345 |
+
else:
|
| 346 |
+
supabase.table("AccountConfig").insert(record).execute()
|
| 347 |
+
|
| 348 |
+
st.session_state.user_config = load_user_config()
|
| 349 |
+
# Reinitialize RAG with new config
|
| 350 |
+
if "rag" in st.session_state:
|
| 351 |
+
del st.session_state.rag
|
| 352 |
+
|
| 353 |
+
st.success("Configuration saved successfully!")
|
| 354 |
except Exception as e:
|
| 355 |
+
st.error(f"Failed to save configuration: {e}")
|
| 356 |
+
|
| 357 |
+
elif nav == "Ingest Data":
|
| 358 |
+
st.header("📥 Ingest Data")
|
| 359 |
+
|
| 360 |
+
uploaded_files = st.file_uploader("Upload CSV transactions", accept_multiple_files=True, type=['csv'])
|
| 361 |
+
if uploaded_files:
|
| 362 |
+
if st.button("Ingest Selected Files", type="primary"):
|
| 363 |
+
if not config:
|
| 364 |
+
st.error("Please set up your Account Config first!")
|
| 365 |
+
return
|
| 366 |
+
|
| 367 |
+
# Initialize RAG if needed
|
| 368 |
+
if "rag" not in st.session_state:
|
| 369 |
+
st.session_state.rag = MoneyRAG(
|
| 370 |
+
llm_provider=config["llm_provider"],
|
| 371 |
+
model_name=config.get("decode_model", "gemini-2.5-pro"),
|
| 372 |
+
embedding_model_name=config.get("embedding_model", "gemini-embedding-001"),
|
| 373 |
+
api_key=config["api_key"],
|
| 374 |
+
user_id=st.session_state.user.id,
|
| 375 |
+
access_token=st.session_state.access_token
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
csv_files_info = []
|
| 379 |
+
user_id = st.session_state.user.id
|
| 380 |
+
|
| 381 |
+
with st.spinner("Uploading to Supabase Storage & Processing..."):
|
| 382 |
+
for uploaded_file in uploaded_files:
|
| 383 |
+
# 1. Save temp locally for pandas parsing
|
| 384 |
+
local_path = os.path.join(st.session_state.rag.temp_dir, uploaded_file.name)
|
| 385 |
+
with open(local_path, "wb") as f:
|
| 386 |
+
f.write(uploaded_file.getbuffer())
|
| 387 |
+
|
| 388 |
+
# 2. Upload raw file to Supabase Object Storage
|
| 389 |
+
s3_key = f"{user_id}/csvs/{uploaded_file.name}"
|
| 390 |
+
try:
|
| 391 |
+
supabase.storage.from_("money-rag-files").upload(
|
| 392 |
+
file=local_path,
|
| 393 |
+
path=s3_key,
|
| 394 |
+
file_options={"content-type": "text/csv", "upsert": "true"}
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# 3. Log the upload in the CSVFile table
|
| 398 |
+
csv_record = supabase.table("CSVFile").insert({
|
| 399 |
+
"user_id": user_id,
|
| 400 |
+
"filename": uploaded_file.name,
|
| 401 |
+
"s3_key": s3_key
|
| 402 |
+
}).execute()
|
| 403 |
+
|
| 404 |
+
csv_id = csv_record.data[0]['id']
|
| 405 |
+
csv_files_info.append({"path": local_path, "csv_id": csv_id})
|
| 406 |
+
|
| 407 |
+
except Exception as e:
|
| 408 |
+
st.error(f"Error uploading {uploaded_file.name}: {e}")
|
| 409 |
+
continue
|
| 410 |
+
|
| 411 |
+
# 4. Trigger the LLM parsing, routing CSV data to Supabase Postgres
|
| 412 |
+
if csv_files_info:
|
| 413 |
+
asyncio.run(st.session_state.rag.setup_session(csv_files_info))
|
| 414 |
+
st.success("Data uploaded, parsed, and vectorized securely!")
|
| 415 |
+
st.rerun()
|
| 416 |
+
|
| 417 |
+
st.divider()
|
| 418 |
+
st.subheader("Your Uploaded Files")
|
| 419 |
+
try:
|
| 420 |
+
res = supabase.table("CSVFile").select("*").eq("user_id", st.session_state.user.id).execute()
|
| 421 |
+
files = res.data
|
| 422 |
+
|
| 423 |
+
if not files:
|
| 424 |
+
st.info("No files uploaded yet.")
|
| 425 |
+
else:
|
| 426 |
+
for f in files:
|
| 427 |
+
col_file, col_del = st.columns([4, 1])
|
| 428 |
+
with col_file:
|
| 429 |
+
st.write(f"📄 **{f['filename']}** (Uploaded: {f['upload_date'][:10]})")
|
| 430 |
+
with col_del:
|
| 431 |
+
if st.button("Delete", key=f"del_{f['id']}"):
|
| 432 |
+
st.session_state[f"confirm_del_{f['id']}"] = True
|
| 433 |
+
|
| 434 |
+
if st.session_state.get(f"confirm_del_{f['id']}", False):
|
| 435 |
+
st.warning("Are you sure? This permanently deletes the file from Cloud Storage, the SQL Database, and the Vector Index.")
|
| 436 |
+
col_y, col_n = st.columns(2)
|
| 437 |
+
with col_y:
|
| 438 |
+
if st.button("Yes, Delete", key=f"yes_{f['id']}", type="primary"):
|
| 439 |
+
with st.spinner("Purging file data..."):
|
| 440 |
+
try:
|
| 441 |
+
# Delete from storage
|
| 442 |
+
supabase.storage.from_("money-rag-files").remove([f['s3_key']])
|
| 443 |
+
except Exception as e:
|
| 444 |
+
print(f"Warning storage delete failed: {e}")
|
| 445 |
+
|
| 446 |
+
# Use initialized RAG to delete from Vectors and Postgres
|
| 447 |
+
if "rag" not in st.session_state and config:
|
| 448 |
+
st.session_state.rag = MoneyRAG(
|
| 449 |
+
llm_provider=config["llm_provider"],
|
| 450 |
+
model_name=config.get("decode_model", "gemini-2.5-pro"),
|
| 451 |
+
embedding_model_name=config.get("embedding_model", "gemini-embedding-001"),
|
| 452 |
+
api_key=config["api_key"],
|
| 453 |
+
user_id=st.session_state.user.id,
|
| 454 |
+
access_token=st.session_state.access_token
|
| 455 |
+
)
|
| 456 |
+
if "rag" in st.session_state:
|
| 457 |
+
asyncio.run(st.session_state.rag.delete_file(f['id']))
|
| 458 |
+
else:
|
| 459 |
+
# Fallback if no RAG config to just delete from Postgres at least
|
| 460 |
+
supabase.table("Transaction").delete().eq("source_csv_id", f['id']).execute()
|
| 461 |
+
supabase.table("CSVFile").delete().eq("id", f['id']).execute()
|
| 462 |
+
|
| 463 |
+
del st.session_state[f"confirm_del_{f['id']}"]
|
| 464 |
+
st.success(f"Deleted {f['filename']}!")
|
| 465 |
+
st.rerun()
|
| 466 |
+
|
| 467 |
+
with col_n:
|
| 468 |
+
if st.button("Cancel", key=f"cancel_{f['id']}"):
|
| 469 |
+
del st.session_state[f"confirm_del_{f['id']}"]
|
| 470 |
+
st.rerun()
|
| 471 |
+
|
| 472 |
+
except Exception as e:
|
| 473 |
+
st.error(f"Failed to load files: {e}")
|
| 474 |
+
|
| 475 |
+
elif nav == "Chat":
|
| 476 |
+
st.header("💬 Financial Assistant")
|
| 477 |
+
if not config:
|
| 478 |
+
st.warning("Please configure your Account Config (API Key) first!")
|
| 479 |
+
return
|
| 480 |
+
|
| 481 |
+
if "rag" not in st.session_state:
|
| 482 |
+
st.session_state.rag = MoneyRAG(
|
| 483 |
+
llm_provider=config["llm_provider"],
|
| 484 |
+
model_name=config.get("decode_model", "gemini-2.5-pro"),
|
| 485 |
+
embedding_model_name=config.get("embedding_model", "gemini-embedding-001"),
|
| 486 |
+
api_key=config["api_key"],
|
| 487 |
+
user_id=st.session_state.user.id,
|
| 488 |
+
access_token=st.session_state.access_token
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if "messages" not in st.session_state:
|
| 492 |
+
st.session_state.messages = []
|
| 493 |
+
|
| 494 |
+
# Show file ingestion status
|
| 495 |
+
try:
|
| 496 |
+
client = get_supabase()
|
| 497 |
+
files_res = client.table("CSVFile").select("id, filename").eq("user_id", st.session_state.user.id).execute()
|
| 498 |
+
file_count = len(files_res.data) if files_res.data else 0
|
| 499 |
+
if file_count == 0:
|
| 500 |
+
st.warning("⚠️ No data loaded yet. Go to **Ingest Data** to upload a CSV file before chatting.")
|
| 501 |
+
else:
|
| 502 |
+
names = ", ".join(f['filename'] for f in files_res.data[:3])
|
| 503 |
+
suffix = f" + {file_count - 3} more" if file_count > 3 else ""
|
| 504 |
+
st.info(f"📊 **{file_count} file{'s' if file_count > 1 else ''} loaded:** {names}{suffix}")
|
| 505 |
+
except Exception:
|
| 506 |
+
pass # Don't break chat if the status check fails
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
# Helper function to cleverly render either text or a Plotly chart
|
| 510 |
+
def render_content(content):
|
| 511 |
+
if isinstance(content, str) and "===CHART===" in content:
|
| 512 |
+
parts = content.split("===CHART===")
|
| 513 |
+
st.markdown(parts[0].strip())
|
| 514 |
|
| 515 |
+
for part in parts[1:]:
|
| 516 |
+
if "===ENDCHART===" in part:
|
| 517 |
+
chart_json, remaining_text = part.split("===ENDCHART===")
|
| 518 |
+
try:
|
| 519 |
+
fig = pio.from_json(chart_json.strip())
|
| 520 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 521 |
+
except Exception as e:
|
| 522 |
+
st.error("Failed to render chart.")
|
| 523 |
+
|
| 524 |
+
if remaining_text.strip():
|
| 525 |
+
st.markdown(remaining_text.strip())
|
| 526 |
+
else:
|
| 527 |
+
st.markdown(content)
|
| 528 |
+
|
| 529 |
+
# Render previous messages
|
| 530 |
+
for message in st.session_state.messages:
|
| 531 |
+
with st.chat_message(message["role"]):
|
| 532 |
+
render_content(message["content"])
|
| 533 |
+
|
| 534 |
+
# Handle new user input
|
| 535 |
+
if prompt := st.chat_input("Ask about your spending..."):
|
| 536 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 537 |
+
with st.chat_message("user"):
|
| 538 |
+
st.markdown(prompt)
|
| 539 |
+
|
| 540 |
+
with st.chat_message("assistant"):
|
| 541 |
+
with st.spinner("Thinking..."):
|
| 542 |
+
try:
|
| 543 |
+
response = asyncio.run(st.session_state.rag.chat(prompt))
|
| 544 |
+
render_content(response)
|
| 545 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 546 |
+
except Exception as e:
|
| 547 |
+
st.error(f"Error during chat: {e}")
|
| 548 |
|
| 549 |
+
if __name__ == "__main__":
|
| 550 |
+
# Attempt to restore session from query params if page was refreshed
|
| 551 |
+
if "user" not in st.session_state:
|
| 552 |
+
token_from_url = st.query_params.get("t")
|
| 553 |
+
if token_from_url:
|
| 554 |
+
try:
|
| 555 |
+
res = supabase.auth.get_user(token_from_url)
|
| 556 |
+
if res and res.user:
|
| 557 |
+
st.session_state.user = res.user
|
| 558 |
+
st.session_state.access_token = token_from_url
|
| 559 |
+
except Exception:
|
| 560 |
+
# Token is invalid/expired - clear it from the URL too
|
| 561 |
+
if "t" in st.query_params:
|
| 562 |
+
del st.query_params["t"]
|
| 563 |
+
|
| 564 |
+
if "user" not in st.session_state:
|
| 565 |
+
login_register_page()
|
| 566 |
+
else:
|
| 567 |
+
main_app_view()
|
mcp_server.py
CHANGED
|
@@ -6,55 +6,66 @@ from qdrant_client import QdrantClient
|
|
| 6 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
import os
|
|
|
|
| 9 |
|
| 10 |
import shutil
|
| 11 |
|
|
|
|
|
|
|
| 12 |
# Load environment variables (API keys, etc.)
|
| 13 |
load_dotenv()
|
| 14 |
|
| 15 |
# Define paths to your data
|
| 16 |
-
# For Hugging Face Spaces (Ephemeral):
|
| 17 |
-
# We use a temporary directory that gets wiped on restart.
|
| 18 |
-
# If DATA_DIR is set (e.g., by your deployment config), use it.
|
| 19 |
DATA_DIR = os.getenv("DATA_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_data"))
|
| 20 |
-
QDRANT_PATH = os.path.join(DATA_DIR, "qdrant_db")
|
| 21 |
-
DB_PATH = os.path.join(DATA_DIR, "money_rag.db")
|
| 22 |
|
| 23 |
# Initialize the MCP Server
|
| 24 |
mcp = FastMCP("Money RAG Financial Analyst")
|
| 25 |
|
| 26 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def get_schema_info() -> str:
|
| 29 |
-
"""Get database schema information."""
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
conn.close()
|
| 55 |
-
return "\n".join(schema_info)
|
| 56 |
-
except Exception as e:
|
| 57 |
-
return f"Error reading schema: {e}"
|
| 58 |
|
| 59 |
|
| 60 |
@mcp.resource("schema://database/tables")
|
|
@@ -64,7 +75,17 @@ def get_database_schema() -> str:
|
|
| 64 |
|
| 65 |
@mcp.tool()
|
| 66 |
def query_database(query: str) -> str:
|
| 67 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
Args:
|
| 70 |
query: The SQL SELECT query to execute
|
|
@@ -78,33 +99,32 @@ def query_database(query: str) -> str:
|
|
| 78 |
- 'amount' column: positive values = spending, negative values = payments/refunds
|
| 79 |
|
| 80 |
Example queries:
|
| 81 |
-
- Find Walmart spending: SELECT SUM(amount) FROM
|
| 82 |
-
- List recent transactions: SELECT
|
| 83 |
-
- Spending by category: SELECT category, SUM(amount) FROM
|
| 84 |
"""
|
| 85 |
-
if not os.path.exists(DB_PATH):
|
| 86 |
-
return "Database file does not exist yet. Please upload data."
|
| 87 |
-
|
| 88 |
# Security: Only allow SELECT queries
|
| 89 |
query_upper = query.strip().upper()
|
| 90 |
-
if not query_upper.startswith("SELECT") and not query_upper.startswith("
|
| 91 |
-
return "Error: Only SELECT
|
| 92 |
|
| 93 |
# Forbidden operations
|
| 94 |
-
forbidden = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "REPLACE", "TRUNCATE"
|
| 95 |
-
# Check for forbidden words as standalone words to avoid false positives (e.g. "update_date" column)
|
| 96 |
-
# Simple check: space-surrounded or end-of-string
|
| 97 |
if any(f" {word} " in f" {query_upper} " for word in forbidden):
|
| 98 |
return f"Error: Query contains forbidden operation. Only SELECT queries allowed."
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
try:
|
| 101 |
-
conn =
|
| 102 |
cursor = conn.cursor()
|
| 103 |
cursor.execute(query)
|
| 104 |
results = cursor.fetchall()
|
| 105 |
|
| 106 |
# Get column names to make result more readable
|
| 107 |
-
column_names = [
|
| 108 |
|
| 109 |
conn.close()
|
| 110 |
|
|
@@ -118,28 +138,20 @@ def query_database(query: str) -> str:
|
|
| 118 |
formatted_results.append(str(row))
|
| 119 |
|
| 120 |
return "\n".join(formatted_results)
|
| 121 |
-
except
|
| 122 |
-
return f"Error: {str(e)}"
|
| 123 |
|
| 124 |
def get_vector_store():
|
| 125 |
"""Initialize connection to the Qdrant vector store"""
|
| 126 |
# Initialize Embedding Model using Google AI Studio
|
| 127 |
-
embeddings = GoogleGenerativeAIEmbeddings(model="
|
| 128 |
|
| 129 |
-
# Connect to Qdrant
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
# Check if collection exists (it might be empty in a new ephemeral session)
|
| 136 |
-
collections = client.get_collections().collections
|
| 137 |
-
collection_names = [c.name for c in collections]
|
| 138 |
|
| 139 |
-
if "transactions" not in collection_names:
|
| 140 |
-
# In a real app, you would probably trigger ingestion here or handle the empty state
|
| 141 |
-
pass
|
| 142 |
-
|
| 143 |
return QdrantVectorStore(
|
| 144 |
client=client,
|
| 145 |
collection_name="transactions",
|
|
@@ -159,20 +171,22 @@ def semantic_search(query: str, top_k: int = 5) -> str:
|
|
| 159 |
top_k: Number of results to return (default 5).
|
| 160 |
"""
|
| 161 |
try:
|
|
|
|
| 162 |
vector_store = get_vector_store()
|
| 163 |
|
| 164 |
-
#
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
| 167 |
|
| 168 |
-
results = vector_store.similarity_search(query, k=top_k)
|
| 169 |
|
| 170 |
if not results:
|
| 171 |
return "No matching transactions found."
|
| 172 |
|
| 173 |
output = []
|
| 174 |
for doc in results:
|
| 175 |
-
# Format the output clearly for the LLM/User
|
| 176 |
amount = doc.metadata.get('amount', 'N/A')
|
| 177 |
date = doc.metadata.get('transaction_date', 'N/A')
|
| 178 |
output.append(f"Date: {date} | Match: {doc.page_content} | Amount: {amount}")
|
|
@@ -184,25 +198,29 @@ def semantic_search(query: str, top_k: int = 5) -> str:
|
|
| 184 |
|
| 185 |
|
| 186 |
@mcp.tool()
|
| 187 |
-
def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_col: str, title: str) -> str:
|
| 188 |
"""
|
| 189 |
-
Generate an interactive Plotly chart
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
CRITICAL INSTRUCTIONS:
|
| 193 |
-
1. Write a valid SQLite SELECT query.
|
| 194 |
-
2. Aggregate data appropriately (e.g., use GROUP BY for pie/bar charts).
|
| 195 |
-
3. Pass the exact column names from your query to x_col and y_col.
|
| 196 |
-
|
| 197 |
Args:
|
| 198 |
-
sql_query: The SQL SELECT query
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
"""
|
| 204 |
try:
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
df = pd.read_sql_query(sql_query, conn)
|
| 207 |
conn.close()
|
| 208 |
if df.empty:
|
|
@@ -226,17 +244,7 @@ def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_co
|
|
| 226 |
return f'{{"error": "Failed to generate chart: {str(e)}"}}'
|
| 227 |
|
| 228 |
|
| 229 |
-
|
| 230 |
-
@mcp.tool()
|
| 231 |
-
def clear_database() -> str:
|
| 232 |
-
"""Clear all stored transaction data to reset the session."""
|
| 233 |
-
try:
|
| 234 |
-
if os.path.exists(DATA_DIR):
|
| 235 |
-
shutil.rmtree(DATA_DIR)
|
| 236 |
-
os.makedirs(DATA_DIR)
|
| 237 |
-
return "Database cleared successfully."
|
| 238 |
-
except Exception as e:
|
| 239 |
-
return f"Error clearing database: {e}"
|
| 240 |
|
| 241 |
if __name__ == "__main__":
|
| 242 |
# Runs the server over stdio
|
|
|
|
| 6 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
import os
|
| 9 |
+
from typing import Optional
|
| 10 |
|
| 11 |
import shutil
|
| 12 |
|
| 13 |
+
from textwrap import dedent
|
| 14 |
+
|
| 15 |
# Load environment variables (API keys, etc.)
|
| 16 |
load_dotenv()
|
| 17 |
|
| 18 |
# Define paths to your data
|
|
|
|
|
|
|
|
|
|
| 19 |
DATA_DIR = os.getenv("DATA_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_data"))
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Initialize the MCP Server
|
| 22 |
mcp = FastMCP("Money RAG Financial Analyst")
|
| 23 |
|
| 24 |
+
import psycopg2
|
| 25 |
+
from supabase import create_client, Client
|
| 26 |
+
|
| 27 |
+
def get_db_connection():
|
| 28 |
+
"""Returns a psycopg2 connection to Supabase Postgres."""
|
| 29 |
+
# Supabase provides postgres connection strings, but typically doesn't default in plain OS vars unless you build it
|
| 30 |
+
# Supabase gives a postgres:// connection string in the dashboard under Database Settings.
|
| 31 |
+
# Alternatively we can build it manually or just use the Supabase python client.
|
| 32 |
+
# To support raw LLM SQL, we use psycopg2 instead of Supabase client.
|
| 33 |
+
db_url = os.environ.get("DATABASE_URL")
|
| 34 |
+
if not db_url:
|
| 35 |
+
raise ValueError("DATABASE_URL must be defined to construct raw SQL connections.")
|
| 36 |
+
return psycopg2.connect(db_url)
|
| 37 |
+
|
| 38 |
+
def get_current_user_id() -> str:
|
| 39 |
+
user_id = os.environ.get("CURRENT_USER_ID")
|
| 40 |
+
if not user_id:
|
| 41 |
+
raise ValueError("CURRENT_USER_ID not injected into MCP environment!")
|
| 42 |
+
return user_id
|
| 43 |
|
| 44 |
def get_schema_info() -> str:
|
| 45 |
+
"""Get database schema information for Postgres tables."""
|
| 46 |
+
return dedent("""
|
| 47 |
+
Here is the PostgreSQL database schema for the authenticated user's data.
|
| 48 |
+
|
| 49 |
+
CRITICAL RULE:
|
| 50 |
+
You MUST add `WHERE user_id = '{current_user_id}'` to EVERY SINGLE query you write.
|
| 51 |
+
Never query data without filtering by user_id!
|
| 52 |
+
|
| 53 |
+
TABLE: "Transaction"
|
| 54 |
+
Columns:
|
| 55 |
+
- id (UUID)
|
| 56 |
+
- user_id (UUID)
|
| 57 |
+
- trans_date (DATE)
|
| 58 |
+
- description (TEXT)
|
| 59 |
+
- amount (DECIMAL)
|
| 60 |
+
- category (VARCHAR)
|
| 61 |
+
|
| 62 |
+
TABLE: "TransactionDetail"
|
| 63 |
+
Columns:
|
| 64 |
+
- id (UUID)
|
| 65 |
+
- transaction_id (UUID)
|
| 66 |
+
- item_description (TEXT)
|
| 67 |
+
- item_total_price (DECIMAL)
|
| 68 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
@mcp.resource("schema://database/tables")
|
|
|
|
| 75 |
|
| 76 |
@mcp.tool()
|
| 77 |
def query_database(query: str) -> str:
|
| 78 |
+
"""
|
| 79 |
+
Execute a raw SQL query against the Postgres database.
|
| 80 |
+
The main table is named "Transaction" (you MUST INCLUDE QUOTES in your SQL!).
|
| 81 |
+
IMPORTANT STRICT SCHEMA:
|
| 82 |
+
- id (UUID)
|
| 83 |
+
- user_id (UUID text)
|
| 84 |
+
- trans_date (DATE)
|
| 85 |
+
- description (TEXT)
|
| 86 |
+
- amount (NUMERIC)
|
| 87 |
+
- category (TEXT)
|
| 88 |
+
- enriched_info (TEXT)
|
| 89 |
|
| 90 |
Args:
|
| 91 |
query: The SQL SELECT query to execute
|
|
|
|
| 99 |
- 'amount' column: positive values = spending, negative values = payments/refunds
|
| 100 |
|
| 101 |
Example queries:
|
| 102 |
+
- Find Walmart spending: SELECT SUM(amount) FROM "Transaction" WHERE description LIKE '%Walmart%' AND amount > 0;
|
| 103 |
+
- List recent transactions: SELECT trans_date, description, amount, category FROM "Transaction" ORDER BY trans_date DESC LIMIT 5;
|
| 104 |
+
- Spending by category: SELECT category, SUM(amount) FROM "Transaction" WHERE amount > 0 GROUP BY category;
|
| 105 |
"""
|
|
|
|
|
|
|
|
|
|
| 106 |
# Security: Only allow SELECT queries
|
| 107 |
query_upper = query.strip().upper()
|
| 108 |
+
if not query_upper.startswith("SELECT") and not query_upper.startswith("WITH"):
|
| 109 |
+
return "Error: Only SELECT queries are allowed"
|
| 110 |
|
| 111 |
# Forbidden operations
|
| 112 |
+
forbidden = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "REPLACE", "TRUNCATE"]
|
|
|
|
|
|
|
| 113 |
if any(f" {word} " in f" {query_upper} " for word in forbidden):
|
| 114 |
return f"Error: Query contains forbidden operation. Only SELECT queries allowed."
|
| 115 |
|
| 116 |
+
user_id = get_current_user_id()
|
| 117 |
+
if user_id not in query:
|
| 118 |
+
return f"Error: You forgot to include the security filter (WHERE user_id = '{user_id}') in your query! Try again."
|
| 119 |
+
|
| 120 |
try:
|
| 121 |
+
conn = get_db_connection()
|
| 122 |
cursor = conn.cursor()
|
| 123 |
cursor.execute(query)
|
| 124 |
results = cursor.fetchall()
|
| 125 |
|
| 126 |
# Get column names to make result more readable
|
| 127 |
+
column_names = [desc[0] for desc in cursor.description] if cursor.description else []
|
| 128 |
|
| 129 |
conn.close()
|
| 130 |
|
|
|
|
| 138 |
formatted_results.append(str(row))
|
| 139 |
|
| 140 |
return "\n".join(formatted_results)
|
| 141 |
+
except psycopg2.Error as e:
|
| 142 |
+
return f"Database Error: {str(e)}"
|
| 143 |
|
| 144 |
def get_vector_store():
|
| 145 |
"""Initialize connection to the Qdrant vector store"""
|
| 146 |
# Initialize Embedding Model using Google AI Studio
|
| 147 |
+
embeddings = GoogleGenerativeAIEmbeddings(model="gemini-embedding-001")
|
| 148 |
|
| 149 |
+
# Connect to Qdrant Cloud
|
| 150 |
+
client = QdrantClient(
|
| 151 |
+
url=os.getenv("QDRANT_URL"),
|
| 152 |
+
api_key=os.getenv("QDRANT_API_KEY"),
|
| 153 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
return QdrantVectorStore(
|
| 156 |
client=client,
|
| 157 |
collection_name="transactions",
|
|
|
|
| 171 |
top_k: Number of results to return (default 5).
|
| 172 |
"""
|
| 173 |
try:
|
| 174 |
+
user_id = get_current_user_id()
|
| 175 |
vector_store = get_vector_store()
|
| 176 |
|
| 177 |
+
# Apply strict multi-tenant filtering based on the payload we injected in money_rag.py
|
| 178 |
+
from qdrant_client.http import models
|
| 179 |
+
filter = models.Filter(
|
| 180 |
+
must=[models.FieldCondition(key="metadata.user_id", match=models.MatchValue(value=user_id))]
|
| 181 |
+
)
|
| 182 |
|
| 183 |
+
results = vector_store.similarity_search(query, k=top_k, filter=filter)
|
| 184 |
|
| 185 |
if not results:
|
| 186 |
return "No matching transactions found."
|
| 187 |
|
| 188 |
output = []
|
| 189 |
for doc in results:
|
|
|
|
| 190 |
amount = doc.metadata.get('amount', 'N/A')
|
| 191 |
date = doc.metadata.get('transaction_date', 'N/A')
|
| 192 |
output.append(f"Date: {date} | Match: {doc.page_content} | Amount: {amount}")
|
|
|
|
| 198 |
|
| 199 |
|
| 200 |
@mcp.tool()
|
| 201 |
+
def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_col: str, title: str, color_col: Optional[str] = None) -> str:
|
| 202 |
"""
|
| 203 |
+
Generate an interactive Plotly chart using SQL data.
|
| 204 |
+
IMPORTANT: The table name MUST be "Transaction" exactly with quotes.
|
| 205 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
Args:
|
| 207 |
+
sql_query: The SQL SELECT query to retrieve the data for the chart from the "Transaction" table.
|
| 208 |
+
- Must use 'user_id' filter.
|
| 209 |
+
chart_type: The type of chart: 'bar', 'line', 'pie', 'scatter'
|
| 210 |
+
x_col: The name of the column to use for the X axis (or labels for pie charts)
|
| 211 |
+
y_col: The name of the column to use for the Y axis (or values for pie charts)
|
| 212 |
+
title: The title of the chart
|
| 213 |
+
color_col: (Optional) Column to use for color grouping
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
A natural language summary confirming chart generation.
|
| 217 |
"""
|
| 218 |
try:
|
| 219 |
+
user_id = get_current_user_id()
|
| 220 |
+
if user_id not in sql_query:
|
| 221 |
+
return f'{{"error": "You forgot the WHERE user_id = \\"{user_id}\\" security clause!"}}'
|
| 222 |
+
|
| 223 |
+
conn = get_db_connection()
|
| 224 |
df = pd.read_sql_query(sql_query, conn)
|
| 225 |
conn.close()
|
| 226 |
if df.empty:
|
|
|
|
| 244 |
return f'{{"error": "Failed to generate chart: {str(e)}"}}'
|
| 245 |
|
| 246 |
|
| 247 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
if __name__ == "__main__":
|
| 250 |
# Runs the server over stdio
|
money_rag.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import uuid
|
| 3 |
import asyncio
|
| 4 |
import pandas as pd
|
|
@@ -21,16 +22,34 @@ from langgraph.checkpoint.memory import InMemorySaver
|
|
| 21 |
from langchain.agents import create_agent
|
| 22 |
from langchain_community.tools import DuckDuckGoSearchRun
|
| 23 |
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
|
|
| 24 |
|
| 25 |
# Import specific embeddings
|
| 26 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 27 |
from langchain_openai import OpenAIEmbeddings
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
class MoneyRAG:
|
| 30 |
-
def __init__(self, llm_provider: str, model_name: str, embedding_model_name: str, api_key: str):
|
| 31 |
self.llm_provider = llm_provider.lower()
|
| 32 |
self.model_name = model_name
|
| 33 |
self.embedding_model_name = embedding_model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# Set API Keys
|
| 36 |
if self.llm_provider == "google":
|
|
@@ -60,17 +79,18 @@ class MoneyRAG:
|
|
| 60 |
self.mcp_client: Optional[MultiServerMCPClient] = None
|
| 61 |
self.search_tool = DuckDuckGoSearchRun()
|
| 62 |
self.merchant_cache = {} # Session-based cache for merchant enrichment
|
|
|
|
| 63 |
|
| 64 |
-
async def setup_session(self,
|
| 65 |
"""Ingests CSVs and sets up DBs."""
|
| 66 |
-
|
| 67 |
-
|
|
|
|
| 68 |
|
| 69 |
self.db = SQLDatabase.from_uri(f"sqlite:///{self.db_path}")
|
| 70 |
self.vector_store = self._sync_to_qdrant()
|
| 71 |
-
await self._init_agent()
|
| 72 |
|
| 73 |
-
async def _ingest_csv(self, file_path):
|
| 74 |
df = pd.read_csv(file_path)
|
| 75 |
headers = df.columns.tolist()
|
| 76 |
sample_data = df.head(10).to_json()
|
|
@@ -108,14 +128,16 @@ class MoneyRAG:
|
|
| 108 |
mapping = await chain.ainvoke({"headers": headers, "sample": sample_data, "filename": os.path.basename(file_path)})
|
| 109 |
|
| 110 |
standard_df = pd.DataFrame()
|
| 111 |
-
standard_df['
|
| 112 |
-
|
|
|
|
| 113 |
standard_df['description'] = df[mapping['desc_col']]
|
|
|
|
|
|
|
| 114 |
|
| 115 |
raw_amounts = pd.to_numeric(df[mapping['amount_col']])
|
| 116 |
standard_df['amount'] = raw_amounts * -1 if mapping['sign_convention'] == "spending_is_negative" else raw_amounts
|
| 117 |
standard_df['category'] = df[mapping.get('category_col')] if mapping.get('category_col') else 'Uncategorized'
|
| 118 |
-
standard_df['source_file'] = os.path.basename(file_path)
|
| 119 |
|
| 120 |
# --- Async Enrichment Step ---
|
| 121 |
print(f" ✨ Enriching descriptions for {os.path.basename(file_path)}...")
|
|
@@ -143,29 +165,49 @@ class MoneyRAG:
|
|
| 143 |
desc_map = dict(zip(unique_descriptions, enrichment_results))
|
| 144 |
standard_df['enriched_info'] = standard_df['description'].map(desc_map).fillna("")
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
def _sync_to_qdrant(self):
|
| 151 |
-
client = QdrantClient(path=self.qdrant_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
collection = "transactions"
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
|
| 158 |
# Check for empty dataframe
|
| 159 |
if df.empty:
|
| 160 |
-
raise ValueError("No transactions found in database. Please
|
| 161 |
|
| 162 |
# Dynamically detect embedding dimension
|
| 163 |
sample_embedding = self.embeddings.embed_query("test")
|
| 164 |
embedding_dim = len(sample_embedding)
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
collection_name=collection,
|
| 168 |
-
|
|
|
|
| 169 |
)
|
| 170 |
|
| 171 |
vs = QdrantVectorStore(client=client, collection_name=collection, embedding=self.embeddings)
|
|
@@ -180,90 +222,124 @@ class MoneyRAG:
|
|
| 180 |
else:
|
| 181 |
texts.append(base_text)
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
return vs
|
| 188 |
|
| 189 |
-
async def
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
server_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mcp_server.py")
|
| 192 |
|
| 193 |
-
|
| 194 |
{
|
| 195 |
"money_rag": {
|
| 196 |
"transport": "stdio",
|
| 197 |
-
"command":
|
| 198 |
"args": [server_path],
|
| 199 |
-
"env": os.environ.copy(),
|
| 200 |
}
|
| 201 |
}
|
| 202 |
)
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
|
|
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
os.remove(chart_path)
|
| 231 |
-
|
| 232 |
-
result = await self.agent.ainvoke(
|
| 233 |
-
{"messages": [{"role": "user", "content": query}]},
|
| 234 |
-
config,
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
# Extract content - handle both string and list formats
|
| 238 |
-
content = result["messages"][-1].content
|
| 239 |
-
|
| 240 |
-
# If content is a list (Gemini format), extract text from blocks
|
| 241 |
-
if isinstance(content, list):
|
| 242 |
-
text_parts = []
|
| 243 |
-
for block in content:
|
| 244 |
-
if isinstance(block, dict) and block.get("type") == "text":
|
| 245 |
-
text_parts.append(block.get("text", ""))
|
| 246 |
-
final_text = "\n".join(text_parts)
|
| 247 |
-
else:
|
| 248 |
-
final_text = content
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
final_text += f"\n\n===CHART===\n{chart_json}\n===ENDCHART==="
|
| 256 |
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
async def cleanup(self):
|
| 260 |
"""Delete temporary session files and close MCP client."""
|
| 261 |
-
if self.mcp_client:
|
| 262 |
-
try:
|
| 263 |
-
await self.mcp_client.close()
|
| 264 |
-
except Exception as e:
|
| 265 |
-
print(f"Warning: Failed to close MCP client: {e}")
|
| 266 |
-
|
| 267 |
if os.path.exists(self.temp_dir):
|
| 268 |
try:
|
| 269 |
shutil.rmtree(self.temp_dir)
|
|
|
|
| 1 |
import os
|
| 2 |
+
import sys
|
| 3 |
import uuid
|
| 4 |
import asyncio
|
| 5 |
import pandas as pd
|
|
|
|
| 22 |
from langchain.agents import create_agent
|
| 23 |
from langchain_community.tools import DuckDuckGoSearchRun
|
| 24 |
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 25 |
+
from qdrant_client.http import models as qdrant_models
|
| 26 |
|
| 27 |
# Import specific embeddings
|
| 28 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 29 |
from langchain_openai import OpenAIEmbeddings
|
| 30 |
|
| 31 |
+
from supabase import create_client, ClientOptions
|
| 32 |
+
|
| 33 |
+
from dotenv import load_dotenv
|
| 34 |
+
load_dotenv()
|
| 35 |
+
|
| 36 |
class MoneyRAG:
|
| 37 |
+
def __init__(self, llm_provider: str, model_name: str, embedding_model_name: str, api_key: str, user_id: str, access_token: str = None):
|
| 38 |
self.llm_provider = llm_provider.lower()
|
| 39 |
self.model_name = model_name
|
| 40 |
self.embedding_model_name = embedding_model_name
|
| 41 |
+
self.user_id = user_id
|
| 42 |
+
|
| 43 |
+
# Initialize Supabase Client
|
| 44 |
+
url = os.environ.get("SUPABASE_URL")
|
| 45 |
+
key = os.environ.get("SUPABASE_KEY")
|
| 46 |
+
|
| 47 |
+
# Security: Inject the logged-in user's JWT so RLS policies pass!
|
| 48 |
+
if access_token:
|
| 49 |
+
opts = ClientOptions(headers={"Authorization": f"Bearer {access_token}"})
|
| 50 |
+
self.supabase = create_client(url, key, options=opts)
|
| 51 |
+
else:
|
| 52 |
+
self.supabase = create_client(url, key)
|
| 53 |
|
| 54 |
# Set API Keys
|
| 55 |
if self.llm_provider == "google":
|
|
|
|
| 79 |
self.mcp_client: Optional[MultiServerMCPClient] = None
|
| 80 |
self.search_tool = DuckDuckGoSearchRun()
|
| 81 |
self.merchant_cache = {} # Session-based cache for merchant enrichment
|
| 82 |
+
self.memory = InMemorySaver() # Session-based cache for chat memory
|
| 83 |
|
| 84 |
+
async def setup_session(self, csv_files: List[dict]):
|
| 85 |
"""Ingests CSVs and sets up DBs."""
|
| 86 |
+
# csv_files format: [{"path": "/temp/file.csv", "csv_id": "uuid"}, ...]
|
| 87 |
+
for file_info in csv_files:
|
| 88 |
+
await self._ingest_csv(file_info["path"], file_info.get("csv_id"))
|
| 89 |
|
| 90 |
self.db = SQLDatabase.from_uri(f"sqlite:///{self.db_path}")
|
| 91 |
self.vector_store = self._sync_to_qdrant()
|
|
|
|
| 92 |
|
| 93 |
+
async def _ingest_csv(self, file_path, csv_id=None):
|
| 94 |
df = pd.read_csv(file_path)
|
| 95 |
headers = df.columns.tolist()
|
| 96 |
sample_data = df.head(10).to_json()
|
|
|
|
| 128 |
mapping = await chain.ainvoke({"headers": headers, "sample": sample_data, "filename": os.path.basename(file_path)})
|
| 129 |
|
| 130 |
standard_df = pd.DataFrame()
|
| 131 |
+
standard_df['trans_date'] = pd.to_datetime(df[mapping['date_col']]).dt.strftime('%Y-%m-%d')
|
| 132 |
+
# Assign user_id AFTER trans_date establishes the DataFrame length, or else it defaults to NaN!
|
| 133 |
+
standard_df['user_id'] = self.user_id
|
| 134 |
standard_df['description'] = df[mapping['desc_col']]
|
| 135 |
+
if csv_id:
|
| 136 |
+
standard_df['source_csv_id'] = csv_id
|
| 137 |
|
| 138 |
raw_amounts = pd.to_numeric(df[mapping['amount_col']])
|
| 139 |
standard_df['amount'] = raw_amounts * -1 if mapping['sign_convention'] == "spending_is_negative" else raw_amounts
|
| 140 |
standard_df['category'] = df[mapping.get('category_col')] if mapping.get('category_col') else 'Uncategorized'
|
|
|
|
| 141 |
|
| 142 |
# --- Async Enrichment Step ---
|
| 143 |
print(f" ✨ Enriching descriptions for {os.path.basename(file_path)}...")
|
|
|
|
| 165 |
desc_map = dict(zip(unique_descriptions, enrichment_results))
|
| 166 |
standard_df['enriched_info'] = standard_df['description'].map(desc_map).fillna("")
|
| 167 |
|
| 168 |
+
# Save to Supabase transactions table instead of local SQLite
|
| 169 |
+
# Use simplejson roundtrip to guarantee all Pandas NaNs, NaTs, and weird floats become strict JSON nulls
|
| 170 |
+
import json
|
| 171 |
+
records = json.loads(standard_df.to_json(orient='records'))
|
| 172 |
+
|
| 173 |
+
batch_size = 100
|
| 174 |
+
for i in range(0, len(records), batch_size):
|
| 175 |
+
batch = records[i:i + batch_size]
|
| 176 |
+
# If insertion fails, it raises an exception so Streamlit surfaces the error
|
| 177 |
+
self.supabase.table("Transaction").insert(batch).execute()
|
| 178 |
|
| 179 |
def _sync_to_qdrant(self):
|
| 180 |
+
# client = QdrantClient(path=self.qdrant_path)
|
| 181 |
+
client = QdrantClient(
|
| 182 |
+
url=os.getenv("QDRANT_URL"),
|
| 183 |
+
api_key=os.getenv("QDRANT_API_KEY"),
|
| 184 |
+
)
|
| 185 |
collection = "transactions"
|
| 186 |
|
| 187 |
+
# Fetch only THIS USER'S transactions from Supabase to sync into VectorDB
|
| 188 |
+
res = self.supabase.table("Transaction").select("*").eq("user_id", self.user_id).execute()
|
| 189 |
+
df = pd.DataFrame(res.data)
|
| 190 |
|
| 191 |
# Check for empty dataframe
|
| 192 |
if df.empty:
|
| 193 |
+
raise ValueError("No transactions found in database for this user. Please upload files first.")
|
| 194 |
|
| 195 |
# Dynamically detect embedding dimension
|
| 196 |
sample_embedding = self.embeddings.embed_query("test")
|
| 197 |
embedding_dim = len(sample_embedding)
|
| 198 |
|
| 199 |
+
# Safely create the collection only if it doesn't already exist to preserve multi-tenant pool
|
| 200 |
+
if not client.collection_exists(collection):
|
| 201 |
+
client.create_collection(
|
| 202 |
+
collection_name=collection,
|
| 203 |
+
vectors_config=qdrant_models.VectorParams(size=embedding_dim, distance=qdrant_models.Distance.COSINE),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Security: Create a strict Payload Index on the user_id field so we can filter by it securely!
|
| 207 |
+
client.create_payload_index(
|
| 208 |
collection_name=collection,
|
| 209 |
+
field_name="metadata.user_id",
|
| 210 |
+
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
| 211 |
)
|
| 212 |
|
| 213 |
vs = QdrantVectorStore(client=client, collection_name=collection, embedding=self.embeddings)
|
|
|
|
| 222 |
else:
|
| 223 |
texts.append(base_text)
|
| 224 |
|
| 225 |
+
# Inject critical user_id payload to Qdrant so we can filter on it during retrieval
|
| 226 |
+
metadatas = df[['id', 'amount', 'category', 'trans_date']].copy()
|
| 227 |
+
if 'source_csv_id' in df.columns:
|
| 228 |
+
metadatas['source_csv_id'] = df['source_csv_id']
|
| 229 |
+
metadatas = metadatas.to_dict('records')
|
| 230 |
|
| 231 |
+
vector_ids = []
|
| 232 |
+
for m in metadatas:
|
| 233 |
+
vector_ids.append(str(m['id'])) # Keep original Postgres UUID as Vector ID to prevent duplication
|
| 234 |
+
m['user_id'] = self.user_id # Secure payload identifier
|
| 235 |
+
m['transaction_date'] = str(m['trans_date']) # Rename for agent consistency
|
| 236 |
+
del m['trans_date']
|
| 237 |
+
|
| 238 |
+
vs.add_texts(texts=texts, metadatas=metadatas, ids=vector_ids)
|
| 239 |
return vs
|
| 240 |
|
| 241 |
+
async def delete_file(self, csv_id: str):
|
| 242 |
+
"""Force delete a file and all its transactions from Postgres and Qdrant."""
|
| 243 |
+
try:
|
| 244 |
+
# 1. Delete from Postgres (Transactions cascade automatically if foreign keyed... but we'll manually ensure they wipe just in case)
|
| 245 |
+
self.supabase.table("Transaction").delete().eq("source_csv_id", csv_id).execute()
|
| 246 |
+
self.supabase.table("CSVFile").delete().eq("id", csv_id).execute()
|
| 247 |
+
|
| 248 |
+
# 2. Delete from Qdrant via payload filter
|
| 249 |
+
client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
|
| 250 |
+
client.delete(
|
| 251 |
+
collection_name="transactions",
|
| 252 |
+
points_selector=qdrant_models.Filter(
|
| 253 |
+
must=[
|
| 254 |
+
qdrant_models.FieldCondition(
|
| 255 |
+
key="metadata.source_csv_id",
|
| 256 |
+
match=qdrant_models.MatchValue(value=csv_id)
|
| 257 |
+
)
|
| 258 |
+
]
|
| 259 |
+
)
|
| 260 |
+
)
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"Error purging file data: {e}")
|
| 263 |
+
|
| 264 |
+
async def chat(self, query: str):
|
| 265 |
+
# 1. Initialize MCP client dynamically to guarantee fresh bindings
|
| 266 |
server_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mcp_server.py")
|
| 267 |
|
| 268 |
+
mcp_client = MultiServerMCPClient(
|
| 269 |
{
|
| 270 |
"money_rag": {
|
| 271 |
"transport": "stdio",
|
| 272 |
+
"command": sys.executable,
|
| 273 |
"args": [server_path],
|
| 274 |
+
"env": {**os.environ.copy(), "CURRENT_USER_ID": self.user_id},
|
| 275 |
}
|
| 276 |
}
|
| 277 |
)
|
| 278 |
|
| 279 |
+
try:
|
| 280 |
+
# 2. Extract tools from the safely established subprocess
|
| 281 |
+
mcp_tools = await mcp_client.get_tools()
|
| 282 |
|
| 283 |
+
# 3. Create the LangGraph agent for this turn, preserving historical memory cache
|
| 284 |
+
system_prompt = (
|
| 285 |
+
"You are a financial analyst. Use the provided tools to query the database "
|
| 286 |
+
"and perform semantic searches. Spending is POSITIVE (>0). "
|
| 287 |
+
"Always explain your findings clearly."
|
| 288 |
+
"IMPORTANT: Whenever possible and relevant (e.g. when discussing trends, comparing categories, or showing breakdowns), "
|
| 289 |
+
"you MUST proactively use the 'generate_interactive_chart' tool to generate visual plots (bar, pie, or line charts) to accompany your analysis. "
|
| 290 |
+
"WARNING: You MUST use the actual tool call to generate the chart. DO NOT simply output a json block with chart parameters as your final text answer."
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
agent = create_agent(
|
| 294 |
+
model=self.llm,
|
| 295 |
+
tools=mcp_tools,
|
| 296 |
+
system_prompt=system_prompt,
|
| 297 |
+
checkpointer=self.memory,
|
| 298 |
+
)
|
| 299 |
|
| 300 |
+
config = {"configurable": {"thread_id": "session_1"}}
|
| 301 |
+
|
| 302 |
+
# Clear out any previous chart so we don't carry over stale plots
|
| 303 |
+
chart_path = os.path.join(self.temp_dir, "latest_chart.json")
|
| 304 |
+
if os.path.exists(chart_path):
|
| 305 |
+
os.remove(chart_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
+
# 4. Invoke the agent against the LLM, triggering our nested Tools locally
|
| 308 |
+
result = await agent.ainvoke(
|
| 309 |
+
{"messages": [{"role": "user", "content": query}]},
|
| 310 |
+
config,
|
| 311 |
+
)
|
|
|
|
| 312 |
|
| 313 |
+
# Extract content - handle both string and list formats
|
| 314 |
+
content = result["messages"][-1].content
|
| 315 |
+
|
| 316 |
+
# If content is a list (Gemini format), extract text from blocks
|
| 317 |
+
if isinstance(content, list):
|
| 318 |
+
text_parts = []
|
| 319 |
+
for block in content:
|
| 320 |
+
if isinstance(block, dict) and block.get("type") == "text":
|
| 321 |
+
text_parts.append(block.get("text", ""))
|
| 322 |
+
final_text = "\n".join(text_parts)
|
| 323 |
+
else:
|
| 324 |
+
final_text = content
|
| 325 |
+
|
| 326 |
+
# Check for generated chart
|
| 327 |
+
if os.path.exists(chart_path):
|
| 328 |
+
with open(chart_path, "r") as f:
|
| 329 |
+
chart_json = f.read()
|
| 330 |
+
return f"{final_text}\n\n===CHART===\n{chart_json}\n===ENDCHART==="
|
| 331 |
+
|
| 332 |
+
return final_text
|
| 333 |
+
|
| 334 |
+
finally:
|
| 335 |
+
# 5. Destroy the subprocess safely so we don't leak FastMCP zombies across Streamlit reruns
|
| 336 |
+
try:
|
| 337 |
+
await mcp_client.close()
|
| 338 |
+
except Exception as close_e:
|
| 339 |
+
print(f"Warning on closing MCP Client: {close_e}")
|
| 340 |
|
| 341 |
async def cleanup(self):
|
| 342 |
"""Delete temporary session files and close MCP client."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
if os.path.exists(self.temp_dir):
|
| 344 |
try:
|
| 345 |
shutil.rmtree(self.temp_dir)
|
requirements.txt
CHANGED
|
@@ -41,3 +41,9 @@ tenacity>=9.1.2
|
|
| 41 |
|
| 42 |
streamlit>=1.53.0
|
| 43 |
ddgs>=9.10.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
streamlit>=1.53.0
|
| 43 |
ddgs>=9.10.0
|
| 44 |
+
|
| 45 |
+
supabase>=2.28.0
|
| 46 |
+
plotly>=6.5.2
|
| 47 |
+
|
| 48 |
+
psycopg2-binary>=2.9.11
|
| 49 |
+
extra-streamlit-components
|