Sajil Awale commited on
Commit
7381684
·
1 Parent(s): dcb0e39

added multi user auth feture in fin adv

Browse files
Files changed (4) hide show
  1. app.py +550 -114
  2. mcp_server.py +104 -96
  3. money_rag.py +159 -83
  4. requirements.txt +6 -0
app.py CHANGED
@@ -3,129 +3,565 @@ import asyncio
3
  import os
4
  import json
5
  import plotly.io as pio
 
 
 
6
  from money_rag import MoneyRAG
7
 
8
- st.set_page_config(page_title="MoneyRAG", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Sidebar for Authentication
11
- with st.sidebar:
12
- st.header("Authentication")
13
- provider = st.selectbox("LLM Provider", ["Google", "OpenAI"])
14
-
15
- if provider == "Google":
16
- models = ["gemini-3-flash-preview", "gemini-3-pro-image-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"]
17
- embeddings = ["gemini-embedding-001"]
18
- else:
19
- models = ["gpt-5-mini", "gpt-5-nano", "gpt-4o-mini", "gpt-4o"]
20
- embeddings = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
21
-
22
- model_name = st.selectbox("Choose Decoder Model", models)
23
- embed_name = st.selectbox("Choose Embedding Model", embeddings)
24
- api_key = st.text_input("API Key", type="password")
25
-
26
- auth_button = st.button("Authenticate")
27
- if auth_button and api_key:
28
- st.session_state.rag = MoneyRAG(provider, model_name, embed_name, api_key)
29
- st.success("Authenticated!")
30
-
31
  st.divider()
32
- st.caption("**Contributors:**")
33
- st.caption("👤 [Sajil Awale](https://github.com/AwaleSajil)")
34
- st.caption("👤 [Simran KC](https://github.com/iamsims)")
35
-
36
- # Main Window
37
- st.title("MoneyRAG 💰")
38
- st.subheader("Where is my money?")
39
- st.markdown("""
40
- This app helps you analyze your personal finances using AI.
41
- Upload your bank/credit card CSV statements to chat with your data semantically.
42
- """)
43
-
44
- # Guides Section
45
- col1, col2 = st.columns(2)
46
-
47
- with col1:
48
- with st.expander("📚 How to get API keys"):
49
- st.markdown("**Google Gemini API:**")
50
- st.markdown("🔗 [Get API key from Google AI Studio](https://aistudio.google.com/app/apikey)")
51
- st.markdown("")
52
- st.markdown("**OpenAI API:**")
53
- st.markdown("🔗 [Get API key from OpenAI Platform](https://platform.openai.com/api-keys)")
54
-
55
- with col2:
56
- with st.expander("📥 How to download transaction history"):
57
- st.markdown("**Chase Credit Card:**")
58
- st.video("https://www.youtube.com/watch?v=gtAFaP9Lts8")
59
- st.markdown("")
60
- st.markdown("**Discover Credit Card:**")
61
- st.video("https://www.youtube.com/watch?v=cry6-H5b0PQ")
62
-
63
- # Architecture Diagram
64
- with st.expander("🏗️ How MoneyRAG Works"):
65
- st.image("architecture.svg", use_container_width=True)
66
-
67
- st.divider()
68
-
69
- if "rag" in st.session_state:
70
- uploaded_files = st.file_uploader("Upload CSV transactions", accept_multiple_files=True, type=['csv'])
71
 
72
- if uploaded_files:
73
- if st.button("Ingest Data"):
74
- temp_paths = []
75
- for uploaded_file in uploaded_files:
76
- path = os.path.join(st.session_state.rag.temp_dir, uploaded_file.name)
77
- with open(path, "wb") as f:
78
- f.write(uploaded_file.getbuffer())
79
- temp_paths.append(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- with st.spinner("Ingesting and vectorizing..."):
82
- asyncio.run(st.session_state.rag.setup_session(temp_paths))
83
- st.success("Data ready for chat!")
 
 
 
84
 
85
- # Chat Interface
86
- st.divider()
87
- if "messages" not in st.session_state:
88
- st.session_state.messages = []
89
-
90
- # Helper function to cleverly render either text or a Plotly chart
91
- def render_content(content):
92
- # We might have mixed text and charts delimited by ===CHART=== ... ===ENDCHART===
93
- if isinstance(content, str) and "===CHART===" in content:
94
- parts = content.split("===CHART===")
95
- # Render first text part
96
- st.markdown(parts[0].strip())
97
 
98
- for part in parts[1:]:
99
- if "===ENDCHART===" in part:
100
- chart_json, remaining_text = part.split("===ENDCHART===")
 
 
101
  try:
102
- fig = pio.from_json(chart_json.strip())
103
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  except Exception as e:
105
- st.error("Failed to render chart.")
106
-
107
- if remaining_text.strip():
108
- st.markdown(remaining_text.strip())
109
- else:
110
- st.markdown(content)
111
-
112
- # Render previous messages
113
- for message in st.session_state.messages:
114
- with st.chat_message(message["role"]):
115
- render_content(message["content"])
116
-
117
- # Handle new user input
118
- if prompt := st.chat_input("Ask about your spending..."):
119
- st.session_state.messages.append({"role": "user", "content": prompt})
120
- with st.chat_message("user"):
121
- st.markdown(prompt)
122
-
123
- with st.chat_message("assistant"):
124
- with st.spinner("Thinking..."):
125
- response = asyncio.run(st.session_state.rag.chat(prompt))
126
- render_content(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- else:
131
- st.info("Please authenticate in the sidebar to start.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import json
5
  import plotly.io as pio
6
+ from supabase import create_client, Client, ClientOptions
7
+ from dotenv import load_dotenv
8
+
9
  from money_rag import MoneyRAG
10
 
11
+ load_dotenv()
12
+
13
+ st.set_page_config(page_title="MoneyRAG", layout="wide", initial_sidebar_state="expanded")
14
+
15
+ # Initialize Supabase Client per request (NO CACHE) to ensure thread-safe auth headers
16
+ def get_supabase() -> Client:
17
+ url = os.environ.get("SUPABASE_URL")
18
+ key = os.environ.get("SUPABASE_KEY")
19
+ if "access_token" in st.session_state:
20
+ opts = ClientOptions(headers={"Authorization": f"Bearer {st.session_state.access_token}"})
21
+ return create_client(url, key, options=opts)
22
+ return create_client(url, key)
23
+
24
+ supabase = get_supabase()
25
+
26
+ def inject_css():
27
+ st.html("""
28
+ <link rel="preconnect" href="https://fonts.googleapis.com">
29
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap" rel="stylesheet">
30
+ <style>
31
+ /* ── Global Reset & Font ── */
32
+ html, body, [class*="css"] {
33
+ font-family: 'Inter', sans-serif !important;
34
+ }
35
+ #MainMenu, footer, header { visibility: hidden; }
36
+ .block-container { padding-top: 2rem !important; }
37
+
38
+ /* ── Background ── */
39
+ .stApp {
40
+ background: #0a0a0f;
41
+ color: #e2e8f0;
42
+ }
43
+
44
+ /* ── Sidebar ── */
45
+ [data-testid="stSidebar"] {
46
+ background: linear-gradient(180deg, #0f0f1a 0%, #0d0d16 100%) !important;
47
+ border-right: 1px solid rgba(99,102,241,0.15) !important;
48
+ }
49
+ [data-testid="stSidebar"] * { color: #cbd5e1 !important; }
50
+
51
+ /* ── Nav buttons ── */
52
+ div[data-testid="stSidebarContent"] .nav-btn > div > button {
53
+ width: 100% !important;
54
+ text-align: left !important;
55
+ border: none !important;
56
+ border-radius: 10px !important;
57
+ background: transparent !important;
58
+ color: #94a3b8 !important;
59
+ padding: 0.65rem 1rem !important;
60
+ font-size: 0.9rem !important;
61
+ font-weight: 500 !important;
62
+ transition: all 0.2s ease !important;
63
+ margin-bottom: 2px !important;
64
+ }
65
+ div[data-testid="stSidebarContent"] .nav-btn > div > button:hover {
66
+ background: rgba(99,102,241,0.1) !important;
67
+ color: #a5b4fc !important;
68
+ }
69
+ div[data-testid="stSidebarContent"] .nav-btn-active > div > button {
70
+ background: linear-gradient(135deg, rgba(99,102,241,0.25), rgba(139,92,246,0.2)) !important;
71
+ color: #a5b4fc !important;
72
+ border: 1px solid rgba(99,102,241,0.3) !important;
73
+ font-weight: 600 !important;
74
+ }
75
+
76
+ /* ── Primary Buttons ── */
77
+ .stButton > button[kind="primary"] {
78
+ background: linear-gradient(135deg, #6366f1, #8b5cf6) !important;
79
+ border: none !important;
80
+ border-radius: 10px !important;
81
+ color: white !important;
82
+ font-weight: 600 !important;
83
+ padding: 0.6rem 1.2rem !important;
84
+ transition: all 0.2s ease !important;
85
+ box-shadow: 0 4px 15px rgba(99,102,241,0.3) !important;
86
+ }
87
+ .stButton > button[kind="primary"]:hover {
88
+ transform: translateY(-1px) !important;
89
+ box-shadow: 0 6px 20px rgba(99,102,241,0.45) !important;
90
+ }
91
+
92
+ /* ── Secondary Buttons ── */
93
+ .stButton > button[kind="secondary"] {
94
+ background: rgba(255,255,255,0.05) !important;
95
+ border: 1px solid rgba(255,255,255,0.1) !important;
96
+ border-radius: 10px !important;
97
+ color: #cbd5e1 !important;
98
+ font-weight: 500 !important;
99
+ transition: all 0.2s ease !important;
100
+ }
101
+ .stButton > button[kind="secondary"]:hover {
102
+ background: rgba(255,255,255,0.08) !important;
103
+ border-color: rgba(99,102,241,0.35) !important;
104
+ }
105
+
106
+ /* ── Inputs ── */
107
+ .stTextInput input, .stSelectbox > div > div {
108
+ background: rgba(255,255,255,0.04) !important;
109
+ border: 1px solid rgba(255,255,255,0.1) !important;
110
+ border-radius: 10px !important;
111
+ color: #e2e8f0 !important;
112
+ transition: border 0.2s ease !important;
113
+ }
114
+ .stTextInput input:focus { border-color: #6366f1 !important; box-shadow: 0 0 0 2px rgba(99,102,241,0.2) !important; }
115
+
116
+ /* ── Glass Cards ── */
117
+ .glass-card {
118
+ background: rgba(255,255,255,0.04);
119
+ border: 1px solid rgba(255,255,255,0.08);
120
+ border-radius: 16px;
121
+ padding: 1.75rem;
122
+ backdrop-filter: blur(12px);
123
+ transition: border 0.2s ease;
124
+ }
125
+ .glass-card:hover { border-color: rgba(99,102,241,0.25); }
126
+
127
+ /* ── Hero ── */
128
+ .hero { text-align: center; padding: 4rem 1rem 2rem; }
129
+ .hero .badge {
130
+ display: inline-block;
131
+ background: linear-gradient(135deg, rgba(99,102,241,0.2), rgba(139,92,246,0.2));
132
+ border: 1px solid rgba(99,102,241,0.35);
133
+ color: #a5b4fc;
134
+ font-size: 0.78rem;
135
+ font-weight: 600;
136
+ letter-spacing: 0.1em;
137
+ text-transform: uppercase;
138
+ padding: 0.3rem 0.9rem;
139
+ border-radius: 99px;
140
+ margin-bottom: 1.25rem;
141
+ }
142
+ .hero h1 {
143
+ font-size: clamp(2.5rem, 6vw, 4rem);
144
+ font-weight: 800;
145
+ letter-spacing: -2px;
146
+ line-height: 1.1;
147
+ background: linear-gradient(135deg, #e2e8f0 30%, #a5b4fc);
148
+ -webkit-background-clip: text;
149
+ -webkit-text-fill-color: transparent;
150
+ margin-bottom: 1rem;
151
+ }
152
+ .hero p { font-size: 1.1rem; color: #64748b; max-width: 440px; margin: 0 auto; line-height: 1.7; }
153
+
154
+ /* ── Divider ── */
155
+ hr { border-color: rgba(255,255,255,0.07) !important; }
156
+
157
+ /* ── Expanders ── */
158
+ [data-testid="stExpander"] {
159
+ background: rgba(255,255,255,0.03) !important;
160
+ border: 1px solid rgba(255,255,255,0.07) !important;
161
+ border-radius: 12px !important;
162
+ }
163
+
164
+ /* ── Alerts ── */
165
+ [data-testid="stAlert"] { border-radius: 10px !important; }
166
+
167
+ /* ── Chat bubbles ── */
168
+ [data-testid="stChatMessage"] { border-radius: 12px !important; }
169
+ </style>
170
+ """)
171
+
172
+ def login_register_page():
173
+ inject_css()
174
+
175
+ st.html("""
176
+ <div class="hero">
177
+ <div class="badge">✦ AI-Powered Finance</div>
178
+ <h1>MoneyRAG</h1>
179
+ <p>Your personal finance analyst. Upload bank statements, ask questions, get insights — powered by AI.</p>
180
+ </div>
181
+ """)
182
+
183
+ col_l, col1, col2, col_r = st.columns([1, 2, 2, 1])
184
+
185
+ with col1:
186
+ st.markdown('<div class="glass-card">', unsafe_allow_html=True)
187
+ st.markdown("### Sign In")
188
+ email = st.text_input("Email", key="login_email", placeholder="you@example.com", label_visibility="collapsed")
189
+ password = st.text_input("Password", type="password", key="login_pass", placeholder="Password", label_visibility="collapsed")
190
+ if st.button("Sign In →", use_container_width=True, type="primary"):
191
+ if email and password:
192
+ with st.spinner(""):
193
+ try:
194
+ res = supabase.auth.sign_in_with_password({"email": email, "password": password})
195
+ st.session_state.user = res.user
196
+ st.session_state.access_token = res.session.access_token
197
+ st.query_params["t"] = res.session.access_token
198
+ try:
199
+ supabase.table("User").upsert({
200
+ "id": res.user.id,
201
+ "email": email,
202
+ "hashed_password": "managed_by_supabase_auth"
203
+ }).execute()
204
+ except Exception as sync_e:
205
+ print(f"Warning: Could not sync user: {sync_e}")
206
+ st.rerun()
207
+ except Exception as e:
208
+ st.error(f"Login failed: {e}")
209
+ st.markdown('</div>', unsafe_allow_html=True)
210
+
211
+ with col2:
212
+ st.markdown('<div class="glass-card">', unsafe_allow_html=True)
213
+ st.markdown("### Create Account")
214
+ reg_email = st.text_input("Email", key="reg_email", placeholder="you@example.com", label_visibility="collapsed")
215
+ reg_password = st.text_input("Password", type="password", key="reg_pass", placeholder="Password", label_visibility="collapsed")
216
+ if st.button("Create Account →", use_container_width=True):
217
+ if reg_email and reg_password:
218
+ with st.spinner(""):
219
+ try:
220
+ res = supabase.auth.sign_up({"email": reg_email, "password": reg_password})
221
+ if res.user:
222
+ try:
223
+ supabase.table("User").upsert({
224
+ "id": res.user.id, "email": reg_email,
225
+ "hashed_password": "managed_by_supabase_auth"
226
+ }).execute()
227
+ except Exception:
228
+ pass
229
+ st.success("Account created! Sign in on the left.")
230
+ except Exception as e:
231
+ st.error(f"Signup failed: {str(e)}")
232
+ st.markdown('</div>', unsafe_allow_html=True)
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  st.divider()
235
+ col3, col4, col5 = st.columns(3)
236
+ with col3:
237
+ with st.expander("📚 API Keys"):
238
+ st.markdown("**Google:** [AI Studio](https://aistudio.google.com/app/apikey)")
239
+ st.markdown("**OpenAI:** [Platform](https://platform.openai.com/api-keys)")
240
+ with col4:
241
+ with st.expander("📥 Export Transactions"):
242
+ st.markdown("**Chase:** [Video guide](https://www.youtube.com/watch?v=gtAFaP9Lts8)")
243
+ st.markdown("**Discover:** [Video guide](https://www.youtube.com/watch?v=cry6-H5b0PQ)")
244
+ with col5:
245
+ with st.expander("🏗️ Architecture"):
246
+ st.image("architecture.svg", use_container_width=True)
247
+
248
+ def load_user_config():
249
+ try:
250
+ # Always get a fresh client with the current auth token
251
+ client = get_supabase()
252
+ res = client.table("AccountConfig").select("*").eq("user_id", st.session_state.user.id).execute()
253
+ if res.data:
254
+ return res.data[0]
255
+ except Exception as e:
256
+ print(f"Failed to load config: {e}")
257
+ return None
258
+
259
+ def main_app_view():
260
+ inject_css()
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
+ # Use session state for active nav tab
263
+ if "nav" not in st.session_state:
264
+ st.session_state.nav = "Chat"
265
+
266
+ with st.sidebar:
267
+ st.markdown(f"**MoneyRAG** 💰")
268
+ st.caption(st.session_state.user.email)
269
+ st.divider()
270
+
271
+ # Modern nav buttons using st.button styled via CSS
272
+ for label, icon in [("Chat", "💬"), ("Ingest Data", "📥"), ("Account Config", "⚙️")]:
273
+ is_active = st.session_state.nav == label
274
+ css_class = "nav-btn-active" if is_active else "nav-btn"
275
+ st.markdown(f'<div class="{css_class}">', unsafe_allow_html=True)
276
+ if st.button(f"{icon} {label}", key=f"nav_{label}", use_container_width=True):
277
+ st.session_state.nav = label
278
+ st.rerun()
279
+ st.markdown('</div>', unsafe_allow_html=True)
280
+
281
+ st.divider()
282
+ if st.button("Log Out", use_container_width=True):
283
+ supabase.auth.sign_out()
284
+ if "t" in st.query_params:
285
+ del st.query_params["t"]
286
+ for key in list(st.session_state.keys()):
287
+ del st.session_state[key]
288
+ st.rerun()
289
+
290
+ st.divider()
291
+ st.caption("[Sajil Awale](https://github.com/AwaleSajil) · [Simran KC](https://github.com/iamsims)")
292
+
293
+ nav = st.session_state.nav
294
+
295
+ # Always reload config fresh (cached None from unauthenticated loads will persist otherwise)
296
+ config = load_user_config()
297
+
298
+ if nav == "Account Config":
299
+ st.header("⚙️ Account Configuration")
300
+ st.write("Configure your AI providers and models here.")
301
+
302
+ current_provider = config['llm_provider'] if config else "Google"
303
+ current_key = config['api_key'] if config else ""
304
+ current_decode = config.get('decode_model', "gemini-3-flash-preview") if config else "gemini-3-flash-preview"
305
+ current_embed = config.get('embedding_model', "gemini-embedding-001") if config else "gemini-embedding-001"
306
+ # Provider Selection - Default to Google
307
+ provider = st.selectbox("LLM Provider", ["Google", "OpenAI"], index=0 if (not config or config['llm_provider'] == "Google") else 1)
308
+
309
+ if provider == "Google":
310
+ models = ["gemini-3-flash-preview", "gemini-3-pro-image-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"]
311
+ embeddings = ["gemini-embedding-001"]
312
+ else:
313
+ models = ["gpt-5-mini", "gpt-5-nano", "gpt-4o-mini", "gpt-4o"]
314
+ embeddings = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
315
+
316
+ with st.form("config_form"):
317
+ api_key = st.text_input("API Key", type="password", value=current_key)
318
 
319
+ col1, col2 = st.columns(2)
320
+ with col1:
321
+ # Default to gemini-3 if no config exists
322
+ m_default_val = current_decode if config else "gemini-3-flash-preview"
323
+ m_idx = models.index(m_default_val) if m_default_val in models else 0
324
+ final_decode = st.selectbox("Select Model", models, index=m_idx)
325
 
326
+ with col2:
327
+ e_idx = embeddings.index(current_embed) if (config and current_embed in embeddings) else 0
328
+ final_embed = st.selectbox("Select Embedding Model", embeddings, index=e_idx)
 
 
 
 
 
 
 
 
 
329
 
330
+ submitted = st.form_submit_button("Save Configuration", type="primary", use_container_width=True)
331
+ if submitted:
332
+ if not api_key:
333
+ st.error("API Key is required.")
334
+ else:
335
  try:
336
+ record = {
337
+ "user_id": st.session_state.user.id,
338
+ "llm_provider": provider,
339
+ "api_key": api_key,
340
+ "decode_model": final_decode,
341
+ "embedding_model": final_embed
342
+ }
343
+ if config:
344
+ supabase.table("AccountConfig").update(record).eq("id", config['id']).execute()
345
+ else:
346
+ supabase.table("AccountConfig").insert(record).execute()
347
+
348
+ st.session_state.user_config = load_user_config()
349
+ # Reinitialize RAG with new config
350
+ if "rag" in st.session_state:
351
+ del st.session_state.rag
352
+
353
+ st.success("Configuration saved successfully!")
354
  except Exception as e:
355
+ st.error(f"Failed to save configuration: {e}")
356
+
357
+ elif nav == "Ingest Data":
358
+ st.header("📥 Ingest Data")
359
+
360
+ uploaded_files = st.file_uploader("Upload CSV transactions", accept_multiple_files=True, type=['csv'])
361
+ if uploaded_files:
362
+ if st.button("Ingest Selected Files", type="primary"):
363
+ if not config:
364
+ st.error("Please set up your Account Config first!")
365
+ return
366
+
367
+ # Initialize RAG if needed
368
+ if "rag" not in st.session_state:
369
+ st.session_state.rag = MoneyRAG(
370
+ llm_provider=config["llm_provider"],
371
+ model_name=config.get("decode_model", "gemini-2.5-pro"),
372
+ embedding_model_name=config.get("embedding_model", "gemini-embedding-001"),
373
+ api_key=config["api_key"],
374
+ user_id=st.session_state.user.id,
375
+ access_token=st.session_state.access_token
376
+ )
377
+
378
+ csv_files_info = []
379
+ user_id = st.session_state.user.id
380
+
381
+ with st.spinner("Uploading to Supabase Storage & Processing..."):
382
+ for uploaded_file in uploaded_files:
383
+ # 1. Save temp locally for pandas parsing
384
+ local_path = os.path.join(st.session_state.rag.temp_dir, uploaded_file.name)
385
+ with open(local_path, "wb") as f:
386
+ f.write(uploaded_file.getbuffer())
387
+
388
+ # 2. Upload raw file to Supabase Object Storage
389
+ s3_key = f"{user_id}/csvs/{uploaded_file.name}"
390
+ try:
391
+ supabase.storage.from_("money-rag-files").upload(
392
+ file=local_path,
393
+ path=s3_key,
394
+ file_options={"content-type": "text/csv", "upsert": "true"}
395
+ )
396
+
397
+ # 3. Log the upload in the CSVFile table
398
+ csv_record = supabase.table("CSVFile").insert({
399
+ "user_id": user_id,
400
+ "filename": uploaded_file.name,
401
+ "s3_key": s3_key
402
+ }).execute()
403
+
404
+ csv_id = csv_record.data[0]['id']
405
+ csv_files_info.append({"path": local_path, "csv_id": csv_id})
406
+
407
+ except Exception as e:
408
+ st.error(f"Error uploading {uploaded_file.name}: {e}")
409
+ continue
410
+
411
+ # 4. Trigger the LLM parsing, routing CSV data to Supabase Postgres
412
+ if csv_files_info:
413
+ asyncio.run(st.session_state.rag.setup_session(csv_files_info))
414
+ st.success("Data uploaded, parsed, and vectorized securely!")
415
+ st.rerun()
416
+
417
+ st.divider()
418
+ st.subheader("Your Uploaded Files")
419
+ try:
420
+ res = supabase.table("CSVFile").select("*").eq("user_id", st.session_state.user.id).execute()
421
+ files = res.data
422
+
423
+ if not files:
424
+ st.info("No files uploaded yet.")
425
+ else:
426
+ for f in files:
427
+ col_file, col_del = st.columns([4, 1])
428
+ with col_file:
429
+ st.write(f"📄 **{f['filename']}** (Uploaded: {f['upload_date'][:10]})")
430
+ with col_del:
431
+ if st.button("Delete", key=f"del_{f['id']}"):
432
+ st.session_state[f"confirm_del_{f['id']}"] = True
433
+
434
+ if st.session_state.get(f"confirm_del_{f['id']}", False):
435
+ st.warning("Are you sure? This permanently deletes the file from Cloud Storage, the SQL Database, and the Vector Index.")
436
+ col_y, col_n = st.columns(2)
437
+ with col_y:
438
+ if st.button("Yes, Delete", key=f"yes_{f['id']}", type="primary"):
439
+ with st.spinner("Purging file data..."):
440
+ try:
441
+ # Delete from storage
442
+ supabase.storage.from_("money-rag-files").remove([f['s3_key']])
443
+ except Exception as e:
444
+ print(f"Warning storage delete failed: {e}")
445
+
446
+ # Use initialized RAG to delete from Vectors and Postgres
447
+ if "rag" not in st.session_state and config:
448
+ st.session_state.rag = MoneyRAG(
449
+ llm_provider=config["llm_provider"],
450
+ model_name=config.get("decode_model", "gemini-2.5-pro"),
451
+ embedding_model_name=config.get("embedding_model", "gemini-embedding-001"),
452
+ api_key=config["api_key"],
453
+ user_id=st.session_state.user.id,
454
+ access_token=st.session_state.access_token
455
+ )
456
+ if "rag" in st.session_state:
457
+ asyncio.run(st.session_state.rag.delete_file(f['id']))
458
+ else:
459
+ # Fallback if no RAG config to just delete from Postgres at least
460
+ supabase.table("Transaction").delete().eq("source_csv_id", f['id']).execute()
461
+ supabase.table("CSVFile").delete().eq("id", f['id']).execute()
462
+
463
+ del st.session_state[f"confirm_del_{f['id']}"]
464
+ st.success(f"Deleted {f['filename']}!")
465
+ st.rerun()
466
+
467
+ with col_n:
468
+ if st.button("Cancel", key=f"cancel_{f['id']}"):
469
+ del st.session_state[f"confirm_del_{f['id']}"]
470
+ st.rerun()
471
+
472
+ except Exception as e:
473
+ st.error(f"Failed to load files: {e}")
474
+
475
+ elif nav == "Chat":
476
+ st.header("💬 Financial Assistant")
477
+ if not config:
478
+ st.warning("Please configure your Account Config (API Key) first!")
479
+ return
480
+
481
+ if "rag" not in st.session_state:
482
+ st.session_state.rag = MoneyRAG(
483
+ llm_provider=config["llm_provider"],
484
+ model_name=config.get("decode_model", "gemini-2.5-pro"),
485
+ embedding_model_name=config.get("embedding_model", "gemini-embedding-001"),
486
+ api_key=config["api_key"],
487
+ user_id=st.session_state.user.id,
488
+ access_token=st.session_state.access_token
489
+ )
490
+
491
+ if "messages" not in st.session_state:
492
+ st.session_state.messages = []
493
+
494
+ # Show file ingestion status
495
+ try:
496
+ client = get_supabase()
497
+ files_res = client.table("CSVFile").select("id, filename").eq("user_id", st.session_state.user.id).execute()
498
+ file_count = len(files_res.data) if files_res.data else 0
499
+ if file_count == 0:
500
+ st.warning("⚠️ No data loaded yet. Go to **Ingest Data** to upload a CSV file before chatting.")
501
+ else:
502
+ names = ", ".join(f['filename'] for f in files_res.data[:3])
503
+ suffix = f" + {file_count - 3} more" if file_count > 3 else ""
504
+ st.info(f"📊 **{file_count} file{'s' if file_count > 1 else ''} loaded:** {names}{suffix}")
505
+ except Exception:
506
+ pass # Don't break chat if the status check fails
507
+
508
+
509
+ # Helper function to cleverly render either text or a Plotly chart
510
+ def render_content(content):
511
+ if isinstance(content, str) and "===CHART===" in content:
512
+ parts = content.split("===CHART===")
513
+ st.markdown(parts[0].strip())
514
 
515
+ for part in parts[1:]:
516
+ if "===ENDCHART===" in part:
517
+ chart_json, remaining_text = part.split("===ENDCHART===")
518
+ try:
519
+ fig = pio.from_json(chart_json.strip())
520
+ st.plotly_chart(fig, use_container_width=True)
521
+ except Exception as e:
522
+ st.error("Failed to render chart.")
523
+
524
+ if remaining_text.strip():
525
+ st.markdown(remaining_text.strip())
526
+ else:
527
+ st.markdown(content)
528
+
529
+ # Render previous messages
530
+ for message in st.session_state.messages:
531
+ with st.chat_message(message["role"]):
532
+ render_content(message["content"])
533
+
534
+ # Handle new user input
535
+ if prompt := st.chat_input("Ask about your spending..."):
536
+ st.session_state.messages.append({"role": "user", "content": prompt})
537
+ with st.chat_message("user"):
538
+ st.markdown(prompt)
539
+
540
+ with st.chat_message("assistant"):
541
+ with st.spinner("Thinking..."):
542
+ try:
543
+ response = asyncio.run(st.session_state.rag.chat(prompt))
544
+ render_content(response)
545
+ st.session_state.messages.append({"role": "assistant", "content": response})
546
+ except Exception as e:
547
+ st.error(f"Error during chat: {e}")
548
 
549
+ if __name__ == "__main__":
550
+ # Attempt to restore session from query params if page was refreshed
551
+ if "user" not in st.session_state:
552
+ token_from_url = st.query_params.get("t")
553
+ if token_from_url:
554
+ try:
555
+ res = supabase.auth.get_user(token_from_url)
556
+ if res and res.user:
557
+ st.session_state.user = res.user
558
+ st.session_state.access_token = token_from_url
559
+ except Exception:
560
+ # Token is invalid/expired - clear it from the URL too
561
+ if "t" in st.query_params:
562
+ del st.query_params["t"]
563
+
564
+ if "user" not in st.session_state:
565
+ login_register_page()
566
+ else:
567
+ main_app_view()
mcp_server.py CHANGED
@@ -6,55 +6,66 @@ from qdrant_client import QdrantClient
6
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
7
  from dotenv import load_dotenv
8
  import os
 
9
 
10
  import shutil
11
 
 
 
12
  # Load environment variables (API keys, etc.)
13
  load_dotenv()
14
 
15
  # Define paths to your data
16
- # For Hugging Face Spaces (Ephemeral):
17
- # We use a temporary directory that gets wiped on restart.
18
- # If DATA_DIR is set (e.g., by your deployment config), use it.
19
  DATA_DIR = os.getenv("DATA_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_data"))
20
- QDRANT_PATH = os.path.join(DATA_DIR, "qdrant_db")
21
- DB_PATH = os.path.join(DATA_DIR, "money_rag.db")
22
 
23
  # Initialize the MCP Server
24
  mcp = FastMCP("Money RAG Financial Analyst")
25
 
26
- import sqlite3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def get_schema_info() -> str:
29
- """Get database schema information."""
30
- if not os.path.exists(DB_PATH):
31
- return "Database file does not exist yet. Please upload data."
32
-
33
- try:
34
- conn = sqlite3.connect(DB_PATH)
35
- cursor = conn.cursor()
36
-
37
- # Get all tables
38
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
39
- tables = cursor.fetchall()
40
-
41
- schema_info = []
42
- for (table_name,) in tables:
43
- schema_info.append(f"\nTable: {table_name}")
44
-
45
- # Get column info for each table
46
- cursor.execute(f"PRAGMA table_info({table_name});")
47
- columns = cursor.fetchall()
48
-
49
- schema_info.append("Columns:")
50
- for col in columns:
51
- col_id, col_name, col_type, not_null, default_val, pk = col
52
- schema_info.append(f" - {col_name} ({col_type})")
53
-
54
- conn.close()
55
- return "\n".join(schema_info)
56
- except Exception as e:
57
- return f"Error reading schema: {e}"
58
 
59
 
60
  @mcp.resource("schema://database/tables")
@@ -64,7 +75,17 @@ def get_database_schema() -> str:
64
 
65
  @mcp.tool()
66
  def query_database(query: str) -> str:
67
- """Execute a SELECT query on the money_rag SQLite database.
 
 
 
 
 
 
 
 
 
 
68
 
69
  Args:
70
  query: The SQL SELECT query to execute
@@ -78,33 +99,32 @@ def query_database(query: str) -> str:
78
  - 'amount' column: positive values = spending, negative values = payments/refunds
79
 
80
  Example queries:
81
- - Find Walmart spending: SELECT SUM(amount) FROM transactions WHERE description LIKE '%Walmart%' AND amount > 0;
82
- - List recent transactions: SELECT transaction_date, description, amount, category FROM transactions ORDER BY transaction_date DESC LIMIT 5;
83
- - Spending by category: SELECT category, SUM(amount) FROM transactions WHERE amount > 0 GROUP BY category;
84
  """
85
- if not os.path.exists(DB_PATH):
86
- return "Database file does not exist yet. Please upload data."
87
-
88
  # Security: Only allow SELECT queries
89
  query_upper = query.strip().upper()
90
- if not query_upper.startswith("SELECT") and not query_upper.startswith("PRAGMA"):
91
- return "Error: Only SELECT and PRAGMA queries are allowed"
92
 
93
  # Forbidden operations
94
- forbidden = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "REPLACE", "TRUNCATE", "ATTACH", "DETACH"]
95
- # Check for forbidden words as standalone words to avoid false positives (e.g. "update_date" column)
96
- # Simple check: space-surrounded or end-of-string
97
  if any(f" {word} " in f" {query_upper} " for word in forbidden):
98
  return f"Error: Query contains forbidden operation. Only SELECT queries allowed."
99
 
 
 
 
 
100
  try:
101
- conn = sqlite3.connect(DB_PATH)
102
  cursor = conn.cursor()
103
  cursor.execute(query)
104
  results = cursor.fetchall()
105
 
106
  # Get column names to make result more readable
107
- column_names = [description[0] for description in cursor.description] if cursor.description else []
108
 
109
  conn.close()
110
 
@@ -118,28 +138,20 @@ def query_database(query: str) -> str:
118
  formatted_results.append(str(row))
119
 
120
  return "\n".join(formatted_results)
121
- except sqlite3.Error as e:
122
- return f"Error: {str(e)}"
123
 
124
  def get_vector_store():
125
  """Initialize connection to the Qdrant vector store"""
126
  # Initialize Embedding Model using Google AI Studio
127
- embeddings = GoogleGenerativeAIEmbeddings(model="text-embedding-004")
128
 
129
- # Connect to Qdrant (Persistent Disk Mode at specific path)
130
- # We ensure the directory exists so Qdrant can write to it.
131
- os.makedirs(QDRANT_PATH, exist_ok=True)
132
-
133
- client = QdrantClient(path=QDRANT_PATH)
134
-
135
- # Check if collection exists (it might be empty in a new ephemeral session)
136
- collections = client.get_collections().collections
137
- collection_names = [c.name for c in collections]
138
 
139
- if "transactions" not in collection_names:
140
- # In a real app, you would probably trigger ingestion here or handle the empty state
141
- pass
142
-
143
  return QdrantVectorStore(
144
  client=client,
145
  collection_name="transactions",
@@ -159,20 +171,22 @@ def semantic_search(query: str, top_k: int = 5) -> str:
159
  top_k: Number of results to return (default 5).
160
  """
161
  try:
 
162
  vector_store = get_vector_store()
163
 
164
- # Safety check: if no data has been ingested yet
165
- if not os.path.exists(QDRANT_PATH) or not os.listdir(QDRANT_PATH):
166
- return "No matching transactions found (Database is empty. Please upload data first)."
 
 
167
 
168
- results = vector_store.similarity_search(query, k=top_k)
169
 
170
  if not results:
171
  return "No matching transactions found."
172
 
173
  output = []
174
  for doc in results:
175
- # Format the output clearly for the LLM/User
176
  amount = doc.metadata.get('amount', 'N/A')
177
  date = doc.metadata.get('transaction_date', 'N/A')
178
  output.append(f"Date: {date} | Match: {doc.page_content} | Amount: {amount}")
@@ -184,25 +198,29 @@ def semantic_search(query: str, top_k: int = 5) -> str:
184
 
185
 
186
  @mcp.tool()
187
- def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_col: str, title: str) -> str:
188
  """
189
- Generate an interactive Plotly chart from the money_rag SQLite database.
190
- Use this proactively whenever a visual representation of data would be helpful.
191
-
192
- CRITICAL INSTRUCTIONS:
193
- 1. Write a valid SQLite SELECT query.
194
- 2. Aggregate data appropriately (e.g., use GROUP BY for pie/bar charts).
195
- 3. Pass the exact column names from your query to x_col and y_col.
196
-
197
  Args:
198
- sql_query: The SQL SELECT query (e.g. "SELECT category, SUM(amount) as total FROM transactions GROUP BY category")
199
- chart_type: Must be exactly "bar", "pie", or "line"
200
- x_col: Column name from query for X-axis (or labels for pie)
201
- y_col: Column name from query for Y-axis (or values for pie)
202
- title: Title of the chart
 
 
 
 
 
203
  """
204
  try:
205
- conn = sqlite3.connect(DB_PATH)
 
 
 
 
206
  df = pd.read_sql_query(sql_query, conn)
207
  conn.close()
208
  if df.empty:
@@ -226,17 +244,7 @@ def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_co
226
  return f'{{"error": "Failed to generate chart: {str(e)}"}}'
227
 
228
 
229
- # A helper to clear data (useful for session reset)
230
- @mcp.tool()
231
- def clear_database() -> str:
232
- """Clear all stored transaction data to reset the session."""
233
- try:
234
- if os.path.exists(DATA_DIR):
235
- shutil.rmtree(DATA_DIR)
236
- os.makedirs(DATA_DIR)
237
- return "Database cleared successfully."
238
- except Exception as e:
239
- return f"Error clearing database: {e}"
240
 
241
  if __name__ == "__main__":
242
  # Runs the server over stdio
 
6
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
7
  from dotenv import load_dotenv
8
  import os
9
+ from typing import Optional
10
 
11
  import shutil
12
 
13
+ from textwrap import dedent
14
+
15
  # Load environment variables (API keys, etc.)
16
  load_dotenv()
17
 
18
  # Define paths to your data
 
 
 
19
  DATA_DIR = os.getenv("DATA_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_data"))
 
 
20
 
21
  # Initialize the MCP Server
22
  mcp = FastMCP("Money RAG Financial Analyst")
23
 
24
+ import psycopg2
25
+ from supabase import create_client, Client
26
+
27
+ def get_db_connection():
28
+ """Returns a psycopg2 connection to Supabase Postgres."""
29
+ # Supabase provides postgres connection strings, but typically doesn't default in plain OS vars unless you build it
30
+ # Supabase gives a postgres:// connection string in the dashboard under Database Settings.
31
+ # Alternatively we can build it manually or just use the Supabase python client.
32
+ # To support raw LLM SQL, we use psycopg2 instead of Supabase client.
33
+ db_url = os.environ.get("DATABASE_URL")
34
+ if not db_url:
35
+ raise ValueError("DATABASE_URL must be defined to construct raw SQL connections.")
36
+ return psycopg2.connect(db_url)
37
+
38
+ def get_current_user_id() -> str:
39
+ user_id = os.environ.get("CURRENT_USER_ID")
40
+ if not user_id:
41
+ raise ValueError("CURRENT_USER_ID not injected into MCP environment!")
42
+ return user_id
43
 
44
  def get_schema_info() -> str:
45
+ """Get database schema information for Postgres tables."""
46
+ return dedent("""
47
+ Here is the PostgreSQL database schema for the authenticated user's data.
48
+
49
+ CRITICAL RULE:
50
+ You MUST add `WHERE user_id = '{current_user_id}'` to EVERY SINGLE query you write.
51
+ Never query data without filtering by user_id!
52
+
53
+ TABLE: "Transaction"
54
+ Columns:
55
+ - id (UUID)
56
+ - user_id (UUID)
57
+ - trans_date (DATE)
58
+ - description (TEXT)
59
+ - amount (DECIMAL)
60
+ - category (VARCHAR)
61
+
62
+ TABLE: "TransactionDetail"
63
+ Columns:
64
+ - id (UUID)
65
+ - transaction_id (UUID)
66
+ - item_description (TEXT)
67
+ - item_total_price (DECIMAL)
68
+ """)
 
 
 
 
 
69
 
70
 
71
  @mcp.resource("schema://database/tables")
 
75
 
76
  @mcp.tool()
77
  def query_database(query: str) -> str:
78
+ """
79
+ Execute a raw SQL query against the Postgres database.
80
+ The main table is named "Transaction" (you MUST INCLUDE QUOTES in your SQL!).
81
+ IMPORTANT STRICT SCHEMA:
82
+ - id (UUID)
83
+ - user_id (UUID text)
84
+ - trans_date (DATE)
85
+ - description (TEXT)
86
+ - amount (NUMERIC)
87
+ - category (TEXT)
88
+ - enriched_info (TEXT)
89
 
90
  Args:
91
  query: The SQL SELECT query to execute
 
99
  - 'amount' column: positive values = spending, negative values = payments/refunds
100
 
101
  Example queries:
102
+ - Find Walmart spending: SELECT SUM(amount) FROM "Transaction" WHERE description LIKE '%Walmart%' AND amount > 0;
103
+ - List recent transactions: SELECT trans_date, description, amount, category FROM "Transaction" ORDER BY trans_date DESC LIMIT 5;
104
+ - Spending by category: SELECT category, SUM(amount) FROM "Transaction" WHERE amount > 0 GROUP BY category;
105
  """
 
 
 
106
  # Security: Only allow SELECT queries
107
  query_upper = query.strip().upper()
108
+ if not query_upper.startswith("SELECT") and not query_upper.startswith("WITH"):
109
+ return "Error: Only SELECT queries are allowed"
110
 
111
  # Forbidden operations
112
+ forbidden = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "REPLACE", "TRUNCATE"]
 
 
113
  if any(f" {word} " in f" {query_upper} " for word in forbidden):
114
  return f"Error: Query contains forbidden operation. Only SELECT queries allowed."
115
 
116
+ user_id = get_current_user_id()
117
+ if user_id not in query:
118
+ return f"Error: You forgot to include the security filter (WHERE user_id = '{user_id}') in your query! Try again."
119
+
120
  try:
121
+ conn = get_db_connection()
122
  cursor = conn.cursor()
123
  cursor.execute(query)
124
  results = cursor.fetchall()
125
 
126
  # Get column names to make result more readable
127
+ column_names = [desc[0] for desc in cursor.description] if cursor.description else []
128
 
129
  conn.close()
130
 
 
138
  formatted_results.append(str(row))
139
 
140
  return "\n".join(formatted_results)
141
+ except psycopg2.Error as e:
142
+ return f"Database Error: {str(e)}"
143
 
144
  def get_vector_store():
145
  """Initialize connection to the Qdrant vector store"""
146
  # Initialize Embedding Model using Google AI Studio
147
+ embeddings = GoogleGenerativeAIEmbeddings(model="gemini-embedding-001")
148
 
149
+ # Connect to Qdrant Cloud
150
+ client = QdrantClient(
151
+ url=os.getenv("QDRANT_URL"),
152
+ api_key=os.getenv("QDRANT_API_KEY"),
153
+ )
 
 
 
 
154
 
 
 
 
 
155
  return QdrantVectorStore(
156
  client=client,
157
  collection_name="transactions",
 
171
  top_k: Number of results to return (default 5).
172
  """
173
  try:
174
+ user_id = get_current_user_id()
175
  vector_store = get_vector_store()
176
 
177
+ # Apply strict multi-tenant filtering based on the payload we injected in money_rag.py
178
+ from qdrant_client.http import models
179
+ filter = models.Filter(
180
+ must=[models.FieldCondition(key="metadata.user_id", match=models.MatchValue(value=user_id))]
181
+ )
182
 
183
+ results = vector_store.similarity_search(query, k=top_k, filter=filter)
184
 
185
  if not results:
186
  return "No matching transactions found."
187
 
188
  output = []
189
  for doc in results:
 
190
  amount = doc.metadata.get('amount', 'N/A')
191
  date = doc.metadata.get('transaction_date', 'N/A')
192
  output.append(f"Date: {date} | Match: {doc.page_content} | Amount: {amount}")
 
198
 
199
 
200
  @mcp.tool()
201
+ def generate_interactive_chart(sql_query: str, chart_type: str, x_col: str, y_col: str, title: str, color_col: Optional[str] = None) -> str:
202
  """
203
+ Generate an interactive Plotly chart using SQL data.
204
+ IMPORTANT: The table name MUST be "Transaction" exactly with quotes.
205
+
 
 
 
 
 
206
  Args:
207
+ sql_query: The SQL SELECT query to retrieve the data for the chart from the "Transaction" table.
208
+ - Must use 'user_id' filter.
209
+ chart_type: The type of chart: 'bar', 'line', 'pie', 'scatter'
210
+ x_col: The name of the column to use for the X axis (or labels for pie charts)
211
+ y_col: The name of the column to use for the Y axis (or values for pie charts)
212
+ title: The title of the chart
213
+ color_col: (Optional) Column to use for color grouping
214
+
215
+ Returns:
216
+ A natural language summary confirming chart generation.
217
  """
218
  try:
219
+ user_id = get_current_user_id()
220
+ if user_id not in sql_query:
221
+ return f'{{"error": "You forgot the WHERE user_id = \\"{user_id}\\" security clause!"}}'
222
+
223
+ conn = get_db_connection()
224
  df = pd.read_sql_query(sql_query, conn)
225
  conn.close()
226
  if df.empty:
 
244
  return f'{{"error": "Failed to generate chart: {str(e)}"}}'
245
 
246
 
247
+
 
 
 
 
 
 
 
 
 
 
248
 
249
  if __name__ == "__main__":
250
  # Runs the server over stdio
money_rag.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import uuid
3
  import asyncio
4
  import pandas as pd
@@ -21,16 +22,34 @@ from langgraph.checkpoint.memory import InMemorySaver
21
  from langchain.agents import create_agent
22
  from langchain_community.tools import DuckDuckGoSearchRun
23
  from langchain_mcp_adapters.client import MultiServerMCPClient
 
24
 
25
  # Import specific embeddings
26
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
27
  from langchain_openai import OpenAIEmbeddings
28
 
 
 
 
 
 
29
  class MoneyRAG:
30
- def __init__(self, llm_provider: str, model_name: str, embedding_model_name: str, api_key: str):
31
  self.llm_provider = llm_provider.lower()
32
  self.model_name = model_name
33
  self.embedding_model_name = embedding_model_name
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # Set API Keys
36
  if self.llm_provider == "google":
@@ -60,17 +79,18 @@ class MoneyRAG:
60
  self.mcp_client: Optional[MultiServerMCPClient] = None
61
  self.search_tool = DuckDuckGoSearchRun()
62
  self.merchant_cache = {} # Session-based cache for merchant enrichment
 
63
 
64
- async def setup_session(self, csv_paths: List[str]):
65
  """Ingests CSVs and sets up DBs."""
66
- for path in csv_paths:
67
- await self._ingest_csv(path)
 
68
 
69
  self.db = SQLDatabase.from_uri(f"sqlite:///{self.db_path}")
70
  self.vector_store = self._sync_to_qdrant()
71
- await self._init_agent()
72
 
73
- async def _ingest_csv(self, file_path):
74
  df = pd.read_csv(file_path)
75
  headers = df.columns.tolist()
76
  sample_data = df.head(10).to_json()
@@ -108,14 +128,16 @@ class MoneyRAG:
108
  mapping = await chain.ainvoke({"headers": headers, "sample": sample_data, "filename": os.path.basename(file_path)})
109
 
110
  standard_df = pd.DataFrame()
111
- standard_df['id'] = [str(uuid.uuid4()) for _ in range(len(df))]
112
- standard_df['transaction_date'] = pd.to_datetime(df[mapping['date_col']])
 
113
  standard_df['description'] = df[mapping['desc_col']]
 
 
114
 
115
  raw_amounts = pd.to_numeric(df[mapping['amount_col']])
116
  standard_df['amount'] = raw_amounts * -1 if mapping['sign_convention'] == "spending_is_negative" else raw_amounts
117
  standard_df['category'] = df[mapping.get('category_col')] if mapping.get('category_col') else 'Uncategorized'
118
- standard_df['source_file'] = os.path.basename(file_path)
119
 
120
  # --- Async Enrichment Step ---
121
  print(f" ✨ Enriching descriptions for {os.path.basename(file_path)}...")
@@ -143,29 +165,49 @@ class MoneyRAG:
143
  desc_map = dict(zip(unique_descriptions, enrichment_results))
144
  standard_df['enriched_info'] = standard_df['description'].map(desc_map).fillna("")
145
 
146
- conn = sqlite3.connect(self.db_path)
147
- standard_df.to_sql("transactions", conn, if_exists="append", index=False)
148
- conn.close()
 
 
 
 
 
 
 
149
 
150
  def _sync_to_qdrant(self):
151
- client = QdrantClient(path=self.qdrant_path)
 
 
 
 
152
  collection = "transactions"
153
 
154
- conn = sqlite3.connect(self.db_path)
155
- df = pd.read_sql_query("SELECT * FROM transactions", conn)
156
- conn.close()
157
 
158
  # Check for empty dataframe
159
  if df.empty:
160
- raise ValueError("No transactions found in database. Please ingest CSV files first.")
161
 
162
  # Dynamically detect embedding dimension
163
  sample_embedding = self.embeddings.embed_query("test")
164
  embedding_dim = len(sample_embedding)
165
 
166
- client.recreate_collection(
 
 
 
 
 
 
 
 
167
  collection_name=collection,
168
- vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
 
169
  )
170
 
171
  vs = QdrantVectorStore(client=client, collection_name=collection, embedding=self.embeddings)
@@ -180,90 +222,124 @@ class MoneyRAG:
180
  else:
181
  texts.append(base_text)
182
 
183
- metadatas = df[['id', 'amount', 'category', 'transaction_date']].to_dict('records')
184
- for m in metadatas: m['transaction_date'] = str(m['transaction_date'])
 
 
 
185
 
186
- vs.add_texts(texts=texts, metadatas=metadatas)
 
 
 
 
 
 
 
187
  return vs
188
 
189
- async def _init_agent(self):
190
- # 1. Initialize MCP client with absolute path to server
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  server_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mcp_server.py")
192
 
193
- self.mcp_client = MultiServerMCPClient(
194
  {
195
  "money_rag": {
196
  "transport": "stdio",
197
- "command": "python",
198
  "args": [server_path],
199
- "env": os.environ.copy(),
200
  }
201
  }
202
  )
203
 
204
- # 2. Get tools from MCP server
205
- mcp_tools = await self.mcp_client.get_tools()
 
206
 
207
- # 3. Define the Agent with MCP Tools
208
- system_prompt = (
209
- "You are a financial analyst. Use the provided tools to query the database "
210
- "and perform semantic searches. Spending is POSITIVE (>0). "
211
- "Always explain your findings clearly."
212
- "IMPORTANT: Whenever possible and relevant (e.g. when discussing trends, comparing categories, or showing breakdowns), "
213
- "you MUST proactively use the 'generate_interactive_chart' tool to generate visual plots (bar, pie, or line charts) to accompany your analysis. "
214
- "WARNING: You MUST use the actual tool call to generate the chart. DO NOT simply output a json block with chart parameters as your final text answer."
215
- )
216
-
217
- self.agent = create_agent(
218
- model=self.llm,
219
- tools=mcp_tools,
220
- system_prompt=system_prompt,
221
- checkpointer=InMemorySaver(),
222
- )
223
 
224
- async def chat(self, query: str):
225
- config = {"configurable": {"thread_id": "session_1"}}
226
-
227
- # Clear out any previous chart so we don't carry over stale plots
228
- chart_path = os.path.join(self.temp_dir, "latest_chart.json")
229
- if os.path.exists(chart_path):
230
- os.remove(chart_path)
231
-
232
- result = await self.agent.ainvoke(
233
- {"messages": [{"role": "user", "content": query}]},
234
- config,
235
- )
236
-
237
- # Extract content - handle both string and list formats
238
- content = result["messages"][-1].content
239
-
240
- # If content is a list (Gemini format), extract text from blocks
241
- if isinstance(content, list):
242
- text_parts = []
243
- for block in content:
244
- if isinstance(block, dict) and block.get("type") == "text":
245
- text_parts.append(block.get("text", ""))
246
- final_text = "\n".join(text_parts)
247
- else:
248
- final_text = content
249
 
250
- # Check if the tool generated a chart file on disk during this turn
251
- chart_path = os.path.join(self.temp_dir, "latest_chart.json")
252
- if os.path.exists(chart_path):
253
- with open(chart_path, "r") as f:
254
- chart_json = f.read()
255
- final_text += f"\n\n===CHART===\n{chart_json}\n===ENDCHART==="
256
 
257
- return final_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  async def cleanup(self):
260
  """Delete temporary session files and close MCP client."""
261
- if self.mcp_client:
262
- try:
263
- await self.mcp_client.close()
264
- except Exception as e:
265
- print(f"Warning: Failed to close MCP client: {e}")
266
-
267
  if os.path.exists(self.temp_dir):
268
  try:
269
  shutil.rmtree(self.temp_dir)
 
1
  import os
2
+ import sys
3
  import uuid
4
  import asyncio
5
  import pandas as pd
 
22
  from langchain.agents import create_agent
23
  from langchain_community.tools import DuckDuckGoSearchRun
24
  from langchain_mcp_adapters.client import MultiServerMCPClient
25
+ from qdrant_client.http import models as qdrant_models
26
 
27
  # Import specific embeddings
28
  from langchain_google_genai import GoogleGenerativeAIEmbeddings
29
  from langchain_openai import OpenAIEmbeddings
30
 
31
+ from supabase import create_client, ClientOptions
32
+
33
+ from dotenv import load_dotenv
34
+ load_dotenv()
35
+
36
  class MoneyRAG:
37
+ def __init__(self, llm_provider: str, model_name: str, embedding_model_name: str, api_key: str, user_id: str, access_token: str = None):
38
  self.llm_provider = llm_provider.lower()
39
  self.model_name = model_name
40
  self.embedding_model_name = embedding_model_name
41
+ self.user_id = user_id
42
+
43
+ # Initialize Supabase Client
44
+ url = os.environ.get("SUPABASE_URL")
45
+ key = os.environ.get("SUPABASE_KEY")
46
+
47
+ # Security: Inject the logged-in user's JWT so RLS policies pass!
48
+ if access_token:
49
+ opts = ClientOptions(headers={"Authorization": f"Bearer {access_token}"})
50
+ self.supabase = create_client(url, key, options=opts)
51
+ else:
52
+ self.supabase = create_client(url, key)
53
 
54
  # Set API Keys
55
  if self.llm_provider == "google":
 
79
  self.mcp_client: Optional[MultiServerMCPClient] = None
80
  self.search_tool = DuckDuckGoSearchRun()
81
  self.merchant_cache = {} # Session-based cache for merchant enrichment
82
+ self.memory = InMemorySaver() # Session-based cache for chat memory
83
 
84
+ async def setup_session(self, csv_files: List[dict]):
85
  """Ingests CSVs and sets up DBs."""
86
+ # csv_files format: [{"path": "/temp/file.csv", "csv_id": "uuid"}, ...]
87
+ for file_info in csv_files:
88
+ await self._ingest_csv(file_info["path"], file_info.get("csv_id"))
89
 
90
  self.db = SQLDatabase.from_uri(f"sqlite:///{self.db_path}")
91
  self.vector_store = self._sync_to_qdrant()
 
92
 
93
+ async def _ingest_csv(self, file_path, csv_id=None):
94
  df = pd.read_csv(file_path)
95
  headers = df.columns.tolist()
96
  sample_data = df.head(10).to_json()
 
128
  mapping = await chain.ainvoke({"headers": headers, "sample": sample_data, "filename": os.path.basename(file_path)})
129
 
130
  standard_df = pd.DataFrame()
131
+ standard_df['trans_date'] = pd.to_datetime(df[mapping['date_col']]).dt.strftime('%Y-%m-%d')
132
+ # Assign user_id AFTER trans_date establishes the DataFrame length, or else it defaults to NaN!
133
+ standard_df['user_id'] = self.user_id
134
  standard_df['description'] = df[mapping['desc_col']]
135
+ if csv_id:
136
+ standard_df['source_csv_id'] = csv_id
137
 
138
  raw_amounts = pd.to_numeric(df[mapping['amount_col']])
139
  standard_df['amount'] = raw_amounts * -1 if mapping['sign_convention'] == "spending_is_negative" else raw_amounts
140
  standard_df['category'] = df[mapping.get('category_col')] if mapping.get('category_col') else 'Uncategorized'
 
141
 
142
  # --- Async Enrichment Step ---
143
  print(f" ✨ Enriching descriptions for {os.path.basename(file_path)}...")
 
165
  desc_map = dict(zip(unique_descriptions, enrichment_results))
166
  standard_df['enriched_info'] = standard_df['description'].map(desc_map).fillna("")
167
 
168
+ # Save to Supabase transactions table instead of local SQLite
169
+ # Use simplejson roundtrip to guarantee all Pandas NaNs, NaTs, and weird floats become strict JSON nulls
170
+ import json
171
+ records = json.loads(standard_df.to_json(orient='records'))
172
+
173
+ batch_size = 100
174
+ for i in range(0, len(records), batch_size):
175
+ batch = records[i:i + batch_size]
176
+ # If insertion fails, it raises an exception so Streamlit surfaces the error
177
+ self.supabase.table("Transaction").insert(batch).execute()
178
 
179
  def _sync_to_qdrant(self):
180
+ # client = QdrantClient(path=self.qdrant_path)
181
+ client = QdrantClient(
182
+ url=os.getenv("QDRANT_URL"),
183
+ api_key=os.getenv("QDRANT_API_KEY"),
184
+ )
185
  collection = "transactions"
186
 
187
+ # Fetch only THIS USER'S transactions from Supabase to sync into VectorDB
188
+ res = self.supabase.table("Transaction").select("*").eq("user_id", self.user_id).execute()
189
+ df = pd.DataFrame(res.data)
190
 
191
  # Check for empty dataframe
192
  if df.empty:
193
+ raise ValueError("No transactions found in database for this user. Please upload files first.")
194
 
195
  # Dynamically detect embedding dimension
196
  sample_embedding = self.embeddings.embed_query("test")
197
  embedding_dim = len(sample_embedding)
198
 
199
+ # Safely create the collection only if it doesn't already exist to preserve multi-tenant pool
200
+ if not client.collection_exists(collection):
201
+ client.create_collection(
202
+ collection_name=collection,
203
+ vectors_config=qdrant_models.VectorParams(size=embedding_dim, distance=qdrant_models.Distance.COSINE),
204
+ )
205
+
206
+ # Security: Create a strict Payload Index on the user_id field so we can filter by it securely!
207
+ client.create_payload_index(
208
  collection_name=collection,
209
+ field_name="metadata.user_id",
210
+ field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
211
  )
212
 
213
  vs = QdrantVectorStore(client=client, collection_name=collection, embedding=self.embeddings)
 
222
  else:
223
  texts.append(base_text)
224
 
225
+ # Inject critical user_id payload to Qdrant so we can filter on it during retrieval
226
+ metadatas = df[['id', 'amount', 'category', 'trans_date']].copy()
227
+ if 'source_csv_id' in df.columns:
228
+ metadatas['source_csv_id'] = df['source_csv_id']
229
+ metadatas = metadatas.to_dict('records')
230
 
231
+ vector_ids = []
232
+ for m in metadatas:
233
+ vector_ids.append(str(m['id'])) # Keep original Postgres UUID as Vector ID to prevent duplication
234
+ m['user_id'] = self.user_id # Secure payload identifier
235
+ m['transaction_date'] = str(m['trans_date']) # Rename for agent consistency
236
+ del m['trans_date']
237
+
238
+ vs.add_texts(texts=texts, metadatas=metadatas, ids=vector_ids)
239
  return vs
240
 
241
+ async def delete_file(self, csv_id: str):
242
+ """Force delete a file and all its transactions from Postgres and Qdrant."""
243
+ try:
244
+ # 1. Delete from Postgres (Transactions cascade automatically if foreign keyed... but we'll manually ensure they wipe just in case)
245
+ self.supabase.table("Transaction").delete().eq("source_csv_id", csv_id).execute()
246
+ self.supabase.table("CSVFile").delete().eq("id", csv_id).execute()
247
+
248
+ # 2. Delete from Qdrant via payload filter
249
+ client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
250
+ client.delete(
251
+ collection_name="transactions",
252
+ points_selector=qdrant_models.Filter(
253
+ must=[
254
+ qdrant_models.FieldCondition(
255
+ key="metadata.source_csv_id",
256
+ match=qdrant_models.MatchValue(value=csv_id)
257
+ )
258
+ ]
259
+ )
260
+ )
261
+ except Exception as e:
262
+ print(f"Error purging file data: {e}")
263
+
264
+ async def chat(self, query: str):
265
+ # 1. Initialize MCP client dynamically to guarantee fresh bindings
266
  server_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mcp_server.py")
267
 
268
+ mcp_client = MultiServerMCPClient(
269
  {
270
  "money_rag": {
271
  "transport": "stdio",
272
+ "command": sys.executable,
273
  "args": [server_path],
274
+ "env": {**os.environ.copy(), "CURRENT_USER_ID": self.user_id},
275
  }
276
  }
277
  )
278
 
279
+ try:
280
+ # 2. Extract tools from the safely established subprocess
281
+ mcp_tools = await mcp_client.get_tools()
282
 
283
+ # 3. Create the LangGraph agent for this turn, preserving historical memory cache
284
+ system_prompt = (
285
+ "You are a financial analyst. Use the provided tools to query the database "
286
+ "and perform semantic searches. Spending is POSITIVE (>0). "
287
+ "Always explain your findings clearly."
288
+ "IMPORTANT: Whenever possible and relevant (e.g. when discussing trends, comparing categories, or showing breakdowns), "
289
+ "you MUST proactively use the 'generate_interactive_chart' tool to generate visual plots (bar, pie, or line charts) to accompany your analysis. "
290
+ "WARNING: You MUST use the actual tool call to generate the chart. DO NOT simply output a json block with chart parameters as your final text answer."
291
+ )
292
+
293
+ agent = create_agent(
294
+ model=self.llm,
295
+ tools=mcp_tools,
296
+ system_prompt=system_prompt,
297
+ checkpointer=self.memory,
298
+ )
299
 
300
+ config = {"configurable": {"thread_id": "session_1"}}
301
+
302
+ # Clear out any previous chart so we don't carry over stale plots
303
+ chart_path = os.path.join(self.temp_dir, "latest_chart.json")
304
+ if os.path.exists(chart_path):
305
+ os.remove(chart_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
+ # 4. Invoke the agent against the LLM, triggering our nested Tools locally
308
+ result = await agent.ainvoke(
309
+ {"messages": [{"role": "user", "content": query}]},
310
+ config,
311
+ )
 
312
 
313
+ # Extract content - handle both string and list formats
314
+ content = result["messages"][-1].content
315
+
316
+ # If content is a list (Gemini format), extract text from blocks
317
+ if isinstance(content, list):
318
+ text_parts = []
319
+ for block in content:
320
+ if isinstance(block, dict) and block.get("type") == "text":
321
+ text_parts.append(block.get("text", ""))
322
+ final_text = "\n".join(text_parts)
323
+ else:
324
+ final_text = content
325
+
326
+ # Check for generated chart
327
+ if os.path.exists(chart_path):
328
+ with open(chart_path, "r") as f:
329
+ chart_json = f.read()
330
+ return f"{final_text}\n\n===CHART===\n{chart_json}\n===ENDCHART==="
331
+
332
+ return final_text
333
+
334
+ finally:
335
+ # 5. Destroy the subprocess safely so we don't leak FastMCP zombies across Streamlit reruns
336
+ try:
337
+ await mcp_client.close()
338
+ except Exception as close_e:
339
+ print(f"Warning on closing MCP Client: {close_e}")
340
 
341
  async def cleanup(self):
342
  """Delete temporary session files and close MCP client."""
 
 
 
 
 
 
343
  if os.path.exists(self.temp_dir):
344
  try:
345
  shutil.rmtree(self.temp_dir)
requirements.txt CHANGED
@@ -41,3 +41,9 @@ tenacity>=9.1.2
41
 
42
  streamlit>=1.53.0
43
  ddgs>=9.10.0
 
 
 
 
 
 
 
41
 
42
  streamlit>=1.53.0
43
  ddgs>=9.10.0
44
+
45
+ supabase>=2.28.0
46
+ plotly>=6.5.2
47
+
48
+ psycopg2-binary>=2.9.11
49
+ extra-streamlit-components