CraftBeer_Search / agent.py
koji9581's picture
Update agent.py
67ebe24 verified
import os
from typing import TypedDict, Annotated, List, Union
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
from langchain_core.prompts import ChatPromptTemplate
from serpapi import GoogleSearch
# --- システムプロンプト ---
BASE_SYSTEM_PROMPT = """あなたは「クラフトビール博士」という役割を持つアシスタントです。
# ユーザー情報
- **味の好み**: {preference}
- **あなたからの提案**: {recommendation}
# あなたの基本行動
- ユーザーからクラフトビールに関する質問をされた場合は、**原則として、まず`search_tool`を使って関連情報を検索してください。**
- 回答の冒頭や文脈の中で、ユーザーの好み({preference})に触れ、なぜそのスタイル({recommendation})がおすすめなのかを一言添えてください。
- その上で、ユーザーの具体的な質問に対して、検索結果に基づいた正確な情報を回答してください。
# 回答のスタイル
- **「検索結果によると」「Web上の情報では」といった表現を使用しても構いません。** 正確な情報を伝えることを優先してください。
- 専門家として丁寧な口調で話してください。
# 禁止事項
- [具体的な銘柄名]のようなプレースホルダーを残さない(必ず具体的な名前に置き換える)。
"""
# --- 型定義 ---
class AgentState(TypedDict):
messages: Annotated[List[BaseMessage], lambda x, y: x + y]
pending_query: Union[str, None]
user_preference: Union[str, None]
recommended_style: Union[str, None]
# --- ツール定義 ---
@tool
def search_tool(query: str) -> str:
"""SerpApiを使ってGoogle検索を行うツール。"""
print(f"--- Search Tool 実行: {query} ---")
try:
params = {
"q": query,
"api_key": os.getenv("SERPAPI_API_KEY"),
"engine": "google",
"google_domain": "google.co.jp",
"gl": "jp",
"hl": "ja",
}
search = GoogleSearch(params)
results = search.get_dict()
snippets = []
if "answer_box" in results:
box = results.get("answer_box", {})
if "answer" in box: snippets.append(box["answer"])
elif "snippet" in box: snippets.append(box["snippet"])
if "organic_results" in results:
for result in results.get("organic_results", [])[:10]:
if "snippet" in result: snippets.append(result["snippet"])
return "\n".join(snippets) if snippets else "検索結果が見つかりませんでした。"
except Exception as e:
return f"検索中にエラーが発生しました: {e}"
# --- グラフ構築関数 ---
def create_graph():
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
tools = [search_tool]
llm_with_tools = llm.bind_tools(tools)
# --- ノード関数群 ---
def ask_preference_node(state: AgentState):
"""最初の発言を分析し、必要なら保留して味の好みを聞く"""
messages = state["messages"]
last_user_msg = messages[-1].content
# 最初の発言が「質問」か「開始トリガー」かを判定
classifier_prompt = ChatPromptTemplate.from_messages([
("system", "ユーザーのメッセージが『具体的な知識を問う質問(例:IPAとは、歴史は)』か、『単なる開始の挨拶や推薦の依頼(例:おすすめ教えて、こんにちは)』かを判定せよ。質問なら `question`、それ以外なら `trigger` と出力せよ。"),
("human", f"メッセージ: {last_user_msg}"),
])
msg_type = (classifier_prompt | llm).invoke({}).content.strip()
print(f"--- 初期メッセージ判定: {last_user_msg} -> {msg_type} ---")
if "question" in msg_type:
pending = last_user_msg
response_text = "承知しました!そのご質問にお答えする前に、お客様に最適な一杯をご提案したいので、好みの味を教えていただけますか?\n(例:苦いのが好き、フルーティーなのがいい、など)"
else:
pending = None
response_text = "いらっしゃいませ!お客様にぴったりのクラフトビールをご提案します。\nまずは、好みの味のタイプを教えていただけますか?\n(例:苦いのが好き、度数は低め、すっきり系、など)"
return {
"pending_query": pending,
"messages": [AIMessage(content=response_text)]
}
def analyze_preference_node(state: AgentState):
"""ユーザーの回答からおすすめのスタイルを推論する"""
messages = state["messages"]
preference_text = messages[-1].content
print(f"--- 好み分析: {preference_text} ---")
analysis_prompt = ChatPromptTemplate.from_messages([
("system", "ユーザーの味の好みに基づき、最もおすすめなクラフトビールの『スタイル名(英語またはカタカナ)』を1つか2つだけ挙げて回答してください。余計な文章は不要です。例:『ピルスナー、ヘレス』"),
("human", f"味の好み: {preference_text}"),
])
analyzer = analysis_prompt | llm
recommended_style = analyzer.invoke({}).content.strip()
print(f"--- 提案スタイル: {recommended_style} ---")
return {
"user_preference": preference_text,
"recommended_style": recommended_style
}
def craft_beer_conversation_node(state: AgentState):
"""検索と回答生成を行う。"""
print("--- 会話ノード実行 ---")
preference = state.get("user_preference", "特になし")
recommendation = state.get("recommended_style", "おすすめのクラフトビール")
pending_query = state.get("pending_query")
current_system_prompt = BASE_SYSTEM_PROMPT.format(
preference=preference,
recommendation=recommendation
)
messages = [SystemMessage(content=current_system_prompt)] + state["messages"]
# --- パターンB: 保留中の質問がある場合(IPAとは?) ---
if pending_query:
print(f"--- 保留中の質問を使用: {pending_query} ---")
instruction = (
f"【システム指示】ユーザーの好みが分かりました。\n"
f"1. まず、「お客様の好み({preference})なら、おすすめは{recommendation}です」と提案してください。\n"
f"2. その後、ユーザーが最初にしていた質問「{pending_query}」について、検索結果を用いて回答してください。"
)
messages.append(HumanMessage(content=instruction))
# --- パターンA: おすすめ希望のみの場合 ---
# 【修正】ここで検索を強制する指示を追加しました
elif state.get("user_preference") and len(state["messages"]) <= 4:
print("--- 提案+銘柄検索実行 ---")
instruction = (
f"【システム指示】ユーザーの好みが「{preference}」だと分かりました。\n"
f"1. この好みに合うスタイルとして「{recommendation}」を提案してください。\n"
f"2. さらに、**必ず`search_tool`を使用して**、そのスタイル({recommendation})の代表的な銘柄や、現在日本で購入できるおすすめの具体的な商品(缶ビールなど)を10件検索し、紹介してください。\n"
f"※もし10件見つからなかった場合でも、可能な限り多く挙げてください。"
)
messages.append(HumanMessage(content=instruction))
response = llm_with_tools.invoke(messages)
return {"messages": [response], "pending_query": None}
def router(state: AgentState) -> str:
"""会話の文脈判断(2ターン目以降)"""
print("--- ルーター実行 ---")
messages = state["messages"]
last_msg = messages[-1].content
prompt = ChatPromptTemplate.from_messages([
("system", "ユーザーの意図を判断: `continue` (継続・質問), `other` (無関係な話題)"),
("human", f"ユーザー: {last_msg}"),
])
intent = (prompt | llm).invoke({}).content.strip()
if "other" in intent and "continue" not in intent:
return "canned"
return "conversation"
def canned_response_node(state: AgentState):
return {"messages": [AIMessage(content="申し訳ありませんが、クラフトビールに関する質問のみお答えできます。")]}
def should_continue(state: AgentState) -> str:
last_message = state["messages"][-1]
if not isinstance(last_message, AIMessage): return "end"
if last_message.tool_calls: return "continue"
return "end"
# --- グラフ定義 ---
workflow = StateGraph(AgentState)
workflow.add_node("ask_preference", ask_preference_node)
workflow.add_node("analyze_preference", analyze_preference_node)
workflow.add_node("conversation", craft_beer_conversation_node)
workflow.add_node("canned", canned_response_node)
workflow.add_node("action", ToolNode(tools))
def entry_router(state):
messages = state["messages"]
if not state.get("user_preference"):
if len(messages) == 1:
return "ask_preference"
else:
return "analyze_preference"
else:
return router(state)
workflow.set_conditional_entry_point(
entry_router,
{
"ask_preference": "ask_preference",
"analyze_preference": "analyze_preference",
"conversation": "conversation",
"canned": "canned"
}
)
workflow.add_edge("ask_preference", END)
workflow.add_edge("analyze_preference", "conversation")
workflow.add_conditional_edges(
"conversation",
should_continue,
{"continue": "action", "end": END}
)
workflow.add_edge("action", "conversation")
memory = MemorySaver()
return workflow.compile(checkpointer=memory)