Spaces:
Running
Running
add more context awareness and add graphing
Browse files- app.py +27 -34
- chatbot.py +17 -3
- router.py +9 -2
- sql/generator.py +6 -5
- viz_utils.py +64 -0
app.py
CHANGED
|
@@ -34,6 +34,7 @@ from database.connection import DatabaseConnection
|
|
| 34 |
from llm import create_llm_client
|
| 35 |
from chatbot import create_chatbot, DatabaseChatbot
|
| 36 |
from memory import ChatMemory, EnhancedChatMemory
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
# Groq models (all FREE!)
|
|
@@ -557,7 +558,7 @@ def render_chat_interface():
|
|
| 557 |
|
| 558 |
with chat_container:
|
| 559 |
# Display messages
|
| 560 |
-
for msg in st.session_state.messages:
|
| 561 |
with st.chat_message(msg["role"]):
|
| 562 |
st.markdown(msg["content"])
|
| 563 |
|
|
@@ -569,6 +570,9 @@ def render_chat_interface():
|
|
| 569 |
if meta.get("sql_query"):
|
| 570 |
with st.expander("SQL Query"):
|
| 571 |
st.code(meta["sql_query"], language="sql")
|
|
|
|
|
|
|
|
|
|
| 572 |
|
| 573 |
# Chat input
|
| 574 |
if prompt := st.chat_input("Ask about your data..."):
|
|
@@ -581,43 +585,32 @@ def render_chat_interface():
|
|
| 581 |
if st.session_state.memory:
|
| 582 |
st.session_state.memory.add_message("user", prompt)
|
| 583 |
|
|
|
|
| 584 |
with st.chat_message("user"):
|
| 585 |
st.markdown(prompt)
|
| 586 |
|
| 587 |
# Get response
|
| 588 |
-
with st.
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
# Save to memory
|
| 611 |
-
st.session_state.messages.append({
|
| 612 |
-
"role": "assistant",
|
| 613 |
-
"content": response.answer,
|
| 614 |
-
"metadata": {
|
| 615 |
-
"query_type": response.query_type,
|
| 616 |
-
"sql_query": response.sql_query
|
| 617 |
-
}
|
| 618 |
-
})
|
| 619 |
-
if st.session_state.memory:
|
| 620 |
-
st.session_state.memory.add_message("assistant", response.answer)
|
| 621 |
|
| 622 |
|
| 623 |
def main():
|
|
|
|
| 34 |
from llm import create_llm_client
|
| 35 |
from chatbot import create_chatbot, DatabaseChatbot
|
| 36 |
from memory import ChatMemory, EnhancedChatMemory
|
| 37 |
+
from viz_utils import render_visualization
|
| 38 |
|
| 39 |
|
| 40 |
# Groq models (all FREE!)
|
|
|
|
| 558 |
|
| 559 |
with chat_container:
|
| 560 |
# Display messages
|
| 561 |
+
for i, msg in enumerate(st.session_state.messages):
|
| 562 |
with st.chat_message(msg["role"]):
|
| 563 |
st.markdown(msg["content"])
|
| 564 |
|
|
|
|
| 570 |
if meta.get("sql_query"):
|
| 571 |
with st.expander("SQL Query"):
|
| 572 |
st.code(meta["sql_query"], language="sql")
|
| 573 |
+
|
| 574 |
+
if meta.get("sql_results"):
|
| 575 |
+
render_visualization(meta["sql_results"], f"hist_{i}")
|
| 576 |
|
| 577 |
# Chat input
|
| 578 |
if prompt := st.chat_input("Ask about your data..."):
|
|
|
|
| 585 |
if st.session_state.memory:
|
| 586 |
st.session_state.memory.add_message("user", prompt)
|
| 587 |
|
| 588 |
+
# Display user message immediately
|
| 589 |
with st.chat_message("user"):
|
| 590 |
st.markdown(prompt)
|
| 591 |
|
| 592 |
# Get response
|
| 593 |
+
with st.spinner("Thinking..."):
|
| 594 |
+
response = st.session_state.chatbot.chat(
|
| 595 |
+
prompt,
|
| 596 |
+
st.session_state.memory,
|
| 597 |
+
ignored_tables=list(st.session_state.ignored_tables)
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# Save to memory
|
| 601 |
+
st.session_state.messages.append({
|
| 602 |
+
"role": "assistant",
|
| 603 |
+
"content": response.answer,
|
| 604 |
+
"metadata": {
|
| 605 |
+
"query_type": response.query_type,
|
| 606 |
+
"sql_query": response.sql_query,
|
| 607 |
+
"sql_results": response.sql_results
|
| 608 |
+
}
|
| 609 |
+
})
|
| 610 |
+
if st.session_state.memory:
|
| 611 |
+
st.session_state.memory.add_message("assistant", response.answer)
|
| 612 |
+
|
| 613 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
|
| 615 |
|
| 616 |
def main():
|
chatbot.py
CHANGED
|
@@ -57,6 +57,12 @@ INSTRUCTIONS:
|
|
| 57 |
- Be concise but complete
|
| 58 |
- Format data nicely
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
YOUR RESPONSE:"""
|
| 61 |
|
| 62 |
def __init__(self, llm_client: Optional[LLMClient] = None):
|
|
@@ -238,11 +244,11 @@ YOUR RESPONSE:"""
|
|
| 238 |
else:
|
| 239 |
return ChatResponse(answer="⚠️ Nothing previous to save. Tell me something to remember first!", query_type="memory")
|
| 240 |
|
| 241 |
-
# Route the query
|
| 242 |
-
routing = self.router.route(query, schema_context)
|
| 243 |
-
|
| 244 |
# Get chat history for context
|
| 245 |
history = memory.get_context_messages(5) if memory else []
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
# Process based on route
|
| 248 |
if routing.query_type == QueryType.RAG:
|
|
@@ -260,6 +266,14 @@ YOUR RESPONSE:"""
|
|
| 260 |
|
| 261 |
def _handle_rag(self, query: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
|
| 262 |
"""Handle RAG-based query."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
context = self.rag_engine.get_context(query, top_k=5, table_filter=allowed_tables)
|
| 264 |
|
| 265 |
prompt = self.RESPONSE_PROMPT.format(context=f"RELEVANT DATA:\n{context}", question=query)
|
|
|
|
| 57 |
- Be concise but complete
|
| 58 |
- Format data nicely
|
| 59 |
|
| 60 |
+
INTERACTION GUIDELINES:
|
| 61 |
+
- If the SQL results show a list (e.g., top products) and hit the limit (5, 10, or 50), MENTION this and ASK the user if they want to see more or a specific number.
|
| 62 |
+
Example: "Here are the top 5 products... Would you like to see the top 10?"
|
| 63 |
+
- If the user's question was broad (e.g., "Show me products") and you're showing a limited set, ASK if they want to filter by a specific attribute (e.g., "Would you like to filter by category or price?").
|
| 64 |
+
- If the answer is "0 results" for a "top/best" query, suggest looking at the data generally.
|
| 65 |
+
|
| 66 |
YOUR RESPONSE:"""
|
| 67 |
|
| 68 |
def __init__(self, llm_client: Optional[LLMClient] = None):
|
|
|
|
| 244 |
else:
|
| 245 |
return ChatResponse(answer="⚠️ Nothing previous to save. Tell me something to remember first!", query_type="memory")
|
| 246 |
|
|
|
|
|
|
|
|
|
|
| 247 |
# Get chat history for context
|
| 248 |
history = memory.get_context_messages(5) if memory else []
|
| 249 |
+
|
| 250 |
+
# Route the query
|
| 251 |
+
routing = self.router.route(query, schema_context, history)
|
| 252 |
|
| 253 |
# Process based on route
|
| 254 |
if routing.query_type == QueryType.RAG:
|
|
|
|
| 266 |
|
| 267 |
def _handle_rag(self, query: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
|
| 268 |
"""Handle RAG-based query."""
|
| 269 |
+
# Check if we have any indexed data
|
| 270 |
+
if self.rag_engine.document_count == 0:
|
| 271 |
+
return ChatResponse(
|
| 272 |
+
answer="⚠️ **I can't answer this yet.**\n\nThis looks like a semantic question (searching for meaning/concepts), but you haven't **indexed the text data** yet.\n\nPlease click the **'📚 Index Text Data'** button in the sidebar to enable this functionality.",
|
| 273 |
+
query_type="error",
|
| 274 |
+
error="RAG index is empty"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
context = self.rag_engine.get_context(query, top_k=5, table_filter=allowed_tables)
|
| 278 |
|
| 279 |
prompt = self.RESPONSE_PROMPT.format(context=f"RELEVANT DATA:\n{context}", question=query)
|
router.py
CHANGED
|
@@ -48,6 +48,7 @@ Determine if this query needs:
|
|
| 48 |
4. GENERAL - General conversation not requiring database access
|
| 49 |
|
| 50 |
IMPORTANT: If the user asks to "show more", "show other", "see remaining", "next results", or similar - this is a PAGINATION request and should be routed to SQL, NOT GENERAL.
|
|
|
|
| 51 |
|
| 52 |
Respond in this exact format:
|
| 53 |
TYPE: [RAG|SQL|HYBRID|GENERAL]
|
|
@@ -61,13 +62,19 @@ REASONING: [brief explanation]"""
|
|
| 61 |
def set_llm_client(self, llm_client):
|
| 62 |
self.llm_client = llm_client
|
| 63 |
|
| 64 |
-
def route(self, query: str, schema_context: str) -> RoutingDecision:
|
| 65 |
"""Analyze query and determine routing."""
|
| 66 |
if not self.llm_client:
|
| 67 |
# Fallback to simple heuristics
|
| 68 |
return self._heuristic_route(query)
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
try:
|
| 73 |
response = self.llm_client.chat([
|
|
|
|
| 48 |
4. GENERAL - General conversation not requiring database access
|
| 49 |
|
| 50 |
IMPORTANT: If the user asks to "show more", "show other", "see remaining", "next results", or similar - this is a PAGINATION request and should be routed to SQL, NOT GENERAL.
|
| 51 |
+
5. REFERENTIAL/AFFIRMATIVE: If the query is simply "yes", "sure", "ok", "please", or "do it", check if it's likely a confirmation to a previous offer (like "would you like to see 10 more?"). If so, this is likely SQL (pagination or new query). If ambiguous, default to GENERAL.
|
| 52 |
|
| 53 |
Respond in this exact format:
|
| 54 |
TYPE: [RAG|SQL|HYBRID|GENERAL]
|
|
|
|
| 62 |
def set_llm_client(self, llm_client):
|
| 63 |
self.llm_client = llm_client
|
| 64 |
|
| 65 |
+
def route(self, query: str, schema_context: str, chat_history: Optional[List[Dict]] = None) -> RoutingDecision:
|
| 66 |
"""Analyze query and determine routing."""
|
| 67 |
if not self.llm_client:
|
| 68 |
# Fallback to simple heuristics
|
| 69 |
return self._heuristic_route(query)
|
| 70 |
|
| 71 |
+
prev_context = ""
|
| 72 |
+
if chat_history and len(chat_history) > 0:
|
| 73 |
+
last_msg = chat_history[-1]
|
| 74 |
+
if last_msg.get("role") == "assistant":
|
| 75 |
+
prev_context = f"\nPREVIOUS ASSISTANT MSG: {last_msg.get('content', '')[:200]}..."
|
| 76 |
+
|
| 77 |
+
prompt = self.ROUTING_PROMPT.format(schema=schema_context, query=query + prev_context)
|
| 78 |
|
| 79 |
try:
|
| 80 |
response = self.llm_client.chat([
|
sql/generator.py
CHANGED
|
@@ -61,15 +61,16 @@ class SQLGenerator:
|
|
| 61 |
RULES:
|
| 62 |
1. ONLY generate SELECT statements.
|
| 63 |
2. NEVER use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, or TRUNCATE.
|
| 64 |
-
3. Always include a LIMIT clause (max 50 rows unless specified).
|
| 65 |
4. Use table and column names EXACTLY as shown in the schema.
|
| 66 |
-
5.
|
|
|
|
| 67 |
- Check multiple likely columns (e.g., `category`, `sub_category`, `type`, `description`).
|
| 68 |
- Use pattern matching for flexibility.
|
| 69 |
- Use `OR` to combine multiple column checks.
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
- Look at the previous conversation for the original query conditions.
|
| 74 |
- Use LIMIT with OFFSET to get the next set of results (e.g., LIMIT 10 OFFSET 10 for the second page).
|
| 75 |
- Keep the same WHERE conditions from the previous query.
|
|
|
|
| 61 |
RULES:
|
| 62 |
1. ONLY generate SELECT statements.
|
| 63 |
2. NEVER use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, or TRUNCATE.
|
| 64 |
+
3. Always include a LIMIT clause (max 50 rows unless specified). Do NOT use LIMIT 1 for "top" or "best" queries unless explicitly asked for "single" or "one".
|
| 65 |
4. Use table and column names EXACTLY as shown in the schema.
|
| 66 |
+
5. TOP/BEST ITEMS: When asked for 'top', 'highest', or 'best' items (e.g. 'top rated products'), use LIMIT 5 or LIMIT 10 to show potential ties or multiple top candidates. Never use LIMIT 1 for these unless the user explicitly asks for "the number one" or "single best".
|
| 67 |
+
6. AMBIGUITY: If the user asks for a category, type, or specific value, and you are unsure which column it belongs to:
|
| 68 |
- Check multiple likely columns (e.g., `category`, `sub_category`, `type`, `description`).
|
| 69 |
- Use pattern matching for flexibility.
|
| 70 |
- Use `OR` to combine multiple column checks.
|
| 71 |
+
7. DATA AWARENESS: In footwear databases, specific types like 'Formal', 'Casual', or 'Sports' often appear in `sub_category` OR `category`. Check both if available.
|
| 72 |
+
8. Return ONLY the SQL query, no explanations.
|
| 73 |
+
9. PAGINATION: If the user asks to "show more", "show other", "see remaining", or similar follow-up:
|
| 74 |
- Look at the previous conversation for the original query conditions.
|
| 75 |
- Use LIMIT with OFFSET to get the next set of results (e.g., LIMIT 10 OFFSET 10 for the second page).
|
| 76 |
- Keep the same WHERE conditions from the previous query.
|
viz_utils.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
def render_visualization(results, key_prefix):
|
| 6 |
+
"""Render data tables and visualizations from SQL results."""
|
| 7 |
+
if not results:
|
| 8 |
+
return
|
| 9 |
+
|
| 10 |
+
# Convert to DataFrame
|
| 11 |
+
df = pd.DataFrame(results)
|
| 12 |
+
|
| 13 |
+
with st.expander("📊 Results & Visualization", expanded=False):
|
| 14 |
+
tab_data, tab_viz = st.tabs(["📄 Data", "📈 Visualize"])
|
| 15 |
+
|
| 16 |
+
with tab_data:
|
| 17 |
+
st.dataframe(df, use_container_width=True)
|
| 18 |
+
|
| 19 |
+
with tab_viz:
|
| 20 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
| 21 |
+
categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
|
| 22 |
+
|
| 23 |
+
if not numeric_cols:
|
| 24 |
+
st.info("No numeric data found to visualize.")
|
| 25 |
+
else:
|
| 26 |
+
col1, col2, col3 = st.columns(3)
|
| 27 |
+
with col1:
|
| 28 |
+
chart_type = st.selectbox(
|
| 29 |
+
"Chart Type",
|
| 30 |
+
["Bar", "Line", "Area", "Scatter"],
|
| 31 |
+
key=f"{key_prefix}_chart_type"
|
| 32 |
+
)
|
| 33 |
+
with col2:
|
| 34 |
+
# Default X axis logic
|
| 35 |
+
x_options = df.columns.tolist()
|
| 36 |
+
default_x = categorical_cols[0] if categorical_cols else x_options[0]
|
| 37 |
+
# Find index safely
|
| 38 |
+
try:
|
| 39 |
+
def_index = x_options.index(default_x)
|
| 40 |
+
except ValueError:
|
| 41 |
+
def_index = 0
|
| 42 |
+
|
| 43 |
+
x_axis = st.selectbox(
|
| 44 |
+
"X Axis",
|
| 45 |
+
x_options,
|
| 46 |
+
index=def_index,
|
| 47 |
+
key=f"{key_prefix}_x_axis"
|
| 48 |
+
)
|
| 49 |
+
with col3:
|
| 50 |
+
y_axis = st.selectbox(
|
| 51 |
+
"Y Axis",
|
| 52 |
+
numeric_cols,
|
| 53 |
+
index=0,
|
| 54 |
+
key=f"{key_prefix}_y_axis"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if chart_type == "Bar":
|
| 58 |
+
st.bar_chart(df, x=x_axis, y=y_axis)
|
| 59 |
+
elif chart_type == "Line":
|
| 60 |
+
st.line_chart(df, x=x_axis, y=y_axis)
|
| 61 |
+
elif chart_type == "Area":
|
| 62 |
+
st.area_chart(df, x=x_axis, y=y_axis)
|
| 63 |
+
elif chart_type == "Scatter":
|
| 64 |
+
st.scatter_chart(df, x=x_axis, y=y_axis)
|