xangma commited on
Commit ·
92134d0
1
Parent(s): 572a6c9
removing this for now
Browse files
chain.py
CHANGED
|
@@ -15,32 +15,6 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
| 15 |
from langchain.chains.llm import LLMChain
|
| 16 |
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
|
| 17 |
from langchain.prompts.prompt import PromptTemplate
|
| 18 |
-
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
| 19 |
-
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
| 20 |
-
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
| 21 |
-
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
| 22 |
-
|
| 23 |
-
class ConversationalRetrievalChainWithGoogleSearch(ConversationalRetrievalChain):
|
| 24 |
-
google_search_tool: GoogleSearchAPIWrapper
|
| 25 |
-
|
| 26 |
-
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
| 27 |
-
# Get documents from the retriever
|
| 28 |
-
docs_from_retriever = self.retriever.get_relevant_documents(question)
|
| 29 |
-
|
| 30 |
-
# Get search results from Google Search
|
| 31 |
-
search_results = self.google_search_tool.results(question, num_results=self.google_search_tool.k)
|
| 32 |
-
|
| 33 |
-
# Create documents from the search results
|
| 34 |
-
docs_from_search = []
|
| 35 |
-
for result in search_results:
|
| 36 |
-
content = result.get("snippet", "")
|
| 37 |
-
metadata = {"title": result["title"], "link": result["link"]}
|
| 38 |
-
docs_from_search.append(Document(page_content=content, metadata=metadata))
|
| 39 |
-
|
| 40 |
-
# Combine both lists of documents
|
| 41 |
-
docs = docs_from_retriever + docs_from_search
|
| 42 |
-
|
| 43 |
-
return self._reduce_tokens_below_limit(docs)
|
| 44 |
|
| 45 |
def get_new_chain1(vectorstore, vectorstore_radio, model_selector, k_textbox, search_type_selector, max_tokens_textbox) -> Chain:
|
| 46 |
retriever = None
|
|
@@ -91,18 +65,8 @@ def get_new_chain1(vectorstore, vectorstore_radio, model_selector, k_textbox, se
|
|
| 91 |
# memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
|
| 92 |
memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
|
| 93 |
|
| 94 |
-
google_search_tool = GoogleSearchAPIWrapper(search_engine = "google", k = int(int(k_textbox)/2))
|
| 95 |
-
|
| 96 |
qa_orig = ConversationalRetrievalChain(
|
| 97 |
retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator, verbose=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
|
| 98 |
-
|
| 99 |
-
retriever=retriever,
|
| 100 |
-
memory=memory,
|
| 101 |
-
combine_docs_chain=doc_chain,
|
| 102 |
-
question_generator=question_generator,
|
| 103 |
-
google_search_tool=google_search_tool,
|
| 104 |
-
verbose=True,
|
| 105 |
-
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()])
|
| 106 |
-
)
|
| 107 |
qa = qa_orig
|
| 108 |
return qa
|
|
|
|
| 15 |
from langchain.chains.llm import LLMChain
|
| 16 |
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
|
| 17 |
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def get_new_chain1(vectorstore, vectorstore_radio, model_selector, k_textbox, search_type_selector, max_tokens_textbox) -> Chain:
|
| 20 |
retriever = None
|
|
|
|
| 65 |
# memory = ConversationKGMemory(llm=llm, input_key="question", output_key="answer")
|
| 66 |
memory = ConversationBufferWindowMemory(input_key="question", output_key="answer", k=5)
|
| 67 |
|
|
|
|
|
|
|
| 68 |
qa_orig = ConversationalRetrievalChain(
|
| 69 |
retriever=retriever, memory=memory, combine_docs_chain=doc_chain, question_generator=question_generator, verbose=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
|
| 70 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
qa = qa_orig
|
| 72 |
return qa
|