backend / src /agents /interrupt_agent.py
anujjoshi3105's picture
initial
22dcdfd
import logging
from datetime import datetime
from typing import Any
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda, RunnableSerializable
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.store.base import BaseStore
from langgraph.types import interrupt
from pydantic import BaseModel, Field
from core import get_model, settings
# Added logger
logger = logging.getLogger(__name__)
class AgentState(MessagesState, total=False):
"""`total=False` is PEP589 specs.
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
"""
birthdate: datetime | None
def wrap_model(
model: BaseChatModel | Runnable[LanguageModelInput, Any], system_prompt: BaseMessage
) -> RunnableSerializable[AgentState, Any]:
preprocessor = RunnableLambda(
lambda state: [system_prompt] + state["messages"],
name="StateModifier",
)
return preprocessor | model
background_prompt = SystemMessagePromptTemplate.from_template("""
You are a helpful assistant that tells users there zodiac sign.
Provide a one sentence summary of the origin of zodiac signs.
Don't tell the user what their sign is, you are just demonstrating your knowledge on the topic.
""")
async def background(state: AgentState, config: RunnableConfig) -> AgentState:
"""This node is to demonstrate doing work before the interrupt"""
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
model_runnable = wrap_model(m, background_prompt.format())
response = await model_runnable.ainvoke(state, config)
return {"messages": [AIMessage(content=response.content)]}
birthdate_extraction_prompt = SystemMessagePromptTemplate.from_template("""
You are an expert at extracting birthdates from conversational text.
Rules for extraction:
- Look for user messages that mention birthdates
- Consider various date formats (MM/DD/YYYY, YYYY-MM-DD, Month Day, Year)
- Validate that the date is reasonable (not in the future)
- If no clear birthdate was provided by the user, return None
""")
class BirthdateExtraction(BaseModel):
birthdate: str | None = Field(
description="The extracted birthdate in YYYY-MM-DD format. If no birthdate is found, this should be None."
)
reasoning: str = Field(
description="Explanation of how the birthdate was extracted or why no birthdate was found"
)
async def determine_birthdate(
state: AgentState, config: RunnableConfig, store: BaseStore
) -> AgentState:
"""This node examines the conversation history to determine user's birthdate, checking store first."""
# Attempt to get user_id for unique storage per user
user_id = config["configurable"].get("user_id")
logger.info(f"[determine_birthdate] Extracted user_id: {user_id}")
namespace = None
key = "birthdate"
birthdate = None # Initialize birthdate
if user_id:
# Use user_id in the namespace to ensure uniqueness per user
namespace = (user_id,)
# Check if we already have the birthdate in the store for this user
try:
result = await store.aget(namespace, key=key)
# Handle cases where store.aget might return Item directly or a list
user_data = None
if result: # Check if anything was returned
if isinstance(result, list):
if result: # Check if list is not empty
user_data = result[0]
else: # Assume it's the Item object directly
user_data = result
if user_data and user_data.value.get("birthdate"):
# Convert ISO format string back to datetime object
birthdate_str = user_data.value["birthdate"]
birthdate = datetime.fromisoformat(birthdate_str) if birthdate_str else None
# We already have the birthdate, return it
logger.info(
f"[determine_birthdate] Found birthdate in store for user {user_id}: {birthdate}"
)
return {
"birthdate": birthdate,
"messages": [],
}
except Exception as e:
# Log the error or handle cases where the store might be unavailable
logger.error(f"Error reading from store for namespace {namespace}, key {key}: {e}")
# Proceed with extraction if read fails
pass
else:
# If no user_id, we cannot reliably store/retrieve user-specific data.
# Consider logging this situation.
logger.warning(
"Warning: user_id not found in config. Skipping persistent birthdate storage/retrieval for this run."
)
# If birthdate wasn't retrieved from store, proceed with extraction
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
model_runnable = wrap_model(
m.with_structured_output(BirthdateExtraction), birthdate_extraction_prompt.format()
).with_config(tags=["skip_stream"])
response: BirthdateExtraction = await model_runnable.ainvoke(state, config)
# If no birthdate found after extraction attempt, interrupt
if response.birthdate is None:
birthdate_input = interrupt(f"{response.reasoning}\nPlease tell me your birthdate?")
# Re-run extraction with the new input
state["messages"].append(HumanMessage(birthdate_input))
# Note: Recursive call might need careful handling of depth or state updates
return await determine_birthdate(state, config, store)
# Birthdate found - convert string to datetime
try:
birthdate = datetime.fromisoformat(response.birthdate)
except ValueError:
# If parsing fails, ask for clarification
birthdate_input = interrupt(
"I couldn't understand the date format. Please provide your birthdate in YYYY-MM-DD format."
)
# Re-run extraction with the new input
state["messages"].append(HumanMessage(birthdate_input))
# Note: Recursive call might need careful handling of depth or state updates
return await determine_birthdate(state, config, store)
# Store the newly extracted birthdate only if we have a user_id
if user_id and namespace:
# Convert datetime to ISO format string for JSON serialization
birthdate_str = birthdate.isoformat() if birthdate else None
try:
await store.aput(namespace, key, {"birthdate": birthdate_str})
except Exception as e:
# Log the error or handle cases where the store write might fail
logger.error(f"Error writing to store for namespace {namespace}, key {key}: {e}")
# Return the determined birthdate (either from store or extracted)
logger.info(f"[determine_birthdate] Returning birthdate {birthdate} for user {user_id}")
return {
"birthdate": birthdate,
"messages": [],
}
response_prompt = SystemMessagePromptTemplate.from_template("""
You are a helpful assistant.
Known information:
- The user's birthdate is {birthdate_str}
User's latest message: "{last_user_message}"
Based on the known information and the user's message, provide a helpful and relevant response.
If the user asked for their birthdate, confirm it.
If the user asked for their zodiac sign, calculate it and tell them.
Otherwise, respond conversationally based on their message.
""")
async def generate_response(state: AgentState, config: RunnableConfig) -> AgentState:
"""Generates the final response based on the user's query and the available birthdate."""
birthdate = state.get("birthdate")
if state.get("messages") and isinstance(state["messages"][-1], HumanMessage):
last_user_message = state["messages"][-1].content
else:
last_user_message = ""
if not birthdate:
# This should ideally not be reached if determine_birthdate worked correctly and possibly interrupted.
# Handle cases where birthdate might still be missing.
return {
"messages": [
AIMessage(
content="I couldn't determine your birthdate. Could you please provide it?"
)
]
}
birthdate_str = birthdate.strftime("%B %d, %Y") # Format for display
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
model_runnable = wrap_model(
m, response_prompt.format(birthdate_str=birthdate_str, last_user_message=last_user_message)
)
response = await model_runnable.ainvoke(state, config)
return {"messages": [AIMessage(content=response.content)]}
# Define the graph
agent = StateGraph(AgentState)
agent.add_node("background", background)
agent.add_node("determine_birthdate", determine_birthdate)
agent.add_node("generate_response", generate_response)
agent.set_entry_point("background")
agent.add_edge("background", "determine_birthdate")
agent.add_edge("determine_birthdate", "generate_response")
agent.add_edge("generate_response", END)
interrupt_agent = agent.compile()
interrupt_agent.name = "interrupt-agent"