File size: 5,860 Bytes
80dbe44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
import json
import asyncio
import logging
from src.domain.transportation_graph import create_transportation_workflow
from src.state.graph_state import TransportationState
from src.app.schema.chat import ChatRequest, StreamChunk, FunctionCall
from src.config.logging_config import get_logger

logger = get_logger(__name__)
router = APIRouter()

@router.post("/chat", response_model=None)
async def chat_stream(request: ChatRequest):
    """Streaming chat endpoint sử dụng LangGraph native streaming với LangChain model"""
    try:
        async def generate_stream():
            """Generate streaming từ LangGraph với stream_mode='messages'"""
            # Tạo graph workflow
            workflow = create_transportation_workflow()
            graph = workflow.compile()
            
            # Tạo initial state với user message
            from src.state.graph_state import create_initial_state
            initial_state = create_initial_state(request.message)
            
            # Emit status starting
            chunk = StreamChunk(
                event="status",
                stage="starting",
                message="Đang xử lý câu hỏi..."
            )
            yield f"data: {chunk.model_dump_json()}\n\n"
            
            try:
                accumulated_text = ""
                final_state = None
                
                logger.info("Using LangGraph native streaming with stream_mode='messages'")
                
                # Sử dụng LangGraph streaming
                async for event in graph.astream(initial_state, stream_mode="messages"):
                    logger.debug(f"Graph event: {type(event)}")
                    
                    # Check for message content from LLM
                    if isinstance(event, tuple) and len(event) == 2:
                        content_obj, metadata = event
                        
                        # Check if this is from llm node
                        if (hasattr(content_obj, 'content') and 
                            content_obj.content and 
                            metadata.get('langgraph_node') == 'llm'):
                            
                            content = content_obj.content
                            accumulated_text += content
                            chunk_obj = StreamChunk(
                                event="delta",
                                content=content
                            )
                            yield f"data: {chunk_obj.model_dump_json()}\n\n"
                            await asyncio.sleep(0.02)
                
                logger.info(f"Graph streaming completed. Accumulated: {len(accumulated_text)} chars")
                
                # Nếu không có streaming content, fallback to invoke
                if not accumulated_text.strip():
                    logger.info("No streaming content, falling back to invoke")
                    final_state = graph.invoke(initial_state)
                    
                    if 'ai_response' in final_state and final_state['ai_response']:
                        # Simulate streaming effect
                        words = final_state['ai_response'].split()
                        for word in words:
                            accumulated_text += word + " "
                            chunk_obj = StreamChunk(
                                event="delta", 
                                content=word + " "
                            )
                            yield f"data: {chunk_obj.model_dump_json()}\n\n"
                            await asyncio.sleep(0.05)
                        accumulated_text = accumulated_text.strip()
                    else:
                        accumulated_text = "Đã xảy ra lỗi khi xử lý yêu cầu."
                else:
                    # Get final state for function calls and cache
                    try:
                        final_state = graph.invoke(initial_state)
                    except Exception as e:
                        logger.warning(f"Graph invoke for final state failed: {e}")
                
                # Include function calls nếu có (tạm thời bỏ qua để đơn giản)
                # if final_state and 'function_calls_made' in final_state and final_state['function_calls_made']:
                #     pass  # Có thể thêm function calls event sau
                
                final_response = accumulated_text if accumulated_text.strip() else "Không có phản hồi"
                        
            except Exception as e:
                logger.error(f"Graph streaming error: {e}")
                chunk = StreamChunk(
                    event="status",
                    stage="error",
                    message=f"Streaming Error: {str(e)}"
                )
                yield f"data: {chunk.model_dump_json()}\n\n"
                final_response = "Đã xảy ra lỗi trong quá trình xử lý"
            
            # Kết thúc stream
            final_chunk = StreamChunk(
                event="final",
                content=final_response
            )
            yield f"data: {final_chunk.model_dump_json()}\n\n"
            yield "data: [DONE]\n\n"
        
        return StreamingResponse(
            generate_stream(),
            media_type="text/event-stream",
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
                "Access-Control-Allow-Origin": "*",
                "Access-Control-Allow-Headers": "Cache-Control"
            }
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")