Spaces:
Sleeping
Sleeping
| # models/state_models.py | |
| from typing import List, Dict, Any, Optional, Annotated, Literal, Union | |
| from pydantic import BaseModel, Field, ConfigDict | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | |
| import operator | |
| import json | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class MultiCountryLegalState(BaseModel): | |
| messages: Annotated[List[Dict[str, Any]], operator.add] = Field(default_factory=list) | |
| legal_context: Dict[str, Any] = Field( | |
| default_factory=lambda: { | |
| "jurisdiction": "Unknown", | |
| "user_type": "general", | |
| "document_type": "legal", | |
| "detected_country": "unknown" | |
| } | |
| ) | |
| # FIX: Make supplemental_message handle concurrent updates | |
| supplemental_message: Optional[str] = Field( | |
| default="", | |
| description="Supplemental message to display to user (e.g., fallback messages, apologies)" | |
| ) | |
| session_id: Optional[str] = None | |
| last_search_query: Optional[str] = None | |
| detected_articles: Annotated[List[str], operator.add] = Field(default_factory=list) | |
| router_decision: Optional[str] = None | |
| search_results: Optional[str] = None | |
| route_explanation: Optional[str] = None | |
| country: Optional[str] = Field(default=None) | |
| # Assistance email fields | |
| assistance_requested: bool = Field(default=False) | |
| user_email: Optional[str] = None | |
| assistance_description: Optional[str] = None | |
| email_status: Optional[str] = None # "pending", "sent", "error" | |
| assistance_step: Optional[str] = Field(default=None) # "collecting_email", "collecting_description", "confirming_send" | |
| pending_assistance_data: Dict[str, Any] = Field(default_factory=dict) | |
| # Conversation repair tracking | |
| repair_type: Optional[str] = None | |
| original_query: Optional[str] = None | |
| misunderstanding_count: int = Field(default=0) | |
| # Enhanced routing support | |
| primary_intent: Optional[str] = Field(default=None) | |
| # NEW: Human approval fields | |
| approval_status: Optional[str] = Field(default=None) # "pending", "approved", "rejected" | |
| approval_reason: Optional[str] = Field(default=None) | |
| approved_by: Optional[str] = Field(default=None) | |
| approval_timestamp: Optional[str] = Field(default=None) | |
| # Conversation summary fields | |
| summary_generated: bool = Field(default=False) | |
| last_summary_timestamp: Optional[str] = Field(default=None) | |
| # NEW: Search-related fields to prevent storing complex data in legal_context | |
| search_metadata: Dict[str, Any] = Field(default_factory=dict) | |
| # ============================================================================ | |
| # CRITICAL FIX FOR JSON SERIALIZATION (Pydantic v2 Configuration) | |
| # This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable | |
| # ============================================================================ | |
| model_config = ConfigDict( | |
| arbitrary_types_allowed=True, # Allow LangChain message types if used | |
| validate_assignment=True, | |
| # CRITICAL: Tell Pydantic how to serialize this model to JSON | |
| json_encoders={ | |
| # Any custom types can be added here | |
| } | |
| ) | |
| def model_dump(self, **kwargs) -> Dict[str, Any]: | |
| """ | |
| Override model_dump to ensure proper serialization for PostgreSQL checkpointing. | |
| This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable | |
| """ | |
| try: | |
| data = super().model_dump(**kwargs) | |
| except Exception as e: | |
| logger.warning(f"Standard model_dump failed: {e}, using manual serialization") | |
| # Fallback to manual serialization | |
| data = { | |
| "messages": self.messages if isinstance(self.messages, list) else [], | |
| "legal_context": self.legal_context if isinstance(self.legal_context, dict) else {}, | |
| "supplemental_message": self.supplemental_message or "", | |
| "session_id": self.session_id, | |
| "last_search_query": self.last_search_query, | |
| "detected_articles": self.detected_articles if isinstance(self.detected_articles, list) else [], | |
| "router_decision": self.router_decision, | |
| "search_results": self.search_results, | |
| "route_explanation": self.route_explanation, | |
| "country": self.country, | |
| "assistance_requested": self.assistance_requested, | |
| "user_email": self.user_email, | |
| "assistance_description": self.assistance_description, | |
| "email_status": self.email_status, | |
| "assistance_step": self.assistance_step, | |
| "pending_assistance_data": self.pending_assistance_data if isinstance(self.pending_assistance_data, dict) else {}, | |
| "repair_type": self.repair_type, | |
| "original_query": self.original_query, | |
| "misunderstanding_count": self.misunderstanding_count, | |
| "primary_intent": self.primary_intent, | |
| "approval_status": self.approval_status, | |
| "approval_reason": self.approval_reason, | |
| "approved_by": self.approved_by, | |
| "approval_timestamp": self.approval_timestamp, | |
| "summary_generated": self.summary_generated, | |
| "last_summary_timestamp": self.last_summary_timestamp, | |
| "search_metadata": self.search_metadata if isinstance(self.search_metadata, dict) else {}, | |
| } | |
| # Ensure all nested objects are JSON-serializable | |
| # Messages should already be dicts, but double-check | |
| if "messages" in data and data["messages"]: | |
| serialized_messages = [] | |
| for msg in data["messages"]: | |
| try: | |
| if isinstance(msg, dict): | |
| serialized_messages.append(msg) | |
| elif isinstance(msg, BaseMessage): | |
| # Convert LangChain message objects to dicts | |
| serialized_messages.append({ | |
| "role": "assistant" if isinstance(msg, AIMessage) else "user", | |
| "content": msg.content, | |
| "meta": getattr(msg, "additional_kwargs", {}), | |
| }) | |
| else: | |
| # Fallback for any other type | |
| serialized_messages.append({ | |
| "role": "unknown", | |
| "content": str(msg), | |
| "meta": {} | |
| }) | |
| except Exception as msg_error: | |
| logger.warning(f"Error serializing message: {msg_error}") | |
| serialized_messages.append({ | |
| "role": "unknown", | |
| "content": str(msg), | |
| "meta": {} | |
| }) | |
| data["messages"] = serialized_messages | |
| # Ensure nested dicts are serializable | |
| for key in ["legal_context", "pending_assistance_data", "search_metadata"]: | |
| if key in data and data[key]: | |
| try: | |
| # Convert any non-serializable objects to strings | |
| data[key] = self._make_json_serializable(data[key]) | |
| except Exception as dict_error: | |
| logger.warning(f"Error serializing {key}: {dict_error}") | |
| data[key] = {} | |
| return data | |
| def model_dump_json(self, **kwargs) -> str: | |
| """ | |
| Override model_dump_json for explicit JSON string conversion. | |
| """ | |
| data = self.model_dump(**kwargs) | |
| return json.dumps(data, default=str) | |
| def _make_json_serializable(obj: Any) -> Any: | |
| """ | |
| Recursively convert objects to JSON-serializable format. | |
| """ | |
| if isinstance(obj, dict): | |
| return {k: MultiCountryLegalState._make_json_serializable(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [MultiCountryLegalState._make_json_serializable(item) for item in obj] | |
| elif isinstance(obj, (str, int, float, bool, type(None))): | |
| return obj | |
| elif isinstance(obj, BaseMessage): | |
| return { | |
| "role": "assistant" if isinstance(obj, AIMessage) else "user", | |
| "content": obj.content, | |
| "meta": getattr(obj, "additional_kwargs", {}), | |
| } | |
| else: | |
| # Convert any other type to string | |
| return str(obj) | |
| def model_validate(cls, obj: Any) -> "MultiCountryLegalState": | |
| """ | |
| Override model_validate to properly handle deserialization from checkpoints. | |
| """ | |
| if isinstance(obj, dict): | |
| # Messages should already be dicts, but handle BaseMessage objects if present | |
| if "messages" in obj and obj["messages"]: | |
| reconstructed_messages = [] | |
| for msg in obj["messages"]: | |
| if isinstance(msg, dict): | |
| reconstructed_messages.append(msg) | |
| elif isinstance(msg, BaseMessage): | |
| reconstructed_messages.append({ | |
| "role": "assistant" if isinstance(msg, AIMessage) else "user", | |
| "content": msg.content, | |
| "meta": getattr(msg, "additional_kwargs", {}), | |
| }) | |
| else: | |
| reconstructed_messages.append({ | |
| "role": "unknown", | |
| "content": str(msg), | |
| "meta": {} | |
| }) | |
| obj["messages"] = reconstructed_messages | |
| return super().model_validate(obj) | |
| # ============================================================================ | |
| def detect_country(text: str) -> str: | |
| """ | |
| Detect country from text based on keywords. | |
| Args: | |
| text: User input text to analyze | |
| Returns: | |
| Country code: "benin", "madagascar", or "unknown" | |
| """ | |
| if not text: | |
| return "unknown" | |
| text_lower = text.lower() | |
| # Benin keywords | |
| benin_keywords = [ | |
| "bénin", "benin", "béninois", "béninoise", | |
| "cotonou", "porto-novo", "porto novo", | |
| "dahomey" # Historical name | |
| ] | |
| # Madagascar keywords | |
| madagascar_keywords = [ | |
| "madagascar", "malgache", "malagasy", | |
| "antananarivo", "tananarive", "tana", | |
| "toamasina", "tamatave" | |
| ] | |
| # Check for country mentions | |
| benin_score = sum(1 for keyword in benin_keywords if keyword in text_lower) | |
| madagascar_score = sum(1 for keyword in madagascar_keywords if keyword in text_lower) | |
| if benin_score > madagascar_score and benin_score > 0: | |
| return "benin" | |
| elif madagascar_score > benin_score and madagascar_score > 0: | |
| return "madagascar" | |
| return "unknown" | |
| class RoutingResult(BaseModel): | |
| country: Literal["benin", "madagascar", "unclear", "greeting_small_talk", | |
| "conversation_repair", "assistance_request", "conversation_summarization", "out_of_scope"] | |
| confidence: Literal["high", "medium", "low"] | |
| method: str | |
| explanation: str | |
| class SearchResult(BaseModel): | |
| documents: List[Any] | |
| detected_articles: List[str] | |
| applied_filters: Dict[str, Any] | |
| query: str | |
| country: str |