Vanshcc commited on
Commit
8823302
·
verified ·
1 Parent(s): 6ca635c

add more context awareness and add graphing

Browse files
Files changed (5) hide show
  1. app.py +27 -34
  2. chatbot.py +17 -3
  3. router.py +9 -2
  4. sql/generator.py +6 -5
  5. viz_utils.py +64 -0
app.py CHANGED
@@ -34,6 +34,7 @@ from database.connection import DatabaseConnection
34
  from llm import create_llm_client
35
  from chatbot import create_chatbot, DatabaseChatbot
36
  from memory import ChatMemory, EnhancedChatMemory
 
37
 
38
 
39
  # Groq models (all FREE!)
@@ -557,7 +558,7 @@ def render_chat_interface():
557
 
558
  with chat_container:
559
  # Display messages
560
- for msg in st.session_state.messages:
561
  with st.chat_message(msg["role"]):
562
  st.markdown(msg["content"])
563
 
@@ -569,6 +570,9 @@ def render_chat_interface():
569
  if meta.get("sql_query"):
570
  with st.expander("SQL Query"):
571
  st.code(meta["sql_query"], language="sql")
 
 
 
572
 
573
  # Chat input
574
  if prompt := st.chat_input("Ask about your data..."):
@@ -581,43 +585,32 @@ def render_chat_interface():
581
  if st.session_state.memory:
582
  st.session_state.memory.add_message("user", prompt)
583
 
 
584
  with st.chat_message("user"):
585
  st.markdown(prompt)
586
 
587
  # Get response
588
- with st.chat_message("assistant"):
589
- with st.spinner("Thinking..."):
590
- response = st.session_state.chatbot.chat(
591
- prompt,
592
- st.session_state.memory,
593
- ignored_tables=list(st.session_state.ignored_tables)
594
- )
595
-
596
- st.markdown(response.answer)
597
-
598
- # Show metadata
599
- if response.query_type != "general":
600
- st.caption(f"Query type: {response.query_type}")
601
-
602
- if response.sql_query:
603
- with st.expander("SQL Query"):
604
- st.code(response.sql_query, language="sql")
605
-
606
- if response.sql_results:
607
- with st.expander("Results"):
608
- st.dataframe(response.sql_results)
609
-
610
- # Save to memory
611
- st.session_state.messages.append({
612
- "role": "assistant",
613
- "content": response.answer,
614
- "metadata": {
615
- "query_type": response.query_type,
616
- "sql_query": response.sql_query
617
- }
618
- })
619
- if st.session_state.memory:
620
- st.session_state.memory.add_message("assistant", response.answer)
621
 
622
 
623
  def main():
 
34
  from llm import create_llm_client
35
  from chatbot import create_chatbot, DatabaseChatbot
36
  from memory import ChatMemory, EnhancedChatMemory
37
+ from viz_utils import render_visualization
38
 
39
 
40
  # Groq models (all FREE!)
 
558
 
559
  with chat_container:
560
  # Display messages
561
+ for i, msg in enumerate(st.session_state.messages):
562
  with st.chat_message(msg["role"]):
563
  st.markdown(msg["content"])
564
 
 
570
  if meta.get("sql_query"):
571
  with st.expander("SQL Query"):
572
  st.code(meta["sql_query"], language="sql")
573
+
574
+ if meta.get("sql_results"):
575
+ render_visualization(meta["sql_results"], f"hist_{i}")
576
 
577
  # Chat input
578
  if prompt := st.chat_input("Ask about your data..."):
 
585
  if st.session_state.memory:
586
  st.session_state.memory.add_message("user", prompt)
587
 
588
+ # Display user message immediately
589
  with st.chat_message("user"):
590
  st.markdown(prompt)
591
 
592
  # Get response
593
+ with st.spinner("Thinking..."):
594
+ response = st.session_state.chatbot.chat(
595
+ prompt,
596
+ st.session_state.memory,
597
+ ignored_tables=list(st.session_state.ignored_tables)
598
+ )
599
+
600
+ # Save to memory
601
+ st.session_state.messages.append({
602
+ "role": "assistant",
603
+ "content": response.answer,
604
+ "metadata": {
605
+ "query_type": response.query_type,
606
+ "sql_query": response.sql_query,
607
+ "sql_results": response.sql_results
608
+ }
609
+ })
610
+ if st.session_state.memory:
611
+ st.session_state.memory.add_message("assistant", response.answer)
612
+
613
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
614
 
615
 
616
  def main():
chatbot.py CHANGED
@@ -57,6 +57,12 @@ INSTRUCTIONS:
57
  - Be concise but complete
58
  - Format data nicely
59
 
 
 
 
 
 
 
60
  YOUR RESPONSE:"""
61
 
62
  def __init__(self, llm_client: Optional[LLMClient] = None):
@@ -238,11 +244,11 @@ YOUR RESPONSE:"""
238
  else:
239
  return ChatResponse(answer="⚠️ Nothing previous to save. Tell me something to remember first!", query_type="memory")
240
 
241
- # Route the query
242
- routing = self.router.route(query, schema_context)
243
-
244
  # Get chat history for context
