File size: 4,035 Bytes
80dbe44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Input Node - Node đầu tiên để nhận tin nhắn từ người dùng với conversation cache
"""
import json
import logging
from pathlib import Path
from src.state.graph_state import TransportationState
from src.config.logging_config import get_logger

logger = get_logger(__name__)

class InputNode:
    
    def __init__(self):
        self.name = "input_node"
        self.cache_file = Path(__file__).resolve().parents[3] / "conversation_cache.json"
    
    def load_conversation_history(self) -> list:
        """Load conversation history từ cache file"""
        try:
            if self.cache_file.exists():
                with open(self.cache_file, 'r', encoding='utf-8') as f:
                    conversations = json.load(f)
                    return conversations[-10:]  # Last 10 conversations
            return []
        except Exception as e:
            logger.error(f"Error loading conversation history: {e}")
            return []
    
    def get_context_from_history(self, conversations: list) -> str:
        """Generate context string từ conversation history với format role/content"""
        if not conversations:
            return ""
        
        context = "Previous conversations for context:\n"
        # Group messages by pairs (user, assistant)
        messages = []
        for i in range(0, len(conversations), 2):
            if i + 1 < len(conversations):
                user_msg = conversations[i]
                assistant_msg = conversations[i + 1]
                if user_msg.get('role') == 'user' and assistant_msg.get('role') == 'assistant':
                    messages.append((user_msg.get('content', ''), assistant_msg.get('content', '')))
        
        # Show last 3 conversation pairs
        for i, (user_content, assistant_content) in enumerate(messages[-3:], 1):
            user_short = user_content[:50] + "..." if len(user_content) > 50 else user_content
            assistant_short = assistant_content[:50] + "..." if len(assistant_content) > 50 else assistant_content
            context += f"{i}. User: {user_short} → AI: {assistant_short}\n"
        
        return context + "\nCurrent conversation:\n"
    
    def process_input(self, state: TransportationState) -> TransportationState:
        """
        Nhận tin nhắn từ state, load conversation history để có context
        
        Args:
            state: Current state với user_message
            
        Returns:
            Updated state với message và context
        """
        user_message = state["user_message"]
        logger.info(f"Received user message: {user_message[:100]}...")
        
        try:
            # Load conversation history
            conversation_history = self.load_conversation_history()
            state["conversation_cache"] = conversation_history
            
            # Lưu message vào state (đã có sẵn)
            state["user_message"] = user_message.strip()
            state["current_step"] = "llm_processing"
            state["error_message"] = None
            
            # Add context cho LLM nếu có history
            if conversation_history:
                context = self.get_context_from_history(conversation_history)
                state["user_message"] = context + user_message
                logger.info(f"Added context from {len(conversation_history)} previous conversations")
            
            logger.info("User message with context saved to state successfully")
            
        except Exception as e:
            error_msg = f"Error processing user input: {str(e)}"
            logger.error(error_msg)
            state["error_message"] = error_msg
        
        return state
    
    def __call__(self, state: TransportationState) -> TransportationState:
        """Callable interface cho LangGraph"""
        return self.process_input(state)

# Factory function để tạo input node
def create_input_node() -> InputNode:
    """Tạo input node instance"""
    return InputNode()