Spaces:
Sleeping
Sleeping
File size: 8,782 Bytes
4cfe4fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | """
CQL Chatbot Engine - Main chatbot logic integrating CQL agent with multi-agent system
"""
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'Conservative Q-learning'))
import torch
import numpy as np
from transformers import T5Tokenizer, T5EncoderModel
from typing import Dict, List, Tuple
import config
from memory_manager import MemoryManager
from cql_agent import CQLAgent
from communication_agent import CommunicationAgent
from drawing_agent import DrawingAgent
class CQLChatbot:
def __init__(self, model_path: str = None, memory_manager: MemoryManager = None):
"""
Initialize CQL Chatbot with all components
Args:
model_path: Path to saved CQL model
memory_manager: Memory manager instance for conversation storage
"""
print("🚀 Initializing CQL Chatbot System...")
# Set device
self.device = torch.device(config.DEVICE if torch.cuda.is_available() else 'cpu')
print(f"📱 Using device: {self.device}")
# Load T5 Encoder for text embedding
print("📚 Loading T5 encoder...")
self.tokenizer = T5Tokenizer.from_pretrained(config.T5_MODEL_NAME)
self.encoder = T5EncoderModel.from_pretrained(config.T5_MODEL_NAME).to(self.device)
self.encoder.eval() # Set to evaluation mode
print("✅ T5 encoder loaded")
# Load CQL Agent (Decision Maker)
print("🧠 Loading CQL agent...")
self.cql_agent = CQLAgent(
state_dim=config.STATE_DIM,
action_dim=config.ACTION_DIM,
is_continuous=False,
device=self.device
)
# Load trained model
model_path = model_path or str(config.MODEL_PATH)
self.cql_agent.load_model(model_path)
print("✅ CQL agent loaded")
# Initialize sub-agents
print("👥 Initializing sub-agents...")
self.communication_agent = CommunicationAgent()
self.drawing_agent = DrawingAgent()
print("✅ All agents initialized")
# Memory manager
self.memory_manager = memory_manager
if self.memory_manager:
print("💾 Memory manager enabled")
# Conversation history
self.conversation_history = []
print("🎉 CQL Chatbot System ready!\n")
def encode_text(self, text: str) -> np.ndarray:
"""
Encode text into T5 embedding
Args:
text: Input text
Returns:
Embedding vector (768-dim)
"""
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.encoder(**inputs)
# Use mean pooling over sequence
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().flatten()
return embedding
def get_action(self, text: str) -> Tuple[int, np.ndarray]:
"""
Get CQL agent's decision for the input text
Args:
text: User input text
Returns:
Tuple of (action_index, q_values)
"""
# Encode text to embedding
embedding = self.encode_text(text)
# Get action from CQL agent
action = self.cql_agent.select_action(embedding, evaluate=True)
# Get Q-values for all actions (for visualization)
state = self.cql_agent.normalizer.normalize(embedding)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
q_values = self.cql_agent.critic_1(state_tensor).cpu().numpy().flatten()
return action, q_values
def chat(self, user_message: str, temperature: float = 0.7) -> Dict:
"""
Main chat function - processes user message and generates response
Args:
user_message: User's input message
temperature: Response creativity
Returns:
Dictionary containing response and metadata
"""
# Get CQL agent's decision
action, q_values = self.get_action(user_message)
action_name = config.ACTION_MAPPING[action]
print(f"\n🤖 CQL Decision: {action_name} (Action {action})")
print(f"📊 Q-values: {q_values}")
# Initialize variables
response_text = ""
image_path = None
# IMPROVED LOGIC: Check for drawing keywords FIRST
# Override CQL decision if drawing keywords detected
drawing_keywords = ['vẽ', 'sketch', 'phác thảo', 'hình', 'ảnh', 'tranh', 'draw', 'paint', 'create image', 'generate']
is_drawing_request = any(keyword in user_message.lower() for keyword in drawing_keywords)
# Force Drawing Agent if keywords detected
if is_drawing_request:
print("🎨 Drawing keywords detected! Forcing Drawing Agent.")
action = 1
action_name = config.ACTION_MAPPING[1]
# Execute based on final action
if action == 0: # Communication Agent
response_text = self.communication_agent.generate_response(
user_message,
self.conversation_history,
temperature
)
elif action == 1: # Drawing Agent
response_text, image_path = self.drawing_agent.generate_sketch(user_message)
elif action == 2: # Clarification - fallback to Communication
print("⚠️ CQL suggested Clarification. Using Communication Agent.")
response_text = self.communication_agent.generate_response(
user_message,
self.conversation_history,
temperature
)
action = 0
action_name = config.ACTION_MAPPING[0]
# Update conversation history
self.conversation_history.append({
'role': 'user',
'content': user_message
})
self.conversation_history.append({
'role': 'assistant',
'content': response_text,
'action': action,
'action_name': action_name
})
# Limit history length
if len(self.conversation_history) > config.MAX_HISTORY_LENGTH:
self.conversation_history = self.conversation_history[-config.MAX_HISTORY_LENGTH:]
# Save to memory manager if available
if self.memory_manager:
self.memory_manager.save_message('user', user_message)
self.memory_manager.save_message(
'assistant',
response_text,
{
'action': action,
'action_name': action_name,
'q_values': q_values.tolist()
}
)
return {
'response': response_text,
'action': action,
'action_name': action_name,
'q_values': q_values.tolist(),
'image_path': image_path
}
def _generate_clarification_request(self, user_message: str) -> str:
"""Generate a clarification request when input is unclear"""
clarifications = [
f"Xin lỗi, tôi chưa hiểu rõ yêu cầu của bạn: '{user_message}'. Bạn có thể nói rõ hơn được không?",
f"Tôi cần thêm thông tin để hiểu câu hỏi của bạn. Bạn muốn tôi làm gì với: '{user_message}'?",
f"Câu hỏi của bạn chưa rõ ràng. Bạn có thể diễn đạt lại không?",
f"Hmm, tôi không chắc bạn đang hỏi gì. Bạn có thể cung cấp thêm chi tiết không?"
]
import random
return random.choice(clarifications)
def clear_history(self):
"""Clear conversation history"""
self.conversation_history = []
print("🗑️ Conversation history cleared")
def get_action_distribution(self) -> Dict[str, int]:
"""Get distribution of actions taken in current conversation"""
distribution = {name: 0 for name in config.ACTION_MAPPING.values()}
for msg in self.conversation_history:
if msg.get('role') == 'assistant' and 'action_name' in msg:
action_name = msg['action_name']
distribution[action_name] = distribution.get(action_name, 0) + 1
return distribution
|