akankshar639's picture
Update main.py
e68f24e verified
from typing import Dict, Literal
import re
print("Importing modules...")
try:
from transformers import pipeline
MODEL_AVAILABLE = True
print("Transformers imported successfully")
except Exception as e:
print(f"Warning: Could not import transformers: {e}")
MODEL_AVAILABLE = False
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage, AIMessage
from tools import handle_general_query, recommend_products, format_recommendations
from complaint_handler import process_complaint_flow, extract_order_mobile_with_llm
from product_recommendation import get_category_from_message, handle_recommendation_flow
print("All modules imported")
class AgentState(Dict):
messages: list
session_id: str
stage: str
order_id: str
customer_name: str
issue_type: str
attempts_to_resolve: int
file_uploaded: bool
current_agent: str
mobile_number: str
category: str
next_action: str
class SparkMartSupervisor:
def __init__(self):
self.conversation_state = {}
self.greeting_patterns = {
'good afternoon': 'Good afternoon', 'good morning': 'Good morning',
'good evening': 'Good evening', 'good noon': 'Good noon',
'afternoon': 'Good afternoon', 'morning': 'Good morning',
'evening': 'Good evening', 'noon': 'Good noon',
'hello': 'Hi there', 'hi': 'Hello', 'hey': 'Hey'
}
# Initialize model only if available and try with timeout
self.text_generator = None
if MODEL_AVAILABLE:
try:
print("Loading DialoGPT model...")
self.text_generator = pipeline(
"text-generation",
model="microsoft/DialoGPT-small",
pad_token_id=50256,
eos_token_id=50256,
return_full_text=False
)
print("Model loaded successfully")
except Exception as e:
print(f"Warning: Could not load model: {e}")
print("Continuing without model (will use fallback methods)")
self.text_generator = None
def _init_session(self, session_id):
if session_id not in self.conversation_state:
self.conversation_state[session_id] = {
'stage': 'greeting', 'order_id': None, 'customer_name': None,
'issue_type': None, 'attempts_to_resolve': 0, 'file_uploaded': False
}
return self.conversation_state[session_id]
def _check_keywords(self, message, keywords):
return any(word in message.lower() for word in keywords)
def detect_greeting(self, message):
msg_lower = message.lower().strip()
for pattern, response in self.greeting_patterns.items():
if pattern in msg_lower:
return response
return "Hello"
def supervise(self, message, session_id):
state = self._init_session(session_id)
# PRIORITY: Global exit handling
if (message.lower().strip() in ['bye', 'exit', 'thanks', 'thank you', 'goodbye'] or
self._check_keywords(message, ['bye', 'goodbye', 'exit'])):
return "Thank you for using SparkMart AI! We appreciate your time and hope to serve you again soon. Have a wonderful day!"
# Greeting stage
if state['stage'] == 'greeting':
greeting = self.detect_greeting(message)
state['stage'] = 'query_type'
return f"{greeting}! I am SparkMart AI. How may I assist you? Do you have a query regarding an order, need product recommendations, or is it a general query?"
# Query type classification
elif state['stage'] == 'query_type':
if self._check_keywords(message, ['order', 'complaint', 'problem', 'issue', 'damaged', 'wrong']):
if re.search(r'\d{6}', message) or re.search(r'\d{10}', message):
return process_complaint_flow(message, state, self.text_generator)
state['stage'] = 'order_details'
return "Could you please provide your order ID or mobile number so that I can fetch your details?"
elif self._check_keywords(message, ['recommend', 'suggestion', 'suggest', 'product', 'buy', 'purchase', 'shopping']):
state['stage'] = 'recommendation_followup'
category = get_category_from_message(message)
mobile_match = re.search(r'\d{10}', message)
if self._check_keywords(message, ['specialized', 'specialization', 'personalized', 'personal', 'customized', 'custom']):
return "Great! Please provide your mobile number so I can give you personalized recommendations based on your purchase history."
elif category:
recs = recommend_products("category", category=category)
return format_recommendations(recs)
elif mobile_match:
recs = recommend_products("personalized", mobile_number=mobile_match.group())
return format_recommendations(recs)
else:
recs = recommend_products("popular")
return format_recommendations(recs)
elif self._check_keywords(message, ['general']) and not self._check_keywords(message, ['order']):
state['stage'] = 'general_query'
return handle_general_query(message)
else:
return "I understand. Could you please clarify if this is regarding a specific order, product recommendations, or a general inquiry about our services?"
# Complaint flow stages
elif state['stage'] in ['order_details', 'issue_description', 'file_upload_request', 'replacement_offer', 'resolution_attempt', 'refund_step1', 'refund_step2', 'refund_step3', 'refund_consideration', 'refund_process', 'processing_request']:
return process_complaint_flow(message, state, self.text_generator)
# Recommendation flow
elif state['stage'] == 'recommendation_followup':
if self._check_keywords(message, ['specialized', 'specialization', 'personalized', 'personal', 'customized', 'custom']):
return "Great! Please provide your mobile number so I can give you personalized recommendations based on your purchase history."
return handle_recommendation_flow(message, state)
# General query stage
elif state['stage'] == 'general_query':
if message.lower().strip() in ['no', 'no thanks', 'nothing', 'nope']:
return "Thank you for using SparkMart AI! We appreciate your time and hope to serve you again soon. Have a wonderful day!"
elif self._check_keywords(message, ['thanks', 'thank you', 'ok thanks']):
return "You're welcome! Is there anything else I can help you with?"
elif self._check_keywords(message, ['order', 'complaint', 'problem', 'issue']):
if re.search(r'\d{6}', message) or re.search(r'\d{10}', message):
state['stage'] = 'order_details'
return process_complaint_flow(message, state, self.text_generator)
state['stage'] = 'order_details'
return "Could you please provide your order ID or mobile number so that I can fetch your details?"
elif self._check_keywords(message, ['recommend', 'suggestion', 'suggest', 'product', 'buy', 'purchase', 'shopping']):
state['stage'] = 'recommendation_followup'
recs = recommend_products("popular")
return format_recommendations(recs)
else:
return handle_general_query(message)
# Global handlers
if self._check_keywords(message, ['general query', 'general', 'service', 'information', 'discount', 'offer', 'delivery', 'return', 'website', 'payment']):
state['stage'] = 'general_query'
return handle_general_query(message)
elif self._check_keywords(message, ['recommend', 'suggestion', 'suggest', 'product', 'buy', 'purchase', 'shopping']):
state['stage'] = 'recommendation_followup'
category = get_category_from_message(message)
mobile_match = re.search(r'\d{10}', message)
if self._check_keywords(message, ['specialized', 'specialization', 'personalized', 'personal', 'customized', 'custom']):
return "Great! Please provide your mobile number so I can give you personalized recommendations based on your purchase history."
elif category:
recs = recommend_products("category", category=category)
return format_recommendations(recs)
elif mobile_match:
recs = recommend_products("personalized", mobile_number=mobile_match.group())
return format_recommendations(recs)
else:
recs = recommend_products("popular")
return format_recommendations(recs)
if (self._check_keywords(message, ['order', 'complaint', 'problem', 'issue']) and
state['stage'] not in ['order_details', 'issue_description', 'file_upload_request', 'replacement_offer', 'resolution_attempt', 'refund_step1', 'refund_step2', 'refund_step3', 'refund_consideration', 'refund_process', 'processing_request']):
if re.search(r'\d{6}', message) or re.search(r'\d{10}', message):
state['stage'] = 'order_details'
return process_complaint_flow(message, state, self.text_generator)
state['stage'] = 'order_details'
return "Could you please provide your order ID or mobile number so that I can fetch your details?"
return "I'm here to help! I can assist you with orders, product recommendations, or general information. What would you like help with?"
print("Initializing supervisor...")
supervisor = SparkMartSupervisor()
print("Supervisor initialized")
def supervisor_node(state: AgentState) -> AgentState:
message = state["messages"][-1].content
session_id = state["session_id"]
response = supervisor.supervise(message, session_id)
state["messages"].append(AIMessage(content=response))
state["next_action"] = "END"
return state
def should_continue(state: AgentState) -> Literal["END"]:
return "END"
print("Building workflow...")
workflow = StateGraph(AgentState)
workflow.add_node("supervisor", supervisor_node)
workflow.set_entry_point("supervisor")
workflow.add_conditional_edges("supervisor", should_continue, {"END": END})
app = workflow.compile()
print("Workflow compiled successfully")
def agent_respond(message, session_id, conversation_history):
try:
current_state = supervisor._init_session(session_id)
state = AgentState(
messages=[HumanMessage(content=message)],
session_id=session_id,
stage=current_state.get('stage', 'greeting'),
order_id=current_state.get('order_id', ''),
customer_name=current_state.get('customer_name', ''),
issue_type=current_state.get('issue_type', ''),
attempts_to_resolve=current_state.get('attempts_to_resolve', 0),
file_uploaded=current_state.get('file_uploaded', False),
current_agent="supervisor",
mobile_number="",
category="",
next_action=""
)
result = app.invoke(state)
ai_messages = [msg for msg in result["messages"] if isinstance(msg, AIMessage)]
return ai_messages[-1].content if ai_messages else "I'm here to help! How can I assist you today?"
except Exception as e:
print(f"LangGraph error: {e}")
return supervisor.supervise(message, session_id)
print("Main module loaded completely!")