| | from typing import Any, Dict, List, Union, Tuple, Optional |
| | from langchain_core.messages import ( |
| | BaseMessage, |
| | AIMessage, |
| | FunctionMessage, |
| | HumanMessage, |
| | ) |
| |
|
| | from langchain_core.prompts.base import BasePromptTemplate, format_document |
| | from langchain_core.prompts.chat import MessagesPlaceholder |
| | from langchain_core.output_parsers import StrOutputParser |
| | from langchain_core.output_parsers.base import BaseOutputParser |
| | from langchain_core.retrievers import BaseRetriever, RetrieverOutput |
| | from langchain_core.language_models import LanguageModelLike |
| | from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough |
| | from langchain_core.runnables.history import RunnableWithMessageHistory |
| | from langchain_core.runnables.utils import ConfigurableFieldSpec |
| | from langchain_core.chat_history import BaseChatMessageHistory |
| | from langchain_core.pydantic_v1 import BaseModel, Field |
| | from langchain.chains.combine_documents.base import ( |
| | DEFAULT_DOCUMENT_PROMPT, |
| | DEFAULT_DOCUMENT_SEPARATOR, |
| | DOCUMENTS_KEY, |
| | BaseCombineDocumentsChain, |
| | _validate_prompt, |
| | ) |
| | from langchain.chains.llm import LLMChain |
| | from langchain_core.callbacks import Callbacks |
| | from langchain_core.documents import Document |
| |
|
| |
|
| | CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage] |
| |
|
| | from langchain_core.runnables.config import RunnableConfig |
| | from langchain_core.messages import BaseMessage |
| |
|
| |
|
| | class CustomRunnableWithHistory(RunnableWithMessageHistory): |
| | def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: |
| | """ |
| | Get the last k conversations from the message history. |
| | |
| | Args: |
| | input (Any): The input data. |
| | config (RunnableConfig): The runnable configuration. |
| | |
| | Returns: |
| | List[BaseMessage]: The last k conversations. |
| | """ |
| | hist: BaseChatMessageHistory = config["configurable"]["message_history"] |
| | messages = hist.messages.copy() |
| |
|
| | if not self.history_messages_key: |
| | |
| | messages += self._get_input_messages(input) |
| |
|
| | |
| | if config["configurable"]["memory_window"] == 0: |
| | messages = [] |
| | else: |
| | messages = messages[-2 * config["configurable"]["memory_window"] :] |
| | return messages |
| |
|
| |
|
| | def _get_chat_history(chat_history: List[CHAT_TURN_TYPE], n: int = None) -> str: |
| | """ |
| | Convert chat history to a formatted string. |
| | |
| | Args: |
| | chat_history (List[CHAT_TURN_TYPE]): The chat history. |
| | |
| | Returns: |
| | str: The formatted chat history. |
| | """ |
| | _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "} |
| | buffer = "" |
| | if n is not None: |
| | |
| | turns_to_take = n * 2 |
| | chat_history = chat_history[-turns_to_take:] |
| | for dialogue_turn in chat_history: |
| | if isinstance(dialogue_turn, BaseMessage): |
| | role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ") |
| | buffer += f"\n{role_prefix}{dialogue_turn.content}" |
| | elif isinstance(dialogue_turn, tuple): |
| | human = "Student: " + dialogue_turn[0] |
| | ai = "AI Tutor: " + dialogue_turn[1] |
| | buffer += "\n" + "\n".join([human, ai]) |
| | else: |
| | raise ValueError( |
| | f"Unsupported chat history format: {type(dialogue_turn)}." |
| | f" Full chat history: {chat_history} " |
| | ) |
| | return buffer |
| |
|
| |
|
| | class InMemoryHistory(BaseChatMessageHistory, BaseModel): |
| | """In-memory implementation of chat message history.""" |
| |
|
| | messages: List[BaseMessage] = Field(default_factory=list) |
| |
|
| | def add_messages(self, messages: List[BaseMessage]) -> None: |
| | """Add a list of messages to the store.""" |
| | self.messages.extend(messages) |
| |
|
| | def clear(self) -> None: |
| | """Clear the message history.""" |
| | self.messages = [] |
| |
|
| | def __len__(self) -> int: |
| | """Return the number of messages.""" |
| | return len(self.messages) |
| |
|
| | def get_last_n_conversations(self, n: int) -> "InMemoryHistory": |
| | """Return a new InMemoryHistory object with the last n conversations from the message history. |
| | |
| | Args: |
| | n (int): The number of last conversations to return. If 0, return an empty history. |
| | |
| | Returns: |
| | InMemoryHistory: A new InMemoryHistory object containing the last n conversations. |
| | """ |
| | if n == 0: |
| | return InMemoryHistory() |
| | |
| | num_messages = n * 2 |
| | last_messages = self.messages[-num_messages:] |
| | return InMemoryHistory(messages=last_messages) |
| |
|
| |
|
| | def create_history_aware_retriever( |
| | llm: LanguageModelLike, |
| | retriever: BaseRetriever, |
| | prompt: BasePromptTemplate, |
| | ) -> Runnable[Dict[str, Any], RetrieverOutput]: |
| | """Create a chain that takes conversation history and returns documents.""" |
| | if "input" not in prompt.input_variables: |
| | raise ValueError( |
| | "Expected `input` to be a prompt variable, " |
| | f"but got {prompt.input_variables}" |
| | ) |
| |
|
| | retrieve_documents = RunnableBranch( |
| | ( |
| | lambda x: not x["chat_history"], |
| | (lambda x: x["input"]) | retriever, |
| | ), |
| | prompt | llm | StrOutputParser() | retriever, |
| | ).with_config(run_name="chat_retriever_chain") |
| |
|
| | return retrieve_documents |
| |
|
| |
|
| | def create_stuff_documents_chain( |
| | llm: LanguageModelLike, |
| | prompt: BasePromptTemplate, |
| | output_parser: Optional[BaseOutputParser] = None, |
| | document_prompt: Optional[BasePromptTemplate] = None, |
| | document_separator: str = DEFAULT_DOCUMENT_SEPARATOR, |
| | ) -> Runnable[Dict[str, Any], Any]: |
| | """Create a chain for passing a list of Documents to a model.""" |
| | _validate_prompt(prompt) |
| | _document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT |
| | _output_parser = output_parser or StrOutputParser() |
| |
|
| | def format_docs(inputs: dict) -> str: |
| | return document_separator.join( |
| | format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY] |
| | ) |
| |
|
| | return ( |
| | RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config( |
| | run_name="format_inputs" |
| | ) |
| | | prompt |
| | | llm |
| | | _output_parser |
| | ).with_config(run_name="stuff_documents_chain") |
| |
|
| |
|
| | def create_retrieval_chain( |
| | retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]], |
| | combine_docs_chain: Runnable[Dict[str, Any], str], |
| | ) -> Runnable: |
| | """Create retrieval chain that retrieves documents and then passes them on.""" |
| | if not isinstance(retriever, BaseRetriever): |
| | retrieval_docs = retriever |
| | else: |
| | retrieval_docs = (lambda x: x["input"]) | retriever |
| |
|
| | retrieval_chain = ( |
| | RunnablePassthrough.assign( |
| | context=retrieval_docs.with_config(run_name="retrieve_documents"), |
| | ).assign(answer=combine_docs_chain) |
| | ).with_config(run_name="retrieval_chain") |
| |
|
| | return retrieval_chain |
| |
|