deploy_chatbot_demo / chatbot_engine.py
NguyenThanh1405's picture
Deploy CQL Chatbot (without large files)
4cfe4fa
"""
CQL Chatbot Engine - Main chatbot logic integrating CQL agent with multi-agent system
"""
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'Conservative Q-learning'))
import torch
import numpy as np
from transformers import T5Tokenizer, T5EncoderModel
from typing import Dict, List, Tuple
import config
from memory_manager import MemoryManager
from cql_agent import CQLAgent
from communication_agent import CommunicationAgent
from drawing_agent import DrawingAgent
class CQLChatbot:
def __init__(self, model_path: str = None, memory_manager: MemoryManager = None):
"""
Initialize CQL Chatbot with all components
Args:
model_path: Path to saved CQL model
memory_manager: Memory manager instance for conversation storage
"""
print("🚀 Initializing CQL Chatbot System...")
# Set device
self.device = torch.device(config.DEVICE if torch.cuda.is_available() else 'cpu')
print(f"📱 Using device: {self.device}")
# Load T5 Encoder for text embedding
print("📚 Loading T5 encoder...")
self.tokenizer = T5Tokenizer.from_pretrained(config.T5_MODEL_NAME)
self.encoder = T5EncoderModel.from_pretrained(config.T5_MODEL_NAME).to(self.device)
self.encoder.eval() # Set to evaluation mode
print("✅ T5 encoder loaded")
# Load CQL Agent (Decision Maker)
print("🧠 Loading CQL agent...")
self.cql_agent = CQLAgent(
state_dim=config.STATE_DIM,
action_dim=config.ACTION_DIM,
is_continuous=False,
device=self.device
)
# Load trained model
model_path = model_path or str(config.MODEL_PATH)
self.cql_agent.load_model(model_path)
print("✅ CQL agent loaded")
# Initialize sub-agents
print("👥 Initializing sub-agents...")
self.communication_agent = CommunicationAgent()
self.drawing_agent = DrawingAgent()
print("✅ All agents initialized")
# Memory manager
self.memory_manager = memory_manager
if self.memory_manager:
print("💾 Memory manager enabled")
# Conversation history
self.conversation_history = []
print("🎉 CQL Chatbot System ready!\n")
def encode_text(self, text: str) -> np.ndarray:
"""
Encode text into T5 embedding
Args:
text: Input text
Returns:
Embedding vector (768-dim)
"""
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.encoder(**inputs)
# Use mean pooling over sequence
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().flatten()
return embedding
def get_action(self, text: str) -> Tuple[int, np.ndarray]:
"""
Get CQL agent's decision for the input text
Args:
text: User input text
Returns:
Tuple of (action_index, q_values)
"""
# Encode text to embedding
embedding = self.encode_text(text)
# Get action from CQL agent
action = self.cql_agent.select_action(embedding, evaluate=True)
# Get Q-values for all actions (for visualization)
state = self.cql_agent.normalizer.normalize(embedding)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
q_values = self.cql_agent.critic_1(state_tensor).cpu().numpy().flatten()
return action, q_values
def chat(self, user_message: str, temperature: float = 0.7) -> Dict:
"""
Main chat function - processes user message and generates response
Args:
user_message: User's input message
temperature: Response creativity
Returns:
Dictionary containing response and metadata
"""
# Get CQL agent's decision
action, q_values = self.get_action(user_message)
action_name = config.ACTION_MAPPING[action]
print(f"\n🤖 CQL Decision: {action_name} (Action {action})")
print(f"📊 Q-values: {q_values}")
# Initialize variables
response_text = ""
image_path = None
# IMPROVED LOGIC: Check for drawing keywords FIRST
# Override CQL decision if drawing keywords detected
drawing_keywords = ['vẽ', 'sketch', 'phác thảo', 'hình', 'ảnh', 'tranh', 'draw', 'paint', 'create image', 'generate']
is_drawing_request = any(keyword in user_message.lower() for keyword in drawing_keywords)
# Force Drawing Agent if keywords detected
if is_drawing_request:
print("🎨 Drawing keywords detected! Forcing Drawing Agent.")
action = 1
action_name = config.ACTION_MAPPING[1]
# Execute based on final action
if action == 0: # Communication Agent
response_text = self.communication_agent.generate_response(
user_message,
self.conversation_history,
temperature
)
elif action == 1: # Drawing Agent
response_text, image_path = self.drawing_agent.generate_sketch(user_message)
elif action == 2: # Clarification - fallback to Communication
print("⚠️ CQL suggested Clarification. Using Communication Agent.")
response_text = self.communication_agent.generate_response(
user_message,
self.conversation_history,
temperature
)
action = 0
action_name = config.ACTION_MAPPING[0]
# Update conversation history
self.conversation_history.append({
'role': 'user',
'content': user_message
})
self.conversation_history.append({
'role': 'assistant',
'content': response_text,
'action': action,
'action_name': action_name
})
# Limit history length
if len(self.conversation_history) > config.MAX_HISTORY_LENGTH:
self.conversation_history = self.conversation_history[-config.MAX_HISTORY_LENGTH:]
# Save to memory manager if available
if self.memory_manager:
self.memory_manager.save_message('user', user_message)
self.memory_manager.save_message(
'assistant',
response_text,
{
'action': action,
'action_name': action_name,
'q_values': q_values.tolist()
}
)
return {
'response': response_text,
'action': action,
'action_name': action_name,
'q_values': q_values.tolist(),
'image_path': image_path
}
def _generate_clarification_request(self, user_message: str) -> str:
"""Generate a clarification request when input is unclear"""
clarifications = [
f"Xin lỗi, tôi chưa hiểu rõ yêu cầu của bạn: '{user_message}'. Bạn có thể nói rõ hơn được không?",
f"Tôi cần thêm thông tin để hiểu câu hỏi của bạn. Bạn muốn tôi làm gì với: '{user_message}'?",
f"Câu hỏi của bạn chưa rõ ràng. Bạn có thể diễn đạt lại không?",
f"Hmm, tôi không chắc bạn đang hỏi gì. Bạn có thể cung cấp thêm chi tiết không?"
]
import random
return random.choice(clarifications)
def clear_history(self):
"""Clear conversation history"""
self.conversation_history = []
print("🗑️ Conversation history cleared")
def get_action_distribution(self) -> Dict[str, int]:
"""Get distribution of actions taken in current conversation"""
distribution = {name: 0 for name in config.ACTION_MAPPING.values()}
for msg in self.conversation_history:
if msg.get('role') == 'assistant' and 'action_name' in msg:
action_name = msg['action_name']
distribution[action_name] = distribution.get(action_name, 0) + 1
return distribution