245
  history = memory.get_context_messages(5) if memory else []
 
 
 
246
 
247
  # Process based on route
248
  if routing.query_type == QueryType.RAG:
@@ -260,6 +266,14 @@ YOUR RESPONSE:"""
260
 
261
  def _handle_rag(self, query: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
262
  """Handle RAG-based query."""
 
 
 
 
 
 
 
 
263
  context = self.rag_engine.get_context(query, top_k=5, table_filter=allowed_tables)
264
 
265
  prompt = self.RESPONSE_PROMPT.format(context=f"RELEVANT DATA:\n{context}", question=query)
 
57
  - Be concise but complete
58
  - Format data nicely
59
 
60
+ INTERACTION GUIDELINES:
61
+ - If the SQL results show a list (e.g., top products) and hit the limit (5, 10, or 50), MENTION this and ASK the user if they want to see more or a specific number.
62
+ Example: "Here are the top 5 products... Would you like to see the top 10?"
63
+ - If the user's question was broad (e.g., "Show me products") and you're showing a limited set, ASK if they want to filter by a specific attribute (e.g., "Would you like to filter by category or price?").
64
+ - If the answer is "0 results" for a "top/best" query, suggest looking at the data generally.
65
+
66
  YOUR RESPONSE:"""
67
 
68
  def __init__(self, llm_client: Optional[LLMClient] = None):
 
244
  else:
245
  return ChatResponse(answer="⚠️ Nothing previous to save. Tell me something to remember first!", query_type="memory")
246
 
 
 
 
247
  # Get chat history for context
248
  history = memory.get_context_messages(5) if memory else []
249
+
250
+ # Route the query
251
+ routing = self.router.route(query, schema_context, history)
252
 
253
  # Process based on route
254
  if routing.query_type == QueryType.RAG:
 
266
 
