File size: 1,233 Bytes
9797603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
"""Create a conversational retrieval QA chain using LangChain and OpenAI.
"""
import os
os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false")
os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true")

from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain


def make_conversational_chain(retriever, model_name: str = "gpt-3.5-turbo", temperature: float = 0.0):
    """Return a ConversationalRetrievalChain configured with ChatOpenAI and a ConversationBufferMemory.

    The ConversationBufferMemory is configured with output_key='answer' so that when the chain
    returns multiple outputs (for example 'answer' and 'source_documents'), the memory will pick the
    single 'answer' field to store in the chat history.
    """
    assert os.environ.get("OPENAI_API_KEY"), "OPENAI_API_KEY must be set in environment"
    llm = ChatOpenAI(model_name=model_name, temperature=temperature)
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer")
    chain = ConversationalRetrievalChain.from_llm(llm, retriever, memory=memory, return_source_documents=True)
    return chain