Spaces:
Sleeping
Sleeping
SAAHMATHWORKS
commited on
Commit
·
8f0db18
1
Parent(s):
69f5099
production
Browse files- api/main.py +97 -21
- models/state_models.py +62 -14
api/main.py
CHANGED
|
@@ -4,7 +4,7 @@ import sys
|
|
| 4 |
from pathlib import Path
|
| 5 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 6 |
|
| 7 |
-
from typing import Optional
|
| 8 |
from contextlib import asynccontextmanager
|
| 9 |
from fastapi import FastAPI, Query, HTTPException
|
| 10 |
from fastapi.responses import StreamingResponse, HTMLResponse
|
|
@@ -34,6 +34,66 @@ graph = None
|
|
| 34 |
system_initialized = False
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
async def initialize_system():
|
| 38 |
global chat_manager, graph, system_initialized
|
| 39 |
try:
|
|
@@ -258,48 +318,64 @@ async def generate_legal_chat_responses(message: str, session_id: Optional[str]
|
|
| 258 |
|
| 259 |
if node_name != current_node:
|
| 260 |
current_node = node_name
|
| 261 |
-
yield f"data: {
|
| 262 |
|
| 263 |
if event_type == "on_chat_model_stream":
|
| 264 |
chunk_content = serialize_ai_message_chunk(event["data"]["chunk"])
|
| 265 |
current_content += chunk_content
|
| 266 |
-
yield f"data: {
|
| 267 |
|
| 268 |
elif event_type == "on_chat_model_end":
|
| 269 |
-
yield f"data: {
|
| 270 |
|
| 271 |
elif event_type == "on_chain_start" and "retrieval" in node_name:
|
| 272 |
country = node_name.replace("_retrieval", "")
|
| 273 |
-
yield f"data: {
|
| 274 |
|
| 275 |
elif event_type == "on_chain_end" and "retrieval" in node_name:
|
| 276 |
country = node_name.replace("_retrieval", "")
|
| 277 |
-
yield f"data: {
|
| 278 |
|
| 279 |
elif event_type == "on_tool_end":
|
| 280 |
tool_name = event["name"]
|
| 281 |
-
yield f"data: {
|
| 282 |
|
| 283 |
elif event_type == "on_graph_end":
|
| 284 |
-
# Capture and convert the final state
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
except Exception as e:
|
| 294 |
logger.error(f"Error in generate_legal_chat_responses: {e}", exc_info=True)
|
| 295 |
-
yield f"data: {
|
| 296 |
|
| 297 |
-
# Yield final state if captured
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
-
yield f"data: {
|
| 303 |
|
| 304 |
|
| 305 |
@app.get("/chat")
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 6 |
|
| 7 |
+
from typing import Optional, Any
|
| 8 |
from contextlib import asynccontextmanager
|
| 9 |
from fastapi import FastAPI, Query, HTTPException
|
| 10 |
from fastapi.responses import StreamingResponse, HTMLResponse
|
|
|
|
| 34 |
system_initialized = False
|
| 35 |
|
| 36 |
|
| 37 |
+
# ============================================================================
|
| 38 |
+
# CRITICAL: Safe JSON Serialization Utilities
|
| 39 |
+
# ============================================================================
|
| 40 |
+
class SafeJSONEncoder(json.JSONEncoder):
|
| 41 |
+
"""
|
| 42 |
+
Custom JSON encoder that safely handles Pydantic models and other non-serializable objects.
|
| 43 |
+
This is the ultimate fallback for any serialization issues.
|
| 44 |
+
"""
|
| 45 |
+
def default(self, obj):
|
| 46 |
+
# Handle Pydantic models
|
| 47 |
+
if hasattr(obj, 'model_dump'):
|
| 48 |
+
return obj.model_dump()
|
| 49 |
+
if hasattr(obj, 'dict'):
|
| 50 |
+
return obj.dict()
|
| 51 |
+
|
| 52 |
+
# Handle LangChain messages
|
| 53 |
+
if isinstance(obj, BaseMessage):
|
| 54 |
+
return {
|
| 55 |
+
"role": "assistant" if isinstance(obj, AIMessage) else "user",
|
| 56 |
+
"content": obj.content if hasattr(obj, 'content') else str(obj),
|
| 57 |
+
"meta": getattr(obj, "additional_kwargs", {}),
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Handle sets
|
| 61 |
+
if isinstance(obj, set):
|
| 62 |
+
return list(obj)
|
| 63 |
+
|
| 64 |
+
# Handle bytes
|
| 65 |
+
if isinstance(obj, bytes):
|
| 66 |
+
return obj.decode('utf-8', errors='ignore')
|
| 67 |
+
|
| 68 |
+
# Fallback: convert to string
|
| 69 |
+
try:
|
| 70 |
+
return str(obj)
|
| 71 |
+
except Exception:
|
| 72 |
+
return f"<Unserializable: {type(obj).__name__}>"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def safe_json_dumps(obj: Any) -> str:
|
| 76 |
+
"""
|
| 77 |
+
Safely convert any object to JSON string with multiple fallback strategies.
|
| 78 |
+
"""
|
| 79 |
+
try:
|
| 80 |
+
# Try standard JSON encoding first
|
| 81 |
+
return json.dumps(obj)
|
| 82 |
+
except (TypeError, ValueError):
|
| 83 |
+
try:
|
| 84 |
+
# Try with custom encoder
|
| 85 |
+
return json.dumps(obj, cls=SafeJSONEncoder)
|
| 86 |
+
except Exception:
|
| 87 |
+
try:
|
| 88 |
+
# Try with default=str fallback
|
| 89 |
+
return json.dumps(obj, default=str)
|
| 90 |
+
except Exception as e:
|
| 91 |
+
# Ultimate fallback: return error message
|
| 92 |
+
logger.error(f"Complete JSON serialization failure: {e}")
|
| 93 |
+
return json.dumps({"error": "serialization_failed", "message": str(e)})
|
| 94 |
+
# ============================================================================
|
| 95 |
+
|
| 96 |
+
|
| 97 |
async def initialize_system():
|
| 98 |
global chat_manager, graph, system_initialized
|
| 99 |
try:
|
|
|
|
| 318 |
|
| 319 |
if node_name != current_node:
|
| 320 |
current_node = node_name
|
| 321 |
+
yield f"data: {safe_json_dumps({'type': 'node_transition', 'node': node_name})}\n\n"
|
| 322 |
|
| 323 |
if event_type == "on_chat_model_stream":
|
| 324 |
chunk_content = serialize_ai_message_chunk(event["data"]["chunk"])
|
| 325 |
current_content += chunk_content
|
| 326 |
+
yield f"data: {safe_json_dumps({'type': 'content', 'content': chunk_content})}\n\n"
|
| 327 |
|
| 328 |
elif event_type == "on_chat_model_end":
|
| 329 |
+
yield f"data: {safe_json_dumps({'type': 'content_end'})}\n\n"
|
| 330 |
|
| 331 |
elif event_type == "on_chain_start" and "retrieval" in node_name:
|
| 332 |
country = node_name.replace("_retrieval", "")
|
| 333 |
+
yield f"data: {safe_json_dumps({'type': 'search_start', 'country': country})}\n\n"
|
| 334 |
|
| 335 |
elif event_type == "on_chain_end" and "retrieval" in node_name:
|
| 336 |
country = node_name.replace("_retrieval", "")
|
| 337 |
+
yield f"data: {safe_json_dumps({'type': 'search_end', 'country': country})}\n\n"
|
| 338 |
|
| 339 |
elif event_type == "on_tool_end":
|
| 340 |
tool_name = event["name"]
|
| 341 |
+
yield f"data: {safe_json_dumps({'type': 'tool_complete', 'tool': tool_name})}\n\n"
|
| 342 |
|
| 343 |
elif event_type == "on_graph_end":
|
| 344 |
+
# Capture and convert the final state - WITH SAFE SERIALIZATION
|
| 345 |
+
try:
|
| 346 |
+
state = event.get("data", {}).get("output")
|
| 347 |
+
if state:
|
| 348 |
+
if isinstance(state, MultiCountryLegalState):
|
| 349 |
+
final_state = state
|
| 350 |
+
# Use our custom model_dump method for proper serialization
|
| 351 |
+
state_dict = state.model_dump()
|
| 352 |
+
elif isinstance(state, dict):
|
| 353 |
+
state_dict = state
|
| 354 |
+
else:
|
| 355 |
+
# Fallback: convert to string
|
| 356 |
+
state_dict = {"state": str(state)}
|
| 357 |
+
|
| 358 |
+
yield f"data: {safe_json_dumps({'type': 'state', 'content': state_dict})}\n\n"
|
| 359 |
+
except Exception as state_error:
|
| 360 |
+
logger.warning(f"Could not serialize state: {state_error}")
|
| 361 |
+
# Don't fail, just skip state output
|
| 362 |
+
|
| 363 |
+
yield f"data: {safe_json_dumps({'type': 'graph_end'})}\n\n"
|
| 364 |
|
| 365 |
except Exception as e:
|
| 366 |
logger.error(f"Error in generate_legal_chat_responses: {e}", exc_info=True)
|
| 367 |
+
yield f"data: {safe_json_dumps({'type': 'error', 'message': str(e)})}\n\n"
|
| 368 |
|
| 369 |
+
# Yield final state if captured - WITH SAFE SERIALIZATION
|
| 370 |
+
try:
|
| 371 |
+
if final_state and isinstance(final_state, MultiCountryLegalState):
|
| 372 |
+
final_state_dict = final_state.model_dump()
|
| 373 |
+
yield f"data: {safe_json_dumps({'type': 'final_state', 'content': final_state_dict})}\n\n"
|
| 374 |
+
except Exception as final_error:
|
| 375 |
+
logger.warning(f"Could not serialize final state: {final_error}")
|
| 376 |
+
# Don't fail, just skip final state output
|
| 377 |
|
| 378 |
+
yield f"data: {safe_json_dumps({'type': 'end'})}\n\n"
|
| 379 |
|
| 380 |
|
| 381 |
@app.get("/chat")
|
models/state_models.py
CHANGED
|
@@ -4,6 +4,9 @@ from pydantic import BaseModel, Field, ConfigDict
|
|
| 4 |
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
| 5 |
import operator
|
| 6 |
import json
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class MultiCountryLegalState(BaseModel):
|
|
@@ -76,24 +79,65 @@ class MultiCountryLegalState(BaseModel):
|
|
| 76 |
Override model_dump to ensure proper serialization for PostgreSQL checkpointing.
|
| 77 |
This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable
|
| 78 |
"""
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# Ensure all nested objects are JSON-serializable
|
| 82 |
# Messages should already be dicts, but double-check
|
| 83 |
if "messages" in data and data["messages"]:
|
| 84 |
serialized_messages = []
|
| 85 |
for msg in data["messages"]:
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
serialized_messages.append({
|
| 98 |
"role": "unknown",
|
| 99 |
"content": str(msg),
|
|
@@ -104,8 +148,12 @@ class MultiCountryLegalState(BaseModel):
|
|
| 104 |
# Ensure nested dicts are serializable
|
| 105 |
for key in ["legal_context", "pending_assistance_data", "search_metadata"]:
|
| 106 |
if key in data and data[key]:
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
return data
|
| 111 |
|
|
|
|
| 4 |
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
| 5 |
import operator
|
| 6 |
import json
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
|
| 12 |
class MultiCountryLegalState(BaseModel):
|
|
|
|
| 79 |
Override model_dump to ensure proper serialization for PostgreSQL checkpointing.
|
| 80 |
This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable
|
| 81 |
"""
|
| 82 |
+
try:
|
| 83 |
+
data = super().model_dump(**kwargs)
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.warning(f"Standard model_dump failed: {e}, using manual serialization")
|
| 86 |
+
# Fallback to manual serialization
|
| 87 |
+
data = {
|
| 88 |
+
"messages": self.messages if isinstance(self.messages, list) else [],
|
| 89 |
+
"legal_context": self.legal_context if isinstance(self.legal_context, dict) else {},
|
| 90 |
+
"supplemental_message": self.supplemental_message or "",
|
| 91 |
+
"session_id": self.session_id,
|
| 92 |
+
"last_search_query": self.last_search_query,
|
| 93 |
+
"detected_articles": self.detected_articles if isinstance(self.detected_articles, list) else [],
|
| 94 |
+
"router_decision": self.router_decision,
|
| 95 |
+
"search_results": self.search_results,
|
| 96 |
+
"route_explanation": self.route_explanation,
|
| 97 |
+
"country": self.country,
|
| 98 |
+
"assistance_requested": self.assistance_requested,
|
| 99 |
+
"user_email": self.user_email,
|
| 100 |
+
"assistance_description": self.assistance_description,
|
| 101 |
+
"email_status": self.email_status,
|
| 102 |
+
"assistance_step": self.assistance_step,
|
| 103 |
+
"pending_assistance_data": self.pending_assistance_data if isinstance(self.pending_assistance_data, dict) else {},
|
| 104 |
+
"repair_type": self.repair_type,
|
| 105 |
+
"original_query": self.original_query,
|
| 106 |
+
"misunderstanding_count": self.misunderstanding_count,
|
| 107 |
+
"primary_intent": self.primary_intent,
|
| 108 |
+
"approval_status": self.approval_status,
|
| 109 |
+
"approval_reason": self.approval_reason,
|
| 110 |
+
"approved_by": self.approved_by,
|
| 111 |
+
"approval_timestamp": self.approval_timestamp,
|
| 112 |
+
"summary_generated": self.summary_generated,
|
| 113 |
+
"last_summary_timestamp": self.last_summary_timestamp,
|
| 114 |
+
"search_metadata": self.search_metadata if isinstance(self.search_metadata, dict) else {},
|
| 115 |
+
}
|
| 116 |
|
| 117 |
# Ensure all nested objects are JSON-serializable
|
| 118 |
# Messages should already be dicts, but double-check
|
| 119 |
if "messages" in data and data["messages"]:
|
| 120 |
serialized_messages = []
|
| 121 |
for msg in data["messages"]:
|
| 122 |
+
try:
|
| 123 |
+
if isinstance(msg, dict):
|
| 124 |
+
serialized_messages.append(msg)
|
| 125 |
+
elif isinstance(msg, BaseMessage):
|
| 126 |
+
# Convert LangChain message objects to dicts
|
| 127 |
+
serialized_messages.append({
|
| 128 |
+
"role": "assistant" if isinstance(msg, AIMessage) else "user",
|
| 129 |
+
"content": msg.content,
|
| 130 |
+
"meta": getattr(msg, "additional_kwargs", {}),
|
| 131 |
+
})
|
| 132 |
+
else:
|
| 133 |
+
# Fallback for any other type
|
| 134 |
+
serialized_messages.append({
|
| 135 |
+
"role": "unknown",
|
| 136 |
+
"content": str(msg),
|
| 137 |
+
"meta": {}
|
| 138 |
+
})
|
| 139 |
+
except Exception as msg_error:
|
| 140 |
+
logger.warning(f"Error serializing message: {msg_error}")
|
| 141 |
serialized_messages.append({
|
| 142 |
"role": "unknown",
|
| 143 |
"content": str(msg),
|
|
|
|
| 148 |
# Ensure nested dicts are serializable
|
| 149 |
for key in ["legal_context", "pending_assistance_data", "search_metadata"]:
|
| 150 |
if key in data and data[key]:
|
| 151 |
+
try:
|
| 152 |
+
# Convert any non-serializable objects to strings
|
| 153 |
+
data[key] = self._make_json_serializable(data[key])
|
| 154 |
+
except Exception as dict_error:
|
| 155 |
+
logger.warning(f"Error serializing {key}: {dict_error}")
|
| 156 |
+
data[key] = {}
|
| 157 |
|
| 158 |
return data
|
| 159 |
|