File size: 8,782 Bytes
4cfe4fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""
CQL Chatbot Engine - Main chatbot logic integrating CQL agent with multi-agent system
"""
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'Conservative Q-learning'))

import torch
import numpy as np
from transformers import T5Tokenizer, T5EncoderModel
from typing import Dict, List, Tuple
import config
from memory_manager import MemoryManager
from cql_agent import CQLAgent
from communication_agent import CommunicationAgent
from drawing_agent import DrawingAgent


class CQLChatbot:
    def __init__(self, model_path: str = None, memory_manager: MemoryManager = None):
        """
        Initialize CQL Chatbot with all components
        
        Args:
            model_path: Path to saved CQL model
            memory_manager: Memory manager instance for conversation storage
        """
        print("🚀 Initializing CQL Chatbot System...")
        
        # Set device
        self.device = torch.device(config.DEVICE if torch.cuda.is_available() else 'cpu')
        print(f"📱 Using device: {self.device}")
        
        # Load T5 Encoder for text embedding
        print("📚 Loading T5 encoder...")
        self.tokenizer = T5Tokenizer.from_pretrained(config.T5_MODEL_NAME)
        self.encoder = T5EncoderModel.from_pretrained(config.T5_MODEL_NAME).to(self.device)
        self.encoder.eval()  # Set to evaluation mode
        print("✅ T5 encoder loaded")
        
        # Load CQL Agent (Decision Maker)
        print("🧠 Loading CQL agent...")
        self.cql_agent = CQLAgent(
            state_dim=config.STATE_DIM,
            action_dim=config.ACTION_DIM,
            is_continuous=False,
            device=self.device
        )
        
        # Load trained model
        model_path = model_path or str(config.MODEL_PATH)
        self.cql_agent.load_model(model_path)
        print("✅ CQL agent loaded")
        
        # Initialize sub-agents
        print("👥 Initializing sub-agents...")
        self.communication_agent = CommunicationAgent()
        self.drawing_agent = DrawingAgent()
        print("✅ All agents initialized")
        
        # Memory manager
        self.memory_manager = memory_manager
        if self.memory_manager:
            print("💾 Memory manager enabled")
        
        # Conversation history
        self.conversation_history = []
        
        print("🎉 CQL Chatbot System ready!\n")
    
    def encode_text(self, text: str) -> np.ndarray:
        """
        Encode text into T5 embedding
        
        Args:
            text: Input text
            
        Returns:
            Embedding vector (768-dim)
        """
        inputs = self.tokenizer(
            text, 
            return_tensors="pt", 
            padding=True,
            truncation=True,
            max_length=512
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.encoder(**inputs)
            # Use mean pooling over sequence
            embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().flatten()
        
        return embedding
    
    def get_action(self, text: str) -> Tuple[int, np.ndarray]:
        """
        Get CQL agent's decision for the input text
        
        Args:
            text: User input text
            
        Returns:
            Tuple of (action_index, q_values)
        """
        # Encode text to embedding
        embedding = self.encode_text(text)
        
        # Get action from CQL agent
        action = self.cql_agent.select_action(embedding, evaluate=True)
        
        # Get Q-values for all actions (for visualization)
        state = self.cql_agent.normalizer.normalize(embedding)
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            q_values = self.cql_agent.critic_1(state_tensor).cpu().numpy().flatten()
        
        return action, q_values
    
    def chat(self, user_message: str, temperature: float = 0.7) -> Dict:
        """
        Main chat function - processes user message and generates response
        
        Args:
            user_message: User's input message
            temperature: Response creativity
            
        Returns:
            Dictionary containing response and metadata
        """
        # Get CQL agent's decision
        action, q_values = self.get_action(user_message)
        action_name = config.ACTION_MAPPING[action]
        
        print(f"\n🤖 CQL Decision: {action_name} (Action {action})")
        print(f"📊 Q-values: {q_values}")
        
        # Initialize variables
        response_text = ""
        image_path = None
        
        # IMPROVED LOGIC: Check for drawing keywords FIRST
        # Override CQL decision if drawing keywords detected
        drawing_keywords = ['vẽ', 'sketch', 'phác thảo', 'hình', 'ảnh', 'tranh', 'draw', 'paint', 'create image', 'generate']
        is_drawing_request = any(keyword in user_message.lower() for keyword in drawing_keywords)
        
        # Force Drawing Agent if keywords detected
        if is_drawing_request:
            print("🎨 Drawing keywords detected! Forcing Drawing Agent.")
            action = 1
            action_name = config.ACTION_MAPPING[1]
        
        # Execute based on final action
        if action == 0:  # Communication Agent
            response_text = self.communication_agent.generate_response(
                user_message,
                self.conversation_history,
                temperature
            )
        
        elif action == 1:  # Drawing Agent
            response_text, image_path = self.drawing_agent.generate_sketch(user_message)
        
        elif action == 2:  # Clarification - fallback to Communication
            print("⚠️ CQL suggested Clarification. Using Communication Agent.")
            response_text = self.communication_agent.generate_response(
                user_message,
                self.conversation_history,
                temperature
            )
            action = 0
            action_name = config.ACTION_MAPPING[0]
        
        # Update conversation history
        self.conversation_history.append({
            'role': 'user',
            'content': user_message
        })
        self.conversation_history.append({
            'role': 'assistant',
            'content': response_text,
            'action': action,
            'action_name': action_name
        })
        
        # Limit history length
        if len(self.conversation_history) > config.MAX_HISTORY_LENGTH:
            self.conversation_history = self.conversation_history[-config.MAX_HISTORY_LENGTH:]
        
        # Save to memory manager if available
        if self.memory_manager:
            self.memory_manager.save_message('user', user_message)
            self.memory_manager.save_message(
                'assistant', 
                response_text,
                {
                    'action': action,
                    'action_name': action_name,
                    'q_values': q_values.tolist()
                }
            )
        
        return {
            'response': response_text,
            'action': action,
            'action_name': action_name,
            'q_values': q_values.tolist(),
            'image_path': image_path
        }
    
    def _generate_clarification_request(self, user_message: str) -> str:
        """Generate a clarification request when input is unclear"""
        clarifications = [
            f"Xin lỗi, tôi chưa hiểu rõ yêu cầu của bạn: '{user_message}'. Bạn có thể nói rõ hơn được không?",
            f"Tôi cần thêm thông tin để hiểu câu hỏi của bạn. Bạn muốn tôi làm gì với: '{user_message}'?",
            f"Câu hỏi của bạn chưa rõ ràng. Bạn có thể diễn đạt lại không?",
            f"Hmm, tôi không chắc bạn đang hỏi gì. Bạn có thể cung cấp thêm chi tiết không?"
        ]
        
        import random
        return random.choice(clarifications)
    
    def clear_history(self):
        """Clear conversation history"""
        self.conversation_history = []
        print("🗑️ Conversation history cleared")
    
    def get_action_distribution(self) -> Dict[str, int]:
        """Get distribution of actions taken in current conversation"""
        distribution = {name: 0 for name in config.ACTION_MAPPING.values()}
        
        for msg in self.conversation_history:
            if msg.get('role') == 'assistant' and 'action_name' in msg:
                action_name = msg['action_name']
                distribution[action_name] = distribution.get(action_name, 0) + 1
        
        return distribution