Spaces:
Sleeping
Sleeping
| """Question-Answering Agent""" | |
| from typing import Any, AsyncGenerator, Dict | |
| from textwrap import dedent | |
| from datetime import datetime | |
| from langchain_core.runnables import Runnable | |
| from langchain_core.output_parsers import JsonOutputParser | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_tavily import TavilySearch | |
| from langgraph.graph import StateGraph, START, END, MessagesState | |
| from langgraph.graph.state import CompiledStateGraph | |
| from src.common.timer import T | |
| from src.common.logger import log_info | |
| from src.common.llm import llm | |
| from src.agents.web_logger import WebLogger | |
| LAST_NODE_NAME = "generate_response" | |
| class GraphState(MessagesState): | |
| question: str | |
| question_en: str | |
| path_state: str | |
| verse: str | |
| response: str | |
| contexts: list | |
| references: dict | |
| # remaining_steps: RemainingSteps | |
| class QAAgent: | |
| def __init__(self): | |
| self.graph = self._build_graph() | |
| self.logger = WebLogger() | |
| def match_references(self, text, url_list): | |
| import re | |
| # ์ฐธ๊ณ ์๋ฃ ๋ฒํธ ์ถ์ถ (์: [[3,4]] ํ์) | |
| references = [] | |
| for match in re.finditer(r"\[\[([^\]]+)\]\]", text): | |
| # ์ผํ๋ก ๊ตฌ๋ถ๋ ์ซ์๋ค์ ๊ฐ๋ณ ์ฐธ์กฐ๋ก ๋ถ๋ฆฌ | |
| nums = match.group(1).split(",") | |
| for num in nums: | |
| references.append(int(num.strip())) | |
| # ์ฐธ๊ณ ์๋ฃ ์์๋๋ก ์ ๋ ฌ (๋ฑ์ฅ ์์ ์ ์ง) | |
| unique_references = list(dict.fromkeys(references)) | |
| # ์ฐธ๊ณ ์๋ฃ ๋งคํ ์์ฑ (๋ฑ์ฅ ์์๋๋ก 1๋ถํฐ ๋ฒํธ ๋ถ์ฌ) | |
| reference_mapping = {} | |
| for i, ref in enumerate(unique_references): | |
| reference_mapping[ref] = i + 1 | |
| # URL ๋งคํ ์์ฑ (์๋ก์ด ์ฐธ์กฐ ๋ฒํธ์ ๋ง์ถฐ URL ์ฌ๋ฐฐ์ด) | |
| url_mapping = {} | |
| for orig_num, new_num in reference_mapping.items(): | |
| # ์๋ ์ธ๋ฑ์ค๋ 1๋ถํฐ ์์ํ๋ฏ๋ก -1 ํด์ค | |
| if orig_num >= 1 and orig_num < len(url_list): | |
| url_mapping[new_num] = url_list[orig_num - 1] | |
| # ํ ์คํธ ๋ด ์ฐธ๊ณ ์๋ฃ ๋ฒํธ ๋ณํ | |
| def replace_references(match): | |
| nums = match.group(1).split(",") | |
| # ๊ฐ ๋ฒํธ๋ฅผ ๊ฐ๋ณ [[]] ํ๊ทธ๋ก ๋ณํ | |
| new_refs = [] | |
| for num in nums: | |
| new_num = reference_mapping[int(num.strip())] | |
| new_refs.append(f"[[{new_num}]]") | |
| return "".join(new_refs) | |
| # ํ ์คํธ์์ ์ฐธ๊ณ ์๋ฃ ๋ฒํธ ๋ณํ | |
| modified_text = re.sub(r"\[\[([^\]]+)\]\]", replace_references, text) | |
| return modified_text, url_mapping | |
| def chunk_formatting(self, contexts): | |
| if len(contexts) == 0: | |
| return "" | |
| formatted_results = [] | |
| url_list = [] | |
| for i, result in enumerate(contexts): | |
| formatted_results.append(f"์ฐธ๊ณ ์๋ฃ{i+1} : {result['content']}") | |
| url_list.append(f"{result['url']}") | |
| return "\n".join(formatted_results), url_list | |
| def _build_nodes(self): | |
| def initialize_state(state: GraphState) -> Dict: | |
| """Initialize the state.""" | |
| result = { | |
| "question": state["question"], | |
| } | |
| return result | |
| def search_internet(state: GraphState) -> GraphState: | |
| # Tavily ๊ฒ์ ๋๊ตฌ ์ค์ | |
| tavily_search_tool = TavilySearch( | |
| max_results=5, | |
| include_domains=[ | |
| "biblehub.com", | |
| "biblegateway.com", | |
| "gotquestions.org", | |
| "bible.org", | |
| "christianity.com", | |
| ], | |
| ) | |
| # ์ฌ์ฉ์ ์ง๋ฌธ์ผ๋ก ๊ฒ์ ์ํ | |
| search_results = tavily_search_tool.invoke(state["question"]) | |
| # ๊ฒ์ ๊ฒฐ๊ณผ๊ฐ ์๋ ๊ฒฝ์ฐ ์ฒ๋ฆฌ | |
| if not search_results["results"]: | |
| return { | |
| "verse": { | |
| "answer": "์ฃ์กํฉ๋๋ค. ํด๋น ์ง๋ฌธ์ ๋ํ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. ๋ค๋ฅธ ๋ฐฉ์์ผ๋ก ์ง๋ฌธํด ์ฃผ์๊ฒ ์ต๋๊น?", | |
| "reference": [], | |
| }, | |
| "chunk": [], | |
| } | |
| result = {"contexts": search_results["results"]} | |
| return result | |
| def generate_response(state: GraphState) -> Dict: | |
| """ | |
| Generate a response based on the retrieved verses. | |
| """ | |
| # ํ๋กฌํํธ ์ ์ ์์ | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| dedent( | |
| """ | |
| # Role | |
| ๋น์ ์ ์ฑ๊ฒฝ์ ๋ํ ์ง๋ฌธ์ ์น์ ํ๊ฒ ์๋ ค์ฃผ๋ AI์ ๋๋ค. ์ ์ ์ ์ง๋ฌธ์ ์ ์ฝ๊ณ , ๊ด๋ จํด์ ํ์ํ ์๋ฃ๋ฅผ ์ฐธ๊ณ ํด์ ์ ์ ํ ๋๋ต์ ํด์ค! | |
| ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ์ผ๋ก ์ฑ๊ฒฝ์ ์ธ ๊ด์ ์์ ๋ต๋ณ์ ์์ฑํด์ฃผ์ธ์. | |
| ๋ต๋ณ์ ์ค๋ ๋ ผ๋ฌธ์ ํํ์ฒ๋ผ ์ฐธ๊ณ ํ ์๋ฃ์ ๋ฒํธ๋ฅผ ๋ํ๋ด์ค [[*]] ํํ๋ก ์ ์ด์ค! "~~[[1]] ~~~[[2,3]] "ํํ๋ก ๋ํ๋ด์ค! | |
| (์ฐธ๊ณ ์๋ฃ ๋ฒํธ๋ ๋ฐ๋์ reference์ ์๋ ๋ด์ฉ๊ณผ ๋์ผํด์ผํด!) | |
| ๋ต๋ณ์ ๋ค์๊ณผ ๊ฐ์ JSON ํ์์ผ๋ก ์์ฑํด์ฃผ์ธ์: | |
| {{ | |
| "answer": "๋ต๋ณ ๋ด์ฉ์ ์ฌ๊ธฐ์ ์์ฑ", | |
| }} | |
| """ | |
| ), | |
| ), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ("human", "question: {question}"), | |
| ("ai", "๊ฒ์ ๊ฒฐ๊ณผ: {search_results}"), | |
| ] | |
| ) | |
| formatted_result, url_list = self.chunk_formatting(state["contexts"]) | |
| # LLM์ผ๋ก ์๋ต ์์ฑ | |
| chain = prompt | llm | JsonOutputParser() | |
| response = chain.invoke( | |
| { | |
| "messages": state.get("messages", []), | |
| "question": state["question"], | |
| "search_results": formatted_result, | |
| } | |
| ) | |
| response, references = self.match_references(response["answer"], url_list) | |
| result = {"response": response, "references": references} | |
| return result | |
| def update_history(state: GraphState) -> GraphState: | |
| """ | |
| Update the history of the conversation. | |
| """ | |
| if "messages" not in state or not isinstance(state["messages"], list): | |
| log_info("testse") | |
| state["messages"] = [] | |
| # ์ฌ์ฉ์์ ์ง๋ฌธ๊ณผ AI์ ์๋ต์ ๋ํ ๋ด์ญ์ ์ถ๊ฐํฉ๋๋ค. | |
| state["messages"].append(("human", state["question"])) | |
| state["messages"].append(("ai", state["response"])) | |
| self.logger.insert( | |
| question=state["question"], | |
| answer=state["response"], | |
| history=state["messages"], | |
| created_at=datetime.now().isoformat(), | |
| ) | |
| log_info(state["messages"]) | |
| log_info(state["references"]) | |
| return {"messages": state["messages"], "references": state["references"]} | |
| return { | |
| "initialize_state": initialize_state, | |
| "search_internet": search_internet, | |
| "update_history": update_history, | |
| "generate_response": generate_response, | |
| } | |
| def _build_graph(self) -> CompiledStateGraph: | |
| workflow = StateGraph(GraphState) | |
| # Define the nodes | |
| nodes = self._build_nodes() | |
| workflow.add_node("initialize_state", nodes["initialize_state"]) | |
| workflow.add_node("search_internet", nodes["search_internet"]) | |
| workflow.add_node("generate_response", nodes["generate_response"]) | |
| workflow.add_node("update_history", nodes["update_history"]) | |
| workflow.add_edge(START, "initialize_state") | |
| workflow.add_edge("initialize_state", "search_internet") | |
| workflow.add_edge("search_internet", "generate_response") | |
| workflow.add_edge("generate_response", "update_history") | |
| workflow.add_edge("update_history", END) | |
| checkpointer = MemorySaver() | |
| # Compile | |
| workflow = workflow.compile(checkpointer=checkpointer) | |
| return workflow | |
| def _invoke_agent( | |
| self, agent: Runnable, input: dict, log_prompt: bool = False | |
| ) -> Any: | |
| """Invoke the agent.""" | |
| return agent.invoke(input) | |
| def invoke( | |
| self, question: str, thread_id: str = "test", user_id: str = "test" | |
| ) -> dict: | |
| """Invoke the graph.""" | |
| for output in self.graph.stream( | |
| {"question": question}, | |
| config={"configurable": {"thread_id": thread_id, "user_id": user_id}}, | |
| ): | |
| pass | |
| return { | |
| "answer": output["update_history"]["messages"][-1][-1], | |
| "references": output["update_history"]["references"], | |
| } | |
| async def astream( | |
| self, question: str, thread_id: str = "test", user_id: str = "test" | |
| ) -> AsyncGenerator[str, None]: | |
| """Asynchronous stream the graph.""" | |
| async for event in self.graph.astream_events( | |
| {"question": question}, configurable={"thread_id": thread_id} | |
| ): | |
| kind = event["event"] | |
| if kind == "on_chat_model_stream": | |
| # Print only in the last node | |
| node_name = event["metadata"]["langgraph_node"] | |
| if node_name == LAST_NODE_NAME: | |
| if ( | |
| hasattr(event["data"]["chunk"], "content") | |
| and event["data"]["chunk"].content | |
| ): | |
| addition = event["data"]["chunk"].content | |
| yield addition | |
| def stream(self, question: str, thread_id: str = "test", user_id: str = "test"): | |
| """Synchronous stream the graph.""" | |
| for event in self.graph.stream({"question": question}): | |
| kind = event["event"] | |
| if kind == "on_chat_model_stream": | |
| # Print only in the last node | |
| node_name = event["metadata"]["langgraph_node"] | |
| if node_name == LAST_NODE_NAME: | |
| if ( | |
| hasattr(event["data"]["chunk"], "content") | |
| and event["data"]["chunk"].content | |
| ): | |
| addition = event["data"]["chunk"].content | |
| yield addition | |
| if __name__ == "__main__": | |
| import asyncio | |
| qa_agent = QAAgent() | |
| qa_agent.invoke("ํ๋๋์ด ์ฃ๋ฅผ ์ฐฝ์กฐํ์ ์ด์ ?") | |