adilpanwar's picture
Upload 16 files
927c050 verified
Raw
History Blame Contribute Delete
8.53 kB
"""Node functions for the multi-agent graph."""
import logging
from typing import Optional
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.store.base import BaseStore
from langgraph.types import interrupt
from src.state import State
from src.models import UserInput, UserProfile
from src.agents.prompts import (
generate_music_assistant_prompt,
STRUCTURED_EXTRACTION_PROMPT,
VERIFICATION_PROMPT,
CREATE_MEMORY_PROMPT,
)
from src.db.database import get_engine, normalize_phone
logger = logging.getLogger(__name__)
def get_customer_id_from_identifier(identifier: str) -> Optional[int]:
if not identifier or not identifier.strip():
return None
identifier = identifier.strip()
engine = get_engine()
try:
from sqlalchemy import text
if "@" in identifier:
with engine.connect() as conn:
result = conn.execute(
text("SELECT CustomerId FROM Customer WHERE LOWER(Email) = LOWER(:email)"),
{"email": identifier},
)
row = result.fetchone()
if row:
return int(row[0])
if identifier.isdigit():
with engine.connect() as conn:
result = conn.execute(
text("SELECT CustomerId FROM Customer WHERE CustomerId = :cid"),
{"cid": int(identifier)},
)
row = result.fetchone()
if row:
return int(row[0])
normalized_input = normalize_phone(identifier)
if normalized_input and len(normalized_input) >= 5:
with engine.connect() as conn:
result = conn.execute(text("SELECT CustomerId, Phone FROM Customer WHERE Phone IS NOT NULL"))
for row in result:
db_phone_normalized = normalize_phone(str(row[1]))
if db_phone_normalized == normalized_input:
return int(row[0])
except Exception as e:
logger.error(f"Error looking up customer by identifier '{identifier}': {e}")
return None
def format_user_memory(user_data: dict) -> str:
try:
profile = user_data.get("memory")
if profile and hasattr(profile, "music_preferences") and profile.music_preferences:
return f"Music Preferences: {', '.join(profile.music_preferences)}"
except Exception as e:
logger.error(f"Error formatting user memory: {e}")
return ""
def create_music_assistant_node(llm, music_tools):
llm_with_tools = llm.bind_tools(music_tools)
def music_assistant(state: State, config: RunnableConfig):
memory = state.get("loaded_memory", "None") or "None"
prompt = generate_music_assistant_prompt(memory)
messages = [SystemMessage(content=prompt)]
if state.get("customer_id"):
messages.append(
SystemMessage(content=f"The current verified customer ID is: {state['customer_id']}")
)
messages.extend(state["messages"])
logger.info(f"Music assistant invoked with {len(state['messages'])} conversation messages")
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
return music_assistant
def should_continue(state: State, config: RunnableConfig) -> str:
messages = state["messages"]
last_message = messages[-1]
if not last_message.tool_calls:
return "end"
return "continue"
def should_interrupt(state: State, config: RunnableConfig) -> str:
if state.get("customer_id") is not None:
return "continue"
return "interrupt"
def create_verify_info_node(llm):
structured_llm = llm.with_structured_output(schema=UserInput)
def verify_info(state: State, config: RunnableConfig):
if state.get("customer_id") is not None:
logger.info(f"Customer already verified: {state['customer_id']}")
return {}
user_input = state["messages"][-1]
logger.info(f"Verification attempt with message: {getattr(user_input, 'content', '')[:100]}")
try:
parsed_info = structured_llm.invoke(
[SystemMessage(content=STRUCTURED_EXTRACTION_PROMPT)] + [user_input]
)
identifier = parsed_info.identifier
logger.info(f"Extracted identifier: '{identifier}'")
except Exception as e:
logger.error(f"Error parsing user input for verification: {e}")
identifier = ""
customer_id = None
if identifier:
customer_id = get_customer_id_from_identifier(identifier)
logger.info(f"DB lookup result: customer_id={customer_id}")
if customer_id is not None:
intent_message = SystemMessage(
content=(
f"Customer verified successfully. "
f"The verified customer_id is {customer_id}. "
f"Use this customer_id for all invoice and purchase lookups."
)
)
return {
"customer_id": str(customer_id),
"messages": [intent_message],
}
else:
response = llm.invoke(
[SystemMessage(content=VERIFICATION_PROMPT)] + state["messages"]
)
return {"messages": [response]}
return verify_info
def human_input(state: State, config: RunnableConfig):
user_input = interrupt("Please provide input.")
return {"messages": [HumanMessage(content=user_input)]}
def load_memory(state: State, config: RunnableConfig, store: BaseStore):
user_id = str(state.get("customer_id", ""))
if not user_id:
return {"loaded_memory": ""}
namespace = ("memory_profile", user_id)
try:
existing_memory = store.get(namespace, "user_memory")
if existing_memory and existing_memory.value:
formatted = format_user_memory(existing_memory.value)
logger.info(f"Loaded memory for customer {user_id}: {formatted}")
return {"loaded_memory": formatted}
except Exception as e:
logger.error(f"Error loading memory for user {user_id}: {e}")
return {"loaded_memory": ""}
def create_memory_node(llm):
def create_memory(state: State, config: RunnableConfig, store: BaseStore):
user_id = str(state.get("customer_id", ""))
if not user_id:
return {}
namespace = ("memory_profile", user_id)
try:
existing_preferences = []
existing_memory = store.get(namespace, "user_memory")
formatted_memory = ""
if existing_memory and existing_memory.value:
mem_dict = existing_memory.value
profile = mem_dict.get("memory")
if profile and hasattr(profile, "music_preferences"):
existing_preferences = list(profile.music_preferences or [])
formatted_memory = f"Music Preferences: {', '.join(existing_preferences)}"
recent_messages = state["messages"][-10:]
conversation_summary = "\n".join(
f"{getattr(msg, 'type', 'unknown')}: {getattr(msg, 'content', '')}"
for msg in recent_messages
if getattr(msg, "content", "")
)
formatted_prompt = CREATE_MEMORY_PROMPT.format(
conversation=conversation_summary,
memory_profile=formatted_memory or "Empty, no existing profile",
)
updated_memory = llm.with_structured_output(UserProfile).invoke(
[SystemMessage(content=formatted_prompt)]
)
new_prefs = updated_memory.music_preferences or []
if not new_prefs and existing_preferences:
logger.info(f"Memory unchanged for customer {user_id} (preserving existing preferences)")
return {}
merged_prefs = list(set(existing_preferences + new_prefs))
updated_memory.music_preferences = merged_prefs
updated_memory.customer_id = user_id
store.put(namespace, "user_memory", {"memory": updated_memory})
logger.info(f"Memory updated for customer {user_id}: {merged_prefs}")
except Exception as e:
logger.error(f"Error creating/updating memory for user {user_id}: {e}")
return create_memory