Spaces:
Runtime error
Runtime error
xangma commited on
Commit ·
248a99d
1
Parent(s): e42e3e0
fixes
Browse files
chain.py
CHANGED
|
@@ -43,41 +43,10 @@ from typing import List, Optional, Any
|
|
| 43 |
import chromadb
|
| 44 |
from langchain.vectorstores import Chroma
|
| 45 |
|
| 46 |
-
class CustomChain(Chain, BaseModel):
|
| 47 |
-
|
| 48 |
-
vstore: Chroma
|
| 49 |
-
chain: BaseCombineDocumentsChain
|
| 50 |
-
key_word_extractor: Chain
|
| 51 |
-
|
| 52 |
-
@property
|
| 53 |
-
def input_keys(self) -> List[str]:
|
| 54 |
-
return ["question"]
|
| 55 |
-
|
| 56 |
-
@property
|
| 57 |
-
def output_keys(self) -> List[str]:
|
| 58 |
-
return ["answer"]
|
| 59 |
-
|
| 60 |
-
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
| 61 |
-
question = inputs["question"]
|
| 62 |
-
chat_history_str = _get_chat_history(inputs["chat_history"])
|
| 63 |
-
if chat_history_str:
|
| 64 |
-
new_question = self.key_word_extractor.run(
|
| 65 |
-
question=question, chat_history=chat_history_str
|
| 66 |
-
)
|
| 67 |
-
else:
|
| 68 |
-
new_question = question
|
| 69 |
-
print(new_question)
|
| 70 |
-
docs = self.vstore.similarity_search(new_question, k=10)
|
| 71 |
-
new_inputs = inputs.copy()
|
| 72 |
-
new_inputs["question"] = new_question
|
| 73 |
-
new_inputs["chat_history"] = chat_history_str
|
| 74 |
-
answer, _ = self.chain.combine_docs(docs, **new_inputs)
|
| 75 |
-
return {"answer": answer}
|
| 76 |
-
|
| 77 |
-
|
| 78 |
def get_new_chain1(vectorstore, model_selector, k_textbox) -> Chain:
|
| 79 |
max_tokens_dict = {'gpt-4': 2000, 'gpt-3.5-turbo': 1000}
|
| 80 |
|
|
|
|
| 81 |
_eg_template = """## Example:
|
| 82 |
|
| 83 |
Chat History:
|
|
@@ -106,8 +75,8 @@ Answer in Markdown:"""
|
|
| 106 |
|
| 107 |
# Construct a ChatVectorDBChain with a streaming llm for combine docs
|
| 108 |
# and a separate, non-streaming llm for question generation
|
| 109 |
-
llm = ChatOpenAI(client = None, temperature=0.7, model_name=
|
| 110 |
-
streaming_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=
|
| 111 |
|
| 112 |
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
| 113 |
doc_chain = load_qa_chain(streaming_llm, chain_type="stuff", prompt=QA_PROMPT)
|
|
@@ -119,14 +88,4 @@ Answer in Markdown:"""
|
|
| 119 |
qa = ConversationalRetrievalChain(
|
| 120 |
retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator)
|
| 121 |
|
| 122 |
-
|
| 123 |
-
return qa
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def _get_chat_history(chat_history: List[Tuple[str, str]]):
|
| 127 |
-
buffer = ""
|
| 128 |
-
for human_s, ai_s in chat_history:
|
| 129 |
-
human = f"Human: " + human_s
|
| 130 |
-
ai = f"Assistant: " + ai_s
|
| 131 |
-
buffer += "\n" + "\n".join([human, ai])
|
| 132 |
-
return buffer
|
|
|
|
| 43 |
import chromadb
|
| 44 |
from langchain.vectorstores import Chroma
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def get_new_chain1(vectorstore, model_selector, k_textbox) -> Chain:
|
| 47 |
max_tokens_dict = {'gpt-4': 2000, 'gpt-3.5-turbo': 1000}
|
| 48 |
|
| 49 |
+
# These templates aren't used for the moment.
|
| 50 |
_eg_template = """## Example:
|
| 51 |
|
| 52 |
Chat History:
|
|
|
|
| 75 |
|
| 76 |
# Construct a ChatVectorDBChain with a streaming llm for combine docs
|
| 77 |
# and a separate, non-streaming llm for question generation
|
| 78 |
+
llm = ChatOpenAI(client = None, temperature=0.7, model_name=model_selector)
|
| 79 |
+
streaming_llm = ChatOpenAI(client = None, streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0.7, model_name=model_selector, max_tokens=1000)
|
| 80 |
|
| 81 |
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
| 82 |
doc_chain = load_qa_chain(streaming_llm, chain_type="stuff", prompt=QA_PROMPT)
|
|
|
|
| 88 |
qa = ConversationalRetrievalChain(
|
| 89 |
retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator)
|
| 90 |
|
| 91 |
+
return qa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|