Spaces:
No application file
No application file
| from fastapi import APIRouter, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| import json | |
| import asyncio | |
| import logging | |
| from src.domain.transportation_graph import create_transportation_workflow | |
| from src.state.graph_state import TransportationState | |
| from src.app.schema.chat import ChatRequest, StreamChunk, FunctionCall | |
| from src.config.logging_config import get_logger | |
| logger = get_logger(__name__) | |
| router = APIRouter() | |
| async def chat_stream(request: ChatRequest): | |
| """Streaming chat endpoint sử dụng LangGraph native streaming với LangChain model""" | |
| try: | |
| async def generate_stream(): | |
| """Generate streaming từ LangGraph với stream_mode='messages'""" | |
| # Tạo graph workflow | |
| workflow = create_transportation_workflow() | |
| graph = workflow.compile() | |
| # Tạo initial state với user message | |
| from src.state.graph_state import create_initial_state | |
| initial_state = create_initial_state(request.message) | |
| # Emit status starting | |
| chunk = StreamChunk( | |
| event="status", | |
| stage="starting", | |
| message="Đang xử lý câu hỏi..." | |
| ) | |
| yield f"data: {chunk.model_dump_json()}\n\n" | |
| try: | |
| accumulated_text = "" | |
| final_state = None | |
| logger.info("Using LangGraph native streaming with stream_mode='messages'") | |
| # Sử dụng LangGraph streaming | |
| async for event in graph.astream(initial_state, stream_mode="messages"): | |
| logger.debug(f"Graph event: {type(event)}") | |
| # Check for message content from LLM | |
| if isinstance(event, tuple) and len(event) == 2: | |
| content_obj, metadata = event | |
| # Check if this is from llm node | |
| if (hasattr(content_obj, 'content') and | |
| content_obj.content and | |
| metadata.get('langgraph_node') == 'llm'): | |
| content = content_obj.content | |
| accumulated_text += content | |
| chunk_obj = StreamChunk( | |
| event="delta", | |
| content=content | |
| ) | |
| yield f"data: {chunk_obj.model_dump_json()}\n\n" | |
| await asyncio.sleep(0.02) | |
| logger.info(f"Graph streaming completed. Accumulated: {len(accumulated_text)} chars") | |
| # Nếu không có streaming content, fallback to invoke | |
| if not accumulated_text.strip(): | |
| logger.info("No streaming content, falling back to invoke") | |
| final_state = graph.invoke(initial_state) | |
| if 'ai_response' in final_state and final_state['ai_response']: | |
| # Simulate streaming effect | |
| words = final_state['ai_response'].split() | |
| for word in words: | |
| accumulated_text += word + " " | |
| chunk_obj = StreamChunk( | |
| event="delta", | |
| content=word + " " | |
| ) | |
| yield f"data: {chunk_obj.model_dump_json()}\n\n" | |
| await asyncio.sleep(0.05) | |
| accumulated_text = accumulated_text.strip() | |
| else: | |
| accumulated_text = "Đã xảy ra lỗi khi xử lý yêu cầu." | |
| else: | |
| # Get final state for function calls and cache | |
| try: | |
| final_state = graph.invoke(initial_state) | |
| except Exception as e: | |
| logger.warning(f"Graph invoke for final state failed: {e}") | |
| # Include function calls nếu có (tạm thời bỏ qua để đơn giản) | |
| # if final_state and 'function_calls_made' in final_state and final_state['function_calls_made']: | |
| # pass # Có thể thêm function calls event sau | |
| final_response = accumulated_text if accumulated_text.strip() else "Không có phản hồi" | |
| except Exception as e: | |
| logger.error(f"Graph streaming error: {e}") | |
| chunk = StreamChunk( | |
| event="status", | |
| stage="error", | |
| message=f"Streaming Error: {str(e)}" | |
| ) | |
| yield f"data: {chunk.model_dump_json()}\n\n" | |
| final_response = "Đã xảy ra lỗi trong quá trình xử lý" | |
| # Kết thúc stream | |
| final_chunk = StreamChunk( | |
| event="final", | |
| content=final_response | |
| ) | |
| yield f"data: {final_chunk.model_dump_json()}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Headers": "Cache-Control" | |
| } | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}") | |