Sarat Kannan commited on
Add files via upload
Browse files- app.py +534 -0
- orchestrator/__init__.py +0 -0
- orchestrator/factories.py +14 -0
- orchestrator/graph_agent.py +100 -0
- orchestrator/graphs.py +231 -0
- orchestrator/settings.py +24 -0
- orchestrator/sql_agent.py +234 -0
- orchestrator/tools.py +88 -0
- requirements.txt +21 -0
- school.db +3 -0
- sqlite.py +296 -0
app.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import streamlit as st
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
| 8 |
+
|
| 9 |
+
from orchestrator.settings import Settings
|
| 10 |
+
from orchestrator.factories import get_llm
|
| 11 |
+
from orchestrator.sql_agent import sql_answer
|
| 12 |
+
from orchestrator.graph_agent import graph_answer
|
| 13 |
+
from orchestrator.tools import run_tools_once
|
| 14 |
+
from orchestrator.graphs import build_router_graph, build_tools_agent_graph
|
| 15 |
+
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
st.set_page_config(page_title="Multi-Agent Orchestration (LangGraph)", page_icon="🧭", layout="wide")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _dict_messages_to_lc(messages: list[dict]) -> list[BaseMessage]:
|
| 22 |
+
out: list[BaseMessage] = []
|
| 23 |
+
for m in messages:
|
| 24 |
+
role = m.get("role")
|
| 25 |
+
content = m.get("content", "")
|
| 26 |
+
if role == "user":
|
| 27 |
+
out.append(HumanMessage(content=content))
|
| 28 |
+
else:
|
| 29 |
+
out.append(AIMessage(content=content))
|
| 30 |
+
return out
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _extract_tool_names_from_messages(messages: list[BaseMessage]) -> list[str]:
|
| 34 |
+
names: list[str] = []
|
| 35 |
+
for m in messages:
|
| 36 |
+
if isinstance(m, AIMessage):
|
| 37 |
+
tool_calls = getattr(m, "tool_calls", None) or []
|
| 38 |
+
for tc in tool_calls:
|
| 39 |
+
if isinstance(tc, dict):
|
| 40 |
+
n = tc.get("name")
|
| 41 |
+
else:
|
| 42 |
+
n = getattr(tc, "name", None)
|
| 43 |
+
if n:
|
| 44 |
+
names.append(str(n))
|
| 45 |
+
deduped: list[str] = []
|
| 46 |
+
for n in names:
|
| 47 |
+
if n not in deduped:
|
| 48 |
+
deduped.append(n)
|
| 49 |
+
return deduped
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _rewrite_followup_to_standalone(settings: Settings, chat_messages: list[dict], question: str) -> str:
|
| 53 |
+
"""
|
| 54 |
+
Used in the *direct* SQL/Graph pages to make follow-ups work better.
|
| 55 |
+
Router graph already does this internally.
|
| 56 |
+
"""
|
| 57 |
+
user_count = sum(1 for m in chat_messages if m.get("role") == "user")
|
| 58 |
+
if user_count <= 1:
|
| 59 |
+
return question
|
| 60 |
+
|
| 61 |
+
llm = get_llm(settings, temperature=0)
|
| 62 |
+
|
| 63 |
+
# Build a short transcript
|
| 64 |
+
recent = chat_messages[-12:]
|
| 65 |
+
lines = []
|
| 66 |
+
for m in recent:
|
| 67 |
+
if m.get("role") == "user":
|
| 68 |
+
lines.append(f"User: {m.get('content','')}")
|
| 69 |
+
else:
|
| 70 |
+
lines.append(f"Assistant: {m.get('content','')}")
|
| 71 |
+
transcript = "\n".join(lines)
|
| 72 |
+
|
| 73 |
+
prompt = (
|
| 74 |
+
"Rewrite the user's latest question into a standalone question.\n"
|
| 75 |
+
"Do NOT answer the question.\n\n"
|
| 76 |
+
f"Conversation:\n{transcript}\n\n"
|
| 77 |
+
f"Latest user question:\n{question}\n\n"
|
| 78 |
+
"Standalone question:"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
msg = llm.invoke(
|
| 82 |
+
[
|
| 83 |
+
SystemMessage(content="You rewrite follow-up questions into standalone questions."),
|
| 84 |
+
HumanMessage(content=prompt),
|
| 85 |
+
]
|
| 86 |
+
)
|
| 87 |
+
rewritten = (msg.content or "").strip()
|
| 88 |
+
return rewritten or question
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# --- Sidebar ---
|
| 92 |
+
st.sidebar.title("🧭 Multi-Agent Orchestration")
|
| 93 |
+
|
| 94 |
+
page = st.sidebar.radio(
|
| 95 |
+
"Navigation",
|
| 96 |
+
["Router Chat", "SQL Agent", "Graph Agent", "Tools Agent", "Settings"],
|
| 97 |
+
index=0,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Runtime settings overrides (UI -> env-like)
|
| 101 |
+
st.sidebar.subheader("Model")
|
| 102 |
+
# llm_model = st.sidebar.text_input("LLM_MODEL (Groq)", value=os.getenv("LLM_MODEL", "llama-3.1-8b-instant"))
|
| 103 |
+
MODEL_OPTIONS = [
|
| 104 |
+
"llama-3.1-8b-instant",
|
| 105 |
+
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
| 106 |
+
"meta-llama/llama-4-scout-17b-16e-instruct",
|
| 107 |
+
"moonshotai/kimi-k2-instruct-0905",
|
| 108 |
+
"openai/gpt-oss-120b",
|
| 109 |
+
"qwen/qwen3-32b",
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
default_model = os.getenv("LLM_MODEL", "meta-llama/llama-4-maverick-17b-128e-instruct")
|
| 113 |
+
if default_model not in MODEL_OPTIONS:
|
| 114 |
+
MODEL_OPTIONS.insert(0, default_model)
|
| 115 |
+
|
| 116 |
+
llm_model = st.sidebar.selectbox("LLM_MODEL", MODEL_OPTIONS, index=MODEL_OPTIONS.index(default_model))
|
| 117 |
+
|
| 118 |
+
st.sidebar.subheader("SQL (SQLite)")
|
| 119 |
+
sqlite_path = st.sidebar.text_input("SQLITE_PATH", value=os.getenv("SQLITE_PATH", "student.db"))
|
| 120 |
+
|
| 121 |
+
st.sidebar.subheader("Neo4j (Graph DB)")
|
| 122 |
+
neo4j_uri = st.sidebar.text_input("NEO4J_URI", value=os.getenv("NEO4J_URI", ""))
|
| 123 |
+
neo4j_username = st.sidebar.text_input("NEO4J_USERNAME", value=os.getenv("NEO4J_USERNAME", ""))
|
| 124 |
+
neo4j_password = st.sidebar.text_input("NEO4J_PASSWORD", value=os.getenv("NEO4J_PASSWORD", ""), type="password")
|
| 125 |
+
|
| 126 |
+
st.sidebar.subheader("UI")
|
| 127 |
+
show_routing = st.sidebar.checkbox("Show routed agent", value=True)
|
| 128 |
+
show_tools_used = st.sidebar.checkbox("Show tools used", value=True)
|
| 129 |
+
|
| 130 |
+
settings = Settings(
|
| 131 |
+
groq_api_key=os.getenv("GROQ_API_KEY", ""),
|
| 132 |
+
llm_model=llm_model,
|
| 133 |
+
sqlite_path=sqlite_path,
|
| 134 |
+
neo4j_uri=neo4j_uri,
|
| 135 |
+
neo4j_username=neo4j_username,
|
| 136 |
+
neo4j_password=neo4j_password,
|
| 137 |
+
wiki_doc_content_chars_max=int(os.getenv("WIKI_DOC_CHARS", "2000")),
|
| 138 |
+
debug=os.getenv("DEBUG", "0") in ("1", "true", "True"),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@st.cache_resource
|
| 143 |
+
def _router_graph_cached(model: str):
|
| 144 |
+
s = Settings(
|
| 145 |
+
groq_api_key=settings.groq_api_key,
|
| 146 |
+
llm_model=model,
|
| 147 |
+
sqlite_path=settings.sqlite_path,
|
| 148 |
+
neo4j_uri=settings.neo4j_uri,
|
| 149 |
+
neo4j_username=settings.neo4j_username,
|
| 150 |
+
neo4j_password=settings.neo4j_password,
|
| 151 |
+
wiki_doc_content_chars_max=settings.wiki_doc_content_chars_max,
|
| 152 |
+
debug=settings.debug,
|
| 153 |
+
)
|
| 154 |
+
return build_router_graph(s)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@st.cache_resource
|
| 158 |
+
def _tools_graph_cached(model: str):
|
| 159 |
+
s = Settings(
|
| 160 |
+
groq_api_key=settings.groq_api_key,
|
| 161 |
+
llm_model=model,
|
| 162 |
+
sqlite_path=settings.sqlite_path,
|
| 163 |
+
neo4j_uri=settings.neo4j_uri,
|
| 164 |
+
neo4j_username=settings.neo4j_username,
|
| 165 |
+
neo4j_password=settings.neo4j_password,
|
| 166 |
+
wiki_doc_content_chars_max=settings.wiki_doc_content_chars_max,
|
| 167 |
+
debug=settings.debug,
|
| 168 |
+
)
|
| 169 |
+
return build_tools_agent_graph(s)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# --- Pages ---
|
| 173 |
+
if page == "Router Chat":
|
| 174 |
+
st.title("🧭 Router Chat (LangGraph)")
|
| 175 |
+
st.write("Multi-turn chat. The router chooses SQL / Graph / Tools / General automatically.")
|
| 176 |
+
|
| 177 |
+
if "router_messages" not in st.session_state:
|
| 178 |
+
st.session_state.router_messages = [
|
| 179 |
+
{"role": "assistant", "content": "Hi! Ask a question — I will route it to the right agent."}
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
c1, c2 = st.columns([1, 4])
|
| 183 |
+
with c1:
|
| 184 |
+
if st.button("Reset chat", key="reset_router"):
|
| 185 |
+
st.session_state.router_messages = [
|
| 186 |
+
{"role": "assistant", "content": "Chat reset. Ask a question!"}
|
| 187 |
+
]
|
| 188 |
+
st.rerun()
|
| 189 |
+
|
| 190 |
+
for m in st.session_state.router_messages:
|
| 191 |
+
with st.chat_message(m["role"]):
|
| 192 |
+
meta = m.get("meta") or {}
|
| 193 |
+
if m["role"] == "assistant" and show_routing and meta.get("route"):
|
| 194 |
+
st.caption(f"🧭 Routed to: `{meta['route']} agent`")
|
| 195 |
+
if m["role"] == "assistant" and show_tools_used and meta.get("tools_used"):
|
| 196 |
+
tools_line = ", ".join([f"`{t}`" for t in meta["tools_used"]])
|
| 197 |
+
st.caption(f"🧰 Tools used: {tools_line}")
|
| 198 |
+
st.write(m["content"])
|
| 199 |
+
|
| 200 |
+
prompt = st.chat_input("Ask a question...", key="router_chat_input")
|
| 201 |
+
if prompt:
|
| 202 |
+
st.session_state.router_messages.append({"role": "user", "content": prompt})
|
| 203 |
+
with st.chat_message("user"):
|
| 204 |
+
st.write(prompt)
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
with st.chat_message("assistant"):
|
| 208 |
+
route_slot = st.empty()
|
| 209 |
+
tools_slot = st.empty()
|
| 210 |
+
answer_slot = st.empty()
|
| 211 |
+
|
| 212 |
+
with st.spinner("Thinking..."):
|
| 213 |
+
graph = _router_graph_cached(settings.llm_model)
|
| 214 |
+
msgs = _dict_messages_to_lc(st.session_state.router_messages)
|
| 215 |
+
|
| 216 |
+
out = graph.invoke({"messages": msgs})
|
| 217 |
+
out_msgs = out.get("messages", []) or []
|
| 218 |
+
|
| 219 |
+
last_ai = next((mm for mm in reversed(out_msgs) if isinstance(mm, AIMessage)), None)
|
| 220 |
+
answer = last_ai.content if last_ai else "(no answer)"
|
| 221 |
+
|
| 222 |
+
dbg = out.get("debug", {}) or {}
|
| 223 |
+
route = out.get("route") or dbg.get("router_label") or dbg.get("routed_to") or "general"
|
| 224 |
+
tools_used = dbg.get("tools_used") or []
|
| 225 |
+
|
| 226 |
+
# Update same bubble (no jump)
|
| 227 |
+
if show_routing:
|
| 228 |
+
route_slot.caption(f"🧭 Routed to: `{route}` agent")
|
| 229 |
+
if show_tools_used and tools_used:
|
| 230 |
+
tools_slot.caption("🧰 Tools used: " + ", ".join([f"`{t}`" for t in tools_used]))
|
| 231 |
+
answer_slot.write(answer)
|
| 232 |
+
|
| 233 |
+
# Append to chat history AFTER we have final answer
|
| 234 |
+
st.session_state.router_messages.append(
|
| 235 |
+
{"role": "assistant", "content": answer, "meta": {"route": route, "tools_used": tools_used}}
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
with st.expander("Debug (route + steps)"):
|
| 239 |
+
st.write(out.get("debug", {}))
|
| 240 |
+
st.write("Messages produced:", len(out_msgs))
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
st.error(str(e))
|
| 244 |
+
|
| 245 |
+
elif page == "SQL Agent":
|
| 246 |
+
st.title("🧮 SQL Agent (Chat)")
|
| 247 |
+
st.write("Multi-turn SQL chat. Good for follow-ups like “now filter by …”")
|
| 248 |
+
|
| 249 |
+
# --- Intro: what the DB contains ---
|
| 250 |
+
with st.expander("📌 What's in the SQL database?", expanded=False):
|
| 251 |
+
st.markdown(
|
| 252 |
+
"""
|
| 253 |
+
The database contains information about **students, courses, enrollments, and attendance**.
|
| 254 |
+
|
| 255 |
+
- **students**: student_id, name, program, section, year
|
| 256 |
+
- **courses**: course_id, course_code, course_name, department, credits
|
| 257 |
+
- **enrollments**: student-course enrollment per semester with score and grade
|
| 258 |
+
- **attendance**: per-class attendance for each student in each course and semester (present = 1/0)
|
| 259 |
+
- **view**: student_performance (avg_score, num_A grades, num_courses per student per semester)
|
| 260 |
+
|
| 261 |
+
Use this chat for analytics questions like rankings, averages, cohorts, and time/semester filtering.
|
| 262 |
+
"""
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# --- Session init ---
|
| 266 |
+
if "sql_messages" not in st.session_state:
|
| 267 |
+
st.session_state.sql_messages = [
|
| 268 |
+
{"role": "assistant", "content": "Ask a question about the student analytics database, or try an example below."}
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
# --- Reset ---
|
| 272 |
+
c1, _ = st.columns([1, 5])
|
| 273 |
+
with c1:
|
| 274 |
+
if st.button("Reset chat", key="reset_sql"):
|
| 275 |
+
st.session_state.sql_messages = [{"role": "assistant", "content": "Chat reset. Ask a SQL question!"}]
|
| 276 |
+
st.rerun()
|
| 277 |
+
|
| 278 |
+
# --- Example queries (auto-run) ---
|
| 279 |
+
st.subheader("⚡ Try an example")
|
| 280 |
+
e1, e2, e3 = st.columns(3)
|
| 281 |
+
|
| 282 |
+
if e1.button("🏆 Top students (2025-Fall)", use_container_width=True):
|
| 283 |
+
st.session_state.sql_demo_query = (
|
| 284 |
+
"Show the top 10 students by average score in semester 2025-Fall. "
|
| 285 |
+
"Use the student_performance view. Return name, program, avg_score, num_courses, and num_A."
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if e2.button("📉 Lowest scoring course (2025-Fall)", use_container_width=True):
|
| 289 |
+
st.session_state.sql_demo_query = (
|
| 290 |
+
"In 2025-Fall, which course has the lowest average score? "
|
| 291 |
+
"Return course_code, course_name, department, and avg_score."
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if e3.button("🧾 Attendance < 70% (2025-Fall)", use_container_width=True):
|
| 295 |
+
st.session_state.sql_demo_query = (
|
| 296 |
+
"For semester 2025-Fall, show students whose overall attendance is below 70%. "
|
| 297 |
+
"Compute attendance_percent as 100 * AVG(present). "
|
| 298 |
+
"Return student name, program, attendance_percent, and total_classes."
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
demo_query = st.session_state.pop("sql_demo_query", None)
|
| 302 |
+
|
| 303 |
+
# --- Render chat history ---
|
| 304 |
+
for m in st.session_state.sql_messages:
|
| 305 |
+
st.chat_message(m["role"]).write(m["content"])
|
| 306 |
+
|
| 307 |
+
# --- Input (manual OR demo) ---
|
| 308 |
+
prompt = st.chat_input("Ask a SQL question...", key="sql_chat_input")
|
| 309 |
+
user_query = prompt or demo_query
|
| 310 |
+
|
| 311 |
+
if user_query:
|
| 312 |
+
st.session_state.sql_messages.append({"role": "user", "content": user_query})
|
| 313 |
+
st.chat_message("user").write(user_query)
|
| 314 |
+
|
| 315 |
+
try:
|
| 316 |
+
# Create assistant bubble immediately (prevents flicker)
|
| 317 |
+
with st.chat_message("assistant"):
|
| 318 |
+
answer_slot = st.empty()
|
| 319 |
+
|
| 320 |
+
with st.spinner("Thinking..."):
|
| 321 |
+
standalone = _rewrite_followup_to_standalone(
|
| 322 |
+
settings,
|
| 323 |
+
st.session_state.sql_messages,
|
| 324 |
+
user_query,
|
| 325 |
+
)
|
| 326 |
+
out = sql_answer(settings, standalone)
|
| 327 |
+
answer = str(out.get("answer", ""))
|
| 328 |
+
|
| 329 |
+
answer_slot.write(answer)
|
| 330 |
+
|
| 331 |
+
# Append to history AFTER we have the final answer
|
| 332 |
+
st.session_state.sql_messages.append({"role": "assistant", "content": answer})
|
| 333 |
+
|
| 334 |
+
with st.expander("Debug"):
|
| 335 |
+
st.write("Standalone question:", standalone)
|
| 336 |
+
st.json(out)
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
st.error(str(e))
|
| 340 |
+
|
| 341 |
+
elif page == "Graph Agent":
|
| 342 |
+
st.title("🕸️ Graph Agent (Chat)")
|
| 343 |
+
st.write("Multi-turn Cypher/Q&A chat over Neo4j.")
|
| 344 |
+
|
| 345 |
+
# --- Explain what graph contains ---
|
| 346 |
+
with st.expander("📌 What's in the Neo4j database?", expanded=False):
|
| 347 |
+
st.markdown(
|
| 348 |
+
"""
|
| 349 |
+
**Theme:** Hollywood movies.
|
| 350 |
+
|
| 351 |
+
**Nodes**
|
| 352 |
+
- `Movie`: title, tagline, released (year)
|
| 353 |
+
- `Person`: name, born (year)
|
| 354 |
+
|
| 355 |
+
**Relationships**
|
| 356 |
+
- `(:Person)-[:ACTED_IN]->(:Movie)`
|
| 357 |
+
- `(:Person)-[:DIRECTED]->(:Movie)`
|
| 358 |
+
- `(:Person)-[:PRODUCED]->(:Movie)`
|
| 359 |
+
|
| 360 |
+
**Examples you can ask about**
|
| 361 |
+
- Movies: “The Matrix”, “Top Gun”, “Jerry Maguire”
|
| 362 |
+
- People: “Tom Cruise”, “Keanu Reeves”, “Tom Hanks”
|
| 363 |
+
"""
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
with st.expander("🧠 Why Neo4j (graph DB) vs Web Search?", expanded=False):
|
| 367 |
+
st.markdown(
|
| 368 |
+
"""
|
| 369 |
+
**Neo4j is best for relationship-heavy questions** where you want exact, structured answers:
|
| 370 |
+
- “Who co-starred with Tom Cruise the most?”
|
| 371 |
+
- “Find actors who worked with both Tom Cruise and Tom Hanks.”
|
| 372 |
+
- “Show movies connected to *The Matrix* via shared actors.”
|
| 373 |
+
|
| 374 |
+
**Web search is best for open-world facts** (news, definitions, anything outside your dataset).
|
| 375 |
+
So: Web search = broad; Neo4j = deep structured relationships inside your graph.
|
| 376 |
+
"""
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# --- Session init ---
|
| 380 |
+
if "graph_messages" not in st.session_state:
|
| 381 |
+
st.session_state.graph_messages = [
|
| 382 |
+
{"role": "assistant", "content": "Ask a question about the Neo4j movies graph, or try an example below."}
|
| 383 |
+
]
|
| 384 |
+
|
| 385 |
+
# --- Reset button ---
|
| 386 |
+
c1, _ = st.columns([1, 5])
|
| 387 |
+
with c1:
|
| 388 |
+
if st.button("Reset chat", key="reset_graph"):
|
| 389 |
+
st.session_state.graph_messages = [
|
| 390 |
+
{"role": "assistant", "content": "Chat reset. Ask a graph question!"}
|
| 391 |
+
]
|
| 392 |
+
st.rerun()
|
| 393 |
+
|
| 394 |
+
# --- Example queries (auto-run) ---
|
| 395 |
+
st.subheader("⚡ Try an example")
|
| 396 |
+
e1, e2, e3 = st.columns(3)
|
| 397 |
+
|
| 398 |
+
if e1.button("🎭 Similar to The Matrix (shared actors)", use_container_width=True):
|
| 399 |
+
st.session_state.graph_demo_query = (
|
| 400 |
+
"Find movies that share at least 2 actors with The Matrix. "
|
| 401 |
+
"Return the movie titles and how many actors are shared."
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
if e2.button("🧭 Shortest path: Tom Hanks ↔ Tom Cruise", use_container_width=True):
|
| 405 |
+
st.session_state.graph_demo_query = (
|
| 406 |
+
"Show the shortest connection between Tom Hanks and Tom Cruise."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
if e3.button("🎬 Recommend like Cast Away", use_container_width=True):
|
| 410 |
+
st.session_state.graph_demo_query = (
|
| 411 |
+
"Recommend movies like Cast Away based on shared actor and director, and also name them."
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
demo_query = st.session_state.pop("graph_demo_query", None)
|
| 415 |
+
|
| 416 |
+
# --- Render chat history ---
|
| 417 |
+
for m in st.session_state.graph_messages:
|
| 418 |
+
st.chat_message(m["role"]).write(m["content"])
|
| 419 |
+
|
| 420 |
+
# --- Input (manual OR demo) ---
|
| 421 |
+
prompt = st.chat_input("Ask a graph question...", key="graph_chat_input")
|
| 422 |
+
user_query = prompt or demo_query
|
| 423 |
+
|
| 424 |
+
if user_query:
|
| 425 |
+
st.session_state.graph_messages.append({"role": "user", "content": user_query})
|
| 426 |
+
st.chat_message("user").write(user_query)
|
| 427 |
+
|
| 428 |
+
try:
|
| 429 |
+
# Create assistant bubble immediately (prevents flicker)
|
| 430 |
+
with st.chat_message("assistant"):
|
| 431 |
+
answer_slot = st.empty()
|
| 432 |
+
|
| 433 |
+
with st.spinner("Thinking..."):
|
| 434 |
+
standalone = _rewrite_followup_to_standalone(
|
| 435 |
+
settings,
|
| 436 |
+
st.session_state.graph_messages,
|
| 437 |
+
user_query,
|
| 438 |
+
)
|
| 439 |
+
out = graph_answer(settings, standalone)
|
| 440 |
+
answer = str(out.get("answer", ""))
|
| 441 |
+
|
| 442 |
+
answer_slot.write(answer)
|
| 443 |
+
|
| 444 |
+
# Append to history AFTER we have the final answer
|
| 445 |
+
st.session_state.graph_messages.append({"role": "assistant", "content": answer})
|
| 446 |
+
|
| 447 |
+
with st.expander("Debug (Cypher + results)"):
|
| 448 |
+
st.write("Standalone question:", standalone)
|
| 449 |
+
st.json(out.get("debug", {}))
|
| 450 |
+
|
| 451 |
+
except Exception as e:
|
| 452 |
+
st.error(str(e))
|
| 453 |
+
|
| 454 |
+
elif page == "Tools Agent":
|
| 455 |
+
st.title("🧰 Tools Agent (Chat)")
|
| 456 |
+
st.write("Tool-Assisted Research Chat (Web + Wikipedia + arXiv + Calculator).")
|
| 457 |
+
|
| 458 |
+
if "tools_messages" not in st.session_state:
|
| 459 |
+
st.session_state.tools_messages = [{"role": "assistant", "content": "Ask a question — I'll search web/Wikipedia/arXiv and use tools when needed."}]
|
| 460 |
+
|
| 461 |
+
c1, _ = st.columns([1, 5])
|
| 462 |
+
with c1:
|
| 463 |
+
if st.button("Reset chat", key="reset_tools"):
|
| 464 |
+
st.session_state.tools_messages = [{"role": "assistant", "content": "Chat reset. Ask a tools question!"}]
|
| 465 |
+
st.rerun()
|
| 466 |
+
|
| 467 |
+
for m in st.session_state.tools_messages:
|
| 468 |
+
st.chat_message(m["role"]).write(m["content"])
|
| 469 |
+
|
| 470 |
+
prompt = st.chat_input("Ask a tools question...", key="tools_chat_input")
|
| 471 |
+
if prompt:
|
| 472 |
+
st.session_state.tools_messages.append({"role": "user", "content": prompt})
|
| 473 |
+
st.chat_message("user").write(prompt)
|
| 474 |
+
|
| 475 |
+
try:
|
| 476 |
+
with st.chat_message("assistant"):
|
| 477 |
+
tools_slot = st.empty()
|
| 478 |
+
answer_slot = st.empty()
|
| 479 |
+
|
| 480 |
+
with st.spinner("Thinking..."):
|
| 481 |
+
tools_graph = _tools_graph_cached(settings.llm_model)
|
| 482 |
+
msgs = _dict_messages_to_lc(st.session_state.tools_messages)
|
| 483 |
+
|
| 484 |
+
out = tools_graph.invoke({"messages": msgs})
|
| 485 |
+
out_msgs = out.get("messages", []) or []
|
| 486 |
+
|
| 487 |
+
last_ai = next((mm for mm in reversed(out_msgs) if isinstance(mm, AIMessage)), None)
|
| 488 |
+
answer = last_ai.content if last_ai else "(no answer)"
|
| 489 |
+
tools_used = _extract_tool_names_from_messages(out_msgs)
|
| 490 |
+
|
| 491 |
+
if show_tools_used and tools_used:
|
| 492 |
+
tools_slot.caption("🧰 Tools used: " + ", ".join([f"`{t}`" for t in tools_used]))
|
| 493 |
+
answer_slot.write(answer)
|
| 494 |
+
|
| 495 |
+
st.session_state.tools_messages.append({"role": "assistant", "content": answer})
|
| 496 |
+
|
| 497 |
+
with st.expander("Debug (tool messages)"):
|
| 498 |
+
st.write("Tools used:", tools_used)
|
| 499 |
+
st.write("Messages produced:", len(out_msgs))
|
| 500 |
+
|
| 501 |
+
except Exception as e:
|
| 502 |
+
st.error(str(e))
|
| 503 |
+
|
| 504 |
+
# Optional: keep your old "run once each" tester as a quick health check
|
| 505 |
+
with st.expander("Quick tool health-check (run each tool once)"):
|
| 506 |
+
q = st.text_input("Query for one-shot tools test", key="tools_q_once")
|
| 507 |
+
if st.button("Run one-shot tools", type="secondary"):
|
| 508 |
+
try:
|
| 509 |
+
results = run_tools_once(
|
| 510 |
+
q,
|
| 511 |
+
wiki_chars=settings.wiki_doc_content_chars_max,
|
| 512 |
+
)
|
| 513 |
+
for r in results:
|
| 514 |
+
with st.expander(r.tool):
|
| 515 |
+
st.write(r.output)
|
| 516 |
+
except Exception as e:
|
| 517 |
+
st.error(str(e))
|
| 518 |
+
|
| 519 |
+
else:
|
| 520 |
+
st.title("⚙️ Settings / Health Check")
|
| 521 |
+
st.write("Use this page to confirm your keys and connections.")
|
| 522 |
+
|
| 523 |
+
if not settings.groq_api_key:
|
| 524 |
+
st.warning("GROQ_API_KEY is not set. Add it in your environment or .env.")
|
| 525 |
+
else:
|
| 526 |
+
st.success("GROQ_API_KEY is set.")
|
| 527 |
+
|
| 528 |
+
st.write("**Current model:**", settings.llm_model)
|
| 529 |
+
st.write("**SQLite path:**", settings.sqlite_path)
|
| 530 |
+
|
| 531 |
+
if settings.neo4j_uri:
|
| 532 |
+
st.write("**Neo4j URI:**", settings.neo4j_uri)
|
| 533 |
+
else:
|
| 534 |
+
st.info("Neo4j not configured yet (NEO4J_URI empty). Graph Agent will fail until set.")
|
orchestrator/__init__.py
ADDED
|
File without changes
|
orchestrator/factories.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from orchestrator.settings import Settings
|
| 5 |
+
|
| 6 |
+
def get_llm(settings: Settings, *, model: Optional[str] = None, temperature: float = 0.2):
|
| 7 |
+
# We use Groq in your stack (same as project 1).
|
| 8 |
+
# If you want OpenAI later, you can add a get_openai_llm here.
|
| 9 |
+
from langchain_groq import ChatGroq
|
| 10 |
+
|
| 11 |
+
m = model or settings.llm_model
|
| 12 |
+
if not settings.groq_api_key:
|
| 13 |
+
raise ValueError("Missing GROQ_API_KEY. Set it in your environment or .env.")
|
| 14 |
+
return ChatGroq(groq_api_key=settings.groq_api_key, model=m, temperature=temperature)
|
orchestrator/graph_agent.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Dict, Any, Optional
|
| 4 |
+
|
| 5 |
+
from orchestrator.settings import Settings
|
| 6 |
+
from orchestrator.factories import get_llm
|
| 7 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 8 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from langchain_community.graphs import Neo4jGraph
|
| 12 |
+
except Exception as e: # pragma: no cover
|
| 13 |
+
Neo4jGraph = None
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class GraphAgentDebug:
|
| 17 |
+
cypher: str = ""
|
| 18 |
+
raw_results: Any = None
|
| 19 |
+
error: str = ""
|
| 20 |
+
|
| 21 |
+
def _get_graph(settings: Settings):
|
| 22 |
+
if Neo4jGraph is None:
|
| 23 |
+
raise ImportError("Neo4jGraph not available. Install langchain-community[neo4j] or neo4j driver.")
|
| 24 |
+
if not (settings.neo4j_uri and settings.neo4j_username and settings.neo4j_password):
|
| 25 |
+
raise ValueError("Missing NEO4J_URI/NEO4J_USERNAME/NEO4J_PASSWORD.")
|
| 26 |
+
return Neo4jGraph(
|
| 27 |
+
url=settings.neo4j_uri,
|
| 28 |
+
username=settings.neo4j_username,
|
| 29 |
+
password=settings.neo4j_password,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def graph_answer(settings: Settings, question: str) -> Dict[str, Any]:
|
| 33 |
+
"""
|
| 34 |
+
A simple Graph DB agent:
|
| 35 |
+
1) Get graph schema
|
| 36 |
+
2) Ask LLM to write Cypher (ONLY the query)
|
| 37 |
+
3) Execute Cypher
|
| 38 |
+
4) Ask LLM to produce a final answer grounded in results
|
| 39 |
+
"""
|
| 40 |
+
llm = get_llm(settings, temperature=0)
|
| 41 |
+
graph = _get_graph(settings)
|
| 42 |
+
|
| 43 |
+
# schema string
|
| 44 |
+
schema = getattr(graph, "schema", None)
|
| 45 |
+
if callable(schema): # older versions: graph.schema is a function
|
| 46 |
+
schema = schema()
|
| 47 |
+
schema = schema or "Schema not available."
|
| 48 |
+
|
| 49 |
+
cypher_prompt = ChatPromptTemplate.from_template(
|
| 50 |
+
"""You are a Neo4j Cypher expert.
|
| 51 |
+
Given the graph schema below, write a Cypher query to answer the user question.
|
| 52 |
+
Return ONLY the Cypher query (no backticks, no explanation).
|
| 53 |
+
|
| 54 |
+
Schema:
|
| 55 |
+
{schema}
|
| 56 |
+
|
| 57 |
+
User question:
|
| 58 |
+
{question}
|
| 59 |
+
"""
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
to_cypher = cypher_prompt | llm | StrOutputParser()
|
| 63 |
+
|
| 64 |
+
dbg = GraphAgentDebug()
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
cypher = (to_cypher.invoke({"schema": schema, "question": question}) or "").strip()
|
| 68 |
+
# Basic cleanup
|
| 69 |
+
cypher = cypher.strip("` ")
|
| 70 |
+
dbg.cypher = cypher
|
| 71 |
+
if not cypher or len(cypher) < 6:
|
| 72 |
+
raise ValueError("LLM did not produce a valid Cypher query.")
|
| 73 |
+
|
| 74 |
+
results = graph.query(cypher)
|
| 75 |
+
dbg.raw_results = results
|
| 76 |
+
|
| 77 |
+
answer_prompt = ChatPromptTemplate.from_template(
|
| 78 |
+
"""You are a helpful assistant answering questions using ONLY the database results.
|
| 79 |
+
If results are empty, say you couldn't find relevant rows.
|
| 80 |
+
|
| 81 |
+
User question:
|
| 82 |
+
{question}
|
| 83 |
+
|
| 84 |
+
Cypher results (JSON-like):
|
| 85 |
+
{results}
|
| 86 |
+
|
| 87 |
+
Answer concisely and clearly.
|
| 88 |
+
"""
|
| 89 |
+
)
|
| 90 |
+
answer_chain = answer_prompt | llm | StrOutputParser()
|
| 91 |
+
answer = answer_chain.invoke({"question": question, "results": results})
|
| 92 |
+
|
| 93 |
+
return {"answer": answer, "debug": dbg.__dict__, "agent": "graph"}
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
dbg.error = str(e)
|
| 97 |
+
return {
|
| 98 |
+
"answer": "I couldn't query the graph database for that question. Check Neo4j connection/schema and try again.",
|
| 99 |
+
"debug": dbg.__dict__,
|
| 100 |
+
}
|
orchestrator/graphs.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Annotated, Any, Dict, List, Literal, TypedDict
|
| 4 |
+
|
| 5 |
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
| 6 |
+
from langgraph.graph import END, START, StateGraph
|
| 7 |
+
from langgraph.graph.message import add_messages
|
| 8 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
| 9 |
+
|
| 10 |
+
from orchestrator.factories import get_llm
|
| 11 |
+
from orchestrator.graph_agent import graph_answer
|
| 12 |
+
from orchestrator.settings import Settings
|
| 13 |
+
from orchestrator.sql_agent import sql_answer
|
| 14 |
+
from orchestrator.tools import make_web_wiki_arxiv_tools
|
| 15 |
+
|
| 16 |
+
Route = Literal["sql", "graph", "tools", "general"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RouterState(TypedDict, total=False):
|
| 20 |
+
messages: Annotated[list[BaseMessage], add_messages]
|
| 21 |
+
route: Route
|
| 22 |
+
debug: Dict[str, Any]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _safe_text(x: Any) -> str:
|
| 26 |
+
if x is None:
|
| 27 |
+
return ""
|
| 28 |
+
return x if isinstance(x, str) else str(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _last_user_text(messages: list[BaseMessage]) -> str:
|
| 32 |
+
for m in reversed(messages):
|
| 33 |
+
if isinstance(m, HumanMessage):
|
| 34 |
+
return _safe_text(m.content).strip()
|
| 35 |
+
return ""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _messages_to_transcript(messages: list[BaseMessage], max_turns: int = 8) -> str:
|
| 39 |
+
"""
|
| 40 |
+
Build a lightweight transcript from the last N Human/AI messages.
|
| 41 |
+
We intentionally skip tool messages to keep prompts stable.
|
| 42 |
+
"""
|
| 43 |
+
kept: List[BaseMessage] = []
|
| 44 |
+
for m in reversed(messages):
|
| 45 |
+
if isinstance(m, (HumanMessage, AIMessage)):
|
| 46 |
+
kept.append(m)
|
| 47 |
+
if len(kept) >= max_turns * 2: # ~turns * 2 messages
|
| 48 |
+
break
|
| 49 |
+
kept.reverse()
|
| 50 |
+
|
| 51 |
+
lines: List[str] = []
|
| 52 |
+
for m in kept:
|
| 53 |
+
if isinstance(m, HumanMessage):
|
| 54 |
+
lines.append(f"User: {_safe_text(m.content)}")
|
| 55 |
+
elif isinstance(m, AIMessage):
|
| 56 |
+
lines.append(f"Assistant: {_safe_text(m.content)}")
|
| 57 |
+
return "\n".join(lines).strip()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _merge_debug(state: RouterState, **kv: Any) -> Dict[str, Any]:
|
| 61 |
+
dbg = dict(state.get("debug") or {})
|
| 62 |
+
for k, v in kv.items():
|
| 63 |
+
if v is not None:
|
| 64 |
+
dbg[k] = v
|
| 65 |
+
return dbg
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _extract_tool_names(messages: list[BaseMessage]) -> List[str]:
|
| 69 |
+
"""
|
| 70 |
+
Extract tool names from AIMessage.tool_calls across LangChain variants.
|
| 71 |
+
"""
|
| 72 |
+
names: List[str] = []
|
| 73 |
+
for m in messages:
|
| 74 |
+
if isinstance(m, AIMessage):
|
| 75 |
+
tool_calls = getattr(m, "tool_calls", None) or []
|
| 76 |
+
for tc in tool_calls:
|
| 77 |
+
# tc may be dict-like or object-like
|
| 78 |
+
if isinstance(tc, dict):
|
| 79 |
+
n = tc.get("name")
|
| 80 |
+
else:
|
| 81 |
+
n = getattr(tc, "name", None)
|
| 82 |
+
if n:
|
| 83 |
+
names.append(str(n))
|
| 84 |
+
# de-dupe, preserve order
|
| 85 |
+
out: List[str] = []
|
| 86 |
+
for n in names:
|
| 87 |
+
if n not in out:
|
| 88 |
+
out.append(n)
|
| 89 |
+
return out
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _rewrite_to_standalone(llm, messages: list[BaseMessage]) -> str:
|
| 93 |
+
"""
|
| 94 |
+
If the user asks a follow-up like "show them", rewrite into a standalone question.
|
| 95 |
+
"""
|
| 96 |
+
question = _last_user_text(messages)
|
| 97 |
+
if not question:
|
| 98 |
+
return ""
|
| 99 |
+
|
| 100 |
+
# If there's only one user message total, no rewrite needed.
|
| 101 |
+
num_user_msgs = sum(1 for m in messages if isinstance(m, HumanMessage))
|
| 102 |
+
if num_user_msgs <= 1:
|
| 103 |
+
return question
|
| 104 |
+
|
| 105 |
+
transcript = _messages_to_transcript(messages, max_turns=8)
|
| 106 |
+
prompt = (
|
| 107 |
+
"Rewrite the user's latest question into a standalone question.\n"
|
| 108 |
+
"Do NOT answer the question.\n\n"
|
| 109 |
+
"Conversation:\n"
|
| 110 |
+
f"{transcript}\n\n"
|
| 111 |
+
"Latest user question:\n"
|
| 112 |
+
f"{question}\n\n"
|
| 113 |
+
"Standalone question:"
|
| 114 |
+
)
|
| 115 |
+
msg = llm.invoke(
|
| 116 |
+
[
|
| 117 |
+
SystemMessage(content="You rewrite follow-up questions into standalone questions."),
|
| 118 |
+
HumanMessage(content=prompt),
|
| 119 |
+
]
|
| 120 |
+
)
|
| 121 |
+
rewritten = _safe_text(getattr(msg, "content", "")).strip()
|
| 122 |
+
return rewritten or question
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def build_tools_agent_graph(settings: Settings):
|
| 126 |
+
tools = make_web_wiki_arxiv_tools(
|
| 127 |
+
wiki_chars=settings.wiki_doc_content_chars_max,
|
| 128 |
+
)
|
| 129 |
+
llm = get_llm(settings, temperature=0).bind_tools(tools)
|
| 130 |
+
|
| 131 |
+
def assistant(state: RouterState):
|
| 132 |
+
msg = llm.invoke(state["messages"])
|
| 133 |
+
return {"messages": [msg]}
|
| 134 |
+
|
| 135 |
+
g = StateGraph(RouterState)
|
| 136 |
+
g.add_node("assistant", assistant)
|
| 137 |
+
g.add_node("tools", ToolNode(tools))
|
| 138 |
+
g.add_edge(START, "assistant")
|
| 139 |
+
g.add_conditional_edges("assistant", tools_condition)
|
| 140 |
+
g.add_edge("tools", "assistant")
|
| 141 |
+
return g.compile()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def build_router_graph(settings: Settings):
|
| 145 |
+
tools_graph = build_tools_agent_graph(settings)
|
| 146 |
+
llm_router = get_llm(settings, temperature=0)
|
| 147 |
+
|
| 148 |
+
route_prompt = (
|
| 149 |
+
"You are a router for a multi-agent system.\n"
|
| 150 |
+
"Choose exactly ONE route label from: sql, graph, tools, general.\n\n"
|
| 151 |
+
"Routing rules:\n"
|
| 152 |
+
"- sql: querying a relational database (tables/rows, SQL, students DB, counts, filters).\n"
|
| 153 |
+
"- graph: querying a Neo4j graph database (nodes/relationships, Cypher).\n"
|
| 154 |
+
"- tools: needs external knowledge / searching (Wikipedia/arXiv/web) or tool use.\n"
|
| 155 |
+
"- general: conceptual explanation or chat that doesn't need tools/DB queries.\n\n"
|
| 156 |
+
"Return ONLY the label.\n"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def router(state: RouterState):
|
| 160 |
+
msgs = state.get("messages", [])
|
| 161 |
+
q = _last_user_text(msgs)
|
| 162 |
+
transcript = _messages_to_transcript(msgs, max_turns=8)
|
| 163 |
+
|
| 164 |
+
payload = (
|
| 165 |
+
"Conversation transcript:\n"
|
| 166 |
+
f"{transcript}\n\n"
|
| 167 |
+
"Latest user question:\n"
|
| 168 |
+
f"{q}"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
msg = llm_router.invoke(
|
| 172 |
+
[SystemMessage(content=route_prompt), HumanMessage(content=payload)]
|
| 173 |
+
)
|
| 174 |
+
label = _safe_text(msg.content).strip().lower()
|
| 175 |
+
if label not in ("sql", "graph", "tools", "general"):
|
| 176 |
+
label = "general"
|
| 177 |
+
|
| 178 |
+
dbg = _merge_debug(state, router_label=label, router_raw=msg.content, routed_to=label)
|
| 179 |
+
return {"route": label, "debug": dbg}
|
| 180 |
+
|
| 181 |
+
def sql_node(state: RouterState):
|
| 182 |
+
standalone = _rewrite_to_standalone(llm_router, state["messages"])
|
| 183 |
+
out = sql_answer(settings, standalone)
|
| 184 |
+
dbg = _merge_debug(state, routed_to="sql", sql=out, standalone_question=standalone)
|
| 185 |
+
return {"route": "sql", "messages": [AIMessage(content=str(out["answer"]))], "debug": dbg}
|
| 186 |
+
|
| 187 |
+
def graph_node(state: RouterState):
|
| 188 |
+
standalone = _rewrite_to_standalone(llm_router, state["messages"])
|
| 189 |
+
out = graph_answer(settings, standalone)
|
| 190 |
+
dbg = _merge_debug(state, routed_to="graph", graph=out.get("debug", {}), standalone_question=standalone)
|
| 191 |
+
return {"route": "graph", "messages": [AIMessage(content=str(out["answer"]))], "debug": dbg}
|
| 192 |
+
|
| 193 |
+
def tools_node(state: RouterState):
|
| 194 |
+
out_state = tools_graph.invoke({"messages": state["messages"]})
|
| 195 |
+
out_msgs = out_state.get("messages", [])
|
| 196 |
+
tools_used = _extract_tool_names(out_msgs)
|
| 197 |
+
|
| 198 |
+
dbg = _merge_debug(
|
| 199 |
+
state,
|
| 200 |
+
routed_to="tools",
|
| 201 |
+
tools_used=tools_used,
|
| 202 |
+
tools_graph={"messages_len": len(out_msgs)},
|
| 203 |
+
)
|
| 204 |
+
return {"route": "tools", "messages": out_msgs, "debug": dbg}
|
| 205 |
+
|
| 206 |
+
def general_node(state: RouterState):
|
| 207 |
+
# Use the conversation itself (not just last message)
|
| 208 |
+
convo = [m for m in state["messages"] if isinstance(m, (HumanMessage, AIMessage))]
|
| 209 |
+
msg = llm_router.invoke([SystemMessage(content="You are a helpful assistant.")] + convo)
|
| 210 |
+
dbg = _merge_debug(state, routed_to="general")
|
| 211 |
+
return {"route": "general", "messages": [AIMessage(content=_safe_text(msg.content))], "debug": dbg}
|
| 212 |
+
|
| 213 |
+
g = StateGraph(RouterState)
|
| 214 |
+
g.add_node("router", router)
|
| 215 |
+
g.add_node("sql", sql_node)
|
| 216 |
+
g.add_node("graph", graph_node)
|
| 217 |
+
g.add_node("tools", tools_node)
|
| 218 |
+
g.add_node("general", general_node)
|
| 219 |
+
|
| 220 |
+
g.add_edge(START, "router")
|
| 221 |
+
g.add_conditional_edges(
|
| 222 |
+
"router",
|
| 223 |
+
lambda s: s["route"],
|
| 224 |
+
{"sql": "sql", "graph": "graph", "tools": "tools", "general": "general"},
|
| 225 |
+
)
|
| 226 |
+
g.add_edge("sql", END)
|
| 227 |
+
g.add_edge("graph", END)
|
| 228 |
+
g.add_edge("tools", END)
|
| 229 |
+
g.add_edge("general", END)
|
| 230 |
+
|
| 231 |
+
return g.compile()
|
orchestrator/settings.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
@dataclass(frozen=True)
|
| 6 |
+
class Settings:
|
| 7 |
+
# LLM
|
| 8 |
+
groq_api_key: str = os.getenv("GROQ_API_KEY", "")
|
| 9 |
+
llm_model: str = os.getenv("LLM_MODEL", "meta-llama/llama-4-maverick-17b-128e-instruct")
|
| 10 |
+
|
| 11 |
+
# SQL (SQLite by default)
|
| 12 |
+
sqlite_path: str = os.getenv("SQLITE_PATH", "student.db")
|
| 13 |
+
|
| 14 |
+
# Neo4j Graph DB
|
| 15 |
+
neo4j_uri: str = os.getenv("NEO4J_URI", "")
|
| 16 |
+
neo4j_username: str = os.getenv("NEO4J_USERNAME", "")
|
| 17 |
+
neo4j_password: str = os.getenv("NEO4J_PASSWORD", "")
|
| 18 |
+
|
| 19 |
+
# Tool settings
|
| 20 |
+
# wiki_top_k_results: int = int(os.getenv("WIKI_TOP_K", "3"))
|
| 21 |
+
wiki_doc_content_chars_max: int = int(os.getenv("WIKI_DOC_CHARS", "2000"))
|
| 22 |
+
|
| 23 |
+
# Debug
|
| 24 |
+
debug: bool = os.getenv("DEBUG", "0") in ("1","true","True","yes","YES")
|
orchestrator/sql_agent.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional, Dict, Any
|
| 5 |
+
import sqlite3
|
| 6 |
+
|
| 7 |
+
from sqlalchemy import create_engine
|
| 8 |
+
|
| 9 |
+
from orchestrator.settings import Settings
|
| 10 |
+
|
| 11 |
+
from langchain_groq import ChatGroq
|
| 12 |
+
from langchain_community.utilities.sql_database import SQLDatabase
|
| 13 |
+
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
| 14 |
+
from langchain_community.agent_toolkits.sql.base import create_sql_agent
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _resolve_sqlite_path(settings: Settings, db_path: Optional[str] = None) -> Path:
|
| 18 |
+
p = Path(db_path or settings.sqlite_path)
|
| 19 |
+
if not p.is_absolute():
|
| 20 |
+
# project root = parent of orchestrator/
|
| 21 |
+
p = (Path(__file__).resolve().parents[1] / p).resolve()
|
| 22 |
+
return p
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _make_sql_db_readonly(sqlite_path: Path) -> SQLDatabase:
|
| 26 |
+
if not sqlite_path.exists():
|
| 27 |
+
raise FileNotFoundError(
|
| 28 |
+
f"SQLite DB not found at: {sqlite_path}\n"
|
| 29 |
+
f"Fix: put student.db at project root OR set SQLITE_PATH to an absolute path."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def _connect():
|
| 33 |
+
return sqlite3.connect(f"file:{sqlite_path.as_posix()}?mode=ro", uri=True)
|
| 34 |
+
|
| 35 |
+
engine = create_engine("sqlite:///", creator=_connect)
|
| 36 |
+
return SQLDatabase(engine)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _make_llm(settings: Settings):
|
| 40 |
+
# ChatGroq param names differ across versions; support both.
|
| 41 |
+
try:
|
| 42 |
+
return ChatGroq(
|
| 43 |
+
api_key=settings.groq_api_key,
|
| 44 |
+
model=settings.llm_model,
|
| 45 |
+
temperature=0,
|
| 46 |
+
)
|
| 47 |
+
except TypeError:
|
| 48 |
+
return ChatGroq(
|
| 49 |
+
groq_api_key=settings.groq_api_key,
|
| 50 |
+
model_name=settings.llm_model,
|
| 51 |
+
temperature=0,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def make_sql_agent(settings: Settings, *, db_path: Optional[str] = None):
|
| 56 |
+
llm = _make_llm(settings)
|
| 57 |
+
|
| 58 |
+
sqlite_path = _resolve_sqlite_path(settings, db_path=db_path)
|
| 59 |
+
db = _make_sql_db_readonly(sqlite_path)
|
| 60 |
+
|
| 61 |
+
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
| 62 |
+
|
| 63 |
+
# This is the key difference vs your b version:
|
| 64 |
+
# Force the tool-calling SQL agent (most reliable on LC 1.2.x).
|
| 65 |
+
agent = create_sql_agent(
|
| 66 |
+
llm=llm,
|
| 67 |
+
toolkit=toolkit,
|
| 68 |
+
agent_type="tool-calling",
|
| 69 |
+
handle_parsing_errors=True,
|
| 70 |
+
max_iterations=30,
|
| 71 |
+
max_execution_time=60,
|
| 72 |
+
verbose=bool(settings.debug),
|
| 73 |
+
return_intermediate_steps=bool(settings.debug),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return agent, db, str(sqlite_path)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def sql_answer(settings: Settings, question: str, *, db_path: Optional[str] = None) -> Dict[str, Any]:
|
| 80 |
+
agent, db, sqlite_path = make_sql_agent(settings, db_path=db_path)
|
| 81 |
+
|
| 82 |
+
q = (question or "").strip().lower()
|
| 83 |
+
|
| 84 |
+
# Keep your deterministic shortcut (nice UX)
|
| 85 |
+
if any(s in q for s in ["list the tables", "list tables", "show tables", "what tables"]):
|
| 86 |
+
tables = db.get_usable_table_names()
|
| 87 |
+
return {"answer": "Tables: " + ", ".join(tables), "db_path": sqlite_path}
|
| 88 |
+
|
| 89 |
+
# Run agent
|
| 90 |
+
out = agent.invoke({"input": question})
|
| 91 |
+
|
| 92 |
+
# Normalize output
|
| 93 |
+
answer = out.get("output") if isinstance(out, dict) else str(out)
|
| 94 |
+
|
| 95 |
+
result = {"answer": str(answer), "db_path": sqlite_path, "agent": "sql"}
|
| 96 |
+
|
| 97 |
+
# If debug enabled, surface intermediate steps in Streamlit expander
|
| 98 |
+
if isinstance(out, dict) and "intermediate_steps" in out:
|
| 99 |
+
result["intermediate_steps"] = out["intermediate_steps"]
|
| 100 |
+
|
| 101 |
+
return result
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# from __future__ import annotations
|
| 110 |
+
|
| 111 |
+
# from pathlib import Path
|
| 112 |
+
# from typing import Optional, Dict, Any
|
| 113 |
+
# import sqlite3
|
| 114 |
+
|
| 115 |
+
# from sqlalchemy import create_engine
|
| 116 |
+
|
| 117 |
+
# from orchestrator.settings import Settings
|
| 118 |
+
# from orchestrator.factories import get_llm
|
| 119 |
+
|
| 120 |
+
# # --- Imports that vary across LangChain versions ---
|
| 121 |
+
# try:
|
| 122 |
+
# # langchain >= 1.x
|
| 123 |
+
# from langchain.sql_database import SQLDatabase
|
| 124 |
+
# except Exception:
|
| 125 |
+
# # older / community
|
| 126 |
+
# from langchain_community.utilities import SQLDatabase
|
| 127 |
+
|
| 128 |
+
# try:
|
| 129 |
+
# from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
| 130 |
+
# except Exception:
|
| 131 |
+
# # older path (rare)
|
| 132 |
+
# from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
| 133 |
+
|
| 134 |
+
# try:
|
| 135 |
+
# from langchain.agents import create_sql_agent
|
| 136 |
+
# except Exception:
|
| 137 |
+
# from langchain_community.agent_toolkits.sql.base import create_sql_agent
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# def _resolve_sqlite_path(settings: Settings) -> Path:
|
| 141 |
+
# """
|
| 142 |
+
# Resolve SQLITE_PATH relative to project root (parent of orchestrator/),
|
| 143 |
+
# so Streamlit's current working directory does not break DB loading.
|
| 144 |
+
# """
|
| 145 |
+
# p = Path(settings.sqlite_path)
|
| 146 |
+
# if not p.is_absolute():
|
| 147 |
+
# p = (Path(__file__).resolve().parents[1] / p).resolve()
|
| 148 |
+
# return p
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# def _make_sql_db_readonly(sqlite_path: Path) -> SQLDatabase:
|
| 152 |
+
# """
|
| 153 |
+
# Open SQLite in READ-ONLY mode so a wrong path does NOT create an empty DB file.
|
| 154 |
+
# """
|
| 155 |
+
# if not sqlite_path.exists():
|
| 156 |
+
# raise FileNotFoundError(
|
| 157 |
+
# f"SQLite DB not found at: {sqlite_path}\n"
|
| 158 |
+
# f"Fix: put student.db at the project root OR set SQLITE_PATH to an absolute path."
|
| 159 |
+
# )
|
| 160 |
+
|
| 161 |
+
# def _connect():
|
| 162 |
+
# return sqlite3.connect(f"file:{sqlite_path.as_posix()}?mode=ro", uri=True)
|
| 163 |
+
|
| 164 |
+
# engine = create_engine("sqlite:///", creator=_connect)
|
| 165 |
+
# return SQLDatabase(engine)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# def _create_agent(llm, toolkit, verbose: bool):
|
| 169 |
+
# """
|
| 170 |
+
# Create SQL agent WITHOUT passing kwargs that frequently clash with defaults
|
| 171 |
+
# in langchain-classic AgentExecutor.
|
| 172 |
+
# """
|
| 173 |
+
# # Keep only the safest option; many builds already set other defaults internally.
|
| 174 |
+
# agent_exec_kwargs = {"handle_parsing_errors": True}
|
| 175 |
+
|
| 176 |
+
# # Some versions accept max_iterations/max_execution_time top-level.
|
| 177 |
+
# # Some accept neither.
|
| 178 |
+
# # We try progressively.
|
| 179 |
+
# try:
|
| 180 |
+
# return create_sql_agent(
|
| 181 |
+
# llm=llm,
|
| 182 |
+
# toolkit=toolkit,
|
| 183 |
+
# verbose=verbose,
|
| 184 |
+
# max_iterations=25,
|
| 185 |
+
# max_execution_time=60,
|
| 186 |
+
# agent_executor_kwargs=agent_exec_kwargs,
|
| 187 |
+
# )
|
| 188 |
+
# except TypeError:
|
| 189 |
+
# # Try without time/iteration controls to avoid duplicate kwargs.
|
| 190 |
+
# return create_sql_agent(
|
| 191 |
+
# llm=llm,
|
| 192 |
+
# toolkit=toolkit,
|
| 193 |
+
# verbose=verbose,
|
| 194 |
+
# agent_executor_kwargs=agent_exec_kwargs,
|
| 195 |
+
# )
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# def make_sql_agent(settings: Settings, *, db_path: Optional[str] = None):
|
| 199 |
+
# llm = get_llm(settings, temperature=0)
|
| 200 |
+
|
| 201 |
+
# sqlite_path = Path(db_path).expanduser().resolve() if db_path else _resolve_sqlite_path(settings)
|
| 202 |
+
# db = _make_sql_db_readonly(sqlite_path)
|
| 203 |
+
# toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
| 204 |
+
|
| 205 |
+
# agent = _create_agent(llm, toolkit, verbose=getattr(settings, "debug", False))
|
| 206 |
+
# return agent, db, str(sqlite_path)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# def sql_answer(settings: Settings, question: str, *, db_path: Optional[str] = None) -> Dict[str, Any]:
|
| 210 |
+
# agent, db, sqlite_path = make_sql_agent(settings, db_path=db_path)
|
| 211 |
+
|
| 212 |
+
# # Deterministic shortcut so this never loops.
|
| 213 |
+
# q = (question or "").strip().lower()
|
| 214 |
+
# if any(s in q for s in ["list the tables", "list tables", "show tables", "what tables"]):
|
| 215 |
+
# try:
|
| 216 |
+
# tables = db.get_usable_table_names()
|
| 217 |
+
# except Exception:
|
| 218 |
+
# # fallback for older SQLDatabase implementations
|
| 219 |
+
# tables = []
|
| 220 |
+
# return {
|
| 221 |
+
# "answer": "Tables: " + (", ".join(tables) if tables else "(none found)"),
|
| 222 |
+
# "db_path": sqlite_path,
|
| 223 |
+
# }
|
| 224 |
+
|
| 225 |
+
# # Run agent
|
| 226 |
+
# out = agent.invoke({"input": question})
|
| 227 |
+
|
| 228 |
+
# # Normalize output
|
| 229 |
+
# if isinstance(out, dict):
|
| 230 |
+
# answer = out.get("output") or out.get("answer") or str(out)
|
| 231 |
+
# else:
|
| 232 |
+
# answer = str(out)
|
| 233 |
+
|
| 234 |
+
# return {"answer": answer, "db_path": sqlite_path}
|
orchestrator/tools.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
from langchain_core.tools import tool
|
| 6 |
+
from langchain_community.utilities import WikipediaAPIWrapper
|
| 7 |
+
from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun, ArxivQueryRun
|
| 8 |
+
|
| 9 |
+
# --- Calculator tool (safe arithmetic) ---
|
| 10 |
+
import ast
|
| 11 |
+
import operator as op
|
| 12 |
+
|
| 13 |
+
_ALLOWED_OPS = {
|
| 14 |
+
ast.Add: op.add,
|
| 15 |
+
ast.Sub: op.sub,
|
| 16 |
+
ast.Mult: op.mul,
|
| 17 |
+
ast.Div: op.truediv,
|
| 18 |
+
ast.Pow: op.pow,
|
| 19 |
+
ast.USub: op.neg,
|
| 20 |
+
ast.Mod: op.mod,
|
| 21 |
+
ast.FloorDiv: op.floordiv,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
def _eval_expr(expr: str) -> float:
|
| 25 |
+
"""Safely evaluate a basic arithmetic expression."""
|
| 26 |
+
node = ast.parse(expr, mode="eval").body
|
| 27 |
+
|
| 28 |
+
def _eval(n):
|
| 29 |
+
if isinstance(n, ast.Num): # py<3.8
|
| 30 |
+
return n.n
|
| 31 |
+
if isinstance(n, ast.Constant): # py>=3.8
|
| 32 |
+
if isinstance(n.value, (int, float)):
|
| 33 |
+
return n.value
|
| 34 |
+
raise ValueError("Only numbers are allowed.")
|
| 35 |
+
if isinstance(n, ast.BinOp) and type(n.op) in _ALLOWED_OPS:
|
| 36 |
+
return _ALLOWED_OPS[type(n.op)](_eval(n.left), _eval(n.right))
|
| 37 |
+
if isinstance(n, ast.UnaryOp) and type(n.op) in _ALLOWED_OPS:
|
| 38 |
+
return _ALLOWED_OPS[type(n.op)](_eval(n.operand))
|
| 39 |
+
raise ValueError("Only basic arithmetic is allowed.")
|
| 40 |
+
|
| 41 |
+
return float(_eval(node))
|
| 42 |
+
|
| 43 |
+
@tool
|
| 44 |
+
def calculator(expression: str) -> str:
|
| 45 |
+
"""Evaluate a math expression. Input must be a plain arithmetic expression (e.g., '12*(3+4)')."""
|
| 46 |
+
try:
|
| 47 |
+
return str(_eval_expr(expression))
|
| 48 |
+
except Exception as e:
|
| 49 |
+
return f"[calculator error] {e}"
|
| 50 |
+
|
| 51 |
+
# --- Web/Wiki/Arxiv tools ---
|
| 52 |
+
def make_web_wiki_arxiv_tools(*, wiki_k: int = 3, wiki_chars: int = 2000):
|
| 53 |
+
"""Return tool objects compatible with LangGraph ToolNode."""
|
| 54 |
+
|
| 55 |
+
web = DuckDuckGoSearchRun()
|
| 56 |
+
|
| 57 |
+
# IMPORTANT: WikipediaQueryRun requires api_wrapper in your installed versions.
|
| 58 |
+
wiki_wrapper = WikipediaAPIWrapper(top_k_results=wiki_k, doc_content_chars_max=wiki_chars)
|
| 59 |
+
wiki = WikipediaQueryRun(api_wrapper=wiki_wrapper)
|
| 60 |
+
|
| 61 |
+
# ArxivQueryRun works similarly; its underlying API doesn't require keys.
|
| 62 |
+
arxiv = ArxivQueryRun()
|
| 63 |
+
|
| 64 |
+
return [web, wiki, arxiv, calculator]
|
| 65 |
+
|
| 66 |
+
# @dataclass
|
| 67 |
+
# class ToolResult:
|
| 68 |
+
# tool: str
|
| 69 |
+
# output: str
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class ToolResult:
|
| 73 |
+
tool: str
|
| 74 |
+
output: str
|
| 75 |
+
ok: bool = True
|
| 76 |
+
error: Optional[str] = None
|
| 77 |
+
|
| 78 |
+
def run_tools_once(query: str, *, wiki_k: int = 3, wiki_chars: int = 2000) -> List[ToolResult]:
|
| 79 |
+
"""Non-agent helper: run each tool once and return outputs (good for debugging)."""
|
| 80 |
+
tools = make_web_wiki_arxiv_tools(wiki_k=wiki_k, wiki_chars=wiki_chars)
|
| 81 |
+
out: List[ToolResult] = []
|
| 82 |
+
for t in tools:
|
| 83 |
+
try:
|
| 84 |
+
out.append(ToolResult(tool=t.name, output=str(t.run(query))))
|
| 85 |
+
except Exception as e:
|
| 86 |
+
out.append(ToolResult(tool=t.name, output=f"[tool error] {e}", ok=False, error=str(e)))
|
| 87 |
+
|
| 88 |
+
return out
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.35
|
| 2 |
+
python-dotenv>=1.0
|
| 3 |
+
|
| 4 |
+
# LangChain / LangGraph stack (align with your env)
|
| 5 |
+
langchain>=1.2.0
|
| 6 |
+
langchain-core>=0.3.0
|
| 7 |
+
langchain-community>=0.4.0
|
| 8 |
+
langchain-groq>=0.3.0
|
| 9 |
+
langgraph>=0.2.0
|
| 10 |
+
langchain-neo4j
|
| 11 |
+
|
| 12 |
+
# Tools
|
| 13 |
+
ddgs
|
| 14 |
+
wikipedia>=1.4.0
|
| 15 |
+
arxiv>=2.1.0
|
| 16 |
+
|
| 17 |
+
# SQL
|
| 18 |
+
sqlalchemy>=2.0
|
| 19 |
+
|
| 20 |
+
# Neo4j graph
|
| 21 |
+
neo4j>=5.0
|
school.db
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:adb52805ce8c7dc02d5ecc0f104eac5944ac2d234a0cfe04276714e29ea9faf8
|
| 3 |
+
size 1478656
|
sqlite.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sqlite.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sqlite3
|
| 6 |
+
import random
|
| 7 |
+
from datetime import date, timedelta
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
DB_NAME = os.environ.get("SQLITE_DB", "school.db")
|
| 12 |
+
SEED = int(os.environ.get("SQLITE_SEED", "42"))
|
| 13 |
+
|
| 14 |
+
# Scale knobs (keep modest for fast demo)
|
| 15 |
+
NUM_STUDENTS = int(os.environ.get("NUM_STUDENTS", "120"))
|
| 16 |
+
NUM_COURSES = int(os.environ.get("NUM_COURSES", "14"))
|
| 17 |
+
SEMESTERS = ["2024-Fall", "2025-Spring", "2025-Fall"] # change freely
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
FIRST_NAMES = [
|
| 21 |
+
"Aarav", "Vivaan", "Aditya", "Vihaan", "Arjun", "Sai", "Reyansh", "Ishaan", "Krishna",
|
| 22 |
+
"Ananya", "Aadhya", "Diya", "Ira", "Meera", "Saanvi", "Myra", "Aarohi", "Riya",
|
| 23 |
+
"Rahul", "Kiran", "Suresh", "Priya", "Neha", "Vikram", "Nikhil", "Sneha", "Pooja",
|
| 24 |
+
]
|
| 25 |
+
LAST_NAMES = [
|
| 26 |
+
"Verma", "Patel", "Gupta", "Mehta", "Singh",
|
| 27 |
+
"Kumar", "Das", "Roy", "Bose", "Chowdhury",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
PROGRAMS = ["Computer Science", "Data Science", "AI & ML", "Information Systems", "Cybersecurity"]
|
| 31 |
+
SECTIONS = ["A", "B", "C", "D"]
|
| 32 |
+
|
| 33 |
+
DEPARTMENTS = ["CS", "DS", "AI", "IS", "CY"]
|
| 34 |
+
COURSE_TITLES = [
|
| 35 |
+
"Database Systems", "Operating Systems", "Computer Networks", "Machine Learning",
|
| 36 |
+
"Deep Learning", "Data Structures", "Algorithms", "Cloud Computing",
|
| 37 |
+
"NLP Fundamentals", "Information Security", "Software Engineering",
|
| 38 |
+
"Data Visualization", "MLOps Foundations", "Graph Databases",
|
| 39 |
+
"Statistics for Data Science", "Ethical AI",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
GRADE_BANDS = [
|
| 43 |
+
("A", 90, 100),
|
| 44 |
+
("B", 80, 89),
|
| 45 |
+
("C", 70, 79),
|
| 46 |
+
("D", 60, 69),
|
| 47 |
+
("F", 0, 59),
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def make_name(rng: random.Random) -> str:
|
| 52 |
+
return f"{rng.choice(FIRST_NAMES)} {rng.choice(LAST_NAMES)}"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def grade_from_score(score: float) -> str:
|
| 56 |
+
for letter, lo, hi in GRADE_BANDS:
|
| 57 |
+
if lo <= score <= hi:
|
| 58 |
+
return letter
|
| 59 |
+
return "F"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def connect(db_path: Path) -> sqlite3.Connection:
|
| 63 |
+
con = sqlite3.connect(str(db_path))
|
| 64 |
+
con.execute("PRAGMA foreign_keys = ON;")
|
| 65 |
+
con.execute("PRAGMA journal_mode = WAL;")
|
| 66 |
+
con.execute("PRAGMA synchronous = NORMAL;")
|
| 67 |
+
return con
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def recreate_schema(con: sqlite3.Connection) -> None:
|
| 71 |
+
cur = con.cursor()
|
| 72 |
+
|
| 73 |
+
# Drop in FK-safe order
|
| 74 |
+
cur.executescript(
|
| 75 |
+
"""
|
| 76 |
+
DROP TABLE IF EXISTS attendance;
|
| 77 |
+
DROP TABLE IF EXISTS enrollments;
|
| 78 |
+
DROP TABLE IF EXISTS courses;
|
| 79 |
+
DROP TABLE IF EXISTS students;
|
| 80 |
+
"""
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
cur.executescript(
|
| 84 |
+
"""
|
| 85 |
+
CREATE TABLE students (
|
| 86 |
+
student_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 87 |
+
name TEXT NOT NULL,
|
| 88 |
+
program TEXT NOT NULL,
|
| 89 |
+
section TEXT NOT NULL,
|
| 90 |
+
year INTEGER NOT NULL CHECK (year BETWEEN 1 AND 4)
|
| 91 |
+
);
|
| 92 |
+
|
| 93 |
+
CREATE TABLE courses (
|
| 94 |
+
course_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 95 |
+
course_code TEXT NOT NULL UNIQUE,
|
| 96 |
+
course_name TEXT NOT NULL,
|
| 97 |
+
department TEXT NOT NULL,
|
| 98 |
+
credits INTEGER NOT NULL CHECK (credits BETWEEN 1 AND 6)
|
| 99 |
+
);
|
| 100 |
+
|
| 101 |
+
CREATE TABLE enrollments (
|
| 102 |
+
enrollment_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 103 |
+
student_id INTEGER NOT NULL,
|
| 104 |
+
course_id INTEGER NOT NULL,
|
| 105 |
+
semester TEXT NOT NULL,
|
| 106 |
+
score REAL NOT NULL CHECK (score BETWEEN 0 AND 100),
|
| 107 |
+
grade TEXT NOT NULL CHECK (grade IN ('A','B','C','D','F')),
|
| 108 |
+
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
| 109 |
+
FOREIGN KEY (student_id) REFERENCES students(student_id) ON DELETE CASCADE,
|
| 110 |
+
FOREIGN KEY (course_id) REFERENCES courses(course_id) ON DELETE CASCADE,
|
| 111 |
+
UNIQUE(student_id, course_id, semester)
|
| 112 |
+
);
|
| 113 |
+
|
| 114 |
+
CREATE TABLE attendance (
|
| 115 |
+
attendance_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 116 |
+
student_id INTEGER NOT NULL,
|
| 117 |
+
course_id INTEGER NOT NULL,
|
| 118 |
+
semester TEXT NOT NULL,
|
| 119 |
+
class_date TEXT NOT NULL,
|
| 120 |
+
present INTEGER NOT NULL CHECK (present IN (0,1)),
|
| 121 |
+
FOREIGN KEY (student_id) REFERENCES students(student_id) ON DELETE CASCADE,
|
| 122 |
+
FOREIGN KEY (course_id) REFERENCES courses(course_id) ON DELETE CASCADE
|
| 123 |
+
);
|
| 124 |
+
|
| 125 |
+
CREATE INDEX idx_enrollments_student ON enrollments(student_id);
|
| 126 |
+
CREATE INDEX idx_enrollments_course ON enrollments(course_id);
|
| 127 |
+
CREATE INDEX idx_enrollments_sem ON enrollments(semester);
|
| 128 |
+
|
| 129 |
+
CREATE INDEX idx_att_student_course ON attendance(student_id, course_id);
|
| 130 |
+
CREATE INDEX idx_att_semester ON attendance(semester);
|
| 131 |
+
CREATE INDEX idx_att_date ON attendance(class_date);
|
| 132 |
+
"""
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
con.commit()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def seed_students(con: sqlite3.Connection, rng: random.Random) -> None:
|
| 139 |
+
cur = con.cursor()
|
| 140 |
+
rows = []
|
| 141 |
+
for _ in range(NUM_STUDENTS):
|
| 142 |
+
rows.append(
|
| 143 |
+
(
|
| 144 |
+
make_name(rng),
|
| 145 |
+
rng.choice(PROGRAMS),
|
| 146 |
+
rng.choice(SECTIONS),
|
| 147 |
+
rng.randint(1, 4),
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
cur.executemany(
|
| 151 |
+
"INSERT INTO students(name, program, section, year) VALUES (?,?,?,?)",
|
| 152 |
+
rows,
|
| 153 |
+
)
|
| 154 |
+
con.commit()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def seed_courses(con: sqlite3.Connection, rng: random.Random) -> None:
|
| 158 |
+
cur = con.cursor()
|
| 159 |
+
titles = COURSE_TITLES[:]
|
| 160 |
+
rng.shuffle(titles)
|
| 161 |
+
titles = titles[:NUM_COURSES]
|
| 162 |
+
|
| 163 |
+
rows = []
|
| 164 |
+
for i, title in enumerate(titles, start=1):
|
| 165 |
+
dept = rng.choice(DEPARTMENTS)
|
| 166 |
+
code = f"{dept}{100 + i}"
|
| 167 |
+
credits = rng.choice([2, 3, 3, 4])
|
| 168 |
+
rows.append((code, title, dept, credits))
|
| 169 |
+
|
| 170 |
+
cur.executemany(
|
| 171 |
+
"INSERT INTO courses(course_code, course_name, department, credits) VALUES (?,?,?,?)",
|
| 172 |
+
rows,
|
| 173 |
+
)
|
| 174 |
+
con.commit()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def seed_enrollments_and_attendance(con: sqlite3.Connection, rng: random.Random) -> None:
|
| 178 |
+
cur = con.cursor()
|
| 179 |
+
|
| 180 |
+
student_ids = [r[0] for r in cur.execute("SELECT student_id FROM students").fetchall()]
|
| 181 |
+
course_ids = [r[0] for r in cur.execute("SELECT course_id FROM courses").fetchall()]
|
| 182 |
+
|
| 183 |
+
enrollment_rows = []
|
| 184 |
+
attendance_rows = []
|
| 185 |
+
|
| 186 |
+
# Build a small calendar per semester (10 class dates)
|
| 187 |
+
sem_start = {
|
| 188 |
+
"2024-Fall": date(2024, 9, 1),
|
| 189 |
+
"2025-Spring": date(2025, 2, 1),
|
| 190 |
+
"2025-Fall": date(2025, 9, 1),
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
for sem in SEMESTERS:
|
| 194 |
+
start = sem_start.get(sem, date(2025, 1, 1))
|
| 195 |
+
class_dates = [(start + timedelta(days=7 * i)).isoformat() for i in range(10)]
|
| 196 |
+
|
| 197 |
+
for sid in student_ids:
|
| 198 |
+
# each semester: 3-5 courses
|
| 199 |
+
chosen = rng.sample(course_ids, k=rng.randint(3, 5))
|
| 200 |
+
for cid in chosen:
|
| 201 |
+
# score distribution: mostly 60-95
|
| 202 |
+
base = rng.gauss(mu=78, sigma=10)
|
| 203 |
+
score = max(0, min(100, round(base, 1)))
|
| 204 |
+
grade = grade_from_score(score)
|
| 205 |
+
|
| 206 |
+
enrollment_rows.append((sid, cid, sem, score, grade))
|
| 207 |
+
|
| 208 |
+
# attendance probability correlates loosely with score
|
| 209 |
+
# higher score => slightly higher attendance
|
| 210 |
+
p_present = min(0.98, max(0.60, 0.70 + (score - 70) / 100))
|
| 211 |
+
for d in class_dates:
|
| 212 |
+
present = 1 if rng.random() < p_present else 0
|
| 213 |
+
attendance_rows.append((sid, cid, sem, d, present))
|
| 214 |
+
|
| 215 |
+
cur.executemany(
|
| 216 |
+
"INSERT OR IGNORE INTO enrollments(student_id, course_id, semester, score, grade) VALUES (?,?,?,?,?)",
|
| 217 |
+
enrollment_rows,
|
| 218 |
+
)
|
| 219 |
+
cur.executemany(
|
| 220 |
+
"INSERT INTO attendance(student_id, course_id, semester, class_date, present) VALUES (?,?,?,?,?)",
|
| 221 |
+
attendance_rows,
|
| 222 |
+
)
|
| 223 |
+
con.commit()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def create_views(con: sqlite3.Connection) -> None:
|
| 227 |
+
cur = con.cursor()
|
| 228 |
+
cur.executescript(
|
| 229 |
+
"""
|
| 230 |
+
DROP VIEW IF EXISTS student_performance;
|
| 231 |
+
|
| 232 |
+
CREATE VIEW student_performance AS
|
| 233 |
+
SELECT
|
| 234 |
+
s.student_id,
|
| 235 |
+
s.name,
|
| 236 |
+
s.program,
|
| 237 |
+
s.section,
|
| 238 |
+
e.semester,
|
| 239 |
+
ROUND(AVG(e.score), 2) AS avg_score,
|
| 240 |
+
SUM(CASE WHEN e.grade = 'A' THEN 1 ELSE 0 END) AS num_A,
|
| 241 |
+
COUNT(*) AS num_courses
|
| 242 |
+
FROM students s
|
| 243 |
+
JOIN enrollments e ON e.student_id = s.student_id
|
| 244 |
+
GROUP BY s.student_id, e.semester;
|
| 245 |
+
"""
|
| 246 |
+
)
|
| 247 |
+
con.commit()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def print_summary(con: sqlite3.Connection) -> None:
|
| 251 |
+
cur = con.cursor()
|
| 252 |
+
tables = cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;").fetchall()
|
| 253 |
+
print("Tables:", [t[0] for t in tables])
|
| 254 |
+
|
| 255 |
+
for t in ["students", "courses", "enrollments", "attendance"]:
|
| 256 |
+
n = cur.execute(f"SELECT COUNT(*) FROM {t};").fetchone()[0]
|
| 257 |
+
print(f"{t}: {n}")
|
| 258 |
+
|
| 259 |
+
# A couple example queries
|
| 260 |
+
print("\nExample: Top 5 students by avg score (latest semester)")
|
| 261 |
+
latest = cur.execute("SELECT semester FROM enrollments ORDER BY semester DESC LIMIT 1;").fetchone()[0]
|
| 262 |
+
rows = cur.execute(
|
| 263 |
+
"""
|
| 264 |
+
SELECT s.name, s.program, ROUND(AVG(e.score), 2) AS avg_score
|
| 265 |
+
FROM students s
|
| 266 |
+
JOIN enrollments e ON e.student_id = s.student_id
|
| 267 |
+
WHERE e.semester = ?
|
| 268 |
+
GROUP BY s.student_id
|
| 269 |
+
ORDER BY avg_score DESC
|
| 270 |
+
LIMIT 5;
|
| 271 |
+
""",
|
| 272 |
+
(latest,),
|
| 273 |
+
).fetchall()
|
| 274 |
+
for r in rows:
|
| 275 |
+
print(r)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def main() -> None:
|
| 279 |
+
rng = random.Random(SEED)
|
| 280 |
+
db_path = Path(DB_NAME).resolve()
|
| 281 |
+
|
| 282 |
+
con = connect(db_path)
|
| 283 |
+
try:
|
| 284 |
+
recreate_schema(con)
|
| 285 |
+
seed_students(con, rng)
|
| 286 |
+
seed_courses(con, rng)
|
| 287 |
+
seed_enrollments_and_attendance(con, rng)
|
| 288 |
+
create_views(con)
|
| 289 |
+
print(f"Created DB: {db_path}")
|
| 290 |
+
print_summary(con)
|
| 291 |
+
finally:
|
| 292 |
+
con.close()
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
main()
|