Spaces:
Sleeping
Sleeping
| import requests | |
| from bs4 import BeautifulSoup | |
| from langchain.callbacks import get_openai_callback | |
| from langchain.chains import RetrievalQAWithSourcesChain | |
| from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings | |
| from langchain.llms import OpenAIChat, HuggingFaceHub | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.vectorstores import FAISS | |
| from streamlit.logger import get_logger | |
| from utils.constants import ( | |
| KNOWLEDGEBASE_DIR, | |
| AssistantType, | |
| BS_HTML_PARSER, | |
| TEXT_TAG, | |
| SOURCE_TAG, | |
| ANSWER_TAG, | |
| QUESTION_TAG, | |
| HF_TEXT_GENERATION_REPO_ID, | |
| EmbeddingType, | |
| TOTAL_TOKENS_TAG, | |
| PROMPT_TOKENS_TAG, | |
| COMPLETION_TOKENS_TAG, | |
| TOTAL_COST_TAG, | |
| OPENAI_CHAT_COMPLETIONS_MODEL, | |
| ) | |
| logger = get_logger(__name__) | |
| def extract_text_from(url_: str): | |
| html = requests.get(url_).text | |
| soup = BeautifulSoup(html, features=BS_HTML_PARSER) | |
| text = soup.get_text() | |
| lines = (line.strip() for line in text.splitlines()) | |
| return "\n".join(line for line in lines if line) | |
| def create_knowledgebase( | |
| urls: list, | |
| assistant_type: AssistantType, | |
| embedding_type: EmbeddingType, | |
| embedding_api_key: str, | |
| knowledgebase_name: str, | |
| ): | |
| pages: list[dict] = [] | |
| for url in urls: | |
| pages.append({TEXT_TAG: extract_text_from(url_=url), SOURCE_TAG: url}) | |
| chunk_size = 500 | |
| chunk_overlap = 30 | |
| if assistant_type == AssistantType.OPENAI: | |
| # # override the default chunk configs | |
| # chunk_size = 1500 | |
| # chunk_overlap = 200 | |
| if embedding_type == EmbeddingType.HUGGINGFACE: | |
| embeddings = HuggingFaceHubEmbeddings( | |
| huggingfacehub_api_token=embedding_api_key | |
| ) | |
| logger.info(f"Using `hf` embeddings") | |
| else: | |
| embeddings = OpenAIEmbeddings(openai_api_key=embedding_api_key) | |
| logger.info(f"Using `openai` embeddings") | |
| else: | |
| embeddings = HuggingFaceHubEmbeddings( | |
| huggingfacehub_api_token=embedding_api_key | |
| ) | |
| logger.info( | |
| f"Since the assistant type is set to `hf`, `hf` embeddings are used by default." | |
| ) | |
| text_splitter = CharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator="\n" | |
| ) | |
| docs, metadata = [], [] | |
| for page in pages: | |
| splits = text_splitter.split_text(page[TEXT_TAG]) | |
| docs.extend(splits) | |
| metadata.extend([{SOURCE_TAG: page[SOURCE_TAG]}] * len(splits)) | |
| print(f"Split {page[SOURCE_TAG]} into {len(splits)} chunks") | |
| vectorstore = FAISS.from_texts(texts=docs, embedding=embeddings, metadatas=metadata) | |
| vectorstore.save_local(folder_path=KNOWLEDGEBASE_DIR, index_name=knowledgebase_name) | |
| def load_vectorstore( | |
| embedding_type: EmbeddingType, | |
| embedding_api_key: str, | |
| knowledgebase_name: str, | |
| ): | |
| if embedding_type == EmbeddingType.OPENAI: | |
| embeddings = OpenAIEmbeddings(openai_api_key=embedding_api_key) | |
| else: | |
| embeddings = HuggingFaceHubEmbeddings( | |
| huggingfacehub_api_token=embedding_api_key | |
| ) | |
| logger.info( | |
| f"Since the assistant type is set to `hf`, `hf` embeddings are used by default." | |
| ) | |
| store = FAISS.load_local( | |
| folder_path=KNOWLEDGEBASE_DIR, | |
| embeddings=embeddings, | |
| index_name=knowledgebase_name, | |
| ) | |
| return store | |
| def construct_query_response(result: dict) -> dict: | |
| return {ANSWER_TAG: result} | |
| class Knowledgebase: | |
| def __init__( | |
| self, | |
| assistant_type: AssistantType, | |
| embedding_type: EmbeddingType, | |
| assistant_api_key: str, | |
| embedding_api_key: str, | |
| knowledgebase_name: str, | |
| ): | |
| self.assistant_type = assistant_type | |
| self.embedding_type = embedding_type | |
| self.assistant_api_key = assistant_api_key | |
| self.embedding_api_key = embedding_api_key | |
| self.knowledgebase = load_vectorstore( | |
| embedding_type=embedding_type, | |
| embedding_api_key=embedding_api_key, | |
| knowledgebase_name=knowledgebase_name, | |
| ) | |
| def query_knowledgebase(self, query: str) -> tuple[dict, dict]: | |
| try: | |
| logger.info( | |
| f"The assistant API key for the current session: ***{self.assistant_api_key[-4:]}" | |
| ) | |
| logger.info( | |
| f"The embedding API key for the current session: ***{self.embedding_api_key[-4:]}" | |
| ) | |
| query = query.strip() | |
| if not query: | |
| return { | |
| ANSWER_TAG: "Oh snap! did you hit send accidentally, because I can't see any questions 🤔", | |
| }, {} | |
| if self.assistant_type == AssistantType.OPENAI: | |
| llm = OpenAIChat( | |
| model_name=OPENAI_CHAT_COMPLETIONS_MODEL, | |
| temperature=0, | |
| verbose=True, | |
| openai_api_key=self.assistant_api_key, | |
| ) | |
| # # this is deprecated | |
| # chain = VectorDBQAWithSourcesChain.from_llm( | |
| # llm=llm, | |
| # vectorstore=self.knowledgebase, | |
| # max_tokens_limit=2048, | |
| # k=2, | |
| # reduce_k_below_max_tokens=True, | |
| # ) | |
| chain = RetrievalQAWithSourcesChain.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=self.knowledgebase.as_retriever(), | |
| reduce_k_below_max_tokens=True, | |
| chain_type_kwargs={"verbose": True}, | |
| ) | |
| else: | |
| llm = HuggingFaceHub( | |
| repo_id=HF_TEXT_GENERATION_REPO_ID, | |
| model_kwargs={"temperature": 0.5, "max_length": 64}, | |
| huggingfacehub_api_token=self.assistant_api_key, | |
| verbose=True, | |
| ) | |
| chain = RetrievalQAWithSourcesChain.from_chain_type( | |
| llm=llm, | |
| chain_type="refine", | |
| retriever=self.knowledgebase.as_retriever(), | |
| max_tokens_limit=1024, | |
| reduce_k_below_max_tokens=True, | |
| chain_type_kwargs={"verbose": True}, | |
| ) | |
| with get_openai_callback() as cb: | |
| result = chain({QUESTION_TAG: query}) | |
| print(f"Total Tokens: {cb.total_tokens}") | |
| print(f"Prompt Tokens: {cb.prompt_tokens}") | |
| print(f"Completion Tokens: {cb.completion_tokens}") | |
| print(f"Total Cost (USD): ${cb.total_cost}") | |
| metadata = { | |
| TOTAL_TOKENS_TAG: cb.total_tokens, | |
| PROMPT_TOKENS_TAG: cb.prompt_tokens, | |
| COMPLETION_TOKENS_TAG: cb.completion_tokens, | |
| TOTAL_COST_TAG: cb.total_cost, | |
| } | |
| return result, metadata | |
| except Exception as e: | |
| logger.error(f"{e.__class__.__name__}: {e}") | |
| return {ANSWER_TAG: f"{e.__class__.__name__}: {e}"}, {} | |