Sarat Kannan commited on
Commit
b69a231
·
unverified ·
1 Parent(s): ec44df3

Add files via upload

Browse files
app.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import streamlit as st
5
+ from dotenv import load_dotenv
6
+
7
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
8
+
9
+ from orchestrator.settings import Settings
10
+ from orchestrator.factories import get_llm
11
+ from orchestrator.sql_agent import sql_answer
12
+ from orchestrator.graph_agent import graph_answer
13
+ from orchestrator.tools import run_tools_once
14
+ from orchestrator.graphs import build_router_graph, build_tools_agent_graph
15
+
16
+ load_dotenv()
17
+
18
+ st.set_page_config(page_title="Multi-Agent Orchestration (LangGraph)", page_icon="🧭", layout="wide")
19
+
20
+
21
+ def _dict_messages_to_lc(messages: list[dict]) -> list[BaseMessage]:
22
+ out: list[BaseMessage] = []
23
+ for m in messages:
24
+ role = m.get("role")
25
+ content = m.get("content", "")
26
+ if role == "user":
27
+ out.append(HumanMessage(content=content))
28
+ else:
29
+ out.append(AIMessage(content=content))
30
+ return out
31
+
32
+
33
+ def _extract_tool_names_from_messages(messages: list[BaseMessage]) -> list[str]:
34
+ names: list[str] = []
35
+ for m in messages:
36
+ if isinstance(m, AIMessage):
37
+ tool_calls = getattr(m, "tool_calls", None) or []
38
+ for tc in tool_calls:
39
+ if isinstance(tc, dict):
40
+ n = tc.get("name")
41
+ else:
42
+ n = getattr(tc, "name", None)
43
+ if n:
44
+ names.append(str(n))
45
+ deduped: list[str] = []
46
+ for n in names:
47
+ if n not in deduped:
48
+ deduped.append(n)
49
+ return deduped
50
+
51
+
52
+ def _rewrite_followup_to_standalone(settings: Settings, chat_messages: list[dict], question: str) -> str:
53
+ """
54
+ Used in the *direct* SQL/Graph pages to make follow-ups work better.
55
+ Router graph already does this internally.
56
+ """
57
+ user_count = sum(1 for m in chat_messages if m.get("role") == "user")
58
+ if user_count <= 1:
59
+ return question
60
+
61
+ llm = get_llm(settings, temperature=0)
62
+
63
+ # Build a short transcript
64
+ recent = chat_messages[-12:]
65
+ lines = []
66
+ for m in recent:
67
+ if m.get("role") == "user":
68
+ lines.append(f"User: {m.get('content','')}")
69
+ else:
70
+ lines.append(f"Assistant: {m.get('content','')}")
71
+ transcript = "\n".join(lines)
72
+
73
+ prompt = (
74
+ "Rewrite the user's latest question into a standalone question.\n"
75
+ "Do NOT answer the question.\n\n"
76
+ f"Conversation:\n{transcript}\n\n"
77
+ f"Latest user question:\n{question}\n\n"
78
+ "Standalone question:"
79
+ )
80
+
81
+ msg = llm.invoke(
82
+ [
83
+ SystemMessage(content="You rewrite follow-up questions into standalone questions."),
84
+ HumanMessage(content=prompt),
85
+ ]
86
+ )
87
+ rewritten = (msg.content or "").strip()
88
+ return rewritten or question
89
+
90
+
91
+ # --- Sidebar ---
92
+ st.sidebar.title("🧭 Multi-Agent Orchestration")
93
+
94
+ page = st.sidebar.radio(
95
+ "Navigation",
96
+ ["Router Chat", "SQL Agent", "Graph Agent", "Tools Agent", "Settings"],
97
+ index=0,
98
+ )
99
+
100
+ # Runtime settings overrides (UI -> env-like)
101
+ st.sidebar.subheader("Model")
102
+ # llm_model = st.sidebar.text_input("LLM_MODEL (Groq)", value=os.getenv("LLM_MODEL", "llama-3.1-8b-instant"))
103
+ MODEL_OPTIONS = [
104
+ "llama-3.1-8b-instant",
105
+ "meta-llama/llama-4-maverick-17b-128e-instruct",
106
+ "meta-llama/llama-4-scout-17b-16e-instruct",
107
+ "moonshotai/kimi-k2-instruct-0905",
108
+ "openai/gpt-oss-120b",
109
+ "qwen/qwen3-32b",
110
+ ]
111
+
112
+ default_model = os.getenv("LLM_MODEL", "meta-llama/llama-4-maverick-17b-128e-instruct")
113
+ if default_model not in MODEL_OPTIONS:
114
+ MODEL_OPTIONS.insert(0, default_model)
115
+
116
+ llm_model = st.sidebar.selectbox("LLM_MODEL", MODEL_OPTIONS, index=MODEL_OPTIONS.index(default_model))
117
+
118
+ st.sidebar.subheader("SQL (SQLite)")
119
+ sqlite_path = st.sidebar.text_input("SQLITE_PATH", value=os.getenv("SQLITE_PATH", "student.db"))
120
+
121
+ st.sidebar.subheader("Neo4j (Graph DB)")
122
+ neo4j_uri = st.sidebar.text_input("NEO4J_URI", value=os.getenv("NEO4J_URI", ""))
123
+ neo4j_username = st.sidebar.text_input("NEO4J_USERNAME", value=os.getenv("NEO4J_USERNAME", ""))
124
+ neo4j_password = st.sidebar.text_input("NEO4J_PASSWORD", value=os.getenv("NEO4J_PASSWORD", ""), type="password")
125
+
126
+ st.sidebar.subheader("UI")
127
+ show_routing = st.sidebar.checkbox("Show routed agent", value=True)
128
+ show_tools_used = st.sidebar.checkbox("Show tools used", value=True)
129
+
130
+ settings = Settings(
131
+ groq_api_key=os.getenv("GROQ_API_KEY", ""),
132
+ llm_model=llm_model,
133
+ sqlite_path=sqlite_path,
134
+ neo4j_uri=neo4j_uri,
135
+ neo4j_username=neo4j_username,
136
+ neo4j_password=neo4j_password,
137
+ wiki_doc_content_chars_max=int(os.getenv("WIKI_DOC_CHARS", "2000")),
138
+ debug=os.getenv("DEBUG", "0") in ("1", "true", "True"),
139
+ )
140
+
141
+
142
+ @st.cache_resource
143
+ def _router_graph_cached(model: str):
144
+ s = Settings(
145
+ groq_api_key=settings.groq_api_key,
146
+ llm_model=model,
147
+ sqlite_path=settings.sqlite_path,
148
+ neo4j_uri=settings.neo4j_uri,
149
+ neo4j_username=settings.neo4j_username,
150
+ neo4j_password=settings.neo4j_password,
151
+ wiki_doc_content_chars_max=settings.wiki_doc_content_chars_max,
152
+ debug=settings.debug,
153
+ )
154
+ return build_router_graph(s)
155
+
156
+
157
+ @st.cache_resource
158
+ def _tools_graph_cached(model: str):
159
+ s = Settings(
160
+ groq_api_key=settings.groq_api_key,
161
+ llm_model=model,
162
+ sqlite_path=settings.sqlite_path,
163
+ neo4j_uri=settings.neo4j_uri,
164
+ neo4j_username=settings.neo4j_username,
165
+ neo4j_password=settings.neo4j_password,
166
+ wiki_doc_content_chars_max=settings.wiki_doc_content_chars_max,
167
+ debug=settings.debug,
168
+ )
169
+ return build_tools_agent_graph(s)
170
+
171
+
172
+ # --- Pages ---
173
+ if page == "Router Chat":
174
+ st.title("🧭 Router Chat (LangGraph)")
175
+ st.write("Multi-turn chat. The router chooses SQL / Graph / Tools / General automatically.")
176
+
177
+ if "router_messages" not in st.session_state:
178
+ st.session_state.router_messages = [
179
+ {"role": "assistant", "content": "Hi! Ask a question — I will route it to the right agent."}
180
+ ]
181
+
182
+ c1, c2 = st.columns([1, 4])
183
+ with c1:
184
+ if st.button("Reset chat", key="reset_router"):
185
+ st.session_state.router_messages = [
186
+ {"role": "assistant", "content": "Chat reset. Ask a question!"}
187
+ ]
188
+ st.rerun()
189
+
190
+ for m in st.session_state.router_messages:
191
+ with st.chat_message(m["role"]):
192
+ meta = m.get("meta") or {}
193
+ if m["role"] == "assistant" and show_routing and meta.get("route"):
194
+ st.caption(f"🧭 Routed to: `{meta['route']} agent`")
195
+ if m["role"] == "assistant" and show_tools_used and meta.get("tools_used"):
196
+ tools_line = ", ".join([f"`{t}`" for t in meta["tools_used"]])
197
+ st.caption(f"🧰 Tools used: {tools_line}")
198
+ st.write(m["content"])
199
+
200
+ prompt = st.chat_input("Ask a question...", key="router_chat_input")
201
+ if prompt:
202
+ st.session_state.router_messages.append({"role": "user", "content": prompt})
203
+ with st.chat_message("user"):
204
+ st.write(prompt)
205
+
206
+ try:
207
+ with st.chat_message("assistant"):
208
+ route_slot = st.empty()
209
+ tools_slot = st.empty()
210
+ answer_slot = st.empty()
211
+
212
+ with st.spinner("Thinking..."):
213
+ graph = _router_graph_cached(settings.llm_model)
214
+ msgs = _dict_messages_to_lc(st.session_state.router_messages)
215
+
216
+ out = graph.invoke({"messages": msgs})
217
+ out_msgs = out.get("messages", []) or []
218
+
219
+ last_ai = next((mm for mm in reversed(out_msgs) if isinstance(mm, AIMessage)), None)
220
+ answer = last_ai.content if last_ai else "(no answer)"
221
+
222
+ dbg = out.get("debug", {}) or {}
223
+ route = out.get("route") or dbg.get("router_label") or dbg.get("routed_to") or "general"
224
+ tools_used = dbg.get("tools_used") or []
225
+
226
+ # Update same bubble (no jump)
227
+ if show_routing:
228
+ route_slot.caption(f"🧭 Routed to: `{route}` agent")
229
+ if show_tools_used and tools_used:
230
+ tools_slot.caption("🧰 Tools used: " + ", ".join([f"`{t}`" for t in tools_used]))
231
+ answer_slot.write(answer)
232
+
233
+ # Append to chat history AFTER we have final answer
234
+ st.session_state.router_messages.append(
235
+ {"role": "assistant", "content": answer, "meta": {"route": route, "tools_used": tools_used}}
236
+ )
237
+
238
+ with st.expander("Debug (route + steps)"):
239
+ st.write(out.get("debug", {}))
240
+ st.write("Messages produced:", len(out_msgs))
241
+
242
+ except Exception as e:
243
+ st.error(str(e))
244
+
245
+ elif page == "SQL Agent":
246
+ st.title("🧮 SQL Agent (Chat)")
247
+ st.write("Multi-turn SQL chat. Good for follow-ups like “now filter by …”")
248
+
249
+ # --- Intro: what the DB contains ---
250
+ with st.expander("📌 What's in the SQL database?", expanded=False):
251
+ st.markdown(
252
+ """
253
+ The database contains information about **students, courses, enrollments, and attendance**.
254
+
255
+ - **students**: student_id, name, program, section, year
256
+ - **courses**: course_id, course_code, course_name, department, credits
257
+ - **enrollments**: student-course enrollment per semester with score and grade
258
+ - **attendance**: per-class attendance for each student in each course and semester (present = 1/0)
259
+ - **view**: student_performance (avg_score, num_A grades, num_courses per student per semester)
260
+
261
+ Use this chat for analytics questions like rankings, averages, cohorts, and time/semester filtering.
262
+ """
263
+ )
264
+
265
+ # --- Session init ---
266
+ if "sql_messages" not in st.session_state:
267
+ st.session_state.sql_messages = [
268
+ {"role": "assistant", "content": "Ask a question about the student analytics database, or try an example below."}
269
+ ]
270
+
271
+ # --- Reset ---
272
+ c1, _ = st.columns([1, 5])
273
+ with c1:
274
+ if st.button("Reset chat", key="reset_sql"):
275
+ st.session_state.sql_messages = [{"role": "assistant", "content": "Chat reset. Ask a SQL question!"}]
276
+ st.rerun()
277
+
278
+ # --- Example queries (auto-run) ---
279
+ st.subheader("⚡ Try an example")
280
+ e1, e2, e3 = st.columns(3)
281
+
282
+ if e1.button("🏆 Top students (2025-Fall)", use_container_width=True):
283
+ st.session_state.sql_demo_query = (
284
+ "Show the top 10 students by average score in semester 2025-Fall. "
285
+ "Use the student_performance view. Return name, program, avg_score, num_courses, and num_A."
286
+ )
287
+
288
+ if e2.button("📉 Lowest scoring course (2025-Fall)", use_container_width=True):
289
+ st.session_state.sql_demo_query = (
290
+ "In 2025-Fall, which course has the lowest average score? "
291
+ "Return course_code, course_name, department, and avg_score."
292
+ )
293
+
294
+ if e3.button("🧾 Attendance < 70% (2025-Fall)", use_container_width=True):
295
+ st.session_state.sql_demo_query = (
296
+ "For semester 2025-Fall, show students whose overall attendance is below 70%. "
297
+ "Compute attendance_percent as 100 * AVG(present). "
298
+ "Return student name, program, attendance_percent, and total_classes."
299
+ )
300
+
301
+ demo_query = st.session_state.pop("sql_demo_query", None)
302
+
303
+ # --- Render chat history ---
304
+ for m in st.session_state.sql_messages:
305
+ st.chat_message(m["role"]).write(m["content"])
306
+
307
+ # --- Input (manual OR demo) ---
308
+ prompt = st.chat_input("Ask a SQL question...", key="sql_chat_input")
309
+ user_query = prompt or demo_query
310
+
311
+ if user_query:
312
+ st.session_state.sql_messages.append({"role": "user", "content": user_query})
313
+ st.chat_message("user").write(user_query)
314
+
315
+ try:
316
+ # Create assistant bubble immediately (prevents flicker)
317
+ with st.chat_message("assistant"):
318
+ answer_slot = st.empty()
319
+
320
+ with st.spinner("Thinking..."):
321
+ standalone = _rewrite_followup_to_standalone(
322
+ settings,
323
+ st.session_state.sql_messages,
324
+ user_query,
325
+ )
326
+ out = sql_answer(settings, standalone)
327
+ answer = str(out.get("answer", ""))
328
+
329
+ answer_slot.write(answer)
330
+
331
+ # Append to history AFTER we have the final answer
332
+ st.session_state.sql_messages.append({"role": "assistant", "content": answer})
333
+
334
+ with st.expander("Debug"):
335
+ st.write("Standalone question:", standalone)
336
+ st.json(out)
337
+
338
+ except Exception as e:
339
+ st.error(str(e))
340
+
341
+ elif page == "Graph Agent":
342
+ st.title("🕸️ Graph Agent (Chat)")
343
+ st.write("Multi-turn Cypher/Q&A chat over Neo4j.")
344
+
345
+ # --- Explain what graph contains ---
346
+ with st.expander("📌 What's in the Neo4j database?", expanded=False):
347
+ st.markdown(
348
+ """
349
+ **Theme:** Hollywood movies.
350
+
351
+ **Nodes**
352
+ - `Movie`: title, tagline, released (year)
353
+ - `Person`: name, born (year)
354
+
355
+ **Relationships**
356
+ - `(:Person)-[:ACTED_IN]->(:Movie)`
357
+ - `(:Person)-[:DIRECTED]->(:Movie)`
358
+ - `(:Person)-[:PRODUCED]->(:Movie)`
359
+
360
+ **Examples you can ask about**
361
+ - Movies: “The Matrix”, “Top Gun”, “Jerry Maguire”
362
+ - People: “Tom Cruise”, “Keanu Reeves”, “Tom Hanks”
363
+ """
364
+ )
365
+
366
+ with st.expander("🧠 Why Neo4j (graph DB) vs Web Search?", expanded=False):
367
+ st.markdown(
368
+ """
369
+ **Neo4j is best for relationship-heavy questions** where you want exact, structured answers:
370
+ - “Who co-starred with Tom Cruise the most?”
371
+ - “Find actors who worked with both Tom Cruise and Tom Hanks.”
372
+ - “Show movies connected to *The Matrix* via shared actors.”
373
+
374
+ **Web search is best for open-world facts** (news, definitions, anything outside your dataset).
375
+ So: Web search = broad; Neo4j = deep structured relationships inside your graph.
376
+ """
377
+ )
378
+
379
+ # --- Session init ---
380
+ if "graph_messages" not in st.session_state:
381
+ st.session_state.graph_messages = [
382
+ {"role": "assistant", "content": "Ask a question about the Neo4j movies graph, or try an example below."}
383
+ ]
384
+
385
+ # --- Reset button ---
386
+ c1, _ = st.columns([1, 5])
387
+ with c1:
388
+ if st.button("Reset chat", key="reset_graph"):
389
+ st.session_state.graph_messages = [
390
+ {"role": "assistant", "content": "Chat reset. Ask a graph question!"}
391
+ ]
392
+ st.rerun()
393
+
394
+ # --- Example queries (auto-run) ---
395
+ st.subheader("⚡ Try an example")
396
+ e1, e2, e3 = st.columns(3)
397
+
398
+ if e1.button("🎭 Similar to The Matrix (shared actors)", use_container_width=True):
399
+ st.session_state.graph_demo_query = (
400
+ "Find movies that share at least 2 actors with The Matrix. "
401
+ "Return the movie titles and how many actors are shared."
402
+ )
403
+
404
+ if e2.button("🧭 Shortest path: Tom Hanks ↔ Tom Cruise", use_container_width=True):
405
+ st.session_state.graph_demo_query = (
406
+ "Show the shortest connection between Tom Hanks and Tom Cruise."
407
+ )
408
+
409
+ if e3.button("🎬 Recommend like Cast Away", use_container_width=True):
410
+ st.session_state.graph_demo_query = (
411
+ "Recommend movies like Cast Away based on shared actor and director, and also name them."
412
+ )
413
+
414
+ demo_query = st.session_state.pop("graph_demo_query", None)
415
+
416
+ # --- Render chat history ---
417
+ for m in st.session_state.graph_messages:
418
+ st.chat_message(m["role"]).write(m["content"])
419
+
420
+ # --- Input (manual OR demo) ---
421
+ prompt = st.chat_input("Ask a graph question...", key="graph_chat_input")
422
+ user_query = prompt or demo_query
423
+
424
+ if user_query:
425
+ st.session_state.graph_messages.append({"role": "user", "content": user_query})
426
+ st.chat_message("user").write(user_query)
427
+
428
+ try:
429
+ # Create assistant bubble immediately (prevents flicker)
430
+ with st.chat_message("assistant"):
431
+ answer_slot = st.empty()
432
+
433
+ with st.spinner("Thinking..."):
434
+ standalone = _rewrite_followup_to_standalone(
435
+ settings,
436
+ st.session_state.graph_messages,
437
+ user_query,
438
+ )
439
+ out = graph_answer(settings, standalone)
440
+ answer = str(out.get("answer", ""))
441
+
442
+ answer_slot.write(answer)
443
+
444
+ # Append to history AFTER we have the final answer
445
+ st.session_state.graph_messages.append({"role": "assistant", "content": answer})
446
+
447
+ with st.expander("Debug (Cypher + results)"):
448
+ st.write("Standalone question:", standalone)
449
+ st.json(out.get("debug", {}))
450
+
451
+ except Exception as e:
452
+ st.error(str(e))
453
+
454
+ elif page == "Tools Agent":
455
+ st.title("🧰 Tools Agent (Chat)")
456
+ st.write("Tool-Assisted Research Chat (Web + Wikipedia + arXiv + Calculator).")
457
+
458
+ if "tools_messages" not in st.session_state:
459
+ st.session_state.tools_messages = [{"role": "assistant", "content": "Ask a question — I'll search web/Wikipedia/arXiv and use tools when needed."}]
460
+
461
+ c1, _ = st.columns([1, 5])
462
+ with c1:
463
+ if st.button("Reset chat", key="reset_tools"):
464
+ st.session_state.tools_messages = [{"role": "assistant", "content": "Chat reset. Ask a tools question!"}]
465
+ st.rerun()
466
+
467
+ for m in st.session_state.tools_messages:
468
+ st.chat_message(m["role"]).write(m["content"])
469
+
470
+ prompt = st.chat_input("Ask a tools question...", key="tools_chat_input")
471
+ if prompt:
472
+ st.session_state.tools_messages.append({"role": "user", "content": prompt})
473
+ st.chat_message("user").write(prompt)
474
+
475
+ try:
476
+ with st.chat_message("assistant"):
477
+ tools_slot = st.empty()
478
+ answer_slot = st.empty()
479
+
480
+ with st.spinner("Thinking..."):
481
+ tools_graph = _tools_graph_cached(settings.llm_model)
482
+ msgs = _dict_messages_to_lc(st.session_state.tools_messages)
483
+
484
+ out = tools_graph.invoke({"messages": msgs})
485
+ out_msgs = out.get("messages", []) or []
486
+
487
+ last_ai = next((mm for mm in reversed(out_msgs) if isinstance(mm, AIMessage)), None)
488
+ answer = last_ai.content if last_ai else "(no answer)"
489
+ tools_used = _extract_tool_names_from_messages(out_msgs)
490
+
491
+ if show_tools_used and tools_used:
492
+ tools_slot.caption("🧰 Tools used: " + ", ".join([f"`{t}`" for t in tools_used]))
493
+ answer_slot.write(answer)
494
+
495
+ st.session_state.tools_messages.append({"role": "assistant", "content": answer})
496
+
497
+ with st.expander("Debug (tool messages)"):
498
+ st.write("Tools used:", tools_used)
499
+ st.write("Messages produced:", len(out_msgs))
500
+
501
+ except Exception as e:
502
+ st.error(str(e))
503
+
504
+ # Optional: keep your old "run once each" tester as a quick health check
505
+ with st.expander("Quick tool health-check (run each tool once)"):
506
+ q = st.text_input("Query for one-shot tools test", key="tools_q_once")
507
+ if st.button("Run one-shot tools", type="secondary"):
508
+ try:
509
+ results = run_tools_once(
510
+ q,
511
+ wiki_chars=settings.wiki_doc_content_chars_max,
512
+ )
513
+ for r in results:
514
+ with st.expander(r.tool):
515
+ st.write(r.output)
516
+ except Exception as e:
517
+ st.error(str(e))
518
+
519
+ else:
520
+ st.title("⚙️ Settings / Health Check")
521
+ st.write("Use this page to confirm your keys and connections.")
522
+
523
+ if not settings.groq_api_key:
524
+ st.warning("GROQ_API_KEY is not set. Add it in your environment or .env.")
525
+ else:
526
+ st.success("GROQ_API_KEY is set.")
527
+
528
+ st.write("**Current model:**", settings.llm_model)
529
+ st.write("**SQLite path:**", settings.sqlite_path)
530
+
531
+ if settings.neo4j_uri:
532
+ st.write("**Neo4j URI:**", settings.neo4j_uri)
533
+ else:
534
+ st.info("Neo4j not configured yet (NEO4J_URI empty). Graph Agent will fail until set.")
orchestrator/__init__.py ADDED
File without changes
orchestrator/factories.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Optional
3
+
4
+ from orchestrator.settings import Settings
5
+
6
+ def get_llm(settings: Settings, *, model: Optional[str] = None, temperature: float = 0.2):
7
+ # We use Groq in your stack (same as project 1).
8
+ # If you want OpenAI later, you can add a get_openai_llm here.
9
+ from langchain_groq import ChatGroq
10
+
11
+ m = model or settings.llm_model
12
+ if not settings.groq_api_key:
13
+ raise ValueError("Missing GROQ_API_KEY. Set it in your environment or .env.")
14
+ return ChatGroq(groq_api_key=settings.groq_api_key, model=m, temperature=temperature)
orchestrator/graph_agent.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Any, Optional
4
+
5
+ from orchestrator.settings import Settings
6
+ from orchestrator.factories import get_llm
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+ from langchain_core.output_parsers import StrOutputParser
9
+
10
+ try:
11
+ from langchain_community.graphs import Neo4jGraph
12
+ except Exception as e: # pragma: no cover
13
+ Neo4jGraph = None
14
+
15
+ @dataclass
16
+ class GraphAgentDebug:
17
+ cypher: str = ""
18
+ raw_results: Any = None
19
+ error: str = ""
20
+
21
+ def _get_graph(settings: Settings):
22
+ if Neo4jGraph is None:
23
+ raise ImportError("Neo4jGraph not available. Install langchain-community[neo4j] or neo4j driver.")
24
+ if not (settings.neo4j_uri and settings.neo4j_username and settings.neo4j_password):
25
+ raise ValueError("Missing NEO4J_URI/NEO4J_USERNAME/NEO4J_PASSWORD.")
26
+ return Neo4jGraph(
27
+ url=settings.neo4j_uri,
28
+ username=settings.neo4j_username,
29
+ password=settings.neo4j_password,
30
+ )
31
+
32
+ def graph_answer(settings: Settings, question: str) -> Dict[str, Any]:
33
+ """
34
+ A simple Graph DB agent:
35
+ 1) Get graph schema
36
+ 2) Ask LLM to write Cypher (ONLY the query)
37
+ 3) Execute Cypher
38
+ 4) Ask LLM to produce a final answer grounded in results
39
+ """
40
+ llm = get_llm(settings, temperature=0)
41
+ graph = _get_graph(settings)
42
+
43
+ # schema string
44
+ schema = getattr(graph, "schema", None)
45
+ if callable(schema): # older versions: graph.schema is a function
46
+ schema = schema()
47
+ schema = schema or "Schema not available."
48
+
49
+ cypher_prompt = ChatPromptTemplate.from_template(
50
+ """You are a Neo4j Cypher expert.
51
+ Given the graph schema below, write a Cypher query to answer the user question.
52
+ Return ONLY the Cypher query (no backticks, no explanation).
53
+
54
+ Schema:
55
+ {schema}
56
+
57
+ User question:
58
+ {question}
59
+ """
60
+ )
61
+
62
+ to_cypher = cypher_prompt | llm | StrOutputParser()
63
+
64
+ dbg = GraphAgentDebug()
65
+
66
+ try:
67
+ cypher = (to_cypher.invoke({"schema": schema, "question": question}) or "").strip()
68
+ # Basic cleanup
69
+ cypher = cypher.strip("` ")
70
+ dbg.cypher = cypher
71
+ if not cypher or len(cypher) < 6:
72
+ raise ValueError("LLM did not produce a valid Cypher query.")
73
+
74
+ results = graph.query(cypher)
75
+ dbg.raw_results = results
76
+
77
+ answer_prompt = ChatPromptTemplate.from_template(
78
+ """You are a helpful assistant answering questions using ONLY the database results.
79
+ If results are empty, say you couldn't find relevant rows.
80
+
81
+ User question:
82
+ {question}
83
+
84
+ Cypher results (JSON-like):
85
+ {results}
86
+
87
+ Answer concisely and clearly.
88
+ """
89
+ )
90
+ answer_chain = answer_prompt | llm | StrOutputParser()
91
+ answer = answer_chain.invoke({"question": question, "results": results})
92
+
93
+ return {"answer": answer, "debug": dbg.__dict__, "agent": "graph"}
94
+
95
+ except Exception as e:
96
+ dbg.error = str(e)
97
+ return {
98
+ "answer": "I couldn't query the graph database for that question. Check Neo4j connection/schema and try again.",
99
+ "debug": dbg.__dict__,
100
+ }
orchestrator/graphs.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Annotated, Any, Dict, List, Literal, TypedDict
4
+
5
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
6
+ from langgraph.graph import END, START, StateGraph
7
+ from langgraph.graph.message import add_messages
8
+ from langgraph.prebuilt import ToolNode, tools_condition
9
+
10
+ from orchestrator.factories import get_llm
11
+ from orchestrator.graph_agent import graph_answer
12
+ from orchestrator.settings import Settings
13
+ from orchestrator.sql_agent import sql_answer
14
+ from orchestrator.tools import make_web_wiki_arxiv_tools
15
+
16
+ Route = Literal["sql", "graph", "tools", "general"]
17
+
18
+
19
+ class RouterState(TypedDict, total=False):
20
+ messages: Annotated[list[BaseMessage], add_messages]
21
+ route: Route
22
+ debug: Dict[str, Any]
23
+
24
+
25
+ def _safe_text(x: Any) -> str:
26
+ if x is None:
27
+ return ""
28
+ return x if isinstance(x, str) else str(x)
29
+
30
+
31
+ def _last_user_text(messages: list[BaseMessage]) -> str:
32
+ for m in reversed(messages):
33
+ if isinstance(m, HumanMessage):
34
+ return _safe_text(m.content).strip()
35
+ return ""
36
+
37
+
38
+ def _messages_to_transcript(messages: list[BaseMessage], max_turns: int = 8) -> str:
39
+ """
40
+ Build a lightweight transcript from the last N Human/AI messages.
41
+ We intentionally skip tool messages to keep prompts stable.
42
+ """
43
+ kept: List[BaseMessage] = []
44
+ for m in reversed(messages):
45
+ if isinstance(m, (HumanMessage, AIMessage)):
46
+ kept.append(m)
47
+ if len(kept) >= max_turns * 2: # ~turns * 2 messages
48
+ break
49
+ kept.reverse()
50
+
51
+ lines: List[str] = []
52
+ for m in kept:
53
+ if isinstance(m, HumanMessage):
54
+ lines.append(f"User: {_safe_text(m.content)}")
55
+ elif isinstance(m, AIMessage):
56
+ lines.append(f"Assistant: {_safe_text(m.content)}")
57
+ return "\n".join(lines).strip()
58
+
59
+
60
+ def _merge_debug(state: RouterState, **kv: Any) -> Dict[str, Any]:
61
+ dbg = dict(state.get("debug") or {})
62
+ for k, v in kv.items():
63
+ if v is not None:
64
+ dbg[k] = v
65
+ return dbg
66
+
67
+
68
+ def _extract_tool_names(messages: list[BaseMessage]) -> List[str]:
69
+ """
70
+ Extract tool names from AIMessage.tool_calls across LangChain variants.
71
+ """
72
+ names: List[str] = []
73
+ for m in messages:
74
+ if isinstance(m, AIMessage):
75
+ tool_calls = getattr(m, "tool_calls", None) or []
76
+ for tc in tool_calls:
77
+ # tc may be dict-like or object-like
78
+ if isinstance(tc, dict):
79
+ n = tc.get("name")
80
+ else:
81
+ n = getattr(tc, "name", None)
82
+ if n:
83
+ names.append(str(n))
84
+ # de-dupe, preserve order
85
+ out: List[str] = []
86
+ for n in names:
87
+ if n not in out:
88
+ out.append(n)
89
+ return out
90
+
91
+
92
+ def _rewrite_to_standalone(llm, messages: list[BaseMessage]) -> str:
93
+ """
94
+ If the user asks a follow-up like "show them", rewrite into a standalone question.
95
+ """
96
+ question = _last_user_text(messages)
97
+ if not question:
98
+ return ""
99
+
100
+ # If there's only one user message total, no rewrite needed.
101
+ num_user_msgs = sum(1 for m in messages if isinstance(m, HumanMessage))
102
+ if num_user_msgs <= 1:
103
+ return question
104
+
105
+ transcript = _messages_to_transcript(messages, max_turns=8)
106
+ prompt = (
107
+ "Rewrite the user's latest question into a standalone question.\n"
108
+ "Do NOT answer the question.\n\n"
109
+ "Conversation:\n"
110
+ f"{transcript}\n\n"
111
+ "Latest user question:\n"
112
+ f"{question}\n\n"
113
+ "Standalone question:"
114
+ )
115
+ msg = llm.invoke(
116
+ [
117
+ SystemMessage(content="You rewrite follow-up questions into standalone questions."),
118
+ HumanMessage(content=prompt),
119
+ ]
120
+ )
121
+ rewritten = _safe_text(getattr(msg, "content", "")).strip()
122
+ return rewritten or question
123
+
124
+
125
+ def build_tools_agent_graph(settings: Settings):
126
+ tools = make_web_wiki_arxiv_tools(
127
+ wiki_chars=settings.wiki_doc_content_chars_max,
128
+ )
129
+ llm = get_llm(settings, temperature=0).bind_tools(tools)
130
+
131
+ def assistant(state: RouterState):
132
+ msg = llm.invoke(state["messages"])
133
+ return {"messages": [msg]}
134
+
135
+ g = StateGraph(RouterState)
136
+ g.add_node("assistant", assistant)
137
+ g.add_node("tools", ToolNode(tools))
138
+ g.add_edge(START, "assistant")
139
+ g.add_conditional_edges("assistant", tools_condition)
140
+ g.add_edge("tools", "assistant")
141
+ return g.compile()
142
+
143
+
144
+ def build_router_graph(settings: Settings):
145
+ tools_graph = build_tools_agent_graph(settings)
146
+ llm_router = get_llm(settings, temperature=0)
147
+
148
+ route_prompt = (
149
+ "You are a router for a multi-agent system.\n"
150
+ "Choose exactly ONE route label from: sql, graph, tools, general.\n\n"
151
+ "Routing rules:\n"
152
+ "- sql: querying a relational database (tables/rows, SQL, students DB, counts, filters).\n"
153
+ "- graph: querying a Neo4j graph database (nodes/relationships, Cypher).\n"
154
+ "- tools: needs external knowledge / searching (Wikipedia/arXiv/web) or tool use.\n"
155
+ "- general: conceptual explanation or chat that doesn't need tools/DB queries.\n\n"
156
+ "Return ONLY the label.\n"
157
+ )
158
+
159
+ def router(state: RouterState):
160
+ msgs = state.get("messages", [])
161
+ q = _last_user_text(msgs)
162
+ transcript = _messages_to_transcript(msgs, max_turns=8)
163
+
164
+ payload = (
165
+ "Conversation transcript:\n"
166
+ f"{transcript}\n\n"
167
+ "Latest user question:\n"
168
+ f"{q}"
169
+ )
170
+
171
+ msg = llm_router.invoke(
172
+ [SystemMessage(content=route_prompt), HumanMessage(content=payload)]
173
+ )
174
+ label = _safe_text(msg.content).strip().lower()
175
+ if label not in ("sql", "graph", "tools", "general"):
176
+ label = "general"
177
+
178
+ dbg = _merge_debug(state, router_label=label, router_raw=msg.content, routed_to=label)
179
+ return {"route": label, "debug": dbg}
180
+
181
+ def sql_node(state: RouterState):
182
+ standalone = _rewrite_to_standalone(llm_router, state["messages"])
183
+ out = sql_answer(settings, standalone)
184
+ dbg = _merge_debug(state, routed_to="sql", sql=out, standalone_question=standalone)
185
+ return {"route": "sql", "messages": [AIMessage(content=str(out["answer"]))], "debug": dbg}
186
+
187
+ def graph_node(state: RouterState):
188
+ standalone = _rewrite_to_standalone(llm_router, state["messages"])
189
+ out = graph_answer(settings, standalone)
190
+ dbg = _merge_debug(state, routed_to="graph", graph=out.get("debug", {}), standalone_question=standalone)
191
+ return {"route": "graph", "messages": [AIMessage(content=str(out["answer"]))], "debug": dbg}
192
+
193
+ def tools_node(state: RouterState):
194
+ out_state = tools_graph.invoke({"messages": state["messages"]})
195
+ out_msgs = out_state.get("messages", [])
196
+ tools_used = _extract_tool_names(out_msgs)
197
+
198
+ dbg = _merge_debug(
199
+ state,
200
+ routed_to="tools",
201
+ tools_used=tools_used,
202
+ tools_graph={"messages_len": len(out_msgs)},
203
+ )
204
+ return {"route": "tools", "messages": out_msgs, "debug": dbg}
205
+
206
+ def general_node(state: RouterState):
207
+ # Use the conversation itself (not just last message)
208
+ convo = [m for m in state["messages"] if isinstance(m, (HumanMessage, AIMessage))]
209
+ msg = llm_router.invoke([SystemMessage(content="You are a helpful assistant.")] + convo)
210
+ dbg = _merge_debug(state, routed_to="general")
211
+ return {"route": "general", "messages": [AIMessage(content=_safe_text(msg.content))], "debug": dbg}
212
+
213
+ g = StateGraph(RouterState)
214
+ g.add_node("router", router)
215
+ g.add_node("sql", sql_node)
216
+ g.add_node("graph", graph_node)
217
+ g.add_node("tools", tools_node)
218
+ g.add_node("general", general_node)
219
+
220
+ g.add_edge(START, "router")
221
+ g.add_conditional_edges(
222
+ "router",
223
+ lambda s: s["route"],
224
+ {"sql": "sql", "graph": "graph", "tools": "tools", "general": "general"},
225
+ )
226
+ g.add_edge("sql", END)
227
+ g.add_edge("graph", END)
228
+ g.add_edge("tools", END)
229
+ g.add_edge("general", END)
230
+
231
+ return g.compile()
orchestrator/settings.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ import os
4
+
5
+ @dataclass(frozen=True)
6
+ class Settings:
7
+ # LLM
8
+ groq_api_key: str = os.getenv("GROQ_API_KEY", "")
9
+ llm_model: str = os.getenv("LLM_MODEL", "meta-llama/llama-4-maverick-17b-128e-instruct")
10
+
11
+ # SQL (SQLite by default)
12
+ sqlite_path: str = os.getenv("SQLITE_PATH", "student.db")
13
+
14
+ # Neo4j Graph DB
15
+ neo4j_uri: str = os.getenv("NEO4J_URI", "")
16
+ neo4j_username: str = os.getenv("NEO4J_USERNAME", "")
17
+ neo4j_password: str = os.getenv("NEO4J_PASSWORD", "")
18
+
19
+ # Tool settings
20
+ # wiki_top_k_results: int = int(os.getenv("WIKI_TOP_K", "3"))
21
+ wiki_doc_content_chars_max: int = int(os.getenv("WIKI_DOC_CHARS", "2000"))
22
+
23
+ # Debug
24
+ debug: bool = os.getenv("DEBUG", "0") in ("1","true","True","yes","YES")
orchestrator/sql_agent.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Dict, Any
5
+ import sqlite3
6
+
7
+ from sqlalchemy import create_engine
8
+
9
+ from orchestrator.settings import Settings
10
+
11
+ from langchain_groq import ChatGroq
12
+ from langchain_community.utilities.sql_database import SQLDatabase
13
+ from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
14
+ from langchain_community.agent_toolkits.sql.base import create_sql_agent
15
+
16
+
17
+ def _resolve_sqlite_path(settings: Settings, db_path: Optional[str] = None) -> Path:
18
+ p = Path(db_path or settings.sqlite_path)
19
+ if not p.is_absolute():
20
+ # project root = parent of orchestrator/
21
+ p = (Path(__file__).resolve().parents[1] / p).resolve()
22
+ return p
23
+
24
+
25
+ def _make_sql_db_readonly(sqlite_path: Path) -> SQLDatabase:
26
+ if not sqlite_path.exists():
27
+ raise FileNotFoundError(
28
+ f"SQLite DB not found at: {sqlite_path}\n"
29
+ f"Fix: put student.db at project root OR set SQLITE_PATH to an absolute path."
30
+ )
31
+
32
+ def _connect():
33
+ return sqlite3.connect(f"file:{sqlite_path.as_posix()}?mode=ro", uri=True)
34
+
35
+ engine = create_engine("sqlite:///", creator=_connect)
36
+ return SQLDatabase(engine)
37
+
38
+
39
+ def _make_llm(settings: Settings):
40
+ # ChatGroq param names differ across versions; support both.
41
+ try:
42
+ return ChatGroq(
43
+ api_key=settings.groq_api_key,
44
+ model=settings.llm_model,
45
+ temperature=0,
46
+ )
47
+ except TypeError:
48
+ return ChatGroq(
49
+ groq_api_key=settings.groq_api_key,
50
+ model_name=settings.llm_model,
51
+ temperature=0,
52
+ )
53
+
54
+
55
+ def make_sql_agent(settings: Settings, *, db_path: Optional[str] = None):
56
+ llm = _make_llm(settings)
57
+
58
+ sqlite_path = _resolve_sqlite_path(settings, db_path=db_path)
59
+ db = _make_sql_db_readonly(sqlite_path)
60
+
61
+ toolkit = SQLDatabaseToolkit(db=db, llm=llm)
62
+
63
+ # This is the key difference vs your b version:
64
+ # Force the tool-calling SQL agent (most reliable on LC 1.2.x).
65
+ agent = create_sql_agent(
66
+ llm=llm,
67
+ toolkit=toolkit,
68
+ agent_type="tool-calling",
69
+ handle_parsing_errors=True,
70
+ max_iterations=30,
71
+ max_execution_time=60,
72
+ verbose=bool(settings.debug),
73
+ return_intermediate_steps=bool(settings.debug),
74
+ )
75
+
76
+ return agent, db, str(sqlite_path)
77
+
78
+
79
+ def sql_answer(settings: Settings, question: str, *, db_path: Optional[str] = None) -> Dict[str, Any]:
80
+ agent, db, sqlite_path = make_sql_agent(settings, db_path=db_path)
81
+
82
+ q = (question or "").strip().lower()
83
+
84
+ # Keep your deterministic shortcut (nice UX)
85
+ if any(s in q for s in ["list the tables", "list tables", "show tables", "what tables"]):
86
+ tables = db.get_usable_table_names()
87
+ return {"answer": "Tables: " + ", ".join(tables), "db_path": sqlite_path}
88
+
89
+ # Run agent
90
+ out = agent.invoke({"input": question})
91
+
92
+ # Normalize output
93
+ answer = out.get("output") if isinstance(out, dict) else str(out)
94
+
95
+ result = {"answer": str(answer), "db_path": sqlite_path, "agent": "sql"}
96
+
97
+ # If debug enabled, surface intermediate steps in Streamlit expander
98
+ if isinstance(out, dict) and "intermediate_steps" in out:
99
+ result["intermediate_steps"] = out["intermediate_steps"]
100
+
101
+ return result
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+ # from __future__ import annotations
110
+
111
+ # from pathlib import Path
112
+ # from typing import Optional, Dict, Any
113
+ # import sqlite3
114
+
115
+ # from sqlalchemy import create_engine
116
+
117
+ # from orchestrator.settings import Settings
118
+ # from orchestrator.factories import get_llm
119
+
120
+ # # --- Imports that vary across LangChain versions ---
121
+ # try:
122
+ # # langchain >= 1.x
123
+ # from langchain.sql_database import SQLDatabase
124
+ # except Exception:
125
+ # # older / community
126
+ # from langchain_community.utilities import SQLDatabase
127
+
128
+ # try:
129
+ # from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
130
+ # except Exception:
131
+ # # older path (rare)
132
+ # from langchain.agents.agent_toolkits import SQLDatabaseToolkit
133
+
134
+ # try:
135
+ # from langchain.agents import create_sql_agent
136
+ # except Exception:
137
+ # from langchain_community.agent_toolkits.sql.base import create_sql_agent
138
+
139
+
140
+ # def _resolve_sqlite_path(settings: Settings) -> Path:
141
+ # """
142
+ # Resolve SQLITE_PATH relative to project root (parent of orchestrator/),
143
+ # so Streamlit's current working directory does not break DB loading.
144
+ # """
145
+ # p = Path(settings.sqlite_path)
146
+ # if not p.is_absolute():
147
+ # p = (Path(__file__).resolve().parents[1] / p).resolve()
148
+ # return p
149
+
150
+
151
+ # def _make_sql_db_readonly(sqlite_path: Path) -> SQLDatabase:
152
+ # """
153
+ # Open SQLite in READ-ONLY mode so a wrong path does NOT create an empty DB file.
154
+ # """
155
+ # if not sqlite_path.exists():
156
+ # raise FileNotFoundError(
157
+ # f"SQLite DB not found at: {sqlite_path}\n"
158
+ # f"Fix: put student.db at the project root OR set SQLITE_PATH to an absolute path."
159
+ # )
160
+
161
+ # def _connect():
162
+ # return sqlite3.connect(f"file:{sqlite_path.as_posix()}?mode=ro", uri=True)
163
+
164
+ # engine = create_engine("sqlite:///", creator=_connect)
165
+ # return SQLDatabase(engine)
166
+
167
+
168
+ # def _create_agent(llm, toolkit, verbose: bool):
169
+ # """
170
+ # Create SQL agent WITHOUT passing kwargs that frequently clash with defaults
171
+ # in langchain-classic AgentExecutor.
172
+ # """
173
+ # # Keep only the safest option; many builds already set other defaults internally.
174
+ # agent_exec_kwargs = {"handle_parsing_errors": True}
175
+
176
+ # # Some versions accept max_iterations/max_execution_time top-level.
177
+ # # Some accept neither.
178
+ # # We try progressively.
179
+ # try:
180
+ # return create_sql_agent(
181
+ # llm=llm,
182
+ # toolkit=toolkit,
183
+ # verbose=verbose,
184
+ # max_iterations=25,
185
+ # max_execution_time=60,
186
+ # agent_executor_kwargs=agent_exec_kwargs,
187
+ # )
188
+ # except TypeError:
189
+ # # Try without time/iteration controls to avoid duplicate kwargs.
190
+ # return create_sql_agent(
191
+ # llm=llm,
192
+ # toolkit=toolkit,
193
+ # verbose=verbose,
194
+ # agent_executor_kwargs=agent_exec_kwargs,
195
+ # )
196
+
197
+
198
+ # def make_sql_agent(settings: Settings, *, db_path: Optional[str] = None):
199
+ # llm = get_llm(settings, temperature=0)
200
+
201
+ # sqlite_path = Path(db_path).expanduser().resolve() if db_path else _resolve_sqlite_path(settings)
202
+ # db = _make_sql_db_readonly(sqlite_path)
203
+ # toolkit = SQLDatabaseToolkit(db=db, llm=llm)
204
+
205
+ # agent = _create_agent(llm, toolkit, verbose=getattr(settings, "debug", False))
206
+ # return agent, db, str(sqlite_path)
207
+
208
+
209
+ # def sql_answer(settings: Settings, question: str, *, db_path: Optional[str] = None) -> Dict[str, Any]:
210
+ # agent, db, sqlite_path = make_sql_agent(settings, db_path=db_path)
211
+
212
+ # # Deterministic shortcut so this never loops.
213
+ # q = (question or "").strip().lower()
214
+ # if any(s in q for s in ["list the tables", "list tables", "show tables", "what tables"]):
215
+ # try:
216
+ # tables = db.get_usable_table_names()
217
+ # except Exception:
218
+ # # fallback for older SQLDatabase implementations
219
+ # tables = []
220
+ # return {
221
+ # "answer": "Tables: " + (", ".join(tables) if tables else "(none found)"),
222
+ # "db_path": sqlite_path,
223
+ # }
224
+
225
+ # # Run agent
226
+ # out = agent.invoke({"input": question})
227
+
228
+ # # Normalize output
229
+ # if isinstance(out, dict):
230
+ # answer = out.get("output") or out.get("answer") or str(out)
231
+ # else:
232
+ # answer = str(out)
233
+
234
+ # return {"answer": answer, "db_path": sqlite_path}
orchestrator/tools.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+
5
+ from langchain_core.tools import tool
6
+ from langchain_community.utilities import WikipediaAPIWrapper
7
+ from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun, ArxivQueryRun
8
+
9
+ # --- Calculator tool (safe arithmetic) ---
10
+ import ast
11
+ import operator as op
12
+
13
+ _ALLOWED_OPS = {
14
+ ast.Add: op.add,
15
+ ast.Sub: op.sub,
16
+ ast.Mult: op.mul,
17
+ ast.Div: op.truediv,
18
+ ast.Pow: op.pow,
19
+ ast.USub: op.neg,
20
+ ast.Mod: op.mod,
21
+ ast.FloorDiv: op.floordiv,
22
+ }
23
+
24
+ def _eval_expr(expr: str) -> float:
25
+ """Safely evaluate a basic arithmetic expression."""
26
+ node = ast.parse(expr, mode="eval").body
27
+
28
+ def _eval(n):
29
+ if isinstance(n, ast.Num): # py<3.8
30
+ return n.n
31
+ if isinstance(n, ast.Constant): # py>=3.8
32
+ if isinstance(n.value, (int, float)):
33
+ return n.value
34
+ raise ValueError("Only numbers are allowed.")
35
+ if isinstance(n, ast.BinOp) and type(n.op) in _ALLOWED_OPS:
36
+ return _ALLOWED_OPS[type(n.op)](_eval(n.left), _eval(n.right))
37
+ if isinstance(n, ast.UnaryOp) and type(n.op) in _ALLOWED_OPS:
38
+ return _ALLOWED_OPS[type(n.op)](_eval(n.operand))
39
+ raise ValueError("Only basic arithmetic is allowed.")
40
+
41
+ return float(_eval(node))
42
+
43
+ @tool
44
+ def calculator(expression: str) -> str:
45
+ """Evaluate a math expression. Input must be a plain arithmetic expression (e.g., '12*(3+4)')."""
46
+ try:
47
+ return str(_eval_expr(expression))
48
+ except Exception as e:
49
+ return f"[calculator error] {e}"
50
+
51
+ # --- Web/Wiki/Arxiv tools ---
52
+ def make_web_wiki_arxiv_tools(*, wiki_k: int = 3, wiki_chars: int = 2000):
53
+ """Return tool objects compatible with LangGraph ToolNode."""
54
+
55
+ web = DuckDuckGoSearchRun()
56
+
57
+ # IMPORTANT: WikipediaQueryRun requires api_wrapper in your installed versions.
58
+ wiki_wrapper = WikipediaAPIWrapper(top_k_results=wiki_k, doc_content_chars_max=wiki_chars)
59
+ wiki = WikipediaQueryRun(api_wrapper=wiki_wrapper)
60
+
61
+ # ArxivQueryRun works similarly; its underlying API doesn't require keys.
62
+ arxiv = ArxivQueryRun()
63
+
64
+ return [web, wiki, arxiv, calculator]
65
+
66
+ # @dataclass
67
+ # class ToolResult:
68
+ # tool: str
69
+ # output: str
70
+
71
+ @dataclass
72
+ class ToolResult:
73
+ tool: str
74
+ output: str
75
+ ok: bool = True
76
+ error: Optional[str] = None
77
+
78
+ def run_tools_once(query: str, *, wiki_k: int = 3, wiki_chars: int = 2000) -> List[ToolResult]:
79
+ """Non-agent helper: run each tool once and return outputs (good for debugging)."""
80
+ tools = make_web_wiki_arxiv_tools(wiki_k=wiki_k, wiki_chars=wiki_chars)
81
+ out: List[ToolResult] = []
82
+ for t in tools:
83
+ try:
84
+ out.append(ToolResult(tool=t.name, output=str(t.run(query))))
85
+ except Exception as e:
86
+ out.append(ToolResult(tool=t.name, output=f"[tool error] {e}", ok=False, error=str(e)))
87
+
88
+ return out
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.35
2
+ python-dotenv>=1.0
3
+
4
+ # LangChain / LangGraph stack (align with your env)
5
+ langchain>=1.2.0
6
+ langchain-core>=0.3.0
7
+ langchain-community>=0.4.0
8
+ langchain-groq>=0.3.0
9
+ langgraph>=0.2.0
10
+ langchain-neo4j
11
+
12
+ # Tools
13
+ ddgs
14
+ wikipedia>=1.4.0
15
+ arxiv>=2.1.0
16
+
17
+ # SQL
18
+ sqlalchemy>=2.0
19
+
20
+ # Neo4j graph
21
+ neo4j>=5.0
school.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adb52805ce8c7dc02d5ecc0f104eac5944ac2d234a0cfe04276714e29ea9faf8
3
+ size 1478656
sqlite.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sqlite.py
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import sqlite3
6
+ import random
7
+ from datetime import date, timedelta
8
+ from pathlib import Path
9
+
10
+
11
+ DB_NAME = os.environ.get("SQLITE_DB", "school.db")
12
+ SEED = int(os.environ.get("SQLITE_SEED", "42"))
13
+
14
+ # Scale knobs (keep modest for fast demo)
15
+ NUM_STUDENTS = int(os.environ.get("NUM_STUDENTS", "120"))
16
+ NUM_COURSES = int(os.environ.get("NUM_COURSES", "14"))
17
+ SEMESTERS = ["2024-Fall", "2025-Spring", "2025-Fall"] # change freely
18
+
19
+
20
+ FIRST_NAMES = [
21
+ "Aarav", "Vivaan", "Aditya", "Vihaan", "Arjun", "Sai", "Reyansh", "Ishaan", "Krishna",
22
+ "Ananya", "Aadhya", "Diya", "Ira", "Meera", "Saanvi", "Myra", "Aarohi", "Riya",
23
+ "Rahul", "Kiran", "Suresh", "Priya", "Neha", "Vikram", "Nikhil", "Sneha", "Pooja",
24
+ ]
25
+ LAST_NAMES = [
26
+ "Verma", "Patel", "Gupta", "Mehta", "Singh",
27
+ "Kumar", "Das", "Roy", "Bose", "Chowdhury",
28
+ ]
29
+
30
+ PROGRAMS = ["Computer Science", "Data Science", "AI & ML", "Information Systems", "Cybersecurity"]
31
+ SECTIONS = ["A", "B", "C", "D"]
32
+
33
+ DEPARTMENTS = ["CS", "DS", "AI", "IS", "CY"]
34
+ COURSE_TITLES = [
35
+ "Database Systems", "Operating Systems", "Computer Networks", "Machine Learning",
36
+ "Deep Learning", "Data Structures", "Algorithms", "Cloud Computing",
37
+ "NLP Fundamentals", "Information Security", "Software Engineering",
38
+ "Data Visualization", "MLOps Foundations", "Graph Databases",
39
+ "Statistics for Data Science", "Ethical AI",
40
+ ]
41
+
42
+ GRADE_BANDS = [
43
+ ("A", 90, 100),
44
+ ("B", 80, 89),
45
+ ("C", 70, 79),
46
+ ("D", 60, 69),
47
+ ("F", 0, 59),
48
+ ]
49
+
50
+
51
+ def make_name(rng: random.Random) -> str:
52
+ return f"{rng.choice(FIRST_NAMES)} {rng.choice(LAST_NAMES)}"
53
+
54
+
55
+ def grade_from_score(score: float) -> str:
56
+ for letter, lo, hi in GRADE_BANDS:
57
+ if lo <= score <= hi:
58
+ return letter
59
+ return "F"
60
+
61
+
62
+ def connect(db_path: Path) -> sqlite3.Connection:
63
+ con = sqlite3.connect(str(db_path))
64
+ con.execute("PRAGMA foreign_keys = ON;")
65
+ con.execute("PRAGMA journal_mode = WAL;")
66
+ con.execute("PRAGMA synchronous = NORMAL;")
67
+ return con
68
+
69
+
70
+ def recreate_schema(con: sqlite3.Connection) -> None:
71
+ cur = con.cursor()
72
+
73
+ # Drop in FK-safe order
74
+ cur.executescript(
75
+ """
76
+ DROP TABLE IF EXISTS attendance;
77
+ DROP TABLE IF EXISTS enrollments;
78
+ DROP TABLE IF EXISTS courses;
79
+ DROP TABLE IF EXISTS students;
80
+ """
81
+ )
82
+
83
+ cur.executescript(
84
+ """
85
+ CREATE TABLE students (
86
+ student_id INTEGER PRIMARY KEY AUTOINCREMENT,
87
+ name TEXT NOT NULL,
88
+ program TEXT NOT NULL,
89
+ section TEXT NOT NULL,
90
+ year INTEGER NOT NULL CHECK (year BETWEEN 1 AND 4)
91
+ );
92
+
93
+ CREATE TABLE courses (
94
+ course_id INTEGER PRIMARY KEY AUTOINCREMENT,
95
+ course_code TEXT NOT NULL UNIQUE,
96
+ course_name TEXT NOT NULL,
97
+ department TEXT NOT NULL,
98
+ credits INTEGER NOT NULL CHECK (credits BETWEEN 1 AND 6)
99
+ );
100
+
101
+ CREATE TABLE enrollments (
102
+ enrollment_id INTEGER PRIMARY KEY AUTOINCREMENT,
103
+ student_id INTEGER NOT NULL,
104
+ course_id INTEGER NOT NULL,
105
+ semester TEXT NOT NULL,
106
+ score REAL NOT NULL CHECK (score BETWEEN 0 AND 100),
107
+ grade TEXT NOT NULL CHECK (grade IN ('A','B','C','D','F')),
108
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
109
+ FOREIGN KEY (student_id) REFERENCES students(student_id) ON DELETE CASCADE,
110
+ FOREIGN KEY (course_id) REFERENCES courses(course_id) ON DELETE CASCADE,
111
+ UNIQUE(student_id, course_id, semester)
112
+ );
113
+
114
+ CREATE TABLE attendance (
115
+ attendance_id INTEGER PRIMARY KEY AUTOINCREMENT,
116
+ student_id INTEGER NOT NULL,
117
+ course_id INTEGER NOT NULL,
118
+ semester TEXT NOT NULL,
119
+ class_date TEXT NOT NULL,
120
+ present INTEGER NOT NULL CHECK (present IN (0,1)),
121
+ FOREIGN KEY (student_id) REFERENCES students(student_id) ON DELETE CASCADE,
122
+ FOREIGN KEY (course_id) REFERENCES courses(course_id) ON DELETE CASCADE
123
+ );
124
+
125
+ CREATE INDEX idx_enrollments_student ON enrollments(student_id);
126
+ CREATE INDEX idx_enrollments_course ON enrollments(course_id);
127
+ CREATE INDEX idx_enrollments_sem ON enrollments(semester);
128
+
129
+ CREATE INDEX idx_att_student_course ON attendance(student_id, course_id);
130
+ CREATE INDEX idx_att_semester ON attendance(semester);
131
+ CREATE INDEX idx_att_date ON attendance(class_date);
132
+ """
133
+ )
134
+
135
+ con.commit()
136
+
137
+
138
+ def seed_students(con: sqlite3.Connection, rng: random.Random) -> None:
139
+ cur = con.cursor()
140
+ rows = []
141
+ for _ in range(NUM_STUDENTS):
142
+ rows.append(
143
+ (
144
+ make_name(rng),
145
+ rng.choice(PROGRAMS),
146
+ rng.choice(SECTIONS),
147
+ rng.randint(1, 4),
148
+ )
149
+ )
150
+ cur.executemany(
151
+ "INSERT INTO students(name, program, section, year) VALUES (?,?,?,?)",
152
+ rows,
153
+ )
154
+ con.commit()
155
+
156
+
157
+ def seed_courses(con: sqlite3.Connection, rng: random.Random) -> None:
158
+ cur = con.cursor()
159
+ titles = COURSE_TITLES[:]
160
+ rng.shuffle(titles)
161
+ titles = titles[:NUM_COURSES]
162
+
163
+ rows = []
164
+ for i, title in enumerate(titles, start=1):
165
+ dept = rng.choice(DEPARTMENTS)
166
+ code = f"{dept}{100 + i}"
167
+ credits = rng.choice([2, 3, 3, 4])
168
+ rows.append((code, title, dept, credits))
169
+
170
+ cur.executemany(
171
+ "INSERT INTO courses(course_code, course_name, department, credits) VALUES (?,?,?,?)",
172
+ rows,
173
+ )
174
+ con.commit()
175
+
176
+
177
+ def seed_enrollments_and_attendance(con: sqlite3.Connection, rng: random.Random) -> None:
178
+ cur = con.cursor()
179
+
180
+ student_ids = [r[0] for r in cur.execute("SELECT student_id FROM students").fetchall()]
181
+ course_ids = [r[0] for r in cur.execute("SELECT course_id FROM courses").fetchall()]
182
+
183
+ enrollment_rows = []
184
+ attendance_rows = []
185
+
186
+ # Build a small calendar per semester (10 class dates)
187
+ sem_start = {
188
+ "2024-Fall": date(2024, 9, 1),
189
+ "2025-Spring": date(2025, 2, 1),
190
+ "2025-Fall": date(2025, 9, 1),
191
+ }
192
+
193
+ for sem in SEMESTERS:
194
+ start = sem_start.get(sem, date(2025, 1, 1))
195
+ class_dates = [(start + timedelta(days=7 * i)).isoformat() for i in range(10)]
196
+
197
+ for sid in student_ids:
198
+ # each semester: 3-5 courses
199
+ chosen = rng.sample(course_ids, k=rng.randint(3, 5))
200
+ for cid in chosen:
201
+ # score distribution: mostly 60-95
202
+ base = rng.gauss(mu=78, sigma=10)
203
+ score = max(0, min(100, round(base, 1)))
204
+ grade = grade_from_score(score)
205
+
206
+ enrollment_rows.append((sid, cid, sem, score, grade))
207
+
208
+ # attendance probability correlates loosely with score
209
+ # higher score => slightly higher attendance
210
+ p_present = min(0.98, max(0.60, 0.70 + (score - 70) / 100))
211
+ for d in class_dates:
212
+ present = 1 if rng.random() < p_present else 0
213
+ attendance_rows.append((sid, cid, sem, d, present))
214
+
215
+ cur.executemany(
216
+ "INSERT OR IGNORE INTO enrollments(student_id, course_id, semester, score, grade) VALUES (?,?,?,?,?)",
217
+ enrollment_rows,
218
+ )
219
+ cur.executemany(
220
+ "INSERT INTO attendance(student_id, course_id, semester, class_date, present) VALUES (?,?,?,?,?)",
221
+ attendance_rows,
222
+ )
223
+ con.commit()
224
+
225
+
226
+ def create_views(con: sqlite3.Connection) -> None:
227
+ cur = con.cursor()
228
+ cur.executescript(
229
+ """
230
+ DROP VIEW IF EXISTS student_performance;
231
+
232
+ CREATE VIEW student_performance AS
233
+ SELECT
234
+ s.student_id,
235
+ s.name,
236
+ s.program,
237
+ s.section,
238
+ e.semester,
239
+ ROUND(AVG(e.score), 2) AS avg_score,
240
+ SUM(CASE WHEN e.grade = 'A' THEN 1 ELSE 0 END) AS num_A,
241
+ COUNT(*) AS num_courses
242
+ FROM students s
243
+ JOIN enrollments e ON e.student_id = s.student_id
244
+ GROUP BY s.student_id, e.semester;
245
+ """
246
+ )
247
+ con.commit()
248
+
249
+
250
+ def print_summary(con: sqlite3.Connection) -> None:
251
+ cur = con.cursor()
252
+ tables = cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;").fetchall()
253
+ print("Tables:", [t[0] for t in tables])
254
+
255
+ for t in ["students", "courses", "enrollments", "attendance"]:
256
+ n = cur.execute(f"SELECT COUNT(*) FROM {t};").fetchone()[0]
257
+ print(f"{t}: {n}")
258
+
259
+ # A couple example queries
260
+ print("\nExample: Top 5 students by avg score (latest semester)")
261
+ latest = cur.execute("SELECT semester FROM enrollments ORDER BY semester DESC LIMIT 1;").fetchone()[0]
262
+ rows = cur.execute(
263
+ """
264
+ SELECT s.name, s.program, ROUND(AVG(e.score), 2) AS avg_score
265
+ FROM students s
266
+ JOIN enrollments e ON e.student_id = s.student_id
267
+ WHERE e.semester = ?
268
+ GROUP BY s.student_id
269
+ ORDER BY avg_score DESC
270
+ LIMIT 5;
271
+ """,
272
+ (latest,),
273
+ ).fetchall()
274
+ for r in rows:
275
+ print(r)
276
+
277
+
278
+ def main() -> None:
279
+ rng = random.Random(SEED)
280
+ db_path = Path(DB_NAME).resolve()
281
+
282
+ con = connect(db_path)
283
+ try:
284
+ recreate_schema(con)
285
+ seed_students(con, rng)
286
+ seed_courses(con, rng)
287
+ seed_enrollments_and_attendance(con, rng)
288
+ create_views(con)
289
+ print(f"Created DB: {db_path}")
290
+ print_summary(con)
291
+ finally:
292
+ con.close()
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()