Spaces:
Sleeping
Sleeping
| import json | |
| import uuid | |
| from typing import Dict, List, Any, Optional, Tuple | |
| from langchain_core.messages import HumanMessage, AIMessage, BaseMessage | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_openai import ChatOpenAI | |
| from src.utils.logger import logger | |
| from .data import State, UserProfile, convert_to_langchain_messages | |
| from .prompt import ( | |
| DEFAULT_SYSTEM_PROMPT, | |
| ANALYZE_REQUEST_TEMPLATE, | |
| GENERATE_PROBING_QUESTIONS_TEMPLATE, | |
| UPDATE_SYSTEM_PROMPT_TEMPLATE, | |
| RESPONSE_TEMPLATE, | |
| CREATE_USER_PROFILE_TEMPLATE, | |
| ) | |
| # Initialize LLM - use whatever is available in your environment | |
| try: | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash", | |
| temperature=0.2, | |
| verbose=True, | |
| ) | |
| except Exception: | |
| try: | |
| llm = ChatOpenAI( | |
| model="gpt-3.5-turbo", | |
| temperature=0.2, | |
| verbose=True, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize LLM: {e}") | |
| raise | |
| def initialize_state(state: State) -> State: | |
| """ | |
| Initialize the state with default values if they are not present. | |
| Args: | |
| state: Current state | |
| Returns: | |
| Updated state with default values | |
| """ | |
| if "session_id" not in state or not state["session_id"]: | |
| state["session_id"] = str(uuid.uuid4()) | |
| if "messages_history" not in state or not state["messages_history"]: | |
| state["messages_history"] = [] | |
| if "current_system_prompt" not in state or not state["current_system_prompt"]: | |
| state["current_system_prompt"] = DEFAULT_SYSTEM_PROMPT | |
| if "user_profile" not in state or not state["user_profile"]: | |
| state["user_profile"] = {} | |
| if "messages" not in state: | |
| state["messages"] = [] | |
| return state | |
| async def analyze_user_request(state: State) -> State: | |
| """ | |
| Analyze the user's request to determine intent and whether we need to update the prompt. | |
| Args: | |
| state: Current state | |
| Returns: | |
| Updated state with analysis results | |
| """ | |
| try: | |
| # Prepare message history | |
| history = convert_to_langchain_messages(state["messages_history"]) | |
| # Build the prompt | |
| prompt = ANALYZE_REQUEST_TEMPLATE | |
| # Call the LLM | |
| response = await llm.ainvoke( | |
| prompt.format_messages( | |
| history=history, | |
| current_system_prompt=state["current_system_prompt"], | |
| user_message=state["user_message"], | |
| ) | |
| ) | |
| # Parse the JSON response | |
| try: | |
| # Clean up the response content to handle potential formatting issues | |
| content = response.content.strip() | |
| # Some models might return the JSON with code formatting markers | |
| if "```json" in content: | |
| content = content.split("```json")[1].split("```")[0].strip() | |
| elif "```" in content: | |
| content = content.split("```")[1].strip() | |
| # Parse the JSON | |
| analysis_result = json.loads(content) | |
| state["analysis_result"] = analysis_result | |
| state["prompt_needs_update"] = analysis_result.get("prompt_needs_update", False) | |
| state["probing_questions_needed"] = analysis_result.get("probing_questions_needed", False) | |
| logger.info(f"Analysis complete: {analysis_result}") | |
| return state | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to parse analysis result: {e}") | |
| logger.error(f"Raw response: {response.content}") | |
| # Try to extract a JSON object from the response if it exists | |
| try: | |
| import re | |
| json_match = re.search(r'({[\s\S]*})', response.content) | |
| if json_match: | |
| json_str = json_match.group(1) | |
| analysis_result = json.loads(json_str) | |
| state["analysis_result"] = analysis_result | |
| state["prompt_needs_update"] = analysis_result.get("prompt_needs_update", False) | |
| state["probing_questions_needed"] = analysis_result.get("probing_questions_needed", False) | |
| logger.info(f"Successfully extracted JSON with regex: {analysis_result}") | |
| return state | |
| except Exception: | |
| # If regex extraction fails, use the default fallback | |
| pass | |
| # Fallback analysis result | |
| state["analysis_result"] = { | |
| "intent": "unknown", | |
| "keywords": [], | |
| "prompt_needs_update": True, | |
| "probing_questions_needed": False, | |
| "confidence": 0.0, | |
| "reasoning": "Failed to parse the analysis result" | |
| } | |
| return state | |
| except Exception as e: | |
| logger.error(f"Error in analyze_user_request: {e}") | |
| state["analysis_result"] = { | |
| "intent": "unknown", | |
| "keywords": [], | |
| "prompt_needs_update": True, | |
| "probing_questions_needed": False, | |
| "confidence": 0.0, | |
| "reasoning": f"Error during analysis: {str(e)}" | |
| } | |
| return state | |
| async def generate_probing_questions(state: State) -> State: | |
| """ | |
| Generate probing questions to better understand the user. | |
| Args: | |
| state: Current state | |
| Returns: | |
| Updated state with probing questions | |
| """ | |
| try: | |
| # Prepare message history | |
| history = convert_to_langchain_messages(state["messages_history"]) | |
| # Build the prompt | |
| prompt = GENERATE_PROBING_QUESTIONS_TEMPLATE | |
| # Call the LLM | |
| response = await llm.ainvoke( | |
| prompt.format_messages( | |
| history=history, | |
| user_message=state["user_message"], | |
| analysis_result=state["analysis_result"], | |
| ) | |
| ) | |
| # Parse the JSON response | |
| try: | |
| questions = json.loads(response.content) | |
| if isinstance(questions, list): | |
| state["probing_questions"] = questions | |
| logger.info(f"Generated probing questions: {questions}") | |
| else: | |
| logger.error(f"Invalid format for probing questions: {questions}") | |
| state["probing_questions"] = ["Bạn có thể chia sẻ thêm về nhu cầu của bạn không?"] | |
| return state | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to parse probing questions: {e}") | |
| logger.error(f"Raw response: {response.content}") | |
| state["probing_questions"] = ["Bạn có thể chia sẻ thêm về nhu cầu của bạn không?"] | |
| return state | |
| except Exception as e: | |
| logger.error(f"Error in generate_probing_questions: {e}") | |
| state["probing_questions"] = ["Bạn có thể chia sẻ thêm về nhu cầu của bạn không?"] | |
| return state | |
| async def update_user_profile(state: State) -> State: | |
| """ | |
| Update the user profile based on the current conversation. | |
| Args: | |
| state: Current state | |
| Returns: | |
| Updated state with updated user profile | |
| """ | |
| try: | |
| # Prepare message history | |
| history = convert_to_langchain_messages(state["messages_history"]) | |
| # Build the prompt | |
| prompt = CREATE_USER_PROFILE_TEMPLATE | |
| # Prepare probing answers (if any) | |
| probing_answers = "Không có" # Default | |
| # Call the LLM | |
| response = await llm.ainvoke( | |
| prompt.format_messages( | |
| history=history, | |
| current_profile=state["user_profile"], | |
| user_message=state["user_message"], | |
| probing_answers=probing_answers, | |
| ) | |
| ) | |
| # Parse the JSON response | |
| try: | |
| profile_updates = json.loads(response.content) | |
| if isinstance(profile_updates, dict): | |
| # Update the existing profile | |
| if not state["user_profile"]: | |
| state["user_profile"] = {} | |
| state["user_profile"].update(profile_updates) | |
| logger.info(f"Updated user profile: {profile_updates}") | |
| else: | |
| logger.error(f"Invalid format for user profile: {profile_updates}") | |
| return state | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to parse user profile: {e}") | |
| logger.error(f"Raw response: {response.content}") | |
| return state | |
| except Exception as e: | |
| logger.error(f"Error in update_user_profile: {e}") | |
| return state | |
| async def update_system_prompt(state: State) -> State: | |
| """ | |
| Update the system prompt based on the user's request and profile. | |
| Args: | |
| state: Current state | |
| Returns: | |
| Updated state with new system prompt | |
| """ | |
| try: | |
| # Prepare message history | |
| history = convert_to_langchain_messages(state["messages_history"]) | |
| # Build the prompt | |
| prompt = UPDATE_SYSTEM_PROMPT_TEMPLATE | |
| # Call the LLM | |
| response = await llm.ainvoke( | |
| prompt.format_messages( | |
| history=history, | |
| current_system_prompt=state["current_system_prompt"], | |
| user_message=state["user_message"], | |
| user_profile=state["user_profile"], | |
| analysis_result=state["analysis_result"], | |
| ) | |
| ) | |
| # Get the new system prompt | |
| new_system_prompt = response.content.strip() | |
| state["updated_system_prompt"] = new_system_prompt | |
| state["final_system_prompt"] = new_system_prompt # Also set as final | |
| logger.info(f"Updated system prompt: {new_system_prompt}") | |
| return state | |
| except Exception as e: | |
| logger.error(f"Error in update_system_prompt: {e}") | |
| state["updated_system_prompt"] = state["current_system_prompt"] # Keep current | |
| state["final_system_prompt"] = state["current_system_prompt"] # Keep current | |
| return state | |
| async def generate_bot_response(state: State) -> State: | |
| """ | |
| Generate the bot's response to the user's message. | |
| Args: | |
| state: Current state | |
| Returns: | |
| Updated state with bot's response | |
| """ | |
| try: | |
| # Prepare message history | |
| history = convert_to_langchain_messages(state["messages_history"]) | |
| # Determine which system prompt to use | |
| system_prompt = state.get("final_system_prompt", state["current_system_prompt"]) | |
| # Build the prompt | |
| prompt = RESPONSE_TEMPLATE | |
| # Call the LLM | |
| response = await llm.ainvoke( | |
| prompt.format_messages( | |
| history=history, | |
| system_prompt=system_prompt, | |
| user_message=state["user_message"], | |
| ) | |
| ) | |
| # Get the bot's response | |
| bot_message = response.content | |
| state["bot_message"] = bot_message | |
| logger.info(f"Generated bot response: {bot_message[:100]}...") | |
| return state | |
| except Exception as e: | |
| logger.error(f"Error in generate_bot_response: {e}") | |
| state["bot_message"] = "Xin lỗi, tôi đang gặp vấn đề kỹ thuật. Vui lòng thử lại sau." | |
| return state | |
| async def process_return_value(state: State) -> State: | |
| """ | |
| Process the final state to prepare the return value. | |
| Args: | |
| state: Current state | |
| Returns: | |
| Final state with processed return values | |
| """ | |
| # Update the message history with the new messages | |
| if "bot_message" in state and state["bot_message"]: | |
| # Add the user message to history if not present | |
| user_msg_in_history = False | |
| if state["messages_history"]: | |
| last_msg = state["messages_history"][-1] | |
| # Check if last_msg is a ChatMessage object or a dictionary | |
| if hasattr(last_msg, 'type'): # It's a Pydantic model | |
| user_msg_in_history = (last_msg.type == "human" and last_msg.content == state["user_message"]) | |
| else: # It's a dictionary | |
| user_msg_in_history = (last_msg["type"] == "human" and last_msg["content"] == state["user_message"]) | |
| if not user_msg_in_history: | |
| state["messages_history"].append({ | |
| "content": state["user_message"], | |
| "type": "human" | |
| }) | |
| # Add the bot message to history | |
| state["messages_history"].append({ | |
| "content": state["bot_message"], | |
| "type": "ai" | |
| }) | |
| return state | |
| def trim_history(state: State, max_history: int = 20) -> State: | |
| """ | |
| Trim the message history to avoid it getting too long. | |
| Args: | |
| state: Current state | |
| max_history: Maximum number of messages to keep | |
| Returns: | |
| State with trimmed history | |
| """ | |
| if len(state["messages_history"]) > max_history: | |
| # Keep only the last max_history messages | |
| state["messages_history"] = state["messages_history"][-max_history:] | |
| return state |