""" CQL Chatbot Engine - Main chatbot logic integrating CQL agent with multi-agent system """ import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), 'Conservative Q-learning')) import torch import numpy as np from transformers import T5Tokenizer, T5EncoderModel from typing import Dict, List, Tuple import config from memory_manager import MemoryManager from cql_agent import CQLAgent from communication_agent import CommunicationAgent from drawing_agent import DrawingAgent class CQLChatbot: def __init__(self, model_path: str = None, memory_manager: MemoryManager = None): """ Initialize CQL Chatbot with all components Args: model_path: Path to saved CQL model memory_manager: Memory manager instance for conversation storage """ print("🚀 Initializing CQL Chatbot System...") # Set device self.device = torch.device(config.DEVICE if torch.cuda.is_available() else 'cpu') print(f"📱 Using device: {self.device}") # Load T5 Encoder for text embedding print("📚 Loading T5 encoder...") self.tokenizer = T5Tokenizer.from_pretrained(config.T5_MODEL_NAME) self.encoder = T5EncoderModel.from_pretrained(config.T5_MODEL_NAME).to(self.device) self.encoder.eval() # Set to evaluation mode print("✅ T5 encoder loaded") # Load CQL Agent (Decision Maker) print("🧠 Loading CQL agent...") self.cql_agent = CQLAgent( state_dim=config.STATE_DIM, action_dim=config.ACTION_DIM, is_continuous=False, device=self.device ) # Load trained model model_path = model_path or str(config.MODEL_PATH) self.cql_agent.load_model(model_path) print("✅ CQL agent loaded") # Initialize sub-agents print("👥 Initializing sub-agents...") self.communication_agent = CommunicationAgent() self.drawing_agent = DrawingAgent() print("✅ All agents initialized") # Memory manager self.memory_manager = memory_manager if self.memory_manager: print("💾 Memory manager enabled") # Conversation history self.conversation_history = [] print("🎉 CQL Chatbot System ready!\n") def encode_text(self, text: str) -> np.ndarray: """ Encode text into T5 embedding Args: text: Input text Returns: Embedding vector (768-dim) """ inputs = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(self.device) with torch.no_grad(): outputs = self.encoder(**inputs) # Use mean pooling over sequence embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().flatten() return embedding def get_action(self, text: str) -> Tuple[int, np.ndarray]: """ Get CQL agent's decision for the input text Args: text: User input text Returns: Tuple of (action_index, q_values) """ # Encode text to embedding embedding = self.encode_text(text) # Get action from CQL agent action = self.cql_agent.select_action(embedding, evaluate=True) # Get Q-values for all actions (for visualization) state = self.cql_agent.normalizer.normalize(embedding) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) with torch.no_grad(): q_values = self.cql_agent.critic_1(state_tensor).cpu().numpy().flatten() return action, q_values def chat(self, user_message: str, temperature: float = 0.7) -> Dict: """ Main chat function - processes user message and generates response Args: user_message: User's input message temperature: Response creativity Returns: Dictionary containing response and metadata """ # Get CQL agent's decision action, q_values = self.get_action(user_message) action_name = config.ACTION_MAPPING[action] print(f"\n🤖 CQL Decision: {action_name} (Action {action})") print(f"📊 Q-values: {q_values}") # Initialize variables response_text = "" image_path = None # IMPROVED LOGIC: Check for drawing keywords FIRST # Override CQL decision if drawing keywords detected drawing_keywords = ['vẽ', 'sketch', 'phác thảo', 'hình', 'ảnh', 'tranh', 'draw', 'paint', 'create image', 'generate'] is_drawing_request = any(keyword in user_message.lower() for keyword in drawing_keywords) # Force Drawing Agent if keywords detected if is_drawing_request: print("🎨 Drawing keywords detected! Forcing Drawing Agent.") action = 1 action_name = config.ACTION_MAPPING[1] # Execute based on final action if action == 0: # Communication Agent response_text = self.communication_agent.generate_response( user_message, self.conversation_history, temperature ) elif action == 1: # Drawing Agent response_text, image_path = self.drawing_agent.generate_sketch(user_message) elif action == 2: # Clarification - fallback to Communication print("⚠️ CQL suggested Clarification. Using Communication Agent.") response_text = self.communication_agent.generate_response( user_message, self.conversation_history, temperature ) action = 0 action_name = config.ACTION_MAPPING[0] # Update conversation history self.conversation_history.append({ 'role': 'user', 'content': user_message }) self.conversation_history.append({ 'role': 'assistant', 'content': response_text, 'action': action, 'action_name': action_name }) # Limit history length if len(self.conversation_history) > config.MAX_HISTORY_LENGTH: self.conversation_history = self.conversation_history[-config.MAX_HISTORY_LENGTH:] # Save to memory manager if available if self.memory_manager: self.memory_manager.save_message('user', user_message) self.memory_manager.save_message( 'assistant', response_text, { 'action': action, 'action_name': action_name, 'q_values': q_values.tolist() } ) return { 'response': response_text, 'action': action, 'action_name': action_name, 'q_values': q_values.tolist(), 'image_path': image_path } def _generate_clarification_request(self, user_message: str) -> str: """Generate a clarification request when input is unclear""" clarifications = [ f"Xin lỗi, tôi chưa hiểu rõ yêu cầu của bạn: '{user_message}'. Bạn có thể nói rõ hơn được không?", f"Tôi cần thêm thông tin để hiểu câu hỏi của bạn. Bạn muốn tôi làm gì với: '{user_message}'?", f"Câu hỏi của bạn chưa rõ ràng. Bạn có thể diễn đạt lại không?", f"Hmm, tôi không chắc bạn đang hỏi gì. Bạn có thể cung cấp thêm chi tiết không?" ] import random return random.choice(clarifications) def clear_history(self): """Clear conversation history""" self.conversation_history = [] print("🗑️ Conversation history cleared") def get_action_distribution(self) -> Dict[str, int]: """Get distribution of actions taken in current conversation""" distribution = {name: 0 for name in config.ACTION_MAPPING.values()} for msg in self.conversation_history: if msg.get('role') == 'assistant' and 'action_name' in msg: action_name = msg['action_name'] distribution[action_name] = distribution.get(action_name, 0) + 1 return distribution