SAAHMATHWORKS commited on
Commit
c6bd968
·
1 Parent(s): cff0952

Revert "Update generate_legal_chat_responses for Pydantic v2 compatibility on Hugging Face"

Browse files

This reverts commit e1e3b4d7993fd3a80c40a59045f280358afa691c.

Files changed (1) hide show
  1. api/main.py +17 -30
api/main.py CHANGED
@@ -256,43 +256,30 @@ async def generate_legal_chat_responses(message: str, session_id: Optional[str]
256
  elif event_type == "on_graph_end":
257
  # Capture and convert the final state
258
  state = event.get("data", {}).get("state")
259
- if state is not None:
260
- try:
261
- if isinstance(state, MultiCountryLegalState):
262
- final_state = state
263
- # Use model_dump_json for Pydantic v2 compatibility
264
- state_dict = json.loads(state.model_dump_json(exclude_unset=True)) if hasattr(state, "model_dump_json") else state.dict(exclude_unset=True)
265
- if "messages" in state_dict and isinstance(state_dict["messages"], list):
266
- state_dict["messages"] = [
267
- msg if isinstance(msg, dict) else {"role": "unknown", "content": str(msg), "meta": {}}
268
- for msg in state_dict["messages"]
269
- ]
270
- yield f"data: {json.dumps({'type': 'state', 'content': state_dict})}\n\n"
271
- else:
272
- logger.warning(f"Unexpected state type: {type(state)}")
273
- except Exception as e:
274
- logger.error(f"Error converting state to dict: {str(e)}")
275
- state_dict = {"error": "Failed to serialize state"}
276
- yield f"data: {json.dumps({'type': 'state', 'content': state_dict})}\n\n"
277
  yield f"data: {json.dumps({'type': 'graph_end'})}\n\n"
278
 
279
  # Yield final state if captured
280
  if final_state and isinstance(final_state, MultiCountryLegalState):
281
- try:
282
- # Use model_dump_json for Pydantic v2 compatibility
283
- final_state_dict = json.loads(final_state.model_dump_json(exclude_unset=True)) if hasattr(final_state, "model_dump_json") else final_state.dict(exclude_unset=True)
284
- if "messages" in final_state_dict and isinstance(final_state_dict["messages"], list):
285
- final_state_dict["messages"] = [
286
- msg if isinstance(msg, dict) else {"role": "unknown", "content": str(msg), "meta": {}}
287
- for msg in final_state_dict["messages"]
288
- ]
289
- yield f"data: {json.dumps({'type': 'final_state', 'content': final_state_dict})}\n\n"
290
- except Exception as e:
291
- logger.error(f"Error converting final state to dict: {str(e)}")
292
- yield f"data: {json.dumps({'type': 'error', 'message': 'An internal error occurred'})}\n\n"
293
 
294
  yield f"data: {json.dumps({'type': 'end'})}\n\n"
295
 
 
296
  @app.get("/chat")
297
  async def chat_stream(
298
  message: str = Query(..., description="User message"),
 
256
  elif event_type == "on_graph_end":
257
  # Capture and convert the final state
258
  state = event.get("data", {}).get("state")
259
+ if state and isinstance(state, MultiCountryLegalState):
260
+ final_state = state
261
+ state_dict = state.model_dump() if hasattr(state, "model_dump") else state.dict()
262
+ if "messages" in state_dict and isinstance(state_dict["messages"], list):
263
+ state_dict["messages"] = [
264
+ msg if isinstance(msg, dict) else {"role": "unknown", "content": str(msg), "meta": {}}
265
+ for msg in state_dict["messages"]
266
+ ]
267
+ yield f"data: {json.dumps({'type': 'state', 'content': state_dict})}\n\n"
 
 
 
 
 
 
 
 
 
268
  yield f"data: {json.dumps({'type': 'graph_end'})}\n\n"
269
 
270
  # Yield final state if captured
271
  if final_state and isinstance(final_state, MultiCountryLegalState):
272
+ final_state_dict = final_state.model_dump() if hasattr(final_state, "model_dump") else final_state.dict()
273
+ if "messages" in final_state_dict and isinstance(final_state_dict["messages"], list):
274
+ final_state_dict["messages"] = [
275
+ msg if isinstance(msg, dict) else {"role": "unknown", "content": str(msg), "meta": {}}
276
+ for msg in final_state_dict["messages"]
277
+ ]
278
+ yield f"data: {json.dumps({'type': 'final_state', 'content': final_state_dict})}\n\n"
 
 
 
 
 
279
 
280
  yield f"data: {json.dumps({'type': 'end'})}\n\n"
281
 
282
+
283
  @app.get("/chat")
284
  async def chat_stream(
285
  message: str = Query(..., description="User message"),