|
|
import logging |
|
|
from datetime import datetime |
|
|
from typing import Any |
|
|
|
|
|
from langchain_core.language_models.base import LanguageModelInput |
|
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage |
|
|
from langchain_core.prompts import SystemMessagePromptTemplate |
|
|
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda, RunnableSerializable |
|
|
from langgraph.graph import END, MessagesState, StateGraph |
|
|
from langgraph.store.base import BaseStore |
|
|
from langgraph.types import interrupt |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from core import get_model, settings |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class AgentState(MessagesState, total=False): |
|
|
"""`total=False` is PEP589 specs. |
|
|
|
|
|
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality |
|
|
""" |
|
|
|
|
|
birthdate: datetime | None |
|
|
|
|
|
|
|
|
def wrap_model( |
|
|
model: BaseChatModel | Runnable[LanguageModelInput, Any], system_prompt: BaseMessage |
|
|
) -> RunnableSerializable[AgentState, Any]: |
|
|
preprocessor = RunnableLambda( |
|
|
lambda state: [system_prompt] + state["messages"], |
|
|
name="StateModifier", |
|
|
) |
|
|
return preprocessor | model |
|
|
|
|
|
|
|
|
background_prompt = SystemMessagePromptTemplate.from_template(""" |
|
|
You are a helpful assistant that tells users there zodiac sign. |
|
|
Provide a one sentence summary of the origin of zodiac signs. |
|
|
Don't tell the user what their sign is, you are just demonstrating your knowledge on the topic. |
|
|
""") |
|
|
|
|
|
|
|
|
async def background(state: AgentState, config: RunnableConfig) -> AgentState: |
|
|
"""This node is to demonstrate doing work before the interrupt""" |
|
|
|
|
|
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) |
|
|
model_runnable = wrap_model(m, background_prompt.format()) |
|
|
response = await model_runnable.ainvoke(state, config) |
|
|
|
|
|
return {"messages": [AIMessage(content=response.content)]} |
|
|
|
|
|
|
|
|
birthdate_extraction_prompt = SystemMessagePromptTemplate.from_template(""" |
|
|
You are an expert at extracting birthdates from conversational text. |
|
|
|
|
|
Rules for extraction: |
|
|
- Look for user messages that mention birthdates |
|
|
- Consider various date formats (MM/DD/YYYY, YYYY-MM-DD, Month Day, Year) |
|
|
- Validate that the date is reasonable (not in the future) |
|
|
- If no clear birthdate was provided by the user, return None |
|
|
""") |
|
|
|
|
|
|
|
|
class BirthdateExtraction(BaseModel): |
|
|
birthdate: str | None = Field( |
|
|
description="The extracted birthdate in YYYY-MM-DD format. If no birthdate is found, this should be None." |
|
|
) |
|
|
reasoning: str = Field( |
|
|
description="Explanation of how the birthdate was extracted or why no birthdate was found" |
|
|
) |
|
|
|
|
|
|
|
|
async def determine_birthdate( |
|
|
state: AgentState, config: RunnableConfig, store: BaseStore |
|
|
) -> AgentState: |
|
|
"""This node examines the conversation history to determine user's birthdate, checking store first.""" |
|
|
|
|
|
|
|
|
user_id = config["configurable"].get("user_id") |
|
|
logger.info(f"[determine_birthdate] Extracted user_id: {user_id}") |
|
|
namespace = None |
|
|
key = "birthdate" |
|
|
birthdate = None |
|
|
|
|
|
if user_id: |
|
|
|
|
|
namespace = (user_id,) |
|
|
|
|
|
|
|
|
try: |
|
|
result = await store.aget(namespace, key=key) |
|
|
|
|
|
user_data = None |
|
|
if result: |
|
|
if isinstance(result, list): |
|
|
if result: |
|
|
user_data = result[0] |
|
|
else: |
|
|
user_data = result |
|
|
|
|
|
if user_data and user_data.value.get("birthdate"): |
|
|
|
|
|
birthdate_str = user_data.value["birthdate"] |
|
|
birthdate = datetime.fromisoformat(birthdate_str) if birthdate_str else None |
|
|
|
|
|
logger.info( |
|
|
f"[determine_birthdate] Found birthdate in store for user {user_id}: {birthdate}" |
|
|
) |
|
|
return { |
|
|
"birthdate": birthdate, |
|
|
"messages": [], |
|
|
} |
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Error reading from store for namespace {namespace}, key {key}: {e}") |
|
|
|
|
|
pass |
|
|
else: |
|
|
|
|
|
|
|
|
logger.warning( |
|
|
"Warning: user_id not found in config. Skipping persistent birthdate storage/retrieval for this run." |
|
|
) |
|
|
|
|
|
|
|
|
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) |
|
|
model_runnable = wrap_model( |
|
|
m.with_structured_output(BirthdateExtraction), birthdate_extraction_prompt.format() |
|
|
).with_config(tags=["skip_stream"]) |
|
|
response: BirthdateExtraction = await model_runnable.ainvoke(state, config) |
|
|
|
|
|
|
|
|
if response.birthdate is None: |
|
|
birthdate_input = interrupt(f"{response.reasoning}\nPlease tell me your birthdate?") |
|
|
|
|
|
state["messages"].append(HumanMessage(birthdate_input)) |
|
|
|
|
|
return await determine_birthdate(state, config, store) |
|
|
|
|
|
|
|
|
try: |
|
|
birthdate = datetime.fromisoformat(response.birthdate) |
|
|
except ValueError: |
|
|
|
|
|
birthdate_input = interrupt( |
|
|
"I couldn't understand the date format. Please provide your birthdate in YYYY-MM-DD format." |
|
|
) |
|
|
|
|
|
state["messages"].append(HumanMessage(birthdate_input)) |
|
|
|
|
|
return await determine_birthdate(state, config, store) |
|
|
|
|
|
|
|
|
if user_id and namespace: |
|
|
|
|
|
birthdate_str = birthdate.isoformat() if birthdate else None |
|
|
try: |
|
|
await store.aput(namespace, key, {"birthdate": birthdate_str}) |
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Error writing to store for namespace {namespace}, key {key}: {e}") |
|
|
|
|
|
|
|
|
logger.info(f"[determine_birthdate] Returning birthdate {birthdate} for user {user_id}") |
|
|
return { |
|
|
"birthdate": birthdate, |
|
|
"messages": [], |
|
|
} |
|
|
|
|
|
|
|
|
response_prompt = SystemMessagePromptTemplate.from_template(""" |
|
|
You are a helpful assistant. |
|
|
|
|
|
Known information: |
|
|
- The user's birthdate is {birthdate_str} |
|
|
|
|
|
User's latest message: "{last_user_message}" |
|
|
|
|
|
Based on the known information and the user's message, provide a helpful and relevant response. |
|
|
If the user asked for their birthdate, confirm it. |
|
|
If the user asked for their zodiac sign, calculate it and tell them. |
|
|
Otherwise, respond conversationally based on their message. |
|
|
""") |
|
|
|
|
|
|
|
|
async def generate_response(state: AgentState, config: RunnableConfig) -> AgentState: |
|
|
"""Generates the final response based on the user's query and the available birthdate.""" |
|
|
birthdate = state.get("birthdate") |
|
|
if state.get("messages") and isinstance(state["messages"][-1], HumanMessage): |
|
|
last_user_message = state["messages"][-1].content |
|
|
else: |
|
|
last_user_message = "" |
|
|
|
|
|
if not birthdate: |
|
|
|
|
|
|
|
|
return { |
|
|
"messages": [ |
|
|
AIMessage( |
|
|
content="I couldn't determine your birthdate. Could you please provide it?" |
|
|
) |
|
|
] |
|
|
} |
|
|
|
|
|
birthdate_str = birthdate.strftime("%B %d, %Y") |
|
|
|
|
|
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) |
|
|
model_runnable = wrap_model( |
|
|
m, response_prompt.format(birthdate_str=birthdate_str, last_user_message=last_user_message) |
|
|
) |
|
|
response = await model_runnable.ainvoke(state, config) |
|
|
|
|
|
return {"messages": [AIMessage(content=response.content)]} |
|
|
|
|
|
|
|
|
|
|
|
agent = StateGraph(AgentState) |
|
|
agent.add_node("background", background) |
|
|
agent.add_node("determine_birthdate", determine_birthdate) |
|
|
agent.add_node("generate_response", generate_response) |
|
|
|
|
|
agent.set_entry_point("background") |
|
|
agent.add_edge("background", "determine_birthdate") |
|
|
agent.add_edge("determine_birthdate", "generate_response") |
|
|
agent.add_edge("generate_response", END) |
|
|
|
|
|
interrupt_agent = agent.compile() |
|
|
interrupt_agent.name = "interrupt-agent" |
|
|
|