|
|
from operator import itemgetter |
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda |
|
|
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser |
|
|
from langchain_anthropic import ChatAnthropic |
|
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
from utils import _combine_documents, Retriever |
|
|
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE |
|
|
from tools import Retrieve |
|
|
|
|
|
|
|
|
|
|
|
class Agent(): |
|
|
def __init__( |
|
|
self, |
|
|
model_name:str = "gpt-4-turbo", |
|
|
system_template: str = _AGENT_SYSTEM_TEMPLATE, |
|
|
temperature: float = 0.0, |
|
|
) -> None: |
|
|
self.model_name = model_name |
|
|
self.system_template = system_template |
|
|
self.temperature = temperature |
|
|
|
|
|
self.runnable = self._create_runnable() |
|
|
pass |
|
|
|
|
|
def _create_runnable(self) -> RunnableParallel: |
|
|
prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", self.system_template), |
|
|
MessagesPlaceholder("chat_history", optional = True), |
|
|
("human", "{query}"), |
|
|
] |
|
|
) |
|
|
|
|
|
if "gpt" in self.model_name: |
|
|
model = ChatOpenAI( |
|
|
name = "agent", |
|
|
streaming = True, |
|
|
model = self.model_name, |
|
|
temperature = self.temperature, |
|
|
).bind_tools([Retrieve]) |
|
|
|
|
|
elif "claude" in self.model_name: |
|
|
model = ChatAnthropic( |
|
|
name = "agent", |
|
|
streaming = True, |
|
|
model = self.model_name, |
|
|
temperature = self.temperature, |
|
|
).bind_tools([Retrieve]) |
|
|
|
|
|
agent_runnable = ( |
|
|
prompt |
|
|
| model |
|
|
| ToolsAgentOutputParser() |
|
|
) |
|
|
|
|
|
return agent_runnable |
|
|
|
|
|
|
|
|
class Answerer(): |
|
|
def __init__( |
|
|
self, |
|
|
model_name:str = "gpt-4-turbo", |
|
|
collection_index:int = 0, |
|
|
use_doctrines:bool = True, |
|
|
rewrite:bool = True, |
|
|
search_type:str = "similarity", |
|
|
similarity_threshold:float = 0.0, |
|
|
k:int = 15, |
|
|
temperature:float = 0.0, |
|
|
system_template:str = _ANSWERER_SYSTEM_TEMPLATE, |
|
|
) -> None: |
|
|
|
|
|
self.model_name = model_name |
|
|
self.collection_index = collection_index |
|
|
self.use_doctrines = use_doctrines |
|
|
self.rewrite = rewrite |
|
|
self.search_type = search_type |
|
|
self.similarity_threshold = similarity_threshold |
|
|
self.k = k |
|
|
self.temperature = temperature |
|
|
self.system_template = system_template |
|
|
|
|
|
self.runnable = self._create_runnable() |
|
|
|
|
|
pass |
|
|
|
|
|
def _create_runnable(self) -> RunnableParallel: |
|
|
|
|
|
vectorstore = Retriever( |
|
|
collection_index = self.collection_index, |
|
|
use_doctrines = self.use_doctrines, |
|
|
search_type = self.search_type, |
|
|
k = self.k, |
|
|
similarity_threshold = self.similarity_threshold, |
|
|
) |
|
|
|
|
|
_retrieved_docs = RunnablePassthrough.assign( |
|
|
docs = itemgetter("query") | RunnableLambda(vectorstore._retrieve), |
|
|
) |
|
|
|
|
|
ANSWER_PROMPT = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", self.system_template), |
|
|
("human", "{query}"), |
|
|
] |
|
|
) |
|
|
|
|
|
if "gpt" in self.model_name: |
|
|
model = ChatOpenAI( |
|
|
name = "answerer", |
|
|
streaming = True, |
|
|
model = self.model_name, |
|
|
temperature = self.temperature, |
|
|
) |
|
|
elif "claude" in self.model_name: |
|
|
model = ChatAnthropic( |
|
|
name = "answerer", |
|
|
streaming = True, |
|
|
model = self.model_name, |
|
|
temperature = self.temperature, |
|
|
) |
|
|
|
|
|
_answer = { |
|
|
"answer": |
|
|
RunnablePassthrough.assign( |
|
|
context = lambda x: _combine_documents(x["docs"]), |
|
|
) |
|
|
| ANSWER_PROMPT |
|
|
| model, |
|
|
"docs": itemgetter("docs"), |
|
|
"standalone_question": itemgetter("query"), |
|
|
} |
|
|
|
|
|
chain = _retrieved_docs | _answer |
|
|
|
|
|
return chain |