Spaces:
Sleeping
Sleeping
| """ | |
| CQL Chatbot Engine - Main chatbot logic integrating CQL agent with multi-agent system | |
| """ | |
| import sys | |
| import os | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'Conservative Q-learning')) | |
| import torch | |
| import numpy as np | |
| from transformers import T5Tokenizer, T5EncoderModel | |
| from typing import Dict, List, Tuple | |
| import config | |
| from memory_manager import MemoryManager | |
| from cql_agent import CQLAgent | |
| from communication_agent import CommunicationAgent | |
| from drawing_agent import DrawingAgent | |
| class CQLChatbot: | |
| def __init__(self, model_path: str = None, memory_manager: MemoryManager = None): | |
| """ | |
| Initialize CQL Chatbot with all components | |
| Args: | |
| model_path: Path to saved CQL model | |
| memory_manager: Memory manager instance for conversation storage | |
| """ | |
| print("🚀 Initializing CQL Chatbot System...") | |
| # Set device | |
| self.device = torch.device(config.DEVICE if torch.cuda.is_available() else 'cpu') | |
| print(f"📱 Using device: {self.device}") | |
| # Load T5 Encoder for text embedding | |
| print("📚 Loading T5 encoder...") | |
| self.tokenizer = T5Tokenizer.from_pretrained(config.T5_MODEL_NAME) | |
| self.encoder = T5EncoderModel.from_pretrained(config.T5_MODEL_NAME).to(self.device) | |
| self.encoder.eval() # Set to evaluation mode | |
| print("✅ T5 encoder loaded") | |
| # Load CQL Agent (Decision Maker) | |
| print("🧠 Loading CQL agent...") | |
| self.cql_agent = CQLAgent( | |
| state_dim=config.STATE_DIM, | |
| action_dim=config.ACTION_DIM, | |
| is_continuous=False, | |
| device=self.device | |
| ) | |
| # Load trained model | |
| model_path = model_path or str(config.MODEL_PATH) | |
| self.cql_agent.load_model(model_path) | |
| print("✅ CQL agent loaded") | |
| # Initialize sub-agents | |
| print("👥 Initializing sub-agents...") | |
| self.communication_agent = CommunicationAgent() | |
| self.drawing_agent = DrawingAgent() | |
| print("✅ All agents initialized") | |
| # Memory manager | |
| self.memory_manager = memory_manager | |
| if self.memory_manager: | |
| print("💾 Memory manager enabled") | |
| # Conversation history | |
| self.conversation_history = [] | |
| print("🎉 CQL Chatbot System ready!\n") | |
| def encode_text(self, text: str) -> np.ndarray: | |
| """ | |
| Encode text into T5 embedding | |
| Args: | |
| text: Input text | |
| Returns: | |
| Embedding vector (768-dim) | |
| """ | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.encoder(**inputs) | |
| # Use mean pooling over sequence | |
| embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().flatten() | |
| return embedding | |
| def get_action(self, text: str) -> Tuple[int, np.ndarray]: | |
| """ | |
| Get CQL agent's decision for the input text | |
| Args: | |
| text: User input text | |
| Returns: | |
| Tuple of (action_index, q_values) | |
| """ | |
| # Encode text to embedding | |
| embedding = self.encode_text(text) | |
| # Get action from CQL agent | |
| action = self.cql_agent.select_action(embedding, evaluate=True) | |
| # Get Q-values for all actions (for visualization) | |
| state = self.cql_agent.normalizer.normalize(embedding) | |
| state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| q_values = self.cql_agent.critic_1(state_tensor).cpu().numpy().flatten() | |
| return action, q_values | |
| def chat(self, user_message: str, temperature: float = 0.7) -> Dict: | |
| """ | |
| Main chat function - processes user message and generates response | |
| Args: | |
| user_message: User's input message | |
| temperature: Response creativity | |
| Returns: | |
| Dictionary containing response and metadata | |
| """ | |
| # Get CQL agent's decision | |
| action, q_values = self.get_action(user_message) | |
| action_name = config.ACTION_MAPPING[action] | |
| print(f"\n🤖 CQL Decision: {action_name} (Action {action})") | |
| print(f"📊 Q-values: {q_values}") | |
| # Initialize variables | |
| response_text = "" | |
| image_path = None | |
| # IMPROVED LOGIC: Check for drawing keywords FIRST | |
| # Override CQL decision if drawing keywords detected | |
| drawing_keywords = ['vẽ', 'sketch', 'phác thảo', 'hình', 'ảnh', 'tranh', 'draw', 'paint', 'create image', 'generate'] | |
| is_drawing_request = any(keyword in user_message.lower() for keyword in drawing_keywords) | |
| # Force Drawing Agent if keywords detected | |
| if is_drawing_request: | |
| print("🎨 Drawing keywords detected! Forcing Drawing Agent.") | |
| action = 1 | |
| action_name = config.ACTION_MAPPING[1] | |
| # Execute based on final action | |
| if action == 0: # Communication Agent | |
| response_text = self.communication_agent.generate_response( | |
| user_message, | |
| self.conversation_history, | |
| temperature | |
| ) | |
| elif action == 1: # Drawing Agent | |
| response_text, image_path = self.drawing_agent.generate_sketch(user_message) | |
| elif action == 2: # Clarification - fallback to Communication | |
| print("⚠️ CQL suggested Clarification. Using Communication Agent.") | |
| response_text = self.communication_agent.generate_response( | |
| user_message, | |
| self.conversation_history, | |
| temperature | |
| ) | |
| action = 0 | |
| action_name = config.ACTION_MAPPING[0] | |
| # Update conversation history | |
| self.conversation_history.append({ | |
| 'role': 'user', | |
| 'content': user_message | |
| }) | |
| self.conversation_history.append({ | |
| 'role': 'assistant', | |
| 'content': response_text, | |
| 'action': action, | |
| 'action_name': action_name | |
| }) | |
| # Limit history length | |
| if len(self.conversation_history) > config.MAX_HISTORY_LENGTH: | |
| self.conversation_history = self.conversation_history[-config.MAX_HISTORY_LENGTH:] | |
| # Save to memory manager if available | |
| if self.memory_manager: | |
| self.memory_manager.save_message('user', user_message) | |
| self.memory_manager.save_message( | |
| 'assistant', | |
| response_text, | |
| { | |
| 'action': action, | |
| 'action_name': action_name, | |
| 'q_values': q_values.tolist() | |
| } | |
| ) | |
| return { | |
| 'response': response_text, | |
| 'action': action, | |
| 'action_name': action_name, | |
| 'q_values': q_values.tolist(), | |
| 'image_path': image_path | |
| } | |
| def _generate_clarification_request(self, user_message: str) -> str: | |
| """Generate a clarification request when input is unclear""" | |
| clarifications = [ | |
| f"Xin lỗi, tôi chưa hiểu rõ yêu cầu của bạn: '{user_message}'. Bạn có thể nói rõ hơn được không?", | |
| f"Tôi cần thêm thông tin để hiểu câu hỏi của bạn. Bạn muốn tôi làm gì với: '{user_message}'?", | |
| f"Câu hỏi của bạn chưa rõ ràng. Bạn có thể diễn đạt lại không?", | |
| f"Hmm, tôi không chắc bạn đang hỏi gì. Bạn có thể cung cấp thêm chi tiết không?" | |
| ] | |
| import random | |
| return random.choice(clarifications) | |
| def clear_history(self): | |
| """Clear conversation history""" | |
| self.conversation_history = [] | |
| print("🗑️ Conversation history cleared") | |
| def get_action_distribution(self) -> Dict[str, int]: | |
| """Get distribution of actions taken in current conversation""" | |
| distribution = {name: 0 for name in config.ACTION_MAPPING.values()} | |
| for msg in self.conversation_history: | |
| if msg.get('role') == 'assistant' and 'action_name' in msg: | |
| action_name = msg['action_name'] | |
| distribution[action_name] = distribution.get(action_name, 0) + 1 | |
| return distribution | |