# [file name]: core/nodes/routing_nodes.py # Add this as the FIRST lines of code (after docstrings) import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) import logging from typing import Dict, Any from langchain_core.runnables import RunnableConfig from models.state_models import MultiCountryLegalState from core.router import CountryRouter from core.nodes.base_node import BaseNode from core.prompts.prompt_templates import PromptTemplates logger = logging.getLogger(__name__) class RoutingNodes(BaseNode): """Router, greeting, and conversation repair nodes""" def __init__(self, router: CountryRouter, conversation_repair, llm): self.router = router self.conversation_repair = conversation_repair self.llm = llm self.prompts = PromptTemplates() async def router_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]: """Enhanced router that detects primary intent with state awareness""" try: s = state.model_dump() # CRITICAL: Check if we're continuing an assistance workflow # This prevents the router from misclassifying continuation messages assistance_step = s.get("assistance_step") if assistance_step and assistance_step not in [None, "cancelled", "completed"]: logger.info(f"⏩ Bypassing router - continuing assistance at step: {assistance_step}") return { "router_decision": "assistance_request", "route_explanation": f"Continuing assistance workflow: {assistance_step}", "assistance_step": assistance_step, # Ensure step persists "assistance_requested": True } # Normal routing for new messages return await self._perform_normal_routing(state, s) except Exception as e: logger.error(f"Router error: {str(e)}") legal_context = state.legal_context if hasattr(state, 'legal_context') else {} return self._create_router_response("unclear", f"Router error: {str(e)}", legal_context) async def _perform_normal_routing(self, state: MultiCountryLegalState, state_dict: Dict) -> Dict[str, Any]: """Perform normal routing for new user queries""" if not state_dict.get("messages"): logger.warning("No messages in state for router") return self._create_router_response("unclear", "No messages in state", state_dict.get("legal_context", {})) last_human = self._get_last_human_message(state_dict.get("messages", [])) if not last_human: logger.warning("No user query found in router") return self._create_router_response("unclear", "No user query found", state_dict.get("legal_context", {})) user_query = last_human.get("content", "").strip() if not user_query: logger.warning("Empty user query in router") return self._create_router_response("unclear", "Empty user query", state_dict.get("legal_context", {})) logger.info(f"🔀 Routing query: '{user_query[:50]}...'") routing_result = await self.router.route_query(user_query, state_dict["messages"]) primary_intent = routing_result.country logger.info(f"🎯 Router decision: {primary_intent} ({routing_result.confidence}) - {routing_result.method}") updated_context = self._update_legal_context(state_dict["legal_context"], primary_intent) response = { "router_decision": primary_intent, "route_explanation": f"{routing_result.method}: {routing_result.explanation}", "legal_context": updated_context, "primary_intent": primary_intent } # If this is an assistance request, initialize the workflow if primary_intent == "assistance_request": response.update({ "assistance_step": "collecting_email", "assistance_requested": True }) return response async def greeting_small_talk_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]: """Handle greetings and small talk""" try: s = state.model_dump() last_human = self._get_last_human_message(s.get("messages", [])) user_query = last_human.get("content", "").lower() if last_human else "" logger.info(f"👋 Handling greeting/small_talk: '{user_query[:30]}...'") greeting_response = self.prompts.generate_greeting_response(user_query) return { "messages": [{ "role": "assistant", "content": greeting_response, "meta": { "is_greeting": True, "timestamp": self._get_timestamp() } }], "search_results": "Greeting handled - no legal search performed" } except Exception as e: logger.error(f"Error in greeting node: {str(e)}") return self._create_error_state(f"Error in greeting: {str(e)}") async def conversation_repair_node(self, state: MultiCountryLegalState, config: RunnableConfig) -> Dict[str, Any]: """Unified repair handling with LLM""" try: s = state.model_dump() last_human = self._get_last_human_message(s.get("messages", [])) user_query = last_human.get("content", "") if last_human else "" logger.info(f"🔧 Handling repair request: '{user_query[:30]}...'") repair_response = await self.conversation_repair.generate_repair_response( user_query, s.get("messages", []), self.llm ) return { "messages": [{ "role": "assistant", "content": repair_response, "meta": { "is_repair_response": True, "timestamp": self._get_timestamp() } }], "search_results": "Repair handled - no legal search performed" } except Exception as e: logger.error(f"Error in repair node: {str(e)}") return self._create_error_state(f"Error in repair: {str(e)}") def _create_router_response(self, decision: str, explanation: str, legal_context: Dict) -> Dict[str, Any]: """Create a standardized router response""" return { "router_decision": decision, "route_explanation": explanation, "legal_context": legal_context, "primary_intent": decision } def _get_last_human_message(self, messages: list) -> Dict[str, Any]: """Get the last human message from conversation history""" for msg in reversed(messages): if msg.get("role") in ["user", "human"]: return msg return {} def _update_legal_context(self, legal_context: Dict, primary_intent: str) -> Dict: """Update legal context based on routing decision""" updated_context = legal_context.copy() # Map router decisions to detected_country country_mapping = { "benin": "benin", "madagascar": "madagascar", "assistance_request": updated_context.get("detected_country", "unknown"), "greeting_small_talk": "unknown", "conversation_repair": updated_context.get("detected_country", "unknown"), "conversation_summarization": updated_context.get("detected_country", "unknown"), "unclear": "unknown", "out_of_scope": "unknown" } updated_context["detected_country"] = country_mapping.get(primary_intent, "unknown") updated_context["primary_intent"] = primary_intent return updated_context def _get_timestamp(self) -> str: """Get current timestamp""" from datetime import datetime return datetime.now().isoformat() def _create_error_state(self, error_message: str) -> Dict[str, Any]: """Create error state response""" return { "messages": [{ "role": "assistant", "content": f"❌ Désolé, une erreur s'est produite. Veuillez réessayer.", "meta": {"error": error_message} }] }