SparkMart / nodes.py
Tamannathakur's picture
Update nodes.py
7687ad3 verified
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy import text, inspect
import logging
import re
from llm import gemini_model
from database import engine
from prompts import INTENT_DETECTION_PROMPT,QUERY_GENERATOR_PROMPT,RESPONSE_FORMATTER_PROMPT
from schema import RecommendationState
logger = logging.getLogger(__name__)
def intent_detector_node(state: RecommendationState) -> RecommendationState:
"""
Converts vague user queries into a clear, structured intent form.
Overwrites state['user_query'] with a cleaned version.
Extracts:
- intent (semantic_search, category_browse, product_lookup, follow_up)
- keywords (list)
Then forwards the updated user_query to the next node.
"""
logger.info("[INTENT_DETECTOR] Analyzing user intent...")
raw_query = state.get("user_query", "")
prompt = ChatPromptTemplate.from_messages([
("system", INTENT_DETECTION_PROMPT),
("human", f"User Query: {raw_query}\nReturn JSON:")
])
try:
chain = prompt | gemini_model
response = chain.invoke({})
content = response.content.strip()
import json
data = json.loads(content)
state["user_query"] = data.get("clean_query", raw_query)
logger.info(f"[INTENT_DETECTOR] Clean Query: {state['user_query']}")
except Exception as e:
logger.error(f"[INTENT_DETECTOR] Error: {e}")
state["user_query"] = raw_query
return state
def inspect_schema_node(state: RecommendationState) -> RecommendationState:
"""
Fetches actual database schema and sample data
"""
logger.info("[SCHEMA_INSPECTOR] Fetching database schema...")
try:
inspector = inspect(engine)
columns = [col['name'] for col in inspector.get_columns('Ecommerce_Data')]
state["available_columns"] = columns
with engine.connect() as conn:
result = conn.execute(text("SELECT DISTINCT Category FROM Ecommerce_Data LIMIT 20"))
state["available_categories"] = [row[0] for row in result.fetchall()]
result = conn.execute(text("SELECT Product_Name FROM Ecommerce_Data LIMIT 10"))
state["sample_products"] = [row[0] for row in result.fetchall()]
logger.info(f"[SCHEMA_INSPECTOR] Found {len(columns)} columns")
logger.info(f"[SCHEMA_INSPECTOR] Found {len(state['available_categories'])} categories")
except Exception as e:
logger.error(f"[SCHEMA_INSPECTOR] Error: {e}")
state["error_message"] = f"Database schema inspection failed: {e}"
return state
def generate_query_node(state: RecommendationState) -> RecommendationState:
"""
Uses LLM to understand intent and generate SQL query
The LLM has access to conversation history via checkpointer
"""
logger.info("[QUERY_GENERATOR] Generating SQL query with LLM...")
if state.get("error_message"):
return state
prompt = ChatPromptTemplate.from_messages([
("system", QUERY_GENERATOR_PROMPT),
("human", "User Query: {user_query}\n\nGenerate the SQL query:")
])
try:
chain = prompt | gemini_model
response = chain.invoke({
"user_query": state["user_query"],
"columns": ", ".join(state["available_columns"]),
"categories": ", ".join(state["available_categories"][:10]),
"sample_products": ", ".join(state["sample_products"][:5])
})
sql_query = response.content.strip()
sql_query = re.sub(r'```sql\s*|\s*```', '', sql_query)
sql_query = sql_query.strip()
if not sql_query.endswith(';'):
sql_query += ';'
state["sql_query"] = sql_query
logger.info(f"[QUERY_GENERATOR] Generated query: {sql_query}")
except Exception as e:
logger.error(f"[QUERY_GENERATOR] Error: {e}")
state["error_message"] = f"Query generation failed: {e}"
return state
def validate_query_node(state: RecommendationState) -> RecommendationState:
"""
Validates SQL query for safety and syntax
"""
logger.info("[QUERY_VALIDATOR] Validating SQL query...")
if state.get("error_message"):
return state
sql_query = state["sql_query"]
errors = []
dangerous_keywords = ['DROP', 'DELETE', 'UPDATE', 'INSERT', 'ALTER', 'TRUNCATE', 'CREATE']
for keyword in dangerous_keywords:
if keyword in sql_query.upper():
errors.append(f"Dangerous keyword '{keyword}' detected")
if not sql_query.upper().strip().startswith('SELECT'):
errors.append("Query must be a SELECT statement")
if 'LIMIT' not in sql_query.upper():
errors.append("Query must include LIMIT clause")
if 'FROM' not in sql_query.upper():
errors.append("Query must include FROM clause")
state["validation_errors"] = errors
if errors:
logger.warning(f"[QUERY_VALIDATOR] ✗ Validation failed: {errors}")
state["error_message"] = f"Invalid query: {', '.join(errors)}"
else:
logger.info("[QUERY_VALIDATOR] Query is valid and safe")
return state
def execute_query_node(state: RecommendationState) -> RecommendationState:
"""
Executes the validated SQL query
"""
logger.info("[QUERY_EXECUTOR] Executing SQL query...")
if state.get("error_message") or state.get("validation_errors"):
return state
try:
with engine.connect() as conn:
sql_query = state["sql_query"].rstrip(';')
result = conn.execute(text(sql_query))
rows = result.fetchall()
columns = result.keys()
results = [dict(zip(columns, row)) for row in rows]
state["query_results"] = results
logger.info(f"[QUERY_EXECUTOR] Found {len(results)} results")
except Exception as e:
logger.error(f"[QUERY_EXECUTOR] Error: {e}")
state["error_message"] = f"Query execution failed: {e}"
return state
def format_response_node(state: RecommendationState) -> RecommendationState:
"""
Uses LLM to format results into natural language response
The LLM has access to conversation history via checkpointer
"""
logger.info("[RESPONSE_FORMATTER] Formatting response...")
if state.get("error_message"):
state["formatted_response"] = (
"I apologize, but I encountered an issue searching for products. "
"Could you please rephrase your request or try browsing our categories?"
)
return state
prompt = ChatPromptTemplate.from_messages([
("system", RESPONSE_FORMATTER_PROMPT),
("human", """User Query: {user_query}
Search Results: {results}
Format response:""")
])
try:
chain = prompt | gemini_model
response = chain.invoke({
"user_query": state["user_query"],
"results": state["query_results"][:10],
"categories": ", ".join(state.get("available_categories", [])[:5])
})
state["formatted_response"] = response.content
logger.info("[RESPONSE_FORMATTER] Response formatted")
except Exception as e:
logger.error(f"[RESPONSE_FORMATTER] Error: {e}")
results = state["query_results"][:5]
response = f"Found {len(state['query_results'])} products:\n\n"
for idx, product in enumerate(results, 1):
response += f"{idx}. {product.get('Product_Name', 'N/A')}\n"
if 'Category' in product:
response += f" Category: {product['Category']}\n"
if 'Price' in product:
response += f" Price: ${product['Price']}\n"
state["formatted_response"] = response
return state