Spaces:
Sleeping
Sleeping
| from typing import Dict, Literal | |
| import re | |
| print("Importing modules...") | |
| try: | |
| from transformers import pipeline | |
| MODEL_AVAILABLE = True | |
| print("Transformers imported successfully") | |
| except Exception as e: | |
| print(f"Warning: Could not import transformers: {e}") | |
| MODEL_AVAILABLE = False | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from tools import handle_general_query, recommend_products, format_recommendations | |
| from complaint_handler import process_complaint_flow, extract_order_mobile_with_llm | |
| from product_recommendation import get_category_from_message, handle_recommendation_flow | |
| print("All modules imported") | |
| class AgentState(Dict): | |
| messages: list | |
| session_id: str | |
| stage: str | |
| order_id: str | |
| customer_name: str | |
| issue_type: str | |
| attempts_to_resolve: int | |
| file_uploaded: bool | |
| current_agent: str | |
| mobile_number: str | |
| category: str | |
| next_action: str | |
| class SparkMartSupervisor: | |
| def __init__(self): | |
| self.conversation_state = {} | |
| self.greeting_patterns = { | |
| 'good afternoon': 'Good afternoon', 'good morning': 'Good morning', | |
| 'good evening': 'Good evening', 'good noon': 'Good noon', | |
| 'afternoon': 'Good afternoon', 'morning': 'Good morning', | |
| 'evening': 'Good evening', 'noon': 'Good noon', | |
| 'hello': 'Hi there', 'hi': 'Hello', 'hey': 'Hey' | |
| } | |
| # Initialize model only if available and try with timeout | |
| self.text_generator = None | |
| if MODEL_AVAILABLE: | |
| try: | |
| print("Loading DialoGPT model...") | |
| self.text_generator = pipeline( | |
| "text-generation", | |
| model="microsoft/DialoGPT-small", | |
| pad_token_id=50256, | |
| eos_token_id=50256, | |
| return_full_text=False | |
| ) | |
| print("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Warning: Could not load model: {e}") | |
| print("Continuing without model (will use fallback methods)") | |
| self.text_generator = None | |
| def _init_session(self, session_id): | |
| if session_id not in self.conversation_state: | |
| self.conversation_state[session_id] = { | |
| 'stage': 'greeting', 'order_id': None, 'customer_name': None, | |
| 'issue_type': None, 'attempts_to_resolve': 0, 'file_uploaded': False | |
| } | |
| return self.conversation_state[session_id] | |
| def _check_keywords(self, message, keywords): | |
| return any(word in message.lower() for word in keywords) | |
| def detect_greeting(self, message): | |
| msg_lower = message.lower().strip() | |
| for pattern, response in self.greeting_patterns.items(): | |
| if pattern in msg_lower: | |
| return response | |
| return "Hello" | |
| def supervise(self, message, session_id): | |
| state = self._init_session(session_id) | |
| # PRIORITY: Global exit handling | |
| if (message.lower().strip() in ['bye', 'exit', 'thanks', 'thank you', 'goodbye'] or | |
| self._check_keywords(message, ['bye', 'goodbye', 'exit'])): | |
| return "Thank you for using SparkMart AI! We appreciate your time and hope to serve you again soon. Have a wonderful day!" | |
| # Greeting stage | |
| if state['stage'] == 'greeting': | |
| greeting = self.detect_greeting(message) | |
| state['stage'] = 'query_type' | |
| return f"{greeting}! I am SparkMart AI. How may I assist you? Do you have a query regarding an order, need product recommendations, or is it a general query?" | |
| # Query type classification | |
| elif state['stage'] == 'query_type': | |
| if self._check_keywords(message, ['order', 'complaint', 'problem', 'issue', 'damaged', 'wrong']): | |
| if re.search(r'\d{6}', message) or re.search(r'\d{10}', message): | |
| return process_complaint_flow(message, state, self.text_generator) | |
| state['stage'] = 'order_details' | |
| return "Could you please provide your order ID or mobile number so that I can fetch your details?" | |
| elif self._check_keywords(message, ['recommend', 'suggestion', 'suggest', 'product', 'buy', 'purchase', 'shopping']): | |
| state['stage'] = 'recommendation_followup' | |
| category = get_category_from_message(message) | |
| mobile_match = re.search(r'\d{10}', message) | |
| if self._check_keywords(message, ['specialized', 'specialization', 'personalized', 'personal', 'customized', 'custom']): | |
| return "Great! Please provide your mobile number so I can give you personalized recommendations based on your purchase history." | |
| elif category: | |
| recs = recommend_products("category", category=category) | |
| return format_recommendations(recs) | |
| elif mobile_match: | |
| recs = recommend_products("personalized", mobile_number=mobile_match.group()) | |
| return format_recommendations(recs) | |
| else: | |
| recs = recommend_products("popular") | |
| return format_recommendations(recs) | |
| elif self._check_keywords(message, ['general']) and not self._check_keywords(message, ['order']): | |
| state['stage'] = 'general_query' | |
| return handle_general_query(message) | |
| else: | |
| return "I understand. Could you please clarify if this is regarding a specific order, product recommendations, or a general inquiry about our services?" | |
| # Complaint flow stages | |
| elif state['stage'] in ['order_details', 'issue_description', 'file_upload_request', 'replacement_offer', 'resolution_attempt', 'refund_step1', 'refund_step2', 'refund_step3', 'refund_consideration', 'refund_process', 'processing_request']: | |
| return process_complaint_flow(message, state, self.text_generator) | |
| # Recommendation flow | |
| elif state['stage'] == 'recommendation_followup': | |
| if self._check_keywords(message, ['specialized', 'specialization', 'personalized', 'personal', 'customized', 'custom']): | |
| return "Great! Please provide your mobile number so I can give you personalized recommendations based on your purchase history." | |
| return handle_recommendation_flow(message, state) | |
| # General query stage | |
| elif state['stage'] == 'general_query': | |
| if message.lower().strip() in ['no', 'no thanks', 'nothing', 'nope']: | |
| return "Thank you for using SparkMart AI! We appreciate your time and hope to serve you again soon. Have a wonderful day!" | |
| elif self._check_keywords(message, ['thanks', 'thank you', 'ok thanks']): | |
| return "You're welcome! Is there anything else I can help you with?" | |
| elif self._check_keywords(message, ['order', 'complaint', 'problem', 'issue']): | |
| if re.search(r'\d{6}', message) or re.search(r'\d{10}', message): | |
| state['stage'] = 'order_details' | |
| return process_complaint_flow(message, state, self.text_generator) | |
| state['stage'] = 'order_details' | |
| return "Could you please provide your order ID or mobile number so that I can fetch your details?" | |
| elif self._check_keywords(message, ['recommend', 'suggestion', 'suggest', 'product', 'buy', 'purchase', 'shopping']): | |
| state['stage'] = 'recommendation_followup' | |
| recs = recommend_products("popular") | |
| return format_recommendations(recs) | |
| else: | |
| return handle_general_query(message) | |
| # Global handlers | |
| if self._check_keywords(message, ['general query', 'general', 'service', 'information', 'discount', 'offer', 'delivery', 'return', 'website', 'payment']): | |
| state['stage'] = 'general_query' | |
| return handle_general_query(message) | |
| elif self._check_keywords(message, ['recommend', 'suggestion', 'suggest', 'product', 'buy', 'purchase', 'shopping']): | |
| state['stage'] = 'recommendation_followup' | |
| category = get_category_from_message(message) | |
| mobile_match = re.search(r'\d{10}', message) | |
| if self._check_keywords(message, ['specialized', 'specialization', 'personalized', 'personal', 'customized', 'custom']): | |
| return "Great! Please provide your mobile number so I can give you personalized recommendations based on your purchase history." | |
| elif category: | |
| recs = recommend_products("category", category=category) | |
| return format_recommendations(recs) | |
| elif mobile_match: | |
| recs = recommend_products("personalized", mobile_number=mobile_match.group()) | |
| return format_recommendations(recs) | |
| else: | |
| recs = recommend_products("popular") | |
| return format_recommendations(recs) | |
| if (self._check_keywords(message, ['order', 'complaint', 'problem', 'issue']) and | |
| state['stage'] not in ['order_details', 'issue_description', 'file_upload_request', 'replacement_offer', 'resolution_attempt', 'refund_step1', 'refund_step2', 'refund_step3', 'refund_consideration', 'refund_process', 'processing_request']): | |
| if re.search(r'\d{6}', message) or re.search(r'\d{10}', message): | |
| state['stage'] = 'order_details' | |
| return process_complaint_flow(message, state, self.text_generator) | |
| state['stage'] = 'order_details' | |
| return "Could you please provide your order ID or mobile number so that I can fetch your details?" | |
| return "I'm here to help! I can assist you with orders, product recommendations, or general information. What would you like help with?" | |
| print("Initializing supervisor...") | |
| supervisor = SparkMartSupervisor() | |
| print("Supervisor initialized") | |
| def supervisor_node(state: AgentState) -> AgentState: | |
| message = state["messages"][-1].content | |
| session_id = state["session_id"] | |
| response = supervisor.supervise(message, session_id) | |
| state["messages"].append(AIMessage(content=response)) | |
| state["next_action"] = "END" | |
| return state | |
| def should_continue(state: AgentState) -> Literal["END"]: | |
| return "END" | |
| print("Building workflow...") | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("supervisor", supervisor_node) | |
| workflow.set_entry_point("supervisor") | |
| workflow.add_conditional_edges("supervisor", should_continue, {"END": END}) | |
| app = workflow.compile() | |
| print("Workflow compiled successfully") | |
| def agent_respond(message, session_id, conversation_history): | |
| try: | |
| current_state = supervisor._init_session(session_id) | |
| state = AgentState( | |
| messages=[HumanMessage(content=message)], | |
| session_id=session_id, | |
| stage=current_state.get('stage', 'greeting'), | |
| order_id=current_state.get('order_id', ''), | |
| customer_name=current_state.get('customer_name', ''), | |
| issue_type=current_state.get('issue_type', ''), | |
| attempts_to_resolve=current_state.get('attempts_to_resolve', 0), | |
| file_uploaded=current_state.get('file_uploaded', False), | |
| current_agent="supervisor", | |
| mobile_number="", | |
| category="", | |
| next_action="" | |
| ) | |
| result = app.invoke(state) | |
| ai_messages = [msg for msg in result["messages"] if isinstance(msg, AIMessage)] | |
| return ai_messages[-1].content if ai_messages else "I'm here to help! How can I assist you today?" | |
| except Exception as e: | |
| print(f"LangGraph error: {e}") | |
| return supervisor.supervise(message, session_id) | |
| print("Main module loaded completely!") |