SAAHMATHWORKS commited on
Commit
8755af8
·
1 Parent(s): a9e0a4c

Replacing graph.aget(config) with state capture from on_graph_end

Browse files
Files changed (1) hide show
  1. api/main.py +82 -73
api/main.py CHANGED
@@ -10,6 +10,7 @@ from fastapi import FastAPI, Query, HTTPException
10
  from fastapi.responses import StreamingResponse, HTMLResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from langchain_core.messages import AIMessageChunk
 
13
  import json
14
  from uuid import uuid4
15
  import logging
@@ -20,6 +21,9 @@ import asyncio
20
  from core.system_initializer import setup_system
21
  from models.state_models import MultiCountryLegalState
22
 
 
 
 
23
  # Setup logging
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
@@ -185,92 +189,97 @@ def serialize_ai_message_chunk(chunk):
185
  raise TypeError(
186
  f"Object of type {type(chunk).__name__} is not correctly formatted for serialisation"
187
  )
188
-
189
- async def generate_legal_chat_responses(message: str, session_id: Optional[str] = None):
190
- """Generate streaming responses for legal chat"""
191
- if not system_initialized:
192
- yield f"data: {json.dumps({'type': 'error', 'message': 'System is still starting up. Please try again in a moment.'})}\n\n"
193
- yield f"data: {json.dumps({'type': 'end'})}\n\n"
194
- return
195
-
196
- is_new_conversation = session_id is None
197
-
198
- if is_new_conversation:
199
  session_id = f"api_{uuid4()}"
200
- logger.info(f"🆕 New conversation session: {session_id}")
201
- yield f"data: {json.dumps({'type': 'session', 'session_id': session_id})}\n\n"
202
- else:
203
- logger.info(f"🔄 Continuing session: {session_id}")
204
 
205
- try:
206
- input_state = {
207
- "messages": [{"role": "user", "content": message, "meta": {}}],
208
- "legal_context": {
209
- "jurisdiction": "Unknown",
210
- "user_type": "general",
211
- "document_type": "legal",
212
- "detected_country": "unknown"
213
- },
214
- "session_id": session_id,
215
- "router_decision": None,
216
- "search_results": None,
217
- "route_explanation": None,
218
- "last_search_query": None,
219
- "detected_articles": [],
220
- }
221
 
222
- config = {
223
- "configurable": {
224
- "thread_id": session_id
225
- }
226
  }
 
227
 
228
- events = graph.astream_events(
229
- MultiCountryLegalState(**input_state),
230
- version="v2",
231
- config=config
232
- )
233
 
234
- current_content = ""
235
- current_node = ""
 
236
 
