ClearMLIISC / agent /router_agent.py
sahanacp's picture
Upload 11 files
3195421 verified
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.memory import ConversationBufferMemory, SimpleMemory
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
llm = None
chat_memory = None
query_memory = None
prompt = None
def initialize_router_agent(llm_instance, chat_memory_instance):
global llm, chat_memory, prompt
llm = llm_instance
chat_memory = chat_memory_instance
system_prompt = """You are an intelligent query classification system for an e-commerce platform.
Your role is to accurately categorize incoming customer queries into one of two categories:
1. product_review:
- Queries about product features, specifications, or capabilities
- Questions about product prices and availability
- Requests for product reviews or comparisons
- Questions about product warranties or guarantees
- Inquiries about product shipping or delivery
- Questions about product compatibility or dimensions
- Requests for recommendations between products
2. generic:
- General customer service inquiries
- Account-related questions
- Technical support issues not related to specific products
- Website navigation help
- Payment or billing queries
- Return policy questions
- Company information requests
- Non-product related shipping questions
- Any other queries not directly related to specific products
INSTRUCTIONS:
- Analyze the input query carefully
- Respond ONLY with either "product_review" or "generic"
- Do not include any other text in your response
- If unsure, classify as "generic"
EXAMPLES:
User: "What are the features of the Samsung Galaxy S21?"
Assistant: product_review
User: "How much does the iPhone 13 Pro Max cost?"
Assistant: product_review
User: "Can you compare the Dell XPS 15 with the MacBook Pro?"
Assistant: product_review
User: "Is the Sony WH-1000XM4 headphone available in black?"
Assistant: product_review
User: "What's the battery life of the iPad Pro?"
Assistant: product_review
User: "I need help resetting my password"
Assistant: generic
User: "Where can I view my order history?"
Assistant: generic
User: "How do I update my shipping address?"
Assistant: generic
User: "What are your return policies?"
Assistant: generic
User: "I haven't received my refund yet"
Assistant: generic
User: "Do you ship internationally?"
Assistant: generic
User: "Can you recommend a good gaming laptop under $1000?"
Assistant: product_review
User: "What's the warranty period for electronics?"
Assistant: generic
User: "Is the Instant Pot dishwasher safe?"
Assistant: product_review
User: "How do I track my order?"
Assistant: generic
"""
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", "{input}")
])
logger.info("Router agent initialized successfully")
def classify_query(query):
try:
# Create chain with memory
chain = prompt | llm
# Add query to chat history before classification
if chat_memory and hasattr(chat_memory, 'chat_memory'):
chat_memory.chat_memory.add_user_message(query)
# Classify the query
response = chain.invoke({"input": query})
category = response.content.strip().lower()
# Validate category
if category not in ["product_review", "generic"]:
category = "generic" # Default fallback
# Add classification result to chat history
if chat_memory and hasattr(chat_memory, 'chat_memory'):
chat_memory.chat_memory.add_ai_message(f"Query classified as: {category}")
logger.info(f"Query: {query}")
logger.info(f"Classification: {category}")
print("**** in router agent****")
print("query :", query)
print("category :", category)
return category
except Exception as e:
print(f"Error in routing: {str(e)}")
return "generic" # Default fallback on error
def get_classification_history():
"""Retrieve classification history from memory"""
if chat_memory and hasattr(chat_memory, 'chat_memory'):
return chat_memory.chat_memory.messages
return []
def clear_context():
"""Clear all memory contexts"""
if chat_memory:
chat_memory.clear()
logger.info("Router agent context cleared")