# models/state_models.py from typing import List, Dict, Any, Optional, Annotated, Literal, Union from pydantic import BaseModel, Field, ConfigDict from langchain_core.messages import BaseMessage, HumanMessage, AIMessage import operator import json import logging logger = logging.getLogger(__name__) class MultiCountryLegalState(BaseModel): messages: Annotated[List[Dict[str, Any]], operator.add] = Field(default_factory=list) legal_context: Dict[str, Any] = Field( default_factory=lambda: { "jurisdiction": "Unknown", "user_type": "general", "document_type": "legal", "detected_country": "unknown" } ) # FIX: Make supplemental_message handle concurrent updates supplemental_message: Optional[str] = Field( default="", description="Supplemental message to display to user (e.g., fallback messages, apologies)" ) session_id: Optional[str] = None last_search_query: Optional[str] = None detected_articles: Annotated[List[str], operator.add] = Field(default_factory=list) router_decision: Optional[str] = None search_results: Optional[str] = None route_explanation: Optional[str] = None country: Optional[str] = Field(default=None) # Assistance email fields assistance_requested: bool = Field(default=False) user_email: Optional[str] = None assistance_description: Optional[str] = None email_status: Optional[str] = None # "pending", "sent", "error" assistance_step: Optional[str] = Field(default=None) # "collecting_email", "collecting_description", "confirming_send" pending_assistance_data: Dict[str, Any] = Field(default_factory=dict) # Conversation repair tracking repair_type: Optional[str] = None original_query: Optional[str] = None misunderstanding_count: int = Field(default=0) # Enhanced routing support primary_intent: Optional[str] = Field(default=None) # NEW: Human approval fields approval_status: Optional[str] = Field(default=None) # "pending", "approved", "rejected" approval_reason: Optional[str] = Field(default=None) approved_by: Optional[str] = Field(default=None) approval_timestamp: Optional[str] = Field(default=None) # Conversation summary fields summary_generated: bool = Field(default=False) last_summary_timestamp: Optional[str] = Field(default=None) # NEW: Search-related fields to prevent storing complex data in legal_context search_metadata: Dict[str, Any] = Field(default_factory=dict) # ============================================================================ # CRITICAL FIX FOR JSON SERIALIZATION (Pydantic v2 Configuration) # This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable # ============================================================================ model_config = ConfigDict( arbitrary_types_allowed=True, # Allow LangChain message types if used validate_assignment=True, # CRITICAL: Tell Pydantic how to serialize this model to JSON json_encoders={ # Any custom types can be added here } ) def model_dump(self, **kwargs) -> Dict[str, Any]: """ Override model_dump to ensure proper serialization for PostgreSQL checkpointing. This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable """ try: data = super().model_dump(**kwargs) except Exception as e: logger.warning(f"Standard model_dump failed: {e}, using manual serialization") # Fallback to manual serialization data = { "messages": self.messages if isinstance(self.messages, list) else [], "legal_context": self.legal_context if isinstance(self.legal_context, dict) else {}, "supplemental_message": self.supplemental_message or "", "session_id": self.session_id, "last_search_query": self.last_search_query, "detected_articles": self.detected_articles if isinstance(self.detected_articles, list) else [], "router_decision": self.router_decision, "search_results": self.search_results, "route_explanation": self.route_explanation, "country": self.country, "assistance_requested": self.assistance_requested, "user_email": self.user_email, "assistance_description": self.assistance_description, "email_status": self.email_status, "assistance_step": self.assistance_step, "pending_assistance_data": self.pending_assistance_data if isinstance(self.pending_assistance_data, dict) else {}, "repair_type": self.repair_type, "original_query": self.original_query, "misunderstanding_count": self.misunderstanding_count, "primary_intent": self.primary_intent, "approval_status": self.approval_status, "approval_reason": self.approval_reason, "approved_by": self.approved_by, "approval_timestamp": self.approval_timestamp, "summary_generated": self.summary_generated, "last_summary_timestamp": self.last_summary_timestamp, "search_metadata": self.search_metadata if isinstance(self.search_metadata, dict) else {}, } # Ensure all nested objects are JSON-serializable # Messages should already be dicts, but double-check if "messages" in data and data["messages"]: serialized_messages = [] for msg in data["messages"]: try: if isinstance(msg, dict): serialized_messages.append(msg) elif isinstance(msg, BaseMessage): # Convert LangChain message objects to dicts serialized_messages.append({ "role": "assistant" if isinstance(msg, AIMessage) else "user", "content": msg.content, "meta": getattr(msg, "additional_kwargs", {}), }) else: # Fallback for any other type serialized_messages.append({ "role": "unknown", "content": str(msg), "meta": {} }) except Exception as msg_error: logger.warning(f"Error serializing message: {msg_error}") serialized_messages.append({ "role": "unknown", "content": str(msg), "meta": {} }) data["messages"] = serialized_messages # Ensure nested dicts are serializable for key in ["legal_context", "pending_assistance_data", "search_metadata"]: if key in data and data[key]: try: # Convert any non-serializable objects to strings data[key] = self._make_json_serializable(data[key]) except Exception as dict_error: logger.warning(f"Error serializing {key}: {dict_error}") data[key] = {} return data def model_dump_json(self, **kwargs) -> str: """ Override model_dump_json for explicit JSON string conversion. """ data = self.model_dump(**kwargs) return json.dumps(data, default=str) @staticmethod def _make_json_serializable(obj: Any) -> Any: """ Recursively convert objects to JSON-serializable format. """ if isinstance(obj, dict): return {k: MultiCountryLegalState._make_json_serializable(v) for k, v in obj.items()} elif isinstance(obj, list): return [MultiCountryLegalState._make_json_serializable(item) for item in obj] elif isinstance(obj, (str, int, float, bool, type(None))): return obj elif isinstance(obj, BaseMessage): return { "role": "assistant" if isinstance(obj, AIMessage) else "user", "content": obj.content, "meta": getattr(obj, "additional_kwargs", {}), } else: # Convert any other type to string return str(obj) @classmethod def model_validate(cls, obj: Any) -> "MultiCountryLegalState": """ Override model_validate to properly handle deserialization from checkpoints. """ if isinstance(obj, dict): # Messages should already be dicts, but handle BaseMessage objects if present if "messages" in obj and obj["messages"]: reconstructed_messages = [] for msg in obj["messages"]: if isinstance(msg, dict): reconstructed_messages.append(msg) elif isinstance(msg, BaseMessage): reconstructed_messages.append({ "role": "assistant" if isinstance(msg, AIMessage) else "user", "content": msg.content, "meta": getattr(msg, "additional_kwargs", {}), }) else: reconstructed_messages.append({ "role": "unknown", "content": str(msg), "meta": {} }) obj["messages"] = reconstructed_messages return super().model_validate(obj) # ============================================================================ @staticmethod def detect_country(text: str) -> str: """ Detect country from text based on keywords. Args: text: User input text to analyze Returns: Country code: "benin", "madagascar", or "unknown" """ if not text: return "unknown" text_lower = text.lower() # Benin keywords benin_keywords = [ "bénin", "benin", "béninois", "béninoise", "cotonou", "porto-novo", "porto novo", "dahomey" # Historical name ] # Madagascar keywords madagascar_keywords = [ "madagascar", "malgache", "malagasy", "antananarivo", "tananarive", "tana", "toamasina", "tamatave" ] # Check for country mentions benin_score = sum(1 for keyword in benin_keywords if keyword in text_lower) madagascar_score = sum(1 for keyword in madagascar_keywords if keyword in text_lower) if benin_score > madagascar_score and benin_score > 0: return "benin" elif madagascar_score > benin_score and madagascar_score > 0: return "madagascar" return "unknown" class RoutingResult(BaseModel): country: Literal["benin", "madagascar", "unclear", "greeting_small_talk", "conversation_repair", "assistance_request", "conversation_summarization", "out_of_scope"] confidence: Literal["high", "medium", "low"] method: str explanation: str class SearchResult(BaseModel): documents: List[Any] detected_articles: List[str] applied_filters: Dict[str, Any] query: str country: str