237
- async for event in events:
238
- event_type = event["event"]
239
- node_name = event.get("name", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- if node_name != current_node:
242
- current_node = node_name
243
- yield f"data: {json.dumps({'type': 'node_transition', 'node': node_name})}\n\n"
 
 
 
 
244
 
245
- if event_type == "on_chat_model_stream":
246
- chunk_content = serialize_ai_message_chunk(event["data"]["chunk"])
247
- current_content += chunk_content
248
- yield f"data: {json.dumps({'type': 'content', 'content': chunk_content})}\n\n"
249
-
250
- elif event_type == "on_chat_model_end":
251
- yield f"data: {json.dumps({'type': 'content_end'})}\n\n"
252
-
253
- elif event_type == "on_chain_start" and "retrieval" in node_name:
254
- country = node_name.replace("_retrieval", "")
255
- yield f"data: {json.dumps({'type': 'search_start', 'country': country})}\n\n"
256
-
257
- elif event_type == "on_chain_end" and "retrieval" in node_name:
258
- country = node_name.replace("_retrieval", "")
259
- yield f"data: {json.dumps({'type': 'search_end', 'country': country})}\n\n"
260
-
261
- elif event_type == "on_tool_end":
262
- tool_name = event["name"]
263
- yield f"data: {json.dumps({'type': 'tool_complete', 'tool': tool_name})}\n\n"
264
 
265
- elif event_type == "on_graph_end":
266
- yield f"data: {json.dumps({'type': 'graph_end'})}\n\n"
 
 
 
 
 
 
 
267
 
268
- except Exception as e:
269
- logger.error(f"Error in streaming: {e}")
270
- yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
271
-
272
  yield f"data: {json.dumps({'type': 'end'})}\n\n"
273
 
 
274
  @app.get("/chat")
275
  async def chat_stream(
276
  message: str = Query(..., description="User message"),
 
10
  from fastapi.responses import StreamingResponse, HTMLResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from langchain_core.messages import AIMessageChunk
13
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
14
  import json
15
  from uuid import uuid4
16
  import logging
 
21
  from core.system_initializer import setup_system
22
  from models.state_models import MultiCountryLegalState
23
 
24
+ # utils functions
25
+ from utils.helpers import message_obj_to_dict, dict_to_message_obj
26
+
27
  # Setup logging
28
  logging.basicConfig(level=logging.INFO)
29
  logger = logging.getLogger(__name__)
 
189
  raise TypeError(
190
  f"Object of type {type(chunk).__name__} is not correctly formatted for serialisation"
191
  )
192
+ async def generate_legal_chat_responses(message: str, session_id: Optional[str] = None) -> str:
193
+ if not session_id:
 
 
 
 
 
 
 
 
 
194
  session_id = f"api_{uuid4()}"
 
 
 
 
195
 
196
+ input_state = {
197
+ "messages": [{"role": "user", "content": message, "meta": {}}],
198
+ "legal_context": {
199
+ "jurisdiction": "Unknown",
200
+ "user_type": "general",
201
+ "document_type": "legal",
202
+ "detected_country": "unknown"
203
+ },
204
+ "session_id": session_id,
205
+ "router_decision": None,
206
+ "search_results": None,
207
+ "route_explanation": None,
208
+ "last_search_query": None,
209
+ "detected_articles": [],
210
+ }
 
211
 
212
+ config = {
213
+ "configurable": {
214
+ "thread_id": session_id
 
215
  }
216
+ }
217
 
218
+ events = graph.astream_events(
219
+ MultiCountryLegalState(**input_state),
220
+ version="v2",
221
+ config=config
222
+ )
223
 
224
+ current_content = ""
225
+ current_node = ""
226
+ final_state = None
227
 
228
+ async for event in events:
229
+ event_type = event["event"]
230
+ node_name = event.get("name", "")
231
+
232
+ if node_name != current_node:
233
+ current_node = node_name
234
+ yield f"data: {json.dumps({'type': 'node_transition', 'node': node_name})}\n\n"
235
+
236
+ if event_type == "on_chat_model_stream":
237
+ chunk_content = serialize_ai_message_chunk(event["data"]["chunk"])
238
+ current_content += chunk_content
239
+ yield f"data: {json.dumps({'type': 'content', 'content': chunk_content})}\n\n"
240
+
241
+ elif event_type == "on_chat_model_end":
242
+ yield f"data: {json.dumps({'type': 'content_end'})}\n\n"
243
+
244
+ elif event_type == "on_chain_start" and "retrieval" in node_name:
245
+ country = node_name.replace("_retrieval", "")
246
+ yield f"data: {json.dumps({'type': 'search_start', 'country': country})}\n\n"
247
 
248
+ elif event_type == "on_chain_end" and "retrieval" in node_name:
249
+ country = node_name.replace("_retrieval", "")
250
+ yield f"data: {json.dumps({'type': 'search_end', 'country': country})}\n\n"
251
+
252
+ elif event_type == "on_tool_end":
253
+ tool_name = event["name"]
254
+ yield f"data: {json.dumps({'type': 'tool_complete', 'tool': tool_name})}\n\n"
255
 
256
+ elif event_type == "on_graph_end":
257
+ # Capture and convert the final state
258
+ state = event.get("data", {}).get("state")
259
+ if state and isinstance(state, MultiCountryLegalState):
260
+ final_state = state
261
+ state_dict = state.model_dump() if hasattr(state, "model_dump") else state.dict()
262
+ if "messages" in state_dict and isinstance(state_dict["messages"], list):
263
+ state_dict["messages"] = [
264
+ msg if isinstance(msg, dict) else {"role": "unknown", "content": str(msg), "meta": {}}
265
+ for msg in state_dict["messages"]
266
+ ]
267
+ yield f"data: {json.dumps({'type': 'state', 'content': state_dict})}\n\n"
268
+ yield f"data: {json.dumps({'type': 'graph_end'})}\n\n"
 
 
 
 
 
 
269
 
270
+ # Yield final state if captured
271
+ if final_state and isinstance(final_state, MultiCountryLegalState):
272
+ final_state_dict = final_state.model_dump() if hasattr(final_state, "model_dump") else final_state.dict()
273
+ if "messages" in final_state_dict and isinstance(final_state_dict["messages"], list):
274
+ final_state_dict["messages"] = [
275
+ msg if isinstance(msg, dict) else {"role": "unknown", "content": str(msg), "meta": {}}
276
+ for msg in final_state_dict["messages"]
277
+ ]
278
+ yield f"data: {json.dumps({'type': 'final_state', 'content': final_state_dict})}\n\n"
279
 
 
 
 
 
280
  yield f"data: {json.dumps({'type': 'end'})}\n\n"
281
 
282
+
283
  @app.get("/chat")
284
  async def chat_stream(
285
  message: str = Query(..., description="User message"),