jaeyong2's picture
update logger
ace92aa
"""Question-Answering Agent"""
from typing import Any, AsyncGenerator, Dict
from textwrap import dedent
from datetime import datetime
from langchain_core.runnables import Runnable
from langchain_core.output_parsers import JsonOutputParser
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.checkpoint.memory import MemorySaver
from langchain_tavily import TavilySearch
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.graph.state import CompiledStateGraph
from src.common.timer import T
from src.common.logger import log_info
from src.common.llm import llm
from src.agents.web_logger import WebLogger
LAST_NODE_NAME = "generate_response"
class GraphState(MessagesState):
question: str
question_en: str
path_state: str
verse: str
response: str
contexts: list
references: dict
# remaining_steps: RemainingSteps
class QAAgent:
def __init__(self):
self.graph = self._build_graph()
self.logger = WebLogger()
def match_references(self, text, url_list):
import re
# ์ฐธ๊ณ  ์ž๋ฃŒ ๋ฒˆํ˜ธ ์ถ”์ถœ (์˜ˆ: [[3,4]] ํ˜•์‹)
references = []
for match in re.finditer(r"\[\[([^\]]+)\]\]", text):
# ์‰ผํ‘œ๋กœ ๊ตฌ๋ถ„๋œ ์ˆซ์ž๋“ค์„ ๊ฐœ๋ณ„ ์ฐธ์กฐ๋กœ ๋ถ„๋ฆฌ
nums = match.group(1).split(",")
for num in nums:
references.append(int(num.strip()))
# ์ฐธ๊ณ  ์ž๋ฃŒ ์ˆœ์„œ๋Œ€๋กœ ์ •๋ ฌ (๋“ฑ์žฅ ์ˆœ์„œ ์œ ์ง€)
unique_references = list(dict.fromkeys(references))
# ์ฐธ๊ณ  ์ž๋ฃŒ ๋งคํ•‘ ์ƒ์„ฑ (๋“ฑ์žฅ ์ˆœ์„œ๋Œ€๋กœ 1๋ถ€ํ„ฐ ๋ฒˆํ˜ธ ๋ถ€์—ฌ)
reference_mapping = {}
for i, ref in enumerate(unique_references):
reference_mapping[ref] = i + 1
# URL ๋งคํ•‘ ์ƒ์„ฑ (์ƒˆ๋กœ์šด ์ฐธ์กฐ ๋ฒˆํ˜ธ์— ๋งž์ถฐ URL ์žฌ๋ฐฐ์—ด)
url_mapping = {}
for orig_num, new_num in reference_mapping.items():
# ์›๋ž˜ ์ธ๋ฑ์Šค๋Š” 1๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋ฏ€๋กœ -1 ํ•ด์คŒ
if orig_num >= 1 and orig_num < len(url_list):
url_mapping[new_num] = url_list[orig_num - 1]
# ํ…์ŠคํŠธ ๋‚ด ์ฐธ๊ณ  ์ž๋ฃŒ ๋ฒˆํ˜ธ ๋ณ€ํ™˜
def replace_references(match):
nums = match.group(1).split(",")
# ๊ฐ ๋ฒˆํ˜ธ๋ฅผ ๊ฐœ๋ณ„ [[]] ํƒœ๊ทธ๋กœ ๋ณ€ํ™˜
new_refs = []
for num in nums:
new_num = reference_mapping[int(num.strip())]
new_refs.append(f"[[{new_num}]]")
return "".join(new_refs)
# ํ…์ŠคํŠธ์—์„œ ์ฐธ๊ณ  ์ž๋ฃŒ ๋ฒˆํ˜ธ ๋ณ€ํ™˜
modified_text = re.sub(r"\[\[([^\]]+)\]\]", replace_references, text)
return modified_text, url_mapping
def chunk_formatting(self, contexts):
if len(contexts) == 0:
return ""
formatted_results = []
url_list = []
for i, result in enumerate(contexts):
formatted_results.append(f"์ฐธ๊ณ ์ž๋ฃŒ{i+1} : {result['content']}")
url_list.append(f"{result['url']}")
return "\n".join(formatted_results), url_list
def _build_nodes(self):
@T
def initialize_state(state: GraphState) -> Dict:
"""Initialize the state."""
result = {
"question": state["question"],
}
return result
@T
def search_internet(state: GraphState) -> GraphState:
# Tavily ๊ฒ€์ƒ‰ ๋„๊ตฌ ์„ค์ •
tavily_search_tool = TavilySearch(
max_results=5,
include_domains=[
"biblehub.com",
"biblegateway.com",
"gotquestions.org",
"bible.org",
"christianity.com",
],
)
# ์‚ฌ์šฉ์ž ์งˆ๋ฌธ์œผ๋กœ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
search_results = tavily_search_tool.invoke(state["question"])
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
if not search_results["results"]:
return {
"verse": {
"answer": "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ํ•ด๋‹น ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ๋ฐฉ์‹์œผ๋กœ ์งˆ๋ฌธํ•ด ์ฃผ์‹œ๊ฒ ์Šต๋‹ˆ๊นŒ?",
"reference": [],
},
"chunk": [],
}
result = {"contexts": search_results["results"]}
return result
@T
def generate_response(state: GraphState) -> Dict:
"""
Generate a response based on the retrieved verses.
"""
# ํ”„๋กฌํ”„ํŠธ ์ •์˜ ์ˆ˜์ •
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
dedent(
"""
# Role
๋‹น์‹ ์€ ์„ฑ๊ฒฝ์— ๋Œ€ํ•œ ์งˆ๋ฌธ์„ ์นœ์ ˆํ•˜๊ฒŒ ์•Œ๋ ค์ฃผ๋Š” AI์ž…๋‹ˆ๋‹ค. ์œ ์ €์˜ ์งˆ๋ฌธ์„ ์ž˜ ์ฝ๊ณ , ๊ด€๋ จํ•ด์„œ ํ•„์š”ํ•œ ์ž๋ฃŒ๋ฅผ ์ฐธ๊ณ ํ•ด์„œ ์ ์ ˆํ•œ ๋Œ€๋‹ต์„ ํ•ด์ค˜!
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์„ฑ๊ฒฝ์ ์ธ ๊ด€์ ์—์„œ ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.
๋‹ต๋ณ€์„ ์ค„๋•Œ ๋…ผ๋ฌธ์˜ ํ˜•ํƒœ์ฒ˜๋Ÿผ ์ฐธ๊ณ ํ•œ ์ž๋ฃŒ์˜ ๋ฒˆํ˜ธ๋ฅผ ๋‚˜ํƒ€๋‚ด์ค˜ [[*]] ํ˜•ํƒœ๋กœ ์ ์–ด์ค˜! "~~[[1]] ~~~[[2,3]] "ํ˜•ํƒœ๋กœ ๋‚˜ํƒ€๋‚ด์ค˜!
(์ฐธ๊ณ ์ž๋ฃŒ ๋ฒˆํ˜ธ๋Š” ๋ฐ˜๋“œ์‹œ reference์— ์žˆ๋Š” ๋‚ด์šฉ๊ณผ ๋™์ผํ•ด์•ผํ•ด!)
๋‹ต๋ณ€์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ JSON ํ˜•์‹์œผ๋กœ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”:
{{
"answer": "๋‹ต๋ณ€ ๋‚ด์šฉ์„ ์—ฌ๊ธฐ์— ์ž‘์„ฑ",
}}
"""
),
),
MessagesPlaceholder(variable_name="messages"),
("human", "question: {question}"),
("ai", "๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ: {search_results}"),
]
)
formatted_result, url_list = self.chunk_formatting(state["contexts"])
# LLM์œผ๋กœ ์‘๋‹ต ์ƒ์„ฑ
chain = prompt | llm | JsonOutputParser()
response = chain.invoke(
{
"messages": state.get("messages", []),
"question": state["question"],
"search_results": formatted_result,
}
)
response, references = self.match_references(response["answer"], url_list)
result = {"response": response, "references": references}
return result
@T
def update_history(state: GraphState) -> GraphState:
"""
Update the history of the conversation.
"""
if "messages" not in state or not isinstance(state["messages"], list):
log_info("testse")
state["messages"] = []
# ์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ๊ณผ AI์˜ ์‘๋‹ต์„ ๋Œ€ํ™” ๋‚ด์—ญ์— ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
state["messages"].append(("human", state["question"]))
state["messages"].append(("ai", state["response"]))
self.logger.insert(
question=state["question"],
answer=state["response"],
history=state["messages"],
created_at=datetime.now().isoformat(),
)
log_info(state["messages"])
log_info(state["references"])
return {"messages": state["messages"], "references": state["references"]}
return {
"initialize_state": initialize_state,
"search_internet": search_internet,
"update_history": update_history,
"generate_response": generate_response,
}
def _build_graph(self) -> CompiledStateGraph:
workflow = StateGraph(GraphState)
# Define the nodes
nodes = self._build_nodes()
workflow.add_node("initialize_state", nodes["initialize_state"])
workflow.add_node("search_internet", nodes["search_internet"])
workflow.add_node("generate_response", nodes["generate_response"])
workflow.add_node("update_history", nodes["update_history"])
workflow.add_edge(START, "initialize_state")
workflow.add_edge("initialize_state", "search_internet")
workflow.add_edge("search_internet", "generate_response")
workflow.add_edge("generate_response", "update_history")
workflow.add_edge("update_history", END)
checkpointer = MemorySaver()
# Compile
workflow = workflow.compile(checkpointer=checkpointer)
return workflow
def _invoke_agent(
self, agent: Runnable, input: dict, log_prompt: bool = False
) -> Any:
"""Invoke the agent."""
return agent.invoke(input)
@T
def invoke(
self, question: str, thread_id: str = "test", user_id: str = "test"
) -> dict:
"""Invoke the graph."""
for output in self.graph.stream(
{"question": question},
config={"configurable": {"thread_id": thread_id, "user_id": user_id}},
):
pass
return {
"answer": output["update_history"]["messages"][-1][-1],
"references": output["update_history"]["references"],
}
@T
async def astream(
self, question: str, thread_id: str = "test", user_id: str = "test"
) -> AsyncGenerator[str, None]:
"""Asynchronous stream the graph."""
async for event in self.graph.astream_events(
{"question": question}, configurable={"thread_id": thread_id}
):
kind = event["event"]
if kind == "on_chat_model_stream":
# Print only in the last node
node_name = event["metadata"]["langgraph_node"]
if node_name == LAST_NODE_NAME:
if (
hasattr(event["data"]["chunk"], "content")
and event["data"]["chunk"].content
):
addition = event["data"]["chunk"].content
yield addition
@T
def stream(self, question: str, thread_id: str = "test", user_id: str = "test"):
"""Synchronous stream the graph."""
for event in self.graph.stream({"question": question}):
kind = event["event"]
if kind == "on_chat_model_stream":
# Print only in the last node
node_name = event["metadata"]["langgraph_node"]
if node_name == LAST_NODE_NAME:
if (
hasattr(event["data"]["chunk"], "content")
and event["data"]["chunk"].content
):
addition = event["data"]["chunk"].content
yield addition
if __name__ == "__main__":
import asyncio
qa_agent = QAAgent()
qa_agent.invoke("ํ•˜๋‚˜๋‹˜์ด ์ฃ„๋ฅผ ์ฐฝ์กฐํ•˜์‹  ์ด์œ ?")