MultiCountryRAG / models /state_models.py
SAAHMATHWORKS
production
8f0db18
# models/state_models.py
from typing import List, Dict, Any, Optional, Annotated, Literal, Union
from pydantic import BaseModel, Field, ConfigDict
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
import operator
import json
import logging
logger = logging.getLogger(__name__)
class MultiCountryLegalState(BaseModel):
messages: Annotated[List[Dict[str, Any]], operator.add] = Field(default_factory=list)
legal_context: Dict[str, Any] = Field(
default_factory=lambda: {
"jurisdiction": "Unknown",
"user_type": "general",
"document_type": "legal",
"detected_country": "unknown"
}
)
# FIX: Make supplemental_message handle concurrent updates
supplemental_message: Optional[str] = Field(
default="",
description="Supplemental message to display to user (e.g., fallback messages, apologies)"
)
session_id: Optional[str] = None
last_search_query: Optional[str] = None
detected_articles: Annotated[List[str], operator.add] = Field(default_factory=list)
router_decision: Optional[str] = None
search_results: Optional[str] = None
route_explanation: Optional[str] = None
country: Optional[str] = Field(default=None)
# Assistance email fields
assistance_requested: bool = Field(default=False)
user_email: Optional[str] = None
assistance_description: Optional[str] = None
email_status: Optional[str] = None # "pending", "sent", "error"
assistance_step: Optional[str] = Field(default=None) # "collecting_email", "collecting_description", "confirming_send"
pending_assistance_data: Dict[str, Any] = Field(default_factory=dict)
# Conversation repair tracking
repair_type: Optional[str] = None
original_query: Optional[str] = None
misunderstanding_count: int = Field(default=0)
# Enhanced routing support
primary_intent: Optional[str] = Field(default=None)
# NEW: Human approval fields
approval_status: Optional[str] = Field(default=None) # "pending", "approved", "rejected"
approval_reason: Optional[str] = Field(default=None)
approved_by: Optional[str] = Field(default=None)
approval_timestamp: Optional[str] = Field(default=None)
# Conversation summary fields
summary_generated: bool = Field(default=False)
last_summary_timestamp: Optional[str] = Field(default=None)
# NEW: Search-related fields to prevent storing complex data in legal_context
search_metadata: Dict[str, Any] = Field(default_factory=dict)
# ============================================================================
# CRITICAL FIX FOR JSON SERIALIZATION (Pydantic v2 Configuration)
# This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable
# ============================================================================
model_config = ConfigDict(
arbitrary_types_allowed=True, # Allow LangChain message types if used
validate_assignment=True,
# CRITICAL: Tell Pydantic how to serialize this model to JSON
json_encoders={
# Any custom types can be added here
}
)
def model_dump(self, **kwargs) -> Dict[str, Any]:
"""
Override model_dump to ensure proper serialization for PostgreSQL checkpointing.
This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable
"""
try:
data = super().model_dump(**kwargs)
except Exception as e:
logger.warning(f"Standard model_dump failed: {e}, using manual serialization")
# Fallback to manual serialization
data = {
"messages": self.messages if isinstance(self.messages, list) else [],
"legal_context": self.legal_context if isinstance(self.legal_context, dict) else {},
"supplemental_message": self.supplemental_message or "",
"session_id": self.session_id,
"last_search_query": self.last_search_query,
"detected_articles": self.detected_articles if isinstance(self.detected_articles, list) else [],
"router_decision": self.router_decision,
"search_results": self.search_results,
"route_explanation": self.route_explanation,
"country": self.country,
"assistance_requested": self.assistance_requested,
"user_email": self.user_email,
"assistance_description": self.assistance_description,
"email_status": self.email_status,
"assistance_step": self.assistance_step,
"pending_assistance_data": self.pending_assistance_data if isinstance(self.pending_assistance_data, dict) else {},
"repair_type": self.repair_type,
"original_query": self.original_query,
"misunderstanding_count": self.misunderstanding_count,
"primary_intent": self.primary_intent,
"approval_status": self.approval_status,
"approval_reason": self.approval_reason,
"approved_by": self.approved_by,
"approval_timestamp": self.approval_timestamp,
"summary_generated": self.summary_generated,
"last_summary_timestamp": self.last_summary_timestamp,
"search_metadata": self.search_metadata if isinstance(self.search_metadata, dict) else {},
}
# Ensure all nested objects are JSON-serializable
# Messages should already be dicts, but double-check
if "messages" in data and data["messages"]:
serialized_messages = []
for msg in data["messages"]:
try:
if isinstance(msg, dict):
serialized_messages.append(msg)
elif isinstance(msg, BaseMessage):
# Convert LangChain message objects to dicts
serialized_messages.append({
"role": "assistant" if isinstance(msg, AIMessage) else "user",
"content": msg.content,
"meta": getattr(msg, "additional_kwargs", {}),
})
else:
# Fallback for any other type
serialized_messages.append({
"role": "unknown",
"content": str(msg),
"meta": {}
})
except Exception as msg_error:
logger.warning(f"Error serializing message: {msg_error}")
serialized_messages.append({
"role": "unknown",
"content": str(msg),
"meta": {}
})
data["messages"] = serialized_messages
# Ensure nested dicts are serializable
for key in ["legal_context", "pending_assistance_data", "search_metadata"]:
if key in data and data[key]:
try:
# Convert any non-serializable objects to strings
data[key] = self._make_json_serializable(data[key])
except Exception as dict_error:
logger.warning(f"Error serializing {key}: {dict_error}")
data[key] = {}
return data
def model_dump_json(self, **kwargs) -> str:
"""
Override model_dump_json for explicit JSON string conversion.
"""
data = self.model_dump(**kwargs)
return json.dumps(data, default=str)
@staticmethod
def _make_json_serializable(obj: Any) -> Any:
"""
Recursively convert objects to JSON-serializable format.
"""
if isinstance(obj, dict):
return {k: MultiCountryLegalState._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [MultiCountryLegalState._make_json_serializable(item) for item in obj]
elif isinstance(obj, (str, int, float, bool, type(None))):
return obj
elif isinstance(obj, BaseMessage):
return {
"role": "assistant" if isinstance(obj, AIMessage) else "user",
"content": obj.content,
"meta": getattr(obj, "additional_kwargs", {}),
}
else:
# Convert any other type to string
return str(obj)
@classmethod
def model_validate(cls, obj: Any) -> "MultiCountryLegalState":
"""
Override model_validate to properly handle deserialization from checkpoints.
"""
if isinstance(obj, dict):
# Messages should already be dicts, but handle BaseMessage objects if present
if "messages" in obj and obj["messages"]:
reconstructed_messages = []
for msg in obj["messages"]:
if isinstance(msg, dict):
reconstructed_messages.append(msg)
elif isinstance(msg, BaseMessage):
reconstructed_messages.append({
"role": "assistant" if isinstance(msg, AIMessage) else "user",
"content": msg.content,
"meta": getattr(msg, "additional_kwargs", {}),
})
else:
reconstructed_messages.append({
"role": "unknown",
"content": str(msg),
"meta": {}
})
obj["messages"] = reconstructed_messages
return super().model_validate(obj)
# ============================================================================
@staticmethod
def detect_country(text: str) -> str:
"""
Detect country from text based on keywords.
Args:
text: User input text to analyze
Returns:
Country code: "benin", "madagascar", or "unknown"
"""
if not text:
return "unknown"
text_lower = text.lower()
# Benin keywords
benin_keywords = [
"bénin", "benin", "béninois", "béninoise",
"cotonou", "porto-novo", "porto novo",
"dahomey" # Historical name
]
# Madagascar keywords
madagascar_keywords = [
"madagascar", "malgache", "malagasy",
"antananarivo", "tananarive", "tana",
"toamasina", "tamatave"
]
# Check for country mentions
benin_score = sum(1 for keyword in benin_keywords if keyword in text_lower)
madagascar_score = sum(1 for keyword in madagascar_keywords if keyword in text_lower)
if benin_score > madagascar_score and benin_score > 0:
return "benin"
elif madagascar_score > benin_score and madagascar_score > 0:
return "madagascar"
return "unknown"
class RoutingResult(BaseModel):
country: Literal["benin", "madagascar", "unclear", "greeting_small_talk",
"conversation_repair", "assistance_request", "conversation_summarization", "out_of_scope"]
confidence: Literal["high", "medium", "low"]
method: str
explanation: str
class SearchResult(BaseModel):
documents: List[Any]
detected_articles: List[str]
applied_filters: Dict[str, Any]
query: str
country: str