Spaces:
Running
Running
Upload 13 files
Browse files- app.py +47 -23
- chatbot.py +70 -14
- llm/__pycache__/__init__.cpython-311.pyc +0 -0
- llm/__pycache__/client.cpython-311.pyc +0 -0
- llm/client.py +37 -9
- router.py +4 -2
- sql/__pycache__/__init__.cpython-311.pyc +0 -0
- sql/__pycache__/generator.cpython-311.pyc +0 -0
- sql/__pycache__/validator.cpython-311.pyc +0 -0
- sql/generator.py +39 -4
app.py
CHANGED
|
@@ -565,14 +565,24 @@ def render_chat_interface():
|
|
| 565 |
# Show metadata for assistant messages
|
| 566 |
if msg["role"] == "assistant" and "metadata" in msg:
|
| 567 |
meta = msg["metadata"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
if meta.get("query_type"):
|
| 569 |
st.caption(f"Query type: {meta['query_type']}")
|
|
|
|
|
|
|
| 570 |
if meta.get("sql_query"):
|
| 571 |
-
with st.expander("SQL Query"):
|
| 572 |
st.code(meta["sql_query"], language="sql")
|
| 573 |
-
|
|
|
|
| 574 |
if meta.get("sql_results"):
|
| 575 |
-
|
|
|
|
| 576 |
|
| 577 |
# Chat input
|
| 578 |
if prompt := st.chat_input("Ask about your data..."):
|
|
@@ -582,35 +592,49 @@ def render_chat_interface():
|
|
| 582 |
|
| 583 |
# Add user message
|
| 584 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 585 |
-
if st.session_state.memory:
|
| 586 |
-
st.session_state.memory.add_message("user", prompt)
|
| 587 |
|
| 588 |
-
#
|
| 589 |
with st.chat_message("user"):
|
| 590 |
st.markdown(prompt)
|
| 591 |
|
| 592 |
# Get response
|
| 593 |
with st.spinner("Thinking..."):
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
st.session_state.memory
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
| 605 |
"query_type": response.query_type,
|
| 606 |
"sql_query": response.sql_query,
|
| 607 |
-
"sql_results": response.sql_results
|
|
|
|
| 608 |
}
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
st.session_state.
|
| 612 |
-
|
| 613 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
|
| 615 |
|
| 616 |
def main():
|
|
|
|
| 565 |
# Show metadata for assistant messages
|
| 566 |
if msg["role"] == "assistant" and "metadata" in msg:
|
| 567 |
meta = msg["metadata"]
|
| 568 |
+
|
| 569 |
+
# Show token usage in a box
|
| 570 |
+
if "token_usage" in meta:
|
| 571 |
+
usage = meta["token_usage"]
|
| 572 |
+
st.info(f"📊 **Token Usage**: Input: {usage.get('input', 0)} | Output: {usage.get('output', 0)} | Total: {usage.get('total', 0)}")
|
| 573 |
+
|
| 574 |
if meta.get("query_type"):
|
| 575 |
st.caption(f"Query type: {meta['query_type']}")
|
| 576 |
+
|
| 577 |
+
# SQL Query expander
|
| 578 |
if meta.get("sql_query"):
|
| 579 |
+
with st.expander("🛠️ SQL Query & Details"):
|
| 580 |
st.code(meta["sql_query"], language="sql")
|
| 581 |
+
|
| 582 |
+
# Visualizations
|
| 583 |
if meta.get("sql_results"):
|
| 584 |
+
# Only render viz if we have results
|
| 585 |
+
render_visualization(meta["sql_results"], f"viz_{i}")
|
| 586 |
|
| 587 |
# Chat input
|
| 588 |
if prompt := st.chat_input("Ask about your data..."):
|
|
|
|
| 592 |
|
| 593 |
# Add user message
|
| 594 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
|
| 595 |
|
| 596 |
+
# Calculate memory context for display? No, just render user msg
|
| 597 |
with st.chat_message("user"):
|
| 598 |
st.markdown(prompt)
|
| 599 |
|
| 600 |
# Get response
|
| 601 |
with st.spinner("Thinking..."):
|
| 602 |
+
try:
|
| 603 |
+
# Add memory interaction
|
| 604 |
+
if st.session_state.memory:
|
| 605 |
+
st.session_state.memory.add_message("user", prompt)
|
| 606 |
+
|
| 607 |
+
response = st.session_state.chatbot.chat(
|
| 608 |
+
prompt,
|
| 609 |
+
st.session_state.memory,
|
| 610 |
+
ignored_tables=list(st.session_state.ignored_tables)
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# Create metadata dict
|
| 614 |
+
metadata = {
|
| 615 |
"query_type": response.query_type,
|
| 616 |
"sql_query": response.sql_query,
|
| 617 |
+
"sql_results": response.sql_results,
|
| 618 |
+
"token_usage": response.token_usage
|
| 619 |
}
|
| 620 |
+
|
| 621 |
+
# Save to session state
|
| 622 |
+
st.session_state.messages.append({
|
| 623 |
+
"role": "assistant",
|
| 624 |
+
"content": response.answer,
|
| 625 |
+
"metadata": metadata
|
| 626 |
+
})
|
| 627 |
+
|
| 628 |
+
# Save to active memory
|
| 629 |
+
if st.session_state.memory:
|
| 630 |
+
st.session_state.memory.add_message("assistant", response.answer)
|
| 631 |
+
|
| 632 |
+
st.rerun()
|
| 633 |
+
|
| 634 |
+
except Exception as e:
|
| 635 |
+
st.error(f"An error occurred: {e}")
|
| 636 |
+
import traceback
|
| 637 |
+
st.error(traceback.format_exc())
|
| 638 |
|
| 639 |
|
| 640 |
def main():
|
chatbot.py
CHANGED
|
@@ -32,10 +32,13 @@ class ChatResponse:
|
|
| 32 |
sql_query: Optional[str] = None
|
| 33 |
sql_results: Optional[List[Dict]] = None
|
| 34 |
error: Optional[str] = None
|
|
|
|
| 35 |
|
| 36 |
def __post_init__(self):
|
| 37 |
if self.sources is None:
|
| 38 |
self.sources = []
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
class DatabaseChatbot:
|
|
@@ -62,6 +65,8 @@ INTERACTION GUIDELINES:
|
|
| 62 |
Example: "Here are the top 5 products... Would you like to see the top 10?"
|
| 63 |
- If the user's question was broad (e.g., "Show me products") and you're showing a limited set, ASK if they want to filter by a specific attribute (e.g., "Would you like to filter by category or price?").
|
| 64 |
- If the answer is "0 results" for a "top/best" query, suggest looking at the data generally.
|
|
|
|
|
|
|
| 65 |
|
| 66 |
YOUR RESPONSE:"""
|
| 67 |
|
|
@@ -284,27 +289,42 @@ YOUR RESPONSE:"""
|
|
| 284 |
prompt
|
| 285 |
)
|
| 286 |
|
| 287 |
-
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
def _handle_sql(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
|
| 293 |
"""Handle SQL-based query."""
|
| 294 |
-
sql,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
# Validate SQL
|
| 297 |
is_valid, msg, sanitized_sql = self.sql_validator.validate(sql)
|
| 298 |
if not is_valid:
|
| 299 |
return ChatResponse(answer=f"Could not generate safe query: {msg}",
|
| 300 |
-
query_type="sql", error=msg)
|
| 301 |
|
| 302 |
# Execute query
|
| 303 |
try:
|
| 304 |
results = self.db.execute_query(sanitized_sql)
|
| 305 |
except Exception as e:
|
| 306 |
return ChatResponse(answer=f"Query execution failed: {e}",
|
| 307 |
-
query_type="sql", sql_query=sanitized_sql, error=str(e)
|
|
|
|
| 308 |
|
| 309 |
# SMART FALLBACK: If SQL returns nothing, it might be a semantic issue (e.g. wrong column)
|
| 310 |
# We try RAG as a fallback if SQL found nothing
|
|
@@ -316,6 +336,15 @@ YOUR RESPONSE:"""
|
|
| 316 |
rag_response.answer = f"I couldn't find a direct match using a database query, but here is what I found in the product descriptions:\n\n{rag_response.answer}"
|
| 317 |
rag_response.query_type = "hybrid_fallback"
|
| 318 |
rag_response.sql_query = sanitized_sql
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
return rag_response
|
| 320 |
|
| 321 |
# Generate response
|
|
@@ -328,10 +357,16 @@ YOUR RESPONSE:"""
|
|
| 328 |
prompt
|
| 329 |
)
|
| 330 |
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
-
return ChatResponse(answer=
|
| 334 |
-
sql_query=sanitized_sql, sql_results=results[:10]
|
|
|
|
| 335 |
|
| 336 |
def _handle_hybrid(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
|
| 337 |
"""Handle hybrid RAG + SQL query."""
|
|
@@ -341,8 +376,17 @@ YOUR RESPONSE:"""
|
|
| 341 |
# Try SQL as well
|
| 342 |
sql_context = ""
|
| 343 |
sql_query = None
|
|
|
|
|
|
|
|
|
|
| 344 |
try:
|
| 345 |
-
sql,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
is_valid, _, sanitized_sql = self.sql_validator.validate(sql)
|
| 347 |
if is_valid:
|
| 348 |
results = self.db.execute_query(sanitized_sql)
|
|
@@ -360,9 +404,14 @@ YOUR RESPONSE:"""
|
|
| 360 |
prompt
|
| 361 |
)
|
| 362 |
|
| 363 |
-
|
| 364 |
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
def _construct_messages(self, system_instruction: str, history: List[Dict], user_content: str) -> List[Dict]:
|
| 368 |
"""Construct message list, merging system messages from history."""
|
|
@@ -401,8 +450,15 @@ YOUR RESPONSE:"""
|
|
| 401 |
history,
|
| 402 |
query
|
| 403 |
)
|
| 404 |
-
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
def _format_results(self, results: List[Dict], max_rows: int = 10) -> str:
|
| 408 |
"""Format SQL results for display."""
|
|
|
|
| 32 |
sql_query: Optional[str] = None
|
| 33 |
sql_results: Optional[List[Dict]] = None
|
| 34 |
error: Optional[str] = None
|
| 35 |
+
token_usage: Optional[Dict[str, int]] = None
|
| 36 |
|
| 37 |
def __post_init__(self):
|
| 38 |
if self.sources is None:
|
| 39 |
self.sources = []
|
| 40 |
+
if self.token_usage is None:
|
| 41 |
+
self.token_usage = {"input": 0, "output": 0, "total": 0}
|
| 42 |
|
| 43 |
|
| 44 |
class DatabaseChatbot:
|
|
|
|
| 65 |
Example: "Here are the top 5 products... Would you like to see the top 10?"
|
| 66 |
- If the user's question was broad (e.g., "Show me products") and you're showing a limited set, ASK if they want to filter by a specific attribute (e.g., "Would you like to filter by category or price?").
|
| 67 |
- If the answer is "0 results" for a "top/best" query, suggest looking at the data generally.
|
| 68 |
+
- IF SUBJECTIVE INFERENCE WAS USED (e.g., inferred "summer" = sandals), EXPLAIN THIS to the user.
|
| 69 |
+
Example: "I found these products that match 'summer' (based on being Sandals or breathability)..."
|
| 70 |
|
| 71 |
YOUR RESPONSE:"""
|
| 72 |
|
|
|
|
| 289 |
prompt
|
| 290 |
)
|
| 291 |
|
| 292 |
+
response = self.llm_client.chat(messages)
|
| 293 |
|
| 294 |
+
usage = {
|
| 295 |
+
"input": response.input_tokens,
|
| 296 |
+
"output": response.output_tokens,
|
| 297 |
+
"total": response.total_tokens
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
return ChatResponse(answer=response.content, query_type="rag",
|
| 301 |
+
sources=[{"type": "semantic_search", "context": context[:500]}],
|
| 302 |
+
token_usage=usage)
|
| 303 |
|
| 304 |
def _handle_sql(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
|
| 305 |
"""Handle SQL-based query."""
|
| 306 |
+
sql, gen_response = self.sql_generator.generate(query, schema_context, history)
|
| 307 |
+
|
| 308 |
+
# Initial usage from SQL generation
|
| 309 |
+
total_usage = {
|
| 310 |
+
"input": gen_response.input_tokens,
|
| 311 |
+
"output": gen_response.output_tokens,
|
| 312 |
+
"total": gen_response.total_tokens
|
| 313 |
+
}
|
| 314 |
|
| 315 |
# Validate SQL
|
| 316 |
is_valid, msg, sanitized_sql = self.sql_validator.validate(sql)
|
| 317 |
if not is_valid:
|
| 318 |
return ChatResponse(answer=f"Could not generate safe query: {msg}",
|
| 319 |
+
query_type="sql", error=msg, token_usage=total_usage)
|
| 320 |
|
| 321 |
# Execute query
|
| 322 |
try:
|
| 323 |
results = self.db.execute_query(sanitized_sql)
|
| 324 |
except Exception as e:
|
| 325 |
return ChatResponse(answer=f"Query execution failed: {e}",
|
| 326 |
+
query_type="sql", sql_query=sanitized_sql, error=str(e),
|
| 327 |
+
token_usage=total_usage)
|
| 328 |
|
| 329 |
# SMART FALLBACK: If SQL returns nothing, it might be a semantic issue (e.g. wrong column)
|
| 330 |
# We try RAG as a fallback if SQL found nothing
|
|
|
|
| 336 |
rag_response.answer = f"I couldn't find a direct match using a database query, but here is what I found in the product descriptions:\n\n{rag_response.answer}"
|
| 337 |
rag_response.query_type = "hybrid_fallback"
|
| 338 |
rag_response.sql_query = sanitized_sql
|
| 339 |
+
|
| 340 |
+
# Add usage from SQL gen to RAG usage
|
| 341 |
+
if rag_response.token_usage:
|
| 342 |
+
rag_response.token_usage["input"] += total_usage["input"]
|
| 343 |
+
rag_response.token_usage["output"] += total_usage["output"]
|
| 344 |
+
rag_response.token_usage["total"] += total_usage["total"]
|
| 345 |
+
else:
|
| 346 |
+
rag_response.token_usage = total_usage
|
| 347 |
+
|
| 348 |
return rag_response
|
| 349 |
|
| 350 |
# Generate response
|
|
|
|
| 357 |
prompt
|
| 358 |
)
|
| 359 |
|
| 360 |
+
final_response = self.llm_client.chat(messages)
|
| 361 |
+
|
| 362 |
+
# Add usage from final response
|
| 363 |
+
total_usage["input"] += final_response.input_tokens
|
| 364 |
+
total_usage["output"] += final_response.output_tokens
|
| 365 |
+
total_usage["total"] += final_response.total_tokens
|
| 366 |
|
| 367 |
+
return ChatResponse(answer=final_response.content, query_type="sql",
|
| 368 |
+
sql_query=sanitized_sql, sql_results=results[:10],
|
| 369 |
+
token_usage=total_usage)
|
| 370 |
|
| 371 |
def _handle_hybrid(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
|
| 372 |
"""Handle hybrid RAG + SQL query."""
|
|
|
|
| 376 |
# Try SQL as well
|
| 377 |
sql_context = ""
|
| 378 |
sql_query = None
|
| 379 |
+
|
| 380 |
+
total_usage = {"input": 0, "output": 0, "total": 0}
|
| 381 |
+
|
| 382 |
try:
|
| 383 |
+
sql, gen_response = self.sql_generator.generate(query, schema_context, history)
|
| 384 |
+
|
| 385 |
+
# Accumulate usage
|
| 386 |
+
total_usage["input"] += gen_response.input_tokens
|
| 387 |
+
total_usage["output"] += gen_response.output_tokens
|
| 388 |
+
total_usage["total"] += gen_response.total_tokens
|
| 389 |
+
|
| 390 |
is_valid, _, sanitized_sql = self.sql_validator.validate(sql)
|
| 391 |
if is_valid:
|
| 392 |
results = self.db.execute_query(sanitized_sql)
|
|
|
|
| 404 |
prompt
|
| 405 |
)
|
| 406 |
|
| 407 |
+
final_response = self.llm_client.chat(messages)
|
| 408 |
|
| 409 |
+
# Add final usage
|
| 410 |
+
total_usage["input"] += final_response.input_tokens
|
| 411 |
+
total_usage["output"] += final_response.output_tokens
|
| 412 |
+
total_usage["total"] += final_response.total_tokens
|
| 413 |
+
|
| 414 |
+
return ChatResponse(answer=final_response.content, query_type="hybrid", sql_query=sql_query, token_usage=total_usage)
|
| 415 |
|
| 416 |
def _construct_messages(self, system_instruction: str, history: List[Dict], user_content: str) -> List[Dict]:
|
| 417 |
"""Construct message list, merging system messages from history."""
|
|
|
|
| 450 |
history,
|
| 451 |
query
|
| 452 |
)
|
| 453 |
+
response = self.llm_client.chat(messages)
|
| 454 |
+
|
| 455 |
+
usage = {
|
| 456 |
+
"input": response.input_tokens,
|
| 457 |
+
"output": response.output_tokens,
|
| 458 |
+
"total": response.total_tokens
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
return ChatResponse(answer=response.content, query_type="general", token_usage=usage)
|
| 462 |
|
| 463 |
def _format_results(self, results: List[Dict], max_rows: int = 10) -> str:
|
| 464 |
"""Format SQL results for display."""
|
llm/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/llm/__pycache__/__init__.cpython-311.pyc and b/llm/__pycache__/__init__.cpython-311.pyc differ
|
|
|
llm/__pycache__/client.cpython-311.pyc
CHANGED
|
Binary files a/llm/__pycache__/client.cpython-311.pyc and b/llm/__pycache__/client.cpython-311.pyc differ
|
|
|
llm/client.py
CHANGED
|
@@ -11,11 +11,21 @@ from typing import List, Dict, Optional
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class LLMClient(ABC):
|
| 15 |
"""Abstract base class for LLM clients."""
|
| 16 |
|
| 17 |
@abstractmethod
|
| 18 |
-
def chat(self, messages: List[Dict[str, str]]) ->
|
| 19 |
pass
|
| 20 |
|
| 21 |
@abstractmethod
|
|
@@ -64,14 +74,20 @@ class GroqClient(LLMClient):
|
|
| 64 |
self._client = Groq(api_key=self.api_key)
|
| 65 |
return self._client
|
| 66 |
|
| 67 |
-
def chat(self, messages: List[Dict[str, str]]) ->
|
| 68 |
response = self.client.chat.completions.create(
|
| 69 |
model=self.model,
|
| 70 |
messages=messages,
|
| 71 |
temperature=self.temperature,
|
| 72 |
max_tokens=self.max_tokens
|
| 73 |
)
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
def is_available(self) -> bool:
|
| 77 |
try:
|
|
@@ -106,14 +122,20 @@ class OpenAIClient(LLMClient):
|
|
| 106 |
self._client = OpenAI(api_key=self.api_key)
|
| 107 |
return self._client
|
| 108 |
|
| 109 |
-
def chat(self, messages: List[Dict[str, str]]) ->
|
| 110 |
response = self.client.chat.completions.create(
|
| 111 |
model=self.model,
|
| 112 |
messages=messages,
|
| 113 |
temperature=self.temperature,
|
| 114 |
max_tokens=self.max_tokens
|
| 115 |
)
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
def is_available(self) -> bool:
|
| 119 |
try:
|
|
@@ -150,14 +172,21 @@ class LocalLLaMAClient(LLMClient):
|
|
| 150 |
)
|
| 151 |
return self._pipeline
|
| 152 |
|
| 153 |
-
def chat(self, messages: List[Dict[str, str]]) ->
|
| 154 |
output = self.pipeline(
|
| 155 |
messages,
|
| 156 |
max_new_tokens=self.max_tokens,
|
| 157 |
temperature=self.temperature,
|
| 158 |
do_sample=True
|
| 159 |
)
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
def is_available(self) -> bool:
|
| 163 |
try:
|
|
@@ -165,8 +194,7 @@ class LocalLLaMAClient(LLMClient):
|
|
| 165 |
return True
|
| 166 |
except Exception:
|
| 167 |
return False
|
| 168 |
-
|
| 169 |
-
|
| 170 |
def create_llm_client(provider: str = "groq", **kwargs) -> LLMClient:
|
| 171 |
"""
|
| 172 |
Factory function to create LLM client.
|
|
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class LLMResponse:
|
| 18 |
+
content: str
|
| 19 |
+
input_tokens: int = 0
|
| 20 |
+
output_tokens: int = 0
|
| 21 |
+
total_tokens: int = 0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
class LLMClient(ABC):
|
| 25 |
"""Abstract base class for LLM clients."""
|
| 26 |
|
| 27 |
@abstractmethod
|
| 28 |
+
def chat(self, messages: List[Dict[str, str]]) -> LLMResponse:
|
| 29 |
pass
|
| 30 |
|
| 31 |
@abstractmethod
|
|
|
|
| 74 |
self._client = Groq(api_key=self.api_key)
|
| 75 |
return self._client
|
| 76 |
|
| 77 |
+
def chat(self, messages: List[Dict[str, str]]) -> LLMResponse:
|
| 78 |
response = self.client.chat.completions.create(
|
| 79 |
model=self.model,
|
| 80 |
messages=messages,
|
| 81 |
temperature=self.temperature,
|
| 82 |
max_tokens=self.max_tokens
|
| 83 |
)
|
| 84 |
+
usage = response.usage
|
| 85 |
+
return LLMResponse(
|
| 86 |
+
content=response.choices[0].message.content,
|
| 87 |
+
input_tokens=usage.prompt_tokens if usage else 0,
|
| 88 |
+
output_tokens=usage.completion_tokens if usage else 0,
|
| 89 |
+
total_tokens=usage.total_tokens if usage else 0
|
| 90 |
+
)
|
| 91 |
|
| 92 |
def is_available(self) -> bool:
|
| 93 |
try:
|
|
|
|
| 122 |
self._client = OpenAI(api_key=self.api_key)
|
| 123 |
return self._client
|
| 124 |
|
| 125 |
+
def chat(self, messages: List[Dict[str, str]]) -> LLMResponse:
|
| 126 |
response = self.client.chat.completions.create(
|
| 127 |
model=self.model,
|
| 128 |
messages=messages,
|
| 129 |
temperature=self.temperature,
|
| 130 |
max_tokens=self.max_tokens
|
| 131 |
)
|
| 132 |
+
usage = response.usage
|
| 133 |
+
return LLMResponse(
|
| 134 |
+
content=response.choices[0].message.content,
|
| 135 |
+
input_tokens=usage.prompt_tokens if usage else 0,
|
| 136 |
+
output_tokens=usage.completion_tokens if usage else 0,
|
| 137 |
+
total_tokens=usage.total_tokens if usage else 0
|
| 138 |
+
)
|
| 139 |
|
| 140 |
def is_available(self) -> bool:
|
| 141 |
try:
|
|
|
|
| 172 |
)
|
| 173 |
return self._pipeline
|
| 174 |
|
| 175 |
+
def chat(self, messages: List[Dict[str, str]]) -> LLMResponse:
|
| 176 |
output = self.pipeline(
|
| 177 |
messages,
|
| 178 |
max_new_tokens=self.max_tokens,
|
| 179 |
temperature=self.temperature,
|
| 180 |
do_sample=True
|
| 181 |
)
|
| 182 |
+
generated_text = output[0]["generated_text"][-1]["content"]
|
| 183 |
+
# Approximate tokens for local (or use tokenizer if available)
|
| 184 |
+
return LLMResponse(
|
| 185 |
+
content=generated_text,
|
| 186 |
+
input_tokens=0, # Local pipeline generic usually doesn't give this easily without more access
|
| 187 |
+
output_tokens=0,
|
| 188 |
+
total_tokens=0
|
| 189 |
+
)
|
| 190 |
|
| 191 |
def is_available(self) -> bool:
|
| 192 |
try:
|
|
|
|
| 194 |
return True
|
| 195 |
except Exception:
|
| 196 |
return False
|
| 197 |
+
|
|
|
|
| 198 |
def create_llm_client(provider: str = "groq", **kwargs) -> LLMClient:
|
| 199 |
"""
|
| 200 |
Factory function to create LLM client.
|
router.py
CHANGED
|
@@ -43,7 +43,7 @@ USER QUERY: {query}
|
|
| 43 |
|
| 44 |
Determine if this query needs:
|
| 45 |
1. RAG - Semantic search through text content (searching for meanings, concepts, descriptions)
|
| 46 |
-
2. SQL - Structured database query (counting, filtering, aggregating, specific lookups, OR pagination requests like "show more", "show other", "next results", "remaining items")
|
| 47 |
3. HYBRID - Both semantic search and structured query
|
| 48 |
4. GENERAL - General conversation not requiring database access
|
| 49 |
|
|
@@ -81,7 +81,7 @@ REASONING: [brief explanation]"""
|
|
| 81 |
{"role": "system", "content": "You are a query routing assistant."},
|
| 82 |
{"role": "user", "content": prompt}
|
| 83 |
])
|
| 84 |
-
return self._parse_routing_response(response)
|
| 85 |
except Exception as e:
|
| 86 |
logger.warning(f"LLM routing failed: {e}, using heuristics")
|
| 87 |
return self._heuristic_route(query)
|
|
@@ -158,6 +158,8 @@ REASONING: [brief explanation]"""
|
|
| 158 |
return RoutingDecision(QueryType.SQL, 0.8, "SQL query for data retrieval")
|
| 159 |
elif rag_score > sql_score:
|
| 160 |
return RoutingDecision(QueryType.RAG, 0.8, "Semantic search for concepts")
|
|
|
|
|
|
|
| 161 |
elif sql_score > 0 and rag_score > 0:
|
| 162 |
return RoutingDecision(QueryType.HYBRID, 0.6, "Mixed query type")
|
| 163 |
else:
|
|
|
|
| 43 |
|
| 44 |
Determine if this query needs:
|
| 45 |
1. RAG - Semantic search through text content (searching for meanings, concepts, descriptions)
|
| 46 |
+
2. SQL - Structured database query (counting, filtering, aggregating, specific lookups, OR pagination requests like "show more", "show other", "next results", "remaining items", OR subjective filtering like "for kids", "summer shoes", "rainy season" which map to columns)
|
| 47 |
3. HYBRID - Both semantic search and structured query
|
| 48 |
4. GENERAL - General conversation not requiring database access
|
| 49 |
|
|
|
|
| 81 |
{"role": "system", "content": "You are a query routing assistant."},
|
| 82 |
{"role": "user", "content": prompt}
|
| 83 |
])
|
| 84 |
+
return self._parse_routing_response(response.content)
|
| 85 |
except Exception as e:
|
| 86 |
logger.warning(f"LLM routing failed: {e}, using heuristics")
|
| 87 |
return self._heuristic_route(query)
|
|
|
|
| 158 |
return RoutingDecision(QueryType.SQL, 0.8, "SQL query for data retrieval")
|
| 159 |
elif rag_score > sql_score:
|
| 160 |
return RoutingDecision(QueryType.RAG, 0.8, "Semantic search for concepts")
|
| 161 |
+
elif "is it good" in query_lower or "consider other" in query_lower:
|
| 162 |
+
return RoutingDecision(QueryType.GENERAL, 0.7, "Consultative question about metrics")
|
| 163 |
elif sql_score > 0 and rag_score > 0:
|
| 164 |
return RoutingDecision(QueryType.HYBRID, 0.6, "Mixed query type")
|
| 165 |
else:
|
sql/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/sql/__pycache__/__init__.cpython-311.pyc and b/sql/__pycache__/__init__.cpython-311.pyc differ
|
|
|
sql/__pycache__/generator.cpython-311.pyc
CHANGED
|
Binary files a/sql/__pycache__/generator.cpython-311.pyc and b/sql/__pycache__/generator.cpython-311.pyc differ
|
|
|
sql/__pycache__/validator.cpython-311.pyc
CHANGED
|
Binary files a/sql/__pycache__/validator.cpython-311.pyc and b/sql/__pycache__/validator.cpython-311.pyc differ
|
|
|
sql/generator.py
CHANGED
|
@@ -45,6 +45,12 @@ SQLite-SPECIFIC NOTES:
|
|
| 45 |
else: # MySQL
|
| 46 |
return """
|
| 47 |
MySQL-SPECIFIC NOTES:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
- LIKE is case-insensitive for non-binary strings
|
| 49 |
- Use CONCAT() for string concatenation
|
| 50 |
- Use LIMIT at the end of queries
|
|
@@ -69,11 +75,33 @@ RULES:
|
|
| 69 |
- Use pattern matching for flexibility.
|
| 70 |
- Use `OR` to combine multiple column checks.
|
| 71 |
7. DATA AWARENESS: In footwear databases, specific types like 'Formal', 'Casual', or 'Sports' often appear in `sub_category` OR `category`. Check both if available.
|
| 72 |
-
8.
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
- Look at the previous conversation for the original query conditions.
|
| 75 |
- Use LIMIT with OFFSET to get the next set of results (e.g., LIMIT 10 OFFSET 10 for the second page).
|
| 76 |
- Keep the same WHERE conditions from the previous query.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
{dialect_hints}
|
| 79 |
|
|
@@ -127,8 +155,15 @@ Generate a single {dialect} SELECT query to answer the user's question."""
|
|
| 127 |
|
| 128 |
response = self.llm_client.chat(messages)
|
| 129 |
|
| 130 |
-
# Extract SQL from response
|
| 131 |
-
sql = self._extract_sql(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
return sql, response
|
| 134 |
|
|
|
|
| 45 |
else: # MySQL
|
| 46 |
return """
|
| 47 |
MySQL-SPECIFIC NOTES:
|
| 48 |
+
- CRITICALLY IMPORTANT: This server runs with ONLY_FULL_GROUP_BY enabled.
|
| 49 |
+
- IF YOU USE GROUP BY, EVERY SINGLE COLUMN in the SELECT list MUST be either:
|
| 50 |
+
1. In the GROUP BY clause, OR
|
| 51 |
+
2. Wrapped in an aggregate function (SUM, COUNT, AVT, MAX, MIN).
|
| 52 |
+
- EXAMPLE ERROR: "Expression #2 of SELECT list is not in GROUP BY clause..." -> This means you selected a raw column without aggregation.
|
| 53 |
+
- FIX: Change `SELECT name, clicks... GROUP BY name` to `SELECT name, SUM(clicks)... GROUP BY name`.
|
| 54 |
- LIKE is case-insensitive for non-binary strings
|
| 55 |
- Use CONCAT() for string concatenation
|
| 56 |
- Use LIMIT at the end of queries
|
|
|
|
| 75 |
- Use pattern matching for flexibility.
|
| 76 |
- Use `OR` to combine multiple column checks.
|
| 77 |
7. DATA AWARENESS: In footwear databases, specific types like 'Formal', 'Casual', or 'Sports' often appear in `sub_category` OR `category`. Check both if available.
|
| 78 |
+
8. SUBJECTIVE/IMPLICIT FILTERS:
|
| 79 |
+
If the user asks for subjective attributes (e.g., "for kids", "summer usage", "rainy season") and no direct column exists:
|
| 80 |
+
- INFER logical mappings using available columns (material, type, category, description).
|
| 81 |
+
- EXAMPLES:
|
| 82 |
+
* "Summer" -> `category` IN ('Sandals', 'Slippers', 'Flip Flops') OR `material` IN ('Canvas', 'Mesh') OR `description` LIKE '%breathable%'
|
| 83 |
+
* "Winter/Rainy" -> `category` IN ('Boots') OR `material` IN ('Leather', 'Synthetic', 'Rubber') OR `description` LIKE '%waterproof%'
|
| 84 |
+
* "Kids" -> `category` IN ('Kids', 'Children', 'Junior') OR `product_name` LIKE '%Junior%' OR `product_name` LIKE '%Infant%' OR (`size` < 6 AND `size` > 0)
|
| 85 |
+
- Use `OR` broadly to capture potential matches.
|
| 86 |
+
- Use pattern matching (`LIKE` / `ILIKE`) on text columns if categories are unclear.
|
| 87 |
+
|
| 88 |
+
9. Return ONLY the SQL query, no explanations.
|
| 89 |
+
10. PAGINATION: If the user asks to "show more", "show other", "see remaining", or similar follow-up:
|
| 90 |
- Look at the previous conversation for the original query conditions.
|
| 91 |
- Use LIMIT with OFFSET to get the next set of results (e.g., LIMIT 10 OFFSET 10 for the second page).
|
| 92 |
- Keep the same WHERE conditions from the previous query.
|
| 93 |
+
- Use LIMIT with OFFSET to get the next set of results (e.g., LIMIT 10 OFFSET 10 for the second page).
|
| 94 |
+
- Keep the same WHERE conditions from the previous query.
|
| 95 |
+
11. GROUP BY RULES: If you use GROUP BY, every column in the SELECT list must be either in the GROUP BY clause or wrapped in an aggregate function (SUM, AVG, COUNT, MAX). Do NOT select raw columns like `clicks` or `price` if you are grouping by `product_name`; use SUM(clicks), AVG(price), etc.
|
| 96 |
+
12. BUSINESS LOGIC & METRICS:
|
| 97 |
+
- `sales` column is usually QUANTITY (integer). `price` is Unit Price. `mfrcost` is Unit Cost.
|
| 98 |
+
- REVENUE = `sales * price`
|
| 99 |
+
- GROSS PROFIT = `(price - mfrcost) * sales`
|
| 100 |
+
- NET PROFIT (w/ Ad Cost) = `((price - mfrcost) * sales) - adcost`
|
| 101 |
+
- PROFIT MARGIN (%) = `(NET PROFIT / REVENUE) * 100`
|
| 102 |
+
- ROAS = `REVENUE / adcost`
|
| 103 |
+
- If User asks for "Profit" or "Margin" and `adcost` is available, PREFER the NET PROFIT formula that subtracts `adcost`.
|
| 104 |
+
- Always aggregate (SUM) these values when grouping by product/category.
|
| 105 |
|
| 106 |
{dialect_hints}
|
| 107 |
|
|
|
|
| 155 |
|
| 156 |
response = self.llm_client.chat(messages)
|
| 157 |
|
| 158 |
+
# Extract SQL from response content
|
| 159 |
+
sql = self._extract_sql(response.content)
|
| 160 |
+
|
| 161 |
+
# We can optionally pass usage back too, but for strict backward compatibility
|
| 162 |
+
# let's just use the content in the tuple for now, or update the return type.
|
| 163 |
+
# Since I am updating the chatbot anyway, I will attach usage to the response.
|
| 164 |
+
# However, to avoid breaking other calls immediately, I'll return the response object as the second item
|
| 165 |
+
# instead of just the explanation string, OR I can monkey-patch the explanation string.
|
| 166 |
+
# Better: let's update return type to Tuple[str, LLMResponse].
|
| 167 |
|
| 168 |
return sql, response
|
| 169 |
|