File size: 1,233 Bytes
9797603 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
"""Create a conversational retrieval QA chain using LangChain and OpenAI.
"""
import os
os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false")
os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true")
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
def make_conversational_chain(retriever, model_name: str = "gpt-3.5-turbo", temperature: float = 0.0):
"""Return a ConversationalRetrievalChain configured with ChatOpenAI and a ConversationBufferMemory.
The ConversationBufferMemory is configured with output_key='answer' so that when the chain
returns multiple outputs (for example 'answer' and 'source_documents'), the memory will pick the
single 'answer' field to store in the chat history.
"""
assert os.environ.get("OPENAI_API_KEY"), "OPENAI_API_KEY must be set in environment"
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer")
chain = ConversationalRetrievalChain.from_llm(llm, retriever, memory=memory, return_source_documents=True)
return chain
|