267
  def _handle_rag(self, query: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
268
  """Handle RAG-based query."""
269
+ # Check if we have any indexed data
270
+ if self.rag_engine.document_count == 0:
271
+ return ChatResponse(
272
+ answer="⚠️ **I can't answer this yet.**\n\nThis looks like a semantic question (searching for meaning/concepts), but you haven't **indexed the text data** yet.\n\nPlease click the **'📚 Index Text Data'** button in the sidebar to enable this functionality.",
273
+ query_type="error",
274
+ error="RAG index is empty"
275
+ )
276
+
277
  context = self.rag_engine.get_context(query, top_k=5, table_filter=allowed_tables)
278
 
279
  prompt = self.RESPONSE_PROMPT.format(context=f"RELEVANT DATA:\n{context}", question=query)
router.py CHANGED
@@ -48,6 +48,7 @@ Determine if this query needs:
48
  4. GENERAL - General conversation not requiring database access
49
 
50
  IMPORTANT: If the user asks to "show more", "show other", "see remaining", "next results", or similar - this is a PAGINATION request and should be routed to SQL, NOT GENERAL.
 
51
 
52
  Respond in this exact format:
53
  TYPE: [RAG|SQL|HYBRID|GENERAL]
@@ -61,13 +62,19 @@ REASONING: [brief explanation]"""
61
  def set_llm_client(self, llm_client):
62
  self.llm_client = llm_client
63
 
64
- def route(self, query: str, schema_context: str) -> RoutingDecision:
65
  """Analyze query and determine routing."""
66
  if not self.llm_client:
67
  # Fallback to simple heuristics
68
  return self._heuristic_route(query)
69
 
70
- prompt = self.ROUTING_PROMPT.format(schema=schema_context, query=query)
 
 
 
 
 
 
71
 
72
  try:
73
  response = self.llm_client.chat([
 
48
  4. GENERAL - General conversation not requiring database access
49
 
50
  IMPORTANT: If the user asks to "show more", "show other", "see remaining", "next results", or similar - this is a PAGINATION request and should be routed to SQL, NOT GENERAL.
51
+ 5. REFERENTIAL/AFFIRMATIVE: If the query is simply "yes", "sure", "ok", "please", or "do it", check if it's likely a confirmation to a previous offer (like "would you like to see 10 more?"). If so, this is likely SQL (pagination or new query). If ambiguous, default to GENERAL.
52
 
53
  Respond in this exact format:
54
  TYPE: [RAG|SQL|HYBRID|GENERAL]
 
62
  def set_llm_client(self, llm_client):
63
  self.llm_client = llm_client
64
 
65
+ def route(self, query: str, schema_context: str, chat_history: Optional[List[Dict]] = None) -> RoutingDecision:
66
  """Analyze query and determine routing."""
67
  if not self.llm_client:
68
  # Fallback to simple heuristics
69
  return self._heuristic_route(query)
70
 
71
+ prev_context = ""
72
+ if chat_history and len(chat_history) > 0:
73
+ last_msg = chat_history[-1]
74
+ if last_msg.get("role") == "assistant":
75
+ prev_context = f"\nPREVIOUS ASSISTANT MSG: {last_msg.get('content', '')[:200]}..."
76
+
77
+ prompt = self.ROUTING_PROMPT.format(schema=schema_context, query=query + prev_context)
78
 
79
  try:
80
  response = self.llm_client.chat([
sql/generator.py CHANGED
@@ -61,15 +61,16 @@ class SQLGenerator:
61
  RULES:
62
  1. ONLY generate SELECT statements.
63
  2. NEVER use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, or TRUNCATE.
64
- 3. Always include a LIMIT clause (max 50 rows unless specified).
65
  4. Use table and column names EXACTLY as shown in the schema.
66
- 5. AMBIGUITY: If the user asks for a category, type, or specific value, and you are unsure which column it belongs to:
 
67
  - Check multiple likely columns (e.g., `category`, `sub_category`, `type`, `description`).
68
  - Use pattern matching for flexibility.
69
  - Use `OR` to combine multiple column checks.
70
- 6. DATA AWARENESS: In footwear databases, specific types like 'Formal', 'Casual', or 'Sports' often appear in `sub_category` OR `category`. Check both if available.
71
- 7. Return ONLY the SQL query, no explanations.
72
- 8. PAGINATION: If the user asks to "show more", "show other", "see remaining", or similar follow-up:
73
  - Look at the previous conversation for the original query conditions.
74
  - Use LIMIT with OFFSET to get the next set of results (e.g., LIMIT 10 OFFSET 10 for the second page).
75
  - Keep the same WHERE conditions from the previous query.
 
61
  RULES:
62
  1. ONLY generate SELECT statements.
63
  2. NEVER use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, or TRUNCATE.
64
+ 3. Always include a LIMIT clause (max 50 rows unless specified). Do NOT use LIMIT 1 for "top" or "best" queries unless explicitly asked for "single" or "one".
65
  4. Use table and column names EXACTLY as shown in the schema.
66
+ 5. TOP/BEST ITEMS: When asked for 'top', 'highest', or 'best' items (e.g. 'top rated products'), use LIMIT 5 or LIMIT 10 to show potential ties or multiple top candidates. Never use LIMIT 1 for these unless the user explicitly asks for "the number one" or "single best".
67
+ 6. AMBIGUITY: If the user asks for a category, type, or specific value, and you are unsure which column it belongs to:
68
  - Check multiple likely columns (e.g., `category`, `sub_category`, `type`, `description`).
69
  - Use pattern matching for flexibility.
70
  - Use `OR` to combine multiple column checks.
71
+ 7. DATA AWARENESS: In footwear databases, specific types like 'Formal', 'Casual', or 'Sports' often appear in `sub_category` OR `category`. Check both if available.
72
+ 8. Return ONLY the SQL query, no explanations.
73
+ 9. PAGINATION: If the user asks to "show more", "show other", "see remaining", or similar follow-up:
74
  - Look at the previous conversation for the original query conditions.
75
  - Use LIMIT with OFFSET to get the next set of results (e.g., LIMIT 10 OFFSET 10 for the second page).
76
  - Keep the same WHERE conditions from the previous query.
viz_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import pandas as pd
4
+
5
+ def render_visualization(results, key_prefix):
6
+ """Render data tables and visualizations from SQL results."""
7
+ if not results:
8
+ return
9
+
10
+ # Convert to DataFrame
11
+ df = pd.DataFrame(results)
12
+
13
+ with st.expander("📊 Results & Visualization", expanded=False):
14
+ tab_data, tab_viz = st.tabs(["📄 Data", "📈 Visualize"])
15
+
16
+ with tab_data:
17
+ st.dataframe(df, use_container_width=True)
18
+
19
+ with tab_viz:
20
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
21
+ categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
22
+
23
+ if not numeric_cols:
24
+ st.info("No numeric data found to visualize.")
25
+ else:
26
+ col1, col2, col3 = st.columns(3)
27
+ with col1:
28
+ chart_type = st.selectbox(
29
+ "Chart Type",
30
+ ["Bar", "Line", "Area", "Scatter"],
31
+ key=f"{key_prefix}_chart_type"
32
+ )
33
+ with col2:
34
+ # Default X axis logic
35
+ x_options = df.columns.tolist()
36
+ default_x = categorical_cols[0] if categorical_cols else x_options[0]
37
+ # Find index safely
38
+ try:
39
+ def_index = x_options.index(default_x)
40
+ except ValueError:
41
+ def_index = 0
42
+
43
+ x_axis = st.selectbox(
44
+ "X Axis",
45
+ x_options,
46
+ index=def_index,
47
+ key=f"{key_prefix}_x_axis"
48
+ )
49
+ with col3:
50
+ y_axis = st.selectbox(
51
+ "Y Axis",
52
+ numeric_cols,
53
+ index=0,
54
+ key=f"{key_prefix}_y_axis"
55
+ )
56
+
57
+ if chart_type == "Bar":
58
+ st.bar_chart(df, x=x_axis, y=y_axis)
59
+ elif chart_type == "Line":
60
+ st.line_chart(df, x=x_axis, y=y_axis)
61
+ elif chart_type == "Area":
62
+ st.area_chart(df, x=x_axis, y=y_axis)
63
+ elif chart_type == "Scatter":
64
+ st.scatter_chart(df, x=x_axis, y=y_axis)