Spaces:
Sleeping
Sleeping
File size: 11,756 Bytes
69f5099 fbdfc24 69f5099 fbdfc24 69f5099 8f0db18 69f5099 fbdfc24 69f5099 8f0db18 69f5099 8f0db18 69f5099 8f0db18 69f5099 fbdfc24 69f5099 fbdfc24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
# models/state_models.py
from typing import List, Dict, Any, Optional, Annotated, Literal, Union
from pydantic import BaseModel, Field, ConfigDict
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
import operator
import json
import logging
logger = logging.getLogger(__name__)
class MultiCountryLegalState(BaseModel):
messages: Annotated[List[Dict[str, Any]], operator.add] = Field(default_factory=list)
legal_context: Dict[str, Any] = Field(
default_factory=lambda: {
"jurisdiction": "Unknown",
"user_type": "general",
"document_type": "legal",
"detected_country": "unknown"
}
)
# FIX: Make supplemental_message handle concurrent updates
supplemental_message: Optional[str] = Field(
default="",
description="Supplemental message to display to user (e.g., fallback messages, apologies)"
)
session_id: Optional[str] = None
last_search_query: Optional[str] = None
detected_articles: Annotated[List[str], operator.add] = Field(default_factory=list)
router_decision: Optional[str] = None
search_results: Optional[str] = None
route_explanation: Optional[str] = None
country: Optional[str] = Field(default=None)
# Assistance email fields
assistance_requested: bool = Field(default=False)
user_email: Optional[str] = None
assistance_description: Optional[str] = None
email_status: Optional[str] = None # "pending", "sent", "error"
assistance_step: Optional[str] = Field(default=None) # "collecting_email", "collecting_description", "confirming_send"
pending_assistance_data: Dict[str, Any] = Field(default_factory=dict)
# Conversation repair tracking
repair_type: Optional[str] = None
original_query: Optional[str] = None
misunderstanding_count: int = Field(default=0)
# Enhanced routing support
primary_intent: Optional[str] = Field(default=None)
# NEW: Human approval fields
approval_status: Optional[str] = Field(default=None) # "pending", "approved", "rejected"
approval_reason: Optional[str] = Field(default=None)
approved_by: Optional[str] = Field(default=None)
approval_timestamp: Optional[str] = Field(default=None)
# Conversation summary fields
summary_generated: bool = Field(default=False)
last_summary_timestamp: Optional[str] = Field(default=None)
# NEW: Search-related fields to prevent storing complex data in legal_context
search_metadata: Dict[str, Any] = Field(default_factory=dict)
# ============================================================================
# CRITICAL FIX FOR JSON SERIALIZATION (Pydantic v2 Configuration)
# This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable
# ============================================================================
model_config = ConfigDict(
arbitrary_types_allowed=True, # Allow LangChain message types if used
validate_assignment=True,
# CRITICAL: Tell Pydantic how to serialize this model to JSON
json_encoders={
# Any custom types can be added here
}
)
def model_dump(self, **kwargs) -> Dict[str, Any]:
"""
Override model_dump to ensure proper serialization for PostgreSQL checkpointing.
This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable
"""
try:
data = super().model_dump(**kwargs)
except Exception as e:
logger.warning(f"Standard model_dump failed: {e}, using manual serialization")
# Fallback to manual serialization
data = {
"messages": self.messages if isinstance(self.messages, list) else [],
"legal_context": self.legal_context if isinstance(self.legal_context, dict) else {},
"supplemental_message": self.supplemental_message or "",
"session_id": self.session_id,
"last_search_query": self.last_search_query,
"detected_articles": self.detected_articles if isinstance(self.detected_articles, list) else [],
"router_decision": self.router_decision,
"search_results": self.search_results,
"route_explanation": self.route_explanation,
"country": self.country,
"assistance_requested": self.assistance_requested,
"user_email": self.user_email,
"assistance_description": self.assistance_description,
"email_status": self.email_status,
"assistance_step": self.assistance_step,
"pending_assistance_data": self.pending_assistance_data if isinstance(self.pending_assistance_data, dict) else {},
"repair_type": self.repair_type,
"original_query": self.original_query,
"misunderstanding_count": self.misunderstanding_count,
"primary_intent": self.primary_intent,
"approval_status": self.approval_status,
"approval_reason": self.approval_reason,
"approved_by": self.approved_by,
"approval_timestamp": self.approval_timestamp,
"summary_generated": self.summary_generated,
"last_summary_timestamp": self.last_summary_timestamp,
"search_metadata": self.search_metadata if isinstance(self.search_metadata, dict) else {},
}
# Ensure all nested objects are JSON-serializable
# Messages should already be dicts, but double-check
if "messages" in data and data["messages"]:
serialized_messages = []
for msg in data["messages"]:
try:
if isinstance(msg, dict):
serialized_messages.append(msg)
elif isinstance(msg, BaseMessage):
# Convert LangChain message objects to dicts
serialized_messages.append({
"role": "assistant" if isinstance(msg, AIMessage) else "user",
"content": msg.content,
"meta": getattr(msg, "additional_kwargs", {}),
})
else:
# Fallback for any other type
serialized_messages.append({
"role": "unknown",
"content": str(msg),
"meta": {}
})
except Exception as msg_error:
logger.warning(f"Error serializing message: {msg_error}")
serialized_messages.append({
"role": "unknown",
"content": str(msg),
"meta": {}
})
data["messages"] = serialized_messages
# Ensure nested dicts are serializable
for key in ["legal_context", "pending_assistance_data", "search_metadata"]:
if key in data and data[key]:
try:
# Convert any non-serializable objects to strings
data[key] = self._make_json_serializable(data[key])
except Exception as dict_error:
logger.warning(f"Error serializing {key}: {dict_error}")
data[key] = {}
return data
def model_dump_json(self, **kwargs) -> str:
"""
Override model_dump_json for explicit JSON string conversion.
"""
data = self.model_dump(**kwargs)
return json.dumps(data, default=str)
@staticmethod
def _make_json_serializable(obj: Any) -> Any:
"""
Recursively convert objects to JSON-serializable format.
"""
if isinstance(obj, dict):
return {k: MultiCountryLegalState._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [MultiCountryLegalState._make_json_serializable(item) for item in obj]
elif isinstance(obj, (str, int, float, bool, type(None))):
return obj
elif isinstance(obj, BaseMessage):
return {
"role": "assistant" if isinstance(obj, AIMessage) else "user",
"content": obj.content,
"meta": getattr(obj, "additional_kwargs", {}),
}
else:
# Convert any other type to string
return str(obj)
@classmethod
def model_validate(cls, obj: Any) -> "MultiCountryLegalState":
"""
Override model_validate to properly handle deserialization from checkpoints.
"""
if isinstance(obj, dict):
# Messages should already be dicts, but handle BaseMessage objects if present
if "messages" in obj and obj["messages"]:
reconstructed_messages = []
for msg in obj["messages"]:
if isinstance(msg, dict):
reconstructed_messages.append(msg)
elif isinstance(msg, BaseMessage):
reconstructed_messages.append({
"role": "assistant" if isinstance(msg, AIMessage) else "user",
"content": msg.content,
"meta": getattr(msg, "additional_kwargs", {}),
})
else:
reconstructed_messages.append({
"role": "unknown",
"content": str(msg),
"meta": {}
})
obj["messages"] = reconstructed_messages
return super().model_validate(obj)
# ============================================================================
@staticmethod
def detect_country(text: str) -> str:
"""
Detect country from text based on keywords.
Args:
text: User input text to analyze
Returns:
Country code: "benin", "madagascar", or "unknown"
"""
if not text:
return "unknown"
text_lower = text.lower()
# Benin keywords
benin_keywords = [
"bénin", "benin", "béninois", "béninoise",
"cotonou", "porto-novo", "porto novo",
"dahomey" # Historical name
]
# Madagascar keywords
madagascar_keywords = [
"madagascar", "malgache", "malagasy",
"antananarivo", "tananarive", "tana",
"toamasina", "tamatave"
]
# Check for country mentions
benin_score = sum(1 for keyword in benin_keywords if keyword in text_lower)
madagascar_score = sum(1 for keyword in madagascar_keywords if keyword in text_lower)
if benin_score > madagascar_score and benin_score > 0:
return "benin"
elif madagascar_score > benin_score and madagascar_score > 0:
return "madagascar"
return "unknown"
class RoutingResult(BaseModel):
country: Literal["benin", "madagascar", "unclear", "greeting_small_talk",
"conversation_repair", "assistance_request", "conversation_summarization", "out_of_scope"]
confidence: Literal["high", "medium", "low"]
method: str
explanation: str
class SearchResult(BaseModel):
documents: List[Any]
detected_articles: List[str]
applied_filters: Dict[str, Any]
query: str
country: str |