"""Question-Answering Agent""" from typing import Any, AsyncGenerator, Dict from textwrap import dedent from datetime import datetime from langchain_core.runnables import Runnable from langchain_core.output_parsers import JsonOutputParser from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langgraph.checkpoint.memory import MemorySaver from langchain_tavily import TavilySearch from langgraph.graph import StateGraph, START, END, MessagesState from langgraph.graph.state import CompiledStateGraph from src.common.timer import T from src.common.logger import log_info from src.common.llm import llm from src.agents.web_logger import WebLogger LAST_NODE_NAME = "generate_response" class GraphState(MessagesState): question: str question_en: str path_state: str verse: str response: str contexts: list references: dict # remaining_steps: RemainingSteps class QAAgent: def __init__(self): self.graph = self._build_graph() self.logger = WebLogger() def match_references(self, text, url_list): import re # 참고 자료 번호 추출 (예: [[3,4]] 형식) references = [] for match in re.finditer(r"\[\[([^\]]+)\]\]", text): # 쉼표로 구분된 숫자들을 개별 참조로 분리 nums = match.group(1).split(",") for num in nums: references.append(int(num.strip())) # 참고 자료 순서대로 정렬 (등장 순서 유지) unique_references = list(dict.fromkeys(references)) # 참고 자료 매핑 생성 (등장 순서대로 1부터 번호 부여) reference_mapping = {} for i, ref in enumerate(unique_references): reference_mapping[ref] = i + 1 # URL 매핑 생성 (새로운 참조 번호에 맞춰 URL 재배열) url_mapping = {} for orig_num, new_num in reference_mapping.items(): # 원래 인덱스는 1부터 시작하므로 -1 해줌 if orig_num >= 1 and orig_num < len(url_list): url_mapping[new_num] = url_list[orig_num - 1] # 텍스트 내 참고 자료 번호 변환 def replace_references(match): nums = match.group(1).split(",") # 각 번호를 개별 [[]] 태그로 변환 new_refs = [] for num in nums: new_num = reference_mapping[int(num.strip())] new_refs.append(f"[[{new_num}]]") return "".join(new_refs) # 텍스트에서 참고 자료 번호 변환 modified_text = re.sub(r"\[\[([^\]]+)\]\]", replace_references, text) return modified_text, url_mapping def chunk_formatting(self, contexts): if len(contexts) == 0: return "" formatted_results = [] url_list = [] for i, result in enumerate(contexts): formatted_results.append(f"참고자료{i+1} : {result['content']}") url_list.append(f"{result['url']}") return "\n".join(formatted_results), url_list def _build_nodes(self): @T def initialize_state(state: GraphState) -> Dict: """Initialize the state.""" result = { "question": state["question"], } return result @T def search_internet(state: GraphState) -> GraphState: # Tavily 검색 도구 설정 tavily_search_tool = TavilySearch( max_results=5, include_domains=[ "biblehub.com", "biblegateway.com", "gotquestions.org", "bible.org", "christianity.com", ], ) # 사용자 질문으로 검색 수행 search_results = tavily_search_tool.invoke(state["question"]) # 검색 결과가 없는 경우 처리 if not search_results["results"]: return { "verse": { "answer": "죄송합니다. 해당 질문에 대한 검색 결과를 찾을 수 없습니다. 다른 방식으로 질문해 주시겠습니까?", "reference": [], }, "chunk": [], } result = {"contexts": search_results["results"]} return result @T def generate_response(state: GraphState) -> Dict: """ Generate a response based on the retrieved verses. """ # 프롬프트 정의 수정 prompt = ChatPromptTemplate.from_messages( [ ( "system", dedent( """ # Role 당신은 성경에 대한 질문을 친절하게 알려주는 AI입니다. 유저의 질문을 잘 읽고, 관련해서 필요한 자료를 참고해서 적절한 대답을 해줘! 검색 결과를 바탕으로 성경적인 관점에서 답변을 작성해주세요. 답변을 줄때 논문의 형태처럼 참고한 자료의 번호를 나타내줘 [[*]] 형태로 적어줘! "~~[[1]] ~~~[[2,3]] "형태로 나타내줘! (참고자료 번호는 반드시 reference에 있는 내용과 동일해야해!) 답변은 다음과 같은 JSON 형식으로 작성해주세요: {{ "answer": "답변 내용을 여기에 작성", }} """ ), ), MessagesPlaceholder(variable_name="messages"), ("human", "question: {question}"), ("ai", "검색 결과: {search_results}"), ] ) formatted_result, url_list = self.chunk_formatting(state["contexts"]) # LLM으로 응답 생성 chain = prompt | llm | JsonOutputParser() response = chain.invoke( { "messages": state.get("messages", []), "question": state["question"], "search_results": formatted_result, } ) response, references = self.match_references(response["answer"], url_list) result = {"response": response, "references": references} return result @T def update_history(state: GraphState) -> GraphState: """ Update the history of the conversation. """ if "messages" not in state or not isinstance(state["messages"], list): log_info("testse") state["messages"] = [] # 사용자의 질문과 AI의 응답을 대화 내역에 추가합니다. state["messages"].append(("human", state["question"])) state["messages"].append(("ai", state["response"])) self.logger.insert( question=state["question"], answer=state["response"], history=state["messages"], created_at=datetime.now().isoformat(), ) log_info(state["messages"]) log_info(state["references"]) return {"messages": state["messages"], "references": state["references"]} return { "initialize_state": initialize_state, "search_internet": search_internet, "update_history": update_history, "generate_response": generate_response, } def _build_graph(self) -> CompiledStateGraph: workflow = StateGraph(GraphState) # Define the nodes nodes = self._build_nodes() workflow.add_node("initialize_state", nodes["initialize_state"]) workflow.add_node("search_internet", nodes["search_internet"]) workflow.add_node("generate_response", nodes["generate_response"]) workflow.add_node("update_history", nodes["update_history"]) workflow.add_edge(START, "initialize_state") workflow.add_edge("initialize_state", "search_internet") workflow.add_edge("search_internet", "generate_response") workflow.add_edge("generate_response", "update_history") workflow.add_edge("update_history", END) checkpointer = MemorySaver() # Compile workflow = workflow.compile(checkpointer=checkpointer) return workflow def _invoke_agent( self, agent: Runnable, input: dict, log_prompt: bool = False ) -> Any: """Invoke the agent.""" return agent.invoke(input) @T def invoke( self, question: str, thread_id: str = "test", user_id: str = "test" ) -> dict: """Invoke the graph.""" for output in self.graph.stream( {"question": question}, config={"configurable": {"thread_id": thread_id, "user_id": user_id}}, ): pass return { "answer": output["update_history"]["messages"][-1][-1], "references": output["update_history"]["references"], } @T async def astream( self, question: str, thread_id: str = "test", user_id: str = "test" ) -> AsyncGenerator[str, None]: """Asynchronous stream the graph.""" async for event in self.graph.astream_events( {"question": question}, configurable={"thread_id": thread_id} ): kind = event["event"] if kind == "on_chat_model_stream": # Print only in the last node node_name = event["metadata"]["langgraph_node"] if node_name == LAST_NODE_NAME: if ( hasattr(event["data"]["chunk"], "content") and event["data"]["chunk"].content ): addition = event["data"]["chunk"].content yield addition @T def stream(self, question: str, thread_id: str = "test", user_id: str = "test"): """Synchronous stream the graph.""" for event in self.graph.stream({"question": question}): kind = event["event"] if kind == "on_chat_model_stream": # Print only in the last node node_name = event["metadata"]["langgraph_node"] if node_name == LAST_NODE_NAME: if ( hasattr(event["data"]["chunk"], "content") and event["data"]["chunk"].content ): addition = event["data"]["chunk"].content yield addition if __name__ == "__main__": import asyncio qa_agent = QAAgent() qa_agent.invoke("하나님이 죄를 창조하신 이유?")