ABAO77's picture
Upload 107 files
16d5a75 verified
import json
import uuid
from typing import Dict, List, Any, Optional, Tuple
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from src.utils.logger import logger
from .data import State, UserProfile, convert_to_langchain_messages
from .prompt import (
DEFAULT_SYSTEM_PROMPT,
ANALYZE_REQUEST_TEMPLATE,
GENERATE_PROBING_QUESTIONS_TEMPLATE,
UPDATE_SYSTEM_PROMPT_TEMPLATE,
RESPONSE_TEMPLATE,
CREATE_USER_PROFILE_TEMPLATE,
)
# Initialize LLM - use whatever is available in your environment
try:
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
temperature=0.2,
verbose=True,
)
except Exception:
try:
llm = ChatOpenAI(
model="gpt-3.5-turbo",
temperature=0.2,
verbose=True,
)
except Exception as e:
logger.error(f"Failed to initialize LLM: {e}")
raise
def initialize_state(state: State) -> State:
"""
Initialize the state with default values if they are not present.
Args:
state: Current state
Returns:
Updated state with default values
"""
if "session_id" not in state or not state["session_id"]:
state["session_id"] = str(uuid.uuid4())
if "messages_history" not in state or not state["messages_history"]:
state["messages_history"] = []
if "current_system_prompt" not in state or not state["current_system_prompt"]:
state["current_system_prompt"] = DEFAULT_SYSTEM_PROMPT
if "user_profile" not in state or not state["user_profile"]:
state["user_profile"] = {}
if "messages" not in state:
state["messages"] = []
return state
async def analyze_user_request(state: State) -> State:
"""
Analyze the user's request to determine intent and whether we need to update the prompt.
Args:
state: Current state
Returns:
Updated state with analysis results
"""
try:
# Prepare message history
history = convert_to_langchain_messages(state["messages_history"])
# Build the prompt
prompt = ANALYZE_REQUEST_TEMPLATE
# Call the LLM
response = await llm.ainvoke(
prompt.format_messages(
history=history,
current_system_prompt=state["current_system_prompt"],
user_message=state["user_message"],
)
)
# Parse the JSON response
try:
# Clean up the response content to handle potential formatting issues
content = response.content.strip()
# Some models might return the JSON with code formatting markers
if "```json" in content:
content = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
content = content.split("```")[1].strip()
# Parse the JSON
analysis_result = json.loads(content)
state["analysis_result"] = analysis_result
state["prompt_needs_update"] = analysis_result.get("prompt_needs_update", False)
state["probing_questions_needed"] = analysis_result.get("probing_questions_needed", False)
logger.info(f"Analysis complete: {analysis_result}")
return state
except json.JSONDecodeError as e:
logger.error(f"Failed to parse analysis result: {e}")
logger.error(f"Raw response: {response.content}")
# Try to extract a JSON object from the response if it exists
try:
import re
json_match = re.search(r'({[\s\S]*})', response.content)
if json_match:
json_str = json_match.group(1)
analysis_result = json.loads(json_str)
state["analysis_result"] = analysis_result
state["prompt_needs_update"] = analysis_result.get("prompt_needs_update", False)
state["probing_questions_needed"] = analysis_result.get("probing_questions_needed", False)
logger.info(f"Successfully extracted JSON with regex: {analysis_result}")
return state
except Exception:
# If regex extraction fails, use the default fallback
pass
# Fallback analysis result
state["analysis_result"] = {
"intent": "unknown",
"keywords": [],
"prompt_needs_update": True,
"probing_questions_needed": False,
"confidence": 0.0,
"reasoning": "Failed to parse the analysis result"
}
return state
except Exception as e:
logger.error(f"Error in analyze_user_request: {e}")
state["analysis_result"] = {
"intent": "unknown",
"keywords": [],
"prompt_needs_update": True,
"probing_questions_needed": False,
"confidence": 0.0,
"reasoning": f"Error during analysis: {str(e)}"
}
return state
async def generate_probing_questions(state: State) -> State:
"""
Generate probing questions to better understand the user.
Args:
state: Current state
Returns:
Updated state with probing questions
"""
try:
# Prepare message history
history = convert_to_langchain_messages(state["messages_history"])
# Build the prompt
prompt = GENERATE_PROBING_QUESTIONS_TEMPLATE
# Call the LLM
response = await llm.ainvoke(
prompt.format_messages(
history=history,
user_message=state["user_message"],
analysis_result=state["analysis_result"],
)
)
# Parse the JSON response
try:
questions = json.loads(response.content)
if isinstance(questions, list):
state["probing_questions"] = questions
logger.info(f"Generated probing questions: {questions}")
else:
logger.error(f"Invalid format for probing questions: {questions}")
state["probing_questions"] = ["Bạn có thể chia sẻ thêm về nhu cầu của bạn không?"]
return state
except json.JSONDecodeError as e:
logger.error(f"Failed to parse probing questions: {e}")
logger.error(f"Raw response: {response.content}")
state["probing_questions"] = ["Bạn có thể chia sẻ thêm về nhu cầu của bạn không?"]
return state
except Exception as e:
logger.error(f"Error in generate_probing_questions: {e}")
state["probing_questions"] = ["Bạn có thể chia sẻ thêm về nhu cầu của bạn không?"]
return state
async def update_user_profile(state: State) -> State:
"""
Update the user profile based on the current conversation.
Args:
state: Current state
Returns:
Updated state with updated user profile
"""
try:
# Prepare message history
history = convert_to_langchain_messages(state["messages_history"])
# Build the prompt
prompt = CREATE_USER_PROFILE_TEMPLATE
# Prepare probing answers (if any)
probing_answers = "Không có" # Default
# Call the LLM
response = await llm.ainvoke(
prompt.format_messages(
history=history,
current_profile=state["user_profile"],
user_message=state["user_message"],
probing_answers=probing_answers,
)
)
# Parse the JSON response
try:
profile_updates = json.loads(response.content)
if isinstance(profile_updates, dict):
# Update the existing profile
if not state["user_profile"]:
state["user_profile"] = {}
state["user_profile"].update(profile_updates)
logger.info(f"Updated user profile: {profile_updates}")
else:
logger.error(f"Invalid format for user profile: {profile_updates}")
return state
except json.JSONDecodeError as e:
logger.error(f"Failed to parse user profile: {e}")
logger.error(f"Raw response: {response.content}")
return state
except Exception as e:
logger.error(f"Error in update_user_profile: {e}")
return state
async def update_system_prompt(state: State) -> State:
"""
Update the system prompt based on the user's request and profile.
Args:
state: Current state
Returns:
Updated state with new system prompt
"""
try:
# Prepare message history
history = convert_to_langchain_messages(state["messages_history"])
# Build the prompt
prompt = UPDATE_SYSTEM_PROMPT_TEMPLATE
# Call the LLM
response = await llm.ainvoke(
prompt.format_messages(
history=history,
current_system_prompt=state["current_system_prompt"],
user_message=state["user_message"],
user_profile=state["user_profile"],
analysis_result=state["analysis_result"],
)
)
# Get the new system prompt
new_system_prompt = response.content.strip()
state["updated_system_prompt"] = new_system_prompt
state["final_system_prompt"] = new_system_prompt # Also set as final
logger.info(f"Updated system prompt: {new_system_prompt}")
return state
except Exception as e:
logger.error(f"Error in update_system_prompt: {e}")
state["updated_system_prompt"] = state["current_system_prompt"] # Keep current
state["final_system_prompt"] = state["current_system_prompt"] # Keep current
return state
async def generate_bot_response(state: State) -> State:
"""
Generate the bot's response to the user's message.
Args:
state: Current state
Returns:
Updated state with bot's response
"""
try:
# Prepare message history
history = convert_to_langchain_messages(state["messages_history"])
# Determine which system prompt to use
system_prompt = state.get("final_system_prompt", state["current_system_prompt"])
# Build the prompt
prompt = RESPONSE_TEMPLATE
# Call the LLM
response = await llm.ainvoke(
prompt.format_messages(
history=history,
system_prompt=system_prompt,
user_message=state["user_message"],
)
)
# Get the bot's response
bot_message = response.content
state["bot_message"] = bot_message
logger.info(f"Generated bot response: {bot_message[:100]}...")
return state
except Exception as e:
logger.error(f"Error in generate_bot_response: {e}")
state["bot_message"] = "Xin lỗi, tôi đang gặp vấn đề kỹ thuật. Vui lòng thử lại sau."
return state
async def process_return_value(state: State) -> State:
"""
Process the final state to prepare the return value.
Args:
state: Current state
Returns:
Final state with processed return values
"""
# Update the message history with the new messages
if "bot_message" in state and state["bot_message"]:
# Add the user message to history if not present
user_msg_in_history = False
if state["messages_history"]:
last_msg = state["messages_history"][-1]
# Check if last_msg is a ChatMessage object or a dictionary
if hasattr(last_msg, 'type'): # It's a Pydantic model
user_msg_in_history = (last_msg.type == "human" and last_msg.content == state["user_message"])
else: # It's a dictionary
user_msg_in_history = (last_msg["type"] == "human" and last_msg["content"] == state["user_message"])
if not user_msg_in_history:
state["messages_history"].append({
"content": state["user_message"],
"type": "human"
})
# Add the bot message to history
state["messages_history"].append({
"content": state["bot_message"],
"type": "ai"
})
return state
def trim_history(state: State, max_history: int = 20) -> State:
"""
Trim the message history to avoid it getting too long.
Args:
state: Current state
max_history: Maximum number of messages to keep
Returns:
State with trimmed history
"""
if len(state["messages_history"]) > max_history:
# Keep only the last max_history messages
state["messages_history"] = state["messages_history"][-max_history:]
return state