SAAHMATHWORKS commited on
Commit
69f5099
·
1 Parent(s): c6bd968

fix from claude ai

Browse files
Files changed (2) hide show
  1. api/main.py +69 -47
  2. models/state_models.py +114 -2
api/main.py CHANGED
@@ -189,11 +189,14 @@ def serialize_ai_message_chunk(chunk):
189
  raise TypeError(
190
  f"Object of type {type(chunk).__name__} is not correctly formatted for serialisation"
191
  )
 
192
  async def generate_legal_chat_responses(message: str, session_id: Optional[str] = None) -> str:
193
  if not session_id:
194
  session_id = f"api_{uuid4()}"
195
 
196
- input_state = {
 
 
197
  "messages": [{"role": "user", "content": message, "meta": {}}],
198
  "legal_context": {
199
  "jurisdiction": "Unknown",
@@ -207,16 +210,39 @@ async def generate_legal_chat_responses(message: str, session_id: Optional[str]
207
  "route_explanation": None,
208
  "last_search_query": None,
209
  "detected_articles": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  }
211
 
 
 
 
212
  config = {
213
  "configurable": {
214
  "thread_id": session_id
215
  }
216
  }
217
 
 
218
  events = graph.astream_events(
219
- MultiCountryLegalState(**input_state),
220
  version="v2",
221
  config=config
222
  )
@@ -225,56 +251,52 @@ async def generate_legal_chat_responses(message: str, session_id: Optional[str]
225
  current_node = ""
226
  final_state = None
227
 
228
- async for event in events:
229
- event_type = event["event"]
230
- node_name = event.get("name", "")
231
-
232
- if node_name != current_node:
233
- current_node = node_name
234
- yield f"data: {json.dumps({'type': 'node_transition', 'node': node_name})}\n\n"
235
-
236
- if event_type == "on_chat_model_stream":
237
- chunk_content = serialize_ai_message_chunk(event["data"]["chunk"])
238
- current_content += chunk_content
239
- yield f"data: {json.dumps({'type': 'content', 'content': chunk_content})}\n\n"
240
-
241
- elif event_type == "on_chat_model_end":
242
- yield f"data: {json.dumps({'type': 'content_end'})}\n\n"
243
-
244
- elif event_type == "on_chain_start" and "retrieval" in node_name:
245
- country = node_name.replace("_retrieval", "")
246
- yield f"data: {json.dumps({'type': 'search_start', 'country': country})}\n\n"
247
-
248
- elif event_type == "on_chain_end" and "retrieval" in node_name:
249
- country = node_name.replace("_retrieval", "")
250
- yield f"data: {json.dumps({'type': 'search_end', 'country': country})}\n\n"
251
 
252
- elif event_type == "on_tool_end":
253
- tool_name = event["name"]
254
- yield f"data: {json.dumps({'type': 'tool_complete', 'tool': tool_name})}\n\n"
255
 
256
- elif event_type == "on_graph_end":
257
- # Capture and convert the final state
258
- state = event.get("data", {}).get("state")
259
- if state and isinstance(state, MultiCountryLegalState):
260
- final_state = state
261
- state_dict = state.model_dump() if hasattr(state, "model_dump") else state.dict()
262
- if "messages" in state_dict and isinstance(state_dict["messages"], list):
263
- state_dict["messages"] = [
264
- msg if isinstance(msg, dict) else {"role": "unknown", "content": str(msg), "meta": {}}
265
- for msg in state_dict["messages"]
266
- ]
267
- yield f"data: {json.dumps({'type': 'state', 'content': state_dict})}\n\n"
268
- yield f"data: {json.dumps({'type': 'graph_end'})}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  # Yield final state if captured
271
  if final_state and isinstance(final_state, MultiCountryLegalState):
272
- final_state_dict = final_state.model_dump() if hasattr(final_state, "model_dump") else final_state.dict()
273
- if "messages" in final_state_dict and isinstance(final_state_dict["messages"], list):
274
- final_state_dict["messages"] = [
275
- msg if isinstance(msg, dict) else {"role": "unknown", "content": str(msg), "meta": {}}
276
- for msg in final_state_dict["messages"]
277
- ]
278
  yield f"data: {json.dumps({'type': 'final_state', 'content': final_state_dict})}\n\n"
279
 
280
  yield f"data: {json.dumps({'type': 'end'})}\n\n"
 
189
  raise TypeError(
190
  f"Object of type {type(chunk).__name__} is not correctly formatted for serialisation"
191
  )
192
+
193
  async def generate_legal_chat_responses(message: str, session_id: Optional[str] = None) -> str:
194
  if not session_id:
195
  session_id = f"api_{uuid4()}"
196
 
197
+ # CRITICAL FIX: Create input state as a dictionary first, then convert to Pydantic model
198
+ # This ensures proper serialization for PostgreSQL checkpointing
199
+ input_state_dict = {
200
  "messages": [{"role": "user", "content": message, "meta": {}}],
201
  "legal_context": {
202
  "jurisdiction": "Unknown",
 
210
  "route_explanation": None,
211
  "last_search_query": None,
212
  "detected_articles": [],
213
+ "supplemental_message": "",
214
+ "country": None,
215
+ "assistance_requested": False,
216
+ "user_email": None,
217
+ "assistance_description": None,
218
+ "email_status": None,
219
+ "assistance_step": None,
220
+ "pending_assistance_data": {},
221
+ "repair_type": None,
222
+ "original_query": None,
223
+ "misunderstanding_count": 0,
224
+ "primary_intent": None,
225
+ "approval_status": None,
226
+ "approval_reason": None,
227
+ "approved_by": None,
228
+ "approval_timestamp": None,
229
+ "summary_generated": False,
230
+ "last_summary_timestamp": None,
231
+ "search_metadata": {}
232
  }
233
 
234
+ # Convert to Pydantic model (this will use our custom model_dump for serialization)
235
+ input_state = MultiCountryLegalState(**input_state_dict)
236
+
237
  config = {
238
  "configurable": {
239
  "thread_id": session_id
240
  }
241
  }
242
 
243
+ # Stream events from the graph
244
  events = graph.astream_events(
245
+ input_state, # Pass the Pydantic model directly
246
  version="v2",
247
  config=config
248
  )
 
251
  current_node = ""
252
  final_state = None
253
 
254
+ try:
255
+ async for event in events:
256
+ event_type = event["event"]
257
+ node_name = event.get("name", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ if node_name != current_node:
260
+ current_node = node_name
261
+ yield f"data: {json.dumps({'type': 'node_transition', 'node': node_name})}\n\n"
262
 
263
+ if event_type == "on_chat_model_stream":
264
+ chunk_content = serialize_ai_message_chunk(event["data"]["chunk"])
265
+ current_content += chunk_content
266
+ yield f"data: {json.dumps({'type': 'content', 'content': chunk_content})}\n\n"
267
+
268
+ elif event_type == "on_chat_model_end":
269
+ yield f"data: {json.dumps({'type': 'content_end'})}\n\n"
270
+
271
+ elif event_type == "on_chain_start" and "retrieval" in node_name:
272
+ country = node_name.replace("_retrieval", "")
273
+ yield f"data: {json.dumps({'type': 'search_start', 'country': country})}\n\n"
274
+
275
+ elif event_type == "on_chain_end" and "retrieval" in node_name:
276
+ country = node_name.replace("_retrieval", "")
277
+ yield f"data: {json.dumps({'type': 'search_end', 'country': country})}\n\n"
278
+
279
+ elif event_type == "on_tool_end":
280
+ tool_name = event["name"]
281
+ yield f"data: {json.dumps({'type': 'tool_complete', 'tool': tool_name})}\n\n"
282
+
283
+ elif event_type == "on_graph_end":
284
+ # Capture and convert the final state
285
+ state = event.get("data", {}).get("output")
286
+ if state and isinstance(state, MultiCountryLegalState):
287
+ final_state = state
288
+ # Use our custom model_dump method for proper serialization
289
+ state_dict = state.model_dump()
290
+ yield f"data: {json.dumps({'type': 'state', 'content': state_dict})}\n\n"
291
+ yield f"data: {json.dumps({'type': 'graph_end'})}\n\n"
292
+
293
+ except Exception as e:
294
+ logger.error(f"Error in generate_legal_chat_responses: {e}", exc_info=True)
295
+ yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
296
 
297
  # Yield final state if captured
298
  if final_state and isinstance(final_state, MultiCountryLegalState):
299
+ final_state_dict = final_state.model_dump()
 
 
 
 
 
300
  yield f"data: {json.dumps({'type': 'final_state', 'content': final_state_dict})}\n\n"
301
 
302
  yield f"data: {json.dumps({'type': 'end'})}\n\n"
models/state_models.py CHANGED
@@ -1,7 +1,10 @@
1
- # [file name]: models/state_models.py
2
  from typing import List, Dict, Any, Optional, Annotated, Literal, Union
3
- from pydantic import BaseModel, Field
 
4
  import operator
 
 
5
 
6
  class MultiCountryLegalState(BaseModel):
7
  messages: Annotated[List[Dict[str, Any]], operator.add] = Field(default_factory=list)
@@ -55,6 +58,114 @@ class MultiCountryLegalState(BaseModel):
55
  # NEW: Search-related fields to prevent storing complex data in legal_context
56
  search_metadata: Dict[str, Any] = Field(default_factory=dict)
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  @staticmethod
59
  def detect_country(text: str) -> str:
60
  """
@@ -104,6 +215,7 @@ class RoutingResult(BaseModel):
104
  method: str
105
  explanation: str
106
 
 
107
  class SearchResult(BaseModel):
108
  documents: List[Any]
109
  detected_articles: List[str]
 
1
+ # models/state_models.py
2
  from typing import List, Dict, Any, Optional, Annotated, Literal, Union
3
+ from pydantic import BaseModel, Field, ConfigDict
4
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
5
  import operator
6
+ import json
7
+
8
 
9
  class MultiCountryLegalState(BaseModel):
10
  messages: Annotated[List[Dict[str, Any]], operator.add] = Field(default_factory=list)
 
58
  # NEW: Search-related fields to prevent storing complex data in legal_context
59
  search_metadata: Dict[str, Any] = Field(default_factory=dict)
60
 
61
+ # ============================================================================
62
+ # CRITICAL FIX FOR JSON SERIALIZATION (Pydantic v2 Configuration)
63
+ # This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable
64
+ # ============================================================================
65
+ model_config = ConfigDict(
66
+ arbitrary_types_allowed=True, # Allow LangChain message types if used
67
+ validate_assignment=True,
68
+ # CRITICAL: Tell Pydantic how to serialize this model to JSON
69
+ json_encoders={
70
+ # Any custom types can be added here
71
+ }
72
+ )
73
+
74
+ def model_dump(self, **kwargs) -> Dict[str, Any]:
75
+ """
76
+ Override model_dump to ensure proper serialization for PostgreSQL checkpointing.
77
+ This fixes: TypeError: Object of type MultiCountryLegalState is not JSON serializable
78
+ """
79
+ data = super().model_dump(**kwargs)
80
+
81
+ # Ensure all nested objects are JSON-serializable
82
+ # Messages should already be dicts, but double-check
83
+ if "messages" in data and data["messages"]:
84
+ serialized_messages = []
85
+ for msg in data["messages"]:
86
+ if isinstance(msg, dict):
87
+ serialized_messages.append(msg)
88
+ elif isinstance(msg, BaseMessage):
89
+ # Convert LangChain message objects to dicts
90
+ serialized_messages.append({
91
+ "role": "assistant" if isinstance(msg, AIMessage) else "user",
92
+ "content": msg.content,
93
+ "meta": getattr(msg, "additional_kwargs", {}),
94
+ })
95
+ else:
96
+ # Fallback for any other type
97
+ serialized_messages.append({
98
+ "role": "unknown",
99
+ "content": str(msg),
100
+ "meta": {}
101
+ })
102
+ data["messages"] = serialized_messages
103
+
104
+ # Ensure nested dicts are serializable
105
+ for key in ["legal_context", "pending_assistance_data", "search_metadata"]:
106
+ if key in data and data[key]:
107
+ # Convert any non-serializable objects to strings
108
+ data[key] = self._make_json_serializable(data[key])
109
+
110
+ return data
111
+
112
+ def model_dump_json(self, **kwargs) -> str:
113
+ """
114
+ Override model_dump_json for explicit JSON string conversion.
115
+ """
116
+ data = self.model_dump(**kwargs)
117
+ return json.dumps(data, default=str)
118
+
119
+ @staticmethod
120
+ def _make_json_serializable(obj: Any) -> Any:
121
+ """
122
+ Recursively convert objects to JSON-serializable format.
123
+ """
124
+ if isinstance(obj, dict):
125
+ return {k: MultiCountryLegalState._make_json_serializable(v) for k, v in obj.items()}
126
+ elif isinstance(obj, list):
127
+ return [MultiCountryLegalState._make_json_serializable(item) for item in obj]
128
+ elif isinstance(obj, (str, int, float, bool, type(None))):
129
+ return obj
130
+ elif isinstance(obj, BaseMessage):
131
+ return {
132
+ "role": "assistant" if isinstance(obj, AIMessage) else "user",
133
+ "content": obj.content,
134
+ "meta": getattr(obj, "additional_kwargs", {}),
135
+ }
136
+ else:
137
+ # Convert any other type to string
138
+ return str(obj)
139
+
140
+ @classmethod
141
+ def model_validate(cls, obj: Any) -> "MultiCountryLegalState":
142
+ """
143
+ Override model_validate to properly handle deserialization from checkpoints.
144
+ """
145
+ if isinstance(obj, dict):
146
+ # Messages should already be dicts, but handle BaseMessage objects if present
147
+ if "messages" in obj and obj["messages"]:
148
+ reconstructed_messages = []
149
+ for msg in obj["messages"]:
150
+ if isinstance(msg, dict):
151
+ reconstructed_messages.append(msg)
152
+ elif isinstance(msg, BaseMessage):
153
+ reconstructed_messages.append({
154
+ "role": "assistant" if isinstance(msg, AIMessage) else "user",
155
+ "content": msg.content,
156
+ "meta": getattr(msg, "additional_kwargs", {}),
157
+ })
158
+ else:
159
+ reconstructed_messages.append({
160
+ "role": "unknown",
161
+ "content": str(msg),
162
+ "meta": {}
163
+ })
164
+ obj["messages"] = reconstructed_messages
165
+
166
+ return super().model_validate(obj)
167
+ # ============================================================================
168
+
169
  @staticmethod
170
  def detect_country(text: str) -> str:
171
  """
 
215
  method: str
216
  explanation: str
217
 
218
+
219
  class SearchResult(BaseModel):
220
  documents: List[Any]
221
  detected_articles: List[str]